diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 1227b03dc..83436efc0 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -21,9 +21,9 @@ name: build on: # Triggers the workflow on push or pull request events but only for the master branch push: - branches: [ master ] + branches: [ "master" ] pull_request: - branches: [ master ] + branches: [ "master" ] # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -33,59 +33,57 @@ jobs: # This workflow contains a single job called "build" build: # The type of runner that the job will run on - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 strategy: matrix: - python-version: [3.7] - # tf-nightly has some pip version conflicts, so can't be installed. - # Use only numbered TF as of now. - # tf-version: ["2.4.*", "tf-nightly"] - tf-version: ["2.4.*"] + python-version: [ '3.10' ] + # Which tf-version run. + tf-version: [ '2.17.0' ] # Which set of tests to run. - trax-test: ["lib", "research"] + trax-test: [ 'lib','research' ] # Steps represent a sequence of tasks that will be executed as part of the job steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install -q -U setuptools numpy - python -m pip install flake8 pytest - if [[ ${{matrix.tf-version}} == "tf-nightly" ]]; then python -m pip install tf-nightly; else python -m pip install -q "tensorflow=="${{matrix.tf-version}}; fi - pip install -e .[tests,t5] - # # Lint with flake8 - # - name: Lint with flake8 - # run: | - # # stop the build if there are Python syntax errors or undefined names - # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - # Test out right now with only testing one directory. - - name: Test with pytest - run: | - TRAX_TEST=" ${{matrix.trax-test}}" ./oss_scripts/oss_tests.sh - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v3 + - name: Set up Python ${{matrix.python-version}} + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python-version}} + cache: 'pip' + - name: Install dependencies + env: + PIP_DISABLE_PIP_VERSION_CHECK: '1' + run: | + python -m pip install -U pip + # Install TensorFlow matching matrix version. + python -m pip install "tensorflow==${{ matrix.tf-version }}" + # Install package in editable mode with test and T5 extras (tests use T5 preprocessors). + python -m pip install -e .[tests,t5,rl] + # Test out right now with only testing one directory. + - name: Install trax package + run: | + python -m pip install -e . + - name: Test with pytest + working-directory: . + run: | + TRAX_TEST="${{matrix.trax-test}}" ./oss_scripts/oss_tests.sh + # The below step just reports the success or failure of tests as a "commit status". + # This is needed for copy bara integration. + - name: Report success or failure as github status + if: always() + shell: bash + run: | + status="${{ job.status }}" + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + curl -sS --request POST \ + --url https://api.github.com/repos/${{github.repository}}/statuses/${{github.sha}} \ + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + --header 'content-type: application/json' \ + --data '{ + "state": "'$lowercase_status'", + "target_url": "https://github.com/${{github.repository}}/actions/runs/${{github.run_id}}", + "description": "'$status'", + "context": "github-actions/build" + }' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..5bc0ab766 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,82 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ "1.5.1" ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ "1.5.1" ] + schedule: + - cron: '31 4 * * 1' + +jobs: + analyze: + name: Analyze + # Runner size impacts CodeQL analysis time. To learn more, please see: + # - https://gh.io/recommended-hardware-resources-for-running-codeql + # - https://gh.io/supported-runners-and-hardware-resources + # - https://gh.io/using-larger-runners + # Consider using larger runners for possible analysis time improvements. + runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} + timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' ] + # Use only 'java-kotlin' to analyze code written in Java, Kotlin or both + # Use only 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both + # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + + # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + # â„šī¸ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + + # If the Autobuild fails above, remove it and uncomment the following three lines. + # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. + + # - run: | + # echo "Run, Build Application using script" + # ./location_of_script_within_repo/buildscript.sh + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{matrix.language}}" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d9de2c1be..f6b5ff3c5 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -28,6 +28,6 @@ formats: all # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 + version: 3.10 install: - requirements: docs/requirements.txt diff --git a/.travis.yml b/.travis.yml index 0251cb069..50cfa6499 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,10 @@ git: depth: 3 quiet: true python: - - "3.6" + - "3.10" env: global: - - TF_VERSION="2.4.*" + - TF_VERSION="2.11.0" matrix: - TRAX_TEST="lib" - TRAX_TEST="research" diff --git a/README.md b/README.md index 33884979a..f91c5699d 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ version](https://badge.fury.io/py/trax.svg)](https://badge.fury.io/py/trax) [![GitHub Issues](https://img.shields.io/github/issues/google/trax.svg)](https://github.com/google/trax/issues) -![GitHub Build](https://github.com/google/trax/actions/workflows/build.yaml/badge.svg) +![GitHub Build](https://github.com/mmarcinmichal/trax/actions/workflows/build.yaml/badge.svg) [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) @@ -26,6 +26,15 @@ Here are a few example notebooks:- * [**trax.data API explained**](https://github.com/google/trax/blob/master/trax/examples/trax_data_Explained.ipynb) : Explains some of the major functions in the `trax.data` API * [**Named Entity Recognition using Reformer**](https://github.com/google/trax/blob/master/trax/examples/NER_using_Reformer.ipynb) : Uses a [Kaggle dataset](https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus) for implementing Named Entity Recognition using the [Reformer](https://arxiv.org/abs/2001.04451) architecture. * [**Deep N-Gram models**](https://github.com/google/trax/blob/master/trax/examples/Deep_N_Gram_Models.ipynb) : Implementation of deep n-gram models trained on Shakespeares works +* **Graph neural networks**: baseline models available via + `trax.models.GraphConvNet`, + `trax.models.GraphEdgeNet` for node and edge updates, or the + attention-based `trax.models.GraphAttentionNet`. +* Example Python scripts using these GNNs for MNIST and IMDB classification are + in + [`resources/examples/python/gnn_mnist/train.py`](resources/examples/python/gnn_mnist/train.py) + and + [`resources/examples/python/gnn_imdb/train.py`](resources/examples/python/gnn_imdb/train.py). diff --git a/docs/.readthedocs.yaml b/docs/.readthedocs.yaml index d9de2c1be..f6b5ff3c5 100644 --- a/docs/.readthedocs.yaml +++ b/docs/.readthedocs.yaml @@ -28,6 +28,6 @@ formats: all # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 + version: 3.10 install: - requirements: docs/requirements.txt diff --git a/docs/source/conf.py b/docs/source/conf.py index 9901cffd3..d7cb356a1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,12 +15,12 @@ # -*- coding: utf-8 -*- # -# Configuration file for the Sphinx documentation builder. +# Configuration file for the Sphinx documentation loader. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config -"""Configuration file for Sphinx autodoc API documentation builder.""" +"""Configuration file for Sphinx autodoc API documentation loader.""" # -- Path setup -------------------------------------------------------------- @@ -31,19 +31,20 @@ import os import sys -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'Trax' -copyright = '2020, Google LLC.' # pylint: disable=redefined-builtin -author = 'The Trax authors' +project = "Trax" +copyright = "2020, Google LLC." # pylint: disable=redefined-builtin +author = "The Trax authors" # The short X.Y version -version = '' +version = "" # The full version, including alpha/beta/rc tags -release = '' +release = "" # -- General configuration --------------------------------------------------- @@ -56,23 +57,23 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'nbsphinx', - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', + "nbsphinx", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -95,7 +96,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -106,7 +107,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -121,8 +122,8 @@ # -- Options for HTMLHelp output --------------------------------------------- -# Output file base name for HTML help builder. -htmlhelp_basename = 'Traxdoc' +# Output file base name for HTML help loader. +htmlhelp_basename = "Traxdoc" # -- Options for LaTeX output ------------------------------------------------ @@ -131,15 +132,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -149,8 +147,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Trax.tex', 'Trax Documentation', - 'Trax authors', 'manual'), + (master_doc, "Trax.tex", "Trax Documentation", "Trax authors", "manual"), ] @@ -158,10 +155,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'trax', 'Trax Documentation', - [author], 1) -] +man_pages = [(master_doc, "trax", "Trax Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -170,9 +164,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Trax', 'Trax Documentation', - author, 'Trax', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "Trax", + "Trax Documentation", + author, + "Trax", + "One line description of project.", + "Miscellaneous", + ), ] @@ -191,37 +191,36 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" autodoc_default_options = { - 'members': None, # Include all public members. - 'undoc-members': True, # Include members that lack docstrings. - 'show-inheritance': True, - 'special-members': '__call__, __init__', + "members": None, # Include all public members. + "undoc-members": True, # Include members that lack docstrings. + "show-inheritance": True, + "special-members": "__call__, __init__", } autodoc_mock_imports = [ - 'gin', - 'jax', - 'numpy', - 'scipy', - 'tensorflow', - 'tensorflow_datasets', - 'funcsigs', - 'trax.tf_numpy', - 'absl', - 'gym', - 'tensor2tensor', - 'tensorflow_text', - 'matplotlib', - 'cloudpickle', - 't5', - 'psutil', + "gin", + "jax", + "numpy", + "scipy", + "tensorflow", + "tensorflow_datasets", + "funcsigs", + "trax.tf", + "absl", + "gym", + "tensor2tensor", + "tensorflow_text", + "matplotlib", + "cloudpickle", + "t5", + "psutil", # 'setup', ] - diff --git a/oss_scripts/oss_pip_install.sh b/oss_scripts/oss_pip_install.sh index 839d3d9fc..652fa447d 100755 --- a/oss_scripts/oss_pip_install.sh +++ b/oss_scripts/oss_pip_install.sh @@ -15,7 +15,7 @@ #!/bin/bash set -v # print commands as they're executed -set -e # fail and exit on any command erroring +set -e # fail and exit on any command error : "${TF_VERSION:?}" diff --git a/oss_scripts/oss_release.sh b/oss_scripts/oss_release.sh index 9d913ba8f..7b2944fb9 100755 --- a/oss_scripts/oss_release.sh +++ b/oss_scripts/oss_release.sh @@ -15,18 +15,18 @@ #!/bin/bash set -v # print commands as they're executed -set -e # fail and exit on any command erroring +set -e # fail and exit on any command error GIT_COMMIT_ID=${1:-""} [[ -z $GIT_COMMIT_ID ]] && echo "Must provide a commit" && exit 1 TMP_DIR=$(mktemp -d) -pushd $TMP_DIR +pushd "$TMP_DIR" echo "Cloning trax and checking out commit $GIT_COMMIT_ID" git clone https://github.com/google/trax.git cd trax -git checkout $GIT_COMMIT_ID +git checkout "$GIT_COMMIT_ID" python3 -m pip install wheel twine pyopenssl @@ -42,4 +42,4 @@ python3 -m twine upload dist/* # Cleanup rm -rf build/ dist/ trax.egg-info/ popd -rm -rf $TMP_DIR +rm -rf "$TMP_DIR" diff --git a/oss_scripts/oss_tests.sh b/oss_scripts/oss_tests.sh index ee3bf428f..5db095be5 100755 --- a/oss_scripts/oss_tests.sh +++ b/oss_scripts/oss_tests.sh @@ -1,3 +1,5 @@ +#!/bin/bash + # Copyright 2022 The Trax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -#!/bin/bash - set -v # print commands as they're executed # aliases aren't expanded in non-interactive shells by default. @@ -46,98 +46,53 @@ set_status # # Run pytest with coverage. # alias pytest='coverage run -m pytest' -# Check tests, separate out directories for easy triage. - +# Check tests, check each directory of tests separately. if [[ "${TRAX_TEST}" == "lib" ]] then - ## Core Trax and Supervised Learning + echo "Testing all framework packages..." - # Disabled the decoding test for now, since it OOMs. - # TODO(afrozm): Add the decoding_test.py back again. - - # training_test and trainer_lib_test parse flags, so can't use with --ignore - pytest \ - --ignore=trax/supervised/callbacks_test.py \ - --ignore=trax/supervised/decoding_test.py \ - --ignore=trax/supervised/decoding_timing_test.py \ - --ignore=trax/supervised/trainer_lib_test.py \ - --ignore=trax/supervised/training_test.py \ - trax/supervised + ## Core Trax and Supervised Learning + pytest tests/data set_status - # Testing these separately here. - pytest \ - trax/supervised/callbacks_test.py \ - trax/supervised/trainer_lib_test.py \ - trax/supervised/training_test.py + pytest tests/fastmath set_status - pytest trax/data + pytest tests/layers set_status - # Ignoring acceleration_test's test_chunk_grad_memory since it is taking a - # lot of time on OSS. - pytest \ - --deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_grad_memory \ - --deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_memory \ - --ignore=trax/layers/initializers_test.py \ - --ignore=trax/layers/test_utils.py \ - trax/layers + pytest tests/learning set_status - pytest trax/layers/initializers_test.py + pytest tests/models set_status - pytest trax/fastmath + pytest tests/optimizers set_status - pytest trax/optimizers + pytest tests/tf set_status - # Catch-all for futureproofing. - pytest \ - --ignore=trax/trax2keras_test.py \ - --ignore=trax/data \ - --ignore=trax/fastmath \ - --ignore=trax/layers \ - --ignore=trax/models \ - --ignore=trax/optimizers \ - --ignore=trax/rl \ - --ignore=trax/supervised \ - --ignore=trax/tf_numpy + pytest tests/trainers set_status -else - # Models, RL and misc right now. - ## Models - # Disabled tests are quasi integration tests. - pytest \ - --ignore=trax/models/reformer/reformer_e2e_test.py \ - --ignore=trax/models/reformer/reformer_memory_test.py \ - --ignore=trax/models/research/terraformer_e2e_test.py \ - --ignore=trax/models/research/terraformer_memory_test.py \ - --ignore=trax/models/research/terraformer_oom_test.py \ - trax/models + pytest tests/utils/import_test.py set_status - ## RL Trax - pytest trax/rl + pytest tests/utils/shapes_test.py set_status - ## Trax2Keras - # TODO(afrozm): Make public again after TF 2.5 releases. - # pytest trax/trax2keras_test.py - # set_status +else + echo "No testing ..." + # Models, RL and misc right now. # Check notebooks. # TODO(afrozm): Add more. - jupyter nbconvert --ExecutePreprocessor.kernel_name=python3 \ - --ExecutePreprocessor.timeout=600 --to notebook --execute \ - trax/intro.ipynb; - set_status + # jupyter nbconvert --ExecutePreprocessor.kernel_name=python3 \ + # --ExecutePreprocessor.timeout=600 --to notebook --execute \ + # trax/intro.ipynb; + # set_status fi -# TODO(traxers): Test tf-numpy separately. - exit $STATUS diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..f284563f6 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = . trax \ No newline at end of file diff --git a/trax/data/testdata/bert_uncased_vocab.txt b/resources/data/testdata/bert_uncased_vocab.txt similarity index 100% rename from trax/data/testdata/bert_uncased_vocab.txt rename to resources/data/testdata/bert_uncased_vocab.txt diff --git a/trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 b/resources/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 rename to resources/data/testdata/c4/en/2.3.0/c4-train.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 b/resources/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 rename to resources/data/testdata/c4/en/2.3.0/c4-validation.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/c4/en/2.3.0/dataset_info.json b/resources/data/testdata/c4/en/2.3.0/dataset_info.json similarity index 100% rename from trax/data/testdata/c4/en/2.3.0/dataset_info.json rename to resources/data/testdata/c4/en/2.3.0/dataset_info.json diff --git a/trax/data/testdata/corpus-1.txt b/resources/data/testdata/corpus-1.txt similarity index 100% rename from trax/data/testdata/corpus-1.txt rename to resources/data/testdata/corpus-1.txt diff --git a/trax/data/testdata/corpus-2.txt b/resources/data/testdata/corpus-2.txt similarity index 100% rename from trax/data/testdata/corpus-2.txt rename to resources/data/testdata/corpus-2.txt diff --git a/trax/data/testdata/en_8k.subword b/resources/data/testdata/en_8k.subword similarity index 100% rename from trax/data/testdata/en_8k.subword rename to resources/data/testdata/en_8k.subword diff --git a/trax/data/testdata/para_crawl/ende/1.2.0/dataset_info.json b/resources/data/testdata/para_crawl/ende/1.2.0/dataset_info.json similarity index 100% rename from trax/data/testdata/para_crawl/ende/1.2.0/dataset_info.json rename to resources/data/testdata/para_crawl/ende/1.2.0/dataset_info.json diff --git a/trax/data/testdata/para_crawl/ende/1.2.0/features.json b/resources/data/testdata/para_crawl/ende/1.2.0/features.json similarity index 100% rename from trax/data/testdata/para_crawl/ende/1.2.0/features.json rename to resources/data/testdata/para_crawl/ende/1.2.0/features.json diff --git a/trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 b/resources/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 rename to resources/data/testdata/para_crawl/ende/1.2.0/para_crawl-train.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/sentencepiece.model b/resources/data/testdata/sentencepiece.model similarity index 100% rename from trax/data/testdata/sentencepiece.model rename to resources/data/testdata/sentencepiece.model diff --git a/trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json b/resources/data/testdata/squad/v1.1/3.0.0/dataset_info.json similarity index 74% rename from trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json rename to resources/data/testdata/squad/v1.1/3.0.0/dataset_info.json index 3298dab0a..31c889289 100644 --- a/trax/data/testdata/squad/v1.1/3.0.0/dataset_info.json +++ b/resources/data/testdata/squad/v1.1/3.0.0/dataset_info.json @@ -1,51 +1,51 @@ { - "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", - "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", + "citation": "@article{2016arXiv160605250R,\n author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n Konstantin and {Liang}, Percy},\n title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n journal = {arXiv e-prints},\n year = 2016,\n eid = {arXiv:1606.05250},\n pages = {arXiv:1606.05250},\narchivePrefix = {arXiv},\n eprint = {1606.05250},\n}\n", + "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n", "location": { "urls": [ "https://rajpurkar.github.io/SQuAD-explorer/" ] - }, - "name": "squad", + }, + "name": "squad", "schema": { "feature": [ { "name": "answers" - }, + }, { - "name": "context", + "name": "context", "type": "BYTES" - }, + }, { - "name": "id", + "name": "id", "type": "BYTES" - }, + }, { - "name": "question", + "name": "question", "type": "BYTES" - }, + }, { - "name": "title", + "name": "title", "type": "BYTES" } ] - }, - "sizeInBytes": "35142551", + }, + "sizeInBytes": "35142551", "splits": [ { - "name": "train", - "numShards": "1", + "name": "train", + "numShards": "1", "shardLengths": [ "10" ] }, { - "name": "validation", - "numShards": "1", + "name": "validation", + "numShards": "1", "shardLengths": [ "10" ] } - ], + ], "version": "3.0.0" -} +} \ No newline at end of file diff --git a/trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 b/resources/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 rename to resources/data/testdata/squad/v1.1/3.0.0/squad-train.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 b/resources/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 similarity index 100% rename from trax/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 rename to resources/data/testdata/squad/v1.1/3.0.0/squad-validation.tfrecord-00000-of-00001 diff --git a/trax/data/testdata/vocab-1.txt b/resources/data/testdata/vocab-1.txt similarity index 100% rename from trax/data/testdata/vocab-1.txt rename to resources/data/testdata/vocab-1.txt diff --git a/trax/data/testdata/vocab-2.txt b/resources/data/testdata/vocab-2.txt similarity index 100% rename from trax/data/testdata/vocab-2.txt rename to resources/data/testdata/vocab-2.txt diff --git a/trax/models/reformer/testdata/vocab.translate_ende_wmt32k.32768.subwords b/resources/data/testdata/vocab.translate_ende_wmt32k.32768.subwords similarity index 100% rename from trax/models/reformer/testdata/vocab.translate_ende_wmt32k.32768.subwords rename to resources/data/testdata/vocab.translate_ende_wmt32k.32768.subwords diff --git a/resources/examples/ipynb/Example-0-Introduction.ipynb b/resources/examples/ipynb/Example-0-Introduction.ipynb new file mode 100644 index 000000000..36a211232 --- /dev/null +++ b/resources/examples/ipynb/Example-0-Introduction.ipynb @@ -0,0 +1,664 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7yuytuIllsv1" + }, + "source": [ + "# Trax Quick Intro\n", + "\n", + "[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google.com/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information.\n", + "\n", + " 1. **Run a pre-trained Transformer**: create a translator in a few lines of code\n", + " 1. **Features and resources**: [API docs](https://trax-ml.readthedocs.io/en/latest/trax.html), where to [talk to us](https://gitter.im/trax-ml/community), how to [open an issue](https://github.com/google/trax/issues) and more\n", + " 1. **Walkthrough**: how Trax works, how to make new models and train on your own data\n", + "\n", + "We welcome **contributions** to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love **notebooks** that explain how models work and show how to use them to solve problems!\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BIl27504La0G" + }, + "source": [ + "**General Setup**\n", + "\n", + "Execute the following few cells (once) before running any of the code samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 36794, + "status": "ok", + "timestamp": 1607149386661, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 463, + "status": "ok", + "timestamp": 1607149387132, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "vlGjGoGMTt-D", + "outputId": "3076e638-695d-4017-e757-98d929630e17" + }, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import sys" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "from trax.data.encoder import encoder\n", + "from trax.learning.supervised import decoding as decoding\n", + "from trax import models as models\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-LQ89rFFsEdk" + }, + "source": [ + "## 1. Run a pre-trained Transformer\n", + "\n", + "Here is how you create an Engligh-German translator in a few lines of code:\n", + "\n", + "* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)\n", + "* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)\n", + "* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)\n", + "* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)\n", + "* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 46373, + "status": "ok", + "timestamp": 1607149433512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "djTiSLcaNFGa", + "outputId": "a7917337-0a77-4064-8a6e-4e44e4a9c7c7" + }, + "outputs": [], + "source": [ + "# Create a Transformer model.\n", + "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", + "model = models.Transformer(\n", + " input_vocab_size=33300,\n", + " d_model=512, d_ff=2048,\n", + " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", + " max_len=2048, mode='predict')\n", + "\n", + "# Initialize using pre-trained weights.\n", + "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", + " weights_only=True)\n", + "\n", + "# Tokenize a sentence.\n", + "sentence = 'It is nice to learn new things today!'\n", + "tokenized = list(encoder.tokenize(iter([sentence]), # Operates on streams.\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword'))[0]\n", + "\n", + "# Decode from the Transformer.\n", + "tokenized = tokenized[None, :] # Add batch dimension.\n", + "tokenized_translation = decoding.autoregressive_sample(\n", + " model, tokenized, temperature=0.0) # Higher temperature: more diverse results.\n", + "\n", + "# De-tokenize,\n", + "tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.\n", + "translation = encoder.detokenize(tokenized_translation,\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword')\n", + "print(translation)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QMo3OnsGgLNK" + }, + "source": [ + "## 2. Features and resources\n", + "\n", + "Trax includes basic models (like [ResNet](https://github.com/google/trax/blob/master/trax/models/resnet.py#L70), [LSTM](https://github.com/google/trax/blob/master/trax/models/rnn.py#L100), [Transformer](https://github.com/google/trax/blob/master/trax/models/transformer.py#L189) and RL algorithms\n", + "(like [REINFORCE](https://github.com/google/trax/blob/master/trax/rl/training.py#L244), [A2C](https://github.com/google/trax/blob/master/trax/rl/actor_critic_joint.py#L458), [PPO](https://github.com/google/trax/blob/master/trax/rl/actor_critic_joint.py#L209)). It is also actively used for research and includes\n", + "new models like the [Reformer](https://github.com/google/trax/tree/master/trax/models/reformer) and new RL algorithms like [AWR](https://arxiv.org/abs/1910.00177). Trax has bindings to a large number of deep learning datasets, including\n", + "[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) and [TensorFlow datasets](https://www.tensorflow.org/datasets/catalog/overview).\n", + "\n", + "\n", + "You can use Trax either as a library from your own python scripts and notebooks\n", + "or as a binary from the shell, which can be more convenient for training large models.\n", + "It runs without any changes on CPUs, GPUs and TPUs.\n", + "\n", + "* [API docs](https://trax-ml.readthedocs.io/en/latest/)\n", + "* [chat with us](https://gitter.im/trax-ml/community)\n", + "* [open an issue](https://github.com/google/trax/issues)\n", + "* subscribe to [trax-discuss](https://groups.google.com/u/1/g/trax-discuss) for news\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8wgfJyhdihfR" + }, + "source": [ + "## 3. Walkthrough\n", + "\n", + "You can learn here how Trax works, how to create new models and how to train them on your own data." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yM12hgQnp4qo" + }, + "source": [ + "### Tensors and Fast Math\n", + "\n", + "The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n", + "\n", + "In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 667, + "status": "ok", + "timestamp": 1607149434186, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "kSauPt0NUl_o", + "outputId": "c7288312-767d-4344-91ae-95ebf386ce57" + }, + "outputs": [], + "source": [ + "from trax.fastmath import numpy as fastnp\n", + "\n", + "trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.\n", + "\n", + "matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n", + "print(f'matrix =\\n{matrix}')\n", + "vector = fastnp.ones(3)\n", + "print(f'vector = {vector}')\n", + "product = fastnp.dot(vector, matrix)\n", + "print(f'product = {product}')\n", + "tanh = fastnp.tanh(product)\n", + "print(f'tanh(product) = {tanh}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "snLYtU6OsKU2" + }, + "source": [ + "Gradients can be calculated using `trax.fastmath.grad`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 545, + "status": "ok", + "timestamp": 1607149434742, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "cqjYoxPEu8PG", + "outputId": "04739509-9d3a-446d-d088-84882b8917bc" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " return 2.0 * x * x\n", + "\n", + "\n", + "grad_f = trax.fastmath.grad(f)\n", + "\n", + "print(f'grad(2x^2) at 1 = {grad_f(1.0)}')\n", + "print(f'grad(2x^2) at -2 = {grad_f(-2.0)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p-wtgiWNseWw" + }, + "source": [ + "### Layers\n", + "\n", + "Layers are basic building blocks of Trax models. You will learn all about them in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) but for now, just take a look at the implementation of one core Trax layer, `Embedding`:\n", + "\n", + "```\n", + "class Embedding(base.Layer):\n", + " \"\"\"Trainable layer that maps discrete tokens/IDs to vectors.\"\"\"\n", + "\n", + " def __init__(self,\n", + " vocab_size,\n", + " d_feature,\n", + " kernel_initializer=init.RandomNormalInitializer(1.0)):\n", + " \"\"\"Returns an embedding layer with given vocabulary size and vector size.\n", + "\n", + " Args:\n", + " vocab_size: Size of the input vocabulary. The layer will assign a unique\n", + " vector to each id in `range(vocab_size)`.\n", + " d_feature: Dimensionality/depth of the output vectors.\n", + " kernel_initializer: Function that creates (random) initial vectors for\n", + " the embedding.\n", + " \"\"\"\n", + " super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')\n", + " self._d_feature = d_feature # feature dimensionality\n", + " self._vocab_size = vocab_size\n", + " self._kernel_initializer = kernel_initializer\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Returns embedding vectors corresponding to input token IDs.\n", + "\n", + " Args:\n", + " x: Tensor of token IDs.\n", + "\n", + " Returns:\n", + " Tensor of embedding vectors.\n", + " \"\"\"\n", + " return jnp.take(self.weights, x, axis=0, mode='clip')\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " del input_signature\n", + " shape_w = (self._vocab_size, self._d_feature)\n", + " w = self._kernel_initializer(shape_w, self.rng)\n", + " self.weights = w\n", + "```\n", + "\n", + "Layers with trainable weights like `Embedding` need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 598, + "status": "ok", + "timestamp": 1607149436202, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "4MLSQsIiw9Aw", + "outputId": "394efc9d-9e3c-4f8c-80c2-ce3b5a935e38" + }, + "outputs": [], + "source": [ + "from trax import layers as tl\n", + "from trax.utils import shapes\n", + "\n", + "# Create an input tensor x.\n", + "x = np.arange(15)\n", + "print(f'x = {x}')\n", + "\n", + "# Create the embedding layer.\n", + "embedding = tl.Embedding(vocab_size=20, d_feature=32)\n", + "embedding.init(trax.utils.shapes.signature(x))\n", + "\n", + "# Run the layer -- y = embedding(x).\n", + "y = embedding(x)\n", + "print(f'shape of y = {y.shape}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MgCPl9ZOyCJw" + }, + "source": [ + "### Models\n", + "\n", + "Models in Trax are built from layers most often using the `Serial` and `Branch` combinators. You can read more about those combinators in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) and\n", + "see the code for many models in `trax/models/`, e.g., this is how the [Transformer Language Model](https://github.com/google/trax/blob/master/trax/models/transformer.py#L167) is implemented. Below is an example of how to build a sentiment classification model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 473, + "status": "ok", + "timestamp": 1607149436685, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "WoSz5plIyXOU", + "outputId": "f94c84c4-3224-4231-8879-4a68f328b89e" + }, + "outputs": [], + "source": [ + "model = tl.Serial(\n", + " tl.Embedding(vocab_size=8192, d_feature=256),\n", + " tl.Mean(axis=1), # Average on axis 1 (length of sentence).\n", + " tl.Dense(2), # Classify 2 classes.\n", + ")\n", + "\n", + "# You can print model structure.\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FcnIjFLD0Ju1" + }, + "source": [ + "### Data\n", + "\n", + "To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call `next(data_stream)` and get a tuple, e.g., `(inputs, targets)`. Trax allows you to use [TensorFlow Datasets](https://www.tensorflow.org/datasets) easily and you can also get an iterator from your own text file using the standard `open('my_file.txt')`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 19863, + "status": "ok", + "timestamp": 1607149456555, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "pKITF1jR0_Of", + "outputId": "44a73b25-668d-4f85-9133-ebb0f5edd191" + }, + "outputs": [], + "source": [ + "from trax.data.loader.tf import base as dataset\n", + "\n", + "train_stream = dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()\n", + "eval_stream = dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()\n", + "print(next(train_stream)) # See one example." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fRGj4Skm1kL4" + }, + "source": [ + "Using the `trax.data` module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using `trax.data.Serial` and they are functions that you apply to streams to create processed streams." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 1746, + "status": "ok", + "timestamp": 1607149458319, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "AV5wrgjZ10yU", + "outputId": "82b8e3bc-7812-4cd3-a669-401fef29f1c0" + }, + "outputs": [], + "source": [ + "from trax.data.preprocessing import inputs as preprocessing\n", + "from trax.data.encoder import encoder\n", + "\n", + "data_pipeline = preprocessing.Serial(\n", + " encoder.Tokenize(vocab_file='en_8k.subword', keys=[0]),\n", + " preprocessing.Shuffle(),\n", + " preprocessing.FilterByLength(max_length=2048, length_keys=[0]),\n", + " preprocessing.BucketByLength(boundaries=[32, 128, 512, 2048],\n", + " batch_sizes=[512, 128, 32, 8, 1],\n", + " length_keys=[0]),\n", + " preprocessing.AddLossWeights()\n", + ")\n", + "train_batches_stream = data_pipeline(train_stream)\n", + "eval_batches_stream = data_pipeline(eval_stream)\n", + "example_batch = next(train_batches_stream)\n", + "print(example_batch)\n", + "#print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l25krioP2twf" + }, + "source": [ + "### Supervised training\n", + "\n", + "When you have the model and the data, use `trax.supervised.training` to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 43631, + "status": "ok", + "timestamp": 1607149504226, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "d6bIKUO-3Cw8", + "outputId": "038e6ad5-0d2f-442b-ffa1-ed431dc1d2e0" + }, + "outputs": [], + "source": [ + "from trax.learning.supervised import training\n", + "\n", + "# Training task.\n", + "train_task = training.TrainTask(\n", + " labeled_data=train_batches_stream,\n", + " loss_layer=tl.WeightedCategoryCrossEntropy(),\n", + " optimizer=trax.optimizers.Adam(0.01),\n", + " n_steps_per_checkpoint=500,\n", + ")\n", + "\n", + "# Evaluaton task.\n", + "eval_task = training.EvalTask(\n", + " labeled_data=eval_batches_stream,\n", + " metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],\n", + " n_eval_batches=20 # For less variance in eval numbers.\n", + ")\n", + "\n", + "# Training loop saves checkpoints to output_dir.\n", + "output_dir = os.path.expanduser('~/output_dir/')\n", + "!rm -rf {output_dir}\n", + "training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir=output_dir)\n", + "\n", + "# Run 2000 steps (batches).\n", + "training_loop.run(2000)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-aCkIu3x686C" + }, + "source": [ + "After training the model, run it like any layer to get results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 1683, + "status": "ok", + "timestamp": 1607149514303, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "yuPu37Lp7GST", + "outputId": "fdc4d832-2f1d-4aee-87b5-9c9dc1238503" + }, + "outputs": [], + "source": [ + "example_input = next(eval_batches_stream)[0][0]\n", + "example_input_str = encoder.detokenize(example_input, vocab_file='en_8k.subword')\n", + "print(f'example input_str: {example_input_str}')\n", + "sentiment_log_probs = model(example_input[None, :]) # Add batch dimension.\n", + "print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", + "kind": "private" + }, + "name": "Trax Quick Intro", + "provenance": [ + { + "file_id": "trax/intro.ipynb", + "timestamp": 1595931762204 + }, + { + "file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV", + "timestamp": 1578964243645 + }, + { + "file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07", + "timestamp": 1572044421118 + }, + { + "file_id": "intro.ipynb", + "timestamp": 1571858674399 + }, + { + "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", + "timestamp": 1569980697572 + }, + { + "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", + "timestamp": 1563927451951 + } + ] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-1-Trax-Data.ipynb b/resources/examples/ipynb/Example-1-Trax-Data.ipynb new file mode 100644 index 000000000..3344c3829 --- /dev/null +++ b/resources/examples/ipynb/Example-1-Trax-Data.ipynb @@ -0,0 +1,732 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6NWA5uxOmBVz" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zOPgYEe2i7Cg" + }, + "source": [ + "Notebook Author: [@SauravMaheshkar](https://github.com/SauravMaheshkar)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jtMr8yxvM2m3" + }, + "source": [ + "# Introduction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yD3A2vRGSDwy" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v5VsWct1QjPz" + }, + "source": [ + "# Serial Fn" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gEa5pT6FQuta" + }, + "source": [ + "In Trax, we use combinators to build input pipelines, much like building deep learning models. The `Serial` combinator applies layers serially using function composition and uses stack semantics to manage data.\n", + "\n", + "Trax has the following definition for a `Serial` combinator.\n", + "\n", + "> ```\n", + "def Serial(*fns):\n", + " def composed_fns(generator=None):\n", + " for f in fastmath.tree_flatten(fns):\n", + " generator = f(generator)\n", + " return generator\n", + " return composed_fns\n", + " ```\n", + "\n", + "The `Serial` function has the following structure:\n", + "\n", + "* It takes as **input** arbitrary number of functions\n", + "* Convert the structure into lists\n", + "* Iterate through the list and apply the functions Serially\n", + "\n", + "---\n", + "\n", + "The [`fastmath.tree_flatten()`](https://github.com/google/trax/blob/c38a5b1e4c5cfe13d156b3fc0bfdb83554c8f799/trax/fastmath/numpy.py#L195) function, takes a tree as a input and returns a flattened list. This way we can use various generator functions like Tokenize and Shuffle, and apply them serially by '*iterating*' through the list.\n", + "\n", + "Initially, we've defined `generator` to `None`. Thus, in the first iteration we have no input and thus the first step executes the first function in our tree structure. In the next iteration, the `generator` variable is updated to be the output of the next function in the list.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1rkCvxscXtvk" + }, + "source": [ + "# Log Function" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oodQFyHDYJHF" + }, + "source": [ + "> ```\n", + "def Log(n_steps_per_example=1, only_shapes=True):\n", + " def log(stream):\n", + " counter = 0\n", + " for example in stream:\n", + " item_to_log = example\n", + " if only_shapes:\n", + " item_to_log = fastmath.nested_map(shapes.signature, example)\n", + " if counter % n_steps_per_example == 0:\n", + " logging.info(str(item_to_log))\n", + " print(item_to_log)\n", + " counter += 1\n", + " yield example\n", + " return log\n", + "\n", + "Every Deep Learning Framework needs to have a logging component for efficient debugging.\n", + "\n", + "`trax.data.Log` generator uses the `absl` package for logging. It uses a [`fastmath.nested_map`](https://github.com/google/trax/blob/c38a5b1e4c5cfe13d156b3fc0bfdb83554c8f799/trax/fastmath/numpy.py#L80) function that maps a certain function recursively inside a object. In the case depicted below, the function maps the `shapes.signature` recursively inside the input stream, thus giving us the shapes of the various objects in our stream.\n", + "\n", + "--\n", + "\n", + "The following two cells show the difference between when we set the `only_shapes` variable to `False`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trax.data.preprocessing import inputs as preprocessing\n", + "from trax.data.encoder import encoder\n", + "from trax.data.loader.tf import base as dataset\n", + "\n", + "data_pipeline = preprocessing.Serial(\n", + " dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", + " encoder.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", + " preprocessing.Log(only_shapes=False)\n", + ")\n", + "example = data_pipeline()\n", + "print(next(example))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Wy8L-e9qcRY4" + }, + "source": [ + "# Shuffling our datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-cfg48KgcrlM" + }, + "source": [ + "Trax offers two generator functions to add shuffle functionality in our input pipelines.\n", + "\n", + "1. The `shuffle` function shuffles a given stream\n", + "2. The `Shuffle` function returns a shuffle function instead" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4iD21oiycWf4" + }, + "source": [ + "## `shuffle`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bVgN1yYAcaKM" + }, + "source": [ + "> ```\n", + "def shuffle(samples, queue_size):\n", + " if queue_size < 1:\n", + " raise ValueError(f'Arg queue_size ({queue_size}) is less than 1.')\n", + " if queue_size == 1:\n", + " logging.warning('Queue size of 1 results in no shuffling.')\n", + " queue = []\n", + " try:\n", + " queue.append(next(samples))\n", + " i = np.random.randint(queue_size)\n", + " yield queue[i]\n", + " queue[i] = sample\n", + " except StopIteration:\n", + " logging.warning(\n", + " 'Not enough samples (%d) to fill initial queue (size %d).',\n", + " len(queue), queue_size)\n", + " np.random.shuffle(queue)\n", + " for sample in queue:\n", + " yield sample\n", + "\n", + "\n", + "The `shuffle` function takes two inputs, the data stream and the queue size (minimum number of samples within which the shuffling takes place). Apart from the usual warnings, for negative and unity queue sizes, this generator function shuffles the given stream using [`np.random.randint()`](https://docs.python.org/3/library/random.html#random.randint) by randomly picks out integers using the `queue_size` as a range and then shuffle this new stream again using the [`np.random.shuffle()`](https://docs.python.org/3/library/random.html#random.shuffle)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = [\n", + " 'Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem. Ut enim ad minima veniam, quis nostrum exercitationem ullam corporis suscipit laboriosam, nisi ut aliquid ex ea commodi consequatur? Quis autem vel eum iure reprehenderit qui in ea voluptate velit esse quam nihil molestiae consequatur, vel illum qui dolorem eum fugiat quo voluptas nulla pariatur?',\n", + " 'But I must explain to you how all this mistaken idea of denouncing pleasure and praising pain was born and I will give you a complete account of the system, and expound the actual teachings of the great explorer of the truth, the master-loader of human happiness. No one rejects, dislikes, or avoids pleasure itself, because it is pleasure, but because those who do not know how to pursue pleasure rationally encounter consequences that are extremely painful. Nor again is there anyone who loves or pursues or desires to obtain pain of itself, because it is pain, but because occasionally circumstances occur in which toil and pain can procure him some great pleasure. To take a trivial example, which of us ever undertakes laborious physical exercise, except to obtain some advantage from it? But who has any right to find fault with a man who chooses to enjoy a pleasure that has no annoying consequences, or one who avoids a pain that produces no resultant pleasure?',\n", + " 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum',\n", + " 'At vero eos et accusamus et iusto odio dignissimos ducimus qui blanditiis praesentium voluptatum deleniti atque corrupti quos dolores et quas molestias excepturi sint occaecati cupiditate non provident, similique sunt in culpa qui officia deserunt mollitia animi, id est laborum et dolorum fuga. Et harum quidem rerum facilis est et expedita distinctio. Nam libero tempore, cum soluta nobis est eligendi optio cumque nihil impedit quo minus id quod maxime placeat facere possimus, omnis voluptas assumenda est, omnis dolor repellendus. Temporibus autem quibusdam et aut officiis debitis aut rerum necessitatibus saepe eveniet ut et voluptates repudiandae sint et molestiae non recusandae. Itaque earum rerum hic tenetur a sapiente delectus, ut aut reiciendis voluptatibus maiores alias consequatur aut perferendis doloribus asperiores repellat.']\n", + "\n", + "\n", + "def sample_generator(x):\n", + " for i in x:\n", + " yield i\n", + "\n", + "\n", + "example_shuffle = list(preprocessing.shuffle(sample_generator(sentence), queue_size=2))\n", + "example_shuffle" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k-kTDkF-e7Vn" + }, + "source": [ + "## `Shuffle`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I5Djvqw2e9Jg" + }, + "source": [ + "> ```\n", + "def Shuffle(queue_size=1024):\n", + " return lambda g: shuffle(g, queue_size)\n", + "\n", + "This function returns the aforementioned `shuffle` function and is mostly used in input pipelines.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AA-Z4Sipkq98" + }, + "source": [ + "# Batch Generators" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yzwONDulksbd" + }, + "source": [ + "## `batch`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-DCABkndkudF" + }, + "source": [ + "This function, creates batches for the input generator function.\n", + "\n", + "> ```\n", + "def batch(generator, batch_size):\n", + " if batch_size <= 0:\n", + " raise ValueError(f'Batch size must be positive, but is {batch_size}.')\n", + " buf = []\n", + " for example in generator:\n", + " buf.append(example)\n", + " if len(buf) == batch_size:\n", + " batched_example = tuple(np.stack(x) for x in zip(*buf))\n", + " yield batched_example\n", + " buf = []\n", + "\n", + "It keeps adding objects from the generator into a list until the size becomes equal to the `batch_size` and then creates batches using the `np.stack()` function.\n", + "\n", + "It also raises an error for non-positive batch_sizes.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZMKY6VUpD3M" + }, + "source": [ + "## `Batch`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g6pYJHgOpIG4" + }, + "source": [ + "> ```\n", + " def Batch(batch_size):\n", + " return lambda g: batch(g, batch_size)\n", + "\n", + "This Function returns the aforementioned `batch` function with given batch size." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cmQzaXw9vrbW" + }, + "source": [ + "# Pad to Maximum Dimensions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iL3MuKQIvt-Q" + }, + "source": [ + "This function is used to pad a tuple of tensors to a joint dimension and return their batch.\n", + "\n", + "For example, in this case a pair of tensors (1,2) and ( (3,4) , (5,6) ) is changed to (1,2,0) and ( (3,4) , (5,6) , 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "id": "lvbBDuq4p4qW", + "outputId": "ed69c541-3219-4a23-cf73-4568e3e2882f" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from trax.data.preprocessing import inputs as preprocessing\n", + "\n", + "tensors = (np.array([(1., 2.)]), np.array([(3., 4.), (5., 6.)]))\n", + "print(type(tensors[0]))\n", + "padded_tensors = preprocessing.pad_to_max_dims(tensors=tensors, boundary=3)\n", + "padded_tensors" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PDQQYCdLOkl1" + }, + "source": [ + "# Creating Buckets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RjGD3YKJWj58" + }, + "source": [ + "For training Recurrent Neural Networks, with large vocabulary a method called Bucketing is usually applied.\n", + "\n", + "The usual technique of using padding ensures that all occurences within a mini-batch are of the same length. But this reduces the inter-batch variability and intuitively puts similar sentences into the same batch therefore, reducing the overall robustness of the system.\n", + "\n", + "Thus, we use Bucketing where multiple buckets are created depending on the length of the sentences and these occurences are assigned to buckets on the basis of which bucket corresponds to it's length. We need to ensure that the bucket sizes are large for adding some variablity to the system." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "17z3ASA-OrSF" + }, + "source": [ + "## `bucket_by_length`\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rf5trhANYpy5" + }, + "source": [ + "> ```\n", + "def bucket_by_length(generator, length_fn, boundaries, batch_sizes,strict_pad_on_len=False):\n", + " buckets = [[] for _ in range(len(batch_sizes))]\n", + " boundaries = boundaries + [math.inf]\n", + " for example in generator:\n", + " length = length_fn(example)\n", + " bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b])\n", + " buckets[bucket_idx].append(example)\n", + " if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]:\n", + " batched = zip(*buckets[bucket_idx])\n", + " boundary = boundaries[bucket_idx]\n", + " boundary = None if boundary == math.inf else boundary\n", + " padded_batch = tuple(\n", + " pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched)\n", + " yield padded_batch\n", + " buckets[bucket_idx] = []\n", + "\n", + "---\n", + "\n", + "This function can be summarised as:\n", + "\n", + "* Create buckets as per the lengths given in the `batch_sizes` array\n", + "\n", + "* Assign sentences into buckets if their length matches the bucket size\n", + "\n", + "* If padding is required, we use the `pad_to_max_dims` function\n", + "\n", + "---\n", + "\n", + "### Parameters\n", + "\n", + "1. **generator:** The input generator function\n", + "2. **length_fn:** A custom length function for determing the length of functions, not necessarily `len()`\n", + "3. **boundaries:** A python list containing corresponding bucket boundaries\n", + "4. **batch_sizes:** A python list containing batch sizes\n", + "5. **strict_pad_on_len:** – A python boolean variable (`True` or `False`). If set to true then the function pads on the length dimension, where dim[0] is strictly a multiple of boundary.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c0uQZaaPVyF_" + }, + "source": [ + "## `BucketByLength`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qhh21q71aX3l" + }, + "source": [ + "> ```\n", + "def BucketByLength(boundaries, batch_sizes,length_keys=None, length_axis=0, strict_pad_on_len=False):\n", + " length_keys = length_keys or [0, 1]\n", + " length_fn = lambda x: _length_fn(x, length_axis, length_keys)\n", + " return lambda g: bucket_by_length(g, length_fn, boundaries, batch_sizes, strict_pad_on_len)\n", + "\n", + "---\n", + "\n", + "This function, is usually used inside input pipelines(*combinators*) and uses the afforementioned `bucket_by_length`. It applies a predefined `length_fn` which chooses the maximum shape on length_axis over length_keys.\n", + "\n", + "It's use is illustrated below" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "id": "PFeqDQNsV0PV", + "outputId": "ab9139c1-de56-4570-bcb6-731c1b475b12" + }, + "outputs": [], + "source": [ + "from trax.fastmath import numpy as jnp\n", + "\n", + "data_pipeline = preprocessing.Serial(\n", + " dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", + " encoder.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", + " # Make sure that all elements are arrays or vectors\n", + " lambda g: map(lambda x: tuple(jnp.asarray(elem) for elem in x), g),\n", + " preprocessing.BucketByLength(boundaries=[32, 128, 512, 2048],\n", + " batch_sizes=[512, 128, 32, 8, 1],\n", + " length_keys=[0]),\n", + " preprocessing.Log(only_shapes=True)\n", + ")\n", + "example = data_pipeline()\n", + "print(next(example))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9D0YdAT_ceSN" + }, + "source": [ + "# Filter by Length" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YLvi4Wu-eFAF" + }, + "source": [ + "> ```\n", + "def FilterByLength(max_length,length_keys=None, length_axis=0):\n", + " length_keys = length_keys or [0, 1]\n", + " length_fn = lambda x: _length_fn(x, length_axis, length_keys)\n", + " def filtered(gen):\n", + " for example in gen:\n", + " if length_fn(example) <= max_length:\n", + " yield example\n", + " return filtered\n", + "\n", + "---\n", + "\n", + "This function used the same predefined `length_fn` to only include those instances which are less than the given `max_length` parameter.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "id": "qyueQ1z-cg2p", + "outputId": "da007ab0-e719-4044-e6a4-6bba5f43131e" + }, + "outputs": [], + "source": [ + "Filtered = preprocessing.Serial(\n", + " dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", + " encoder.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", + " # Make sure that all elements are arrays or vectors\n", + " lambda g: map(lambda x: tuple(jnp.asarray(elem) for elem in x), g),\n", + " preprocessing.BucketByLength(boundaries=[32, 128, 512, 2048],\n", + " batch_sizes=[512, 128, 32, 8, 1],\n", + " length_keys=[0]),\n", + " preprocessing.FilterByLength(max_length=2048, length_keys=[0]),\n", + " preprocessing.Log(only_shapes=True)\n", + ")\n", + "filtered_example = Filtered()\n", + "print(next(filtered_example))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1XRrJSsUeZX-" + }, + "source": [ + "# Adding Loss Weights" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P3ySYhnpejy4" + }, + "source": [ + "## `add_loss_weights`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QgaXAlhgeuQv" + }, + "source": [ + "> ```\n", + "def add_loss_weights(generator, id_to_mask=None):\n", + " for example in generator:\n", + " if len(example) > 3 or len(example) < 2:\n", + " assert id_to_mask is None, 'Cannot automatically mask this stream.'\n", + " yield example\n", + " else:\n", + " if len(example) == 2:\n", + " weights = np.ones_like(example[1]).astype(np.float32)\n", + " else:\n", + " weights = example[2].astype(np.float32)\n", + " mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32)\n", + " weights *= mask\n", + " yield (example[0], example[1], weights)\n", + "\n", + "---\n", + "\n", + "This function essentially adds a loss mask (tensor of ones of the same shape) to the input stream.\n", + "\n", + "**Masking** is essentially a way to tell sequence-processing layers that certain timesteps in an input are missing, and thus should be skipped when processing the data.\n", + "\n", + "Thus, it adds 'weights' to the system.\n", + "\n", + "---\n", + "\n", + "### Parameters\n", + "\n", + "1. **generator:** The input data generator\n", + "2. **id_to_mask:** The value with which to mask. Can be used as `` in NLP." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hZPWc6a9hk_u" + }, + "source": [ + "```\n", + "\n", + "train_generator = trax.data.inputs.add_loss_weights(\n", + " data_generator(batch_size, x_train, y_train,vocab[''], True),\n", + " id_to_mask=vocab[''])\n", + "\n", + "\n", + "```\n", + "\n", + "For example, in this case I used the `add_loss_weights` function to add padding while implementing Named Entity Recogntion using the Reformer Architecture. You can read more about the project [here](https://www.kaggle.com/sauravmaheshkar/trax-ner-using-reformer)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GL31NErOgL3u" + }, + "source": [ + "## `AddLossWeights`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mBLf6iuXgPp2" + }, + "source": [ + "This function performs the afforementioned `add_loss_weights` to the data stream.\n", + "\n", + "> ```\n", + "def AddLossWeights(id_to_mask=None):\n", + " return lambda g: add_loss_weights(g,id_to_mask=id_to_mask)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 173 + }, + "id": "Jwtt-k_2iHEy", + "outputId": "52295b0e-ff9c-415e-9ba6-1d5c1359b508" + }, + "outputs": [], + "source": [ + "data_pipeline = preprocessing.Serial(\n", + " dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", + " encoder.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", + " # Make sure that all elements are arrays or vectors\n", + " lambda g: map(lambda x: tuple(jnp.asarray(elem) for elem in x), g),\n", + " preprocessing.Shuffle(),\n", + " preprocessing.FilterByLength(max_length=2048, length_keys=[0]),\n", + " preprocessing.BucketByLength(boundaries=[32, 128, 512, 2048],\n", + " batch_sizes=[512, 128, 32, 8, 1],\n", + " length_keys=[0]),\n", + " preprocessing.AddLossWeights(),\n", + " preprocessing.Log(only_shapes=True)\n", + ")\n", + "\n", + "example = data_pipeline()\n", + "print(next(example))" + ] + } + ], + "metadata": { + "colab": { + "authorship_tag": "ABX9TyMN9H/craeNOTmFImALz3Uk", + "collapsed_sections": [], + "include_colab_link": true, + "name": "trax.data Explained", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-2-Layers-Introduction.ipynb b/resources/examples/ipynb/Example-2-Layers-Introduction.ipynb new file mode 100644 index 000000000..cee98bb7f --- /dev/null +++ b/resources/examples/ipynb/Example-2-Layers-Introduction.ipynb @@ -0,0 +1,1229 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7yuytuIllsv1" + }, + "source": [ + "# Trax Layers Intro\n", + "\n", + "This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:\n", + "\n", + " 1. **Layers**: the basic building blocks and how to combine them\n", + " 1. **Inputs and Outputs**: how data streams flow through layers\n", + " 1. **Defining New Layer Classes** (if combining existing layers isn't enough)\n", + " 1. **Testing and Debugging Layer Classes**\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BIl27504La0G" + }, + "source": [ + "**General Setup**\n", + "\n", + "Execute the following few cells (once) before running any of the code samples in this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "# Copyright 2018 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": { + "height": 51 + }, + "colab_type": "code", + "id": "vlGjGoGMTt-D", + "outputId": "76b95a37-3f1b-4748-bef0-646858f33e25" + }, + "outputs": [], + "source": [ + "# Import Trax\n", + "from trax import layers as tl\n", + "from trax.utils import shapes\n", + "from trax.fastmath import numpy as jnp # For use in defining new layer types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bYWNWL9MJHv9" + }, + "outputs": [], + "source": [ + "# Settings and utilities for handling inputs, outputs, and object properties.\n", + "np.set_printoptions(precision=3) # Reduce visual noise from extra digits.\n", + "\n", + "\n", + "def show_layer_properties(layer_obj, layer_name):\n", + " template = ('{}.n_in: {}\\n'\n", + " '{}.n_out: {}\\n'\n", + " '{}.sublayers: {}\\n'\n", + " '{}.weights: {}\\n')\n", + " print(template.format(layer_name, layer_obj.n_in,\n", + " layer_name, layer_obj.n_out,\n", + " layer_name, layer_obj.sublayers,\n", + " layer_name, layer_obj.weights))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-LQ89rFFsEdk" + }, + "source": [ + "## 1. Layers\n", + "\n", + "The Layer class represents Trax's basic building blocks:\n", + "```\n", + "class Layer:\n", + " \"\"\"Base class for composable layers in a deep learning network.\n", + "\n", + " Layers are the basic building blocks for deep learning models. A Trax layer\n", + " computes a function from zero or more inputs to zero or more outputs,\n", + " optionally using trainable weights (common) and non-parameter state (not\n", + " common). ...\n", + "\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LyLVtdxorDPO" + }, + "source": [ + "### Layers compute functions." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ntZ4_eNQldzL" + }, + "source": [ + "A layer computes a function from zero or more inputs to zero or more outputs.\n", + "The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.\n", + "\n", + "The simplest layers, those with no weights or sublayers, can be used without\n", + "initialization. You can think of them as (pure) mathematical functions that can\n", + "be plugged into neural networks.\n", + "\n", + "For ease of testing and interactive exploration, layer objects implement the\n", + "`__call__ ` method, so you can call them directly on input data:\n", + "```\n", + "y = my_layer(x)\n", + "```\n", + "\n", + "Layers are also objects, so you can inspect their properties. For example:\n", + "```\n", + "print(f'Number of inputs expected by this layer: {my_layer.n_in}')\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hCoapc5le8B7" + }, + "source": [ + "**Example 1.** tl.Relu $[n_{in} = 1, n_{out} = 1]$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 224 + }, + "colab_type": "code", + "id": "V09viOSEQvQe", + "outputId": "a0134cee-0db8-4396-825e-93e695a42ca5" + }, + "outputs": [], + "source": [ + "relu = tl.Relu()\n", + "\n", + "x = np.array([[-2, -1, 0, 1, 2],\n", + " [-20, -10, 0, 10, 20]])\n", + "y = relu(x)\n", + "\n", + "# Show input, output, and two layer properties.\n", + "print(f'x:\\n{x}\\n\\n'\n", + " f'relu(x):\\n{y}\\n\\n'\n", + " f'Number of inputs expected by this layer: {relu.n_in}\\n'\n", + " f'Number of outputs promised by this layer: {relu.n_out}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7sYxIT8crFVE" + }, + "source": [ + "**Example 2.** tl.Concatenate $[n_{in} = 2, n_{out} = 1]$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 255 + }, + "colab_type": "code", + "id": "LMPPNWXLoOZI", + "outputId": "42f595b1-4014-429a-a0b3-2c12d630cd32" + }, + "outputs": [], + "source": [ + "concat = tl.Concatenate()\n", + "\n", + "x0 = np.array([[1, 2, 3],\n", + " [4, 5, 6]])\n", + "x1 = np.array([[10, 20, 30],\n", + " [40, 50, 60]])\n", + "y = concat([x0, x1])\n", + "\n", + "print(f'x0:\\n{x0}\\n\\n'\n", + " f'x1:\\n{x1}\\n\\n'\n", + " f'concat([x1, x2]):\\n{y}\\n\\n'\n", + " f'Number of inputs expected by this layer: {concat.n_in}\\n'\n", + " f'Number of outputs promised by this layer: {concat.n_out}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "z7N1qe91eYyM" + }, + "source": [ + "### Layers are configurable.\n", + "\n", + "Many layer types have creation-time parameters for flexibility. The\n", + "`Concatenate` layer type, for instance, has two optional parameters:\n", + "\n", + "* `axis`: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.\n", + "* `n_items`: number of tensors to join into one by concatenation; default value is 2.\n", + "\n", + "The following example shows `Concatenate` configured for **3** input tensors,\n", + "and concatenation along the initial $(0^{th})$ axis." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "l53Jw23pZ4s6" + }, + "source": [ + "**Example 3.** tl.Concatenate(n_items=3, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 340 + }, + "colab_type": "code", + "id": "bhhWlVLffZtf", + "outputId": "5a8afaa1-66c8-47fe-abcc-e7cfa33bb28c" + }, + "outputs": [], + "source": [ + "concat3 = tl.Concatenate(n_items=3, axis=0)\n", + "\n", + "x0 = np.array([[1, 2, 3],\n", + " [4, 5, 6]])\n", + "x1 = np.array([[10, 20, 30],\n", + " [40, 50, 60]])\n", + "x2 = np.array([[100, 200, 300],\n", + " [400, 500, 600]])\n", + "\n", + "y = concat3([x0, x1, x2])\n", + "\n", + "print(f'x0:\\n{x0}\\n\\n'\n", + " f'x1:\\n{x1}\\n\\n'\n", + " f'x2:\\n{x2}\\n\\n'\n", + " f'concat3([x0, x1, x2]):\\n{y}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1oZv3R8bRMvF" + }, + "source": [ + "### Layers are trainable.\n", + "\n", + "Many layer types include weights that affect the computation of outputs from\n", + "inputs, and they use back-progagated gradients to update those weights.\n", + "\n", + "🚧🚧 *A very small subset of layer types, such as `BatchNorm`, also include\n", + "modifiable weights (called `state`) that are updated based on forward-pass\n", + "inputs/computation rather than back-propagated gradients.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3d64M7wLryji" + }, + "source": [ + "**Initialization**\n", + "\n", + "Trainable layers must be initialized before use. Trax can take care of this\n", + "as part of the overall training process. In other settings (e.g., in tests or\n", + "interactively in a Colab notebook), you need to initialize the\n", + "*outermost/topmost* layer explicitly. For this, use `init`:\n", + "\n", + "```\n", + " def init(self, input_signature, rng=None, use_cache=False):\n", + " \"\"\"Initializes weights/state of this layer and its sublayers recursively.\n", + "\n", + " Initialization creates layer weights and state, for layers that use them.\n", + " It derives the necessary array shapes and data types from the layer's input\n", + " signature, which is itself just shape and data type information.\n", + "\n", + " For layers without weights or state, this method safely does nothing.\n", + "\n", + " This method is designed to create weights/state only once for each layer\n", + " instance, even if the same layer instance occurs in multiple places in the\n", + " network. This enables weight sharing to be implemented as layer sharing.\n", + "\n", + " Args:\n", + " input_signature: `ShapeDtype` instance (if this layer takes one input)\n", + " or list/tuple of `ShapeDtype` instances.\n", + " rng: Single-use random number generator (JAX PRNG key), or `None`;\n", + " if `None`, use a default computed from an integer 0 seed.\n", + " use_cache: If `True`, and if this layer instance has already been\n", + " initialized elsewhere in the network, then return special marker\n", + " values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.\n", + " Else return this layer's newly initialized weights and state.\n", + "\n", + " Returns:\n", + " A `(weights, state)` tuple.\n", + " \"\"\"\n", + "```\n", + "\n", + "Input signatures can be built from scratch using `ShapeDType` objects, or can\n", + "be derived from data via the `signature` function (in module `shapes`):\n", + "```\n", + "def signature(obj):\n", + " \"\"\"Returns a `ShapeDtype` signature for the given `obj`.\n", + "\n", + " A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`\n", + " instances. Note that this function is permissive with respect to its inputs\n", + " (accepts lists or tuples or dicts, and underlying objects can be any type\n", + " as long as they have shape and dtype attributes) and returns the corresponding\n", + " nested structure of `ShapeDtype`.\n", + "\n", + " Args:\n", + " obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict\n", + " of such objects.\n", + "\n", + " Returns:\n", + " A corresponding nested structure of `ShapeDtype` instances.\n", + " \"\"\"\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "yL8HAj6GEAp1" + }, + "source": [ + "**Example 4.** tl.LayerNorm $[n_{in} = 1, n_{out} = 1]$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 221 + }, + "colab_type": "code", + "id": "Ie7iyX91qAx2", + "outputId": "0efecdf5-c0a4-4304-f442-d12fc1a51253" + }, + "outputs": [], + "source": [ + "layer_norm = tl.LayerNorm()\n", + "\n", + "x = np.array([[-2, -1, 0, 1, 2],\n", + " [1, 2, 3, 4, 5],\n", + " [10, 20, 30, 40, 50]]).astype(np.float32)\n", + "layer_norm.init(shapes.signature(x))\n", + "\n", + "y = layer_norm(x)\n", + "\n", + "print(f'x:\\n{x}\\n\\n'\n", + " f'layer_norm(x):\\n{y}\\n')\n", + "print(f'layer_norm.weights:\\n{layer_norm.weights}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "d47gVdGV1vWw" + }, + "source": [ + "### Layers combine into layers.\n", + "\n", + "The Trax library authors encourage users to build networks and network\n", + "components as combinations of existing layers, by means of a small set of\n", + "_combinator_ layers. A combinator makes a list of layers behave as a single\n", + "layer -- by combining the sublayer computations yet looking from the outside\n", + "like any other layer. The combined layer, like other layers, can:\n", + "\n", + "* compute outputs from inputs,\n", + "* update parameters from gradients, and\n", + "* combine with yet more layers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vC1ymG2j0iyp" + }, + "source": [ + "**Combine with `Serial`**\n", + "\n", + "The most common way to combine layers is with the `Serial` combinator:\n", + "```\n", + "class Serial(base.Layer):\n", + " \"\"\"Combinator that applies layers serially (by function composition).\n", + "\n", + " This combinator is commonly used to construct deep networks, e.g., like this::\n", + "\n", + " mlp = tl.Serial(\n", + " tl.Dense(128),\n", + " tl.Relu(),\n", + " tl.Dense(10),\n", + " )\n", + "\n", + " A Serial combinator uses stack semantics to manage data for its sublayers.\n", + " Each sublayer sees only the inputs it needs and returns only the outputs it\n", + " has generated. The sublayers interact via the data stack. For instance, a\n", + " sublayer k, following sublayer j, gets called with the data stack in the\n", + " state left after layer j has applied. The Serial combinator then:\n", + "\n", + " - takes n_in items off the top of the stack (n_in = k.n_in) and calls\n", + " layer k, passing those items as arguments; and\n", + "\n", + " - takes layer k's n_out return values (n_out = k.n_out) and pushes\n", + " them onto the data stack.\n", + "\n", + " A Serial instance with no sublayers acts as a special-case (but useful)\n", + " 1-input 1-output no-op.\n", + " \"\"\"\n", + "```\n", + "If one layer has the same number of outputs as the next layer has inputs (which\n", + "is the usual case), the successive layers behave like function composition:\n", + "\n", + "```\n", + "# h(.) = g(f(.))\n", + "layer_h = Serial(\n", + " layer_f,\n", + " layer_g,\n", + ")\n", + "```\n", + "Note how, inside `Serial`, function composition is expressed naturally as a\n", + "succession of operations, so that no nested parentheses are needed.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uPOnrDa9ViPi" + }, + "source": [ + "**Example 5.** y = layer_norm(relu(x)) $[n_{in} = 1, n_{out} = 1]$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 136 + }, + "colab_type": "code", + "id": "dW5fpusjvjmh", + "outputId": "acdcffe7-23d5-4ecd-df9b-32f48ae77959" + }, + "outputs": [], + "source": [ + "layer_block = tl.Serial(\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + ")\n", + "\n", + "x = np.array([[-2, -1, 0, 1, 2],\n", + " [-20, -10, 0, 10, 20]]).astype(np.float32)\n", + "layer_block.init(shapes.signature(x))\n", + "y = layer_block(x)\n", + "\n", + "print(f'x:\\n{x}\\n\\n'\n", + " f'layer_block(x):\\n{y}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "bRtmN6ckQO1q" + }, + "source": [ + "And we can inspect the block as a whole, as if it were just another layer:\n", + "\n", + "**Example 5'.** Inspecting a `Serial` layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 68 + }, + "colab_type": "code", + "id": "D6BpYddZQ1eu", + "outputId": "1a00c9f2-63a0-450c-d902-c9baf06dc917" + }, + "outputs": [], + "source": [ + "print(f'layer_block: {layer_block}\\n\\n'\n", + " f'layer_block.weights: {layer_block.weights}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kJ8bpYZtE66x" + }, + "source": [ + "**Combine with `Branch`**\n", + "\n", + "The `Branch` combinator arranges layers into parallel computational channels:\n", + "```\n", + "def Branch(*layers, name='Branch'):\n", + " \"\"\"Combinator that applies a list of layers in parallel to copies of inputs.\n", + "\n", + " Each layer in the input list is applied to as many inputs from the stack\n", + " as it needs, and their outputs are successively combined on stack.\n", + "\n", + " For example, suppose one has three layers:\n", + "\n", + " - F: 1 input, 1 output\n", + " - G: 3 inputs, 1 output\n", + " - H: 2 inputs, 2 outputs (h1, h2)\n", + "\n", + " Then Branch(F, G, H) will take 3 inputs and give 4 outputs:\n", + "\n", + " - inputs: a, b, c\n", + " - outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)\n", + "\n", + " As an important special case, a None argument to Branch acts as if it takes\n", + " one argument, which it leaves unchanged. (It acts as a one-arg no-op.)\n", + "\n", + " Args:\n", + " *layers: List of layers.\n", + " name: Descriptive name for this layer.\n", + "\n", + " Returns:\n", + " A branch layer built from the given sublayers.\n", + " \"\"\"\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RlPcnRtdIVgq" + }, + "source": [ + "Residual blocks, for example, are implemented using `Branch`:\n", + "```\n", + "def Residual(*layers, shortcut=None):\n", + " \"\"\"Wraps a series of layers with a residual connection.\n", + "\n", + " Args:\n", + " *layers: One or more layers, to be applied in series.\n", + " shortcut: If None (the usual case), the Residual layer computes the\n", + " element-wise sum of the stack-top input with the output of the layer\n", + " series. If specified, the `shortcut` layer applies to a copy of the\n", + " inputs and (elementwise) adds its output to the output from the main\n", + " layer series.\n", + "\n", + " Returns:\n", + " A layer representing a residual connection paired with a layer series.\n", + " \"\"\"\n", + " layers = _ensure_flat(layers)\n", + " layer = layers[0] if len(layers) == 1 else Serial(layers)\n", + " return Serial(\n", + " Branch(shortcut, layer),\n", + " Add(),\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ruX4aFMdUOwS" + }, + "source": [ + "Here's a simple code example to highlight the mechanics." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JGGnKjg4ESIg" + }, + "source": [ + "**Example 6.** `Branch`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 204 + }, + "colab_type": "code", + "id": "lw6A2YwuW-Ul", + "outputId": "a07ef350-bafa-4fa7-a083-19e6f725b3ce" + }, + "outputs": [], + "source": [ + "relu = tl.Relu()\n", + "times_100 = tl.Fn(\"Times100\", lambda x: x * 100.0)\n", + "branch_relu_t100 = tl.Branch(relu, times_100)\n", + "\n", + "x = np.array([[-2, -1, 0, 1, 2],\n", + " [-20, -10, 0, 10, 20]])\n", + "branch_relu_t100.init(shapes.signature(x))\n", + "\n", + "y0, y1 = branch_relu_t100(x)\n", + "\n", + "print(f'x:\\n{x}\\n\\n'\n", + " f'y0:\\n{y0}\\n\\n'\n", + " f'y1:\\n{y1}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zr2ZZ1vO8T8V" + }, + "source": [ + "## 2. Inputs and Outputs\n", + "\n", + "Trax allows layers to have multiple input streams and output streams. When\n", + "designing a network, you have the flexibility to use layers that:\n", + "\n", + " - process a single data stream ($n_{in} = n_{out} = 1$),\n", + " - process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, ... $),\n", + " - split or inject data streams ($n_{in} < n_{out}$), or\n", + " - merge or remove data streams ($n_{in} > n_{out}$).\n", + "\n", + "We saw in section 1 the example of `Residual`, which involves both a split and a merge:\n", + "```\n", + " ...\n", + " return Serial(\n", + " Branch(shortcut, layer),\n", + " Add(),\n", + " )\n", + "```\n", + "In other words, layer by layer:\n", + "\n", + " - `Branch(shortcut, layers)`: makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [$n_{in} = 1$, $n_{out} = 2$]\n", + " - `Add()`: combines the two streams back into one by adding two tensors elementwise. [$n_{in} = 2$, $n_{out} = 1$]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1FEttSCVVM3T" + }, + "source": [ + "### Data Stack\n", + "\n", + "Trax supports flexible data flows through a network via a data stack, which is\n", + "managed by the `Serial` combinator:\n", + "```\n", + "class Serial(base.Layer):\n", + " \"\"\"Combinator that applies layers serially (by function composition).\n", + "\n", + " ...\n", + "\n", + " A Serial combinator uses stack semantics to manage data for its sublayers.\n", + " Each sublayer sees only the inputs it needs and returns only the outputs it\n", + " has generated. The sublayers interact via the data stack. For instance, a\n", + " sublayer k, following sublayer j, gets called with the data stack in the\n", + " state left after layer j has applied. The Serial combinator then:\n", + "\n", + " - takes n_in items off the top of the stack (n_in = k.n_in) and calls\n", + " layer k, passing those items as arguments; and\n", + "\n", + " - takes layer k's n_out return values (n_out = k.n_out) and pushes\n", + " them onto the data stack.\n", + "\n", + " ...\n", + "\n", + " \"\"\"\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "5DAiajI-Gzk4" + }, + "source": [ + "**Simple Case 1 -- Each layer takes one input and has one output.**\n", + "\n", + "This is in effect a single data stream pipeline, and the successive layers\n", + "behave like function composition:\n", + "\n", + "```\n", + "# s(.) = h(g(f(.)))\n", + "layer_s = Serial(\n", + " layer_f,\n", + " layer_g,\n", + " layer_h,\n", + ")\n", + "```\n", + "Note how, inside `Serial`, function composition is expressed naturally as a\n", + "succession of operations, so that no nested parentheses are needed and the\n", + "order of operations matches the textual order of layers.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WR8bh64tIzIY" + }, + "source": [ + "**Simple Case 2 -- Each layer consumes all outputs of the preceding layer.**\n", + "\n", + "This is still a single pipeline, but data streams internal to it can split and\n", + "merge. The `Residual` example above illustrates this kind.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ACG88RdtLbvG" + }, + "source": [ + "**General Case -- Successive layers interact via the data stack.**\n", + "\n", + "As described in the `Serial` class docstring, each layer gets its inputs from\n", + "the data stack after the preceding layer has put its outputs onto the stack.\n", + "This covers the simple cases above, but also allows for more flexible data\n", + "interactions between non-adjacent layers. The following example is schematic:\n", + "```\n", + "x, y_target = get_batch_of_labeled_data()\n", + "\n", + "model_plus_eval = Serial(\n", + " my_fancy_deep_model(), # Takes one arg (x) and has one output (y_hat)\n", + " my_eval(), # Takes two args (y_hat, y_target) and has one output (score)\n", + ")\n", + "\n", + "eval_score = model_plus_eval((x, y_target))\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "66hUOOYRQqej" + }, + "source": [ + "Here is the corresponding progression of stack states:\n", + "\n", + "0. At start: _--empty--_\n", + "0. After `get_batch_of_labeled_data()`: *x*, *y_target*\n", + "0. After `my_fancy_deep_model()`: *y_hat*, *y_target*\n", + "0. After `my_eval()`: *score*\n", + "\n", + "Note in particular how the application of the model (between stack states 1\n", + "and 2) only uses and affects the top element on the stack: `x` --> `y_hat`.\n", + "The rest of the data stack (`y_target`) comes in use only later, for the\n", + "eval function." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "65ite-671cTT" + }, + "source": [ + "## 3. Defining New Layer Classes\n", + "\n", + "If you need a layer type that is not easily defined as a combination of\n", + "existing layer types, you can define your own layer classes in a couple\n", + "different ways." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hHSaD9H6hDTf" + }, + "source": [ + "### With the `Fn` layer-creating function.\n", + "\n", + "Many layer types needed in deep learning compute pure functions from inputs to\n", + "outputs, using neither weights nor randomness. You can use Trax's `Fn` function\n", + "to define your own pure layer types:\n", + "```\n", + "def Fn(name, f, n_out=1): # pylint: disable=invalid-name\n", + " \"\"\"Returns a layer with no weights that applies the function `f`.\n", + "\n", + " `f` can take and return any number of arguments, and takes only positional\n", + " arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).\n", + " The following, for example, would create a layer that takes two inputs and\n", + " returns two outputs -- element-wise sums and maxima:\n", + "\n", + " `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`\n", + "\n", + " The layer's number of inputs (`n_in`) is automatically set to number of\n", + " positional arguments in `f`, but you must explicitly set the number of\n", + " outputs (`n_out`) whenever it's not the default value 1.\n", + "\n", + " Args:\n", + " name: Class-like name for the resulting layer; for use in debugging.\n", + " f: Pure function from input tensors to output tensors, where each input\n", + " tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.\n", + " Output tensors must be packaged as specified in the `Layer` class\n", + " docstring.\n", + " n_out: Number of outputs promised by the layer; default value 1.\n", + "\n", + " Returns:\n", + " Layer executing the function `f`.\n", + " \"\"\"\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "TX30lGLXcjB1" + }, + "source": [ + "**Example 7.** Use `Fn` to define a new layer type:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 153 + }, + "colab_type": "code", + "id": "vKrc6XMV9ErS", + "outputId": "13f74094-e43e-4267-9055-f3d55d58ae53" + }, + "outputs": [], + "source": [ + "# Define new layer type.\n", + "def Gcd():\n", + " \"\"\"Returns a layer to compute the greatest common divisor, elementwise.\"\"\"\n", + " return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))\n", + "\n", + "\n", + "# Use it.\n", + "gcd = Gcd()\n", + "\n", + "x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n", + "x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])\n", + "\n", + "y = gcd((x0, x1))\n", + "\n", + "print(f'x0:\\n{x0}\\n\\n'\n", + " f'x1:\\n{x1}\\n\\n'\n", + " f'gcd((x0, x1)):\\n{y}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "W74Eehgp5A57" + }, + "source": [ + "The `Fn` function infers `n_in` (number of inputs) as the length of `f`'s arg\n", + "list. `Fn` does not infer `n_out` (number out outputs) though. If your `f` has\n", + "more than one output, you need to give an explicit value using the `n_out`\n", + "keyword arg." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2lCjml7SCR-u" + }, + "source": [ + "**Example 8.** `Fn` with multiple outputs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 204 + }, + "colab_type": "code", + "id": "rfnA2B9ZczWK", + "outputId": "9ffd7648-ffda-453e-b88b-4aa4ba8ea482" + }, + "outputs": [], + "source": [ + "# Define new layer type.\n", + "def SumAndMax():\n", + " \"\"\"Returns a layer to compute sums and maxima of two input tensors.\"\"\"\n", + " return tl.Fn('SumAndMax',\n", + " lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),\n", + " n_out=2)\n", + "\n", + "\n", + "# Use it.\n", + "sum_and_max = SumAndMax()\n", + "\n", + "x0 = np.array([1, 2, 3, 4, 5])\n", + "x1 = np.array([10, -20, 30, -40, 50])\n", + "\n", + "y0, y1 = sum_and_max([x0, x1])\n", + "\n", + "print(f'x0:\\n{x0}\\n\\n'\n", + " f'x1:\\n{x1}\\n\\n'\n", + " f'y0:\\n{y0}\\n\\n'\n", + " f'y1:\\n{y1}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GrXQUSbKDs41" + }, + "source": [ + "**Example 9.** Use `Fn` to define a configurable layer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 374 + }, + "colab_type": "code", + "id": "h1KwpmFpEIK3", + "outputId": "9f6e7009-04a0-46c9-b005-35c091f720eb" + }, + "outputs": [], + "source": [ + "# Function defined in trax/layers/core.py:\n", + "def Flatten(n_axes_to_keep=1):\n", + " \"\"\"Returns a layer that combines one or more trailing axes of a tensor.\n", + "\n", + " Flattening keeps all the values of the input tensor, but reshapes it by\n", + " collapsing one or more trailing axes into a single axis. For example, a\n", + " `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape\n", + " `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.\n", + "\n", + " Args:\n", + " n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;\n", + " collapse only the axes after these.\n", + " \"\"\"\n", + " layer_name = f'Flatten_keep{n_axes_to_keep}'\n", + "\n", + " def f(x):\n", + " in_rank = len(x.shape)\n", + " if in_rank <= n_axes_to_keep:\n", + " raise ValueError(f'Input rank ({in_rank}) must exceed the number of '\n", + " f'axes to keep ({n_axes_to_keep}) after flattening.')\n", + " return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))\n", + "\n", + " return tl.Fn(layer_name, f)\n", + "\n", + "\n", + "flatten_keep_1_axis = Flatten(n_axes_to_keep=1)\n", + "flatten_keep_2_axes = Flatten(n_axes_to_keep=2)\n", + "\n", + "x = np.array([[[1, 2, 3],\n", + " [10, 20, 30],\n", + " [100, 200, 300]],\n", + " [[4, 5, 6],\n", + " [40, 50, 60],\n", + " [400, 500, 600]]])\n", + "\n", + "y1 = flatten_keep_1_axis(x)\n", + "y2 = flatten_keep_2_axes(x)\n", + "\n", + "print(f'x:\\n{x}\\n\\n'\n", + " f'flatten_keep_1_axis(x):\\n{y1}\\n\\n'\n", + " f'flatten_keep_2_axes(x):\\n{y2}')\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "cqM6WJwNhoHI" + }, + "source": [ + "### By defining a `Layer` subclass\n", + "\n", + "If you need a layer type that uses trainable weights (or state), you can extend\n", + "the base `Layer` class:\n", + "```\n", + "class Layer:\n", + " \"\"\"Base class for composable layers in a deep learning network.\n", + "\n", + " ...\n", + "\n", + " Authors of new layer subclasses typically override at most two methods of\n", + " the base `Layer` class:\n", + "\n", + " `forward(inputs)`:\n", + " Computes this layer's output as part of a forward pass through the model.\n", + "\n", + " `init_weights_and_state(self, input_signature)`:\n", + " Initializes weights and state for inputs with the given signature.\n", + "\n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tZlzxNUigD_4" + }, + "source": [ + "The `forward` method uses *weights stored in the layer object* (`self.weights`)\n", + "to compute outputs from inputs. For example, here is the definition of\n", + "`forward` for Trax's `Dense` layer:\n", + "```\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input, except the final dimension\n", + " is the layer's `n_units` value.\n", + " \"\"\"\n", + " if self._use_bias:\n", + " if not isinstance(self.weights, (tuple, list)):\n", + " raise ValueError(f'Weights should be a (w, b) tuple or list; '\n", + " f'instead got: {self.weights}')\n", + " w, b = self.weights\n", + " return jnp.dot(x, w) + b # Affine map.\n", + " else:\n", + " w = self.weights\n", + " return jnp.dot(x, w) # Linear map.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PJEEyX9_iPbk" + }, + "source": [ + "Layer weights must be initialized before the layer can be used; the\n", + "`init_weights_and_state` method specifies how. Continuing the `Dense` example,\n", + "here is the corresponding initialization code:\n", + "```\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\n", + "\n", + " Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the\n", + " default case), or a `w` tensor for layers created with `use_bias=False`.\n", + "\n", + " Args:\n", + " input_signature: `ShapeDtype` instance characterizing the input this layer\n", + " should compute on.\n", + " \"\"\"\n", + " shape_w = (input_signature.shape[-1], self._n_units)\n", + " shape_b = (self._n_units,)\n", + " rng_w, rng_b = fastmath.random.split(self.rng, 2)\n", + " w = self._kernel_initializer(shape_w, rng_w)\n", + "\n", + " if self._use_bias:\n", + " b = self._bias_initializer(shape_b, rng_b)\n", + " self.weights = (w, b)\n", + " else:\n", + " self.weights = w\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "D77mYZZD41QO" + }, + "source": [ + "### By defining a `Combinator` subclass\n", + "\n", + "*TBD*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PgdQvZ5G6Aei" + }, + "source": [ + "## 4. Testing and Debugging Layer Classes\n", + "\n", + "*TBD*" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", + "kind": "private" + }, + "name": "Trax Layers Intro", + "provenance": [ + { + "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", + "timestamp": 1569980697572 + }, + { + "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", + "timestamp": 1563927451951 + } + ] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-4-Early-Stopping.ipynb b/resources/examples/ipynb/Example-4-Early-Stopping.ipynb new file mode 100644 index 000000000..f85dd834e --- /dev/null +++ b/resources/examples/ipynb/Example-4-Early-Stopping.ipynb @@ -0,0 +1,613 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6NWA5uxOmBVz" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OLUMD0tPP6Hd" + }, + "outputs": [], + "source": [ + "import collections\n", + "import functools\n", + "import os\n", + "import sys\n", + "import time\n", + "\n", + "import numpy as np\n", + "import psutil\n", + "from absl import logging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "from trax import fastmath\n", + "from trax import layers as tl\n", + "from trax.fastmath import numpy as jnp\n", + "from trax.learning.supervised import training\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nG4CK5NsP6He" + }, + "outputs": [], + "source": [ + "class MyLoop(training.Loop):\n", + " def __init__(\n", + " self,\n", + " *args, **kwargs\n", + " ):\n", + " super().__init__(\n", + " *args, **kwargs\n", + " )\n", + " self._stop_training = False\n", + "\n", + " def run(self, n_steps=1):\n", + " \"\"\"Just add a logic to break the loop to ``training.Loop.run`` when\n", + " the early stopping condition is satisfied.\n", + " \"\"\"\n", + "\n", + " with self._open_summary_writers() as (\n", + " train_summary_writers,\n", + " eval_summary_writers,\n", + " ):\n", + " process = psutil.Process(os.getpid())\n", + " loss_acc, step_acc = 0.0, 0\n", + " start_time = time.time()\n", + " optimizer_metrics_acc = collections.defaultdict(float)\n", + " for i in range(n_steps):\n", + " prev_task_index = self._which_task(self._step)\n", + " self._step += 1\n", + " task_index = self._which_task(self._step)\n", + " task_changed = task_index != prev_task_index\n", + "\n", + " if task_changed:\n", + " loss_acc, step_acc = 0.0, 0\n", + "\n", + " loss, optimizer_metrics = self._run_one_step(task_index, task_changed)\n", + "\n", + " optimizer_metrics, loss = fastmath.nested_map(\n", + " functools.partial(tl.mean, self._n_devices),\n", + " (optimizer_metrics, loss),\n", + " )\n", + "\n", + " loss_acc += loss\n", + " # Log loss every 50 steps, every step in memory-efficient trainers.\n", + " if self._step % 50 == 0 or self._use_memory_efficient_trainer:\n", + " self._log_step(\"Loss: %.4f\" % loss, stdout=False)\n", + " step_acc += 1\n", + " for metric_name, value in optimizer_metrics.items():\n", + " optimizer_metrics_acc[metric_name] += value\n", + "\n", + " if self._checkpoint_at(self.step):\n", + " self.save_checkpoint(\"model\")\n", + " if self._permanent_checkpoint_at(self.step):\n", + " self.save_checkpoint(f\"model_{self.step}\")\n", + " if self._eval_at(self.step):\n", + " logging.info(\n", + " \"cpu memory use (MB): %.2f\",\n", + " process.memory_info().rss / float(1024 * 1024),\n", + " )\n", + " elapsed_time = time.time() - start_time\n", + " self._log_training_progress(\n", + " task=self._tasks[task_index],\n", + " total_loss=loss_acc,\n", + " n_steps=step_acc,\n", + " elapsed_time=elapsed_time,\n", + " optimizer_metrics=optimizer_metrics_acc,\n", + " summary_writer=train_summary_writers[task_index],\n", + " )\n", + " self.run_evals(eval_summary_writers)\n", + " loss_acc, step_acc = 0.0, 0\n", + " start_time = time.time()\n", + " optimizer_metrics_acc = collections.defaultdict(float)\n", + "\n", + " if self._checkpoint_at(self.step):\n", + " if self._checkpoint_low_metric is not None and self._at_lowest():\n", + " self.save_checkpoint(f\"lowest_{self._checkpoint_low_metric}\")\n", + " if self._checkpoint_high_metric is not None and self._at_highest():\n", + " self.save_checkpoint(f\"highest_{self._checkpoint_high_metric}\")\n", + "\n", + " for callback in self._callbacks:\n", + " if callback.call_at(self.step):\n", + " if callback.__class__.__name__ == 'EarlyStopping':\n", + " #added to check for earlystopping callback after\n", + " # history was updated.\n", + " #callback.on_step_end execute before history was\n", + " #updated.\n", + " best_step = callback.on_step_begin_with_history(self.step)\n", + "\n", + " if not self._stop_training and self.step == n_steps:\n", + " self._log_step(\"Did not meet early stopping condition.\")\n", + "\n", + " if self._stop_training:\n", + " # added to stop the training.\n", + " self._log_step(f\"Early stopping... \"\n", + " f\" the best step at {best_step}\")\n", + " break\n", + "\n", + " self._eval_model.weights = self._model.weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rfncVhM7P6Hg" + }, + "outputs": [], + "source": [ + "def callback_earlystopper(\n", + " monitor=None,\n", + " min_delta=0,\n", + " patience=0,\n", + " mode=\"auto\",\n", + " restore_best_checkpoint=True\n", + "):\n", + " \"\"\"Wrap the EarlyStopping class into a callable.\n", + "\n", + " Returns an early stopping.\n", + "\n", + " Args:\n", + " monitor: Quantity to be monitored.\n", + "\n", + " min_delta: Minimum change in the monitored quantity\n", + " to qualify as an improvement, i.e. an absolute\n", + " change of less than min_delta, will count as no\n", + " improvement.\n", + "\n", + " patience: ``patience`` times ``n_steps_per_checkpoint`` will be\n", + " the total number of steps without improvement\n", + " after which training will be stopped.\n", + "\n", + " mode: One of ``{\"auto\", \"min\", \"max\"}``. In ``min``(``max``) mode,\n", + " training will stop when the quantity monitored has stopped\n", + " decreasing(increasing) during the number of steps assigned\n", + " in ``patience``; in ``\"auto\"``\n", + " mode, the direction is automatically inferred\n", + " from the name of the monitored quantity.\n", + "\n", + " restore_best_checkpoint: Whether to restore model from\n", + " the checkpoint with the best value of the monitored quantity.\n", + " If False, the model weights obtained at the last step of\n", + " training are used. If True and there is an early stopping,\n", + " the best checkpoint will be restored.\n", + " \"\"\"\n", + "\n", + " if mode not in [\"auto\", \"max\", \"min\"]:\n", + " self._loop._log_step(\n", + " f\"Early stopping mode='{mode}' is unknown, \" \"fallback to 'auto' mode\"\n", + " )\n", + " mode = \"auto\"\n", + "\n", + " class EarlyStopping:\n", + " \"\"\"Create a call back taht activates early stopping.\n", + "\n", + " Activate early stopping.\n", + " \"\"\"\n", + "\n", + " def __init__(self, loop):\n", + " \"\"\"Configures an early stopping.\n", + " This is inspired by keras.callbacks.EarlyStopping.\n", + "\n", + " Args:\n", + " loop: training ``Loop`` from the current training.\n", + "\n", + " \"\"\"\n", + "\n", + " self._loop = loop\n", + " self.monitor = monitor\n", + " self.min_delta = jnp.abs(min_delta)\n", + " self.patience = jnp.maximum(patience, 1)\n", + "\n", + " self.restore_best_checkpoint = restore_best_checkpoint\n", + "\n", + " if mode == \"min\":\n", + " self.monitor_op = jnp.less\n", + " elif mode == \"max\":\n", + " self.monitor_op = jnp.greater\n", + " else:\n", + " if self.monitor.endswith(\"Accuracy\"):\n", + " self.monitor_op = jnp.greater\n", + " else:\n", + " self.monitor_op = jnp.less\n", + "\n", + " if self.monitor_op == np.greater:\n", + " self.min_delta *= 1\n", + " else:\n", + " self.min_delta *= -1\n", + "\n", + " self.wait = 0\n", + " self.stopped_step = 1\n", + " self.best = jnp.inf if self.monitor_op == jnp.less else -jnp.inf\n", + " self.best_step = 1\n", + " self.best_checkpoint_path = None\n", + "\n", + " def _is_metric_exist(self):\n", + " metric_names = [\n", + " name\n", + " for eval_task in self._loop._eval_tasks\n", + " for name in eval_task.metric_names\n", + " ]\n", + " return self.monitor in metric_names\n", + "\n", + " def call_at(self, step):\n", + " return self._loop._eval_at(step)\n", + "\n", + " def on_step_begin(self, step):\n", + " if not self._is_metric_exist():\n", + " # Raise error if the monitor name is not in evaluation task.\n", + " self._loop._log_step(\n", + " f\"Early Stopping metric '{self.monitor}' \" \"is not in eval_tasks.\"\n", + " )\n", + " self._loop._log_step(\n", + " \"Select one of \" f\"them from here {self.metric_names}.\"\n", + " )\n", + "\n", + " raise SystemExit(\"Monitoring metric not found.\")\n", + "\n", + " def on_step_end(self, step):\n", + " pass\n", + "\n", + " def on_step_begin_with_history(self, step):\n", + " if self.restore_best_checkpoint and self.best_checkpoint_path is None:\n", + " self._loop.save_checkpoint(\"best_checkpoint\")\n", + " self.best_checkpoint_path = os.path.join(\n", + " self._loop._output_dir, \"best_checkpoint.pkl.gz\"\n", + " )\n", + "\n", + " self.wait += 1\n", + " current_step, current = self._get_monitor_value()\n", + "\n", + " if current is None:\n", + " return\n", + "\n", + " if self._is_improvement(current, self.best):\n", + " self.best = current\n", + " self.best_step = current_step\n", + " self._loop.save_checkpoint(\"best_checkpoint\")\n", + "\n", + " # reset wait\n", + " self.wait = 0\n", + "\n", + " if self.wait >= self.patience and step > 1:\n", + " self.stopped_step = current_step\n", + " self._loop._stop_training = True\n", + "\n", + " if (\n", + " self.restore_best_checkpoint\n", + " and self.best_checkpoint_path is not None\n", + " ):\n", + " self._loop.load_checkpoint(self.best_checkpoint_path)\n", + " self._loop._log_step(\n", + " f\"Best checkpoint was restored from Step {self.best_step}.\"\n", + " )\n", + "\n", + " return self.best_step\n", + "\n", + " def _is_improvement(self, monitor_value, reference_value):\n", + " return self.monitor_op(monitor_value - self.min_delta, reference_value)\n", + "\n", + " def _get_monitor_value(self):\n", + " step, monitor_value = self._loop.history.get(\n", + " \"eval\", \"metrics/\" + self.monitor\n", + " )[-1]\n", + " return step, monitor_value\n", + "\n", + " return EarlyStopping" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sJHUx_nSP6Hh" + }, + "source": [ + "## Linear Regression\n", + "## Generate data for linear model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dKYZQY-pP6Hi" + }, + "outputs": [], + "source": [ + "def get_data_linear():\n", + " while True:\n", + " x = np.random.randint(low=1, high=10) * 1.0\n", + " y = x * 2.0 - 1\n", + " yield (np.array([x]), np.array([y]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SCTZW1pBP6Hj" + }, + "outputs": [], + "source": [ + "data_linear = get_data_linear()\n", + "print(next(data_linear))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4pcAhWJMP6Hk" + }, + "outputs": [], + "source": [ + "from trax.data.preprocessing import inputs as preprocessing\n", + "\n", + "data_pipeline = preprocessing.Serial(preprocessing.Batch(50), preprocessing.AddLossWeights(), )\n", + "data_stream = data_pipeline(data_linear)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2vK15-1oP6Hl" + }, + "source": [ + "## Build a simple linear model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xzN0oZBCP6Hl" + }, + "outputs": [], + "source": [ + "model_linear = tl.Serial(tl.Dense(1))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qi0bM41PP6Hl" + }, + "source": [ + "## Train a linear model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d0_9qZHVP6Hm" + }, + "outputs": [], + "source": [ + "from trax import optimizers as optimizers\n", + "\n", + "# Use the same data_stream for both training and evaluation\n", + "train_task = training.TrainTask(\n", + " labeled_data=data_stream,\n", + " loss_layer=tl.L2Loss(),\n", + " optimizer=optimizers.SGD(0.01),\n", + " n_steps_per_checkpoint=10,\n", + ")\n", + "\n", + "eval_task = training.EvalTask(\n", + " labeled_data=data_stream, metrics=[tl.L2Loss()], n_eval_batches=15,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R5ngyoYSP6Hm" + }, + "source": [ + "## Add early stopping function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SKetNF4LP6Hm" + }, + "outputs": [], + "source": [ + "earlystopping = callback_earlystopper(monitor='L2Loss', min_delta=1e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D2XjQO80P6Hn" + }, + "outputs": [], + "source": [ + "# Delete the training folder\n", + "!rm -r linear_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mCrc_bXZP6Hn" + }, + "outputs": [], + "source": [ + "model_linear = tl.Serial(tl.Dense(1))\n", + "training_loop = MyLoop(\n", + " model=model_linear, tasks=train_task, eval_tasks=[eval_task], output_dir=\"./linear_model\",\n", + " callbacks=[earlystopping]\n", + ")\n", + "# training_loop.save_checkpoint(f'step_{training_loop.step}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kFURD6T4P6Hn" + }, + "outputs": [], + "source": [ + "training_loop.run(1500)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lg_ONworP6Hn" + }, + "source": [ + "## Change patience\n", + "patience = 10 means it will wait for 10 x 10 = 100 steps (patience * n_steps_per_checkpoint ) to before making a decision to stop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IStFKG7GP6Hn" + }, + "outputs": [], + "source": [ + "earlystopping = callback_earlystopper(monitor='L2Loss', patience=10, min_delta=1e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pihrcvTtP6Ho" + }, + "outputs": [], + "source": [ + "# Delete the training folder\n", + "!rm -r linear_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UvjDLZd3P6Ho" + }, + "outputs": [], + "source": [ + "model_linear = tl.Serial(tl.Dense(1))\n", + "training_loop = MyLoop(\n", + " model=model_linear, tasks=train_task, eval_tasks=[eval_task], output_dir=\"./linear_model\",\n", + " callbacks=[earlystopping]\n", + ")\n", + "# training_loop.save_checkpoint(f'step_{training_loop.step}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bAsft27BP6Ho" + }, + "outputs": [], + "source": [ + "training_loop.run(1500)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6HyIjZWBP6Ho" + }, + "source": [ + "## Make a prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d7bVzat7P6Ho" + }, + "outputs": [], + "source": [ + "test_data = np.array([[2.0], [3.0], [10.0], [44.0]])\n", + "model_linear(test_data)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "earlystopping.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-5-TF-Numpy-And-Keras.ipynb b/resources/examples/ipynb/Example-5-TF-Numpy-And-Keras.ipynb new file mode 100644 index 000000000..62e4dccb4 --- /dev/null +++ b/resources/examples/ipynb/Example-5-TF-Numpy-And-Keras.ipynb @@ -0,0 +1,475 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7yuytuIllsv1" + }, + "source": [ + "# Using Trax with TensorFlow NumPy and Keras\n", + "\n", + "This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/tf_numpy_and_keras.ipynb)) shows how you can run [Trax](https://trax-ml.readthedocs.io/en/latest/) directly with [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy). You will also see how to use Trax layers and models inside [Keras](https://keras.io/) so you can use Trax in production, e.g., with [TensorFlow.js](https://www.tensorflow.org/js/) or [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving).\n", + "\n", + " 1. **Trax with TensorFlow NumPy**: use Trax with [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) without any code changes\n", + " 1. **Convert Trax to Keras**: how to get a [Keras](https://keras.io/) layer for your Trax model and use it\n", + " 1. **Exporting Trax Models for Deployment**: how to export Trax models to [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model)\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-LQ89rFFsEdk" + }, + "source": [ + "## 1. Trax with TensorFlow NumPy\n", + "\n", + "In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n", + "\n", + "The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n", + "\n", + "You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n", + "\n", + "Let's train the sentiment analysis model from the [Trax intro](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb) using TensorFlow NumPy to see how it works." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BIl27504La0G" + }, + "source": [ + "**General Setup**\n", + "\n", + "Execute the following few cells (once) before running any of the code samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "executionInfo": { + "elapsed": 38104, + "status": "ok", + "timestamp": 1607390269924, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "executionInfo": { + "elapsed": 309, + "status": "ok", + "timestamp": 1607390270242, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "vlGjGoGMTt-D", + "outputId": "279a980e-1e71-4080-9587-d89aeb17ebc6" + }, + "outputs": [], + "source": [ + "# Install and import Trax\n", + "!pip install -q -U git+https://github.com/google/trax@master\n", + "\n", + "import os\n", + "import numpy as np\n", + "import trax" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O_3JcfZaT5oP" + }, + "source": [ + "Here is how you can set the fastmath backend to `tensorflow-numpy` and verify that it's been set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 286, + "status": "ok", + "timestamp": 1607390270535, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "djTiSLcaNFGa", + "outputId": "bac38e28-d1e5-41bd-9054-d85913fc2900" + }, + "outputs": [], + "source": [ + "# Use the tensorflow-numpy backend.\n", + "trax.fastmath.set_backend('tensorflow-numpy')\n", + "print(trax.fastmath.backend_name())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 15126, + "status": "ok", + "timestamp": 1607390285667, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "AV5wrgjZ10yU", + "outputId": "6385fbe2-5a8e-415c-8851-b5bef099e02f" + }, + "outputs": [], + "source": [ + "# Create data streams.\n", + "train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()\n", + "eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()\n", + "\n", + "data_pipeline = trax.data.Serial(\n", + " trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),\n", + " trax.data.Shuffle(),\n", + " trax.data.FilterByLength(max_length=2048, length_keys=[0]),\n", + " trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],\n", + " batch_sizes=[512, 128, 32, 8, 1],\n", + " length_keys=[0]),\n", + " trax.data.AddLossWeights()\n", + " )\n", + "train_batches_stream = data_pipeline(train_stream)\n", + "eval_batches_stream = data_pipeline(eval_stream)\n", + "\n", + "# Print example shapes.\n", + "example_batch = next(train_batches_stream)\n", + "print(f'batch shapes = {[x.shape for x in example_batch]}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 409, + "status": "ok", + "timestamp": 1607390286085, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "WoSz5plIyXOU", + "outputId": "aa1db911-96fb-430b-8360-1a6e3f764cee" + }, + "outputs": [], + "source": [ + "# Create the model.\n", + "from trax import layers as tl\n", + "\n", + "model = tl.Serial(\n", + " tl.Embedding(vocab_size=8192, d_feature=256),\n", + " tl.Mean(axis=1), # Average on axis 1 (length of sentence).\n", + " tl.Dense(2), # Classify 2 classes.\n", + ")\n", + "\n", + "# You can print model structure.\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 79139, + "status": "ok", + "timestamp": 1607390365232, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "d6bIKUO-3Cw8", + "outputId": "ba4199f4-cc31-459e-b46c-d14ec2f4ef68" + }, + "outputs": [], + "source": [ + "# Train the model.\n", + "from trax.supervised import training\n", + "\n", + "# Training task.\n", + "train_task = training.TrainTask(\n", + " labeled_data=train_batches_stream,\n", + " loss_layer=tl.WeightedCategoryCrossEntropy(),\n", + " optimizer=trax.optimizers.Adam(0.01),\n", + " n_steps_per_checkpoint=500,\n", + ")\n", + "\n", + "# Evaluaton task.\n", + "eval_task = training.EvalTask(\n", + " labeled_data=eval_batches_stream,\n", + " metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],\n", + " n_eval_batches=20 # For less variance in eval numbers.\n", + ")\n", + "\n", + "# Training loop saves checkpoints to output_dir.\n", + "output_dir = os.path.expanduser('~/output_dir/')\n", + "training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir=output_dir)\n", + "\n", + "# Run 2000 steps (batches).\n", + "training_loop.run(2000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 832, + "status": "ok", + "timestamp": 1607390366089, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "yuPu37Lp7GST", + "outputId": "b95f944d-b5e8-44c6-829c-25c0b0b08f38" + }, + "outputs": [], + "source": [ + "# Run on an example.\n", + "example_input = next(eval_batches_stream)[0][0]\n", + "example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')\n", + "print(f'example input_str: {example_input_str}')\n", + "sentiment_activations = model(example_input[None, :]) # Add batch dimension.\n", + "print(f'Model returned sentiment activations: {np.asarray(sentiment_activations)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8wgfJyhdihfR" + }, + "source": [ + "## 2. Convert Trax to Keras\n", + "\n", + "Thanks to [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) you can convert the model you just trained into a [Keras](https://keras.io/) layer using `trax.AsKeras`. This allows you to:\n", + "\n", + "* use Trax layers inside Keras models\n", + "* run Trax models with existing Keras input pipelines\n", + "* export Trax models to [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model)\n", + "\n", + "When creating a Keras layer from a Trax one, the Keras layer weights will get initialized to the ones the Trax layer had at the moment of creation. In this way, you can create Keras layers from pre-trained Trax models and save them as SavedModel as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 322, + "status": "ok", + "timestamp": 1607390366418, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "bxSLRyjftuxH", + "outputId": "6ec7180b-ff85-47e4-bba2-3634df913ad4" + }, + "outputs": [], + "source": [ + "# Convert the model into a Keras layer, use the weights from model.\n", + "keras_layer = trax.AsKeras(model)\n", + "print(keras_layer)\n", + "\n", + "# Run the Keras layer to verify it returns the same result.\n", + "sentiment_activations = keras_layer(example_input[None, :])\n", + "print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 3983, + "status": "ok", + "timestamp": 1607390370412, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "r8C-FoFGxGE1", + "outputId": "0edfd1fa-2677-494a-f03f-2cc87324e88c" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "\n", + "# Create a full Keras model using the layer from Trax.\n", + "inputs = tf.keras.Input(shape=(None,), dtype='int32')\n", + "hidden = keras_layer(inputs) \n", + "# You can add other Keras layers here operating on hidden.\n", + "outputs = hidden\n", + "keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", + "print(keras_model)\n", + "\n", + "# Run the Keras model to verify it returns the same result.\n", + "sentiment_activations = keras_model(example_input[None, :])\n", + "print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EQH1bvXwy5fE" + }, + "source": [ + "## 3. Exporting Trax Models for Deployment\n", + "\n", + "You can export the Keras model to disk as [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model). It's as simple as calling `keras_model.save` and allows you to use models with TF tools [TensorFlow.js](https://www.tensorflow.org/js/), [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) and [TensorFlow Lite](https://www.tensorflow.org/lite)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 1355, + "status": "ok", + "timestamp": 1607390371776, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "nQIJrOUgxRfK", + "outputId": "62c028a5-da9e-40b1-d223-aa5f45b6a2aa" + }, + "outputs": [], + "source": [ + "# Save the Keras model to output_dir.\n", + "model_file = os.path.join(output_dir, \"model_checkpoint\")\n", + "keras_model.save(model_file)\n", + "\n", + "# Load the model from SavedModel.\n", + "loaded_model = tf.keras.models.load_model(model_file)\n", + "\n", + "# Run the loaded model to verify it returns the same result.\n", + "sentiment_activations = loaded_model(example_input[None, :])\n", + "print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", + "kind": "private" + }, + "name": "Using Trax with Keras", + "provenance": [ + { + "file_id": "1RNbQoOuzKsp_FoDqOFQX4mA--Wzt5ofq", + "timestamp": 1596181556972 + }, + { + "file_id": "https://github.com/google/trax/blob/master/trax/intro.ipynb", + "timestamp": 1596178511100 + }, + { + "file_id": "trax/intro.ipynb", + "timestamp": 1595931762204 + }, + { + "file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV", + "timestamp": 1578964243645 + }, + { + "file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07", + "timestamp": 1572044421118 + }, + { + "file_id": "intro.ipynb", + "timestamp": 1571858674399 + }, + { + "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", + "timestamp": 1569980697572 + }, + { + "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", + "timestamp": 1563927451951 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-6-5-Image-Semantic-Segmentation.ipynb b/resources/examples/ipynb/Example-6-5-Image-Semantic-Segmentation.ipynb new file mode 100644 index 000000000..a4bef0f2d --- /dev/null +++ b/resources/examples/ipynb/Example-6-5-Image-Semantic-Segmentation.ipynb @@ -0,0 +1,642 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Author- [@yashkhasbage25](https://github.com/yashkhasbage25 \"Yash Khasbage\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AZWS_qfMw1B3" + }, + "source": [ + "# Semantic Segmentation \n", + "Semantic Segmentation is a computer vision task that divides an image into segments, identifying what parts of image belong to what object. \n", + "\n", + "In this tutorial, we will train a Convolutional neural network to segment images. \n", + "\n", + "Briefly, we will discuss\n", + "1. downloading an image segmentation dataset from kaggle\n", + "2. processing the dataset according to our need\n", + "3. Create a dataloader\n", + "4. Creating a Custom loss function\n", + "5. Creating TrainTask and EvalTask \n", + "6. Create a Neural Network and train it\n", + "\n", + "(You need to have a kaggle account for downloading the dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0AjBi0zHE4pv" + }, + "source": [ + "Assuming that you already have a kaggle account, we will first begin by creating a kaggle API token. \n", + "If you don't have API token, follow these steps to create a new one:\n", + "1. Go to the Account section of kaggle website, after you login. \n", + "2. Click \"Expire API Token\" and then \"Create New API Token\". A file \"kaggle.json\" will be downloaded. \n", + "3. Using \"Choose files\" button, upload the kaggle.json file. The API token present in this file will help us download the dataset directly from kaggle. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 72, + "resources": { + "http://localhost:8080/nbextensions/google.colab/files.js": { + "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", + "headers": [ + [ + "content-type", + "application/javascript" + ] + ], + "ok": true, + "status": 200, + "status_text": "" + } + } + }, + "id": "dzXwMVFPf2qR", + "outputId": "0d776c36-8dbd-4242-e933-0b73abe243b0" + }, + "outputs": [], + "source": [ + "! pip install -q kaggle\n", + "from google.colab import files\n", + "files.upload() # upload kaggle.json\n", + "! mkdir ~/.kaggle" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sskHBrFsM4Yl" + }, + "source": [ + "We need to place kaggle.json at ~/.kaggle and also change its file permissions. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TT61H-y8gg4E", + "outputId": "333b528f-e768-496d-9593-09e4036703c0" + }, + "outputs": [], + "source": [ + "! cp kaggle.json ~/.kaggle/\n", + "! chmod 600 ~/.kaggle/kaggle.json\n", + "! kaggle datasets list" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S867tktZNHqD" + }, + "source": [ + "Now with this command, we actually download the dataset. This may take some time, depending on internet speed. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZQ_96E2ngwvJ", + "outputId": "0aba3d27-0698-4abd-8057-c9615518e7f2" + }, + "outputs": [], + "source": [ + "! kaggle datasets download -d dansbecker/cityscapes-image-pairs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OKF2tpAAPHBN" + }, + "source": [ + "The download has to be uncompressed. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DQtpMD67hbAO" + }, + "outputs": [], + "source": [ + "! unzip -q cityscapes-image-pairs.zip" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zKH-76ZJPMeR" + }, + "source": [ + "Intall trax\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mh05_t3Phy2h", + "outputId": "21c0cd27-8c13-49d5-b30f-65533d9a8084" + }, + "outputs": [], + "source": [ + "! pip install -q -U trax" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6HGMKVu1kfYh" + }, + "outputs": [], + "source": [ + "# several imports from trax\n", + "\n", + "import trax\n", + "import numpy as np\n", + "import trax.layers as tl\n", + "from trax.fastmath import numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o-3g4wi1leJy" + }, + "outputs": [], + "source": [ + "# several imports out of trax\n", + "\n", + "import os\n", + "import os.path as osp\n", + "from PIL import Image\n", + "from itertools import cycle\n", + "from sklearn.cluster import KMeans\n", + "import matplotlib.pyplot as plt\n", + "% matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZTt5oh_QjdcI" + }, + "outputs": [], + "source": [ + "# let's fix batch size\n", + "batch_size = 32" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KmqJJqs8PnEN" + }, + "source": [ + "Some details of the dataset in its original form: \n", + "The original images are of the shape 256x512x3. The left half and the right half of images belong to input and label respectively. In a typical segmentation label, the label should be a 2D matrix consisting of the class label of objects, such that each pixel is alloted a class. In the label images given, we are not directly provided with the class labels. However, each class label is represented with a specific color. We need to map colors to class labels, to convert them into usable format. \n", + "\n", + "We know that there are total 13 classes in the dataset. Hence, we will be given 13 different colors in labels. For processing the label images, according to the procedure mentioned above, we will use K-Means utility of sklearn.\n", + "\n", + "We do the processing in the following manner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tIBSJ3gpkmf9" + }, + "outputs": [], + "source": [ + "def color_kmean(root):\n", + " \"\"\" creates a k-means objects that recognizes all 13 colors of dataset. \"\"\"\n", + " \n", + " # take 10 first images\n", + " files = os.listdir(root)[:10] \n", + " colors = list()\n", + " for f in files:\n", + " img = load_image(osp.join(root, f))\n", + " # total width\n", + " w = img.shape[2]\n", + " # get the right half of image, which is the label image\n", + " img = img[:, w:, :]\n", + " # collect all the colors present in label image\n", + " colors.append(img.reshape(-1, 3))\n", + "\n", + " colors = np.array(colors)\n", + " colors = colors.reshape(-1, 3)\n", + "\n", + " # finally, fit all the colors into the KMeans\n", + " kmeans = KMeans(13)\n", + " kmeans.fit(colors)\n", + "\n", + " return kmeans\n", + "\n", + "def load_image(path):\n", + " \"\"\" loading an image. \"\"\"\n", + " \n", + " assert osp.exists(path), path + \" not found\"\n", + " image = Image.open(path)\n", + " image = np.asarray(image)\n", + " return image\n", + "\n", + "def color2class(segs, km):\n", + " \"\"\" \n", + " given an label image, convert it to class matrix, \n", + " which is a 2D matrix of class labels (scalars).\n", + " \"\"\"\n", + " \n", + " h, w, c = segs.shape\n", + " segs = segs.reshape((-1, 3))\n", + " segs = km.predict(segs)\n", + " segs = segs.reshape((h, w, 1))\n", + " return segs\n", + "\n", + "def load_dataset(root, km):\n", + " \"\"\" load dataset. \"\"\"\n", + " index = 0\n", + " imgs_path = [osp.join(root, f) for f in os.listdir(root)]\n", + "\n", + " # load images one by one, finally, and image and \n", + " # its label matrix is returned\n", + " while True:\n", + " img = load_image(imgs_path[index])\n", + " w = img.shape[1] // 2\n", + " img, seg = img[:, :w, :], img[:, w:, :]\n", + "\n", + " seg = color2class(seg, km)\n", + "\n", + " seg = seg.reshape(-1)\n", + " assert img.shape == (256, 256, 3), img.shape\n", + " assert seg.shape == (256 * 256,), seg.shape\n", + " yield img, seg\n", + "\n", + " index = (index + 1) % len(imgs_path)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "udqueyxmA6Pc" + }, + "source": [ + "Uncomment to try other backend. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DJq1biuLxeFa", + "outputId": "f95918ee-413a-4ecb-9982-a34c3d3e6177" + }, + "outputs": [], + "source": [ + "# trax.fastmath.set_backend('tensorflow-numpy')\n", + "print(trax.fastmath.backend_name())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ce_KBGtTBB50" + }, + "source": [ + "Set path to dataset, and get kmeans color setter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HLysKwN0Xy5t" + }, + "outputs": [], + "source": [ + "root = 'cityscapes_data'\n", + "\n", + "trainset_path = osp.join(root, 'train')\n", + "valset_path = osp.join(root, 'val')\n", + "\n", + "km = color_kmean(trainset_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lex2Tm72BrFf" + }, + "source": [ + "Create dataset loaders and data transforms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ngHMyZBbjfft" + }, + "outputs": [], + "source": [ + "train_dataset = load_dataset(trainset_path, km)\n", + "val_dataset = load_dataset(valset_path, km)\n", + "\n", + "train_transforms = trax.data.Serial(\n", + " trax.data.Shuffle(),\n", + " trax.data.Batch(batch_size),\n", + " lambda g: map(lambda p: (p[0].astype(np.float32), p[1]), g),\n", + ")\n", + "val_transforms = trax.data.Serial(\n", + " trax.data.Batch(batch_size),\n", + " lambda g: map(lambda p: (p[0].astype(np.float32), p[1]), g),\n", + ")\n", + "\n", + "train_dataset = train_transforms(train_dataset)\n", + "val_dataset = val_transforms(val_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HURVJcElB9et" + }, + "source": [ + "Create a custom loss. In semantic segmentation we need to apply cross entropy for every pixel of image. Hence, we decrease the number of dimensions of the matrices so that we can use CrossEntropy2d, while maintaining the order of elements of matrices. \n", + "\n", + "Here, we convert the 3D Neural Network to 2D array and 2D label matrix to 1D array." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZEdJXM9g8rif", + "outputId": "6b78ca76-db43-44c6-b618-435cbd8c8f3e" + }, + "outputs": [], + "source": [ + "def CrossEntropy3d(criterion_2d):\n", + " \"\"\" returns 3D cross entropy loss function \"\"\"\n", + " def _loss_fn(output, target):\n", + " output = output.reshape(-1, 13)\n", + " target = target.reshape(-1,)\n", + " loss = criterion_2d((output, target))\n", + " return loss\n", + " return _loss_fn\n", + "\n", + "# check dataset\n", + "x, y = next(train_dataset) \n", + "print(x.shape, y.shape)\n", + "print(x.dtype, y.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VWmhQZDElSo6" + }, + "outputs": [], + "source": [ + "# set learning rate\n", + "lr = 1e-2\n", + "\n", + "# create new trax Fn for new loss fn, and provide it a name\n", + "criterion = trax.layers.base.Fn(\"CrossEntropy3d\", \n", + " CrossEntropy3d(tl.CategoryCrossEntropy())\n", + " )\n", + "\n", + "# create TrainTask\n", + "train_task = trax.supervised.training.TrainTask(\n", + " labeled_data=train_dataset,\n", + " loss_layer=criterion,\n", + " optimizer=trax.optimizers.Momentum(lr),\n", + " n_steps_per_checkpoint=50\n", + ")\n", + "\n", + "# create EvalTask\n", + "eval_task = trax.supervised.training.EvalTask(\n", + " labeled_data=val_dataset,\n", + " metrics=[criterion]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mieHBPnpExJo" + }, + "source": [ + "Now create a simple Serial model. You can create a complex one according to your need. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LgWQmYCVoBXU" + }, + "outputs": [], + "source": [ + "model = tl.Serial(\n", + " tl.Conv(13, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(64, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(128, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(64, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + " tl.Conv(13, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Z5SsOVNE6KJ" + }, + "source": [ + "Crete a training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TE2Rfdafv5xl", + "outputId": "3cc3fc96-f812-470b-d058-b07b7d67f339" + }, + "outputs": [], + "source": [ + "training_loop = trax.supervised.training.Loop(\n", + " model, \n", + " train_task, \n", + " eval_tasks=[eval_task],\n", + " output_dir=None\n", + ")\n", + "\n", + "training_loop.run(500)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "F_eQXlgAJQd8" + }, + "source": [ + "Lets see some example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CcR_gzqsJUom", + "outputId": "ea1e1457-b4d1-4499-f7da-c791163eb740" + }, + "outputs": [], + "source": [ + "x, y = next(val_dataset)\n", + "\n", + "fig, axs = plt.subplots(nrows=1, ncols=3)\n", + "\n", + "x = x[0]\n", + "y = y[0]\n", + "\n", + "y = np.reshape(y, (256, 256))\n", + "axs[0].imshow(x.astype(np.int32))\n", + "axs[1].imshow(y)\n", + "fig.show()\n", + "\n", + "x = np.expand_dims(x, 0)\n", + "y_hat = model(x)\n", + "y_hat = y_hat[0]\n", + "\n", + "y_hat = np.argmax(y_hat, 2)\n", + "y_hat = np.reshape(y_hat, (-1,))\n", + "y_hat = km.cluster_centers_[y_hat]\n", + "y_hat = np.reshape(y_hat, (256, 256, 3))\n", + "y_hat = np.round_(y_hat).astype(np.int32)\n", + "\n", + "axs[2].imshow(y_hat)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q-TYBBWHk1v6" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "semantic_segmentation.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/resources/examples/ipynb/Example-6-6-MathQA-Python-Generation.ipynb b/resources/examples/ipynb/Example-6-6-MathQA-Python-Generation.ipynb new file mode 100644 index 000000000..98e14a9c3 --- /dev/null +++ b/resources/examples/ipynb/Example-6-6-MathQA-Python-Generation.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "#@title License\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lpqiZgTy4DqT" + }, + "source": [ + "How to generate the MathQA-Python dataset?\n", + "\n", + "\n", + "\n", + "---\n", + "\n", + "\n", + "\n", + "1. Download the dataset from the MathQA project webpage: https://math-qa.github.io/\n", + "2. Create the mathqa directory in the local colab drive.\n", + "3. Unpack the json files (train.json, dev.json, test.json, challenge_test.json) and place them in the mathqa directory.\n", + "4. Run the cells below - they will generate the MathQA-Python dataset for the test split. \n", + "5. Repeat the process for other splits if needed.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "B8nqRq0Qhcf8" + }, + "outputs": [], + "source": [ + "!pip install -U git+https://github.com/google/trax.git@220a62303ebf4ad18871aa5607b4dda2f064f2d2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "v4RKdd18hqRH" + }, + "outputs": [], + "source": [ + "from trax import data\n", + "import json\n", + "import numpy as np\n", + "import os\n", + "import tensorflow as tf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TAyU75naIFW5" + }, + "outputs": [], + "source": [ + "dataset_path = '/content/mathqa/'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "L-RZ9MeajaWC" + }, + "outputs": [], + "source": [ + "mathqa_test_gen = data.CreateMathQAInputs(dataset_path=dataset_path, cumulative=False, python_code=True, full_dict=True, train=False, test=True)()\n", + "def read_all_problems(mathqa_gen):\n", + " problems = []\n", + " questions = set()\n", + " index = 0\n", + " while True:\n", + " problem = next(mathqa_gen)\n", + " problem_dict = {}\n", + " if problem[0] in questions:\n", + " break\n", + " else:\n", + " problem_dict['text'] = problem[0]\n", + " problem_dict['code'] = problem[1]\n", + " problem_dict['dsl_code'] = problem[2]\n", + " problem_dict['reasoning'] = problem[3].strip('\\\"').strip(\"\\'\")\n", + " problem_dict['answer'] = data.tf_inputs.execute_mathqa_program(problem[0], problem[1].split('\\n'))\n", + " problem_dict['task_id'] = index\n", + " np.testing.assert_almost_equal(problem_dict['answer'], data.tf_inputs.execute_mathqa_dsl_program(problem[0], [problem[2]]))\n", + " problems.append(problem_dict)\n", + " questions.add(problem[0])\n", + " index += 1\n", + " return problems" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "K96xIQDQjyrS" + }, + "outputs": [], + "source": [ + "test_problems = read_all_problems(mathqa_test_gen)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "K5y7244_j3mB" + }, + "outputs": [], + "source": [ + "len(test_problems)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "emEvo5iAucGl" + }, + "outputs": [], + "source": [ + "test_problems[0]" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "MathQA_Python_generation_notebook.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1pdlfcJ8F4-QhBWe3KRKJW_iSov7zl6Ve", + "timestamp": 1626376876263 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-6.1-Fashion-MNIST.ipynb b/resources/examples/ipynb/Example-6.1-Fashion-MNIST.ipynb new file mode 100644 index 000000000..0c1ac72ad --- /dev/null +++ b/resources/examples/ipynb/Example-6.1-Fashion-MNIST.ipynb @@ -0,0 +1,341 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 436, + "status": "ok", + "timestamp": 1607381103381, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "1ecEWLK0nsyg" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 447, + "status": "ok", + "timestamp": 1607381103836, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "vxLvhYV5XrvS", + "outputId": "f399419a-f30c-462d-b66e-61fa55c1a466" + }, + "outputs": [], + "source": [ + "import os\n", + "#!pip install -q -U trax\n", + "import sys\n", + "\n", + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 34658, + "status": "ok", + "timestamp": 1607381138504, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "ssFKSDd3X9Xj", + "outputId": "9eba95c4-ba52-461f-ea42-6a7b1d671a3f" + }, + "outputs": [], + "source": [ + "from trax import fastmath\n", + "from trax.fastmath.jax import jax\n", + "\n", + "# Use the tensorflow-numpy backend.\n", + "fastmath.set_backend(fastmath.Backend.JAX.value)\n", + "print(trax.fastmath.backend_name())\n", + "print(jax.devices())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 18987, + "status": "ok", + "timestamp": 1607381157508, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "OHKt1_SaYGZW" + }, + "outputs": [], + "source": [ + "# https://www.tensorflow.org/datasets/catalog/fashion_mnist\n", + "from trax.data.preprocessing import inputs as preprocessing\n", + "from trax.data.loader.tf import base as dataset\n", + "\n", + "train_stream = dataset.TFDS('fashion_mnist', keys=('image', 'label'), train=True)()\n", + "eval_stream = dataset.TFDS('fashion_mnist', keys=('image', 'label'), train=False)()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 470, + "status": "ok", + "timestamp": 1607381157985, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "AfGtZHo4YYf6" + }, + "outputs": [], + "source": [ + "train_data_pipeline = preprocessing.Serial(\n", + " preprocessing.Shuffle(),\n", + " preprocessing.Batch(8),\n", + ")\n", + "\n", + "train_batches_stream = train_data_pipeline(train_stream)\n", + "\n", + "eval_data_pipeline = preprocessing.Batch(8)\n", + "eval_batches_stream = eval_data_pipeline(eval_stream)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 907, + "status": "ok", + "timestamp": 1607381158899, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "T75v8i91ZKcp", + "outputId": "5711f41d-2bf6-498d-fe44-247e16fadb07" + }, + "outputs": [], + "source": [ + "example_batch = next(train_batches_stream)\n", + "print(f'batch shape (image, label) = {[x.shape for x in example_batch]}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 430, + "status": "ok", + "timestamp": 1607381159334, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "SbRlJX9_ZRLj" + }, + "outputs": [], + "source": [ + "from trax import layers as tl\n", + "\n", + "\n", + "def get_model(n_output_classes=10):\n", + " model = tl.Serial(\n", + " tl.ToFloat(),\n", + "\n", + " tl.Conv(32, (3, 3), (1, 1), 'SAME'),\n", + " tl.LayerNorm(),\n", + " tl.Relu(),\n", + " tl.MaxPool(),\n", + "\n", + " tl.Conv(64, (3, 3), (1, 1), 'SAME'),\n", + " tl.LayerNorm(),\n", + " tl.Relu(),\n", + " tl.MaxPool(),\n", + "\n", + " tl.Flatten(),\n", + " tl.Dense(n_output_classes),\n", + " )\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 944, + "status": "ok", + "timestamp": 1607381160283, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "zv6LSQZdaV6z" + }, + "outputs": [], + "source": [ + "from trax.learning.supervised import training\n", + "from trax import optimizers as optimizers\n", + "\n", + "train_task = training.TrainTask(\n", + " labeled_data=train_batches_stream,\n", + " loss_layer=tl.CategoryCrossEntropy(),\n", + " optimizer=optimizers.Adam(0.01),\n", + " n_steps_per_checkpoint=100,\n", + ")\n", + "\n", + "eval_task = training.EvalTask(\n", + " labeled_data=eval_batches_stream,\n", + " metrics=[tl.CategoryCrossEntropy(), tl.CategoryAccuracy()],\n", + " n_eval_batches=20,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 14526, + "status": "ok", + "timestamp": 1607381174829, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "Rcz3ngZCa_9i", + "outputId": "3ece3594-8835-416d-d968-205e804f4bcc" + }, + "outputs": [], + "source": [ + "model = get_model()\n", + "\n", + "training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir='./cnn_model')\n", + "\n", + "training_loop.run(100)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "\n", + "shutil.rmtree(training_loop.output_dir, ignore_errors=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 530, + "status": "ok", + "timestamp": 1607381175378, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 480 + }, + "id": "AMhqFx6HbOs_" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Fashion MNIST with Trax.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-6.2-NER-Based-Reformer.ipynb b/resources/examples/ipynb/Example-6.2-NER-Based-Reformer.ipynb new file mode 100644 index 000000000..a712a23d7 --- /dev/null +++ b/resources/examples/ipynb/Example-6.2-NER-Based-Reformer.ipynb @@ -0,0 +1,1167 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eGCe1pjznIQS" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zpj2rPQdm8nb" + }, + "source": [ + "Author - [@SauravMaheshkar](https://github.com/SauravMaheshkar)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8LURHZ84v9-i", + "papermill": { + "duration": 0.034262, + "end_time": "2020-10-20T14:06:51.973823", + "exception": false, + "start_time": "2020-10-20T14:06:51.939561", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Install Dependencies\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yEuPYcg3BoAb", + "papermill": { + "duration": 0.031347, + "end_time": "2020-10-20T14:06:52.037011", + "exception": false, + "start_time": "2020-10-20T14:06:52.005664", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Install the latest version of the [Trax](https://github.com/google/trax) Library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "_kg_hide-output": true, + "id": "u4GfFPtWv0eb", + "outputId": "59aaef48-c9fc-4af9-9043-a1f7c7745749", + "papermill": { + "duration": 53.749037, + "end_time": "2020-10-20T14:07:45.817478", + "exception": false, + "start_time": "2020-10-20T14:06:52.068441", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "25jBohslvAaM", + "papermill": { + "duration": 0.031968, + "end_time": "2020-10-20T14:07:45.882676", + "exception": false, + "start_time": "2020-10-20T14:07:45.850708", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Introduction\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "drjA2GYE4g_F", + "papermill": { + "duration": 0.031988, + "end_time": "2020-10-20T14:07:45.947830", + "exception": false, + "start_time": "2020-10-20T14:07:45.915842", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "---\n", + "\n", + "**Named-entity recognition** (NER) is a subtask of *information extraction* that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc.\n", + "\n", + "To evaluate the quality of a NER system's output, several measures have been defined. The usual measures are called **Precision**, **recall**, and **F1 score**. However, several issues remain in just how to calculate those values. State-of-the-art NER systems for English produce near-human performance. For example, the best system entering MUC-7 scored 93.39% of F-measure while human annotators scored 97.60% and 96.95%." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SbzpsLnB6Rt_", + "papermill": { + "duration": 0.031674, + "end_time": "2020-10-20T14:07:46.011670", + "exception": false, + "start_time": "2020-10-20T14:07:45.979996", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Importing Packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2pGNHjR46RFs", + "papermill": { + "duration": 11.822159, + "end_time": "2020-10-20T14:07:57.865897", + "exception": false, + "start_time": "2020-10-20T14:07:46.043738", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from trax import layers as tl\n", + "import numpy as np # For scientific computing\n", + "import pandas as pd # For basic data analysis\n", + "import random as rnd # For using random functions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qQFItoSJGeti", + "papermill": { + "duration": 0.032906, + "end_time": "2020-10-20T14:07:57.931601", + "exception": false, + "start_time": "2020-10-20T14:07:57.898695", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Pre-Processing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jdC7V8KspbHb", + "papermill": { + "duration": 0.032062, + "end_time": "2020-10-20T14:07:57.996789", + "exception": false, + "start_time": "2020-10-20T14:07:57.964727", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Loading the Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CcLS-1P0IePt", + "papermill": { + "duration": 0.032255, + "end_time": "2020-10-20T14:07:58.061951", + "exception": false, + "start_time": "2020-10-20T14:07:58.029696", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Let's load the `ner_dataset.csv` file into a dataframe and see what it looks like" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import zipfile\n", + "\n", + "\n", + "def download_and_extract_kaggle_dataset():\n", + " # Create directory structure if it doesn't exist\n", + " os.makedirs(\"tmp/kaggle\", exist_ok=True)\n", + "\n", + " # Download the zip file using curl\n", + " download_path = os.path.expanduser(\"~/Downloads/entity-annotated-corpus.zip\")\n", + " download_cmd = [\n", + " \"curl\", \"-L\", \"-o\", download_path,\n", + " \"https://www.kaggle.com/api/v1/datasets/download/abhinavwalia95/entity-annotated-corpus\"\n", + " ]\n", + "\n", + " print(\"Downloading dataset...\")\n", + " result = subprocess.run(download_cmd, capture_output=True, text=True)\n", + "\n", + " if result.returncode != 0:\n", + " print(f\"Download failed with error: {result.stderr}\")\n", + " return False\n", + "\n", + " # Unzip the file to tmp/kaggle directory\n", + " extract_path = \"tmp/kaggle\"\n", + " print(f\"Extracting to {extract_path}...\")\n", + "\n", + " try:\n", + " with zipfile.ZipFile(download_path, 'r') as zip_ref:\n", + " zip_ref.extractall(extract_path)\n", + " print(f\"Dataset extracted successfully to {extract_path}\")\n", + " return True\n", + " except Exception as e:\n", + " print(f\"Extraction failed: {str(e)}\")\n", + " return False\n", + "\n", + "\n", + "# Call the function to download and extract\n", + "download_and_extract_kaggle_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q83GgD2JWlTz", + "outputId": "3f67e377-6450-41d2-ca19-b1f8f79f4766", + "papermill": { + "duration": 1.430809, + "end_time": "2020-10-20T14:07:59.524871", + "exception": false, + "start_time": "2020-10-20T14:07:58.094062", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "data = pd.read_csv(f\"{os.getcwd()}/tmp/kaggle/ner_dataset.csv\", encoding='ISO-8859-1')\n", + "data = data.fillna(method='ffill')\n", + "data.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DCss4fg8IQwN", + "papermill": { + "duration": 0.032562, + "end_time": "2020-10-20T14:07:59.590814", + "exception": false, + "start_time": "2020-10-20T14:07:59.558252", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Creating a Vocabulary File" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "etNMRldEImgg", + "papermill": { + "duration": 0.032586, + "end_time": "2020-10-20T14:07:59.656501", + "exception": false, + "start_time": "2020-10-20T14:07:59.623915", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We can see there's a column for the words in each sentence. Thus, we can extract this column using the [`.loc()`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.loc.html) and store it into a `.txt` file using the [`.savetext()`](https://numpy.org/doc/stable/reference/generated/numpy.savetxt.html) function from numpy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tw9ewglyIa_0", + "papermill": { + "duration": 2.336183, + "end_time": "2020-10-20T14:08:02.025687", + "exception": false, + "start_time": "2020-10-20T14:07:59.689504", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "## Extract the 'Word' column from the dataframe\n", + "words = data.loc[:, \"Word\"]\n", + "\n", + "## Convert into a text file using the .savetxt() function\n", + "np.savetxt(r'words.txt', words.values, fmt=\"%s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "skW3Wz9YKULq", + "papermill": { + "duration": 0.032752, + "end_time": "2020-10-20T14:08:02.092503", + "exception": false, + "start_time": "2020-10-20T14:08:02.059751", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Creating a Dictionary for Vocabulary" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LWMeXo8LCkkG", + "papermill": { + "duration": 0.032646, + "end_time": "2020-10-20T14:08:02.158153", + "exception": false, + "start_time": "2020-10-20T14:08:02.125507", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Here, we create a Dictionary for our vocabulary by reading through all the sentences in the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C9TxwknFKStf", + "outputId": "b7574311-badd-4623-da2a-5ef101db0b00", + "papermill": { + "duration": 0.675227, + "end_time": "2020-10-20T14:08:02.866282", + "exception": false, + "start_time": "2020-10-20T14:08:02.191055", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "vocab = {}\n", + "with open('words.txt') as f:\n", + " for i, l in enumerate(f.read().splitlines()):\n", + " vocab[l] = i\n", + " print(\"Number of words:\", len(vocab))\n", + " vocab[''] = len(vocab)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zj-hlBvzpl5x", + "papermill": { + "duration": 0.035449, + "end_time": "2020-10-20T14:08:02.936000", + "exception": false, + "start_time": "2020-10-20T14:08:02.900551", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Extracting Sentences from the Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wYUmK0skDFU7", + "papermill": { + "duration": 0.033405, + "end_time": "2020-10-20T14:08:03.003298", + "exception": false, + "start_time": "2020-10-20T14:08:02.969893", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "For extracting sentences from the dataset and creating (X,y) pairs for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "J_iN8EMIWyNM", + "papermill": { + "duration": 0.047324, + "end_time": "2020-10-20T14:08:03.084165", + "exception": false, + "start_time": "2020-10-20T14:08:03.036841", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "class Get_sentence(object):\n", + " def __init__(self, data):\n", + " self.n_sent = 1\n", + " self.data = data\n", + " agg_func = lambda s: [(w, p, t) for w, p, t in zip(s[\"Word\"].values.tolist(),\n", + " s[\"POS\"].values.tolist(),\n", + " s[\"Tag\"].values.tolist())]\n", + " self.grouped = self.data.groupby(\"Sentence #\").apply(agg_func)\n", + " self.sentences = [s for s in self.grouped]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OXZjM3UeW3ur", + "papermill": { + "duration": 7.236033, + "end_time": "2020-10-20T14:08:10.354445", + "exception": false, + "start_time": "2020-10-20T14:08:03.118412", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "getter = Get_sentence(data)\n", + "sentence = getter.sentences" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_ZKrFo7cW5RX", + "papermill": { + "duration": 0.196933, + "end_time": "2020-10-20T14:08:10.588222", + "exception": false, + "start_time": "2020-10-20T14:08:10.391289", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "words = list(set(data[\"Word\"].values))\n", + "words_tag = list(set(data[\"Tag\"].values))\n", + "\n", + "word_idx = {w: i + 1 for i, w in enumerate(words)}\n", + "tag_idx = {t: i for i, t in enumerate(words_tag)}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yxWy9E-gXPRJ", + "papermill": { + "duration": 0.669432, + "end_time": "2020-10-20T14:08:11.292061", + "exception": false, + "start_time": "2020-10-20T14:08:10.622629", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "X = [[word_idx[w[0]] for w in s] for s in sentence]\n", + "y = [[tag_idx[w[2]] for w in s] for s in sentence]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UM6bvNKkpyYP", + "papermill": { + "duration": 0.034986, + "end_time": "2020-10-20T14:08:11.365543", + "exception": false, + "start_time": "2020-10-20T14:08:11.330557", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Making a Batch Generator" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jO-C08uzDqDf", + "papermill": { + "duration": 0.034216, + "end_time": "2020-10-20T14:08:11.434628", + "exception": false, + "start_time": "2020-10-20T14:08:11.400412", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Here, we create a batch generator for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kLaPXRDtXe6E", + "papermill": { + "duration": 0.056187, + "end_time": "2020-10-20T14:08:11.525386", + "exception": false, + "start_time": "2020-10-20T14:08:11.469199", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def data_generator(batch_size, x, y, pad, shuffle=False, verbose=False):\n", + " num_lines = len(x)\n", + " lines_index = [*range(num_lines)]\n", + " if shuffle:\n", + " rnd.shuffle(lines_index)\n", + "\n", + " index = 0\n", + " while True:\n", + " buffer_x = [0] * batch_size\n", + " buffer_y = [0] * batch_size\n", + "\n", + " max_len = 0\n", + " for i in range(batch_size):\n", + " if index >= num_lines:\n", + " index = 0\n", + " if shuffle:\n", + " rnd.shuffle(lines_index)\n", + "\n", + " buffer_x[i] = x[lines_index[index]]\n", + " buffer_y[i] = y[lines_index[index]]\n", + "\n", + " lenx = len(x[lines_index[index]])\n", + " if lenx > max_len:\n", + " max_len = lenx\n", + "\n", + " index += 1\n", + "\n", + " X = np.full((batch_size, max_len), pad)\n", + " Y = np.full((batch_size, max_len), pad)\n", + "\n", + " for i in range(batch_size):\n", + " x_i = buffer_x[i]\n", + " y_i = buffer_y[i]\n", + "\n", + " for j in range(len(x_i)):\n", + " X[i, j] = x_i[j]\n", + " Y[i, j] = y_i[j]\n", + "\n", + " if verbose: print(\"index=\", index)\n", + " yield ((X, Y))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_xtaMdPWp8NW", + "papermill": { + "duration": 0.034404, + "end_time": "2020-10-20T14:08:11.594978", + "exception": false, + "start_time": "2020-10-20T14:08:11.560574", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Splitting into Test and Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RWYE1ndgX2up", + "papermill": { + "duration": 0.089107, + "end_time": "2020-10-20T14:08:11.718853", + "exception": false, + "start_time": "2020-10-20T14:08:11.629746", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MAeHnfnjx-Am", + "papermill": { + "duration": 0.034597, + "end_time": "2020-10-20T14:08:11.788761", + "exception": false, + "start_time": "2020-10-20T14:08:11.754164", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Building the Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "30w3W3IGzsP-", + "papermill": { + "duration": 0.038502, + "end_time": "2020-10-20T14:08:11.869814", + "exception": false, + "start_time": "2020-10-20T14:08:11.831312", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## The Reformer Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ISjwwZJLx_5j", + "papermill": { + "duration": 0.035572, + "end_time": "2020-10-20T14:08:11.940351", + "exception": false, + "start_time": "2020-10-20T14:08:11.904779", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "In this notebook, we use the Reformer, which is a more efficient of Transformer that uses reversible layers and locality-sensitive hashing. You can read the original paper [here](https://arxiv.org/abs/2001.04451).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cLrrFjeuzxVn", + "papermill": { + "duration": 0.034724, + "end_time": "2020-10-20T14:08:12.010232", + "exception": false, + "start_time": "2020-10-20T14:08:11.975508", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Locality-Sensitive Hashing\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fjo8QzSw2PbN", + "papermill": { + "duration": 0.034683, + "end_time": "2020-10-20T14:08:12.079753", + "exception": false, + "start_time": "2020-10-20T14:08:12.045070", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "---\n", + "The biggest problem that one might encounter while using Transformers, for huge corpora is the handling of the attention layer. Reformer introduces Locality Sensitive Hashing to solve this problem, by computing a hash function that groups similar vectors together. Thus, a input sequence is rearranged to bring elements with the same hash together and then divide into segments(or *chunks*, *buckets*) to enable parallel processing. Thus, we can apply Attention to these chunks (rather than the whole input sequence) to reduce the computational load." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u0YsTmPq13el", + "papermill": { + "duration": 0.03446, + "end_time": "2020-10-20T14:08:12.150541", + "exception": false, + "start_time": "2020-10-20T14:08:12.116081", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "![Reformer LSH.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WyRZCUtO2Dbm", + "papermill": { + "duration": 0.035247, + "end_time": "2020-10-20T14:08:12.220409", + "exception": false, + "start_time": "2020-10-20T14:08:12.185162", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Reversible Layers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xAdYG2122Jt7", + "papermill": { + "duration": 0.03461, + "end_time": "2020-10-20T14:08:12.289666", + "exception": false, + "start_time": "2020-10-20T14:08:12.255056", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "---\n", + "\n", + "Using Locality Sensitive Hashing, we were able to solve the problem of computation but still we have a memory issue. Reformer implements a novel approach to solve this problem, by recomputing the input of each layer on-demand during back-propagation, rather than storing it in memory. This is accomplished by using Reversible Layers (*activations from last layers are used to recover activations from any intermediate layer*).\n", + "\n", + "Reversible layers store two sets of activations for each layer.\n", + "\n", + "- One follows the standard procedure in which the activations are added as they pass through the network\n", + "\n", + "- The other set only captures the changes. Thus, if we run the network in reverse, we simply subtract the activations applied at each layer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cgA4DL7g30bG", + "papermill": { + "duration": 0.038825, + "end_time": "2020-10-20T14:08:12.363527", + "exception": false, + "start_time": "2020-10-20T14:08:12.324702", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "![Reformer Reversible.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5IGhItKo6kIr", + "papermill": { + "duration": 0.035579, + "end_time": "2020-10-20T14:08:12.433667", + "exception": false, + "start_time": "2020-10-20T14:08:12.398088", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Model Architecture" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BTv1SBEn9-Wa", + "papermill": { + "duration": 0.034786, + "end_time": "2020-10-20T14:08:12.503419", + "exception": false, + "start_time": "2020-10-20T14:08:12.468633", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We will perform the following steps:\n", + "\n", + "* Use input tensors from our data generator\n", + "\n", + "* Produce Semantic entries from an Embedding Layer\n", + "\n", + "* Feed these into our Reformer Language model\n", + "\n", + "* Run the Output through a Linear Layer\n", + "\n", + "* Run these through a log softmax layer to get predicted classes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4s0vDUd--pY4", + "papermill": { + "duration": 0.034523, + "end_time": "2020-10-20T14:08:12.572892", + "exception": false, + "start_time": "2020-10-20T14:08:12.538369", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We use the:\n", + "\n", + "\n", + "\n", + "1. [`tl.Serial()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.combinators.Serial): Combinator that applies layers serially(by function composition). It's commonly used to construct deep networks. It uses stack semantics to manage data for its sublayers\n", + "2. [`tl.Embedding()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Embedding): Initializes a trainable embedding layer that maps discrete tokens/ids to vectors\n", + "\n", + "3. [`trax.models.reformer.Reformer()`](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.reformer.reformer.Reformer): Creates a Reversible Transformer encoder-decoder model.\n", + "\n", + "4. [`tl.Dense()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Dense): Creates a Dense(*fully-connected, affine*) layer\n", + "\n", + "5. [`tl.LogSoftmax()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.LogSoftmax): Creates a layer that applies log softmax along one tensor axis.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gDqWqFKT6a6r", + "papermill": { + "duration": 0.046731, + "end_time": "2020-10-20T14:08:12.656598", + "exception": false, + "start_time": "2020-10-20T14:08:12.609867", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from trax.models import reformer\n", + "\n", + "\n", + "def NERmodel(tags, vocab_size=35181, d_model=50):\n", + " model = tl.Serial(\n", + " reformer.Reformer(vocab_size, d_model, ff_activation=tl.LogSoftmax),\n", + " tl.Dense(tags),\n", + " tl.LogSoftmax()\n", + " )\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "_kg_hide-output": true, + "id": "NsCct_PV8kEi", + "outputId": "fc664cfd-87a1-4f98-cadc-8fdcafc9929f", + "papermill": { + "duration": 0.062424, + "end_time": "2020-10-20T14:08:12.754804", + "exception": false, + "start_time": "2020-10-20T14:08:12.692380", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model = NERmodel(tags=17)\n", + "\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1GsNxS4JYETt", + "papermill": { + "duration": 0.041676, + "end_time": "2020-10-20T14:08:12.833227", + "exception": false, + "start_time": "2020-10-20T14:08:12.791551", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Train the Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9nhKmsUkYFgD", + "papermill": { + "duration": 0.051837, + "end_time": "2020-10-20T14:08:12.924577", + "exception": false, + "start_time": "2020-10-20T14:08:12.872740", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from trax.learning.supervised import training\n", + "from trax.data.preprocessing import inputs\n", + "\n", + "rnd.seed(33)\n", + "batch_size = 64\n", + "\n", + "train_generator = inputs.add_loss_weights(\n", + " data_generator(batch_size, x_train, y_train, vocab[''], True),\n", + " id_to_mask=vocab[''])\n", + "\n", + "eval_generator = inputs.add_loss_weights(\n", + " data_generator(batch_size, x_test, y_test, vocab[''], True),\n", + " id_to_mask=vocab[''])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3CZWK9HgY_lj", + "papermill": { + "duration": 0.05051, + "end_time": "2020-10-20T14:08:13.013644", + "exception": false, + "start_time": "2020-10-20T14:08:12.963134", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from trax import optimizers as optimizers\n", + "\n", + "\n", + "def train_model(model, train_generator, eval_generator, train_steps=1, output_dir='model'):\n", + " train_task = training.TrainTask(\n", + " train_generator,\n", + " loss_layer=tl.CrossEntropyLoss(),\n", + " optimizer=optimizers.Adam(0.01),\n", + " n_steps_per_checkpoint=10\n", + " )\n", + "\n", + " eval_task = training.EvalTask(\n", + " labeled_data=eval_generator,\n", + " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],\n", + " n_eval_batches=10\n", + " )\n", + "\n", + " training_loop = training.Loop(\n", + " model,\n", + " train_task,\n", + " eval_tasks=eval_task,\n", + " output_dir=output_dir)\n", + "\n", + " training_loop.run(n_steps=train_steps)\n", + " return training_loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Y8kYOG9xZNF7", + "jupyter": { + "is_executing": true + }, + "outputId": "29557238-28fb-4d50-b22d-fd50a203cc52", + "papermill": { + "duration": 29506.536646, + "end_time": "2020-10-20T22:19:59.586493", + "exception": false, + "start_time": "2020-10-20T14:08:13.049847", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "train_steps = 100\n", + "training_loop = train_model(model, train_generator, eval_generator, train_steps)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dRwN9mp74kZG", + "papermill": { + "duration": 0.058348, + "end_time": "2020-10-20T22:19:59.703317", + "exception": false, + "start_time": "2020-10-20T22:19:59.644969", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# References" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "--6G7L9w4mNg", + "papermill": { + "duration": 0.058998, + "end_time": "2020-10-20T22:19:59.820862", + "exception": false, + "start_time": "2020-10-20T22:19:59.761864", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "---\n", + "\n", + "* [Google AI Blog- Reformer: The Efficient Transformer](https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html)\n", + "\n", + "* [Google AI Blog- Transformer: A Novel Neural Network Architecture for Language Understanding](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html)\n", + "\n", + "* [Trax: Deep Learning with Clear Code and Speed](https://github.com/google/trax)\n", + "\n", + "* [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)\n", + "\n", + "* [Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n", + "\n", + "* [Illustrating the Reformer](https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)" + ] + } + ], + "metadata": { + "colab": { + "include_colab_link": true, + "name": "NER using Reformer", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "papermill": { + "duration": 29594.274789, + "end_time": "2020-10-20T22:20:01.092204", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2020-10-20T14:06:46.817415", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-6.3-Deep-N-Gram-Models.ipynb b/resources/examples/ipynb/Example-6.3-Deep-N-Gram-Models.ipynb new file mode 100644 index 000000000..88de24d90 --- /dev/null +++ b/resources/examples/ipynb/Example-6.3-Deep-N-Gram-Models.ipynb @@ -0,0 +1,971 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lAAzPCP8n05S" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CcV2B-3LnvBk" + }, + "source": [ + "Author - [@SauravMaheshkar](https://github.com/SauravMaheshkar)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uEg7rw6fnr0q", + "papermill": { + "duration": 0.024472, + "end_time": "2020-10-19T05:23:45.163806", + "exception": false, + "start_time": "2020-10-19T05:23:45.139334", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Downloading the Trax Package" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7iVotT-qnr0q", + "papermill": { + "duration": 0.024546, + "end_time": "2020-10-19T05:23:45.211638", + "exception": false, + "start_time": "2020-10-19T05:23:45.187092", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "import os\n", + "import sys\n", + "\n", + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s4e-X6Ranr0s", + "papermill": { + "duration": 0.121469, + "end_time": "2020-10-19T05:24:41.120599", + "exception": false, + "start_time": "2020-10-19T05:24:40.999130", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Importing Packages" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zaoHVZj0nr0s", + "papermill": { + "duration": 0.117117, + "end_time": "2020-10-19T05:24:41.355694", + "exception": false, + "start_time": "2020-10-19T05:24:41.238577", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "In this notebook we will use the following packages:\n", + "\n", + "* [**Pandas**](https://pandas.pydata.org/) is a fast, powerful, flexible and easy to use open-source data analysis and manipulation tool, built on top of the Python programming language. It offers a fast and efficient DataFrame object for data manipulation with integrated indexing.\n", + "* [**os**](https://docs.python.org/3/library/os.html) module provides a portable way of using operating system dependent functionality.\n", + "* [**trax**](https://trax-ml.readthedocs.io/en/latest/trax.html) is an end-to-end library for deep learning that focuses on clear code and speed.\n", + "* [**random**](https://docs.python.org/3/library/random.html) module implements pseudo-random number generators for various distributions.\n", + "* [**itertools**](https://docs.python.org/3/library/itertools.html) module implements a number of iterator building blocks inspired by constructs from APL, Haskell, and SML. Each has been recast in a form suitable for Python." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "import shutil\n", + "import trax.fastmath.numpy as np\n", + "import random as rnd\n", + "from trax import layers as tl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZZaUGa2Lnr0s", + "papermill": { + "duration": 0.118759, + "end_time": "2020-10-19T05:24:54.899617", + "exception": false, + "start_time": "2020-10-19T05:24:54.780858", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Loading the Data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WbwaTxIFnr0s", + "papermill": { + "duration": 0.122704, + "end_time": "2020-10-19T05:24:55.144895", + "exception": false, + "start_time": "2020-10-19T05:24:55.022191", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "For this project, I've used the [gothic-literature](https://www.kaggle.com/charlesaverill/gothic-literature), [shakespeare-plays](https://www.kaggle.com/kingburrito666/shakespeare-plays) and [shakespeareonline](https://www.kaggle.com/kewagbln/shakespeareonline) datasets from the Kaggle library.\n", + "\n", + "We perform the following steps for loading in the data:\n", + "\n", + "* Iterate over all the directories in the `/kaggle/input/` directory\n", + "* Filter out `.txt` files\n", + "* Make a `lines` list containing the individual lines from all the datasets combined" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import subprocess\n", + "import zipfile\n", + "\n", + "\n", + "def download_datasets(download_dir):\n", + " os.makedirs(download_dir, exist_ok=True)\n", + "\n", + " # Define the datasets with output filename and download URL\n", + " datasets = [\n", + " {\n", + " \"filename\": \"gothic-literature.zip\",\n", + " \"url\": \"https://www.kaggle.com/api/v1/datasets/download/charlesaverill/gothic-literature\"\n", + " },\n", + " {\n", + " \"filename\": \"shakespeare-plays.zip\",\n", + " \"url\": \"https://www.kaggle.com/api/v1/datasets/download/kingburrito666/shakespeare-plays\"\n", + " },\n", + " {\n", + " \"filename\": \"shakespeareonline.zip\",\n", + " \"url\": \"https://www.kaggle.com/api/v1/datasets/download/kewagbln/shakespeareonline\"\n", + " }\n", + " ]\n", + "\n", + " # Download each dataset using curl\n", + " for dataset in datasets:\n", + " output_path = os.path.join(download_dir, dataset[\"filename\"])\n", + " # Build the curl command (using -L for following redirects)\n", + " cmd = [\n", + " \"curl\",\n", + " \"-L\",\n", + " \"-o\", output_path,\n", + " dataset[\"url\"]\n", + " ]\n", + " print(f\"Downloading {dataset['filename']}...\")\n", + " subprocess.run(cmd, check=True)\n", + " print(f\"Downloaded to {output_path}\")\n", + "\n", + "\n", + "def extract_zip_files(download_dir, extract_dir):\n", + " os.makedirs(extract_dir, exist_ok=True)\n", + "\n", + " # Iterate through the zip files in the download directory\n", + " for file in os.listdir(download_dir):\n", + " if file.lower().endswith(\".zip\"):\n", + " zip_path = os.path.join(download_dir, file)\n", + " # Create a subdirectory for each zip file (optional)\n", + " extract_subdir = os.path.join(extract_dir, os.path.splitext(file)[0])\n", + " os.makedirs(extract_subdir, exist_ok=True)\n", + " print(f\"Extracting {zip_path} to {extract_subdir}...\")\n", + " with zipfile.ZipFile(zip_path, 'r') as z:\n", + " z.extractall(extract_subdir)\n", + " print(\"Extraction completed.\")\n", + "\n", + "\n", + "def read_text_files(extracted_dir):\n", + " lines = []\n", + "\n", + " # Walk through the unzipped directories and process each .txt file\n", + " for root, _, files in os.walk(extracted_dir):\n", + " for filename in files:\n", + " if filename.lower().endswith(\".txt\"):\n", + " file_path = os.path.join(root, filename)\n", + " print(f\"Reading {file_path}...\")\n", + " with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:\n", + " for line in f:\n", + " processed_line = line.strip()\n", + " if processed_line:\n", + " lines.append(processed_line)\n", + " return lines" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set download and extraction directories\n", + "download_dir = os.path.expanduser(\"~/Downloads\")\n", + "extract_dir = os.path.join(download_dir, \"extracted_datasets\")\n", + "\n", + "# Download datasets using curl\n", + "download_datasets(download_dir)\n", + "\n", + "# Extract downloaded zip files\n", + "extract_zip_files(download_dir, extract_dir)\n", + "\n", + "# Read text files from extracted data\n", + "all_lines = read_text_files(extract_dir)\n", + "\n", + "print(f\"Total non-empty lines read: {len(all_lines)}\")\n", + "# For example purposes, printing first 10 lines\n", + "print(\"\\nFirst 10 lines:\")\n", + "for line in all_lines[:10]:\n", + " print(line)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EPifypFdnr0s", + "papermill": { + "duration": 0.113664, + "end_time": "2020-10-19T05:24:55.951966", + "exception": false, + "start_time": "2020-10-19T05:24:55.838302", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Pre-Processing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eU58tWP3nr0s", + "papermill": { + "duration": 0.119888, + "end_time": "2020-10-19T05:24:56.194726", + "exception": false, + "start_time": "2020-10-19T05:24:56.074838", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Converting to Lowercase\n", + "\n", + "Converting all the characters in the `lines` list to **lowercase**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QAxU3uzunr0s", + "papermill": { + "duration": 0.253923, + "end_time": "2020-10-19T05:24:56.569875", + "exception": false, + "start_time": "2020-10-19T05:24:56.315952", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "for i, line in enumerate(all_lines):\n", + " all_lines[i] = line.lower()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "voNUJBrRnr0s", + "papermill": { + "duration": 0.11122, + "end_time": "2020-10-19T05:24:56.795120", + "exception": false, + "start_time": "2020-10-19T05:24:56.683900", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Converting into Tensors\n", + "\n", + "Creating a function to convert each line into a tensor by converting each character into it's ASCII value. And adding a optional `EOS`(**End of statement**) character." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "J0F2sUJfnr0s", + "papermill": { + "duration": 0.131432, + "end_time": "2020-10-19T05:24:57.037392", + "exception": false, + "start_time": "2020-10-19T05:24:56.905960", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def line_to_tensor(line, EOS_int=1):\n", + " tensor = []\n", + " for c in line:\n", + " c_int = ord(c)\n", + " tensor.append(c_int)\n", + "\n", + " tensor.append(EOS_int)\n", + "\n", + " return tensor" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zYT5__Danr0s", + "papermill": { + "duration": 0.109763, + "end_time": "2020-10-19T05:24:57.259043", + "exception": false, + "start_time": "2020-10-19T05:24:57.149280", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Creating a Batch Generator\n", + "\n", + "Here, we create a `batch_generator()` function to yield a batch and mask generator. We perform the following steps:\n", + "\n", + "* Shuffle the lines if not shuffled\n", + "* Convert the lines into a Tensor\n", + "* Pad the lines if it's less than the maximum length\n", + "* Generate a mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V-D_5L_snr0s", + "papermill": { + "duration": 0.134497, + "end_time": "2020-10-19T05:24:57.503870", + "exception": false, + "start_time": "2020-10-19T05:24:57.369373", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def data_generator(batch_size, max_length, data_lines, line_to_tensor=line_to_tensor, shuffle=True):\n", + " index = 0\n", + " cur_batch = []\n", + " num_lines = len(data_lines)\n", + " lines_index = [*range(num_lines)]\n", + "\n", + " if shuffle:\n", + " rnd.shuffle(lines_index)\n", + "\n", + " while True:\n", + "\n", + " if index >= num_lines:\n", + " index = 0\n", + " if shuffle:\n", + " rnd.shuffle(lines_index)\n", + "\n", + " line = data_lines[lines_index[index]]\n", + "\n", + " if len(line) < max_length:\n", + " cur_batch.append(line)\n", + "\n", + " index += 1\n", + "\n", + " if len(cur_batch) == batch_size:\n", + "\n", + " batch = []\n", + " mask = []\n", + "\n", + " for li in cur_batch:\n", + " tensor = line_to_tensor(li)\n", + "\n", + " pad = [0] * (max_length - len(tensor))\n", + " tensor_pad = tensor + pad\n", + " batch.append(tensor_pad)\n", + "\n", + " example_mask = [0 if t == 0 else 1 for t in tensor_pad]\n", + " mask.append(example_mask)\n", + "\n", + " batch_np_arr = np.array(batch)\n", + " mask_np_arr = np.array(mask)\n", + "\n", + " yield batch_np_arr, batch_np_arr, mask_np_arr\n", + "\n", + " cur_batch = []\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generator = data_generator(2, 10, all_lines, line_to_tensor=line_to_tensor, shuffle=True)\n", + "print(next(generator))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "biglhqPjnr0s", + "papermill": { + "duration": 0.113922, + "end_time": "2020-10-19T05:24:57.728762", + "exception": false, + "start_time": "2020-10-19T05:24:57.614840", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Defining the Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6JgMdnTonr0s", + "papermill": { + "duration": 0.110544, + "end_time": "2020-10-19T05:24:57.950897", + "exception": false, + "start_time": "2020-10-19T05:24:57.840353", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Gated Recurrent Unit\n", + "\n", + "This function generates a GRU Language Model, consisting of the following layers:\n", + "\n", + "* ShiftRight()\n", + "* Embedding()\n", + "* GRU Units(Number specified by the `n_layers` parameter)\n", + "* Dense() Layer\n", + "* LogSoftmax() Activation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MSA3bpCHnr0s", + "papermill": { + "duration": 0.124594, + "end_time": "2020-10-19T05:24:58.186525", + "exception": false, + "start_time": "2020-10-19T05:24:58.061931", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def GRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", + " model = tl.Serial(\n", + " tl.ShiftRight(mode=mode),\n", + " tl.Embedding(vocab_size=vocab_size, d_feature=d_model),\n", + " [tl.GRU(n_units=d_model) for _ in range(n_layers)],\n", + " tl.Dense(n_units=vocab_size),\n", + " tl.LogSoftmax()\n", + " )\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9A0JtfgCnr0s", + "papermill": { + "duration": 0.150132, + "end_time": "2020-10-19T05:24:58.463252", + "exception": false, + "start_time": "2020-10-19T05:24:58.313120", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Long Short Term Memory\n", + "\n", + "This function generates a LSTM Language Model, consisting of the following layers:\n", + "\n", + "* ShiftRight()\n", + "* Embedding()\n", + "* LSTM Units(Number specified by the `n_layers` parameter)\n", + "* Dense() Layer\n", + "* LogSoftmax() Activation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ScuXPmvLnr0s", + "papermill": { + "duration": 0.129976, + "end_time": "2020-10-19T05:24:58.717410", + "exception": false, + "start_time": "2020-10-19T05:24:58.587434", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def LSTMLM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", + " model = tl.Serial(\n", + " tl.ShiftRight(mode=mode),\n", + " tl.Embedding(vocab_size=vocab_size, d_feature=d_model),\n", + " [tl.LSTM(n_units=d_model) for _ in range(n_layers)],\n", + " tl.Dense(n_units=vocab_size),\n", + " tl.LogSoftmax()\n", + " )\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zWVaUwG1nr0s", + "papermill": { + "duration": 0.130305, + "end_time": "2020-10-19T05:24:58.971978", + "exception": false, + "start_time": "2020-10-19T05:24:58.841673", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Simple Recurrent Unit\n", + "\n", + "This function generates a SRU Language Model, consisting of the following layers:\n", + "\n", + "* ShiftRight()\n", + "* Embedding()\n", + "* SRU Units(Number specified by the `n_layers` parameter)\n", + "* Dense() Layer\n", + "* LogSoftmax() Activation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ECzZRknPnr0s", + "papermill": { + "duration": 0.12795, + "end_time": "2020-10-19T05:24:59.221979", + "exception": false, + "start_time": "2020-10-19T05:24:59.094029", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def SRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", + " model = tl.Serial(\n", + " tl.ShiftRight(mode=mode),\n", + " tl.Embedding(vocab_size=vocab_size, d_feature=d_model),\n", + " [tl.SRU(n_units=d_model) for _ in range(n_layers)],\n", + " tl.Dense(n_units=vocab_size),\n", + " tl.LogSoftmax()\n", + " )\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1i8UlSvhnr0s", + "outputId": "f4894449-5399-48c8-e22d-a8fa05be3615", + "papermill": { + "duration": 0.132413, + "end_time": "2020-10-19T05:24:59.466681", + "exception": false, + "start_time": "2020-10-19T05:24:59.334268", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "GRUmodel = GRULM(n_layers=5)\n", + "LSTMmodel = LSTMLM(n_layers=5)\n", + "SRUmodel = SRULM(n_layers=5)\n", + "print(GRUmodel)\n", + "print(LSTMmodel)\n", + "print(SRUmodel)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "As2O2Zj8nr0t", + "papermill": { + "duration": 0.117255, + "end_time": "2020-10-19T05:24:59.712882", + "exception": false, + "start_time": "2020-10-19T05:24:59.595627", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Hyperparameters" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cxIs1y_Gnr0t", + "papermill": { + "duration": 0.113458, + "end_time": "2020-10-19T05:24:59.939569", + "exception": false, + "start_time": "2020-10-19T05:24:59.826111", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Here, we declare `the batch_size` and the `max_length` hyperparameters for the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BLKz_gfKnr0t", + "papermill": { + "duration": 0.121757, + "end_time": "2020-10-19T05:25:00.176474", + "exception": false, + "start_time": "2020-10-19T05:25:00.054717", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "batch_size = 32\n", + "max_length = 64" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zUKNlXAmnr0t", + "papermill": { + "duration": 0.111425, + "end_time": "2020-10-19T05:25:00.399880", + "exception": false, + "start_time": "2020-10-19T05:25:00.288455", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Creating Evaluation and Training Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TYJepc9Knr0t", + "papermill": { + "duration": 0.130539, + "end_time": "2020-10-19T05:25:00.641885", + "exception": false, + "start_time": "2020-10-19T05:25:00.511346", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "eval_lines = all_lines[-1000:] # Create a holdout validation set\n", + "lines = all_lines[:-1000] # Leave the rest for training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1DbI1fFSnr0t", + "papermill": { + "duration": 0.112994, + "end_time": "2020-10-19T05:25:00.871007", + "exception": false, + "start_time": "2020-10-19T05:25:00.758013", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Training the Models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8LKJoIzenr0t", + "papermill": { + "duration": 0.112218, + "end_time": "2020-10-19T05:25:01.096544", + "exception": false, + "start_time": "2020-10-19T05:25:00.984326", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Here, we create a function to train the models. This function does the following:\n", + "\n", + "* Creating a Train and Evaluation Generator that cycles infinetely using the `itertools` module\n", + "* Train the Model using Adam Optimizer\n", + "* Use the Accuracy Metric for Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "i4-fSW3Tnr0t", + "papermill": { + "duration": 0.130503, + "end_time": "2020-10-19T05:25:01.339549", + "exception": false, + "start_time": "2020-10-19T05:25:01.209046", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from trax.learning.supervised import training\n", + "from trax import optimizers as optimizers\n", + "import itertools\n", + "\n", + "\n", + "def train_model(model, data_generator, batch_size=32, max_length=64, lines=lines, eval_lines=eval_lines, n_steps=10,\n", + " output_dir='model/'):\n", + " bare_train_generator = data_generator(batch_size, max_length, data_lines=lines)\n", + " infinite_train_generator = itertools.cycle(bare_train_generator)\n", + "\n", + " bare_eval_generator = data_generator(batch_size, max_length, data_lines=eval_lines)\n", + " infinite_eval_generator = itertools.cycle(bare_eval_generator)\n", + "\n", + " train_task = training.TrainTask(\n", + " labeled_data=infinite_train_generator,\n", + " loss_layer=tl.CrossEntropyLoss(),\n", + " optimizer=optimizers.Adam(0.0005),\n", + " n_steps_per_checkpoint=1\n", + " )\n", + "\n", + " eval_task = training.EvalTask(\n", + " labeled_data=infinite_eval_generator,\n", + " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],\n", + " n_eval_batches=1\n", + " )\n", + "\n", + " training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir=output_dir\n", + " )\n", + "\n", + " training_loop.run(n_steps=n_steps)\n", + "\n", + " return training_loop\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dykzx2t1nr0t", + "papermill": { + "duration": 79.597768, + "end_time": "2020-10-19T05:26:21.064134", + "exception": false, + "start_time": "2020-10-19T05:25:01.466366", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "shutil.rmtree(os.path.expanduser('model/GRU'), ignore_errors=True)\n", + "GRU_training_loop = train_model(GRUmodel, data_generator, n_steps=10, output_dir='model/GRU')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4w9jvGYDnr0t", + "papermill": { + "duration": 93.801876, + "end_time": "2020-10-19T05:27:55.049974", + "exception": false, + "start_time": "2020-10-19T05:26:21.248098", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "shutil.rmtree(os.path.expanduser('model/LSTM'), ignore_errors=True)\n", + "LSTM_training_loop = train_model(LSTMmodel, data_generator, n_steps=10, output_dir='model/LSTM')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PWePFGVKnr0t", + "papermill": { + "duration": 41.004194, + "end_time": "2020-10-19T05:28:36.239938", + "exception": false, + "start_time": "2020-10-19T05:27:55.235744", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "shutil.rmtree(os.path.expanduser('model/SRU'), ignore_errors=True)\n", + "SRU_training_loop = train_model(SRUmodel, data_generator, n_steps=50_000, output_dir='model/SRU')" + ] + } + ], + "metadata": { + "colab": { + "include_colab_link": true, + "name": "Deep N-Gram Models", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "papermill": { + "duration": 297.094983, + "end_time": "2020-10-19T05:28:36.576660", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2020-10-19T05:23:39.481677", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-6.4-NMT-with-Transformers-Reformers.ipynb b/resources/examples/ipynb/Example-6.4-NMT-with-Transformers-Reformers.ipynb new file mode 100644 index 000000000..79dc86a23 --- /dev/null +++ b/resources/examples/ipynb/Example-6.4-NMT-with-Transformers-Reformers.ipynb @@ -0,0 +1,1727 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lAAzPCP8n05S" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2021 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hqqdEx7xtHuH" + }, + "source": [ + "# **NMT with Transformers/Reformers using Trax**\n", + "\n", + "A guide to Neural Machine Translation using ***Transformers/Reformers***. Includes a detailed tutorial using ***Trax*** in Google Colaboratory.\n", + "\n", + "Machine translation is an important task in natural language processing and could be useful not only for translating one language to another but also for word sense disambiguation.\n", + "\n", + "In this Notebook you will:\n", + "* Learn how to preprocess your training and evaluation data.\n", + "* implement an encoder-decoder system with attention.\n", + "* understand how attention works.\n", + "* build the NMT model from scratch using Trax.\n", + "* learn how to preprocess your training and evaluation data.\n", + "* generate translations using greedy and Minimum Bayes Risk (MBR) decoding.\n", + "\n", + "This notebook contains a lot of cells taken from [Natural Language Processing Specialization](https://www.coursera.org/specializations/natural-language-processing)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R8u7YU2uqOXH" + }, + "source": [ + "# Part (-1): Run on CPU/GPU/TPU" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pO10zU6I87dc" + }, + "source": [ + "This notebook was designed to run on TPU.\n", + "\n", + "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8QCsYnkLv59s", + "outputId": "29c114d1-c940-4411-fcf1-984a34b7f9fa" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QVw5457jqlOm" + }, + "source": [ + "# Part (0): Important Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nA7u_MqG9dmQ", + "outputId": "741a1e11-319a-4742-e38e-6217da1295e9" + }, + "outputs": [], + "source": [ + "from trax import layers as tl\n", + "from trax.learning.supervised import training\n", + "from trax.data.preprocessing import inputs as preprocessing\n", + "from trax.data.encoder import encoder\n", + "from trax.data.loader.tf import base as dataset\n", + "from trax import models\n", + "from trax import optimizers\n", + "from trax.learning.supervised import lr_schedules as learning_schedule\n", + "\n", + "import numpy as np\n", + "\n", + "from termcolor import colored\n", + "import random\n", + "import shutil" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aByNRLKr9dmG" + }, + "source": [ + "# Part (1): Data Preparation\n", + "\n", + "**You Can jump directly to Trax Data Pipeline (optional) Section and skip 1.1 to 1.5 sections.**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_WpKodqa9dmJ" + }, + "source": [ + "## 1.1 Importing the Data\n", + "We will be using [ParaCrawl](https://paracrawl.eu/), a large multi-lingual translation dataset created by the European Union. All of these datasets are available via [TFDS para_crawl](https://www.tensorflow.org/datasets/catalog/para_crawl). We used English to French dataset. You can try the other avaliable languages by changing the `dataset_name` and `keys`. Or even try another datasets available at TFDS.\n", + "\n", + "Notice: It will take a while in the first time to download the dataset. So, it is prefered to specify `data_dir` on Google Drive not in Colab runtime. Try other than para_crawl dataset. since, the para_crawl is a large dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jEJaYJ5C9dmb" + }, + "outputs": [], + "source": [ + "# This will download the train dataset if no data_dir is specified.\n", + "train_stream_fn = dataset.TFDS('para_crawl/enfr',\n", + " data_dir='~/tensorflow_datasets/',\n", + " keys=('en', 'fr'),\n", + " eval_holdout_size=0.01, # 1% for eval\n", + " train=True)\n", + "\n", + "# Get generator function for the eval set\n", + "eval_stream_fn = dataset.TFDS('para_crawl/enfr',\n", + " data_dir='~/tensorflow_datasets/',\n", + " keys=('en', 'fr'),\n", + " eval_holdout_size=0.01, # 1% for eval\n", + " train=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kk-x0gW9-qsD" + }, + "source": [ + "You can work with your own datasets instead of loading your dataset with TFDS. Opening a file as shown above creates that generator for you. dont forget to make another function for eval.\n", + "\n", + "```python\n", + "def train_stream_fn():\n", + " # provide an infinite generator while True: # open the first language file (e.g. English sentences)\n", + " with open('lang1.csv','r') as f1:\n", + " # open the second language file (e.g. French sentences)\n", + " with open('lang2.csv','r') as f2:\n", + " # looping over the two files to combine the two translation toghether and yields them.\n", + " for l1, l2 in zip(f1,f2):\n", + " yield (l1, l2)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tPlcZf3RLNAg" + }, + "source": [ + "Notice that TFDS returns a generator *function*.\n", + "\n", + "Let's print a a sample pair from our train and eval data. Notice that the raw ouput is represented in bytes (denoted by the `b'` prefix) and these will be converted to strings internally in the next steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "16UrIf259dml", + "outputId": "9a860216-c9fa-4f29-e28a-e9c535feefd4" + }, + "outputs": [], + "source": [ + "train_stream = train_stream_fn()\n", + "print(colored('train data (en, fr) tuple:', 'red'), next(train_stream))\n", + "print()\n", + "\n", + "eval_stream = eval_stream_fn()\n", + "print(colored('eval data (en, fr) tuple:', 'red'), next(eval_stream))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kWUH9_PNIe5g" + }, + "source": [ + "Now that we have imported our corpus, we will be preprocessing the sentences into a format that our model can accept. This will be composed of several steps:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VY6_SnLM9dms" + }, + "source": [ + "## 1.2 Tokenization and Formatting\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PWP3GAoXHiwo" + }, + "source": [ + "**Tokenizing the sentences using subword representations:** we want to represent each sentence as an array of integers instead of strings. For our application, we will use *subword* representations to tokenize our sentences. This is a common technique to avoid out-of-vocabulary words by allowing parts of words to be represented separately. For example, instead of having separate entries in your vocabulary for \"fear\", \"fearless\", \"fearsome\", \"some\", and \"less\", you can simply store \"fear\", \"some\", and \"less\" then allow your tokenizer to combine these subwords when needed. This allows it to be more flexible so you won't have to save uncommon words explicitly in your vocabulary (e.g. *stylebender*, *nonce*, etc). Tokenizing is done with the `trax.data.Tokenize()` command. The combined subword vocabulary for English, German and French (i.e. `endefr_32k.subword`) is provided by trax. Feel free to open this file to see how the subwords look like." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# global variables that state the filename and directory of the vocabulary file\n", + "VOCAB_FILE = 'endefr_32k.subword'\n", + "VOCAB_DIR = 'gs://trax-ml/vocabs/'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q8R2RxvK9dmt" + }, + "outputs": [], + "source": [ + "# Tokenize the dataset.\n", + "tokenized_train_stream = encoder.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)\n", + "tokenized_eval_stream = encoder.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yrmCi915HTKA" + }, + "source": [ + "**Append an end-of-sentence token to each sentence:** We will assign a token (i.e. in this case `1`) to mark the end of a sentence. This will be useful in inference/prediction so we'll know that the model has completed the translation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RuolzODV9dm0" + }, + "outputs": [], + "source": [ + "# Append EOS at the end of each sentence.\n", + "\n", + "# Integer assigned as end-of-sentence (EOS)\n", + "EOS = 1\n", + "\n", + "\n", + "# generator helper function to append EOS to each sentence\n", + "def append_eos(stream):\n", + " for (inputs, targets) in stream:\n", + " inputs_with_eos = list(inputs) + [EOS]\n", + " targets_with_eos = list(targets) + [EOS]\n", + " yield np.array(inputs_with_eos), np.array(targets_with_eos)\n", + "\n", + "\n", + "# append EOS to the train data\n", + "tokenized_train_stream = append_eos(tokenized_train_stream)\n", + "\n", + "# append EOS to the eval data\n", + "tokenized_eval_stream = append_eos(tokenized_eval_stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rbaYhKr99dm8" + }, + "source": [ + "**Filter long sentences:** We will place a limit on the number of tokens per sentence to ensure we won't run out of memory. This is done with the `trax.data.FilterByLength()` method and you can see its syntax below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Miw7Uu849dm9", + "outputId": "3a6b35a5-f257-42fb-b914-59534d3f2a76" + }, + "outputs": [], + "source": [ + "# Filter too long sentences to not run out of memory.\n", + "# length_keys=[0, 1] means we filter both English and French sentences, so\n", + "# both much be not longer than 512 tokens for training / 1024 for eval.\n", + "filtered_train_stream = preprocessing.FilterByLength(\n", + " max_length=512, length_keys=[0, 1])(tokenized_train_stream)\n", + "filtered_eval_stream = preprocessing.FilterByLength(\n", + " max_length=1024, length_keys=[0, 1])(tokenized_eval_stream)\n", + "\n", + "# print a sample input-target pair of tokenized sentences\n", + "train_input, train_target = next(filtered_train_stream)\n", + "print(colored(f'Single tokenized example input:', 'red'), train_input)\n", + "print(colored(f'Single tokenized example target:', 'red'), train_target)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WD0ZqedYIpr3" + }, + "source": [ + "## 1.3 tokenize & detokenize helper functions\n", + "\n", + "- tokenize(): converts a text sentence to its corresponding token list (i.e. list of indices). Also converts words to subwords (parts of words).\n", + "- detokenize(): converts a token list to its corresponding sentence (i.e. string)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OyO5I2e_9dnD" + }, + "outputs": [], + "source": [ + "# Setup helper functions for tokenizing and retokenizing sentences\n", + "def tokenize(input_str, vocab_file=None, vocab_dir=None):\n", + " \"\"\"Encodes a string to an array of integers\n", + " Args:\n", + " input_str (str): human-readable string to encode\n", + " vocab_file (str): filename of the vocabulary text file\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " numpy.ndarray: tokenized version of the input string\n", + " \"\"\"\n", + " # Set the encoding of the \"end of sentence\" as 1\n", + " EOS = 1\n", + " # Use the trax.data.tokenize method. It takes streams and returns streams,\n", + " # we get around it by making a 1-element stream with `iter`.\n", + " inputs = next(encoder.tokenize(iter([input_str]),\n", + " vocab_file=vocab_file, vocab_dir=vocab_dir))\n", + " # Mark the end of the sentence with EOS\n", + " inputs = list(inputs) + [EOS]\n", + " # Adding the batch dimension to the front of the shape\n", + " batch_inputs = np.reshape(np.array(inputs), [1, -1])\n", + " return batch_inputs\n", + "\n", + "\n", + "def detokenize(integers, vocab_file=None, vocab_dir=None):\n", + " \"\"\"Decodes an array of integers to a human-readable string\n", + " Args:\n", + " integers (numpy.ndarray): array of integers to decode\n", + " vocab_file (str): filename of the vocabulary text file\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " str: the decoded sentence.\n", + " \"\"\"\n", + " # Remove the dimensions of size 1\n", + " integers = list(np.squeeze(integers))\n", + " # Set the encoding of the \"end of sentence\" as 1\n", + " EOS = 1\n", + " # Remove the EOS to decode only the original tokens\n", + " if EOS in integers:\n", + " integers = integers[:integers.index(EOS)]\n", + " return encoder.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NKfYr4SA9dnH" + }, + "source": [ + "Let's see how we might use these functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xb7UEVAS9dnI", + "outputId": "dc1cc233-77ef-4ee2-93cc-34f7ddc586c2" + }, + "outputs": [], + "source": [ + "# Detokenize an input-target pair of tokenized sentences\n", + "print(colored(f'Single retokenized example input:', 'red'),\n", + " detokenize(train_input, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))\n", + "print(colored(f'Single retokenized example target:', 'red'),\n", + " detokenize(train_target, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))\n", + "print()\n", + "\n", + "# Tokenize and detokenize a word that is not explicitly saved in the vocabulary file.\n", + "# See how it combines the subwords 'hell' and 'o' to form the word 'hello'.\n", + "print(colored(f\"tokenize('hello'): \", 'green'), tokenize('hello', vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r_8UOdZ_9dnO" + }, + "source": [ + "## 1.4 Bucketing\n", + "\n", + "Bucketing the tokenized sentences is an important technique used to speed up training in NLP.\n", + "Here is a\n", + "[nice article describing it in detail](https://medium.com/@rashmi.margani/how-to-speed-up-the-training-of-the-sequence-model-using-bucketing-techniques-9e302b0fd976)\n", + "but the gist is very simple. Our inputs have variable lengths and you want to make these the same when batching groups of sentences together. One way to do that is to pad each sentence to the length of the longest sentence in the dataset. This might lead to some wasted computation though. For example, if there are multiple short sentences with just two tokens, do we want to pad these when the longest sentence is composed of a 100 tokens? Instead of padding with 0s to the maximum length of a sentence each time, we can group our tokenized sentences by length and bucket, as on this image (from the article above):\n", + "\n", + "![alt text](https://miro.medium.com/max/700/1*hcGuja_d5Z_rFcgwe9dPow.png)\n", + "\n", + "We batch the sentences with similar length together (e.g. the blue sentences in the image above) and only add minimal padding to make them have equal length (usually up to the nearest power of two). This allows to waste less computation when processing padded sequences.\n", + "In Trax, it is implemented in the [bucket_by_length](https://github.com/google/trax/blob/5fb8aa8c5cb86dabb2338938c745996d5d87d996/trax/supervised/inputs.py#L378) function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MUlfg9kX9dnP" + }, + "outputs": [], + "source": [ + "# Bucketing to create streams of batches.\n", + "\n", + "# Buckets are defined in terms of boundaries and batch sizes.\n", + "# Batch_sizes[i] determines the batch size for items with length < boundaries[i]\n", + "# So below, we'll take a batch of 128 sentences of length < 8, 128 if length is\n", + "# between 8 and 16, and so on. 128 batch is also taken if length is over 256.\n", + "boundaries = [8, 16, 32, 64, 128, 256]\n", + "batch_sizes = [128, 128, 128, 128, 128, 128, 128]\n", + "# Notice all is 128. As we are using TPUs, We need the same batch_size to run in parallel.\n", + "# You can make diffrent batch_sizes if you are using GPU or CPU.\n", + "\n", + "# Create the generators.\n", + "train_batch_stream = preprocessing.BucketByLength(\n", + " boundaries, batch_sizes,\n", + " length_keys=[0, 1] # As before: count inputs and targets to length.\n", + ")(filtered_train_stream)\n", + "\n", + "eval_batch_stream = preprocessing.BucketByLength(\n", + " boundaries, batch_sizes,\n", + " length_keys=[0, 1] # As before: count inputs and targets to length.\n", + ")(filtered_eval_stream)\n", + "\n", + "# Add masking for the padding (0s).\n", + "train_batch_stream = preprocessing.AddLossWeights(id_to_mask=0)(train_batch_stream)\n", + "eval_batch_stream = preprocessing.AddLossWeights(id_to_mask=0)(eval_batch_stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v5IDVjXl9dnU" + }, + "source": [ + "## 1.5 Exploring the data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vX-ukU52No8Q" + }, + "source": [ + "We will now be displaying some of our data. You will see that the functions defined above (i.e. `tokenize()` and `detokenize()`)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zI_Rea2Q9dnV", + "outputId": "2db581ca-3b3a-450f-9b91-1d924420fa51" + }, + "outputs": [], + "source": [ + "input_batch, target_batch, mask_batch = next(train_batch_stream)\n", + "\n", + "# let's see the data type of batch\n", + "print(\"input_batch data type: \", type(input_batch))\n", + "print(\"target_batch data type: \", type(target_batch))\n", + "print(\"target_batch data type: \", type(mask_batch))\n", + "\n", + "# let's see the shape of this particular batch (batch length, sentence length)\n", + "print(\"input_batch shape: \", input_batch.shape)\n", + "print(\"target_batch shape: \", target_batch.shape)\n", + "print(\"target_batch shape: \", mask_batch.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wE_ilByVN8zT" + }, + "source": [ + "The `input_batch` and `target_batch` are Numpy arrays consisting of tokenized English sentences and French sentences respectively. These tokens will later be used to produce embedding vectors for each word in the sentence (so the embedding for a sentence will be a matrix).\n", + "\n", + "We can now visually inspect some of the data. You can run the cell below several times to shuffle through the sentences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vd_71uRi9dnb", + "outputId": "c25d1a92-8953-4d1b-c916-9bfde6882414" + }, + "outputs": [], + "source": [ + "# pick a random index less than the batch size.\n", + "index = random.randrange(len(input_batch))\n", + "\n", + "# use the index to grab an entry from the input and target batch\n", + "print(colored('THIS IS THE ENGLISH SENTENCE: \\n', 'red'),\n", + " detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", + "print(colored('THIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \\n ', 'red'), input_batch[index], '\\n')\n", + "print(colored('THIS IS THE FRENCH TRANSLATION: \\n', 'red'),\n", + " detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", + "print(colored('THIS IS THE TOKENIZED VERSION OF THE FRENCH TRANSLATION: \\n', 'red'), target_batch[index], '\\n')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UDSPHBZaeRAW" + }, + "source": [ + "## Trax Data Pipeline (optional)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WP2RACXYeTse" + }, + "source": [ + "Those were the steps needed to prepare the data (steps from 1.1 to 1.5) But you could simply use [Trax data pipeline](https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html#Data) `trax.data.Serial` in the next cell. **if you run this cell you should skip (steps from 1.1 to 1.5).**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BUPhstH70Xzu" + }, + "source": [ + "You can work with your own datasets instead of loading your dataset with TFDS you can simply replace the TFDS call with an `lambda _: train_stream_fn()`\n", + "Everything in tf.Serial is a generator. Opening a file as shown above creates that generator for you.\n", + "\n", + "```python\n", + "def train_stream_fn():\n", + " # open the first language file (e.g. English sentences)\n", + " with open('lang1.csv','r') as f1:\n", + " # open the second language file (e.g. French sentences)\n", + " with open('lang2.csv','r') as f2:\n", + " # looping over the two files to combine the two translation toghether and yields them.\n", + " for l1, l2 in zip(f1,f2):\n", + " yield (l1, l2)\n", + "```\n", + "\n", + "and then add\n", + "```python\n", + "lambda _: train_stream_fn()\n", + "```\n", + "to `trax.data.Serial()` instead of\n", + "```python\n", + "trax.data.TFDS('para_crawl/enfr',\n", + " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", + " keys=('en', 'fr'),\n", + " eval_holdout_size=0.01, # 1% for eval\n", + " train=True)\n", + "```\n", + "for both the training and eval streams." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# if you run this cell you should skip (steps from 1.1 to 1.5).\n", + "\n", + "# global variables that state the filename and directory of the vocabulary file\n", + "VOCAB_FILE = 'endefr_32k.subword'\n", + "VOCAB_DIR = 'gs://trax-ml/vocabs/'\n", + "\n", + "EOS = 1\n", + "\n", + "\n", + "# generator helper function to append EOS to each sentence\n", + "def append_eos(stream):\n", + " for (inputs, targets) in stream:\n", + " inputs_with_eos = list(inputs) + [EOS]\n", + " targets_with_eos = list(targets) + [EOS]\n", + " yield np.array(inputs_with_eos), np.array(targets_with_eos)\n", + "\n", + "\n", + "train_batches_stream = preprocessing.Serial(\n", + " dataset.TFDS('para_crawl/enfr',\n", + " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", + " keys=('en', 'fr'),\n", + " eval_holdout_size=0.01, # 1% for eval\n", + " train=True), # replace TFDS with lambda _: train_stream_fn() if you want to run with your own data\n", + " encoder.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR),\n", + " lambda _: append_eos(_),\n", + " preprocessing.Shuffle(),\n", + " preprocessing.FilterByLength(max_length=512, length_keys=[0, 1]),\n", + " preprocessing.BucketByLength(boundaries=[8, 16, 32, 64, 128, 256],\n", + " batch_sizes=[128, 128, 128, 128, 128, 128, 128],\n", + " length_keys=[0, 1]),\n", + " preprocessing.AddLossWeights(id_to_mask=0)\n", + ")\n", + "\n", + "eval_batches_stream = preprocessing.Serial(\n", + " dataset.TFDS('para_crawl/enfr',\n", + " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", + " keys=('en', 'fr'),\n", + " eval_holdout_size=0.01, # 1% for eval\n", + " train=False),\n", + " encoder.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR),\n", + " lambda _: append_eos(_),\n", + " preprocessing.Shuffle(),\n", + " preprocessing.FilterByLength(max_length=1024, length_keys=[0, 1]),\n", + " preprocessing.BucketByLength(boundaries=[8, 16, 32, 64, 128, 256],\n", + " batch_sizes=[128, 128, 128, 128, 128, 128, 128],\n", + " length_keys=[0, 1]),\n", + " preprocessing.AddLossWeights(id_to_mask=0)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wWvu5PraqBQx", + "outputId": "34480b60-6a40-4be6-f51c-4e481c2e1bc3", + "tags": [] + }, + "outputs": [], + "source": [ + "#Exploring the data\n", + "train_batch_stream = train_batches_stream()\n", + "eval_batch_stream = eval_batches_stream()\n", + "input_batch, target_batch, mask_batch = next(train_batch_stream)\n", + "\n", + "# let's see the data type of batch\n", + "print(\"input_batch data type: \", type(input_batch))\n", + "print(\"target_batch data type: \", type(target_batch))\n", + "# let's see the shape of this particular batch (batch length, sentence length)\n", + "print(\"input_batch shape: \", input_batch.shape)\n", + "print(\"target_batch shape: \", target_batch.shape)\n", + "\n", + "# pick a random index less than the batch size.\n", + "index = random.randrange(len(input_batch))\n", + "# use the index to grab an entry from the input and target batch\n", + "print(colored('ENGLISH SENTENCE: \\n', 'red'),\n", + " encoder.detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", + "print(colored('THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \\n ', 'red'), input_batch[index], '\\n')\n", + "print(colored('THE FRENCH TRANSLATION: \\n', 'red'),\n", + " encoder.detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", + "print(colored('THE TOKENIZED VERSION OF THE FRENCH TRANSLATION: \\n', 'red'), target_batch[index], '\\n')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M-69Mr-9_VEk" + }, + "source": [ + "# Part (2): Model\n", + "\n", + "Now that we’ve seen preprocessing, it’s time to move into Modeling itself. Trax allows the use of Predefined Models, such as:\n", + " - Seq2Seq with Attention\n", + " - BERT\n", + " - Transformer\n", + " - Reformer\n", + "\n", + "We will be using Transformer in this Notebook As Trax provided a pretrained Transformer NMT Model which is traind on English to German dataset and We now are going to train it on English to French dataset and get a very close results to the one provide by Google Brain Team.\n", + "\n", + "You can simply change `trax.models.Transformer` in the next cell to `trax.models.Reformer` to use the Reformer model.\n", + "\n", + "```python\n", + "# you could check the available pretrained models and vocab files provided by trax by running:\n", + "!gsutil ls gs://trax-ml/\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SkAuvdOErAlP" + }, + "outputs": [], + "source": [ + "# Create a Transformer model.\n", + "model = models.Transformer(\n", + " input_vocab_size=33600,\n", + " d_model=512, d_ff=2048, dropout=0.1,\n", + " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", + " max_len=2048, mode='train')\n", + "\n", + "is_remote = True\n", + "\n", + "if is_remote:\n", + " # Pre-trained Transformer model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", + " # Initialize Transformer using pre-trained weights.\n", + " model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", + " weights_only=True)\n", + "else:\n", + " # You also, could initiate the model from an output checkpoint.\n", + " # simply change 'gs://trax-ml/models/translation/ende_wmt32k.pkl.gz' to 'output_dir/ + last_checkpoint'\n", + " # for example:\n", + " model.init_from_file(os.path.expanduser('~/Transformer_FR_pretrained_336/model.pkl.gz'), weights_only=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2p0AGzlKQusn" + }, + "source": [ + "You could have a peek at the model layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CvmdtOfeZ9Ff" + }, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E8FfOMp59doX" + }, + "source": [ + "# Part (3): Training\n", + "We will now be training our model in this section. Doing supervised training in Trax is pretty straightforward (short example [here](https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html#Supervised-training)). We will be instantiating three classes for this: `TrainTask`, `EvalTask`, and `Loop`. Let's take a closer look at each of these in the sections below." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "re1ZHUac9doY" + }, + "source": [ + "## 3.1 TrainTask\n", + "\n", + "The [TrainTask](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.training.TrainTask) class allows us to define the labeled data to use for training and the feedback mechanisms to compute the loss and update the weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gFP83q7S9doZ" + }, + "outputs": [], + "source": [ + "train_task = training.TrainTask(\n", + " # use the train batch stream as labeled data\n", + " labeled_data=train_batch_stream,\n", + " # use the cross entropy loss with LogSoftmax\n", + " loss_layer=tl.CrossEntropyLossWithLogSoftmax(),\n", + " # use the Ada factor optimizer with learning rate of 0.001\n", + " optimizer=optimizers.Adafactor(learning_rate=0.001, epsilon1=1e-30),\n", + " # have 500 warmup steps\n", + " lr_schedule=learning_schedule.multifactor(constant=1.0, warmup_steps=500),\n", + " # have a checkpoint every 100 steps\n", + " n_steps_per_checkpoint=100,\n", + " # saving a checkpoint every 1000 steps on the output_dir\n", + " n_steps_per_permanent_checkpoint=1000\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7EQI-c999doi" + }, + "source": [ + "## 3.2 EvalTask\n", + "\n", + "The [EvalTask](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.training.EvalTask) on the other hand allows us to see how the model is doing while training. For our application, we want it to report the cross entropy loss with LogSoftmax and accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u5hVQ0Qd9doj" + }, + "outputs": [], + "source": [ + "eval_task = training.EvalTask(\n", + " # use the eval batch stream as labeled data\n", + " labeled_data=eval_batch_stream,\n", + " # use the cross entropy loss with LogSoftmax and accuracy as metrics\n", + " metrics=[tl.CrossEntropyLossWithLogSoftmax(), tl.WeightedCategoryAccuracy()],\n", + " # you could specify the number of eval batch by n_eval_batches = 64 or any other number,\n", + " # but it is not specified here as we want to evaluate the whole eval data\n", + " # n_eval_batches = 64\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "14pSLHEw9dol" + }, + "source": [ + "## 3.3 Loop\n", + "\n", + "The [Loop](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.training.Loop) class defines the model we will train as well as the train and eval tasks to execute. Its `run()` method allows us to execute the training for a specified number of steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QdnRbEAz9dom" + }, + "outputs": [], + "source": [ + "# define the output directory\n", + "output_dir = '~/Transformer_FR_pretrained_336'\n", + "\n", + "# # remove old model if it exists. restarts training.\n", + "# !rm -rf output_dir\n", + "shutil.rmtree(os.path.expanduser(output_dir), ignore_errors=True)\n", + "\n", + "# define the training loop\n", + "training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir=output_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bRk-1Wsu9doo", + "outputId": "1459c595-b218-4be8-d6ea-805147ca20c5" + }, + "outputs": [], + "source": [ + "# Start Training!\n", + "training_loop.run(5_000)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SftZvkIl_4ko" + }, + "source": [ + "## More Steps (optional)\n", + "\n", + "As we have specified the `n_steps_per_permanent_checkpoint` in `training.TrainTask` it saves checkpoint in `output_dir` after the specified number of steps. So, if you have face runtime disconnection or you want to train the model for more number of steps to improve the result, you could load last checkpoint saved and load it using `training_loop.load_checkpoint`.\n", + "\n", + "This is an optional way. you could have used `model.init_from_file` as in (Part (2): Model) cells. change 'gs://trax-ml/models/translation/ende_wmt32k.pkl.gz' to 'output_dir/ + last_checkpoint'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = '~/Transformer_FR_pretrained_336'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LBq6EZy6_4Lo" + }, + "outputs": [], + "source": [ + "# This loads a checkpoint:\n", + "training_loop.load_checkpoint(directory=output_dir, filename=\"model.pkl.gz\")\n", + "# Continue training:\n", + "training_loop.run(5_000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tensorboard (optional)\n", + "The Trax training loop optimizes training, creates TensorBoard logs and model checkpoints for you. you could simply visualize them using the following:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sfCy3oZAuron" + }, + "outputs": [], + "source": [ + "# Load the TensorBoard notebook extension\n", + "%load_ext tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd {os.path.expanduser(output_dir)}\n", + "%ls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kdcy6NP1uxOI" + }, + "outputs": [], + "source": [ + "%tensorboard --logdir {os.path.expanduser(output_dir)}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bIy5wc90m0ZW" + }, + "source": [ + "if it is not loading properly, and for example your `output_dir` is:\n", + "\n", + "```python\n", + "output_dir = '/content/drive/MyDrive/Colab Notebooks/Transformer_FR_pretrained_336'\n", + "```\n", + "add:\n", + "```\n", + "%cd '/content/drive/MyDrive/Colab Notebooks/'\n", + "```\n", + "before:\n", + "```\n", + "%tensorboard --logdir output_dir\n", + "```\n", + "and change it to:\n", + "```\n", + "%tensorboard --logdir Transformer_FR_pretrained_336\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0WXTnjBJ9dov" + }, + "source": [ + "# Part (4): Testing\n", + "\n", + "We will now be using the model you just trained to translate English sentences to French. We will implement this with two functions: The first allows you to identify the next symbol (i.e. output token). The second one takes care of combining the entire translated string.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5g9O_h-R9do0" + }, + "source": [ + "## 4.1 Decoding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xH-imC6U-jBn" + }, + "outputs": [], + "source": [ + "# Setup helper functions for tokenizing and detokenizing sentences\n", + "def tokenize(input_str, vocab_file=None, vocab_dir=None):\n", + " \"\"\"Encodes a string to an array of integers\n", + " Args:\n", + " input_str (str): human-readable string to encode\n", + " vocab_file (str): filename of the vocabulary text file\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " numpy.ndarray: tokenized version of the input string\n", + " \"\"\"\n", + " # Set the encoding of the \"end of sentence\" as 1\n", + " EOS = 1\n", + " # Use the trax.data.tokenize method. It takes streams and returns streams,\n", + " # we get around it by making a 1-element stream with `iter`.\n", + " inputs = next(encoder.tokenize(iter([input_str]),\n", + " vocab_file=vocab_file, vocab_dir=vocab_dir))\n", + " # Mark the end of the sentence with EOS\n", + " inputs = list(inputs) + [EOS]\n", + " # Adding the batch dimension to the front of the shape\n", + " batch_inputs = np.reshape(np.array(inputs), [1, -1])\n", + " return batch_inputs\n", + "\n", + "\n", + "def detokenize(integers, vocab_file=None, vocab_dir=None):\n", + " \"\"\"Decodes an array of integers to a human readable string\n", + " Args:\n", + " integers (numpy.ndarray): array of integers to decode\n", + " vocab_file (str): filename of the vocabulary text file\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " str: the decoded sentence.\n", + " \"\"\"\n", + " # Remove the dimensions of size 1\n", + " integers = list(np.squeeze(integers))\n", + " # Set the encoding of the \"end of sentence\" as 1\n", + " EOS = 1\n", + " # Remove the EOS to decode only the original tokens\n", + " if EOS in integers:\n", + " integers = integers[:integers.index(EOS)]\n", + " return encoder.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R3ud8xnDGL-5" + }, + "source": [ + "There are several ways to get the next token when translating a sentence. For instance, we can just get the most probable token at each step (i.e. greedy decoding) or get a sample from a distribution. We can generalize the implementation of these two approaches by using the `tl.logsoftmax_sample()` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cD8F14b49do1" + }, + "outputs": [], + "source": [ + "def next_symbol(model, input_tokens, cur_output_tokens, temperature):\n", + " \"\"\"Returns the index of the next token.\n", + " Args:\n", + " model: the NMT model.\n", + " input_tokens (np.ndarray 1 x n_tokens): tokenized representation of the input sentence\n", + " cur_output_tokens (list): tokenized representation of previously translated words\n", + " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", + " 0.0: same as argmax, always pick the most probable token\n", + " 1.0: sampling from the distribution (can sometimes say random things)\n", + " Returns:\n", + " int: index of the next token in the translated sentence\n", + " float: log probability of the next symbol\n", + " \"\"\"\n", + " # set the length of the current output tokens\n", + " token_length = len(cur_output_tokens)\n", + " # calculate next power of 2 for padding length\n", + " padded_length = np.power(2, int(np.ceil(np.log2(token_length + 1))))\n", + " # pad cur_output_tokens up to the padded_length\n", + " padded = cur_output_tokens + [0] * (padded_length - token_length)\n", + " # model expects the output to have an axis for the batch size in front so\n", + " # convert `padded` list to a numpy array with shape (x, ) where the\n", + " # x position is the batch axis.\n", + " padded_with_batch = np.expand_dims(padded, axis=0)\n", + " # the model prediction.\n", + " output, _ = model((input_tokens, padded_with_batch))\n", + " # get log probabilities from the last token output\n", + " log_probs = output[0, token_length, :]\n", + " # get the next symbol by getting a logsoftmax sample\n", + " symbol = int(tl.logsoftmax_sample(log_probs, temperature))\n", + " return symbol, float(log_probs[symbol])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R0KlObsa9dpE" + }, + "source": [ + "The `sampling_decode()` function will call the `next_symbol()` function above several times until the next output is the end-of-sentence token (i.e. `EOS`). It takes in an input string and returns the translated version of that string.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bwIB-MQl9dpF" + }, + "outputs": [], + "source": [ + "def sampling_decode(input_sentence, model=None, temperature=0.0, vocab_file=None, vocab_dir=None):\n", + " \"\"\"Returns the translated sentence.\n", + " Args:\n", + " input_sentence (str): sentence to translate.\n", + " model: the NMT model.\n", + " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", + " 0.0: same as argmax, always pick the most probable token\n", + " 1.0: sampling from the distribution (can sometimes say random things)\n", + " vocab_file (str): filename of the vocabulary\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " tuple: (list, str, float)\n", + " list of int: tokenized version of the translated sentence\n", + " float: log probability of the translated sentence\n", + " str: the translated sentence\n", + " \"\"\"\n", + " # encode the input sentence\n", + " input_tokens = tokenize(input_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", + " # initialize the list of output tokens\n", + " cur_output_tokens = []\n", + " # initialize an integer that represents the current output index\n", + " cur_output = 0\n", + " # Set the encoding of the \"end of sentence\" as 1\n", + " EOS = 1\n", + " # check that the current output is not the end of sentence token\n", + " while cur_output != EOS:\n", + " # update the current output token by getting the index of the next word\n", + " cur_output, log_prob = next_symbol(model, input_tokens, cur_output_tokens, temperature)\n", + " # append the current output token to the list of output tokens\n", + " cur_output_tokens.append(cur_output)\n", + " # detokenize the output tokens\n", + " sentence = detokenize(cur_output_tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", + " return cur_output_tokens, log_prob, sentence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "diQYEDgF9dpG", + "outputId": "40d8f201-3aa6-42fc-ee96-fa09f2c6959f" + }, + "outputs": [], + "source": [ + "# Test the function above. Try varying the temperature setting with values from 0 to 1.\n", + "# Run it several times with each setting and see how often the output changes.\n", + "sampling_decode(\"Hello.\", model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uRbgTBWt9dpO" + }, + "source": [ + "We have set a default value of `0` to the temperature setting in our implementation of `sampling_decode()` above. As you may have noticed in the `logsoftmax_sample()` method, this setting will ultimately result in greedy decoding. This algorithm generates the translation by getting the most probable word at each step. It gets the argmax of the output array of your model and then returns that index. See the testing function and sample inputs below. You'll notice that the output will remain the same each time you run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "g1txjY-x9dpP" + }, + "outputs": [], + "source": [ + "def greedy_decode_test(sentence, model=None, vocab_file=None, vocab_dir=None):\n", + " \"\"\"Prints the input and output of our NMT model using greedy decode\n", + " Args:\n", + " sentence (str): a custom string.\n", + " model: the NMT model.\n", + " vocab_file (str): filename of the vocabulary\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " str: the translated sentence\n", + " \"\"\"\n", + " _, _, translated_sentence = sampling_decode(sentence, model, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", + " print(\"English: \", sentence)\n", + " print(\"French: \", translated_sentence)\n", + " return translated_sentence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "i7XKz-9I9dpS", + "outputId": "89ebe68b-1522-40e7-bc40-ee525c509235" + }, + "outputs": [], + "source": [ + "# put a custom string here\n", + "your_sentence = 'I love languages.'\n", + "greedy_decode_test(your_sentence, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "M8UlR7LS9dpU", + "outputId": "40223d3f-77b5-43c5-cdef-42192673211f" + }, + "outputs": [], + "source": [ + "greedy_decode_test('You are almost done with the assignment!', model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sf80_9T29dpX" + }, + "source": [ + "## 4.2 Minimum Bayes-Risk Decoding\n", + "\n", + "Getting the most probable token at each step may not necessarily produce the best results. Another approach is to do Minimum Bayes Risk Decoding or MBR. The general steps to implement this are:\n", + "\n", + "1. take several random samples\n", + "2. score each sample against all other samples\n", + "3. select the one with the highest score" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hp_qzJ8u9dpX" + }, + "source": [ + "\n", + "### 4.2.1 Generating samples\n", + "\n", + "First, let's build a function to generate several samples. You can use the `sampling_decode()` function you developed earlier to do this easily. We want to record the token list and log probability for each sample as these will be needed in the next step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4iSRPOrI9dpX" + }, + "outputs": [], + "source": [ + "def generate_samples(sentence, n_samples, model=None, temperature=0.6, vocab_file=None, vocab_dir=None):\n", + " \"\"\"Generates samples using sampling_decode()\n", + " Args:\n", + " sentence (str): sentence to translate.\n", + " n_samples (int): number of samples to generate\n", + " model: the NMT model.\n", + " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", + " 0.0: same as argmax, always pick the most probable token\n", + " 1.0: sampling from the distribution (can sometimes say random things)\n", + " vocab_file (str): filename of the vocabulary\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " tuple: (list, list)\n", + " list of lists: token list per sample\n", + " list of floats: log probability per sample\n", + " \"\"\"\n", + " # define lists to contain samples and probabilities\n", + " samples, log_probs = [], []\n", + " # run a for loop to generate n samples\n", + " for _ in range(n_samples):\n", + " # get a sample using the sampling_decode() function\n", + " sample, logp, _ = sampling_decode(sentence, model, temperature, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", + " # append the token list to the samples list\n", + " samples.append(sample)\n", + " # append the log probability to the log_probs list\n", + " log_probs.append(logp)\n", + " return samples, log_probs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LlYC8y8H9dpZ", + "outputId": "241edf8f-3921-46be-930b-00058cd6efb5" + }, + "outputs": [], + "source": [ + "# generate 4 samples with the default temperature (0.6)\n", + "generate_samples('I love languages.', 4, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 36 + }, + "id": "VR6FLNdcILll", + "outputId": "b228a403-306f-4c74-b5ea-8c0f129906a4" + }, + "outputs": [], + "source": [ + "detokenize([769, 31, 31720, 21, 15267, 3, 1], VOCAB_FILE, VOCAB_DIR)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HonzLcOP9dpb" + }, + "source": [ + "### 4.2.2 Comparing overlaps\n", + "\n", + "Let us now build our functions to compare a sample against another. There are several metrics available and you can try experimenting with any one of these. We will be calculating scores for unigram overlaps. One of the more simple metrics is the [Jaccard similarity](https://en.wikipedia.org/wiki/Jaccard_index) which gets the intersection over union of two sets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IB7ipzoZ9dpc" + }, + "outputs": [], + "source": [ + "def jaccard_similarity(candidate, reference):\n", + " \"\"\"Returns the Jaccard similarity between two token lists\n", + " Args:\n", + " candidate (list of int): tokenized version of the candidate translation\n", + " reference (list of int): tokenized version of the reference translation\n", + " Returns:\n", + " float: overlap between the two token lists\n", + " \"\"\"\n", + " # convert the lists to a set to get the unique tokens\n", + " can_unigram_set, ref_unigram_set = set(candidate), set(reference)\n", + " # get the set of tokens common to both candidate and reference\n", + " joint_elems = can_unigram_set.intersection(ref_unigram_set)\n", + " # get the set of all tokens found in either candidate or reference\n", + " all_elems = can_unigram_set.union(ref_unigram_set)\n", + " # divide the number of joint elements by the number of all elements\n", + " overlap = len(joint_elems) / len(all_elems)\n", + " return overlap" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CZRis5hp9dph" + }, + "source": [ + "One of the more commonly used metrics in machine translation is the ROUGE score. For unigrams, this is called ROUGE-1 and you can output the scores for both precision and recall when comparing two samples. To get the final score, you will want to compute the F1-score as given by:\n", + "\n", + "$$score = 2* \\frac{(precision * recall)}{(precision + recall)}$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WRhPTgv09dpi" + }, + "outputs": [], + "source": [ + "# for making a frequency table easily\n", + "from collections import Counter\n", + "\n", + "\n", + "def rouge1_similarity(system, reference):\n", + " \"\"\"Returns the ROUGE-1 score between two token lists\n", + " Args:\n", + " system (list of int): tokenized version of the system translation\n", + " reference (list of int): tokenized version of the reference translation\n", + " Returns:\n", + " float: overlap between the two token lists\n", + " \"\"\"\n", + " # make a frequency table of the system tokens\n", + " sys_counter = Counter(system)\n", + " # make a frequency table of the reference tokens\n", + " ref_counter = Counter(reference)\n", + " # initialize overlap to 0\n", + " overlap = 0\n", + " # run a for loop over the sys_counter object\n", + " for token in sys_counter:\n", + " # lookup the value of the token in the sys_counter dictionary\n", + " token_count_sys = sys_counter.get(token, 0)\n", + " # lookup the value of the token in the ref_counter dictionary\n", + " token_count_ref = ref_counter.get(token, 0)\n", + " # update the overlap by getting the smaller number between the two token counts above\n", + " overlap += min(token_count_sys, token_count_ref)\n", + " # get the precision (i.e. number of overlapping tokens / number of system tokens)\n", + " precision = overlap / sum(sys_counter.values())\n", + " # get the recall (i.e. number of overlapping tokens / number of reference tokens)\n", + " recall = overlap / sum(ref_counter.values())\n", + " if precision + recall != 0:\n", + " # compute the f1-score\n", + " rouge1_score = 2 * ((precision * recall) / (precision + recall))\n", + " else:\n", + " rouge1_score = 0\n", + " return rouge1_score" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qn3wLqSb9dpp" + }, + "source": [ + "### 4.2.3 Overall score\n", + "\n", + "We will now build a function to generate the overall score for a particular sample. As mentioned earlier, we need to compare each sample with all other samples. For instance, if we generated 30 sentences, we will need to compare sentence 1 to sentences 2 to 30. Then, we compare sentence 2 to sentences 1 and 3 to 30, and so forth. At each step, we get the average score of all comparisons to get the overall score for a particular sample. To illustrate, these will be the steps to generate the scores of a 4-sample list.\n", + "\n", + "1. Get similarity score between sample 1 and sample 2\n", + "2. Get similarity score between sample 1 and sample 3\n", + "3. Get similarity score between sample 1 and sample 4\n", + "4. Get average score of the first 3 steps. This will be the overall score of sample 1.\n", + "5. Iterate and repeat until samples 1 to 4 have overall scores.\n", + "\n", + "We will be storing the results in a dictionary for easy lookups." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Umtj0NLX9dpp" + }, + "outputs": [], + "source": [ + "def average_overlap(similarity_fn, samples, *ignore_params):\n", + " \"\"\"Returns the arithmetic mean of each candidate sentence in the samples\n", + " Args:\n", + " similarity_fn (function): similarity function used to compute the overlap\n", + " samples (list of lists): tokenized version of the translated sentences\n", + " *ignore_params: additional parameters will be ignored\n", + " Returns:\n", + " dict: scores of each sample\n", + " key: index of the sample\n", + " value: score of the sample\n", + " \"\"\"\n", + " # initialize dictionary\n", + " scores = {}\n", + " # run a for loop for each sample\n", + " for index_candidate, candidate in enumerate(samples):\n", + " # initialize overlap to 0.0\n", + " overlap = 0.0\n", + " # run a for loop for each sample\n", + " for index_sample, sample in enumerate(samples):\n", + " # skip if the candidate index is the same as the sample index\n", + " if index_candidate == index_sample:\n", + " continue\n", + " # get the overlap between candidate and sample using the similarity function\n", + " sample_overlap = similarity_fn(candidate, sample)\n", + " # add the sample overlap to the total overlap\n", + " overlap += sample_overlap\n", + " # get the score for the candidate by computing the average\n", + " score = overlap / index_sample\n", + " # save the score in the dictionary. use index as the key.\n", + " scores[index_candidate] = score\n", + " return scores" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-w7LL7lm9dpx" + }, + "source": [ + "It is also common to see the weighted mean being used to calculate the overall score instead of just the arithmetic mean." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o70TS8PG9dpy" + }, + "outputs": [], + "source": [ + "def weighted_avg_overlap(similarity_fn, samples, log_probs):\n", + " \"\"\"Returns the weighted mean of each candidate sentence in the samples\n", + " Args:\n", + " samples (list of lists): tokenized version of the translated sentences\n", + " log_probs (list of float): log probability of the translated sentences\n", + " Returns:\n", + " dict: scores of each sample\n", + " key: index of the sample\n", + " value: score of the sample\n", + " \"\"\"\n", + " # initialize dictionary\n", + " scores = {}\n", + " # run a for loop for each sample\n", + " for index_candidate, candidate in enumerate(samples):\n", + " # initialize overlap and weighted sum\n", + " overlap, weight_sum = 0.0, 0.0\n", + " # run a for loop for each sample\n", + " for index_sample, (sample, logp) in enumerate(zip(samples, log_probs)):\n", + " # skip if the candidate index is the same as the sample index\n", + " if index_candidate == index_sample:\n", + " continue\n", + " # convert log probability to linear scale\n", + " sample_p = float(np.exp(logp))\n", + " # update the weighted sum\n", + " weight_sum += sample_p\n", + " # get the unigram overlap between candidate and sample\n", + " sample_overlap = similarity_fn(candidate, sample)\n", + " # update the overlap\n", + " overlap += sample_p * sample_overlap\n", + " # get the score for the candidate\n", + " score = overlap / weight_sum\n", + " # save the score in the dictionary. use index as the key.\n", + " scores[index_candidate] = score\n", + " return scores" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l5jgBrPu9dp4" + }, + "source": [ + "### 4.2.4 Putting it all together\n", + "\n", + "We will now put everything together and develop the `mbr_decode()` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "S58nPXgY9dp5" + }, + "outputs": [], + "source": [ + "def mbr_decode(sentence, n_samples=4, score_fn=weighted_avg_overlap, similarity_fn=rouge1_similarity, model=model,\n", + " temperature=0.6, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR):\n", + " \"\"\"Returns the translated sentence using Minimum Bayes Risk decoding\n", + " Args:\n", + " sentence (str): sentence to translate.\n", + " n_samples (int): number of samples to generate\n", + " score_fn (function): function that generates the score for each sample\n", + " similarity_fn (function): function used to compute the overlap between a\n", + " pair of samples\n", + " model: the NMT model.\n", + " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", + " 0.0: same as argmax, always pick the most probable token\n", + " 1.0: sampling from the distribution (can sometimes say random things)\n", + " vocab_file (str): filename of the vocabulary\n", + " vocab_dir (str): path to the vocabulary file\n", + " Returns:\n", + " str: the translated sentence\n", + " \"\"\"\n", + " # generate samples\n", + " samples, log_probs = generate_samples(sentence, n_samples,\n", + " model, temperature,\n", + " vocab_file, vocab_dir)\n", + " # use the scoring function to get a dictionary of scores\n", + " scores = score_fn(similarity_fn, samples, log_probs)\n", + " # find the key with the highest score\n", + " max_index = max(scores, key=scores.get)\n", + " # detokenize the token list associated with the max_index\n", + " translated_sentence = detokenize(samples[max_index], vocab_file, vocab_dir)\n", + " return (translated_sentence, max_index, scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ab1LHo-59dp8" + }, + "outputs": [], + "source": [ + "# put a custom string here\n", + "your_sentence = 'She speaks English, French and German.'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BhgGWv7c9dp_", + "outputId": "ae5e00cb-0935-45f7-9fbe-c96f0e12dfc1" + }, + "outputs": [], + "source": [ + "mbr_decode(your_sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 36 + }, + "id": "QqyR1Ym6A_Ah", + "outputId": "eb7db397-28a6-41c7-c44e-0c314574d147" + }, + "outputs": [], + "source": [ + "mbr_decode('You have completed the tutorial.')[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RPbqDGUY8Vp_" + }, + "source": [ + "# **Resources**\n", + "\n", + "- [Natural Language Processing Specialization](https://www.coursera.org/specializations/natural-language-processing)\n", + "\n", + "- [Trax documentation](https://trax-ml.readthedocs.io/en/latest/index.html)\n", + "\n", + "- [Trax community](https://gitter.im/trax-ml/community)" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "collapsed_sections": [ + "_WpKodqa9dmJ", + "VY6_SnLM9dms", + "WD0ZqedYIpr3", + "r_8UOdZ_9dnO", + "v5IDVjXl9dnU", + "4U_V6nNQ_37u" + ], + "include_colab_link": true, + "machine_shape": "hm", + "name": "NMT with Transformers/Reformers using Trax.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-7-1-Wide-Residual-Networks.ipynb b/resources/examples/ipynb/Example-7-1-Wide-Residual-Networks.ipynb new file mode 100644 index 000000000..28f4e4468 --- /dev/null +++ b/resources/examples/ipynb/Example-7-1-Wide-Residual-Networks.ipynb @@ -0,0 +1,772 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "A00Q5PP0j8ZH" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zTfrmAx5kBwR" + }, + "source": [ + "# Author\n", + "\n", + "SauravMaheshkar- [@MaheshkarSaurav](https://twitter.com/MaheshkarSaurav)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_kg_hide-input": true, + "_kg_hide-output": true, + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "id": "pgp28DB-j6ev", + "papermill": { + "duration": 29.585875, + "end_time": "2020-12-01T01:33:26.950713", + "exception": false, + "start_time": "2020-12-01T01:32:57.364838", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install --upgrade trax" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uUPujeMDj6ew", + "papermill": { + "duration": 0.035652, + "end_time": "2020-12-01T01:33:27.018778", + "exception": false, + "start_time": "2020-12-01T01:33:26.983126", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x8qmALfWj6ew", + "papermill": { + "duration": 0.035743, + "end_time": "2020-12-01T01:33:27.097892", + "exception": false, + "start_time": "2020-12-01T01:33:27.062149", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Prior to the introduction of [Wide Residual Networks](https://arxiv.org/pdf/1605.07146.pdf) (WRNs) by Sergey Zagoruyko and Nikos Komodakis, deep residual networks were shown to have a fractional increase in performance but at the cost of **doubling** the number of layers. This led to the problem of diminishing feature reuse and overall made the models slow to train. WRNs showed that having a wider residual network leads to better performance and increased the then SOTA results on CIFAR, SVHN and COCO.\n", + "\n", + "In this notebook we run through a simple demonstration of training a WideResnet on the `cifar10` dataset using the [Trax](https://github.com/google/trax) framework. Trax is an end-to-end library for deep learning that focuses on **clear code and speed**. It is actively used and maintained in the *Google Brain team*." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SAGHOSLHj6ew", + "papermill": { + "duration": 0.031704, + "end_time": "2020-12-01T01:33:27.164144", + "exception": false, + "start_time": "2020-12-01T01:33:27.132440", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Issues with Traditional Residual Networks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l644xzpolDH6" + }, + "source": [ + "![Screenshot 2020-12-01 at 10.04.11 AM.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Br6sY3Skj6ew", + "papermill": { + "duration": 0.031242, + "end_time": "2020-12-01T01:33:27.288600", + "exception": false, + "start_time": "2020-12-01T01:33:27.257358", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Figure 1: *Various ResNet Blocks*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1Aa3qJM8j6ew", + "papermill": { + "duration": 0.030469, + "end_time": "2020-12-01T01:33:27.350166", + "exception": false, + "start_time": "2020-12-01T01:33:27.319697", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Diminishing Feature Reuse" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aByhtJHsj6ew", + "papermill": { + "duration": 0.029283, + "end_time": "2020-12-01T01:33:27.409595", + "exception": false, + "start_time": "2020-12-01T01:33:27.380312", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "A **Residual block with a identity mapping**, which allows us to train very deep networks is a **weakness**. As the gradient flows through the network there is nothing to force it to go through the residual block weights and thus it can avoid learning during training. This only a few blocks can run valuable representations or many blocks could share very little information with small contributions to the final goal. This problem was tried to be addressed using a special case of dropout applied to residual blocks in which an identity scalar weight is added to each residual block on which dropout is applied.\n", + "\n", + "As we are widening our residual blocks, this results in an increase in the number of parameters, and the authors decided to study the effects of dropout to regularize training and prevent overfitting. They argued that the dropout should be inserted between convolutional layers instead of being inserted in the identity part of the block and showed that this results in consistent gains, yielding new SOTA results." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "reNP-uCgj6ew", + "papermill": { + "duration": 0.029729, + "end_time": "2020-12-01T01:33:27.469022", + "exception": false, + "start_time": "2020-12-01T01:33:27.439293", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "The paper [Wide Residual Networks](https://arxiv.org/pdf/1605.07146.pdf) attemptsto answer the question of how wide deep residual networks should be and address the problem of training." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RKfIYWqoj6ew", + "papermill": { + "duration": 0.029851, + "end_time": "2020-12-01T01:33:27.529228", + "exception": false, + "start_time": "2020-12-01T01:33:27.499377", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Residual Networks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c7uSNYYuj6ew", + "papermill": { + "duration": 0.030099, + "end_time": "2020-12-01T01:33:27.588818", + "exception": false, + "start_time": "2020-12-01T01:33:27.558719", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "$\\large\n", + "x_{l+1} = x_l + \\mathbb{F}(x_l, W_l)\n", + "$\n", + "\n", + "\n", + "This is the representation of a Residual block with an identity mapping.\n", + "\n", + "* $x_{l+1}$ and $x_l$ represent the input and output of the $l$-th unit in the network\n", + "\n", + "* $\\mathbb{F}$ is a residual function\n", + "\n", + "* $W_l$ are the parameters" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E-3X8obdj6ew", + "papermill": { + "duration": 0.029364, + "end_time": "2020-12-01T01:33:27.647782", + "exception": false, + "start_time": "2020-12-01T01:33:27.618418", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Figure 1(a) and 1(c) represent the fundamental difference between the *basic* and the *basic-wide* blocks used." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mgf_2paVj6ew", + "papermill": { + "duration": 0.032155, + "end_time": "2020-12-01T01:33:27.709374", + "exception": false, + "start_time": "2020-12-01T01:33:27.677219", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Architecture" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rc5ECr36lQSU" + }, + "source": [ + "![Screenshot 2020-12-01 at 10.04.48 AM.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CQvqUPVWj6ew", + "papermill": { + "duration": 0.029146, + "end_time": "2020-12-01T01:33:27.827298", + "exception": false, + "start_time": "2020-12-01T01:33:27.798152", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "This is the basic structure of Wide Residual Networks. In the papers the size of `conv1` was fixed in all the experiments, while the \"widening\" factor `k` was experimented with in the next three groups. Here `k` is the. widening factor which multiplies the number of features in convolutional layers\n", + "\n", + "Let B(M) denote various residual block structures, where M is a list with the kernel sizes of the convoutional layers in a block.\n", + "The following architectures were used in experimentation:-\n", + "\n", + "* B(3,3) - The Original \"basic\" block. (Figure 1(a))\n", + "* B(3,1,3) - Same as basic but with a extra 1x1 layer in between\n", + "* B(1,3,1) - For Bottleneck (Figure 1(b))\n", + "* B(1,3) - Having Alternative 1x1-3x3 convolutions\n", + "* B(3,1) - Having Alternative 3x3-1x1 convolutions\n", + "* B(3,1,1) - A Network-in-Network style block" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DyOIi4wmj6ew", + "papermill": { + "duration": 0.029591, + "end_time": "2020-12-01T01:33:27.886716", + "exception": false, + "start_time": "2020-12-01T01:33:27.857125", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Experimental Results" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RA8b4sDzlbJk" + }, + "source": [ + "![Screenshot 2020-12-01 at 10.05.33 AM.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5YyR95dEj6ex", + "papermill": { + "duration": 0.032181, + "end_time": "2020-12-01T01:33:28.007010", + "exception": false, + "start_time": "2020-12-01T01:33:27.974829", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "*Test error (%, median over 5 runs) on CIFAR-10 of residual networks with k = 1 and different block types. Time represents one training epoch*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VsUwjYPsj6ex", + "papermill": { + "duration": 0.029054, + "end_time": "2020-12-01T01:33:28.070359", + "exception": false, + "start_time": "2020-12-01T01:33:28.041305", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "The paper highlights that the block structure B(3,3) beats B(3,1) and B(3,1,3) by a little margin." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rlc-MxfZj6ex", + "papermill": { + "duration": 0.028867, + "end_time": "2020-12-01T01:33:28.128575", + "exception": false, + "start_time": "2020-12-01T01:33:28.099708", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Key Takeaways" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G0Ix8SDaj6ex", + "papermill": { + "duration": 0.030514, + "end_time": "2020-12-01T01:33:28.188471", + "exception": false, + "start_time": "2020-12-01T01:33:28.157957", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "The paper highlights a method, giving a total improvement of 4.4% over ResNet-1001 and showing that:-\n", + "\n", + "* widening consistently improves performance across residual networks of different depth\n", + "\n", + "* incresing both depth and width helps until the number of parameters becomes too high and stronger regularization is required\n", + "\n", + "* there doesn't seem to be a regularization effect from very high depth in residual networks as wide networks with the same number of parameters as thin ones can learn same or better representations. Furthermore, wide networks can successfully learn with a 2 or more times larger number of parameters than thin ones, which would require doubling the depth of thin networks, making them infeasibly expensive to train." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dXm6AEpdj6ex", + "papermill": { + "duration": 0.028903, + "end_time": "2020-12-01T01:33:28.247192", + "exception": false, + "start_time": "2020-12-01T01:33:28.218289", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Importing Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", + "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", + "id": "lJ_OiL_wj6ex", + "papermill": { + "duration": 40.550443, + "end_time": "2020-12-01T01:34:08.826937", + "exception": false, + "start_time": "2020-12-01T01:33:28.276494", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import trax\n", + "from trax import layers as tl\n", + "from trax.supervised import training\n", + "\n", + "# Trax offers the WideResnet architecture in it's models module\n", + "from trax.models.resnet import WideResnet\n", + "\n", + "trax.fastmath.set_backend('tensorflow-numpy')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P9PPQOMOj6ex", + "papermill": { + "duration": 0.029394, + "end_time": "2020-12-01T01:34:08.888184", + "exception": false, + "start_time": "2020-12-01T01:34:08.858790", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Downloading Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9Uto6Pgej6ex", + "papermill": { + "duration": 0.02981, + "end_time": "2020-12-01T01:34:08.947487", + "exception": false, + "start_time": "2020-12-01T01:34:08.917677", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Trax offers a rich collection of [.data](https://trax-ml.readthedocs.io/en/latest/trax.data.html) API's to create input pipelines. One of which is the [`trax.data.TFDS()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.TFDS) which returns an iterator of numpy arrays representing the dataset.\n", + "\n", + "If you'd like to learn more about the trax.data API's please checkout the notebook [here](https://www.kaggle.com/sauravmaheshkar/trax-data-explained) where I explain the most common API's in a in-depth manner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ihYJyhJoj6ex", + "papermill": { + "duration": 56.298163, + "end_time": "2020-12-01T01:35:05.275849", + "exception": false, + "start_time": "2020-12-01T01:34:08.977686", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "%%capture\n", + "train_stream = trax.data.TFDS('cifar10', keys=('image', 'label'), train=True)()\n", + "eval_stream = trax.data.TFDS('cifar10', keys=('image', 'label'), train=False)()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tqEE8bLXj6ex", + "papermill": { + "duration": 0.031813, + "end_time": "2020-12-01T01:35:05.346382", + "exception": false, + "start_time": "2020-12-01T01:35:05.314569", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Batch Generator" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3X4Yy6P9j6ex", + "papermill": { + "duration": 0.029693, + "end_time": "2020-12-01T01:35:05.405910", + "exception": false, + "start_time": "2020-12-01T01:35:05.376217", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Here, we create pre-processing pipelines, by using the [`Shuffle()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.inputs.Shuffle), [`Batch()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.inputs.Batch) and [`AddLossWeights()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.inputs.AddLossWeights) functions from the trax.data API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BvR6FwLxj6ex", + "papermill": { + "duration": 0.042864, + "end_time": "2020-12-01T01:35:05.478534", + "exception": false, + "start_time": "2020-12-01T01:35:05.435670", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "train_data_pipeline = trax.data.Serial(\n", + " trax.data.Shuffle(),\n", + " trax.data.Batch(64),\n", + " trax.data.AddLossWeights(),\n", + ")\n", + "\n", + "train_batches_stream = train_data_pipeline(train_stream)\n", + "\n", + "eval_data_pipeline = trax.data.Serial(\n", + " trax.data.Batch(64),\n", + " trax.data.AddLossWeights(),\n", + ")\n", + "\n", + "eval_batches_stream = eval_data_pipeline(eval_stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZFSkOQIGj6ex", + "papermill": { + "duration": 0.030008, + "end_time": "2020-12-01T01:35:05.539520", + "exception": false, + "start_time": "2020-12-01T01:35:05.509512", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Model Architecture" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m3GvLNa1j6ex", + "papermill": { + "duration": 0.030691, + "end_time": "2020-12-01T01:35:05.601093", + "exception": false, + "start_time": "2020-12-01T01:35:05.570402", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We use the `WideResnet` architecture defined in `trax.models.resnet` module. By Default the \"widening factor\" is set to 1, thus we experiment with two values, 1 and 2. The Architecture doesn't contain a [`tl.LogSoftmax()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.LogSoftmax) function so we add it to our model using the [`tl.Serial()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.combinators.Serial) combinator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZYMPoH0yj6ex", + "papermill": { + "duration": 0.050465, + "end_time": "2020-12-01T01:35:05.682174", + "exception": false, + "start_time": "2020-12-01T01:35:05.631709", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "thin_model = tl.Serial(\n", + " WideResnet(widen_factor=1),\n", + " tl.LogSoftmax()\n", + ")\n", + "\n", + "wide_model = tl.Serial(\n", + " WideResnet(widen_factor=2),\n", + " tl.LogSoftmax()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7_6akNEVj6ex", + "papermill": { + "duration": 0.030998, + "end_time": "2020-12-01T01:35:05.744169", + "exception": false, + "start_time": "2020-12-01T01:35:05.713171", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "When we have our model and the data, we use [`trax.supervised.training`](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#module-trax.supervised.training) to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HPzQ5xJHj6ex", + "papermill": { + "duration": 0.678771, + "end_time": "2020-12-01T01:35:06.454617", + "exception": false, + "start_time": "2020-12-01T01:35:05.775846", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "train_task = training.TrainTask(\n", + " labeled_data=train_batches_stream,\n", + " loss_layer=tl.CrossEntropyLoss(),\n", + " optimizer=trax.optimizers.Adam(0.01),\n", + " n_steps_per_checkpoint=1000,\n", + ")\n", + "\n", + "eval_task = training.EvalTask(\n", + " labeled_data=eval_batches_stream,\n", + " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],\n", + " n_eval_batches=20,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eaj_Y4FPj6ex", + "outputId": "55396574-ad00-4112-f560-06268d7efe21", + "papermill": { + "duration": 3162.496721, + "end_time": "2020-12-01T02:27:48.982225", + "exception": false, + "start_time": "2020-12-01T01:35:06.485504", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "training_loop = training.Loop(thin_model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir='./thin_model')\n", + "\n", + "training_loop.run(5000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3NvZ7a1Kj6ez", + "outputId": "84ea1d39-0fb6-4892-85fc-d340f562de3c", + "papermill": { + "duration": 6897.182439, + "end_time": "2020-12-01T04:22:46.210173", + "exception": false, + "start_time": "2020-12-01T02:27:49.027734", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "training_loop = training.Loop(wide_model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir='./wide_model')\n", + "\n", + "training_loop.run(5000)" + ] + } + ], + "metadata": { + "colab": { + "name": "illustrated-wideresnet.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "papermill": { + "duration": 10194.991178, + "end_time": "2020-12-01T04:22:46.481666", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2020-12-01T01:32:51.490488", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-7-2-Terraformer-From-Scratch.ipynb b/resources/examples/ipynb/Example-7-2-Terraformer-From-Scratch.ipynb new file mode 100644 index 000000000..199809d26 --- /dev/null +++ b/resources/examples/ipynb/Example-7-2-Terraformer-From-Scratch.ipynb @@ -0,0 +1,2586 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Vzsxj2EV3lfL" + }, + "source": [ + "# Scaling Transformers - Sparse Is Enough\n", + "\n", + "Licensed under the Apache License, Version 2.0\n", + "This colab contains all relevant code for the paper \"Sparse is Enough in Scaling Transformers\". We depend on the Trax library and the experiments in the paper were not run with the colab but in a distributed setup with the attached config files -- but with the code below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SMmztiOqenFD" + }, + "outputs": [], + "source": [ + "# Imports.\n", + "!pip install --upgrade -q trax==1.3.9\n", + "\n", + "import functools\n", + "import os\n", + "import random\n", + "import time\n", + "import numpy as np\n", + "\n", + "import jax\n", + "import trax\n", + "from trax import layers as tl\n", + "from trax import fastmath\n", + "from trax.fastmath import numpy as jnp\n", + "from trax.supervised import training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fi6zzlt15l-d" + }, + "source": [ + "## Main sparse layers\n", + "\n", + "This cell contains the implementation of our main sparse layers:\n", + "* sparse QKV layers\n", + "* sparse feed-forward blocks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kbTJBQ_fBz8d" + }, + "outputs": [], + "source": [ + "def SplitLastAxis(num_splits):\n", + " return tl.Fn(f'SplitLastAxis_{num_splits}',\n", + " lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1)))\n", + "\n", + "\n", + "def MergeLastTwoAxes():\n", + " return tl.Fn('MergeLastTwoAxes',\n", + " lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,)))\n", + "\n", + "\n", + "def LocallyConnectedDense(n_modules, n_units, kernel_size=1,\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6),\n", + " use_bias=True):\n", + " \"\"\"Layer using LocallyConnected1d for approximation of Dense layer.\n", + "\n", + " The layer splits the last axis of a tensor into `n_modules`, then runs\n", + " LocallyConnected1d (grouped convolution) on all those modules, and\n", + " concatenates their results. It is essentially a locally-sensitive\n", + " approximation of Dense layer, with number of parameters smaller by the factor\n", + " of `n_modules / kernel_size`.\n", + "\n", + " Args:\n", + " n_modules: Indicates how many modules (pixels) should be input and output\n", + " split into for processing.\n", + " n_units: how many outputs (filters) should each module generate.\n", + " kernel_size: The size of the kernel to be used.\n", + " kernel_initializer: Function that creates a matrix of (random) initial\n", + " connection weights `W` for the layer.\n", + " bias_initializer: Function that creates a vector of (random) initial\n", + " bias weights `b` for the layer.\n", + " use_bias: If `True`, compute an affine map `y = Wx + b`; else compute\n", + " a linear map `y = Wx`.\n", + "\n", + " Returns:\n", + " LocallyConnectedDense tl.Layer.\n", + " \"\"\"\n", + " if n_modules == 1:\n", + " return tl.Dense(n_units, kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer, use_bias=use_bias)\n", + " return tl.Serial(\n", + " SplitLastAxis(n_modules),\n", + " tl.LocallyConnected1d(\n", + " n_units, kernel_size, kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer, use_bias=use_bias, padding='WRAP'),\n", + " MergeLastTwoAxes())\n", + "\n", + "\n", + "class _RememberPad(tl.Layer):\n", + " \"\"\"Layer which remembers last N elements in predict mode.\"\"\"\n", + "\n", + " def __init__(self, n_items_to_remember, mode):\n", + " \"\"\"Returns a layer which remembers last N elements in predict mode.\n", + "\n", + " For predict mode, the layer remembers last N elements and pads with them.\n", + " For other modes, it pads with zeros. The layer pads/remembers elements from\n", + " the second axis.\n", + "\n", + " Args:\n", + " n_items_to_remember: Number of items to remember/pad with.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " \"\"\"\n", + " super().__init__(name='_RememberPad')\n", + " self._n_items_to_remember = n_items_to_remember\n", + " self._mode = mode\n", + " self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none\n", + "\n", + " def monkey_patched_mask(self):\n", + " # This is necessary for Terraformer model. See comments there.\n", + " # The mask will only be used in Terraformer in predict mode.\n", + " return None\n", + "\n", + " def forward(self, x):\n", + " if self._n_items_to_remember == 0:\n", + " return x\n", + " if self._mode == 'predict':\n", + " x = jnp.concatenate([self.state[0], x], axis=1)\n", + " if self._portal_mask is not None and 'init' in self.state[1]:\n", + " assert x.shape[0] == 1\n", + " mask = self._portal_mask.get_value()\n", + " count_padding = jnp.sum(mask == 0, dtype=jnp.int32)\n", + " self.state = (fastmath.dynamic_slice_in_dim(\n", + " x, x.shape[1] - (self._n_items_to_remember + count_padding),\n", + " self._n_items_to_remember, axis=1), {'forward': ()})\n", + " else:\n", + " self.state = (x[:, -self._n_items_to_remember:, ...], {'forward': ()})\n", + " else:\n", + " pad_widths = [[0, 0] for _ in range(len(x.shape))]\n", + " pad_widths[1][0] = self._n_items_to_remember\n", + " x = jnp.pad(x, pad_width=pad_widths, mode='constant')\n", + " return x\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer's weights.\"\"\"\n", + " if isinstance(input_signature, (list, tuple)):\n", + " input_signature = input_signature[0]\n", + " self.weights = ()\n", + " if self._mode == 'predict':\n", + " shape = list(input_signature.shape)\n", + " shape[1] = self._n_items_to_remember\n", + " self.state = (jnp.zeros(shape, dtype=jnp.float32), {'init': ()})\n", + " else:\n", + " self.state = ()\n", + "\n", + "\n", + "def LocallyConvDense(n_modules, n_units, mode, kernel_size=1,\n", + " length_kernel_size=1):\n", + " \"\"\"Layer using local convolutions for approximation of Dense layer.\n", + "\n", + " The layer splits the last axis of a tensor into `n_modules`, then runs\n", + " a convolution on all those modules, and concatenates their results.\n", + " It is similar to LocallyConnectedDense above, but shares weights.\n", + "\n", + " Args:\n", + " n_modules: Indicates how many modules (pixels) should be input and output\n", + " split into for processing.\n", + " n_units: how many outputs (filters) should each module generate.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " kernel_size: The size of the kernel to be used.\n", + " length_kernel_size: If > 1, also do causal convolution on the previous axis,\n", + " which is often the sentence length in sequence models.\n", + "\n", + " Returns:\n", + " LocallyConvDense tl.Layer.\n", + " \"\"\"\n", + " if n_modules == 1:\n", + " return tl.Dense(n_units)\n", + " if kernel_size % 2 != 1:\n", + " raise ValueError('Currently we only handle odd kernel sizes.')\n", + " half = (kernel_size - 1) // 2\n", + " pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]]\n", + " return tl.Serial(\n", + " SplitLastAxis(n_modules),\n", + " tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths, mode='constant')),\n", + " _RememberPad(length_kernel_size-1, mode=mode),\n", + " tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)),\n", + " MergeLastTwoAxes()\n", + " )\n", + "\n", + "\n", + "def RandomLayer(layer_a, layer_b, prob_a):\n", + " \"\"\"Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.\"\"\"\n", + " condition = tl.Serial(\n", + " tl.RandomUniform(),\n", + " tl.Fn('SmallerThan', lambda x: x < prob_a)\n", + " )\n", + " return tl.Cond(condition, layer_a, layer_b)\n", + "\n", + "\n", + "def SparseDenseWithOptions(n_units, d_input=None, sparsity_type=None,\n", + " sparsity=0, d_lowrank=None, prob_sparse=None,\n", + " mode=None, use_bias=True, use_bfloat16=False):\n", + " \"\"\"Configurable sparse version of Dense layer.\"\"\"\n", + " if prob_sparse is not None:\n", + " if mode is not None and mode != 'train':\n", + " # For non-training modes, we want to use a sparse variant.\n", + " # This is different than simply prob_sparse being None, as the weights of\n", + " # the model are different.\n", + " prob_sparse = 1.0\n", + " return RandomLayer(\n", + " SparseDenseWithOptions(n_units, d_input, sparsity_type, sparsity,\n", + " d_lowrank, use_bias=use_bias,\n", + " use_bfloat16=use_bfloat16),\n", + " tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16),\n", + " prob_sparse)\n", + "\n", + " if sparsity_type is None or sparsity_type == 'None' or sparsity == 0:\n", + " return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16)\n", + " if sparsity_type == 'mult':\n", + " return FactoredDense(sparsity, d_input, n_units, use_bias=use_bias,\n", + " use_bfloat16=use_bfloat16)\n", + "\n", + " assert not use_bfloat16 # use_bfloat16 is unsupported for other variants\n", + " if sparsity_type == 'local':\n", + " assert use_bias # use_bias = False is unsupported\n", + " assert n_units % sparsity == 0\n", + " return LocallyConnectedDense(sparsity, n_units/sparsity)\n", + " if sparsity_type == 'local3':\n", + " assert use_bias # use_bias = False is unsupported\n", + " assert n_units % sparsity == 0\n", + " return LocallyConnectedDense(sparsity, n_units/sparsity, kernel_size=3)\n", + "\n", + " raise ValueError('Unknown sparsity type: {}'.format(sparsity_type))\n", + "\n", + "\n", + "def FactoredDense(n_modules, d_in, d_out, use_bias=True, use_bfloat16=False):\n", + " r\"\"\"Returns a Dense-like layer, internally factored to use fewer parameters.\n", + "\n", + " This layer treats an activation vector as if divided into :math:`M`\n", + " subvectors (``n_modules`` 'modules'). It uses this factored view to compute\n", + " a :py:class:`Dense`-like mapping with high mixing/connectivity, but using\n", + " approximately :math:`1/M` the number of weights of a similarly dimensioned\n", + " :py:class:`Dense` layer.\n", + "\n", + " More specifically, each activation vector of dimensionality ``n_in`` is\n", + " multiplied element-wise (a generalized form of gating) with ``n_modules``\n", + " vectors also of dimensionality ``n_in``. The resulting vectors are projected\n", + " to the subvector/module dimensionality ``d_out / n_modules`` via a matrix\n", + " multiply, and finally reshaped back to a single vector of dimensionality\n", + " ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at\n", + " the end. All the above-mentioned non-input objects -- gating vectors,\n", + " projection matrix, and optional bias -- are trainable weights.\n", + "\n", + " Args:\n", + " n_modules: Number by which an activation vector is divided into subvectors\n", + " (modules) for the factored computation.\n", + " d_in: Last/innermost dimension of input array.\n", + " d_out: Last/innermost dimension of output array.\n", + " use_bias: If True, add bias vectors at the end of the layer; else end the\n", + " layer with the matrix multiply.\n", + " use_bfloat16: If True, use bfloat16 weights; else use float32 weights.\n", + " \"\"\"\n", + " if d_out % n_modules != 0:\n", + " raise ValueError(f'Value d_out ({d_out}) must be a multiple of arg '\n", + " f'n_modules ({n_modules}).')\n", + " d_module = d_out // n_modules\n", + "\n", + " def GatingVectors():\n", + " return tl.Weights(tl.RandomNormalInitializer(stddev=0.5),\n", + " shape=[n_modules, d_in],\n", + " use_bfloat16=use_bfloat16)\n", + "\n", + " def ProjectionMatrix():\n", + " return tl.Weights(tl.GlorotUniformInitializer(),\n", + " shape=[d_in, d_module],\n", + " use_bfloat16=use_bfloat16),\n", + "\n", + " def Bias():\n", + " return tl.Weights(tl.RandomNormalInitializer(1e-6),\n", + " shape=[d_out],\n", + " use_bfloat16=use_bfloat16),\n", + "\n", + " layers = [\n", + " GatingVectors(),\n", + " ProjectionMatrix(),\n", + " _GateAndProject(),\n", + " MergeLastTwoAxes(),\n", + " ]\n", + " if use_bias:\n", + " layers += [Bias(), tl.Add()]\n", + "\n", + " return tl.Serial(layers)\n", + "\n", + "\n", + "def _GateAndProject():\n", + " \"\"\"Returns a combined gating+projection layer that saves on memory.\"\"\"\n", + "\n", + " def f(projection, gating, x):\n", + " # Args arrive in reverse order because of how they were put on the stack.\n", + " # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules)\n", + " return jnp.einsum('...d,nd,dm->...nm', x, gating, projection)\n", + "\n", + " return tl.Fn('_GateAndProject', f)\n", + "\n", + "\n", + "def MultiplicativeConvCausalAttention(\n", + " d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0,\n", + " force_no_dropout=False, max_inference_length=2048, share_qk=False,\n", + " output_layer_type='none', v_concat_type='none', mode='train'):\n", + " \"\"\"Returns a layer that maps activations to activations, with causal masking.\n", + "\n", + " Like `CausalAttention`, this layer type represents one pass of multi-head\n", + " self-attention with causal masking rather than padding-based masking. However,\n", + " for computing Q/K/V instead of a Dense layer it combines\n", + " FactoredDense layer with LocallyConvLayer.\n", + "\n", + " Args:\n", + " d_feature: Depth/dimensionality of feature embedding.\n", + " n_heads: Number of attention heads.\n", + " sparsity: The sparsity of the layer; usually it should be equal to n_heads.\n", + " length_kernel_size: Size of convolution kernel on the length dimension.\n", + " dropout: Probababilistic rate for internal dropout applied to attention\n", + " activations (based on query-key pairs) before dotting them with values.\n", + " force_no_dropout: If True, force dropout to be 0.0 independent of the above\n", + " value; used to override some configurations.\n", + " max_inference_length: maximum length for inference.\n", + " share_qk: if True, average Q and K embeddings and share for both Q and K.\n", + " output_layer_type: Which sparse layers to use for processing output from the\n", + " attention mechanism. One of `'none'`, `'mult'`, `'conv'`,\n", + " or `'multconv'`.\n", + " v_concat_type: What kind of concatenation to use when computing V tensor.\n", + " One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just\n", + " output from mutliplicative layer shared by Q, K, V. `'fixed'` means\n", + " using output from multiplicative layer concatenated, for each module,\n", + " with the layer input. `'original'` means using concatenation without\n", + " properly taking modules into account; this method was used in\n", + " experiments previously, so it is included for backwards-compatibility.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " \"\"\"\n", + " assert output_layer_type in ['none', 'mult', 'conv', 'multconv']\n", + " assert v_concat_type in ['original', 'fixed', 'none']\n", + "\n", + " dropout = 0.0 if force_no_dropout else dropout\n", + " sparsity = n_heads if sparsity is None else sparsity\n", + " d_module = d_feature // sparsity\n", + "\n", + " output_layers = []\n", + " if 'mult' in output_layer_type:\n", + " output_layers.append(FactoredDense(\n", + " sparsity, d_feature, d_feature))\n", + " if 'conv' in output_layer_type:\n", + " output_layers.append(LocallyConvDense(\n", + " sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size))\n", + "\n", + " if v_concat_type == 'original':\n", + " # 'original'` uses concatenation without properly taking modules into\n", + " # account; this method was used in experiments previously, so it is included\n", + " # for backwards-compatibility.\n", + " concat_layers = [tl.Concatenate()] # use permuted and original for v\n", + " elif v_concat_type == 'fixed':\n", + " # `'fixed'` uses the output from multiplicative layer concatenated, for each\n", + " # module, with the layer input. This means that every module in Conv layer\n", + " # has access both to parts of embeddings which were used to compute Q/K of\n", + " # this particular module, and it ha access to parts of the embedding which\n", + " # will be modified by this module.\n", + " concat_layers = [\n", + " tl.Parallel(\n", + " tl.Fn('Reshape1', lambda x: jnp.reshape( # pylint: disable=g-long-lambda\n", + " x, (x.shape[0], x.shape[1], sparsity, d_module))),\n", + " tl.Fn('Reshape2', lambda x: jnp.reshape( # pylint: disable=g-long-lambda\n", + " x, (x.shape[0], x.shape[1], sparsity, d_module)))),\n", + " tl.Concatenate(),\n", + " tl.Fn('Reshape3',\n", + " lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2*d_feature))),\n", + " ]\n", + " elif v_concat_type == 'none':\n", + " # `'none'` doesn't use concatenation: we throw away the original layer\n", + " # input and pass to Conv only output of shared Multiplicative layer.\n", + " concat_layers = [tl.Select([0], n_in=2)]\n", + "\n", + " if share_qk:\n", + " return tl.Serial(\n", + " tl.Select([0, 0]), # pre-qkv, pre-v-for-concat\n", + " FactoredDense(sparsity, d_feature, d_feature), # shared q k\n", + " tl.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat\n", + " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads),\n", + " tl.Select([0, 0]), # use for q and k\n", + " tl.Parallel(\n", + " [],\n", + " [],\n", + " [concat_layers,\n", + " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " ),\n", + " tl.DotProductCausalAttention(\n", + " dropout=dropout, max_inference_length=max_inference_length,\n", + " mode=mode),\n", + " tl.MergeHeads(n_heads),\n", + " output_layers,\n", + " )\n", + " return tl.Serial(\n", + " tl.Select([0, 0]), # duplicate activations\n", + " FactoredDense(sparsity, d_feature, d_feature), # shared q, k\n", + " tl.Select([0, 0, 0]), # use for q, k, v\n", + " tl.Parallel(\n", + " [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " [concat_layers,\n", + " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,\n", + " length_kernel_size=length_kernel_size),\n", + " tl.SplitIntoHeads(n_heads)],\n", + " ),\n", + " tl.DotProductCausalAttention(\n", + " dropout=dropout, max_inference_length=max_inference_length,\n", + " mode=mode),\n", + " tl.MergeHeads(n_heads),\n", + " output_layers,\n", + " )\n", + "\n", + "\n", + "class DotProductCausalAttention(tl.Layer):\n", + " \"\"\"Layer that computes attention strengths by masking out the \"future\".\n", + "\n", + " Causal attention uses masking to prevent a given sequence position from\n", + " attending to positions greater than / following it. This is used, for\n", + " example, when training autoregressive sequence models, or when decoding a\n", + " sequence symbol by symbol.\n", + "\n", + " This layer performs the core per-head attention calculation. The layer\n", + " assumes that any splitting into attention heads precedes it, and that any\n", + " merging of attention heads will follow it.\n", + " \"\"\"\n", + "\n", + " def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'):\n", + " \"\"\"Creates a :py:class:`DotProductCausalAttention` instance.\n", + "\n", + " Args:\n", + " dropout: Probababilistic rate for attention dropout, which overrides\n", + " (sets to zero) some attention strengths derived from query-key\n", + " matching. As a result, on a given forward pass, some value vectors\n", + " don't contribute to the output, analogous to how regular dropout can\n", + " cause some node activations to be ignored. Applies only if layer is\n", + " created in ``'train'`` mode.\n", + " max_inference_length: Maximum sequence length allowed in non-training\n", + " modes.\n", + " mode: One of ``'train'``, ``'eval'``, or ``'predict'``.\n", + " \"\"\"\n", + " super().__init__(n_in=3, n_out=1)\n", + " self._dropout = dropout\n", + " self._mode = mode\n", + " self._max_len = max_inference_length\n", + " self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none\n", + "\n", + " def monkey_patched_mask(self):\n", + " # This is necessary for Terraformer model. See comments there.\n", + " # The mask will only be used in Terraformer in predict mode.\n", + " return None\n", + "\n", + " def forward(self, inputs):\n", + " \"\"\"Returns attention-computed activations.\n", + "\n", + " Args:\n", + " inputs: A (queries, keys, values) tuple.\n", + " \"\"\"\n", + " q, k, v = inputs\n", + "\n", + " if self._portal_mask is not None:\n", + " mask_for_predict = self._portal_mask.get_value()\n", + " else:\n", + " mask_for_predict = None\n", + "\n", + " if self._mode == 'predict':\n", + " self.state, mask = _fast_inference_update_state(\n", + " inputs, self.state,\n", + " mask_for_predict=mask_for_predict)\n", + " if self._portal_mask is not None:\n", + " (_, k, v, _) = self.state\n", + " else:\n", + " (k, v, _) = self.state\n", + " else:\n", + " sequence_length = q.shape[-2]\n", + " mask = _causal_mask(sequence_length)\n", + "\n", + " activations, attn_strengths = _per_head_attention(\n", + " q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)\n", + " if self._mode == 'viz':\n", + " self.state = attn_strengths\n", + " return activations\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer for fast inference, if in ``'predict'`` mode.\"\"\"\n", + " if self._mode == 'predict':\n", + " self.state = _fast_inference_init_state(\n", + " input_signature, self._max_len,\n", + " predict_mask=self._portal_mask)\n", + " \n", + "def _fast_inference_init_state(input_signature, buffer_length,\n", + " predict_mask=None):\n", + " \"\"\"Returns an initial state for causal attention layer fast inference.\"\"\"\n", + " def zeros_for(batch_size, shape_dtype):\n", + " shape, dtype = shape_dtype.as_tuple()\n", + " d_feature = shape[-1]\n", + " return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)\n", + "\n", + " batch_size = input_signature[0].shape[0]\n", + " k = zeros_for(batch_size, input_signature[1])\n", + " v = zeros_for(batch_size, input_signature[2])\n", + " if predict_mask is not None:\n", + " mask_for_predict = jnp.zeros((buffer_length,)) != 0\n", + " return (mask_for_predict, k, v, jnp.array(0))\n", + " else:\n", + " return (k, v, jnp.array(0))\n", + "\n", + "\n", + "def _fast_inference_update_state(inputs, state, mask_for_predict=None):\n", + " \"\"\"Updates state of a causal attention layer for fast inference.\n", + "\n", + " The layer state stores arrays with cached values of keys and values,\n", + " as well as an index. To make shapes static, keys and values in the state are\n", + " long, and the index indicates where the new keys and values from inputs need\n", + " to be appended.\n", + "\n", + " During update, we append new_keys and new_values to keys and values at\n", + " position given by index. And we increment index by length of new keys.\n", + " We also create a mask to be 1 at appropriate positions (causal mask).\n", + "\n", + " Args:\n", + " inputs: a triple (new_queries, new_keys, new_values)\n", + " state: layer state with (keys, values, index)\n", + " mask_for_predict: mask used for predict mode. This is used only in\n", + " Terraformer.\n", + "\n", + " Returns:\n", + " Updated state and mask to be used.\n", + " \"\"\"\n", + " # Fast inference: run step-by-step, storing the sequence\n", + " # of keys and values calculated so far in state.\n", + " (_, new_k, new_v) = inputs\n", + " if mask_for_predict is not None:\n", + " (state_mask_for_predict, ks, vs, idx) = state\n", + " else:\n", + " (ks, vs, idx) = state\n", + " length = new_k.shape[1]\n", + " ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)\n", + " vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)\n", + " k_length = ks.shape[1]\n", + "\n", + " # Mask is of shape [1, q_length, k_length].\n", + " # Mask should be true for every pair of (query_token, key_token) such that\n", + " # index of query_token is equal or larger to index of key_token.\n", + " mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length))\n", + " <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1)))\n", + " if mask_for_predict is None:\n", + " return (ks, vs, idx + length), mask\n", + " else:\n", + " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", + " state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0,\n", + " axis=0)\n", + "\n", + " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", + " state_mask_for_predict != 0, jnp.ones((1,)) != 0,\n", + " jnp.sum(mask_for_predict, dtype=jnp.int32), axis=0)\n", + "\n", + " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", + " state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0)\n", + " placeholder = jnp.reshape(state_mask_for_predict != 0,\n", + " (1, 1, mask.shape[2],))\n", + " mask = mask * placeholder\n", + "\n", + " return (state_mask_for_predict, ks, vs, idx + length), mask\n", + "\n", + "\n", + "def _causal_mask(length):\n", + " # Not all backends define jnp.tril. However, using np.tril is inefficient\n", + " # in that it creates a large global constant.\n", + " if fastmath.is_backend(fastmath.Backend.JAX):\n", + " return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)\n", + " else:\n", + " return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)\n", + "\n", + "\n", + "def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):\n", + " \"\"\"Computes new per-head activations via scaled dot-product attention.\n", + "\n", + " This function is the core of the attention mechanism. Given per-head\n", + " ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:\n", + "\n", + " - computes the scaled dot product of each Q-K pair;\n", + " - applies ``mask`` to screen out positions that come from padding tokens\n", + " (indicated by 0 value);\n", + " - [in ``'train'`` mode] applies dropout to Q-K dot products;\n", + " - computes Q-K attention strengths using a per-query softmax of the Q-K dot\n", + " products; and\n", + " - for each query position, combines V vectors according to the Q-K\n", + " attention strengths.\n", + "\n", + " Args:\n", + " queries: Per-head activations representing attention queries.\n", + " keys: Per-head activations representing attention keys.\n", + " values: Per-head activations to be combined by computed attention strengths.\n", + " mask: Mask that distinguishes positions with real content vs. padding.\n", + " dropout: Probababilistic rate for attention dropout, which overrides\n", + " (sets to zero) some attention strengths derived from query-key\n", + " matching. As a result, on a given forward pass, some value vectors\n", + " don't contribute to the output, analogous to how regular dropout can\n", + " cause some node activations to be ignored. Applies only in ``'train'``\n", + " mode.\n", + " mode: One of ``'train'``, ``'eval'``, or ``'predict'``.\n", + " rng: Single-use random number generator (JAX PRNG key).\n", + "\n", + " Returns:\n", + " Tuple of (activations, attn_strengths), where activations are new per-head\n", + " activation vectors and attn_strengths is a matrix of per-head attention\n", + " strengths.\n", + " \"\"\"\n", + " if dropout >= 1.0:\n", + " raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')\n", + "\n", + " d_feature = queries.shape[-1]\n", + "\n", + " dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)\n", + " if mask is not None:\n", + " dots = jnp.where(mask,\n", + " dots,\n", + " jnp.full_like(dots, -1e9))\n", + " attn_strengths = (\n", + " jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))\n", + " if dropout is not None and dropout > 0.0 and mode == 'train':\n", + " keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)\n", + " attn_strengths = jnp.where(keep,\n", + " attn_strengths / (1.0 - dropout),\n", + " jnp.zeros_like(attn_strengths))\n", + " activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)\n", + " attn_strengths = attn_strengths.astype(jnp.float32)\n", + " return activations, attn_strengths\n", + "\n", + "\n", + "class _RememberInReverse(tl.Layer):\n", + " \"\"\"Layer remembering the input in forward pass. For reversible models.\"\"\"\n", + "\n", + " def __init__(self, output=True):\n", + " \"\"\"Layer remembering the input in forward pass. For reversible models.\n", + "\n", + " During the first pass through the model this layer saves the input as\n", + " state, and returns the input unmodified. During the second pass through the\n", + " model the layer outputs the input from the first pass. This is used to\n", + " combat numerical stability problems in Terraformer. It doesn't do anything\n", + " in non-reversible models.\n", + "\n", + " Args:\n", + " output: Whether to pass the input or not.\n", + " \"\"\"\n", + " n_out = 1 if output else 0\n", + " self._output = output\n", + " super().__init__(name='_RememberInReverse', n_out=n_out)\n", + "\n", + " def forward(self, x):\n", + " if 'running_second_time_yes' in self.state[1]:\n", + " result = self.state[0]\n", + " else:\n", + " result = x\n", + " self.state = (x, {'running_second_time': ()})\n", + "\n", + " if self._output:\n", + " return result\n", + " else:\n", + " return tuple()\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer's weights.\"\"\"\n", + " if isinstance(input_signature, (list, tuple)):\n", + " input_signature = input_signature[0]\n", + " self.weights = ()\n", + " self.state = (jnp.zeros(input_signature.shape, dtype=jnp.int32),\n", + " {'running_second_time': ()})\n", + "\n", + "\n", + "class _RecallQuantMaskInReverse(tl.Layer):\n", + " \"\"\"Layer recalling quant mask from specific _RememberInReverse.\n", + "\n", + " This layer is needed for memory-efficient training of reversible model with\n", + " ff chunking. During forward pass it simply returns minus ones, which are\n", + " ignored in the controller. During reverse_and_grad it returns a quant_mask\n", + " which was memorized (saved to state) by a RememberInReverse layer.\n", + "\n", + " This enable us to save quant_mask right after chunking, and load it again\n", + " (when reversing) right before chunking.\n", + " \"\"\"\n", + "\n", + " def __init__(self, remember_layer, elements):\n", + " self._remember_layer = remember_layer\n", + " self._elements = elements\n", + " super().__init__(name='_RecallQuantMaskInReverse', n_in=1, n_out=2)\n", + "\n", + " def forward(self, x):\n", + " if (self._remember_layer.state and\n", + " 'running_second_time_yes' in self._remember_layer.state[1]):\n", + " # It's reverse_and_grad, so we pull the quant_mask from remembering layer.\n", + " result = self._remember_layer.state[0]\n", + " else:\n", + " result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32)\n", + " return (x, result)\n", + "\n", + "\n", + "class _SparseFFController(tl.Layer):\n", + " \"\"\"The controller part of Sparse Feed-Forward layer.\"\"\"\n", + "\n", + " def __init__(self, d_ff, n_elements_in_block, d_lowrank, temperature,\n", + " use_bfloat16, mode, kernel_initializer, bias_initializer,\n", + " also_return_nondiscrete_output):\n", + " \"\"\"Returns a sparse feed-forward block.\"\"\"\n", + " n_out = 2 if also_return_nondiscrete_output else 1\n", + " super().__init__(name=f'_SparseFFController_{d_ff}', n_in=2, n_out=n_out)\n", + " self._use_bfloat16 = use_bfloat16\n", + " self._d_ff = d_ff\n", + " self._d_lowrank = d_lowrank\n", + " # Q: what temperature is actually most useful in training?\n", + " self._temperature = temperature if mode == 'train' else 0.0\n", + " self._mode = mode\n", + " self._n_elements_in_block = n_elements_in_block\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " # Helper numbers as d_ff will be divided by n_elements_in_block.\n", + " assert self._d_ff % self._n_elements_in_block == 0\n", + " self._d1 = self._d_ff // self._n_elements_in_block\n", + " self._d2 = self._n_elements_in_block\n", + " self._also_return_nondiscrete_output = also_return_nondiscrete_output\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " x, recalled_quant_mask = x\n", + " m1, m2, mb = self.weights\n", + "\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " # Q: should we add bias and/or put relu after the low-rank m1 dot?\n", + " # Replacing multiplication and reshape by this einsum brings training speed\n", + " # improvement (see also reshape in initialization).\n", + " mask_logits = jnp.einsum('bd,dl,lxy->bxy', x, m1, m2) + mb\n", + "\n", + " if self._also_return_nondiscrete_output:\n", + " # Softmax.\n", + " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", + " log_mask = mask_logits - mask_logsumexp\n", + " mask = jnp.exp(log_mask)\n", + " # Gumbel-softmax with straight-through discretization.\n", + " if self._temperature == 0.0:\n", + " quant_mask = jnp.argmax(log_mask, axis=-1)\n", + " else:\n", + " u = fastmath.random.uniform(self.rng, mask.shape, jnp.float32, 1e-6,\n", + " 1.0 - 1e-6)\n", + " g = -jnp.log(-jnp.log(u))\n", + " quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", + " else:\n", + " quant_mask = jnp.argmax(mask_logits, axis=-1)\n", + "\n", + " if self._mode == 'train':\n", + " # We use recalled_quant_mask if it's different than -1; otherwise\n", + " # we use a quant_mask which we have just computed.\n", + " quant_mask = jnp.where(recalled_quant_mask == -1,\n", + " quant_mask, recalled_quant_mask)\n", + "\n", + " if self._also_return_nondiscrete_output:\n", + " return quant_mask, mask\n", + " else:\n", + " return quant_mask\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " x_input_signature = input_signature[0]\n", + " d_model = x_input_signature.shape[-1]\n", + " shape_m1 = (d_model, self._d_lowrank)\n", + " shape_m2 = (self._d_lowrank, self._d_ff)\n", + " shape_mb = (self._d_ff,)\n", + "\n", + " rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3)\n", + " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", + " m2 = self._kernel_initializer(shape_m2, rng_m2)\n", + " mb = self._bias_initializer(shape_mb, rng_mb)\n", + " if self._use_bfloat16:\n", + " m1 = m1.astype(jnp.bfloat16)\n", + " m2 = m2.astype(jnp.bfloat16)\n", + " mb = mb.astype(jnp.bfloat16)\n", + "\n", + " # Reshapes below, with einsum in feedforward, improve the training speed.\n", + " m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2])\n", + " mb = jnp.reshape(mb, [self._d1, self._d2])\n", + "\n", + " self.weights = (m1, m2, mb)\n", + "\n", + "\n", + "class _SparseFFMain(tl.Layer):\n", + " \"\"\"The main (non-controller) part of Sparse Feed-Forward layer.\"\"\"\n", + "\n", + " def __init__(self, d_ff, n_elements_in_block, d_lowrank, quant_prob,\n", + " use_bfloat16, big_weights_in_bfloat16, mode, kernel_initializer,\n", + " bias_initializer, multiply_by_controller_output, kernel_scaling):\n", + " \"\"\"Returns a sparse feed-forward block.\"\"\"\n", + " n_in = 3 if mode == 'train' or multiply_by_controller_output else 2\n", + " super().__init__(name=f'_SparseFFMain_{d_ff}', n_in=n_in, n_out=2)\n", + " self._mode = mode\n", + " self._use_bfloat16 = use_bfloat16\n", + " self._big_weights_in_bfloat16 = big_weights_in_bfloat16\n", + " self._d_ff = d_ff\n", + " self._d_lowrank = d_lowrank\n", + " self._quant_prob = quant_prob\n", + " self._n_elements_in_block = n_elements_in_block\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " # Helper numbers as d_ff will be divided by n_elements_in_block.\n", + " assert self._d_ff % self._n_elements_in_block == 0\n", + " self._d1 = self._d_ff // self._n_elements_in_block\n", + " self._d2 = self._n_elements_in_block\n", + " self._multiply_by_controller_output = multiply_by_controller_output\n", + " self._kernel_scaling = kernel_scaling\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " if self._mode == 'train' or self._multiply_by_controller_output:\n", + " quant_mask, mask, x = x\n", + " else:\n", + " quant_mask, x = x\n", + " original_quant_mask = quant_mask\n", + "\n", + " w1, w2, b2 = self.weights\n", + "\n", + " if self._mode == 'predict':\n", + " w1 = jnp.transpose(w1, (1, 2, 0)) # dm, d1, d2 -> d1, d2, dm\n", + " w2 = jnp.transpose(w2, (1, 0, 2)) # d2, d1, dm -> d1, d2, dm\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " if self._mode == 'train':\n", + " # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797\n", + " quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)\n", + " quant_mask = fastmath.stop_gradient(quant_mask)\n", + " quant_mask += mask - fastmath.stop_gradient(mask) # straight-through\n", + " # We will sometimes (quant_prob of the batches) use the soft-mask instead\n", + " # of the quantized mask to improve training stability (see paper above).\n", + " select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0)\n", + " quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)\n", + "\n", + " # In training, run full matmul to get benefits from the above tricks.\n", + " mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask\n", + " relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)\n", + " if self._multiply_by_controller_output:\n", + " # We multiply only for quantized decisions, since for non-quantized\n", + " # decisions we've already multiplied the output.\n", + " mask_mult = jnp.where(select < self._quant_prob,\n", + " mask, jnp.ones_like(mask))\n", + " # Stop-gradient is here, because we already have a pass-through gradient\n", + " # (for quantized decisions).\n", + " mask_mult = fastmath.stop_gradient(mask_mult)\n", + " relu = relu * mask_mult\n", + " res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2\n", + " elif self._mode == 'predict':\n", + " # This implementation mimicks inference. It's not efficient for large\n", + " # size of joint_batch, but at inference that will be 1 most of the time.\n", + " # Shapes:\n", + " # quant_mask is [joint_batch, self._d1]\n", + " # w1 is [d_model, self._d1, self._d2]\n", + " # we'll index w1 with advanced numpy indexing, first range over\n", + " # self._d1 times the batch size, second range being quant_mask\n", + " batch_size = quant_mask.shape[0]\n", + " idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)\n", + " # flatten indices and select from w1\n", + " idx1 = jnp.reshape(idx1, [-1])\n", + " idx2 = jnp.reshape(quant_mask, [-1])\n", + " w = w1[idx1, idx2, :] # now we have per-element weights with batch dim\n", + " w = jnp.reshape(w, [batch_size, self._d1, -1])\n", + " mid = jnp.einsum('ai,aji->aj', x, w)\n", + " relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)\n", + " if self._multiply_by_controller_output:\n", + " mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0]\n", + " relu = relu * mask_mult\n", + " # w2 is [self._d1, self._d2, d_model]\n", + " v = w2[idx1, idx2, :]\n", + " v = jnp.reshape(v, [batch_size, self._d1, -1])\n", + " res = jnp.einsum('ai,aij->aj', relu, v) + b2\n", + " else:\n", + " quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)\n", + " mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask\n", + " relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)\n", + " if self._multiply_by_controller_output:\n", + " relu = relu * mask\n", + " res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2\n", + "\n", + " return original_quant_mask, jnp.reshape(res, x_shape)\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " d_model = input_signature[-1].shape[-1]\n", + " shape_w1 = (d_model, self._d_ff)\n", + " shape_w2 = (self._d_ff, d_model)\n", + " shape_b2 = (d_model,)\n", + "\n", + " rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3)\n", + " if tl.N_WEIGHTS_SHARDS > 1:\n", + " # In sharded-weights mode, put the weights on CPU on init\n", + " # as they will be sharded later.\n", + " w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1))\n", + " w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2))\n", + " else:\n", + " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", + " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", + "\n", + " b2 = self._bias_initializer(shape_b2, rng_b2)\n", + " if self._use_bfloat16:\n", + " b2 = b2.astype(jnp.bfloat16)\n", + " if self._use_bfloat16 or self._big_weights_in_bfloat16:\n", + " w1 = w1.astype(jnp.bfloat16)\n", + " w2 = w2.astype(jnp.bfloat16)\n", + "\n", + " w1 = jnp.reshape(w1, (-1, self._d1, self._d2))\n", + " w2 = jnp.reshape(w2, (self._d2, self._d1, -1))\n", + "\n", + " if self._kernel_scaling:\n", + " # This keeps expected variance of the output regardless of N.\n", + " w2 = w2 * (self._n_elements_in_block ** 0.5)\n", + "\n", + " self.weights = (w1, w2, b2)\n", + "\n", + "\n", + "def SparseFF(\n", + " d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.1, quant_prob=0.3,\n", + " use_bfloat16=False, big_weights_in_bfloat16=False, mode='train',\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6),\n", + " dropout_rate=0.0, dropout_shared_axes=None, ff_chunk_size=0,\n", + " multiply_by_controller_output=False, kernel_scaling=False):\n", + " \"\"\"Returns Feed-forward block with sparsity.\n", + "\n", + " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", + " that takes an input, makes it of size d_ff (usually larger than it was) and\n", + " then brings it back to the original size after Relu. It is commonly used in\n", + " Transformer models where it often accounts for most of the trainable weights.\n", + "\n", + " The original block can be slow in decoding due to the need to fetch a lot of\n", + " weights from memory. This sparse block only allows one non-zero element\n", + " in a block of a specified size. This is trained with straight-through Gumbel\n", + " softmax trick.\n", + "\n", + " Args:\n", + " d_ff: Depth/dimensionality of FeedForward layer.\n", + " n_elements_in_block: The sparsity level. The layer is divided into blocks of\n", + " this size, and each block has only a single element active.\n", + " d_lowrank: The dimensionality of low-rank controller.\n", + " temperature: The temperature of the controller during training.\n", + " quant_prob: During training this proportion of blocks will have quantized\n", + " mask (i.e. a single element active). The rest will use a soft mask.\n", + " use_bfloat16: Whether to use bfloat16 for weights.\n", + " big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the\n", + " FeedForward layer.\n", + " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", + " kernel_initializer: Function that creates a matrix of (random) initial\n", + " connection weights `W` for the layer.\n", + " bias_initializer: Function that creates a vector of (random) initial\n", + " bias weights `b` for the layer.\n", + " dropout_rate: Probability for dropping an activation value.\n", + " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", + " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", + " way to save memory and apply consistent masks to activation vectors at\n", + " different sequence positions.\n", + " ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks.\n", + " multiply_by_controller_output: whether to multiply the middle activation\n", + " layer of FF by controller output (i.e. softmax).\n", + " kernel_scaling: Whether to scale the kernel matrix (during init) to keep the\n", + " variance of the layer output regardless of n_elements_in_block.\n", + " \"\"\"\n", + "\n", + " if mode == 'train' or multiply_by_controller_output:\n", + " also_return_nondiscrete_output = True\n", + " else:\n", + " also_return_nondiscrete_output = False\n", + " controller = _SparseFFController(\n", + " d_ff=d_ff, n_elements_in_block=n_elements_in_block,\n", + " d_lowrank=d_lowrank, temperature=temperature,\n", + " use_bfloat16=use_bfloat16, mode=mode,\n", + " kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer,\n", + " also_return_nondiscrete_output=also_return_nondiscrete_output)\n", + "\n", + " main = [\n", + " _SparseFFMain(\n", + " d_ff=d_ff, n_elements_in_block=n_elements_in_block,\n", + " d_lowrank=d_lowrank, quant_prob=quant_prob, use_bfloat16=use_bfloat16,\n", + " big_weights_in_bfloat16=big_weights_in_bfloat16, mode=mode,\n", + " kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer,\n", + " multiply_by_controller_output=multiply_by_controller_output,\n", + " kernel_scaling=kernel_scaling),\n", + " # quant_mask, emb\n", + " tl.Select([1, 0]),\n", + " # emb, quant_mask\n", + " tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode),\n", + " tl.Select([1, 0]),\n", + " # quant_mask, emb\n", + " ]\n", + "\n", + " # We will \"remember\" quant_mask _after_ chunking, and \"recall\" this same\n", + " # quant_mask during reverse_and_grad _before_ chunking.\n", + " remembering = _RememberInReverse(output=False)\n", + " recalling = _RecallQuantMaskInReverse(\n", + " remember_layer=remembering, elements=d_ff//n_elements_in_block)\n", + "\n", + " return tl.BatchLeadingAxes(tl.Serial(\n", + " recalling, # emb, quant_mask\n", + " tl.Chunk(chunk_size=ff_chunk_size, layer=tl.Serial(\n", + " # emb, quant_mask\n", + " tl.Select((0, 1, 0)), # emb, quant_mask, emb\n", + " controller, # quant_mask, mask, emb\n", + " main, # quant_mask, emb/output\n", + " )),\n", + " remembering, # emb/output\n", + " ))\n", + "\n", + "\n", + "class BlockSparseFF(tl.Layer):\n", + " \"\"\"Feed-forward block with block sparsity.\n", + "\n", + " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", + " that takes an input, makes it of size d_ff (usually larger than it was) and\n", + " then brings it back to the original size after Relu. It is commonly used in\n", + " Transformer models where it often accounts for most of the trainable weights.\n", + "\n", + " This block sparse layer mimics mixture of experts architecture.\n", + " It divides the dimension of d_ff in each weight matrix to # of blocks equal to\n", + " n_experts and activates only one non-zero block from the weights matrix.\n", + " This is trained with straight-through Gumbel softmax trick.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " d_ff,\n", + " n_experts=64,\n", + " temperature=0.7,\n", + " mode='train',\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6)):\n", + " \"\"\"Returns a block sparse feed-forward block.\"\"\"\n", + " super().__init__(name=f'BlockSparseFF_{d_ff}')\n", + " self._mode = mode\n", + " self._d_ff = d_ff\n", + " self._n_experts = n_experts\n", + " self._temperature = temperature if mode == 'train' else 0.0\n", + " self._n_elements_in_block = d_ff // n_experts\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " assert self._d_ff % self._n_experts == 0\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " m1, w1, w2, b2 = self.weights\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " # Q: check if we need bias and/or put relu after the m1 dot?\n", + " mask_logits = jnp.dot(x, m1)\n", + " # Softmax.\n", + " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", + " log_mask = mask_logits - mask_logsumexp\n", + " mask = jnp.exp(log_mask)\n", + " # Gumbel-softmax with straight-through discretization.\n", + " rng1, rng2 = fastmath.random.split(self.rng, 2)\n", + " u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)\n", + " g = -jnp.log(-jnp.log(u))\n", + " selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", + " if self._mode == 'train':\n", + " # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797\n", + " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", + " quant_mask = fastmath.stop_gradient(quant_mask)\n", + " quant_mask += mask - fastmath.stop_gradient(mask) # straight-through\n", + " # We will sometimes (50% of the batches) use the soft-mask instead of\n", + " # the quantized mask to improve training stability (see the paper above).\n", + " # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?\n", + " select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)\n", + " quant_mask = jnp.where(select > 0.0, quant_mask, mask)\n", + " else:\n", + " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", + " quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])\n", + " batch_size = quant_mask.shape[0]\n", + "\n", + " if self._mode == 'predict' and batch_size == 1:\n", + " # This implementation mimicks inference for batch_size 1.\n", + " start_idx = selected_experts[0] * self._n_elements_in_block\n", + " # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]\n", + " w = fastmath.dynamic_slice(w1, [0, start_idx],\n", + " [w1.shape[0], self._n_elements_in_block])\n", + " mid = jnp.dot(x, w)\n", + " relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)\n", + " # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]\n", + " v = fastmath.dynamic_slice(w2, [start_idx, 0],\n", + " [self._n_elements_in_block, w2.shape[-1]])\n", + " v = jnp.reshape(v, [self._n_elements_in_block, -1])\n", + " res = jnp.dot(relu, v) + b2\n", + " else:\n", + " expanded_mask = jnp.broadcast_to(\n", + " quant_mask,\n", + " (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))\n", + " expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))\n", + " mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff]\n", + " relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)\n", + " res = jnp.dot(relu, w2) + b2\n", + "\n", + " return jnp.reshape(res, x_shape) # un-flatten if needed\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " d_model = input_signature.shape[-1]\n", + " shape_m1 = (d_model, self._n_experts)\n", + " shape_w1 = (d_model, self._d_ff)\n", + " shape_w2 = (self._d_ff, d_model)\n", + " shape_b2 = (d_model,)\n", + "\n", + " rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)\n", + " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", + " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", + " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", + " b2 = self._bias_initializer(shape_b2, rng_b2)\n", + "\n", + " self.weights = (m1, w1, w2, b2)\n", + "\n", + "\n", + "class SwitchSparseFF(tl.Layer):\n", + " \"\"\"Feed-forward block with switch-style block sparsity.\n", + "\n", + " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", + " that takes an input, makes it of size d_ff (usually larger than it was) and\n", + " then brings it back to the original size after Relu. It is commonly used in\n", + " Transformer models where it often accounts for most of the trainable weights.\n", + "\n", + " This block sparse layer mimics mixture of experts architecture.\n", + " It divides the dimension of d_ff in each weight matrix to # of blocks equal to\n", + " n_experts and activates only one non-zero block from the weights matrix.\n", + " This is trained with methods following the Switch Transformer.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " d_ff,\n", + " n_experts=64,\n", + " temperature=0.1,\n", + " mode='train',\n", + " kernel_initializer=tl.GlorotUniformInitializer(),\n", + " bias_initializer=tl.RandomNormalInitializer(1e-6)):\n", + " \"\"\"Returns a switch-style training block sparse feed-forward block.\"\"\"\n", + " super().__init__(name=f'SwitchSparseFF_{d_ff}')\n", + " self._mode = mode\n", + " self._d_ff = d_ff\n", + " self._n_experts = n_experts\n", + " self._temperature = temperature if mode == 'train' else 0.0\n", + " self._n_elements_in_block = d_ff // n_experts\n", + " self._kernel_initializer = kernel_initializer\n", + " self._bias_initializer = bias_initializer\n", + " assert self._d_ff % self._n_experts == 0\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Executes this layer as part of a forward pass through the model.\n", + "\n", + " Args:\n", + " x: Tensor of same shape and dtype as the input signature used to\n", + " initialize this layer.\n", + "\n", + " Returns:\n", + " Tensor of same shape and dtype as the input.\n", + " \"\"\"\n", + " m1, w1, w2, b2 = self.weights\n", + " x_shape = x.shape\n", + " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", + "\n", + " # Q: check if we need bias and/or put relu after the m1 dot?\n", + " mask_logits = jnp.dot(x, m1)\n", + " # Softmax.\n", + " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", + " log_mask = mask_logits - mask_logsumexp\n", + " mask = jnp.exp(log_mask)\n", + " # Gumbel noise to allow sampling from the softmax.\n", + " rng1, _ = fastmath.random.split(self.rng, 2)\n", + " u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)\n", + " g = -jnp.log(-jnp.log(u))\n", + " selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", + " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", + " quant_mask = fastmath.stop_gradient(quant_mask)\n", + " quant_mask *= mask # go to just the selected expert\n", + " quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])\n", + " batch_size = quant_mask.shape[0]\n", + "\n", + " if self._mode == 'predict' and batch_size == 1:\n", + " mask_flat = jnp.reshape(mask, [-1, self._n_experts])\n", + " selected_flat = jnp.reshape(selected_experts, [-1])\n", + " selected_mask_flat = mask_flat[np.arange(selected_flat.size),\n", + " selected_flat]\n", + " # This implementation mimicks inference for batch_size 1.\n", + " start_idx = selected_experts[0] * self._n_elements_in_block\n", + " # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]\n", + " w = fastmath.dynamic_slice(w1, [0, start_idx],\n", + " [w1.shape[0], self._n_elements_in_block])\n", + " mid = jnp.dot(x, w)\n", + " mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None]\n", + " relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)\n", + " # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]\n", + " v = fastmath.dynamic_slice(w2, [start_idx, 0],\n", + " [self._n_elements_in_block, w2.shape[-1]])\n", + " v = jnp.reshape(v, [self._n_elements_in_block, -1])\n", + " res = jnp.dot(relu, v) + b2\n", + " else:\n", + " expanded_mask = jnp.broadcast_to(\n", + " quant_mask,\n", + " (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))\n", + " expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))\n", + " mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff]\n", + " relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)\n", + " res = jnp.dot(relu, w2) + b2\n", + "\n", + " return jnp.reshape(res, x_shape) # un-flatten if needed\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", + " d_model = input_signature.shape[-1]\n", + " shape_m1 = (d_model, self._n_experts)\n", + " shape_w1 = (d_model, self._d_ff)\n", + " shape_w2 = (self._d_ff, d_model)\n", + " shape_b2 = (d_model,)\n", + "\n", + " rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)\n", + " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", + " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", + " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", + " b2 = self._bias_initializer(shape_b2, rng_b2)\n", + "\n", + " self.weights = (m1, w1, w2, b2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4-3_EPyP4c7K" + }, + "outputs": [], + "source": [ + "# SRU needs to be changed in order for concatenated encoder-decoder attention\n", + "# to work in predict mode.\n", + "\n", + "def MakeZeroState(depth_multiplier=1):\n", + " \"\"\"Makes zeros of shape like x but removing the length (axis 1).\"\"\"\n", + " def f(x): # pylint: disable=invalid-name\n", + " if len(x.shape) != 3:\n", + " raise ValueError(f'Layer input should be a rank 3 tensor representing'\n", + " f' (batch_size, sequence_length, feature_depth); '\n", + " f'instead got shape {x.shape}.')\n", + " return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]),\n", + " dtype=jnp.float32)\n", + " return tl.Fn('MakeZeroState', f)\n", + "\n", + "def InnerSRUCell():\n", + " \"\"\"The inner (non-parallel) computation of an SRU.\"\"\"\n", + " def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name\n", + " res = cur_f * cur_state + cur_x_times_one_minus_f\n", + " return res, res\n", + " return tl.Fn('InnerSRUCell', f, n_out=2)\n", + "\n", + "\n", + "def ScanSRUCell(mode, monkey_patched_mask=None):\n", + " \"\"\"The inner (non-parallel) computation of an SRU.\"\"\"\n", + " if monkey_patched_mask is None:\n", + " return tl.Scan(InnerSRUCell(), axis=1, mode=mode)\n", + "\n", + " # This is necessary for Terraformer model. See comments there.\n", + " # The mask will only be used in Terraformer in predict mode.\n", + " assert mode == 'predict'\n", + "\n", + " def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name\n", + " initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32)\n", + " if initial.shape[1] > 1:\n", + " updated_mask = fastmath.dynamic_update_slice_in_dim(\n", + " initial != 0, mask != 0, 1, axis=1)\n", + " else:\n", + " updated_mask = initial\n", + " return updated_mask, x_times_one_minus_f\n", + "\n", + " def masked_inner_sru_cell(cur_mask, cur_x_times_one_minus_f, cur_f, # pylint: disable=invalid-name\n", + " cur_state):\n", + " res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask\n", + " + (1 - cur_mask) * cur_state)\n", + " return res, res\n", + "\n", + " return tl.Serial(\n", + " monkey_patched_mask.get_layer(),\n", + " tl.Fn('update_mask', update_mask, n_out=2),\n", + " tl.Scan(tl.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2),\n", + " axis=1, mode=mode),\n", + " )\n", + "\n", + "\n", + "def SRU(n_units, activation=None, mode='train'):\n", + " r\"\"\"SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.\n", + "\n", + " As defined in the paper:\n", + "\n", + " .. math::\n", + " y_t &= W x_t + B \\quad \\hbox{(include $B$ optionally)} \\\\\n", + " f_t &= \\sigma(Wf x_t + bf) \\\\\n", + " r_t &= \\sigma(Wr x_t + br) \\\\\n", + " c_t &= f_t \\times c_{t-1} + (1 - f_t) \\times y_t \\\\\n", + " h_t &= r_t \\times \\hbox{activation}(c_t) + (1 - r_t) \\times x_t\n", + "\n", + " We assume the input is of shape [batch, length, depth] and recurrence\n", + " happens on the length dimension. This returns a single layer. It's best\n", + " to use at least 2, they say in the paper, except inside a Transformer.\n", + "\n", + " Args:\n", + " n_units: output depth of the SRU layer.\n", + " activation: Optional activation function.\n", + " mode: if 'predict' then we save the previous state for one-by-one inference\n", + "\n", + " Returns:\n", + " The SRU layer.\n", + " \"\"\"\n", + " sigmoid_activation = tl.Sigmoid()\n", + " return tl.Serial( # x\n", + " tl.Branch(tl.Dense(3 * n_units), []), # r_f_y, x\n", + " tl.Split(n_items=3), # r, f, y, x\n", + " tl.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x\n", + " tl.Fn('',\n", + " lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x\n", + " n_out=3),\n", + " tl.Parallel([], [], tl.Branch(MakeZeroState(), [])),\n", + " ScanSRUCell(mode=mode),\n", + " tl.Select([0], n_in=2), # act(c), r, x\n", + " activation if activation is not None else [],\n", + " tl.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)),\n", + " # Set the name to SRU and don't print sublayers.\n", + " name=f'SRU_{n_units}', sublayers_to_print=[]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cyf_7nTr55gU" + }, + "source": [ + "## Terraformer\n", + "\n", + "The cells below contain the implementation of the Terraformer architecture:\n", + "* feed-forward and positional encoding blocks\n", + "* encoder and decoder blocks\n", + "* concatenation and stripping to combine the encoder and decoder\n", + "* the final Terraformer model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3eEe0xnOvG_X" + }, + "outputs": [], + "source": [ + "def _FeedForward(d_model, d_ff, dropout, activation, act_dropout,\n", + " use_bfloat16, mode):\n", + " \"\"\"Feed-forward block with layer normalization at start.\"\"\"\n", + " if act_dropout is None:\n", + " act_dropout = dropout\n", + " return [\n", + " tl.Dense(d_ff, use_bfloat16=use_bfloat16),\n", + " tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode),\n", + " activation(),\n", + " tl.Dense(d_model, use_bfloat16=use_bfloat16),\n", + " ]\n", + "\n", + "\n", + "def FeedForwardWithOptions(d_model,\n", + " d_ff,\n", + " dropout,\n", + " dropout_shared_axes,\n", + " ff_activation,\n", + " ff_dropout,\n", + " ff_chunk_size,\n", + " ff_use_sru,\n", + " ff_sparsity,\n", + " center_layernorm,\n", + " mode,\n", + " use_bfloat16=False,\n", + " ff_sparsity_type='1inN'):\n", + " \"\"\"Feed-Forward block with all the options.\n", + "\n", + " Args:\n", + " d_model: Final dimension of tensors at most points in the model, including\n", + " the initial embedding output.\n", + " d_ff: Size of special dense layer in the feed-forward part of each block.\n", + " dropout: Stochastic rate (probability) for dropping an activation value when\n", + " applying dropout within a block.\n", + " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", + " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", + " way to save memory and apply consistent masks to activation vectors at\n", + " different sequence positions.\n", + " ff_activation: Type of activation function at the end of each block; must be\n", + " an activation-type subclass of `Layer`.\n", + " ff_dropout: Stochastic rate (probability) for dropping an activation value\n", + " when applying dropout after the FF dense layer.\n", + " ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks\n", + " ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers\n", + " in addition to the feed-forward block (second int specifies sru size)\n", + " ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block\n", + " with this sparsity\n", + " center_layernorm: whether to use centering in LayerNorm (default) or if\n", + " to skip it, which is known as RMS normalization.\n", + " mode: If `'train'`, each block will include dropout; else, it will pass all\n", + " values through unaltered.\n", + " use_bfloat16: whether to use bfloat16 for weights (default: False).\n", + " ff_sparsity_type: string, if ff_sparsity >0,\n", + " use SparseFF if ff_sparsity_type=`'1inN'` and\n", + " use BlockSparseFF if ff_sparsity_type=`'Block'`\n", + " use SwitchSparseFF if ff_sparsity_type=`'Switch'`\n", + "\n", + " Returns:\n", + " A list of layers which maps vectors to vectors.\n", + " \"\"\"\n", + " if ff_sparsity and ff_sparsity_type == '1inN':\n", + " temperature, quant_prob = 0.1, 0.3\n", + " if isinstance(ff_sparsity, str):\n", + " # This is hacky but used to pass ff_sparsity in yaml sweep files.\n", + " ff_sparsity = [(float(x) if '.' in x else int(x))\n", + " for x in ff_sparsity.split()]\n", + " if isinstance(ff_sparsity, (list, tuple)):\n", + " if len(ff_sparsity) == 2:\n", + " n_elements_in_block, d_lowrank = ff_sparsity\n", + " else:\n", + " n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity\n", + " else:\n", + " assert isinstance(ff_sparsity, int)\n", + " n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity\n", + " ff = SparseFF(\n", + " d_ff,\n", + " n_elements_in_block=n_elements_in_block,\n", + " d_lowrank=d_lowrank,\n", + " temperature=temperature,\n", + " quant_prob=quant_prob,\n", + " use_bfloat16=use_bfloat16,\n", + " mode=mode,\n", + " dropout_rate=dropout,\n", + " dropout_shared_axes=dropout_shared_axes,\n", + " ff_chunk_size=ff_chunk_size)\n", + " elif ff_sparsity and ff_sparsity_type == 'Block':\n", + " ff = BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)\n", + " elif ff_sparsity and ff_sparsity_type == 'Switch':\n", + " ff = SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)\n", + " else:\n", + " ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout,\n", + " use_bfloat16, mode)\n", + " res = [tl.LayerNorm(center=center_layernorm), ff]\n", + " if ff_sparsity_type != '1inN' or ff_sparsity == 0:\n", + " # SparseFF has Dropout and BatchLeadingAxes built-in.\n", + " res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes,\n", + " mode=mode))\n", + " if ff_chunk_size > 0:\n", + " res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size))\n", + " if ff_use_sru:\n", + " if isinstance(ff_use_sru, (list, tuple)):\n", + " sru_n_layers, sru_n_units = ff_use_sru\n", + " else:\n", + " sru_n_layers, sru_n_units = ff_use_sru, 32\n", + " sru = [SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)]\n", + " block = [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units)\n", + " ] + sru + [tl.Dense(d_model)]\n", + " res = tl.Residual(block, shortcut=res)\n", + " return [res]\n", + "\n", + "\n", + "def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal,\n", + " masked, attention_dropout, output_dropout,\n", + " attention_chunk_size, mode):\n", + " \"\"\"Runs the supplied attention layer.\"\"\"\n", + " try:\n", + " attention = attention_type(\n", + " n_heads=n_heads,\n", + " d_qk=d_qk,\n", + " d_v=d_v,\n", + " causal=causal,\n", + " masked=masked,\n", + " output_dropout=output_dropout,\n", + " attention_dropout=attention_dropout,\n", + " mode=mode)\n", + " except TypeError: # No d_qk arguments in less advanced layers.\n", + " attention = attention_type(\n", + " d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode)\n", + " return tl.Chunk(attention, attention_chunk_size)\n", + "\n", + "\n", + "def PositionalEncoder(mode,\n", + " dropout=None,\n", + " max_len=None,\n", + " pos_type=None,\n", + " pos_axial_shape=None,\n", + " pos_d_axial_embs=None,\n", + " pos_start_from_zero_prob=1.0,\n", + " pos_max_offset_to_add=0,\n", + " use_bfloat16=False):\n", + " \"\"\"Returns the positional encoding layer depending on the arguments.\n", + "\n", + " Args:\n", + " mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder\n", + " block will include dropout; else, it will pass all values through\n", + " unaltered.\n", + " dropout: Stochastic rate (probability) for dropping an activation\n", + " value when applying dropout after the embedding block.\n", + " max_len: Maximum symbol length for positional encoding.\n", + " pos_type: string, the type of positional embeddings to use.\n", + " pos_axial_shape: tuple of ints: input shape to use for the axial position\n", + " encoding. If unset, axial position encoding is disabled.\n", + " pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.\n", + " Tuple length must match pos_axial_shape, and values must sum to d_model.\n", + " pos_start_from_zero_prob: how often to start from 0 during training,\n", + " (if 1.0, we always start from position 0, if less, we randomize).\n", + " pos_max_offset_to_add: maximum offset to add to positions during training\n", + " when randomizing; this offset plus input length must still be less than\n", + " max_len for all training examples.\n", + " use_bfloat16: If `True`, use bfloat16 weights instead of the default\n", + " float32; this can save memory but may (rarely) lead to numerical issues.\n", + "\n", + " Returns:\n", + " A layer that will do the positional encoding.\n", + " \"\"\"\n", + " if not pos_type:\n", + " positional_encoding = tl.PositionalEncoding(\n", + " max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16,\n", + " start_from_zero_prob=pos_start_from_zero_prob,\n", + " max_offset_to_add=pos_max_offset_to_add, mode=mode)\n", + " elif pos_type == 'sin-cos':\n", + " positional_encoding = tl.SinCosPositionalEncoding(mode=mode)\n", + " elif pos_type == 'fixed-base':\n", + " positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)\n", + " elif pos_type == 'infinite':\n", + " positional_encoding = tl.InfinitePositionalEncoding(affine=False)\n", + " elif pos_type == 'infinite-affine':\n", + " positional_encoding = tl.InfinitePositionalEncoding()\n", + " elif pos_type == 'time-bin':\n", + " positional_encoding = tl.TimeBinPositionalEncoding()\n", + " else:\n", + " assert pos_d_axial_embs is not None\n", + " positional_encoding = tl.AxialPositionalEncoding(\n", + " shape=pos_axial_shape, d_embs=pos_d_axial_embs,\n", + " dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)),\n", + " dropout=dropout, mode=mode)\n", + "\n", + " return positional_encoding\n", + "\n", + "\n", + "def EmbeddingAndPositionalEncodings(input_vocab_size,\n", + " d_model,\n", + " mode,\n", + " embedding_dropout,\n", + " dropout_shared_axes,\n", + " max_len,\n", + " output_vocab_size=None,\n", + " pos_type=None,\n", + " pos_axial_shape=None,\n", + " pos_d_axial_embs=None,\n", + " pos_start_from_zero_prob=1.0,\n", + " pos_max_offset_to_add=0,\n", + " use_bfloat16=False):\n", + " \"\"\"Returns the embedder and positional encoder.\n", + "\n", + " Args:\n", + " input_vocab_size: Input vocabulary size -- each element of the input tensor\n", + " should be an integer in `range(vocab_size)`. These integers typically\n", + " represent token IDs from a vocabulary-based tokenizer.\n", + " d_model: Final dimension of tensors at most points in the model, including\n", + " the initial embedding output.\n", + " mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder\n", + " block will include dropout; else, it will pass all values through\n", + " unaltered.\n", + " embedding_dropout: Stochastic rate (probability) for dropping an activation\n", + " value when applying dropout after the embedding block.\n", + " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", + " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", + " way to save memory and apply consistent masks to activation vectors at\n", + " different sequence positions.\n", + " max_len: Maximum symbol length for positional encoding.\n", + " output_vocab_size: If specified, gives the vocabulary size for the targets;\n", + " if None, then input and target integers (token IDs) are assumed to come\n", + " from the same vocabulary.\n", + " pos_type: string, the type of positional embeddings to use.\n", + " pos_axial_shape: tuple of ints: input shape to use for the axial position\n", + " encoding. If unset, axial position encoding is disabled.\n", + " pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.\n", + " Tuple length must match pos_axial_shape, and values must sum to d_model.\n", + " pos_start_from_zero_prob: how often to start from 0 during training,\n", + " (if 1.0, we always start from position 0, if less, we randomize).\n", + " pos_max_offset_to_add: maximum offset to add to positions during training\n", + " when randomizing; this offset plus input length must still be less than\n", + " max_len for all training examples.\n", + " use_bfloat16: If `True`, use bfloat16 weights instead of the default\n", + " float32; this can save memory but may (rarely) lead to numerical issues.\n", + "\n", + " Returns:\n", + " A tuple of (input encoder, output encoder, output vocab size used).\n", + " \"\"\"\n", + " # tokens --> vectors\n", + " def Embedder(vocab_size, embedding_mode):\n", + " if vocab_size is not None:\n", + " embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16)\n", + " else:\n", + " embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16)\n", + " return [\n", + " embedding,\n", + " tl.Dropout(rate=embedding_dropout,\n", + " shared_axes=dropout_shared_axes,\n", + " mode=embedding_mode),\n", + " ]\n", + "\n", + " # NOTE: Positional encodings are not shared between encoder and decoder.\n", + "\n", + " # Since encoder doesn't run stepwise, we do not use predict mode there.\n", + " encoder_mode = 'eval' if mode == 'predict' else mode\n", + " in_embedder = Embedder(input_vocab_size, encoder_mode)\n", + " in_encoder = in_embedder + [\n", + " PositionalEncoder(encoder_mode,\n", + " dropout=embedding_dropout,\n", + " max_len=max_len,\n", + " pos_type=pos_type,\n", + " pos_axial_shape=pos_axial_shape,\n", + " pos_d_axial_embs=pos_d_axial_embs,\n", + " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", + " pos_max_offset_to_add=pos_max_offset_to_add,\n", + " use_bfloat16=use_bfloat16)\n", + " ]\n", + "\n", + " # If output_vocab_size is None, we reuse the same embedding matrix, otherwise\n", + " # we initialize one.\n", + " assert input_vocab_size or output_vocab_size\n", + " if output_vocab_size is None:\n", + " out_embedder = in_embedder\n", + " else:\n", + " out_embedder = Embedder(output_vocab_size, mode)\n", + "\n", + " out_encoder = out_embedder + [\n", + " PositionalEncoder(mode,\n", + " dropout=embedding_dropout,\n", + " max_len=max_len,\n", + " pos_type=pos_type,\n", + " pos_axial_shape=pos_axial_shape,\n", + " pos_d_axial_embs=pos_d_axial_embs,\n", + " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", + " pos_max_offset_to_add=pos_max_offset_to_add,\n", + " use_bfloat16=use_bfloat16)\n", + " ]\n", + "\n", + " # Set this to the value actually used.\n", + " if output_vocab_size is None:\n", + " output_vocab_size = input_vocab_size\n", + "\n", + " if input_vocab_size is None:\n", + " in_encoder = tl.AssertFunction('...a->...b', in_encoder)\n", + " else:\n", + " in_encoder = tl.AssertFunction('...->...d', in_encoder)\n", + " out_encoder = tl.AssertFunction('...->...d', out_encoder)\n", + "\n", + " return in_encoder, out_encoder, output_vocab_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2D3dQi9Q2bO7" + }, + "outputs": [], + "source": [ + "def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,\n", + " n_heads, attention_type, dropout, ff_activation,\n", + " ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity,\n", + " attention_chunk_size, n_attention_layers=1,\n", + " n_feedforward_layers=1, center_layernorm=True,\n", + " use_bfloat16=False, mode='train'):\n", + " \"\"\"Reversible transformer decoder layer.\n", + "\n", + " Args:\n", + " d_model: int: depth of embedding\n", + " d_ff: int: depth of feed-forward layer\n", + " d_attention_key: int: depth of key vector for each attention head\n", + " d_attention_value: int: depth of value vector for each attention head\n", + " n_heads: int: number of attention heads\n", + " attention_type: subclass of tl.BaseCausalAttention: attention class to use\n", + " dropout: float: dropout rate (how much to drop out)\n", + " ff_activation: the non-linearity in feed-forward layer\n", + " ff_dropout: the dropout rate in feed-forward layer\n", + " ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward\n", + " ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks\n", + " ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity\n", + " attention_chunk_size: int, if > 0 run attention chunked at this size\n", + " n_attention_layers: how many residual causal attention layers should we\n", + " have before the feed-forward block (default: 1, the standard block)\n", + " n_feedforward_layers: how many FFNN layers should we have (default 1).\n", + " center_layernorm: whether to use centering in LayerNorm (default) or if\n", + " to skip it, which is known as RMS normalization.\n", + " use_bfloat16: whether to use bfloat16 for weights (default: False).\n", + " mode: str: 'train' or 'eval'\n", + "\n", + "\n", + " Returns:\n", + " the layer.\n", + " \"\"\"\n", + " # pylint: disable=g-complex-comprehension\n", + " def _Attn():\n", + " return ApplyAttentionLayer(\n", + " attention_type, d_model, n_heads, d_attention_key,\n", + " d_attention_value, True, False, dropout, dropout,\n", + " attention_chunk_size, mode)\n", + "\n", + " def _FF():\n", + " return FeedForwardWithOptions(\n", + " d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,\n", + " ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,\n", + " mode, use_bfloat16)\n", + "\n", + " def _attention_half_residual():\n", + " return [\n", + " tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),\n", + " attention_layer=_Attn(),\n", + " name='ReversibleHalfResidualDecoderAttn'),\n", + " tl.ReversibleSwap()\n", + " ]\n", + "\n", + " def _feed_forward():\n", + " return [\n", + " tl.ReversibleHalfResidual(_FF(),\n", + " name='ReversibleHalfResidualDecoderFF'),\n", + " tl.ReversibleSwap()\n", + " ]\n", + "\n", + " return ([_attention_half_residual() for _ in range(n_attention_layers)]\n", + " + [_feed_forward() for _ in range(n_feedforward_layers)])\n", + "\n", + "\n", + "def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,\n", + " ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0,\n", + " attention_chunk_size=0, center_layernorm=True,\n", + " use_bfloat16=False, use_two_swaps_per_block=True,\n", + " mode='train'):\n", + " \"\"\"Returns a list of layers that implements a Terraformer encoder block.\n", + "\n", + " The input to the layer is a pair, (activations, mask), where the mask was\n", + " created from the original source tokens to prevent attending to the padding\n", + " part of the input.\n", + "\n", + " Args:\n", + " d_model: int: depth of embedding\n", + " d_ff: int: depth of feed-forward layer\n", + " n_heads: int: number of attention heads\n", + " attention_type: subclass of tl.BaseCausalAttention: attention class to use\n", + " dropout: float: dropout rate (how much to drop out)\n", + " ff_activation: the non-linearity in feed-forward layer\n", + " ff_dropout: the dropout rate in feed-forward layer\n", + " ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward\n", + " ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks\n", + " ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity\n", + " attention_chunk_size: int, if > 0 run attention chunked at this size\n", + " center_layernorm: whether to use centering in LayerNorm (default) or if\n", + " to skip it, which is known as RMS normalization.\n", + " use_bfloat16: whether to use bfloat16 for weights (default: False)\n", + " use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder\n", + " block, otherwise use only one swap.\n", + " mode: str: 'train' or 'eval'\n", + "\n", + " Returns:\n", + " A list of layers that maps (activations, mask) to (activations, mask).\n", + " \"\"\"\n", + " if mode == 'predict':\n", + " # Mode 'predict' means that the decoder should be run one token at a time.\n", + " # The encoder only ever runs over full sequences, which is why it's switched\n", + " # to 'eval' mode instead.\n", + " mode = 'eval'\n", + "\n", + " def _Attn():\n", + " return ApplyAttentionLayer(\n", + " attention_type=attention_type, d_model=d_model, n_heads=n_heads,\n", + " d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False,\n", + " attention_dropout=dropout, output_dropout=dropout,\n", + " attention_chunk_size=attention_chunk_size, mode=mode)\n", + "\n", + " def _FF():\n", + " return FeedForwardWithOptions(\n", + " d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,\n", + " ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,\n", + " mode, use_bfloat16)\n", + "\n", + " attention = _Attn()\n", + " if attention.n_out == 2:\n", + " attention = tl.Serial(\n", + " tl.Parallel([], _InsertAxes12()),\n", + " attention,\n", + " tl.Select([0], n_in=2)\n", + " )\n", + "\n", + " def _attention_half_residual():\n", + " return [\n", + " tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),\n", + " attention_layer=attention,\n", + " name='ReversibleHalfResidualEncoderAttn'),\n", + " tl.ReversibleSwap()\n", + " ]\n", + "\n", + " def _feed_forward():\n", + " layers = [\n", + " tl.ReversibleHalfResidual(_FF(),\n", + " name='ReversibleHalfResidualEncoderFF')\n", + " ]\n", + " if use_two_swaps_per_block:\n", + " layers.append(tl.ReversibleSwap())\n", + " return layers\n", + "\n", + " return _attention_half_residual() + _feed_forward()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ITiWrbEnAZyb" + }, + "outputs": [], + "source": [ + "# Arg shapes: (B, L1, H), (B, L2, H), (B, L1).\n", + "def _ConcatWithPadding(vec_e, vec_d, mask_e):\n", + " \"\"\"Concatenate with padding: see the ConcatWithPadding layer for details.\"\"\"\n", + " # pylint: disable=invalid-name\n", + " B, L1, H = vec_e.shape\n", + " L2 = vec_d.shape[1]\n", + " # pylint: enable=invalid-name\n", + "\n", + " if vec_d.shape != (B, L2, H):\n", + " raise ValueError(f'Shape of decoder vector, {vec_d.shape}, does not'\n", + " f' equal {(B, L2, H)}.')\n", + " if mask_e.shape != (B, L1):\n", + " raise ValueError(f'Shape of encoder mask, {mask_e.shape}, does not'\n", + " f' equal {(B, L1)}.')\n", + "\n", + " def _UpdateRow(x):\n", + " # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)\n", + " row_e, row_d, row_mask_e = x\n", + " # final_row - (L1+L2, H)\n", + " final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)\n", + " # Find the last real token/vector of the encoder.\n", + " e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)\n", + " # Starting after that index, update with the decoder row.\n", + " zero = jnp.array(0, dtype=e_idx.dtype) # avoid int32/int64 mismatch\n", + " return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))\n", + "\n", + " return fastmath.map(_UpdateRow, [vec_e, vec_d, mask_e])\n", + "\n", + "\n", + "def _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d):\n", + " \"\"\"Strip concatenate with padding: see the layer below for details.\"\"\"\n", + " # pylint: disable=invalid-name\n", + " B, L, H = vec_ed.shape\n", + " L1 = tok_e.shape[1]\n", + " L2 = tok_d.shape[1]\n", + " # pylint: enable=invalid-name\n", + " if L != L1 + L2:\n", + " raise ValueError(f'Length from encoder-decoder vectors ({L}) does not'\n", + " f' equal sum of lengths from encoder ({L1}) and decoder'\n", + " f' ({L2}).')\n", + " if tok_e.shape != (B, L1):\n", + " raise ValueError(f'Shape of encoder tokens, {tok_e.shape}, does not'\n", + " f' equal {(B, L1)}.')\n", + " if tok_d.shape != (B, L2):\n", + " raise ValueError(f'Shape of decoder tokens, {tok_d.shape}, does not'\n", + " f' equal {(B, L2)}.')\n", + "\n", + " def _UpdateRow(x):\n", + " # (L, H), (L1, H) & (L2, H)\n", + " row_ed, row_e, _ = x\n", + " mask_e = row_e != 0\n", + " len_e = jnp.sum(mask_e, dtype=jnp.int32)\n", + " # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`\n", + " # and pick up (L2, H) tensor slice from there.\n", + " zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch\n", + " return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))\n", + "\n", + " return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])\n", + "\n", + "\n", + "class StripFromConcatenateWithPadding(tl.Layer):\n", + " \"\"\"Strips out the leading encoder tokens from the concatenated array.\"\"\"\n", + "\n", + " def __init__(self, mode='train'):\n", + " super().__init__(n_in=3, n_out=1)\n", + " self._mode = mode\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Sets layer-specific internal state.\"\"\"\n", + " del input_signature\n", + " self.state = jnp.array(0, dtype=jnp.int32)\n", + "\n", + " def forward(self, inputs):\n", + " vec_ed, tok_e, tok_d = inputs\n", + "\n", + " # In training/eval mode or at the first step predict mode i.e. when\n", + " # state.shape is (), i.e. at first step, we do the actual compuration\n", + " if self._mode != 'predict' or not self.state.shape:\n", + " # Now state.shape will not evaluate to false.\n", + " self.state = self.state.reshape((1,))\n", + " return _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d)\n", + "\n", + " # In predict mode and on subsequent steps (i.e. after the first step) vec_ed\n", + " # is actually vec_d, since no concatenation happened at all.\n", + " return vec_ed\n", + "\n", + "\n", + "class ConcatWithPadding(tl.ReversibleLayer):\n", + " \"\"\"Concatenates two length padded (B, L, H) arrays (of different lenghts).\"\"\"\n", + "\n", + " def __init__(self, mode='train'):\n", + " super().__init__(n_in=5, n_out=3)\n", + " self._mode = mode\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Sets layer-specific internal state.\"\"\"\n", + " del input_signature\n", + " self.state = jnp.array(0, dtype=jnp.int32)\n", + "\n", + " def forward(self, inputs):\n", + " vec_e, vec_d, mask_e, tok_e, tok_d = inputs\n", + "\n", + " # In training/eval mode or at the first step predict mode i.e. when\n", + " # state.shape is (), i.e. at first step, we return the concatenated output.\n", + " if self._mode != 'predict' or not self.state.shape:\n", + " # Now state.shape will not evaluate to false.\n", + " self.state = self.state.reshape((1,))\n", + " return _ConcatWithPadding(vec_e, vec_d, mask_e), tok_e, tok_d\n", + "\n", + " # In predict mode and on subsequent steps (i.e. after the first step) we\n", + " # don't concatenate anymore, but just return the decoder vector.\n", + " return vec_d, tok_e, tok_d\n", + "\n", + " def reverse(self, output, weights=(), state=(), new_state=(), rng=None):\n", + " del state, new_state, rng, weights\n", + " assert self._mode != 'predict', 'cannot reverse in predict mode'\n", + " vecs_ed, toks_e, toks_d = output\n", + " vecs_d = _StripFromConcatenateWithPadding(vecs_ed, toks_e, toks_d)\n", + " mask_e = (toks_e != 0)\n", + " mask_e_float = mask_e.astype(jnp.float32)\n", + " vecs_e = vecs_ed[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", + " return vecs_e, vecs_d, mask_e, toks_e, toks_d\n", + "\n", + "\n", + "class ConcatWithPadding2(tl.ReversibleLayer):\n", + " \"\"\"Concatenate with padding operating on pairs to combine with rev-nets.\"\"\"\n", + "\n", + " def __init__(self, mode='train'):\n", + " super().__init__(n_in=6, n_out=4)\n", + " self._mode = mode\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Sets layer-specific internal state.\"\"\"\n", + " del input_signature\n", + " self.state = jnp.array(0, dtype=jnp.int32)\n", + "\n", + " def forward(self, inputs):\n", + " vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d = inputs\n", + "\n", + " # In training/eval mode or at the first step predict mode i.e. when\n", + " # state.shape is (), i.e. at first step, we return the concatenated output.\n", + " if self._mode != 'predict' or not self.state.shape:\n", + " # Now state.shape will not evaluate to false.\n", + " self.state = self.state.reshape((1,))\n", + " # Calculate mask and concat_with_padding on the pairs.\n", + " vecs_ed1 = _ConcatWithPadding(vecs_e1, vecs_d, mask_e)\n", + " vecs_ed2 = _ConcatWithPadding(vecs_e2, vecs_d, mask_e)\n", + " return vecs_ed1, vecs_ed2, toks_e, toks_d\n", + "\n", + " # In predict mode and on subsequent steps (i.e. after the first step) we\n", + " # don't concatenate anymore, but just return the decoder vector.\n", + " return vecs_d, vecs_d, toks_e, toks_d\n", + "\n", + " def reverse(self, output, weights=(), state=(), new_state=(), rng=None):\n", + " del state, new_state, rng, weights\n", + " assert self._mode != 'predict', 'cannot reverse in predict mode'\n", + " vecs_ed1, vecs_ed2, toks_e, toks_d = output\n", + " vecs_d = _StripFromConcatenateWithPadding(vecs_ed1, toks_e, toks_d)\n", + " mask_e = (toks_e != 0)\n", + " mask_e_float = mask_e.astype(jnp.float32)\n", + " vecs_e1 = vecs_ed1[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", + " vecs_e2 = vecs_ed2[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", + " return vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4FPVnsq8Ersd" + }, + "outputs": [], + "source": [ + "def Terraformer(input_vocab_size,\n", + " output_vocab_size=None,\n", + " d_model=512,\n", + " d_ff=2048,\n", + " d_attention_key=None,\n", + " d_attention_value=None,\n", + " n_encoder_layers=6,\n", + " n_decoder_layers=6,\n", + " n_heads=8,\n", + " dropout=0.1,\n", + " max_len=2048,\n", + " encoder_attention_type=tl.SelfAttention,\n", + " encoder_decoder_attention_type=tl.SelfAttention,\n", + " pos_type='fixed-base',\n", + " pos_axial_shape=(),\n", + " pos_d_axial_embs=None,\n", + " pos_start_from_zero_prob=1.0,\n", + " pos_max_offset_to_add=0,\n", + " ff_activation=tl.Relu,\n", + " ff_use_sru=(1, 32),\n", + " ff_chunk_size=0,\n", + " ff_dropout=None,\n", + " ff_sparsity=32,\n", + " loss_sparsity_type='mult',\n", + " loss_sparsity=0,\n", + " loss_d_lowrank=0,\n", + " loss_sparsity_prob=None,\n", + " attention_chunk_size=0,\n", + " n_layers_forget=0,\n", + " forget_dense=True,\n", + " n_decoder_attention_layers=2,\n", + " use_bfloat16=False,\n", + " reversible_encoder=False,\n", + " use_two_swaps_per_encoder_block=True,\n", + " center_layernorm=True,\n", + " half_before_layer=None,\n", + " double_after_layer=None,\n", + " mode='train'):\n", + " \"\"\"Returns a highly configurable Terraformer encoder-decoder model.\n", + "\n", + " This model maps paired text sequences (source and target) to float-valued\n", + " losses. If ``input_vocab_size`` is not ``None``, the layer takes\n", + " two input sequences:\n", + "\n", + " - inputs (2):\n", + "\n", + " - source: 2-D int array representing a batch of text strings via token\n", + " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", + " where sequence_length <= ``max_len``. Array elements are in\n", + " ``range(input_vocab_size)``, and 0 values mark padding positions.\n", + "\n", + " - target: 2-D int array representing a batch of text strings via token\n", + " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", + " where sequence_length <= ``max_len``. Array elements are in\n", + " ``range(output_vocab_size)``, and 0 values mark padding positions.\n", + "\n", + " - output: 1-D float array of losses; shape is `(batch_size)`.\n", + "\n", + " If ``input_vocab_size`` is ``None``, the layer takes three input sequences:\n", + "\n", + " - inputs (3):\n", + "\n", + " - source: 3-D float array representing a batch of already-embedded text\n", + " strings; shape is `(batch_size, sequence_length, d_model)`, where\n", + " sequence_length <= ``max_len``.\n", + "\n", + " - mask: 2-D int array representing active versus masked positions; 0\n", + " values mark masked (padding) positions.\n", + "\n", + " - target: 2-D int array representing a batch of text strings via token\n", + " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", + " where sequence_length <= ``max_len``. Array elements are in\n", + " ``range(output_vocab_size)``, and 0 values mark padding positions.\n", + "\n", + " - output: 1-D float array of losses; shape is `(batch_size)`.\n", + "\n", + " Args:\n", + " input_vocab_size: Input vocabulary size -- each element of the input tensor\n", + " should be an integer in ``range(vocab_size)``. These integers typically\n", + " represent token IDs from a vocabulary-based tokenizer.\n", + " output_vocab_size: If specified, gives the vocabulary size for the targets;\n", + " if ``None``, then input and target integers (token IDs) are assumed to\n", + " come from the same vocabulary.\n", + " d_model: Last/innermost dimension of activation arrays at most points in\n", + " the model, including the initial embedding output.\n", + " d_ff: Last/innermost dimension of special (typically wider)\n", + " :py:class:`Dense` layer in the feedforward part of each encoder block.\n", + " d_attention_key: Depth of key vectors in each attention head.\n", + " d_attention_value: Depth of value vectors in each attention head.\n", + " n_encoder_layers: Number of encoder blocks.\n", + " n_decoder_layers: Number of decoder blocks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Stochastic rate (probability) for dropping an activation value\n", + " when applying dropout within encoder/decoder blocks. The same rate is\n", + " also used for attention dropout in encoder/decoder blocks.\n", + " max_len: Maximum symbol length for positional encoding.\n", + " encoder_attention_type: Type of attention to use in the encoder; must be\n", + " an attention-type subclass of :py:class:`trax.layers.Layer`.\n", + " encoder_decoder_attention_type: Type of attention to use in the decoder;\n", + " must be an attention-type subclass of :py:class:`trax.layers.Layer`.\n", + " pos_type: String indicating the type of positional embeddings to use.\n", + " pos_axial_shape: Shape (tuple of ints) to use for the axial position\n", + " encoding. If unset, axial position encoding is disabled.\n", + " pos_d_axial_embs: Tuple of ints specifying the depth of position embedding\n", + " for each axis. Tuple length must match ``pos_axial_shape``, and values\n", + " must sum to ``d_model``.\n", + " pos_start_from_zero_prob: Stochastic rate (probability) for starting\n", + " positional encoding at position 0 during training. If 1.0, always start\n", + " from position 0; if < 1.0, the non-zero starts will be uniformly\n", + " distributed up to ``pos_max_offset_to_add``.\n", + " pos_max_offset_to_add: Maximum offset to add to positions during training\n", + " when randomizing. This offset plus input length must be less than\n", + " ``max_len`` for all training examples.\n", + " ff_activation: Type of activation function at the end of each block; must\n", + " be an activation-type subclass of :py:class:`trax.layers.Layer`.\n", + " ff_use_sru: If > 0, use this number of SRU layers in place of feedforward\n", + " layers.\n", + " ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this\n", + " size.\n", + " ff_dropout: Stochastic rate (probability) for dropping an activation value\n", + " at feedforward nonlinearities.\n", + " ff_sparsity: If > 0, use sparse feedforward blocks with this level of\n", + " sparsity.\n", + " loss_sparsity_type: String indicating the type of sparsity to used in loss\n", + " layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``,\n", + " use no sparsity.\n", + " loss_sparsity: If > 0, use this level of sparsity in the loss layer.\n", + " loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this\n", + " dimension, in the loss.\n", + " loss_sparsity_prob: Stochastic rate (probability) for using the sparse\n", + " version of the loss. If ``None``, use the sparse version exclusively.\n", + " attention_chunk_size: If > 0, compute attention using chunks of this size.\n", + " n_layers_forget: How often to have a forgetting block between layers.\n", + " forget_dense: If True, use :py:class:`Dense` instances as forget layers;\n", + " else use no-ops.\n", + " n_decoder_attention_layers: Number of attention layers in a decoder block.\n", + " use_bfloat16: If True, use bfloat16 for weights; else use float32.\n", + " reversible_encoder: If True, make the encoder be reversible.\n", + " use_two_swaps_per_encoder_block: If True, ensure that there is a an even\n", + " number of swaps across the encoder.\n", + " center_layernorm: If True, use centering in :py:class:`LayerNorm` (the\n", + " default); else omit centering (which is known as RMS normalization).\n", + " half_before_layer: If not None, specifies an n'th layer such that all\n", + " layers before the n'th use half the normal values for ``d_model`` and\n", + " ``d_ff``.\n", + " double_after_layer: If not None, specifies an n'th layer such that all\n", + " layers after the n'th use double the normal values for ``d_model`` and\n", + " ``d_ff``.\n", + " mode: If ``'train'``, include dropout in each encoder/decoder block; else\n", + " dropout layers have no effect.\n", + "\n", + " Returns:\n", + " A Terraformer encoder-decoder as a layer that maps from target and source\n", + " text sequences to a scalar loss.\n", + " \"\"\"\n", + " if mode == 'predict':\n", + " portal_mask = _PortalInput()\n", + " else:\n", + " portal_mask = None\n", + "\n", + " # Set default dimensions for attention head key and value sizes.\n", + " if (d_model / 2) % n_heads != 0:\n", + " raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')\n", + " if d_attention_key is None:\n", + " d_attention_key = d_model // n_heads\n", + " if d_attention_value is None:\n", + " d_attention_value = d_model // n_heads\n", + "\n", + " # Set values of d_model, d_ff and d_qkv for the first stage.\n", + " d_model1, d_ff1 = d_model, d_ff\n", + " d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value\n", + " if half_before_layer:\n", + " d_model1, d_ff1 = d_model / 2, d_ff / 2\n", + " d_attention_key1 = d_attention_key / 2\n", + " d_attention_value1 = d_attention_value / 2\n", + "\n", + " # Set values of d_model, d_ff and d_qkv for the final stage.\n", + " d_model2, d_ff2 = d_model, d_ff\n", + " d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value\n", + " if double_after_layer:\n", + " d_model2, d_ff2 = d_model * 2, d_ff * 2\n", + " d_attention_key2 = d_attention_key * 2\n", + " d_attention_value2 = d_attention_value * 2\n", + "\n", + " # Vector embeddings.\n", + " in_encoder, out_encoder, output_vocab_size = (\n", + " EmbeddingAndPositionalEncodings(\n", + " input_vocab_size,\n", + " d_model1,\n", + " mode,\n", + " dropout,\n", + " [-2], # dropout_shared_axes\n", + " max_len,\n", + " output_vocab_size=output_vocab_size,\n", + " pos_type=pos_type,\n", + " pos_axial_shape=pos_axial_shape,\n", + " pos_d_axial_embs=pos_d_axial_embs,\n", + " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", + " pos_max_offset_to_add=pos_max_offset_to_add,\n", + " use_bfloat16=use_bfloat16)\n", + " )\n", + "\n", + " def _EncoderBlock():\n", + " return EncoderBlock(\n", + " d_model1,\n", + " d_ff1,\n", + " n_heads,\n", + " encoder_attention_type,\n", + " dropout=dropout,\n", + " ff_activation=ff_activation,\n", + " ff_dropout=ff_dropout,\n", + " ff_use_sru=ff_use_sru,\n", + " ff_chunk_size=ff_chunk_size,\n", + " ff_sparsity=ff_sparsity,\n", + " attention_chunk_size=attention_chunk_size,\n", + " center_layernorm=center_layernorm,\n", + " use_bfloat16=use_bfloat16,\n", + " use_two_swaps_per_block=use_two_swaps_per_encoder_block,\n", + " mode=mode)\n", + "\n", + " def _Encoder(): # vec_e mask_e tok_e tok_d tok_d\n", + " layers = [\n", + " tl.ReversibleSelect([0, 0]),\n", + " _ReversibleSerialForget(\n", + " [_EncoderBlock() for _ in range(n_encoder_layers)],\n", + " d_model1,\n", + " n_layers_forget,\n", + " forget_dense)\n", + " ]\n", + " if not reversible_encoder:\n", + " layers += [\n", + " _XYAvg(),\n", + " tl.Dense(d_model1, use_bfloat16=use_bfloat16),\n", + " tl.LayerNorm(),\n", + " ]\n", + " if mode == 'predict':\n", + " return tl.Cache(tl.Serial(layers))\n", + " else:\n", + " return tl.Serial(layers)\n", + "\n", + " if mode == 'predict':\n", + " global DotProductCausalAttention\n", + " DotProductCausalAttention.monkey_patched_mask = (\n", + " lambda x: portal_mask)\n", + " global _RememberPad\n", + " _RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access\n", + " lambda x: portal_mask)\n", + " global ScanSRUCell\n", + " originalScanSRUCell = ScanSRUCell\n", + " ScanSRUCell = functools.partial(ScanSRUCell,\n", + " monkey_patched_mask=portal_mask)\n", + "\n", + " decoder_blocks = []\n", + "\n", + " if isinstance(encoder_decoder_attention_type, (tuple, list)):\n", + " assert n_decoder_layers % len(encoder_decoder_attention_type) == 0\n", + " else:\n", + " encoder_decoder_attention_type = [encoder_decoder_attention_type]\n", + " for layer_idx in range(n_decoder_layers):\n", + " layer_attention_type = encoder_decoder_attention_type[\n", + " layer_idx % len(encoder_decoder_attention_type)]\n", + " # Grow d_model, d_ff, and d_qkv if requested.\n", + " d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1\n", + " if half_before_layer and layer_idx >= half_before_layer:\n", + " d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value\n", + " if double_after_layer and layer_idx > double_after_layer:\n", + " d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2\n", + " decoder_block = DecoderBlock(\n", + " d_m, d_f, d_k, d_v, n_heads,\n", + " attention_type=layer_attention_type,\n", + " dropout=dropout,\n", + " ff_activation=ff_activation,\n", + " ff_dropout=ff_dropout,\n", + " ff_use_sru=ff_use_sru,\n", + " ff_chunk_size=ff_chunk_size,\n", + " ff_sparsity=ff_sparsity,\n", + " attention_chunk_size=attention_chunk_size,\n", + " n_attention_layers=n_decoder_attention_layers,\n", + " center_layernorm=center_layernorm,\n", + " use_bfloat16=use_bfloat16,\n", + " mode=mode)\n", + " decoder_blocks.append(decoder_block)\n", + " if half_before_layer and layer_idx == half_before_layer - 1:\n", + " decoder_blocks.append(tl.ReversibleConcatenatePair())\n", + " if double_after_layer and layer_idx == double_after_layer:\n", + " decoder_blocks.append(tl.ReversibleConcatenatePair())\n", + "\n", + " if mode == 'predict':\n", + " # After initializing the decoder we can revert to original state of\n", + " # previously monkey-patched classes/functions.\n", + " DotProductCausalAttention.monkey_patched_mask = (\n", + " lambda x: None)\n", + " _RememberPad.monkey_patched_mask = (lambda x: None) # pylint: disable=protected-access\n", + " ScanSRUCell = originalScanSRUCell\n", + "\n", + " def _Loss():\n", + " return SparseDenseWithOptions(\n", + " output_vocab_size,\n", + " d_input=d_model2,\n", + " sparsity_type=loss_sparsity_type,\n", + " sparsity=loss_sparsity,\n", + " d_lowrank=loss_d_lowrank,\n", + " prob_sparse=loss_sparsity_prob,\n", + " use_bfloat16=use_bfloat16,\n", + " mode=mode)\n", + "\n", + " def _enc_dec_concat():\n", + " \"\"\"Layers to merge encoder and decoder.\"\"\"\n", + " if reversible_encoder:\n", + " return [\n", + " tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d\n", + " ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d\n", + " ]\n", + " else:\n", + " return [\n", + " tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d\n", + " ConcatWithPadding(mode=mode), # v_ed tok_e tok_d\n", + " tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d\n", + " ]\n", + "\n", + " def _inp_layers():\n", + " if input_vocab_size is not None:\n", + " return tl.AssertFunction(\n", + " 'bl,br->bld,bl,bl,br', # b: batch, l/r: enc/dec length, d: vec depth\n", + " tl.Serial( # tok_e tok_d\n", + " tl.Select([0, 0, 0, 1]),\n", + " tl.Parallel(in_encoder, [tl.PaddingMask(),\n", + " _RemoveAxes12()])\n", + " )) # vec_e mask_e tok_e tok_d\n", + " else:\n", + " # Input in this case is vec_e, mask_e, tok_d. Where all downstream\n", + " # operations expect tok_e, we give it instead mask_e, expecting that\n", + " # downstream ops only are looking for padding/not padding.\n", + " return tl.AssertFunction(\n", + " 'blf,bl,br->bld,bl,bl,br', # f: in-feature depth, d: out-vector depth\n", + " tl.Serial( # vec_e mask_e tok_d\n", + " tl.Select([0, 1, 1, 2]),\n", + " tl.Parallel(in_encoder, [], _AsTokenIDs())\n", + " )) # vec_e mask_e tok_e tok_d\n", + "\n", + " # Assemble and return the model.\n", + " return tl.Serial(\n", + " _inp_layers(), # vec_e mask_e tok_e tok_d\n", + " tl.Parallel([], portal_mask),\n", + "\n", + " tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss.\n", + "\n", + " # Embed in and out tokens; done together as weights may be shared.\n", + " tl.Parallel([], [], [], [tl.ShiftRight(mode=mode),\n", + " out_encoder]), # vec_e mask_e tok_e vec_d tok_d\n", + "\n", + " # Encode; then concat encoder and decoder, given encoder mask.\n", + " _Encoder(), # vec_e mask_e tok_e vec_d tok_d\n", + " _enc_dec_concat(),\n", + "\n", + " # Run decoder blocks.\n", + " _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,\n", + " forget_dense), # vec_ed1 vec_ed2 tok_e tok_d\n", + " _XYAvg(), # vec_ed tok_e tok_d\n", + " tl.LayerNorm(), # vec_ed tok_e tok_d\n", + "\n", + " # Separate out the encoder part from the concatenated vector,\n", + " # then compute loss.\n", + " tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d\n", + " StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d\n", + " _Loss(), # vec_d tok_d\n", + " )\n", + "\n", + "\n", + "def _InsertAxes12():\n", + " \"\"\"Returns a layer that inserts two internal size-1 axes into an array.\"\"\"\n", + " return tl.Fn('InsertAxes12',\n", + " lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])))\n", + "\n", + "\n", + "def _RemoveAxes12():\n", + " \"\"\"Returns a layer that removes two internal size-1 axes from an array.\"\"\"\n", + " return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2)))\n", + "\n", + "\n", + "def _AsTokenIDs():\n", + " \"\"\"Returns a layer that makes mask values look like token ID ints.\"\"\"\n", + " return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32))\n", + "\n", + "\n", + "def _XYAvg():\n", + " \"\"\"Returns a layer that computes the element-wise average of two arrays.\"\"\"\n", + " return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0)\n", + "\n", + "\n", + "def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True):\n", + " \"\"\"ReversibleSerial but with a forgetting block every n_layers.\"\"\"\n", + " if not n_layers or len(layers) <= n_layers + 1:\n", + " return tl.ReversibleSerial(layers)\n", + " layers1, layers2 = layers[:n_layers], layers[n_layers:]\n", + "\n", + " if forget_dense:\n", + " forgetting_layer = tl.Serial(\n", + " _XYAvg(),\n", + " tl.Dense(d_model),\n", + " tl.Dup(),\n", + " )\n", + " else:\n", + " forgetting_layer = tl.Select([0, 1])\n", + "\n", + " return tl.Serial(\n", + " tl.ReversibleSerial(layers1),\n", + " forgetting_layer,\n", + " _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense)\n", + " )\n", + "\n", + "\n", + "def _ConvertToNaNsOnAnyZero():\n", + " def _convert_to_nans(x, y):\n", + " # if all values in y are non-zeros, return x; otherwise return 0s\n", + " return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y\n", + " return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2)\n", + "\n", + "\n", + "class _PortalInput(tl.Layer):\n", + " \"\"\"Portal input for monkey-patching of mask in predict mode.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__(name='_PortalInput', n_out=1, n_in=1)\n", + " self._portal_output = _PortalOutput(self)\n", + "\n", + " def forward(self, x):\n", + " if isinstance(x, (list, tuple)):\n", + " x = x[0]\n", + " self.state = (x,)\n", + " return x\n", + "\n", + " def init_weights_and_state(self, input_signature):\n", + " \"\"\"Initializes this layer's weights.\"\"\"\n", + " if isinstance(input_signature, (list, tuple)):\n", + " input_signature = input_signature[0]\n", + " self.state = (jnp.zeros(input_signature.shape),)\n", + "\n", + " def get_value(self):\n", + " return self.state[0]\n", + "\n", + " def get_layer(self):\n", + " return self._portal_output\n", + "\n", + "\n", + "class _PortalOutput(tl.Layer):\n", + " \"\"\"Portal input for monkey-patching of mask in predict mode.\"\"\"\n", + "\n", + " def __init__(self, portal_input):\n", + " super().__init__(name='_PortalOutput', n_out=1, n_in=0)\n", + " self._portal_input = portal_input\n", + "\n", + " def forward(self, x):\n", + " return self._portal_input.get_value()\n", + "\n", + " def get_value(self):\n", + " return self._portal_input.get_value()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E0Rq71ML6XZu" + }, + "source": [ + "## Example training\n", + "\n", + "Here we show how the Terraformer can be trained on example inputs. The results for the paper were obtained with identical training but for different configurations of inputs and models, which are specified in the attached config files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oI5XQcltJmeE" + }, + "outputs": [], + "source": [ + "model = Terraformer(\n", + " input_vocab_size=12,\n", + " # small model for testing\n", + " d_model=128,\n", + " d_ff=512,\n", + " n_encoder_layers=2,\n", + " n_decoder_layers=2,\n", + " # setting sparsity\n", + " ff_use_sru=(1, 32),\n", + " ff_sparsity=32,\n", + " loss_sparsity=4,\n", + " encoder_decoder_attention_type=functools.partial(\n", + " MultiplicativeConvCausalAttention, sparsity=16, length_kernel_size=3),\n", + " )\n", + "\n", + "copy_inputs = trax.data.inputs.simple_sequence_copy_inputs(\n", + " vocab_size=10, batch_size=32, train_length=32,\n", + " eval_min_length=16, eval_max_length=32)\n", + "\n", + "# Training task.\n", + "train_task = training.TrainTask(\n", + " labeled_data=copy_inputs.train_stream(1),\n", + " loss_layer=tl.WeightedCategoryCrossEntropy(),\n", + " optimizer=trax.optimizers.Adam(0.0001),\n", + " n_steps_per_checkpoint=5,\n", + ")\n", + "\n", + "# Evaluaton task.\n", + "eval_task = training.EvalTask(\n", + " labeled_data=copy_inputs.eval_stream(1),\n", + " metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],\n", + " n_eval_batches=2 # For less variance in eval numbers.\n", + ")\n", + "\n", + "# Training loop saves checkpoints to output_dir.\n", + "output_dir = os.path.expanduser('~/output_dir/')\n", + "!rm -rf {output_dir}\n", + "training_loop = training.Loop(model,\n", + " train_task,\n", + " eval_tasks=[eval_task],\n", + " output_dir=output_dir)\n", + "\n", + "# Run 2000 steps (batches).\n", + "training_loop.run(20)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Terraformer from scratch.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1mdBTceBJGE_yff5FvRAByrisUsc88Nw7", + "timestamp": 1635190861529 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-7-3-Attention-Visualization.ipynb b/resources/examples/ipynb/Example-7-3-Attention-Visualization.ipynb new file mode 100644 index 000000000..f0aa28094 --- /dev/null +++ b/resources/examples/ipynb/Example-7-3-Attention-Visualization.ipynb @@ -0,0 +1,1063 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7yuytuIllsv1" + }, + "source": [ + "# Attention Visualization in Trax\n", + "\n", + "For more information see the [tenso2tensor](https://trax-ml.readthedocs.io/en/latest/) visualization colab. All js tools are taken from the tensor2tensor version along with attention processing methods. The \"viz\" mode for a Trax model used in this colab [was added to Trax](https://github.com/google/trax/commit/e9a171379ef206a3e351b67cef91fe40bf37589c) with the attention visualization in mind. The colab re-uses some parts of the [Intro to Trax](https://github.com/google/trax/blob/master/trax/intro.ipynb) colab.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BIl27504La0G" + }, + "source": [ + "**General Setup**\n", + "\n", + "Execute the following few cells (once) before running of visualization codes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": {}, + "colab_type": "code", + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2020 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "import IPython.display as display\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "both", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 466 + }, + "colab_type": "code", + "id": "vlGjGoGMTt-D", + "outputId": "28f4556b-caef-47a1-bddd-7f51ecc064d8" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "# For example, if trax is inside a 'src' directory\n", + "project_root = os.environ.get('TRAX_PROJECT_ROOT', '')\n", + "sys.path.insert(0, project_root)\n", + "\n", + "# Option to verify the import path\n", + "print(f\"Python will look for packages in: {sys.path[0]}\")\n", + "\n", + "# Import trax\n", + "import trax\n", + "\n", + "# Verify the source of the imported package\n", + "print(f\"Imported trax from: {trax.__file__}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "VCBjVMrZRS6q" + }, + "outputs": [], + "source": [ + "#@title Some cool tooling for attention (make sure that you run the cell)\n", + "def resize(att_mat, max_length=None):\n", + " \"\"\"Normalize attention matrices and reshape as necessary.\"\"\"\n", + " for i, att in enumerate(att_mat):\n", + " # Add extra batch dim for viz code to work.\n", + " if att.ndim == 3:\n", + " att = np.expand_dims(att, axis=0)\n", + " if max_length is not None:\n", + " # Sum across different attention values for each token.\n", + " att = att[:, :, :max_length, :max_length]\n", + " row_sums = np.sum(att, axis=2)\n", + " # Normalize\n", + " att /= row_sums[:, :, np.newaxis]\n", + " att_mat[i] = att\n", + " return att_mat\n", + "\n", + "\n", + "def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):\n", + " \"\"\"Compute representation of the attention ready for the d3 visualization.\n", + "\n", + " Args:\n", + " inp_text: list of strings, words to be displayed on the left of the vis\n", + " out_text: list of strings, words to be displayed on the right of the vis\n", + " enc_atts: numpy array, encoder self-attentions\n", + " [num_layers, batch_size, num_heads, enc_length, enc_length]\n", + " dec_atts: numpy array, decoder self-attentions\n", + " [num_layers, batch_size, num_heads, dec_length, dec_length]\n", + " encdec_atts: numpy array, encoder-decoder attentions\n", + " [num_layers, batch_size, num_heads, dec_length, enc_length]\n", + "\n", + " Returns:\n", + " Dictionary of attention representations with the structure:\n", + " {\n", + " 'all': Representations for showing all attentions at the same time.\n", + " 'inp_inp': Representations for showing encoder self-attentions\n", + " 'inp_out': Representations for showing encoder-decoder attentions\n", + " 'out_out': Representations for showing decoder self-attentions\n", + " }\n", + " and each sub-dictionary has structure:\n", + " {\n", + " 'att': list of inter attentions matrices, one for each attention head\n", + " 'top_text': list of strings, words to be displayed on the left of the vis\n", + " 'bot_text': list of strings, words to be displayed on the right of the vis\n", + " }\n", + " \"\"\"\n", + "\n", + " def get_full_attention(layer):\n", + " \"\"\"Get the full input+output - input+output attentions.\"\"\"\n", + " enc_att = enc_atts[layer][0]\n", + " dec_att = dec_atts[layer][0]\n", + " encdec_att = encdec_atts[layer][0]\n", + " enc_att = np.transpose(enc_att, [0, 2, 1])\n", + " dec_att = np.transpose(dec_att, [0, 2, 1])\n", + " encdec_att = np.transpose(encdec_att, [0, 2, 1])\n", + " # [heads, query_length, memory_length]\n", + " enc_length = enc_att.shape[1]\n", + " dec_length = dec_att.shape[1]\n", + " num_heads = enc_att.shape[0]\n", + " first = np.concatenate([enc_att, encdec_att], axis=2)\n", + " second = np.concatenate(\n", + " [np.zeros((num_heads, dec_length, enc_length)), dec_att], axis=2)\n", + " full_att = np.concatenate([first, second], axis=1)\n", + " return [ha.T.tolist() for ha in full_att]\n", + "\n", + " def get_inp_inp_attention(layer):\n", + " att = np.transpose(enc_atts[layer][0], (0, 2, 1))\n", + " return [ha.T.tolist() for ha in att]\n", + "\n", + " def get_out_inp_attention(layer):\n", + " att = np.transpose(encdec_atts[layer][0], (0, 2, 1))\n", + " return [ha.T.tolist() for ha in att]\n", + "\n", + " def get_out_out_attention(layer):\n", + " att = np.transpose(dec_atts[layer][0], (0, 2, 1))\n", + " return [ha.T.tolist() for ha in att]\n", + "\n", + " def get_attentions(get_attention_fn):\n", + " num_layers = len(enc_atts)\n", + " return [get_attention_fn(i) for i in range(num_layers)]\n", + "\n", + " attentions = {\n", + " 'all': {\n", + " 'att': get_attentions(get_full_attention),\n", + " 'top_text': inp_text + out_text,\n", + " 'bot_text': inp_text + out_text,\n", + " },\n", + " 'inp_inp': {\n", + " 'att': get_attentions(get_inp_inp_attention),\n", + " 'top_text': inp_text,\n", + " 'bot_text': inp_text,\n", + " },\n", + " 'inp_out': {\n", + " 'att': get_attentions(get_out_inp_attention),\n", + " 'top_text': inp_text,\n", + " 'bot_text': out_text,\n", + " },\n", + " 'out_out': {\n", + " 'att': get_attentions(get_out_out_attention),\n", + " 'top_text': out_text,\n", + " 'bot_text': out_text,\n", + " },\n", + " }\n", + "\n", + " return attentions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "47lzWIH5THcw" + }, + "outputs": [], + "source": [ + "#@title Some cool HTML and js stuff (make sure that you run the cell)\n", + "vis_html = \"\"\"\n", + " \n", + " Layer: \n", + " Attention: \n", + " \n", + "
\n", + "\"\"\"\n", + "\n", + "\n", + "def call_html():\n", + " display.display(display.HTML('''\n", + " \n", + " \n", + " '''))\n", + "\n", + "\n", + "vis_js = \"\"\"\n", + "/**\n", + " * @fileoverview Transformer Visualization D3 javascript code.\n", + " */\n", + "\n", + "requirejs(['jquery', 'd3'],\n", + "function($, d3) {\n", + "\n", + "var attention = window.attention;\n", + "\n", + "const TEXT_SIZE = 15;\n", + "const BOXWIDTH = TEXT_SIZE * 8;\n", + "const BOXHEIGHT = TEXT_SIZE * 1.5;\n", + "const WIDTH = 2000;\n", + "const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;\n", + "const MATRIX_WIDTH = 150;\n", + "const head_colours = d3.scale.category10();\n", + "const CHECKBOX_SIZE = 20;\n", + "\n", + "function lighten(colour) {\n", + " var c = d3.hsl(colour);\n", + " var increment = (1 - c.l) * 0.6;\n", + " c.l += increment;\n", + " c.s -= increment;\n", + " return c;\n", + "}\n", + "\n", + "function transpose(mat) {\n", + " return mat[0].map(function(col, i) {\n", + " return mat.map(function(row) {\n", + " return row[i];\n", + " });\n", + " });\n", + "}\n", + "\n", + "function zip(a, b) {\n", + " return a.map(function (e, i) {\n", + " return [e, b[i]];\n", + " });\n", + "}\n", + "\n", + "\n", + "function renderVis(id, top_text, bot_text, attention_heads, config) {\n", + " $(id).empty();\n", + " var svg = d3.select(id)\n", + " .append('svg')\n", + " .attr(\"width\", WIDTH)\n", + " .attr(\"height\", HEIGHT);\n", + "\n", + " var att_data = [];\n", + " for (var i=0; i < attention_heads.length; i++) {\n", + " var att_trans = transpose(attention_heads[i]);\n", + " att_data.push(zip(attention_heads[i], att_trans));\n", + " }\n", + "\n", + " renderText(svg, top_text, true, att_data, 0);\n", + " renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);\n", + "\n", + " renderAttentionHighlights(svg, att_data);\n", + "\n", + " svg.append(\"g\").classed(\"attention_heads\", true);\n", + "\n", + " renderAttention(svg, attention_heads);\n", + "\n", + " draw_checkboxes(config, 0, svg, attention_heads);\n", + "}\n", + "\n", + "\n", + "function renderText(svg, text, is_top, att_data, left_pos) {\n", + " var id = is_top ? \"top\" : \"bottom\";\n", + " var textContainer = svg.append(\"svg:g\")\n", + " .attr(\"id\", id);\n", + "\n", + " textContainer.append(\"g\").classed(\"attention_boxes\", true)\n", + " .selectAll(\"g\")\n", + " .data(att_data)\n", + " .enter()\n", + " .append(\"g\")\n", + " .selectAll(\"rect\")\n", + " .data(function(d) {return d;})\n", + " .enter()\n", + " .append(\"rect\")\n", + " .attr(\"x\", function(d, i, j) {\n", + " return left_pos + box_offset(j);\n", + " })\n", + " .attr(\"y\", function(d, i) {\n", + " return (+1) * BOXHEIGHT;\n", + " })\n", + " .attr(\"width\", BOXWIDTH/active_heads())\n", + " .attr(\"height\", function() { return BOXHEIGHT; })\n", + " .attr(\"fill\", function(d, i, j) {\n", + " return head_colours(j);\n", + " })\n", + " .style(\"opacity\", 0.0);\n", + "\n", + "\n", + " var tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n", + " .data(text)\n", + " .enter()\n", + " .append(\"g\");\n", + "\n", + " tokenContainer.append(\"rect\")\n", + " .classed(\"background\", true)\n", + " .style(\"opacity\", 0.0)\n", + " .attr(\"fill\", \"lightgray\")\n", + " .attr(\"x\", left_pos)\n", + " .attr(\"y\", function(d, i) {\n", + " return (i+1) * BOXHEIGHT;\n", + " })\n", + " .attr(\"width\", BOXWIDTH)\n", + " .attr(\"height\", BOXHEIGHT);\n", + "\n", + " var theText = tokenContainer.append(\"text\")\n", + " .text(function(d) { return d; })\n", + " .attr(\"font-size\", TEXT_SIZE + \"px\")\n", + " .style(\"cursor\", \"default\")\n", + " .style(\"-webkit-user-select\", \"none\")\n", + " .attr(\"x\", left_pos)\n", + " .attr(\"y\", function(d, i) {\n", + " return (i+1) * BOXHEIGHT;\n", + " });\n", + "\n", + " if (is_top) {\n", + " theText.style(\"text-anchor\", \"end\")\n", + " .attr(\"dx\", BOXWIDTH - TEXT_SIZE)\n", + " .attr(\"dy\", TEXT_SIZE);\n", + " } else {\n", + " theText.style(\"text-anchor\", \"start\")\n", + " .attr(\"dx\", + TEXT_SIZE)\n", + " .attr(\"dy\", TEXT_SIZE);\n", + " }\n", + "\n", + " tokenContainer.on(\"mouseover\", function(d, index) {\n", + " textContainer.selectAll(\".background\")\n", + " .style(\"opacity\", function(d, i) {\n", + " return i == index ? 1.0 : 0.0;\n", + " });\n", + "\n", + " svg.selectAll(\".attention_heads\").style(\"display\", \"none\");\n", + "\n", + " svg.selectAll(\".line_heads\") // To get the nesting to work.\n", + " .selectAll(\".att_lines\")\n", + " .attr(\"stroke-opacity\", function(d) {\n", + " return 1.0;\n", + " })\n", + " .attr(\"y1\", function(d, i) {\n", + " if (is_top) {\n", + " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " } else {\n", + " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " }\n", + " })\n", + " .attr(\"x1\", BOXWIDTH)\n", + " .attr(\"y2\", function(d, i) {\n", + " if (is_top) {\n", + " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " } else {\n", + " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", + " }\n", + " })\n", + " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", + " .attr(\"stroke-width\", 2)\n", + " .attr(\"stroke\", function(d, i, j) {\n", + " return head_colours(j);\n", + " })\n", + " .attr(\"stroke-opacity\", function(d, i, j) {\n", + " if (is_top) {d = d[0];} else {d = d[1];}\n", + " if (config.head_vis[j]) {\n", + " if (d) {\n", + " return d[index];\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " });\n", + "\n", + "\n", + " function updateAttentionBoxes() {\n", + " var id = is_top ? \"bottom\" : \"top\";\n", + " var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;\n", + " svg.select(\"#\" + id)\n", + " .selectAll(\".attention_boxes\")\n", + " .selectAll(\"g\")\n", + " .selectAll(\"rect\")\n", + " .attr(\"x\", function(d, i, j) { return the_left_pos + box_offset(j); })\n", + " .attr(\"y\", function(d, i) { return (i+1) * BOXHEIGHT; })\n", + " .attr(\"width\", BOXWIDTH/active_heads())\n", + " .attr(\"height\", function() { return BOXHEIGHT; })\n", + " .style(\"opacity\", function(d, i, j) {\n", + " if (is_top) {d = d[0];} else {d = d[1];}\n", + " if (config.head_vis[j])\n", + " if (d) {\n", + " return d[index];\n", + " } else {\n", + " return 0.0;\n", + " }\n", + " else\n", + " return 0.0;\n", + "\n", + " });\n", + " }\n", + "\n", + " updateAttentionBoxes();\n", + " });\n", + "\n", + " textContainer.on(\"mouseleave\", function() {\n", + " d3.select(this).selectAll(\".background\")\n", + " .style(\"opacity\", 0.0);\n", + "\n", + " svg.selectAll(\".att_lines\").attr(\"stroke-opacity\", 0.0);\n", + " svg.selectAll(\".attention_heads\").style(\"display\", \"inline\");\n", + " svg.selectAll(\".attention_boxes\")\n", + " .selectAll(\"g\")\n", + " .selectAll(\"rect\")\n", + " .style(\"opacity\", 0.0);\n", + " });\n", + "}\n", + "\n", + "function renderAttentionHighlights(svg, attention) {\n", + " var line_container = svg.append(\"g\");\n", + " line_container.selectAll(\"g\")\n", + " .data(attention)\n", + " .enter()\n", + " .append(\"g\")\n", + " .classed(\"line_heads\", true)\n", + " .selectAll(\"line\")\n", + " .data(function(d){return d;})\n", + " .enter()\n", + " .append(\"line\").classed(\"att_lines\", true);\n", + "}\n", + "\n", + "function renderAttention(svg, attention_heads) {\n", + " var line_container = svg.selectAll(\".attention_heads\");\n", + " line_container.html(null);\n", + " for(var h=0; h\").val(i).text(i));\n", + "}\n", + "\n", + "$(\"#layer\").on('change', function(e) {\n", + " config.layer = +e.currentTarget.value;\n", + " render();\n", + "});\n", + "\n", + "$(\"#att_type\").on('change', function(e) {\n", + " config.att_type = e.currentTarget.value;\n", + " render();\n", + "});\n", + "\n", + "$(\"button\").on('click', visualize);\n", + "\n", + "visualize();\n", + "\n", + "});\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-LQ89rFFsEdk" + }, + "source": [ + "## 1. Run a pre-trained Transformer\n", + "\n", + "* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)\n", + "* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)\n", + "* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)\n", + "* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)\n", + "* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trax import models as model\n", + "from trax.learning.supervised import decoding as decoding\n", + "from trax.data.encoder import encoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "djTiSLcaNFGa", + "outputId": "b5ad2955-5e1d-47aa-97bb-5d72a25ed76d" + }, + "outputs": [], + "source": [ + "# Create a Transformer model.\n", + "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", + "model = model.Transformer(\n", + " input_vocab_size=33300,\n", + " d_model=512, d_ff=2048,\n", + " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", + " max_len=2048, mode='predict')\n", + "\n", + "# Initialize using pre-trained weights.\n", + "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", + " weights_only=True)\n", + "\n", + "# Tokenize a sentence.\n", + "sentence = 'It is nice to learn new things today!'\n", + "tokenized = list(encoder.tokenize(iter([sentence]), # Operates on streams.\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword'))[0]\n", + "\n", + "# Decode from the Transformer.\n", + "tokenized = tokenized[None, :] # Add batch dimension.\n", + "tokenized_translation = decoding.autoregressive_sample(\n", + " model, tokenized, temperature=0.0) # Higher temperature: more diverse results.\n", + "\n", + "# De-tokenize,\n", + "tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.\n", + "translation = encoder.detokenize(tokenized_translation,\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword')\n", + "print(translation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "pWDPwZfSJeD3", + "outputId": "050d40bf-f28d-49ea-b69a-af2886cf92a4" + }, + "outputs": [], + "source": [ + "tokenized, tokenized_translation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Lu6URNjbXIHv" + }, + "source": [ + "## 2. Prepare the tokens for visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kqNWMpNdMg9z" + }, + "outputs": [], + "source": [ + "def decode(single_token):\n", + " return encoder.detokenize(single_token,\n", + " vocab_dir='gs://trax-ml/vocabs/',\n", + " vocab_file='ende_32k.subword')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "H2fbJB_BMeRw" + }, + "outputs": [], + "source": [ + "def get_tokens_str(integers):\n", + " token_strs = []\n", + " for i in range(integers.shape[1]):\n", + " token_strs.append(decode(integers[:, i]))\n", + " return token_strs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "YkNT8rbgKM5-" + }, + "outputs": [], + "source": [ + "tokenized_translation_with_start = np.array([0] + list(tokenized_translation), dtype=np.int64)\n", + "tokenized_translation_with_start = tokenized_translation_with_start[np.newaxis, ...]\n", + "tokenized_translation = np.array(tokenized_translation, dtype=np.int64)\n", + "tokenized_translation = tokenized_translation[np.newaxis, ...]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "r-FVdSZPKQhs" + }, + "outputs": [], + "source": [ + "tokenized_str = get_tokens_str(tokenized)\n", + "tokenized_translation_str = get_tokens_str(tokenized_translation_with_start)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 223 + }, + "colab_type": "code", + "id": "Cy7edKBuKash", + "outputId": "c1e00dbe-f467-48df-eaaf-579f68ef788f" + }, + "outputs": [], + "source": [ + "tokenized_str, tokenized_translation_str" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1XxJSqAsOTBe" + }, + "outputs": [], + "source": [ + "max_len = max(tokenized.shape[1], tokenized_translation.shape[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Qju-9pPHOV6G" + }, + "outputs": [], + "source": [ + "tokenized_translation_pad = np.zeros((1, max_len), dtype=np.int64)\n", + "tokenized_translation_pad[:, :tokenized_translation.shape[1]] = tokenized_translation\n", + "\n", + "tokenized_pad = np.zeros((1, max_len), dtype=np.int64)\n", + "tokenized_pad[:, :tokenized.shape[1]] = tokenized" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "zGxBSk0gOfYi", + "outputId": "d83328fa-eec8-4631-d2b6-4fffc3f0b933" + }, + "outputs": [], + "source": [ + "tokenized_translation_pad.shape, tokenized_pad.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WqvjmRaCXign" + }, + "source": [ + "## 3. Create the same pre-trained model in the \"viz\" mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Qb2F4Pj_OLMZ" + }, + "outputs": [], + "source": [ + "# Create a Transformer model in the \"viz\" mode\n", + "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", + "model_viz = model.Transformer(\n", + " input_vocab_size=33300,\n", + " d_model=512, d_ff=2048,\n", + " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", + " max_len=2048, mode='viz')\n", + "\n", + "# Initialize using pre-trained weights.\n", + "model_viz.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", + " weights_only=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "AxcrAfprO0rD" + }, + "outputs": [], + "source": [ + "# We run the viz model because later we want to inspect its state\n", + "_ = model_viz((tokenized_pad, tokenized_translation_pad))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "lVCYSQSuXw6f" + }, + "source": [ + "## 4. Find the attention weights (aka dots)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "dsGuqdgnO2Lf" + }, + "outputs": [], + "source": [ + "attention_weights = []\n", + "\n", + "\n", + "def attention_sublayers(layer):\n", + " if 'Attention' in layer.name:\n", + " print(\"Found layer {}\".format(layer.name))\n", + " attention_weights.append(layer.state)\n", + " if layer.sublayers:\n", + " for sublayer in layer.sublayers:\n", + " attention_sublayers(sublayer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 326 + }, + "colab_type": "code", + "id": "FA3ba2-DO5l4", + "outputId": "f66756b1-fa86-4582-bd04-9b464ae132eb" + }, + "outputs": [], + "source": [ + "attention_sublayers(model_viz)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "q36-o98QO7HC", + "outputId": "445fe1ce-f1fa-484a-9db4-b37f56915d7c" + }, + "outputs": [], + "source": [ + "len(attention_weights)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LahOE6q6PB1B" + }, + "outputs": [], + "source": [ + "# Manually identification of layers would be difficult, hence we rely on attention_sublayers function\n", + "enc_atts = attention_weights[:6]\n", + "dec_atts = attention_weights[6::2] # these are the DotProductCausalAttention layers\n", + "encdec_atts = attention_weights[7::2] # these are the PureAttention layers starting from the 6th layer on\n", + "\n", + "# Here we use a number of python utils inherited from tensor2tensor\n", + "enc_atts_res = resize(enc_atts)\n", + "dec_atts_res = resize(dec_atts)\n", + "encdec_atts_res = resize(encdec_atts)\n", + "attention_dict = _get_attention(tokenized_str, tokenized_translation_str, enc_atts_res, dec_atts_res, encdec_atts_res)\n", + "attention_json = json.dumps(attention_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1DgBBfg-X6-d" + }, + "source": [ + "## 5. Display attention" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "resources": { + "http://localhost:8080/static/components/requirejs/require.js": { + "data": "LyoqIHZpbTogZXQ6dHM9NDpzdz00OnN0cz00CiAqIEBsaWNlbnNlIFJlcXVpcmVKUyAyLjEuMjIgQ29weXJpZ2h0IChjKSAyMDEwLTIwMTUsIFRoZSBEb2pvIEZvdW5kYXRpb24gQWxsIFJpZ2h0cyBSZXNlcnZlZC4KICogQXZhaWxhYmxlIHZpYSB0aGUgTUlUIG9yIG5ldyBCU0QgbGljZW5zZS4KICogc2VlOiBodHRwOi8vZ2l0aHViLmNvbS9qcmJ1cmtlL3JlcXVpcmVqcyBmb3IgZGV0YWlscwogKi8KLy9Ob3QgdXNpbmcgc3RyaWN0OiB1bmV2ZW4gc3RyaWN0IHN1cHBvcnQgaW4gYnJvd3NlcnMsICMzOTIsIGFuZCBjYXVzZXMKLy9wcm9ibGVtcyB3aXRoIHJlcXVpcmVqcy5leGVjKCkvdHJhbnNwaWxlciBwbHVnaW5zIHRoYXQgbWF5IG5vdCBiZSBzdHJpY3QuCi8qanNsaW50IHJlZ2V4cDogdHJ1ZSwgbm9tZW46IHRydWUsIHNsb3BweTogdHJ1ZSAqLwovKmdsb2JhbCB3aW5kb3csIG5hdmlnYXRvciwgZG9jdW1lbnQsIGltcG9ydFNjcmlwdHMsIHNldFRpbWVvdXQsIG9wZXJhICovCgp2YXIgcmVxdWlyZWpzLCByZXF1aXJlLCBkZWZpbmU7CihmdW5jdGlvbiAoZ2xvYmFsKSB7CiAgICB2YXIgcmVxLCBzLCBoZWFkLCBiYXNlRWxlbWVudCwgZGF0YU1haW4sIHNyYywKICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCwgY3VycmVudGx5QWRkaW5nU2NyaXB0LCBtYWluU2NyaXB0LCBzdWJQYXRoLAogICAgICAgIHZlcnNpb24gPSAnMi4xLjIyJywKICAgICAgICBjb21tZW50UmVnRXhwID0gLyhcL1wqKFtcc1xTXSo/KVwqXC98KFteOl18XilcL1wvKC4qKSQpL21nLAogICAgICAgIGNqc1JlcXVpcmVSZWdFeHAgPSAvW14uXVxzKnJlcXVpcmVccypcKFxzKlsiJ10oW14nIlxzXSspWyInXVxzKlwpL2csCiAgICAgICAganNTdWZmaXhSZWdFeHAgPSAvXC5qcyQvLAogICAgICAgIGN1cnJEaXJSZWdFeHAgPSAvXlwuXC8vLAogICAgICAgIG9wID0gT2JqZWN0LnByb3RvdHlwZSwKICAgICAgICBvc3RyaW5nID0gb3AudG9TdHJpbmcsCiAgICAgICAgaGFzT3duID0gb3AuaGFzT3duUHJvcGVydHksCiAgICAgICAgYXAgPSBBcnJheS5wcm90b3R5cGUsCiAgICAgICAgaXNCcm93c2VyID0gISEodHlwZW9mIHdpbmRvdyAhPT0gJ3VuZGVmaW5lZCcgJiYgdHlwZW9mIG5hdmlnYXRvciAhPT0gJ3VuZGVmaW5lZCcgJiYgd2luZG93LmRvY3VtZW50KSwKICAgICAgICBpc1dlYldvcmtlciA9ICFpc0Jyb3dzZXIgJiYgdHlwZW9mIGltcG9ydFNjcmlwdHMgIT09ICd1bmRlZmluZWQnLAogICAgICAgIC8vUFMzIGluZGljYXRlcyBsb2FkZWQgYW5kIGNvbXBsZXRlLCBidXQgbmVlZCB0byB3YWl0IGZvciBjb21wbGV0ZQogICAgICAgIC8vc3BlY2lmaWNhbGx5LiBTZXF1ZW5jZSBpcyAnbG9hZGluZycsICdsb2FkZWQnLCBleGVjdXRpb24sCiAgICAgICAgLy8gdGhlbiAnY29tcGxldGUnLiBUaGUgVUEgY2hlY2sgaXMgdW5mb3J0dW5hdGUsIGJ1dCBub3Qgc3VyZSBob3cKICAgICAgICAvL3RvIGZlYXR1cmUgdGVzdCB3L28gY2F1c2luZyBwZXJmIGlzc3Vlcy4KICAgICAgICByZWFkeVJlZ0V4cCA9IGlzQnJvd3NlciAmJiBuYXZpZ2F0b3IucGxhdGZvcm0gPT09ICdQTEFZU1RBVElPTiAzJyA/CiAgICAgICAgICAgICAgICAgICAgICAvXmNvbXBsZXRlJC8gOiAvXihjb21wbGV0ZXxsb2FkZWQpJC8sCiAgICAgICAgZGVmQ29udGV4dE5hbWUgPSAnXycsCiAgICAgICAgLy9PaCB0aGUgdHJhZ2VkeSwgZGV0ZWN0aW5nIG9wZXJhLiBTZWUgdGhlIHVzYWdlIG9mIGlzT3BlcmEgZm9yIHJlYXNvbi4KICAgICAgICBpc09wZXJhID0gdHlwZW9mIG9wZXJhICE9PSAndW5kZWZpbmVkJyAmJiBvcGVyYS50b1N0cmluZygpID09PSAnW29iamVjdCBPcGVyYV0nLAogICAgICAgIGNvbnRleHRzID0ge30sCiAgICAgICAgY2ZnID0ge30sCiAgICAgICAgZ2xvYmFsRGVmUXVldWUgPSBbXSwKICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwoKICAgIGZ1bmN0aW9uIGlzRnVuY3Rpb24oaXQpIHsKICAgICAgICByZXR1cm4gb3N0cmluZy5jYWxsKGl0KSA9PT0gJ1tvYmplY3QgRnVuY3Rpb25dJzsKICAgIH0KCiAgICBmdW5jdGlvbiBpc0FycmF5KGl0KSB7CiAgICAgICAgcmV0dXJuIG9zdHJpbmcuY2FsbChpdCkgPT09ICdbb2JqZWN0IEFycmF5XSc7CiAgICB9CgogICAgLyoqCiAgICAgKiBIZWxwZXIgZnVuY3Rpb24gZm9yIGl0ZXJhdGluZyBvdmVyIGFuIGFycmF5LiBJZiB0aGUgZnVuYyByZXR1cm5zCiAgICAgKiBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoKGFyeSwgZnVuYykgewogICAgICAgIGlmIChhcnkpIHsKICAgICAgICAgICAgdmFyIGk7CiAgICAgICAgICAgIGZvciAoaSA9IDA7IGkgPCBhcnkubGVuZ3RoOyBpICs9IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICAvKioKICAgICAqIEhlbHBlciBmdW5jdGlvbiBmb3IgaXRlcmF0aW5nIG92ZXIgYW4gYXJyYXkgYmFja3dhcmRzLiBJZiB0aGUgZnVuYwogICAgICogcmV0dXJucyBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoUmV2ZXJzZShhcnksIGZ1bmMpIHsKICAgICAgICBpZiAoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpOwogICAgICAgICAgICBmb3IgKGkgPSBhcnkubGVuZ3RoIC0gMTsgaSA+IC0xOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICBmdW5jdGlvbiBoYXNQcm9wKG9iaiwgcHJvcCkgewogICAgICAgIHJldHVybiBoYXNPd24uY2FsbChvYmosIHByb3ApOwogICAgfQoKICAgIGZ1bmN0aW9uIGdldE93bihvYmosIHByb3ApIHsKICAgICAgICByZXR1cm4gaGFzUHJvcChvYmosIHByb3ApICYmIG9ialtwcm9wXTsKICAgIH0KCiAgICAvKioKICAgICAqIEN5Y2xlcyBvdmVyIHByb3BlcnRpZXMgaW4gYW4gb2JqZWN0IGFuZCBjYWxscyBhIGZ1bmN0aW9uIGZvciBlYWNoCiAgICAgKiBwcm9wZXJ0eSB2YWx1ZS4gSWYgdGhlIGZ1bmN0aW9uIHJldHVybnMgYSB0cnV0aHkgdmFsdWUsIHRoZW4gdGhlCiAgICAgKiBpdGVyYXRpb24gaXMgc3RvcHBlZC4KICAgICAqLwogICAgZnVuY3Rpb24gZWFjaFByb3Aob2JqLCBmdW5jKSB7CiAgICAgICAgdmFyIHByb3A7CiAgICAgICAgZm9yIChwcm9wIGluIG9iaikgewogICAgICAgICAgICBpZiAoaGFzUHJvcChvYmosIHByb3ApKSB7CiAgICAgICAgICAgICAgICBpZiAoZnVuYyhvYmpbcHJvcF0sIHByb3ApKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBTaW1wbGUgZnVuY3Rpb24gdG8gbWl4IGluIHByb3BlcnRpZXMgZnJvbSBzb3VyY2UgaW50byB0YXJnZXQsCiAgICAgKiBidXQgb25seSBpZiB0YXJnZXQgZG9lcyBub3QgYWxyZWFkeSBoYXZlIGEgcHJvcGVydHkgb2YgdGhlIHNhbWUgbmFtZS4KICAgICAqLwogICAgZnVuY3Rpb24gbWl4aW4odGFyZ2V0LCBzb3VyY2UsIGZvcmNlLCBkZWVwU3RyaW5nTWl4aW4pIHsKICAgICAgICBpZiAoc291cmNlKSB7CiAgICAgICAgICAgIGVhY2hQcm9wKHNvdXJjZSwgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICBpZiAoZm9yY2UgfHwgIWhhc1Byb3AodGFyZ2V0LCBwcm9wKSkgewogICAgICAgICAgICAgICAgICAgIGlmIChkZWVwU3RyaW5nTWl4aW4gJiYgdHlwZW9mIHZhbHVlID09PSAnb2JqZWN0JyAmJiB2YWx1ZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAhaXNBcnJheSh2YWx1ZSkgJiYgIWlzRnVuY3Rpb24odmFsdWUpICYmCiAgICAgICAgICAgICAgICAgICAgICAgICEodmFsdWUgaW5zdGFuY2VvZiBSZWdFeHApKSB7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIXRhcmdldFtwcm9wXSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0ge307CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgbWl4aW4odGFyZ2V0W3Byb3BdLCB2YWx1ZSwgZm9yY2UsIGRlZXBTdHJpbmdNaXhpbik7CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0gdmFsdWU7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9KTsKICAgICAgICB9CiAgICAgICAgcmV0dXJuIHRhcmdldDsKICAgIH0KCiAgICAvL1NpbWlsYXIgdG8gRnVuY3Rpb24ucHJvdG90eXBlLmJpbmQsIGJ1dCB0aGUgJ3RoaXMnIG9iamVjdCBpcyBzcGVjaWZpZWQKICAgIC8vZmlyc3QsIHNpbmNlIGl0IGlzIGVhc2llciB0byByZWFkL2ZpZ3VyZSBvdXQgd2hhdCAndGhpcycgd2lsbCBiZS4KICAgIGZ1bmN0aW9uIGJpbmQob2JqLCBmbikgewogICAgICAgIHJldHVybiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgIHJldHVybiBmbi5hcHBseShvYmosIGFyZ3VtZW50cyk7CiAgICAgICAgfTsKICAgIH0KCiAgICBmdW5jdGlvbiBzY3JpcHRzKCkgewogICAgICAgIHJldHVybiBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnc2NyaXB0Jyk7CiAgICB9CgogICAgZnVuY3Rpb24gZGVmYXVsdE9uRXJyb3IoZXJyKSB7CiAgICAgICAgdGhyb3cgZXJyOwogICAgfQoKICAgIC8vQWxsb3cgZ2V0dGluZyBhIGdsb2JhbCB0aGF0IGlzIGV4cHJlc3NlZCBpbgogICAgLy9kb3Qgbm90YXRpb24sIGxpa2UgJ2EuYi5jJy4KICAgIGZ1bmN0aW9uIGdldEdsb2JhbCh2YWx1ZSkgewogICAgICAgIGlmICghdmFsdWUpIHsKICAgICAgICAgICAgcmV0dXJuIHZhbHVlOwogICAgICAgIH0KICAgICAgICB2YXIgZyA9IGdsb2JhbDsKICAgICAgICBlYWNoKHZhbHVlLnNwbGl0KCcuJyksIGZ1bmN0aW9uIChwYXJ0KSB7CiAgICAgICAgICAgIGcgPSBnW3BhcnRdOwogICAgICAgIH0pOwogICAgICAgIHJldHVybiBnOwogICAgfQoKICAgIC8qKgogICAgICogQ29uc3RydWN0cyBhbiBlcnJvciB3aXRoIGEgcG9pbnRlciB0byBhbiBVUkwgd2l0aCBtb3JlIGluZm9ybWF0aW9uLgogICAgICogQHBhcmFtIHtTdHJpbmd9IGlkIHRoZSBlcnJvciBJRCB0aGF0IG1hcHMgdG8gYW4gSUQgb24gYSB3ZWIgcGFnZS4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSBtZXNzYWdlIGh1bWFuIHJlYWRhYmxlIGVycm9yLgogICAgICogQHBhcmFtIHtFcnJvcn0gW2Vycl0gdGhlIG9yaWdpbmFsIGVycm9yLCBpZiB0aGVyZSBpcyBvbmUuCiAgICAgKgogICAgICogQHJldHVybnMge0Vycm9yfQogICAgICovCiAgICBmdW5jdGlvbiBtYWtlRXJyb3IoaWQsIG1zZywgZXJyLCByZXF1aXJlTW9kdWxlcykgewogICAgICAgIHZhciBlID0gbmV3IEVycm9yKG1zZyArICdcbmh0dHA6Ly9yZXF1aXJlanMub3JnL2RvY3MvZXJyb3JzLmh0bWwjJyArIGlkKTsKICAgICAgICBlLnJlcXVpcmVUeXBlID0gaWQ7CiAgICAgICAgZS5yZXF1aXJlTW9kdWxlcyA9IHJlcXVpcmVNb2R1bGVzOwogICAgICAgIGlmIChlcnIpIHsKICAgICAgICAgICAgZS5vcmlnaW5hbEVycm9yID0gZXJyOwogICAgICAgIH0KICAgICAgICByZXR1cm4gZTsKICAgIH0KCiAgICBpZiAodHlwZW9mIGRlZmluZSAhPT0gJ3VuZGVmaW5lZCcpIHsKICAgICAgICAvL0lmIGEgZGVmaW5lIGlzIGFscmVhZHkgaW4gcGxheSB2aWEgYW5vdGhlciBBTUQgbG9hZGVyLAogICAgICAgIC8vZG8gbm90IG92ZXJ3cml0ZS4KICAgICAgICByZXR1cm47CiAgICB9CgogICAgaWYgKHR5cGVvZiByZXF1aXJlanMgIT09ICd1bmRlZmluZWQnKSB7CiAgICAgICAgaWYgKGlzRnVuY3Rpb24ocmVxdWlyZWpzKSkgewogICAgICAgICAgICAvL0RvIG5vdCBvdmVyd3JpdGUgYW4gZXhpc3RpbmcgcmVxdWlyZWpzIGluc3RhbmNlLgogICAgICAgICAgICByZXR1cm47CiAgICAgICAgfQogICAgICAgIGNmZyA9IHJlcXVpcmVqczsKICAgICAgICByZXF1aXJlanMgPSB1bmRlZmluZWQ7CiAgICB9CgogICAgLy9BbGxvdyBmb3IgYSByZXF1aXJlIGNvbmZpZyBvYmplY3QKICAgIGlmICh0eXBlb2YgcmVxdWlyZSAhPT0gJ3VuZGVmaW5lZCcgJiYgIWlzRnVuY3Rpb24ocmVxdWlyZSkpIHsKICAgICAgICAvL2Fzc3VtZSBpdCBpcyBhIGNvbmZpZyBvYmplY3QuCiAgICAgICAgY2ZnID0gcmVxdWlyZTsKICAgICAgICByZXF1aXJlID0gdW5kZWZpbmVkOwogICAgfQoKICAgIGZ1bmN0aW9uIG5ld0NvbnRleHQoY29udGV4dE5hbWUpIHsKICAgICAgICB2YXIgaW5DaGVja0xvYWRlZCwgTW9kdWxlLCBjb250ZXh0LCBoYW5kbGVycywKICAgICAgICAgICAgY2hlY2tMb2FkZWRUaW1lb3V0SWQsCiAgICAgICAgICAgIGNvbmZpZyA9IHsKICAgICAgICAgICAgICAgIC8vRGVmYXVsdHMuIERvIG5vdCBzZXQgYSBkZWZhdWx0IGZvciBtYXAKICAgICAgICAgICAgICAgIC8vY29uZmlnIHRvIHNwZWVkIHVwIG5vcm1hbGl6ZSgpLCB3aGljaAogICAgICAgICAgICAgICAgLy93aWxsIHJ1biBmYXN0ZXIgaWYgdGhlcmUgaXMgbm8gZGVmYXVsdC4KICAgICAgICAgICAgICAgIHdhaXRTZWNvbmRzOiA3LAogICAgICAgICAgICAgICAgYmFzZVVybDogJy4vJywKICAgICAgICAgICAgICAgIHBhdGhzOiB7fSwKICAgICAgICAgICAgICAgIGJ1bmRsZXM6IHt9LAogICAgICAgICAgICAgICAgcGtnczoge30sCiAgICAgICAgICAgICAgICBzaGltOiB7fSwKICAgICAgICAgICAgICAgIGNvbmZpZzoge30KICAgICAgICAgICAgfSwKICAgICAgICAgICAgcmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgLy9yZWdpc3RyeSBvZiBqdXN0IGVuYWJsZWQgbW9kdWxlcywgdG8gc3BlZWQKICAgICAgICAgICAgLy9jeWNsZSBicmVha2luZyBjb2RlIHdoZW4gbG90cyBvZiBtb2R1bGVzCiAgICAgICAgICAgIC8vYXJlIHJlZ2lzdGVyZWQsIGJ1dCBub3QgYWN0aXZhdGVkLgogICAgICAgICAgICBlbmFibGVkUmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgdW5kZWZFdmVudHMgPSB7fSwKICAgICAgICAgICAgZGVmUXVldWUgPSBbXSwKICAgICAgICAgICAgZGVmaW5lZCA9IHt9LAogICAgICAgICAgICB1cmxGZXRjaGVkID0ge30sCiAgICAgICAgICAgIGJ1bmRsZXNNYXAgPSB7fSwKICAgICAgICAgICAgcmVxdWlyZUNvdW50ZXIgPSAxLAogICAgICAgICAgICB1bm5vcm1hbGl6ZWRDb3VudGVyID0gMTsKCiAgICAgICAgLyoqCiAgICAgICAgICogVHJpbXMgdGhlIC4gYW5kIC4uIGZyb20gYW4gYXJyYXkgb2YgcGF0aCBzZWdtZW50cy4KICAgICAgICAgKiBJdCB3aWxsIGtlZXAgYSBsZWFkaW5nIHBhdGggc2VnbWVudCBpZiBhIC4uIHdpbGwgYmVjb21lCiAgICAgICAgICogdGhlIGZpcnN0IHBhdGggc2VnbWVudCwgdG8gaGVscCB3aXRoIG1vZHVsZSBuYW1lIGxvb2t1cHMsCiAgICAgICAgICogd2hpY2ggYWN0IGxpa2UgcGF0aHMsIGJ1dCBjYW4gYmUgcmVtYXBwZWQuIEJ1dCB0aGUgZW5kIHJlc3VsdCwKICAgICAgICAgKiBhbGwgcGF0aHMgdGhhdCB1c2UgdGhpcyBmdW5jdGlvbiBzaG91bGQgbG9vayBub3JtYWxpemVkLgogICAgICAgICAqIE5PVEU6IHRoaXMgbWV0aG9kIE1PRElGSUVTIHRoZSBpbnB1dCBhcnJheS4KICAgICAgICAgKiBAcGFyYW0ge0FycmF5fSBhcnkgdGhlIGFycmF5IG9mIHBhdGggc2VnbWVudHMuCiAgICAgICAgICovCiAgICAgICAgZnVuY3Rpb24gdHJpbURvdHMoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpLCBwYXJ0OwogICAgICAgICAgICBmb3IgKGkgPSAwOyBpIDwgYXJ5Lmxlbmd0aDsgaSsrKSB7CiAgICAgICAgICAgICAgICBwYXJ0ID0gYXJ5W2ldOwogICAgICAgICAgICAgICAgaWYgKHBhcnQgPT09ICcuJykgewogICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSwgMSk7CiAgICAgICAgICAgICAgICAgICAgaSAtPSAxOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmIChwYXJ0ID09PSAnLi4nKSB7CiAgICAgICAgICAgICAgICAgICAgLy8gSWYgYXQgdGhlIHN0YXJ0LCBvciBwcmV2aW91cyB2YWx1ZSBpcyBzdGlsbCAuLiwKICAgICAgICAgICAgICAgICAgICAvLyBrZWVwIHRoZW0gc28gdGhhdCB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGggaXQgbWF5CiAgICAgICAgICAgICAgICAgICAgLy8gc3RpbGwgd29yayB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGgsIGV2ZW4gdGhvdWdoCiAgICAgICAgICAgICAgICAgICAgLy8gYXMgYW4gSUQgaXQgaXMgbGVzcyB0aGFuIGlkZWFsLiBJbiBsYXJnZXIgcG9pbnQKICAgICAgICAgICAgICAgICAgICAvLyByZWxlYXNlcywgbWF5IGJlIGJldHRlciB0byBqdXN0IGtpY2sgb3V0IGFuIGVycm9yLgogICAgICAgICAgICAgICAgICAgIGlmIChpID09PSAwIHx8IChpID09PSAxICYmIGFyeVsyXSA9PT0gJy4uJykgfHwgYXJ5W2kgLSAxXSA9PT0gJy4uJykgewogICAgICAgICAgICAgICAgICAgICAgICBjb250aW51ZTsKICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKGkgPiAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSAtIDEsIDIpOwogICAgICAgICAgICAgICAgICAgICAgICBpIC09IDI7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBHaXZlbiBhIHJlbGF0aXZlIG1vZHVsZSBuYW1lLCBsaWtlIC4vc29tZXRoaW5nLCBub3JtYWxpemUgaXQgdG8KICAgICAgICAgKiBhIHJlYWwgbmFtZSB0aGF0IGNhbiBiZSBtYXBwZWQgdG8gYSBwYXRoLgogICAgICAgICAqIEBwYXJhbSB7U3RyaW5nfSBuYW1lIHRoZSByZWxhdGl2ZSBuYW1lCiAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IGJhc2VOYW1lIGEgcmVhbCBuYW1lIHRoYXQgdGhlIG5hbWUgYXJnIGlzIHJlbGF0aXZlCiAgICAgICAgICogdG8uCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBhcHBseU1hcCBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgdmFsdWUuIFNob3VsZAogICAgICAgICAqIG9ubHkgYmUgZG9uZSBpZiB0aGlzIG5vcm1hbGl6YXRpb24gaXMgZm9yIGEgZGVwZW5kZW5jeSBJRC4KICAgICAgICAgKiBAcmV0dXJucyB7U3RyaW5nfSBub3JtYWxpemVkIG5hbWUKICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBub3JtYWxpemUobmFtZSwgYmFzZU5hbWUsIGFwcGx5TWFwKSB7CiAgICAgICAgICAgIHZhciBwa2dNYWluLCBtYXBWYWx1ZSwgbmFtZVBhcnRzLCBpLCBqLCBuYW1lU2VnbWVudCwgbGFzdEluZGV4LAogICAgICAgICAgICAgICAgZm91bmRNYXAsIGZvdW5kSSwgZm91bmRTdGFyTWFwLCBzdGFySSwgbm9ybWFsaXplZEJhc2VQYXJ0cywKICAgICAgICAgICAgICAgIGJhc2VQYXJ0cyA9IChiYXNlTmFtZSAmJiBiYXNlTmFtZS5zcGxpdCgnLycpKSwKICAgICAgICAgICAgICAgIG1hcCA9IGNvbmZpZy5tYXAsCiAgICAgICAgICAgICAgICBzdGFyTWFwID0gbWFwICYmIG1hcFsnKiddOwoKICAgICAgICAgICAgLy9BZGp1c3QgYW55IHJlbGF0aXZlIHBhdGhzLgogICAgICAgICAgICBpZiAobmFtZSkgewogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3BsaXQoJy8nKTsKICAgICAgICAgICAgICAgIGxhc3RJbmRleCA9IG5hbWUubGVuZ3RoIC0gMTsKCiAgICAgICAgICAgICAgICAvLyBJZiB3YW50aW5nIG5vZGUgSUQgY29tcGF0aWJpbGl0eSwgc3RyaXAgLmpzIGZyb20gZW5kCiAgICAgICAgICAgICAgICAvLyBvZiBJRHMuIEhhdmUgdG8gZG8gdGhpcyBoZXJlLCBhbmQgbm90IGluIG5hbWVUb1VybAogICAgICAgICAgICAgICAgLy8gYmVjYXVzZSBub2RlIGFsbG93cyBlaXRoZXIgLmpzIG9yIG5vbiAuanMgdG8gbWFwCiAgICAgICAgICAgICAgICAvLyB0byBzYW1lIGZpbGUuCiAgICAgICAgICAgICAgICBpZiAoY29uZmlnLm5vZGVJZENvbXBhdCAmJiBqc1N1ZmZpeFJlZ0V4cC50ZXN0KG5hbWVbbGFzdEluZGV4XSkpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lW2xhc3RJbmRleF0gPSBuYW1lW2xhc3RJbmRleF0ucmVwbGFjZShqc1N1ZmZpeFJlZ0V4cCwgJycpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vIFN0YXJ0cyB3aXRoIGEgJy4nIHNvIG5lZWQgdGhlIGJhc2VOYW1lCiAgICAgICAgICAgICAgICBpZiAobmFtZVswXS5jaGFyQXQoMCkgPT09ICcuJyAmJiBiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAvL0NvbnZlcnQgYmFzZU5hbWUgdG8gYXJyYXksIGFuZCBsb3Agb2ZmIHRoZSBsYXN0IHBhcnQsCiAgICAgICAgICAgICAgICAgICAgLy9zbyB0aGF0IC4gbWF0Y2hlcyB0aGF0ICdkaXJlY3RvcnknIGFuZCBub3QgbmFtZSBvZiB0aGUgYmFzZU5hbWUncwogICAgICAgICAgICAgICAgICAgIC8vbW9kdWxlLiBGb3IgaW5zdGFuY2UsIGJhc2VOYW1lIG9mICdvbmUvdHdvL3RocmVlJywgbWFwcyB0bwogICAgICAgICAgICAgICAgICAgIC8vJ29uZS90d28vdGhyZWUuanMnLCBidXQgd2Ugd2FudCB0aGUgZGlyZWN0b3J5LCAnb25lL3R3bycgZm9yCiAgICAgICAgICAgICAgICAgICAgLy90aGlzIG5vcm1hbGl6YXRpb24uCiAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZEJhc2VQYXJ0cyA9IGJhc2VQYXJ0cy5zbGljZSgwLCBiYXNlUGFydHMubGVuZ3RoIC0gMSk7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vcm1hbGl6ZWRCYXNlUGFydHMuY29uY2F0KG5hbWUpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHRyaW1Eb3RzKG5hbWUpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuam9pbignLycpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL0FwcGx5IG1hcCBjb25maWcgaWYgYXZhaWxhYmxlLgogICAgICAgICAgICBpZiAoYXBwbHlNYXAgJiYgbWFwICYmIChiYXNlUGFydHMgfHwgc3Rhck1hcCkpIHsKICAgICAgICAgICAgICAgIG5hbWVQYXJ0cyA9IG5hbWUuc3BsaXQoJy8nKTsKCiAgICAgICAgICAgICAgICBvdXRlckxvb3A6IGZvciAoaSA9IG5hbWVQYXJ0cy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lU2VnbWVudCA9IG5hbWVQYXJ0cy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgIGlmIChiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9GaW5kIHRoZSBsb25nZXN0IGJhc2VOYW1lIHNlZ21lbnQgbWF0Y2ggaW4gdGhlIGNvbmZpZy4KICAgICAgICAgICAgICAgICAgICAgICAgLy9TbywgZG8gam9pbnMgb24gdGhlIGJpZ2dlc3QgdG8gc21hbGxlc3QgbGVuZ3RocyBvZiBiYXNlUGFydHMuCiAgICAgICAgICAgICAgICAgICAgICAgIGZvciAoaiA9IGJhc2VQYXJ0cy5sZW5ndGg7IGogPiAwOyBqIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1hcFZhbHVlID0gZ2V0T3duKG1hcCwgYmFzZVBhcnRzLnNsaWNlKDAsIGopLmpvaW4oJy8nKSk7CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iYXNlTmFtZSBzZWdtZW50IGhhcyBjb25maWcsIGZpbmQgaWYgaXQgaGFzIG9uZSBmb3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1hcFZhbHVlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbWFwVmFsdWUgPSBnZXRPd24obWFwVmFsdWUsIG5hbWVTZWdtZW50KTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAobWFwVmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXRjaCwgdXBkYXRlIG5hbWUgdG8gdGhlIG5ldyB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRNYXAgPSBtYXBWYWx1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRJID0gaTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWsgb3V0ZXJMb29wOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9DaGVjayBmb3IgYSBzdGFyIG1hcCBtYXRjaCwgYnV0IGp1c3QgaG9sZCBvbiB0byBpdCwKICAgICAgICAgICAgICAgICAgICAvL2lmIHRoZXJlIGlzIGEgc2hvcnRlciBzZWdtZW50IG1hdGNoIGxhdGVyIGluIGEgbWF0Y2hpbmcKICAgICAgICAgICAgICAgICAgICAvL2NvbmZpZywgdGhlbiBmYXZvciBvdmVyIHRoaXMgc3RhciBtYXAuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFmb3VuZFN0YXJNYXAgJiYgc3Rhck1hcCAmJiBnZXRPd24oc3Rhck1hcCwgbmFtZVNlZ21lbnQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGZvdW5kU3Rhck1hcCA9IGdldE93bihzdGFyTWFwLCBuYW1lU2VnbWVudCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHN0YXJJID0gaTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFmb3VuZE1hcCAmJiBmb3VuZFN0YXJNYXApIHsKICAgICAgICAgICAgICAgICAgICBmb3VuZE1hcCA9IGZvdW5kU3Rhck1hcDsKICAgICAgICAgICAgICAgICAgICBmb3VuZEkgPSBzdGFySTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBpZiAoZm91bmRNYXApIHsKICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMuc3BsaWNlKDAsIGZvdW5kSSwgZm91bmRNYXApOwogICAgICAgICAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHMuam9pbignLycpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICAvLyBJZiB0aGUgbmFtZSBwb2ludHMgdG8gYSBwYWNrYWdlJ3MgbmFtZSwgdXNlCiAgICAgICAgICAgIC8vIHRoZSBwYWNrYWdlIG1haW4gaW5zdGVhZC4KICAgICAgICAgICAgcGtnTWFpbiA9IGdldE93bihjb25maWcucGtncywgbmFtZSk7CgogICAgICAgICAgICByZXR1cm4gcGtnTWFpbiA/IHBrZ01haW4gOiBuYW1lOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gcmVtb3ZlU2NyaXB0KG5hbWUpIHsKICAgICAgICAgICAgaWYgKGlzQnJvd3NlcikgewogICAgICAgICAgICAgICAgZWFjaChzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHROb2RlKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKHNjcmlwdE5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKSA9PT0gbmFtZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2NyaXB0Tm9kZS5nZXRBdHRyaWJ1dGUoJ2RhdGEtcmVxdWlyZWNvbnRleHQnKSA9PT0gY29udGV4dC5jb250ZXh0TmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzY3JpcHROb2RlLnBhcmVudE5vZGUucmVtb3ZlQ2hpbGQoc2NyaXB0Tm9kZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBoYXNQYXRoRmFsbGJhY2soaWQpIHsKICAgICAgICAgICAgdmFyIHBhdGhDb25maWcgPSBnZXRPd24oY29uZmlnLnBhdGhzLCBpZCk7CiAgICAgICAgICAgIGlmIChwYXRoQ29uZmlnICYmIGlzQXJyYXkocGF0aENvbmZpZykgJiYgcGF0aENvbmZpZy5sZW5ndGggPiAxKSB7CiAgICAgICAgICAgICAgICAvL1BvcCBvZmYgdGhlIGZpcnN0IGFycmF5IHZhbHVlLCBzaW5jZSBpdCBmYWlsZWQsIGFuZAogICAgICAgICAgICAgICAgLy9yZXRyeQogICAgICAgICAgICAgICAgcGF0aENvbmZpZy5zaGlmdCgpOwogICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlLnVuZGVmKGlkKTsKCiAgICAgICAgICAgICAgICAvL0N1c3RvbSByZXF1aXJlIHRoYXQgZG9lcyBub3QgZG8gbWFwIHRyYW5zbGF0aW9uLCBzaW5jZQogICAgICAgICAgICAgICAgLy9JRCBpcyAiYWJzb2x1dGUiLCBhbHJlYWR5IG1hcHBlZC9yZXNvbHZlZC4KICAgICAgICAgICAgICAgIGNvbnRleHQubWFrZVJlcXVpcmUobnVsbCwgewogICAgICAgICAgICAgICAgICAgIHNraXBNYXA6IHRydWUKICAgICAgICAgICAgICAgIH0pKFtpZF0pOwoKICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvL1R1cm5zIGEgcGx1Z2luIXJlc291cmNlIHRvIFtwbHVnaW4sIHJlc291cmNlXQogICAgICAgIC8vd2l0aCB0aGUgcGx1Z2luIGJlaW5nIHVuZGVmaW5lZCBpZiB0aGUgbmFtZQogICAgICAgIC8vZGlkIG5vdCBoYXZlIGEgcGx1Z2luIHByZWZpeC4KICAgICAgICBmdW5jdGlvbiBzcGxpdFByZWZpeChuYW1lKSB7CiAgICAgICAgICAgIHZhciBwcmVmaXgsCiAgICAgICAgICAgICAgICBpbmRleCA9IG5hbWUgPyBuYW1lLmluZGV4T2YoJyEnKSA6IC0xOwogICAgICAgICAgICBpZiAoaW5kZXggPiAtMSkgewogICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZS5zdWJzdHJpbmcoMCwgaW5kZXgpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3Vic3RyaW5nKGluZGV4ICsgMSwgbmFtZS5sZW5ndGgpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIHJldHVybiBbcHJlZml4LCBuYW1lXTsKICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIENyZWF0ZXMgYSBtb2R1bGUgbWFwcGluZyB0aGF0IGluY2x1ZGVzIHBsdWdpbiBwcmVmaXgsIG1vZHVsZQogICAgICAgICAqIG5hbWUsIGFuZCBwYXRoLiBJZiBwYXJlbnRNb2R1bGVNYXAgaXMgcHJvdmlkZWQgaXQgd2lsbAogICAgICAgICAqIGFsc28gbm9ybWFsaXplIHRoZSBuYW1lIHZpYSByZXF1aXJlLm5vcm1hbGl6ZSgpCiAgICAgICAgICoKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gbmFtZSB0aGUgbW9kdWxlIG5hbWUKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gW3BhcmVudE1vZHVsZU1hcF0gcGFyZW50IG1vZHVsZSBtYXAKICAgICAgICAgKiBmb3IgdGhlIG1vZHVsZSBuYW1lLCB1c2VkIHRvIHJlc29sdmUgcmVsYXRpdmUgbmFtZXMuCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBpc05vcm1hbGl6ZWQ6IGlzIHRoZSBJRCBhbHJlYWR5IG5vcm1hbGl6ZWQuCiAgICAgICAgICogVGhpcyBpcyB0cnVlIGlmIHRoaXMgY2FsbCBpcyBkb25lIGZvciBhIGRlZmluZSgpIG1vZHVsZSBJRC4KICAgICAgICAgKiBAcGFyYW0ge0Jvb2xlYW59IGFwcGx5TWFwOiBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgSUQuCiAgICAgICAgICogU2hvdWxkIG9ubHkgYmUgdHJ1ZSBpZiB0aGlzIG1hcCBpcyBmb3IgYSBkZXBlbmRlbmN5LgogICAgICAgICAqCiAgICAgICAgICogQHJldHVybnMge09iamVjdH0KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBtYWtlTW9kdWxlTWFwKG5hbWUsIHBhcmVudE1vZHVsZU1hcCwgaXNOb3JtYWxpemVkLCBhcHBseU1hcCkgewogICAgICAgICAgICB2YXIgdXJsLCBwbHVnaW5Nb2R1bGUsIHN1ZmZpeCwgbmFtZVBhcnRzLAogICAgICAgICAgICAgICAgcHJlZml4ID0gbnVsbCwKICAgICAgICAgICAgICAgIHBhcmVudE5hbWUgPSBwYXJlbnRNb2R1bGVNYXAgPyBwYXJlbnRNb2R1bGVNYXAubmFtZSA6IG51bGwsCiAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWUgPSBuYW1lLAogICAgICAgICAgICAgICAgaXNEZWZpbmUgPSB0cnVlLAogICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUgPSAnJzsKCiAgICAgICAgICAgIC8vSWYgbm8gbmFtZSwgdGhlbiBpdCBtZWFucyBpdCBpcyBhIHJlcXVpcmUgY2FsbCwgZ2VuZXJhdGUgYW4KICAgICAgICAgICAgLy9pbnRlcm5hbCBuYW1lLgogICAgICAgICAgICBpZiAoIW5hbWUpIHsKICAgICAgICAgICAgICAgIGlzRGVmaW5lID0gZmFsc2U7CiAgICAgICAgICAgICAgICBuYW1lID0gJ19AcicgKyAocmVxdWlyZUNvdW50ZXIgKz0gMSk7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIG5hbWVQYXJ0cyA9IHNwbGl0UHJlZml4KG5hbWUpOwogICAgICAgICAgICBwcmVmaXggPSBuYW1lUGFydHNbMF07CiAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHNbMV07CgogICAgICAgICAgICBpZiAocHJlZml4KSB7CiAgICAgICAgICAgICAgICBwcmVmaXggPSBub3JtYWxpemUocHJlZml4LCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICBwbHVnaW5Nb2R1bGUgPSBnZXRPd24oZGVmaW5lZCwgcHJlZml4KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9BY2NvdW50IGZvciByZWxhdGl2ZSBwYXRocyBpZiB0aGVyZSBpcyBhIGJhc2UgbmFtZS4KICAgICAgICAgICAgaWYgKG5hbWUpIHsKICAgICAgICAgICAgICAgIGlmIChwcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICBpZiAocGx1Z2luTW9kdWxlICYmIHBsdWdpbk1vZHVsZS5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9QbHVnaW4gaXMgbG9hZGVkLCB1c2UgaXRzIG5vcm1hbGl6ZSBtZXRob2QuCiAgICAgICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gcGx1Z2luTW9kdWxlLm5vcm1hbGl6ZShuYW1lLCBmdW5jdGlvbiAobmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIG5lc3RlZCBwbHVnaW4gcmVmZXJlbmNlcywgdGhlbiBkbyBub3QgdHJ5IHRvCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIG5vcm1hbGl6ZSwgYXMgaXQgd2lsbCBub3Qgbm9ybWFsaXplIGNvcnJlY3RseS4gVGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvLyBwbGFjZXMgYSByZXN0cmljdGlvbiBvbiByZXNvdXJjZUlkcywgYW5kIHRoZSBsb25nZXIKICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGVybSBzb2x1dGlvbiBpcyBub3QgdG8gbm9ybWFsaXplIHVudGlsIHBsdWdpbnMgYXJlCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIGxvYWRlZCBhbmQgYWxsIG5vcm1hbGl6YXRpb25zIHRvIGFsbG93IGZvciBhc3luYwogICAgICAgICAgICAgICAgICAgICAgICAvLyBsb2FkaW5nIG9mIGEgbG9hZGVyIHBsdWdpbi4gQnV0IGZvciBub3csIGZpeGVzIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAvLyBjb21tb24gdXNlcy4gRGV0YWlscyBpbiAjMTEzMQogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5hbWUuaW5kZXhPZignIScpID09PSAtMSA/CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplKG5hbWUsIHBhcmVudE5hbWUsIGFwcGx5TWFwKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbmFtZTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIC8vQSByZWd1bGFyIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CgogICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplZCBuYW1lIG1heSBiZSBhIHBsdWdpbiBJRCBkdWUgdG8gbWFwIGNvbmZpZwogICAgICAgICAgICAgICAgICAgIC8vYXBwbGljYXRpb24gaW4gbm9ybWFsaXplLiBUaGUgbWFwIGNvbmZpZyB2YWx1ZXMgbXVzdAogICAgICAgICAgICAgICAgICAgIC8vYWxyZWFkeSBiZSBub3JtYWxpemVkLCBzbyBkbyBub3QgbmVlZCB0byByZWRvIHRoYXQgcGFydC4KICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMgPSBzcGxpdFByZWZpeChub3JtYWxpemVkTmFtZSk7CiAgICAgICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZVBhcnRzWzBdOwogICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gbmFtZVBhcnRzWzFdOwogICAgICAgICAgICAgICAgICAgIGlzTm9ybWFsaXplZCA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIHVybCA9IGNvbnRleHQubmFtZVRvVXJsKG5vcm1hbGl6ZWROYW1lKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiB0aGUgaWQgaXMgYSBwbHVnaW4gaWQgdGhhdCBjYW5ub3QgYmUgZGV0ZXJtaW5lZCBpZiBpdCBuZWVkcwogICAgICAgICAgICAvL25vcm1hbGl6YXRpb24sIHN0YW1wIGl0IHdpdGggYSB1bmlxdWUgSUQgc28gdHdvIG1hdGNoaW5nIHJlbGF0aXZlCiAgICAgICAgICAgIC8vaWRzIHRoYXQgbWF5IGNvbmZsaWN0IGNhbiBiZSBzZXBhcmF0ZS4KICAgICAgICAgICAgc3VmZml4ID0gcHJlZml4ICYmICFwbHVnaW5Nb2R1bGUgJiYgIWlzTm9ybWFsaXplZCA/CiAgICAgICAgICAgICAgICAgICAgICdfdW5ub3JtYWxpemVkJyArICh1bm5vcm1hbGl6ZWRDb3VudGVyICs9IDEpIDoKICAgICAgICAgICAgICAgICAgICAgJyc7CgogICAgICAgICAgICByZXR1cm4gewogICAgICAgICAgICAgICAgcHJlZml4OiBwcmVmaXgsCiAgICAgICAgICAgICAgICBuYW1lOiBub3JtYWxpemVkTmFtZSwKICAgICAgICAgICAgICAgIHBhcmVudE1hcDogcGFyZW50TW9kdWxlTWFwLAogICAgICAgICAgICAgICAgdW5ub3JtYWxpemVkOiAhIXN1ZmZpeCwKICAgICAgICAgICAgICAgIHVybDogdXJsLAogICAgICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBvcmlnaW5hbE5hbWUsCiAgICAgICAgICAgICAgICBpc0RlZmluZTogaXNEZWZpbmUsCiAgICAgICAgICAgICAgICBpZDogKHByZWZpeCA/CiAgICAgICAgICAgICAgICAgICAgICAgIHByZWZpeCArICchJyArIG5vcm1hbGl6ZWROYW1lIDoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUpICsgc3VmZml4CiAgICAgICAgICAgIH07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBnZXRNb2R1bGUoZGVwTWFwKSB7CiAgICAgICAgICAgIHZhciBpZCA9IGRlcE1hcC5pZCwKICAgICAgICAgICAgICAgIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwoKICAgICAgICAgICAgaWYgKCFtb2QpIHsKICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXSA9IG5ldyBjb250ZXh0Lk1vZHVsZShkZXBNYXApOwogICAgICAgICAgICB9CgogICAgICAgICAgICByZXR1cm4gbW9kOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gb24oZGVwTWFwLCBuYW1lLCBmbikgewogICAgICAgICAgICB2YXIgaWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIGlkKTsKCiAgICAgICAgICAgIGlmIChoYXNQcm9wKGRlZmluZWQsIGlkKSAmJgogICAgICAgICAgICAgICAgICAgICghbW9kIHx8IG1vZC5kZWZpbmVFbWl0Q29tcGxldGUpKSB7CiAgICAgICAgICAgICAgICBpZiAobmFtZSA9PT0gJ2RlZmluZWQnKSB7CiAgICAgICAgICAgICAgICAgICAgZm4oZGVmaW5lZFtpZF0pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgbW9kID0gZ2V0TW9kdWxlKGRlcE1hcCk7CiAgICAgICAgICAgICAgICBpZiAobW9kLmVycm9yICYmIG5hbWUgPT09ICdlcnJvcicpIHsKICAgICAgICAgICAgICAgICAgICBmbihtb2QuZXJyb3IpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICBtb2Qub24obmFtZSwgZm4pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBvbkVycm9yKGVyciwgZXJyYmFjaykgewogICAgICAgICAgICB2YXIgaWRzID0gZXJyLnJlcXVpcmVNb2R1bGVzLAogICAgICAgICAgICAgICAgbm90aWZpZWQgPSBmYWxzZTsKCiAgICAgICAgICAgIGlmIChlcnJiYWNrKSB7CiAgICAgICAgICAgICAgICBlcnJiYWNrKGVycik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBlYWNoKGlkcywgZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwogICAgICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9TZXQgZXJyb3Igb24gbW9kdWxlLCBzbyBpdCBza2lwcyB0aW1lb3V0IGNoZWNrcy4KICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVycm9yID0gZXJyOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm90aWZpZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIGlmICghbm90aWZpZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdG8gdHJhbnNmZXIgZ2xvYmFsUXVldWUgaXRlbXMgdG8gdGhpcyBjb250ZXh0J3MKICAgICAgICAgKiBkZWZRdWV1ZS4KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiB0YWtlR2xvYmFsUXVldWUoKSB7CiAgICAgICAgICAgIC8vUHVzaCBhbGwgdGhlIGdsb2JhbERlZlF1ZXVlIGl0ZW1zIGludG8gdGhlIGNvbnRleHQncyBkZWZRdWV1ZQogICAgICAgICAgICBpZiAoZ2xvYmFsRGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICBlYWNoKGdsb2JhbERlZlF1ZXVlLCBmdW5jdGlvbihxdWV1ZUl0ZW0pIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQgPSBxdWV1ZUl0ZW1bMF07CiAgICAgICAgICAgICAgICAgICAgaWYgKHR5cGVvZiBpZCA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcFtpZF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICBkZWZRdWV1ZS5wdXNoKHF1ZXVlSXRlbSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGdsb2JhbERlZlF1ZXVlID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGhhbmRsZXJzID0gewogICAgICAgICAgICAncmVxdWlyZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QucmVxdWlyZSkgewogICAgICAgICAgICAgICAgICAgIHJldHVybiBtb2QucmVxdWlyZTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChtb2QucmVxdWlyZSA9IGNvbnRleHQubWFrZVJlcXVpcmUobW9kLm1hcCkpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAogICAgICAgICAgICAnZXhwb3J0cyc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIG1vZC51c2luZ0V4cG9ydHMgPSB0cnVlOwogICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChkZWZpbmVkW21vZC5tYXAuaWRdID0gbW9kLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLmV4cG9ydHMgPSBkZWZpbmVkW21vZC5tYXAuaWRdID0ge30pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKICAgICAgICAgICAgJ21vZHVsZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QubW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG1vZC5tb2R1bGU7CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLm1vZHVsZSA9IHsKICAgICAgICAgICAgICAgICAgICAgICAgaWQ6IG1vZC5tYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIHVyaTogbW9kLm1hcC51cmwsCiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZzogZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGdldE93bihjb25maWcuY29uZmlnLCBtb2QubWFwLmlkKSB8fCB7fTsKICAgICAgICAgICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0czogbW9kLmV4cG9ydHMgfHwgKG1vZC5leHBvcnRzID0ge30pCiAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjbGVhblJlZ2lzdHJ5KGlkKSB7CiAgICAgICAgICAgIC8vQ2xlYW4gdXAgbWFjaGluZXJ5IHVzZWQgZm9yIHdhaXRpbmcgbW9kdWxlcy4KICAgICAgICAgICAgZGVsZXRlIHJlZ2lzdHJ5W2lkXTsKICAgICAgICAgICAgZGVsZXRlIGVuYWJsZWRSZWdpc3RyeVtpZF07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBicmVha0N5Y2xlKG1vZCwgdHJhY2VkLCBwcm9jZXNzZWQpIHsKICAgICAgICAgICAgdmFyIGlkID0gbW9kLm1hcC5pZDsKCiAgICAgICAgICAgIGlmIChtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgIG1vZC5lbWl0KCdlcnJvcicsIG1vZC5lcnJvcik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICB0cmFjZWRbaWRdID0gdHJ1ZTsKICAgICAgICAgICAgICAgIGVhY2gobW9kLmRlcE1hcHMsIGZ1bmN0aW9uIChkZXBNYXAsIGkpIHsKICAgICAgICAgICAgICAgICAgICB2YXIgZGVwSWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIGRlcCA9IGdldE93bihyZWdpc3RyeSwgZGVwSWQpOwoKICAgICAgICAgICAgICAgICAgICAvL09ubHkgZm9yY2UgdGhpbmdzIHRoYXQgaGF2ZSBub3QgY29tcGxldGVkCiAgICAgICAgICAgICAgICAgICAgLy9iZWluZyBkZWZpbmVkLCBzbyBzdGlsbCBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAgICAgLy9hbmQgb25seSBpZiBpdCBoYXMgbm90IGJlZW4gbWF0Y2hlZCB1cAogICAgICAgICAgICAgICAgICAgIC8vaW4gdGhlIG1vZHVsZSBhbHJlYWR5LgogICAgICAgICAgICAgICAgICAgIGlmIChkZXAgJiYgIW1vZC5kZXBNYXRjaGVkW2ldICYmICFwcm9jZXNzZWRbZGVwSWRdKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChnZXRPd24odHJhY2VkLCBkZXBJZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1vZC5kZWZpbmVEZXAoaSwgZGVmaW5lZFtkZXBJZF0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmNoZWNrKCk7IC8vcGFzcyBmYWxzZT8KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGJyZWFrQ3ljbGUoZGVwLCB0cmFjZWQsIHByb2Nlc3NlZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIHByb2Nlc3NlZFtpZF0gPSB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBjaGVja0xvYWRlZCgpIHsKICAgICAgICAgICAgdmFyIGVyciwgdXNpbmdQYXRoRmFsbGJhY2ssCiAgICAgICAgICAgICAgICB3YWl0SW50ZXJ2YWwgPSBjb25maWcud2FpdFNlY29uZHMgKiAxMDAwLAogICAgICAgICAgICAgICAgLy9JdCBpcyBwb3NzaWJsZSB0byBkaXNhYmxlIHRoZSB3YWl0IGludGVydmFsIGJ5IHVzaW5nIHdhaXRTZWNvbmRzIG9mIDAuCiAgICAgICAgICAgICAgICBleHBpcmVkID0gd2FpdEludGVydmFsICYmIChjb250ZXh0LnN0YXJ0VGltZSArIHdhaXRJbnRlcnZhbCkgPCBuZXcgRGF0ZSgpLmdldFRpbWUoKSwKICAgICAgICAgICAgICAgIG5vTG9hZHMgPSBbXSwKICAgICAgICAgICAgICAgIHJlcUNhbGxzID0gW10sCiAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSBmYWxzZSwKICAgICAgICAgICAgICAgIG5lZWRDeWNsZUNoZWNrID0gdHJ1ZTsKCiAgICAgICAgICAgIC8vRG8gbm90IGJvdGhlciBpZiB0aGlzIGNhbGwgd2FzIGEgcmVzdWx0IG9mIGEgY3ljbGUgYnJlYWsuCiAgICAgICAgICAgIGlmIChpbkNoZWNrTG9hZGVkKSB7CiAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIGluQ2hlY2tMb2FkZWQgPSB0cnVlOwoKICAgICAgICAgICAgLy9GaWd1cmUgb3V0IHRoZSBzdGF0ZSBvZiBhbGwgdGhlIG1vZHVsZXMuCiAgICAgICAgICAgIGVhY2hQcm9wKGVuYWJsZWRSZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgdmFyIG1hcCA9IG1vZC5tYXAsCiAgICAgICAgICAgICAgICAgICAgbW9kSWQgPSBtYXAuaWQ7CgogICAgICAgICAgICAgICAgLy9Ta2lwIHRoaW5ncyB0aGF0IGFyZSBub3QgZW5hYmxlZCBvciBpbiBlcnJvciBzdGF0ZS4KICAgICAgICAgICAgICAgIGlmICghbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICByZXFDYWxscy5wdXNoKG1vZCk7CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIHRoZSBtb2R1bGUgc2hvdWxkIGJlIGV4ZWN1dGVkLCBhbmQgaXQgaGFzIG5vdAogICAgICAgICAgICAgICAgICAgIC8vYmVlbiBpbml0ZWQgYW5kIHRpbWUgaXMgdXAsIHJlbWVtYmVyIGl0LgogICAgICAgICAgICAgICAgICAgIGlmICghbW9kLmluaXRlZCAmJiBleHBpcmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChoYXNQYXRoRmFsbGJhY2sobW9kSWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2luZ1BhdGhGYWxsYmFjayA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9Mb2Fkcy5wdXNoKG1vZElkKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChtb2RJZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCFtb2QuaW5pdGVkICYmIG1vZC5mZXRjaGVkICYmIG1hcC5pc0RlZmluZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIW1hcC5wcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vTm8gcmVhc29uIHRvIGtlZXAgbG9va2luZyBmb3IgdW5maW5pc2hlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9sb2FkaW5nLiBJZiB0aGUgb25seSBzdGlsbExvYWRpbmcgaXMgYQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9wbHVnaW4gcmVzb3VyY2UgdGhvdWdoLCBrZWVwIGdvaW5nLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iZWNhdXNlIGl0IG1heSBiZSB0aGF0IGEgcGx1Z2luIHJlc291cmNlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL2lzIHdhaXRpbmcgb24gYSBub24tcGx1Z2luIGN5Y2xlLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChuZWVkQ3ljbGVDaGVjayA9IGZhbHNlKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSk7CgogICAgICAgICAgICBpZiAoZXhwaXJlZCAmJiBub0xvYWRzLmxlbmd0aCkgewogICAgICAgICAgICAgICAgLy9JZiB3YWl0IHRpbWUgZXhwaXJlZCwgdGhyb3cgZXJyb3Igb2YgdW5sb2FkZWQgbW9kdWxlcy4KICAgICAgICAgICAgICAgIGVyciA9IG1ha2VFcnJvcigndGltZW91dCcsICdMb2FkIHRpbWVvdXQgZm9yIG1vZHVsZXM6ICcgKyBub0xvYWRzLCBudWxsLCBub0xvYWRzKTsKICAgICAgICAgICAgICAgIGVyci5jb250ZXh0TmFtZSA9IGNvbnRleHQuY29udGV4dE5hbWU7CiAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihlcnIpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL05vdCBleHBpcmVkLCBjaGVjayBmb3IgYSBjeWNsZS4KICAgICAgICAgICAgaWYgKG5lZWRDeWNsZUNoZWNrKSB7CiAgICAgICAgICAgICAgICBlYWNoKHJlcUNhbGxzLCBmdW5jdGlvbiAobW9kKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWtDeWNsZShtb2QsIHt9LCB7fSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiBzdGlsbCB3YWl0aW5nIG9uIGxvYWRzLCBhbmQgdGhlIHdhaXRpbmcgbG9hZCBpcyBzb21ldGhpbmcKICAgICAgICAgICAgLy9vdGhlciB0aGFuIGEgcGx1Z2luIHJlc291cmNlLCBvciB0aGVyZSBhcmUgc3RpbGwgb3V0c3RhbmRpbmcKICAgICAgICAgICAgLy9zY3JpcHRzLCB0aGVuIGp1c3QgdHJ5IGJhY2sgbGF0ZXIuCiAgICAgICAgICAgIGlmICgoIWV4cGlyZWQgfHwgdXNpbmdQYXRoRmFsbGJhY2spICYmIHN0aWxsTG9hZGluZykgewogICAgICAgICAgICAgICAgLy9Tb21ldGhpbmcgaXMgc3RpbGwgd2FpdGluZyB0byBsb2FkLiBXYWl0IGZvciBpdCwgYnV0IG9ubHkKICAgICAgICAgICAgICAgIC8vaWYgYSB0aW1lb3V0IGlzIG5vdCBhbHJlYWR5IGluIGVmZmVjdC4KICAgICAgICAgICAgICAgIGlmICgoaXNCcm93c2VyIHx8IGlzV2ViV29ya2VyKSAmJiAhY2hlY2tMb2FkZWRUaW1lb3V0SWQpIHsKICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IHNldFRpbWVvdXQoZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IDA7CiAgICAgICAgICAgICAgICAgICAgICAgIGNoZWNrTG9hZGVkKCk7CiAgICAgICAgICAgICAgICAgICAgfSwgNTApOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICBpbkNoZWNrTG9hZGVkID0gZmFsc2U7CiAgICAgICAgfQoKICAgICAgICBNb2R1bGUgPSBmdW5jdGlvbiAobWFwKSB7CiAgICAgICAgICAgIHRoaXMuZXZlbnRzID0gZ2V0T3duKHVuZGVmRXZlbnRzLCBtYXAuaWQpIHx8IHt9OwogICAgICAgICAgICB0aGlzLm1hcCA9IG1hcDsKICAgICAgICAgICAgdGhpcy5zaGltID0gZ2V0T3duKGNvbmZpZy5zaGltLCBtYXAuaWQpOwogICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHMgPSBbXTsKICAgICAgICAgICAgdGhpcy5kZXBNYXBzID0gW107CiAgICAgICAgICAgIHRoaXMuZGVwTWF0Y2hlZCA9IFtdOwogICAgICAgICAgICB0aGlzLnBsdWdpbk1hcHMgPSB7fTsKICAgICAgICAgICAgdGhpcy5kZXBDb3VudCA9IDA7CgogICAgICAgICAgICAvKiB0aGlzLmV4cG9ydHMgdGhpcy5mYWN0b3J5CiAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcyA9IFtdLAogICAgICAgICAgICAgICB0aGlzLmVuYWJsZWQsIHRoaXMuZmV0Y2hlZAogICAgICAgICAgICAqLwogICAgICAgIH07CgogICAgICAgIE1vZHVsZS5wcm90b3R5cGUgPSB7CiAgICAgICAgICAgIGluaXQ6IGZ1bmN0aW9uIChkZXBNYXBzLCBmYWN0b3J5LCBlcnJiYWNrLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIG5vdCBkbyBtb3JlIGluaXRzIGlmIGFscmVhZHkgZG9uZS4gQ2FuIGhhcHBlbiBpZiB0aGVyZQogICAgICAgICAgICAgICAgLy9hcmUgbXVsdGlwbGUgZGVmaW5lIGNhbGxzIGZvciB0aGUgc2FtZSBtb2R1bGUuIFRoYXQgaXMgbm90CiAgICAgICAgICAgICAgICAvL2Egbm9ybWFsLCBjb21tb24gY2FzZSwgYnV0IGl0IGlzIGFsc28gbm90IHVuZXhwZWN0ZWQuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgdGhpcy5mYWN0b3J5ID0gZmFjdG9yeTsKCiAgICAgICAgICAgICAgICBpZiAoZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgIC8vUmVnaXN0ZXIgZm9yIGVycm9ycyBvbiB0aGlzIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLm9uKCdlcnJvcicsIGVycmJhY2spOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmICh0aGlzLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgIC8vSWYgbm8gZXJyYmFjayBhbHJlYWR5LCBidXQgdGhlcmUgYXJlIGVycm9yIGxpc3RlbmVycwogICAgICAgICAgICAgICAgICAgIC8vb24gdGhpcyBtb2R1bGUsIHNldCB1cCBhbiBlcnJiYWNrIHRvIHBhc3MgdG8gdGhlIGRlcHMuCiAgICAgICAgICAgICAgICAgICAgZXJyYmFjayA9IGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0RvIGEgY29weSBvZiB0aGUgZGVwZW5kZW5jeSBhcnJheSwgc28gdGhhdAogICAgICAgICAgICAgICAgLy9zb3VyY2UgaW5wdXRzIGFyZSBub3QgbW9kaWZpZWQuIEZvciBleGFtcGxlCiAgICAgICAgICAgICAgICAvLyJzaGltIiBkZXBzIGFyZSBwYXNzZWQgaW4gaGVyZSBkaXJlY3RseSwgYW5kCiAgICAgICAgICAgICAgICAvL2RvaW5nIGEgZGlyZWN0IG1vZGlmaWNhdGlvbiBvZiB0aGUgZGVwTWFwcyBhcnJheQogICAgICAgICAgICAgICAgLy93b3VsZCBhZmZlY3QgdGhhdCBjb25maWcuCiAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMgPSBkZXBNYXBzICYmIGRlcE1hcHMuc2xpY2UoMCk7CgogICAgICAgICAgICAgICAgdGhpcy5lcnJiYWNrID0gZXJyYmFjazsKCiAgICAgICAgICAgICAgICAvL0luZGljYXRlIHRoaXMgbW9kdWxlIGhhcyBiZSBpbml0aWFsaXplZAogICAgICAgICAgICAgICAgdGhpcy5pbml0ZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIHRoaXMuaWdub3JlID0gb3B0aW9ucy5pZ25vcmU7CgogICAgICAgICAgICAgICAgLy9Db3VsZCBoYXZlIG9wdGlvbiB0byBpbml0IHRoaXMgbW9kdWxlIGluIGVuYWJsZWQgbW9kZSwKICAgICAgICAgICAgICAgIC8vb3IgY291bGQgaGF2ZSBiZWVuIHByZXZpb3VzbHkgbWFya2VkIGFzIGVuYWJsZWQuIEhvd2V2ZXIsCiAgICAgICAgICAgICAgICAvL3RoZSBkZXBlbmRlbmNpZXMgYXJlIG5vdCBrbm93biB1bnRpbCBpbml0IGlzIGNhbGxlZC4gU28KICAgICAgICAgICAgICAgIC8vaWYgZW5hYmxlZCBwcmV2aW91c2x5LCBub3cgdHJpZ2dlciBkZXBlbmRlbmNpZXMgYXMgZW5hYmxlZC4KICAgICAgICAgICAgICAgIGlmIChvcHRpb25zLmVuYWJsZWQgfHwgdGhpcy5lbmFibGVkKSB7CiAgICAgICAgICAgICAgICAgICAgLy9FbmFibGUgdGhpcyBtb2R1bGUgYW5kIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAvL1dpbGwgY2FsbCB0aGlzLmNoZWNrKCkKICAgICAgICAgICAgICAgICAgICB0aGlzLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmNoZWNrKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBkZWZpbmVEZXA6IGZ1bmN0aW9uIChpLCBkZXBFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAvL0JlY2F1c2Ugb2YgY3ljbGVzLCBkZWZpbmVkIGNhbGxiYWNrIGZvciBhIGdpdmVuCiAgICAgICAgICAgICAgICAvL2V4cG9ydCBjYW4gYmUgY2FsbGVkIG1vcmUgdGhhbiBvbmNlLgogICAgICAgICAgICAgICAgaWYgKCF0aGlzLmRlcE1hdGNoZWRbaV0pIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hdGNoZWRbaV0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwQ291bnQgLT0gMTsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHNbaV0gPSBkZXBFeHBvcnRzOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZmV0Y2g6IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIGlmICh0aGlzLmZldGNoZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB0aGlzLmZldGNoZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIGNvbnRleHQuc3RhcnRUaW1lID0gKG5ldyBEYXRlKCkpLmdldFRpbWUoKTsKCiAgICAgICAgICAgICAgICB2YXIgbWFwID0gdGhpcy5tYXA7CgogICAgICAgICAgICAgICAgLy9JZiB0aGUgbWFuYWdlciBpcyBmb3IgYSBwbHVnaW4gbWFuYWdlZCByZXNvdXJjZSwKICAgICAgICAgICAgICAgIC8vYXNrIHRoZSBwbHVnaW4gdG8gbG9hZCBpdCBub3cuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5tYWtlUmVxdWlyZSh0aGlzLm1hcCwgewogICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVCdWlsZENhbGxiYWNrOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgfSkodGhpcy5zaGltLmRlcHMgfHwgW10sIGJpbmQodGhpcywgZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgfSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL1JlZ3VsYXIgZGVwZW5kZW5jeS4KICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBsb2FkOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICB2YXIgdXJsID0gdGhpcy5tYXAudXJsOwoKICAgICAgICAgICAgICAgIC8vUmVndWxhciBkZXBlbmRlbmN5LgogICAgICAgICAgICAgICAgaWYgKCF1cmxGZXRjaGVkW3VybF0pIHsKICAgICAgICAgICAgICAgICAgICB1cmxGZXRjaGVkW3VybF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIGNvbnRleHQubG9hZCh0aGlzLm1hcC5pZCwgdXJsKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBDaGVja3MgaWYgdGhlIG1vZHVsZSBpcyByZWFkeSB0byBkZWZpbmUgaXRzZWxmLCBhbmQgaWYgc28sCiAgICAgICAgICAgICAqIGRlZmluZSBpdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNoZWNrOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBpZiAoIXRoaXMuZW5hYmxlZCB8fCB0aGlzLmVuYWJsaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHZhciBlcnIsIGNqc01vZHVsZSwKICAgICAgICAgICAgICAgICAgICBpZCA9IHRoaXMubWFwLmlkLAogICAgICAgICAgICAgICAgICAgIGRlcEV4cG9ydHMgPSB0aGlzLmRlcEV4cG9ydHMsCiAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0cywKICAgICAgICAgICAgICAgICAgICBmYWN0b3J5ID0gdGhpcy5mYWN0b3J5OwoKICAgICAgICAgICAgICAgIGlmICghdGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICAvLyBPbmx5IGZldGNoIGlmIG5vdCBhbHJlYWR5IGluIHRoZSBkZWZRdWV1ZS4KICAgICAgICAgICAgICAgICAgICBpZiAoIWhhc1Byb3AoY29udGV4dC5kZWZRdWV1ZU1hcCwgaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZmV0Y2goKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgdGhpcy5lcnJvcik7CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCF0aGlzLmRlZmluaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgLy9UaGUgZmFjdG9yeSBjb3VsZCB0cmlnZ2VyIGFub3RoZXIgcmVxdWlyZSBjYWxsCiAgICAgICAgICAgICAgICAgICAgLy90aGF0IHdvdWxkIHJlc3VsdCBpbiBjaGVja2luZyB0aGlzIG1vZHVsZSB0bwogICAgICAgICAgICAgICAgICAgIC8vZGVmaW5lIGl0c2VsZiBhZ2Fpbi4gSWYgYWxyZWFkeSBpbiB0aGUgcHJvY2VzcwogICAgICAgICAgICAgICAgICAgIC8vb2YgZG9pbmcgdGhhdCwgc2tpcCB0aGlzIHdvcmsuCiAgICAgICAgICAgICAgICAgICAgdGhpcy5kZWZpbmluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlcENvdW50IDwgMSAmJiAhdGhpcy5kZWZpbmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChpc0Z1bmN0aW9uKGZhY3RvcnkpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0cnkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjb250ZXh0LmV4ZWNDYihpZCwgZmFjdG9yeSwgZGVwRXhwb3J0cywgZXhwb3J0cyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXJyID0gZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBGYXZvciByZXR1cm4gdmFsdWUgb3ZlciBleHBvcnRzLiBJZiBub2RlL2NqcyBpbiBwbGF5LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGhlbiB3aWxsIG5vdCBoYXZlIGEgcmV0dXJuIHZhbHVlIGFueXdheS4gRmF2b3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG1vZHVsZS5leHBvcnRzIGFzc2lnbm1lbnQgb3ZlciBleHBvcnRzIG9iamVjdC4KICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLm1hcC5pc0RlZmluZSAmJiBleHBvcnRzID09PSB1bmRlZmluZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjanNNb2R1bGUgPSB0aGlzLm1vZHVsZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAoY2pzTW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjanNNb2R1bGUuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMudXNpbmdFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vZXhwb3J0cyBhbHJlYWR5IHNldCB0aGUgZGVmaW5lZCB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIHRoZXJlIGlzIGFuIGVycm9yIGxpc3RlbmVyLCBmYXZvciBwYXNzaW5nCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdG8gdGhhdCBpbnN0ZWFkIG9mIHRocm93aW5nIGFuIGVycm9yLiBIb3dldmVyLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG9ubHkgZG8gaXQgZm9yIGRlZmluZSgpJ2QgIG1vZHVsZXMuIHJlcXVpcmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBlcnJiYWNrcyBzaG91bGQgbm90IGJlIGNhbGxlZCBmb3IgZmFpbHVyZXMgaW4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyB0aGVpciBjYWxsYmFja3MgKCM2OTkpLiBIb3dldmVyIGlmIGEgZ2xvYmFsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gb25FcnJvciBpcyBzZXQsIHVzZSB0aGF0LgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodGhpcy5ldmVudHMuZXJyb3IgJiYgdGhpcy5tYXAuaXNEZWZpbmUpIHx8CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5vbkVycm9yICE9PSBkZWZhdWx0T25FcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1hcCA9IHRoaXMubWFwOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1vZHVsZXMgPSB0aGlzLm1hcC5pc0RlZmluZSA/IFt0aGlzLm1hcC5pZF0gOiBudWxsOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZVR5cGUgPSB0aGlzLm1hcC5pc0RlZmluZSA/ICdkZWZpbmUnIDogJ3JlcXVpcmUnOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcigodGhpcy5lcnJvciA9IGVycikpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAodHlwZW9mIGNvbnNvbGUgIT09ICd1bmRlZmluZWQnICYmCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjb25zb2xlLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIExvZyB0aGUgZXJyb3IgZm9yIGRlYnVnZ2luZy4gSWYgcHJvbWlzZXMgY291bGQgYmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdXNlZCwgdGhpcyB3b3VsZCBiZSBkaWZmZXJlbnQsIGJ1dCBtYWtpbmcgZG8uCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbnNvbGUuZXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBEbyBub3Qgd2FudCB0byBjb21wbGV0ZWx5IGxvc2UgdGhlIGVycm9yLiBXaGlsZSB0aGlzCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIHdpbGwgbWVzcyB1cCBwcm9jZXNzaW5nIGFuZCBsZWFkIHRvIHNpbWlsYXIgcmVzdWx0cwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBhcyBidWcgMTQ0MCwgaXQgYXQgbGVhc3Qgc3VyZmFjZXMgdGhlIGVycm9yLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIGxpdGVyYWwgdmFsdWUKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBmYWN0b3J5OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmV4cG9ydHMgPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMubWFwLmlzRGVmaW5lICYmICF0aGlzLmlnbm9yZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmaW5lZFtpZF0gPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEub25SZXNvdXJjZUxvYWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YXIgcmVzTG9hZE1hcHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXNMb2FkTWFwcy5wdXNoKGRlcE1hcC5ub3JtYWxpemVkTWFwIHx8IGRlcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmVxLm9uUmVzb3VyY2VMb2FkKGNvbnRleHQsIHRoaXMubWFwLCByZXNMb2FkTWFwcyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQ2xlYW4gdXAKICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9GaW5pc2hlZCB0aGUgZGVmaW5lIHN0YWdlLiBBbGxvdyBjYWxsaW5nIGNoZWNrIGFnYWluCiAgICAgICAgICAgICAgICAgICAgLy90byBhbGxvdyBkZWZpbmUgbm90aWZpY2F0aW9ucyBiZWxvdyBpbiB0aGUgY2FzZSBvZiBhCiAgICAgICAgICAgICAgICAgICAgLy9jeWNsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluaW5nID0gZmFsc2U7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlZmluZWQgJiYgIXRoaXMuZGVmaW5lRW1pdHRlZCkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXR0ZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2RlZmluZWQnLCB0aGlzLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXRDb21wbGV0ZSA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGNhbGxQbHVnaW46IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIHZhciBtYXAgPSB0aGlzLm1hcCwKICAgICAgICAgICAgICAgICAgICBpZCA9IG1hcC5pZCwKICAgICAgICAgICAgICAgICAgICAvL01hcCBhbHJlYWR5IG5vcm1hbGl6ZWQgdGhlIHByZWZpeC4KICAgICAgICAgICAgICAgICAgICBwbHVnaW5NYXAgPSBtYWtlTW9kdWxlTWFwKG1hcC5wcmVmaXgpOwoKICAgICAgICAgICAgICAgIC8vTWFyayB0aGlzIGFzIGEgZGVwZW5kZW5jeSBmb3IgdGhpcyBwbHVnaW4sIHNvIGl0CiAgICAgICAgICAgICAgICAvL2NhbiBiZSB0cmFjZWQgZm9yIGN5Y2xlcy4KICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcy5wdXNoKHBsdWdpbk1hcCk7CgogICAgICAgICAgICAgICAgb24ocGx1Z2luTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbikgewogICAgICAgICAgICAgICAgICAgIHZhciBsb2FkLCBub3JtYWxpemVkTWFwLCBub3JtYWxpemVkTW9kLAogICAgICAgICAgICAgICAgICAgICAgICBidW5kbGVJZCA9IGdldE93bihidW5kbGVzTWFwLCB0aGlzLm1hcC5pZCksCiAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSB0aGlzLm1hcC5uYW1lLAogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnROYW1lID0gdGhpcy5tYXAucGFyZW50TWFwID8gdGhpcy5tYXAucGFyZW50TWFwLm5hbWUgOiBudWxsLAogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKG1hcC5wYXJlbnRNYXAsIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZUJ1aWxkQ2FsbGJhY2s6IHRydWUKICAgICAgICAgICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgICAgIC8vSWYgY3VycmVudCBtYXAgaXMgbm90IG5vcm1hbGl6ZWQsIHdhaXQgZm9yIHRoYXQKICAgICAgICAgICAgICAgICAgICAvL25vcm1hbGl6ZWQgbmFtZSB0byBsb2FkIGluc3RlYWQgb2YgY29udGludWluZy4KICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5tYXAudW5ub3JtYWxpemVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIHRoZSBJRCBpZiB0aGUgcGx1Z2luIGFsbG93cyBpdC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHBsdWdpbi5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSBwbHVnaW4ubm9ybWFsaXplKG5hbWUsIGZ1bmN0aW9uIChuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCB0cnVlKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pIHx8ICcnOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL3ByZWZpeCBhbmQgbmFtZSBzaG91bGQgYWxyZWFkeSBiZSBub3JtYWxpemVkLCBubyBuZWVkCiAgICAgICAgICAgICAgICAgICAgICAgIC8vZm9yIGFwcGx5aW5nIG1hcCBjb25maWcgYWdhaW4gZWl0aGVyLgogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTWFwID0gbWFrZU1vZHVsZU1hcChtYXAucHJlZml4ICsgJyEnICsgbmFtZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5tYXAucGFyZW50TWFwKTsKICAgICAgICAgICAgICAgICAgICAgICAgb24obm9ybWFsaXplZE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICdkZWZpbmVkJywgYmluZCh0aGlzLCBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLm1hcC5ub3JtYWxpemVkTWFwID0gbm9ybWFsaXplZE1hcDsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZWQ6IHRydWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlnbm9yZTogdHJ1ZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE1vZCA9IGdldE93bihyZWdpc3RyeSwgbm9ybWFsaXplZE1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChub3JtYWxpemVkTW9kKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL01hcmsgdGhpcyBhcyBhIGRlcGVuZGVuY3kgZm9yIHRoaXMgcGx1Z2luLCBzbyBpdAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9jYW4gYmUgdHJhY2VkIGZvciBjeWNsZXMuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMucHVzaChub3JtYWxpemVkTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5ldmVudHMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLm9uKCdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0lmIGEgcGF0aHMgY29uZmlnLCB0aGVuIGp1c3QgbG9hZCB0aGF0IGZpbGUgaW5zdGVhZCB0bwogICAgICAgICAgICAgICAgICAgIC8vcmVzb2x2ZSB0aGUgcGx1Z2luLCBhcyBpdCBpcyBidWlsdCBpbnRvIHRoYXQgcGF0aHMgbGF5ZXIuCiAgICAgICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMubWFwLnVybCA9IGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkKTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGxvYWQgPSBiaW5kKHRoaXMsIGZ1bmN0aW9uICh2YWx1ZSkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICBsb2FkLmVycm9yID0gYmluZCh0aGlzLCBmdW5jdGlvbiAoZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuaW5pdGVkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lcnJvciA9IGVycjsKICAgICAgICAgICAgICAgICAgICAgICAgZXJyLnJlcXVpcmVNb2R1bGVzID0gW2lkXTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vUmVtb3ZlIHRlbXAgdW5ub3JtYWxpemVkIG1vZHVsZXMgZm9yIHRoaXMgbW9kdWxlLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NpbmNlIHRoZXkgd2lsbCBuZXZlciBiZSByZXNvbHZlZCBvdGhlcndpc2Ugbm93LgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaWQuaW5kZXhPZihpZCArICdfdW5ub3JtYWxpemVkJykgPT09IDApIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjbGVhblJlZ2lzdHJ5KG1vZC5tYXAuaWQpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgICAgIG9uRXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9BbGxvdyBwbHVnaW5zIHRvIGxvYWQgb3RoZXIgY29kZSB3aXRob3V0IGhhdmluZyB0byBrbm93IHRoZQogICAgICAgICAgICAgICAgICAgIC8vY29udGV4dCBvciBob3cgdG8gJ2NvbXBsZXRlJyB0aGUgbG9hZC4KICAgICAgICAgICAgICAgICAgICBsb2FkLmZyb21UZXh0ID0gYmluZCh0aGlzLCBmdW5jdGlvbiAodGV4dCwgdGV4dEFsdCkgewogICAgICAgICAgICAgICAgICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtb2R1bGVOYW1lID0gbWFwLm5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVNYXAgPSBtYWtlTW9kdWxlTWFwKG1vZHVsZU5hbWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaGFzSW50ZXJhY3RpdmUgPSB1c2VJbnRlcmFjdGl2ZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQXMgb2YgMi4xLjAsIHN1cHBvcnQganVzdCBwYXNzaW5nIHRoZSB0ZXh0LCB0byByZWluZm9yY2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9mcm9tVGV4dCBvbmx5IGJlaW5nIGNhbGxlZCBvbmNlIHBlciByZXNvdXJjZS4gU3RpbGwKICAgICAgICAgICAgICAgICAgICAgICAgLy9zdXBwb3J0IG9sZCBzdHlsZSBvZiBwYXNzaW5nIG1vZHVsZU5hbWUgYnV0IGRpc2NhcmQKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGF0IG1vZHVsZU5hbWUgaW4gZmF2b3Igb2YgdGhlIGludGVybmFsIHJlZi4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRleHRBbHQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRleHQgPSB0ZXh0QWx0OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1R1cm4gb2ZmIGludGVyYWN0aXZlIHNjcmlwdCBtYXRjaGluZyBmb3IgSUUgZm9yIGFueSBkZWZpbmUKICAgICAgICAgICAgICAgICAgICAgICAgLy9jYWxscyBpbiB0aGUgdGV4dCwgdGhlbiB0dXJuIGl0IGJhY2sgb24gYXQgdGhlIGVuZC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc0ludGVyYWN0aXZlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1ByaW1lIHRoZSBzeXN0ZW0gYnkgY3JlYXRpbmcgYSBtb2R1bGUgaW5zdGFuY2UgZm9yCiAgICAgICAgICAgICAgICAgICAgICAgIC8vaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGdldE1vZHVsZShtb2R1bGVNYXApOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9UcmFuc2ZlciBhbnkgY29uZmlnIHRvIHRoaXMgb3RoZXIgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzUHJvcChjb25maWcuY29uZmlnLCBpZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5jb25maWdbbW9kdWxlTmFtZV0gPSBjb25maWcuY29uZmlnW2lkXTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdHJ5IHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5leGVjKHRleHQpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ2Zyb210ZXh0ZXZhbCcsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdmcm9tVGV4dCBldmFsIGZvciAnICsgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICcgZmFpbGVkOiAnICsgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW2lkXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXJrIHRoaXMgYXMgYSBkZXBlbmRlbmN5IGZvciB0aGUgcGx1Z2luCiAgICAgICAgICAgICAgICAgICAgICAgIC8vcmVzb3VyY2UKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBNYXBzLnB1c2gobW9kdWxlTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3VwcG9ydCBhbm9ueW1vdXMgbW9kdWxlcy4KICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQobW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgdGhlIHZhbHVlIG9mIHRoYXQgbW9kdWxlIHRvIHRoZSB2YWx1ZSBmb3IgdGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvL3Jlc291cmNlIElELgogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUoW21vZHVsZU5hbWVdLCBsb2FkKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9Vc2UgcGFyZW50TmFtZSBoZXJlIHNpbmNlIHRoZSBwbHVnaW4ncyBuYW1lIGlzIG5vdCByZWxpYWJsZSwKICAgICAgICAgICAgICAgICAgICAvL2NvdWxkIGJlIHNvbWUgd2VpcmQgc3RyaW5nIHdpdGggbm8gcGF0aCB0aGF0IGFjdHVhbGx5IHdhbnRzIHRvCiAgICAgICAgICAgICAgICAgICAgLy9yZWZlcmVuY2UgdGhlIHBhcmVudE5hbWUncyBwYXRoLgogICAgICAgICAgICAgICAgICAgIHBsdWdpbi5sb2FkKG1hcC5uYW1lLCBsb2NhbFJlcXVpcmUsIGxvYWQsIGNvbmZpZyk7CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgIHRoaXMucGx1Z2luTWFwc1twbHVnaW5NYXAuaWRdID0gcGx1Z2luTWFwOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZW5hYmxlOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBlbmFibGVkUmVnaXN0cnlbdGhpcy5tYXAuaWRdID0gdGhpczsKICAgICAgICAgICAgICAgIHRoaXMuZW5hYmxlZCA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9TZXQgZmxhZyBtZW50aW9uaW5nIHRoYXQgdGhlIG1vZHVsZSBpcyBlbmFibGluZywKICAgICAgICAgICAgICAgIC8vc28gdGhhdCBpbW1lZGlhdGUgY2FsbHMgdG8gdGhlIGRlZmluZWQgY2FsbGJhY2tzCiAgICAgICAgICAgICAgICAvL2ZvciBkZXBlbmRlbmNpZXMgZG8gbm90IHRyaWdnZXIgaW5hZHZlcnRlbnQgbG9hZAogICAgICAgICAgICAgICAgLy93aXRoIHRoZSBkZXBDb3VudCBzdGlsbCBiZWluZyB6ZXJvLgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9FbmFibGUgZWFjaCBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgYmluZCh0aGlzLCBmdW5jdGlvbiAoZGVwTWFwLCBpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIGlkLCBtb2QsIGhhbmRsZXI7CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwTWFwID09PSAnc3RyaW5nJykgewogICAgICAgICAgICAgICAgICAgICAgICAvL0RlcGVuZGVuY3kgbmVlZHMgdG8gYmUgY29udmVydGVkIHRvIGEgZGVwTWFwCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYW5kIHdpcmVkIHVwIHRvIHRoaXMgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBkZXBNYXAgPSBtYWtlTW9kdWxlTWFwKGRlcE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAodGhpcy5tYXAuaXNEZWZpbmUgPyB0aGlzLm1hcCA6IHRoaXMubWFwLnBhcmVudE1hcCksCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZmFsc2UsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIXRoaXMuc2tpcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwc1tpXSA9IGRlcE1hcDsKCiAgICAgICAgICAgICAgICAgICAgICAgIGhhbmRsZXIgPSBnZXRPd24oaGFuZGxlcnMsIGRlcE1hcC5pZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFuZGxlcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBFeHBvcnRzW2ldID0gaGFuZGxlcih0aGlzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBDb3VudCArPSAxOwoKICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKGRlcEV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLnVuZGVmZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZURlcChpLCBkZXBFeHBvcnRzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMuZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZXJyb3InLCBiaW5kKHRoaXMsIHRoaXMuZXJyYmFjaykpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXZlbnRzLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBObyBkaXJlY3QgZXJyYmFjayBvbiB0aGlzIG1vZHVsZSwgYnV0IHNvbWV0aGluZwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gZWxzZSBpcyBsaXN0ZW5pbmcgZm9yIGVycm9ycywgc28gYmUgc3VyZSB0bwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gcHJvcGFnYXRlIHRoZSBlcnJvciBjb3JyZWN0bHkuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBvbihkZXBNYXAsICdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24oZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lbWl0KCdlcnJvcicsIGVycik7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlkID0gZGVwTWFwLmlkOwogICAgICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXTsKCiAgICAgICAgICAgICAgICAgICAgLy9Ta2lwIHNwZWNpYWwgbW9kdWxlcyBsaWtlICdyZXF1aXJlJywgJ2V4cG9ydHMnLCAnbW9kdWxlJwogICAgICAgICAgICAgICAgICAgIC8vQWxzbywgZG9uJ3QgY2FsbCBlbmFibGUgaWYgaXQgaXMgYWxyZWFkeSBlbmFibGVkLAogICAgICAgICAgICAgICAgICAgIC8vaW1wb3J0YW50IGluIGNpcmN1bGFyIGRlcGVuZGVuY3kgY2FzZXMuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGhhbmRsZXJzLCBpZCkgJiYgbW9kICYmICFtb2QuZW5hYmxlZCkgewogICAgICAgICAgICAgICAgICAgICAgICBjb250ZXh0LmVuYWJsZShkZXBNYXAsIHRoaXMpOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pKTsKCiAgICAgICAgICAgICAgICAvL0VuYWJsZSBlYWNoIHBsdWdpbiB0aGF0IGlzIHVzZWQgaW4KICAgICAgICAgICAgICAgIC8vYSBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoUHJvcCh0aGlzLnBsdWdpbk1hcHMsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbk1hcCkgewogICAgICAgICAgICAgICAgICAgIHZhciBtb2QgPSBnZXRPd24ocmVnaXN0cnksIHBsdWdpbk1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCAmJiAhbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IGZhbHNlOwoKICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG9uOiBmdW5jdGlvbiAobmFtZSwgY2IpIHsKICAgICAgICAgICAgICAgIHZhciBjYnMgPSB0aGlzLmV2ZW50c1tuYW1lXTsKICAgICAgICAgICAgICAgIGlmICghY2JzKSB7CiAgICAgICAgICAgICAgICAgICAgY2JzID0gdGhpcy5ldmVudHNbbmFtZV0gPSBbXTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNicy5wdXNoKGNiKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGVtaXQ6IGZ1bmN0aW9uIChuYW1lLCBldnQpIHsKICAgICAgICAgICAgICAgIGVhY2godGhpcy5ldmVudHNbbmFtZV0sIGZ1bmN0aW9uIChjYikgewogICAgICAgICAgICAgICAgICAgIGNiKGV2dCk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGlmIChuYW1lID09PSAnZXJyb3InKSB7CiAgICAgICAgICAgICAgICAgICAgLy9Ob3cgdGhhdCB0aGUgZXJyb3IgaGFuZGxlciB3YXMgdHJpZ2dlcmVkLCByZW1vdmUKICAgICAgICAgICAgICAgICAgICAvL3RoZSBsaXN0ZW5lcnMsIHNpbmNlIHRoaXMgYnJva2VuIE1vZHVsZSBpbnN0YW5jZQogICAgICAgICAgICAgICAgICAgIC8vY2FuIHN0YXkgYXJvdW5kIGZvciBhIHdoaWxlIGluIHRoZSByZWdpc3RyeS4KICAgICAgICAgICAgICAgICAgICBkZWxldGUgdGhpcy5ldmVudHNbbmFtZV07CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjYWxsR2V0TW9kdWxlKGFyZ3MpIHsKICAgICAgICAgICAgLy9Ta2lwIG1vZHVsZXMgYWxyZWFkeSBkZWZpbmVkLgogICAgICAgICAgICBpZiAoIWhhc1Byb3AoZGVmaW5lZCwgYXJnc1swXSkpIHsKICAgICAgICAgICAgICAgIGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKGFyZ3NbMF0sIG51bGwsIHRydWUpKS5pbml0KGFyZ3NbMV0sIGFyZ3NbMl0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiByZW1vdmVMaXN0ZW5lcihub2RlLCBmdW5jLCBuYW1lLCBpZU5hbWUpIHsKICAgICAgICAgICAgLy9GYXZvciBkZXRhY2hFdmVudCBiZWNhdXNlIG9mIElFOQogICAgICAgICAgICAvL2lzc3VlLCBzZWUgYXR0YWNoRXZlbnQvYWRkRXZlbnRMaXN0ZW5lciBjb21tZW50IGVsc2V3aGVyZQogICAgICAgICAgICAvL2luIHRoaXMgZmlsZS4KICAgICAgICAgICAgaWYgKG5vZGUuZGV0YWNoRXZlbnQgJiYgIWlzT3BlcmEpIHsKICAgICAgICAgICAgICAgIC8vUHJvYmFibHkgSUUuIElmIG5vdCBpdCB3aWxsIHRocm93IGFuIGVycm9yLCB3aGljaCB3aWxsIGJlCiAgICAgICAgICAgICAgICAvL3VzZWZ1bCB0byBrbm93LgogICAgICAgICAgICAgICAgaWYgKGllTmFtZSkgewogICAgICAgICAgICAgICAgICAgIG5vZGUuZGV0YWNoRXZlbnQoaWVOYW1lLCBmdW5jKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUucmVtb3ZlRXZlbnRMaXN0ZW5lcihuYW1lLCBmdW5jLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIEdpdmVuIGFuIGV2ZW50IGZyb20gYSBzY3JpcHQgbm9kZSwgZ2V0IHRoZSByZXF1aXJlanMgaW5mbyBmcm9tIGl0LAogICAgICAgICAqIGFuZCB0aGVuIHJlbW92ZXMgdGhlIGV2ZW50IGxpc3RlbmVycyBvbiB0aGUgbm9kZS4KICAgICAgICAgKiBAcGFyYW0ge0V2ZW50fSBldnQKICAgICAgICAgKiBAcmV0dXJucyB7T2JqZWN0fQogICAgICAgICAqLwogICAgICAgIGZ1bmN0aW9uIGdldFNjcmlwdERhdGEoZXZ0KSB7CiAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgLy9hbGwgb2xkIGJyb3dzZXJzIHdpbGwgYmUgc3VwcG9ydGVkLCBidXQgdGhpcyBvbmUgd2FzIGVhc3kgZW5vdWdoCiAgICAgICAgICAgIC8vdG8gc3VwcG9ydCBhbmQgc3RpbGwgbWFrZXMgc2Vuc2UuCiAgICAgICAgICAgIHZhciBub2RlID0gZXZ0LmN1cnJlbnRUYXJnZXQgfHwgZXZ0LnNyY0VsZW1lbnQ7CgogICAgICAgICAgICAvL1JlbW92ZSB0aGUgbGlzdGVuZXJzIG9uY2UgaGVyZS4KICAgICAgICAgICAgcmVtb3ZlTGlzdGVuZXIobm9kZSwgY29udGV4dC5vblNjcmlwdExvYWQsICdsb2FkJywgJ29ucmVhZHlzdGF0ZWNoYW5nZScpOwogICAgICAgICAgICByZW1vdmVMaXN0ZW5lcihub2RlLCBjb250ZXh0Lm9uU2NyaXB0RXJyb3IsICdlcnJvcicpOwoKICAgICAgICAgICAgcmV0dXJuIHsKICAgICAgICAgICAgICAgIG5vZGU6IG5vZGUsCiAgICAgICAgICAgICAgICBpZDogbm9kZSAmJiBub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJykKICAgICAgICAgICAgfTsKICAgICAgICB9CgogICAgICAgIGZ1bmN0aW9uIGludGFrZURlZmluZXMoKSB7CiAgICAgICAgICAgIHZhciBhcmdzOwoKICAgICAgICAgICAgLy9BbnkgZGVmaW5lZCBtb2R1bGVzIGluIHRoZSBnbG9iYWwgcXVldWUsIGludGFrZSB0aGVtIG5vdy4KICAgICAgICAgICAgdGFrZUdsb2JhbFF1ZXVlKCk7CgogICAgICAgICAgICAvL01ha2Ugc3VyZSBhbnkgcmVtYWluaW5nIGRlZlF1ZXVlIGl0ZW1zIGdldCBwcm9wZXJseSBwcm9jZXNzZWQuCiAgICAgICAgICAgIHdoaWxlIChkZWZRdWV1ZS5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGFyZ3MgPSBkZWZRdWV1ZS5zaGlmdCgpOwogICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ21pc21hdGNoJywgJ01pc21hdGNoZWQgYW5vbnltb3VzIGRlZmluZSgpIG1vZHVsZTogJyArCiAgICAgICAgICAgICAgICAgICAgICAgIGFyZ3NbYXJncy5sZW5ndGggLSAxXSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL2FyZ3MgYXJlIGlkLCBkZXBzLCBmYWN0b3J5LiBTaG91bGQgYmUgbm9ybWFsaXplZCBieSB0aGUKICAgICAgICAgICAgICAgICAgICAvL2RlZmluZSgpIGZ1bmN0aW9uLgogICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoYXJncyk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcCA9IHt9OwogICAgICAgIH0KCiAgICAgICAgY29udGV4dCA9IHsKICAgICAgICAgICAgY29uZmlnOiBjb25maWcsCiAgICAgICAgICAgIGNvbnRleHROYW1lOiBjb250ZXh0TmFtZSwKICAgICAgICAgICAgcmVnaXN0cnk6IHJlZ2lzdHJ5LAogICAgICAgICAgICBkZWZpbmVkOiBkZWZpbmVkLAogICAgICAgICAgICB1cmxGZXRjaGVkOiB1cmxGZXRjaGVkLAogICAgICAgICAgICBkZWZRdWV1ZTogZGVmUXVldWUsCiAgICAgICAgICAgIGRlZlF1ZXVlTWFwOiB7fSwKICAgICAgICAgICAgTW9kdWxlOiBNb2R1bGUsCiAgICAgICAgICAgIG1ha2VNb2R1bGVNYXA6IG1ha2VNb2R1bGVNYXAsCiAgICAgICAgICAgIG5leHRUaWNrOiByZXEubmV4dFRpY2ssCiAgICAgICAgICAgIG9uRXJyb3I6IG9uRXJyb3IsCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogU2V0IGEgY29uZmlndXJhdGlvbiBmb3IgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAqIEBwYXJhbSB7T2JqZWN0fSBjZmcgY29uZmlnIG9iamVjdCB0byBpbnRlZ3JhdGUuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBjb25maWd1cmU6IGZ1bmN0aW9uIChjZmcpIHsKICAgICAgICAgICAgICAgIC8vTWFrZSBzdXJlIHRoZSBiYXNlVXJsIGVuZHMgaW4gYSBzbGFzaC4KICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybC5jaGFyQXQoY2ZnLmJhc2VVcmwubGVuZ3RoIC0gMSkgIT09ICcvJykgewogICAgICAgICAgICAgICAgICAgICAgICBjZmcuYmFzZVVybCArPSAnLyc7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vU2F2ZSBvZmYgdGhlIHBhdGhzIHNpbmNlIHRoZXkgcmVxdWlyZSBzcGVjaWFsIHByb2Nlc3NpbmcsCiAgICAgICAgICAgICAgICAvL3RoZXkgYXJlIGFkZGl0aXZlLgogICAgICAgICAgICAgICAgdmFyIHNoaW0gPSBjb25maWcuc2hpbSwKICAgICAgICAgICAgICAgICAgICBvYmpzID0gewogICAgICAgICAgICAgICAgICAgICAgICBwYXRoczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnOiB0cnVlLAogICAgICAgICAgICAgICAgICAgICAgICBtYXA6IHRydWUKICAgICAgICAgICAgICAgICAgICB9OwoKICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZywgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKG9ianNbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFjb25maWdbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHt9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIG1peGluKGNvbmZpZ1twcm9wXSwgdmFsdWUsIHRydWUsIHRydWUpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHZhbHVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIC8vUmV2ZXJzZSBtYXAgdGhlIGJ1bmRsZXMKICAgICAgICAgICAgICAgIGlmIChjZmcuYnVuZGxlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZy5idW5kbGVzLCBmdW5jdGlvbiAodmFsdWUsIHByb3ApIHsKICAgICAgICAgICAgICAgICAgICAgICAgZWFjaCh2YWx1ZSwgZnVuY3Rpb24gKHYpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh2ICE9PSBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlc01hcFt2XSA9IHByb3A7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vTWVyZ2Ugc2hpbQogICAgICAgICAgICAgICAgaWYgKGNmZy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgZWFjaFByb3AoY2ZnLnNoaW0sIGZ1bmN0aW9uICh2YWx1ZSwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9Ob3JtYWxpemUgdGhlIHN0cnVjdHVyZQogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaXNBcnJheSh2YWx1ZSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHZhbHVlID0gewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGRlcHM6IHZhbHVlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodmFsdWUuZXhwb3J0cyB8fCB2YWx1ZS5pbml0KSAmJiAhdmFsdWUuZXhwb3J0c0ZuKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YWx1ZS5leHBvcnRzRm4gPSBjb250ZXh0Lm1ha2VTaGltRXhwb3J0cyh2YWx1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgc2hpbVtpZF0gPSB2YWx1ZTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICBjb25maWcuc2hpbSA9IHNoaW07CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgLy9BZGp1c3QgcGFja2FnZXMgaWYgbmVjZXNzYXJ5LgogICAgICAgICAgICAgICAgaWYgKGNmZy5wYWNrYWdlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2goY2ZnLnBhY2thZ2VzLCBmdW5jdGlvbiAocGtnT2JqKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBsb2NhdGlvbiwgbmFtZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHBrZ09iaiA9IHR5cGVvZiBwa2dPYmogPT09ICdzdHJpbmcnID8ge25hbWU6IHBrZ09ian0gOiBwa2dPYmo7CgogICAgICAgICAgICAgICAgICAgICAgICBuYW1lID0gcGtnT2JqLm5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIGxvY2F0aW9uID0gcGtnT2JqLmxvY2F0aW9uOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobG9jYXRpb24pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5wYXRoc1tuYW1lXSA9IHBrZ09iai5sb2NhdGlvbjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TYXZlIHBvaW50ZXIgdG8gbWFpbiBtb2R1bGUgSUQgZm9yIHBrZyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAvL1JlbW92ZSBsZWFkaW5nIGRvdCBpbiBtYWluLCBzbyBtYWluIHBhdGhzIGFyZSBub3JtYWxpemVkLAogICAgICAgICAgICAgICAgICAgICAgICAvL2FuZCByZW1vdmUgYW55IHRyYWlsaW5nIC5qcywgc2luY2UgZGlmZmVyZW50IHBhY2thZ2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9lbnZzIGhhdmUgZGlmZmVyZW50IGNvbnZlbnRpb25zOiBzb21lIHVzZSBhIG1vZHVsZSBuYW1lLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NvbWUgdXNlIGEgZmlsZSBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICBjb25maWcucGtnc1tuYW1lXSA9IHBrZ09iai5uYW1lICsgJy8nICsgKHBrZ09iai5tYWluIHx8ICdtYWluJykKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC5yZXBsYWNlKGN1cnJEaXJSZWdFeHAsICcnKQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0lmIHRoZXJlIGFyZSBhbnkgIndhaXRpbmcgdG8gZXhlY3V0ZSIgbW9kdWxlcyBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAvL3VwZGF0ZSB0aGUgbWFwcyBmb3IgdGhlbSwgc2luY2UgdGhlaXIgaW5mbywgbGlrZSBVUkxzIHRvIGxvYWQsCiAgICAgICAgICAgICAgICAvL21heSBoYXZlIGNoYW5nZWQuCiAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIG1vZHVsZSBhbHJlYWR5IGhhcyBpbml0IGNhbGxlZCwgc2luY2UgaXQgaXMgdG9vCiAgICAgICAgICAgICAgICAgICAgLy9sYXRlIHRvIG1vZGlmeSB0aGVtLCBhbmQgaWdub3JlIHVubm9ybWFsaXplZCBvbmVzCiAgICAgICAgICAgICAgICAgICAgLy9zaW5jZSB0aGV5IGFyZSB0cmFuc2llbnQuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFtb2QuaW5pdGVkICYmICFtb2QubWFwLnVubm9ybWFsaXplZCkgewogICAgICAgICAgICAgICAgICAgICAgICBtb2QubWFwID0gbWFrZU1vZHVsZU1hcChpZCwgbnVsbCwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9JZiBhIGRlcHMgYXJyYXkgb3IgYSBjb25maWcgY2FsbGJhY2sgaXMgc3BlY2lmaWVkLCB0aGVuIGNhbGwKICAgICAgICAgICAgICAgIC8vcmVxdWlyZSB3aXRoIHRob3NlIGFyZ3MuIFRoaXMgaXMgdXNlZnVsIHdoZW4gcmVxdWlyZSBpcyBkZWZpbmVkIGFzIGEKICAgICAgICAgICAgICAgIC8vY29uZmlnIG9iamVjdCBiZWZvcmUgcmVxdWlyZS5qcyBpcyBsb2FkZWQuCiAgICAgICAgICAgICAgICBpZiAoY2ZnLmRlcHMgfHwgY2ZnLmNhbGxiYWNrKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlKGNmZy5kZXBzIHx8IFtdLCBjZmcuY2FsbGJhY2spOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgbWFrZVNoaW1FeHBvcnRzOiBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgIGZ1bmN0aW9uIGZuKCkgewogICAgICAgICAgICAgICAgICAgIHZhciByZXQ7CiAgICAgICAgICAgICAgICAgICAgaWYgKHZhbHVlLmluaXQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0ID0gdmFsdWUuaW5pdC5hcHBseShnbG9iYWwsIGFyZ3VtZW50cyk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIHJldHVybiByZXQgfHwgKHZhbHVlLmV4cG9ydHMgJiYgZ2V0R2xvYmFsKHZhbHVlLmV4cG9ydHMpKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIHJldHVybiBmbjsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG1ha2VSZXF1aXJlOiBmdW5jdGlvbiAocmVsTWFwLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICBmdW5jdGlvbiBsb2NhbFJlcXVpcmUoZGVwcywgY2FsbGJhY2ssIGVycmJhY2spIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQsIG1hcCwgcmVxdWlyZU1vZDsKCiAgICAgICAgICAgICAgICAgICAgaWYgKG9wdGlvbnMuZW5hYmxlQnVpbGRDYWxsYmFjayAmJiBjYWxsYmFjayAmJiBpc0Z1bmN0aW9uKGNhbGxiYWNrKSkgewogICAgICAgICAgICAgICAgICAgICAgICBjYWxsYmFjay5fX3JlcXVpcmVKc0J1aWxkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwcyA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzRnVuY3Rpb24oY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0ludmFsaWQgY2FsbAogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdyZXF1aXJlYXJncycsICdJbnZhbGlkIHJlcXVpcmUgY2FsbCcpLCBlcnJiYWNrKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiByZXF1aXJlfGV4cG9ydHN8bW9kdWxlIGFyZSByZXF1ZXN0ZWQsIGdldCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy92YWx1ZSBmb3IgdGhlbSBmcm9tIHRoZSBzcGVjaWFsIGhhbmRsZXJzLiBDYXZlYXQ6CiAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBvbmx5IHdvcmtzIHdoaWxlIG1vZHVsZSBpcyBiZWluZyBkZWZpbmVkLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAocmVsTWFwICYmIGhhc1Byb3AoaGFuZGxlcnMsIGRlcHMpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gaGFuZGxlcnNbZGVwc10ocmVnaXN0cnlbcmVsTWFwLmlkXSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3luY2hyb25vdXMgYWNjZXNzIHRvIG9uZSBtb2R1bGUuIElmIHJlcXVpcmUuZ2V0IGlzCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYXZhaWxhYmxlIChhcyBpbiB0aGUgTm9kZSBhZGFwdGVyKSwgcHJlZmVyIHRoYXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEuZ2V0KSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gcmVxLmdldChjb250ZXh0LCBkZXBzLCByZWxNYXAsIGxvY2FsUmVxdWlyZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIG1vZHVsZSBuYW1lLCBpZiBpdCBjb250YWlucyAuIG9yIC4uCiAgICAgICAgICAgICAgICAgICAgICAgIG1hcCA9IG1ha2VNb2R1bGVNYXAoZGVwcywgcmVsTWFwLCBmYWxzZSwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFwLmlkOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGRlZmluZWQsIGlkKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub3Rsb2FkZWQnLCAnTW9kdWxlIG5hbWUgIicgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIgaGFzIG5vdCBiZWVuIGxvYWRlZCB5ZXQgZm9yIGNvbnRleHQ6ICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dE5hbWUgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHJlbE1hcCA/ICcnIDogJy4gVXNlIHJlcXVpcmUoW10pJykpKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0dyYWIgZGVmaW5lcyB3YWl0aW5nIGluIHRoZSBnbG9iYWwgcXVldWUuCiAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAvL01hcmsgYWxsIHRoZSBkZXBlbmRlbmNpZXMgYXMgbmVlZGluZyB0byBiZSBsb2FkZWQuCiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5uZXh0VGljayhmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vU29tZSBkZWZpbmVzIGNvdWxkIGhhdmUgYmVlbiBhZGRlZCBzaW5jZSB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9yZXF1aXJlIGNhbGwsIGNvbGxlY3QgdGhlbS4KICAgICAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZCA9IGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKG51bGwsIHJlbE1hcCkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TdG9yZSBpZiBtYXAgY29uZmlnIHNob3VsZCBiZSBhcHBsaWVkIHRvIHRoaXMgcmVxdWlyZQogICAgICAgICAgICAgICAgICAgICAgICAvL2NhbGwgZm9yIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5za2lwTWFwID0gb3B0aW9ucy5za2lwTWFwOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5pbml0KGRlcHMsIGNhbGxiYWNrLCBlcnJiYWNrLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICAgICAgY2hlY2tMb2FkZWQoKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGxvY2FsUmVxdWlyZTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBtaXhpbihsb2NhbFJlcXVpcmUsIHsKICAgICAgICAgICAgICAgICAgICBpc0Jyb3dzZXI6IGlzQnJvd3NlciwKCiAgICAgICAgICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAgICAgICAgICogQ29udmVydHMgYSBtb2R1bGUgbmFtZSArIC5leHRlbnNpb24gaW50byBhbiBVUkwgcGF0aC4KICAgICAgICAgICAgICAgICAgICAgKiAqUmVxdWlyZXMqIHRoZSB1c2Ugb2YgYSBtb2R1bGUgbmFtZS4gSXQgZG9lcyBub3Qgc3VwcG9ydCB1c2luZwogICAgICAgICAgICAgICAgICAgICAqIHBsYWluIFVSTHMgbGlrZSBuYW1lVG9VcmwuCiAgICAgICAgICAgICAgICAgICAgICovCiAgICAgICAgICAgICAgICAgICAgdG9Vcmw6IGZ1bmN0aW9uIChtb2R1bGVOYW1lUGx1c0V4dCkgewogICAgICAgICAgICAgICAgICAgICAgICB2YXIgZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaW5kZXggPSBtb2R1bGVOYW1lUGx1c0V4dC5sYXN0SW5kZXhPZignLicpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2VnbWVudCA9IG1vZHVsZU5hbWVQbHVzRXh0LnNwbGl0KCcvJylbMF0sCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpc1JlbGF0aXZlID0gc2VnbWVudCA9PT0gJy4nIHx8IHNlZ21lbnQgPT09ICcuLic7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0hhdmUgYSBmaWxlIGV4dGVuc2lvbiBhbGlhcywgYW5kIGl0IGlzIG5vdCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9kb3RzIGZyb20gYSByZWxhdGl2ZSBwYXRoLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaW5kZXggIT09IC0xICYmICghaXNSZWxhdGl2ZSB8fCBpbmRleCA+IDEpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBleHQgPSBtb2R1bGVOYW1lUGx1c0V4dC5zdWJzdHJpbmcoaW5kZXgsIG1vZHVsZU5hbWVQbHVzRXh0Lmxlbmd0aCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lUGx1c0V4dCA9IG1vZHVsZU5hbWVQbHVzRXh0LnN1YnN0cmluZygwLCBpbmRleCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBjb250ZXh0Lm5hbWVUb1VybChub3JtYWxpemUobW9kdWxlTmFtZVBsdXNFeHQsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbE1hcCAmJiByZWxNYXAuaWQsIHRydWUpLCBleHQsICB0cnVlKTsKICAgICAgICAgICAgICAgICAgICB9LAoKICAgICAgICAgICAgICAgICAgICBkZWZpbmVkOiBmdW5jdGlvbiAoaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGhhc1Byb3AoZGVmaW5lZCwgbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQpOwogICAgICAgICAgICAgICAgICAgIH0sCgogICAgICAgICAgICAgICAgICAgIHNwZWNpZmllZDogZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQ7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBoYXNQcm9wKGRlZmluZWQsIGlkKSB8fCBoYXNQcm9wKHJlZ2lzdHJ5LCBpZCk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9Pbmx5IGFsbG93IHVuZGVmIG9uIHRvcCBsZXZlbCByZXF1aXJlIGNhbGxzCiAgICAgICAgICAgICAgICBpZiAoIXJlbE1hcCkgewogICAgICAgICAgICAgICAgICAgIGxvY2FsUmVxdWlyZS51bmRlZiA9IGZ1bmN0aW9uIChpZCkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgYW55IHdhaXRpbmcgZGVmaW5lKCkgY2FsbHMgdG8gdGhpcyBjb250ZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAvL2ZpeCBmb3IgIzQwOAogICAgICAgICAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtYXAgPSBtYWtlTW9kdWxlTWFwKGlkLCByZWxNYXAsIHRydWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kID0gZ2V0T3duKHJlZ2lzdHJ5LCBpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBtb2QudW5kZWZlZCA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgICAgIGRlbGV0ZSB1cmxGZXRjaGVkW21hcC51cmxdOwogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgdW5kZWZFdmVudHNbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9DbGVhbiBxdWV1ZWQgZGVmaW5lcyB0b28uIEdvIGJhY2t3YXJkcwogICAgICAgICAgICAgICAgICAgICAgICAvL2luIGFycmF5IHNvIHRoYXQgdGhlIHNwbGljZXMgZG8gbm90CiAgICAgICAgICAgICAgICAgICAgICAgIC8vbWVzcyB1cCB0aGUgaXRlcmF0aW9uLgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUmV2ZXJzZShkZWZRdWV1ZSwgZnVuY3Rpb24oYXJncywgaSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmUXVldWUuc3BsaWNlKGksIDEpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICAgICAgZGVsZXRlIGNvbnRleHQuZGVmUXVldWVNYXBbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9Ib2xkIG9uIHRvIGxpc3RlbmVycyBpbiBjYXNlIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9tb2R1bGUgd2lsbCBiZSBhdHRlbXB0ZWQgdG8gYmUgcmVsb2FkZWQKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdXNpbmcgYSBkaWZmZXJlbnQgY29uZmlnLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5ldmVudHMuZGVmaW5lZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVuZGVmRXZlbnRzW2lkXSA9IG1vZC5ldmVudHM7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBsb2NhbFJlcXVpcmU7CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGVkIHRvIGVuYWJsZSBhIG1vZHVsZSBpZiBpdCBpcyBzdGlsbCBpbiB0aGUgcmVnaXN0cnkKICAgICAgICAgICAgICogYXdhaXRpbmcgZW5hYmxlbWVudC4gQSBzZWNvbmQgYXJnLCBwYXJlbnQsIHRoZSBwYXJlbnQgbW9kdWxlLAogICAgICAgICAgICAgKiBpcyBwYXNzZWQgaW4gZm9yIGNvbnRleHQsIHdoZW4gdGhpcyBtZXRob2QgaXMgb3ZlcnJpZGRlbiBieQogICAgICAgICAgICAgKiB0aGUgb3B0aW1pemVyLiBOb3Qgc2hvd24gaGVyZSB0byBrZWVwIGNvZGUgY29tcGFjdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGVuYWJsZTogZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgZGVwTWFwLmlkKTsKICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICBnZXRNb2R1bGUoZGVwTWFwKS5lbmFibGUoKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdXNlZCBieSBlbnZpcm9ubWVudCBhZGFwdGVycyB0byBjb21wbGV0ZSBhIGxvYWQgZXZlbnQuCiAgICAgICAgICAgICAqIEEgbG9hZCBldmVudCBjb3VsZCBiZSBhIHNjcmlwdCBsb2FkIG9yIGp1c3QgYSBsb2FkIHBhc3MgZnJvbSBhIHN5bmNocm9ub3VzCiAgICAgICAgICAgICAqIGxvYWQgY2FsbC4KICAgICAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSB0byBwb3RlbnRpYWxseSBjb21wbGV0ZS4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNvbXBsZXRlTG9hZDogZnVuY3Rpb24gKG1vZHVsZU5hbWUpIHsKICAgICAgICAgICAgICAgIHZhciBmb3VuZCwgYXJncywgbW9kLAogICAgICAgICAgICAgICAgICAgIHNoaW0gPSBnZXRPd24oY29uZmlnLnNoaW0sIG1vZHVsZU5hbWUpIHx8IHt9LAogICAgICAgICAgICAgICAgICAgIHNoRXhwb3J0cyA9IHNoaW0uZXhwb3J0czsKCiAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICB3aGlsZSAoZGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICAgICAgYXJncyA9IGRlZlF1ZXVlLnNoaWZ0KCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICAgICAgYXJnc1swXSA9IG1vZHVsZU5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vSWYgYWxyZWFkeSBmb3VuZCBhbiBhbm9ueW1vdXMgbW9kdWxlIGFuZCBib3VuZCBpdAogICAgICAgICAgICAgICAgICAgICAgICAvL3RvIHRoaXMgbmFtZSwgdGhlbiB0aGlzIGlzIHNvbWUgb3RoZXIgYW5vbiBtb2R1bGUKICAgICAgICAgICAgICAgICAgICAgICAgLy93YWl0aW5nIGZvciBpdHMgY29tcGxldGVMb2FkIHRvIGZpcmUuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChmb3VuZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAoYXJnc1swXSA9PT0gbW9kdWxlTmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0ZvdW5kIG1hdGNoaW5nIGRlZmluZSBjYWxsIGZvciB0aGlzIHNjcmlwdCEKICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgY2FsbEdldE1vZHVsZShhcmdzKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQuZGVmUXVldWVNYXAgPSB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIHRoaXMgYWZ0ZXIgdGhlIGN5Y2xlIG9mIGNhbGxHZXRNb2R1bGUgaW4gY2FzZSB0aGUgcmVzdWx0CiAgICAgICAgICAgICAgICAvL29mIHRob3NlIGNhbGxzL2luaXQgY2FsbHMgY2hhbmdlcyB0aGUgcmVnaXN0cnkuCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmICghZm91bmQgJiYgIWhhc1Byb3AoZGVmaW5lZCwgbW9kdWxlTmFtZSkgJiYgbW9kICYmICFtb2QuaW5pdGVkKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKGNvbmZpZy5lbmZvcmNlRGVmaW5lICYmICghc2hFeHBvcnRzIHx8ICFnZXRHbG9iYWwoc2hFeHBvcnRzKSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc1BhdGhGYWxsYmFjayhtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub2RlZmluZScsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdObyBkZWZpbmUgY2FsbCBmb3IgJyArIG1vZHVsZU5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIG51bGwsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIFttb2R1bGVOYW1lXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9BIHNjcmlwdCB0aGF0IGRvZXMgbm90IGNhbGwgZGVmaW5lKCksIHNvIGp1c3Qgc2ltdWxhdGUKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGUgY2FsbCBmb3IgaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoW21vZHVsZU5hbWUsIChzaGltLmRlcHMgfHwgW10pLCBzaGltLmV4cG9ydHNGbl0pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBjaGVja0xvYWRlZCgpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIENvbnZlcnRzIGEgbW9kdWxlIG5hbWUgdG8gYSBmaWxlIHBhdGguIFN1cHBvcnRzIGNhc2VzIHdoZXJlCiAgICAgICAgICAgICAqIG1vZHVsZU5hbWUgbWF5IGFjdHVhbGx5IGJlIGp1c3QgYW4gVVJMLgogICAgICAgICAgICAgKiBOb3RlIHRoYXQgaXQgKipkb2VzIG5vdCoqIGNhbGwgbm9ybWFsaXplIG9uIHRoZSBtb2R1bGVOYW1lLAogICAgICAgICAgICAgKiBpdCBpcyBhc3N1bWVkIHRvIGhhdmUgYWxyZWFkeSBiZWVuIG5vcm1hbGl6ZWQuIFRoaXMgaXMgYW4KICAgICAgICAgICAgICogaW50ZXJuYWwgQVBJLCBub3QgYSBwdWJsaWMgb25lLiBVc2UgdG9VcmwgZm9yIHRoZSBwdWJsaWMgQVBJLgogICAgICAgICAgICAgKi8KICAgICAgICAgICAgbmFtZVRvVXJsOiBmdW5jdGlvbiAobW9kdWxlTmFtZSwgZXh0LCBza2lwRXh0KSB7CiAgICAgICAgICAgICAgICB2YXIgcGF0aHMsIHN5bXMsIGksIHBhcmVudE1vZHVsZSwgdXJsLAogICAgICAgICAgICAgICAgICAgIHBhcmVudFBhdGgsIGJ1bmRsZUlkLAogICAgICAgICAgICAgICAgICAgIHBrZ01haW4gPSBnZXRPd24oY29uZmlnLnBrZ3MsIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmIChwa2dNYWluKSB7CiAgICAgICAgICAgICAgICAgICAgbW9kdWxlTmFtZSA9IHBrZ01haW47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgYnVuZGxlSWQgPSBnZXRPd24oYnVuZGxlc01hcCwgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkLCBleHQsIHNraXBFeHQpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vSWYgYSBjb2xvbiBpcyBpbiB0aGUgVVJMLCBpdCBpbmRpY2F0ZXMgYSBwcm90b2NvbCBpcyB1c2VkIGFuZCBpdCBpcyBqdXN0CiAgICAgICAgICAgICAgICAvL2FuIFVSTCB0byBhIGZpbGUsIG9yIGlmIGl0IHN0YXJ0cyB3aXRoIGEgc2xhc2gsIGNvbnRhaW5zIGEgcXVlcnkgYXJnIChpLmUuID8pCiAgICAgICAgICAgICAgICAvL29yIGVuZHMgd2l0aCAuanMsIHRoZW4gYXNzdW1lIHRoZSB1c2VyIG1lYW50IHRvIHVzZSBhbiB1cmwgYW5kIG5vdCBhIG1vZHVsZSBpZC4KICAgICAgICAgICAgICAgIC8vVGhlIHNsYXNoIGlzIGltcG9ydGFudCBmb3IgcHJvdG9jb2wtbGVzcyBVUkxzIGFzIHdlbGwgYXMgZnVsbCBwYXRocy4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIHBsYWluIHBhdGgsIG5vdCBtb2R1bGUgbmFtZSBsb29rdXAsIHNvIGp1c3QgcmV0dXJuIGl0LgogICAgICAgICAgICAgICAgICAgIC8vQWRkIGV4dGVuc2lvbiBpZiBpdCBpcyBpbmNsdWRlZC4gVGhpcyBpcyBhIGJpdCB3b25reSwgb25seSBub24tLmpzIHRoaW5ncyBwYXNzCiAgICAgICAgICAgICAgICAgICAgLy9hbiBleHRlbnNpb24sIHRoaXMgbWV0aG9kIHByb2JhYmx5IG5lZWRzIHRvIGJlIHJld29ya2VkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IG1vZHVsZU5hbWUgKyAoZXh0IHx8ICcnKTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgLy9BIG1vZHVsZSB0aGF0IG5lZWRzIHRvIGJlIGNvbnZlcnRlZCB0byBhIHBhdGguCiAgICAgICAgICAgICAgICAgICAgcGF0aHMgPSBjb25maWcucGF0aHM7CgogICAgICAgICAgICAgICAgICAgIHN5bXMgPSBtb2R1bGVOYW1lLnNwbGl0KCcvJyk7CiAgICAgICAgICAgICAgICAgICAgLy9Gb3IgZWFjaCBtb2R1bGUgbmFtZSBzZWdtZW50LCBzZWUgaWYgdGhlcmUgaXMgYSBwYXRoCiAgICAgICAgICAgICAgICAgICAgLy9yZWdpc3RlcmVkIGZvciBpdC4gU3RhcnQgd2l0aCBtb3N0IHNwZWNpZmljIG5hbWUKICAgICAgICAgICAgICAgICAgICAvL2FuZCB3b3JrIHVwIGZyb20gaXQuCiAgICAgICAgICAgICAgICAgICAgZm9yIChpID0gc3ltcy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50TW9kdWxlID0gc3ltcy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gZ2V0T3duKHBhdGhzLCBwYXJlbnRNb2R1bGUpOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAocGFyZW50UGF0aCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiBhbiBhcnJheSwgaXQgbWVhbnMgdGhlcmUgYXJlIGEgZmV3IGNob2ljZXMsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0Nob29zZSB0aGUgb25lIHRoYXQgaXMgZGVzaXJlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzQXJyYXkocGFyZW50UGF0aCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gcGFyZW50UGF0aFswXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIHN5bXMuc3BsaWNlKDAsIGksIHBhcmVudFBhdGgpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIC8vSm9pbiB0aGUgcGF0aCBwYXJ0cyB0b2dldGhlciwgdGhlbiBmaWd1cmUgb3V0IGlmIGJhc2VVcmwgaXMgbmVlZGVkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IHN5bXMuam9pbignLycpOwogICAgICAgICAgICAgICAgICAgIHVybCArPSAoZXh0IHx8ICgvXmRhdGFcOnxcPy8udGVzdCh1cmwpIHx8IHNraXBFeHQgPyAnJyA6ICcuanMnKSk7CiAgICAgICAgICAgICAgICAgICAgdXJsID0gKHVybC5jaGFyQXQoMCkgPT09ICcvJyB8fCB1cmwubWF0Y2goL15bXHdcK1wuXC1dKzovKSA/ICcnIDogY29uZmlnLmJhc2VVcmwpICsgdXJsOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBjb25maWcudXJsQXJncyA/IHVybCArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAoKHVybC5pbmRleE9mKCc/JykgPT09IC0xID8gJz8nIDogJyYnKSArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnLnVybEFyZ3MpIDogdXJsOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLy9EZWxlZ2F0ZXMgdG8gcmVxLmxvYWQuIEJyb2tlbiBvdXQgYXMgYSBzZXBhcmF0ZSBmdW5jdGlvbiB0bwogICAgICAgICAgICAvL2FsbG93IG92ZXJyaWRpbmcgaW4gdGhlIG9wdGltaXplci4KICAgICAgICAgICAgbG9hZDogZnVuY3Rpb24gKGlkLCB1cmwpIHsKICAgICAgICAgICAgICAgIHJlcS5sb2FkKGNvbnRleHQsIGlkLCB1cmwpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIEV4ZWN1dGVzIGEgbW9kdWxlIGNhbGxiYWNrIGZ1bmN0aW9uLiBCcm9rZW4gb3V0IGFzIGEgc2VwYXJhdGUgZnVuY3Rpb24KICAgICAgICAgICAgICogc29sZWx5IHRvIGFsbG93IHRoZSBidWlsZCBzeXN0ZW0gdG8gc2VxdWVuY2UgdGhlIGZpbGVzIGluIHRoZSBidWlsdAogICAgICAgICAgICAgKiBsYXllciBpbiB0aGUgcmlnaHQgc2VxdWVuY2UuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwcml2YXRlCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBleGVjQ2I6IGZ1bmN0aW9uIChuYW1lLCBjYWxsYmFjaywgYXJncywgZXhwb3J0cykgewogICAgICAgICAgICAgICAgcmV0dXJuIGNhbGxiYWNrLmFwcGx5KGV4cG9ydHMsIGFyZ3MpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIGNhbGxiYWNrIGZvciBzY3JpcHQgbG9hZHMsIHVzZWQgdG8gY2hlY2sgc3RhdHVzIG9mIGxvYWRpbmcuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwYXJhbSB7RXZlbnR9IGV2dCB0aGUgZXZlbnQgZnJvbSB0aGUgYnJvd3NlciBmb3IgdGhlIHNjcmlwdAogICAgICAgICAgICAgKiB0aGF0IHdhcyBsb2FkZWQuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdExvYWQ6IGZ1bmN0aW9uIChldnQpIHsKICAgICAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgICAgIC8vYWxsIG9sZCBicm93c2VycyB3aWxsIGJlIHN1cHBvcnRlZCwgYnV0IHRoaXMgb25lIHdhcyBlYXN5IGVub3VnaAogICAgICAgICAgICAgICAgLy90byBzdXBwb3J0IGFuZCBzdGlsbCBtYWtlcyBzZW5zZS4KICAgICAgICAgICAgICAgIGlmIChldnQudHlwZSA9PT0gJ2xvYWQnIHx8CiAgICAgICAgICAgICAgICAgICAgICAgIChyZWFkeVJlZ0V4cC50ZXN0KChldnQuY3VycmVudFRhcmdldCB8fCBldnQuc3JjRWxlbWVudCkucmVhZHlTdGF0ZSkpKSB7CiAgICAgICAgICAgICAgICAgICAgLy9SZXNldCBpbnRlcmFjdGl2ZSBzY3JpcHQgc28gYSBzY3JpcHQgbm9kZSBpcyBub3QgaGVsZCBvbnRvIGZvcgogICAgICAgICAgICAgICAgICAgIC8vdG8gbG9uZy4KICAgICAgICAgICAgICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCA9IG51bGw7CgogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvdXQgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSBhbmQgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAgICAgICAgdmFyIGRhdGEgPSBnZXRTY3JpcHREYXRhKGV2dCk7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQoZGF0YS5pZCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGJhY2sgZm9yIHNjcmlwdCBlcnJvcnMuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdEVycm9yOiBmdW5jdGlvbiAoZXZ0KSB7CiAgICAgICAgICAgICAgICB2YXIgZGF0YSA9IGdldFNjcmlwdERhdGEoZXZ0KTsKICAgICAgICAgICAgICAgIGlmICghaGFzUGF0aEZhbGxiYWNrKGRhdGEuaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIHBhcmVudHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24odmFsdWUsIGtleSkgewogICAgICAgICAgICAgICAgICAgICAgICBpZiAoa2V5LmluZGV4T2YoJ19AcicpICE9PSAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHZhbHVlLmRlcE1hcHMsIGZ1bmN0aW9uKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChkZXBNYXAuaWQgPT09IGRhdGEuaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50cy5wdXNoKGtleSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ3NjcmlwdGVycm9yJywgJ1NjcmlwdCBlcnJvciBmb3IgIicgKyBkYXRhLmlkICsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHBhcmVudHMubGVuZ3RoID8KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIsIG5lZWRlZCBieTogJyArIHBhcmVudHMuam9pbignLCAnKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICciJyksIGV2dCwgW2RhdGEuaWRdKSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBjb250ZXh0LnJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKCk7CiAgICAgICAgcmV0dXJuIGNvbnRleHQ7CiAgICB9CgogICAgLyoqCiAgICAgKiBNYWluIGVudHJ5IHBvaW50LgogICAgICoKICAgICAqIElmIHRoZSBvbmx5IGFyZ3VtZW50IHRvIHJlcXVpcmUgaXMgYSBzdHJpbmcsIHRoZW4gdGhlIG1vZHVsZSB0aGF0CiAgICAgKiBpcyByZXByZXNlbnRlZCBieSB0aGF0IHN0cmluZyBpcyBmZXRjaGVkIGZvciB0aGUgYXBwcm9wcmlhdGUgY29udGV4dC4KICAgICAqCiAgICAgKiBJZiB0aGUgZmlyc3QgYXJndW1lbnQgaXMgYW4gYXJyYXksIHRoZW4gaXQgd2lsbCBiZSB0cmVhdGVkIGFzIGFuIGFycmF5CiAgICAgKiBvZiBkZXBlbmRlbmN5IHN0cmluZyBuYW1lcyB0byBmZXRjaC4gQW4gb3B0aW9uYWwgZnVuY3Rpb24gY2FsbGJhY2sgY2FuCiAgICAgKiBiZSBzcGVjaWZpZWQgdG8gZXhlY3V0ZSB3aGVuIGFsbCBvZiB0aG9zZSBkZXBlbmRlbmNpZXMgYXJlIGF2YWlsYWJsZS4KICAgICAqCiAgICAgKiBNYWtlIGEgbG9jYWwgcmVxIHZhcmlhYmxlIHRvIGhlbHAgQ2FqYSBjb21wbGlhbmNlIChpdCBhc3N1bWVzIHRoaW5ncwogICAgICogb24gYSByZXF1aXJlIHRoYXQgYXJlIG5vdCBzdGFuZGFyZGl6ZWQpLCBhbmQgdG8gZ2l2ZSBhIHNob3J0CiAgICAgKiBuYW1lIGZvciBtaW5pZmljYXRpb24vbG9jYWwgc2NvcGUgdXNlLgogICAgICovCiAgICByZXEgPSByZXF1aXJlanMgPSBmdW5jdGlvbiAoZGVwcywgY2FsbGJhY2ssIGVycmJhY2ssIG9wdGlvbmFsKSB7CgogICAgICAgIC8vRmluZCB0aGUgcmlnaHQgY29udGV4dCwgdXNlIGRlZmF1bHQKICAgICAgICB2YXIgY29udGV4dCwgY29uZmlnLAogICAgICAgICAgICBjb250ZXh0TmFtZSA9IGRlZkNvbnRleHROYW1lOwoKICAgICAgICAvLyBEZXRlcm1pbmUgaWYgaGF2ZSBjb25maWcgb2JqZWN0IGluIHRoZSBjYWxsLgogICAgICAgIGlmICghaXNBcnJheShkZXBzKSAmJiB0eXBlb2YgZGVwcyAhPT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgLy8gZGVwcyBpcyBhIGNvbmZpZyBvYmplY3QKICAgICAgICAgICAgY29uZmlnID0gZGVwczsKICAgICAgICAgICAgaWYgKGlzQXJyYXkoY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAvLyBBZGp1c3QgYXJncyBpZiB0aGVyZSBhcmUgZGVwZW5kZW5jaWVzCiAgICAgICAgICAgICAgICBkZXBzID0gY2FsbGJhY2s7CiAgICAgICAgICAgICAgICBjYWxsYmFjayA9IGVycmJhY2s7CiAgICAgICAgICAgICAgICBlcnJiYWNrID0gb3B0aW9uYWw7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBkZXBzID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGlmIChjb25maWcgJiYgY29uZmlnLmNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dE5hbWUgPSBjb25maWcuY29udGV4dDsKICAgICAgICB9CgogICAgICAgIGNvbnRleHQgPSBnZXRPd24oY29udGV4dHMsIGNvbnRleHROYW1lKTsKICAgICAgICBpZiAoIWNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dCA9IGNvbnRleHRzW2NvbnRleHROYW1lXSA9IHJlcS5zLm5ld0NvbnRleHQoY29udGV4dE5hbWUpOwogICAgICAgIH0KCiAgICAgICAgaWYgKGNvbmZpZykgewogICAgICAgICAgICBjb250ZXh0LmNvbmZpZ3VyZShjb25maWcpOwogICAgICAgIH0KCiAgICAgICAgcmV0dXJuIGNvbnRleHQucmVxdWlyZShkZXBzLCBjYWxsYmFjaywgZXJyYmFjayk7CiAgICB9OwoKICAgIC8qKgogICAgICogU3VwcG9ydCByZXF1aXJlLmNvbmZpZygpIHRvIG1ha2UgaXQgZWFzaWVyIHRvIGNvb3BlcmF0ZSB3aXRoIG90aGVyCiAgICAgKiBBTUQgbG9hZGVycyBvbiBnbG9iYWxseSBhZ3JlZWQgbmFtZXMuCiAgICAgKi8KICAgIHJlcS5jb25maWcgPSBmdW5jdGlvbiAoY29uZmlnKSB7CiAgICAgICAgcmV0dXJuIHJlcShjb25maWcpOwogICAgfTsKCiAgICAvKioKICAgICAqIEV4ZWN1dGUgc29tZXRoaW5nIGFmdGVyIHRoZSBjdXJyZW50IHRpY2sKICAgICAqIG9mIHRoZSBldmVudCBsb29wLiBPdmVycmlkZSBmb3Igb3RoZXIgZW52cwogICAgICogdGhhdCBoYXZlIGEgYmV0dGVyIHNvbHV0aW9uIHRoYW4gc2V0VGltZW91dC4KICAgICAqIEBwYXJhbSAge0Z1bmN0aW9ufSBmbiBmdW5jdGlvbiB0byBleGVjdXRlIGxhdGVyLgogICAgICovCiAgICByZXEubmV4dFRpY2sgPSB0eXBlb2Ygc2V0VGltZW91dCAhPT0gJ3VuZGVmaW5lZCcgPyBmdW5jdGlvbiAoZm4pIHsKICAgICAgICBzZXRUaW1lb3V0KGZuLCA0KTsKICAgIH0gOiBmdW5jdGlvbiAoZm4pIHsgZm4oKTsgfTsKCiAgICAvKioKICAgICAqIEV4cG9ydCByZXF1aXJlIGFzIGEgZ2xvYmFsLCBidXQgb25seSBpZiBpdCBkb2VzIG5vdCBhbHJlYWR5IGV4aXN0LgogICAgICovCiAgICBpZiAoIXJlcXVpcmUpIHsKICAgICAgICByZXF1aXJlID0gcmVxOwogICAgfQoKICAgIHJlcS52ZXJzaW9uID0gdmVyc2lvbjsKCiAgICAvL1VzZWQgdG8gZmlsdGVyIG91dCBkZXBlbmRlbmNpZXMgdGhhdCBhcmUgYWxyZWFkeSBwYXRocy4KICAgIHJlcS5qc0V4dFJlZ0V4cCA9IC9eXC98OnxcP3xcLmpzJC87CiAgICByZXEuaXNCcm93c2VyID0gaXNCcm93c2VyOwogICAgcyA9IHJlcS5zID0gewogICAgICAgIGNvbnRleHRzOiBjb250ZXh0cywKICAgICAgICBuZXdDb250ZXh0OiBuZXdDb250ZXh0CiAgICB9OwoKICAgIC8vQ3JlYXRlIGRlZmF1bHQgY29udGV4dC4KICAgIHJlcSh7fSk7CgogICAgLy9FeHBvcnRzIHNvbWUgY29udGV4dC1zZW5zaXRpdmUgbWV0aG9kcyBvbiBnbG9iYWwgcmVxdWlyZS4KICAgIGVhY2goWwogICAgICAgICd0b1VybCcsCiAgICAgICAgJ3VuZGVmJywKICAgICAgICAnZGVmaW5lZCcsCiAgICAgICAgJ3NwZWNpZmllZCcKICAgIF0sIGZ1bmN0aW9uIChwcm9wKSB7CiAgICAgICAgLy9SZWZlcmVuY2UgZnJvbSBjb250ZXh0cyBpbnN0ZWFkIG9mIGVhcmx5IGJpbmRpbmcgdG8gZGVmYXVsdCBjb250ZXh0LAogICAgICAgIC8vc28gdGhhdCBkdXJpbmcgYnVpbGRzLCB0aGUgbGF0ZXN0IGluc3RhbmNlIG9mIHRoZSBkZWZhdWx0IGNvbnRleHQKICAgICAgICAvL3dpdGggaXRzIGNvbmZpZyBnZXRzIHVzZWQuCiAgICAgICAgcmVxW3Byb3BdID0gZnVuY3Rpb24gKCkgewogICAgICAgICAgICB2YXIgY3R4ID0gY29udGV4dHNbZGVmQ29udGV4dE5hbWVdOwogICAgICAgICAgICByZXR1cm4gY3R4LnJlcXVpcmVbcHJvcF0uYXBwbHkoY3R4LCBhcmd1bWVudHMpOwogICAgICAgIH07CiAgICB9KTsKCiAgICBpZiAoaXNCcm93c2VyKSB7CiAgICAgICAgaGVhZCA9IHMuaGVhZCA9IGRvY3VtZW50LmdldEVsZW1lbnRzQnlUYWdOYW1lKCdoZWFkJylbMF07CiAgICAgICAgLy9JZiBCQVNFIHRhZyBpcyBpbiBwbGF5LCB1c2luZyBhcHBlbmRDaGlsZCBpcyBhIHByb2JsZW0gZm9yIElFNi4KICAgICAgICAvL1doZW4gdGhhdCBicm93c2VyIGRpZXMsIHRoaXMgY2FuIGJlIHJlbW92ZWQuIERldGFpbHMgaW4gdGhpcyBqUXVlcnkgYnVnOgogICAgICAgIC8vaHR0cDovL2Rldi5qcXVlcnkuY29tL3RpY2tldC8yNzA5CiAgICAgICAgYmFzZUVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnYmFzZScpWzBdOwogICAgICAgIGlmIChiYXNlRWxlbWVudCkgewogICAgICAgICAgICBoZWFkID0gcy5oZWFkID0gYmFzZUVsZW1lbnQucGFyZW50Tm9kZTsKICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBBbnkgZXJyb3JzIHRoYXQgcmVxdWlyZSBleHBsaWNpdGx5IGdlbmVyYXRlcyB3aWxsIGJlIHBhc3NlZCB0byB0aGlzCiAgICAgKiBmdW5jdGlvbi4gSW50ZXJjZXB0L292ZXJyaWRlIGl0IGlmIHlvdSB3YW50IGN1c3RvbSBlcnJvciBoYW5kbGluZy4KICAgICAqIEBwYXJhbSB7RXJyb3J9IGVyciB0aGUgZXJyb3Igb2JqZWN0LgogICAgICovCiAgICByZXEub25FcnJvciA9IGRlZmF1bHRPbkVycm9yOwoKICAgIC8qKgogICAgICogQ3JlYXRlcyB0aGUgbm9kZSBmb3IgdGhlIGxvYWQgY29tbWFuZC4gT25seSB1c2VkIGluIGJyb3dzZXIgZW52cy4KICAgICAqLwogICAgcmVxLmNyZWF0ZU5vZGUgPSBmdW5jdGlvbiAoY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgbm9kZSA9IGNvbmZpZy54aHRtbCA/CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50TlMoJ2h0dHA6Ly93d3cudzMub3JnLzE5OTkveGh0bWwnLCAnaHRtbDpzY3JpcHQnKSA6CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdzY3JpcHQnKTsKICAgICAgICBub2RlLnR5cGUgPSBjb25maWcuc2NyaXB0VHlwZSB8fCAndGV4dC9qYXZhc2NyaXB0JzsKICAgICAgICBub2RlLmNoYXJzZXQgPSAndXRmLTgnOwogICAgICAgIG5vZGUuYXN5bmMgPSB0cnVlOwogICAgICAgIHJldHVybiBub2RlOwogICAgfTsKCiAgICAvKioKICAgICAqIERvZXMgdGhlIHJlcXVlc3QgdG8gbG9hZCBhIG1vZHVsZSBmb3IgdGhlIGJyb3dzZXIgY2FzZS4KICAgICAqIE1ha2UgdGhpcyBhIHNlcGFyYXRlIGZ1bmN0aW9uIHRvIGFsbG93IG90aGVyIGVudmlyb25tZW50cwogICAgICogdG8gb3ZlcnJpZGUgaXQuCiAgICAgKgogICAgICogQHBhcmFtIHtPYmplY3R9IGNvbnRleHQgdGhlIHJlcXVpcmUgY29udGV4dCB0byBmaW5kIHN0YXRlLgogICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZS4KICAgICAqIEBwYXJhbSB7T2JqZWN0fSB1cmwgdGhlIFVSTCB0byB0aGUgbW9kdWxlLgogICAgICovCiAgICByZXEubG9hZCA9IGZ1bmN0aW9uIChjb250ZXh0LCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgY29uZmlnID0gKGNvbnRleHQgJiYgY29udGV4dC5jb25maWcpIHx8IHt9LAogICAgICAgICAgICBub2RlOwogICAgICAgIGlmIChpc0Jyb3dzZXIpIHsKICAgICAgICAgICAgLy9JbiB0aGUgYnJvd3NlciBzbyB1c2UgYSBzY3JpcHQgdGFnCiAgICAgICAgICAgIG5vZGUgPSByZXEuY3JlYXRlTm9kZShjb25maWcsIG1vZHVsZU5hbWUsIHVybCk7CiAgICAgICAgICAgIGlmIChjb25maWcub25Ob2RlQ3JlYXRlZCkgewogICAgICAgICAgICAgICAgY29uZmlnLm9uTm9kZUNyZWF0ZWQobm9kZSwgY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpOwogICAgICAgICAgICB9CgogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcsIGNvbnRleHQuY29udGV4dE5hbWUpOwogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJywgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAvL1NldCB1cCBsb2FkIGxpc3RlbmVyLiBUZXN0IGF0dGFjaEV2ZW50IGZpcnN0IGJlY2F1c2UgSUU5IGhhcwogICAgICAgICAgICAvL2Egc3VidGxlIGlzc3VlIGluIGl0cyBhZGRFdmVudExpc3RlbmVyIGFuZCBzY3JpcHQgb25sb2FkIGZpcmluZ3MKICAgICAgICAgICAgLy90aGF0IGRvIG5vdCBtYXRjaCB0aGUgYmVoYXZpb3Igb2YgYWxsIG90aGVyIGJyb3dzZXJzIHdpdGgKICAgICAgICAgICAgLy9hZGRFdmVudExpc3RlbmVyIHN1cHBvcnQsIHdoaWNoIGZpcmUgdGhlIG9ubG9hZCBldmVudCBmb3IgYQogICAgICAgICAgICAvL3NjcmlwdCByaWdodCBhZnRlciB0aGUgc2NyaXB0IGV4ZWN1dGlvbi4gU2VlOgogICAgICAgICAgICAvL2h0dHBzOi8vY29ubmVjdC5taWNyb3NvZnQuY29tL0lFL2ZlZWRiYWNrL2RldGFpbHMvNjQ4MDU3L3NjcmlwdC1vbmxvYWQtZXZlbnQtaXMtbm90LWZpcmVkLWltbWVkaWF0ZWx5LWFmdGVyLXNjcmlwdC1leGVjdXRpb24KICAgICAgICAgICAgLy9VTkZPUlRVTkFURUxZIE9wZXJhIGltcGxlbWVudHMgYXR0YWNoRXZlbnQgYnV0IGRvZXMgbm90IGZvbGxvdyB0aGUgc2NyaXB0CiAgICAgICAgICAgIC8vc2NyaXB0IGV4ZWN1dGlvbiBtb2RlLgogICAgICAgICAgICBpZiAobm9kZS5hdHRhY2hFdmVudCAmJgogICAgICAgICAgICAgICAgICAgIC8vQ2hlY2sgaWYgbm9kZS5hdHRhY2hFdmVudCBpcyBhcnRpZmljaWFsbHkgYWRkZWQgYnkgY3VzdG9tIHNjcmlwdCBvcgogICAgICAgICAgICAgICAgICAgIC8vbmF0aXZlbHkgc3VwcG9ydGVkIGJ5IGJyb3dzZXIKICAgICAgICAgICAgICAgICAgICAvL3JlYWQgaHR0cHM6Ly9naXRodWIuY29tL2pyYnVya2UvcmVxdWlyZWpzL2lzc3Vlcy8xODcKICAgICAgICAgICAgICAgICAgICAvL2lmIHdlIGNhbiBOT1QgZmluZCBbbmF0aXZlIGNvZGVdIHRoZW4gaXQgbXVzdCBOT1QgbmF0aXZlbHkgc3VwcG9ydGVkLgogICAgICAgICAgICAgICAgICAgIC8vaW4gSUU4LCBub2RlLmF0dGFjaEV2ZW50IGRvZXMgbm90IGhhdmUgdG9TdHJpbmcoKQogICAgICAgICAgICAgICAgICAgIC8vTm90ZSB0aGUgdGVzdCBmb3IgIltuYXRpdmUgY29kZSIgd2l0aCBubyBjbG9zaW5nIGJyYWNlLCBzZWU6CiAgICAgICAgICAgICAgICAgICAgLy9odHRwczovL2dpdGh1Yi5jb20vanJidXJrZS9yZXF1aXJlanMvaXNzdWVzLzI3MwogICAgICAgICAgICAgICAgICAgICEobm9kZS5hdHRhY2hFdmVudC50b1N0cmluZyAmJiBub2RlLmF0dGFjaEV2ZW50LnRvU3RyaW5nKCkuaW5kZXhPZignW25hdGl2ZSBjb2RlJykgPCAwKSAmJgogICAgICAgICAgICAgICAgICAgICFpc09wZXJhKSB7CiAgICAgICAgICAgICAgICAvL1Byb2JhYmx5IElFLiBJRSAoYXQgbGVhc3QgNi04KSBkbyBub3QgZmlyZQogICAgICAgICAgICAgICAgLy9zY3JpcHQgb25sb2FkIHJpZ2h0IGFmdGVyIGV4ZWN1dGluZyB0aGUgc2NyaXB0LCBzbwogICAgICAgICAgICAgICAgLy93ZSBjYW5ub3QgdGllIHRoZSBhbm9ueW1vdXMgZGVmaW5lIGNhbGwgdG8gYSBuYW1lLgogICAgICAgICAgICAgICAgLy9Ib3dldmVyLCBJRSByZXBvcnRzIHRoZSBzY3JpcHQgYXMgYmVpbmcgaW4gJ2ludGVyYWN0aXZlJwogICAgICAgICAgICAgICAgLy9yZWFkeVN0YXRlIGF0IHRoZSB0aW1lIG9mIHRoZSBkZWZpbmUgY2FsbC4KICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKCiAgICAgICAgICAgICAgICBub2RlLmF0dGFjaEV2ZW50KCdvbnJlYWR5c3RhdGVjaGFuZ2UnLCBjb250ZXh0Lm9uU2NyaXB0TG9hZCk7CiAgICAgICAgICAgICAgICAvL0l0IHdvdWxkIGJlIGdyZWF0IHRvIGFkZCBhbiBlcnJvciBoYW5kbGVyIGhlcmUgdG8gY2F0Y2gKICAgICAgICAgICAgICAgIC8vNDA0cyBpbiBJRTkrLiBIb3dldmVyLCBvbnJlYWR5c3RhdGVjaGFuZ2Ugd2lsbCBmaXJlIGJlZm9yZQogICAgICAgICAgICAgICAgLy90aGUgZXJyb3IgaGFuZGxlciwgc28gdGhhdCBkb2VzIG5vdCBoZWxwLiBJZiBhZGRFdmVudExpc3RlbmVyCiAgICAgICAgICAgICAgICAvL2lzIHVzZWQsIHRoZW4gSUUgd2lsbCBmaXJlIGVycm9yIGJlZm9yZSBsb2FkLCBidXQgd2UgY2Fubm90CiAgICAgICAgICAgICAgICAvL3VzZSB0aGF0IHBhdGh3YXkgZ2l2ZW4gdGhlIGNvbm5lY3QubWljcm9zb2Z0LmNvbSBpc3N1ZQogICAgICAgICAgICAgICAgLy9tZW50aW9uZWQgYWJvdmUgYWJvdXQgbm90IGRvaW5nIHRoZSAnc2NyaXB0IGV4ZWN1dGUsCiAgICAgICAgICAgICAgICAvL3RoZW4gZmlyZSB0aGUgc2NyaXB0IGxvYWQgZXZlbnQgbGlzdGVuZXIgYmVmb3JlIGV4ZWN1dGUKICAgICAgICAgICAgICAgIC8vbmV4dCBzY3JpcHQnIHRoYXQgb3RoZXIgYnJvd3NlcnMgZG8uCiAgICAgICAgICAgICAgICAvL0Jlc3QgaG9wZTogSUUxMCBmaXhlcyB0aGUgaXNzdWVzLAogICAgICAgICAgICAgICAgLy9hbmQgdGhlbiBkZXN0cm95cyBhbGwgaW5zdGFsbHMgb2YgSUUgNi05LgogICAgICAgICAgICAgICAgLy9ub2RlLmF0dGFjaEV2ZW50KCdvbmVycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yKTsKICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUuYWRkRXZlbnRMaXN0ZW5lcignbG9hZCcsIGNvbnRleHQub25TY3JpcHRMb2FkLCBmYWxzZSk7CiAgICAgICAgICAgICAgICBub2RlLmFkZEV2ZW50TGlzdGVuZXIoJ2Vycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICAgICAgbm9kZS5zcmMgPSB1cmw7CgogICAgICAgICAgICAvL0ZvciBzb21lIGNhY2hlIGNhc2VzIGluIElFIDYtOCwgdGhlIHNjcmlwdCBleGVjdXRlcyBiZWZvcmUgdGhlIGVuZAogICAgICAgICAgICAvL29mIHRoZSBhcHBlbmRDaGlsZCBleGVjdXRpb24sIHNvIHRvIHRpZSBhbiBhbm9ueW1vdXMgZGVmaW5lCiAgICAgICAgICAgIC8vY2FsbCB0byB0aGUgbW9kdWxlIG5hbWUgKHdoaWNoIGlzIHN0b3JlZCBvbiB0aGUgbm9kZSksIGhvbGQgb24KICAgICAgICAgICAgLy90byBhIHJlZmVyZW5jZSB0byB0aGlzIG5vZGUsIGJ1dCBjbGVhciBhZnRlciB0aGUgRE9NIGluc2VydGlvbi4KICAgICAgICAgICAgY3VycmVudGx5QWRkaW5nU2NyaXB0ID0gbm9kZTsKICAgICAgICAgICAgaWYgKGJhc2VFbGVtZW50KSB7CiAgICAgICAgICAgICAgICBoZWFkLmluc2VydEJlZm9yZShub2RlLCBiYXNlRWxlbWVudCk7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBoZWFkLmFwcGVuZENoaWxkKG5vZGUpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIGN1cnJlbnRseUFkZGluZ1NjcmlwdCA9IG51bGw7CgogICAgICAgICAgICByZXR1cm4gbm9kZTsKICAgICAgICB9IGVsc2UgaWYgKGlzV2ViV29ya2VyKSB7CiAgICAgICAgICAgIHRyeSB7CiAgICAgICAgICAgICAgICAvL0luIGEgd2ViIHdvcmtlciwgdXNlIGltcG9ydFNjcmlwdHMuIFRoaXMgaXMgbm90IGEgdmVyeQogICAgICAgICAgICAgICAgLy9lZmZpY2llbnQgdXNlIG9mIGltcG9ydFNjcmlwdHMsIGltcG9ydFNjcmlwdHMgd2lsbCBibG9jayB1bnRpbAogICAgICAgICAgICAgICAgLy9pdHMgc2NyaXB0IGlzIGRvd25sb2FkZWQgYW5kIGV2YWx1YXRlZC4gSG93ZXZlciwgaWYgd2ViIHdvcmtlcnMKICAgICAgICAgICAgICAgIC8vYXJlIGluIHBsYXksIHRoZSBleHBlY3RhdGlvbiBpcyB0aGF0IGEgYnVpbGQgaGFzIGJlZW4gZG9uZSBzbwogICAgICAgICAgICAgICAgLy90aGF0IG9ubHkgb25lIHNjcmlwdCBuZWVkcyB0byBiZSBsb2FkZWQgYW55d2F5LiBUaGlzIG1heSBuZWVkCiAgICAgICAgICAgICAgICAvL3RvIGJlIHJlZXZhbHVhdGVkIGlmIG90aGVyIHVzZSBjYXNlcyBiZWNvbWUgY29tbW9uLgogICAgICAgICAgICAgICAgaW1wb3J0U2NyaXB0cyh1cmwpOwoKICAgICAgICAgICAgICAgIC8vQWNjb3VudCBmb3IgYW5vbnltb3VzIG1vZHVsZXMKICAgICAgICAgICAgICAgIGNvbnRleHQuY29tcGxldGVMb2FkKG1vZHVsZU5hbWUpOwogICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICBjb250ZXh0Lm9uRXJyb3IobWFrZUVycm9yKCdpbXBvcnRzY3JpcHRzJywKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnaW1wb3J0U2NyaXB0cyBmYWlsZWQgZm9yICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lICsgJyBhdCAnICsgdXJsLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW21vZHVsZU5hbWVdKSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9OwoKICAgIGZ1bmN0aW9uIGdldEludGVyYWN0aXZlU2NyaXB0KCkgewogICAgICAgIGlmIChpbnRlcmFjdGl2ZVNjcmlwdCAmJiBpbnRlcmFjdGl2ZVNjcmlwdC5yZWFkeVN0YXRlID09PSAnaW50ZXJhY3RpdmUnKSB7CiAgICAgICAgICAgIHJldHVybiBpbnRlcmFjdGl2ZVNjcmlwdDsKICAgICAgICB9CgogICAgICAgIGVhY2hSZXZlcnNlKHNjcmlwdHMoKSwgZnVuY3Rpb24gKHNjcmlwdCkgewogICAgICAgICAgICBpZiAoc2NyaXB0LnJlYWR5U3RhdGUgPT09ICdpbnRlcmFjdGl2ZScpIHsKICAgICAgICAgICAgICAgIHJldHVybiAoaW50ZXJhY3RpdmVTY3JpcHQgPSBzY3JpcHQpOwogICAgICAgICAgICB9CiAgICAgICAgfSk7CiAgICAgICAgcmV0dXJuIGludGVyYWN0aXZlU2NyaXB0OwogICAgfQoKICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gc2NyaXB0IGF0dHJpYnV0ZSwgd2hpY2ggY291bGQgYWxzbyBhZGp1c3QgdGhlIGJhc2VVcmwuCiAgICBpZiAoaXNCcm93c2VyICYmICFjZmcuc2tpcERhdGFNYWluKSB7CiAgICAgICAgLy9GaWd1cmUgb3V0IGJhc2VVcmwuIEdldCBpdCBmcm9tIHRoZSBzY3JpcHQgdGFnIHdpdGggcmVxdWlyZS5qcyBpbiBpdC4KICAgICAgICBlYWNoUmV2ZXJzZShzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHQpIHsKICAgICAgICAgICAgLy9TZXQgdGhlICdoZWFkJyB3aGVyZSB3ZSBjYW4gYXBwZW5kIGNoaWxkcmVuIGJ5CiAgICAgICAgICAgIC8vdXNpbmcgdGhlIHNjcmlwdCdzIHBhcmVudC4KICAgICAgICAgICAgaWYgKCFoZWFkKSB7CiAgICAgICAgICAgICAgICBoZWFkID0gc2NyaXB0LnBhcmVudE5vZGU7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gYXR0cmlidXRlIHRvIHNldCBtYWluIHNjcmlwdCBmb3IgdGhlIHBhZ2UKICAgICAgICAgICAgLy90byBsb2FkLiBJZiBpdCBpcyB0aGVyZSwgdGhlIHBhdGggdG8gZGF0YSBtYWluIGJlY29tZXMgdGhlCiAgICAgICAgICAgIC8vYmFzZVVybCwgaWYgaXQgaXMgbm90IGFscmVhZHkgc2V0LgogICAgICAgICAgICBkYXRhTWFpbiA9IHNjcmlwdC5nZXRBdHRyaWJ1dGUoJ2RhdGEtbWFpbicpOwogICAgICAgICAgICBpZiAoZGF0YU1haW4pIHsKICAgICAgICAgICAgICAgIC8vUHJlc2VydmUgZGF0YU1haW4gaW4gY2FzZSBpdCBpcyBhIHBhdGggKGkuZS4gY29udGFpbnMgJz8nKQogICAgICAgICAgICAgICAgbWFpblNjcmlwdCA9IGRhdGFNYWluOwoKICAgICAgICAgICAgICAgIC8vU2V0IGZpbmFsIGJhc2VVcmwgaWYgdGhlcmUgaXMgbm90IGFscmVhZHkgYW4gZXhwbGljaXQgb25lLgogICAgICAgICAgICAgICAgaWYgKCFjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvZmYgdGhlIGRpcmVjdG9yeSBvZiBkYXRhLW1haW4gZm9yIHVzZSBhcyB0aGUKICAgICAgICAgICAgICAgICAgICAvL2Jhc2VVcmwuCiAgICAgICAgICAgICAgICAgICAgc3JjID0gbWFpblNjcmlwdC5zcGxpdCgnLycpOwogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBzcmMucG9wKCk7CiAgICAgICAgICAgICAgICAgICAgc3ViUGF0aCA9IHNyYy5sZW5ndGggPyBzcmMuam9pbignLycpICArICcvJyA6ICcuLyc7CgogICAgICAgICAgICAgICAgICAgIGNmZy5iYXNlVXJsID0gc3ViUGF0aDsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1N0cmlwIG9mZiBhbnkgdHJhaWxpbmcgLmpzIHNpbmNlIG1haW5TY3JpcHQgaXMgbm93CiAgICAgICAgICAgICAgICAvL2xpa2UgYSBtb2R1bGUgbmFtZS4KICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBtYWluU2NyaXB0LnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKCiAgICAgICAgICAgICAgICAvL0lmIG1haW5TY3JpcHQgaXMgc3RpbGwgYSBwYXRoLCBmYWxsIGJhY2sgdG8gZGF0YU1haW4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtYWluU2NyaXB0KSkgewogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBkYXRhTWFpbjsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1B1dCB0aGUgZGF0YS1tYWluIHNjcmlwdCBpbiB0aGUgZmlsZXMgdG8gbG9hZC4KICAgICAgICAgICAgICAgIGNmZy5kZXBzID0gY2ZnLmRlcHMgPyBjZmcuZGVwcy5jb25jYXQobWFpblNjcmlwdCkgOiBbbWFpblNjcmlwdF07CgogICAgICAgICAgICAgICAgcmV0dXJuIHRydWU7CiAgICAgICAgICAgIH0KICAgICAgICB9KTsKICAgIH0KCiAgICAvKioKICAgICAqIFRoZSBmdW5jdGlvbiB0aGF0IGhhbmRsZXMgZGVmaW5pdGlvbnMgb2YgbW9kdWxlcy4gRGlmZmVycyBmcm9tCiAgICAgKiByZXF1aXJlKCkgaW4gdGhhdCBhIHN0cmluZyBmb3IgdGhlIG1vZHVsZSBzaG91bGQgYmUgdGhlIGZpcnN0IGFyZ3VtZW50LAogICAgICogYW5kIHRoZSBmdW5jdGlvbiB0byBleGVjdXRlIGFmdGVyIGRlcGVuZGVuY2llcyBhcmUgbG9hZGVkIHNob3VsZAogICAgICogcmV0dXJuIGEgdmFsdWUgdG8gZGVmaW5lIHRoZSBtb2R1bGUgY29ycmVzcG9uZGluZyB0byB0aGUgZmlyc3QgYXJndW1lbnQncwogICAgICogbmFtZS4KICAgICAqLwogICAgZGVmaW5lID0gZnVuY3Rpb24gKG5hbWUsIGRlcHMsIGNhbGxiYWNrKSB7CiAgICAgICAgdmFyIG5vZGUsIGNvbnRleHQ7CgogICAgICAgIC8vQWxsb3cgZm9yIGFub255bW91cyBtb2R1bGVzCiAgICAgICAgaWYgKHR5cGVvZiBuYW1lICE9PSAnc3RyaW5nJykgewogICAgICAgICAgICAvL0FkanVzdCBhcmdzIGFwcHJvcHJpYXRlbHkKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbmFtZTsKICAgICAgICAgICAgbmFtZSA9IG51bGw7CiAgICAgICAgfQoKICAgICAgICAvL1RoaXMgbW9kdWxlIG1heSBub3QgaGF2ZSBkZXBlbmRlbmNpZXMKICAgICAgICBpZiAoIWlzQXJyYXkoZGVwcykpIHsKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbnVsbDsKICAgICAgICB9CgogICAgICAgIC8vSWYgbm8gbmFtZSwgYW5kIGNhbGxiYWNrIGlzIGEgZnVuY3Rpb24sIHRoZW4gZmlndXJlIG91dCBpZiBpdCBhCiAgICAgICAgLy9Db21tb25KUyB0aGluZyB3aXRoIGRlcGVuZGVuY2llcy4KICAgICAgICBpZiAoIWRlcHMgJiYgaXNGdW5jdGlvbihjYWxsYmFjaykpIHsKICAgICAgICAgICAgZGVwcyA9IFtdOwogICAgICAgICAgICAvL1JlbW92ZSBjb21tZW50cyBmcm9tIHRoZSBjYWxsYmFjayBzdHJpbmcsCiAgICAgICAgICAgIC8vbG9vayBmb3IgcmVxdWlyZSBjYWxscywgYW5kIHB1bGwgdGhlbSBpbnRvIHRoZSBkZXBlbmRlbmNpZXMsCiAgICAgICAgICAgIC8vYnV0IG9ubHkgaWYgdGhlcmUgYXJlIGZ1bmN0aW9uIGFyZ3MuCiAgICAgICAgICAgIGlmIChjYWxsYmFjay5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGNhbGxiYWNrCiAgICAgICAgICAgICAgICAgICAgLnRvU3RyaW5nKCkKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjb21tZW50UmVnRXhwLCAnJykKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjanNSZXF1aXJlUmVnRXhwLCBmdW5jdGlvbiAobWF0Y2gsIGRlcCkgewogICAgICAgICAgICAgICAgICAgICAgICBkZXBzLnB1c2goZGVwKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAvL01heSBiZSBhIENvbW1vbkpTIHRoaW5nIGV2ZW4gd2l0aG91dCByZXF1aXJlIGNhbGxzLCBidXQgc3RpbGwKICAgICAgICAgICAgICAgIC8vY291bGQgdXNlIGV4cG9ydHMsIGFuZCBtb2R1bGUuIEF2b2lkIGRvaW5nIGV4cG9ydHMgYW5kIG1vZHVsZQogICAgICAgICAgICAgICAgLy93b3JrIHRob3VnaCBpZiBpdCBqdXN0IG5lZWRzIHJlcXVpcmUuCiAgICAgICAgICAgICAgICAvL1JFUVVJUkVTIHRoZSBmdW5jdGlvbiB0byBleHBlY3QgdGhlIENvbW1vbkpTIHZhcmlhYmxlcyBpbiB0aGUKICAgICAgICAgICAgICAgIC8vb3JkZXIgbGlzdGVkIGJlbG93LgogICAgICAgICAgICAgICAgZGVwcyA9IChjYWxsYmFjay5sZW5ndGggPT09IDEgPyBbJ3JlcXVpcmUnXSA6IFsncmVxdWlyZScsICdleHBvcnRzJywgJ21vZHVsZSddKS5jb25jYXQoZGVwcyk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8vSWYgaW4gSUUgNi04IGFuZCBoaXQgYW4gYW5vbnltb3VzIGRlZmluZSgpIGNhbGwsIGRvIHRoZSBpbnRlcmFjdGl2ZQogICAgICAgIC8vd29yay4KICAgICAgICBpZiAodXNlSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgbm9kZSA9IGN1cnJlbnRseUFkZGluZ1NjcmlwdCB8fCBnZXRJbnRlcmFjdGl2ZVNjcmlwdCgpOwogICAgICAgICAgICBpZiAobm9kZSkgewogICAgICAgICAgICAgICAgaWYgKCFuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQgPSBjb250ZXh0c1tub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcpXTsKICAgICAgICAgICAgfQogICAgICAgIH0KCiAgICAgICAgLy9BbHdheXMgc2F2ZSBvZmYgZXZhbHVhdGluZyB0aGUgZGVmIGNhbGwgdW50aWwgdGhlIHNjcmlwdCBvbmxvYWQgaGFuZGxlci4KICAgICAgICAvL1RoaXMgYWxsb3dzIG11bHRpcGxlIG1vZHVsZXMgdG8gYmUgaW4gYSBmaWxlIHdpdGhvdXQgcHJlbWF0dXJlbHkKICAgICAgICAvL3RyYWNpbmcgZGVwZW5kZW5jaWVzLCBhbmQgYWxsb3dzIGZvciBhbm9ueW1vdXMgbW9kdWxlIHN1cHBvcnQsCiAgICAgICAgLy93aGVyZSB0aGUgbW9kdWxlIG5hbWUgaXMgbm90IGtub3duIHVudGlsIHRoZSBzY3JpcHQgb25sb2FkIGV2ZW50CiAgICAgICAgLy9vY2N1cnMuIElmIG5vIGNvbnRleHQsIHVzZSB0aGUgZ2xvYmFsIHF1ZXVlLCBhbmQgZ2V0IGl0IHByb2Nlc3NlZAogICAgICAgIC8vaW4gdGhlIG9uc2NyaXB0IGxvYWQgY2FsbGJhY2suCiAgICAgICAgaWYgKGNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgICAgICBjb250ZXh0LmRlZlF1ZXVlTWFwW25hbWVdID0gdHJ1ZTsKICAgICAgICB9IGVsc2UgewogICAgICAgICAgICBnbG9iYWxEZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgIH0KICAgIH07CgogICAgZGVmaW5lLmFtZCA9IHsKICAgICAgICBqUXVlcnk6IHRydWUKICAgIH07CgogICAgLyoqCiAgICAgKiBFeGVjdXRlcyB0aGUgdGV4dC4gTm9ybWFsbHkganVzdCB1c2VzIGV2YWwsIGJ1dCBjYW4gYmUgbW9kaWZpZWQKICAgICAqIHRvIHVzZSBhIGJldHRlciwgZW52aXJvbm1lbnQtc3BlY2lmaWMgY2FsbC4gT25seSB1c2VkIGZvciB0cmFuc3BpbGluZwogICAgICogbG9hZGVyIHBsdWdpbnMsIG5vdCBmb3IgcGxhaW4gSlMgbW9kdWxlcy4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSB0ZXh0IHRoZSB0ZXh0IHRvIGV4ZWN1dGUvZXZhbHVhdGUuCiAgICAgKi8KICAgIHJlcS5leGVjID0gZnVuY3Rpb24gKHRleHQpIHsKICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgcmV0dXJuIGV2YWwodGV4dCk7CiAgICB9OwoKICAgIC8vU2V0IHVwIHdpdGggY29uZmlnIGluZm8uCiAgICByZXEoY2ZnKTsKfSh0aGlzKSk7Cg==", + "headers": [ + [ + "content-type", + "application/javascript" + ] + ], + "ok": true, + "status": 200, + "status_text": "" + } + } + }, + "colab_type": "code", + "id": "k0j5zzpAPSFn", + "outputId": "cb5b1d88-054b-413e-d303-428e63bce694" + }, + "outputs": [], + "source": [ + "call_html()\n", + "display.display(display.HTML(vis_html))\n", + "display.display(display.Javascript('window.attention = %s' % attention_json))\n", + "display.display(display.Javascript(vis_js))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lydjSs3hgDVF" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Attention_Visualization_in_Trax.ipynb", + "provenance": [ + { + "file_id": "1bJu3Qx37FY9UpHqVMyXCTNb64v4Iw_v7", + "timestamp": 1598692842045 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/resources/examples/ipynb/Example-7-4-Knowledge-Tracing-Transformer.ipynb b/resources/examples/ipynb/Example-7-4-Knowledge-Tracing-Transformer.ipynb new file mode 100644 index 000000000..ffcad96da --- /dev/null +++ b/resources/examples/ipynb/Example-7-4-Knowledge-Tracing-Transformer.ipynb @@ -0,0 +1,2126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eGCe1pjznIQS" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2021 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lAula_PU9jqB" + }, + "source": [ + "## Intro\r\n", + "\r\n", + "This notebook trains a transformer model on the [EdNet dataset](https://github.com/riiid/ednet) using the [google/trax library](https://github.com/google/trax). The EdNet dataset is large set of student responses to multiple choice questions related to English language learning. A recent Kaggle competition, [Riiid! Answer Correctness Prediction](https://www.kaggle.com/c/riiid-test-answer-prediction), provided as subset of this data, consisting of 100 million responses to 13 thousand questions from 300 thousand students.\r\n", + "\r\n", + "The state of the art result, detailed in [SAINT+: Integrating Temporal Features for EdNet Correctness Prediction](https://arxiv.org/abs/2010.12042), achieves an AUC ROC of 0.7914. The winning solution in the [Riiid! Answer Correctness Prediction](https://www.kaggle.com/c/riiid-test-answer-prediction) competition achieved an AUC ROC of 0.820. This notebook achieves an AUC ROC of 0.776 implementing an approach similar to the state of the art approach, training for 25,000 steps. It demonstrates several techniques that may be useful to those getting started with the [google/trax library](https://github.com/google/trax) or deep learning in general. This notebook demonstrates how to:\r\n", + "\r\n", + "* Use BigQuery to perform feature engineering\r\n", + "* Create TFRecords with multiple sequences per record\r\n", + "* Modify the trax Transformer model to accommodate a knowledge tracing dataset:\r\n", + " * Utilize multiple encoder and decoder embeddings - aggregated either by concatenation or sum\r\n", + " * Include a custom metric - AUC ROC\r\n", + " * Utilize a combined padding and future mask\r\n", + "* Use trax's [gin-config](https://github.com/google/gin-config) integration to specify training parameters\r\n", + "* Display training progress using trax's tensorboard integration\r\n", + "\r\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/CalebEverett/riiid_transformer/blob/master/riiid-trax-transformer.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tuG_-VFcpxLc" + }, + "outputs": [], + "source": [ + "# Choose a location for your storage bucket and BigQuery dataset to minimize data egress charges. Once you have\r\n", + "# created them, if you restart your notebook you can run this to see where your colab is running\r\n", + "# and factory reset until you get a location that is near your data.\r\n", + "!curl ipinfo.io" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_SQN6SX89XNq" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vo5bzc9z7nw_" + }, + "outputs": [], + "source": [ + "# \r\n", + "!git clone https://github.com/google/trax.git\r\n", + "!pip install ./trax\r\n", + "!pip install -U pyarrow\r\n", + "!pip install -U google-cloud-bigquery google-cloud-bigquery-storage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0W7kto2g7Sfa" + }, + "outputs": [], + "source": [ + "from functools import partial\r\n", + "import json\r\n", + "import math\r\n", + "import os\r\n", + "from pathlib import Path\r\n", + "import subprocess\r\n", + "import sys\r\n", + "import time\r\n", + "\r\n", + "import gin\r\n", + "from google.cloud import storage, bigquery\r\n", + "from google.cloud.bigquery import LoadJobConfig, QueryJobConfig, \\\r\n", + " SchemaField, SourceFormat\r\n", + "import jax\r\n", + "from jax.config import config\r\n", + "import pandas as pd\r\n", + "import numpy as np\r\n", + "import requests\r\n", + "import sqlite3\r\n", + "import trax\r\n", + "from trax import fastmath\r\n", + "from trax import layers as tl\r\n", + "from trax.fastmath import numpy as tnp\r\n", + "import tensorflow as tf\r\n", + "from tqdm.notebook import tqdm\r\n", + "import zipfile\r\n", + "\r\n", + "# Create google credentials and store in drive\r\n", + "# https://colab.research.google.com/drive/1LWhrqE2zLXqz30T0a0JqXnDPKweqd8ET\r\n", + "# \r\n", + "# Create a config.json file with variables for:\r\n", + "# \"BUCKET\": \"\",\r\n", + "# \"BQ_DATASET\": \"\",\r\n", + "# \"KAGGLE_USERNAME\": \"\",\r\n", + "# \"KAGGLE_KEY\": \"\",\r\n", + "# \"PROJECT\": \"\",\r\n", + "# \"LOCATION\": \"\"\r\n", + "from google.colab import drive\r\n", + "\r\n", + "DRIVE = Path('/content/drive/My Drive')\r\n", + "PATH = 'riiid-transformer'\r\n", + "\r\n", + "if not DRIVE.exists():\r\n", + " drive.mount(str(DRIVE.parent))\r\n", + "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = str(DRIVE/PATH/'google.json')\r\n", + "\r\n", + "with open(str(DRIVE/PATH/'config.json')) as f:\r\n", + " CONFIG = json.load(f)\r\n", + " os.environ = {**os.environ, **CONFIG}\r\n", + "\r\n", + "from kaggle.api.kaggle_api_extended import KaggleApi\r\n", + "kaggle_api = KaggleApi()\r\n", + "kaggle_api.authenticate()\r\n", + "\r\n", + "AUTO = tf.data.experimental.AUTOTUNE\r\n", + "BUCKET = os.getenv('BUCKET', 'riiid-transformer')\r\n", + "BQ_DATASET = os.getenv('BQ_DATASET', 'my_data')\r\n", + "LOCATION = os.getenv('LOCATION', 'us-central1')\r\n", + "PROJECT = os.getenv('PROJECT', 'fastai-caleb')\r\n", + "\r\n", + "bucket = storage.Client(project=PROJECT).get_bucket(BUCKET)\r\n", + "dataset = bigquery.Dataset(f'{PROJECT}.{BQ_DATASET}')\r\n", + "bq_client = bigquery.Client(project=PROJECT, location=LOCATION)\r\n", + "\r\n", + "%matplotlib inline\r\n", + "from matplotlib import pyplot as plt\r\n", + "\r\n", + "%load_ext tensorboard\r\n", + "\r\n", + "gin.enter_interactive_mode()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vL0eRGAnyK9x" + }, + "source": [ + "## Control Panel" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YaAhPw-zv1la" + }, + "source": [ + "These variables can be set to True to run the code in the sections described or False to skip over them after they have been run for the first time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MNrBIpVPyPGX" + }, + "outputs": [], + "source": [ + "USE_TPU = False\r\n", + "DOWNLOAD_DATASET = False\r\n", + "LOAD_DATA_TO_BQ = False\r\n", + "PERFORM_FEATURE_ENGINEERING = False\r\n", + "TEST_FEATURE_ENGNEERING = False\r\n", + "CREATE_TFRECORDS = False\r\n", + "TEST_TFRECORDS = False\r\n", + "TRAIN_MODEL = False" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t8Jvva6lBRyI" + }, + "source": [ + "## Initialize TPU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PsczFYbe80ei" + }, + "outputs": [], + "source": [ + "if USE_TPU:\r\n", + " if 'TPU_DRIVER_MODE' not in globals():\r\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'\r\n", + " resp = requests.post(url)\r\n", + " TPU_DRIVER_MODE = 1\r\n", + "\r\n", + " config.FLAGS.jax_xla_backend = \"tpu_driver\"\r\n", + " config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\r\n", + " print(config.FLAGS.jax_backend_target)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GXP1CnQXBtzd" + }, + "source": [ + "## Download Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YSAnW-bzBzCE" + }, + "outputs": [], + "source": [ + "if DOWNLOAD_DATASET:\r\n", + " kaggle_api.competition_download_cli('riiid-test-answer-prediction')\r\n", + " with zipfile.ZipFile('riiid-test-answer-prediction.zip', 'r') as zip_ref:\r\n", + " zip_ref.extractall()\r\n", + " for f in ['train.csv', 'questions.csv', 'lectures.csv']:\r\n", + " bucket.blob(f).upload_from_filename(f)\r\n", + "\r\n", + "if False:\r\n", + " for f in tqdm(['train.csv', 'questions.csv', 'lectures.csv']):\r\n", + " bucket.blob(f).download_to_filename(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wM1_VVnm-61P" + }, + "source": [ + "## Create BigQuery Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Eo0iR8Y-5Sv" + }, + "outputs": [], + "source": [ + "if False:\r\n", + " delete_contents=False\r\n", + " bq_client.delete_dataset(BQ_DATASET, delete_contents=delete_contents)\r\n", + " print(f'Dataset {dataset.dataset_id} deleted from project {dataset.project}.')\r\n", + "\r\n", + "try:\r\n", + " dataset = bq_client.get_dataset(dataset.dataset_id)\r\n", + " print(f'Dataset {dataset.dataset_id} already exists '\r\n", + " f'in location {dataset.location} in project {dataset.project}.')\r\n", + "except:\r\n", + " dataset = bq_client.create_dataset(dataset)\r\n", + " print(f'Dataset {dataset.dataset_id} created '\r\n", + " f'in location {dataset.location} in project {dataset.project}.')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i7tZZN449eH-" + }, + "source": [ + "## Dtypes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qt70hdhk_j6p" + }, + "outputs": [], + "source": [ + "dtypes_orig = {\r\n", + " 'lectures': {\r\n", + " 'lecture_id': 'uint16',\r\n", + " 'tag': 'uint8',\r\n", + " 'part': 'uint8',\r\n", + " 'type_of': 'str',\r\n", + " },\r\n", + " 'questions': {\r\n", + " 'question_id': 'uint16',\r\n", + " 'bundle_id': 'uint16',\r\n", + " 'correct_answer': 'uint8',\r\n", + " 'part': 'uint8',\r\n", + " 'tags': 'str',\r\n", + " \r\n", + " },\r\n", + " 'train': {\r\n", + " 'row_id': 'int64',\r\n", + " 'timestamp': 'int64',\r\n", + " 'user_id': 'int32',\r\n", + " 'content_id': 'int16',\r\n", + " 'content_type_id': 'int8',\r\n", + " 'task_container_id': 'int16',\r\n", + " 'user_answer': 'int8',\r\n", + " 'answered_correctly': 'int8',\r\n", + " 'prior_question_elapsed_time': 'float32', \r\n", + " 'prior_question_had_explanation': 'bool'\r\n", + " }\r\n", + " \r\n", + "}\r\n", + "\r\n", + "dtypes_new = {\r\n", + " 'lectures': {},\r\n", + " 'questions': {\r\n", + " 'tags_array': 'str'\r\n", + " },\r\n", + " 'train': {\r\n", + " 'task_container_id_q': 'int16',\r\n", + " 'pqet_current': 'int32',\r\n", + " 'ts_delta': 'int32'\r\n", + " }\r\n", + "}\r\n", + "\r\n", + "dtypes = {}\r\n", + "for table_id in dtypes_orig:\r\n", + " dtypes[table_id] = {\r\n", + " **dtypes_orig[table_id],\r\n", + " **dtypes_new[table_id]\r\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zYIOhHEoDw-v" + }, + "source": [ + "### Big Query Table Schemas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q1LEgqZfDulc" + }, + "outputs": [], + "source": [ + "# \r\n", + "type_map = {\r\n", + " 'int64': 'INT64',\r\n", + " 'int32': 'INT64',\r\n", + " 'int16': 'INT64',\r\n", + " 'int8': 'INT64',\r\n", + " 'uint8': 'INT64',\r\n", + " 'uint16': 'INT64',\r\n", + " 'str': 'STRING',\r\n", + " 'bool': 'BOOL',\r\n", + " 'float32': 'FLOAT64'\r\n", + "}\r\n", + "\r\n", + "schemas_orig = {table: [SchemaField(f, type_map[t]) for f, t in\r\n", + " fields.items()] for table, fields in dtypes_orig.items()}\r\n", + "\r\n", + "schemas = {}\r\n", + "for table_id, fields in dtypes_new.items():\r\n", + " new_fields = [SchemaField(f, type_map[t]) for\r\n", + " f, t in fields.items() if 'array' not in f]\r\n", + " \r\n", + " new_array_feilds = [SchemaField(f, 'INT64', 'REPEATED') for\r\n", + " f, t in fields.items() if 'array' in f]\r\n", + "\r\n", + " new_fields += new_array_feilds\r\n", + "\r\n", + " schemas[table_id] = schemas_orig[table_id] + new_fields" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sv7wPwp2EJpH" + }, + "source": [ + "### Load Tables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EtBgHrBvC_H3" + }, + "outputs": [], + "source": [ + "def load_job_cb(future):\r\n", + " \"\"\"Prints update upon completion to output of last run cell.\"\"\"\r\n", + " \r\n", + " seconds = (future.ended - future.created).total_seconds()\r\n", + " print(f'Loaded {future.output_rows:,d} rows to table {future.job_id.split(\"_\")[0]} in '\r\n", + " f'{seconds:>4,.1f} sec, {int(future.output_rows / seconds):,d} per sec.')\r\n", + "\r\n", + "def load_csv_from_uri(table_id, schemas_orig):\r\n", + " full_table_id = f'{BQ_DATASET}.{table_id}'\r\n", + "\r\n", + " job_config = LoadJobConfig(\r\n", + " schema=schemas_orig[table_id],\r\n", + " source_format=SourceFormat.CSV,\r\n", + " skip_leading_rows=1\r\n", + " )\r\n", + "\r\n", + " uri = f'gs://{BUCKET}/{table_id}.csv'\r\n", + " load_job = bq_client.load_table_from_uri(uri, full_table_id,\r\n", + " job_config=job_config,\r\n", + " job_id_prefix=f'{table_id}_')\r\n", + " print(f'job {load_job.job_id} started')\r\n", + " load_job.add_done_callback(load_job_cb)\r\n", + " \r\n", + " return load_job" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "L44_o0NYEOcC" + }, + "outputs": [], + "source": [ + "if LOAD_DATA_TO_BQ:\r\n", + " for table_id in dtypes_orig:\r\n", + " lj = load_csv_from_uri(table_id, schemas_orig).result()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JUAg3Pz5ImSx" + }, + "source": [ + "### Update BiqQuery Schemas" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ym62FdoNgU8t" + }, + "source": [ + "Before performing feature engineering, we have to update the table schemas in Big Query to create columns for the new features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qByuVM7MIr8b" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " for table_id, schema in schemas.items():\r\n", + " table = bq_client.get_table(f'{BQ_DATASET}.{table_id}')\r\n", + " table.schema = schema\r\n", + " table = bq_client.update_table(table, ['schema'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oCHq9dJiFOPh" + }, + "source": [ + "## Feature Engineering" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d6tN2qREdc9A" + }, + "source": [ + "Using BigQuery for a dataset of 100 million rows is much faster than using local dataframes. In addition, you get to use the full power of SQL, including [window functions](https://cloud.google.com/bigquery/docs/reference/standard-sql/analytic-function-concepts), which are especially useful for time series feature engineering.\r\n", + "\r\n", + "Feature engineering for this problem is fairly minimal and includes:\r\n", + "* Replacing missing null values for `prior_question_elapsed_time` and `prior_question_had_explanation` in the train table\r\n", + "* Replacing one missing tag value in the questions table\r\n", + "* Recalcuating the `task_container_id` as `task_container_id_q` so that it excludes lecture records and increases monotonically with `timetamp` so that the calucations for elapsed time and time delta, which depend on values from the immediately prior and immediately succeeding records, are calculated correctly.\r\n", + "* Calculating `pqet_current`, the time it took on average to answer the questions in the current `task_container_id_q`.\r\n", + "* Calculating `ts_delta`, the elapsed time between the last `task_container_id_q` and the current one.\r\n", + "* Creating `folds` table, in which users are assigned to one of 20 folds.\r\n", + "* Creating a `tags_array` field in the questions table, that returns an array of six elements populated with the tags assigned to each questions, padded with zeros if there are less than six." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X2ynWZqPFnSj" + }, + "outputs": [], + "source": [ + "def done_cb(future):\r\n", + " seconds = (future.ended - future.started).total_seconds()\r\n", + " print(f'Job {future.job_id} finished in {seconds} seconds.')\r\n", + "\r\n", + "def run_query(query, job_id_prefix=None, wait=True,\r\n", + " use_query_cache=True):\r\n", + "\r\n", + " job_config = QueryJobConfig(\r\n", + " use_query_cache=use_query_cache)\r\n", + "\r\n", + " query_job = bq_client.query(query, job_id_prefix=job_id_prefix,\r\n", + " job_config=job_config)\r\n", + " print(f'Job {query_job.job_id} started.')\r\n", + " query_job.add_done_callback(done_cb)\r\n", + " if wait:\r\n", + " query_job.result()\r\n", + " \r\n", + " return query_job" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8Qo21D1ITicH" + }, + "outputs": [], + "source": [ + "def get_df_query_bqs(query, dtypes=None, fillna=None):\r\n", + " qj = bq_client.query(query)\r\n", + " df = qj.to_dataframe(create_bqstorage_client=True, progress_bar_type='tqdm_notebook')\r\n", + " if fillna is not None:\r\n", + " df = df.fillna(fillna)\r\n", + " try:\r\n", + " df = df.astype({c: dtypes.get(c, 'int32') for c in df.columns}) \r\n", + " except:\r\n", + " print('dtypes not applied.')\r\n", + " finally: \r\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N71-o9uQFSzV" + }, + "source": [ + "### Replace Missing Values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rnBL1HXxFWKX" + }, + "outputs": [], + "source": [ + "def update_missing_values(table_id='train', column_id=None, value=None):\r\n", + " return f\"\"\"\r\n", + " UPDATE {BQ_DATASET}.{table_id}\r\n", + " SET {column_id} = {value}\r\n", + " WHERE {column_id} is NULL;\r\n", + " \"\"\", sys._getframe().f_code.co_name + '_'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e0qBG2XrGIMB" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " qj = run_query(*update_missing_values('train', 'prior_question_elapsed_time', '0'))\r\n", + " qj = run_query(*update_missing_values('train', 'prior_question_had_explanation', 'false'))\r\n", + " qj = run_query(*update_missing_values('questions', 'tags', '\"188\"'))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "elZXRogqL-pr" + }, + "source": [ + "### Recalculate Task Container Ids for Questions Only" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Li0UdfY2MeOm" + }, + "outputs": [], + "source": [ + "def update_task_container_id(table_id='train',\r\n", + " column_id='task_container_id',\r\n", + " excl_lectures=True):\r\n", + " excl_lec = 'WHERE content_type_id = 0' if excl_lectures else ''\r\n", + " \r\n", + " return f\"\"\"\r\n", + " UPDATE {BQ_DATASET}.{table_id} t\r\n", + " SET {column_id} = target.calc\r\n", + " FROM (\r\n", + " SELECT row_id, DENSE_RANK()\r\n", + " OVER (\r\n", + " PARTITION BY user_id\r\n", + " ORDER BY timestamp\r\n", + " ) calc\r\n", + " FROM {BQ_DATASET}.{table_id}\r\n", + " {excl_lec}\r\n", + " ) target\r\n", + " WHERE target.row_id = t.row_id\r\n", + " \"\"\", sys._getframe().f_code.co_name + '_'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FGFisFdpMtGy" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " q = update_task_container_id(table_id='train',\r\n", + " column_id='task_container_id_q ',\r\n", + " excl_lectures=True)\r\n", + " qj = run_query(*q)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2HblfPhCG618" + }, + "source": [ + "### Calculate Current Question Elapsed Time and Timestamp Delta" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "29-ajMgjHEUl" + }, + "outputs": [], + "source": [ + "def update_pqet_current(table_id='train'):\r\n", + " return f\"\"\"\r\n", + " UPDATE {BQ_DATASET}.{table_id} t\r\n", + " SET t.pqet_current = CAST(p.pqet_current AS INT64)\r\n", + " FROM (\r\n", + " SELECT\r\n", + " row_id, LAST_VALUE(prior_question_elapsed_time) OVER (\r\n", + " PARTITION BY user_id ORDER BY task_container_id_q\r\n", + " RANGE BETWEEN 1 FOLLOWING AND 1 FOLLOWING) pqet_current\r\n", + " FROM {BQ_DATASET}.train \r\n", + " WHERE content_type_id = 0\r\n", + " ) p\r\n", + " WHERE t.row_id = p.row_id;\r\n", + " \r\n", + " UPDATE {BQ_DATASET}.{table_id}\r\n", + " SET pqet_current = 0\r\n", + " WHERE pqet_current IS NULL;\r\n", + " \r\n", + " \"\"\", sys._getframe().f_code.co_name + '_'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "juf9vDrzIF2W" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " qj = run_query(*update_pqet_current())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "E9LnKsjVLgRk" + }, + "outputs": [], + "source": [ + "def update_ts_delta(table_id='train'):\r\n", + " return f\"\"\"\r\n", + " UPDATE {BQ_DATASET}.{table_id} t\r\n", + " SET t.ts_delta = timestamp - p.ts_prior\r\n", + " FROM (\r\n", + " SELECT\r\n", + " row_id, LAST_VALUE(timestamp) OVER (\r\n", + " PARTITION BY user_id ORDER BY task_container_id_q\r\n", + " RANGE BETWEEN 1 PRECEDING AND 1 PRECEDING) ts_prior\r\n", + " FROM {BQ_DATASET}.train \r\n", + " WHERE content_type_id = 0\r\n", + " ) p\r\n", + " WHERE t.row_id = p.row_id;\r\n", + " \r\n", + " UPDATE {BQ_DATASET}.{table_id}\r\n", + " SET ts_delta = 0\r\n", + " WHERE ts_delta IS NULL;\r\n", + " \"\"\", sys._getframe().f_code.co_name + '_'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0-CEUJsoL1dC" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " qj = run_query(*update_ts_delta())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "99qz0H8Xb3i1" + }, + "source": [ + "### Create Folds Table\r\n", + "Assign users randomly to one of 20 folds. Store total records to facilitate filtering based on record count." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "N7UnAHF8cesC" + }, + "outputs": [], + "source": [ + "def create_table_folds(table_id='folds', n_folds=20):\r\n", + " return f\"\"\"\r\n", + " DECLARE f INT64;\r\n", + "\r\n", + " CREATE OR REPLACE TABLE {BQ_DATASET}.{table_id} (\r\n", + " user_id INT64,\r\n", + " fold INT64,\r\n", + " record_count INT64\r\n", + " );\r\n", + "\r\n", + " INSERT {BQ_DATASET}.{table_id} (user_id, fold, record_count)\r\n", + " SELECT f.user_id, CAST(FLOOR(RAND() * {n_folds}) AS INT64) fold, f.record_count\r\n", + " FROM (\r\n", + " SELECT user_id,\r\n", + " COUNT(row_id) record_count\r\n", + " FROM {BQ_DATASET}.train\r\n", + " WHERE content_type_id = 0\r\n", + " GROUP BY user_id\r\n", + " ) f\r\n", + " ORDER BY user_id;\r\n", + " \"\"\", sys._getframe().f_code.co_name + '_'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SVPio880dPSe" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " qj = run_query(*create_table_folds())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "14dQwOnzdolg" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " df_folds = get_df_query_bqs(f\"\"\"\r\n", + " SELECT *\r\n", + " FROM {BQ_DATASET}.folds\r\n", + " \"\"\",\r\n", + " dtypes=dtypes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R9Y9Xhwee6f7" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " df_folds.groupby('fold').count().user_id.plot(kind='bar', title='Count of Users by Fold');" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qciROEyIoowx" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " df_folds.groupby('fold').mean().record_count.plot(kind='bar', title='Average Records per User by Fold');" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q5zS5bWenJaj" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " df_fold_ac = get_df_query_bqs(f\"\"\"\r\n", + " SELECT fold, SUM(answered_correctly) ac_sum, COUNT(answered_correctly) rec_count\r\n", + " FROM {BQ_DATASET}.train\r\n", + " JOIN {BQ_DATASET}.folds\r\n", + " ON train.user_id = folds.user_id\r\n", + " GROUP BY fold\r\n", + " \"\"\",\r\n", + " dtypes=dtypes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3TGelsfEn7xf" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " df_fold_ac.rec_count.plot(kind='bar', title='Count of Records by Fold');" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a1kcjfOkoGU_" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " (df_fold_ac.ac_sum / df_fold_ac.rec_count).plot(kind='bar', title='Percent Answered Correctly by Fold');" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nqg9tqEnOa7l" + }, + "source": [ + "### Create Tags Array on Questions Table\r\n", + "We need the tags as an array later when we create TFRecords. We also increment by one and pad with zeros to a fixed length of 6 so that they can be concatentated as a feature for modeling." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wr4peSSWPHiW" + }, + "outputs": [], + "source": [ + "def update_tags_array(table_id='questions', column_id='tags_array'):\r\n", + " \r\n", + " return f\"\"\"\r\n", + " UPDATE {BQ_DATASET}.{table_id} q\r\n", + " SET {column_id} = tp.tags_fixed_len\r\n", + " FROM (\r\n", + " WITH tags_padded AS (\r\n", + " WITH tags_table AS (SELECT question_id, tags FROM {BQ_DATASET}.{table_id})\r\n", + " SELECT question_id, ARRAY_CONCAT(ARRAY_AGG(CAST(tag AS INT64) + 1), [0,0,0,0,0]) tags_array\r\n", + " FROM tags_table, UNNEST(SPLIT(tags, ' ')) as tag\r\n", + " GROUP BY question_id\r\n", + " )\r\n", + " SELECT question_id,\r\n", + " ARRAY(SELECT x FROM UNNEST(tags_array) AS x WITH OFFSET off WHERE off < 6 ORDER BY off) tags_fixed_len\r\n", + " FROM tags_padded\r\n", + " ) tp\r\n", + " WHERE tp.question_id = q.question_id\r\n", + " \"\"\", sys._getframe().f_code.co_name + '_'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "exf_kIXuRagG" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " qj = run_query(*update_tags_array())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pB2YEqhISnRz" + }, + "outputs": [], + "source": [ + "if PERFORM_FEATURE_ENGINEERING:\r\n", + " df_q = get_df_query_bqs('select * from my_data.questions', dtypes=dtypes)\r\n", + " print(df_q.head())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BrxKtvosXTfw" + }, + "source": [ + "## Feature Engineering Tests\r\n", + "* Features come back out of Biq Query with the same values they went in with\r\n", + "* `ts_delta` is equal to difference between timestamps on consecutive records\r\n", + "* `pqet_current` is equal to `prior_question_elapsed_time` from next record\r\n", + "* visually inspect distributions of `ts_delta` and `pqet_current`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ynTomTnKY7F3" + }, + "source": [ + "### Load Sample from train.csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c9aNhgV7Yfzw" + }, + "outputs": [], + "source": [ + "if TEST_FEATURE_ENGNEERING:\r\n", + " df_train_samp = pd.read_csv('train.csv', nrows=100000)\r\n", + " df_train_samp.prior_question_had_explanation = df_train_samp.prior_question_had_explanation.fillna(False).astype(bool)\r\n", + " df_train_samp.prior_question_elapsed_time = df_train_samp.prior_question_elapsed_time.fillna(0)\r\n", + " user_ids_samp = df_train_samp.user_id.unique()[:-1]\r\n", + " print(len(user_ids_samp))\r\n", + " df_train_samp = df_train_samp[df_train_samp.user_id.isin(user_ids_samp) & (df_train_samp.content_type_id == 0)].reset_index(drop=True)\r\n", + " print(len(df_train_samp))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "elJgSDpmY0q0" + }, + "source": [ + "### Pull sample of corresponding user_ids from BigQuery" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GvQF1yIzZPHE" + }, + "outputs": [], + "source": [ + "if TEST_FEATURE_ENGNEERING:\r\n", + " df_bq_samp = get_df_query_bqs(f\"\"\"\r\n", + " SELECT *\r\n", + " FROM {BQ_DATASET}.train\r\n", + " WHERE user_id IN ({(',').join(map(str, user_ids_samp))})\r\n", + " AND content_type_id = 0\r\n", + " ORDER BY user_id, timestamp, row_id\r\n", + " \"\"\",\r\n", + " dtypes=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gP3mDb-phsSt" + }, + "source": [ + "### Tests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TukjFI4YatpD" + }, + "outputs": [], + "source": [ + "if TEST_FEATURE_ENGNEERING:\r\n", + " # values in columns are the same between train.csv and bq\r\n", + " for c in df_train_samp.columns:\r\n", + " assert all(df_train_samp[c] == df_bq_samp[c]), f'{c} is not the same'\r\n", + "\r\n", + " # pqet_current pulls prior_question_elapsed_time back one task_container_id for each user\r\n", + " df_bq_samp_tst = df_bq_samp[['user_id', 'task_container_id_q', 'prior_question_elapsed_time', 'pqet_current']].groupby(['user_id', 'task_container_id_q']).max()\r\n", + "\r\n", + " for user_id in user_ids_samp:\r\n", + " assert all(df_bq_samp_tst.loc[user_id].pqet_current.shift(1).iloc[1:] == df_bq_samp_tst.loc[user_id].prior_question_elapsed_time.iloc[1:])\r\n", + "\r\n", + " # ts_delta equal to timestamp from current task_container_id_q minus timestamp from prior task_container_id_q\r\n", + " df_bq_samp_tst = df_bq_samp[['user_id', 'task_container_id_q', 'timestamp', 'ts_delta']].groupby(['user_id', 'task_container_id_q']).max()\r\n", + "\r\n", + " for user_id in user_ids_samp:\r\n", + " assert all((df_bq_samp_tst.loc[user_id].timestamp - df_bq_samp_tst.loc[user_id].timestamp.shift(1)).iloc[1:] == df_bq_samp_tst.loc[user_id].ts_delta.iloc[1:])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jZhQStC4bIa_" + }, + "outputs": [], + "source": [ + "if TEST_FEATURE_ENGNEERING:\r\n", + " df_bq_samp.pqet_current.hist();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4WrlU_qZbisC" + }, + "outputs": [], + "source": [ + "if TEST_FEATURE_ENGNEERING:\r\n", + " df_bq_samp.ts_delta.hist();" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qR33_34dpD5R" + }, + "source": [ + "## Create TFRecords" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NjZunnJNqG7U" + }, + "source": [ + "We are going to create a set of TFRecords with one user per record and one fold per file. We are going to include the following columns as features:\r\n", + "* `user_id` - this won't get used as a feature, but is included to able to tie back to original data\r\n", + "* `content_id` - incremented by one to reserve 0 for padding character\r\n", + "* `answered_correctly` - incremented by one to reserve 0 for padding character\r\n", + "* `part`\r\n", + "* `pqet_curret`\r\n", + "* `ts_delta`\r\n", + "* `tags` - already incremented by one with zeros as padding\r\n", + "* `task_container_id` - excluding lectures and already indexed to one\r\n", + "* `timestamp`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e3IhNw5C3epG" + }, + "outputs": [], + "source": [ + "def _int64_feature(value):\r\n", + " \r\n", + " if type(value) != type(list()):\r\n", + " value = [value]\r\n", + "\r\n", + " return tf.train.Feature(int64_list=tf.train.Int64List(value=value))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jBv7Qtjb3gr-" + }, + "outputs": [], + "source": [ + "def serialize_example(user_id, features):\r\n", + " \r\n", + " feature_names = ['content_id', 'answered_correctly', 'part', 'pqet_current', 'ts_delta', 'tags',\r\n", + " 'task_container_id', 'timestamp']\r\n", + " \r\n", + " feature = {'user_id': _int64_feature(user_id)}\r\n", + " \r\n", + " for i, n in enumerate(feature_names):\r\n", + " feature[n] = _int64_feature(features[i])\r\n", + "\r\n", + " return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8sG6HK_p3imX" + }, + "outputs": [], + "source": [ + "def parse_example(example):\r\n", + " \r\n", + " feature_names = {'content_id': tf.int32, 'answered_correctly': tf.int32, 'part': tf.int32,\r\n", + " 'pqet_current': tf.int32, 'ts_delta': tf.int64, 'tags': tf.int32,\r\n", + " 'task_container_id': tf.int32, 'timestamp': tf.int64}\r\n", + " \r\n", + " features = {'user_id': tf.io.FixedLenFeature([1], tf.int64)}\r\n", + " \r\n", + " for k, v in feature_names.items():\r\n", + " features[k] = tf.io.VarLenFeature(tf.int64)\r\n", + "\r\n", + " example = tf.io.parse_single_example(example, features)\r\n", + "\r\n", + " for k, v in feature_names.items():\r\n", + " example[k] = tf.cast(example[k].values, v)\r\n", + " \r\n", + " example['tags'] = tf.reshape(example['tags'], (tf.size(example['answered_correctly']), 6))\r\n", + "\r\n", + " return example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nh-44zO158H8" + }, + "outputs": [], + "source": [ + "def get_ds_tfrec_raw(folds=[0]):\r\n", + " file_pat = 'gs://{BUCKET}/tfrec/{f:02d}-*.tfrec'\r\n", + " file_pats = [file_pat.format(BUCKET=BUCKET, f=f) for f in folds]\r\n", + " options = tf.data.Options()\r\n", + "\r\n", + " ds = (tf.data.Dataset.list_files(file_pats)\r\n", + " .with_options(options)\r\n", + " .interleave(tf.data.TFRecordDataset, num_parallel_calls=AUTO)\r\n", + " .map(parse_example, num_parallel_calls=AUTO)\r\n", + " )\r\n", + " \r\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NEfOS7m2pKZp" + }, + "outputs": [], + "source": [ + "def get_df_tfrec(folds):\r\n", + " df_tfrec = get_df_query_bqs(f\"\"\"\r\n", + " SELECT fold, train.user_id, content_id + 1 content_id,\r\n", + " answered_correctly + 1 answered_correctly, part, pqet_current, ts_delta,\r\n", + " tags_array tags, task_container_id_q task_container_id, timestamp\r\n", + " FROM {BQ_DATASET}.train\r\n", + " JOIN {BQ_DATASET}.folds\r\n", + " ON train.user_id = folds.user_id\r\n", + " JOIN {BQ_DATASET}.questions\r\n", + " ON train.content_id = questions.question_id\r\n", + " WHERE fold IN ({(', ').join(map(str, folds))})\r\n", + " AND content_type_id = 0\r\n", + " ORDER BY user_id, timestamp, row_id\r\n", + " \"\"\",\r\n", + " dtypes=None)\r\n", + "\r\n", + " return df_tfrec" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HoYPoSKR3V5-" + }, + "outputs": [], + "source": [ + "def write_tfrecords(folds):\r\n", + " \r\n", + " df_tfrec = get_df_tfrec(folds)\r\n", + " \r\n", + " for f in folds:\r\n", + " groups_dict = (df_tfrec[df_tfrec.fold == f]\r\n", + " .groupby('user_id')\r\n", + " .apply(lambda r: (list(r['content_id'].values),\r\n", + " list(r['answered_correctly'].values),\r\n", + " list(r['part'].values),\r\n", + " list(r['pqet_current'].values.astype(np.int64)),\r\n", + " list(r['ts_delta'].values.astype(np.int64)),\r\n", + " list(np.concatenate(r['tags'].values)),\r\n", + " list(r['task_container_id'].values.astype(np.int64)),\r\n", + " list(r['timestamp'].values.astype(np.int64)),\r\n", + " ))).to_dict() \r\n", + " \r\n", + " out_path = f'gs://{BUCKET}/tfrec'\r\n", + " filename = f'{f:02d}-{len(groups_dict.keys())}.tfrec'\r\n", + " record_file = f'{out_path}/{filename}'\r\n", + "\r\n", + " with tf.io.TFRecordWriter(record_file) as writer:\r\n", + " for user_id, features in tqdm(groups_dict.items(), desc=f'Fold {f:02d}'):\r\n", + " writer.write(serialize_example(user_id, features))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "48jjQ7L9g1_M" + }, + "source": [ + "## Write TFRecords\r\n", + "\r\n", + "* Process in chunks to avoid running out of memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mpw-Nb7Dg8xL" + }, + "outputs": [], + "source": [ + "if CREATE_TFRECORDS:\r\n", + " fold_splits = np.array_split(np.arange(20), 10)\r\n", + " for folds in tqdm(fold_splits):\r\n", + " write_tfrecords(folds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Un9OMMVC4QQQ" + }, + "source": [ + "## Test TFRecords\r\n", + "\r\n", + "* Same number of users and records as in `df_folds`\r\n", + "* Values in tfrecords are the same as in original data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eIGmrsU28TJ4" + }, + "outputs": [], + "source": [ + "def test_tfrecord_folds(folds_test, n_sample=100):\r\n", + " pbar = tqdm(total=n_sample)\r\n", + " ds = get_ds_tfrec_raw(folds_test)\r\n", + " df = get_df_tfrec(folds_test)\r\n", + "\r\n", + " for b in ds.shuffle(10000).take(n_sample):\r\n", + " try:\r\n", + " for c in [c for c in df.columns if c not in ['tags', 'fold', 'user_id']]:\r\n", + " try:\r\n", + " assert all(df[df.user_id == b['user_id'].numpy()[0]][c] == b[c].numpy())\r\n", + " except:\r\n", + " print(f\"Error for user {b['user_id'].numpy()[0]}\")\r\n", + " user_tags = np.concatenate(df[df.user_id == b['user_id'].numpy()[0]].tags.values)\r\n", + " assert all(user_tags == (b['tags'].numpy().flatten()))\r\n", + " except:\r\n", + " print(f\"Error for user {b['user_id'].numpy()[0]}\")\r\n", + " finally:\r\n", + " pbar.update()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ByaOx2_23BJJ" + }, + "outputs": [], + "source": [ + "if TEST_TFRECORDS:\r\n", + " folds_test = list(range(20))\r\n", + " ds = get_ds_tfrec_raw(folds=folds_test)\r\n", + "\r\n", + " df_folds = get_df_query_bqs(f\"\"\"\r\n", + " SELECT *\r\n", + " FROM {BQ_DATASET}.folds\r\n", + " \"\"\",\r\n", + " dtypes=dtypes)\r\n", + "\r\n", + " user_ids = []\r\n", + " count = 0\r\n", + " for b in ds:\r\n", + " user_ids.append(b['user_id'].numpy()[0])\r\n", + " count += len(b['content_id'].numpy())\r\n", + "\r\n", + " assert len(set(user_ids)) == len(df_folds)\r\n", + " assert df_folds.record_count.sum() == count\r\n", + "\r\n", + " test_tfrecord_folds([10])\r\n", + "\r\n", + " b = next(iter(ds))\r\n", + " print(b)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OMfgCXo159d2" + }, + "source": [ + "## Dataset Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TcQ-dd1b6EKN" + }, + "outputs": [], + "source": [ + "@gin.configurable\r\n", + "def get_ds_tfrec(folds=None, max_len=None, min_len=None):\r\n", + " file_pat = 'gs://{BUCKET}/tfrec/{f:02d}-*.tfrec'\r\n", + " file_pats = [file_pat.format(BUCKET=BUCKET, f=f) for f in folds]\r\n", + " options = tf.data.Options()\r\n", + "\r\n", + " ds = (tf.data.Dataset.list_files(file_pats, shuffle=True)\r\n", + " .with_options(options)\r\n", + " .interleave(tf.data.TFRecordDataset, num_parallel_calls=AUTO)\r\n", + " .shuffle(10000)\r\n", + " .map(parse_example, num_parallel_calls=AUTO)\r\n", + " .filter(partial(filter_min_len, min_len=min_len))\r\n", + " .map(example_to_tuple, num_parallel_calls=AUTO)\r\n", + " .map(partial(trunc_seq, max_len=max_len), num_parallel_calls=AUTO)\r\n", + " .map(con_to_cat, num_parallel_calls=AUTO)\r\n", + " )\r\n", + "\r\n", + " ds = ds.repeat().prefetch(AUTO)\r\n", + " \r\n", + " def gen(generator=None):\r\n", + " del generator\r\n", + " for example in fastmath.dataset_as_numpy(ds):\r\n", + " yield example\r\n", + " \r\n", + " return gen" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "l20hHtYADfGp" + }, + "outputs": [], + "source": [ + "def filter_min_len(e, min_len):\r\n", + " return tf.size(e['content_id']) >= min_len" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2TcKGF8t70Gl" + }, + "outputs": [], + "source": [ + "def example_to_tuple(example):\r\n", + " return (example['content_id'], example['part'], example['tags'], example['task_container_id'],\r\n", + " example['answered_correctly'], example['pqet_current'], example['ts_delta'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5PjKM5Q57uS5" + }, + "outputs": [], + "source": [ + "def trunc_seq(*b, max_len=None):\r\n", + " \"\"\"Returns a sequence drawn randomly from available tokens with a max length\r\n", + " of max_len.\r\n", + " \"\"\"\r\n", + " \r\n", + " max_len = tf.constant(max_len)\r\n", + " seq_len = tf.size(b[0])\r\n", + " seq_end_min = tf.minimum(seq_len - 1, max_len)\r\n", + " seq_end = tf.maximum(max_len, tf.random.uniform((), seq_end_min, seq_len, dtype=tf.int32))\r\n", + " \r\n", + " def get_seq(m):\r\n", + " return m[seq_end-max_len:seq_end]\r\n", + " \r\n", + " return tuple(map(get_seq, b))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9aC7H12W7mRs" + }, + "outputs": [], + "source": [ + "# SAINT+ Elapsed Time = prior_question_elapsed_time and Lag Time = time_stamp_1 - timestamp_0\r\n", + "# Elapsed Time categorical - capped at 300 seconds, discrete value for each second\r\n", + "# Lag Time - discretized to minutes 0, 1, 2, 3, 4, 5, 10, 20, 30 ... 1440. 150 discrete values.\r\n", + "\r\n", + "ts_delta_lookup = tf.concat([tf.range(6, dtype=tf.int32), tf.repeat(5, 5)], axis=0)\r\n", + "\r\n", + "cat = 10\r\n", + "while cat < 1440:\r\n", + " ts_delta_lookup = tf.concat([ts_delta_lookup, tf.repeat(cat, 10)], axis=0)\r\n", + " cat += 10\r\n", + " \r\n", + "ts_delta_lookup = tf.concat([ts_delta_lookup, [1440]], axis=0)\r\n", + "\r\n", + "def con_to_cat(*b):\r\n", + " \r\n", + " def pqet_cat(e, vocab_size=None, val_min=None, val_max=None):\r\n", + " e = tf.clip_by_value(e, val_min, val_max)\r\n", + " val_range = val_max - val_min\r\n", + " e = tf.cast((e - val_min) * (vocab_size - 1) / val_range, tf.int32)\r\n", + " return e\r\n", + " \r\n", + " def ts_delta_cat(e):\r\n", + " val_max = tf.cast(tf.reduce_max(ts_delta_lookup) * 60000, tf.float64)\r\n", + " e = tf.clip_by_value(tf.cast(e, tf.float64), 0, val_max)\r\n", + " e = tf.cast(e / 60000, tf.int32)\r\n", + " e = tf.gather(ts_delta_lookup, e)\r\n", + " return e\r\n", + " \r\n", + " pqet = pqet_cat(b[-2], vocab_size=300, val_min=0, val_max=300000)\r\n", + " ts_delta = ts_delta_cat(b[-1])\r\n", + " \r\n", + " return tuple((*b[:-2], pqet, ts_delta))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cWKX9-WHNIdJ" + }, + "source": [ + "## Metrics Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BUC43hA69diL" + }, + "outputs": [], + "source": [ + "def RocAucScore(num_thresholds=100, pos_label=2):\r\n", + " def f(y_score, y_true, weight): \r\n", + " weight = tnp.expand_dims(tnp.ravel(weight), -1)\r\n", + " \r\n", + " softmax=tl.Softmax(axis=-1)\r\n", + " y_score = tnp.ravel(softmax(y_score)[:, :, -1])\r\n", + " y_score = tnp.expand_dims(y_score, -1)\r\n", + " y_true = tnp.expand_dims(tnp.ravel(y_true) == pos_label, -1).astype(tnp.float32)\r\n", + " \r\n", + " thresholds = tnp.expand_dims(tnp.linspace(1, 0, num_thresholds), 0)\r\n", + " \r\n", + " threshold_counts = y_score > thresholds\r\n", + " \r\n", + " tps = tnp.logical_and(threshold_counts, y_true)\r\n", + " fps = tnp.logical_and(threshold_counts, tnp.logical_not(y_true))\r\n", + " \r\n", + " tps = tnp.sum(tps * weight, axis=0)\r\n", + " fps = tnp.sum(fps * weight, axis=0)\r\n", + " \r\n", + " tpr = tps / tps[-1]\r\n", + " fpr = fps / fps[-1]\r\n", + " \r\n", + " return tnp.trapz(tpr, fpr)\r\n", + " \r\n", + " return tl.Fn('RocAucScore', f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LtbnwInF9fMf" + }, + "outputs": [], + "source": [ + "metrics = {\r\n", + " 'loss': tl.WeightedCategoryCrossEntropy(),\r\n", + " 'accuracy': tl.WeightedCategoryAccuracy(),\r\n", + " 'sequence_accuracy': tl.MaskedSequenceAccuracy(),\r\n", + " 'auc_all': RocAucScore(),\r\n", + " 'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum())\r\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WA_-d9V_NN-h" + }, + "source": [ + "## Model Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QK7Kf4AM9u9P" + }, + "outputs": [], + "source": [ + "@gin.configurable\r\n", + "@tl.assert_shape('bl->b1ll')\r\n", + "def PaddingFutureMask(pad=0, block_self=False, tid=True, pad_end=False):\r\n", + " def f(x):\r\n", + " mask_pad = tnp.logical_not(tnp.equal(x, 0))[:, tnp.newaxis, tnp.newaxis, :]\r\n", + " \r\n", + " x_new = x\r\n", + " if pad_end:\r\n", + " x_new = tnp.where(tnp.equal(x, 0), tnp.max(x), x)\r\n", + " \r\n", + " if tid:\r\n", + " mask_future = x_new[:, :, tnp.newaxis] >= x_new[:, tnp.newaxis, :] + block_self\r\n", + " mask_future = mask_future[:, tnp.newaxis, :, :]\r\n", + " else:\r\n", + " mask_future = tnp.arange(x.shape[-1])[tnp.newaxis, tnp.newaxis, :, tnp.newaxis] \\\r\n", + " >= tnp.arange(x.shape[-1])[tnp.newaxis, :]\r\n", + " \r\n", + " return tnp.logical_and(mask_future, mask_pad)\r\n", + " \r\n", + " return tl.Fn(f'PaddingFutureMask({pad})', f)\r\n", + "\r\n", + "\r\n", + "# the only thing different here is the shape assertions to accomodate the change\r\n", + "# in mask shape from b11l to b1ll\r\n", + "\r\n", + "@tl.assert_shape('bld,b1ll->bld,b1ll')\r\n", + "@gin.configurable\r\n", + "def KTAttention(d_feature, n_heads=1, dropout=0.0, mode='train'):\r\n", + " return tl.Serial(\r\n", + " tl.Select([0, 0, 0]),\r\n", + " tl.AttentionQKV(\r\n", + " d_feature, n_heads=n_heads, dropout=dropout, mode=mode),\r\n", + " )\r\n", + "\r\n", + "def my_add_loss_weights(generator, id_to_mask=None):\r\n", + " for example in generator:\r\n", + " weights = (example[0] != id_to_mask).astype(tnp.float32)\r\n", + " yield (*example, weights)\r\n", + "\r\n", + "@gin.configurable\r\n", + "def KTAddLossWeights(id_to_mask=0): # pylint: disable=invalid-name\r\n", + " return lambda g: my_add_loss_weights(g, id_to_mask=id_to_mask)\r\n", + "\r\n", + "def trim_tags(generator):\r\n", + " for example in generator:\r\n", + " # content_id, part, tags, tid, ac, pqet, ts_delta\r\n", + " yield (example[0], example[1], example[2][:, :, :6], example[3], example[4], example[5], example[6])\r\n", + "\r\n", + "@gin.configurable\r\n", + "def TrimTags():\r\n", + " return lambda g: trim_tags(g)\r\n", + "\r\n", + "@gin.configurable\r\n", + "def KTPositionalEncoder(max_position=10000.0, d_model=512, tid=False): \r\n", + " \"\"\"This is set up to perform standard positional encoding based on the\r\n", + " position in the sequence, but also to calculate position based on the\r\n", + " id of the task container to which the question belongs.\r\n", + " \"\"\"\r\n", + " def f(inputs):\r\n", + " # whether or not to use task_container_id or seq position\r\n", + " if tid:\r\n", + " position = tnp.expand_dims(inputs.astype(tnp.float32), -1)\r\n", + " else:\r\n", + " position = tnp.arange(inputs.shape[1])\r\n", + " \r\n", + " position = position.astype(tnp.float32)[tnp.newaxis, :, tnp.newaxis]\r\n", + "\r\n", + " i = tnp.expand_dims(tnp.arange(d_model, dtype=tnp.float32), 0)\r\n", + "\r\n", + " angles = 1 / tnp.power(max_position, (2 * (i // 2)) /\r\n", + " tnp.array(d_model, dtype=tnp.float32))\r\n", + "\r\n", + " angle_rads = position * angles\r\n", + "\r\n", + " # apply sin to even index in the array\r\n", + " sines = tnp.sin(angle_rads[:, :, 0::2])\r\n", + " # apply cos to odd index in the array\r\n", + " cosines = tnp.cos(angle_rads[:, :, 1::2])\r\n", + "\r\n", + " pos_encoding = tnp.concatenate([sines, cosines], axis=-1)\r\n", + "\r\n", + " return pos_encoding\r\n", + "\r\n", + " return tl.Fn('KTPositionalEncoder', f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Pmo6yeiXkBAQ" + }, + "outputs": [], + "source": [ + "@gin.configurable\r\n", + "def KTTransformer(d_model,\r\n", + " d_input,\r\n", + " d_part,\r\n", + " d_tags,\r\n", + " d_out,\r\n", + " d_pqet,\r\n", + " d_ts_delta,\r\n", + " d_tid,\r\n", + " embed_concat=False,\r\n", + " d_ff=2048,\r\n", + " n_encoder_layers=6,\r\n", + " n_decoder_layers=6,\r\n", + " n_heads=8,\r\n", + " max_len=2048,\r\n", + " dropout=0.1,\r\n", + " dropout_shared_axes=None,\r\n", + " mode='train',\r\n", + " ff_activation=tl.Relu):\r\n", + " \r\n", + " def Embedder(vocab_size, d_embed): # tokens --> vectors\r\n", + " return [\r\n", + " tl.Embedding(vocab_size, d_embed),\r\n", + " tl.Dropout(\r\n", + " rate=dropout, shared_axes=dropout_shared_axes, mode=mode),\r\n", + " ]\r\n", + "\r\n", + " # Encoder Embeddings\r\n", + " in_embedder = Embedder(*d_input)\r\n", + " part_embedder = Embedder(*d_part)\r\n", + " # Keeps the tags in the data batch tuple, but drops it if it\r\n", + " # isn't included in the embeddings.\r\n", + " if d_tags is not None:\r\n", + " tags_embedder = tl.Serial(Embedder(*d_tags), tl.Sum(axis=-2))\r\n", + " else:\r\n", + " tags_embedder = tl.Drop()\r\n", + " in_pos_encoder = KTPositionalEncoder(*d_tid)\r\n", + "\r\n", + " # Decoder Embeddings\r\n", + " out_embedder = Embedder(*d_out)\r\n", + " pqet_embedder = Embedder(*d_pqet)\r\n", + " ts_delta_embedder = Embedder(*d_ts_delta)\r\n", + " out_pos_encoder = KTPositionalEncoder(*d_tid)\r\n", + "\r\n", + " encoder_mode = 'eval' if mode == 'predict' else mode\r\n", + "\r\n", + " in_encoder = [tl.Parallel(in_embedder, part_embedder, tags_embedder, in_pos_encoder)]\r\n", + " out_encoder = [tl.Parallel(out_embedder, pqet_embedder, ts_delta_embedder, out_pos_encoder)]\r\n", + " \r\n", + " if embed_concat:\r\n", + " if d_tags is not None:\r\n", + " in_encoder += [tl.Concatenate(n_items=3), tl.Add()]\r\n", + " else:\r\n", + " in_encoder += [tl.Concatenate(n_items=2), tl.Add()]\r\n", + " out_encoder += [tl.Concatenate(n_items=3), tl.Add()]\r\n", + " else:\r\n", + " if d_tags is not None:\r\n", + " in_encoder += [tl.Add(), tl.Add(), tl.Add()]\r\n", + " else:\r\n", + " in_encoder += [tl.Add(), tl.Add()]\r\n", + " out_encoder += [tl.Add(), tl.Add(), tl.Add()]\r\n", + "\r\n", + " encoder_blocks = [\r\n", + " _KTEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,\r\n", + " mode, ff_activation)\r\n", + " for i in range(n_encoder_layers)]\r\n", + "\r\n", + " encoder = tl.Serial(\r\n", + " in_encoder,\r\n", + " encoder_blocks,\r\n", + " tl.LayerNorm()\r\n", + " )\r\n", + "\r\n", + " encoder_decoder_blocks = [\r\n", + " _KTEncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,\r\n", + " mode, ff_activation)\r\n", + " for i in range(n_decoder_layers)]\r\n", + "\r\n", + " # output tuple - leading number is max index \r\n", + " return tl.Serial( # 7: 0:tok_e 1:tok_p 2:tok_t 3:tok_tid 4:tok_d 5:tok_pq, 6:tok_tsd 7:wts_l \r\n", + " tl.Select([0, 1, 2, 3, 3, 3, # 10: 0:tok_e 1:tok_p 2:tok_t 3:tok_tid 4:tok_tid 5: tok_tid\r\n", + " 4, 5, 6, 4]), # 6:tok_d 7:tok`_pq, 8:tok_tsd 9:tok_d 10:wts_l\r\n", + "\r\n", + " # Encode.\r\n", + " tl.Parallel(\r\n", + " tl.Select([0, 1, 2, 3]),\r\n", + " PaddingFutureMask(tid=True)\r\n", + " ), # 10: tok_e tok_p tok_t tok_tid mask_combined tok_tid tok_d tok_pq tok_tsd tok_d wts_l\r\n", + " encoder, # 7: vec_e mask_combined tok_tid tok_d tok_pq tok_tsd tok_d wts_l\r\n", + " # Decode.\r\n", + " tl.Select([3, 4, 5, 2, 2, 0]), # 7: tok_d tok_pq tok_tsd tok_tid tok_tid vec_e tok_d wts_l\r\n", + " tl.Parallel(\r\n", + " tl.ShiftRight(mode=mode),\r\n", + " tl.ShiftRight(mode=mode), \r\n", + " tl.ShiftRight(mode=mode),\r\n", + " tl.ShiftRight(mode=mode),\r\n", + " tl.Serial(tl.ShiftRight(),\r\n", + " PaddingFutureMask(tid=False)),\r\n", + " ), # 7: tok_d tok_pq tok_tsd tok_tid mask_combined vec_e tok_d wts_l \r\n", + " out_encoder, # 4: vec_d mask_combined vec_e tok_d wts_l\r\n", + " encoder_decoder_blocks, # 4: vec_d mask_combined vec_e tok_d wts_l\r\n", + " tl.LayerNorm(), # 4: vec_d mask_combined vec_e tok_d wts_l\r\n", + "\r\n", + " # Map to output vocab.\r\n", + " tl.Select([0], n_in=3), # 3: vec_d tok_d wts_l\r\n", + " tl.Dense(d_out[0]), # vec_d .....\r\n", + " )\r\n", + "\r\n", + "\r\n", + "def _KTEncoderBlock(d_model, d_ff, n_heads,\r\n", + " dropout, dropout_shared_axes, mode, ff_activation):\r\n", + " \"\"\"Same as the default, but changes attention layer to KTAttention to \r\n", + " accept a combined padding and future mask.\r\n", + " \"\"\"\r\n", + " \r\n", + " attention = KTAttention(\r\n", + " d_model, n_heads=n_heads, dropout=dropout, mode=mode)\r\n", + "\r\n", + " feed_forward = _KTFeedForwardBlock(\r\n", + " d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)\r\n", + "\r\n", + " dropout_ = tl.Dropout(\r\n", + " rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", + "\r\n", + " return [\r\n", + " tl.Residual(\r\n", + " tl.LayerNorm(),\r\n", + " attention,\r\n", + " dropout_,\r\n", + " ),\r\n", + " tl.Residual(\r\n", + " feed_forward\r\n", + " ),\r\n", + " ]\r\n", + "\r\n", + "def _KTEncoderDecoderBlock(d_model, d_ff, n_heads,\r\n", + " dropout, dropout_shared_axes, mode, ff_activation):\r\n", + " \"\"\"Same as the default, but changes the first layer to KTAttention to \r\n", + " accept a combined padding and future mask.\r\n", + " \"\"\"\r\n", + " def _Dropout():\r\n", + " return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", + "\r\n", + " attention = KTAttention(\r\n", + " d_model, n_heads=n_heads, dropout=dropout, mode=mode)\r\n", + "\r\n", + " attention_qkv = tl.AttentionQKV(\r\n", + " d_model, n_heads=n_heads, dropout=dropout, mode=mode)\r\n", + "\r\n", + " feed_forward = _KTFeedForwardBlock(\r\n", + " d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)\r\n", + "\r\n", + " return [ # vec_d masks vec_e\r\n", + " tl.Residual(\r\n", + " tl.LayerNorm(), # vec_d ..... .....\r\n", + " attention, # vec_d ..... .....\r\n", + " _Dropout(), # vec_d ..... .....\r\n", + " ),\r\n", + " tl.Residual(\r\n", + " tl.LayerNorm(), # vec_d ..... .....\r\n", + " tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e\r\n", + " attention_qkv, # vec_d masks vec_e\r\n", + " _Dropout(), # vec_d masks vec_e\r\n", + " ),\r\n", + " tl.Residual(\r\n", + " feed_forward # vec_d masks vec_e\r\n", + " ),\r\n", + " ]\r\n", + "\r\n", + "def _KTFeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes,\r\n", + " mode, activation):\r\n", + " \"\"\"Same as default.\r\n", + " \"\"\"\r\n", + " dropout_middle = tl.Dropout(\r\n", + " rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", + " dropout_final = tl.Dropout(\r\n", + " rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", + "\r\n", + " return [\r\n", + " tl.LayerNorm(),\r\n", + " tl.Dense(d_ff),\r\n", + " activation(),\r\n", + " dropout_middle,\r\n", + " tl.Dense(d_model),\r\n", + " dropout_final,\r\n", + " ]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RFlp9RAINR4d" + }, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5nj3MX9D97Nz" + }, + "outputs": [], + "source": [ + "# Configure hyperparameters.\r\n", + "\r\n", + "total_steps = 10000\r\n", + "\r\n", + "gin.clear_config()\r\n", + "gin.parse_config(f\"\"\"\r\n", + "import trax.layers\r\n", + "import trax.models\r\n", + "import trax.optimizers\r\n", + "import trax.data.inputs\r\n", + "import trax.supervised.trainer_lib\r\n", + "\r\n", + "# Parameters that will vary between experiments:\r\n", + "# ==============================================================================\r\n", + "# min_len = 12\r\n", + "# max_len = 64\r\n", + "# d_model = 512 # need to make sure this works with concat embeddings\r\n", + "# d_ff = 256\r\n", + "# n_encoder_layers = 2\r\n", + "# n_decoder_layers = 2\r\n", + "# n_heads = 2\r\n", + "# dropout = 0.0\r\n", + "\r\n", + "min_len = 12\r\n", + "max_len = 256\r\n", + "d_model = 512 # need to make sure this works with concat embeddings\r\n", + "d_ff = 1024\r\n", + "n_encoder_layers = 6\r\n", + "n_decoder_layers = 6\r\n", + "n_heads = 8\r\n", + "dropout = 0.1\r\n", + "\r\n", + "# Set to True to aggregate embeddings by concatenation. If set\r\n", + "# to False aggregation will be by sum.\r\n", + "embed_concat = True\r\n", + "\r\n", + "# (Vocab, depth) Uncomment to use with aggregation by concatenation.\r\n", + "d_input = (13500, 384)\r\n", + "d_part = (8, 8)\r\n", + "d_tags = (189, 120)\r\n", + "\r\n", + "# (Vocab, depth) Uncomment to use with aggregation by concatenation.\r\n", + "d_out = (3, 384)\r\n", + "d_pqet = (300, 64)\r\n", + "d_ts_delta = (150, 64)\r\n", + "\r\n", + "# Used for positional encodings if not None. Positional encoding based\r\n", + "# on sequence in batch if None.\r\n", + "d_tid = (10000, %d_model)\r\n", + "\r\n", + "# d_input = (13500, %d_model)\r\n", + "# d_part = (8, %d_model)\r\n", + "# d_tags = (189, %d_model)\r\n", + "# # d_tags = None\r\n", + "# d_out = (3, %d_model)\r\n", + "# d_pqet = (300, %d_model)\r\n", + "# d_ts_delta = (150, %d_model)\r\n", + "# d_tid = (10000, %d_model)\r\n", + "\r\n", + "total_steps = {total_steps}\r\n", + "\r\n", + "# Parameters for learning rate schedule:\r\n", + "# ==============================================================================\r\n", + "warmup_and_rsqrt_decay.n_warmup_steps = 3000\r\n", + "warmup_and_rsqrt_decay.max_value = 0.001\r\n", + "\r\n", + "# multifactor.constant = 0.01\r\n", + "# multifactor.factors = 'constant * linear_warmup * cosine_decay'\r\n", + "# multifactor.warmup_steps = 4000\r\n", + "# multifactor.steps_per_cycle = %total_steps\r\n", + "# multifactor.minimum = .0001\r\n", + "\r\n", + "# Parameters for Adam:\r\n", + "# ==============================================================================\r\n", + "# Adam.weight_decay_rate=0.0\r\n", + "Adam.b1 = 0.9\r\n", + "Adam.b2 = 0.999\r\n", + "Adam.eps = 1e-8\r\n", + "\r\n", + "# Parameters for input pipeline:\r\n", + "# ==============================================================================\r\n", + "get_ds_tfrec.min_len = %min_len\r\n", + "get_ds_tfrec.max_len = %max_len\r\n", + "train/get_ds_tfrec.folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]\r\n", + "eval/get_ds_tfrec.folds = [19]\r\n", + "\r\n", + "BucketByLength.boundaries = [32, 64, 128]\r\n", + "BucketByLength.batch_sizes = [512, 256, 128, 64]\r\n", + "# BucketByLength.batch_sizes = [16, 8, 4, 2]\r\n", + "\r\n", + "BucketByLength.strict_pad_on_len = True\r\n", + "\r\n", + "KTAddLossWeights.id_to_mask = 0\r\n", + "\r\n", + "train/make_additional_stream.stream = [\r\n", + " @train/get_ds_tfrec(),\r\n", + " @BucketByLength(),\r\n", + " @TrimTags(),\r\n", + " @KTAddLossWeights()\r\n", + "]\r\n", + "\r\n", + "eval/make_additional_stream.stream = [\r\n", + " @eval/get_ds_tfrec(),\r\n", + " @BucketByLength(),\r\n", + " @TrimTags(),\r\n", + " @KTAddLossWeights()\r\n", + "]\r\n", + "\r\n", + "make_inputs.train_stream = @train/make_additional_stream()\r\n", + "make_inputs.eval_stream = @eval/make_additional_stream()\r\n", + "\r\n", + "# Parameters for KTPositionalEncoder:\r\n", + "# ==============================================================================\r\n", + "KTPositionalEncoder.d_model = %d_model\r\n", + "\r\n", + "# Set to True to calculate positional encodings based on position in orginal\r\n", + "# full length sequence, False to be based on position in batch sequence.\r\n", + "KTPositionalEncoder.tid = False\r\n", + "\r\n", + "# Parameters for PaddingFutureMaske:\r\n", + "# ==============================================================================\r\n", + "PaddingFutureMask.pad_end = False\r\n", + "\r\n", + "# Set to True to calculate future mask based on task container id (questions\r\n", + "# are delivered to users in groups identified by task_container id) or False\r\n", + "# to be based next question only.\r\n", + "PaddingFutureMask.tid = False\r\n", + "\r\n", + "# Parameters for KTTransformer:\r\n", + "# ==============================================================================\r\n", + "KTTransformer.d_model = %d_model\r\n", + "KTTransformer.d_input = %d_input\r\n", + "KTTransformer.d_part = %d_part\r\n", + "KTTransformer.d_tags = %d_tags\r\n", + "KTTransformer.d_out = %d_out\r\n", + "KTTransformer.d_pqet = %d_pqet\r\n", + "KTTransformer.d_ts_delta = %d_ts_delta\r\n", + "KTTransformer.d_tid = %d_tid\r\n", + "KTTransformer.embed_concat = %embed_concat\r\n", + "KTTransformer.d_ff = %d_ff\r\n", + "KTTransformer.n_encoder_layers = %n_encoder_layers\r\n", + "KTTransformer.n_decoder_layers = %n_decoder_layers\r\n", + "KTTransformer.n_heads = %n_heads\r\n", + "KTTransformer.dropout = %dropout\r\n", + "\r\n", + "# Parameters for train:\r\n", + "# ==============================================================================\r\n", + "train.inputs = @make_inputs\r\n", + "train.eval_frequency = 200\r\n", + "train.eval_steps = 20\r\n", + "train.checkpoints_at = {list(range(0,total_steps + 1, 2000))}\r\n", + "train.optimizer = @trax.optimizers.Adam\r\n", + "train.steps = %total_steps\r\n", + "train.model = @KTTransformer\r\n", + "train.lr_schedule_fn = @trax.supervised.lr_schedules.warmup_and_rsqrt_decay\r\n", + "\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PP46-cEAB4i_" + }, + "outputs": [], + "source": [ + "if False:\r\n", + " inputs = trax.data.inputs.make_inputs()\r\n", + " train_stream = inputs.train_stream(trax.fastmath.device_count())\r\n", + " train_eval_stream = inputs.train_eval_stream(trax.fastmath.device_count())\r\n", + " b = next(train_stream)\r\n", + " for i, m in enumerate(b):\r\n", + " print(i, m.shape)\r\n", + " b" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8OmbSWj5Cvt2" + }, + "outputs": [], + "source": [ + "if False:\r\n", + " model = KTTransformer()\r\n", + " model.init(trax.shapes.signature(b))\r\n", + " outs = model(b)\r\n", + " for i, m in enumerate(outs):\r\n", + " print(i, m.shape)\r\n", + " outs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dcHfhEEFNXXJ" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XGIfCGL9GprB" + }, + "outputs": [], + "source": [ + "run_no = 0\r\n", + "prefix = f'model_runs/{run_no:02d}'\r\n", + "output_dir = f'gs://{BUCKET}/{prefix}'\r\n", + "log_dir = output_dir[:-3]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5rhH1YNVHEPO" + }, + "outputs": [], + "source": [ + "%tensorboard --logdir $log_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wQcrZYSTGMhX" + }, + "outputs": [], + "source": [ + "if TRAIN_MODEL:\r\n", + " if False:\r\n", + " init_checkpoint = f'{output_dir}/model.pkl.gz'\r\n", + " else:\r\n", + " bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix)))\r\n", + "\r\n", + " loop = trax.supervised.trainer_lib.train(output_dir, metrics=metrics)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "machine_shape": "hm", + "name": "Knowledge Tracing Transformer", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/trax/examples/README.md b/resources/examples/ipynb/README.md similarity index 100% rename from trax/examples/README.md rename to resources/examples/ipynb/README.md diff --git a/resources/examples/ipynb/models/reformer/image_generation.ipynb b/resources/examples/ipynb/models/reformer/image_generation.ipynb new file mode 100644 index 000000000..28d2487cb --- /dev/null +++ b/resources/examples/ipynb/models/reformer/image_generation.ipynb @@ -0,0 +1,412 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Reformer: Image Generation", + "provenance": [], + "collapsed_sections": [ + "udDs_biH0n5U" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "udDs_biH0n5U", + "colab_type": "text" + }, + "source": [ + "#### Copyright 2020 Google LLC." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WPY-OyyM0pSs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "psnUF-8c02o_", + "colab_type": "text" + }, + "source": [ + "# Reformer: Image Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/image_generation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1lnRd_IoERdk", + "colab_type": "text" + }, + "source": [ + "This notebook was designed to run on TPU.\n", + "\n", + "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8PluCmWbZIpJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install JAX. This custom build raises the TPU timeout threshold, because the\n", + "# default limit of 2 minutes is too short for sampling very long sequences.\n", + "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", + "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", + "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", + "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", + "\n", + "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", + "import requests\n", + "import os\n", + "if 'TPU_DRIVER_MODE' not in globals():\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", + " resp = requests.post(url)\n", + " TPU_DRIVER_MODE = 1\n", + "\n", + "# The following is required to use TPU Driver as JAX's backend.\n", + "from jax.config import config\n", + "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", + "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", + "print(config.FLAGS.jax_backend_target)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yiPdBenoZwH6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", + "\n", + "from tensorflow.compat.v1.io.gfile import GFile\n", + "import gin\n", + "import os\n", + "import jax\n", + "import trax\n", + "from trax.models.beam_search import Search\n", + "from trax.supervised import inputs\n", + "\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "\n", + "from scipy.special import softmax" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yyxRk75iaAap", + "colab_type": "code", + "colab": {} + }, + "source": [ + "%matplotlib inline\n", + "from matplotlib import pyplot as plt" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FQ89jHCYfhpg" + }, + "source": [ + "## Load example data and model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qBvuw2h85WXE", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Normally we train on the full imagenet64 training set, which is quite large so\n", + "# we won't be loading it from this notebook. Instead, let's just load a few PNG\n", + "# images to use in our data pipeline.\n", + "DATA = []\n", + "for i in range(8):\n", + " img = plt.imread(GFile('gs://trax-ml/reformer/img{}.png'.format(i), 'rb'))\n", + " # Convert from RGBA floating-point to RGB integer representation.\n", + " img = np.asarray(img[:, :, :3] * 255, dtype=np.int32)\n", + " DATA.append(img)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "oBZh0Q2UEiaB", + "colab_type": "code", + "outputId": "d5adcac0-6f76-4c56-e6ef-74becaca87be", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 130 + } + }, + "source": [ + "# We can examine one of the images to make sure we've loaded it correctly.\n", + "plt.figure(figsize=(1.5, 1.5))\n", + "plt.axis('off')\n", + "plt.imshow(DATA[0])" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 5 + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAF8AAABfCAYAAACOTBv1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO29eaxk2X3f9znn3KXq1v72pdfp6Vk4\nC3tmyBluoklKpixSlk0pcbzJQhIkkaXAiWUpSOTEie0IiOEkBmwkCBAHhg1FSGzSlkjKDCVRpLgP\nZzjDmeF0z0z36/3129+rV3vde885+eN3q17T0jQ9DQQtIH2ABrrq3brL7/zOb/n+vud3lfee++Pe\nDH2vb+D/z+O+8O/huC/8ezjuC/8ejvvCv4cjuNMf/4+/9Re9wwHgnEZpj3MK4yRCyrVHe4V3coz3\n4HAYZ7HFvGqlOH/+Im988xWSigFAeYdTClNcxwIKj0eR2+JLL/+0kfMoLV8EPiBTclA7s6RDeODX\nPwtAZTlm41aXkw8sYm9clntkQDL4PsumDUCvvU8zewNIcPkYgDTz9IYZaVexsyXPEgaas6djMiu/\nG3RLHB7mHF6LKPcrcq5RjjKKB463ADA5ZB6+8p03+NDTpwCoJyF/+/98Xr1j4Vtl8E5+pxTgFRpP\n8RXay39cEa56QKGwKkBNJkQ5bJ4TBIDyhbAVyikwvjgPODTGO1Ivwg60xzqPLc4TBeCsJjOWsJih\nuoXmv/gXtMMhAMN2Sm25wnD9IsuJnLvX26MZ3CLubcr95BuMDkZcvbHF7uZAvhs/wGzyDGdOP0Zr\nVp5tc32DnUspYeMlAPYP1mlv1bi53uZULRKZaMVBd8wfvHQLgDj0rMyXeO7xFZK4UDQTvq1875ud\nezjuqPmB9+SFlis8hRIWy0A033uHLj5775jkbNN1pjRZnlFYjeKiigyPniZ4Bo9j7DRm8kvlMcqQ\nOdFyoxUOR+QduSg6x/7Zb3HDtinbKgBltUV9Y4t6Y4Z6+n0A0rZmPILdXgpAKQh48flNnnv4V/n5\nH38GgPm5OplJuHZzm+u3RIvPnDnF5vYup1Y/AcC1Nz7Dh37iU1y9dInf/eK/AmD7cBtGKXYogun0\nNWk6IptJ6I7lu3oluzvh57cJ0XtV2OWj75xy8mFih9ByjPfTCVIefGpxSmw7gHeTvxezqRzeycFa\nTa4HWnkCM/lCodCk1x0P/eN/BMCVwFH3muXahkzQ/qs0yOm+uYaviK1uGkdS2sLq7wIQph/nv/6r\nf4/V+ZNcu/o1AP7mf/s5fvEX/iK7Gylzhd2pNpt465hZmAHgodM/z/rFN3juxz7CtWuvifDOH5IR\nk/VFwN1xSr/f59Z6B63E7Ayz/t0JXwFq4nDRKOVxqImpBpQIuvhsvMK6XFbE5BCtyNOcQIN1E8Eq\ncuuJYrl87hxGabz15MVq0AZy5wkKr+xx0LHM/Dd/i/yxh+SYzZs08gvosTx8KT9AVc8ys7JP3O/I\nHY5eJ7F/gp/46G/Ic+SQlGP2Ll/BeLn+uSdO4dIRTz55ju+/dQmA2TDi+LFVgqgIHDwsnz5OHkac\nfebDAFx68wXqcRVdEQWZVR5nZxinOWkmqpal47sTfiFSeQjlUV6jsfjCKYoThomtcYDXCuudOFTA\nKMhthlIQFEqcA9oolJ9ESWKTPBAUBzmnQHnc5Fo+wycJJz/0LJ2OCLYZ3KLVvUqpslJM2AkWVYc0\nfYXlkWj+2XzAuHeF73z12wAsLZ7kyuWL1GoVlhdOA/CpP/soPqnzjedfZHHhGACDfofXXnyen/0P\n/wMAtq9fZn17l+6Vq+hcTFh/aOiPB0zwsdAYSqEmCkNKsTha6ytvK9n7Dvcejh+q+ZMwUikxMUop\nXBEyKjTgsEc+Eu8UGo2f2m5PlmcozXSqIw8aP/UBaIVRkkMExbnHTkyP1sXquAFP/M+/xltlg+ls\nAdDK9ghqyygt2rU8/h3ePdjg0VKXqHB0eZrz5sENLp8Xmx+VyqycOkZmHSqR+HxtZ5evffUL/Kmf\n+tPkA/HmX/nil/jYJ3+My1evAPC3/+avknvNw+96FxUt97jQ8OggYmJ3rfUMh552b0xWRCeTYOSu\nhK8mkQy+iNPVbcvF45VDTeJ9BQaFxaPcZIIUfpjKxEyOm7jtyXJF4RUY5cTcMIluQNtcHuJnPsXN\nH3uOK9fXWRrK71qt4/TJWfaHADxCxLl6H22H5L3i/sc5c5Ut3rv0IAAf/uk/iwlCxsMeX/29LwPQ\nnKvziU/8BLtbGyS1BID/9Jd+gd/67Of5L//aXwNg9vgCS4sn2b5ykbnSjtxTnmOUIVAiRm0slYqh\nisarifmchIjvUPhOia0H8A6MEoFMbJxxFtC44kLGWazyGCt2XSYPxi4j9RBPbLxSZNZL4oWsLu81\nTmsmWuScRQUB2TU55uFf+HmujFIeW1omUiKg9OB1FnWVObsOwJOsobMe6cCQ9cUut/cq/N5bi1wu\nvQLA73znl1mYn2fY7tJNRwA8+4EP8MwHP8gHVj9CurcNwE/+1CepJiXOPfu03EA+5NRixHh3DWdl\nzaqghPMON9FGb0B5PBr9b4Tlf9S4b/Pv4bhzkoWaLhsBFo7sP4AzQaEFogkOCU8z5THFzDtvSccD\nkuAozs8dxFozwY0AXK5QgcYVq0Mph9rJqPwXvwJA7dy7SK68SKOySHtX7PIxdYYkfY1nh2sAlPw2\ndqQhs1NT0EsjvnztOH/nf5HznDpzCus8SoEJi4jEOrTN6e9u8iu/8ssAtJIaOSFnapJ0NQPFW5df\npJwYFOJjPB7t7TRqQ2kU4J0jKBIW696+UnhH4XtAKTP97HyRYN3mTJUWkyQCU1jnUMpNbZ73Dj8C\nHwVoJk4IlJGUDMQHBKFnaC2mCC197jBPvpvV/+wXAfj25i3y9pDWeMQwbopgwwEPre9yxv42IFBR\nThN8D1uEg7+/NsfH//zPcuLBUwCM05RARzjvyYbDqRC8tfzK3/hVPvcb/zcASw+e4sMf/hgXXrwO\nQDI/wOsyJmCaG3qX4ZxCqeK5AgUeBqOcJBLDO1t5e+Ny5yRLqamgrRMYQcx9YfPw4BR+ik/maAzW\nH4FvzgssYa2/zchJLjDJZjWKwEseYIoHSbdg4X/7O7StaM78wT7naHJzdZ6VfXF4ycEOy/oaRYBC\nIzLYPCNwKU6LVn/7fIc/96eb06hDEeByh/MWClQzKpf4n/7+3+dLX/gsKw8/DEC5UuPmzavMnX4c\ngK3rXyQbDrAGOkXeVE7m0cGRWQ8kFaRSDklzkcm1g7uNdiZJFBAojXUOoxSTONLi8UodZcFOS/Tj\nQE2QTu/RFlykpriN8x6fa8KgMEReMfZOVs4N+e7c//D36K4ssteTsKWysUtt6QS0D9nfFKd4OtxD\nlxVqAp8oSznqieMO5cuFxohKkpBnoolZmmO0IR2PSALRhm989Wv8r3/312iePIErnHCne0iiF3ng\nKUngWmf/CuW4TKNeZ2dTENIL3/wNRuMIncu5U6VRoSYOA6KSCD+I3l689x3uPRw/BNtR0zVlXSr4\nCx6vJniHwgo6Lz/w4ApzMvEzucvILSTakbmjwkiExvsj6MKiCDOL/VMfA2D9Q8+Rr7RYvXJVTt1s\nsjcYkAURp5uCwy/tvsAp/30CNS/3Y/bETzlHruoAvLg5h/30Z3j47LsAaM7P0e+PMHbMMJOb/K9+\n+ZdoHFshIMBq0WITKcaDIV//g28A8N6PfoSPPvEMLz7/LQ5vXQCgXK8RmQRd2E+nFc46RsNsaq6d\nv0uzI1anALp0UAhUkEs5QEzFtJKFBa9wePQkAsodGrFUxSon9QqtISvu0GuBivub8PD/+Gvy3fiA\nE99fp312DoDLVy8SLZ0mHd4kOhQj3/RDgrCO82IabP51jDboCqhDuadGvcJL332Jv/HXxXF/8Ec+\njDIljh9b5dqVNwC4dfMyQZwQGD8NFBSek48eZ3FOJrZVCvjcP/1HmPwaSdgQ4UVV0LpIGgEnGW2S\nmCnSO5HfOxa+vw32VV6jdA5YnJvYc4EcpuGod3ibg/PkxepIxzmDQzAlRxSa6aRmgpPKhWyAu5HS\n+Gf/hH0lWr39ygXU4x9gqyP2NQkStjsd7CjmrVQmZHZ8nGbnJrOVr8v9lMv4NMX2mab3T54J+dZa\nxN6+AG2//dnPkeUjtnd2prHv7EyTU0shEUp8GlApa5ZbfdKewNWvXtjCK8PqwhLuKNyTRL2IfjRS\n3XKFXORh7zLa8d7jVZGruhDwAjNPYnjtMCYiKpfk4oFBB2WiJKFaEwEd7G/y8f+4ypULF7jxhsC1\ncQjKmClG5Dsp0S/95zy6tMjogmSrs9XjKDKWOhK1DCslgjwlz9rEsTjTNg121AlS+6aclxyvPWSg\njTzaRx8f8fyVEO8LiFcHuCigXmmRu8JT+5hTczmRzbGTe/I53b1rAjAh8XoQG7x3KHWE6lrv0BPX\nqTzDUcpSyzBM5bthOsn1//C473Dv4fghoabBhDEAppRgggAdRuigLH/WHoXBFlplrccEHu8DTCin\n9qbM6cfOcfrRJ+jvS7b45X/1WQYH+5hQ5r63D3/moz/C9nyJN8Xq8J4PvZ/zl7/NfGkRgD0/oDzy\n+Laj48TmX2s7Hqts0S5LLO663yMuGyIDg26h6eOI5QXH9r6UGo23jK3HGTUNP8slQ57nGHWbffaA\nNtMCvs0dldji3BiFxI/OaLTzR6bAK3QQcH1zTFIU8FfmkrsTfn35QUlFEbvmvJXkpBjWOtBqavMM\nKdlAk9shm/2u3KDNKUUVUueYhOO9zj6BiXAF+LX4V/99vvPUA3QurlFZEef52sYGcS/moC02N7CW\nuBJjqxnNkRRTZs8sUt02dGJxiuw1aLBLWFbUyjJBy6MxT52q86/b8jkKK+h+Sqg12YTeUqqg1Bjl\nBZEF0GgCHGmBlztrGWdD9vdyjBYNiZKYUmDQgZhGrUNCownrCXlRtruyeZcO11qLK8IojcdZi7Vj\nUitCw2lwYHRRDvQ5SivKcYh1ReLhIXdjXArpKCvOC0GQkkqiiv+Zv8S181doLc4xzGVVNXb36Xzn\n2/T3DwA4e+YsG7ZOTERpRlbDYTbgm7P/Hu8ZSUFbVR9nvrKD1Y5SLOHgI6e2+f52ix/9UA2A9f2U\nV75rCSOYK4sGzwYdYlJAHUHovqg754XwvCcMAsplBb7AhHLoeVAF7B0HniiIQXnCovZsjtCZPzTu\n2/x7OO6o+YP+cPp/rQ2B1linCEPRTo0jzVLGmRQzlI8wYUQ6Gk39QKgTgrhEmvemKyQuVbAbfar/\n/X8nJz//XcLHniJVM5T6Eu3s7dwiWqmjHjgJwG5jCZPtk4YV9iNRp9JoDCj20+MAPDvfQ6ffQY0H\nZKFEYGF5j0899TJr7WcBeP1YTKmmOdgfsxLKCjZX98gxKO3QhQ11WmF0wMRY6rBENYnQpQCjZMUY\nZcBP0Ray3NEdDX/gu+AO+MKdIeXAoApq3sHWJuVGgyCI8Lnc4Mh60CFjLxfQTpH2utTiEma63sbY\nPEahyDPBTfprfUo/9xfQD0kB+8yZIZeCOq3OW9hMHjZspozcLP0dEdDmy6/TeO8DVDo3sPtC59Bl\nRWheYa6yC0C+8wUy06SaDHCDIskzitJsyEOzki+cGW3xZPkMaTrim1fkvm84h4kMJkiICmGpwBDq\ngL6Tey77EXG5KJYUyKvFSwI5wb+MolqO8TiUE9FOWAzvWPhprhn3xebur62x+vRzpGnKqLCLoXZo\nH1IORMs8llK5iVEBFBpjswxPiveW7p4IqfTX/yNqn/hxGscEF99MQwYuxdiArMDY48iQbbfxuayq\n5LGTxMaS12bo2uK7/pCazqkENwE46DdYXszZS+eZj8ShdHstSjom7Ug221iBBxe/x1B/mFdfvQqA\nQWMAYxS6iNJ0WBLuqBfhG6Uko0dPgUSLx2s1FT5eSbkUmCSnd2AL3ln4h70e/bY86DhqEEYxzmpU\n4d3DwACOvED1tMtxNifNh2QFFBxHISooQ9Zm8MRzctyf+AjxMGSzJFFKMNilsn4d62Hu5Am5+MZr\n+GRMqkXLoyTkYGdMFEK9qNENjeayOsFuXyb1TLfGI0GfWrZBQ2rjxIsOZxyRFlOZ5ylYg+q+QJhL\nIuiM4EsedVTM8RLJ5dkEXgGtFdbpKRSuvEPfFszkCJzgUUd81jvI977DvYfjjpqflOt853vfA6Bq\nSjyMRRkwrqBWjyF3ClOogjcx1mY4DFHhFLPcYvtDvq5CnvtLPwfAIWOO54pLh4LVl3yf8twA26lS\nLRhie84wCldoHJewMk23mKtCevU1RolUslr+IpkrMYjkmFcrs1w+zHhP7GmNBNVcMtch9ISzslrT\n60OM9lQjw1ZbTIpyIcobKNgXIPUI4/QUgggMBWXRC90RpujtBLdU0zKfnxCypxDKOxb+Gzc3WN8v\n6Nglx2CsiEsRuRInqHOI9JHDGac5oAmjEJfJBFnn+M7738/euaf59YN9AD4wv8C4v89MqUhOQo0e\ntFg/UKg14eQEgwrdEDQyQXNhD/SIuBFQNSKQzD9AMF7HjiQLP6Y9zXLE74zeR70rcf5C5xIqSEHm\nAm0g23WMzYATDcl6d7oxQRgThjHKFA5XCaVxkuGWjEZ5gybFTUqrhamawO7eiRly6ojRqt1dAms3\ndnrs9OTBNtZ3+dSPOlym8boAloISqIy8EHSII3Cefn/MtRlJq7daq7jHz+HeWuPRZKIOnkvZkOG2\nZK8zpsTC7CKPrGxz+Xe/CUDr7DniGpTCglfvcvygj9kbE8fih3quykJ+QHcok/Hs6SF5ZZlu+xhf\nqT0BwPdKfxkXlliKRWHOlL/GA+PfRKURplhlXml0EGB8gFbhVLCpzadFf6U93jtQ+kibvSroR0fE\nAqc8yukpNWDiQ/6ocd/m38PxQzT/kPmmLMPz+7DXSWk1ckxRMNfjDpn1pCXR8n6rwjYD2t1Dbg0F\n/1h99t1cOf8qS9mABx//KADf2LrFUgBLi6sArAcxmUs5mRp+9Kf/PABfu/ga49xSGYldHt+6SPvC\nTZ4LDlheEU3vJE0eXC2xPnpAzpNU2LgJ26niMJXVMVpoYULLsPIIAK+WHyVY+nf5udf/AZ2XvyPP\noRO8AmscalIQ9prM5tNsySiH0x68meJoaIVyfppQOS+lE40XwA2mhLJ3LvxbXRYKvvpBJ+PSlZu8\n772PkhbbctaXZthujjksWADd7as4X6WjqyycOiO/23yLeVMi+cCH+Obr50WQzRbxKGR/LODb8Xib\nNKpQXl3lextX5bnoEW91p0I8Vq8QzVT5Sn6c+UVRiLnqPDeXHqJXOEU7PGR8RhMrSxKLQnQGPWpp\nzvUNKbrXW7OUlOI3z/4VVr56Ua7V0Wh0kbGqiVyxmZ9WqZQuODKe6fYm7cTm3wZqisA52p/A28v+\nh6CaIdMMN6qGnFhokKcpX3zsLADGDYn31okLwr6ZO015XCNqzTGel/i8PNylfPxdjLfXqV+Vhw1V\nyPGzdfq6gIbdAn1dQu9eZy4pqv7xPLulPmpWoIM03aR0pkL1Sk7FiELoqIkfHOIKfr4qNYj8EGMM\nfiAKUUnm2d3dZaEqSKhPIUtTuuUmN899XM7z+S+hUDivMbcZ4pG1mCnKqdAukOJVcYxTRVI1JQ4f\nlQ2nROE7yPfO8EKs0VqWfaV3QFKqMAxiakUE0N7eJjCLjPblwQbff5m9Rx+m2tHoumSvo8VTDLIR\npfXrtA5Ei+vPvYduNKS7K5pfCQ7IMweNJd54Qbbz+JMtllaWCdMCvlU1Bm7A4hPzuIHs9qiYLn7Y\nx0VSUx2T4rTBWodJJOsedIfM4qmMpYy4HyU0SyGmv44992MA2M98Flepo7yfbgD0GrLMipMFlDbk\nHiEFTzZwFBK30xRXTfclTN3sHXpb3He493DcUfOz4YBwVRKYxf/kk3z+yWPsZT2a1wQnaV68yTAs\nkRYAWfz4Y8zbnHQ5Io1F87i5waOrc6SDEVfPyHaeua0N0uEGw/ISAPrkAstpTvviGtFpccKDSoVR\n74BRYTQjbYjLIVWV4aoSn6u4ycAobGUBAJulKJtRDkP2C/JTvrtFuLLC7ljw/Pl8n3Sk6FeWWSo2\njeyXa8KYw0GxVcjhSa0jLBJI5RQGNwXVQLYqOXdkYhRCOvC3cVBvP/4dCb9WqxKcEPsaxHXmtnss\nvnWd/p6Yi3h2lujMaezrwmccPHaMwbFjqDyiWxIsJZqJWV/fZen0aT5s3wJg1OmyveXQPUmEfLXO\ndmrpzSzTjMXHJDnkPqBZFaGFY41xJTZtwFIsDndvOECX68ThxAhDbiO8tjQKa5vPNaA7JE0kX9mt\nL2Nu7VHVkFkBDX1UQpOjVHxkv60iTy1xgc5aJRs/UFN/i/KglMYUpsUqfjDlhaMS4zsV/vb6Fvam\noIPReJs/87Of4iqazceEz+jrdRpXXuHwOSE6NecVw+0+1g9Z6Iitbr+5R/PkAg+GPXqHYndt1sCV\n21NK3cXc0lxappZ5jBcnPDPq09El9ouKWDkqUXcZy9mQdsEIyOeWcSNHfyQPfywO0GmfnazEwUgm\nZDaMsblnqTjvMW154UsvE/30jxMuS/hZXj6DX7+Ei/wUKslzR+Yd5cIDey9O1/rbmDheFZPlp8eg\nJp8mvJ37SdYfy3HnGq4xjArayYPvepxxFBGfmefcoIiArr7Etfd9gsaizO7++VvYSo3l1XkSL3H1\nMOgzaB1jbf8mVSNw8aX9DnpmBhufA6C1sMxcbxNVbdHrTgrWcn1yCSMrYZnD1FIKI8bFuk96bfoq\nIayIadpQHpdoWtmQaFAU8GsxvdzQmRPf1c0Ve/GIoHOJWlf8i4/LgEQ6qgibc+dRqKM9YVMO3m1w\ngpbI53a6pBeC/pQuqe4Q6N9R+ItPP4hqSxi5WK8TDm4xPKjxWksepPTsaWZblv3LYpqau9vw2GlU\n3dO3kuRUT1fJ1tbo1xN6+wKS7Y0dw3SGxpL4k4Vxn3G5SXXQ5cGK2PN+rvEuJI4L5ttgwIEN0eWY\ngn1NyQSUAArbHVFFdTP6lIiKPKNuLIclTb8rZsfbLnNnKxzOPETckGMqvS0iZ3HkWCcnH+cWr/yU\nWm4oGnS4I2FOMJ7ppkEv/1dKTTF/9/aFrDsL/+mVKsvPSlXic40neKU2Q0yflYJ/k7814HquODUW\nLT/2yZ/gYKaFjmKGm1cBWM1GZLVZRkNH51BWzCOVCv0TZ2kUe2WbaYdtX8VVEyb7fkvVFicV1Atm\nwM64z/FQkZYr6KEoxLLR7I/HDJ2ELW1lOWlGbGSarNihuAssNg3HuwLQrQUNFpIHGISO+UC0unv5\nJr6mcC6fopp5Jpu5j1obUBjzfMpY05ZiP9YPQglCp7xDavtvI/z0mSf43KKEh/F+m+buLZLEc1jM\n+G6Ws6hTlv4d2SisGoaNcYljt26QZ/IQ6+Um+QjKgz5JS3D4w/kVtjv7JDW5vFERD+Qjrm4dcLMI\nIxtJhUGqOCzg47g0Q9V3yYY9fFUqUONsQKDH1PsCVadJnQNVIlcWWyCvjWxEd2i4ZZbloWoxm1rR\nCebYX5MVWxqvQfVh2TlZmAtJnCy6CD29mtACdUGaBK8Ft7fTXhTgnCALfspvut915I/luKPmv9TO\nCC7+gRxYKTN6eIWg22WwXYRfKwtUdjpcjiWmDza20Yc77Ngyh3VJsip7HWpJBe89dlaSKhM6ymPL\nWipaMWMCwhBqWtEcil8IUk1l1MfGshI65RLlYJmIjL2BHLNYqdAfj8niyQY1RVaq0KLHQSoqvNSs\nkhvFYVOuHd66TN5cJsoGuOfl2XQ4J30enEcVxOAsywt8nuLcmsBDpuQTFHvSUNO4XzaJC/Lp3aTg\ncpfsBWZalMsioFKtQd4/JLpwjfTUowCcuXmdlZOrRF5ArKBaY7NSpdvfJXhdGMm1dz9Ca5xzo7HI\nXiSJzrn2derecGMkmMxNDccTQ2YzKBflv1LMqFRlPpGJ3tzdpDcYMAyr1Iud471uDxUHLJcKjlA/\n48rBmNmlEp3vS1H90tIT7O3sYiUgotJqsJXH1OszJBvCEbImIPQe78AW6eooy2TTw1SyhcNV4Ow0\nlBGaSGFAPEo2ens//d1kA8g7Fv5MluOrRdTS30Ot3eDazpCTg9cBmH/30yx+4GFu9iSSaA/GMOxQ\nXtuk8ZyEkWzvsD5S7C1XeeSa/K5SC8mTGk9bcYIXXYMbOxYbxcRGNG9kDLrvuVrEuk3nqJfFHYeN\nwsGu71M9fpor+1IRO39rB3vmNGFnh9GsrMberVvUg5yDDcnCcxtRmlOUD/sMnv+/AEhmTh7R1QvB\nWusJOQoVZUfOJNmSMdmT42/ffaKcAHQTX3E/yfrjOe4MrM3N0EolWTGf/l0WZyCbOUd1VTD2+Jl3\n8aby5J2CAXywi9rvkz7+EFcvijb2nzzD2a1rvG/vdbpV4emocsB+rgiLDRTLeweYHPZ0SLNolVLL\nU7TqkRoBzWqNBczhHldR1NdkM0QyO8+t3S3WhwIxN2cquJ0txlkP3RIb3wgNajzgsSVJ8C6+fINy\ndIB5/XWuS3mByqkBS7M1yvWjBnu580SRm2LDmiMG8+3FcadzVNHswLkAlMNhyJW77ci7EL4dD6i8\nJQjmzKkGW2ffT32oOFiQZb83PGC2nVK5LvZ1YDT9uUUaL3yP8H3CjWy4MeWyJjZldvpy82utBcrj\nAde6EvcvRRW8GRPYHkMnIWoQheg44eDKFQA6K/P0rCHt7DGzIILdc2P2PdSLsDIoRSSHB7BynO2i\nGliqBNSbi1y7KPc4zCEeG07/P/+cF8vFc3ZHDHo51caYWk2e7fqNXeZn65Tn40Ia0uDD3VY8QaUo\nb44g+4LNYL0/2vJ0tzvQZ9sDDgNxgPaBj9IbDll76glmFuSuq7t7tF9ap3FCNHpnt81iZOGRM+wX\n3fXy3OBKLcL2iNlCsP7KNVZnI0aZMAp6yQJV3ed4SdMbykrbzeukozYLzSLD1AOSxhytSs4rI3Hw\nw7hE4By1omVkK8+4UK0y4z2HRTFlM8uor/fpdguWm+vxjJ6nVeqyelLue2OzTeYSBt2MdlsmqTMY\ns9HeZe2m3PPqXMLyXImkHFuFgA8AAA7TSURBVEwZCd4Lb2ZScHFMGkIddT67a82PUvAFNXDv5DE2\nkjorowGlL79aCKhMOj+L78vFK6dXKe9epufniS6Kg3v6RImNNKZ3fZ/qKbncvB2QVGepdUTQpaRC\nrBS622Zc3O3YdcjzEenyKRGGtZyNyjA4pN8X4TeCEuXITCHmvOdolsoc1qq0ilU22uwx7GbTaCe3\nLVZvdVg89gB5KsvjPU8+wc5hj/Nv3GR/Q4TvrCMOK7iiicT6wYD17UMq1YCVlijkTLOECW4jTVmF\n1x7rHUdNiN7erd53uPdw3FHzt06dpjMr4Fd964DTVrH1jecZTeiC73k/s+Mcl8jqCK5dovLQInZk\nCbV4rsbcPOUbWxyszOJ7ggkFMytcdGVqWjRvr9NmPlKouMlhYYoqWUqjkjAeyOeRj9hKUuilHBYF\njvfPNOmMumwVzex0r0s8u8RONIP69vMArJaXua4zcoHgyMclzvTf4PqNm5giOTRhxKljxzl94ji7\nbVmN59+6ySuvXiCw8rtyHJOakNHAcaknmFR+0zLbqLJcEMSatRBw0qiy0Pj8DqSpO2M7RtPcl2Wo\nN9Y5vHyZaK6CP/6YTIhSDPM+1V1BFZsPHGMjConckBOrckPXtlLo7DPyJepFUT10OcHhLvWCRt6P\nG2weDrF+SK1eUP+WZti7eovasizfza0N6uESb7S7nC0il26e09VQLwouB60mb3YtJ8cHXK+K0sz0\n+5iyolN0Vjq5vofb/z2GvQOa85OMO8R78M4zW5f7/sizD/O+9zzIjetCX3z1/BX2d7ZJoioTvGZk\nPQeHfXb3Cs5poJmdr7HQLFOOi/Lj3XI14zTHviaJ0ZUAji20qOsydlYclb25RjPukS5KyJj2PKXZ\nEnkfhiW5aG3jRcauykFnxPqqUPjifEQjH+IK+nWSHeArdUoaRonACbFO6JiY8Vg0/+FGwog+fZdC\nkQjtpCmnyxV2vEy+6Q3YHVXIzl8iOSHZc7djyUoRUUlAvTPBW2xcuYTWRzvntfJFkyg13QCYeWk3\n+eBp2aD30OkT7B4ccP7SDd54/c1CeIYoiBnriZbDre091jc09UT80Px87e6Ev/3iC1On8ODWPkmp\nTudjT+NelSU9v9xkuLTMY1uyVBs31vjm4Bid62ucft+TAPRHDlsKaR6bRRU7C7U2uCjnVtGua5xb\nas2I0miIHohgX965TmlukUFB+WiEit3RCB0GbGSiaatpnW3XoVGYoes2ZvbaHj5JaBdtWXo2JSrP\n0hnLMadKu+x2tonD2tTjOTSRNnjnpECC7LJRRYEcwKmcVqvOh599kg++R/o4nF9b5/z5y2xvi3WI\n4wRtyozHnvZIlGb3xt7byve+w72H446aby5coBXJ8jl8+Al6T72b0td/n+GymJ2tlSVObHW5PF/s\n8Pjffx375z5CZWWFziQPacyztdnl9NIyUUe27wz1LPNBjC9KlINqndF4TGxzMi3EqjTQDJWjVmDl\nN7yj7y0HI095UZbyMNZ093foFTa4eyNjMa2xX4OsKyfv1iMCp3lsIBo4XruIVgFRqUpcoKFhEEm3\nFK0whYO0SssWoCnXUqopzjtU0RLgiUdO8tSjJ6Y2/9W3rvD6y28SmpCFgr0xuluHm7z4KtmflEbO\nwbvfzfDLn4dHTlOfFbvcczmDvUOyVJZd/4PnqM3NUentcLMjNjfvj6lGMUp7xrkIKauUGQWKdlFw\n2Oz2KUeGUQZJSRZjpV6hZ9y0++Z+WGJ3b8Cwt8d4QxRif2mVVe/ovS6VtNJggf5ii8P2LUZVEVC/\nnBAN4diBZOrp61/GRHXKlTqlAjQMAiM9/b3CTfr1e4/CQrHZz3uL1j9oKJzzZECriPs/9v6n+ZFn\nH+etq+u89LKgup1bu3cn/N7cKvFHPgDA4NZ5Sk8+RrWqOVgSbEf969+h/+gjZBuCTkbvfZK4f5Vm\nM6K9d00e3kPQWmYvzdGJTIga9Hk5mMUMRWiVTHPy5CJ+0GW/4OC4KCYa9YgKcKVaSvjulVcIZ+v4\nolYQH94imatx4YpEJKWlId3DPmHF4iIRbITmPQsBM9+9CoDFEEclojhChyJY6x2hNqA0pphsoQ1G\nxUsTZEe69zlOKbw96mxulZqSpqzNMDrgsQeP8eRDIqNbe4dvK9/7Nv8ejjtqfvVD76OrBTH0SYm8\nHjHeOkTPS3Sjv/ES2bNP0esU6KRJiZKInp4jjCXxMaZOHA/ZPwhZrYnGxPN1rrx5DWxBZHr8XWzd\nuk6yuMqw6FlQH2UMcku1qPu+9dobuASyZoVSkdCX6yX6FzbIC7TxgCHdSkhSbhIV52lozanf/Me4\n81+Qh4pqhKUKQTkhKDZle2+lV8JtVSqtZWvrtFOUUgTa4JzHmMnvJlx8OcRrieudYxolLTdbdyf8\nwXILWyzDWqgY5DmkKZXrYlIWZ2e42j8kiIpKUjqgmcxwbXtIXLAAqM1QiXdx5YgqEv9pEzEkl8YZ\nQCOuc33s0HFMrYCH8+4eG8kMvatS5O5lfcZJDZMsYBE2XLNcZudqG1/8ZhDUqA0d1u/hTkg38Pde\n+zbx938LXxIFCeOYWr1JqVw+4lhajTEBjnQqEu89yvjpq0nwORaDVky3Remiu+Lt6NmkH8/ET9s7\nsJTvLPzFhKzYVZhf3ufhfI/umUfZuSJRy+Hx4/ir2wwWC45MnNDOHd3UkxY9iW/s7dKYy0kqMcMJ\njz+HQTdlYUG04nBtnXB2Dpdm6MKetqqKfH+DjZHYzEE5glqdmc+/wtonJc6ONofYPYUKJRGq7Wvy\nlSHJ8gLvtzJpK1/4JxDPYoouV0mjRVxOUJhpn0+MJFzGG9l9AtOXNUxJTzoi955A+yPat5cVMtk0\nh5IcRpLAH8T+37HwjQ7onRct18tNtj/9DdJnnqVVvOhlpdNh71vfYvTTPwPA1tIit8YDZlSX/YI/\nGdkRo7RGkOSoWELErd6ITm9IswgZ3/znn2b+F38OM3SMJhyYpMXG1ov0Z0Rje5WEmS9c52p4ktk1\nyWjDzNLWEd2KhIy2PqDZmqHpco594Z/KM9gBPkxozgo7rbawgjEB1uXT9pNaadBidoy/DY38gWK5\nI/TgpGR+2yFHxCqlFBZHYJh2a8G9vfDvO9x7OO4MrO3s00qLziMvb7Nz7ifxlw9ZWpPexbvnX2f0\nF/4y5aI43X7oFOVuD6/BFVxJpXNcskJSrmCKrGpjc4M532dcnLt2fIa9L1+k9MgcutD0vYtXGdZK\njCO5xfqr22xvz+CfydDLckz2rbc4CA15Sa7VOtGiVlJ8/NXfJ7z2glw/apA050iK85pAg/PShbDQ\n8kwpAqewKPSkkK7k9SR2ykyQ1sUa92/YeI26zeOqoh29spPtoXcp/OSwDd+SZOH6yU/wyImA65tt\n+q8XLwL42CdYtwNay2Lz+84y53N6xDAsYm+lsSomMgZfgGRB1iOvlqkVgu0MctbrA07EY0Y3hLM/\n3yrT9iWSwnF3X9AMziaUV0vEa5If7Kyvc3h6lYViMpqjER9Z+yrBN/8lLhZ/UirXmJlfJi4yTuXA\nGVPsIhTJBFqyW+kTOpGqRRk9BfGMDvAuxSk9LREqr7Aqm86FUcUuROWZNmW4W+rI+FvrdE7+lJwj\nsbwxSjm2cwFTNJoKEk9crzGeExg4thloj8tTTLExelyts3vYx820MDviqHPtaAzG7F0UfL9barGd\nGObDhJlZ+d2eN6hmg8OvSOq+MVOhfgaiUcDWV0SrD04vs1KPOBaJ5p9746s0X/pdXNwgjCQRay2u\nEJWr05BRa1A4jIG8YLEqL90HtTIE0x7/Hu3dtEu097k4U9RUm713hATYIkiQFzxQvGWj+OEdSFP3\nbf49HHcONY/9CK2qRDb9JCeIUpL1AYunBMvYaMyRd9r4R4XBFo4HWG1I+23yYtkN21B2W2AzGgWc\n0GnNEf3275EuCNV8o7XISj2SuLsA2w4rCfkNw+4NOU/lqRwqi7Sfv0S0UtATFxJWF+DMC/K+q/k3\nv4gLqygdMrMoBZdya17YZkVbGqc9RkeyTd9MXl1RRI+G6VshvDcoE0x9AF7hfC7U79sp4dofYT7K\n4b3CFf35OTr0nQu/p0ccluUsse8yHrVgY4vqqaKvQi9lbrlEu3iIik1RNidPx6RFA6T9gePxsEsc\nKIbzMmmbL1zjzPY2W3PCHC6vVKklmlI2YLMuoN1hO6bzpTb6QTEpdu4YnZduUF+pMLNV9OB5+Sbl\ntuLExS+LDKMKQRAxt3qSWuFghTF8FB56K2ZHaY0rNl54kCTLMaX+TfdDT6gf3smRzsuGaJjSBaeH\nFK2Qpq+wgj8Exv1bC9+WR2g/4a3UYbdP2e3ST+TBgryLXl3BOXECNks5GFsIPIOunDqql7BVTe4D\ntrfEns9/7it0HjnG4YoIv5aU8NWQYalMuyMONvuXe3ByDKdE0OZCm0HT4bpQuSDR1rlFzfLVF8Qx\nAjouM3/8DM3GHOnEUXqN0RpfcPFDpfCqeKFMsVNQGgG6H7TVTppc+CJeN55pows38RXFb6cvp9Hg\nbdGmeNIA/G57KatyCiNZ4q40ItZdopoiVZLU6MVZulGAKXrcZ5llPExRukS/yPoWKwE2H5P7KvzD\nz8gNrZ7k+uoS0bysoFYZwlqF1w7LLH1GSnTDuRi3HFO6KTD0brbLM43jLK2/wjCQ5Mxc+jRmBlwR\nEbUWVqg3i92Tk67ZKJz3hMXqdBZQHusySa4AcsiNx0xeLTiRrM+ZbNawThpyO+cmc4Z30tQ1nxQm\n8sLZHvX6mvbm+aPGfYd7D8ed4YVI44puf9m4RxQPqZZDLgvLj3huARuG6K7Y5V5uMQwZdDVUC6qG\nSSF37D7/Bk8/9ScBiFZ3eGNukbAkS7NWjtjqRAS/v84QMTNu8RDXn8VclwTu/Y0RC1/6Kje7mudi\nKYysPvWTjHYusCRlAmqzyyiMNNi+zQ57xbRjlDGa3DpCZaYvk7HFNk/nme6lst5hNLjb3o6cadAF\nlXwyrM+mWi70cYdz7o62fjKUv5M7vj/+Px33zc49HPeFfw/HfeHfw3Ff+Pdw3Bf+PRz3hX8Px/8L\nmha/p4Qii9cAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "VXjtCPxl3I82", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# We'll be using a pre-trained 12-layer Reformer model.\n", + "# First, load the config (which sets all needed hyperparameters).\n", + "!gsutil cp gs://trax-ml/reformer/imgnet64/config.gin ./config.gin\n", + "gin.parse_config_file('./config.gin')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "NhiTshPPbvLY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Now we construct a ReformerLM instance and load the pre-trained weights.\n", + "# The 'predict' mode configures the model to accept single tokens at a time,\n", + "# instead of feeding in a complete image all at once.\n", + "model_infer = trax.models.ReformerLM(mode='predict')\n", + "model_infer.init_from_file(\n", + " 'gs://trax-ml/reformer/imgnet64/model.pkl', weights_only=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zY3hpgnI5Rgn", + "colab_type": "text" + }, + "source": [ + "## Sample from the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PnzRPCzFqIVi", + "colab_type": "text" + }, + "source": [ + "Now we're ready to sample from the pre-trained Reformer model. Unlike during training, sampling processes the images one pixel and channel value at a time. The TPU colab runtime has 8 cores so we can sample 8 images in parallel." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "W9ZetV91PujO", + "colab_type": "code", + "colab": {} + }, + "source": [ + "sampling_decoder = Search(\n", + " trax.models.ReformerLM,\n", + " model_infer.weights,\n", + " temperature=1.0,\n", + " max_decode_len=32*64*3,\n", + " )" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HOLawc5dB7QV", + "colab_type": "text" + }, + "source": [ + "Sampling is an inherently serial process and will take up to 9 minutes to run. A good chunk of that time will be spent on JIT-compiling the code, though, so the code cell below will finish faster when re-run for a second time." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "We9Jj9Rap3cB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 214 + }, + "outputId": "10b6142b-11f1-414d-9b63-353f721a6a82" + }, + "source": [ + "flat_prompt = []\n", + "for i, img in enumerate(DATA[:trax.fastmath.device_count()]):\n", + " img = img.reshape((-1, 64, 3))[:32, :, :]\n", + " flat_prompt.append(img.reshape((-1,)))\n", + "prompt = np.stack(flat_prompt, 0)\n", + "\n", + "print(\"Prompt:\")\n", + "plt.figure(figsize=(10, 10*8))\n", + "for i in range(prompt.shape[0]):\n", + " plt.subplot(1, 8, i+1)\n", + " plt.axis('off')\n", + " plt.imshow(prompt[i].reshape((-1, 64, 3)), aspect='equal')\n", + "plt.show()\n", + "\n", + "seqs, scores = sampling_decoder.decode(targets_prefix=prompt, batch_size=8)\n", + "\n", + "print(\"Sampled completions:\")\n", + "plt.figure(figsize=(10, 10*8))\n", + "for i in range(prompt.shape[0]):\n", + " plt.subplot(1, 8, i+1)\n", + " plt.axis('off')\n", + " plt.imshow(seqs[i, -1].reshape((-1, 64, 3)), aspect='equal')\n", + "\n", + "plt.figure(figsize=(10, 10*8))\n", + "for i in range(prompt.shape[0]):\n", + " plt.subplot(1, 8, i+1)\n", + " plt.axis('off')\n", + " img = jnp.concatenate([prompt[i], seqs[i, -1]], -1)\n", + " plt.imshow(img.reshape((-1, 64, 3)), aspect='equal')" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Prompt:\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eZRlV3Xm+dvn3OGNMUdkRM6DlKkh\nUxISCBCDEGCEwWAwduPC5bmqGozbhbtNuQu77Ha5DMvluapcVe62jSdsbFw2BmMMhUACoQFSQkNK\nqVTOGZkZkTG+ePMdztn9x30RKcal8Fq9ei1WfGtpKfIN97633z7nfOfb3z6iqmxhC1vYwha2sIUt\nfDvD/P/9AbawhS1sYQtb2MIW/r/GFuHZwha2sIUtbGEL3/bYIjxb2MIWtrCFLWzh2x5bhGcLW9jC\nFrawhS1822OL8GxhC1vYwha2sIVve2wRni1sYQtb2MIWtvBtj+BbPfkHv/AO9Xi8N4hRvBesV3Kj\niFNUweOx3uEwGBGefvokzzzwOJWqRdTjRbCAAwRFEfIcUDC2uC4ogQZk4mhkjrQH+//sY1RnYuYu\nt9izfxtu9gw5XQ4k99NurDCSPQNU8HlCmintXkbaEhaveMLAcO2+mMw16LZKrK3lrJ2PcMsxYoX9\nu0axOWQK937pGV5+616GKiG/9KGH5fkG7s4T5zRAMCJ4MTgRqkbIEJwJcMYxjKVhhNBYBEHE0MUR\nilAFVIQcsGIQASNCLEIPQECKMGEwZCgWg4jHqgcxRCheDF6VFDAIXkFVEVFyBVElFRDvEQUPJN6D\ngFclQvBAnueUMOTeY0VQVYwqHgjUc8/ebc87NgAPzGVqg5A/+/Vf4/Dhmzl89+tI+w7nhMwLKDgP\n6hWv4L3gvKIevC+uYQffHyAQAVM8r8og98BpkVVGipgZo4gIYoq8M0YRAyIQWUGMYkSxBsTIIO4e\nwRfX0Kv3lEEceM7jKlLksQ5uCIDyih3h847PK1/zcm2srHDo4AEaK8sM1Wp0W022TY5w+4tupNts\nUqtU2b1zL6+5624ee+RRTKycnr3IleUVQjUcO/0sf/OxewjDEcJSwr5DY4yMVhFRvDPkuafZalIp\n1VhZbdHr9jES4r0DySmVQ+JSSL1WoRxFbNteI6yME9gqWepBHMZk9JOEytDt7Nl1F8uNZzl79l4C\nGmT9RSqVnUTxfrq+yo1HXk2302Zp7gFWlo8xs+s13HnX23njCyfQIljPCz/7m2/T1cUuzliWly9y\n5tnTRBVL6pW5C02WOvCHH/hPXLN/ih95z/dDUieoZAR5lUo9pi3zlMeUzqqgudLtClEAzimBEWrb\nIEuUsckIu3SISydXEZ/SaLaZmBzmrte+iIPXHGL3rkNIXOIrxz/N7/7WnxD1lNf90Cv57u/6fpaW\nT/Inf/aXHL3/MkFJCMqKZDBhyrTbKaUZJfWe0SnD6iKM7ILOGrg2jE1BJOD6wlpLOX3UPe/Y/MgP\n7tMrjRSLYsXQbHbo93p4p0QlIYosy3mXP7r9Vdxyx6043UGWXKY7VOFif436noMENka9J0vXSPoN\n2mtLfPz+x1hb6/HI4xcgVH7+p9+JsauMDteIKxWsGhBDril/9/cP8ME/vQ8JTDGn4QjjiDe8+S7G\nhsuExlMKDHgHRlBvMXiyLEMUFEHFk+eet976aiIT8lu/90GG+6dp2WEe65R49omzmAguLj7/vAH4\nnff/hCZJk9Ba4iiimCGELPVcWrhML03odLu0kg69tAveERhBtUkt7FKNAypRndhGxVyAFvODxqS5\n0s8cqYIJY0rlCuVomDiuUYpiRALyzJHnOVnuaLd7NNbWmL18kU4/oZ8N5jh1pE5p9RMEC2LBWKJA\nCEyItUJgLNYKobUIivcg3oAWc5IVQ2CF+x85+bzjs7BwVkUVsASBLeYBoFQdx4iQ9BuIhARhBICx\nFhEBVYwJsdZSzHdXbyny1bf/2n8D9JMOF44/SHN+juHhIXYduYNyffKrrgOg6nFPfhgrLdj1chSD\nX75I84mHWF1p0bI1/Pg0k3sPEg2Pgjii0ghhqY6YEFDCqII1IemTH6F8y9u/YWy+JeFxYlFfLAqo\nYFC8gFHBqUehSHkJEO9R8bg8JwgAURyCeAGrGAWPwarHm2Jxc94TBeCdIbOOMHcMORj5yEdohD16\njZT6TJXepZPMVJR2e5m4cwbN5+iv9jk3e4Wl+S6a7Ge8chsH9t3I6DjMX5pj8VRKOPwoK6uXaFyp\nc/FSgz2Vbay2Eu579DJxqGyfLPHiw9upxBax4bcKxddDlVwFa8AKhN6DsYQUREO9YdUW5CN3DmMs\nRjJiLLkqXYEICpIxWFQjDLl6YmMQ9WQYUgrSGKqAOEIvZCLFxKHQQwm0mERVlQwIEDJVrAoBQugc\noYGeGJx3uAEJNUDbe8oiqFj66hiSggjmKGZAUS1+c7GhIBHqMn7wZ9/LJ/7v3+PMb/9H7n7Xewkj\nQTKHV4M1oCp4D06VwIP6gsRIkVwF11AAxYjgBt+TAQ+x6jdeWJCXgjyKASMFEbIDYmOtYkxxnXWi\ngxQEqBgdxfeFb0x2RAoCWbyieM/6KNgMZrZNUauUCAPDgf17aawuE8UBQWjodruEYYXdu/bz2te+\nmiceO0ovaWF8xNNPPcULX3oHcxcuEYchB6/Zx6lT85SqFovFuZRKtYzLDc73qFYN3ncJwh71YUu3\nreRZhjHQ76UEYYDgqIZC1usRhE2iMCSKI5J+Rp7lBMYSRiGVkRGCkQOk+UXaSydo5ys0W1dwzYyw\nPMajj36YWsWztjJPEG9j14FDjEyObjpvvvPF/5JXvez1fPSe/4d/+6v/irEZeOub3sM9X/gfXFpo\n8iv/6heY2FbmX//s9/PyW76X1HdoN5dZWVkgCVeoVz3Li9BdFeJIKJcgS2FoSrCBkuVCdVjodhNq\nQZPxaUskNZqdFlNTYwRRCZtcIbhyAr/jjST9nNwJsVW6rQWePvEZnjo2S7NZ/OpRCHsOxAgJrbMp\n7bZjeqJMP+2Rpp5e2zAFpA6cQq8BrTa0VhWXbi42KoYoFEIxtHsOYy3eeZxTDBHOQQgcW1rmlgtr\ndCZCTkmbZt5mrdMmfOZJGs0mIzXDQ48dx6qn0/d0Us/liy1cpvSSJqn2uXCmx+//8Qd55UsOs7jU\nZdtMjbtediPX7BmnXotIk5Q4hJmd0xx5wXVsmyozXI6AjMwZRAXvDVme4xHCUoRzHnWeUikiTRxP\nLz6LX/LYzjxjwxHPXF6iPLKf0bGA5Ua+6dwBS5p4JBKseESKzY/3xaYwd0ruldwZvA9JcynmCC1h\n1BGYiDiI8RKBCM47AgyIRTQgMAYjJeJSjXK5TlwqE8cxcRRhjMU5T5Z5kn6KupgsC4jiHr3eGiGO\nVD24AK8O5xy5E1QUsZDmBms81lishcga4qiYkQTBeIfFbsxtTjdXnOl3VpGsRxCVkVKVMKqAGFRd\n8f3EYIxBZHBdLe6rBeVCB/crOI0MXqLfkORcfVwwYrGBJev3cKUIzZPBvFoIHcUqNBADZu5g/qkv\nMD0REtXHycZqLO3NmcuPkTnDiAgu62KlTpbngw1ugLUW9Q5BwRh6k7dR/iZx+JaEJ1AlFzZY5vo3\nNlqwdFXP+rmFRRwMWZ4xEG2KhZdCKQCL4km8wRjFiiXzDmsEjydST96DnX/yd8y6BmVXoyxXGJq7\nwtDwGEPpMdKGYWkppRQEHH14nhcfeh/vvPs2JieGyGyF8xcXuHD5MgcO7GV+YYm9O97A+Wf+By//\nzrdy7tQpPvTnvwf9FNfzNDuGNO2TjVVoJZ6havY8U6eAxaCmIIGhV7AGdQ4VgxhHIAa8p2/sgNR4\nYlUSFDWQekENZKIEKKEKGY4RI6Qe0sHSG4vBKwieQBWHEnpIjMOpYHBYAlItJtQizRSvHotskKmy\nQoAnFiFTUDyOgiAEIoDDqZJq8VkThFAVCySb22gVUEA8Pk148zv/Vz7xB3/M7777+/nn/9d/YWRy\nrCDIWmSOGrAq5AP1xvgB6diQWgbXG8REtFClioFjNl4iAipXiY4xYEWK/xvF2mIgGqMY3EAVKvap\nxQGc+tybbXyGDUqjxfWLnVbxaURls3yHMDSMjw7RbjZpra5yzbW7GanvoVIp022vMT5UY2pqis9/\n/l4uzZ7l5iNHCOIhvudtb2P/wYOcePIZtu/ezuT0Tn7nmf+KS2tkSaFiuFzp9xN6HSXPoiIPZBhj\nhHpdCUNLEASICLnL6fUSdu/dQz9PWFlrUomHMFjwDpelSFyl12+z1rxCaUiZnBwhzIYIzTCNZhMC\nx/btB1lenePcyUeIokle/Yo3cvsdd1Apb75iPjE5CcDdr/hnfOJzf0Z1bIq/+uhvs7IEK5fg4aNf\nYeYa2HH99Rx54S3cfOBGvuPOtzC/dILdd1zHCw6NMhY42r0WYUVwfkBKRVi6INQnIYhgbUlYaZ4j\nTcD3wZZKjI4PEWVL7F66h+0pPH3yfh454YhKCh1otFY5fuox7rvvDHNXILBCXFb27t1LJcp4fP4i\nYeToNxK+70ffyT987I9oxQm9ppDlIB66C4oGggQwdc3mYhObgH6vQ8crViDtZTjnAIP3inolzWF6\nz3Wc3zPOQytPMb84z/U7DvGr/+1DnL8CIxW4sgAzY/Dmu1/C9FSZrzx5gTRLSbKMeq1Gc+UKu3dt\n54ljOU8ce4xaDaIIPvH3jzI9BjfsHWNqxw7GxuqEwyVqlRHCIKKXZsUGAkcUGILQUy6FqBp6/QQj\nFmMDjDGUSxHzrRWaF+cYkmVCKREFyrZaj5tfvZ0vHG1sOndUIXOOIA/wAVgtlHdFcD4A8VhTIrBC\nLx8oJ1Ko4i73OBfiXIxKiA7UDYfgvEUkwpoIE8QEpkwcVqlX6tg4xiJ4VQILgQFrIrwL6KdCHFQJ\nbILmCcZCF4+mildL7hWPwXiwBnJjgOI6mYVMPZEtFHYrxZxtjUG0EBA2A5d06J07hs8SSqMzDO++\nntLQJLgUNWXwKd5niDEIHh2sISBYG6JhQbaKeVI2iI4qz/n76mZx8IuACFFUIjDKMD3i3iwaOLzL\nwSXgE8T3IBqhs7ZKOD7D0MxBAIK0RCwpExXLmgtRn6IuxQQRtFdJxKLqKVlL6JNCtY8qlLX9TePw\nLQlPzvoOWzb2vgJ48QOqN9AAim9dLA6pw0tRwlI/WBnwIJ4ifsUKYkQJbLFCCYb0gufg7/9nzgae\nITXM1OewK08wTE7rxGm02mDEeuKhRwjT1/Hz7/pVdkzu4fy5L/Bzv/hx3v0T72BpLmVidJzayAjq\nPGNTYxzc904unXyGF7/2VXzuk39ERkzWyWglKZ1Oh8uXmhix9LLO5jLIKwOVEW8c1hfBTEULyRlH\ngiVWV6glIqgaAlG8g9RCmYJ8RFLEKhKlj0V9MQIsgOYYAlI8BiFW6BplGKFnoeYNoq7QG0RRNWSq\nhGoQih1KIkpjQCYsHqNCn6IkaRBy78kHyoUTqCMY9agY1HtiNjm6YJ0BA0reT/muf/HDNAn4m3//\nL/nJP/goSd8Xk/TgvkYHMQLUy4DfFO9fV3kKYUdxqli9en0FzAbZWf9PBmWrAfkRxRgw4gpFR77u\nYw7IT/HA+uB97jfX5z6gz6FGmzytPEv6lKKQ0AovfNGtOE0olQz1epWpsTFGh8pkeZulpQUmJ6dZ\nWmpSHYLFtRYnTp9m/sIc23ZO8rKXvojHH38lTzz5OK3VDplaWs2U+bllmmtd6vUhxBQEq16vEsWW\nUjnCmJDARHgjBD6lakrUaqOsrC7SS0NEPL2sS+b6DMczTG0bpSRtTHuNwK8xMVGhVhonigKaPYjK\nNcbtLi6eOcvkyLUcvu42to2Vn8NYnz8azRUAem6Nmw69ifFto1yav0KQLjMfPMPnH/o4/doz7Jm6\nnsnhYYKgBMD52ePsGq9wzYE7eOALn2BifIhm2sQjRCUIQs/k3kIGXDyvNOYUa4SwBJ1MKfmQKFni\nVv8su0eh5wzX1Jb5jgguTBuSC0qvmXC606G1BqEVAhHmF5RLZ9vcfffr+PI/fJgoiuiser7y5b8n\nyfqUx0ACiJ1lbcnRawn1Mcv+w4oZdpuLTTshywtFMs0dYSz0+x4xhszlhEFISWDkroMslC3TtX2I\nVx768hmWlqEMuD7cdM0kb3nDzVxeWeXhJy9z+tQynTRHA0Op1+PYo49w5xsmeNubb+b++x5n+zRM\nj8fM7Jxi7/7d1IbriLVAQGN1lU7epFItI2IJpFBROg6cyykFnnolZKgOvZ6jn3isqZJnSikMeel1\nN/DA7CJRxXPy1FnS8xfYdV3Inl3jm84dURA/UIc9hXqCIKoIIZEYiCy5F5IsLxQNA4EarAiBDYEY\nlQART2gKwmxtDARgS1gbEcUVKpUKpbiENSEeQXSgMCBYG4CHNMkohQFl6/HpMjYeIsibeNcprAeA\nMSFhVMcaAwYUQ2AAo0VZ3hYleQuE1hCJxYgF8y2X7q+DFUMcVsgTjzaXSRbOE1WHEZ8gYQmftsnX\nVrBxCc0SXNLEtxcLhXdiL/HMTdgwHgS2UGXWoSobis6GMk5BfATI8WT9FuOjq5h+A6ovx+LQs3+M\nzJ2G2WMsm1tZu/77GL/2JrzrIjbE+x4SVxg68ALKTlmbPUOeZqRpl6nsDG5lhVwisuoOOvEEoSi1\nbI5S/yTsvOUbxuFbRq34Cr5YakXxyNWFZlAycT4vFB8AI+RpTmAG/gwVcqdEcTDwhhjUKWqKQRHY\ngrXSdIz9u18gv/EgZv4iw/lxTJJRyleR2rWMbV8h7jSR/lO87VUn8DlUyjHLZ85iNeCWI3vxaZ+b\nbrqFY8+eYjyM2LVzB0FkMAoz+3aRhxHeeobiGqYqjIvi3RhJmpNmjixNNpVAqkroDVlRsSMb1HtF\nhY56Im8R4xGESDyBt/TFgzEYhMgrLYFQPE6FUChYv3pUit0JeIxCb1DKKgd5wfTV0336ONnKCjI9\nAwcP4t2gRIjBU3h4rBaqWnnwm63gCWTglRGw6wv7QOXwWpTnmqqEYojUFGTtn8R3riY+oqS9lB/4\nFz/Ax+pD/OEHfocf+OmfIgjB5W6QT+uqCYVXhwGP0Ks7CVXBIAW50XWFZTBxFMEvatxGB3K2bpCf\novadb3h9njswN5Jdr5anNnYtzzH1bPAinkOE5Op7ni/q9SGeevJxDt9wPb1+j+XVOcbG6gzV69xz\nz2d4/d13USmX8F5ZXV1jrdFkaHSSRrvFth3bictlvHM89OB93H77YTK6PH3iGSZH9uHzPlOTo2zf\nPkUUW1RT+v0uxuSEEoMEDNfHmJyYZmJigjKelfPPsrh8npGDhzBxjU7SLsqexhT+kDCm05onbZ0l\n6V+kVq/hNcJITESXIExotRuMT+/jhiOvZGpmqvBN6SZrNsCH/+I/c/+zf8H9932QyanX8Nhf38Py\nEtx551u5fOE0nSShvdxlz80Hef13vIOHj36MhfYsf/PpD/NDP/Ie/u6vf59zC7B/f5vhcUOnBd0V\nz75rx1hbXuPMcce2qYgX3LaDY4+dBQM37NlGvpTyum3Pcv12sIEwFEHXGO44DH95ytATw003v5Iv\nPnQfXiGOBJ9CyQjHj11idfGjJPmg7J97cmkws2uEfteycqFHc7ZLlgr1aagOZ7TWIMg3p4AFxlCK\nQ/IsQY3QaafkDmqRkKmQZV2u2b+L3KZkPaVSG2Fk2zif+9PPsLIExsLth7fx+te+gNnFBdY6Kc7b\nwnqAILmSmBLnn13g/pGHeMNrDjJcabJn7w7GxkcwQbHtjYISWR4iGEpBk3arg08dRjxhYMnynNx7\nwijCRAFLSYIVpVYTdkzX6XegkaSMlGfoN1a47cUv4+BNh5HD9/HiFx1k1+4xqpV407mTu3wwHh3q\ni9g6zfHYYtwbQ4QhDi1RWPhjLCCDclFgI2wwKGeJIQgsOCUwESqGIAzBRsSliFKpRGgDrISFCqRF\necggGGvwVUNZQ7a9YD8vuPY2hoeGERE67SbHnzlGrAFJ2uXy2ae5cOnLXG6epxTUsSYq5nijlIzF\nDDYNoREqAYTGFD5G2RxZro/vQSpD+PYa2ZXHyVZncTuuRVwHWxqHXgN99mO4zKFZH8lTAgNBXEE1\nw49dg9hgUOrSQgGDDaJztczFxnoCppg/JaCfZ6S+Rrw0i0anwIaQT9BniBWzl/bYfkZ23EBgLUm/\nRRiE+CwhKpcxQZnI56yhuDyhtTDLtnP/BVP1hHu+h3J1CPIu2liE2YfpzS9SeeE3jsPzoIm2WDzU\nYHComoGnpzC0qin8POIFK5C7DBEIpFCIjBVE/UbJoChJGBDFq0E0QysV9rz8dprNJiPBZUZb5yhV\nt2PsbrZJkzR9nJl+g2vzLl/6/ENMb9vD2TMnqderzEzt461vuR6tDPHFh4+ybWon3U6TJ48+zA/+\n+I+xcOEMlxaWaJ09R6dn6STdgqxYSyk0RGFIKQ5xWt1UAhk8ESAY8gGhSMmJgSE19HDEanE4RAyZ\n97jAIt5jxKAihWFYBY+wosqwKVQcBHJ1GDVk4hlWoSfQVYNbXCD/4IdI8hx1jr6D0vV7Cd/+vUW9\nWXMsQqqQSk6IxXhwKA4IVXFGMN7jKRSeSAvZX6VQl0pi6Kpi1JGr4/nbca9C1yUQ1geBkvQS3vz2\nN/H60Yinjj7Ie/7jb7L7wHbSJB3kBaiaQtPRor6rqhtkR9UTitkQVHSwpzIwIDtFOa8wLlPUeKXw\nPhmKEtfX15x1oDIONEy5SnCuDl5Fv/odX32FTSo8Z86eZ3hsnAPXXsNjXznKjYcPMj05Tr/X48A1\n17DW6GNo0u526bba3HnXnRx/5gJHjtzEDTcfpt1pk/dzLv7t3xLamJHhUXbu3FPsBAPHxHiFzBdK\nbBSUqFbqGK2g/QBUGK/NsLbQpeRWKFcTaqNVxqerXGo26WlGHhpCE2B6XTrLlzgvI1RrQzQXTqL9\nBp3WCJQquFwxzrFy+SStbo/h8euwtSkoGdIMOs3Ne78+9rGPI4/CUABh5QJhJaRSdXzxwb9lrQlZ\nBlme8aVjH+UrP3sf1x7Zw2/8wTs5PZewd+dOLjYWqATC3HllZBhsKDiFcT0MXCbPTxHGnubaMkEF\nDkzv5kC9xu27n+ZFhwRjCzobVyHJCo/Oj96g/PFRw/33f5YL59uQmw0/kA0hLBnmllaZimJSVfpN\nT70yzdhMyqVjU5y9dJTMCcN7hMqkZ6Q8wsRUicRvTlWOrdDredaz1TtPLQ7JfeFXwcA/+7434H0J\nl7fITEoUl/mNX3w3n7//KY4fn+XIjbtpt/usLvdYaXXp9yFJc3RQYs69cK5jWXvyAlF1nCMvvAEr\nFqeF0dhIQDftEscVJBiiUo2olaC3do6+nSbNAnKXEZiQfjeh3+kX446A1aWMuaDJTYe2U45KdPtL\nJGvwkpftZmTvAq8bqXDq7FFOPLFCp5fyI+9776bi00syHBnOF17FXME7h8ejxhfzgxOcdZiwULvx\nHjHFAi6RwQRm3cWHwxbPBRawSBiAGGwYIkYwYjBBUCgc3qJkBPEQmasT9RJ21CpMTV5LYJWZoTrT\n06MElRqvetGL6Pf6zJ48h9x4G0tnbubUlcs8PPsQp1efpmojxIRExhdz3KCaIKGgpIXvcZM2g6zX\nQOceQa+cQNOUcOIwvttA0gZ+9BpMYCkFOUYTlB4SerABGii+t4DvrSCl6iAy6zo4Az4gz1F5WK9z\nbdSExBjyqMwpP0Z5OWaqto2oNoy5/mWsjj7F8aVPMT4xw/ahOr2F45QroxiJi6YmIGsuQX+NvN/G\npS2Glx7D7/8Z9Oy9SDdERncj9WFUFmlluzjHCt9Y33kehMcPDEjrRiQvimBwMlgbvGAw6KC+l+VZ\nYaswEGnB8RyAEawUnV7eK8aCMR6dhSO/+Ss8W7bY5hVGs2WC+gxiqswkn+bm7hzXl1pE1Yw8zfnH\nxx8hKpXZvncnmfNIZZTTi0t84fOf5PVvfhN5t8e9n7qHV7/xtZw5d5Zf+rn3kavh0A03MDWsRf0P\nxTml19NCJva+WBg3AdFBbIwnY6AiYIouKBQD9FyONwF171FriNRvlFf6XgmM0lchMxAZIfPFjtEP\nSoh9yTc6qzKU2DuW/vBPcb0uGlhCl6CBof/UCfS3f5fof/sJMjGoFSK1pCgl5+hLYV6uqOBUyRgo\nTYOF2g9q1Tk58WByC1CSgWel+09QePS5g0LWhRAl7SX827/9Ij/9lttZecdJfuD//DW+83tfTdLP\nBsTGb0w4Xt2A6FwlPmw8v3719T/9hhmZAfEpLDqFV0fka8nOern1OR6dr7ruIO+/QTlvXeFZn3M2\nmTq43FEervGlL32Zaw/soVyq0G536LQ6jI2M0en2WFpa4Jr9e4hsgPeOiW2jfOGLn+P85dPcetut\ndBo5tXiE2dnL7Jvex/6d1/Kl4w+Re0O3m+ICxUYBuQ8pmTqVaJhyyVApl8lzR95f48C+a0m7c0zW\nh1hZWWH/vp08/MQTRLUyEnoqJQNJjysXHsXaGMlWca5L1OtRG5skCMvktgauTdbvE8Vr+KTF0twa\n5VKVC+cWNxcYoBzXSPpt1rxw7NhJKpUqoQ1ZazQIawbj4PzpBc6eWQAHDz72ZbIMwlB45thFjAjT\n+4S4orRXBd+BZsMQVSpErQgxcHk2p1xrcsu1B9gzPMYL7CNcPw3HTxuOHCp+zwuzymhNEAO7JuDH\nXuT489N9WghBDmkOJiw6DHvdQj3BC1Fg6TRzuk3hlptv5eP3fRQfCNG4MHEttNcgI6WTZTTbmyM8\nq82EgS8Ar0I1DnF5gnphaCjgv37gl1nud1lpLhKFAaJQr40zFo9wcNciQ7Ual+cXSZOMIIpoNBqs\nrLQL9d7nlOpjDNUCwmwJ7cVonhKKpRRHhe8LoVoZpp8WHY82HGIlXyXPLWFlO4EpkeYeY1zR4Wn9\nc0zVBs0M3W7OvfefZu/OYcLIcfedt3Bl+fP84x+u8v7/9iRvvHOIiYk6eWlk07nTyZuopoRI4a3z\nWjRpkOFJQJTceDJSnEnAesTmg5gqmfFkJkfFAA5jDQaLNQKDTSpGCwIFRSkLQARvhdVGTOv4LPHi\nWWKXQXWCZXeUtLVM8oLbyTyJux8AACAASURBVA4dZMdttxFFEZU4IZU1Tn3p7xlrX2afH+KWQy9j\ntnOEv5v9IivZJQiiwjSihauGEEQC0KLEtRm45gqSlfHhTqRSxffbZE/dh9GMYMftXFW2czAgPgWf\nI07Q7gq+eQWpTyPGDjpVZcNKgCrWJ4V4oZ6iG8yjPsPkCVWrDFVqrDU7rLbn2f2y7ybLu5QqNUra\nZZttsTR7nIWhKsNj04iYjVnXBhFOuviki01biBHq6UNcuhRzubudPY0y5RqYWo6WZljzSxw/+g/c\n8tZ3f8M4fEvC44vfGe+LNl7PwDDlc4wYUAcC1rHhU+l3OqQeYik6ujIHQVCUfTwGsYqIhyDAnc84\n9LlPcbweM2ICovo+0tUusamyw53gbb1ZquVl0q6l30xpLFdpNDs8+MX7mZqcpNdo0Ur73H7HHbzz\np9/Dzh0zpMsLvP+X38cTx75Iq51A3mPvzBhrxz7FUC3EIAS2mPgqNaVCUCggsrndaCkvGsXFGCoG\ncuMJJKA7KFEFGIwRSi5HrVDTnMAZEhFScRgxWCdYHJEvSoZOhASLNVB2GW52Frk4S2thkaTX5ekv\nH4PhiL/+Nz9ItSqQ9sgTg0Zlfu/TX+SpRx5n+EW3kmaORAqSpAPVJMHjFZKBh8eipDhQoasQeI8b\nkIyuKqFCSR2JGhLd/E7duYHQuU76pVDDFM9LXnyER1ccT504zw++9NV8+NerVKZu4z3v/w9cd3gn\nSVIMnI2Sm/qNPjG/7hreoCjF0FgvDqx7edbv/XVlKF0nOrJxBRlc5mpn1oaEtDGhrctK8tx6ln7N\nv58nev0enY5F85RWq02r3ebMyZPMTG/jRbfeSm91lb1791Kt1QnEsHBlkWp9Gwf2HSQqlzn6pSfo\npp4Lq4tM7dhOlqa4LGVyKKSVCDkG5zwhIbsnd3H93v20Vhc5d+YSy4tnuf7GQxzYvoPG4mni6ihz\nC02yNKCM4/vvfjtO4DMP/E8uLM1SiiJEHZ1mk4nJfQT9jHf987dz/0Of5czCHFe6XerRHly7z2rz\nGGe6TVoXHyZzhosLFzYXGCAvtQm7kA9J4XVzPVprPXodg3aETtszXDGEAvE49JeFLPPEk4LrQ3NV\nSdaE0XGL9Y5oTNk7Bq3lPgE1ggC+743fxU37d3Hl2Ge5e/rLdHuGl/9CMMiiQdcfAYVGXSyG33Mz\nfOC7PZ8/JnzgIciN4DLFBoKNio1cMdcZ4qrhoU+f5PhXTuIzqO9RahNC1ofKiEFdn7mLRYlqM1hu\n9LGhZWK4TmOtwVriEC1y9v3/7idp9ZZYvbTCqTNPYwKQ3HHp3CJDY7sZHa+j3jO9bYR7v/Q0H/rI\nKSKB0Ql41csPkC2dZnpnmd2H9hCVbqTTyymHIc12n9wZSpU6qpa1lS55lmCtYE2bnBgTlokmxsBn\nRJphdFBuUcWrKwh2ltFrKbMXGqSp47aJ3Ry5ZogrrQan5keZ3nct3/n6PmtLF+ilOZeba5vOnYa/\nQGgElTKRBIg1eONwPicNu3j1pC4liTpkkmJLDMrbFu+hazJSYwiMEJqAvvSohEU3pWpClzYQ4Oij\nPqFEldZ8iUtPzFLRNttCxx5SRkpKPQ7AXGa1K+R1xc1+kXi6jj9xP9WDN3Hpyw+wduJJJpqPEpse\nh6ZiOs1HucYIR0bKLMtr+E/p3yHY4ngRlL61RAPCs1noZ36VvJdB5rCFExJbqSGjE7jlC5goBNeF\nrIPkffBpsbnUDtge+dkHkPoMZmSmaDgp5HPCc78L4/uRkVsK9//8n0BrDlmeRc7cC9ktBNk09WQE\nt/OF7LrlDlRTrFHU9ajMHGHnnUNMd9eQzgru9JO05h2loEt1YoZxey1ahV51hnzmBuLWIn1/mF3l\nM+z2BpdEpKc+giYp3nnq849y+w2v/KZx+NZdWgje+8FCddWI5G0ALtuYGjJRrAevjjTpUgkK03Lu\nITaG9eXK51Kc3yCKLGZU/817qd9yA5WzRxmubqOx1GOnHKCSPsntvdOUdAHXN5A5AglopxE//lPv\nZe+BvThflC1sGOKcx7icztI8733vzzBaqZMTcqB+mZFAePbMUcoVizKGUYcMzrERQL0nGLTJbwbO\nQ2gU8QMPigckJxZwgUVxqC8Yuc09BAYjEKnDiNCRHKuGmkAsDgc4DPrscdLHHmeumREd2I/dt5fy\n6Cjn/uaj/Pr7XsG+6e3Ebhkny9htbcJ2Fdcf4V1vO8zPHa8zjGFNlEwUgyUUJcGQ+ZxU1826Bbkx\naogGpcl8IPd6CrITeA9iQR32n9Cl5ZFCZh+YBov+Ax20PXrSNOGGg3v4yFce4Kfe8mZ04RF+8Z0/\nzo/9zC9x91teWpS51snJwCc1KHIBXM3HIilZV2QK21/x4qKIWviqdL0N8mvPjuCrrDtf13HwNY6d\njeuu31a02BhsBp12l5HhOnEcc/SRR6jXa1x36DqSpM/c8gLnnpnlhutu4sSzpxidGOLkxUtcf0ON\ndgKXzpwmimJOnLuIBCGEder1GlF5hD0Hr+fpZ76M5AlhEBNplRPHZ1mba/DiW68hrgpRFrHSaCAN\nz/DIEJHJqYxUuHRllSwLePTxx9m5axevf9lryZIO9x59kDPnjlGvTbJ9ch8uafL3n/mfHLnuANt2\nHOBjn7qXtWyV8alJgjDg0uwFOs02Ya1Onvc2FxjgJ3/85/iDP/wVMjy5M6Qr0FqBfg/qkcEbBa+U\nx4WoDv0GBJGQOsi6yraRmNVewtKi5+ANu1AajFSmSJIO1juMQnt1jsuzGeeXhUcTeOkh5XM/78mc\nx3nIcgiilHYP+m3FqVArw2pLeHq+MORGAxM8gDqhVldK4pB2hAi4VGhdFib31zj8kiGayUVWlwTv\nheFJ8KmwcmVzG4lqNaJeiVluNMnynDgKyPIc8TBe3cZf/e3HGRsdpz6yg9FqlcWFecamhTRt4PMO\nYXWE8/MJ23fv4Vd++SChVYLCSEcv2UcQRajLWGv2CIOAJE+JwxjnApJejtcUAUIRkl4OEZTE00kd\nfaXYyDrHwmKLdidhba1DnilXFlaZX1xgYmKCH/ruF3LzdeM8e/4yc1dgredotHMudU6wY+848bVT\nHH/yBOKbm84dDRt4E5Ebh7FRodp7wfkUifuI5ohLCW0forRYP8SCFJYLrw4VwRVyME4go0MUCNZG\nRRec9sjo4JyjsTDM5ccvUcmbDNcCJvMlKr6HqMXYiNRbqtUaUaVKEOeMjhny3hrp6hwxPXYPreJ2\nhpDmmDBnanuJvJ9i3RVK3R43l1/MUf9wcVwHkOERm21a3QEIywEmL9qQrCiYgniT9snnnoHaMEE7\nwWZtEAd5BurwPiUHXHIB05jHDG1DTbF2ikJnbYxa60Ek/St8OFWUA+xhkt4+svh21vpXWKVKoz6K\nSbuEjfP4bTvIvaNUGiEIQmz7HP7kZ2HlNOXlL4HmRCMH4MAoGrZg4mVUKjvZ4+rkEpG3r8Az/0iy\nNsRCA8rXHUIqU1CdIt99J6W4/k3j8C0JT2FhKMLrN3a7VxcDEcF5j4hHpfBYaB80CjCDg9zEDg7E\nQwlCpeccJvPYm25mx79+Nw/NXyZv9BhN+vTiEdphl4OXljjgPoFVyBkBbePylM+enuBd1+wlSdOi\ny0S1OD8EUOd47//xPj7+53/J9DV7eeUrX83xoxeoTHZRU8YGoD7De0HEYwIBhW4/pxLljFc3t9sy\neNQ5gkF8FENOjqKEqnhjSKyhKxkGQ8UbclOcLeERylgCgU6vS2AtLjBUTj7NylyXymvfQDhSpkxA\neuo4Nzz4ed7y5sM02wmPPTBHffkKLuzxwncM8+QDC7QuLWMnxnjrzX0+eTyjYg29VhvmF2m22oxO\njJG/4hVYG2IVnFe6CKH4DVOy8UpPlKr3WKAjQh3FqmI3W7MBnDOD2n1hkPYKVrXoPpCi5JVmKft2\nTfL7n/o0P/HW7yFbeoY//62f4YHPvon3/of/nUrNkmXuOQyDDcKyrsqs/yHr9Gbgu9H1s3X0ardG\n8UPpoBT1nHKWDMjR13pxnqP0XH3sOQ8Pyo2bRafbJwhiev02d9xxByh88pOf4uDBA/STPjt3T7G8\nusKFS3Ps2LuP7sUlHnn6aZaXlqlVh1hud6nUaxgb0Fhr4LyjVCpjzAhZf5TcJWR5n8uNRegFLFxZ\nIvMJteowmDKLyytUywH9fod2pUOaZ3QzhXyeLAfNuuzf9hImZyZ5yY//L1y48nJq47v40Ic/w/zi\nGhfPneIz9zzC9h27GZnczcXZM7AstLodFpcXeNn1+9Ewprf6zdtDvxne9a5foJcu8/7f+e+ghToc\n1wqC22znjIwI07vK9LMeQVUwcfEj5GueajDEW77vbv7qwx9hYQV8NsubX/8uJkd28plP/QNR7Bmb\nggcef4TkwCG0b/jo5RmMznHXrYINIEvgLz4Bb3od7K8qrZaQO1hZho88oNx7wTBiHWkg+BxqdQUj\njM9AvjJo46UocflMIerw6EM9RnYJ4oROE4bGYGib0lzb3LiqhpbZSw0kKM6PsWKwJiBzGYmBN735\nO9i74wCffeheHjtxhivzy1ycnefUqXniGF54807GxiepVYdYutJhdHyUWhhgSyVqZYN6Jc/61KwF\nB1nu6fcs6ruI8UQo3VRRJ2QuIcky1la7NJp9et2UTqdPu5NgAktgDEFgyDPPxOQQ73jzazmwu8Yz\nZ+f463vOk2WGaqXBgV0ziE/JndBPPUnmuOXGAzz2xMObzp0g7oCkeJPgpFR0mSI4n4HpF2e1+Jwo\nTLDqEDxiAkRzvOY4Xyh6InYwPYQ4EVYiB1QYs+OU0iG6aU63W2HlsVMESYt6JWIqaxD7DoFJ6SYh\n+IxqPWZy5za8c5CuUa948rhCb+4MfvEk/bP3EEUlslaDuGYwPqY+PEllaDvl+ct8x/IOqL6YU/Zh\n3OAcJ2sKr+xm4bME8rQ4hiR0hCEFiU3apCfvJyuNEl2+ggZ9rCnOUFMjQEyuPbxeQi4+hR3fhalP\nUoxIRW/4bq6s3kze7jGx7whhuUK/12D26S/Q4ArB5LVcevooQWyoJSlrj34Ggjp2ejfUU9Sn2KyJ\n685jQ0cc5RgFpyV86wCm7vB6E6STuLU5TBhRSto05Q7avslaLWZo7hPYsT3kzRJ+LUFWFuHWt3/j\nHPlWQZLB7toNPC7FerDefmeBHIPFaVG+8lqcbeCcrjtJceu+FYRACzOzuwJT//3f03DK5OoKtzDC\nxR2TbF9ZpLK6yIw5T7cHw5HF5RmBT/Em5KGnm7xbBCHA5x6vDvKEqFziN37t17jnkx9j+6FDlKt1\nLl48x8S+w1y58CmyXhdnIbOThdQrEFAoG9VySJpbzq9u0gTmcqyxiM9I1OA1IRu04ds8pxSEoJa+\nFRIrdNXjcyUTKFkh8IKSI3lCqdvlykc/xdqBg4zffB3Gd0gfP0k2MUrj0UepTu4lKXfp9Rw325DH\nGj32v7LC4pMTyK7LXHi6y0zJ0D03z3LrBWT3foGs3aZ0x21sv+kmzn/pUQKKLoKKConxZFq0qZat\noZU5nEDklUwhFqGsHq+FPynwm+sIAMhzIQgKRUcGhnVvC7M0dr0CXJzwPDVe44P/+A/8zA//KAun\nH8E88Tf88Bsf45d+5wMcuW0fSb8/IBqF2rJOUgZJunE2BBsHBepXqS9mvTQmMugu+JrfevD2r/L4\n6HPcO+vP6XPI/nOe26Q4iJiQZ06cJrBKq9lieKjGXXe9muXlRcIwwpuUo08eZXR0mi8+9BgmCCHv\ncHb2IgcPXk/fOcqlKnv27OHBBx8sTPhhxMLCFbZNbefipbOosSROWF5e5saD13PszEUunfoKo2MR\n+/eMUysZfO4Io2UmJicYHhtltFymUq4TxhUW584xEmxDG/b/5ey9oyy77jrfz94n31w5dFV1Tmp1\nS2pFS7JlgbOxMDAegzFjDx7SwGOB4b034DfADDCD15Ae89YwYwawCTYDtmVjjCzZVrCCZcVudc6p\ncrh184l77/fHudWSGdtQ7LW6+96qWn1v7bvPPr/9/X0DbhLzpc9/jiQWuL5FZXQIY8CpBjx35Bkm\nRye5fPUKC4sL3Hnv/QwMb2e52cTxRzY3McCP/utbefnocUZquVzcDQRWYOg2BFoZpAupFWMywfoy\nue2AgW4PFC1u2vc6ih+weMNd72L7zA6Myvh/P/bLHLihypkzq8gs37mOnb/E1uFJIuXwl8cG8Kx1\n7rlF4npw2374vU8J7tojuOcmQ7sFXzyqeXgBfGkjLVhTijQVOfnVzlutymgsacj6NycUdOLceLR9\nCooVKJUhiqDbAGlvbuFcXmwwWK0RRRFJnJCJnBtjWS6s1tk6NsbymZd46IkXqflT7Jrcx90HRxkY\nGKBSDZCWxPYEWHDm7BXOn7/MK0dP0+y2mJoYZWZqEM+1SWJoNFq5SaUlSJKYNDFESW72lqaKNO3L\npS0Hv2Dh+h6VWok0VSQqI0kyhioB99++hbHA4rnzczx/UhEEDu1Oyu03jDG3mlFvhFgWOBYEjqDZ\ny9Aln9/8N+/a9NqR7jpCuGjhkGIhRb7La61BhhiVIIzGAdz+IUlriRYJ0uTimo3WtSVdMh2DSUjM\nDg7KOxEL6zzx9DMUtuzFt5oM0kbIDBGF9HRC0TEUXcFASYEf0Isyav46zYVZmusd5MJV/OFJotnz\nrDz/BcT6CkFNYjuS9UuGsbEm4wVFefIgSTTIRHeN+9xp4tKNzMXH8fSriQX/cAv7x4aOkrwbYbKc\nLuAJHEuRRG1M/QpZeBpaa1B0cYpF5PBeilt3k6zNk107DVFIcvEZ5NAk7vbbEX4FhMEi5Yk/+RXG\nBiapDxZxtx6gvOUmqsMzeN1F0qOfYWk2xd5VxAqGsB2XzsJV0gwGRmaQTglZHseyNcL2uV5a+Cnr\n0++gsPNNFGpjqM4i1s5BVp78c0i3sTQ5jZmA8oUHEW2XxtHzJK0eNb+N4dsLkL4zabmP5ttCovqR\nAxiBwvSJbhIjNLlhbe6pIhVoV1w3YzKZxLFzrkhsNEIIbv6tj9KeHGOt06G4sEp5fAYaTeqLy2x3\n1pCBQKSAUARuBy0lOCmj1YgszUiTvNhI4oiCLXn6a0/y3379N6ltnUEnEa12k4IcY8ctkwzs/lcE\nXkC1UuFrn/1dothFZhmJkAhH4jk2rm9hu5tbQKlRpJkhFgKZxsRKIywLSxpiCYExiMzBKziI1NCy\nNyhuuWwVwPciOHmECy9eZurdb8snu91Btzus/dnfUDl0gIEbdjO6ZYr6/Es4pTnOrmsO76xSGyug\n0wMMJiGHDoesL8L05F4G/uZlBraOMOdOo184zjVjU/ngDxLHKY6CTFqgcyfnBJErbcjbBJJc2ZNo\nQaoVgZA4/XbUZoeAPppm+kqp/IaFJfr+QnlFYgClNOWCxe9/8s/5Pz/0E5x+9kvMTGs+/IEf4f3/\n9hd4349/H2maIoXCthy0NjiuTZoY0gwKBYFKM5TW3+yO3C96jDCv8nCu83pe3THMN8NF/V/gVY6P\n6ZsAXUeDNtpar616NjHeeP+beeyxr+TITFAiSTVXr1zj0qVLTM9M4DoT7D9wkIXlOo3uOqVShayT\nsHPHLhrrTcqlCkEQ0Gq1GBkZoVQqYVmSTq+BUlneWnUcysUiUanH0eOvcOvNdzAzvYNW/RJGdNDK\nRRgbzwuo1+skImGwuI04TZGOZnB4iFK1xqW5y8yv97gyt04kfBrtBoduvplemrIwd4WDB2e4eH6d\n1XqdHbv3sG//PnqdBEv4OHLzbrlrcpaZGwZZ6NVxQ0Hcy5We5SGwOwKvCFdPaAYLgso02MOCXh1i\nJYhTzece/yi1kSX+9itNPvAvfo5jp75Eu/0CfrAHSxaw3VyJZLkJS+FlHDVMhs9/e6aITY+7bxUc\n2A9DLxl+5VOCv98Pj57WLBUmObTH5ciRywSuS9FoMmGIeoKgYEibMDomWI5DtACTgbYgSfPuAEIQ\n9cArQGdF0KhrvGBz62Z0aDBXeCKRlo3KFGncodOGbSNb8cduwBm+i9//d+9FZBFJ2EJFLVTUQaUh\nQtpI28eybW687SDOvXdhFcqcml3k0a89wy//yu/S6sHP//hhCoUqvdjQabcBCKMYneZ7nrQlhcDO\nuYFxkneKpSZKFc1OyOSgz523TiEswRNHLrDDH8HziyiadLop46MFxoerzC6v0+6ElAKbRjvFCIUl\nDZ2uxtsyuOm1I7xuzjmBvnK2X7/YeVyNkBu+vv0hc3NYaTbuszagUcYgSbGMQcpR7k7eQnzlJKcv\nL1D2QWZt9hYF3WaDRmJRCBxMpujpmHXVo1yUDPsGWSpw+dhLrC9kGNfHLM9jeT6NMy8xf26OSuDg\n2xpv0MIfgTiUdNbqeNUVytWA7Ow8Y6ng7cEBni3VOJ0+hWPMN+1d/9ShuiF2/1CoyVCuyT2LYoXv\nDTNw6EbsqYMkta0Iu8CKM4jQLSrUKd1pMO3TdF58hujkVxB+GWfqRoQd0Fo+y65dhxkvV7jy8MeJ\nmj0OH3gLfqnCwtFPUeudZfeBtzDf7YAcxqqMIoanSY1GxXWwHfBrCK+I6K7ka8mA2zxL49Hf5Suf\n+iPuOLwX2VnC9Dp0FpfAL7HcTji77nDnjVtx2h7h3EK+T/s2VtH/tvPwj/jw9HuZOsk9dDAYIZFG\n5D69BnQfwdEGMp2SKShITdrP33KRGJNL2xUCJ1XM3Xsn2eQAWy5dxtRqrPV6pLbL9lqP8dXn2WaO\nY4sRpLWW32S0JhMVXlgcZunaCrWRYbrdCEvFhKnhl37xw1SnJrGxUTLDcgVxL+SpJ57m9vvfyP0H\nb+WFb3ydoFLGtQpIKdBSoJUmCtP+KX1zi8hLc+WV0imFTGMbTYShoEAHHj2jiU2KneayPJSFkjmy\n40pDz1Gox5+ktZIy9cB3kfV6SCkoKoMsFNjyI+/l1Cf+EnPkRe564C3Mzy1zy8ECO+8ex2o2WXsu\nZdU7R2st4sa9Du6tVX7v9x+hZxw6toNz33ayn/hRSsMDmF6EEZJM5kRQQ4o0eVKVFJpBrZg1OYLn\nmtwd2zWSjtEUtMayNs+S0zpHcfKsMfqyWXG9SEaIPh8n/6OUwLcFf/AXf8J/+Ll/xxOf/iOmt2/l\nbz/+UV545ll+7bc/QmmowlNPvIAXlHj2S5/hlSMnsFHsO3Qfb/q+tzIxOcLwRIUszlBCI64bX268\nTn9N58QeXsWZ8u/8w/EPfXo2ECZe819qxHXb9X/q+Jfvex+f+cLnmR4dIk4UUdhjdbXO6Ng4WZag\ntMsrx89SqlTZMjXJ6lqdanmAWm2AbdMl2q0OR08co1KpMDo6ilKKTqdDuThKpVJG6WkuXDoJKqFU\n8il4RU6dPEmxVGD39mFs45P2eti4dHs9lEmpWlUSragUPRRwdX6RXq9Lmikyu0Y7Br/iMz4yRdTW\nZMZCRZIkTqjX17jr7ruYmtlCt7NKvdWkm2bIf8a6oRKRNiWmC5YLMoHuOiSRwWSCbtsQ90AX8riG\nzipUhmGoCKorMd4SozMwIq4QnniaYaNxgnHi2KVSrTDWBS18VldiFpY1BVap+MP0lMtvPZzy/9gJ\ntx+S/PDb4P6b8hbXi13wSzaNuINXBhXmPLfhKvgBIHIfrTCE5mpuaJghcMowtgWWZzW9UFAuCjwB\nqTJYDhQqm5uaNDMImRJGMUnSZX4u4/bbbubnf+YXYHg3q/OzxJ010qiDVimoNG9NaA06xXJ8LNvv\nt3JyUYVjOUwXivzkO9/AT7//PTx99By//Z/+C0dfeIztN04yMFhEWBau5+AXHQpoXNelFyekaUaW\nmnz9hTEjNZd3fNdOVtbbPPLSFbR2saVkVnTYPlDFCiWVQYtb9o7y9LHLWKKCo23SLMWVklDl6iff\nd/nkw0/zln+7ufkR9gYKyzfT7jYQXItvPsRgkHbfMNcYDBnGiPwwnWUYA4OlG3HPXaDVW8cSMZPT\nM7zhhnGm1AUeXMk9cTrtDoHdxpZtWjYEDiw3JKmVsjaXcXXVY9SNGNURZAn1xUW6IYxXcqWwE2r8\nwDA6rHFLNmmvTTcx9DprlAcd9tgNJoMD/F1gcz56ikxluX3JJkbWSZEWCFvm5O2Wzt3Uw4TRH31P\nLr1XMS6LoDPsrmCVKSLLpVJ/ieLap3FqNxKev0B66lGEX8Ye3YVXnKAwNkGaxpRueRve+Haay2fp\nNRy8HffD6I0UgiHG68uo1SWGx8dxp4cQRpOlMXG3gddaxs402do1TApoiI3Dte4kd73p7UwWVzCX\nzpB11hibGSbrrOCHBr19O421BrXWGp5tMK6XAzDRP9NpeaOBIKXdh+1lH07T/QWSIzea3E5cZzon\np4rcYjsxua1/isDIPD6iuwhTQY+Z43M0dg9z8fI53PHtJOEsbjOkZkJsp4I2k6jsKSxpIYsgmppq\npcgv/PxPc8/r34CwfKantnDl0mnmZy9iewVsy1wPd9y6f5qx4REGfJsvfOK/YmVXKBbHQObGfOhc\nSl4oWKC/tfz4Ow3fKJJE4WeGyOTWXW5+9oIoRdl59ISlDLaUCK2xTd5SqWlNcOEi9YWImXsPY4Ux\nQa2A0gK13sFxJE7ZJ884sXjslXPs0gkPf3KZn/reEpcaHaarRcJTZ3B9h+MvQmp1WVprUfED1sOE\nQuE2pseGaUUxoQAPg0kNtp1/Hpkx+edmoNkPhRUijwxJ9cbnn5tEtjff0ULpfLXQ5/HkeTC5e+iG\nLcF1TEXkCKE2BplF/Ic/+C3+6/AAl57/Mu1eSP3q03zgDW/g7R/8Sb7+0F+QdNeYmCoRrXYpFhVf\n+cIzPP/VX6U6NMhNh9/Dv/ypn2ZgbIJMZRj1D1g2140W/8En/i0iIv53c8IcNtrgJRnAaLHpltbd\n97+ehx/5Eu9+1zswqkSmBUFpkChW7N4zTZIo6o02rU4MWGzftZsoiplfWqHbvUyhUODee+/hscce\no1qt0Gq2KBQKZCpDhKWqqQAAIABJREFUSM3pE+eYmpokzSKE5yKUhSccGu0Ox45d49ZbDqLsNXpR\nGxUnWJ6m3ogoBC0Gh4ZxHY8kgXonYc+2LTzx3DFKlRpeUCSKQ4xIyMKI+mqd2dl5ur0WtmNYb8yT\nZjHtRos0TQiCwuYmBlicjUkjg12UtGZzwnJtFIoDgm4j71HWJgzRiqHXNThFSDqSmf0GXTY01uDF\n5+DAnpM4ye+wd/oD6I6HPV7DEgEyCGjWNa02BIFAZLDWXaVWqNIIC/z3F0oEXp1tM4LxUc3XzwpU\nMUeGAtdmZHCQpfkWjhaoLvRSg2ML7Cr0mnnkjJACrRReSTBaEkzs9ZlbTiiPGGRP0OxAGguaq5tb\nOLnDcka90WSkVuNzf/3H3H7PWwivvMz6tdMYnZFH+rlIKUHaaCEwKkNYNpYb4NgOeZme+4MJSxK1\nV4h6dazZoxyuDPDg3/x3nnrpNL/x678AysV2c/GD61goDa12jyzTdKKEVitm/7Yqt+4ao9EO+cyT\n53EsB60dbAuGBooEBYtCUGbHRIDthywu9+h2JdWSTZJpqkWPRtwmNRmu7aC15Kkvn9/02nGc/BJV\n5jXdbXiNS/urIoXrKInJkZ8NYY7EQRiP6eBWqnKMkhikJ0KMEujiIO+69xA31y4TdT3a9Q6JsPGt\nKM+RlEUaWUa3rtk2nLEW5mrMdmzTWoq5Ierhiohr5y8zNVKikYLnKtxYMbso6PQUW2Y8LBnTXu/S\n7sZUBiIqhYxqOePd8m4e9gc50/1bhNoceqrjFGVz3Yk6DmOkUDiDO9HNP4fgrpwwp7qgG7iiyFTz\nzyFU6DWbbDkjPPP36KGdMP8K2cguZGUcrzhALXBwB4coz+xEjuxCSpssS1F+BacyQLF7hS0HdiPk\nVozngj2OjiM61hS2NUdQmkSPfwT32J+RPPmHGBtMllJMVyiVJMm55+HyMUQwgtE9rOoIDjAwuYtg\n7jjBwjKWV0VbNiopoJNvH0vynUnLQgAaYSRCZoDqxwH01VtGY1QG2pAJSRJn+UXv676TZc4sl4g8\nNO1aQvXP/pQzR08hbrybpdYiBbvAcquFijzOJsMMxdPUWrMMFZ/CBAEmSVBdSLXm0E6Hb1xo8MW/\n/QJpFrG8sgIKhgZrbBt3cBFYQlAMJBMDXZLOAq+cWsIIiy2j4/mNSYDRfWiz73nz2kymf+rIlMbJ\nDCpTWJaFFuBoQywMrlbYyhAZQ+zYCJXiSotYaDwEa7ZAHL1AcOdN+FGGX/BodnuozJBlCYVMkkYJ\nez74Q5QCn+6jXyGaXSbsdfnGpQWmtu/lyU7K8M33cuH8JdTyJQq+RCYJsZDgOax+8itYL13Cfuub\nCccmKSpFz8qjIkKtrwe5KmNIEGTCkKFxtek7Z+cZaJHJ40M3O7TKJ1sYcst0ARtQoOlbJYsNwnHf\nqVsYk/s8JSE/+6v/N7/+cy1Ky2fIsJkYS3jswY8xVJL4JR9PFBgbtBidsLhth6DjFGnHGa3uc/zO\nR57lwIH7uOvN72Rix0Ecz8qzW+gXMfnLXmcfi421fv0o2F//bKBAr70xvfpM9zfOzar2syTjpkM3\n8/jjT/Hhn/0ZTp04RtF1KFcqzM43KQY227Zvo9FqMzo6jmPbtJI2ruMyOjrK6uoqH//4J3jHO97B\nxQsXOHTTTX1PlTUajSb79u+jXl+kXCmRdRqUC0Ookk3QbRP2DM88d56De3fi2QHrnYvUSgVGh4cJ\n3CE6LY0xLTzPZXF+lrWVJVq9DNcLcG2LKALPDWist2g2WriOh1OrEicd4gxcx8FxXWzH+RYmj//4\n6KwKem2BtCEMDTqB+rxAWgbHE5Rr4HYgcwUH7hxnrR0yv9DA8uBNB29m0m8icMGTuEM1vvTkQxx/\ntsG+uyaYHA9YvBISJfm1P71NkjUtFhczItFkbHyGnlvir85V+f7oEg3b4X+eSjmwYxw7LlEREOsW\ntpOhsh6uEGRdQ2Z5RInG9lMsB1SiQYCnYHXJUAo0diborYATgBvAcMmg083NT6OVsLi0zI+9/4P8\n+1/6CMn6GuvHH0VrhbTsvrAky1WDKleROU6G5WhEvEqhMEPcOIURAVpnSLuCMS2EXUHKAhpBr71G\nfORL3LllN3/914/w0z/zk7Ray6jApxvmflhpmpJkGq0V3//GbVyer/P8xTppZnAshyzVeL5NqWDl\nxY7rsNxpYGOzd7TMkXOr+K6NJSRxlmB5Dq0oxHMcfMvh5aOzXNi8wA/PMf34nA3UJv+6Ma/y7K6b\nv7NRAOWPFKbP+Uuw9CghC5R1Fa1XifytNGrw5u0DOM2z+COGuLVCxesxt6SxfEFB2KjSIAXZYWzU\nJU0bVEs2bgEaHc2pDkg3IM00rguBm1JvGXZO24Q9TZJlXF0tIoIaI6UZEl9geJrF+SWmbjyEcCyG\nVZvXewdY1SeIkjObmhuTphgtURqEyXDGdxHsex3Zxa8gD3wOTQ8ZXcE0T8D6LIy/l8xZZP3FL+Je\nO0L32BGujDzAuLVKMWyhls6Rbb0NHJve818kvPA8Ztth7Ps+iDMwCtIi7nWoNU9Q3Gahk0UwLUxc\nQDiTdOIZEq9Iof4NdPJV6Cwg0imEO4aJl/AE7Hnjd9M0NnLfAzC+D91aIVpdxrgDsGOc8V13EBQz\nrOazyNokxvJxfEFw6/d823n4zgWPMRiRgc7j1wUgNGip8UsDSNvKbZ8LBUrlYdbri7zlx0tcOnWK\na6fP4zkgLAstDKaV4H7459g/PkajUUSQMt5yCIs+dpaQpQ08L6VBlRUxQ6LO4JFhpIEUpGVz/40R\nz150SKWNdm0qxQEynYLx2Dac4aoMJQzGZLTXroC0UNpge3nImBQyJ1GTOz2HUcL4gEWYSMJkcxVz\nlihkppFSIHQ/yoCcD5M5FiiFbcAyCqnz26QlLBIyCklGmGWUmy3U6AArUYjKDEblKFCcZaQSTLvN\n6c9+Frm4TMPz8ITkq6fXGNXXqJKRzF5m5417+dTLLTqL1+hkGVUswjglkxaXj5/FfuEEI7/yf9Eb\nGsakikxu3MI1icnzgIzO878E0DMKY2y00fhmg56+eS2STnM5urbyOAghN2DjDdtxcrSk3yq9zqsx\nAqUFWS/lJz7yH/nUv/8xPCvlwnLCzbcc5sKZo+hMsbLeplD0uHSuzfFuyt6tKcMDKWnLp1S0iVYe\nJ158kRcv7WJk5gPsOnAApVOuR0W8tkWFyBf2BoJ5/W/Rr39eG5OR/17KACb379Cb1KVblkOqNFMz\nW/nTv/wk/+tTf8Vvf/SjCGFjpKAYVFhYWCTJMmZnF5icnqFUKtNqt+gu9ti5cwff8z0PcPXKNW6/\n43U0m00s6TA0NIRlSxYX5/EDj1phJM/jscs41Qlumj7M6PgkX33kQS6ffIQdU0M4fgUhBUXHZd+e\nvZQrFRYWFsiyBMd1SXGIsxQ7cOh02mzbup1Gq82pk2e5dOkqxWKZ2+/aS5J2sB0faQUUCjbdbrcf\nbLm50esBmaEXQq8rUImhMgRxS5A2DCIFrSCzNUtLLVY6IT0FY4Vb+eAPf5z04sMk9QsgLK5dnWPf\nAKztMATCxnE8wgikC2FTgFKkSDxPstbUTE141Ea2Uxv0eV7sY2zLNsKv/CFZO6Adxaw3I1bW29TX\nQ2yhqBTyZOt2JyROwfGh6OWOs0rlN9ChLZDGGZOD0OmB4+avjxSb5sadOL/M//jPv8L7P/ABuvOX\nSZMULAejDVIKLCvD98ro5gI6i7Ecl3D+AlnrLGm0xrV5C8tqULzhR1CXH8fyixSqGbLgI6SFXdmJ\n9EsYb4jeyhW8Xos//ZNP8GM/9sM02x2SLEEbQZgm7Joqs2XI49Nfv8KA7yGkRmW5atFxbDzPwvUk\nUdgjCm186bB/psCLZ+aJE5sw7OKOV0iVoRfFBAWHLFFkqeHRR84wvcl2H9APoqbPydkg2tEvgnLa\nQp4a8M2HGAv6Bqe5kMWRHTy1m0XaoGfo1HvcMjXChYsXuanSobWm8DyL+dU1WnGR8S2TmFKZkekJ\n9uwZYTw6S9yAbn2ZXkGgB1POXYbM8VhaXqVlHLraMDpYIE57LLUFEzvGmNw2hDO0n/LkboYcm0sn\nztJqrmBSje61aDQSyr7ihvIBjpnNFTxCgJYuwmRYts/gu38JESRcPfkc2W+9iYHdu1h8+Rjulmmi\nSx28LV2SHePUVx1WO1to7PsQY3KV6toVUqcM7SVUr4Uol0mlj+cOIObOorttTG0YCztvnyqNaT8P\ng/swwetBljFLn6AcNKgsfwp6VyG4m/RCl97poySxTa3QL1CvvER5h4Oz9jLi0qOojoVILTICNJI0\nbaOuHcdvNSDqIKwM8a/+I1cmPsj2bzMP/whp2cJyPCy/gGXbSMdF2gFC5g6jShks22CMjeXYGCtg\n+4Gb2b7/IN36PI89+Lf01utYjqRTh++9//Usj/hMHz7IyYvPMuKPsWZ6BJHBNDQtHXKloTlQXKIR\n3IhuH8ELLFwLem1FGrvYTgHLKGKV30yzNCPwLbIsy/0FIF/o0kJpjco0RU+hdQyWnydxCw1GIG2b\nq4sxhYJhcnhz8LubxigjSRQ4dr/UEQK0QlsCpOgTuaFvuYhRGsdkpEISHLqR9ZePsPjYCo4lUb5D\n4LikWYatNYnRhLPzFKKQUEA3jFjPUt542ySPPPQoE9UiExMTXFtcJol6hG1DsWgRGYPMUoTlEmUZ\nluWjjr2CfOP9GJEnZFv9WIu4vyn4wiIxitTkQXoCjRaGngEHk6sXNjmyLOfsCJkjghs2OAj6LdAN\np1LTd+vcaG/lhWmqDGEieeA976N1+WG8Y9e4sNzED4pcW2+yZXwU23ZZaK4yv7DKfD1iZHiQg1sz\nZsZtOoOav394mYIzxz2lGja7URv8HfpFTL9F9c005tfG5L7mUui//dxMDYyRKNO3fd8k/0v0XzPV\nKUGpzPt+4ieZOXQrP/+hDzFZc4njiGqtComkXM7RNaUUnuexuLjI4tIi27buZO/e/cRRyuzVOdqd\nFnfceZh2p8ny8iJ79h7I1UNphgxG2XXgfu69+y5qI2N0ki5Xz3yd5aUmxpIUiy5Ga+YXL7C3to/D\ntx7khedeZs/eAzx35BhK2JBl+L5HsVhkvdnh9OkzedSAjICMJO1iux5aGYzQeJ636cgNyDk7rVZO\nLBaJwRcQLhmENCglWZ0z2EXDjpESh2rDnIln6QjNA2/4XtbWlzl1ZZnVM2fpdFJGPMPA2AC7vSEK\nlSFaPUmxCt0OlMqaoVKVlbRHXMxImyBsi8rAAOPjNXZunQYtKXhw5Pw52msJWQyuDSUBy12BY2kK\ngUXg5ER/nRmMk2/WxYJAJjmIWAjytHFPwEDZptPNaLfACjY3N7/+iz/D+z/4o/TmrmBMntmkTN6e\ncizonHiS6Mpj6Oo9MHIDlb2HaJw4weUXzpJhs/tdv0znyjHqYpyqW6O+tsz6YoilVwiqRaR7lOL4\nHopTt2CVpsmiFva1l/it3/kYh2+/i127J5mUDveObmNCFKi0KuzfNcYXFs5wrWvwbPAch8CXuI6k\nGyY4QuKg6YqI9U4TnQjedhOsegf4+nNXqPgucZzvCOWSx3PPXKYOlP4Z6CD0jyobSD79h330diOX\n6nrkQf8n9PUWtwHjk5pJMrmfJN1D2PQYKZ/HdWzKQZF2awkha8w3uiQioDI4QmXHreybqrFrqspA\nQeG2G1w6fgURd/CFy85Jn2vzLTKVMTpawxocZD5RXF7rsnUQDt0gGdwyQFTcSm1kkOpQSq+5zq0/\n9OM8/+n/idKKpLVGMRjhyLNPcOit7+K8M7CpebEDH+GWcT2J0AY5PEV07BN4l57j4vv+EEe2Wb31\njVSGJtj/vv1cfPIhku4K8wtz1LbNMPfI5xi67200mwOU4nWyXhMTtpFpAqUK7vgYOoowhTKyb/ro\n+D5JO6R7soXv/BrWfR+Dmbchpj4C3Xmyyy/S+epDiOghlt0HiEcOw4RmYOGTICE79TTyypNQqiBF\nDaNjrF4PVIiQCeELX8AxIU4RjM4Q4zfg1J9k6/xvwlujbz0P32mSKhO7QFiYvqmg7jtoKpUvUIuE\ntCfJVMhit41WGb5bJNG5nLnTqmNbLrqbMPZT/5rnbtlB69wFXHsBr+Ox3ljAVgqv6KFKKbWoxdDO\nMUrLFi1vBNaqVFnFCQTlIGQiilEyxHWKyG6CIyUpEukXESJG9BVkEomNJlF5Lz1OQ+prGUFJ4dsW\n0nbylFpL4lQKZEpzaXFzm3OQZmhhIWwboTSpyFtDSIGnAWPIXAsXQ6IV0lg5kTZLMGGGZ6UEtx4k\n0wYrVqRJiIpSbAkiU0THjhMurdFKYvYd3M/C/DWSMOXLX/4GtjAkMsNPe6hQYZa7VPxRGtE1Bioe\nvTTFMhk6yxGNxQf/nvKRk5Te/4PooWGUys2rJILYCKTIUNoQidx3JyAPXbVN3vPdvNYG0BqjrW86\nZV13Q95AVsyrmxEbHSUEQmpsAQMFQ8NxmZ4aZWU9Znb5DK6dMTk+glIJa/U2a6urbN15C2nSIola\nXFwpsdBssW3XIIM7PbxSjVb4DU6+9PvsvPnD3+y5c12m3i92+jla4lsVPf2fM7rvvWvyYE2tc5+W\nzY18o7UtC2NAZRl33naYF4+8xL13HETiMb+wQG1wiFKpiCPzvJowjNi7dy+tZotut0u1UmNubo79\nN+zHsiRz81eo1Wq8+c3v4OSps1RKZaqFLQyNTUK4yje+/nne8tYf4ND+XXyxWKXb6tBrNhkbnmZ1\naZY4aTAzUaVbrTC9dydLVy8zu7jG9NbdtDttxsaqtMMm6/U6cSfDKfgYWxD2UoJSCQsFogcCfK+E\n42xS+kjO/SqVBWli0AVBFoPRkrGRMrOLTZCGqgP33n4nd04Ncqixj7X1Ot0XnuDvjjzOqvIQXU06\ne5kdd+/lsVN13GKFpHmR6Yl9FKTPehJRMD4q9rB1hmVpZKxIe7nAIE0iVJZSKlWxHJskTBiSIMuC\nNCjgRT3WIggTQ6eXYTtQCERe5GORaY0rBUEA0ZIBFyInJz7HoUPcy/A82Kwv44d/8ReJlhcQ0uqj\nyhIp81if5ac+TyyHuTY7gli4SO9yTONv/j+C6gjISSZHB2md/iwXT5xB2F9CpYZCIMAqMThwN62V\nqwSsosI50ladytab8IZ2E61dZmzLjfz2r/4CZ/76c8xUx9m2dZyBQpXBbXsZtlu8d8+9/PHnv8B/\neegse8aHMDolS8EWgjQzZCbl5t1DHLu8jhGS9dlldt93gBPHfIxKAQehYGmpwb4Du7Go8cgLm0Mw\nALLrBqcCrV/l7+SXrugrLTeUSvSx5j4HtX/tp+zGNe9F9aZpr67QXl/lzl2DFOMVQq1RtqRnyjx/\nsQFD27nntjvZsaXC5KCN62ouXjjHi6+cZ4ctGa2OIrMGjhezZwxGRstMbN/F7TftobmygLe6yr5b\nt3Ph0hxfvKi4Uj9CVrnGdx8+yA+8YRIrs6lsPQw0yLSPl6wRWCH2+jK3jHz7ts23GtJ1cQYncEsO\navYqILCHdyNdwZnH/oLWhXPcNewigjJnP9fD1hnutoOM7Zim6GruPriTVZVyZiXlNi/C9FrosIVU\nKaLTQPR6GJUinZwjZtIUnaasVbaRTRQZDA9T+tM/hC0PsRoLaDeovP6nkB/8UZbTkJcef5BpN+b1\n6Sfh8Nsxpx6C0EWoDLs0ggjKqGtXMUmYG0YGLrayMLIMdPN9Pe5iwh1Y06//tvPwHQsepRRa9gMB\nlEKpmEQloCUWNpnJEFIQeE5u320g0zE6gSRKUQpsOyFZAfMDP8yVk5cYGBumsFqn9dyzdOvr7N65\nmwVVwcPFHxyjmfZ4Zui93BY9iCjdyEhxBSU1vneKfduW+e57p5irJxx9UeG4MBy4DNktPJJ8CQuB\nMQotQGV5I9exbYJAoDLoGBAqw7MNru2BMDiWwNokTSVWCmVrrFQTk7srKyHR2oDOsHwXKSDQGgtD\npFMSAYctiw+M+fz2pTbNOMZRCmUErnTJAolsNpn/6uOEjSYYxT33vY4Xv/EcrTjBM5qh0m66WRs3\n0SytJSSZTVnU6BkbO07RcUSv18VyXSKjsQ2EwiI8f4mlX/vPbHvPA8jX30s3M0hL4BhDqKGHRmuB\nNKBFnmYv0bm8fvMH9TyHxUjQfT5wv/axBH05url+CruuCt/4VwiUlkS9DF8kCG+cm28uMlXTfPnZ\na7xy8hJYFq1mg2K5hK0XCQIb7dgYZeFbis5Kl93jRYLpkCiUKHU0L3CMfs0maF5D2elTq/tffzVH\nK/+u6bMeNWB07j1ljCDTgkz974jQdx79k2a/uLIFYDSW7fCeD/0sD//lH2ESQ2O9zfhYAc/1cH2f\nK1euAAbP84jCiMmJIkNDQywuLrKyuoJScd/JWjI2MkxjbY1tO3ax3qiTNVdZadVZuLyM0BZbp0d5\n5cg5MDbdTsrM1lFGywHx3DyyNsFXH/kar7vlBh649zAnTl1lS2mQwwdv5cS5YzTWlhFRD+E7KGXw\nfQ/HBs8T4Ci0yhEZKTe/cIQAy4NeAmEjt7kQSpOmhjTJ5znRhsdf+CqZ+S5Ov3CaW6eH6KQpwwdf\nR8ktU47mCMYzXphtc34xIosbjI4NsGPGB+2RNiLwq5w7k9BOuygj8YqABMsSmH7r27IlYxM1ulfr\njBegg6AtJZYUuMJcp3xpRa46ycDyIcugFymyroB1gSjAUqSZGfeIPEOWGkRd4Jc3OT/dEKNFzoHT\n+cWjwxXoKepmkt7sRVy5TjcxxN119t1yCE9kNKIS1y6fZKQaMrrjAGcuXSVsdDm05y6ihVO02rOE\nvQQTlxmv7ECeO4FIPo1M78UZvYd09hXuvfvNtD72O2w7NIFU52m3qyRnFykevAe9IvjQu+5n19A4\nv/HIESpuQBInKC1IEs34kM9yo4tB4NqS09cKbFtaJU1iSgWPTOXWJqVije9+99088NaE4+/bfMGj\ndR4UuzE2AEa9AfVsXMv9FnQeGtJvpwtQYgyZfT/N+Rpr155FpyG2MQh2o6MO0sRMbt9NdVCy3lih\nvd5gbXGOPeM2vUgyMrqV85bP+eV1vrbW49CAYJ/UbIszdr/hEKPjW7DKgxy8700sHX2Uhc41MmHz\n7Nw0jy/ELNj7uGt6io/88cPo5G7uPbSFsLGMt3uClZMnkYNVRmuGUmAzxcSm5kZIB9VbJV5eRWRA\ndwV77CCVN76XnVEJSyaUrhwjTSwaa22StEtpp4teWCTZthVnbAoZa2bKhqSZkHWbyO46RmXYaUjW\naaMsDx2FKGFwvICwsUDQvMzCy49xvDPAzjf/LDN79hItr+LKCCHaLL3wRZJzL7N6qou1Y5roru+i\n+IYPYyrjtP/mT6kNekjHwmQppt1GxAk4BpEaHBGRaa9PT4C2v435pUHs41e54cC3nofvWPD0uvkR\nRMrcOVNpgeMESDTdXgNhXCzHJYkilFE4soDt+SRZB0vauffCQpfSb/wanHwR58AtJGKQtZXncScr\niB1bWa2OY6V1EqdI3bXwoxgQ1JNp7hjpIJPnEHGP1PFxgjXeuSPhxJSHX5as12MmnQTr8hoZFkJq\npNFoKbCkjSZFOj6lgov0bYoyyIEFkbuItqPw+nN7k0Y8wrIwlotIY2whMdImUSmlJEPVyigBxUwR\nofEMOMKilcS8PV3jufmIt4yOoaKUTy22KBtNKjW2Fpz9zOewpMBWCa7j8uxjT+FJKGUZQkpa4RVs\newilHAIryHOOhEBni2gjaLa67N6xlfVml2StTiYTHMfGMYKu6zP76b+jfOIE1R95PzgBYT8IzzJ5\nMSK1JhV5H9xGkKLAbJ60bNIYbBtpSay++YV4DZryavxcPv/SbJCHTZ9IDpWKTVoawFo/wfriOTAR\n9+4S/ND3vZcrCw0++9BLPHvkFKrZpug5jNTKVALNYr3LpdkWS6seb1wdZOedkqFwDiEcjInzvC24\nzsfZcH4W11Ef+o9f+wvl6g+trbylpQWZyr2FMrXJltY3xVfk0n1p5Zv197zn/ezeMsbv/adfx7UF\na2t1CuUySml2bN/BkaNHuOGGGxgaHqDba/H8C99gdHSE6aktHDv+Cnv37qVUKqIyh6TTYG1pjpGx\nMVpRj+bSKrVgmFKpTNlzqPglYtdjdbVDc3CAsZEC0oOlpTmanRZHT55ifHSUgSEXXItzFy+zNL/I\n2+6/G99oHvrSlynVRgjDLkZK/GIBy5Z0wi62FeC6zqbXjS0NcSjoroOt82Kwi2H2WgNcgfRBuDA8\nOUqhGrB/tMSWisPUjgHC5AoDIqDipxyb17xwdpVuq4OlYtyaRWpydcpIFdrNdQYqI2jTxfETYkUu\n01U6j5kxeZxA4FQpVlpU0MSJodNNKbh5JLLoL5FUgaMFgWvIVJb7/AiBbYFwDLEGT0LWTug2XEzR\nUCkIkk12bdK4h7Ry3piwDapxkdbsIisnvk6WhKwuzTNx8Hu58NSDlIqScK3OWlbALktGR8oIfwer\niwssLzQ4sLvG6nKdxkKDGw/fTBB2abZjenaN6vBhrl48zoy8xlBlgaw9xPjewwRbYLn5LNLycAKf\nSrHAhTNP0WltZdu+Q9x38w5OX1ziwfPzOW9Pg7Thhu1VXr7QxnMspGWzrgJWzy3g+xGDA1UanYQ4\n7VEtVNg6HJBUA5557KObXjtK8Wo7q1/kbBxFrj/vPzH9NrQRG+IJUNk76V2waDeOUQwkzW4IwjAg\nu7R6TXZOegwP+5hyQJjEaL/M0OAgp6+u8+4ffBu//Is/wyePD/Mb/8c7+cZfPcza+lVGthbYM1ll\n+o7XMXjTzTTbGrdYpjp7mnjLFo7Wx3HGNeeOnSWY7jA6qLih1KI1vp2lhXlkb56CXaXTu8bgcAmv\nFODRY+fI5kw9tbSgF5KuNbGziObf/QH+rgOYOMY0OtSGy7RPRZjuAuWBIVrrXZzlOUbveSelvbeC\nSKiefhwxZ5EmM8m1AAAgAElEQVQlCh32MN0mMkvyRIVM0TMJ+uRnsFnDymLsTgtPOSTpAI25qyR/\n9WtY0yFbChXQEqd5mR3+FhKzTmfHLuzRCs8/fZ77y/8DvXiFRsMlbMdMiUV0nJKuxxApTLWAiTNE\n1kTFDmYgP1BX559grdPDf//vfdt5+M5ZWraFEIr1pUWCahXbdjGZJlKG2LhILUg6bcqej2VZQIzK\nPASCLI3oXujif+CHkHu2s3NnyHm7wkDrLLKWEOkhuisJiy+foHr7Doqta6j6IDIQONZRhourZCsP\nkVo1SoUeupdzY/bsXGRntMShYCdJEvHMJZdrWmO5FpZdwLVdhG3hSJuujghMhBdYIEy/3ZUvftsS\nlAIPg0ZomyTdZF9CJQRCENkCLSRKGrwwRagMtxeC7RD3ughpkdXKpDrGXW+yUNQMl32KYZP5C3OU\nijWCLKTjVrHCNiXLJk6iPIW528EohUCSKY0vLDITkcXXkF6RtUSiMk1HdZFCosm4+dBBlldX2Ld7\nG2EUU19fJzOQWLkTNp6PuLRA62Mfx/+x9+M7BZI+3BtJhUee75Uog9VPvvU2GayaD43Q+YxvtIsc\nwbdIpc9xlA1VuDZ5XpZtaYigXT8Piy/jF4coVxxWOqtkCyeZDAb4yR+8g/vu2Mpn/+5Z5tc7LLR7\ndFPJntECwoZYZ3RMTK8+TOop4m4Xr+Bi+pyk/H1tqPQ2vmBek70lrr9HTe4hpA19N/E8fy1Rfb7S\nJhCe16qXrrtEAxjN2ECZbQ+8m6999XGOfv0ZZoanSLKMgcEC9Xqdm266iSzLuHbtKnEccc89r+OF\n519gy5YJ7rjjDlqtFuPj45w7O4slbNq9GL8TMrNzB5XqGLYbsLK8yOpqHc8tksYZ0nVZaba49tI8\nh7ZWUWKWhhGcOX6OCxcfZmpqK29569vpLs2Rdtt4MuF73nYvt9y0jz/7X5+j3QqpVEevI4QITRh1\n8b3Nt7R6sSBuQVAAoww2AtnNzUyrFWi2DHt3TfADb/p+1s++zMzUMCNTVUzJJewlRHg88/IlHvn8\nizRiOHCgmitKOxkiU1TKNqkF44UiW3YNcuWiYnTQZjFdJ+z0GVxGoozCKI20bISt8YTEVRrXcxBS\nI2TOm3Hy+whpCtWCRAtJN8soOrkhodaAnZsnKmmoOA6ptskihbXJ6ckVsnnUy/rxr3Dt0b+Abe+g\npwZpr5yhMlDh9Nc+y8j4JMKGuUaCKAZYq8t0u5qyfInMGuemQzfSSiWe6jF24z3MXTiGxyVq276L\nTKUc///be9MgS6+zzvN3zrve/ebNfa2sylpUKklV2mXZ8m7wAtjGTTdg0zTGHQ66GWCiGSCIYWa6\niQbCDZhmbGgGQ8PYdNtsxrslS7YWa9+lUlWpstas3DPvfu973+2cMx/eLMnD2B6nvzWR/4iMyg8Z\nlZknz/ue5zzPf/nmVznxw++n13iMYWGh0hQpDI0BOM0p/OIwtfIY/bU2l6MrbHZeZLVXR4sh3jKf\n4zMnTZY/ZcOxQ0OcutJiuFwkShS24+JVbfJmFNXfoCHOUq2MEIiIr3zjEre+boWxqTJ1Zdi3y72j\nDaBfjX+F7BNtzCshq2bnZSP4llEXhkH8DqKzw6j+GXzfp765iYgl+bEJtpttJobLFH1FeXSS9mCb\nwlAZlta48tLj/Mjbb+U//eeP8WDrAExYTIWr3PeZ3+Y1b34/W/2A0uR15GcOIfffSqm+TWt9m9zk\nQQr1LbwrHV48s87RuRqdrVP85Z8/yY/987fRbnRZPPdN0o01ykM3ot0hlJPHuCG1MgSD3cW26P4A\nWShBbQzVqtP75hfpPfgZbL/I7E/+NvaBg5RvehtpfR3bshgXKbrTRC09zub2JSxP4538PLEq4SQp\nKuhDp45MEmRnCT24SHlqlMLm3Vj2EMLNY/kunWZAKyhz5503UK2fpxg8jJVuZpy52IIZH2kkw9MT\nrG6uM5E2SB/8W5SAkm3h+CXCrR4Ig/Y8hOcjbBDVGiIIobuNUa8eHu3I4qu//AF+8e8vftt1+K4F\nT5xKon6TxvnzTN90O3EcEwqBIzU528eg8HNVLGEDCSpJMMQYo+jWt/H/539N6Z0/SGWmwHrsEOgY\nS9nI3DDJZguTtskf24dnKdJSja5qk+8PKMmUgr1Ms19hcjylHo8y6m7R7Q0he2eoTMHB8ecYyNfz\nwguXsJBYZO1o6dhIx8cgSE2YuUOLLAAVNEaKVze9uTpc0Fi7vIyaRCFMhGtZBLZFrDSOLbETQ9rp\n4wCRJbBERNqG94+XuZRYnA0SfiANKVmCK2Gfg0OjnNEWOdew9vmvE0chaZpgCXClJBWZ7N82EqEU\ntjAUXIcwDnAtm16aZpwqW6CV4PSpM4TdgHOnLyJzHtIYvEGE9jWRSolNyk+8841sFIr8zV9+mpkP\nfRBXKTQCXwsSkaJU1nvRZK3z4Ptg8UiRjRev8nYssfOCkZkl/VVJOlc7KwACLCF2ih5BqANUf5VU\nVqi6FsnGCqWyhxkM0L0eoRZMOy6//ct/SL0R84nPPs3JS+ucNobhsuTokEvaDmmlfZRbwlXPA7fv\nDK8y8rS4Ogbl6s/Dq1fCq50YTWaYaMQrxU6qyD5SQZx8H6Obne9/9btqBBaaip3SjzX//F/8BE/c\n9yDtdpdG0CEKI5rNJlubmxw/cQLHsYmjiM2NdfbNz2HbNoPBgJGREer1On6+RLPRot/qceHKGo0o\nJu9XcHMJjp9D2x5GOGjdR+kUxysxNXENOatPPwiQqcXKUgfXOsyJY7cStQbEnYCFuUmaW2sM14ao\nlnPcdOJ6nnv5AkkiGM4NY+yIVtLcIa1+e+Lgd0OqAQt6q+BUsiJhdFQwGIBKIZ8acirk0//3HzE3\ncoB6I6S4VOIH3/MOBqbP2OwCc4HFG+cWeawZcf0Nh6jXA/JKIyxJTdVYXelRqLkYy6FQKlG0JFVt\n0U9ShNFE4YCgNyDIBfiyiKtzIAYU8x6TRpAMDCUbUm0yW4y8oBEZEi0wWpEkgGUI+2BssCqCgqsp\nGAfX82gHBaIwYKSyu0vWVSPYuHWFxfvuoRfWkIuPg51nfXUVP2czun+BxOTpbS3hVSaZGonZ3hLU\n1+pURkcQGOqXnqE6fpgwCigFl4gq06StBLF2N5E8ysF3fZDWxgXy0zdjVIyVq7JxeYVHvgFH7nCp\n9WMWSi32LSxQGn4jF88/wcW4x4MPPsGNx6a5oZrj2UaX2aE8/X5CFAMixJYurpWwvyaYOPBuPqQ1\n975wH/c+/Q8kHZ/nXg750Q98gk/9l5/m2qO793CSAvRO4P0rPD2TPb+v+BNf5RSSFT+egIG8i/jC\ntcTdVaQNjlGQCkq1Ca7ZP8eJQz7DyWUaWxv41TxXlpvMTFcIHjjPwnUT/OWn/poHkmPcYJ9ieRka\n2wEVO6U2P83cTVOM3/Im8kfuBO0jvBGq88OYiUn6jTa1+jOMTMyyvrHNVtfiXT/5HkaHi8xd/jy5\n1Kbfh9baCsLNIb08Q47HoLPOYODtam10vw+OjfBymHwB02ogUx+VJOjnH0CGF0lPPoSDJFg8CX6F\nROdJCiXi2xfID1ZwREIUbmPiBvmZm7D3HaBcjnDe/+/QeOjkCk6aolaeRp99FtKUIdHjYGUM0biI\nTJvEgcEvgSz4GDePGnRQXpVCocL0O3+E+Yd/B70NRgrKeYUs2zjDI1jVoSwaQ0p0p4NqNHBVjF/d\n+QVTw3YXngkdfuD2a7/jOnzXgidRkgefPU3RyjHv2hlR2WhSc1UI5aKShETFWJZFmoLU8M005fbf\n+SOmiJhNBefaffykzv74AmpQJFIdKFUZ3neION7AaifEl17EzVcxZpGz2mfJfQ1G30rxQsIt3vPc\nUqwwMbeEkysTLnWwpKYkH6OxXcEI5xWVlDAGtEIKizCKcR2y7ocRWEZmxly8MgLnqgR5t0QVE4WE\nqUVi2YSOpmpZhEKgCz5OlIBOkdIgLYtxV/O1bpfy+gYfPjaHrWIa568wnfepBS3eduNx/v1vfZRg\nbQsv1rgqQduZedPsvjk2t7Yp5iU6CAnDAaFS5D0bkhhShS8E6SAkjUJsUSJNY3zPoRf0yFsWxrah\nn2K7NoQxn/rC1/iZn/hRDpdrXP7Tv2DoX70f20CsDakx9CVIBUiBpaH/fcjSbSmwZfav3CnAs8nW\nzu2Kq/wZnRUg/6hBIqTGtotstWosP/13mG5AbbTK0kqbkg9uoYSRgs0AWr2Y3nabD7x+hvSN8/S7\nA+4/vcnlnqKbas7dN+D0mRy/tPAEdvk1GKMwr9wD/7FVu3hFMWYgC1Q0WYGjNahUEClDmkIUZx9h\ntDuBsTLgpRA7kma/y1iuhJCQ2hb1pMNQrsyhE9czefwQ9ZWzjPkloqCDSQfMz02SRn2aWwHDwyPo\nRFEq5sk5Hto3tNttkiSl0WiQhCEH5heoVUcYHhrBKRbo9ftcvnyZ17zubSTJgAfu+VvefNtRThyZ\nYW39Eq2NlM4gxiv53PnaOzE7YxllOXSTlHOLp7ju4BTjY4e5tN6iF4Q0NhoM+gnNeo/b77iV4sx+\nUsvLRkN8eVdrM2TD5DWC1rjhwvMQ6qywEEKQbGgmFxzOdZr8+Ps+wMzEGCfv/Tz19YRKYQJDj6Cn\nOHLkBAdufobN57eZHB/Gtnyiep+861CoTbBgVxgeLtIOFZb2uP+5S5y/1OP4TfsZxDEFz8HxLFwv\nTzE/zEZXs2kVOTDiMFMrsr1VZ63T4Q13nmBkpMY9X/06FV/QCwy+a5FzFZV9Bm90Z+8XoRNKNpMB\nedEnHABSsL62q6XBLfmEGyssfvaPaTS2GTSXqczsp728ROzPYeiQ3z6Dqt3E/iPjbC2eZr3Vp7jv\nVibKb8b2LTprG5QPHWK9GVJLXyZe+DFKzYuMXVOjZd6HiAasLS8z4vTpnv1rvNd9lLQ0w1uvfRf+\nUJnlc8tMFAucWs8z7m3x+kPnGB0b57pClc6Ra5kfjnnn/HmCmmBjq0+oDbbjUHQcXnvUxRf7mUkO\nc+6xh5gvVTmeFHnvW34VYzs88JY6v/Z//TG/97uf4qUVRRT85q7W52qUDDvFjNHZBYWMypmNAtl5\nrkUmUQ+Tt6DPHoR4FSmKlMujlHM+agCuFZOTmpe//MeMH7qB0WoZW3Q4POuw/tBDXO67/OWXzvK2\n41V+qpzy8mqZ181YfOHZgE+/7xfIlUq84YP/CyulfTiDMvlmD4sUJ5dDFsfY/zO/wdiLD1O4527y\negnfHUetfpVcQRKM3kj7pYfZrFncf+9JOlpxZP8Ydm6E08sR/a2n+Q40lW+PNEa3mgjbQbgOlGvo\nIMYqlQlOPoa7dZKzz10i7EpGRgsIEyBJyR0twcVTVN98C+V3fwi/t4FrF3GiM6TuAolqI8rDJGkX\nW2sGWiALo3jH309yehnnof/K9CAm2P8WVD8mnrwT6idx4g62E0JZ4LiT+ME3WfnTv6IlYGxYolPI\nl8CyW4h2E1oXszuoAUuBWwIxdRtieBZZzBG4R6jk9vFOr0axfuk7LsN3LXjOLK+x0lCUfU0QCTzf\nJRUxMs0OryhOAYnjOugkQmnNE695DfUTN/GpZoM7R8eI+g1qvoN0JDIYYqUpcHoFug5Ieow4PZAh\nXsWmaCUk5gB2tIIKPWakoZpzuSe8g3L3NGOdc2DHSAuSbU1kBcxVimx1PWzHw3E8hOUiREYkVVrj\nWxJhLCTxK+m5CIHZyVDRO6MLqXd3aBmjiJXARmCJlCYKmYLvSvpGoS2JTZZM3u72yXs2p/B4bqvD\njYT0Uk1sSa69boEvfeFzbF5eo2JZJDuOWUalpCqlP+gzUS2ztL6BrVOCNCEnLAZpSm5Heq9llt5e\nKRRIohjfkkijMa5LmCbIJMG3HERqSI2m3mrzax//c8YqNYqzU0QXXiaYPUReCpKsIUaAIaezl4bk\n++lgZG8arbN2smRnXHWVSMhV/sw/+r9NlgRtuw5P3XMf//X3f5Z88QAnDo7y2btPc+O+Ks+swdn6\nJpHJukGDCF5/qMrbKzkGVo5mvcMd1+/nwZMr0GgyVHEplUK2t15g+mBmsojJUof1Tnvb7BCVd36E\nbPhpTCY93/HbUUqQaINKIU6znKRECc7e98fwhp/7ntdG6xSwUUA/AVXUGJ2glEfULdMrwgMPPM/L\np5aYGR2lP0golwzT09P0ej2KxQI5v4jv57h8+TLz8/O0u23CdEC5VGF2dh+piqjMTOBaNmE0oB/0\nGS1XQBlmp6bxPJ+zi6vYMmVra43CjQe47YbDLJ5e4tRaEyuXY6veZGp6Et/3CQcDVtfWcOMGMu1T\nHqpge1UsRBZmqTRh0geTUixWmdp/DcVScdf7xnMEUShB+tiVPvGmQA0gVwR3RGCXUsIYwqBBvRXi\nFyyGKwpXajbWzhEKh/NJSqup6ZWLPPj8Mq12wHF9EefUBsrMceK2m1g4ssDTzz/N9UfHWW0HLG/2\nXvHUwnGIowilUqTMJLYryy3inoW10WXEUyQh3PPoc8wetChfB811kEFm7CYlkBf0bIh6BrsPUShI\nY+jbmbDBpALx3U1B/j9ofOP3KN/xYaLmWXJCERbHGEQ5xo7fRP+pb2CP3sx6VMXrdFhZewk5cRtD\nfht78gauPPs0YzfcTnV7kWT1JFV3H0Gax5z6IpE9zpWz2xR4lGoux9iJ96HXDPnpHyV36C7+3Q/d\nwjIw1e0gc5KGdOhowUBYHExybC9tMTczSrezwqmNFB/N1kaIEiCU4boDFe66YxrVy1PeOkTa2ma4\nOkKaDChVqjR6A6o1l7smJnj0Y7/Pv/j13yY2W7veO5hvFRpctQR5lbwsZHaBwQjQWeC1ak+ikoC0\nG2DlHKyoy+VmkygWXH/oAG86NsRKc5LhoSrlikvS2CZOEkrT+2gEl6nhsbY5YF8+5NDsPEOrL9CW\no6ixOfYfPsrLwRAzJUmqEuI46xp32hvYpJTHJ3HnjjJ/0wb9c5Lxa/ejencwWHqebstDKouxEjx8\nxTA2nKPRSnHsPnGjgeft7rwSRmGCIFsjvwCOjTtzAH/fUZL6Iqqzhp0v4Pf6WYGYJMRhhDh7lqn3\nHGJy5F709uPk3EOYdAgsgaufzS6u0UWceAWwMNYkJArd3KLTrZGzhjFhG178OuVcE9fTyNo8Ymgc\nEoXpt9GNLkk3wrVkNlVI9CthzVlnJfNxE8ZAAPKud2G97YPI/gW29Bw96wB4Y1x1edPzU99xHb7r\nI3dlq8dWz2NtZZv3vkWjE4mREmH7pGkPB42tDf1+xOVano2hafR1J9Bnz3M0b0AbziUDBptr1Cyf\nseFxrpnapHkJvBL4zjpGp5igj1WP8Lw2PV1kLG3SHSTctn9AWpik25rh/tL1POd/gDFPspB7iAPR\nPyBiF8uVWb6XbWMZGykyf/FYpVmyrDQZZ0NkvikZoS073LQwCC3RZJlgu0GapLg2CClxkqzAshKN\niCWp46IlGGnwpCHRknYvJI/h7s0Ox6crTI4ZLNfHtgzVQo6ZoSJho0diVKaIixVawr6ZKosnVxDR\nAK0lvpTINCXVCca2iRKNY0EiHZJeF69SQESCVBt0kmIrg29LemkKQmPZLmWl8KQgGrQIzoZ4h2ax\npw7QJZNZy6vDHiMYGDK58S6hjcoyfHYIgtoIpDGZb83VQ0Fk/kVCvKqXECLriq0s9/jNn/8ZJsfm\n2AhivvnSFTSSL59r48ss6TtvZ3+7Sglebsac/foFBv0B0xWfsYt1brgmzze3DP2OYCKR5KoRcS/A\nLtpozI7zdlaQ/eMWkzEZeTobZYFWAqVAp4JEQRxnvjBLp5/kwoN/AHzvBQ8q+32VEQSxIpaSxHg8\n8fhTXFo8T7lU494v/z0VN6JYGKbVDSnk88RxFtRYr2+jUkmhUMR1XRr1BoevOUy9vcH6+jpKK4Zr\nY/iOS7fT4cTxBVzb5vCRBZ58+in6vQB0QhSHhKki0YKtVpswVDilGtotkQD5HfdnCYSDgPX1TfaN\nFIiVphe0KFg5bCEplUoYadHptwh6HQbdDpFKmZqa2fW+SWKF5+XwPIcrsg95wAEVglMyDA17zEzP\ncHC0xMblp3ndNQeJgpAnv/4Z7nniDGOTM/R6EanRBL2AMErpR4q75iFut1naavL1R77Ihz7wdrY7\nCQ0RUq+3iQIQRqJUgtaaKLaIopQ01UxMjdJpdYn6Csc2WHmL6hSYcUGAImpIoqbBc7KRujGK7qZB\nbwmMBZbe4Y8kgjQF42YBurtV7X/uk3/Oe5wK02/+t5z+0n8jabUxdsqFJ+5Gjt2BpZpsrHQYyvXI\nHX4N3eWzlPcf4MIjXydUksZ9/wmvdoh0EFGdP0ihd540MlRnh7myepraXT9Jb/UCrS98hGvf+n6O\n/uS/4Td+4g08/BIMAVpItJYM0gQ3EqzFLb5xOuLo1BgbGwMO75tj8dSTiAIoIxGW5g3H53nTm/fR\nafZpni8xNuqQ5PIIpSFfJUoiiiOjxNEA33Vw2wF/+7/+Mu/8zY/teu+onUwJbdjh7PBKhyfr7rxa\nDkmZkZxNZ0ASe6SWi0bSbG7hS8MtN9zEwbEyfnARJ19C6i469gjrCaOvew3ysecoc5kYQ6oSNhpN\n5mYLXCoeodkc8Ka3vpX9199IbPm0m3W++dwicn2TtcUlnti6hO16/OzP/hiv++F3U7nt7VQPH6Px\n0tOUJm5AmyGs3hNsrGsOTgmc1FAtKgaRJkxiiBTV0bFdrY0IA0g0JtVopRFejrS+SHh+Dfe6SYxt\nUau6BM0eBVJ6xRyJl2JGKlQne+jqD2HiKPNS0C3QIagWBgfcBczw25GNryGXH0eceYYrZ29lo2/w\nZm5mNrkHVwzQAxfLCbHnr8W67V9hWltEX/0D9NIiQkmkY6FTAylQyBSzWfj0zi8RgfX+D8PCKNtd\nxXL8OsomRlbHyaUDnPBFhj0Btg/c+G3X4f+n4GkzWnU51YB6J2aokmJhIaMOAydHf6jAJgGtbpvV\nQcD0bce5eOoFJpKAg9e9iYc3VpmwYWJ8mhXbI9Ex+2KLvp8QpYpCGBKtLtI6vcztdpPJqYROvsrB\naZ+V8AAr+QJry7AZC9pxm3BsiH7hBl7IHcWe+DF++qWP0nn2CaTMZ7JCSyNEAkaSqHTnFq/R0oCx\nECK7wZkdIpshI8hKnZlP7QqpBjSRVKQ6q0hLRjDQKRYSV8lM2aMNOVKIIrRwaHkez6yscXyogrZS\nLpw6R6ffp7XdIictkkGMQGAdmGNeDPiV/3Ajd38q4uOfaJD3IG/biFSw78gBls6v8SNvu4mtZg/L\ntnjwqRewjSDRGpkmpFGCkIJ+rIgB47j4RpOEMcp2sI0msQ3x0hIpKV5q4YksNE9j00fjf59p6Vor\nlNZIrTFaZpESgBLgGPHKtSsbcclMLUHmGFtvRPzrH7yLoZrFaN7i5e2EsgcTNZfGekitUqSEYqAV\n22GCJSwcqXAcQdyVBEKythWSyxtkztATKeW8prfu4Ezb2N/K1cGwY7ywsx/EjrFgJo03JuPqJCp7\noUYqI6imSlJfX6fz0G+i7V0mOzuaTmJoDwx/9PGP8bP/8ucII83f/M2f0ms+iyfHCLYblCoxkdsl\nND3qTY9iySOIBhQLZYTILP6PHz/OCy+8mAXhBgPm52ZZXDyPjU1uYpKR4SGWr1zk9ttu4cxLz+EJ\nxVpzkzBR+J7L2PAoqeVxaWmdg7Pj4JZYrbfIlUdIkxhJSnOzQ6c/YNBoMXtsnnp9idX1DsUgT7k6\nxL5DR1jfWKYTJAThgInxMbZXN2lu1He9b3IeWCLH1PQUcWR45vEWri2J+oa8C2kSMTkxQ9+a5Eoz\nz3p9k35vwOKFdcIg5NKZc0yNF7NRbtolUTAuNc0eLLgDtDFsrNbpY1MbG2P/sZt56ZHHuXCujtaK\noNVCjo5hCUMcB5RLBYwReK5LHMbk8jZewaabJnQagkEvS0P3EvBklk3nuhadvkIpUFJgGYNtC0hA\nkTn7RglYu/Th+dqiS+FPPsrtP/ReTvzMr/D0332WzoWvkqvdQL3bx4pbzFU6aGVTiC9i77+esHOS\n4ZmbaK0tYo2/g81Tj1KpVolXHmQrPsD8TIW1M+eI7BFO//ffpVyDO3/+E4yM2/zqu17LUjvPagK2\nI3FtCdIiNTYqVGgPLm006EYhE4mkXV8hTCFqtrkch5iggbrlGK1ul6WXA7aXBPkISmqAlStDv0dx\nfJzm2hqzo1Wifg9KRTzH5c9+4YO73jv6qibd7MS+GF7x1Mo4rTty9J0vSxMfTBFFgvGqCAGd+joj\nM/PcfGSKo5OC9Re7/MPnvsIv/uJP092+gj06zeDCJfqlA3R5ngNlUEKy3e6QWA36kWTmujfS9qfp\nlfdxcTuh8bkv8Z61bRYcifBs3j12E18tbvB//O4n+dztN1MYnsDMXstIbYTNL/8FqRboqM/immQg\nBE+tSA7MxBTsmJdfrnP89huJvlV//70gSrLLpjLoKEZaDsH6FsGLlxixQ3K37MNdXWUwP0J/CHQl\nR0n2cU2HzskGo+VnkPtfC7kFdBgi/DzGriIL4xjpIRrnSHrXU//sx1k5C8Htd3KmOOAu7zxGg+XJ\n7OVvQJ35MvadH0YcfDOW/+dkd/HMC2mzLZgcA+FKSFQmMtIgjUFMjcBkm03zK3j1T+NM3Iqqn2Sh\n8WdQXgDPxqg2pE0E7/y2y/DdC57VLmNDwzQ7CecuLnPHrUeJU8XKRI3FYofu5iW0KdKRRcbmF2iu\nn2XU8snf+ToeeekUUXUIL3RoRF1mvU1it0BuehrZuIK30aUdt5kpF3BrRe5PZxkddxkpjrI8cZie\nTlCDNtGCxBOKvJenE/Rorm1SHhrGF4J/OPQvmXpwEdmRSCSWsHZGIqCS7AgTUl5lrWVjnh255FVz\nXM0OOXWX9Y6TJiQobGPwhGFgBJ4E26SkQuKZENspoJAII7FVyqBQwFjwpPKpbG7R6XXoRCF3P/IC\nqcm8KM0FG44AABPMSURBVGw7o6+Ovf1tNP74Uyze20DEilo+h2NnstqC7TAxVGH29VNYfhGjesSu\nxVCxSHu7jmsJEgO+bRGhkdqQA+I0JTUKIS3yQjBQYJHQO7PIcJJgSOlZgqK2cE2MBjwtvq+0dEeC\nMjobaSmzo8LKRluplNgiGykakR0KAp3J4i2H//KR38FN10jdMssDxcGZPBVpqGAoeBbNfp/UsgmV\nJhqklEdhEGryfp6D15WQwrC0btH3bCQOo0nMl7/Z5cRtCb6dZLcSXi249I7nM4AiuxkqnWVkKZ2R\nwdPUEKdZsZMkknY7YOvz/xP1Tpvdmjj1U4f1Zp/zF86zeulF/vITf8iRhVs4MjdDODZJqiVPXfwK\na/UWtUKOfM5hY6vO8nqP6687BsYlGvRxXZcXXniBUqlEs9nEEhbbG1uMj46wb98MQT9gZGIKWwou\nnj/PeG2G6YX9bKxeIUwVaRjjqATPcXEsl62NDVZ6K+SLeTzbpjg6Qtjv4GLRaDa5684T+HaK4zr0\nI8kzjz7NyPgEI2MTdMM2U/kjTEwusLJ8ERNHDILdKUkA3CIEgwGdbh2SjCyolQHbULEcupcMp8Ul\n7r3wAO/4oXdREgl//6V7GOQcFuZHCBqao/MjbDQDPMvGSQTlqM9qdZKV6RvJNS4A8NADT3Lw0D5y\n5RW6rU0KRUhVil8ooI1GyATfy1MbGeX8mfOkIdgCVKpRSpO0oZVknEUwWDLbK9I2eJZNEibEcdaK\nTwCZGEyaxUroQCBSg9ylUOKei5rrKx7ii59l8um7ueMDH2Gt+WbOfPmvKEhBp/kihbnbKIiEeP5t\nbH/jExgRUx1+iJx3gCAIqE5M009dLCvH2NQQ3d4WcaipTU5x3Qf+lP0Hp7j7kx/n8//9yxSmDpKr\nwfalcwxbGq0zr64oifFcjzDR5IShGQ+wty/RtTXb9TZBbYIfee8bqVU9Xjx5gf1zmpce38YxEe2i\nR9Dv45Q0s0NFoq1lKpbF9uIKhZwkWDG4E4ep1HYnu4bMh+cVns7VV9ZVry2xc3EzIosfSi0kknay\niNbjhN2QYt6nWBvjwPgI46WU7uWzeKLPD7zzbZh0wEtnznPXyCi52ijnzj/GwWGHVGk2+ploRtsd\n/JFDtLsBuptQbMWcf/Zlrl3Zor12lpX9k9Sm5hi7eYKfLp3g3LlP8NRjT/KGd7072xiWTengDaw/\n9HmkFIzkFf+w6DJWzsb2g8I083Mt/FKJQbu9q7UxqcZo80rRYxxN6nms2wG1zS2swQRxwaFQaDEk\nGoiBjesU8BzF08tlXvi9e5m5+Tyj0+OUDh0j1gpfpWxdXCdfzHPvpz8Ga3DgPb/K+msOs3zucQ6v\nP8a19gu0XBfbM5hB1q1R2xD96c8gp29H1TfABpEIXG3oDAxKC0Sssk5P5iSZndfl/Vjri0T69/Gq\nNzK3/FuU821Mcg66xwANzgykne+4Dt/dadkBIRRu0WFurEIax9x97BCWHlBuLWON7CcXlXCHRohG\na+QG2+RmryXaXKF8aRFHOMweKtOXAy7rMfrSR24vUfFG2fb7iOFZ4ngdf6FA8WJKwRpGulVM0EZH\nCcKv4JoBlmVhgohCfhQ7bGNiSOKYbq7K8okfQH7xvuzgMjueL0CoFNaOKFpqO+tayB1llria3bSj\nknmFwPy9wzKKRNtkqcMGT0NsGWg30ZM5/HaArnhoKYgNuGmKMSlOanFmfQvvyjlGfI8LW1tsrG0j\nMSS2jQkVtbe8nmhphbt+ZJZ9N8Tc93CHfqoYsi1EqlBC8NBjJ7nm2PV0o3WGKxUunV7Hsh1QWdq5\nsEAmKfZO98SyJcJxEEmMIQsKFcrgCYuw3sseBGHIxRrfKAbC4CJJjUKa3fd4HGmQZL4mSMPVkD5p\nRHbFRe4Y02ksbZDSAguee3GTZ7/0F1SHKiAlUz4cr/k8fLnDhjYcHnNYbCe0+hHzwyWK01USA11v\nwOyMxC5pOudi3jTnkpMwd42HZ3zoWTz8wAbH39qkzTivyjV2eDvwiv9PuhOyp3eKHaUzrk6SGtJE\n0O6EXP7cL1PMSUQgcWWyq7XpRRYPP/g4Lz13PyUn4vzpe6n5EMcNrrvpJ9GWy1MPvIAwqwzlXHK+\nzeqldWZnphgdnUBgsRrFSCEJwxDHcajVali2j1KKXC6H7/nk/Cwd+5FvPsi73vVOfNfjG9+4jzte\ncxt9JXn0kUepC3BJCQd9WnGL3PA0hw4e4MyZJWpDw3ieS6fdwRAwPFYkbreZ3TfN2tYGlaEi9eY2\nlm+hkew7eIwDB+/iltdaPPiNTxLHXRYXd8fMrW9JXKtPu90DA56URF0DMczfci333/c8b3jLtahG\nm5uPv4ZqET73R/cQkXDtkSnOv7TN5MQQzSDBtSRxamOrAdr1sfJjFHOXmByGpx8/xYXHT/HFv/sK\nt94wjGNDmiToJCTodWk2NBurj/Hyiy+ju2A7WUyK1IZgoBEpuBJild2XjAItBIosYkIrIM1UWkKD\nic2OVHHnc3YURbuAtgxfu6AoHPKJlGb9Iz/P5MHD3PG+f8PmWouVh/usrb5IcXaB9tOfozJ3iOZ2\nnV6UUHNbjI7naW7lcPOzJIMO2yvrzO2v8Iaf+0XioMujn/skH/uFL0IR8mNzREGTM40+A8Aam6W5\ndoV2lDAKFHIGY0uMVFQdh0a3halM8csf/zC337FAqiNSLH5c3sX25hZ/9Tt/wMKBPN3BANlsUg17\ndOMew0UwsSZOFIWxGYJ2ndWTz1CaXeDg7pYnU1Mq+H9fZjLl56tfI3ZGXAqjDaWqha3zNJOQ+ZlZ\niq7gxusWEJ0rrC8vUSumXLlwliNzN/Hgc5e55ZZjmDTk4kaMXyiyttQi52Zd9/FRD2lrakM+uWCJ\nUN3M2nNPc52VMPmmt2A7CXHOpbUW4qUD3nn8BL4AJQSWUqgIHN+nPH+QpSceYSBtHluXvPeaEOWN\nUUYgx8axPAdp7U7FliaZzcLO3T8jIPoOlguxaaMHdcZHXXy/hp0fxVgOwraQpIxPVXjQuo2ZcItv\n/sZfc8sb4cXNEmvWKPc+doEffuM8J37hYTatDi9deJHS8klYXqJdm6HfeQFLJuDYJH2JTkAzhBlE\n9P/+K1gjFmJ0HJGEWI02YzmZuXSH2Tmtkmz8KAQEpkJue4T5IxpT05jacYx3A0TboLsgdm4Q8jur\nir+7D48nkTKk0GuS9wsMbI+S1rQ2Nwl7eYKTz1I/eoRiRyLLBcLxeYIkxF9ZYqjZpnz7LXTdAd3t\nLgW7SZpoqEyw1l9hYmoSJw5wRYlAB4xfP4oO+hSsLmbQR7sVImK0tFBKY+V9gu6ACd2i4eap+g5W\nfwV14q2ov/s8ulDOsqu0wEhIEoUxOvPB2Znpyp1qX+14n4hvcdx89UrwvSFRCkskaOmQhhFKSmSs\nEFrhx4okjUmUztJxpSA1JrNg1QZdr/PYs4tUXYvtdh+hFBpDmiqEBeO33cilP/kES/tcvvH5iPNn\n++TtLE/LE2CkpGAbTp85y6/9/E9xcb3JS089RRSn2I4kNlkxE+kUS9oIJCqFtokoXR0txQNSHJBg\nWxZCKXLaEEqBnSryFoRGI4Dw+1BpSa0xtkLplDQVYNyMuGwbLCV2Zu074aHSYKGxhctn/vA/UPIT\njJFMiZRHL/ZZ7kXcPp/HuJJuKCmGCQU35S3H5+imkshyefT0MhubitsnLDZnDEHPUKmlfPWBdRJl\nSLG487YRnrr/bo6962eJo+hb/vYCvRMIqo3cuSmaTH6eGpSCNIY4EcRKcPKL/5FicJmOk6MfhFTK\nu0s6vPsrn+XUc0/Sqy9TKRWpb6+wfOVFcr4HJsHLV5nfvw/bu0Kh4GCFhutvWODsmcvUt5tIS5Cm\nKUZDLpcnCAasr6+TL7jEccShQ4dYWV3hwP4F2p0OBxYWCPoBKlQsLMzzpS9/kYVrb8B2PYRbZGV1\nhUG5hPQdiLo88cxpDh69Bte3iOOIfr9PGA3w8z7byz1i4VFvdoiVhV+osNFcR3pyJzSyzOT0HG94\n04+yuXmFR7761K7WptsylIYFnhT4ecPELFw+mz2b+/dNcj/PM+IX6I+OsPXy8+A5FIZBFzx69QGN\nVsD2SpvqUImw02eqWqTd3ebRe0+RL9ToNTtUgbojcIvQbBqSZo+yAwjJ8HAVbUks22VmZpLNlU1C\nso6x5wusSGMShVQCP68ZpAIiiD3I+wZHClKdHSyYnUJfQnaygY4z/qB4RVG0CxjFswMYX5FUXcNU\nJUfEEhc++ksced07OP6+D3KdU0AnmoESmDRgbblJHEQMun3SnMORY0P45RFwLHKeJOj3+ZOffyeb\nm5C4oMsOQvqsNerYVkrt8Ov43z/wWkqlCkUH1q6c55Fnn+Vrjz7K4QJYvsXSep/5W67hj/7Pd5DP\nQat9EcsuISwX2fDxPIfYHpAMeqytrjDmOERRSqsxIOgqbFcSSUHr4oDtdkSn06cQGO7c5fIkO/YQ\nxoC4mmVI9mxfdXhHGCzhYrCQTg7FOs0tm/kDd6Kkw4lDY1w/YTi/GCOsPJWxCo2Vz6D0zQQdQGk2\nVmICcpSKRTbtJrEwaJ3i54q4hTJTE6O858ffzYV4kmfIEVTKbKeSm47O0+r0SYwmjmKGpsZZunKZ\n9splisNjCMvDGj+Mu3Ka3uKj/NkZi1JO0ejDzZUC1ZFhet0BrWabaq2yq7VRsXolQV5pIIwxjoPt\nW+BIfHWapG9h/ApEVrZ3tYI05GDyEnFyiK/Yt+H+0q/QyklWtjtshpqf+uBRXj77Er3Ok0ysv8jE\npRdpRy6XBoZJbfO5C/O8b/YS2ovRqUW4BXHaxPEsmPUQjsDoLtZolThJ8aMIIxVktFtQO7xl4EI4\nw/SlmNOf/DRv+N+mMGkLZB3hDmXPZHEcY3x0+zR8B4rTd5elDwKc6XHGP/wuvnjDDPWkR/XyGaqL\nyzSPHcW77hijKiWedIk9H5bXODo9QhyEXFo4zMjGGvFgjUFuArlvjMk4pbV4HjUzS9hrEiJwpYWX\ncyiKBF0sIrwqgSVQhTFUEiNUQs5xaMQh6fYG28MTjKYN4lDQL0wyUYBGroRtwEaDyQipsdI4Mitq\nLDTGZDwNfTXqgEyaaHYOc7PLLoYgxU5gAPhGoFSmFEljhTRpNvcPu3jSxSCxbRu706NvewjLod0L\naIssHLGXRlRzHkSa4h23svqpv0J3Oyw+q3npyXVynoN2JR6ZRFpqjbRt8rrH7//hn+P6HiXfYZAk\n6DiBVGG5DjYCV0BqVOaZpC08x6MXhUhj4ckE21j0jQJliNIEXwoSY3C0IW8sukZjyd13eLIcLysr\naHTWfVNYWAhSmXV6rG/psliWyyP3P4WJEgKj+aHpHEXLpudIbjxUJohSigJOrbUp5QvcPD/BTE3y\n9IUul1shQXvAQEoeeKhPvmCYqjn0tcfMfJ4Xrwja9Q7ttYC//fR/Zt8Nb8Id34c2eif8Myt0jLna\n2QGtJKkyJCqrU6MEkliwdvob+J172R7koN+mVnSZGNpdCuQ9X/kojnZQA4EZmeWa6+9g9eIZOp0N\nvviFjzE0doROa5Hh8Rz9NEIZC2lp/JzP1ladYsnHsi08x6XdblOr1VBKUSjkCfp9tra26PUGXLmy\nhBSGQwsHeOmlkxy/7jj5fJ75+TkajTrVkTFSJ08UwOmLq1y6skW3FXL98RvQqQKR0mw26fUDktjQ\nrA8IQ0GjV6dcGSVMJKOT07RWLhAnbeKwydbmBabnphmtHkNHhV3vG2TGlUotQRIKvDxUyoLYGC4/\n/SgAW4/fT+ponn3oy+yPLQ7ddoKT6+cJ15q0t1tsr0uGJ4ZptTv4UQCJoTYCGo1XyJEvSmSkSSwL\ncor2DoFdpSmDfhevNIYUCs93qVRKzM14NIII6Wf7wVhQdh0m8zatQYAxAt/OChnflkTKoEJ2pKxk\nNuY7nYVX+PHiajfie8dbpx1e7sP9rZgfrNmsdRLakcWR0RwvP/IVzj78FYQE14NyZZbc8CTV8gjS\nzWPyEA0GrD+5Tmv7Chvrm3Q6ECkIgcC16Qw0lbxmqxux1IxhaIa3X3OM9vIFClMTDO+fY+jIYU4c\nPcJ73/wmfuk//hZxX3HTjQf4yL+/haC5iFFDKJFjoLs4TgGZuJx/4utYuSEubbYISwkhglreIYl6\npOkA23cZqhap1xv0RI52J8FrNHe9dVIld9h4ZAcmAil3GDwW2NImRaOVAqMQRjOIUo4ffg/7a+Pc\ncf1+Vl94gs/f3WRfvsv+2SEKQ1X6Goo5m9m5YVR5P8+fbpIYg+VoHBfaQUpxfJJer8NQJeDydsRT\nV+DM9hX+2fwUYb9NkR52muAV8rhpSsSAQ/v20Q0u0106hxSG0tgsJlfFnbmeoaNv4tTf3cMPTGRF\n5fbmRVYWR5k6PE398jrVodKu1iaUDl4Uv2q1EaWABTmJTDOer8wp0A1ktjxZN9IC40A/mWCudY5i\nOWbf2hL54RLGLyAvNrnLbGHu/RxWQWNkREU7iJnb8GoFXvjCc/yzWRAqRY+OY+k+xUIpc2ken0W6\nFlpJjEpxaxInXEHmQEciC3PcgRbgNs5y+bU/zbHbP8RXPvLrvP1Dt7P2zKOU5g6QqxbZuudvUSZP\nX0kOHf/1b7sO4vtJNN7DHvawhz3sYQ97+B8J348AZw972MMe9rCHPezhfyjsFTx72MMe9rCHPezh\nnzz2Cp497GEPe9jDHvbwTx57Bc8e9rCHPexhD3v4J4+9gmcPe9jDHvawhz38k8dewbOHPexhD3vY\nwx7+yeP/AaEKzhrzt/cfAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Sampled completions:\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eYyl2Xne9zvnfPvdqu6tvaur1+np\n6ekZkkMOyeHIpESKpmRttmQkRrRAlgIoieMEiuVYjuNEDgwhAWTYiZ0/EgOKbUTRFkmwSVGkqXCT\nOBLF2dee7um9a7/79m1nyR9fzUh/iAOWgCAAUQ9QqOq6BfR33/ue8z7v8z7nCOccJzjBCU5wghOc\n4ATfzpD/fz/ACU5wghOc4AQnOMH/1zghPCc4wQlOcIITnODbHieE5wQnOMEJTnCCE3zb44TwnOAE\nJzjBCU5wgm97nBCeE5zgBCc4wQlO8G2PE8JzghOc4AQnOMEJvu3hvduLP/sHv+3K1gIdL2T3+ptM\nhwPipSYzL2ZfKNovvURSq+MuXMCLWyTNOo/FQ76YNlB5xEX/FodFSRCtQCEo9oeI9SYyihh3ezSb\nCotPljriwEfNumzIEdolZPE6o8GIWjRExg2GNsHdegP9yOO423tsPfows/4tYqMZf/5zNF5/jbBW\nRymwzuP6vT6hZ2gmCmHBOoETFoHCCMCCEQ6cBQcOyb/43WviWw3cz/6zX3SlUHRWVtntDphaiJ1g\nNpmRr6wgR2NMLaKe1DClZlxqikyz0Aq489JLZM9eI0sntKKQUApKB7k2bP3wX2X26d9lMB6hpMA5\ngS9hqKEuJdo5AiWx1mKVj5MegQJjNGmhmZc5i3HMLM2Q1mCdQAOxFMwsFIDyfTypKJ1ECij9mOX/\n8Rew1uFhcdainCTEUeCYCsm17//4txwbgM/+wW0nRYCWERqPUnjgh0gl8UKBUOBL8DxBEMKd195g\nvn+DL/7el7jzJ7/Gpy4uc2ua01mug81pL9bYP5yxPTTMrGUw1mTWkijFSqy43S8hrhGFivWlJmWR\nc+vWIWdPBbS2Ev74jyb8lfcs4gLYevQH+b6/9Q8oyhKHwNrqy1je+bnUDmuh0FAUkKaSLLc89798\nksxO8KOIjdVl0nTMfh7zq7/y77713PnH/6UrZ5JOc53dwztM84w4cNx843mEbNJeO0M6OKC9lGGk\nYjzIKbKShcY6d27v0Om0aC906B30cM7SWlig02kjlePatTe4fPkRrBVkaUoc+Tz1oSeZjsZ09w/Z\nOLVOs9Xg2ZdfZ/P8Zb70lS9QTnd48fnbaCfYWNlkpbOIF0LYsOjSMOumzA3U5YzlZo1xlqNlQpIs\nUhhLsrzEQfcOzfoSteYFzl+4ikkNaTrgf/6f/h7OuW85Np0t5QLPIRwoITAGZhOBlzjOJpbnb8CP\nXQa3GrK2GDLv13jsB/8aX3z+ee78yasc9KY8/eRpOg2fN1+/S6uV0OsX3N3NOXP1PIPxkGy/z7Up\ndGqCpUByruPzxmGGizt8/498gloSE4c+y0trvPKNr6OvvYDzBC/embHy6EW8JCDUGXce3Of5myMi\nI2n4oJQiDn3KsuTBnkYocD5IHAiBtSAcVZvpHE4LnLbfet68r+ZKLK8PDTfGhr+0IIl9R+grrp7y\n8D2J50vqkY8QJX5QQyiPIpvjXE5WeJSuQbT0ELZImVOj6Xb56lsx1+69wWAyI/NqXOvOAHh4fZlz\nnVWW2m1On7vIYj3hcDhhc2ONtaaPCCI+98bX+C9+ep10vku93mBuHLV4jVRrEApPBcymQ/73f/SH\n3L05IlaQJDFFmpMwpxkHGCHJ8hQ/iuilGaawRGHA79zoHWvP+al/rJx0DgcoJZAShBAIHJ60OEf1\nJUQVf1vjseBTPNb6ILvb93jpwYxb/ZLvvxRzcdWjLsY89JEP8cv/9P/kOz94mms3e7znqaf4/Fdf\n5plX98mmA8amxtKlpzh9+T2M/t1/Tm31w2x+73/M2uWrPD24S30v485en+79OxxMd/mV117jf/ux\nH6eIIAlq2IuOvrMsXnqMlbMPI1SImm+z8+u/xM/+D79GMxEYpfmOyw02l9c5+8hZhv0ptU6bR//u\nMfacjzXcjxpNiMM4gfUCtOeRDnM2WzNOfw+YGshZSLX4a5CVoA1eNuB6+WFU5zzPXtvDFimnQkOc\nzbj48CP0tg+pj79Os5khY5hOY/ZaHyDwSn7zV/6Yn/8R2P7u/5WVx5/E3/kitrTIm19CXvlO5Pnv\nROx9Edd/Bfugy/6vfpHmmkDkILy3P0OH5xxv1j9E8fOfYTXy+b2//d1sbD/Lpb/yg2ycW6L/5j6z\nnQe0ojr5/Qds/cqdPzc270p4dkrwXnmFrL1GY6mDvvgw01qN+WjG6d/7Hbh0EW91ncOlVR6ZD7mz\ntM5z9TXkXo/B4S0O2mdYi/YY+DFReZ+eLCmmddaDPqdWWvRNm0k+I7Y99GSEkiV9tYj06nh5j3AR\n7FyQqhqB1YSPv4/uNCNTPqNySM2WTMMYc2oT8/LzmNIH4WOMwViNJyVgEXhoz6GswDmHdnD0EhYQ\nVLznONiZCTwcw9ku1g8RWcYUg9CW3FhCY/A0jDKHCgLSWYkMPPanOe210+zoF6kJhTSWkXEIa9EO\nls19rvzlLfrbGZPJnOHhmMm8xA2mWBlilSS1FicqcoLNSa0kweJ0SegkeZZjrSO31QqXEnKnsA6s\n1SjPw1oDOGQQYJyhnM8RnkcuqvclhGXiwDgInTlecACsoxQOi8EicEKidYlHABoCIXEe9A4Oef6z\nv86zX/63nH/4Efbv3aW90uTAwsV2jYG1RPjcuDWg8DzONh1JEtBPJBNjGY4t49LiPAPaMpxlHO6P\nePRKk+aKT2d1Ba9ueP+ThkmrQEn4zd/6VX7gP/05rPNwDsyfITzGCIxxGC0ojaPUUJSCorTsvfUc\nSTPHFRFeVEM4RxgkJPP58XJnO+dn/uZ/S6Pe4BvP/hZf+Pf/NyZs0WisY7Ipdjaj152hfB+VlKT5\nDGk99vd3abdb1Gs1+r0BSnlk2Rzf9+l2e6xvLPPkkx+k3+9XuTMYcPr0Br1eH19Izp+/yOvXXuH8\n+XNYq8FoonqLUvcJ4zrbN+/Q6SzTXKpjTMbG2ir3HtxHKcmp1U36+29Q8VaftbVTzLKSzZVl3rq7\nB6VkPh0h5H32d+ssLy4RBsfPG3lUjCIpWF2BQQpe7FAW3jgQ1BcE11PHyl7BRl1x7VaX4kufZX9/\nRnulxZnzC1x87AqD7Xvo0uAHBonFiwSHOw8YFyUCSS1xqAicdVitwYHWmhuvXOPUmdNMJ2M87xaT\nbpdEW5TyyQ/mNJ+q01lbwrMFOgq43XseNQVFVVytE1Q7CiAcnhIIWeWWOCq4wgGuIv3HyptCsCAl\n5xNHagzPjx2P1sEvLS8/MJxZtPgyI15dIUzajGZjgppCFxLtb1BrzVlYvMJ0+zXqa1exd79KufAE\nH7qyyNWzS7xw4x6TUnNmOeWFN+7jEMy1ZWc4ovfaG2idc2p5ndt3b+E5w+b6En//v/oOutNXCeOY\n3EIUJuTGsNxqMy8KkAs06xtc/cg2r3z189iVFlmWYpDYWkRaCNLpBJQilJq7I0jHOaut4ti5Y43D\niYooKwnKk1TdrMRYBzikEAQKPJlwlR+g7T3OZz7/DC/sTGnUY65sdRAixsar1JMednqXqF7DOcuV\nh1d54Y++QX1hncDcwPkxFy8/QXzuPRA0udH4UR7veLzvY3+Jdjlj5e4IS8HZ5Qg7XmRtIeGTH/kI\nWTpDGoXVBQvJKqmdUc5nSClBGPBjChew1obxzFGLIM9zXFxn0C+RQUw5TY+XOwPHOJJ0pMGK6u49\n58A6h7UCjEMu/xh84uO4a59GZLvIT/0iYu93cb/9T+i+9YDTjRZi3KeeDrnw8aeYDubM9w5xwx5z\n12JRZCglUNLDTgeUzZjpHNDghjv4X/pH2OEYpn2Ev4/Yew79hX+ACGsQBAgX4VLACKyoyImrPjYs\nQABy2qcgQS6fZ7OVsvPqTaSxXL+zxxNrHTwtERff803j8K6Ex2vFJLKNnox5KwNSgevts9Q/wDxy\nAd1ZJ22tMB9MGI/GnFt33OuO2Wg0mbZWkLN73G+t4q+sU7w1oLFZ4yD3GFNnWLRIbMpyeYByM9Jo\nCW/WJUzqWGexnodLLdNwmViERA2PaZriZYbIU+xlPssrl+gM7rFf9yjzjDBpghPkxlA6iy8lOIVF\n4FuHRWGdrd70USAR4JxAHKuXAM8YhHX4zpAVJc4YjFJ46QxtLZ6xeLZkbnJqJqBZTpkSEjqL12qR\nTjIC30N6EqUdxloWI583P/cKX7rXJYp9lHMIUTWEYeBjrEWWBmT1XpAWg8ETioEzSGdASqwBJRxS\nQOYcpRDgDM4apFVIY3BKYZ3BGUvgCfSgD+1lfGGQQiIExIA2gD0mG6RaSDiDddUzGuFw1lEiEH7E\n/t4Dvvabv8y9l79ELUhpRiGzwW1GszENzyKNz/1hip8oCimZpo5unnL24ZhYggwUk14OXkiZQ+gL\naouWNJOM+xF37qR8/HLCS3cnrOqE5vmYnbcM585LHv+o4htf/jqPf+yj5IXFmCOVx4C2R2THOooS\nSi0oNRhr0Nd/C3yBzQSNRoulZsib93eIk+h4uSN9wqROnkNRFmg7oMwdSyun2HnrFVoNhy8l87mg\nFiuajQbTQUpZFDTqkjwv3tkJGs0m0+kUpRTj0ZgoipBSMplMAYFUksXFBW5cu87mxiYXL5wjKwqm\nszlFkZPOUzJTcPr0Eh/96AfJdMpsNkQKD1NaLm1dYEfucfdgl2YtZmllkYP+mCyf4Xke48kBYZyg\nXUoQl+wdvsFg1OOR7/tJnGkdO2+EhMiXtFsgfEctBD2CYiJIRw4ix58MAQTPvjnnzBbMend4/+Yq\n1+8NWdp4mEYtYq+wGAdOwFwqljo1droj+imcqUl2tKGpYNyDsdEUU9BNQ6kdYb3FQqfF5Yce5Stf\n+AI7t97g3NlFZBxw59YdXnrzVdpJg9kwpxhBcCQUIzW+UlhrAYc8IjqU4CRIV205iKOtxx7v0tfn\nM8FFpVnzYcmTFL7jK334QNPhK81bA8lqLSDd6bMQHKD8mASJ1ZJGzTIYrxCnt1Fqgf6D10kal+g+\nuEYWnsF5Po88cobbb1zjY+9/gu/72NM06g1u3N+j25/woYfWWG6vsrSwyHw65a37d/HrdbTXR1iD\nI8balFbjEsL3GI53iPyYVqPDPCv4nh/4GP/yFz+Pl89QcRPp+wwmc5w1CN/D5iW9WUZROtLCkhbH\nJ8tlUeUP0iEAaw1CgpQSQaWKCwlZ6bBmzmuzGfdeeoa1jTN43ZssNWs0Qh+nZ1x+7BEWliXXP/cr\nvO9Ch4P9Lu1mjSCpsdc7QHoeH//xn+PGNKG7d8iHPvgQT33s/SwFmnmvx/msixMOnY4JQ5+rVxZA\nxOyMuwjP4QmJsILxzSmL711koDUiUFBa8CTh+kPMclhbEGy0fRqtFvPhHs1WE5cXKHU8N4o3E+ig\nWhNGCIRzYG3VDBsJU4PwCkQzwkbL8Ce/g+8+Xu2JRcRwbFmeFzx9WtPsbLLwvd/B9c99EfPaLrVm\nwEjXsWYfz0EgS+qRpcgmjKbQK1dpvfZvwLuPzMCU4DzIRqB8ULUZBBY8UwkQDpzinb4BK6rP1BQY\nISnwiX2L0zEv/vGz3C0afOLDF9GDEflMkyrLyjeLw7sFKRARZbCEOCtZtB7m7g1WkhB97jSeiNhx\nJd3ZhHP7t1lb2mBsBCqo03vzdTaXT1FEgqC1gLm/Q2oiTNnEL4boaR3mz6LOnoVeTiYVyegW0itJ\npyX1UOL8VdyCYp7DkmcwWUo3BTtLiRIf8pR4lHHgtVg+fZ6D0idxVVEtykq+FFJWvEa8vXiqDuCI\n4GJwOCv4i1w2HWRTKqXe4juHsaIiL3mO0I5Ya+xoRGvBx+gU3e1SeD61zhLkGdZTRIGPsBZjNdIJ\nsjSjJiEMHZ40KCvRHniiKsCZtSSeJDWCSFmsUkgtKI0mkBLpBNoYAk9WBVuIo/FX9b4NEiUs1hqk\nscRRRKZL8DzcwSFhu43SjlIZpKnyrXQw4/iERwqJEwLloHDV5mOFxQGvvvIi//LvfpKHz29ybmOF\nL7y8S+B7nFpus7GYsCljClPw8Pk2X7nWIzWafm/OqCi4dl/QqimmxsMpR5RojBG0lxzWOJYuJtSa\nUE4tr90q+P73tHl1W7P3jZILTwUEZQRFj9t3+jz6UYE2AmuqwqRNNcbKtUNr0CVoXS3Qspzh8tfx\nvCZrLY9Op8b9B9u0mnUGw+N1WwEln//sr9KIz3M47BJGglI6SiyDYcHiYo4fGAb9MeOsYLWzSJGX\nLLQWwcHy8jIHe4cYbRkNh9QbDaIoQmvDzs4u0+kEZ+07v2s2GmxsrJOmKZ7nc/f6DZyFPCuIw4jx\nzHDh/FmC0IfSEEdLxFHCUtIgFD7zVotukRF6isj3iaMQv17DCNBOE9QCMuswYsj6aZhODrl/7wVO\nb106dt5QVhLIrBTsHzi8QDAfwcGhBR/W12D3DoDDAz7yXR9gdWMZr5wx1w8oyqJSD5w4UlQqIusL\nTaig9MCzlo6BdV/QuGx54sm/xNLdAc+8cJP2+lmc57HUWSBJEmpxxGBQ8sgVD08XfNcHLjANYpaW\nFnn55Ru8st1FSAc5RHHVOMW+B2isoepWBAjnsAaEEjhTjbiEOl6XNcoLrgmFNiV1HIlwnArhpbFj\nMxecb9p3lOqthYBsmrOYd/GUR6lnRP4hon4VVU/wJhHz0T1qrQUmD95kafMM+e6bPHG6xfz2v2Xt\nwodRRZsPX9qkFj3MdNAlu30LubFJLH0uLS4SNRLeeO4uZy5ptPBpJB0sKbZQtFunUQHsHtwiipq0\n1pf5D/72p/i1f/55lsKC6SRHOIvyQE8ztHBIYxEypBbkDPN3LU1/LrQDoatiajQoK1DSIoRFqSoP\njAFtBDp3zPSnUa2/QW+yx+bmJsXsECMWUXENaXKyIRTGJwpKJqWmGbapJRp7MOG7f+q/R61eYOOw\ny9XLZykNZGmBjB2XdBd/0KW0GqksQjq8qEF/NMCWBbYAGUU4Z1EzQ2Yknq8QKJwnwSREZy4BAqct\nN7vw9HqIsCXp3LBxqo2eT44VmyAvsEZhVNWMKjgi5lAWAjcD/Y3fwLv3NaTeRrbBDAADCI/Uxhz2\nC7bQSKfJvvw5tva/zL37jmTTg1qHLIegDn5UUC/GzLIhY8/nCzdzfqy+j8kr1u954LRD1kCUAqdB\nihThHGVZrQn39ujlqEbjAFMghEAZzRoPWFur89Cj51i5tIKYDylEwnxyH93ofNM4vGtW7QmPc1vr\nDLsj6uMhrG9SKMGgO2ceZaRZxoKFpSc+zIGWzOox+VtvcbETMwg0QbRBfXmN3s5nmQVNlucTimyA\n8PaJF9p416+RyAdYU+C1HsWPIozXwMUxByhqFLTsISJeoSwNVvgUCwm+nSKlwpqUIIuYJ010vYFz\nDmU9LHOwDk9YnJMIIZBHC0ICpbB4TuCUxGlbkaBjKjx7uzt4nk8UxSwGAbk1BDIi15a4KLA6R2Oh\nTMFYRJkTpxmHMoQ8RZUFJvLxHFghwDpyY6mVDmkMnpRMTUHgFBoolcBHYLQlFA7rPDAFUnmkxqKM\nxRmDJ2FSGgJVCTPWWRCSwhoCJIFSZM5S4jBFiRKS0mqYDBGFoRSW0AqsAqstoYOYYwaHP+2unHU4\nazDOoQEI+bV/+AkefughDiczXv76NYIwYDH2WN1Y4fa1Nzn0FB9Yi+mOCmazgnbk0bWOehQyKufM\nJx5XLzRJxAJW+Xh1x/2Ro1abM91P2VhfxJQa1xZ84eY2W2tNzgWKojdnr6tY22hz7dnf5BN/469h\n9dvenaog5RpKDVpXSpnWYHRJNuojXclBL+PK1jL37r5FHK9z/c59WguLx8ude28SBw3SJGKUpczn\nJatnN8lnmnHe42AEh70eDp/INemmc06vLnP7wQ5In9ObZxBCMptPaLcXKYqCOI44ODhgaanDma1z\nJEkNaw3bO/d56ZWXMYVmc2Mdaz3CMKTlJQhXKUlxkFAUJXsHu6xtnKLb7VILc5IFwU5/iA0jeuMJ\nj37gPfTvv4kR0Ns/QIYRXhySiQHj7JCtMzWybEhrUfLya7/H7dt3jp03pXFkJfi5IPKrfc4TAnCg\nIEiq7qQdC37gPZDeeJFnxw9x8ZHHEI0ZgRH4UmHw0Q6cFExmJe2kUlhyAzMB7RC0hDSDpNVmqkd4\nnmQyPsQJjZn3yeYFO9v75D7cGzkmBq69+TK1lVWSSNEbTMCANUdWQCdwzpKLP9NBvf2jEAjlcM5V\nREeAE8frtN4fw9dnhrul4pQoCQTEQtDxYKzhlYHj4QbUAsHB2OCQFKUj9DWjWQlC0xp8nU67AX6L\nnJh5d5d6u045vIGTAePZhNbGkzTOPoGxmjI3TAb7WCs4ffkKvhM0FWgZ4JzlzV+9zuiHGlx+r0E1\nOpS6IEnWKIoZ8+mQKOnQH/Xpp47/8Ee/g3/9zz/PbJ5V4xTlo3KDcQYcyCCgTEtQqlIxjwmrqxop\nJSCPVGZbkVBj3jbwVD49J0Aqj3k25+LmBQb9HoEf0+31+KFPPMF0XvDmH/wBl8+uY4qU/P6MuNak\n3tDcydZobJ7HWLh66SzDucZqw/pqAz+dMP2jZ8jOnSNI6sz2BoTWUboxs9kEa330kcJS83yENvgm\npH35PK7q0kEImlsXWA5KhC84GEnevH2fpz/4PpTUGA3J8jfTMP58hEd5qgEhHc4eeZqOJgLOgVBg\nt3cQAVjfQ1mDKx3kCpcERK2Y+e0MPylp2B7T+AmS915DDAb4RYOhaNBUE6QFX6ZE5ZS9VNJIAAPO\nOETNIc89gvG2CL/rp3F6ivmtn4HDEiczdHb0PByZ3Y78tkhQLifCYYRg0+1hpx3OX1gjKrfJZzXm\nGKxzJFtr3zQO76qLNVROdmebZnZAUuzjBY6uisgCB5M5siyxrToHu0N2NIwmU9RoTGOhzeLCArLZ\nYnawy3KrSWv1FK7pky6dI1rdJLYRtTNX2GtdZrD4KKmAnlpk6vnMywyRTZjPC2r1BvN5wTRZoxkK\nammJXjmD8GoYr0naCAgbq/ibW5hSgzJkReVBQSgQAke1ESEExjlUZSOsJGVn35mtHweNsM5y3MC3\nDqM1aMM8S8lMCrpkmuVYY5C5RszmKCFY1I5iMmfuwE9ifAHWWAIL0lqUktSXa3zPDz9B6iRPf+Qq\n61srLJ/q4FmH7yzWWAoNnrOQO4wxhAL8o/FceaTMOOOqPJEC4zTCWFJnGFtDqTVGG5wHKRbyAntw\ngLEGZao2SBXVuKswoP8CIy1jNMZYjDE4ZzHWUWiDlRKZLHPrcMRsntOqJazFitI5lO9R6JL5vEQC\ns1lJPRD0pilLiUJFAY+eTXjvxRYLYcDeIOX5nZLhqOTymWVQFr/tOLw3w2sGrJx2LK41mdShecZj\n+1XNC8/uIkpwvT9iOsgrcmMc5ZFB2WgwWlAUjqKw6LKkLKcwO+D63X2SWkhaZuQ6ZDRPQTm2dx4c\nL3cWamTzEuHqJLUmToTMs4KsLHnk6iZREjLLhggyrM4ZDecoFVPkBolCIphOp9TrdYQQzOdz0jTl\nzJktlpeXWVhosbf3gOef+zr7O3tsbZ0hiWK6B4cg4ODggFocMZ+MmQz7GAPb2/dZX9tEIHHO0UwS\nnn35ZQ6zKff2t2ksLJBpyWgwYTrOmExSjJFY6yNUxtraKkm4yHQYMBvFuNJwuHPj2HkjBCAFTjoM\njtnUcdA/WpwlOA1nT9f5gY+uEcZ1dp/VqL0ujcYiixunkL6HEg5jDcKCwENJR70WISIf3zpyHOsN\nWK7B3MIzf3KNfD4iTQvOX7jEysoilx66zKVLl7l0+QIbMawDRQjOSpzyefbFa/zBH9+ATGBL8H1B\nEAjCwH9HZXHWvTPCQgKeqDpmqIjQMZfVeiB4r2+5lxnulpKZgUA6IgGTUjDT8OIAdmeOnYljd2LZ\nmTr6M0tpHd2Z4u7A8frdCbdu32XUu8usSGnECXLlaRY2rnL+8U9g+rfovfYV9m7cw4wHDLdvQTqn\n39tlPLzH3b3bjG68xGT7De5sv8x/83e+xuvPjwFBHNRwpkDIAGMkRelhtKMoUmZK81M/90nSsaa0\nljwvybVBKYV2itk8Bxypc6TH33Kq8b86MiQfpYxzVaG3BkojeNsj7knJ6DBB+QGNJGRx7SzSD0g6\na7z2ym1ckRK3mgjpoV3AQj0m1YrSenzfT/w07eUlgiRBeAFJ6LO50WahFhHUY+4ub3DwzPPIWYYX\nBWht6QpNZh02jAgbDYRUCCEIlI83mBGuXKpmcghcMSMfz2h6cL0nSXxDEARk8znNRh1dZkcezG8d\nYdVTU9rqezUQrxoJqwXOgK/qqPoa1q5QjhOK2xKzDZQCfIUnLCx3cLUY01pBDHbRLkEqCERGqRJc\nCVhB6GlM7pGXsBkbpF/VIrHchKd+CPnIFYwNcMUYKcrqeY68O1a/3TxUD+pcpbxrGTE00Jul7F/4\nSaRnufDDT7P1C79O8P7vJE81ur7IxkrwTePwroQn6R1Q69/FHfYxnUtMMkPZ61FkOYP/5/fZDD1q\n05ICTT7sE+/vs7p1ltvE3Hch6yrD4nij68OdXcjGFM4h0hKnJsyzQ7TWOAL6QYuJGVPD4kufhVrE\nYmwJshzhCaaTDCMUpefQE00966K1JtEK8hxqDaSqEjwvDL6oTkeIyiGIQyCtrUyRAK6SmaUTCCeQ\nx2Q8RpeM8wzlLOO0xBaaUmuKvEQXOTadM3npVYb3b5OPRpRBzJ3JFD2fI2YTPOHIS402lsKWKCnB\nOq5d32atnaDHObosEQJOnVoiKzWZsVgsvnAo5ZOZkqK0mFyTWY0sNViNZy2lFUgHWIcpNQpBTcqq\nGACRcOg0JTz6+8mDXZQxSG3RWpMX+p15jirLY8UGwDiDtpbCGkptMaXGWMe9Gy+zslDD6JISSJSl\nFviks4w7D+6hneSJrQbrrRDrCZTvc/5Ui0Ip2oHH51/IePnalG9cG6C9AjEfcKqhIR3z1NmzPLG1\nxaX1JRaTgmTTZ2MdWqOcO5TSVi8AACAASURBVLdnpFmJcwX1SHF2I2Ha28MejTxMWZHFvBQUBWCg\nKA1FMaEopph0By8AV2TsHIxZ9AX7BwcESrDYbBwvNqqk1m4TNRTNhQ5Qo9SWIp8zGN9j8/Qpzp29\nCsKnlkSk84w7O9toC5cuPcz6+jqe56GUotlsUhQFBwcHCFFtEqPRGK2PcmdzkxvXbzKbp9y4fZev\nP/sci51lgtBHSIPEEYd10lnG7vY29XqDs+fOsru/R4HhYDrG+ArtCnZ3brG4ukijtUhRlAwGXW7d\nvE4nruEZj927QxK/QzNaphEsEVv/2HlTWIEXgC1gOBHMcoHWR4bTmsBoOHd1k2L5Ib6x73ixgDv7\nGbdvbnP/YMb6xQvgDIPRsCL/ErCW8STDc45SwMyBkIJTNUFDSW7cukfkCTzP47mvfplXv/ECd29v\no3WJ0Ya0hF3pqHUUZ69c4aELD1FoAylHZudKSTUGCmvR+qgYvV10jxRkAchY4N4+rWWOOdLS0A4k\nD0WOoRE8sJKxgUDBSugwtipo18fQzaGfQ28O2xPH/aFgmlcK5jy3DDPFKPdo1D10mRG5A6LmOvv3\n7pCsXGEyBpGOmRweoPMZe/s3Odh5nb3D++hScmM052sv3OWlssku8PJrJaQeg8kOxmp0kVMPlkmz\nnCCqs1BrUotbfPJ7P8A+oLVFWINxglQLSm2wuErpzkqWT586du4YK7C46hQWR6qtqwq7fdssfvTB\nOGeJl2D1dAMcxJ5lvb1Aux5z7sIqN/ZTLl55hOFoxMHeAc3IIYo+fZ3QtXVu3uuR1GPAsbhYJw49\nAl/hKcl7P/HdfGXUIx0PkKZgWCuZqxykV40z/YAwjLBS4XkSlR4deghjUA6nDSoI6CyHeM7gnODR\nC2cJdEFeljQaTfJcHys2wlYWjrdVLmdd9SzOoU01/pU1gdp6CP/9HyD81Pfi/fWfgI98Py6fcv1w\niPIEdnWd0cZH0Ebi0jH+d/40niuIVYH1E8qS6qBMNaAjErCUaIw8mjakDvvl38L8/i/j/s1fRXz6\n50FValV1NubIoez+1MLD0XoRRlZWCecI1x6jP+whpg/43f/uH/JPP/MqeXuJRt1jcLf3TePwriMt\nP2qiT50m2x1QZI7Y5chmiJooGv/JTzKzIeNhnyJKEN1DlLXslSWeJ7mwtcFh95C8n7OxtoXwh7j+\ngDOj1yjb5yjGKdHiGmvlmDxIkEDdSWZxC2VKUpez5AJ2jMHu9Wk0WxTZlKXOGomcMvJblLnCTjSR\nf4A+s4l761UK7ZhlBfVAgZQ4V42sXKVxYmUVUInF8mdmWe54JjAfh8MhtSGRkhCBtI7SCYbdA+x8\nRlJPKPYPsFKRtRewozFu1Idai0YtJp+lSClAhcxnGdLB3/yxD3HlQ1f4pX91Ceeg3x/yG//6mUrg\nM64yGnuSeZ4RSonCopUAXTIXVEZnq6vOBktxdCJYOktWGjxPoaQgdQJPWKQFbTPcXpegyPDLEi0U\nTio8HFMH87+Ah2eWFhgshbNoq9AICgcPXvkK97oTlushYVhJ10MPzm018Vsxo70h55uC3jQnzw2X\nliJmpWOtFnJ6JSaIchCgAstkrlhfS9gbpmSHU7KDCBlFLK4J8kyTBBEMAtqJR52SSx8K+Z7mEjcP\nuzRW2uTzlCyvTmFpDWUJpnSUZUFRjtDFlLLIMCZntv8WG4sd+qOMcWYolWBro83t7R7t2vHk9/Za\nwivXPkMSfA1rfNIsx4kuGxstGuEmWiUki8tsBDFCp3zH049zMOpz4dIZbt+6xu3b91hZ7lCWJXEc\nkyQJnU6HNM24efMmQRCglGF9Y429/QM6nSVWlxbxA8lbb92g1Wyxs3Mfa+HM6S3ubO+ytrJBaUv6\nvR6NWszpzVNIP0AEIfN0RqcWoJxAq5i9/UPanU20mHHx4U263T7bD0b4KmFjfZ3nXnyJjspY7hxP\ndgco5w4tBYV1DLqg/swOZXHMZpDUE7r9MUMv4m48Ix9MuJDlXHjoAouLbUzvFm+8uUNSQm4Eh3NJ\nmpcsJ5LvuxihZEhLRuxd38evh3zkiYcqJbS8w+Pvv0qy0OHq5UtVZ1nmxCF88PFHKYIHGCXoj8bc\nuX5YzQfU20fOBaWxkObMCgdBVUQqtZV3ZHl3pMTao+nFcfDixBBKS2LhlDCkDvapxmmnJSwHVWuX\nGrg/h9IKFgNHKGGQW3COQAo8CaHnCJTj1sASqwNqUZfo+psIHApDqxnS3QOb5WydWibWc4bGIy1m\nPJgqnnmQ82IP/uu/9UH+xY+sEQYNrJcSeAnOzYjjFtoG1Bw0mw9hLWTZgHpnkd//6t/hwx/9J5wN\nASkptCUQIIVinmqe+sufoh3Ex84dnEMgEOKI6DiB4E/JJkffrRNHR9en7Jn/g66IudD5Z6yc8ol7\n13n+5bs8/f6LBJ5gNhjhshl70Xn0lZ/AhTWuvb7HYO+Qb/zRc9iiZOviOdbX13jysTUatUrh+qFf\n+iU++xu/zZPLBjspyYoQz6uUB5fnWGOphQEqCoiaMRYD2QybTpFRggwNl68+SX/yh1zbdxxs3+P0\nuXNEgUdpSvJZdqzQJFRXjRjkO2THVvwSbR3OKTATGH8VBkcemgLEPGZwqDicl1gn+Pozb/HcquUX\nL2Ts5DHLb3waHS3gFzlSCNIZyLrDi1MWOh6rS46thZTbh0u8tfAU39X9NGI4QQYSV1Z+ZBcUoKu1\nbo54nDl6DQHSCZxwCJeB0QgkxDG7H/z73PiNf8XtXo/3Lizy2gtdnn5ojT/+zB/yH/3cnx+HdyU8\nunQUqcFRIsZDBsvLDNIxC50mtvToDw9ZKKuFVydnob2IDhSdmgMzIW+uIRghBndIXIkfQB+PNTVm\nXG+QzjLqUQtPgZERraRBNx2hw0XkaAfRWafTWmSmdyjjBkmjSek36alV6tMHTL2cTstH2w0WtoaV\nxG40tjR4UXUkUSIQFhwGIxXKvC2Xy2qUdSTumWNuPtZVxmTnwJOWHIGSglIIojBgPjZ4rho5+b5H\nMZ/jrKNmDSqbM5nMCZUkL0tkAb6sxjnlbMbowW3iRkwUBZxeT+h0PO5dc2jfoiwYU90fg6s64sCT\nRzK6wViLLxVaOJwFicIJjRWu6nhsdUIg8T1AkOsS4Rwi8rFaExiDwVJ6GoHAtxL/uDszYHVJ7hSF\nq8YLGjAostEe7zvfoT+Y4ZCUiURPCwokDz+yTv/2Hp9/a8x71xqkFlYaC+xPJjy03mCnP+FD5xcY\n93MGJkNLzXQ+p9NU9EvLg1lOMC/ZGfosNUL6gSHyPYw09A4c+Z0xq+dTvEZIVDgMIeUR0anELEep\nC4piRFlMKMuMMp8dHeGe4HzBre6cK5fOotIx03SCJxWD7Hidup63aTVSth/cY6HZYtgf4wenyeaQ\npwW7e3dYaLZITUo99sjyEUtLMXkxRJuMdKrZPLXBfD5nNBqRpinT6RTfVyRJjWazyXQ6ZDrN6LSX\naDVbNBoNOp0WS51Fbt68RSdqM5tl3L23zWg8JR/3WF1fqY6/mpLpaIDwInw/JNKamh8i8RlNp2Rl\nynZ3SFyTbJ1ZIlnw2fTWGPQ0K6fOUbt7G5OXPOgfHjtvPFV5XZSiuiurgHekkqP7kbwoIh1P0Lq6\nV6JAMOj2eWUyZmVpka2OwvMlOrfgLG0PhKdYWYpoJqCNZDrN8CKIWzEPDscEQhD4AcPBFC0SprOU\nTjshqdcpDPQebHPn+i4mHVMIj52dyqgubHWM0lPQCiQEkqnWR3ftVD46AeDzzghLUB0f+9ZvJ6rQ\n1ZaakqTW4VlHjeq/yaxjiiAW1ZCi7oE4Uo8GhSOQUPMECkmoHOFRkSsU+AZyzzEtDYlfqWmZkUST\nEicUTgumdw+JPMnu1HJtCK8P/lRd+J5PnmE6nRG0HUakKCtJU8N4NqReO4UxAdlsn1RXqtp83mN9\nfYG/9599F7/+y18lLwzaOjyv8gaZCxdZatXoHXzzLv2b4W2y886/RfW7txV+QXXq1UF1WMWCkgIh\nUl65/zOsrP4CW63H8Lweo7njlLXM0zHTGYw++OPEXkTkSU5vdDjc7tJqtVleb9NZXmDz1DIWSxT4\nZDNH5gyf/Os/yPQz/xezucVJgSkNKgwwZUkAqFoIvoVhisCC0bhSQyRQUUT70qPsf+YPuTUIuTwX\n1Ho9WotL4DcIwuOdDHVH5Oao9T9i3Eem/rdDJj0wBhtJCCoPpydT9rJFclvxgf5ck9QbKDdmPi9o\n1VqUxRApQpQtMdWQAWksyvOohSmy0eKc12VRP0PXnWbF3kccLQynDU7Zqkbb6mQjR55Wd1Sdofqs\n/LyPcoJ67ONqEYcPXqOQCfPte3g+PLbeIssF7eVv7ql8V8Izmw+pr9ZxrQ5Zy8ft7pH7glrnNAMf\nhPIR9Ro2y6lduspwNuWUnBL4IeloSiAK9NxQtNcgnSHzObQ2GThFLZ5jsh6ZW8SFEsZDekZTqyVk\nkz1UrUaZ5lDkTIVPYCR5mZJLHzubUfohXlxjpFNiQuYrp4niuBpv4fCVOLpjpwqZO9qBjKjGWMoe\nmYWrke+Rq+cYMIYSkBiyHKJAkmpJ4AzTLEWUmshKMinQRUnkhcxCiT+2dEddgkKTxAGhlBSi8org\nHP/+c7f57GfewqaOXEFZWuJEonxx5N+ySCfwqS4RdM6RGVuNxIxjLgTG2Io5G4uU1dhMW4snvcpT\nIx2mrI6tW2PxfR8yTTabITyvujAssxgFwsl3VKJjhUdbtHOUxlC66j4fqQR+GDNFksqAjZqPnDpq\ngY8fSFpBjBAB3VlBUAuop45cwlyX9IYls9zh9aasLERcrS3yYFjghIcpFC/3xrQWNMK35POc0HN0\nx5a7U0tRahaaCZcuR4yBXtfSeRisk0dkx1GWlrLI0OWUPJ9QFnN0kR0pPJpsPmcyz1hq1JjPJyxE\nHosu5sKy5ebu6FixmXUHzCZDOott5tMxtVpE/v9y9ibBliX3ed8vM8945zcPNXX1iO4GQACEg5RM\nB2UNDMumltLKG9vhrVb2xhtFaGHvHN544/DWEbLDCzFkmwxZIimTJgmCAwg0Gg10dVXX9F698c5n\nyNGLPPdVk3R38OFEVNd79Ya+N0+ezC+///d9/8qRyhErLalXM7YGCu/nQEmRl9TNirqxFHnO4rrB\nWsN0OuXw8BAp5U0pa7GYo5Rie3uHsixxFvqDIZPxmCRRfPbZCz799BFvPHyT1WpNmqW0TUN/0EcI\nwdXlJYN+yfHREWcXU9q6RSnFyfMXTCaHTFdLimFCYTzW1kjlSXOYTht2946ompYkTXEWVtXtTqEA\nKhGUvUic8Fefya5EIVSCDx7rJMMhpErw+7/1pyyBr7094td/7QNGvQLdaedQ0C9SBB5jPcYa1p2w\n/+lFxfj+GJGAlIG8lzMc5PR7JUY3tE1D28DJySlNs+CnjxacrwAbc3c2V79UGCkopML7jvHrSmfQ\nnZgdf5lquOWSE4Jg7QItkIdoaOsJGCC49mCEYNw5W0YqkAhBZSNbNmsDiYTCCwoZSCXk0QlM5iCR\ngmUTUAicCCxbcMHTGkHlBSdLODeScb/Ht9/eR3vLZK9H055TlCVJ3idR0HpLlvQo831kMib3axCO\nVGY07ZThYMLp9Jr/9D//D/hf/qffwQcZbeMhcJ1k/OK77/P24QQRbhlSRMxw6vTKQHh9fwI3YOcm\ni2TzdRFdunkaWCz/ex7bD3hw8E+Z3Btx/fT3WNeBk2/91wyEJE0TpFJkueLDX3yfLJPMF52sYWU4\n7OdoY+kP+7RVhQ2OK5mTBkviYxZc0AZpAxQ5vq4hpNirCnF9SpomUcjeVAirmaQtx/dKntaGQZHQ\nGIH3BiladHM71j2ETt4hvqCjDwJJPBgHK8BaxN4AuZVDliKkgsJxVnmyoodQijYIcAJFy7oGefwu\n7gf/GqGiucCLCCZjlIlnoEAIh0MyVleo4GltTpG0iOMDyAp49jQeDqwALwgdAgthA1hjWbhOBmgJ\nhVSoJOHyL/6CK7NkUC/YbST9411cI9l/ePSl4/DVtvQ7D2iCJDl9SZjsIoaK+3pBpme0155JmZKG\nnONBn6xasVVfwe4RYvuA0n/C2UKS7WbUs4qenpEGSbZ6jt17E7925K1lYB/TZN+kLjxpb4KqZ4gs\nwyQDWmPIyjGpmpNf/ozVes1s/+tsT59TH72DCbAlIUklKj+kGh0Snn5CIiBT6gbMROAT/6s69Og7\n54cghobdspxO5hxeCaSxOJXhrEd6SyslxXqNUhInLMfjLV5eXyGbml5esG6usIsF4yLduO5QQmBF\nXGBEqqh0S5kLshCLmd5bPNFGaEO3KMn4uiWB1rqoE0gSUmdwUtI55uNER+AINN6RJwrv40naeg8q\nPgFCeMTlJcnhYXQ4bcbO25/DowWN1ZiQYJxDO4kDlJdkecHjV9coIVimfY4nBS9rx9wGfu/ffY+D\ncZ/z2ZKV9WR5hvExZ+jyesrB7oDFSuCswXrD1jDl0YuK66XHt4aTp2C9I8sVVJ7do4REOgY9z86g\nprI9SikY7ikWly1BSLQOWBMwpka3S7RZYZoVWte4tsUaHYXUi1OkC1iRkPiU3ckus8U15yvYG2/f\nbu4oR9A1QYwo8x3miyUP3zgmUxlFvk1TfcZykXJ8dMhiOqNpNCYEpFRYZyl7JScnJxweHrJarUiS\nhDRNSRLFgwdvcHV1xXqdUtcNvd4Ao1t+/NGPGI16vDo9ZXtrQtO2TGczWu3Z298lV5GJHI6HGK05\nP7tge2ePk4sZB3fvsDMecX5+zmjc5+TVI7a2jiD0efHsFTZtOD58i8urhlfnJ7TakogCqW6fpeI7\nnr0xm1n3BVTgQdfgjEO3Dm0twzGoWpLuK5bnlnI4YrHQLOsofLfW4oRiXObYUBMQpJnCB4P14POU\nv/MPfoXRoMfTk/+V7cNjsjQjUZI8H1DkOW99+D7buwmjq5+SFimLumKVbDQQMdrgYu7Ym8iYI5Qo\nmtp3vEL0TgTXuWBs3GhFAsltWWUfEAScFKyFoCVgApQExggqDxfAXhq1LIWKDHftQHtwITDXgbWE\nXAnySEhReKJ9G5AiCn5dEMy05+Va0FhBlkGeQVWvODlfsVrAP3j7ASoIiiKlbedYN4zaMEoSSpp2\nTggW51NMEGyPJ9SNYWd0F2ta5hbKqDSgNo69d97izTu7nK8a9sZfLjz9ssv5QLCRLUJJ5E2sLEix\nId1kpxsPsUwiI+BRqcDZBXX1Rzz6/I/4H5/+XX6lPyTzBclgTK/Xw0M8PIaU1niKImU0KPDGkkgw\nrSEfDDDGdnloLjplfRJFKNaQZj1EIiF4nHe4VUvzyROGIolltt4QVEqorjn5/DF3797FfvIps7XE\nhRaVp9imohzeTjdINxKbw70KEPBIKbEenIewVARdwfPV68fOwdNmh51JQQhgnEU7Dd5TVUCz7IJ7\nYwnXdCGuXntUkVDKFuzqhgkdlVNMk8fnwUsIZdTC+S++yPh52NBxImqQEhGihMN7qus5yf13ufj9\nf8kv7xSUWU5bRbdf2f/yufOVgMdVNbme4e5sk0jJWGzh/Ijq4pLDvUMuL1dM8hY1SMlbw3D/kDKp\nuLo4Y+V2GY80Zr1g1C4xyRhlpoitO+QukPYLZsO36FUz5npOnibkrmKVlbTWUsqU0fJz2vaSbHxI\nur2FGO3gFxV2vMNWWXJ18RwjMnQ/wdZr5Ptvsvjoe+RFL05kIopVCAIWGeKU76qAUQV+M7C329Zd\n8OQIXKJIgkMKQZAC5T0i+FhaaxoutGZ7e4vp1TVpr0DagLCWoldEIOIdMkBBQiUMynuc1TiVYAQU\nQlJ5Q9I9sEIQ8x1QBBFhXColIThWTkda12hQyU25LoSADKCkRCBIpMAaS8CiUJAqggdXLbH+ACfj\nQpkEiZbh5woedNZinMZ6ifUSFwIuSFTRIxEpk0Ly+HyKfHCf7V6FWLb4Xo/aet45GPPTkzVlafnk\n3COTwLLMCTUc7ARqArVTXJxafvLCUbWeug0MS4VDYLTHJZJ1JdgbSIyXCFWyuLLIPGeUJaQubkfO\nOoxpMabC2Rara6xp0U2NMxprDdZYnFmSyQmuvSTbznnx6oy7ByOm84bFLfVfLmhkgMvLKf3BDqC4\nvj5htSqYjLfYGm8xvVogUQx6JVfn13gpSdKcshgQjKYse8znc6SUlGXJfD7n4GD/xq21Wq346KMf\n8c7b77FcrijznF7ZQ0nFaDTger5iuVxgnGQwnrBYLJBSMBqPSaVkd2eHsj+gXLZIIblaTEFYtnfG\nrOshw16GFAn90ZDr9QWLZYP38OaDuzx//DN8EBTF7RdlqWKZ1ru/SoIIktDZagEfHASHUqAkpIMU\nzh0/e/qKrVFOpaFwIKXicun5hQcKLwoEmjRVXFk4SGDdWn7jN34XZ+HxowXf+90/ZLA1JkMzHA75\n9LPPaVYzRvuHZJOSb779BmfXH7O6Mp0GJ77Cfi5IVTxQqeQLDEJnx4UO9GxYh9svOd3WLTrtRUBI\nWAHWw7YM9AKYIHjVwF4OiqjXGSCoBWjfBfIFWNlAIyCVgr6Phy7RhZUKoLKB0wpM8AQFqza+BiWh\n0oI1AaGg198hSzXGgleaQuQkacl8fY4VgiIZsTY1xlYEPyBVLdYOmC+vWRIrfdZaevsH3L9/n8mw\nz3TuWITbC96de+22FT7EsEfJTXgrIgKf1EfjEUHg4CbnJQhwNr6vuftt/pUHVx3wHwtP3TakSU6a\np9SNZTjIAAmuJUsVqYrrf9vaKO4l4OdTamNQSYo0AZGkeDyZkmChzApUdY3b3UFt3cHPniNVSgge\nkffYv3PA1XSOc/Bi5vj2lsc0GpUXrKufw8a2YU39jS44Otp8wLWe0EKYA04g0w4Ilh4nJHnicc5j\ntYluN6dQ0rJ68jGq61og8QQUxjjSNI6v31BuKpZ/gwOl2jjmL06B08jqdILqjTsrdG7RzXwMHrwR\ngGTVGjJTMdk/5tXjK+QHB6hU0SwbirJgffXlyfdfuVIHIZjkS9JqzrBZI2YXtEmfUBaYtCDvO+ph\nStbf5s5uhq+X6AX0lSJtLsiaJaZeI1XCTtpgsj4qz8jQNHIASE4MbPX7IAV52sPowK7UeFtR10vS\ntsKHnEruIHtHhMufMtAWMV8w7o8ZiTW+LFCFx3/zA66nS4pMAQERQidY9vguhDACoI1zYoMibl3Q\nIrgoeo5pzpKl952lF4Ru8QbKBHqJoqljqSjxHkcgSyW60RjraYwF49DORooxkQyShJZ4gqxDGyeC\n8zEGnIASMkK44PB4GmcJPran8C6eWggBGWJPIikFUkBwFms0tdUEEeiVA1SWdjVci18usd7gnScx\nBmNrgulCRm551a2l1pq61bRtg9Ytuq0pxwf0tyTXreNga8TbOzLGCGQZj5+dsyawbi2zeUNpDdPZ\nmucnC3RjuVi2nJwbPnvZ8i9+d8m/+3jNqmlBeD54d0i/p/AWVCpx1tHzgY+errlaaD56vKJpNcNB\nQnVZ41YOpTKc11hbY/Ua3S6xek1brzG6xdjouNNGU+uG0bhgNMqY9DOUCiAMeb9g75b6yhfnZxiv\nmE3XEAJGtySJZ70+xxjF9niPLO2RJyOadUCQ8fjJC9Zrw7qqKMuSsiyZTqdsbW3hu5DBk5MTvPc8\nfvyEpmn5+te/gbGWPM9J0wzvPe+99y77+3tMJhOc91xdX/HRRz+iyHOGwyHVehUddNbig+fg4AAf\nPMtqDsKxWk+ZbA9ZracxNM5YVmuL1prRcMhHH/0pu7sj0jSh17u98DQVnXAxhL/yTHafN/GUHUIE\ntst5TAPfzNDtcUHZyxgPs5jo6qPbrpdv5nkEJmnockiEoj8s6Q8LigLuP7zH2++9y4MHb/Lwza/x\n3rtv0eiKpRXoVeDPf/KC6+tOwxLipiAiodCdCzpK+QtHySBFZFKlYKNkDi7mPt3q6s4e7sZdE8FL\nJaDqmOKeCAyTwFUjaDpGSEooFRRSUG4AYiecrm1gZgJTE1homOvo7jprBEEqiizDWCjyPt/+4EMO\ntoZsjXaAIe+9MYmljGSEZ02RDZm37ub+9ESKcwbdLhHK0Zg52uUs1+fsjCb83b99n2vvUbngrXfe\n48HuBIQg70/YK/PbTp14P72IQas+RAdm1GrjQszfsd3YbZQ9oRtHF6LLyzlF00jWRtBWUOydoY3H\n+0Cvl6MU5FmCEhKCi+G2wSOVREhFkiqkSkjzgubJ5wihEEHirYkOIx8IrYZOmiDyHulej2A1Mklj\nmK3VWONx9KMVvS+Y5ApnLNfXs2jGcbd0aX0BXAc6PU8HnEOIImEpAnIQkPtlbGPUxDLgLChyJQne\n4xBcXFyzXDfIvI+cv0SIJM7Jjo6xrhv31kTgI6Ij2nvwbSxbdXREVFZ1z6XvdFWv72F8na5DZ8Fb\nnPfs5Ak7x3dpFnMqAUJ62spyfTGHTBK+omfLVzI82+s5vpeSDnfplSn1fEaoK8a5ZtbWZIMx272U\nfn1FI+NJuswzqvUpAxNwIcVnI7RoWbYVaZqROU2QBWNzSuZ6WJaYfItJmTGTfZQ9IzRLfL6Hv/N1\nivEIu6gwwbEKkuLf+/ssm5pRO6VgwEKnzNsWuc7ZPXoTefQGmTBd4GBcX4KINVwrYn1adguTDZ1W\nQHRhVbe4QvBRRIlDADke6Sw6KKRQZMFQJSnWeFTwjNKU6+Ua3TTYRmOSCFAyAa33N/Sd1iZOBOdR\nwZGrmLLsfWwR4XxACo8LggKBEZC4gMwVofXI4NDOoxTYTrmUipQklTTG4EMgEwopJU5rXAggFQHJ\nlmmx2iCCwBA1ATmaW/Q3vLn2BoLpsmKuA41TMc5cOpJygrCBDx4eU60q+mXC3Hj2Mvjue4e4IOkP\nclpt2e4pJqOCi0VD3VoGw4SgYG09w96SXMJwOMT4QOYSBpnDFy1FkpKVkidXNe/fy5ivLFILRJ7z\n4qxmMko43hL8xR//bazHFwAAIABJREFU78zPTvnm3/svMLrGmLoDZw1Wa4zROGOw1vLOW/dYvjzl\n3TfuoGuDlJrpyjCUt5882mWY9ZrDw2MW8znDfkZoLMpGC2qW9vBO4hzMF0v29rb57sF3cVbSNoZ+\nNmQ2m3N4eMh6vSbPc6y1DAbDmKw8GsVTUoDj42Nmsxn3DncYDPsYrVmvVhRFDgSyLEPWDUIIBv0B\nx8fHKBFYzKacn58j8wFSeIaDIUWhqM2aPE/oDfq4EEhVwv7WFnu7uzSN42h/RLVumIyO0XV763nT\ntDGYT4buMHKDtUO3qNI11Q0Y45hVIFOL6+bo9cWKz7KXNJUmEWDxiASqpsFah8oEMlHY1mNzaKqK\nhw/fZDLu8+MffMb9N+6ztTOh3++jpIxlZusYDHqgYLmq0KaTwH4h7CV0S3jsUMdrqUg8e930BBI+\n/lvwX/j5v+HlO6Fp3BzitNvc5ysRtTrbKjJhozxQ28iWDVTEWmUSreuJjPZ1E+J6omMlgLYDQcZF\nR5wQlmkF3/2FX6BICzJamizhfp5yfJxTMmc43MOLFYkqma2fU+TbOK8p8h5KBJKkoHEBY1O2xgfM\nlpdsj+8jZeCXf/UDfvsPnjHav8vRnXsc7Y6YriqGvV5sXnnLy28scC4KwpN4wr2hCf2G894Me3dP\nVJQ/dmMZf4cPUdw+vfD837//n/HBO/+cb3/wDcp+jzxTGOsRQtA6TS+T9DJJbRyNCxQEVCJoXz5h\nKDK8N92G3TFsRoMqsM6Qb02QZSDYGpRCJJJQB4RKGB8fk0nBu4egHRFwJRkqFZh6fbvBCdzolXyA\npKO0RPdv9TKQLWNIo0gUYdmgVJxbRqQoXPzYB9ZVS23i5ipWJwRShJIdIM9xrLBWkPgY/EsHgNGB\nkIiYf9dJKW4eIfeFrgdfuD9RxxMz5YTTyM5lvDae33n0gncPBzf3zdmYZWe+4oD+lYCnaZ8zOv4l\nDoc9ripNPT5EXl3iR8eEZUXpA4PRDn7Qwy9eYfoTliHQ2zmifvkZl65kq5TMQ0kvLfHNFK0tl16y\nn24zHPe5vhT0tEMl4O0FvTKnrj2lX5KaBCUd5wQGaQ9bVTTVNT3tCNs72JCyGjjyVlMkNZXNSN56\nH3nyAwQKEUS0oSMQXpJ0iD50fUykCPggEUFsJOx/88v5ri+XxIZ4AnBBUihoQ6B1gTQY0u7hWUtF\nv0gpRXRUJHmCCpGOTl1sHJorifIBLyTGa+7d2eXR8zOyLn1TESiUoDWeVAlM8GgTA9aSEAt3KSCV\nxIdOx9QlHdfaRIpRCGwwYBOEgCyJJ9/gHOuqIqsbGilwIdLNrVSkP4eI59GjH9FUCwbDCY3pY9U2\nQuYkQrNYW8rjwP0HIzyK/UHJujUom/JqtqRMBFu5ZLLVZ+4bPnx7wr/9k1OqtuXD93MSAou54o3j\nAeXdPWRZcf50QaMsJo3Wy9XCRLF0I/ng4ZCs5/jDH62YrxKSN1Ns22f14vs8vf4+78x/HS8Vpq0x\nbY01Da1uI9hpW5z1TFfXLKqGiWmx2nN5teR4Z8RZfcnDw71bjY3M+qyvZhg9RQFF0kNXNQ8f3mVu\nLWU5QgjJarVi/3CH9XqJcgWvTq8piwFJP2O1WvHw4UO01jx+/Jg0Tfnwww+5vr7GWstqtUJrTZpm\nPLh3j8ViCcKyXq9pmobecNI5CGPpoGnamIkVAtYZVssVdd2Ahb29bRbLC1bLhryXMJr0MUays7PP\n6emUUsD16QVae3qjXaSH5Tr+vtte1sT+Ov7G6f+F5zIRYAMegQ8B0/VGyvOcg8M9Ts9eIIJgOBrg\ntUYHg/SCXEqKtKAya6wP8fnoMlush7OXr/j4oylNA7/7W7/DeHePr3/4kP29PT7//AVXFyse/x+/\nz+dnmxcSefbNqTlRsUyMlAjAWA+24+I3LE7noNq0mmDDMN/i2mwGoSuXed+tbV3vqAUxqXovmmwY\npnGDmprI7PRU1OxI4t+hK4EtiUDIdcxHogQ2wKKF73z4Ndp1xaNHf8EbWxn/1d87YtLvo/oFv/W8\nQsgabxuEKnCuxnmBdoZhknM6uyZVQ7a23mC2eIn0liwb0tanCJHxn/zaN/lv/7vf4vD4Dr1M8eNP\nn3G8U3I2u8Do27VrgRguKLqEZYmITTK7ZrSbWoYg2p29j99PBwDw3UzbyEZEzDFzTmLsT/iTj/8x\n947/iLxq2d2dkKcpxhrSJEWHOI96uULKBONjWxFnHYIcb+J7US7KFwgBKWQEaBLE/WNU0SOsmwg4\nVIISUO7cRaUZwQVSqej1B9jWUDWBejG71dhsJByv8cTruedCFAzbijhPwxIlwBhB0gbqAMNuD1pV\nmt5bE4y+wnROZ+8FwcW2JsaKyOo4CDesZ1diJO6zQcQIlyDETWRD8LHFxM1hoYsOEB68iAd35y2J\nh8Z5Vuua65ef8639Mc5b2tYhhSeEwPr6y0taX+3S2nqP5mxN1h/TsqRYOabLBaK/x9o6jkoBraNa\nLuh//TvY1YrxckpzvaZMTQxpu/RkZYmxNZoE0Xju+mdI28fqFpFper6l9ttMlqfMekOGSWBJghcV\nue5TZgl6ren1C4Ztg88GBG0w6wsmMmo41NYBfbtkuT1CvYi31IvwWpMfy+mobtDtTVG383LdUoeB\nCHhrYjR3gMQLrHSIoHABslTRekXdtvQSQVFXhDyDJCXBk0iFs5ZMglYCbWI/LSGjRkc4z08evaQs\n0o5+9ogQ6+5BdLEem/ekwLk4czwy1u4TsNZFrZyU0c1EQAZH8CICzBBorEEh8N7RrCpSY8B3zUdl\ntHTfmv4CHv3Jv0ZJuLsz4vR6TbL9HWT/Dpcn3+fBwTbzlWWvL/nR00vOrtbc3xsxyAyvtMYLyaVO\nOVk0ZCl89njG3qhAKcn0zPB87tguBOsAW7uBxWXGwds7+HbNbNbi1rCYOUYEViLhyWnAmoaMAuUa\nvvPOMWYxZSIu8fNAszrHiQLnWrRu0E2L0S1WdwyP0eA016s15SynyDKClGRFn1oVnFe3y+G5f+9r\nPDx6wN2DHT776ccsZufs7Ix5/OinFHsPCblk0B+ilIXgKfKcV+dTIOCd5fLykrfffpv5fE6SJOzt\n7aGUYjqdorXGe8/Ozi7eO05PT7l355j79++zXs+RsuHOnbtczhYopUhC1DEdHh6ytTUhBI+xlslk\nDCrBkjIajVkuL8mLjLxMMdZzeTkjzXLKPGOSDfn82ROslVxPz9jeHrKYVlxd3d5arDIR++w48de1\nY90a3TQt61WDsQaBoNUt1/Oa/hCKMuHtt+7yuK15edZgutgG7UEbj5SBtvVkWULdxpP3eGtMUAJ4\nzvvfeIdiOOYXvvVNtrf3yZXn8z//HipkQB11Ox0bsNkhkzxuqmkSn31tOyZGBEgEwhHTo0PoDlfx\n1HrbMjqvyYrYCLPT3dxE8csYqmi9YFcEBiIm7OZpzOa50hH0FB3jEwgxgTeJombTsWTaQ9MEvvH2\nm4DkR48+AeC//NX77I9KZGIYb6W88azh7OwlB7vbGKHwvmBQ3keKlFpb9idvkhYlrVkjRU7T1iya\ncwbFiOASdnfG/MP/6F2sHTMqMmrpyIoRtYKJvPXoEEKXpC86b27oSiRdIGcXhdSxLdyAXlwHBG7y\nIrtdoyNvrYU8hX/z+/8Nf/sX/xk7O+OooRISmQDe44PEWocPDumj/KBtPalwBCUQxsWmzkmCsCZq\nJFOJKEvEG/eB6FRDBIJMECHyUTt7R4wGnzLs9Th5ds4v/tIvYZsVvdFtG/PGktIGbEftavc+O1bP\nmmgLj76eWPRrKtAyapSa2tDagHexBrjp1sEm70gpWgO95DWIVCIyPJtQHWE9QYnY6y6J94JNTlVX\nfgw3qKe7V93T4r1FSIH2kmZ2xfUP/4zsb73FslpQDgXj/W2sdnj/5czyVwKeUE4Y7k1weUkpSox+\nRZoXzC+ec3D3AfXsKVMUxc4Rl9MF6eoSVfZI21c0o/fZbRc8tTMKuYIkwS/P6DlN/tb7XBhBcfmc\nkZXoYkQVHGuZs93bR+WW7ZCyqBfMxC6ZfoVU0T4c8n28NMyXDYM8w6kxSlyw1pp20VA0K7xUSNEZ\n2kIAIZFhgxIlUgQIEkHU3VgCStzW5kfnAJMk0uGcJ0VhhCBYg5eKHI+XgdZ6ijyjtZYkSUgQaG3R\n3tLL0ljbJWBswKQB73yMw09ltMv7eFIJCHLi4dGEqCEaoFg6fZPcqkVMmdYmkISAlLJLMHWIEHN6\nkKprYCijWl9JVBBIbZC2RlpBKpPYc0xIWnl7iif4wGzVUirJ5bMTJlenvJqukDLwYuX4+oMJ52c1\nD7f7bE9SDvKCy+s1eV9xeVHx4bsjXJ5ileJlvebb7+/y2aNr3jkc8eCNPkVjeOks26M+qmoYbEmq\nVUKWNFyKaz64O4TKILUjAK8uClRWczAZMKs8zUyw0j+jv0h59K/+Odvf+SfY5gqb7iFkjnOCtm2w\nRmONZrFecDDssT/oYYTku197wGfPX0EiyEa3o99DCPzxH/8hH2WKrWGfyXDAYn7NYFDwyaMnfP39\nbyCkJM97nQNGY82UPMvJs4LJzjbOOebzOU3T8O1vf5vPPvsM5xzvv/8+T548QQg6IJRwenpKu15w\ndLTP0dERWmtm0yl5XoD03Lt3l+vra87OTjk4OKCp16wWU1aVYbx7yHw+o8hzrIs7gtWGoiyYzRbs\n7dzj1atL5vOKujbs7N8FlaKSFP1zdLxWEkaFoN2IFW++Im4qXLHzddyWvI+J/EWScHiwx4cf3sPp\nhpdnS0wAKwPSe/JUIaXCuQbtFKvG0U+gaWEw6JEVUW8UgMnWsHO3abwNKCk5mAh0UvL8UYxXFkKg\nVMytCZ2GyBpNIlKsj44sIQVBv65uiU2Nvdt1btu/z7MBWpsfDJ2GSNy4MgUCTeBMCFoX2JMCJaFP\noFSdRV0LchXzeZSAnhJkMqB9/LoXgopAnuf8yY8/BiQHpWJvpBC5JMmGUU8YDGlimdcLtvbuULuE\ndX1Gmgwp8y20WSETSSIzdrYGWO1JBCRFybyqqN2Sf/j33+M3/01CmQe+++GbfPb5K0iGZKPbu7R8\niCF1LgAuhsAmaXjdksnHeRS/3LH9fqPz2ST1xLU2rraiK4FJWh1ozG/zB3+aI+U/Y3tryGAQy1tS\ninjo9HFejXKBNw2+MbgkwXfGmVhLNBGoEkXnvH2IPNiL9kOVEIKLbIhzyGxEWhTsDPucXK1JESxW\nC44OdxHhduxpxy9G3BE2IuH4FR86wXfXSFncsJddynkISBHdwD4IqtWS3r5CiuiENETwKIHaSSbh\n9f6olL8pHYZOTxYssdRLBJk3zGV4/dAHH/PyQkfJxS85TPAkBLSzzC/AaEeOpF5rtg8ThAg3tvb/\nv+srAY/Pe+hByaJfYmbPwUicbhhmJVwvMIMxKu+T6ilH9ZLry2vYuofYOkKkgmubcbAtMStFtTT0\nkoLhYExVSwbzp/T2jzHrcyRRDDX2M3rXF5h7f4vpxXMmRZ/1/AUhz0iCJ5EB7ypma420FSuRkrg5\nSX+XZHlKkymy5Qr8ZrK6eOe8v5nISoAVArWp1QpP4sOts4R98KigaJ2JOQ4hZvwoF9tXaGvZEQqZ\npMyMpW1bhlnOMhUsuhqqUoqcCHATIQgitnXwIsQgJikIQSCFxwaPJD7MCiBElsqGGDQYOkidAhZJ\nKgKVdyTWgYwCbktE8HGee5SI2hofHMoZKmsp1wakohEOJSRWKZS6fSaG94J+pliuFuwfjPnmO/sM\nJts8fnnBb/7bP+Xz85pvvrXLvKqZLRqKHYVNISNB5Rl1Y3h5taLIUlIJP/t8yr39PmtjmD6b0u/n\nyJ2UJ58+x6uMoh1TrWts1fLW4YTKN1TaYrWi1wuMj3L2R2OG2oDRPHpVM+wrBgcS73JO/uw3CE6j\n8iHrumXn4bdZmX50r9mWrWEf5zxaBEzreDI/RwvFcNAnSW8XAuZtTaYUEmLQX78kS3us5heYNlBV\nNffvv0FVrZhdXyElvPvO1zg5OSXLcuq65uXLl2xtxbLUz372M+7du0fbtjx//pzxeIwQ8OTJE7wP\n3Dk+5s7hLk275gc/eMRka4KQEu8dWhvG4y1yFdjd2QIBRVGQyG0QK7xzNHVNr19gq4oXL16wczBi\nPJ5weTEnz0sW66eUgz4ycxhv2dk74POnp8wXt8snAqibwGIl8HZj6v5iQb9jIEyDbrt2JxbI4OBw\nm7PLOZ988oxhEZ8bgKbxrK1HEMizFOdapIRUKVrnURL+/Ps/IcnjUlgva5787Bn7kx36vYLHjx7H\nhoSDnHcORyxX10irub5s8d2eU6SQpcSQUxnLKqEDQqJjbUOIIIiNNunnuIJ/rXmIJa3XFZvIPIWb\nwEPrAlMhaB3shkCpujVHQC4DrRfUNq6JAwVp19yUIGg6Zk3ebByef/TdNzEu6Q5WLYI+3/yFXYq8\nJB8e0jQtk3KbJC0h6WFMi1SSJB1Rtw3GrCnSPo1cYkxLIjPqVvGtb93lf/uXP0XLBzx5foIWPYaD\nkqQ3uPX4eN8543y4ie3wTtyUAENn+tiIzb2LDJAnxJZDN1Ot2xvobOsibsAyBOaL3+T//N1P+fe/\n+z/wjfe/jreBNJMMctkdVEVcV6slSkqMjS16pAgEZ5GtQWYpMnTRKHvbiDwn1FWcK0HQ2fnI+kOC\nzBiOt9nt9GveaB598pSvffjGzzmJeB1ZcoObBdZwoyn7wjRC3di5wGsPIrBqzE33edH9UikkLgiM\nVx17BsF31QHrI3vWAUxBHE827kW/AVzx4xsG1cdviSaczeuQJCEgRMI1MFtqjkpIEglK8uj7P+Jr\nf+eXvvTtf7VLqwL/4nOaJyfYa01YvaKXZbjRiHYwJMxr8tWKEsmqv8twqGgnAyotqJ+ecmg1sveA\ndPE544Fja+uIfDigNZoLsUNYVWi1Qyszes6QDnZY9e+xOn2EkDlLmdKb7ECzhiyWXCrrKPWMO9u7\n7G7vE3p76HSAUSN2dye0rmNDhO9uRqw9+iCQHoyIziVPIHamChASNgFHf+N5EwLGGRIvyIVAJQrV\nCaekCKRKssRFC6mLjoHWWyb9Iv4sEYnWNhaSlfcoAso58hDPIcIHBB7T0YZJiHofQ2ReVCculAGE\n2li/QXpHY7tOmEoSiH1qEhFTk1V3mtHORBstEuc9rtYIpxHdz3prUdYS9O17aS0WU4JwzJYrrs9m\n/NkPn/LxRz8j9Uu+/uYR33jriDwXPLs0eCd4Y3vIpFdyd6fkeJyRZoGLpaWvACmQSH74dIrTnt1S\nMEqgJ/skeY/jnQLjaoKqKMeBpxc1L55ahMooR5DniiLXhFBzJWvOg+LgjSGWlKoCL3NIFdoLrLME\nAs8++j2WL79Ps/gcW1+RqRSVZtSrlvVyTZqUeATKW67Wt9MbPHv0MaNeyf7ePv3BEJWkIFICJWUx\n4OLikjwvefbsOd5HKvnk5Iy7dx9gjSPNUqSSFB0rIaXk5cuXlGXJ7u4uWmt6vR5JknB8fESaJKgk\nIYTAnTvHLJcrnjx5Qln2uHv3Hp8/fUYIgaura+7cucvu7i4+wO7eHt4HrqdXtG1L07S8+867VKsa\n76BposDe+IplPYPEUzUVP/74x1hneeutN289b1wd+z017i/zO/BaZudNwAV3s3hJCdVqyfnZCbqd\ncbG8pm0sxgNE0YpDxEyezrljbKBMYuPPd792j6998DYAw8mEO/fvcPfeXQ6P7/P2u29RrywhSZCJ\nYDZfMppEu/3Glp4k8ZUmSpInCdZ0WTsbNAJx5daho4K6P7fVDQa6DfH15rT5eNOt/aZzu4+5NMsQ\neBng0gqMj6UsJaFUgX4ChQzMbWDtopBZSEiVRAIfPXp0879+ezcjeEvwPrILIlD0DpAqQ0lJ1b7C\nhiXeLxFuQZEP8N7jzBpnZgz6e6RlSZJsQQhoM6UoJty/MwKvWa9XpPmYfpmivOdqdnuwTKftsp0m\nxLm49gYXx8KHeBCLVvSurNXpljYMBHyhrCK4ccYGIXAheqyFe8Tv/cGv8+zZc/Iio9GOurYMhzmp\nUljvaK/OSZSKgNQ6WDegLd5ZZJIipSA9mMRIEOOiw1d0zEta4IVECMH24RHbo5IyVzRaMt7e5s13\nHrJarW49Nq+3uE1z1ddlVWs3LrWobXLd+BgbS4LWe+ym86izgEQbhyPpuq1HQGm9imxa9ztEJyCO\n4YbEykKkKm/m7WbObkqzIUTw4zf3oGNznbUIAjp4mtYyArJg4r1xgfVc8+Z3vslq9uWC7q8EPJl+\nhU4EbvWCNmlpJmOMUPSLLYIINFqTppZK5WQiYT18E6VbQlsjegV+ssvs9Dm93XdZhRQpA/P5mtyc\ncTdfkuyOsUWBG2+TNa/AJvT9iqAKiiTFt4aqdfQGfWamT1ZKEmFI+rtMRR/tNam5ppg/QWYpl1aR\n9wrCTd5mJMGF39SsJTLE2uWNhCdIEP72gMc7UiFAOFrvaLRGG4sIniRI0hDQITIwibQQAsJ5GmtZ\ne0cKKDyFFFjvMCEQRDxBOSlIhMBIMCEunpbY+DDWowXGRYFpQWdk8Y5URCu6J24CqYxiMhHi6VPF\nEjFKyk47KQjeda/ZI5ynXS7BW5SLGiVlLdLfHvBcmhVpIrmqK/qjgixPOF2uWM0FgwyadsUPny44\n3Ck4OpiwNoaPn81IgV98cw/XeL715oBKBqarlp1Rj1/7zj36vZS7R2Men9esZyt2J0OmVbRFl3mf\ntglIJXjjwZisdMgkofUK1c8ZZ57rc49wKUk54PBgi8EwYV1dc3D3Q6wQrNYxcHBQCKSeU198zsWz\nj5g1KwqV0ViJUZL+IKUs+lBresVXEqV/7bImakGurq7o9wbUjcYYh/cZw/4QIQQ//OEPOTw8Yr2u\naFoLQXF5ec14soNznqPDI2azGcfHx+zsbOO9Z7FYMJ1OWSwWrNdrdnd3mU6nbJoJHR0fRlHfekWW\n5VxfX/P8+XP6/T7jrS3eePgw6oCMwTjL9fWUdbXm4PAgltPyDOc8TWPI0oKD/UPWq4q8t01vOOK9\nDz5ApYKj40OyXNA0t3SSAIRusQ3AX30mN3u8CNAtzpt1fLFY0StGvPfOAf/hL78TSzMBtPPUOgr5\nlZSxi7oQGBcFviRQ9odsbcfwyLKfc3Cwi1LRDQICbeJi3C8yEHD26hqIDieAxsYDBEpivCNTrzeX\nzYYiNkFaX3xbt0w7jSfnL8TtbzbmjVYldsp87XzpdEbawZWHV16wdvH7BVEDnkliGi7QeFjaQGsg\nTwSViRTWYb+kUAJJA97hQhLLHNJTDg5jxEU+Jk16nbZFQdAcHX4XVI/tyZsYfcZqcUmrK7K0T54f\n0Mv7aHHIew8NptH0x0NsSEFJesXtbenWhhshsvUB5zsxtt+wOVHjIzZjttGzhE3pamOC22hGvjj4\nAReihd06iZTwf/3OP+XsbIrWhsaD1hbvHUmSsf6LP8HJTnDp4s4dulBNCch+huz3EPu70C6I0d0S\nZBobNmcZIQR0W2FFj1nleDWdMZ0uEViEul1OkejcfBsrfvw7HiqCDzgLQQe8DTcg0bvIvAgRDwu1\nsQREFzsQ4yFwG3AZEEhMNwd9iG9JSl5bF70gOHEDNm/m7V9idjZMZrixqEN8HQSHCIFcSer1kh6R\nh2tah8pSVKIQSvDq0cmXjsNXAh7bHzCoZsjJDrLM2PNrXJLTXM0xTUXrFZUfkVQXSL9g4M7R/TFS\nKQrzAy5PnnFYtqSzM+5kmsG4j0wzDvfukfZ2kNkA5mvSqoZigiwz5skIFzzBGco+BD3l0hTsDgKN\nGOFlhuj1WYcU7TJEOYJ8zEh5RldXmLNLhIoqqC/E7OCUwKkNsOlq6BtqE9GZy//mlwoBEzwy+JtA\nwFQKrAgY72i9J8OTSk8uEgrho6Auwi6cC2gfotCto6ad0zgCZQjUwZF5GV1ezsXFy3mMd6Q3AWIO\nI6KYy/n4INsQOwGnIS6AInhcAOM9XgRciIK6ICRJl02kiAhetga9XFI4B9ogvac1FmFu77bZ2uoz\n2BozHI0oxwV7O2O++/4DnlxcYVygrT39NIlx7SLw6fka4wLf+3zKTy+WrFXGJ4/nnJ03fOuNfU6X\nK2QSeO/uLmmesTfM2BaOV1cLVBop1YuziuGwZHs7xYYW4VOUEJQjhQ2KpetxfL/HpKxIK4NAUqox\nGZIffv//Qa8arNN4o2l0XDCraonwNUlW8PnLUzIMpZJcXS0YpYHhKIvHoFuNTQQow0HsdC6FQIgE\nJWM68t07d+j3e6RpRn84ZP/wEGMDz56d4JxnvV5jjCHLMtI05fT0FePxmCzLODk5YW9vj6urK169\neoVSCU+ePOHq6prZdM719ZTt7R2EFCRJitaaEALLxYJXr14xn885efESgaAsS7I8o6oqjI2hlufn\n57z77gfM5yuW8zWrVcP5VcXu/h2sN1hfMd7qs1hekme3L4WCoOhBkf/153HTJsnZqEVTgmjxJgIa\nKSXrynJ+folz8SQvBdQaqnUbGZAQN8E0lViAABfnF3z608hm/PDPfswf/b/f4+Mf/5hPf/oT/vzP\nfkRtIVjPso6Na+vVDb8ORNNA/FXx35zgJroqtpbo9Bqd1iZsdtbbDs+m7Xp4vUH4zSax0T188XNE\n1zE8RLeWDbwKgrmPTjchIuhJRMwMGyrBKBFk6nXgG8C9UUqOIRUpQUR2QvsUXbfd4VLgXEWiBhAU\nQeQ0RrNaPUEJyenFY0CQl9uMBnt4kbIzPiBLeyzWFb/yq2/x7OVnXJ0+Y5RahtKCuX1bEtEx+f5m\nHOLnzkWg6B03QDmE2K9Nbu5Jp9mJPx+63/F6PP0GJBHXVeclzvw5//O/+Mc8fX5KpaNLdzwo8PNL\n9OkpiBTXGrwx0bHlPd7FTVumCrHVjwdOY0FG/YlvKpx1YA0hJAwnE0bDHkc7Q4a9HuNBgXUB095u\nfMINmiDqlIMHfqZZAAAgAElEQVS4eZ4iewLOCrwVOBM/9y72sgvEJa7RFilidIq1vmPDo7Vqk33U\n+AgUA5uDS/hL8xU27rjXZarwBcC5GXPvYmnSuw6QOfDWYkPgsjEs5xW7dLqsEFiuWozRWOOYXn85\n+/WVR9PEe4zPEa1kK9XM7IT+AFaLFf9fe2cOY9mRpecvIu767tsy82Vm7WSR7G72Oj0aoTE9gkae\nLNkCZAiQMDJly5YlT7YMWXLkyhEgDATMoCFDmhn1MtMryWazKllVub3Mt7+7xCYj7n1ZZJOcTnpq\n5AESWZXLy3fjxo3445z//0+W9ujdn6BVRloMSPMcnVfE2RC1+Bir/pBjfcH+cMIsH6J9wqpeMTjc\nY2YLfLQim1+xQTLwFT5OcCIi0xafKMT2OZV+kyQXCHNJWR+hBhnm5Bl1b5+0OSPKPVodIOKMjavQ\n5QIhQmpQtMSsgOI9kRXomyJXuNGdQ6S/PYfHOE9sg5cPOCIZTnxKBBKX8FBah3A2HLBbByzhHMoZ\nrI8DilXB/8JYSyYlwjtCwlAS4dDeE/lQioukak8sDiOCwMwS5JOO0HYC4ZEu9NuRsltXg6LDuKBq\nkSiMCP23HJIah7Qaaw1isaEebHEqIbKB7xTsMm8XRwdH/OY3J0SpYDpbMpttGPVS8lzx/HLNHz4Y\nUxjN/b2U80XJ6WzLvKwZZBF1aYkiwfFeAd7ys+dTjvZ7/PmPT/mz779BXRuuNg2DImF7uWIw7DMs\nQNiEsoF+njM+yrm+XlMZSyE8qIr1NsaZCF/U9I4kH/38kiLqE6uYPNU03pNJIFWkiefjizWpinES\ndOlQUYxwgahXa4tznvmiRg7HtxscL+n3C4SXaNOEViHekecFqdtydnbG/v6Ys7Mz6rpmsVjibPAc\nquuGKI53Ke3Oh0dKSVVVHB0dcXV1xdOnT/nVr37FYDBkOBxSbreMhj0eP37Mh7/5Ddvtltl8jooz\nmrrm+PgQYwxXV1N6RUGepVgL66rBOUuaFvTygrSX8+GHH4GPePLmO5y9uiAfHqKyAYgSFW05+fhv\nOTocslndTr3WxWoeTm5d7Mik7TSUsjVw84AFqyBKIlRVEcWSvBeRZYpmHUpYAEophLVEKpBInbE4\nGQDR4dERaZYA/5c//v4fkfR7fPVrX6NXDBDW8MO/+F9467mcLVuZ+SfZRbJda6z1of+iaxMHHRH2\ntY0F1/HzOmLo7x6esEG717gWu9d9/Yd8i42g/SMdQJOstaeSMJKhLU+xcyK+aayZScEoDgeotYEn\nkz6xCHwnSQkuxdsty3XNPVNjRUUapxi7wgqJb5bEcRZUjmKGiipEPMGYBTio6iW4hihOiKKC737r\nmH/7g//KX7aXEEcJeVHwH//Tf77dAH2qZOOdD/xI2WZX2i7qOwJ5u/kHvkoYMNEByTZjJgShy31b\nXwm4R+AaT6RA+J/zg7/+DxD/e/aHbzAyjuriFUhFow1GV0jnAvjUGunasY4jeOcN/HwGaRbeum5C\n2wknsLoiGubko32yNMOYmvGwoDGaB+MRLz5e3W5odvNCtNLw8EWPbDk07f7FzXUigiGhpfW9cqGF\nkRXB+LDRpi0XOpwL1ZPG3GRtnISuR5F3hPYSrS9S4LfRytJ96xXQgsuOQL1TRIb7IVpK0ONezkc4\nFMF2QdcWOYxoSk1vLIm/IOP+hYCnt7mgzB9QvXpGEjUkB2+x1bB/fJ9q/orh4QGrRpF6iynn5PGA\nl1fn7NUrXNPA0VM2xuOufsEiOmYgPGex5qj5CN/fw0rJIE+5rqb0kERx4FMoVyPSx8R7PcxyRS+O\nMXbJ9tlzhpMjit4B9eIZZZSQrF+gYk+9jrhal6RnJ8j9NNw26RFOYIVESIcQEmGDjZGTYaEKGTPR\n9l353cM5h5Wmtdh2NEaQqojSWJwNJawMSSNEILMhcFKQpBmT0ZB6u8XJkC73KnBrgqOkIzCKAkx2\ngJMSrKHxoU+WEMFUy7YlLIMLaquW1ehdIC93c8kKggpLhKnsRSv0cwInDEooMBbblJSbNVxP6fWH\nmKAtCE3kbhl+vmG51TzOhrz5tXu899ErTl7NSLxgr8gY9hJELUilIEcwHuQMUsV8U3FdNmRxzCBW\nnC5LDoY9/vidY64WS/7bj0/ZG0cMJilnqy3f+eZX6EnN2XJK1EvJnGC2XrNYzYAIZz2lSoijGCEa\nZCRwNubsrKR/HNOsNqyuFXhDESVczrYYC3vDgsN+TpZHlKVmvt1wOBiRRBpHjPKSJI1ZG8k7e7cj\nWB72Mz786CMev/lVlssFcZqyXC5pnGdvb8TZ+SlFkZEkCbPZjPv372GNYD6fs1yuONwfB8fssiSO\nY4bDIZvNBmMMe3t7DAZ9Xrx40RJbLU1dMxgc4b1nvV7xne98m0obLi6vmM8XHB4eEUUxh5MJ+3tj\nTk9PKXoFF9NZ2LSlZLPdgFdEaY80HZGmBVUJ+5NHfPA3f81+P6ZJNoyGKVlqKTdrvoS4D4Ej60uc\nFtStff6OadC+njbBzbfW7BbPzXzJRsOLV3AxC9kda9mZFfbSmMVmG4zhXGiP0E8lSSI5e/kKorAU\nPn/+krRXcDAYEcUxJ8+et414PV//2kOmrzynH84/ATJctyCLAKCi1siu61jeNRAVkiBPt6/tzbcI\n59yOjN3F6+8jgKyb/3Qwp/WzDXxFwt5yiedSQBLBN9KWACBAekiVp4gEeSxg4bjeWkqjcInBGpCq\ngUbz8sWSR9trNME7ylHTmC3Hk29R1hLjIvrFkNVqynp+ikwKirggiwpqVyHFMUl8xf6jJ7z7Rwf8\n8IdXIATGala39JkJF9Z6uxDWOteVSVrXZdEq5KQMjTSVCuoq15Fsvce27Q92GR0PkQw/61SwS9BN\n+Du1FqHEdfnn/Pf/8ef8xQ++xz/7k3/H93/5l+joXjCdrSym0SRxDJstqj8gynKKf/o9GBaB0JwV\neCK88ki9DW86ycFDXVeM7t8n/iDD+hohI6qyYr2+HeDpUo67DuTt9dEeHLR1mErio/BN5UTgG/uw\nj2hr2XiNEpKrdYWpPVaLXSbIWYnbCGrrsHXoV2YjiQqE2TCm5kYJJwjPlNOBd+bMDbBxbWmySwo5\n2u8nAZDW1lLXdRD8GIsR4DSIWKErzfjg81vafOHRPRscYOorxuMerijYOEOeZERGMMj32DQ19fKU\nsi7ZxGNeXF6iBBwPYlSWwmZJNMjY7n+N/t6EOK4Ybl7hlCTeLIL5XjMlywpsNMaVK7xe0rdL8s05\n5WlNGqfkyRjnItK9Q8q6YbGaoZ/8Ab2DN0gO3qR4+A9ZOU/y7ARl2g1dgHCKTmLoRTBJEiJMVrhp\nO4F3re/A7x7WWxQShcO70KOqthbrHco6GgSRCEaAuQBhPU4bBIZG12xaYGKcCwThcAYIKNlbEgJ/\nRwqPsSGVGON3lviyPa147wLfB79zZA5lqu6E191kt1v0jDVBlhmwO03rQSq1xtcNotLUdYU2DVrX\neHP7k/r5csv+0ZDptuRnP/mQPo48jXk8Loil5Ocv59TOURrNlbPkWcLDyYC3jwc02tLgWVUNX3uw\nT5Ipfvnskrp2PBjF5ElC5Wr0KCZLPWeLFddTzXpVYmwJtiFPU7x1DPsxHk1VO4SPcdZRl4ZYSFIk\ncZTQm6TEqeR6WTHa67E3HhApKBuNd5ZGa9I4Zb4quaoiNqahJx2rTYUSnnJ9O4Ll+cUV+0f3mF5d\nIpVivliwLUuu5teslzMePnrI4yeP2Ww2fO2rXyHv9Vitl/R6GdvthnK7YbVacXx8zGq14vr6mtFo\nxIMHD8jznKqq8Z5QEmvdkwN/Nnjs/OhHPwYheeedt0mShHJbBpn7conWmgcPHlBWVQA7IoB0FUU0\nRmOsZ7GseX5yznyx5b33PsBVS/qppZ9H7I/3cUYSqxz3JTKDRIKmhKaBT8OCHam0baHyejbDWMNw\nkPDGkwO+8nRCEon2+Qi/U9U6lDcIi7sk8Ia194wP9nn06AEADx8f8/Sdxzx68oivfvWrfP1bXwcB\nWlvqxqI+wx9GyGBwsUvZdGen7kH07R9rU/ti93DeLjpiaPfSvq23vP5599GWGHYn5LaEILrsRfte\nGi04bRu1ClrQIyCWUCjYL+BHJzMqF8bKE9YfJeCNp5PQMy0pEHiSKEZFQzbbBVV1FqwJ6jW9Yh/t\nFXl2gIzH9Hr76EaH8ojesKkF//pf/Gm4/RGoqCOE3y5cR1Jur60rR4W2ES1nxbcTprsF3e7fKog6\n8ncYs278ROt7JogVxInosALCw6aWVCVsFn/Nrz/8AddXK2rjgjrJmKCAM5ZExERRHNbmcR9fN8g4\nbRuLa/AGohRMg3cWJyRxOqReLRjvH1CkSXAZl57jR/duOTqvI2Cxu2xoidk+0Cxcy8nxLkjUnQ1z\npzaOsi3dW2txjaU9n+NaLo61Htv+23Vy847L5ggcHk/bu+zmo5vXHdHZOdG+Tlue7F7LhV1MO89g\nb0xFsF/p5r8SHhVLyurzOadfOK1kVfNANcQ+QcqcB0WEd7ClptQ15caS1zVxr2C9WsCwx0A6PuaI\n2bZi7VIq1Sf3Jb5aY4tH1HtfIUkTmsGY5XSOTA/wzRa/eUkvUuSiB/sPaUZDjvg1pGMsMdZbyAYM\nc0UalejnH7JuUiqZMRU56tFj4hcvUIlAydBYE0G7ELVgBx/cN9s7Ebx5ZJue/u2F7AunjxNUxgRk\n7CzGORwW7wwuiii8ZWks2jqstzjpkVHwY+gVPXzLOA/ydkHSnkQkDuk81jki4Ul9yMYkxlGKUPpQ\nwmK9JcUSdyuD8GgczodSS+U8ot0UhGPn6iy9bxuR+iDbdJ7Eh5Odcx5XN0G91ZjwYVp6/S2jX6R8\n/fiIP/3OY7733cc82O8xGWVstWd/lFKMM96/WPBXH1wym254/+SCHz9fYKzk7YMel+cznjyY8ORg\nwCSJEJGisZ4oVVgL9yZjnuwPWJWOJRtc3uDWjuuZpawiXr5oyNOENEkQTmG0pXGeyjRUtUVIwXxR\nUaQFOZLZwpKnMbX2JCrYAGjrWS8bhoOMnnMcHQ4ZxZrYKl4t1tRVw6QYcXZ1u9Nov+iFurh3zJYL\naqMRkUJGEfPlHO8d77//ASpSzOZz3n/vPcpyy+HhIb084969ezx58oTVasVkMmEymSCEYDabYW0w\nEVwul/R6PZbLJefn53jvGY2G5FnG/fv32Gy3vHz5ktFwiJQCbTTb7ZbNdst6vW4dmUuUksRRRN1o\nmkZzcvIx8/kcrTUff/yC4WDI8fGIk+e/Zr1acX52RZr2SNM+St1Org+A86iI0PH6U9FxDkS7YXVn\nFGchTSSHBwPuHY3oJVkoG3kwVjDMIUpT4kiFBobhxIDWnohgFpckgQS6WFRs1iXWGrQO2THnoKoM\nVRUyhO27uHlj3oUyhQoL+o6L3BkH4UP5vP36bjO/Nejx7Sm4XcPaDWy3pH3WR0uU7T524Kgt2Xjn\nuTTwIgjudu14IGwgvVix0A2zVYM1Fd5YgoQCvvH1+2yrC5J8xLq6wtgmlPbchkH/MXGckiQD9gb3\nSOIBziqECmtULx1T19dkSZ+yrPiT778d7mVLzOaW6zHctIdwzt2AunaDdU60/ned1w67Pk3Oth2+\nO7VQe/2+LUcGQNKOswygLEk8KiYoSAVsS0ndwIuf/hfWySQcspt6VyITTY1QMgDK+0PEYAhZD5Ks\nBTseGYd2Lw4LaR+BQ6QSW9comTKfT0myHuvVmji6nVACbsajUwd2VDLvOxVbB3Jk4PPYwOWRUtAY\nS6WDglU72wJLj+kyPG2fMk/gjHkbvI66FI1v+2t1nBzXtpu4Aedt/yz8LhtkO+J5m3ELwMfSU5LG\ne06AygafNWODItpq+4WtH78YR6cp2+yArashyWhMhHcVtvbM1ltiVSLvP2a9bcgzQTp9xUrHTBJN\nlMaMfv4/mb44QZZzuHqGbirs1Xvo/IAMh7QzembDSMYUaY+Lyduo4/uopkblD5APv0MkNY0y5MrT\nMyt0sYdXA6KsoTn/CV4vUKc/Q5YLxE//DhlFrTnZjbNmR0gLh7MAcQSiBUVtku+W+XfTUs3jzu/B\nW4wFhSR2Gm09qXdYZzBe4DWtE6dHZAm4Vj3mw+ZaiVA/td4FPx8haQjKiQiPVhB5R0zosSSFxIrg\nYGRbcpdsT3fmNRWady74HHlPQgA+ljARjfekIii2jLNBrm41Zl2y2WwRukLbBqFvy3CC1abhVy9f\n4pIe/YMR/fuHXGy3vPdqSrWu2S5K3j3a47tvHHK432OYJ0hXc7IqGQwK/tE3H3N+teTkes3RaIi3\njtJY1pVlqjWnesOmJ9HJkr1Bwl6/z+i4TzGISWTKYKRIC0WjG+IkwhqJ14EnNRpmVBbyPGZdr5hu\ntvT7Cuk9faXQ2uC8Ylyk5IOUdQX7x4eMe5KPztf0Rwn9oo+1FR++OiNYht5mbEqkVIzHe8RJwmg8\nZnI4Ybw35pvf/g6Xl5f0ej3KssQ5x3A4ZH9/n5OPT3j8+DEgOD8/YzAYYFtCuzGGPM+ZTqecnp7y\n9OlTkiRhb2+PouhhTFBdxXHgjvWynEcPH7Fer4njmCSOqaqK87NzmqYhSRIePX7M9OqK9WpFvyiw\n1qK1ppdljEcjrNFMpxccjMcMih5NbTg6vI8nRhBTlbcHyqiwMRnz23Oua8XgCMT7DgA5H1pSXE7X\nPHt2wfnZjG2lcQQJdpYEg7KyttSdR5f1LTcDZtNrPvrwGQCzq2tmV0ENc352ysmzk1CecjDqZ+Rp\nd6/9628MpYIIQMpuEb/JMkDnxeNveAuvnX5/5+gyF7+laLnFS+3AUJct8xjtmRp41YjdVQkgEdBT\njqMEfvqqRGuNcR4nIqyF5mTBd0cDrK4ZZhMGvQmJKhAMUQp0U3J6/WvOph8yGh4hlWCzuqaslwjh\nSKIcIRUqGTE+GPHwUbHL3Hn/Oe//CyIcXEPWoCPR0krPOyAVOCcCZ9qms+6mJGm9D9fXZsREV3rx\nfocdRTg/E8eCOAUZudb5RLBYQ09+hTgbgANb1eE2W4NwDiEF0lnibzwCGSOcw2/XgTIQp2A0ZnWN\nUK1CzWqkSFBpjzxLONgfc3H2ijhWlOvPb5/wufe9A3rc4GUIGSzrOkDSgc72a11pyXq07UYZjLbB\n4sEEno/TnbIrjLUNhYuwt7Z2Cfhwdg6E5OCo/gmSvQ8jbl13v3gt6wPtxo7FE7eCG63dbr1YXq1b\nQcDnPw1fCHhWS1g0EhdnZEkoFSW6wixfcv/eAWNpiZWh3hg2l2eIZsmmWXJux4wPjjD/5J+jzBRl\nNc34IbGpSNMj4nqG1hVFcZ9KeLyPsVIzuvgVs+WKRiq82bLxfayPSZsNMs7QIsVePqeqNenem0wG\nI+JygRk9ZiAjVtOr19QFqn0I/K5BqBXBk8OJm6fJO4GVis84UH7x/GlXHd0aFjggtgbhLHUTFFHa\nOhKhsMYipUM6ENogjaYRAWR4OkJzm1Vomfzah0yPFC50UW9JX85bEu8xzoXNrpXtKedDH67u6fU+\nSOIJpxcA4TVVy/dRbauKxvtgh46jdjpI9psG2YRMT2YczZfI8Lx5b0JiPKJp0JXjarbhw5dz3nx0\nSB1JvvnWfYZpRG08S+M4ngzJs5RUJMzKip++mmFc6D2jpCXPUpSSTLcV+4c5w9EA7zXWGeIoZIDW\nqxJvPP39iMlen/OLNVmRYrUnU5ZBP8ZLh7GOSGk2VYNTkCtBHKWIKEZFggcH/eCuXGrWlcYIzS9+\n84ym8oz7Oe+9OAdvGOURvVQyu17eamzW2y11XWOtZW9vj6Zpdv+vq4r1es3JyQlJHKOiiOPj41Cz\nThKqquSnP/07jDFIKanrmul0ilKK6XTK/v7+DgQVRUEcx4xGI46ODsmyjOvrK1arFSfPTzg/vyDL\nMpRSXFxMkVLy7rvv4r2nqkp+8fOfs91sQQimF1M2my0PHjwKi8tyhXOOKIqx2pLlfZrGsy0dL04u\nmF7O6eW9W88bZDhRquh1WnCI7uTmndtt9kC7SDqcs6w2Fca6mzkfBZl1FqvAw9CByeLaRdUamBxO\nODgM/dAmh3s8fuMBx/cOefj4CV//5rth87Jh3vhPLZmCNjPbWvIjQumjbT20Mwa8kea2QEN+qQTP\n7nc+TVjuNo1dZgM+tZm8dsJ/7fUgvCer4bTxvGhEew3hW5GAgww+vCrRLqY2ZrcTzZZLDocZDyWk\n+UOm84+xdsOmWaCtQKohaZQjRUJZrUBAv39EGhfk6aQ1lJMoGUN8zL/5V8EwznF7QvfNmPiWLP7J\n0t3OVbnNpGtL27E+ZBA6ECm5ATqdfH83lu1nKUO/tzj2xIlARR4Rhfl0sliG/k9Vg9QGYTSqapC2\nJUQ7T/b4Md668CJR6yhtGrxzRHGOSHPCg6CI8xwlFcvliqZ2DEcTVJKQF7d7ttrOVLtMoG8vZqcs\ndK35X5tpcdYHMNNmapwPZrpBrm9xrWLYm/ZZsiJYAXCTpbMu9DQLcvQwx5xtQU8LfnwLfujKYC70\n49oBL9uSmYN/IwIYxBHLzTY4O/tAmm6MxWiDaTSm+fzWEl8IeJrmCmE1cdqDeJ/YWXySocbHlJdT\nqmiAc5ZYXVGMxkRPvknBhnp2wXI2Q16+JE4G9Mf3mMQW4dYkqUQm++hSU8cFfnCEzCUqHqPimHj5\niroyRFlOZM8ReokcHTJfNyh9TdI/RtqS+fUaV0yoZUE2e8a6rsCGUkVYeIIyaZfCDMWt1pMnZHYC\nx0UQAf6WfAPtglLLOYe2FuUctXdoG8wMpQ5ZFWM0CZ5GOGIZZOJFJIlVoElr64KiopWaRwKylmwa\ni+B0WRH4RlqEJqNWQGeVWDuL96GRqceTEtpFGEJX+NpbhA+GYY0TCBdKKdpZmlD03p1gIkLbAGMM\n2mhc1WCdubXsGuDjkxfEzvLyfMHZ1YJYRDw9ntAYzziPqRrNyXTD//7glMTAtx8ccjRMOB4K+nHC\nu/f3eDTO2QrPj0+v8QJW64ZvPJwQ5Z5NWVFtt1Sl5Wph2G7KQLKznmGSYMWWg72M9cIgnKQ/zEky\nSSQiitQirSCyBfE2QwlBGsHeIKPoRaysIU9ihq3k3FaGWEbMVhVFntCLUiovqFqFwDq+HYdntLfH\ner1hdn3Nq5cv+Qd/+N0d4Lm4uODdd9/l6dOnrNZr+kVB0zRorTm+d4+6rjk8OmJ/f5/tdstyueTR\no0esViu+8Y2vE0URSZJQVRVXV1dcXV3xy1/8gtVqxXq9DgDFWg4ODnDOkmUZy+WCwWDAeDRmPl/s\nMkPj8R5HR4cMBwO01hRFn9Vy1fbsquj3eyyXoadWudWoqEev2OPp02/gnGI2vx0QBPB16Omjm89I\ngbT/tdZ++ogKhFLY4aRgfy8hSlqDOR88U9I0IlKKPFOBryYgVtDLJFIJ+v2ifW2D0cE93VqLbjTW\nC2wD68U2kDBff7+wU4tDeJ7dblMJGQYhAAVCifCAd9ya244Nr/1O+49PAIMOBP09FaHPyp6EsgJM\ntees4/S0wCeL4MXlgum6Ai9p6uBSGA9zOLvmXm9FT09Jkj693pC90RtIqbi4/Dv62QH90RFlvUZ4\nzeXVS2SUMl+dI71DmyqASa350z/59u46vkyGx3c3ozWXdG0WG8ROsm9tS0xuS1y7zAa729LWA0J2\nUMiwae+AAiHDIxVEbZYnigVKCqIEzuwH1NUGtlXgbDa6Fe6HDLzcK5C9HkgZJOlKgnN4GYxBfRzt\nuHN4UGmBNQ1Zr49KU1abFeWm5oOPTm83NogbAChaANjyZ/CB7tDokPXyrYzfdmDQB0NW01ocWOva\nA0BLLjYiGBcaF36/BSq2bXTQtazYlcss0AKaTxoPBrC1K5MZ2r/XyuRNuMnzKuxRMVC2hpll4/FK\nEqUJsfr8vfwLd3mVpozHB8RCBMXO3oD1/ALlIyJvkLZCN4rh+B7IHKoVvcnbDPoJPkkxQjC0axbW\nEQ3fwMoxa7nHxXxOfviQ9bpBWE3lYpSoaRqJGgzwMiIzG66qGJMOOb+cMUxLbDHBqZh4f0LPXXCy\nsLi8h+ntIV98RNQuRh6QrlNggRCSzne560nlnQw4vi1peWFvNYGwEDmLcIZYeIzxRM5DWxpqhCDC\ntUDDIozBGxcclq3b+TIIJW68GVzYsGvnSRCB/4MPUjof5OkJoJ3FWIPz3fdcK8Vn57ysXCAldxbm\ngrb1BUEaH3lBZj3O6dAjqc0MeatxRiMaA42lKSv8FyDmz4u3H004nAxZrbdcTa8RZYl2no3WLLaW\neakRieKNeyPiLGK2XeC9ZL7VPDjqcbFYsawN0jRsF5rpuuT+pCBKPWnhOdjzCGXYbkuqckmSe5IE\n9lPJq+mSQb6PlCll3bDVTXCS1R7dwKvLksnxiPEoQg08w35EZRyrTYnxsFpWZP2UTdmQFxl5HFP0\nEi4WW3oKJqMeWRRxMbPsTybs9W8nS5+t1hjnODg4oCgKfvKTvyWJE5pGc3x8zOXFBZv1mjeePOHk\n5ITT01MePLjPfD4LHAljMMYipWS73TKdTrl//z4/+tGPWa1WHBwc7Jx++/0+WZaFBSKOybKM7bZE\nKkW/P2A+n7O/f4CUiuVqRa+XM5td45yjqkqGwwHb7YZiMMAYy8XlFBVF5HmPPE+RUlD0+xTDISpJ\nmBw+5tXpFU/efIdHj5/eet6IlusiPqPEfINvbvgVYRMSiChms9ah2asLJeGwiYXTYiRDixbXgnsc\nRFJSNo7ZdMbpyzMAnj9/xdnpJR8/P+HlyUf88me/wrbPTlB7/vaSKQinTEQg88o2FSOEuFEPuJBR\n8DbsOB2341Zj033uwJL/5MeuHPXpofOf/NzxNj4dvt3AXtWeuQmbviT4i0368PEcrGvaEoOh1vCT\nH7yC5yc8UhXWbJhe/pI0KhBScbD/JnEyYLM1KCmYraYIpanWJ0SRwDqDVD2sL2mc4lvf+irDvfT2\nSLAN9xhZ1UwAAAUjSURBVPrmGW5HS9K+KTHuwJAL/K4AgF4rdXU/R8c9F+29bJtluq60JZAqNGGO\nU1CxJ4qhVPBqeYXTJqyjpl3bfVh3kzcPYdjHV+XNjRDBjiQIbDpAbEHFCDz9yX3wJUWeYustkfL0\ne7fr3xeux7ccp84rpy03dSW9djyMuyEPGxtSBdb5nZmjdxZtQwVDtyDJttkgZz3WSLQWWCODt5Dx\nbTanuz8tX8qIXcPWbtz9a2A1lB5boLQrkbUtyRoT7EFM+JvbxqLiCNM4+oPPz359MeBxjrWOkEmK\n2syppxt8ccB2dk0zOqCO++jVlNx71nVJ1TK9dXGMJ2IbFWwl9MyWbXnN1q/J9BV721Ocrcn7MbJa\ncTAaEg/v0RsPyOOMYeEox28y0guEsBR+Rd1/BKIhtQ3GpER7b3BfnJG3KLJ6dhomoWBn8hW7YKrn\nCQuR2HlSgFBBoRTtOD236+wsZEC52oTSUmQtjQ1GhJU1OFOFmmzbOVdYT+kM3jqSKKKDV8I4VDv5\nKhd6ETkXskNRJNsJ5oJplXfUzoaMFCFbhAugKN4pCsLq571DtmRt57ufCSee8PuBK4QP91k4G1xo\nnSNyGudCc1O0xX4Gn+Lvi2a+Yr0oKRRU1nFytWQ2nzFAoVTDs8s1e2nE9x7ts9pqau05GKRYLzm7\n3pD0UoaZJJaKw37GW/sFk37GS2+JWq6SijNEbNk7GGJshiVlriXjfsJ6s6ZpSrJMYLXA1LBee+JI\nMBnn2KZCiwisZLBXMM5TojwjSyXaCUzVkAtJ1Ri086QqZdSLWJQWKRQphidHMdeLJWwvbzU2UX/C\nVguWywVNU3NwcIC1ltFoyPn5eehiHimMMRRFwf7+/o6QnCQJSZIwHA5ZLBYcHh7y1ltv8ezZMx4+\nfMi9e/faklSFEAJrLYNBH601m82WsqwYjUY0TcX19TVlWXF1dY2UAucc17MZh4dHWGsZj8cM+n3K\nbUnRy0OPrSjm4YMH5FnGcrlkMOiT5ikOx3A45q/+6v9QFEPef/8jfv7TX9x63uyIlJ8x5bp93BnH\nJzgwHqw29LKIzUZzPatuXHJdMJgzhMOENiHz21WYIgmTowlvvv0EgDefPuLBowc8efqEN956hz/4\no2/v/pYQnjj+DDJ1q4AJqqj237bNmgQP1PBzUgQg1yGJLzM2HWj5FNj5RPrHfcbXXvv8Otj5dJmr\nM3l7WQnq9n1HEgYxvH8NtXFY3YRNTmtq2efX769wJ88ZX00ZDd9gU53jnKbcznk5/QDvSmqzppf1\nSbIDBqN3SJIU6zOiqCBSY4yp8DLmz/7lP27H9EuMT6D5YttSlNspfTpD1puN1bSIMWz2olVc3Qxa\n1wncA1K2QmrflW/C60gliOOQ2YnjUKFKYvjh+Qd4bfGN2WVqpFJIKcnefhiaN+sKbw2uafAqDkpY\n3UCU4p0BROiY3tToasto/4gIwfG9h8Rpn6Ojv7nt4LzOc3+t5BkeONe6JduWsGw7kLHLfgUhjfOh\nW7ox4VkyBpyRGAPGORrrW26UwLTs/a6U5doS4utOzrtMTpfV6cCXuWlv4TtOT1toSKTE6AYJNCZk\neLT11JXGaAM8/9xhEP6zoP5d3MVd3MVd3MVd3MXvUXzJvr13cRd3cRd3cRd3cRf//8Qd4LmLu7iL\nu7iLu7iL3/u4Azx3cRd3cRd3cRd38Xsfd4DnLu7iLu7iLu7iLn7v4w7w3MVd3MVd3MVd3MXvfdwB\nnru4i7u4i7u4i7v4vY//B8JgNN16YgdXAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAABJCAYAAAA5f/zBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8ebBlV3Xm+Vt7n+GObx5zHqRMDakB\nCQSIQQgwwmAwGLtx4fJcVQ3G7cLdptyFXXa7XIZwea4qV5W7beMJGxuXzWCMoRBIIDRASmhIZSqV\nc77MfJlvvO/OZ9h79R/nvpdiDD1HdHSE430RCr28wzn3rrv23t/+1re2qCpb2MIWtrCFLWxhC/+U\nYf7//gBb2MIWtrCFLWxhC/9fY4vwbGELW9jCFrawhX/y2CI8W9jCFrawhS1s4Z88tgjPFrawhS1s\nYQtb+CePLcKzhS1sYQtb2MIW/skj+HZP/sEvvEM9Hu8NYhTvBeuV3CjiFFXweKx3OAxGhKNHT/DM\ng09QqVpEPV4ECzhAUBQhzwEFY4vrghJoQCaORuZIe7Dvzz5OdTZm/lKL3fumcXOnyemyP3mAdmOF\nkewZoILPE9JMafcy0paweMUTBoZr98ZkrkG3VWJtLWftXIRbjhEr7Ns5is0hU7jvy8/w8tv2MFQJ\n+aUPPSLPN3B3HT+rAYIRwYvBiVA1QobgTIAzjmEsDSOExiIIIoYujlCEKqAi5IAVgwgYEWIRegAC\nUoQJgyFDsRhEPFY9iCFC8WLwqqSAQfAKqoqIkiuIKqmAeI8oeCDxHgS8KhGCB/I8p4Qh9x4rgqpi\nVPFAoJ5790w/79gAPDifqQ1C/uzXf41Dh27h0D2vI+07nBMyL6DgPKhXvIL3gvOKevC+uIYdfH+A\nQARM8bwqg9wDp0VWGSliZowiIogp8s4YRQyIQGQFMYoRxRoQI4O4ewRfXEOv3lMGceA5j6tIkcc6\nuCEAyiu2h887Pq98zcu1sbLCwQP7aawsM1Sr0W01mZ4c4Y4X3Ui32aRWqbJrxx5ec/c9PP7oY5hY\nOTV3gSvLK4RqOHLqWf7m4/cShiOEpYS9B8cYGa0ionhnyHNPs9WkUqqxstqi1+1jJMR7B5JTKofE\npZB6rUI5ipjeViOsjBPYKlnqQRzGZPSThMrQHezeeTfLjWc5c+Y+Ahpk/UUqlR1E8T66vsqNN72a\nbqfN0vyDrCwfYXbna7jr7rfzxhdOoEWwnhd+9jffpquLXZyxLC9f4PSzp4gqltQr8+ebLHXgDz/w\nn7hm3xQ/8p7vh6ROUMkI8iqVekxbLlMeUzqrguZKtytEATinBEaoTUOWKGOTEXbpIBdPrCI+pdFs\nMzE5zN2vfREHrjnIrp0HkbjEV499ht/9rT8h6imv+6FX8t3f9f0sLZ/gT/7sLzn8wCWCkhCUFclg\nwpRpt1NKs0rqPaNThtVFGNkJnTVwbRibgkjA9YW1lnLqsHvesfmRH9yrVxopFsWKodns0O/18E6J\nSkIUWZbzLn90x6u49c7bcLqdLLlEd6jChf4a9d0HCGyMek+WrpH0G7TXlvjEA4+zttbj0SfOQ6j8\n/E+/E2NXGR2uEVcqWDUghlxTPvZ3D/LBP70fCUwxp+EI44g3vPluxobLhMZTCgx4B0ZQbzF4sixD\nFBRBxZPnnrfe9moiE/Jbv/dBhvunaNlhHu+UePbJM5gILiw+/7wB+J33/4QmSZPQWuIoopghhCz1\nXFy4RC9N6HS7tJIOvbQL3hEYQbVJLexSjQMqUZ3YRsVcgBbzg8akudLPHKmCCWNK5QrlaJg4rlGK\nYkQC8syR5zlZ7mi3ezTW1pi7dIFOP6GfDeY4daROafUTBAtiwViiQAhMiLVCYCzWCqG1CIr3IN6A\nFnOSFUNghQcePfG847OwcEZFFbAEgS3mAaBUHceIkPQbiIQEYQSAsRYRAVWMCbHWUsx3V28p8rW3\n//p/A/STDuePPUTz8jzDw0PsvOlOyvXJr7kOgKrHPfVhrLRg58tRDH75As0nH2Z1pUXL1vDjM0zu\nOUA0PAriiEojhKU6YkJACaMK1oSkT32E8q1v/6ax+baEx4lFfbEooIJB8QJGBacehSLlJUC8R8Xj\n8pwgAERxCOIFrGIUPAarHm+Kxc15TxSAd4bMOsLcMeRg5CMfoRH26DVS6rNVehdPMFtR2u1l4s5p\nNJ+nv9rn7NwVli530WQf45Xb2b/3RkbH4fLFeRZPpoTDj7GyepHGlToXLjbYXZlmtZVw/2OXiENl\n22SJFx/aRiW2iA2/XSi+EarkKlgDViD0HowlpCAa6g2rtiAfuXMYYzGSEWPJVekKRFCQjMGiGmHI\n1RMbg6gnw5BSkMZQBcQReiETKSYOhR5KoMUkqqpkQICQqWJVCBBC5wgN9MTgvMMNSKgB2t5TFkHF\n0lfHkBREMEcxA4pq8ZuLDQWJUJfxgz/7Xj75f/8ep3/7P3LPu95LGAmSObwarAFVwXtwqgQe1Bck\nRorkKriGAihGBDf4ngx4iFW/8cKCvBTkUQwYKYiQHRAbaxVjiuusEx2kIEDF6Ci+L3xzsiNSEMji\nFcV71kfBZjA7PUWtUiIMDPv37aGxukwUBwShodvtEoYVdu3cx2tf+2qefPwwvaSF8RFHn36aF770\nTubPXyQOQw5cs5eTJy9TqlosFudSKtUyLjc436NaNXjfJQh71Ict3baSZxnGQL+XEoQBgqMaClmv\nRxA2icKQKI5I+hl5lhMYSxiFVEZGCEb2k+YXaC8dp52v0GxdwTUzwvIYjz32YWoVz9rKZYJ4mp37\nDzIyObrpvPnOF/9LXvWy1/PRe/8f/u2v/ivGZuGtb3oP937xf3Bxocmv/KtfYGK6zL/+2e/n5bd+\nL6nv0G4us7KyQBKuUK96lhehuyrEkVAuQZbC0JRgAyXLheqw0O0m1IIm4zOWSGo0Oy2mpsYIohI2\nuUJw5Th++xtJ+jm5E2KrdFsLHD3+WZ4+MkezWfzqUQi798cICa0zKe22Y2aiTD/tkaaeXtswBaQO\nnEKvAa02tFYVl24uNiqGKBRCMbR7DmMt3nmcUwwRzkEIHFla5tbza3QmQk5Km2beZq3TJnzmKRrN\nJiM1w8OPH8Oqp9P3dFLPpQstXKb0kiap9jl/usfv//EHeeVLDrG41GV6tsbdL7uRa3aPU69FpElK\nHMLsjhluesF1TE+VGS5HQEbmDKKC94Ysz/EIYSnCOY86T6kUkSaOo4vP4pc8tnOZseGIZy4tUR7Z\nx+hYwHIj33TugCVNPBIJVjwixebH+2JTmDsl90ruDN6HpLkUc4SWMOoITEQcxHiJQATnHQEGxCIa\nEBiDkRJxqUa5XCculYnjmDiKMMbinCfLPEk/RV1MlgVEcY9eb40QR6oeXIBXh3OO3AkqilhIc4M1\nHmss1kJkDXFUzEiCYLzDYjfmNqebK870O6tI1iOIykipShhVQAyqrvh+YjDGIDK4rhb31YJyoYP7\nFZxGBi/Rb0pyrj4uGLHYwJL1e7hShObJYF4thI5iFRqIAbN3cvnpLzIzERLVx8nGaiztyZnPj5A5\nw4gILutipU6W54MNboC1FvUOQcEYepO3U/4Wcfi2hCdQJRc2WOb6NzZasHRVz/oxPkUcDFmeMRBt\nioWXQikAi+JJvMEYxYol8w5rBI8nUk/egx1/8jHmXIOyq1GWKwzNX2FoeIyh9Ahpw7C0lFIKAg4/\ncpkXH3wf77zndiYnhshshXMXFjh/6RL79+/h8sISe7a/gXPP/A9e/p1v5ezJk3zoz38P+imu52l2\nDGnaJxur0Eo8Q9XseaZOAYtBTUECQ69gDeocKgYxjkAMeE/f2AGp8cSqJChqIPWCGshECVBCFTIc\nI0ZIPaSDpTcWg1cQPIEqDiX0kBiHU8HgsASkWkyoRZopXj0W2SBTZYUATyxCpqB4HAVBCEQAh1Ml\n1eKzJgihKhZINrfRKqCAeHya8OZ3/q988g/+mN999/fzz/+v/8LI5FhBkLXIHDVgVcgH6o3xA9Kx\nIbUMrjeIiWihShUDx2y8RARUrhIdY8CKFP83irXFQDRGMbiBKlTsU4vzqPS5N9v4DBuURovrFzut\n4tOIymb5DmFoGB8dot1s0lpd5ZprdzFS302lUqbbXmN8qMbU1BRf+MJ9XJw7wy033UQQD/E9b3sb\n+w4c4PhTz7Bt1zYmZ3bwO8/8V1xaI0sKFcPlSr+f0OsoeRYVeSDDGCPU60oYWoIgQETIXU6vl7Br\nz276ecLKWpNKPITBgne4LEXiKr1+m7XmFUpDyuTkCGE2RGiGaTSbEDi2bTvA8uo8Z088ShRN8upX\nvJE77ryTSnnzFfOJyUkA7nnFP+OTn/8zqmNT/NVHf5uVJVi5CI8c/iqz18D266/nphfeyi37b+Q7\n7noLl5eOs+vO63jBwVHGAke71yKsCM4PSKkIS+eF+iQEEawtCSvNs6QJ+D7YUonR8SGibIldS/ey\nLYWjJx7g0eOOqKTQgUZrlWMnH+f++08zfwUCK8RlZc+ePVSijCcuXyCMHP1Gwvf96Dv5+4//Ea04\nodcUshzEQ3dB0UCQAKau2VxsYhPQ73XoeMUKpL0M5xxg8F5Rr6Q5zOy+jnO7x3l45WkuL17m+u0H\n+dX/9iHOXYGRClxZgNkxePM9L2FmqsxXnzpPmqUkWUa9VqO5coVdO7fx5JGcJ488Tq0GUQSf/LvH\nmBmDG/aMMbV9O2NjdcLhErXKCGEQ0UuzYgOBIwoMQegpl0JUDb1+ghGLsQHGGMqliMutFZoX5hmS\nZUIpEQXKdK3HLa/exhcPNzadO6qQOUeQB/gArBbKuyI4H4B4rCkRWKGXD5QTKVRxl3ucC3EuRiVE\nB+qGQ3DeIhJhTYQJYgJTJg6r1Ct1bBxjEbwqgYXAgDUR3gX0UyEOqgQ2QfMEY6GLR1PFqyX3isdg\nPFgDuTFAcZ3MQqaeyBYKu5VizrbGIFoICJuBSzr0zh7BZwml0VmGd11PaWgSXIqaMvgU7zPEGASP\nDtYQEKwN0bAgW8U8KRtER5Xn/H11szj4RUCEKCoRGGWYHnFvDg0c3uXgEvAJ4nsQjdBZWyUcn2Vo\n9gAAQVoilpSJimXNhahPUZdiggjaqyRiUfWUrCX0SaHaRxXK2v6Wcfi2hCdnfYctG3tfAbz4AdUb\naADFty4Wh9ThpShhqR+sDHgQTxG/YgUxogS2WKEEQ3rec+D3/zNnAs+QGmbr89iVJxkmp3X8FFpt\nMGI98dCjhOnr+Pl3/SrbJ3dz7uwX+blf/ATv/ol3sDSfMjE6Tm1kBHWesakxDux9JxdPPMOLX/sq\nPv+pPyIjJutktJKUTqfDpYtNjFh6WWdzGeSVgcqINw7ri2CmooXkjCPBEqsr1BIRVA2BKN5BaqFM\nQT4iKWIVidLHor4YARZAcwwBKR6DECt0jTKM0LNQ8wZRV+gNoqgaMlVCNQjFDiURpTEgExaPUaFP\nUZI0CLn35APlwgnUEYx6VAzqPTGbHF2wzoABJe+nfNe/+GGaBPzNv/+X/OQffJSk74tJenBfo4MY\nAeplwG+K96+rPIWwozhVrF69vgJmg+ys/yeDstWA/IhiDBhxhaIj3/AxB+SneGB98D73m+tzH9Dn\nUKNNHt6ZJX1KUUhohRe+6DacJpRKhnq9ytTYGKNDZbK8zdLSApOTMywtNakOweJai+OnTnH5/DzT\nOyZ52UtfxBNPvJInn3qC1mqHTC2tZsrl+WWaa13q9SHEFASrXq8SxZZSOcKYkMBEeCMEPqVqStRq\no6ysLtJLQ0Q8vaxL5voMx7NMTY9SkjamvUbg15iYqFArjRNFAc0eROUa43YnF06fYXLkWg5ddzvT\nY+XnMNbnj0ZzBYCeW+Pmg29ifHqUi5evEKTLXA6e4QsPf4J+7Rl2T13P5PAwQVAC4NzcMXaOV7hm\n/508+MVPMjE+RDNt4hGiEgShZ3JPIQMunlMa84o1QliCTqaUfEiULHGbf5Zdo9Bzhmtqy3xHBOdn\nDMl5pddMONXp0FqD0AqBCJcXlItn2txzz+v4yt9/mCiK6Kx6vvqVvyPJ+pTHQAKInWVtydFrCfUx\ny75Dihl2m4tNOyHLC0UyzR1hLPT7HjGGzOWEQUhJYOTuAyyULTO1vYhXHv7KaZaWoQy4Ptx8zSRv\necMtXFpZ5ZGnLnHq5DKdNEcDQ6nX48hjj3LXGyZ425tv4YH7n2DbDMyMx8zumGLPvl3UhuuItUBA\nY3WVTt6kUi0jYgmkUFE6DpzLKQWeeiVkqA69nqOfeKypkmdKKQx56XU38ODcIlHFc+LkGdJz59l5\nXcjuneObzh1RED9Qhz2FeoIgqgghkRiILLkXkiwvFA0DgRqsCIENgRiVABFPaArCbG0MBGBLWBsR\nxRUqlQqluIQ1IR5BdKAwIFgbgIc0ySiFAWXr8ekyNh4iyJt41ymsB4AxIWFUxxoDBhRDYACjRVne\nFiV5C4TWEInFiAXzbZfub4AVQxxWyBOPNpdJFs4RVYcRnyBhCZ+2yddWsHEJzRJc0sS3FwuFd2IP\n8ezN2DAeBLZQZdahKhuKzoYyTkF8BMjxZP0W46OrmH4Dqi/H4tAzf4zMn4K5Iyyb21i7/vsYv/Zm\nvOsiNsT7HhJXGNr/AspOWZs7TZ5mpGmXqew0bmWFXCKy6nY68QShKLVsnlL/BOy49ZvG4dtGrfgK\nvlhqRfHI1YVmUDJxPi8UHwAj5GlOYAb+DBVyp0RxMPCGGNQpaopBEdiCtdJ0jP27XyC/8QDm8gWG\n82OYJKOUryK1axnbtkLcaSL9p3nbq47jc6iUY5ZPn8FqwK037cGnfW6++VaOPHuS8TBi547tBJHB\nKMzu3UkeRnjrGYprmKowLop3YyRpTpo5sjTZVAKpKqE3ZEXFjmxQ7xUVOuqJvEWMRxAi8QTe0hcP\nxmAQIq+0BELxOBVCoWD96lEpdifgMQq9QSmrHOQF01dP9+gxspUVZGYWDhzAu0GJEIOn8PBYLVS1\n8uA3W8ETyMArI2DXF/aByuG1KM81VQnFEKkpyNo/iu9cTXxESXspP/AvfoCP14f4ww/8Dj/w0z9F\nEILL3SCf1lUTCq8OAx6hV3cSqoJBCnKj6wrLYOIogl/UuI0O5GzdID9F7Tvf8Po8d2BuJLteLU9t\n7FqeY+rZ4EU8hwjJ1fc8X9TrQzz91BMcuuF6ev0ey6vzjI3VGarXuffez/L6e+6mUi7hvbK6usZa\no8nQ6CSNdovp7duIy2W8czz80P3cccchMrocPf4MkyN78XmfqclRtm2bIootqin9fhdjckKJQQKG\n62NMTswwMTFBGc/KuWdZXD7HyIGDmLhGJ2kXZU9jCn9IGNNpXSZtnSHpX6BWr+E1wkhMRJcgTGi1\nG4zP7OWGm17J1OxU4ZvSTdZsgA//xX/mgWf/ggfu/yCTU6/h8b++l+UluOuut3Lp/Ck6SUJ7ucvu\nWw7w+u94B48c/jgL7Tn+5jMf5od+5D187K9/n7MLsG9fm+FxQ6cF3RXP3mvHWFte4/Qxx/RUxAtu\n386Rx8+AgRt2T5Mvpbxu+lmu3wY2EIYi6BrDnYfgL08aemK4+ZZX8qWH78crxJHgUygZ4diRi6wu\nfpQkH5T9c08uDWZ3jtDvWlbO92jOdclSoT4D1eGM1hoE+eYUsMAYSnFIniWoETrtlNxBLRIyFbKs\nyzX7dpLblKynVGojjEyP8/k//SwrS2As3HFomte/9gXMLS6w1klx3hbWAwTJlcSUOPfsAg+MPMwb\nXnOA4UqT3Xu2MzY+ggmKbW8UlMjyEMFQCpq0Wx186jDiCQNLlufk3hNGESYKWEoSrCi1mrB9pk6/\nA40kZaQ8S7+xwu0vfhkHbj6EHLqfF7/oADt3jVGtxJvOndzlg/HoUF/E1mmOxxbj3hgiDHFoicLC\nH2MBGZSLAhthg0E5SwxBYMEpgYlQMQRhCDYiLkWUSiVCG2AlLFQgLcpDBsFYg68ayhoy/YJ9vODa\n2xkeGkZE6LSbHHvmCLEGJGmXS2eOcv7iV7jUPEcpqGNNVMzxRikZixlsGkIjVAIIjSl8jLI5slwf\n341UhvDtNbIrT5CtzuG2X4u4DrY0Dr0G+uzHcZlDsz6SpwQGgriCaoYfuwaxwaDUpYUCBhtE52qZ\ni431BEwxf0pAP89IfY14aQ6NToINIZ+gzxArZg/tsX2MbL+BwFqSfoswCPFZQlQuY4Iykc9ZQ3F5\nQmthjumz/wVT9YS7v4dydQjyLtpYhLlH6F1epPLCbx6H50ETbbF4qMHgUDUDT09haFVT+HnEC1Yg\ndxkiEEihEBkriPqNkkFRkjAgileDaIZWKux++R00m01GgkuMts5Sqm7D2F1MS5M0fYLZfoNr8y5f\n/sLDzEzv5szpE9TrVWan9vLWt1yPVob40iOHmZ7aQbfT5KnDj/CDP/5jLJw/zcWFJVpnztLpWTpJ\ntyAr1lIKDVEYUopDnFY3lUAGTwQIhnxAKFJyYmBIDT0csVocDhFD5j0usIj3GDGoSGEYVsEjrKgy\nbAoVB4FcHUYNmXiGVegJdNXgFhfIP/ghkjxHnaPvoHT9HsK3f29Rb9Yci5AqpJITYjEeHIoDQlWc\nEYz3eAqFJ9JC9lcp1KWSGLqqGHXk6nj+dtyr0HUJhPVBoCS9hDe//U28fjTi6cMP8Z7/+Jvs2r+N\nNEkHeQGqptB0tKjvquoG2VH1hGI2BBUd7KkMDMhOUc4rjMsUNV4pvE+GosT1jTVnHaiMAw1TrhKc\nq4NX0a99x9deYZMKz+kz5xgeG2f/tdfw+FcPc+OhA8xMjtPv9dh/zTWsNfoYmrS7XbqtNnfdfRfH\nnjnPTTfdzA23HKLdaZP3cy787d8S2piR4VF27Nhd7AQDx8R4hcwXSmwUlKhW6hitoP0AVBivzbK2\n0KXkVihXE2qjVcZnqlxsNulpRh4aQhNgel06yxc5JyNUa0M0F06g/Qad1giUKrhcMc6xcukErW6P\n4fHrsLUpKBnSDDrNzXu/Pv7xTyCPwVAAYeU8YSWkUnV86aG/Za0JWQZZnvHlIx/lqz97P9fetJvf\n+IN3cmo+Yc+OHVxoLFAJhPlzysgw2FBwCuN6CLhEnp8kjD3NtWWCCuyf2cX+eo07dh3lRQcFYws6\nG1chyQqPzo/eoPzxYcMDD3yO8+fakJsNP5ANISwZ5pdWmYpiUlX6TU+9MsPYbMrFI1OcuXiYzAnD\nu4XKpGekPMLEVInEb05Vjq3Q63nWs9U7Ty0OyX3hV8HAP/u+N+B9CZe3yExKFJf5jV98N1944GmO\nHZvjpht30W73WV3usdLq0u9DkubooMSce+Fsx7L21Hmi6jg3vfAGrFicFkZjIwHdtEscV5BgiEo1\nolaC3tpZ+naGNAvIXUZgQvrdhH6nX4w7AlaXMuaDJjcf3EY5KtHtL5GswUtetouRPQu8bqTCyTOH\nOf7kCp1eyo+8772bik8vyXBkOF94FXMF7xwejxpfzA9OcNZhwkLtxnvEFAu4RAYTmHUXHw5bPBdY\nwCJhAGKwYYgYwYjBBEGhcHiLkhHEQ2SuTtRL2F6rMDV5LYFVZofqzMyMElRqvOpFL6Lf6zN34ixy\n4+0snb6Fk1cu8cjcw5xaPUrVRogJiYwv5rhBNUFCQUkL3+MmbQZZr4HOP4peOY6mKeHEIXy3gaQN\n/Og1mMBSCnKMJig9JPRgAzRQfG8B31tBStVBZNZ1cAZ8QJ6j8rBe59qoCYkx5FGZk36M8nLMVG2a\nqDaMuf5lrI4+zbGlTzM+Mcu2oTq9hWOUK6MYiYumJiBrLkF/jbzfxqUthpcex+/7GfTMfUg3REZ3\nIfVhVBZpZTs5ywrfXN95HoTHDwxI60YkL4pgcDJYG7xgMOigvpflWWGrMBBpwfEcgBGsFJ1e3ivG\ngjEenYObfvNXeLZssc0rjGbLBPVZxFSZTT7DLd15ri+1iKoZeZrzD088SlQqs23PDjLnkcoopxaX\n+OIXPsXr3/wm8m6P+z59L69+42s5ffYMv/Rz7yNXw8EbbmBqWIv6H4pzSq+nhUzsfbEwbgKig9gY\nT8ZARcAUXVAoBui5HG8C6t6j1hCp3yiv9L0SGKWvQmYgMkLmix2jH5QQ+5JvdFZlKLF3LP3hn+J6\nXTSwhC5BA0P/6ePob/8u0f/2E2RiUCtEaklRSs7Rl8K8XFHBqZIxUJoGC7Uf1KpzcuLB5BagJAPP\nSvcfofDocweFrAshStpL+Ld/+yV++i13sPKOE/zA//lrfOf3vpqknw2Ijd+YcLy6AdG5SnzYeH79\n6ut/+g0zMgPiU1h0Cq+OyNeTnfVy63M8Ol9z3UHef5Ny3rrCsz7nbDJ1cLmjPFzjy1/+Ctfu3025\nVKHd7tBpdRgbGaPT7bG0tMA1+3YT2QDvHRPTo3zxS5/n3KVT3Hb7bXQaObV4hLm5S+yd2cu+Hdfy\n5WMPk3tDt5viAsVGAbkPKZk6lWiYcslQKZfJc0feX2P/3mtJu/NM1odYWVlh394dPPLkk0S1MhJ6\nKiUDSY8r5x/D2hjJVnGuS9TrURubJAjL5LYGrk3W7xPFa/ikxdL8GuVSlfNnFzcXGKAc10j6bda8\ncOTICSqVKqENWWs0CGsG4+DcqQXOnF4ABw89/hWyDMJQeObIBYwIM3uFuKK0VwXfgWbDEFUqRK0I\nMXBpLqdca3LrtfvZPTzGC+yjXD8Dx04ZbjpY/J7n55TRmiAGdk7Aj73I8een+rQQghzSHExYdBj2\nuoV6gheiwNJp5nSbwq233MYn7v8oPhCicWHiWmivQUZKJ8totjdHeFabCQNfAF6Fahzi8gT1wtBQ\nwH/9wC+z3O+y0lwkCgNEoV4bZywe4cDORYZqNS5dXiRNMoIootFosLLSLtR7n1OqjzFUCwizJbQX\no3lKKJZSHBW+L4RqZZh+WnQ82nCIlXyVPLeElW0EpkSae4xxRYen9c8xVRs0M3S7Ofc9cIo9O4YJ\nI8c9d93KleUv8A9/uMr7/9tTvPGuISYm6uSlkU3nTidvopoSIoW3zmvRpEGGJwFRcuPJSHEmAesR\nmw9iqmTGk5kcFQM4jDUYLNYIDDapGC0IFBSlLAARvBVWGzGtY3PEi2eIXQbVCZbdYdLWMskL7iA7\neIDtt99OFEVU4oRU1jj55VQFs2QAACAASURBVL9jrH2JvX6IWw++jLnOTXxs7kusZBchiArTiBau\nGkIQCUCLEtdm4JorSFbGhzuQShXfb5M9fT9GM4Ltd3BV2c7BgPgUfI44Qbsr+OYVpD6DGDvoVJUN\nKwGqWJ8U4oV6im4wj/oMkydUrTJUqbHW7LDavsyul303Wd6lVKlR0i7TtsXS3DEWhqoMj80gYjZm\nXRtEOOniky42bSFGqKcPc/FizKXuNnY3ypRrYGo5WpplzS9x7PDfc+tb3/1N4/BtCY8vfme8L9p4\nPQPDlM8xYkAdCFjHhk+l3+mQeoil6OjKHARBUfbxGMQqIh6CAHcu4+DnP82xesyICYjqe0lXu8Sm\nynZ3nLf15qiWl0m7ln4zpbFcpdHs8NCXHmBqcpJeo0Ur7XPHnXfyzp9+Dzu2z5IuL/D+X34fTx75\nEq12AnmPPbNjrB35NEO1EIMQ2GLiq9SUCkGhgMjmdqOlvGgUF2OoGMiNJ5CA7qBEFWAwRii5HLVC\nTXMCZ0hESMVhxGCdYHFEvigZOhESLNZA2WW4uTnkwhythUWSXpejXzkCwxF//W9+kGpVIO2RJwaN\nyvzeZ77E048+wfCLbiPNHIkUJEkHqkmCxyskAw+PRUlxoEJXIfAeNyAZXVVChZI6EjUkuvmdunMD\noXOd9Euhhimel7z4Jh5bcTx9/Bw/+NJX8+Ffr1KZup33vP8/cN2hHSRJMXA2Sm7qN/rE/LpreIOi\nFENjvTiw7uVZv/c3lKF0nejIxhVkcJmrnVkbEtLGhLYuK8lz61n6df9+nuj1e3Q6Fs1TWq02rXab\n0ydOMDszzYtuu43e6ip79uyhWqsTiGHhyiLV+jT79x4gKpc5/OUn6aae86uLTG3fRpamuCxlciik\nlQg5Buc8ISG7Jndy/Z59tFYXOXv6IsuLZ7j+xoPs37adxuIp4uoo8wtNsjSgjOP773k7TuCzD/5P\nzi/NUYoiRB2dZpOJyb0E/Yx3/fO388DDn+P0wjxXul3q0W5cu89q8winu01aFx4hc4YLC+c3Fxgg\nL7UJu5APSeF1cz1aaz16HYN2hE7bM1wxhALxOPSXhSzzxJOC60NzVUnWhNFxi/WOaEzZMwat5T4B\nNYIAvu+N38XN+3Zy5cjnuGfmK3R7hpf/QjDIokHXHwGFRl0sht9zC3zguz1fOCJ84GHIjeAyxQaC\njYqNXDHXGeKq4eHPnODYV0/gM6jvVmoTQtaHyohBXZ/5C0WJajNYbvSxoWViuE5jrcFa4hAtcvb9\n/+4nafWWWL24wsnTRzEBSO64eHaRobFdjI7XUe+ZmR7hvi8f5UMfOUkkMDoBr3r5frKlU8zsKLPr\n4G6i0o10ejnlMKTZ7pM7Q6lSR9WyttIlzxKsFaxpkxNjwjLRxBj4jEgzjA7KLap4dQXBzjJ6LWXu\nfIM0ddw+sYubrhniSqvBycujzOy9lu98fZ+1pfP00pxLzbVN507Dnyc0gkqZSALEGrxxOJ+Thl28\nelKXkkQdMkmxJQblbYv30DUZqTEERghNQF96VMKim1I1oUsbCHD0UZ9QokrrcomLT85R0TbToWM3\nKSMlpR4HYC6x2hXyuuLmvkQ8U8cff4DqgZu5+JUHWTv+FBPNx4hNj4NTMZ3mY1xjhJtGyizLa/hP\n6ccQbHG8CErfWqIB4dks9LO/St7LIHPYwgmJrdSQ0Qnc8nlMFILrQtZB8j74tNhcagdsj/zMg0h9\nFjMyWzScFPI54dnfhfF9yMithfv/8p9Aax5ZnkNO3wfZrQTZDPVkBLfjhey89U5UU6xR1PWozN7E\njruGmOmuIZ0V3KmnaF12lIIu1YlZxu21aBV61Vny2RuIW4v0/SF2lk+zyxtcEpGe/AiapHjnqV9+\njDtueOW3jMO379JC8N4PFqqrRiRvA3DZxtSQiWI9eHWkSZdKUJiWcw+xMawvVz6X4vwGUWQxo/pv\n3kv91huonDnMcHWaxlKPHbKfSvoUd/ROUdIFXN9A5ggkoJ1G/PhPvZc9+/fgfFG2sGGIcx7jcjpL\nl3nve3+G0UqdnJD99UuMBMKzpw9TrliUMYw6ZHCOjQDqPcGgTX4zcB5Co4gfeFA8IDmxgAssikN9\nwcht7iEwGIFIHUaEjuRYNdQEYnE4wGHQZ4+RPv4E882MaP8+7N49lEdHOfs3H+XX3/cK9s5sI3bL\nOFnGTrcJ21Vcf4R3ve0QP3eszjCGNVEyUQyWUJQEQ+ZzUl036xbkxqghGpQm84Hc6ynITuA9iAV1\n2H9El5ZHCpl9YBos+g900PboSdOEGw7s5iNffZCfesub0YVH+cV3/jg/9jO/xD1veWlR5lonJwOf\n1KDIBXA1H4ukZF2RKWx/xYuLImrhq9L1NsivPzuCr7HufEPHwdc5djauu35b0WJjsBl02l1GhuvE\ncczhRx+lXq9x3cHrSJI+88sLnH1mjhuuu5njz55kdGKIExcucv0NNdoJXDx9iiiKOX72AhKEENap\n12tE5RF2H7ieo898BckTwiAm0irHj82xNt/gxbddQ1wVoixipdFAGp7hkSEik1MZqXDxyipZFvDY\nE0+wY+dOXv+y15IlHe47/BCnzx6hXptk2+ReXNLk7z77P7npuv1Mb9/Pxz99H2vZKuNTkwRhwMW5\n83SabcJanTzvbS4wwE/++M/xB3/4K2R4cmdIV6C1Av0e1CODNwpeKY8LUR36DQgiIXWQdZXpkZjV\nXsLSoufADTtRGoxUpkiSDtY7jEJ7dZ5LcxnnloXHEnjpQeXzP+/JnMd5yHIIopR2D/ptxalQK8Nq\nSzh6uTDkRgMTPIA6oVZXSuKQdoQIuFRoXRIm99U49JIhmskFVpcE74XhSfCpsHJlcxuJajWiXolZ\nbjTJ8pw4CsjyHPEwXp3mr/72E4yNjlMf2c5otcriwmXGZoQ0beDzDmF1hHOXE7bt2s2v/PIBQqsE\nhZGOXrKXIIpQl7HW7BEGAUmeEocxzgUkvRyvKQKEIiS9HCIoiaeTOvpKsZF1joXFFu1OwtpahzxT\nriyscnlxgYmJCX7ou1/ILdeN8+y5S8xfgbWeo9HOudg5zvY948TXTnHsqeOIb246dzRs4E1EbhzG\nRoVq7wXnUyTuI5ojLiW0fYjSYv0QC1JYLrw6VARXyME4gYwOUSBYGxVdcNojo4NzjsbCMJeeuEgl\nbzJcC5jMl6j4HqIWYyNSb6lWa0SVKkGcMzpmyHtrpKvzxPTYNbSK2xFCmmPCnKltJfJ+inVXKHV7\n3FJ+MYf9I8VxHUCGR2y2aXUHICwHmLxoQ7KiYAriTdonn38GasME7QSbtUEc5Bmow/uUHHDJeUzj\nMmZoGjXF2ikKnbUxaq2HkPSv8OFUUQ6wh0h6e8niO1jrX2GVKo36KCbtEjbO4ae3k3tHqTRCEITY\n9ln8ic/ByinKy18GzYlG9sP+UTRswcTLqFR2sNvVySUib1+BZ/6BZG2IhQaUrzuIVKagOkW+6y5K\ncf1bxuHbEp7CwlCE12/sdq8uBiKC8x4Rj0rhsdA+aBRgBge5iR0ciIcShErPOUzmsTffwvZ//W4e\nvnyJvNFjNOnTi0doh10OXFxiv/skViFnBLSNy1M+d2qCd12zhyRNiy4T1eL8EECd473/x/v4xJ//\nJTPX7OGVr3w1xw6fpzLZRU0ZG4D6DO8FEY8JBBS6/ZxKlDNe3dxuy+BR5wgG8VEMOTmKEqrijSGx\nhq5kGAwVb8hNcbaERyhjCQQ6vS6BtbjAUDlxlJX5LpXXvoFwpEyZgPTkMW546Au85c2HaLYTHn9w\nnvryFVzY44XvGOapBxdoXVzGTozx1lv6fOpYRsUaeq02XF6k2WozOjFG/opXYG2IVXBe6SKE4jdM\nycYrPVGq3mOBjgh1FKuK3WzNBnDODGr3hUHaK1jVovtAipJXmqXs3TnJ73/6M/zEW7+HbOkZ/vy3\nfoYHP/cm3vsf/ncqNUuWuecwDDYIy7oqs/6HrNObge9G18/W0avdGsUPpYNS1HPKWTIgR1/vxXmO\n0nP1sec8PCg3bhadbp8giOn129x5552g8KlPfZoDB/bTT/rs2DXF8uoK5y/Os33PXroXlnj06FGW\nl5apVYdYbnep1GsYG9BYa+C8o1QqY8wIWX+U3CVkeZ9LjUXoBSxcWSLzCbXqMJgyi8srVMsB/X6H\ndqVDmmd0M4X8MlkOmnXZN/0SJmcnecmP/y+cv/JyauM7+dCHP8vlxTUunD3JZ+99lG3bdzEyuYsL\nc6dhWWh1OywuL/Cy6/ehYUxv9Vu3h34rvOtdv0AvXeb9v/PfQQt1OK4VBLfZzhkZEWZ2lulnPYKq\nYOLiR8jXPNVgiLd83z381Yc/wsIK+GyON7/+XUyO7OCzn/57otgzNgUPPvEoyf6DaN/w0UuzGJ3n\n7tsEG0CWwF98Et70OthXVVotIXewsgwfeVC577xhxDrSQPA51OoKRhifhXxl0MZLUeLymULU4bGH\ne4zsFMQJnSYMjcHQtNJc29y4qoaWuYsNJCjOj7FisCYgcxmJgTe9+TvYs30/n3v4Ph4/fporl5e5\nMHeZkycvE8fwwlt2MDY+Sa06xNKVDqPjo9TCAFsqUSsb1Ct51qdmLTjIck+/Z1HfRYwnQummijoh\ncwlJlrG22qXR7NPrpnQ6fdqdBBNYAmMIAkOeeSYmh3jHm1/L/l01njkzz1/fe44sM1QrDfbvnEV8\nSu6EfupJMsetN+7n8Scf2XTuBHEHJMWbBCelossUwfkMTL84q8XnRGGCVYfgERMgmuM1x/lC0ROx\ng+khxImwEjmgwpgdp5QO0U1zut0KK4+fJEha1CsRU1mD2HcITEo3CcFnVOsxkzum8c5Buka94snj\nCr350/jFE/TP3EsUlchaDeKawfiY+vAklaFtlC9f4juWt0P1xZy0j+AG5zhZU3hlNwufJZCnxTEk\noSMMKUhs0iY98QBZaZTo0hU06GNNcYaaGgFicu3h9SJy4Wns+E5MfZJiRCp6w3dzZfUW8naPib03\nEZYr9HsN5o5+kQZXCCav5eLRwwSxoZakrD32WQjq2JldUE9Rn2KzJq57GRs64ijHKDgt4Vv7MXWH\n15shncStzWPCiFLSpil30vZN1moxQ/OfxI7tJm+W8GsJsrIIt739m+fItwuSDHbXbuBxKdaD9fY7\nC+QYLE6L8pXX4mwD53TdSYpb960gBFqYmd0VmPrv/56GUyZXV7iVES5sn2TbyiKV1UVmzTm6PRiO\nLC7PCHyKNyEPH23ybhGEAJ97vDrIE6Jyid/4tV/j3k99nG0HD1Ku1rlw4SwTew9x5fynyXpdnIXM\nThZSr0BAoWxUyyFpbjm3ukkTmMuxxiI+I1GD14Rs0IZv85xSEIJa+lZIrNBVj8+VTKBkhcALSo7k\nCaVulysf/TRr+w8wfst1GN8hfeIE2cQojcceozq5h6Tcpddz3GJDHm/02PfKCotPTSA7L3H+aJfZ\nkqF79jLLrReQ3fdFsnab0p23s+3mmzn35ccIKLoIKiokxpNp0aZatoZW5nACkVcyhViEsnq8Fv6k\nwG+uIwAgz4UgKBQdGRjWvS3M0tj1CnBxwvPUeI0P/sPf8zM//KMsnHoU8+Tf8MNvfJxf+p0PcNPt\ne0n6/QHRKNSWdZIySNKNsyHYOChQv0Z9MeulMZFBd8HX/daDt3+Nx0ef495Zf06fQ/af89wmxUHE\nhDxz/BSBVVrNFsNDNe6++9UsLy8ShhHepBx+6jCjozN86eHHMUEIeYczcxc4cOB6+s5RLlXZvXs3\nDz30UGHCDyMWFq4wPbWNCxfPoMaSOGF5eZkbD1zPkdMXuHjyq4yORezbPU6tZPC5I4yWmZicYHhs\nlNFymUq5ThhXWJw/y0gwjTYsUZrwDx/7KP8vZ+8dZdl11/l+9j755sqhq6pzUqtbUssKlmTLAmdj\nYeB5DMaMPfiRBh4LDO+9Ab9hmAFm8BrSY94aZswANsFmwLYMxsiSbQUrWJZaUrc651Q53Lr5xL33\n++PcasmMbSj2Wt19b1Wtvrf23Wef3/7+viGJBa5vURkdwhhwqgHPH32WydFJrly7ysLiAnfd9wAD\nw9tZbjZx/JHNTQzwo//qdl4+doKRWi4XdwOBFRi6DYFWBulCasWYTLC+TG47YKDbA0WLW/a9nuIH\nLd5497vZPrMDozL+34//MgduqnL27Coyy3eu4xcus3V4kkg5/MXxATxrnXtvk7gevG4//O6nBXfv\nEdx7i6Hdgi8e0zyyAL60kRasKUWaipz8auetVmU0ljRk/ZsTCjpxbjzaPg3FCpTKEEXQbYC0N7dw\nriw2GKzWiKKIJE7IRM6NsSwXVutsHRtj+exLPPzki9T8KXZN7uOeg6MMDAxQqQZIS2J7Aiw4e+4q\nFy5c4ZVjZ2h2W0xNjDIzNYjn2iQxNBqt3KTSEiRJTJoYoiQ3e0tTRZr25dKWg1+wcH2PSq1EmioS\nlZEkGUOVgAfu2MJYYPH8hTleOKUIAod2J+WOm8aYW82oN0IsCxwLAkfQ7GXoks9v/O/v3vTake46\nQrho4ZBiIUW+y2utQYYYlSCMxgHc/iFJa4kWCdLk4pqN1rUlXTIdg0lIzA4OyrsQC+s8+cyzFLbs\nxbeaDNJGyAwRhfR0QtExFF3BQEmBH9CLMmr+Os2FWZrrHeTCNfzhSaLZC6y88AXE+gpBTWI7kvXL\nhrGxJuMFRXnyIEk0yER3jfvdaeLSzczFJ/D0q4kF/3AL+8eGjpK8G2GynC7gCRxLkURtTP0qWXgG\nWmtQdHGKReTwXopbd5OszZNdPwNRSHLpWeTQJO72OxB+BYTBIuXJP/4VxgYmqQ8WcbceoLzlFqrD\nM3jdRdJjn2VpNsXeVcQKhrAdl87CNdIMBkZmkE4JWR7HsjXC9rlRWvgp69PvpLDzzRRqY6jOItbO\nQVae+jNIt7E0OY2ZgPLFhxBtl8axCyStHjW/jeHbC5C+M2m5j+bbQqL6kQMYgcL0iW4SIzS5YW3u\nqSIVaFfcMGMymcSxc65IbDRCCG79zY/RnhxjrdOhuLBKeXwGGk3qi8tsd9aQgUCkgFAEbgctJTgp\no9WILM1Ik7zYSOKIgi155mtP8V9/7TeobZ1BJxGtdpOCHGPHbZMM7P6XBF5AtVLha5/7HaLYRWYZ\niZAIR+I5Nq5vYbubW0CpUaSZIRYCmcbESiMsC0saYgmBMYjMwSs4iNTQsjcobrlsFcD3Ijh1lIsv\nXmHqPW/PJ7vdQbc7rP3pX1M5dICBm3YzumWK+vxLOKU5zq1rDu+sUhsroNMDDCYhhw6HrC/C9ORe\nBv76ZQa2jjDnTqOPnOC6sal86AeJ4xRHQSYt0LmTc4LIlTbkbQJJruxJtCDVikBInH47arNDQB9N\nM32lVH7DwhJ9f6G8IjGAUppyweL3PvVn/J8f/gnOPPclZqY1H/ngj/CBf/0LvP/Hv480TZFCYVsO\nWhsc1yZNDGkGhYJApRlK6292R+4XPUaYV3k4N3g9r+4Y5pvhov4v8CrHx/RNgG6gQRttrddWPZsY\nb3rgLTz++FdyZCYokaSaa1evc/nyZaZnJnCdCfYfOMjCcp1Gd51SqULWSdi5YxeN9SblUoUgCGi1\nWoyMjFAqlbAsSafXQKksb606DuVikajU49iJV7j91juZmd5Bq34ZIzpo5SKMjecF1Ot1EpEwWNxG\nnKZIRzM4PESpWuPy3BXm13tcnVsnEj6NdoNDt95KL01ZmLvKwYMzXLqwzmq9zo7de9i3fx+9ToIl\nfBy5ebfcNTnLzE2DLPTquKEg7uVKz/IQ2B2BV4RrJzWDBUFlGuxhQa8OsRLEqebzT3yM2sgSf/uV\nJh/8336O46e/RLt9BD/YgyUL2G6uRLLchKXwCo4aJsPnvz5bxKbHPbcLDuyHoZcMv/Jpwd/vh8fO\naJYKkxza43L06BUC16VoNJkwRD1BUDCkTRgdEyzHIVqAyUBbkKR5dwAhiHrgFaCzImjUNV6wuXUz\nOjSYKzyRSMtGZYo07tBpw7aRrfhjN+EM383v/Zv3IbKIJGyhohYq6qDSECFtpO1j2TY3v+4gzn13\nYxXKnJ5d5LGvPcsv/8rv0OrBz//4YQqFKr3Y0Gm3AQijGJ3me560JYXAzrmBcZJ3iqUmShXNTsjk\noM9dt08hLMGTRy+ywx/B84somnS6KeOjBcaHq8wur9PuhJQCm0Y7xQiFJQ2drsbbMrjptSO8bs45\ngb5ytl+/2HlcjZAbvr79IXNzWGk27rM2oFHGIEmxjEHKUe5J3kp89RRnrixQ9kFmbfYWBd1mg0Zi\nUQgcTKbo6Zh11aNclAz7BlkqcOX4S6wvZBjXxyzPY3k+jbMvMX9+jkrg4Nsab9DCH4E4lHTW6njV\nFcrVgOzcPGOp4B3BAZ4r1TiTPo1jzDftXf/Uobohdv9QqMlQrsk9i2KF7w0zcOhm7KmDJLWtCLvA\nijOI0C0q1CndZTDtM3RefJbo1FcQfhln6maEHdBaPseuXYcZL1e4+sgniJo9Dh94K36pwsKxT1Pr\nnWP3gbcy3+2AHMaqjCKGp0mNRsV1sB3wawiviOiu5GvJgNs8R+Ox3+Ern/5D7jy8F9lZwvQ6dBaX\nwC+x3E44t+5w181bcdoe4dxCvk/7NlbR/7bz8I/48PR7mTrJPXQwGCGRRuQ+vQZ0H8HRBjKdkiko\nSE3az99ykRiTS9sVAidVzN13F9nkAFsuX8HUaqz1eqS2y/Zaj/HVF9hmTmCLEaS1lt9ktCYTFY4s\nDrN0fYXayDDdboSlYsLU8Eu/+BGqU5PY2CiZYbmCuBfy9JPPcMcDb+KBg7dz5BtfJ6iUca0CUgq0\nFGilicK0f0rf3CLy0lx5pXRKIdPYRhNhKCjQgUfPaGKTYqe5LA9loWSO7LjS0HMU6omnaK2kTD34\nXWS9HlIKisogCwW2/Mj7OP3Jv8AcfZG7H3wr83PL3HawwM57xrGaTdaeT1n1ztNai7h5r4N7e5Xf\n/b1H6RmHju3g3L+d7Cd+lNLwAKYXYYQkkzkR1JAiTZ5UJYVmUCtmTY7guSZ3x3aNpGM0Ba2xrM2z\n5LTOUZw8a4y+bFbcKJIRos/Hyf8oJfBtwe//+R/z73/u3/DkZ/6Q6e1b+dtPfIwjzz7Hr/7WRykN\nVXj6ySN4QYnnvvRZXjl6EhvFvkP38+bvexsTkyMMT1TI4gwlNOKG8eXG6/TXdE7s4VWcKf/OPxz/\n0KdnA2HiNf+lRtywXf+njn/x/vfz2S/8DdOjQ8SJIgp7rK7WGR0bJ8sSlHZ55cQ5SpUqW6YmWV2r\nUy0PUKsNsG26RLvV4djJ41QqFUZHR1FK0el0KBdHqVTKKD3NxcunQCWUSj4Fr8jpU6colgrs3j6M\nbXzSXg8bl26vhzIpVatKohWVoocCrs0v0ut1STNFZtdox+BXfMZHpojamsxYqEiSxAn1+hp333M3\nUzNb6HZWqbeadNMM+c9YN1Qi0qbEdMFyQSbQXYckMphM0G0b4h7oQh7X0FmFyjAMFUF1JcZbYnQG\nRsRVwpPPMGw0TjBOHLtUqhXGuqCFz+pKzMKypsAqFX+YnnL5zUdS/h874Y5Dkh9+OzxwS97ierEL\nfsmmEXfwyqDCnOc2XAU/AETuoxWG0FzNDQ0zBE4ZxrbA8qymFwrKRYEnIFUGy4FCZXNTk2YGIVPC\nKCZJuszPZdzxulv5+Z/5BRjezer8LHFnjTTqoFUKKs1bE1qDTrEcH8v2+62cXFThWA7ThSI/+a43\n8tMfeC/PHDvPb/3H/8yxI4+z/eZJBgaLCMvC9Rz8okMBjeu69OKENM3IUpOvvzBmpObyzu/aycp6\nm0dfuorWLraUzIoO2weqWKGkMmhx295Rnjl+BUtUcLRNmqW4UhKqXP3k+y6feuQZ3vqvNzc/wt5A\nYflm2t0GgmvxzYcYDNLuG+YagyHDGJEfprMMY2CwdDPu+Yu0eutYImZyeoY33jTOlLrIQyu5J06n\n3SGw29iyTcuGwIHlhiS1UtbmMq6teoy6EaM6giyhvrhIN4TxSq4UdkKNHxhGhzVuySbttekmhl5n\njfKgwx67wWRwgL8LbC5ET5OpLLcv2cTIOinSAmHLnLzd0rmbepgw+qPvzaX3KsZlEXSG3RWsMkVk\nuVTqL1Fc+wxO7WbCCxdJTz+G8MvYo7vwihMUxiZI05jSbW/HG99Oc/kcvYaDt+MBGL2ZQjDEeH0Z\ntbrE8Pg47vQQwmiyNCbuNvBay9iZJlu7jkkBDbFxuN6d5O43v4PJ4grm8lmyzhpjM8NknRX80KC3\nb6ex1qDWWsOzDcb1cgAm+mc6LW80EKS0+7C97MNpur9AcuRGk9uJ60zn5FSRW2wnJrf1TxEYmcdH\ndBdhKugxc2KOxu5hLl05jzu+nSScxW2G1EyI7VTQZhKVPY0lLWQRRFNTrRT5hZ//ae59wxsRls/0\n1BauXj7D/OwlbK+AbZkb4Y5b908zNjzCgG/zhU/+F6zsKsXiGMjcmA+dS8kLBQv0t5Yff6fhG0WS\nKPzMEJncusvNz14QpSg7j56wlMGWEqE1tslbKjWtCS5eor4QMXPfYawwJqgVUFqg1js4jsQp++QZ\nJxaPv3KeXTrhkU8t81PfW+Jyo8N0tUh4+iyu73DiRUitLktrLSp+wHqYUCi8jumxYVpRTCjAw2BS\ng23nn0dmTP65GWj2Q2GFyCNDUr3x+ecmke3Nd7RQOl8t9Hk8eR5M7h66YUtwA1MROUKojUFmEf/+\n93+T/zI8wOUXvky7F1K/9gwffOMbeceHfpKvP/znJN01JqZKRKtdikXFV77wLC989d9RHRrklsPv\n5V/81E8zMDZBpjKM+gcsmxtGi//gE/8WERH/qzlhDhtt8JIMYLTYdEvrngfewCOPfon3vPudGFUi\n04KgNEgUK3bvmSZJFPVGm1YnBiy279pNFMXML63Q7V6hUChw33338vjjj1OtVmg1WxQKBTKVIaTm\nzMnzTE1NkmYRwnMRSKngdQAAIABJREFUysITDo12h+PHr3P7bQdR9hq9qI2KEyxPU29EFIIWg0PD\nuI5HkkC9k7Bn2xaefP44pUoNLygSxSFGJGRhRH21zuzsPN1eC9sxrDfmSbOYdqNFmiYEQWFzEwMs\nzsakkcEuSlqzOWG5NgrFAUG3kfcoaxOGaMXQ6xqcIiQdycx+gy4bGmvw4vNwYM8pnOS32Tv9QXTH\nwx6vYYkAGQQ065pWG4JAIDJY665SK1RphAX+25ESgVdn24xgfFTz9XMCVcyRocC1GRkcZGm+haMF\nqgu91ODYArsKvWYeOSOkQCuFVxKMlgQTe33mlhPKIwbZEzQ7kMaC5urmFk7usJxRbzQZqdX4/F/9\nEXfc+1bCqy+zfv0MRmfkkX4uUkqQNloIjMoQlo3lBji2Q16m5/5gwpJE7RWiXh1r9hiHKwM89Nf/\njadfOsOv/9ovgHKx3Vz84DoWSkOr3SPLNJ0oodWK2b+tyu27xmi0Qz771AUcy0FrB9uCoYEiQcGi\nEJTZMRFg+yGLyz26XUm1ZJNkmmrRoxG3SU2GaztoLXn6yxc2vXYcJ79ElXlNdxte49L+qkjhBkpi\ncuRnQ5gjcRDGYzq4naocoyQG6YkQowS6OMi77zvErbUrRF2Pdr1DImx8K8pzJGWRRpbRrWu2DWes\nhbkasx3btJZibop6uCLi+oUrTI2UaKTguQo3VswuCjo9xZYZD0vGtNe7tLsxlYGISiGjWs54j7yH\nR/xBznb/FqE2h57qOEXZ3HCijsMYKRTO4E50888guDsnzKku6AauKDLV/DMIFXrNJlvOCM/+PXpo\nJ8y/QjayC1kZxysOUAsc3MEhyjM7kSO7kNImy1KUX8GpDFDsXmXLgd0IuRXjuWCPo+OIjjWFbc0R\nlCbR4x/FPf6nJE/9AcYGk6UU0xVKJUly/gW4chwRjGB0D6s6ggMMTO4imDtBsLCM5VXRlo1KCujk\n28eSfGfSshCARhiJkBmg+nEAffWW0RiVgTZkQpLEWX7R+7rvZJkzyyUiD027nlD90z/h7LHTiJvv\nYam1SMEusNxqoSKPc8kwQ/E0tdYsQ8WnMUGASRJUF1KtObTT4RsXG3zxb79AmkUsr6yAgqHBGtvG\nHVwElhAUA8nEQJeks8Arp5cwwmLL6Hh+YxJgdB/a7HvevDaT6Z86MqVxMoPKFJZloQU42hALg6sV\ntjJExhA7NkKluNIiFhoPwZotEMcuEtx1C36U4Rc8mt0eKjNkWUIhk6RRwp4P/RClwKf72FeIZpcJ\ne12+cXmBqe17eaqTMnzrfVy8cBm1fJmCL5FJQiwkeA6rn/oK1kuXsd/2FsKxSYpK0bPyqIhQ6xtB\nrsoYEgSZMGRoXG36ztl5Blpk8vjQzQ6t8skWhtwyXcAGFGj6Vslig3Dcd+oWxuQ+T0nIz/67/5tf\n+7kWpeWzZNhMjCU8/tDHGSpJ/JKPJwqMDVqMTli8boeg4xRpxxmt7vP89kef48CB+7n7Le9iYsdB\nHM/Ks1voFzH5y95gH4uNtX7jKNhf/2ygQK+9Mb36TPc3zs2q9rMk45ZDt/LEE0/zkZ/9GU6fPE7R\ndShXKszONykGNtu2b6PRajM6Oo5j27SSNq7jMjo6yurqKp/4xCd55zvfyaWLFzl0yy19T5U1Go0m\n+/bvo15fpFwpkXUalAtDqJJN0G0T9gzPPn+Bg3t34tkB651L1EoFRoeHCdwhOi2NMS08z2Vxfpa1\nlSVavQzXC3BtiygCzw1orLdoNlq4jodTqxInHeIMXMfBcV1sx/kWJo//+OisCnptgbQhDA06gfq8\nQFoGxxOUa+B2IHMFB+4aZ60dMr/QwPLgzQdvZdJvInDBk7hDNb701MOceK7BvrsnmBwPWLwaEiX5\ntT+9TZI1LRYXMyLRZGx8hp5b4i/PV/n+6DIN2+F/nE45sGMcOy5RERDrFraTobIerhBkXUNmeUSJ\nxvZTLAdUokGAp2B1yVAKNHYm6K2AE4AbwHDJoNPNzU+jlbC4tMyPfeBD/Ntf+ijJ+hrrJx5Da4W0\n7L6wJMtVgypXkTlOhuVoRLxKoTBD3DiNEQFaZ0i7gjEthF1BygIaQa+9Rnz0S9y1ZTd/9VeP8tM/\n85O0WsuowKcb5n5YaZqSZBqtFd//pm1cma/zwqU6aWZwLIcs1Xi+Talg5cWO67DcaWBjs3e0zNHz\nq/iujSUkcZZgeQ6tKMRzHHzL4eVjs1zcvMAPzzH9+JwN1Cb/ujGv8uxumL+zUQDljxSmz/lLsPQo\nIQuUdRWtV4n8rTRq8JbtAzjNc/gjhri1QsXrMbeksXxBQdio0iAF2WFs1CVNG1RLNm4BGh3N6Q5I\nNyDNNK4LgZtSbxl2TtuEPU2SZVxbLSKCGiOlGRJfYHiGxfklpm4+hHAshlWbN3gHWNUniZKzm5ob\nk6YYLVEahMlwxncR7Hs92aWvIA98Hk0PGV3FNE/C+iyMv4/MWWT9xS/iXj9K9/hRro48yLi1SjFs\noZbOk219HTg2vRe+SHjxBcy2w9j3fwhnYBSkRdzrUGuepLjNQieLYFqYuIBwJunEMyRekUL9G+jk\nq9BZQKRTCHcMEy/hCdjzpu+maWzkvgdhfB+6tUK0uoxxB2DHOOO77iQoZljN55C1SYzl4/iC4Pbv\n+bbz8J0LHmMwIgOdx68LQGjQUuOXBpC2lds+FwqUysOs1xd564+XuHz6NNfPXMBzQFgWWhhMK8H9\nyM+xf3yMRqOIIGW85RAWfewsIUsbeF5KgyorYoZEncUjw0gDKUjL5oGbI5675JBKG+3aVIoDZDoF\n47FtOMNVGUoYjMlor10FaaG0wfbykDEpZE6iJnd6DqOE8QGLMJGEyeYq5ixRyEwjpUDofpQBOR8m\ncyxQCtuAZRRS57dJS1gkZBSSjDDLKDdbqNEBVqIQlRmMylGgOMtIJZh2mzOf+xxycZmG5+EJyVfP\nrDGqr1MlI5m9ws6b9/Lpl1t0Fq/TyTKqWIRxSiYtrpw4h33kJCO/8n/RGxrGpIpMbtzCNYnJ84CM\nzvO/BNAzCmNstNH4ZoOevnktkk5zObq28jgIITdg4w3bcXK0pN8qvcGrMQKlBVkv5Sc++h/49L/9\nMTwr5eJywq23Hebi2WPoTLGy3qZQ9Lh8vs2JbsrerSnDAylpy6dUtIlWniBefJEXL+9iZOaD7Dpw\nAKVTbkRFvLZFhcgX9gaCeeNv0a9/XhuTkf9eygAm9+/Qm9SlW5ZDqjRTM1v5k7/4FP/z03/Jb33s\nYwhhY6SgGFRYWFgkyTJmZxeYnJ6hVCrTarfoLvbYuXMH3/M9D3Lt6nXuuPP1NJtNLOkwNDSEZUsW\nF+fxA49aYSTP47HLONUJbpk+zOj4JF999CGunHqUHVNDOH4FIQVFx2Xfnr2UKxUWFhbIsgTHdUlx\niLMUO3DodNps27qdRqvN6VPnuHz5GsVimTvu3kuSdrAdH2kFFAo23W63H2y5udHrAZmhF0KvK1CJ\noTIEcUuQNgwiBa0gszVLSy1WOiE9BWOF2/nQD3+C9NIjJPWLICyuX5tj3wCs7TAEwsZxPMIIpAth\nU4BSpEg8T7LW1ExNeNRGtlMb9HlB7GNsyzbCr/wBWTugHcWsNyNW1tvU10NsoagU8mTrdickTsHx\noejljrNK5TfQoS2QxhmTg9DpgePmr48Um+bGnbywzH//T7/CBz74QbrzV0iTFCwHow1SCiwrw/fK\n6OYCOouxHJdw/iJZ6xxptMb1eQvLalC86UdQV57A8osUqhmy4COkhV3ZifRLGG+I3spVvF6LP/nj\nT/JjP/bDNNsdkixBG0GYJuyaKrNlyOMzX7/KgO8hpEZluWrRcWw8z8L1JFHYIwptfOmwf6bAi2fn\niRObMOzijldIlaEXxQQFhyxRZKnhsUfPMr3Jdh/QD6Kmz8nZINrRL4Jy2kKeGvDNhxgL+ganuZDF\nkR08tZtF2qBn6NR73DY1wsVLl7il0qG1pvA8i/nVNVpxkfEtk5hSmZHpCfbsGWE8OkfcgG59mV5B\noAdTzl+BzPFYWl6lZRy62jA6WCBOeyy1BRM7xpjcNoQztJ/y5G6GHJvLJ8/Raq5gUo3utWg0Esq+\n4qbyAY6bzRU8QoCWLsJkWLbP4Ht+CREkXDv1PNlvvpmB3btYfPk47pZpossdvC1dkh3j1FcdVjtb\naOz7MGNyleraVVKnDO0lVK+FKJdJpY/nDiDmzqG7bUxtGAs7b58qjWm/AIP7MMEbQJYxS5+kHDSo\nLH8aetcguIf0YpfemWMksU2t0C9Qr75EeYeDs/Yy4vJjqI6FSC0yAjSSNG2jrp/AbzUg6iCsDPEv\n/wNXJz7E9m8zD/8IadnCcjwsv4Bl20jHRdoBQuYOo0oZLNtgjI3l2BgrYPuBW9m+/yDd+jyPP/S3\n9NbrWI6kU4fvfeANLI/4TB8+yKlLzzHij7FmegSRwTQ0LR1ytaE5UFyiEdyMbh/FCyxcC3ptRRq7\n2E4Byyhild9MszQj8C2yLMv9BSBf6NJCaY3KNEVPoXUMlp8ncQsNRiBtm2uLMYWCYXJ4c/C7m8Yo\nI0kUOHa/1BECtEJbAqToE7mhb7mIURrHZKRCEhy6mfWXj7L4+AqOJVG+Q+C4pFmGrTWJ0YSz8xSi\nkFBAN4xYz1Le9LpJHn34MSaqRSYmJri+uEwS9QjbhmLRIjIGmaUIyyXKMizLRx1/BfmmBzAiT8i2\n+rEWcX9T8IVFYhSpyYP0BBotDD0DDiZXL2xyZFnO2REyRwQ3bHAQ9FugG06lpu/WudHeygvTVBnC\nRPLge99P68ojeMevc3G5iR8Uub7eZMv4KLbtstBcZX5hlfl6xMjwIAe3ZsyM23QGNX//yDIFZ457\nSzVsdqM2+Dv0i5h+i+qbacyvjcl9zaXQf/u5mRoYI1Gmb/u+Sf6X6L9mqlOCUpn3/8RPMnPodn7+\nwx9msuYSxxHVWhUSSbmco2tKKTzPY3FxkcWlRbZt3cnevfuJo5TZa3O0Oy3uvOsw7U6T5eVF9uw9\nkKuH0gwZjLLrwAPcd8/d1EbG6CRdrp39OstLTYwlKRZdjNbML15kb20fh28/yJHnX2bP3gM8f/Q4\nStiQZfi+R7FYZL3Z4cyZs3nUgIyAjCTtYrseWhmM0Hiet+nIDcg5O61WTiwWicEXEC4ZhDQoJVmd\nM9hFw46REodqw5yNZ+kIzYNv/F7W1pc5fXWZ1bPn6HRSRjzDwNgAu70hCpUhWj1JsQrdDpTKmqFS\nlZW0R1zMSJsgbIvKwADj4zV2bp0GLSl4cPTCedprCVkMrg0lActdgWNpCoFF4OREf50ZjJNv1sWC\nQCY5iFgI8rRxT8BA2abTzWi3wAo2Nze/9os/wwc+9KP05q5iTJ7ZpEzennIs6Jx8iujq4+jqvTBy\nE5W9h2icPMmVI+fIsNn97l+mc/U4dTFO1a1RX1tmfTHE0isE1SLSPUZxfA/FqduwStNkUQv7+kv8\n5m9/nMN33M2u3ZNMSof7RrcxIQpUWhX27xrjCwtnud41eDZ4jkPgS1xH0g0THCFx0HRFxHqniU4E\nb78FVr0DfP35q1R8lzjOd4RyyeP5Z69QB0r/DHQQ+keVDSSf/sM+eruRS3Uj8qD/E/pGi9uA8UnN\nJJncT5LuIWx6jJQv4Do25aBIu7WEkDXmG10SEVAZHKGy43b2TdXYNVVloKBw2w0un7iKiDv4wmXn\npM/1+RaZyhgdrWENDjKfKK6sddk6CIdukgxuGSAqbqU2Mkh1KKXXXOf2H/pxXvjM/0BpRdJaoxiM\ncPS5Jzn0tndzwRnY1LzYgY9wy7ieRGiDHJ4iOv5JvMvPc+n9f4Aj26ze/iYqQxPsf/9+Lj31MEl3\nhfmFOWrbZph79PMM3f92ms0BSvE6Wa+JCdvINIFSBXd8DB1FmEIZ2Td9dHyfpB3SPdXCd34V6/6P\nw8zbEVMfhe482ZUX6Xz1YUT0MMvug8Qjh2FCM7DwKZCQnX4GefUpKFWQoobRMVavBypEyITwyBdw\nTIhTBKMzxPhNOPWn2Dr/G/C26FvPw3eapMrELhAWpm8qqPsOmkrlC9QiIe1JMhWy2G2jVYbvFkl0\nLmfutOrYlovuJoz91L/i+dt20Dp/EddewOt4rDcWsJXCK3qoUkotajG0c4zSskXLG4G1KlVWcQJB\nOQiZiGKUDHGdIrKb4EhJikT6RYSIEX0FmURio0lU3kuP05D6WkZQUvi2hbSdPKXWkjiVApnSXF7c\n3OYcpBlaWAjbRihNKvLWEFLgacAYMtfCxZBohTRWTqTNEkyY4Vkpwe0HybTBihVpEqKiFFuCyBTR\n8ROES2u0kph9B/ezMH+dJEz58pe/gS0Miczw0x4qVJjlLhV/lEZ0nYGKRy9NsUyGznJEY/Ghv6d8\n9BSlD/wgemgYpXLzKokgNgIpMpQ2RCL33QnIQ1dtk/d8N6+1AbTGaOubTlk33JA3kBXz6mbERkcJ\ngZAaW8BAwdBwXKanRllZj5ldPotrZ0yOj6BUwlq9zdrqKlt33kaatEiiFpdWSiw0W2zbNcjgTg+v\nVKMVfoNTL/0eO2/9yDd77tyQqfeLnX6OlvhWRU//54zue++aPFhT69ynZXMj32hty8IYUFnGXa87\nzItHX+K+Ow8i8ZhfWKA2OESpVMSReV5NGEbs3buXVrNFt9ulWqkxNzfH/pv2Y1mSufmr1Go13vKW\nd3Lq9DkqpTLVwhaGxiYhXOUbX/8b3vq2H+DQ/l18sVil2+rQazYZG55mdWmWOGkwM1GlW60wvXcn\nS9euMLu4xvTW3bQ7bcbGqrTDJuv1OnEnwyn4GFsQ9lKCUgkLBaIHAnyvhONsUvpIzv0qlQVpYtAF\nQRaD0ZKxkTKzi02QhqoD991xF3dNDXKosY+19TrdI0/yd0efYFV5iK4mnb3Cjnv28vjpOm6xQtK8\nxPTEPgrSZz2JKBgfFXvYOsOyNDJWpL1cYJAmESpLKZWqWI5NEiYMSZBlQRoU8KIeaxGEiaHTy7Ad\nKAQiL/KxyLTGlYIggGjJgAuRkxOf49Ah7mV4HmzWl/Ejv/iLRMsLCGn1UWWJlHmsz/LTf0Msh7k+\nO4JYuETvSkzjr/8/guoIyEkmRwdpnfkcl06eRdhfQqWGQiDAKjE4cA+tlWsErKLCOdJWncrWW/CG\ndhOtXWFsy8381r/7Bc7+1eeZqY6zbes4A4Uqg9v2Mmy3eN+e+/ijv/kC//nhc+wZH8LolCwFWwjS\nzJCZlFt3D3H8yjpGSNZnl9l9/wFOHvcxKgUchIKlpQb7DuzGosajRzaHYABkNwxOBVq/yt/JL13R\nV1puKJXoY819Dmr/2k/ZjWveh+pN015dob2+yl27BinGK4Rao2xJz5R54VIDhrZz7+vuYseWCpOD\nNq6ruXTxPC++coEdtmS0OorMGjhezJ4xGBktM7F9F3fcsofmygLe6ir7bt/OxctzfPGS4mr9KFnl\nOt99+CA/8MZJrMymsvUw0CDTPl6yRmCF2OvL3Dby7ds232pI18UZnMAtOajZa4DAHt6NdAVnH/9z\nWhfPc/ewiwjKnPt8D1tnuNsOMrZjmqKruefgTlZVytmVlNd5EabXQoctpEoRnQai18OoFOnkHDGT\npug0Za2yjWyiyGB4mNKf/AFseZjVWEC7QeUNP4X80I+ynIa89MRDTLsxb0g/BYffgTn9MIQuQmXY\npRFEUEZdv4ZJwtwwMnCxlYWRZaCb7+txFxPuwJp+w7edh+9Y8Cil0LIfCKAUSsUkKgEtsbDJTIaQ\ngsBzcvtuA5mO0QkkUYpSYNsJyQqYH/hhrp66zMDYMIXVOq3nn6NbX2f3zt0sqAoeLv7gGM20x7ND\n7+N10UOI0s2MFFdQUuN7p9m3bZnvvm+KuXrCsRcVjgvDgcuQ3cIjyZewEBij0AJUljdyHdsmCAQq\ng44BoTI82+DaHgiDYwmsTdJUYqVQtsZKNTG5u7ISEq0N6AzLd5ECAq2xMEQ6JRFw2LL44JjPb11u\n04xjHKVQRuBKlyyQyGaT+a8+QdhoglHce//refEbz9OKEzyjGSrtppu1cRPN0lpCktmURY2esbHj\nFB1H9HpdLNclMhrbQCgswguXWfrV/8S29z6IfMN9dDODtASOMYQaemi0FkgDWuRp9hKdy+s3f1DP\nc1iMBN3nA/drH0vQl6ObG6ewG6rwjX+FQGlJ1MvwRYLwxrn11iJTNc2Xn7vOK6cug2XRajYolkvY\nepEgsNGOjVEWvqXorHTZPV4kmA6JQolSx/ICx+jXbILmNZSdPrW6//VXc7Ty75o+61EDRufeU8YI\nMi3I1P+KCH3n0T9p9osrWwBGY9kO7/3wz/LIX/whJjE01tuMjxXwXA/X97l69Spg8DyPKIyYnCgy\nNDTE4uIiK6srKBX3nawlYyPDNNbW2LZjF+uNOllzlZVWnYUrywhtsXV6lFeOngdj0+2kzGwdZbQc\nEM/NI2sTfPXRr/H6227iwfsOc/L0NbaUBjl88HZOnj9OY20ZEfUQvoNSBt/3cGzwPAGOQqsckZFy\n8wtHCLA86CUQNnKbC6E0aWpIk3yeE2144shXycx3cebIGW6fHqKTpgwffD0lt0w5miMYzzgy2+bC\nYkQWNxgdG2DHjA/aI21E4Fc5fzahnXZRRuIVAQmWJTD91rdlS8YmanSv1RkvQAdBW0osKXCFuUH5\n0opcdZKB5UOWQS9SZF0B6wJRgKVIMzPuEXmGLDWIusAvb3J+uiFGi5wDp/OLR4cr0FPUzSS92Uu4\ncp1uYoi76+y77RCeyGhEJa5fOcVINWR0xwHOXr5G2OhyaM/dRAunabVnCXsJJi4zXtmBPH8SkXwG\nmd6HM3ov6ewr3HfPW2h9/LfZdmgCqS7QbldJzi1SPHgvekXw4Xc/wK6hcX790aNU3IAkTlBakCSa\n8SGf5UYXg8C1JWeuF9i2tEqaxJQKHpnKrU1KxRrf/Z57ePBtCSfev/mCR+s8KHZjbACMegPq2biW\n+y3oPDSk304XoMQYMvt+mvM11q4/h05DbGMQ7EZHHaSJmdy+m+qgZL2xQnu9wdriHHvGbXqRZGR0\nKxcsnwvL63xtrcehAcE+qdkWZ+x+4yFGx7dglQc5eP+bWTr2GAud62TC5rm5aZ5YiFmw93H39BQf\n/aNH0Mk93HdoC2FjGW/3BCunTiEHq4zWDKXAZoqJTc2NkA6qt0q8vIrIgO4K9thBKm96HzujEpZM\nKF09TppYNNbaJGmX0k4XvbBIsm0rztgUMtbMlA1JMyHrNpHddYzKsNOQrNNGWR46ClHC4HgBYWOB\noHmFhZcf50RngJ1v+Vlm9uwlWl7FlRFCtFk68kWS8y+zerqLtWOa6O7vovjGj2Aq47T/+k+oDXpI\nx8JkKabdRsQJOAaRGhwRkWmvT0+Atr+N+aVB7BPXuOnAt56H71jw9Lr5EUTK3DlTaYHjBEg03V4D\nYVwsxyWJIpRROLKA7fkkWQdL2rn3wkKX0q//Kpx6EefAbSRikLWVF3AnK4gdW1mtjmOldRKnSN21\n8KMYENSTae4c6SCT5xFxj9TxcYI13rUj4eSUh1+WrNdjJp0E68oaGRZCaqTRaCmwpI0mRTo+pYKL\n9G2KMsiBBZG7iLaj8MZze5NGPMKyMJaLSGNsITHSJlEppSRD1cooAcVMEaHxDDjCopXEvCNd4/n5\niLeOjqGilE8vtigbTSo1thac++znsaTAVgmu4/Lc40/jSShlGUJKWuFVbHsIpRwCK8hzjoRAZ4to\nI2i2uuzesZX1ZpdkrU4mExzHxjGCrusz+5m/o3zyJNUf+QA4AWE/CM8yeTEitSYVeR/cRpCiwGye\ntGzSGGwbaUmsvvmFeA2a8mr8XD7/0myQh02fSA6Vik1aGsBaP8n64nkwEfftEvzQ972PqwsNPvfw\nSzx39DSq2aboOYzUylQCzWK9y+XZFkurHm9aHWTnXZKhcA4hHIyJ87wtuMHH2XB+FjdQH/qPX/sL\n5eoPra28paUFmcq9hTK1yZbWN8VX5NJ9aeWb9fe89wPs3jLG7/7HX8O1BWtrdQrlMkppdmzfwdFj\nR7npppsYGh6g22vxwpFvMDo6wvTUFo6feIW9e/dSKhVRmUPSabC2NMfI2BitqEdzaZVaMEypVKbs\nOVT8ErHrsbraoTk4wNhIAenB0tIczU6LY6dOMz46ysCQC67F+UtXWJpf5O0P3INvNA9/6cuUaiOE\nYRcjJX6xgGVLOmEX2wpwXWfT68aWhjgUdNfB1nkx2MUwe70BrkD6IFwYnhylUA3YP1piS8VhascA\nYXKVARFQ8VOOz2uOnFul2+pgqRi3ZpGaXJ0yUoV2c52BygjadHH8hFiRy3SVzmNmTB4nEDhVipUW\nFTRxYuh0UwpuHoks+kskVeBoQeAaMpXlPj9CYFsgHEOswZOQtRO6DRdTNFQKgmSTXZs07iGtnDcm\nbINqXKI1u8jKya+TJSGrS/NMHPxeLj79EKWiJFyrs5YVsMuS0ZEywt/B6uICywsNDuyusbpcp7HQ\n4ObDtxKEXZrtmJ5dozp8mGuXTjAjrzNUWSBrDzG+9zDBFlhuPoe0PJzAp1IscPHs03RaW9m27xD3\n37qDM5eWeOjCfM7b0yBtuGl7lZcvtvEcC2nZrKuA1fML+H7E4ECVRichTntUCxW2Dgck1YBnH//Y\npteOUrzazuoXORtHkRvP+09Mvw1txIZ4AlT2LnoXLdqN4xQDSbMbgjAMyC6tXpOdkx7Dwz6mHBAm\nMdovMzQ4yJlr67znB9/OL//iz/CpE8P8+v/xLr7xl4+wtn6Nka0F9kxWmb7z9QzecivNtsYtlqnO\nniHesoVj9XGccc354+cIpjuMDipuKrVojW9naWEe2ZunYFfp9K4zOFzCKwV49Ng5sjlTTy0t6IWk\na03sLKL5d7+Pv+sAJo4xjQ614TLt0xGmu0B5YIjWehdneY7Re99Fae/tIBKqZ55AzFlkiUKHPUy3\nicySPFEhU/SEJAiJAAAgAElEQVRMgj71WWzWsLIYu9PCUw5JOkBj7hrJX/4q1nTIlkIFtMRpXmGH\nv4XErNPZsQt7tMILz1zggfJ/Ry9epdFwCdsxU2IRHaek6zFEClMtYOIMkTVRsYMZyA/U1fknWev0\n8D/wu992Hr5zlpZtIYRifWmRoFrFtl1MpomUITYuUguSTpuy52NZFhCjMg+BIEsjuhe7+B/8IeSe\n7ezcGXLBrjDQOoesJUR6iO5KwuLLJ6nesYNi6zqqPogMBI51jOHiKtnKw6RWjVKhh+7l3Jg9OxfZ\nGS1xKNhJkkQ8e9nlutZYroVlF3BtF2FbONKmqyMCE+EFFgjTb3fli9+2BKXAw6AR2iZJN9mXUAmB\nEES2QAuJkgYvTBEqw+2FYDvEvS5CWmS1MqmOcdebLBQ1w2WfYthk/uIcpWKNIAvpuFWssE3JsomT\nKE9h7nYwSiGQZErjC4vMRGTxdaRXZC2RqEzTUV2kkGgybj10kOXVFfbt3kYYxdTX18kMJFbuhI3n\nIy4v0Pr4J/B/7AP4ToGkD/dGUuGR53slymD1k2+9TQar5kMjdD7jG+0iR/AtUulzHGVDFa5Nnpdl\nWxoiaNcvwOLL+MUhyhWHlc4q2cIpJoMBfvIH7+T+O7fyub97jvn1DgvtHt1Usme0gLAh1hkdE9Or\nD5N6irjbxSu4mD4nKX9fGyq9jS+Y12RviRvvUZN7CGlD3008z19LVJ+vtAmE57XqpRsu0QBGMzZQ\nZtuD7+FrX32CY19/lpnhKZIsY2CwQL1e55ZbbiHLMq5fv0YcR9x77+s58sIRtmyZ4M4776TVajE+\nPs75c7NYwqbdi/E7ITM7d1CpjmG7ASvLi6yu1vHcImmcIV2XlWaL6y/Nc2hrFSVmaRjB2RPnuXjp\nEaamtvLWt72D7tIcabeNJxO+5+33cdst+/jT//l52q2QSnX0BkKI0IRRF9/bfEurFwviFgQFMMpg\nI5Dd3My0WoFmy7B31wQ/8ObvZ/3cy8xMDTMyVcWUXMJeQoTHsy9f5tG/eZFGDAcOVHNFaSdDZIpK\n2Sa1YLxQZMuuQa5eUowO2iym64SdPoPLSJRRGKWRlo2wNZ6QuErjeg5CaoTMeTNOfh8hTaFakGgh\n6WYZRSc3JNQasHPzRCUNFcch1TZZpLA2OT25QjaPelk/8RWuP/bnsO2d9NQg7ZWzVAYqnPna5xgZ\nn0TYMNdIEMUAa3WZbldTli+RWePccuhmWqnEUz3Gbr6XuYvH8bhMbdt3kamUE09/iVvf/cN06v8/\nZ28eJdl113l+7n1r7JGR+1pZlbWoVJKqtC+2vBtsyY1t3HQDNk1j3MeHbgaY0wxwmOmZ6WYaOG7A\nNGObZjA0jE23zWbkRbZkydZi7btUmyprzco9M/blrffe+eNlSR7G1jh9z4mT8UecjIhf3Pfub/ku\nTzIsLFSaIoWhEYDTnMIvDlMrj9Ffa3Mpusxm5xVWe3W0GOKd8zm+eNxk/lM2HDkwxMnLLYbLRaJE\nYTsuXtUmb0ZR/Q0a4gzVyggDEfH1b1/k5jevMDZVpq4Me3a5d7QB9Ov2r5A90ca8ZrJqdm42gu8a\ndWEI4vcSnRlG9U/j+z71zU1ELMmPTbDdbDMxXKboK8qjk7SDbQpDZVha4/KJp/ix99zMf/rPn+KR\n1j6YsJgKV3nwi7/D7e/4MFv9AaXJa8jPHEDuvZlSfZvW+ja5yf0U6lt4lzu8cnqdw3M1Olsn+cs/\nf4af+Gfvpt3osnj2O6Qba5SHrke7Qygnj3FDamUYBLuzbdH9AFkoQW0M1arT+85X6T3yRWy/yOxP\n/w72vv2Ub3g3aX0d27IYFym600QtPcXm9kUsT+Md/zKxKuEkKWrQh04dmSTIzhI6uEB5apTC5n1Y\n9hDCzWP5Lp3mgNagzB13XEe1fo7i4DGsdDPDzMUWzPhIIxmenmB1c52JtEH6yN+iBJRsC8cvEW71\nQBi05yE8H2GDqNYQgxC62xj1+uHRjiy+8asf4Zf//sL3jMMbJjxxKon6TRrnzjF9w63EcUwoBI7U\n5Gwfg8LPVbGEDSSoJMEQY4yiW9/G/x//FaW7fpTKTIH12GGgYyxlI3PDJJstTNomf2QPnqVISzW6\nqk2+H1CSKQV7mWa/wuR4Sj0eZdTdotsbQvZOU5mC/eMvEsi38PLLF7GQWGTtaOnYSMfHIEhNmKlD\ni8wAFTRGitc3vbkyXNBYuyxGTaIQJsK1LAa2Raw0ji2xE0Pa6eMAkSWwRETahg+Pl7mYWJwZJPxI\nGlKyBJfDPvuHRjmtLXKuYe3L3yKOQtI0wRLgSkkqMtq/bSRCKWxhKLgOYTzAtWx6aZphqmyBVoJT\nJ08TdgecPXUBmfOQxuAFEdrXRColNik/ddfb2CgU+Zu//AIzH/sorlJoBL4WJCJFqaz3osla54Mf\nAsUjRTZevILbscTODUZmkvRXKOlc6awACLCE2El6BKEeoPqrpLJC1bVINlYolT1MEKB7PUItmHZc\nfudXf5R6I+azX3qO4xfXOWUMw2XJ4SGXtB3SSvsot4SrXgJu3RleZeBpcWUMypXPw+sl4ZVOjCYT\nTDTitWQnVWSPVBAnP8ToZuf9r7yrRmChqdgp/Vjzz/75T/H0g4/QbndpDDpEYUSz2WRrc5Ojx47h\nODZxFLG5sc6e+Tls2yYIAkZGRqjX6/j5Es1Gi36rx/nLazSimLxfwc0lOH4ObXsY4aB1H6VTHK/E\n1MRV5Kw+/cEAmVqsLHVwrYMcO3IzUSsg7gxYmJukubXGcG2IajnHDceu5cVXz5MkguHcMMaOaCXN\nHdDq9wYOvtFKNWBBbxWcSpYkjI4KggBUCvnUkFMhX/i/P8PcyD7qjZDiUokf/cB7CUyfsdkF5gYW\nb5tb5MlmxLXXHaBeH5BXGmFJaqrG6kqPQs3FWA6FUomiJalqi36SIowmCgMGvYBBboAvi7g6ByKg\nmPeYNIIkMJRsSLXJZDHygkZkSLTAaEWSAJYh7IOxwaoICq6mYBxcz6M9KBCFA0YquyuyrgjBxq3L\nLD54P72whlx8Cuw866ur+Dmb0b0LJCZPb2sJrzLJ1EjM9pagvlanMjqCwFC/+DzV8YOE0YDS4CJR\nZZq0lSDW7iOSh9l/90dpbZwnP30jRsVYuSobl1Z4/Ntw6DaXWj9modRiz8ICpeG3ceHc01yIezzy\nyNNcf2Sa66o5Xmh0mR3K0+8nRDEgQmzp4loJe2uCiX3v52Na88DLD/LAc/9A0vF58dWQH//IZ/n8\nf/lZrj68ew0nKUDvGN6/htMz2fX7mj7xFUwhWfLjCQjkncTnryburiJtcIyCVFCqTXDV3jmOHfAZ\nTi7R2NrAr+a5vNxkZrrC4OFzLFwzwV9+/q95ODnCdfZJlpehsT2gYqfU5qeZu2GK8ZveTv7QHaB9\nhDdCdX4YMzFJv9GmVn+ekYlZ1je22epa3P3TH2B0uMjcpS+TS236fWitrSDcHNLLM+R4BJ11gsDb\nVWx0vw+OjfBymHwB02ogUx+VJOiXHkaGF0iPP4qDZLB4HPwKic6TFErEty6QD1ZwREIUbmPiBvmZ\nG7D37KNcjnA+/G/ReOjkMk6aolaeQ595AdKUIdFjf2UM0biATJvEA4NfAlnwMW4eFXRQXpVCocL0\nXT/G/GO/i94GIwXlvEKWbZzhEazqUGaNISW600E1Grgqxq/ufMHUsN2F50OHH7n16u8bhzdMeBIl\neeSFUxStHPOunQGVjSY1V4hQLipJSFSMZVmkKUgN30lTbv3dzzBFxGwqONvu4yd19sbnUUGRSHWg\nVGV4zwHieAOrnRBffAU3X8WYRc5onyX3doy+meL5hJu8l7ipWGFibgknVyZc6mBJTUk+SWO7ghHO\naywpYQxohRQWYRTjOmTdDyOwjMyEuXhtBM4VCvJugSomCglTi8SyCR1N1bIIhUAXfJwoAZ0ipUFa\nFuOu5pvdLuX1DT5+ZA5bxTTOXWY671MbtHj39Uf597/9SQZrW3ixxlUJ2s7Em2b3zLG5tU0xL9GD\nkDAMCJUi79mQxJAqfCFIg5A0CrFFiTSN8T2H3qBH3rIwtg39FNu1IYz5/Fe+yc/91I9zsFzj0p/+\nBUP/8sPYBmJtSI2hL0EqQAosDf0fgpZuS4Ets79yJwHPJls71RVX8DM6S0D+UYNESI1tF9lq1Vh+\n7u8w3QG10SpLK21KPriFEkYKNgfQ6sX0ttt85C0zpG+bp98NeOjUJpd6im6qOftgwKnTOX5l4Wns\n8u0YozCv1YH/WKpdvMYYM5AZKposwdEaVCqIlCFNIYqzRxjtjmCsDHgpxI6k2e8ylishJKS2RT3p\nMJQrc+DYtUwePUB95Qxjfolo0MGkAfNzk6RRn+bWgOHhEXSiKBXz5BwP7Rva7TZJktJoNEjCkH3z\nC9SqIwwPjeAUC/T6fS5dusTtb343SRLw8P1/yztuOcyxQzOsrV+ktZHSCWK8ks8db7oDszOWUZZD\nN0k5u3iSa/ZPMT52kIvrLXqDkMZGg6Cf0Kz3uPW2mynO7CW1vGw0xL27is2QDZNXCVrjhvMvQaiz\nxEIIQbKhmVxwONtp8pMf+ggzE2Mcf+DL1NcTKoUJDD0GPcWhQ8fYd+PzbL60zeT4MLblE9X75F2H\nQm2CBbvC8HCRdqiwtMdDL17k3MUeR2/YSxDHFDwHx7NwvTzF/DAbXc2mVWTfiMNMrcj2Vp21Toe3\n3nGMkZEa93/jW1R8QW9g8F2LnKuo7DF4ozt7vwidULKZBORFnzAApGB9bVehwS35hBsrLH7pj2k0\ntgmay1Rm9tJeXiL25zB0yG+fRtVuYO+hcbYWT7He6lPcczMT5Xdg+xadtQ3KBw6w3gyppa8SL/wE\npeYFxq6q0TIfQkQBa8vLjDh9umf+Gu/NnyQtzfCuq+/GHyqzfHaZiWKBk+t5xr0t3nLgLKNj41xT\nqNI5dDXzwzF3zZ9jUBNsbPUJtcF2HIqOw5sOu/hiLzPJQc4++SjzpSpHkyIffOevY2yHh99Z5zf+\nrz/m93/v85xYUUSD39pVfK5YybCTzBidFShkUM5sFMjOdS0yinqYvBN9Zj/Eq0hRpFwepZzzUQG4\nVkxOal69948ZP3Ado9UytuhwcNZh/dFHudR3+cuvneHdR6v8TDnl1dUyb56x+MoLA77woV8iVyrx\n1o/+T6yU9uAEZfLNHhYpTi6HLI6x9+f+HWOvPEbh/vvI6yV8dxy1+g1yBclg9HraJx5js2bx0APH\n6WjFob1j2LkRTi1H9Lee4/vAVL73SmN0q4mwHYTrQLmGHsRYpTKD40/ibh3nzIsXCbuSkdECwgyQ\npOQOl+DCSarvuIny+z+G39vAtYs40WlSd4FEtRHlYZK0i601gRbIwije0Q+TnFrGefS/Mh3EDPa+\nE9WPiSfvgPpxnLiD7YRQFjjuJP7gO6z86V/REjA2LNEp5Etg2S1EuwmtC1kNasBS4JZATN2CGJ5F\nFnMM3ENUcnu4y6tRrF/8vmF4w4Tn9PIaKw1F2dcMIoHnu6QiRqbZ4RXFKSBxXAedRCitefr226kf\nu4HPNxvcMTpG1G9Q8x2kI5GDIVaaAqdXoOuApMeI0wMZ4lVsilZCYvZhRyuo0GNGGqo5l/vD2yh3\nTzHWOQt2jLQg2dZE1oC5SpGtrofteDiOh7BchMiApEprfEsijIUkfs09FyEwOx4qemd0IfXuDi1j\nFLES2AgskdJEIVPwXUnfKLQlscmcydvdPnnP5iQeL251uJ6QXqqJLcnV1yzwta/cw+alNSqWRbKj\nmGVUSqpS+kGfiWqZpfUNbJ0ySBNywiJIU3I71HstM/f2SqFAEsX4lkQajXFdwjRBJgm+5SBSQ2o0\n9Vab3/j0nzNWqVGcnSI6/yqD2QPkpSDJGmIMMOR0dtOQ/DAdjOxOo3XWTpbsjKuuAAm5gp/5R//b\nZE7Qtuvw7P0P8l//4OfJF/dxbP8oX7rvFNfvqfL8GpypbxKZrBsURPCWA1XeU8kRWDma9Q63XbuX\nR46vQKPJUMWlVArZ3nqZ6f2ZyCImcx3WO+1tswNU3vkI2fDTmIx6vqO3o5Qg0QaVQpxmPkmJEpx5\n8I/hrb/wA8dG6xSwUUA/AVXUGJ2glEfULdMrwsMPv8SrJ5eYGR2lHySUS4bp6Wl6vR7FYoGcX8T3\nc1y6dIn5+Xna3TZhGlAuVZid3UOqIiozE7iWTRgF9Ad9RssVUIbZqWk8z+fM4iq2TNnaWqNw/T5u\nue4gi6eWOLnWxMrl2Ko3mZqexPd9wiBgdW0NN24g0z7loQq2V8VCZGaWShMmfTApxWKVqb1XUSwV\nd71vPEcQhRKkj13pE28KVAC5IrgjAruUEsYQDhrUWyF+wWK4onClZmPtLKFwOJektJqaXrnIIy8t\n02oPOKov4JzcQJk5jt1yAwuHFnjupee49vA4q+0By5u91zS1cBziKEKpFCkziu3Kcou4Z2FtdBnx\nFEkI9z/xIrP7LcrXQHMd5CATdpMSyAt6NkQ9g92HKBSkMfTtjNhgUoF4Y1GQ/89qfPv3Kd/2caLm\nGXJCERbHCKIcY0dvoP/st7FHb2Q9quJ1OqysnUBO3MKQ38aevI7LLzzH2HW3Ut1eJFk9TtXdwyDN\nY05+lcge5/KZbQo8QTWXY+zYh9Brhvz0j5M7cCf/9n03sQxMdTvInKQhHTpaEAiL/UmO7aUt5mZG\n6XZWOLmR4qPZ2ghRAoQyXLOvwp23TaN6ecpbB0hb2wxXR0iTgFKlSqMXUK253DkxwROf+gP++W/+\nDrHZ2vXewXw30eCKJMjr4GUhswIGI0BnhteqPYlKBqTdAVbOwYq6XGo2iWLBtQf28fYjQ6w0Jxke\nqlKuuCSNbeIkoTS9h8bgEjU81jYD9uRDDszOM7T6Mm05ihqbY+/Bw7w6GGKmJElVQhxnXeNOewOb\nlPL4JO7cYeZv2KB/VjJ+9V5U7zaCpZfotjykshgrwWOXDWPDORqtFMfuEzcaeN7uzithFGYwyGLk\nF8CxcWf24e85TFJfRHXWsPMF/F4/SxCThDiMEGfOMPWBA0yOPIDefoqcewCTDoElcPULWeEaXcCJ\nVwALY01CotDNLTrdGjlrGBO24ZVvUc41cT2NrM0jhsYhUZh+G93oknQjXEtmU4VEv2bWnHVWMh03\nYQwMQN55N9a7P4rsn2dLz9Gz9oE3xhWVNz0/9X3j8IaX3OWtHls9j7WVbT74To1OJEZKhO2Tpj0c\nNLY29PsRl2p5Noam0dccQ585x+G8AW04mwQEm2vULJ+x4XGumtqkeRG8EvjOOkanmEEfqx7heW16\nushY2qQbJNyyNyAtTNJtzfBQ6Vpe9D/CmCdZyD3KvugfELGL5crM38u2sYyNFJm+eKzSzFlWmgyz\nITLdlAzQlh1uWhiElmgyT7DdrDRJcW0QUuIkWYJlJRoRS1LHRUsw0uBJQ6Il7V5IHsN9mx2OTleY\nHDNYro9tGaqFHDNDRcJGj8SojBEXK7SEPTNVFo+vIKIArSW+lMg0JdUJxraJEo1jQSIdkl4Xr1JA\nRIJUG3SSYiuDb0t6aQpCY9kuZaXwpCAKWgzOhHgHZrGn9tElo1nLK8MeIwgMGd14l0sblXn47AAE\ntRFIYzLdmiuHgsj0i4R4nS8hRNYVW1nu8Vu/+HNMjs2xMYj5zonLaCT3nm3jy8zpO29nv12lBK82\nY8586zxBP2C64jN2oc51V+X5zpah3xFMJJJcNSLuDbCLNhqzo7ydJWT/uMVkTAaezkZZoJVAKdCp\nIFEQx5kuzNKpZzj/yB8CP3jCg8q+rzKCQayIpSQxHk8/9SwXF89RLtV44N6/p+JGFAvDtLohhXye\nOM6MGuv1bVQqKRSKuK5Lo97g4FUHqbc3WF9fR2nFcG0M33HpdjocO7qAa9scPLTAM889S783AJ0Q\nxSFhqki0YKvVJgwVTqmGdkskQH5H/VkCYTBgfX2TPSMFYqXpDVoUrBy2kJRKJYy06PRbDHodgm6H\nSKVMTc3set8kscLzcniew2XZhzzggArBKRmGhj1mpmfYP1pi49JzvPmq/USDkGe+9UXuf/o0Y5Mz\n9HoRqdEMegPCKKUfKe6ch7jdZmmrybce/yof+8h72O4kNERIvd4mGoAwEqUStNZEsUUUpaSpZmJq\nlE6rS9RXOLbByltUp8CMCwYoooYkaho8JxupG6Pobhr0lsBYYOkd/EgiSFMwbmagu1vW/j2f+3M+\n4FSYfse/4dTX/htJq42xU84/fR9y7DYs1WRjpcNQrkfu4O10l89Q3ruP849/i1BJGg/+J7zaAdIg\nojq/n0LvHGlkqM4Oc3n1FLU7f5re6nlaX/kEV7/rwxz+6X/Nv/upt/LYCRgCtJBoLQnSBDcSrMUt\nvn0q4vDUGBsbAQf3zLF48hlEAZSRCEvz1qPzvP0de+g0+zTPlRgbdUhyeYTSkK8SJRHFkVHiKMB3\nHdz2gL/9X36Vu37rU7veO2rHU0IbdjA7vNbhybo7r6dDUmYgZ9MJSGKP1HLRSJrNLXxpuOm6G9g/\nVsYfXMDJl5C6i449wnrC6JtvRz75ImUuEWNIVcJGo8ncbIGLxUM0mwFvf9e72Hvt9cSWT7tZ5zsv\nLiLXN1lbXOLprYvYrsfP//xP8OZ/8n4qt7yH6sEjNE48R2niOrQZwuo9zca6Zv+UwEkN1aIiiDRh\nEkOkqI6O7So2IhxAojGpRiuN8HKk9UXCc2u410xibIta1WXQ7FEgpVfMkXgpZqRCdbKHrr4PE0eZ\nloJugQ5BtTA44C5ght+DbHwTufwU4vTzXD5zMxt9gzdzI7PJ/bgiQAculhNiz1+Ndcu/xLS2iL7x\nh+ilRYSSSMdCpwZSoJAxZjPz6Z0vEYH14Y/DwijbXcVy/GbKJkZWx8mlAU74CsOeANsHrv+ecfj/\nSXjajFZdTjag3okZqqRYWMioQ+Dk6A8V2GRAq9tmNRgwfctRLpx8mYlkwP5r3s5jG6tM2DAxPs2K\n7ZHomD2xRd9PiFJFIQyJVhdpnVrmVrvJ5FRCJ19l/7TPSriPlXyBtWXYjAXtuE04NkS/cB0v5w5j\nT/wEP3vik3ReeBop8xmt0NIIkYCRJCrdqeI1WhowFkJkFZzZAbIZMoCs1Jn41K5WqgFNJBWpzjLS\nkhEEOsVC4iqZMXu0IUcKUYQWDi3P4/mVNY4OVdBWyvmTZ+n0+7S2W+SkRRLECATWvjnmRcCv/Yfr\nue/zEZ/+bIO8B3nbRqSCPYf2sXRujR979w1sNXtYtsUjz76MbQSJ1sg0IY0ShBT0Y0UMGMfFN5ok\njFG2g200iW2Il5ZISfFSC09kpnkamz4a/4d0S9daobRGao3RMrOUAJQAx4jXyq5sxCUztgSZYmy9\nEfGvfvROhmoWo3mLV7cTyh5M1Fwa6yG1SpESikArtsMES1g4UuE4grgrGQjJ2lZILm+QOUNPpJTz\nmt66gzNtY383VgfDjvDCzn4QO8KCGTXemAyrk6jshhqpDKCaKkl9fZ3Oo7+Ftnfp7OxoOomhHRg+\n8+lP8fP/4hcII83f/M2f0mu+gCfHGGw3KFViIrdLaHrUmx7FkscgCigWygiRSfwfPXqUl19+JTPC\nHQTMz82yuHgOG5vcxCQjw0MsX77ArbfcxOkTL+IJxVpzkzBR+J7L2PAoqeVxcWmd/bPj4JZYrbfI\nlUdIkxhJSnOzQ6cfEDRazB6Zp15fYnW9Q3GQp1wdYs+BQ6xvLNMZJAzCgInxMbZXN2lu1He9b3Ie\nWCLH1PQUcWR4/qkWri2J+oa8C2kSMTkxQ9+a5HIzz3p9k34vYPH8OuEg5OLps0yNF7NRbtolUTAu\nNc0eLLgB2hg2Vuv0samNjbH3yI2cePwpzp+to7Vi0GohR8ewhCGOB5RLBYwReK5LHMbk8jZewaab\nJnQagqCXuaF7CXgy86ZzXYtOX6EUKCmwjMG2BSSgyJR9owSsXerwfHPRpfAnn+TW932QYz/3azz3\nd1+ic/4b5GrXUe/2seIWc5UOWtkU4gvYe68l7BxneOYGWmuLWOPvZfPkE1SqVeKVR9iK9zE/U2Ht\n9Fkie4RT//33KNfgjl/8LCPjNr9+95tYaudZTcB2JK4tQVqkxkaFCu3BxY0G3ShkIpG06yuEKUTN\nNpfiEDNooG46QqvbZenVAdtLgnwEJRVg5crQ71EcH6e5tsbsaJWo34NSEc9x+bNf+uiu946+wkk3\nO7Yvhtc0tTJM6w4dfedlaeKDKaJIMF4VIaBTX2dkZp4bD01xeFKw/kqXf7jn6/zyL/8s3e3L2KPT\nBOcv0i/to8tL7CuDEpLtdofEatCPJDPXvI22P02vvIcL2wmNe77GB9a2WXAkwrN5/9gNfKO4wf/+\ne5/jnltvpDA8gZm9mpHaCJv3/gWpFuioz+KaJBCCZ1ck+2ZiCnbMq6/WOXrr9UTfzb//QVaUZMWm\nMugoRloOg/UtBq9cZMQOyd20B3d1lWB+hP4Q6EqOkuzjmg6d4w1Gy88j974JcgvoMET4eYxdRRbG\nMdJDNM6S9K6l/qVPs3IGBrfeweliwJ3eOYwGy5PZzd+AOn0v9h0fR+x/B5b/52S1eKaFtNkWTI6B\ncCUkKiMZaZDGIKZGYLLNpvk1vPoXcCZuRtWPs9D4MygvgGdjVBvSJoK7vmcY3jjhWe0yNjRMs5Nw\n9sIyt918mDhVrEzUWCx26G5eRJsiHVlkbH6B5voZRi2f/B1v5vETJ4mqQ3ihQyPqMuttErsFctPT\nyMZlvI0u7bjNTLmAWyvyUDrL6LjLSHGU5YmD9HSCCtpECxJPKPJens6gR3Ntk/LQML4Q/MOBf8HU\nI4vIjkQisYS1MxIBlWRHmJDyCmotG/Ps0CWviONqdsCpu8x3nDQhQWEbgycMgRF4EmyTkgqJZ0Js\np4BCIozEVilBoYCx4BnlU9ncotPr0IlC7nv8ZVKTaVHYdgZfHXvPu2n88edZfKCBiBW1fA7Hzmi1\nBdthYjIt6GYAACAASURBVKjC7FumsPwiRvWIXYuhYpH2dh3XEiQGfNsiQiO1IQfEaUpqFEJa5IUg\nUGCR0Du9yHCSYEjpWYKitnBNjAY8LX4ot3RHgjI6G2kps8PCykZbqZTYIhspGpEdCgKd0eIth//y\nid/FTddI3TLLgWL/TJ6KNFQwFDyLZr9PatmEShMFKeVRCEJN3s+z/5oSUhiW1i36no3EYTSJufc7\nXY7dkuDbSVaV8HrCpXc0nwEUWWWodOaRpXQGBk9TQ5xmyU6SSNrtAVtf/h+od9rsVsSpnzqsN/uc\nO3+O1Yuv8Jef/SMOLdzEobkZwrFJUi159sLXWau3qBVy5HMOG1t1ltd7XHvNETAuUdDHdV1efvll\nSqUSzWYTS1hsb2wxPjrCnj0zDPoDRiamsKXgwrlzjNdmmF7Yy8bqZcJUkYYxjkrwHBfHctna2GCl\nt0K+mMezbYqjI4T9Di4WjWaTO+84hm+nOK5DP5I8/8RzjIxPMDI2QTdsM5U/xMTkAivLFzBxRDDY\nHZMEwC3CIAjodOuQZGBBrQzYhorl0L1oOCUu8sD5h3nv++6mJBL+/mv3E+QcFuZHGDQ0h+dH2GgO\n8CwbJxGUoz6r1UlWpq8n1zgPwKMPP8P+A3vIlVfotjYpFCFVKX6hgDYaIRN8L09tZJRzp8+RhmAL\nUKlGKU3ShlaSYRbBYMlsr0jb4Fk2SZgQx1krPgFkYjBpZiuhBwKRGuQuiRL3X9BcW/EQX/0Sk8/d\nx20f+QRrzXdw+t6/oiAFneYrFOZuoSAS4vl3s/3tz2JETHX4UXLePgaDAdWJafqpi2XlGJsaotvb\nIg41tckprvnIn7J3/xT3fe7TfPm/30thaj+5GmxfPMuwpdE60+qKkhjP9QgTTU4YmnGAvX2Rrq3Z\nrrcZ1Cb4sQ++jVrV45Xj59k7pznx1DaOiWgXPQb9Pk5JMztUJNpapmJZbC+uUMhJBisGd+Igldru\naNeQ6fC8htO5csu6orUldgo3IzL7odRCImkni2g9TtgNKeZ9irUx9o2PMF5K6V46gyf6/Mhd78ak\nASdOn+POkVFytVHOnnuS/cMOqdJs9DPSjLY7+CMHaHcH6G5CsRVz7oVXuXpli/baGVb2TlKbmmPs\nxgl+tnSMs2c/y7NPPsNb735/tjEsm9L+61h/9MtIKRjJK/5h0WWsnI3tg8I083Mt/FKJoN3eVWxM\nqjHavJb0GEeTeh7r9oDa5hZWMEFccCgUWgyJBiKwcZ0CnqN4brnMy7//ADM3nmN0epzSgSPEWuGr\nlK0L6+SLeR74wqdgDfZ94NdZv/0gy2ef4uD6k1xtv0zLdbE9gwmybo3ahuhPfw45fSuqvgE2iETg\nakMnMCgtELHKOj2ZkmR2Xpf3Yq0vEuk/wKtez9zyb1POtzHJWegeATQ4M5B2vm8c3lhp2QEhFG7R\nYW6sQhrH3HfkAJYOKLeWsUb2kotKuEMjRKM1csE2udmriTZXKF9cxBEOswfK9GXAJT1GX/rI7SUq\n3ijbfh8xPEscr+MvFCheSClYw0i3ihm00VGC8Cu4JsCyLMwgopAfxQ7bmBiSOKabq7J87EeQX30w\nO7jMjuYLECqFtUOKltrOuhZyh5klrng37bBkXgMw/+DLMopE22SuwwZPQ2wZaDfRkzn89gBd8dBS\nEBtw0xRjUpzU4vT6Ft7ls4z4Hue3tthY20ZiSGwbEypq73wL0dIKd/7YLHuui3nwsQ79VDFkW4hU\noYTg0SePc9WRa+lG6wxXKlw8tY5lO6Ayt3NhgUxS7J3uiWVLhOMgkhhDZhQqlMETFmG9l10IwpCL\nNb5RBMLgIkmNQprd93gcaZBkuiZIwxWTPmlEVuIid4TpNJY2SGmBBS++sskLX/sLqkMVkJIpH47W\nfB671GFDGw6OOSy2E1r9iPnhEsXpKomBrhcwOyOxS5rO2Zi3z7nkJMxd5eEZH3oWjz28wdF3NWkz\nzut0jR3cDrym/5PumOzpnWRH6Qyrk6SGNBG0OyGX7vlVijmJGEhcmewqNr3I4rFHnuLEiw9RciLO\nnXqAmg9x3OCaG34abbk8+/DLCLPKUM4l59usXlxndmaK0dEJBBarUYwUkjAMcRyHWq2GZfsopcjl\ncvieT87P3LEf/84j3H33Xfiux7e//SC33X4LfSV54vEnqAtwSQmDPq24RW54mgP793H69BK1oWE8\nz6XT7mAYMDxWJG63md0zzdrWBpWhIvXmNpZvoZHs2X+Effvv5KY3WTzy7c8Rx10WF3eHzK1vSVyr\nT7vdAwOelERdAzHM33Q1Dz34Em9959WoRpsbj95OtQj3fOZ+IhKuPjTFuRPbTE4M0RwkuJYkTm1s\nFaBdHys/RjF3kclheO6pk5x/6iRf/buvc/N1wzg2pEmCTkIGvS7NhmZj9UlefeVVdBdsJ7NJkdow\nCDQiBVdCrLJ6ySjQQqDILCa0AtKMpSU0mNjsUBV3nrPDKNrF0pbhm+cVhQM+kdKsf+IXmdx/kNs+\n9K/ZXGux8liftdVXKM4u0H7uHipzB2hu1+lFCTW3xeh4nuZWDjc/SxJ02F5ZZ25vhbf+wi8TD7o8\ncc/n+NQvfRWKkB+bIxo0Od3oEwDW2CzNtcu0o4RRoJAzGFtipKLqODS6LUxlil/99Me59bYFUh2R\nYvGT8k62N7f4q9/9Qxb25ekGAbLZpBr26MY9hotgYk2cKApjMwzadVaPP09pdoH9uwtPxqZU8P8u\nZjLm5+uvETsjLoXRhlLVwtZ5mknI/MwsRVdw/TULiM5l1peXqBVTLp8/w6G5G3jkxUvcdNMRTBpy\nYSPGLxRZW2qRc7Ou+/ioh7Q1tSGf3GCJUN3I2ovPcY2VMPn2d2I7CXHOpbUW4qUBdx09hi9ACYGl\nFCoCx/cpz+9n6enHCaTNk+uSD14Vorwxygjk2DiW5yCt3bHY0iSTWdip/TMAou9guRCbNjqoMz7q\n4vs17PwoxnIQtoUkZXyqwiPWLcyEW3zn3/01N70NXtkssWaN8sCT5/knb5vn2C89xqbV4cT5Vygt\nH4flJdq1Gfqdl7FkAo5N0pfoBDRDmCCi//dfxxqxEKPjiCTEarQZy8lMpTvMzmmVZONHIWBgKuS2\nR5g/pDE1jakdxXjXQbQNugtip4KQ359V/MY6PJ5EypBCr0neLxDYHiWtaW1uEvbyDI6/QP3wIYod\niSwXCMfnGSQh/soSQ8025VtvousGdLe7FOwmaaKhMsFaf4WJqUmceIArSgz0gPFrR9GDPgWriwn6\naLdCRIyWFkpprLzPoBswoVs03DxV38Hqr6COvQv1d19GF8qZd5UWGAlJojBGZzo4OzNduZPtqx3t\nE/FdipuvlwQ/2EqUwhIJWjqkYYSSEhkrhFb4sSJJYxKlM3dcKUiNySRYtUHX6zz5wiJV12K73Uco\nhcaQpgphwfgt13PxTz7L0h6Xb3854tyZPnk789PyBBgpKdiGU6fP8Bu/+DNcWG9y4tlnieIU25HE\nJktmIp1iSRuBRKXQNhGlK6OlOCDFAQm2ZSGUIqcNoRTYqSJvQWg0Agh/CJaW1BpjK5ROSVMBxs2A\ny7bBUmJn1r5jHioNFhpbuHzxj/4DJT/BGMmUSHniQp/lXsSt83mMK+mGkmKYUHBT3nl0jm4qiSyX\nJ04ts7GpuHXCYnPGMOgZKrWUbzy8TqIMKRZ33DLCsw/dx5G7f544ir7rtxfoHUNQbeROpWgy+nlq\nUArSGOJEECvB8a/+R4qDS3ScHP1BSKW8O6fD+77+JU6++Ay9+jKVUpH69grLl18h53tgErx8lfm9\ne7C9yxQKDlZouPa6Bc6cvkR9u4m0BGmaYjTkcnkGg4D19XXyBZc4jjhw4AArqyvs27tAu9Nh38IC\ng/4AFSoWFub52r1fZeHq67BdD+EWWVldISiXkL4DUZennz/F/sNX4foWcRzR7/cJowA/77O93CMW\nHvVmh1hZ+IUKG811pCd3TCPLTE7P8da3/zibm5d5/BvP7io23ZahNCzwpMDPGyZm4dKZ7Nrcu2eS\nh3iJEb9Af3SErVdfAs+hMAy64NGrBzRaA7ZX2lSHSoSdPlPVIu3uNk88cJJ8oUav2aEK1B2BW4Rm\n05A0e5QdQEiGh6toS2LZLjMzk2yubBKSdYw9X2BFGpMopBL4eU2QCogg9iDvGxwpSHV2sGB2En0J\n2ckGOs7wg+I1RtEullG8EMD4iqTqGqYqOSKWOP/JX+HQm9/L0Q99lGucAjrRBEpg0gFry03iQUTQ\n7ZPmHA4dGcIvj4BjkfMkg36fP/nFu9jchMQFXXYQ0metUce2UmoH38z/9pE3USpVKDqwdvkcj7/w\nAt984gkOFsDyLZbW+8zfdBWf+T/fSz4HrfYFLLuEsFxkw8fzHGI7IAl6rK2uMOY4RFFKqxEw6Cps\nVxJJQetCwHY7otPpUxgY7thleJIdeQhjQFzxMiS7tq8ovCMMlnAxWEgnh2Kd5pbN/L47UNLh2IEx\nrp0wnFuMEVaeyliFxsoXUfpGBh1AaTZWYgbkKBWLbNpNYmHQOsXPFXELZaYmRvnAT76f8/Ekz5Nj\nUCmznUpuODxPq9MnMZo4ihmaGmfp8iXaK5coDo8hLA9r/CDuyil6i0/wZ6ctSjlFow83VgpUR4bp\ndQNazTbVWmVXsVGxes1BXmkgjDGOg+1b4Eh8dYqkb2H8CkRWtne1gjRkf3KCODnA1+1bcH/l12jl\nJCvbHTZDzc989DCvnjlBr/MME+uvMHHxFdqRy8XAMKlt7jk/z4dmL6K9GJ1ahFsQp00cz4JZD+EI\njO5ijVaJkxQ/ijBSQQa7BbWDWwbOhzNMX4w59bkv8Nb/dQqTtkDWEe5Qdk0WxzHGR7dPwfeBOL0x\nLT0Y4EyPM/7xu/nqdTPUkx7VS6epLi7TPHIY75ojjKqUeNIl9nxYXuPw9AjxIOTiwkFGNtaIgzWC\n3ARyzxiTcUpr8RxqZpaw1yRE4EoLL+dQFAm6WER4VQaWQBXGUEmMUAk5x6ERh6TbG2wPTzCaNohD\nQb8wyUQBGrkStgEbDSYDpMZK48gsqbHQGJPhNPQVqwMyaqLZOczNLrsYghQ7gQDwjUCpjCmSxgpp\n0mzuH3bxpItBYts2dqdH3/YQlkO7N6AtMnPEXhpRzXkQaYq33czq5/8K3e2w+ILmxDPr5DwH7Uo8\nMoq01Bpp2+R1jz/4oz/H9T1KvkOQJOg4gVRhuQ42AldAalSmmaQtPMejF4VIY+HJBNtY9I0CZYjS\nBF8KEmNwtCFvLLpGY8ndd3gyHy8rS2h01n1TWFgIUpl1eqzv6rJYlsvjDz2LiRIGRvO+6RxFy6bn\nSK4/UGYQpRQFnFxrU8oXuHF+gpma5LnzXS61QgbtgEBKHn60T75gmKo59LXHzHyeVy4L2vUO7bUB\nf/uF/8ye696OO74HbfSO+WeW6BhzpbMDWklSZUhUlqdGCSSxYO3Ut/E7D7Ad5KDfplZ0mRjanQvk\n/V//JI52UIHAjMxy1bW3sXrhNJ3OBl/9yqcYGjtEp7XI8HiOfhqhjIW0NH7OZ2urTrHkY9kWnuPS\nbrep1WoopSgU8gz6fba2tuj1Ai5fXkIKw4GFfZw4cZyj1xwln88zPz9Ho1GnOjJG6uSJBnDqwioX\nL2/RbYVce/Q6dKpApDSbTXr9AUlsaNYDwlDQ6NUpV0YJE8no5DStlfPESZs4bLK1eZ7puWlGq0fQ\nUWHX+waZYaVSS5CEAi8PlbIgNoZLzz0BwNZTD5E6mhcevZe9scWBW45xfP0c4VqT9naL7XXJ8MQw\nrXYHPxpAYqiNgEbjFXLkixIZaRLLgpyivQNgV2lK0O/ilcaQQuH5LpVKibkZj8YgQvrZfjAWlF2H\nybxNKxhgjMC3s0TGtyWRMqiQHSormYz5TmfhNXy8uNKN+MHXu6YdXu3DQ62YH63ZrHUS2pHFodEc\nrz7+dc489nWEBNeDcmWW3PAk1fII0s1j8hAFAevPrNPavszG+iadDkQKQmDg2nQCTSWv2epGLDVj\nGJrhPVcdob18nsLUBMN75xg6dJBjhw/xwXe8nV/5j79N3FfccP0+PvHvb2LQXMSoIZTIEegujlNA\nJi7nnv4WVm6Ii5stwlJCiKCWd0iiHmkaYPsuQ9Ui9XqDnsjR7iR4jeaut06q5A4aj+zARCDlDoLH\nAlvapGi0UmAUwmiCKOXowQ+wtzbObdfuZfXlp/nyfU325LvsnR2iMFSlr6GYs5mdG0aV9/LSqSaJ\nMViOxnGhPUgpjk/S63UYqgy4tB3x7GU4vX2Zfzo/RdhvU6SHnSZ4hTxumhIRcGDPHrqDS3SXziKF\noTQ2i8lVcWeuZejw2zn5d/fzIxNZUrm9eYGVxVGmDk5Tv7ROdai0q9iE0sGL4telNqIUsCAnkWmG\n85U5BbqBzMKTdSMtMA70kwnmWmcplmP2rC2RHy5h/ALyQpM7zRbmgXuwChojIyraQczcglcr8PJX\nXuSfzoJQKXp0HEv3KRZKmUrz+CzStdBKYlSKW5M44QoyBzoSmZnjztIC3MYZLr3pZzly68f4+id+\nk/d87FbWnn+C0tw+ctUiW/f/Lcrk6SvJgaO/+T3j8IYJz/s+82sklSrDtsfamVfphYrc/Dj9/dfQ\nExbFl16CQhHHd8kNBuRrVfawzbempnAjnzGnyVZcxveL0OzS22jhTA4zgqYTppQrBSKqtANDznWw\n+ttM6TZ5kSc0gkE3pOC3kFYJR+SxzYBktMjKhR5zRw6hG+cJeinJNdfjnzyBcVyQKVrZxJHCszPF\nXKGzUYURIIWVTTH0DiX9CvVtlxieBfok0mJ4eJy17SY9DTktMs2aNEbGMcqRFPMuKonpJClxnDKS\nE1zsbOP5OcKgi+t7THqlTPbeVVTm5+m/coIoSTKLCUcAKcEALCnRJmuX636CtpyM552GhErjSJuB\nihgqFegHIY7M1IAVkJOCVGl6KsByHGwpSIxERTGWk8NWaQbmU4pEa7SReCT4GHq77b0DCRqpUrS0\ngRSVAHhZYikFwsrAy7YtkQ6cO34KN1ilVC5TtQUF1+ZcL+LY3homjpiuFdnY6jNcLdLXmm+eXOcr\nxzV5y2IsZ6FUAm6ByCtRq5bZCCOeeGyL+WmX6T15Lq9axNqhlnN48O//G3f/m/85a/MidthYWayy\n7k42vtIakhTiGNJUEieaS/f+H4Ra4fgpU+PjBEGHc53ddcAO7b+epC8ZLk+ytnWRXrvP+MQM5049\njwjWcFyXKA4pCE3R8+kMIuIw4sD+A1y8sIptQ606TH2zjjGaKIoYHq6RJCnrG+v/D2dvFqRJdp7n\nPeec3P+t/tqrunrv6enpaQyBGSwEQAIkQXAxSdGmFbbCXIIiL2jLti4syrZsy5YdDkY4QpLtkH1h\nK0xTCtMUSVMKmRAEEBSxkAAJYgYzmLWn9+ru2v99yfUsvsi/GwiLgFTMi96qIjr/rJOZ3/m+931e\nlrpdnBM8fPiIOPK5dOEc62vrvPH662yf2aLT6XB79212GkuM85yqsrxx6xjtBNs7V3CE9I7HTDOL\nrgx5kRP6be7dfcxau0luCyZzTZJ0OTocsLN5hePeA6w2zCbH7O++i8kMWXb6l5aUAps7isyhRe2M\nc0oSJ45RWusWRNqjvRGy2Q2ZDhp83yc+gv16xIM/fZNKQdBtEASG2XCA6iTkZQkFpMeHDCcj8rkl\n19BwlmeXFDuBZKqgKA2d7jKNxK9dnNZQFXM+fr6F89q89mDO+nuv4CUBoc558PgRDojCGk9hraMy\nFlNp0PUad/4iNNdbOIWeOEUX1vTTHGcCn/XA8raAf3Fi+N4lSVFpXi0UN84kNfhUSQLPZ5qOybVm\n1DugzFOcK8hLj8q1iFZvsHY2o0GDtjvgS3diHp28w3A6J582uNmb1+u0WfDml/+A1eVl4iCAPOdk\nNGVne5PNts//9T/893zmnS/zV39xiyw9oNlsMUlzGvESaE1lUjylWbt2jaVOj93emKIYcZLElP0x\nCSntOMDkBfnRCD+K6Gd9TGn/XJTuJ5lZjgWEVoIQsuZ8OYvRumZsSYEnHM5GfGTlB3lPZ4WDvZv8\nH6+8wr1BxY9fjVnrRjA/Il66wLnL72c2r3jphecJojYP9l8nTwsCArYu3eA9Vz/M2Wvfxfj//Y9I\nvJCd932QMGnyV1q7NN97lgeHDXqPHvAHX/sav/7WW/xvP/Oz5ALkoxGXr7QZDI/x1zZoYhEyJNo4\nx8rmMj+8Du1I1HqZZpvpdEyRdgmVZD4dnera/Fqu+WljCRfuUCtAY/CiACHnyBa4Bsh5CF6ICBuQ\n15tnmQ9ZomLl2mVevnlIr5ScmU6I8wOuPPsce/uG5mRM2+TIGEQWEwQz/P6QZDYGBQcf+V9Zf+ED\n+Pt/gK0s8u7nkde/D3np+xCHf4AbvEH4uMfR248JonoCIJ5oTcwizzEzeO//KcrIZ19oPvPf/c9c\n/Tf+Al2Rc/Snu8z3Z3QiCB89/rbX4TsWPPsVeG+8Qb68SWt1BX3lWWaNBul4ztl//k/g6hW8jS1O\nVjd4Lh3xYHWLV5qbyMM+w5N7HC+fZzM6ZOjHRNUj+rKinDXZCgacWe8wMMtMizmx7aOnY5SsGKgu\n0mviFX3Cbi3wy1SDwGrCF95Hb5aTK59xNaJhK2ZhjDmzg3n965jKB+FjTD1K8WSN3RR4aM89tYdq\nt7BFfwuG/JSudPbnAg/HaH6A9UNEnjPDILSlMJbQGDwN49yhgoBsXiEDj6NZwfLmWfb1azSEQhrL\n2DiErYGOa+YR13/oHIO9nOk0ZXQyYZpWuOEMK0OskmTW4oSrO1q2ILOSBIvTFaGTFHmBtY5iYVWQ\nEgqnFrEIGuV5WFsHz8ggwDhDlaYIz6MQ9ecSwjJ1C0ieO70tHeuohMNSU5ydkGhd4RGAhkBInAf9\n4xO+/unf5OUv/FMuPfscRw93WV5vc2zhynKDobVE+Ny+N6T0PC60HUkSMEgkU2MZTSyTqh6foS2j\nec7J0Zjnr7dpr/usbKzjNQ0vfcAw7ZQoCb/9O7/BT/wHv4x13lOBcl3wLCIjjMNoQWVcXfBUgrKy\nHN55haRd4MoIL2rUsRtBQpKmp1s7ewW/9Jf/K1rNFl97+Xf43O/9P5iwQ6u1hcln2Pmcfm+O8n1U\nUpEVc6T1ODo6YHm5Q7PRYNAfopRHnqf4vk+v12dre40PfOCDDAaDeu0Mh5w9u02/P8AXkkuXrvD2\nzTe4dOlizQIymqjZodIDwrjJ3t0HrKys0V5tYkzO9uYGDx8/QinJmY0dBkfvIDzwQ5/NzTPM84qd\n9TXu7B5CJUlnY4R8xNFBk7XuKmFw+nUjRf2wi6RgYx2GGXixQ1l451jQXBLcyhzrhyXbTcXNez3K\nz3+ao6M5y+sdzl9a4sp7rjPce4iuDH5gkFi8SHCy/5hJWSGQNBKHiurCwy5ehFprbr9xkzPnzzKb\nTvC8e0x7PRJtUcqnOE5pf7jJyuYqni3RUcD9/tdRs1qWJoRYOIEWhYxweEogZL22nowUhANcXRCd\nat2UgiUpuZQ4MmP4+sTxfBP8yvL6Y8P5rsWXOfHGOmGyzHg+IWgodCnR/jaNTspS9zqzvbdobt7A\n7n6JaulFPnS9y40Lq7x6+yHTSnN+LePVdx7hEKTasj8a03/rHbQuOLO2xf3de3jOsLO1yt/4T76H\n3uxNwjimsBCFCYUxrHWWScsS5BLt5jY3PrLHG1/6LHa9Q55nGCS2EZGVgmw2BaUIpWZ3DNmkYKNT\nnnrt2IU5QgmBkqC8J4pXuYBg1rl1gQJPJtzgJ1j2XuBTn/0Kr+7PaDVjrp9bQYgYG2/QTPrY2S5R\ns4FzluvPbvDqH3+N5tIWgbmN82OuXHuR+OJ3QdDmduuneWHF430f/16Wqznru2MsJRfWIuyky+ZS\nwic/8hHybI40CqtLlpINMjunSudIKUEY8GNKF7C5DJO5oxFBURS4uMlwUCGDmGp2Oovf/tAxiSQr\n0izYc/VatK7e5GEccu1n4BM/gLv5u4j8APnDv4I4/Ge4f/x36N15zNlWBzEZ0MxGXP6BDzMbpqSH\nJ7hRn9R16IocpRY5lrMhVTtmllJr2Ub7+J//b7GjCcwGCP8IcfgK+nP/ZV1cBQHCRbgMMHVB5i3O\nEbeQHgQgZwNKEuTaJXY6Gftv3kUay60Hh7y4uYKnJeLKd33b6/CdNTydmEQuo6cT7uRAJnD9I1YH\nx5jnLqNXtsg666TDKZPxhItbjoe9CdutNrPOOnL+kEedDfz1Lco7Q1o7DY4LjwlNRmWHxGasVcco\nNyeLVvHmPcKkiXUW63m4zDIL14hFSNTymGUZXm6IPMVh7rO2fpWV4UOOmh5VkRMmbXCCwhgqZ/Gl\nBKewCHzrsCiss/WH/pbOjnPiXyL9/qsOzxiEdfjOkJcVzhiMUnjZHG0tnrF4tiI1BQ0T0K5mzAgJ\nncXrdMimOYHvIT2J0jUksRv5vPuZN/j8wx5R7KMWwZYSCAO/tnlXBmT9WZAWg8ETiqEzSLdAbxtQ\nC35B7hzVQlXprEFahTQGp1TNyjGWwBPo4QCW1/CFQQqJEBBTW7JrmMXpDuscOIN19TkaUTsEKgTC\njzg6fMyXf/tXefj652kEGe0oZD68z3g+oeVZpPF5NMrwE0UpJbPM0SsyLjwbE0uQgWLaL8ALqYoa\nWNfoWrJcMhlEPHiQ8QPXEr6xO2VDJ7QvxezfMVy8JHnhY4qvfeGrvPDxj1GUFmPEQqBcU32NrgGD\nZQWVFlQajDXoW78DvsDmglarw2o75N1H+8RJdLq1I33CpElRQFmVaDukKhyr62fYv/MGnZbDl5I0\nFTRiRbvVYjbMqMqSVlNSFOXTJ0Gr3WY2m6GUYjKeEEURUkqm0xkgkErS7S5x++YtdrZ3uHL5InlZ\nMpunlGVBlmbkpuTs2VU+9rEPkuuM+XyEFB6mslw9d5l9ecju8QHtRszqepfjwYS8mON5HpPpMWGc\n2wNtaQAAIABJREFUoF1GEFccnrzDcNznuR/7eZw5nc4A6rl95EuWOzXhuRGCHkM5FWRjB5HjT0cA\ngpffTTl/Dub9B7y0s8GthyNWt5+l1Yg4LC1m0dVNpWJ1pcF+b8wgg/MNyb42tBVM+jAxmnIGum2o\ntCNsdlha6XDtmef54uc+x/69d7h4oYuMAx7ce8A33n2T5aTFfFRQjiFwiw2T1PhKYRfwF7kodKhJ\nGbUrsT71+tHzxB/9r3l8PRdcUZpNH1Y9Sek7vjiA97cdvtLcGUo2GgHZ/oCl4BjlxyRIrJa0Gpbh\nZJ04u49SSwwev03Sukrv8U3y8DzO83nuufPcf+cmH3/pRX7s4x+l1Wxx+9EhvcGUDz2zydryBqtL\nXdLZjDuPdvGbTbQ3WGTmxVib0WldRfgeo8k+kR/Taa2Q5iU/8hMf5+//ymfxijkqbiN9n+E0xVmD\n8D1sUdGf55SVIystWXn6YrkqF7oPubChW4OQIKVEIOtMNQl55bAm5a35nIff+Aqb2+fxendZbTdo\nhT5Oz7n2nudYWpPc+syv877LKxwf9VhuNwiSBof9Y6Tn8QM/+8vcniX0Dk/40Aef4cMff4nVQJP2\n+1zKezjh0NmEMPS5cX0JRMz+pIfwHJ6QCCuY3J3RfW+XodaIQEFlwZOEW88wL2BzSbC97NPqdEhH\nh7Q7bVxRotTpuu7eXKCD+p4wTxMJbL0ZNhJmBuGViHaEjdbgT/8JvvuB+plYRowmlrW05KNnNe2V\nHZZ+9Hu49Zk/wLx1QKMdMNZNrDnCcxDIimZkKfMp4xn0qw06b/1D8B4hczBVLebPx6B8UI05BBY8\nUzcgaorMN6cuVtQ/U1NihKTEJ/YtTse89icvs1u2+MR3X0EPxxRzTabst5PwfOeCJxARVbCKuCDp\nWg+ze5v1JERfPIsnIvZdRW8+5eLRfTZXt5kYgQqa9N99m521M5SRIOgsYR7tk5kIU7XxyxF61oT0\nZdSFC9AvyKUiGd9DehXZrKIZSpy/gVtSpAWsegaTZ/SyOgQtSnwoMuJxzrHXYe3sJY4rn2QBtisr\nuxCuyQVG/MnN474FPw6Ghcj59K5rgny2SNy1+M4tLKkSVRQI7Yi1xo7HdJZ8jM7QvR6l59NYWYUi\nx3qKKPAR1mKsRjpBnuU0JIShw5MGZSXaA0/UL+DcWhJPkhlBpCxWKaQWVEYTSIl0Am0MgSfrF7YQ\nNb1ywR4yyJpLZA3SWOIoItcVeB7u+IRweRmlHZUySFOvt8r9+aIlpKijPpSDcuGKsKJm3bz5xmv8\n/b/+SZ69tMPF7XU+9/oBge9xZm2Z7W7CjowpTcmzl5b54s0+mdEM+injsuTmI0GnoZgZD6ccUVJn\nfy2v1nbL1SsJjTZUM8tb90p+/LuWeXNPc/i1issfDgiqCMo+9x8MeP5jTxLP6xeTNvUYq9B1dISu\nav2OqaCq5rjibTyvzWbHY2WlwaPHe3TaTYaj0+22Aio+++nfoBVf4mTUI4wElXRUWIajkm63wA8M\nw8GESV6ysdKlLCqWOl1wsLa2xvHhCUZbxqMRzVaLKIrQ2rC/f8BsNsVZ+/Tf2q0W29tbZFmG5/ns\n3rqNs1DkJXEYMZkbLl+6QBD6UBniaJU4SlhNWoTCJ+106JU5oaeIfJ84CvGbDYwA7TRBIyC3DiNG\nbJ2F2fSERw9f5ey5q6deN1R1C2ReCY6OHV4gSMdwfGLBh61NOHgA4PCAj3z/+9nYXsOr5qT6MWVV\n1t2DJ1l5ri5kfaEJFVQeeNayYmDLF7SuWV78wPeyujvkK6/eZXnrAs7zWF1ZIkkSGnHEcFjx3HUP\nT5d8//svMwtiVle7vP76bd7Y6yGkgwKiuN44xX49xrWGJ3kqiMUoWiiBMzUKQ6jT7bLGRclNodCm\nookjEY4zIXxj4tgpBJfa9mmn+txSQD4r6BY9POVR6TmRf4Jo3kA1E7xpRDp+SKOzxPTxu6zunKc4\neJcXz3ZI7/9TNi9/N6pc5ruv7tCInmU27JHfv4fc3iGWPle7XaJWwjuv7HL+qkYLn1aygiXDlorl\nzllUAAfH94iiNp2tNf6d//iH+Ud/77OshiWzaYFwFuWBnuVo4ZCmHuk0goJRcUoMNdTmlIUzzuga\n+KikRQiLUvU6MAa0EejCMde/i+r8JfrTQ3Z2dijnJxjRRcUNpCnIR1AanyiomFaadrhMI9HY4yk/\n+Av/DWrjMtsnPW5cu0BlIM9KZOy4qnv4wx6V1UhlEdLhRS0G4yG2KrElyCjCOYuaG3Ij8XyFQOE8\nCSYhOn8VEDhtuduDj26FCFuRpYbtM8vodHqqaxMUJdYozCJEWsGiMK+1iW4O+mu/hffwy0i9h1wG\nM6QGRwmPzMacDErOoZFOk3/hM5w7+gIPHzmSHQ8aK+RFjZXwo5JmOWGej5h4Pp+7W/AzzSNMUVf9\nngdOO2SjtqM7DVJkCOeoqoVr9mkCrHvqrMOUCCFQRrPJYzY3mzzz/EXWr64j0hGlSEinj9CtlW97\nHb7jqjoUHhfPbTHqjWlORrC1Q6kEw15KGuVkec6ShdUXv5tjLZk3Y4o7d7iyEjMMNEG0TXNtk/7+\np5kHbdbSKWU+RHhHxEvLeLduksjHWFPidZ7HjyKM18LFMccoGpR07AkiXqeqDFb4lEsJvp0hpcKa\njCCPSJM2utmqk4SthyWtAzdFLVZ+EmKpXf38qYTFcwKnZM0n+JdBu//K4/BgH8/ziaKYbhBQWEMg\nIwpticsSqws0FqoMjEVUBXGWcyJDKDJUVWIiH8/VdlasozCWRuWQxuBJycyUBE6hgUoJfARGW0Lh\nsM4DUyKVR2YsylicMXgSppUhUAvx7YIyXVpDgCRQitxZKhymrFBCUlkN0xGiNFTCElqBVWC1JXQQ\nn1bgxDd3VzX7wWCcW0SQhvyjv/kJnn3mGU6mc17/6k2CMKAbe2xsr3P/5ruceIr3b8b0xiXzecly\n5NGzjmYUMq5S0qnHjcttErGEVT5e0/Fo7Gg0UmZHGdtbXUylccuCz93d49xmm4uBouynHPYUm9vL\n3Hz5t/nEX/q3sFoshMr1C6nQtW5H67pTpjUYXZGPB0hXcdzPuX5ujYe7d4jjLW49eERnqXu6tfPw\nXeKgRZZEjPOMNK3YuLBDMddMij7HYzjp93H4RK5NL0s5u7HG/cf7IH3O7pxHCMk8nbK83KUsS+I4\n4vj4mNXVFc6fu0iSNLDWsLf/iG+88Tqm1Oxsb2GtRxiGdLwE4epOUhwklGXF4fEBm9tn6PV6NMKC\nZEmwPxhhw4j+ZMrz7/8uBo/exQjoHx0jwwgvDsnFkEl+wrnzDfJ8RKcref2tf879+w9OvW4q48gr\n8AtB5NfPOa+OtQYFQVLvTpZjwU98F2S3X+PlyTNcee49iNacwAh8qTD4deafFEznFctJ3WEpDMwF\nLIegJWQ5JJ1lZnqM50mmkxOc0Jh0QJ6W7O8dUfjwcOyYGrj57us01jdIIkV/OK2dV2aB+nKi1lR9\nq/3qyR+FQKg67Vw8saef0qb1UgxfnRt2K8UZUREIiIVgxYOJhjeGjmdb0AgExxODQ1JWjtDXjOcV\nCE1n+FVWllvgdyiISXsHNJebVKPbOBkwmU/pbH+A1oUXMVZTFYbp8AhrBWevXcd3grYCLQOcs7z7\nG7cY/2SLa+81qNYKlS5Jkk3Kck46GxElKwzGAwaZ49/96e/hH/y9zzJP83qconxUYTCutuLIIKDK\nKlCq7mKe8rD6mxRl5KLLbOsi9AmFubaki1rPqTzSPOXKzmWGgz6BH9Pr9/nJT7zILC159w//kGsX\ntjBlRvFoTtxo02xpHuSbtHYuYSzcuHqBUaqx2rC10cLPpsz++CvkFy8SJE3mh0NC66jchPl8irU+\netFhaXg+Qht8E7J87VIddrpIVm6fu8xaUCF8wfFY8u79R3z0g+9DSV2H6J6StBwu1qlmkT5g66aA\nW0wEnKtRJnZvHxGA9T2UNbjKQaFwSUDUiUnv5/hJRcv2mcUvkrz3JmI4xC9bjESLtpoiLfgyI6pm\nHGaSVkLttjIO0XDIi89hvHOE3/+LOD3D/M4vwUmFkzk6X5zPE2/WQm+LBOUKIhxGCHbcIXa2wqXL\nm0TVHsW8QYrBOkdybvPbXofv2BdrqYL8wR7t/JikPMILHD0VkQcOpimyqrCdJscHI/Y1jKcz1HhC\na2mZ7tISst1hfnzAWqdNZ+MMru2TrV4k2tghthGN89c57Fxj2H2eTEBfdZl5PmmVI/IpaVrSaLZI\n05JZskk7FDSyCr1+HuE1MF6brBUQtjbwd87VYkFlyMtag4JQi/gCsYiXEBjnUIt0KGHrXvST2fpp\njlbYZC1u4VtXi+G0Ic0zcpOBrpjlBdYYZKER8xQlBF3tKKcpqQM/ifEFWGMJbG3jVkrSXGvwIz/1\nIpmTfPQjN9g6t87amRU86/BdjQUvNXjOQuEwxhAK8BfjuWrRmXGmditIKTBOI4wlc4aJNVRaY7TB\neZBhoSixx8cYa1Cm3gapsh53lQb0n2OkZUyd4m5MjQcw1lFqg5USmaxx72TMPC3oNBI2Y7VwPXiU\nuiJNKyQwn1c0A0F/lrGaKFQU8PyFhPde6bAUBhwOM76+XzEaV1w7vwbK4i87Th7O8doB62cd3c02\n0ya0z3vsval59eUDRAWu/8fMhkVd3BhHZet8LKPrgNCydJSlRVcVVTWD+TG3do9IGiFZlVPokHGa\ngXLs7X97kdyfuXaWGuRphXBNkkYbJ0LSvCSvKp67sUOUhMzzEYIcqwvGoxSlYsrCIFFIxCJTq4kQ\ngjRNybKM8+fPsba2xtJSh8PDx3z9la9ytH/IuXPnSaKY3vEJCDg+PqYRR6TTCdPRAGNgb+8RW5s7\nNXXaOdpJwsuvv85JPuPh0R6tpSVyLRkPp8wmOdNphjESa32Eytnc3CAJu8xGAfNxjKsMJ/u3T71u\nhACkwEmHwTGfOY4Hi5uzAqfhwtkmP/GxTcK4ycHLGnXYo9Xq0t0+g/Q9lHAYaxAWBB5KOpqNCBH5\n+NZR4NhqwVoDUgtf+dObFOmYLCu5dPkq6+tdrj5zjatXr3H12mW2Y9iitp47K3HK5+XXbvKHf3Ib\ncoGtwPcFQSAIA/9pl8UtEBD1W5gFuXDxQWsE1amOrUDwXt/yMDfsVpK5gUDWTs9pJZhreG0IB3PH\n/tRxMLXszxyDuaWyjt5csTt0vL075d79Xcb9XeZlRitOkOsfZWn7Bpde+ARmcI/+W1/k8PZDzGTI\naO8eZCmD/gGT0UN2D+8zvv0Npnvv8GDvdf6Lv/Zl3v76BBDEQQNnSoQMMEZSVh5GO8oyY640v/DL\nnySbaCprKYqauK+UQjvFPC0AR+Yc2ekfOfX4X9UasG/N0bKLl31lBHqBIvGkZHySoPyAVhLS3byA\n9AOSlU3eeuM+rsyIO22E9NAuYKkZk2lFZT1+7Od+keW1VYIkQXgBSeizs73MUiMiaMbsrm1z/JWv\nI+c5XhSgtaUnNLl12DAibLUQsib/B8rHG84J16/WMzkErpxTTOa0PbjVlyS+IQgC8jSl3Wqiq3yh\nwfzXP8J6T01lFxRq4AmNzmqBM+CrJqq5ibXrVJOE8r7E7AGVAF/hCQtrK7hGjOmsI4YHaJcgFQQi\np1IJrgKsIPQ0pvAoKtiJDdKv30VirQ0f/knkc9cxNsCVE6So6vNZaHesfrJ5qE/UubrzrmXEyEB/\nnnF0+eeRnuXyT32Uc3/rNwle+j6KTKObXbbXv73g/TsWPEn/mMZgF3cywKxcZZobqn6fMi8Y/ovf\nZyf0aMwqSjTFaEB8dMTGuQvcJ+aRC9lSORbHOz0fHhxAPqF0DpFVODUlzU9qnggBg6DD1ExoYPGl\nz1IjohtbgrxAeILZNMcIReU59FTTzHtorUm0gqKARgtZ4wMoSoMv6tiIOpyytiZKa2tRJNRqfVfb\no58yek5xGF0xKXKUs0yyCltqKq0piwpdFtgsZfqNNxk9uk8xHlMFMQ+mM3SaIuZTPOEoKo02ltJW\ntfXbOm7e2mNzOUFPCnRVIQScObNKXmlyY7FYfOFQyic3FWVlMYUmtxpZabAaz1oqK+qgTuswlUYh\naEhZvwyASDh0lhEuvn/6+ABlDFJbtNZ1MOxinqOq04H1AIyrIzdKa6h07Vwx1vHw9uusLzUwuqrz\nmpSlEfhk85wHjx+ineTFcy22OiHWEyjf59KZDqVSLAcen3015/WbM752c4j2SkQ65ExLQzbhwxcu\n8OK5c1zdWqWblCQ7Pttb0BkXPLg/J8srnCtpRooL2wmz/mHt0DKLsZWBohKUJWCgrAxlOaUsZ5hs\nHy8AV+bsH0/o+oKj42MCJei2T2cRNaqisbxM1FK0l1aABpW2lEXKcPKQnbNnuHjhBgifRhKRpTkP\n9vfQFq5efZatrS08z0MpRbvdpixLjo+PEaJ+SIzHE7RerJ2dHW7fuss8zbh9f5evvvwK3ZU1gtBH\nSIPEEYdNsnnOwd4ezWaLCxcvcHB0SInheDbB+ArtSg7279Hd6NLqdCnLiuGwx727t1iJG3jG42B3\nROKv0I7WaAWrxPaUKGGgtAIvAFvCaCqYFzULCRyyITAaLt7YoVx7hq8dOV4r4cFRzv27ezw6nrN1\n5TI4w3A8qot/CVjLZJrjOUclYL5w6pxpCFpKcvveQyJP4Hker3zpC7z5tVfZvb+H1hVGG7IKDqSj\nsaK4cP06z1x+hlKbmknhnuC1LMZAaS1aL15G36QuPIW5y7gmeAtHjdo/xTHWsBxInokcIyN4bCUT\nA4GC9dDVgEwLtybQK2BQQD+Fvanj0UgwK+oOZlpYRrliXHi0mh66yoncMVF7i6OHD0jWrzOdgMgm\nTE+O0cWcw6O7HO+/zeHJI3QluT1O+fKru3yjanMAvP5WBZnHcLpfs7fKgmawRpYXBFGTpUabRtzh\nkz/6fo4ArS3CGowTZFpQ6ZpFhnUUecXa2TOnXjvGCiwOt3jG2wVywi3iYr7ZUKvzFeNV2DjbAgex\nZ9laXmK5GXPx8ga3jzKuXH+O0XjM8eEx7cghygEDndCzTe4+7JM0Y8DR7TaJQ4/AV3hK8t5P/CBf\nHPfJJkOkKRk1KlJVgPTqcaYfEIYRVio8T6KyhekhjEE5nDaoIGBlLcRzBucEz1++QKBLiqqi1WpT\nFN8ervdnHcLWEo4nXS5nXX0uzqFNPf6VDYE69wz+S+8n/OEfxfuLPwcf+XFcMePWyQjlCezGFuPt\nj6CNxGUT/O/7RTxXEqsS6ydUFbVRph7QEQlYTTRGLqYNmcN+4Xcwv/+ruH/4byJ+9z8HVXer3BNo\n5CL/7OndsbhfhJG1VMI5ws33MBj1EbPH/LP/+m/yP37qTYrlVVpNj+Hut4+0+Y4jLT9qo8+cJT8Y\nUuaO2BXIdoiaKlr//s8ztyGT0YAyShC9E5S1HFYVnie5fG6bk94JxaBge/Mcwh/hBkPOj9+iWr5I\nOcmIuptsVhOKIEECTSeZxx2UqchcwaoL2DcGezig1e5Q5jNWVzZJ5Iyx36EqFHaqifxj9Pkd3J03\nKbVjnpc0AwVS4pz9Zvq1dXWulq3TxS3fMss6JYfHx+FwSG1IpCREIK2jcoJR7xibzkmaCeXRMVYq\n8uUl7HiCGw+g0aHViCnmGVIKUCHpPEc6+Ms/8yGuf+g6f/vXruIcDAYjfusffKVu8BlXC409SVrk\nhFKisGglQFekglrobJ/kYVnKBUlYOkteGTxPoaQgc6KOd7CgbY477BGUOX5VoYXCSYWHY+Yg/XNo\neOZZicFSOou2Ck1NnH78xhd52Juy1gwJw7p1PfLg4rk2fidmfDjiUlvQnxUUheHqasS8cmw2Qs6u\nxwRRAQJUYJmmiq3NhMNRRn4yIz+OkFFEd1NQ5JokiGAYsJx4NKm4+qGQH2mvcvekR2t9mSLNyIva\nhaUXsRGmclRVSVmN0eWMqswxpmB+dIft7gqDcc4kN1RKcG57mft7fZYbp2u/L28mvHHzUyTBl7HG\nJ8sLnOixvd2hFe6gVULSXWM7iBE643s++gLH4wGXr57n/r2b3L//kPW1FaqqIo5jkiRhZWWFLMu5\ne/cuQRCglGFre5PDo2NWVlbZWO3iB5I7d27TaXfY33+EtXD+7Dke7B2wub5NZSsG/T6tRszZnTNI\nP0AEIWk2Z6URoJxAq5jDoxOWV3bQYs6VZ3fo9QbsPR7jq4TtrS1eee0brKictZXTtd0BqtShpaC0\njmEP1Lc8oSyO+RySZkJvMGHkRezGc4rhlMt5weVnLtPtLmP693jn3X2SCgojOEklWVGxlkh+7EqE\nkiEdGXF46wi/GfKRF5+pO6HVA1546QbJ0go3rl2td5ZVQRzCB194njJ4jFGCwXjCg1sn9XxA1Xoc\nawWVsZAVzEsHQf0SqbutPG3LL5in9Qv4lLSH16aGUFoSC2eEIXNwRD1OOythLai3dpmBRylUVtAN\nHKGEYVEH+QZS4EkIPUegHPeGllgd04h6RLfeReBQGDrtkN4h2Lzg3Jk1Yp0yMh5ZOefxTPGVxwWv\n9eE//Q8/yP/yb28SBi2slxF4Cc7NieMO2gY0HLTbz2At5PmQ5kqX3//SX+O7P/Z3uBACUlJqS7BA\nhqSZ5sM/9MMsB6djWwE8AT0K4Z7m4Qm+JTlo8bt1YmFdn3Fo/k96Iubyyv/E+hmfuH+Lr7++y0df\nukLgCebDMS6fcxhdQl//OVzY4ObbhwwPT/jaH7+CLSvOXbnI1tYmH3jPJq1G3eH6yb/9t/n0b/1j\nPrBmsNOKvAzxvLrz4IoCayyNMEBFAVE7xmIgn2OzGTJKkKHh2o0PMJj+ETePHMd7Dzl78SJR4FGZ\nimKen+rSJFiUkxjk02Kn5o/VRg3nFJgpTL4Ew4WGpgSRxgxPFCdphXWCr37lDq9sWH7lcs5+EbP2\nzu+ioyX8skAKQTYH2XR4ccbSisfGquPcUsb9k1XuLH2Y7+/9LmI0RQYSVy3wVEEJur7XzaKOM4uv\nIUC6BV7G5WB0HVAUxxx88G9w+7d+jfv9Pu9d6vLWqz0++swmf/KpP+Lf++U/+zp8x4JHV44yMzgq\nxGTEcG2NYTZhaaWNrTwGoxOWqvrGa1KwtNxFB4qVhgMzpWhvIhgjhg9IXIUfwACPTTVh0myRzXOa\nUQdPgZERnaRFLxujwy5yvI9Y2WKl02Wu96niFkmrTeW36asNmrPHzLyClY6PttssnRvVLXajsZXB\ni2pLonySnYXBSIUyT9rlsh5lLZp75rSYd1cLk50DT1oKBEoKKiGIwoB0UudsGWPwfY8yTXHW0bAG\nladMpymhkhRVhSzBl/U4p5rPGT++T9yKiaKAs1sJKyseD286tG9RFoyp6nakq3fEgScXbfQ6sNOX\nCi1cHbqGwgmNFa7e8djaIZD4HiAodIVwDhH5WK0JjMFgqTyNQOBbiX/aJzM1eLBwitLV4wUNGBT5\n+JD3XVphMJzjkFSJRM9KSiTPPrfF4P4hn70z4b2bLTIL660ljqZTntlqsT+Y8qFLS0wGBUOTo6Vm\nlqastBWDyvJ4XhCkFfsjn9VWyCAwRL6HkYb+saN4MGHjUobXColKhyGkWhQ6dTPLUemSshxTlVOq\nKqcq5gsL9xTnC+71Uq5fvYDKJsyyKZ5UDPPT7dR1ukynlbH3+CFL7Q6jwQQ/OEueQpGVHBw+YKnd\nITMZzdgjL8asrsYU5QhtcrKZZufMNmmaMh6PybKM2WyG7yuSpEG73WY2GzGb5awsr9Jpd2i1Wqys\ndFhd6XL37j1WomXm85zdh3uMJzOKSZ+NrfXa/moqZuMhwovw/ZBIaxp+iMRnPJuRVxl7vRFxQ3Lu\n/CrJks+Ot8mwr1k/c5HG7n1MUfF4cHLqdeOpWuuiFGDBlPC0VbLgI3lRRDaZonXNlSgRDHsD3phO\nWF/tcm5F4fkSXVhwlmUPhKdYX41oJzWgbjbL8SKIOzGPTyYEQhD4AaPhDC0SZvOMleWEpNmkNNB/\nvMeDWweYbEIpPPb3a6G6sGJBK4dOICGQzLR+mgH0JNsJn6cjLLEgELrTLRt62tJQksw6POtoUP83\nuXXMEMSLyJymB2LRPRqWjkBCwxMoJKFyhIuXXKnAN1B4jlllSPy6m5YbSTStcELhtGC2e0LkSQ5m\nlpsjeHv4ze7Cj3zyPLPZnGDZYUSGspIsM0zmI5qNMxgTkM+PyHTdVUvTPltbS/xnf+X7+c1f/RJF\nadC2DldtKzCXr7DaadA/Pn3w7JNi5+nfBU8jbepSqHa9OqjNKhaUFAiR8cajX2J9429xrvMePK/P\nOHWcsZY0mzCbw/iDP0vsRUSe5Oz2Cid7PTqdZda2lllZW2LnzBoWSxT45HNH7gyf/It/gdmn/m/m\nqcVJgakMKgwwVUUAqEYIvoVRVodgGI2rNEQCFUUsX32eo0/9EfeGIddSQaPfp9NdBb9FEJ7OGerc\nk9DURaaYdU8dyk/NgtIDY7CRhKDWcHoy4zDvUti6HhikmqTZQrkJaVrSaXSoyhFShChbYeohA9JY\nlOfRCDNkq8NFr0dXf4WeO8u6fYRY3BhOG5yqcySdrZ2NLDStjifhT/XPyi8GKCdoxj6uEXHy+C1K\nmZDuPcTz4T1bHfJCsLz27TWV37HgmacjmhtNXGeFvOPjDg4pfEFj5SxDH4TyEc0GNi9oXL3BaD7j\njJwR+CHZeEYgSnRqKJc3IZsjixQ6OwydohGnmLxP7rq4UMJkRN9oGo2EfHqIajSosgLKgpnwCYyk\nqDIK6WPncyo/xIsbjHVGTEi6fpYojuvxFg5fiQVjp75kbvEEMmJB+bULsXA98l2oek5xGFOHAmLI\nC4gCSaYlgTPM8gxRaSIryaVAlxWRFzIPJf7E0hv3CEpNEgeEUlKKWiuCc/zeZ+7z6U/dwWaOQkFV\nWeJEonyx0G9ZpBP4UCegO0duFjRk40iFwBhbV87GImU9NtPW4kmv1tRIh6lq27o1Ft/3IdcOkMF7\nAAAgAElEQVTk8znC80AovNxiFAgnn3aJTnV5tEU7R2UMlat5PlIJ/DBmhiSTAdsNHzlzNAIfP5B0\nghghAnrzkqAR0MwchYRUV/RHFfPC4fVnrC9F3Gh0eTwqccLDlIrX+xM6SxrhW4q0IPQcvYlld2Yp\nK81SO+HqtYgJ0O9ZVp6t6cp1seOoKktV5uhqRlFMqcoUXeaLDo8mT1Omac5qq0GaTlmKPLou5vKa\n5e7B6YL85r0h8+mIle4y6WxCoxFRpAZftpmVkmw2ottUWDsGYqIwJstnZLkmCkMmgxytK4bDIZub\nm0gpn46yJpMxSimWl1eI4xijodFssdTp4HmKu3cfc/v2HS5cvMRsNscPfIo8p9FsIISg3+vRbMRs\nb21xdDKkyAqUUuw/eszS0ibD2ZSo5RFVFq0zpLL4IQyHOatrW6R5gef7GA2z9HS7UADlCeKkbpzw\n/78nFyMKoTyss2gjabXAV4I/+swrTIFrV9r8+A9dp51ElAvtHAoakY+gjnqpdMV8IezfPUnpnOsg\nPJDSESYhrWZII4mpypwizyly2N8/IM8nvHtnwvEM0DV358nRiBWVFERSYe2i47cYncFix7yg1z5t\nNZzykeOcYG4cBRC62tCWCGgiGFiohKCzcLa0lcMTglTX3bJR4fAkRFYQSYcvIaydwAQGPCmY5g5F\nHa45Lerw36ISpFawP4XjStJpJLzvyjql1SytJeTFMVEc44UNPAWF1QReQhyuI70OoZ2DMPgyIC+G\ntJpLHAwH/MwvfC+//r9/HutkbRt3joEX8NLV57iyuYRwp4QUUTOcFnplwH3z5+N4Wuw8ZZE8+bqo\nXbqh75hM/y739HXOb/xVls62Gez+IfPMsf/ev05T1NEpUimCUPH8S88RBJLxZCFrmFVsNkLKStNo\nNSjSFO0MfRniO41naxacKyukdhCF2CwD56P7KWJwgO97tZA9TxG6ZMkv2D4bs5tVNCOPvBJYWyFF\nQZmfruvu3ELeIb5FR+8Eknpj7LQArRFrTWQ3hMBHSAWR4Si1BFGCUIrCCTACRcE8A7l9FfPa7yFU\nbS6woi4ma5SJpalACINB0lF9lLMUOiTyCsT2BgQRPNytNwdagBW4RQX2JO0e6rFw5jUpJURSoTyP\n3je+Qb+a0swmrOaSxvYqJpesX9z6ttfhO9vSz5wndxLvYA+3tIpoKc6VE4JyRDGwLMU+vgvZbjYI\n0hndrA+rW4jlDWJ7k6OJJFgNyEYpSTnCd5Jg9gi9dgk7N4SFpqnvkQcvkEUWP1lCZSNEEFB5TYqq\nIog7+GpM2LvFbD5ntH6D5eEjsq1nqBx0JXi+RIWbpO1N3O5NPAGBUk+LmbrwqX9Vi+rRLpwfghoa\ndspxOoExWCWQlcaoAKMt0moKKYnmc5SSGKHZ7nTZG/SReUYSRszzPnoyoRP5T1x3KCHQon7ACF+R\nlgVxKAjqQB6s1VhqG6F2i4eSrM9b4ii0qXUCnodvKoyUT2nSbtHmNThyawg9hbX1TlpbWwcaOhDC\nIno9vM3N2uH05NpZ/efwaEGuSyrnURlDaSQGUFYShBH3DgcoIZj6DbaXIvYyw1g7/vCLX2Wj0+B4\nNGWmLUEYUNmaM9QbDNlYbTKZCYyu0Lai2/K58zhlMLXYomJ/F7Q1BKGC1LK65eFJQzOxrDQzUp0Q\nS0FrTTHpFTghKUuHrhxVlVEWU8pqRpXPKMsMUxToqqyF1JMDpHFo4eFZn9WlVUaTAcczWOssn27t\nKIMrM5xoE4crjCdTLl7YJlABUbhMnt5lOvHZ3tpkMhyR5yWVqwNWtdHEScz+/j6bm5vMZjM8z8P3\nfTxPcf78Bfr9PvO5T5blJEmTqix46803aLcTDg8OWO4ukRcFw9GIorSsra8SqroT2eq0qMqS46MT\nllfW2D8ZsbFzhpVOm+PjY9qdBvuHd+h2t8A1ePzwEO3nbG9eptfPOTzerzPdRIRUp2ep2EWfPa+e\nrLpvqQoslBmYylAWhlJrWp3/j7M3ibUsOe/8fhFxxju/+eXLOWtiVZEUSdGQ1Bagdk/ottvL9t5o\neGd45ZU3bfSivTLgjeGFvTbahgG30LYl290aLFkTKVEii8UimZVz5pvfnc8Qoxdx7suSZBb0eIBE\n5hvz3jhxIr74f/8BVC1J9xXLM0s5HLFYaJZ1JL5ba3FCMS5zbKgJCNJM4YPBevB5yt/++7/KaNDj\n+Zv/ie3DI7I0I1GSPB9Q5DnvfPwh27sJo8sfkxYpi7pilWw4ENHa4Hzu2JvI6COUKJrad7hC1E4E\n16lgbJellUByU1TZBwQBJwVrIWgJmAAlgTGCysM5sJdGLkuhIsJdO9A+ZgjOdWAtIVeCPAJSFJ4o\n3wakiIRfFwQz7Xm9FjRWkGWQZ1DVK96crVgt4O+/ex8VBEWR0rZzrBtGbhglCSVNOycEi/MpJgi2\nxxPqxrAzuoM1LXMLZWQaUBvH3nvv8Oj2Lmerhr3xzZ2WnQ8EG9EilERe28qCFBvQTXa88RDbJDIW\nPCoVOLugrv6Ix8/+iP/2+d/hV/tDMl+QDMb0ej08xMNjSGmNpyhSRoMCbyyJBNMa8sEAY2znh+ai\nUtYnkYRiDWnWQyQSgsd5h1u1NJ89ZSiS2GbrDUGlhOqKN8+ecOfOHexnP2W2lrjQovIU21SUw5vx\nBulGYnO4VwECHikl1sd8rbBUBF3By9Xbx87B82aHnUlBCGCcRTsN3lNVQLPsjHtjC9d0Jq5ee1SR\nUMoW7OoaCR2VU0yTx+fBSwhl5ML5L77I+HHYwHEicpASETMXrfdUV3OSe+9z/vv/il/eKSiznLaK\nar+y/7PnzpcWPK6qyfUMd3ubRErGYgvnR1TnFxzuHXJxsWKSt6hBSt4ahvuHlEnF5fkpK7fLeKQx\n6wWjdolJxigzRWzdJneBtF8wG75Dr5ox13PyNCF3FauspLWWUqaMls9o2wuy8SHp9hZitINfVNjx\nDltlyeX5S4zI0P0EW6+RHz5i8ckfkxe9OJGJVaxCELBd6neg6wJGFvj1wN5sW3fBkyNwiSIJDikE\nQQqU94jgY2utaTjXmu3tLaaXV6S9AmkDwlqKXhELEe+QAQoSKmFQ3uOsxqkEI6AQksobku6BFYLo\n74AiiFjGpVISgmPldIR1jQaVXLfrQgjIEKMpBIJECqyxxEhNBakieHDVEusPcDIulEmQaBl+LuNB\nZy3GaayXMX08BFyQqKJHIlImheTJ2RR5/x7bvQqxbPG9HrX1vHcw5sdv1pSl5bMzj0wCyzIn1HCw\nE6gJ1E5xfmz50StH1XrqNjAsFQ6B0R6XSNaVYG8gMV4iVMni0iLznFGWkLq4HTnrMKbFmApnW6yu\nsaZFNzXOaKw1WGNxZkkmJ7j2gmw759XJKXcORkznDYsb8r9c0MgAFxdT+oMdQHF19YbVqmAy3mJr\nvMX0coFEMeiVXJ5d4aUkSXPKYkAwmrLsMZ/PkVJSliXz+ZyDg/1rtdZqteKTT37Ae+9+wHK5osxz\nemUPJRWj0YCr+YrlcoFxksF4wmKxQErBaDwmlZLdnR3K/oBy2SKF5HIxBWHZ3hmzrocMexlSJPRH\nQ67W5yyWDd7Do/t3ePnkJ/ggKIqbL8pSxTatd38VBBEkoZPVAj7EwB+lQElIBymcOX7y/IStUU6l\noXAgpeJi6fmF+wovCgSaNFVcWjhIYN1afv3Xfwdn4cnjBX/8O3/IYGtMhmY4HPLTz5/RrGaM9g/J\nJiVff/cBp1efsro0HQcnvsJ+LkhVPFCp5AsIQifHha7o2aAON19yuq1bdNyLgJCwAqyHbRnoBTBB\ncNLAXg6KyNcZIKgFaN8Z8gVY2UAjIJWCvo+HLtGZlQqgsoHjCkzwBAWrmLWLklBpwZqAUNDr75Cl\nGmPBK00hcpK0ZL4+wwpBkYxYmxpjK4IfkKoWawfMl1csiZ0+ay29/QPu3bvHZNhnOncsws0J7869\nVdsKH6LZY5e0LbsBlCJ2kYwAgsDBtc9LEOBsfF9z91v8aw+uOuDfF566bUiTnDRPqRvLcJARg99a\nslSRqrj+t62N5F4Cfj6lNgaVpEgTEEmKx5MpCRbKrEBVV7jdHdTWbfzsJVKlMfQ677F/+4DL6Rzn\n4NXM8c0tj2k0Ki9YVz+HjG2DmvprXnBUtPmAaz2hhTAHnECmXSFYepyQ5InHOY/VJqrdnEJJy+rp\np6gutUDiCSiMcaRpHF+/gdxUbP8GB0q1ccxfHQPHEdXpCNUbdVbo1KKb+Rg8eCMAyao1ZKZisn/E\nyZNL5EcHqFTRLBuKsmB9+bOd7790pQ5CMMmXpNWcYbNGzM5pkz6hLDBpQd531MOUrL/N7d0MXy/R\nC+grRdqckzVLTL1GqoSdtMFkfVSekaFp5ACQvDGw1e+DFORpD6MDu1LjbUVdL0nbCh9yKrmD7N0i\nXPyYgbaI+YJxf8xIrPFlgSo8/usfcTVdUmQKCDEtV8RK1ncmhLEA2ignNlXEjRtaBBdJz9HNWbL0\nvpP0gtAt3kCZQC9RNHVsFSXe4whkqUQ3GmM9jbFgHNrZCDEmkkGS0BJPkHVo40RwPtqAE1BdPowN\nDo+ncZbgYzyFd/HUQgjIEDNIpBRIAcFZrNHUVhNEoFcOUFna9XAtfrnEeoN3nsQYjK0JpjMZueFV\nt5Zaa+pW07YNWrfotqYcH9Dfkly1joOtEe/uyGgjkGU8eXHGmsC6tczmDaU1TGdrXr5ZoBvL+bLl\nzZnh89ct//J3lvzup2tWTQvC89H7Q/o9hbegUomzjp4PfPJ8zeVC88mTFU2rGQ4Sqosat3IoleG8\nxtoaq9fodonVa9p6jdEtxkbFnTaaWjeMxgWjUcakn6FUAGHI+wV7N+RXvjo7xXjFbLqGEDC6JUk8\n6/UZxii2x3tkaY88GdGsA4KMJ09fsV4b1lVFWZaUZcl0OmVrawvfmQy+efMG7z1PnjylaVq++tWv\nYawlz3PSNMN7zwcfvM/+/h6TyQTnPZdXl3zyyQ8o8pzhcEi1XkUFnbX44Dk4OMAHz7Kag3Cs1lMm\n20NW62k0jTOW1dqitWY0HPLJJ3/K7u6INE3o9W5OPE1FR1wM4a88k93HTTxlhxAL2+U8uoFvZuj2\nuKDsZYyHWXR09VFt18s38zwWJmnofEiEoj8s6Q8LigLuPbzLux+8z/37j3j46Ct88P47NLpiaQV6\nFfjej15xddVxWELcFEQEFLpzQQcpf+EoGaSISKoUbJjMwYUbh4duOEHuWl0Ti5dKQNUhxT0RGCaB\ny0bQdIiQlFAqKKSg3BSIHXG6toGZCUxNYKFhrqO667QRBKkosgxjocj7fPOjjznYGrI12gGGfPBg\nElsZyQjPmiIbMm/d9f3piRTnDLpdIpSjMXO0y1muz9gZTfg7f+seV96jcsE7733A/d0JCEHen7BX\n5jedOvF++phpFkOBRQyL7lrq3gtsN3YbZk/oxtGFqPJyTtE0krURtBUUe6do4/E+0OvlKAV5lqCE\nhOCiuW3wSCURUpGkCqkS0rygefoMIRQiSLw1UWHkA6HV0FETRN4j3esRrEYmaTSztRprPI5+lKL3\nBZNc4Yzl6moWxTjuhiqtLxTXgbe5Y76T8DsT0T05CMj9MsYYNbENOAuKXEmC9zgE5+dXLNcNMu8j\n568RIolzsoNjrOvGvTWx8BFREe09MRTUi45qIiKzqnsufcerensP4+t0XXUWvMV5z06esHN0h2Yx\npxIgpKetLFfnc8gk4UsyW74U4dlez/G9lHS4S69MqeczQl0xzjWztiYbjNnupfTrSxoZT9JlnlGt\njxmYgAspPhuhRcuyrUjTjMxpgiwYm2My18OyxORbTMqMmeyj7CmhWeLzPfztr1KMR9hFhQmOVZAU\n/87fY9nUjNopBQMWOmXetsh1zu6tR8hbD8iE6QwH4/oS8/q6MEIRkN3CZEPHFRCdWdUNrhB8JFHi\nEEDehWXqoJBCkQVDlaRY41HBM0pTrpZrdNNgG41JYoGSCWi9v4bvtDZxIjiPCo5cRZdl72NEhPMB\nKTwuCAoERkDiAjJXhNYjg0M7j1JgO+ZSKlKSVNIYgw+BTCiklDitcSGAVAQkW6bFaoMIAkPkBOTo\na++Km1x7A8F0WTHXgcapaGcuHUk5QdjARw+PqFYV/TJhbjx7GXz7g0NckPQHOa22bPcUk1HB+aKh\nbi2DYUJQsLaeYW9JLmE4HGJ8IHMJg8zhi5YiSclKydPLmg/vZsxXFqkFIs95dVozGSUcbQn+4k/+\nF+anx3z97/5TjK4xpu6KswarNcZonDFYa3nvnbssXx/z/oPb6NogpWa6MgzlzSePdhlmvebw8IjF\nfM6wnxEai7JRgpqlPbyTOAfzxZK9vW2+ffBtnJW0jaGfDZnN5hweHrJer8nzHGstg8EwOiuPRvGU\nFODo6IjZbMbdwx0Gwz5Ga9arFUWRA4Esy5B1gxCCQX/A0dERSgQWsylnZ2fIfIAUnuFgSFEoarMm\nzxN6gz4uBFKVsL+1xd7uLk3juLU/olo3TEZH6Lq98bxp2mjMJ0N3GLmutUO3qEaiaQgBYxyzCmRq\ncd0cvTpf8Xn2mqbSJAIsHpFA1TRY61CZQCYK23psDk1V8fDhIybjPj/888+59+AeWzsT+v0+SsrY\nZraOwaAHCparCm06CuwXzF5Ct4THhDreUkXi2es6E0j4+Lngv/Dzf8PLd0TTuDnEabe5z5cicnW2\nVUTCRnmgthEtG6hYa5VJlK4nMsrXTYjriY6dANquCDIuKuKEsEwr+PYv/AJFWpDR0mQJ9/KUo6Oc\nkjnD4R5erEhUyWz9kiLfxnlNkfdQIpAkBY0LGJuyNT5gtrxge3wPKQO//Gsf8Vt/8ILR/h1u3b7L\nrd0R01XFsNeDvH/jueM3EjgXCeFJPOFew4R+g3lvhr27JyrSH7uxjL/Dh0hun557/u/f/4/56L1/\nzjc/+hplv0eeKYz1CCFonaaXSXqZpDaOxgUKAioRtK+fMhQZ3ptuw+4QNqNBFVhnyLcmyDIQbA1K\nIRJJqANCJYyPjsik4P1D0I5YcCUZKhWYen2zwQlc85V8gKSDtET3uXoZyJbRpFEkirBsUCrOLSNS\nFC7+2wfWVUtt4uYqVm8IpAglu4I8x7HCWkHio/EvXQGMDoRERP+7jkpx/Qi5L6QefOH+RB5P9JQT\nTiM7lfHaeH778SvePxxc3zdno5ed+ZID+pcWPE37ktHRL3E47HFZaerxIfLyAj86IiwrSh8YjHbw\ngx5+cYLpT1iGQG/nFvXrz7lwJVulZB5KemmJb6Zobbnwkv10m+G4z9WFoKcdKgFvz+mVOXXtKf2S\n1CQo6TgjMEh72Kqiqa7oaUfY3sGGlNXAkbeaIqmpbEbyzofIN3+OQCGCiDJ0BMJLkq6iD12OiRQB\nH2RMZb1hrg3Od7lcEhviCcAFSaGgDYHWBdJgSLuHZy0V/SKlFFFRkeQJKkQ4OnUxODRXEuUDXkiM\n19y9vcvjl6dknfumIlAoQWs8qRKY4NEmGqwlITbuUkAqiQ8dj6lzOq61iRCjENhgwCYIAVkST77B\nOdZVRVY3NFLgQoSbW6lIfw4Sz+PHP6CpFgyGExrTx6pthMxJhGaxtpRHgXv3R3gU+4OSdWtQNuVk\ntqRMBFu5ZLLVZ+4bPn53wr/97jFV2/LxhzkJgcVc8eBoQHlnD1lWnD1f0CiLSaP0crUwkSzdSD56\nOCTrOf7wByvmq4TkUYpt+6xefYfnV9/hvfk/xkuFaWtMW2NNQ6vbWOy0Lc56pqsrFlXDxLRY7bm4\nXHK0M+K0vuDh4d6NxkZmfdaXM4yeooAi6aGrmocP7zC3lrIcIYRktVqxf7jDer1EuYKT4yvKYkDS\nz1itVjx8+BCtNU+ePCFNUz7++GOurq6w1rJardBak6YZ9+/eZbFYgrCs12uapqE3nHQKwtg6aJo2\nemKFgHWG1XJFXTdgYW9vm8XynNWyIe8ljCZ9jJHs7OxzfDylFHB1fI7Wnt5oF+lhuY6/76aXNTFf\nx18r/b/wXCYCbMAj8CFgumykPM85ONzj+PQVIgiGowFea3QwSC/IpaRICyqzxvoQn4/Os8V6OH19\nwqefTGka+J3f/G3Gu3t89eOH7O/t8ezZKy7PVzz5336fZ6ebFxJx9s2pOVGxTYyUCMBYD7bD4jco\nTqeg2kRNsEGYb3BtNoPQtcu879a2LjtqQXSq3osiG4Zp3KCmJiI7PRU5O5L4d+haYEtiIeQ65CNR\nAhtg0cK3Pv4K7bri8eO/4MFWxn/+d28x6fdR/YLffFkhZI23DUIVOFfjvEA7wzDJOZ5dkaohW1sP\nmC1eI70ly4a09TFCZPwH/+Dr/Iv/6jc5PLpNL1P88KcvONopOZ2dY/TN4logmguKzmFZImJIZhdG\nu+llCKLc2fv4/XQFAL6baRvaiIg+Zs5JjP0R3/30n3D36I/Iq5bd3Ql5mmKsIU1SdIjzqJcrpEww\nPsaKOOsQ5HgT34tykb5ACEghY4EmQdw7QhU9wrqJBYdKUALKnTuoNCO4QCoVvf4A2xqqJlAvbpaW\nvqFwvK0n3s49FyJh2FbEeRqWKAHGCJI2UAcYdnvQqtL03plg9CWmUzp7LwguxpoYKyKq4yBco55d\ni5G4zwYRLVyCENeWDcHHiInrw0JnHSA8eBEP7s5bEg+N86zWNVevn/GN/THOW9rWIYUnhMD66me3\ntL5cpbX1Ac3pmqw/pmVJsXJMlwtEf4+1ddwqBbSOarmg/9VvYVcrxsspzdWaMjXRpO3Ck5UlxtZo\nEkTjueNfIG0fq1tEpun5ltpvM1keM+sNGSaBJQleVOS6T5kl6LWm1y8Ytg0+GxC0wazPmcjI4VBb\nB/TtkuX2CPUq3lIvwltOfmyno7pBt9dN3U7LdUMeBiLgrYnW3AESL7DSIYLCBchSResVddvSSwRF\nXRHyDJKUBE8iFc5aMglaCbSJeVpCRo6OcJ4fPX5NWaQd/OwRIfbdg+hsPTbvSYFzceZ4ZOzdJ2Ct\ni1w5KaOaiYAMjuBFLDBDoLEGhcB7R7OqSI0B34WPyijpvjH8BTz+7v+FknBnZ8Tx1Zpk+1vI/m0u\n3nyH+wfbzFeWvb7kB88vOL1cc29vxCAznGiNF5ILnfJm0ZCl8PmTGXujAqUk01PDy7ljuxCsA2zt\nBhYXGQfv7uDbNbNZi1vDYuYYEViJhKfHAWsaMgqUa/jWe0eYxZSJuMDPA83qDCcKnGvRukE3LUa3\nWN0hPEaD01yt1pSznCLLCFKSFX1qVXBW3cyH597dr/Dw1n3uHOzw+Y8/ZTE7Y2dnzJPHP6bYe0jI\nJYP+EKUsBE+R55ycTYGAd5aLiwveffdd5vM5SZKwt7eHUorpdIrWGu89Ozu7eO84Pj7m7u0j7t27\nx3o9R8qG27fvcDFboJQiCZHHdHh4yNbWhBA8xlomkzGoBEvKaDRmubwgLzLyMsVYz8XFjDTLKfOM\nSTbk2YunWCu5mp6yvT1kMa24vLy5tFhlIubsOPHXuWPdGt00LetVg7EGgaDVLVfzmv4QijLh3Xfu\n8KSteX3aYDrbBu1BG4+Ugbb1ZFlC3caT93hrTFACeMmHX3uPYjjmF77xdba398mV59n3/hgVMqCO\nvJ0ODdjskEkeN9U0ic++th0SIwIkAuGI7tEhdIereGq9aRudt2BFDMLseDfXVvwymipaL9gVgYGI\nDrt5Gr15LnUseooO8QmE6MCbRFKz6VAy7aFpAl979xEg+cHjzwD4T37tHvujEpkYxlspD140nJ6+\n5mB3GyMU3hcMyntIkVJry/7kEWlR0po1UuQ0bc2iOWNQjAguYXdnzD/6h+9j7ZhRkVFLR1aMqBVM\n5I1HhxA6J33RaXND1yLpDDk7K6QObeG66MV1hcC1X2S3a3TgrbWQp/Bvfv+/4G/94j9jZ2ccOVRC\nIhPAe3yQWOvwwSF9pB+0rScVjqAEwrgY6pwkCGsiRzKViLJEPLgHRKUaIhBkgggRj9rZu8Vo8FOG\nvR5vXpzxi7/0S9hmRW9002De2FLaFNuRu9q9zw7VsybKwqOuJzb9mgq0jBylpja0NuBd7AFu0jrY\n+B0pRWugl7wtIpWICM/GVEdYT1AiZt0l8V6w8anq2o/huurp7lX3tHhvEVKgvaSZXXL1/T8j+5V3\nWFYLyqFgvL+N1Q7vfzay/KUFTygnDPcmuLykFCVGn5DmBfPzlxzcuU89e84URbFzi4vpgnR1gSp7\npO0JzehDdtsFz+2MQq4gSfDLU3pOk7/zIedGUFy8ZGQluhhRBcda5mz39lG5ZTukLOoFM7FLpk+Q\nKsqHQ76Pl4b5smGQZzg1Rolz1lrTLhqKZoWXCik6QVsIICQybKpEiRQBgkQQeTeWgBI3lfnRKcAk\niXQ450lRGCEI1uClIsfjZaC1niLPaK0lSRISBFpbtLf0sjT2dgkYGzBpwDsf7fBTGeXyPp5UAoKc\neHg0IXKIBiiWTl87t2oRXaa1CSQhIKXsHEwdIkSfHqTqAgxlZOsriQoCqQ3S1kgrSGUSM8eEpJU3\nh3iCD8xWLaWSXLx4w+TymJPpCikDr1aOr96fcHZa83C7z/Yk5SAvuLhak/cVF+cVH78/wuUpVile\n12u++eEunz++4r3DEfcf9Ckaw2tn2R71UVXDYEtSrRKypOFCXPHRnSFUBqkdATg5L1BZzcFkwKzy\nNDPBSv+E/iLl8b/+52x/6z/CNpfYdA8hc5wTtG2DNRprNIv1goNhj/1BDyMk3/7KfT5/eQKJIBvd\nDH4PIfAnf/KHfJIptoZ9JsMBi/kVg0HBZ4+f8tUPv4aQkjzvdQoYjTVT8iwnzwomO9s455jP5zRN\nwze/+U0+//xznHN8+OGHPH36FCHoCqGE4+Nj2vWCW7f2uXXrFlprZtMpeV6A9Ny9e4erqytOT485\nODigqdesFlNWlWG8e8h8PqPIc6yLO4LVhqIsmM0W7O3c5eTkgvm8oq4NO/t3QKWoJF3xNygAACAA\nSURBVEX/HInXSsKoELQbsuL1V8R1hysmX8dtyfvoyF8kCYcHe3z88V2cbnh9usQEsDIgvSdPFVIq\nnGvQTrFqHP0EmhYGgx5ZEflGAZhsDTt1m8bbgJKSg4lAJyUvH0d7ZSEESkXfmtBxiKzRJCLF+qjI\nElIQ9Nvultj02Ltd56b5fZ5NobX5wdBxiMS1KlMg0AROhaB1gT0pUBL6BErVSdS1IFfRn0cJ6ClB\nJgPax697IagI5HnOd3/4KSA5KBV7I4XIJUk2jHzCYEgTy7xesLV3m9olrOtT0mRImW+hzQqZSBKZ\nsbM1wGpPIiApSuZVRe2W/KO/9wG/8W8Syjzw7Y8f8fmzE0iGZKObq7R8iCZ1LgAumsAmaXgbyeTj\nPIpf7tB+v+H5bJx64lobV1vRtcAkrQ405rf4gz/NkfKfsb01ZDCI7S0pRTx0+jivRrnAmwbfGFyS\n4DvhTOwlmlioEknnvHuIPNiL8kOVEIKLaIhzyGxEWhTsDPu8uVyTIlisFtw63EWEm6GnHb4Y646w\nIQnHr/jQEb67IGVxjV52LuchIEVUA/sgqFZLevsKKaIS0hCLRwnUTjIJb/dHpfx16zB0fLJgia1e\nYpF5jVyGtw998NEvL3SQXPySwwRPQkA7y/wcjHbkSOq1ZvswQYhwLWv//7u+tODxeQ89KFn0S8zs\nJRiJ0w3DrISrBWYwRuV9Uj3lVr3k6uIKtu4itm4hUsGVzTjYlpiVoloaeknBcDCmqiWD+XN6+0eY\n9RmSSIYa+xm9q3PM3V9hev6SSdFnPX9FyDOS4ElkwLuK2VojbcVKpCRuTtLfJVke02SKbLkCv5ms\nLt45768nshJghUBterXCk/hwYy9hHzwqKFpnoo9DiB4/ysX4Cm0tO0Ihk5SZsbRtyzDLWaaCRddD\nVUqREwvcRAiCiLEOXoRoxCQFIQik8NjgkcSHWQGEiFLZEI0GQ1dSp4BFkopA5R2JdSAjgdsSK/g4\nzz1KRG6NDw7lDJW1lGsDUtEIhxISqxRK3dwTw3tBP1MsVwv2D8Z8/b19BpNtnrw+5zf+7Z/y7Kzm\n6+/sMq9qZouGYkdhU8hIUHlG3RheX64ospRUwk+eTbm732dtDNMXU/r9HLmT8vSnL/Eqo2jHVOsa\nW7W8czih8g2Vtlit6PUC41s5+6MxQ23AaB6f1Az7isGBxLucN3/26wSnUfmQdd2y8/CbrEw/qtds\ny9awj3MeLQKmdTydn6GFYjjok6Q3MwHztiZTCgnR6K9fkqU9VvNzTBuoqpp79x5QVStmV5dICe+/\n9xXevDkmy3Lquub169dsbcW21E9+8hPu3r1L27a8fPmS8XiMEPD06VO8D9w+OuL24S5Nu+bP//wx\nk60JQkq8d2htGI+3yFVgd2cLBBRFQSK3QazwztHUNb1+ga0qXr16xc7BiPF4wsX5nDwvWayfUw76\nyMxhvGVn74Bnz4+ZL27mTwRQN4HFSuDtRtT9xYZ+h0CYBt12cScWyODgcJvTizmfffaCYRGfG4Cm\n8aytRxDIsxTnWqSEVCla51ESvvedH5HkcSmslzVPf/KC/ckO/V7Bk8dPYiDhIOe9wxHL1RXSaq4u\nWny35xQpZCnR5FTGtkroCiHRobYhxCKIDTfp57iCf8t5iC2ttx2biDyFa8ND6wJTIWgd7IZAqbo1\nR0AuA60X1DauiQMFaRduShA0HbImrzcOz3/47UcYl3QHqxZBn6//wi5FXpIPD2malkm5TZKWkPQw\npkUqSZKOqNsGY9YUaZ9GLjGmJZEZdav4xjfu8D//qx+j5X2evnyDFj2Gg5KkN7jx+HjfKeN8uLbt\n8E5ctwBDJ/rYkM29iwiQJ8TIoeup1u0NdLJ1ETdgGQLzxW/wv//OT/l3v/3f8LUPv4q3gTSTDHLZ\nHVRFXFerJUpKjI0RPVIEgrPI1iCzFBk6a5S9bUSeE+oqzpUg6OR8ZP0hQWYMx9vsdvw1bzSPP3vO\nVz5+8HNOIt5allzXzQJruOaUfWEaoa7lXOC1BxFYNeY6fV50v1QKiQsC41WHnkHwXXfA+oiedQWm\nII4nG/Wi3xRc8d/XCKqP3xJFOJvXIUlCQIiEK2C21NwqIUkkKMnj7/yAr/ztX/qZb//LVVoV+FfP\naJ6+wV5pwuqEXpbhRiPawZAwr8lXK0okq/4uw6GinQyotKB+fsyh1cjefdLFM8YDx9bWLfLhgNZo\nzsUOYVWh1Q6tzOg5QzrYYdW/y+r4MULmLGVKb7IDzRqy2HKprKPUM25v77K7vU/o7aHTAUaN2N2d\n0LoODRG+uxmx9+iDQHowIiqXPIGYTBUgJGwMjv7G8yYEjDMkXpALgUoUqiNOSRFIlWSJixJSFxUD\nrbdM+kX8WWIlWtvYSFbeowgo58hDPIcIHxB4TAcbJiHyfQwReVEduVAGEGoj/QbpHY3tkjCVJBBz\nahIRXZNVd5rRzkQZLRLnPa7WCKcR3c96a1HWEvTNs7QWiylBOGbLFVenM/7s+8/59JOfkPolX310\ni6+9c4s8F7y4MHgneLA9ZNIrubNTcjTOSLPA+dLSV4AUSCTffz7Fac9uKRgl0JN9krzH0U6BcTVB\nVZTjwPPzmlfPLUJllCPIc0WRa0KouZQ1Z0Fx8GCIJaWqwMscUoX2AussgcCLT36P5evv0CyeYetL\nMpWi0ox61bJerkmTEo9Aecvl+mZ8gxePP2XUK9nf26c/GKKSFERKoKQsBpyfX5DnJS9evMT7CCW/\neXPKnTv3scaRZilSSYoOlZBS8vr1a8qyZHd3F601vV6PJEk4OrpFmiSoJCGEwO3bRyyXK54+fUpZ\n9rhz5y7Pnr8ghMDl5RW3b99hd3cXH2B3bw/vA1fTS9q2pWla3n/vfapVjXfQNJFgb3zFsp5B4qma\nih9++kOss7zzzqMbzxtXx7ynxv1lfAfe0uy8CbjgrhcvKaFaLTk7fYNuZ5wvr2gbi/EAkbTiENGT\np1PuGBsokxj8+f5X7vKVj94FYDiZcPvebe7cvcPh0T3eff8d6pUlJAkyEczmS0aTKLffyNKTJL7S\nREnyJMGazmtnU41AXLl16KCg7s9NeYOBbkN8uzlt/r1Ja79ObvfRl2YZAq8DXFiB8bGVpSSUKtBP\noJCBuQ2sXSQyCwmpkkjgk8ePr//rd3czgrcE7yO6IAJF7wCpMpSUVO0JNizxfolwC4p8gPceZ9Y4\nM2PQ3yMtS5JkC0JAmylFMeHe7RF4zXq9Is3H9MsU5T2Xs5sXy3TcLttxQpyLa29wcSx8iAexKEXv\n2lodb2mDQMAX2iqCa2VsEAIXosZauMf83h/8Y168eEleZDTaUdeW4TAnVQrrHe3lGYlSsSC1DtYN\naIt3FpmkSClIDybREsS4qPAVHfKSFnghEUKwfXiL7VFJmSsaLRlvb/PovYesVqsbj83bLW4Trvq2\nrWrtRqUWuU2uGx9jY0vQeo/dJI86C0i0cTiSLm09FpTWq4imdb9DdATiaG5I7CxEqPJ63m7m7KY1\nG0IsfvzmHnRorrMWQUAHT9NaRkAWTLw3LrCeax596+usZj+b0P2lBU+mT9CJwK1e0SYtzWSMEYp+\nsUUQgUZr0tRSqZxMJKyHj1C6JbQ1olfgJ7vMjl/S232fVUiRMjCfr8nNKXfyJcnuGFsUuPE2WXMC\nNqHvVwRVUCQpvjVUraM36DMzfbJSkghD0t9lKvpor0nNFcX8KTJLubCKvFcQrv02Iwgu/KZnLZEh\n9i6vKTxBgvA3L3i8IxUChKP1jkZrtLGI4EmCJA0BHSICk0gLISCcp7GWtXekgMJTSIH1DhMCQcQT\nlJOCRAiMBBPi4mmJwYexHy0wLhJMCzohi3ekIkrRPXETSGUkk4kQT58qtohRUnbcSUHwrnvNHuE8\n7XIJ3qJc5Cgpa5H+5gXPhVmRJpLLuqI/KsjyhOPlitVcMMigaVd8//mCw52CWwcT1sbw6YsZKfCL\nj/ZwjecbjwZUMjBdteyMevyDb92l30u5c2vMk7Oa9WzF7mTItIqy6DLv0zYBqQQP7o/JSodMElqv\nUP2ccea5OvMIl5KUAw4PthgME9bVFQd3PsYKwWodDQcHhUDqOfX5M85ffMKsWVGojMZKjJL0Byll\n0Yda0yu+FCj9a5c1kQtyeXlJvzegbjTGOLzPGPaHCCH4/ve/z+HhLdbriqa1EBQXF1eMJzs457l1\neIvZbMbR0RE7O9t471ksFkynUxaLBev1mt3dXabTKZswoVtHh5HUt16RZTlXV1e8fPmSfr/PeGuL\nBw8fRh6QMRhnubqasq7WHBwexHZanuGcp2kMWVpwsH/IelWR97bpDUd88NFHqFRw6+iQLBc0zQ2V\nJAChW2wD8Fefyc0eLwJ0i/NmHV8sVvSKER+8d8C/98vvxdZMAO08tY5EfiVlTFEXAuMiwZcEyv6Q\nre1oHln2cw4OdlEqqkFAoE1cjPtFBgJOT66AqHACaGw8QKAkxjsy9XZz2WwoYmOk9cW3dUO303hy\n/oLd/mZj3nBVYlLmW+VLxzPSDi49nHjB2sXvF0QOeCaJbrhA42FpA62BPBFUJkJYh/2SQgkkDXiH\nC0lsc0hPOTiMFhf5mDTpddwWBUFz6/DboHpsTx5h9CmrxQWtrsjSPnl+QC/vo8UhHzw0mEbTHw+x\nIQUl6RU3l6VbG66JyNYHnO/I2H6D5kSOj9iM2YbPEjatq40IbsMZ+eLgB1yIEnbrJFLC//Hb/xmn\np1O0NjQetLZ470iSjPVffBcnO8Klizt36Ew1JSD7GbLfQ+zvQrsgWndLkGkMbM4yQgjotsKKHrPK\ncTKdMZ0uEViEuplPkejUfBspfvw7HiqCDzgLQQe8DddFoncReREiHhZqYwmIznYg2kPgNsVlQCAx\n3Rz0Ib4lKXkrXfSC4MR1sXk9b/8SsrNBMsO1RB3i6yA4RAjkSlKvl/SIOFzTOlSWohKFUIKTx29+\n5jh8acFj+wMG1Qw52UGWGXt+jUtymss5pqlovaLyI5LqHOkXDNwZuj9GKkVh/pyLNy84LFvS2Sm3\nM81g3EemGYd7d0l7O8hsAPM1aVVDMUGWGfNkhAue4AxlH4KecmEKdgeBRozwMkP0+qxDinYZohxB\nPmakPKPLS8zpBUJFFtQXbHZwSuDUprDpeugbaBPRicv/5pcKARM8MvhrQ8BUCqwIGO9ovSfDk0pP\nLhIK4SOhLpZdOBfQPkSiWwdNO6dxBMoQqIMj8zKqvJyLi5fzGO9Irw3EHEZEMpfz8UG2ISYBpyEu\ngCJ4XADjPV4EXIiEuiAkSedNpIgVvGwNermkcA60QXpPayzC3Fxts7XVZ7A1ZjgaUY4L9nbGfPvD\n+zw9v8S4QFt7+mkS7dpF4Kdna4wL/PGzKT8+X7JWGZ89mXN61vCNB/scL1fIJPDBnV3SPGNvmLEt\nHCeXC1QaIdXz04rhsGR7O8WGFuFTlBCUI4UNiqXrcXSvx6SsSCuDQFKqMRmS73/n/0GvGqzTeKNp\ndFwwq2qJ8DVJVvDs9TEZhlJJLi8XjNLAcJTFY9CNxiYWKMNBTDqXQiBEgpLRHfnO7dv0+z3SNKM/\nHLJ/eIixgRcv3uCcZ71eY4whyzLSNOX4+ITxeEyWZbx584a9vT0uLy85OTlBqYSnT59yeXnFbDrn\n6mrK9vYOQgqSJEVrTQiB5WLByckJ8/mcN69eIxCUZUmWZ1RVhbHR1PLs7Iz33/+I+XzFcr5mtWo4\nu6zY3b+N9QbrK8ZbfRbLC/Ls5q1QEBQ9KPK//jxuYpKcjVw0JYgSb2JBI6VkXVnOzi5wLp7kpYBa\nQ7VuIwIS4iaYphILEOD87Jyf/jiiGd//sx/yR//vH/PpD3/IT3/8I773Zz+gthCsZ1nH4Np6dY2v\nA1E0EH9V/JwTXFtXxWiJjq/RcW3CZme96fBsYtfD2w3CbzaJDe/hix8jusTwENVaNnASBHMflW5C\nxKInEdEzbKgEo0SQqbeGbwB3Ryk5hlSkBBHRCe1TdN12h0uBcxWJGkBQBJHTGM1q9RQlJMfnTwBB\nXm4zGuzhRcrO+IAs7bFYV/zqr73Di9efc3n8glFqGUoL5uaxJKJD8v31OMSPnYuFondcF8ohxLw2\nubknHWcn/nzofsfb8fSbIom4rjovceZ7/A//8p/w/OUxlY4q3fGgwM8v0MfHIFJca/DGRMWW93gX\nN22ZKsRWPx44jQUZ+Se+qXDWgTWEkDCcTBgNe9zaGTLs9RgPCqwLmPZm4xOuqwkiTzmI6+cpoifg\nrMBbgTPxY+9ill0gLnGNtkgRrVOs9R0aHqVVG++jxsdCMbA5uIS/NF9ho45726YKXyg4N2PuXWxN\netcVZA68tdgQuGgMy3nFLh0vKwSWqxZjNNY4plc/G/360qNp4j3G54hWspVqZnZCfwCrxYoi79G7\ntYtRBXl/SF6WmLIhLUao+Uuc+iYH5ozt0S7TcoQJGct2yXBvi6nrE5IlxeySNZJhaAhphhcJhXGE\nTCGq5zTmAVkpEPacut1HDQvsi2e0vW1yfUJSBozaQaQFa99g6jlCRGhQdMSsWMUHEicwb5tc8UZv\nHCLDzTk81gdSF718wJPIeOJTIpK4RIDaeYR38YDdOWAJ71He4kIaq1gV/S+scxRSIoInAoaSBI8J\ngSTEVlwiVXdi8VgRBWaOKJ/0xNgJRED6mLcj5WZdjYoO66OqRaKwIuZveSQtHukMzlnEfE07rPAq\nI3GR7xTtMm927e/s8+TJC5JccDFdMJ2uGfdyylLx/HzFN48m9K3h1lbO6bzmeFoxq1uGRUJbO5JE\ncLDVh+D45PkF+9s9/s/vHfNPf+U+bWu5XGuG/YzqfMlwNGDUB+Eyag2DsmSyX3J1taKxjr4IoBpW\nVYq3CaHf0tuXPP3hOf1kQKpSytygQ6CQQK7Is8DLsxW5SvESTO1RSYrwkajXGof3gdm8RY4mNxuc\nIBkM+oggMVbHqJDgKcs+ua84OTlhe3vCyckJbdsyny/wLnoOta0mSdNrSHvjwyOlpGka9vf3uby8\n5OHDh3z22WcMhyNGoxF1VTEe9bh79y6fP3lCVVVMZzNUWqDbloODPay1XF5e0Ov3KYsc52DVaLx3\n5HmfXtkn75V8/vlTCAn3HrzLyZszytEeqhiCqFFJxYuXf8H+3oj18mbqtc21nMWT2+a6JpN201DK\nzsAtAA6cgiRLUE1DkkrKXkJRKPQqtrAAlFII50hUJJF66/AyFkR7+/vkRQZ8l1/+lV8kG/R4/4MP\n6PWHCGf509/6PYILnE8Xncz8L7OLZLfWOBdi/qLvgIMNEfYLGwt+w8/bEEP/5lcgbtD+C1yL69/7\nxW8KXW0E3X+yKdAkKxNoJIxljOXpXzsRvw3WLKRgnMYD1MrCvd0BqYh8J0kNPie4isWq5dC2ONGQ\npznWLXFCEvSCNC2iylFMUUmDSHexdg4emnYBXpOkGUnS5xtfPeA//d3/kd/u3kKaZJT9Pv/1f/ff\n32yA/krLJvgQ+ZGyQ1e6FPVrAnm3+Ue+ShwwsSkkO8RMCGLKfddfiXWPwOtAokCEH/K7f/IvIP0v\n2R7dZ2w9zdkbkAptLNY0SO9j8WkM0ndjnSbw7n3CbAp5EV+60TF2wgucaUhGJeV4myIvsLZlMuqj\nreFoMubVy+XNhuZ6XohOGh4/GZAdh6bbv3j7PhHRkNDR+V75GGHkRDQ+1MZ27UKP97F7ou1b1MZL\n2GQUBU+Ml+h8kSK/jU6WHjqvgK643BCorxWR8X6IjhJ0t1fyFI8i2i6Y1iFHCbo29CaS9EsQ9y8t\neHrrM+ryiObNM7JEk+08ojKwfXCLZvaG0d4OS63Ig8PWM8p0yOvLU7baJV5r2H/I2gb85afMkwOG\nInCSGvb1U8JgCyclwzLnqrmghyRJI59C+RaR3yXd6mEXS3ppinULqmfPGe3u0+/t0M6fUScZ2eoV\nKg20q4TLVU1+8gK5ncfbJgPCC5yQCOkRQiJctDHyMi5UETETXe7K3/zy3uOk7Sy2PdoKcpVQW4d3\nsYVVINFCRDIbAi8FWV6wOx7RVhVeRrg8qMitiY6SnsgoimWyB7yU4Cw6xJwsIaKplutaWBYf1VYd\nqzH4SF7ezCUniCosEadyEJ3Qzwu8sCihwDqcrqnXK7i6oDcYYaO2IIbI3fAKszWLynC3GPHgg0N+\n/PQNL95MyYJgq18w6mWIVpBLQYlgMiwZ5orZuuGq1hRpyjBVHC9qdkY9fvndAy7nC/7X7x2zNUkY\n7uacLCu+/vF79KThZHFB0sspvGC6WjFfToEE7wK1ykiTFCE0MhF4l3JyUjM4SNHLNcsrBcHSTzLO\npxXWwdaoz96gpCgT6towq9bsDcdkicGTooIky1NWVvLu1s0IlnuDgs+fPuXug/dZLOakec7i/2vv\nzHojO9L0/EScNfPkSiaXYrFKqlYvknqZtsdoYMbw+M5XvjbgCwM27N/g/+P/YMAYGLAxV/bMYLrb\no+nukdRyV1FVxSKZZO55lth8EXGSJbWkGerOjfyABPdkZpw4EW983/u+33JJYx3j8ZA3V5cURU6a\npsxmMx49OsVowXw+Z7lccXQw8o7ZZUmSJAwGAzabDVprxuMx/X6Ply9fBmKroalr+v1jnHOs1yt+\n8pMfUynN9c0t8/mCo6Nj4jjhaDLhYDzi8vKSoltwPZ35TVtKNtsNuIg465JlQ7KsoCrhYHLOp3/9\nVxz0Epp0w3CQkWeGcrPmW4j7EFjynsQqQR3s83dMg/B8Sns331qxWzw38yUbBS9fw/XMZ3eMYWdW\n2M0SFputN4azvj1CL5OkqeTNq9cQ+6XwxYtXZN2Cw/6QOEm4eP4iNOJ1fPCDx0xfOy4/m38BZNh2\nQRYeQMXByK7tWN42EBUSL083b+3NDwhr7Y6M3cbbr8ODrPsvWpgT/Gw9XxG/t9zguBGQxvBhFggA\nAqSDLHIUsaCTCFhY7raGUkfYVGM0yKiBRvHq5ZLz7R0K7x1lqWn0lpPJjyhribYxvWLAajVlPb9E\npgVFUpDHBbWtkOKENLnl4Pwp7//xIX/zN7cgBNooVg/0mfFvLHi74Nc625ZJguuyCAo5KX0jzSjy\n6irbkmydw4T2B7uMjoNY+t+1kbdLUI3/P7USvsR18+f81//25/yPv/gZ//pP/zN/8pv/iYpPvels\nZdCNIk0S2GyJen3ivEPxr34Gg8ITmvMCR4yLHFJt/YtOO+CgriuGjx6RfJpjXI2QMVVZsV4/DPC0\nKcddB/Lw/ggHB2UsupK42P8wssLzjZ3fR5QxbJwiEpLbdYWuHUaJXSbIGondCGpjMbXvV2ZiSeQJ\ns35M9b0STuDvKas878zqe2BjQ2myTQpZws9TD0hrY6jr2gt+tEELsApEEqEqxejw61vafOPRPe8f\noutbRqMutijYWE0nzYm1oN8Zs2lq6uUlZV2ySUa8vLkhEnDST4jyDDZL4n7O9uAH9MYTkqRisHmN\njSTJZuHN95opeV5g4hG2XOHUkp5Z0tlcUV7WZElGJx1hbUw2PqKsGxarGerpH9E9fIf08F2Kx/+M\nlXWkzy+IdNjQBQgb0UoMnfAmSUL4yQr3bSdwNvgO/OPDOEOEJMLirO9RVRuDcZbIWBoEsfBGgB0B\nwjis0gg0jarZBGCirfUEYX8G8CjZGVI8f0cKhzY+lZjgdpb4MpxWnLOe74PbOTL7MlV7wmsvst0t\netpoL8v02J0meJBKpXB1g6gUdV2hdINSNU4//KR+tdxycDxgui35u19+Rg9LJ0t4MipIpORXr+bU\n1lJqxa01dPKUx5M+7530aZShwbGqGn5wdkCaR/zm+Q11bTkbJnTSlMrWqGFCnjneLFbcTRXrVYk2\nJZiGTpbhjGXQS3AoqtoiXII1lrrUJEKSIUnilO4kI8kkd8uK4bjLeNQnjqBsFM4aGqXIkoz5quS2\nitnohq60rDYVkXCU64cRLK+ubzk4PmV6e4OMIuaLBduy5HZ+x3o54/H5Y548fcJms+EH3/8enW6X\n1XpJt5uz3W4otxtWqxUnJyesVivu7u4YDoecnZ3R6XSoqhrn8CWx4J7s+bPeY+fnP/8FCMl3v/se\naZpSbksvc18uUUpxdnZGWVUe7AgP0qM4ptEKbRyLZc2Liyvmiy0ff/wptlrSywy9TszB6ACrJUnU\nwX6LzCCxoCmhaeDLsGBHKg0tVN7OZmijGfRT3nl6yPeeTUhjEe4P/zdVrXx5A7+4SzxvWDnH6PCA\n8/MzAB4/OeHZd59w/vSc73//+3zwow9AgFKGujFEX+EPI6Q3uNilbNqzU3sjuvDPQmpf7G7Oh0VL\nDG2f2oV6y9sfd49QYtidkEMJQbTZi/BaGiW4DI1aBQH0CEgkFBEcFPDzixmV9WPl8OtPJOCdZxPf\nMy0tEDjSOCGKB2y2C6rqjbcmqNd0iwOUi+jkh8hkRLd7gGqUL4+oDZta8B/+7Z/5yx9DFLeE8IeF\nbUnK4b215SjfNiJwVlyYMO0laHf/oCBqyd9+zNrxE8H3TJBEkKSixQoIB5taUpWwWfwVv/3sL7i7\nXVFr69VJWnsFnDakIiGOE782j3q4ukEmWWgsrsBpiDPQDc4arJAk2YB6tWB0cEiRpd5lXDpOzk8f\nODpvI2Cxe9sQiNnO0yxs4OQ46yXq1vi5U2tLGUr3xhhsYwjnc2zg4hjjMOFz28rNWy6bxXN4HKF3\n2f2jndct0dlaEZ4nlCfb57J+F1PW0R+PqPD2K+38j4QjSiRl9fWc02+cVrKqOYsaEpciZYezIsZZ\n2FJTqppyY+jUNUm3YL1awKBLX1o+55jZtmJtM6qoR8eVuGqNKc6px98jzVKa/ojldI7MDnHNFrd5\nRTeO6IguHDymGQ445reQjTAkGGcg7zPoRGRxiXrxGesmo5I5U9EhOn9C8vIlUSqIpG+siSAsRAHs\n4Lz7ZrgS3ptHhvT07y9k3zh9rKDS2iNja9DWYjE4q7FxTOEMS21QxmKcwUqHA/CXNwAAF01JREFU\njL0fQ7fo4gLj3MvbBWk4iUgs0jqMtcTCkTmfjUm1pRS+9BEJg3GGDEPSrgzCobBY50stlXWIsCkI\ny87VWToXGpE6L9u0jtT5k521Dls3Xr3VaP/QgV7/wOgVGR+cHPNnP3nCz376hLODLpNhzlY5DoYZ\nxSjnk+sFf/npDbPphk8urvnFiwXaSN477HJzNePp2YSnh30maYyIIxrjiLMIY+B0MuLpQZ9VaVmy\nwXYa7NpyNzOUVcyrlw2dLCVLU4SN0MrQWEelG6raIKRgvqgosoIOktnC0MkSauVII28DoIxjvWwY\n9HO61nJ8NGCYKBIT8Xqxpq4aJsWQN7cPO432iq6vizvLbLmg1goRR8g4Zr6c45zlk08+JYojZvM5\nn3z8MWW55ejoiG4n5/T0lKdPn7JarZhMJkwmE4QQzGYzjPEmgsvlkm63y3K55OrqCuccw+GATp7z\n6NEpm+2WV69eMRwMkFKgtGK73bLZblmv18GRuSSKJEkcUzeKplFcXHzOfD5HKcXnn79k0B9wcjLk\n4sVvWa9WXL25Jcu6ZFmPKHqYXB8A64hifMfrL0XLORBhw2rPKNZAlkqODvucHg/pprkvGznQRjDo\nQJxlJHHkGxj6EwNKOWK8WVyaehLoYlGxWZcYo1HKZ8esharSVJXPEIZXcf/CnPVlisgv6Dsucmsc\nhPPl8/D93Wb+YNDjwik4rGFhA9staV/1CETZ9rEDR6Fk46zjRsNLL7jbteMBv4F0k4iFapitGoyu\ncNrgJRTw4QeP2FbXpJ0h6+oWbRpf2rMb+r0nJElGmvYZ909Jkz7WRIjIr1HdbERd35GnPcqy4k//\n5D1/LQMxmweux3DfHsJaew/qwgZrrQj+d63XDrs+TdaEDt+tWii8fxfKkR6QhHGWHpSlqSNK8ApS\nAdtSUjfw8qP/wjqd+EN2U+9KZKKpEZH0gPLRANEfQN6FNA9gxyET3+7FYiDrIbCITGLqmkhmzOdT\n0rzLerUmiR8mlID78WjVgS2VzLlWxdaCHOn5PMZzeaQUNNpQKa9gVdYEYOnQbYYn9ClzeM6YM97r\nqE3RuNBfq+Xk2NBu4h6ch/5ZuF02yLTE85Bx88DH0I0kjXNcAJXxPmvaeEW0UeYbWz9+M47OMrb5\nIVtbQ5rT6BhnK0ztmK23JFGJfPSE9bahkwuy6WtWKmGSKuIsYfir/8705QWynMPtc1RTYW4/RnUO\nybFIM6OrNwxlQpF1uZ68R3TyiKipiTpnyMc/IZaKJtJ0IkdXr1DFGBf1ifOG5uqXOLUguvw7ZLlA\nfPS3yDgO5mT3zpotIc0fzjzEEYgAikKS74H5dx2o5knr9+AM2kCEJLEKZRyZsxir0U7gFMGJ0yHy\nFGxQjzm/uVbC10+Ns97PR0gavHIixqEiiJ0lwfdYkkJihHcwMoHcJcPpTr+lQnPWep8j50jxwMfg\nJ6J2jkx4xZa2xsvVjUKvSzabLUJVKNMg1EMZTrDaNPz9q1fYtEvvcEjv0RHX2y0fv55SrWu2i5L3\nj8f89J0jjg66DDop0tZcrEr6/YJ//sMnXN0uubhbczwc4Iyl1IZ1ZZgqxaXasOlKVLpk3E8Z93oM\nT3oU/YRUZvSHEVkR0aiGJI0xWuKU50kNBzmVgU4nYV2vmG629HoR0jl6UYRSGusiRkVGp5+xruDg\n5IhRV/K7qzW9YUqv6GFMxWev3+AtQx8yNiVSRoxGY5I0ZTgaMTmaMBqP+OGPf8LNzQ3dbpeyLLHW\nMhgMODg44OLzC548eQIIrq7e0O/3MYHQrrWm0+kwnU65vLzk2bNnpGnKeDymKLpo7VVXSeK5Y928\nw/njc9brNUmSkCYJVVVx9eaKpmlI05TzJ0+Y3t6yXq3oFQXGGJRSdPOc0XCI0Yrp9JrD0Yh+0aWp\nNcdHj3AkCBKq8uFAmchvTFr//pxrWzFYPPG+BUDW+ZYUN9M1z59fc/VmxrZSWLwEO0+9QVlZG+rW\no8u4wM2A2fSO3332HIDZ7R2zW6+GuXpzycXzC1+esjDs5XSy9lq7t18YUeRFAFK2i/h9lgFaLx53\nz1t46/T7j442c/F7ipYHPNUODLXZModWjqmG143YvSsBpAK6keU4hY9elyil0NZhRYwx0Fws+Omw\nj1E1g3xCvzshjQoEA6IIVFNyefdb3kw/Yzg4RkaCzeqOsl4ihCWNOwgZEaVDRodDHp8Xu8ydc1/z\n+r8h/MHVZw1aEi1Bet4CKc85EVgdms7a+5Kkcc6/v5ARE23pxbkddhT+/EySCJIMZGyD84lgsYau\n/B5J3gcLpqr9ZTYaYS1CCqQ1JB+eg0wQ1uK2a08ZSDLQCr26Q0RBoWYUUqREWZdOnnJ4MOL6zWuS\nJKJcf337hK+97i3Q4x4vg89gGdsCkhZ0hu+1pSXjUKYdZdDKeIsH7Xk+VrXKLj/Wxhcu/N4a7BJw\n/uzsCcneUf0LJHvnR9zY9nrxVtYHwsaOwZEEwY1SdrdeLG/XQRDw9XfDNwKe1RIWjcQmOXnqS0Wp\nqtDLVzw6PWQkDUmkqTeazc0bRLNk0yy5MiNGh8fof/lviPSUyCia0WMSXZFlxyT1DKUqiuIRlXA4\nl2CkYnj998yWKxoZ4fSWjethXELWbJBJjhIZ5uYFVa3Ixu8y6Q9JygV6+IS+jFlNb99SF0ThJnC7\nBqFGeE8OK+7vJmcFRkZ8xYHym+dPWHVUMCywQGI0whrqxiuilLGkIsJog5QWaUEojdSKRniQ4WgJ\nzSGrEJj8yvlMjxTWd1EPpC/rDKlzaGv9Zhdke5F1vg9Xe/c65yXx+NMLgHCKKvB9otCqonHO26Fj\nqa3ykv2mQTY+05NrS/MtMjzvnk5ItUM0Daqy3M42fPZqzrvnR9Sx5IffecQgi6m1Y6ktJ5MBnTwj\nEymzsuKj1zO09b1nImno5BlRJJluKw6OOgyGfZxTGKtJYp8BWq9KnHb0DmIm4x5X12vyIsMoRx4Z\n+r0EJy3aWOJIsakabASdSJDEGSJOiGLB2WHPuyuXinWl0ELx6//7nKZyjHodPn55BU4z7MR0M8ns\nbvmgsVlvt9R1jTGG8XhM0zS7r+uqYr1ec3FxQZokRHHMycmJr1mnKVVV8tFHf4vWGikldV0znU6J\noojpdMrBwcEOBBVFQZIkDIdDjo+PyPOcu7tbVqsVFy8uuLq6Js9zoiji+nqKlJL3338f5xxVVfLr\nX/2K7WYLQjC9nrLZbDk7O/eLy3KFtZY4TjDKkHd6NI1jW1peXlwzvZnT7XQfPG+Q/kQZxW/Tgn20\nJzdn7W6zB8IiabHWsNpUaGPv53zsZdZ5EnkehvJMFhsWVaNhcjTh8Mj3Q5scjXnyzhknp0c8fvKU\nD374vt+8jJ837ktLpiBkZoMlP8KXPkLroZ0x4L00NwAN+a0SPLu/+TJhud00dpkN+NJm8tYJ/63n\nA/+ajILLxvGyEeE9+B/FAg5z+Oy2RNmEWuvdTjRbLjka5DyWkHUeM51/jjEbNs0CZQQyGpDFHaRI\nKasVCOj1jsmSgk42CYZykkgmkJzwn/69N4yzPJzQfT8mLpDFv1i627kqh0y6MoSO9T6D0IJIyT3Q\naeX7u7EMH6X0/d6SxJGkgih2iNjPp4vF0vd/qhqk0gitiKoGaQIh2jryJ09wxvoniYOjtG5w1hIn\nHUTWwd8IEUmnQyQjlssVTW0ZDCdEaUqneNi9FTpT7TKBLryZnbLQBvO/kGmxxnkwEzI11nkzXS/X\nN9igGHY63EtGeCsA7rN0xvqeZl6O7ueYNQH0BPDjAvihLYNZ349rB7xMIDN7/0YE0E9ilputd3Z2\nnjTdaINWGt0odPP1rSW+EfA0zS3CKJKsC8kBiTW4NCcanVDeTKniPtYakuiWYjgifvpDCjbUs2uW\nsxny5hVJ2qc3OmWSGIRdk2YSmR6gSkWdFLj+MbIjiZIRUZKQLF9TV5o47xCbK4RaIodHzNcNkboj\n7Z0gTcn8bo0tJtSyIJ89Z11XYHypwi88Xpm0S2H64lbw5PGZHc9xEcSAeyDfQFmv1LLWoowhspba\nWZTxZoZS+ayK1ooURyMsifQy8SKWJJGnSStjvaIiSM1jAXkgmybCO11WeL6REr7JqBHQWiXW1uCc\nb2TqcGT4dhEa3xW+dgbhvGFYYwXC+lKKsobGF713J5gY3zZAa43SCls1GKsfLLsG+PziJYk1vLpa\n8OZ2QSJinp1MaLRj1EmoGsXFdMP/+vSSVMOPz444HqScDAS9JOX9R2PORx22wvGLyzucgNW64cPH\nE+KOY1NWVNstVWm4XWi2m9KT7IxjkKYYseVwnLNeaISV9AYd0lwSi5giM0gjiE1Bss2JhCCLYdzP\nKboxK6PppAmDIDk3lSaRMbNVRdFJ6cYZlRNUQSGwTh7G4RmOx6zXG2Z3d7x+9Yp/+k9+ugM819fX\nvP/++zx79ozVek2vKGiaBqUUJ6en1HXN0fExBwcHbLdblssl5+fnrFYrPvzwA+I4Jk1Tqqri9vaW\n29tbfvPrX7NarViv1x6gGMPh4SHWGvI8Z7lc0O/3GQ1HzOeLXWZoNBpzfHzEoN9HKUVR9FgtV6Fn\nV0Wv12W59D21yq0iirt0izHPnn2ItRGz+cOAIICrfU8f1XxFCiR8aYz58hEV8KWwo0nBwTglToPB\nnPOeKVkWE0cRnTzyfDUBSQTdXCIjQa9XhOfWaOXd040xqEZhnMA0sF5sPQnz7dcLO7U4+PvZ7jYV\nn2EQAohARMLf4C235qFjw1t/Ez75AjBoQdA/UBH6quyJLyvAVDnetJyeAHzyGF7eLJiuK3CSpvYu\nhcmgA2/uOO2u6Kopadqj2x0wHr6DlBHXN39LLz+kNzymrNcIp7i5fYWMM+arK6SzKF15MKkUf/an\nP969j2+T4XHtxQjmkjZksUHsJPvGBGJyKHHtMhvsLkuoB/jsoJB+094BBXyGR0YQhyxPnAgiKYhT\neGM+pa42sK08Z7NRQbjvM/ByXCC7XZDSS9IjCdbipDcGdUm8487hIMoKjG7Iuz2iLGO1WVFuaj79\n3eXDxgZxDwBFAICBP4PzdIdG+ayXCzJ+04JB5w1ZdbA4MMaGA0AgF2vhjQu19X8fgIoJjQ7alhW7\ncpkBAqD5ovGgB1u7Mpkm/L8gk9f+Is8rv0clQBkMM8vG4SJJnKUk0dfv5d+4y0dZxmh0SCKEV+yM\n+6zn10QuJnYaaSpUEzEYnYLsQLWiO3mPfi/FpRlaCAZmzcJY4sE7GDliLcdcz+d0jh6zXjcIo6hs\nQiRqmkYS9fs4GZPrDbdVgs4GXN3MGGQlpphgo4TkYELXXnOxMNhOF90dI1/+jjgsRg6QtlVggRCS\n1ne57UnlrPQ4PpS0nDAPmkAYiK1BWE0iHFo7YusglIYaIYixAWgYhNY4bb3DsrE7XwYRiXtvBus3\n7No6UoTn/+C8lM55eXoKKGvQRmNd+zMbpPjsnJcj60nJrYW5ILS+wEvjYyfIjcNa5XskhcyQMwqr\nFaLR0BiassJ9A2L+unjvfMLRZMBqveV2eocoS5R1bJRisTXMS4VII945HZLkMbPtAuck863i7LjL\n9WLFstZI3bBdKKbrkkeTgjhzZIXjcOwQkWa7LanKJWnHkaZwkEleT5f0OwdImVHWDVvVeCdZ5VAN\nvL4pmZwMGQ1jor5j0IuptGW1KdEOVsuKvJexKRs6RU4nSSi6KdeLLd0IJsMueRxzPTMcTCaMew+T\npc9Wa7S1HB4eUhQFv/zl/yFNUppGcXJyws31NZv1mneePuXi4oLLy0vOzh4xn888R0JrtDZIKdlu\nt0ynUx49esTPf/4LVqsVh4eHO6ffXq9Hnud+gUgS8jxnuy2RUUSv12c+n3NwcIiUEcvVim63w2x2\nh7WWqioZDPpstxuKfh+tDdc3U6I4ptPp0ulkSCkoej2KwYAoTZkcPeH15S1P3/0u50+ePXjeiMB1\nEV9RYr7HN/f8Cr8JCUScsFkr3+zV+pKw38T8aTGWvkWLDeAeC7GUlI1lNp1x+eoNAC9evObN5Q2f\nv7jg1cXv+M3f/T0m3Dte7fn7S6bAnzIRnswrQypGCHGvHrA+o+CM33FabseDxqb92IIl98XHrhz1\n5aFzX/zY8ja+HC5sYK9rx1z7TV/i/cUmPfh8DsY2ocSgqRX88i9ew4sLzqMKozdMb35DFhcIGXF4\n8C5J2mez1URSMFtNEZGiWl8QxwJjNTLqYlxJYyN+9KPvMxhnD0eCIezbm6e/HIGkfV9i3IEh6/ld\nHgC9Vepqf4+Wey7CtQzNMm1b2hLIyDdhTjKIEkecQBnB6+UtVmm/juqwtju/7qbvHsGgh6vK+wsh\nvB2JF9i0gNhAlCBw9CaPwJUUnQxTb4kjR6/7sP59/v24wHFqvXJCuakt6YXx0PaePKyNTxUY63Zm\njs4alPEVDBVAkgnZIGscRkuUEhgtvbeQdiGb016fwJfSYtewtR139xZY9aXHAJR2JbLQkqzR3h5E\n+/+5bQxREqMbS6//9dmvbwY81rJWMTLNiDZz6ukGVxyynd3RDA+pkx5qNaXjHOu6pApMb1Wc4IjZ\nxgVbCV29ZVvesXVrcnXLeHuJNTWdXoKsVhwOBySDU7qjPp0kZ1BYytG7DNUCIQyFW1H3zkE0ZKZB\n64x4/A6PxBs6AUVWzy/9JBTsTL4S6031HH4hEjtPChCRVyjFO07Pwzo7C+lRrtK+tBQbQ2O8EWFl\nNFZXviYbOucK4yitxhlLGse08EpoSxQmX2V9LyJrfXYojmWYYNabVjlLbY3PSOGzRVgPipKdosCv\nfs5ZZCBrW9f+jj/x+L/3XCGcv87CGu9Cay2xVVjrm5uiDOYr+BT/UDTzFetFSRFBZSwXt0tm8xl9\nIqKo4fnNmnEW87PzA1ZbRa0ch/0M4yRv7jak3YxBLklkxFEv5zsHBZNezitniANXKUpyRGIYHw7Q\nJseQMVeSUS9lvVnTNCV5LjBKoGtYrx1JLJiMOpimQokYjKQ/Lhh1MuJOTp5JlBXoqqEjJFWjUdaR\nRRnDbsyiNEgRkaF5epxwt1jC9uZBYxP3JmyVYLlc0DQ1h4eHGGMYDgdcXV35LuZxhNaaoig4ODjY\nEZLTNCVNUwaDAYvFgqOjI77zne/w/PlzHj9+zOnpaShJVQghMMbQ7/dQSrHZbCnLiuFwSNNU3N3d\nUZYVt7d3SCmw1nI3m3F0dIwxhtFoRL/Xo9yWFN2O77EVJzw+O6OT5yyXS/r9Hlknw2IZDEb85V/+\nb4piwCef/I5fffTrB8+bHZHyK6Zcu49bbfkCB8aBUZpuHrPZKO5m1b1LrvUGcxp/mFDaZ37bClMs\nYXI84d33ngLw7rNzzs7PePrsKe9857v80R//ePe/hHAkyVeQqYMCxquiwucmZE28B6r/PSk8kGuR\nxLcZmxa0fAnsfCH9Y7/ie299fBvsfLnM1Zq8vaoEdXjdsYR+Ap/cQa0tRjV+k1OKWvb47Scr7MUL\nRrdThoN32FRXWKsot3NeTT/F2ZJar+nmPdL8kP7wu6RphnE5cVwQRyO0rnAy4T/+u38RxvRbjI+n\n+WJCKcrulD6tIev9xqoDYvSbvQiKq/tBazuBO0DKIKR2bfnGP4+MBEniMztJ4itUaQJ/c/UpThlc\no3eZGhlFSCnJ33vsmzerCmc0tmlwUeKVsKqBOMNZDQjfMb2pUdWW4cExMYKT08ckWY/j479+6OC8\nzXN/q+Tpbzgb3JJNICybFmTssl9eSGOd75autb+XtAarJVqDtpbGuMCNEujA3m9LWTaUEN92ct5l\nctqsTgu+9H17C9dyekKhIZUSrRok0Gif4VHGUVcKrTTw4muHQbivgvr72Mc+9rGPfexjH39A8S37\n9u5jH/vYxz72sY99/P8Te8Czj33sYx/72Mc+/uBjD3j2sY997GMf+9jHH3zsAc8+9rGPfexjH/v4\ng4894NnHPvaxj33sYx9/8LEHPPvYxz72sY997OMPPv4fdIUDdUTUECYAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "olF4PpORpCTK", + "colab_type": "code", + "colab": {} + }, + "source": [], + "execution_count": 0, + "outputs": [] + } + ] +} diff --git a/resources/examples/ipynb/models/reformer/machine_translation.ipynb b/resources/examples/ipynb/models/reformer/machine_translation.ipynb new file mode 100644 index 000000000..8e4745f5a --- /dev/null +++ b/resources/examples/ipynb/models/reformer/machine_translation.ipynb @@ -0,0 +1,380 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Reformer: Machine Translation", + "provenance": [], + "collapsed_sections": [ + "udDs_biH0n5U" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "udDs_biH0n5U", + "colab_type": "text" + }, + "source": [ + "#### Copyright 2020 Google LLC." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WPY-OyyM0pSs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "psnUF-8c02o_", + "colab_type": "text" + }, + "source": [ + "# Reformer: Machine Translation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1lnRd_IoERdk", + "colab_type": "text" + }, + "source": [ + "This notebook was designed to run on TPU.\n", + "\n", + "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8PluCmWbZIpJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install JAX.\n", + "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", + "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", + "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", + "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", + "\n", + "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", + "import requests\n", + "import os\n", + "if 'TPU_DRIVER_MODE' not in globals():\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", + " resp = requests.post(url)\n", + " TPU_DRIVER_MODE = 1\n", + "\n", + "# The following is required to use TPU Driver as JAX's backend.\n", + "from jax.config import config\n", + "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", + "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", + "print(config.FLAGS.jax_backend_target)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yiPdBenoZwH6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", + "\n", + "from tensorflow.compat.v1.io.gfile import GFile\n", + "import gin\n", + "import os\n", + "import pickle\n", + "import jax\n", + "import trax\n", + "from trax.models.beam_search import Search\n", + "from trax.supervised import inputs\n", + "\n", + "from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder\n", + "\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "\n", + "from scipy.special import softmax" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "uCX88z9iXB7s", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install sacreBLEU\n", + "!pip install sacrebleu\n", + "import sacrebleu" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FQ89jHCYfhpg" + }, + "source": [ + "## Load WMT14 data" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8S3h28Q9b_9B", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Download the newstest2014 English-to-German translation pairs\n", + "!sacrebleu -t wmt14/full -l en-de --echo src > wmt14-en-de.src\n", + "!sacrebleu -t wmt14/full -l en-de --echo ref > wmt14-en-de.ref" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "CBv2SDnWZEI7", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Load the source text and reference translations into Python\n", + "refs = []\n", + "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1):\n", + " if line.endswith('\\n'):\n", + " line = line[:-1]\n", + " refs.append(line)\n", + "srcs = []\n", + "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1):\n", + " if line.endswith('\\n'):\n", + " line = line[:-1]\n", + " srcs.append(line)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "CbYw4eMXZGKa", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Set up our sub-word tokenizer\n", + "tokenizer = SubwordTextEncoder(\n", + " 'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2NbOslppZGZ0", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Encode source sentences using the tokenizer\n", + "input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64)\n", + "for i, x in enumerate(srcs):\n", + " x = tokenizer.encode(x)\n", + " assert len(x) <= 127\n", + " input_ids[i, :len(x)] = x\n", + " input_ids[i, len(x)] = 1" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YwzU64GmZTb2", + "colab_type": "text" + }, + "source": [ + "## Load the pre-trained model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "VXjtCPxl3I82", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# We'll be using a pre-trained reversible transformer-base model.\n", + "# First, load the config (which sets all needed hyperparameters).\n", + "!gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin\n", + "gin.parse_config_file('./config.gin')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "IediBe8MXyLf", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Now we load the pre-trained model weights.\n", + "with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f:\n", + " model_weights = pickle.load(f)['weights']" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zY3hpgnI5Rgn", + "colab_type": "text" + }, + "source": [ + "## Beam search decoding" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fc_VlhrBYW0u", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Set up beam search.\n", + "beam_decoder = Search(\n", + " trax.models.Reformer, model_weights,\n", + " beam_size=4,\n", + " alpha=0.6, # For length normalization, set to 0.6 following Vaswani et al.\n", + " eos_id=1, # The stop token has id 1 in the vocabulary we use.\n", + " max_decode_len=146,\n", + " )" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "bynTpreMYXPs", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 71 + }, + "outputId": "cfd24e01-617b-4beb-a5f2-98a7ce2e1449" + }, + "source": [ + "pred_ids = []\n", + "preds = []\n", + "BATCH_SIZE = 1024\n", + "for start in range(0, input_ids.shape[0], BATCH_SIZE):\n", + " print(start, '/', input_ids.shape[0], flush=True)\n", + " batch = input_ids[start:start+BATCH_SIZE]\n", + " seqs, scores = beam_decoder.decode(batch, batch_size=BATCH_SIZE)\n", + " # Select highest scoring output.\n", + " batch_pred_ids = seqs[:, -1]\n", + " pred_ids.append(batch_pred_ids)\n", + " preds.extend([\n", + " tokenizer.decode(pred.tolist(), strip_extraneous=True)\n", + " for pred in batch_pred_ids\n", + " ])" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "text": [ + "0 / 3003\n", + "1024 / 3003\n", + "2048 / 3003\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "c5Gq4qF_YY2i", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "37a5e24f-9264-4d7a-dd74-065758c9a7e4" + }, + "source": [ + "bleu = sacrebleu.corpus_bleu(preds, [refs], lowercase=True, tokenize='intl')\n", + "print(bleu)" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "text": [ + "BLEU = 27.86 59.5/33.5/21.3/14.2 (BP = 1.000 ratio = 1.020 hyp_len = 65943 ref_len = 64676)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "olF4PpORpCTK", + "colab_type": "code", + "colab": {} + }, + "source": [], + "execution_count": 0, + "outputs": [] + } + ] +} diff --git a/resources/examples/ipynb/models/reformer/text_generation.ipynb b/resources/examples/ipynb/models/reformer/text_generation.ipynb new file mode 100644 index 000000000..7465b0023 --- /dev/null +++ b/resources/examples/ipynb/models/reformer/text_generation.ipynb @@ -0,0 +1,544 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Reformer: Text Generation", + "provenance": [], + "collapsed_sections": [ + "udDs_biH0n5U" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "udDs_biH0n5U", + "colab_type": "text" + }, + "source": [ + "#### Copyright 2020 Google LLC." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WPY-OyyM0pSs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "https: // www.apache.org / licenses / LICENSE - 2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "psnUF-8c02o_", + "colab_type": "text" + }, + "source": [ + "# Reformer: Text Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1lnRd_IoERdk", + "colab_type": "text" + }, + "source": [ + "This notebook was designed to run on TPU.\n", + "\n", + "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8PluCmWbZIpJ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Install JAX.\n", + "!pip install --upgrade jax\n", + "!pip install --upgrade jaxlib\n", + "!pip install --upgrade trax\n", + "\n", + "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", + "import requests\n", + "import os\n", + "\n", + "if 'TPU_DRIVER_MODE' not in globals():\n", + " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", + " resp = requests.post(url)\n", + " TPU_DRIVER_MODE = 1\n", + "\n", + "# The following is required to use TPU Driver as JAX's backend.\n", + "from jax.config import config\n", + "\n", + "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", + "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", + "print(config.FLAGS.jax_backend_target)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yiPdBenoZwH6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!pip install --upgrade -q sentencepiece\n", + "!pip install --upgrade -q gin\n", + "\n", + "from tensorflow.compat.v1.io.gfile import GFile\n", + "import gin\n", + "import os\n", + "import trax\n", + "\n", + "import numpy as np\n", + "\n", + "from sentencepiece import SentencePieceProcessor" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FQ89jHCYfhpg" + }, + "source": [ + "## Setting up data and model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9_OCIqghSyfs", + "colab_type": "text" + }, + "source": [ + "In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of \"Crime and Punishment\" on *each* of the 8 TPU cores (over 500,000 tokens per 8GB of memory)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tYSOVGR47LVL", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Import a copy of \"Crime and Punishment\", by Fyodor Dostoevsky\n", + "with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:\n", + " text = f.read()\n", + "\n", + "# The file read above includes metadata and licensing information.\n", + "# For training our language model, we will only use the actual novel text.\n", + "start = text.find('CRIME AND PUNISHMENT') # skip header\n", + "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip header\n", + "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip translator preface\n", + "end = text.rfind('End of Project') # skip extra text at the end\n", + "text = text[start:end].strip()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "mMntV3H-6OR0", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "c8d4386c-cf5d-4dc4-92d9-24391fa2f30e" + }, + "source": [ + "# Load a BPE vocabulaary with 320 types. This mostly consists of single letters\n", + "# and pairs of letters, but it has some common words and word pieces, too.\n", + "!gsutil cp gs://trax-ml/reformer/cp.320.* .\n", + "\n", + "TOKENIZER = SentencePieceProcessor()\n", + "TOKENIZER.load('cp.320.model')" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Copying gs://trax-ml/reformer/cp.320.model...\n", + "Copying gs://trax-ml/reformer/cp.320.vocab...\n", + "/ [2 files][239.0 KiB/239.0 KiB] \n", + "Operation completed over 2 objects/239.0 KiB. \n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 4 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "HnJzxSi_77zP", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "f8b2050b-0233-40e4-88f1-e546a1541b31" + }, + "source": [ + "# Tokenize\n", + "IDS = TOKENIZER.EncodeAsIds(text)\n", + "IDS = np.asarray(IDS, dtype=np.int32)\n", + "PAD_AMOUNT = 512 * 1024 - len(IDS)\n", + "print(\"Number of tokens:\", IDS.shape[0])" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Number of tokens: 513812\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bzQ7G9uGSga5", + "colab_type": "text" + }, + "source": [ + "As we see above, \"Crime and Punishment\" has just over half a million tokens with the BPE vocabulary we have selected.\n", + "\n", + "Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.\n", + "\n", + "We have 8 TPU cores, so we will separately randomize the amount of padding for each core." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PdAwmpS220ub", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "c0919b3d-4c63-4d2f-db44-3aeccaf4d966" + }, + "source": [ + "# Set up the data pipeline.\n", + "def my_inputs(n_devices):\n", + " while True:\n", + " inputs = []\n", + " mask = []\n", + " pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)\n", + " for i in range(n_devices):\n", + " inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", + " mode='constant'))\n", + " mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),\n", + " (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", + " mode='constant'))\n", + " inputs = np.stack(inputs)\n", + " mask = np.stack(mask)\n", + " yield (inputs, inputs, mask)\n", + "\n", + "\n", + "print(\"(device count, tokens per device) = \",\n", + " next(my_inputs(trax.fastmath.device_count()))[0].shape)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "(device count, tokens per device) = (8, 524288)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Ei90LdK024r_", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Configure hyperparameters.\n", + "gin.parse_config(\"\"\"\n", + "import trax.layers\n", + "import trax.models\n", + "import trax.optimizers\n", + "import trax.data.inputs\n", + "import trax.supervised.trainer_lib\n", + "\n", + "# Parameters that will vary between experiments:\n", + "# ==============================================================================\n", + "train.model = @trax.models.ReformerLM\n", + "# Our model will have 6 layers, alternating between the LSH attention proposed\n", + "# in the Reformer paper and local attention within a certain context window.\n", + "n_layers = 6\n", + "attn_type = [\n", + " @trax.layers.SelfAttention,\n", + " @LSHSelfAttention,\n", + " @trax.layers.SelfAttention,\n", + " @LSHSelfAttention,\n", + " @trax.layers.SelfAttention,\n", + " @LSHSelfAttention,\n", + " ]\n", + "share_qk = False # LSH attention ignores this flag and always shares q & k\n", + "n_heads = 2\n", + "attn_kv = 64\n", + "dropout = 0.05\n", + "n_tokens = 524288\n", + "\n", + "# Parameters for multifactor:\n", + "# ==============================================================================\n", + "multifactor.constant = 0.01\n", + "multifactor.factors = 'constant * linear_warmup * cosine_decay'\n", + "multifactor.warmup_steps = 100\n", + "multifactor.steps_per_cycle = 900\n", + "\n", + "# Parameters for Adam:\n", + "# ==============================================================================\n", + "Adam.weight_decay_rate=0.0\n", + "Adam.b1 = 0.86\n", + "Adam.b2 = 0.92\n", + "Adam.eps = 1e-9\n", + "\n", + "# Parameters for SelfAttention:\n", + "# ==============================================================================\n", + "trax.layers.SelfAttention.attention_dropout = 0.05\n", + "trax.layers.SelfAttention.chunk_len = 64\n", + "trax.layers.SelfAttention.n_chunks_before = 1\n", + "trax.layers.SelfAttention.n_parallel_heads = 1\n", + "\n", + "# Parameters for LSHSelfAttention:\n", + "# ==============================================================================\n", + "LSHSelfAttention.attention_dropout = 0.0\n", + "LSHSelfAttention.chunk_len = 64\n", + "LSHSelfAttention.n_buckets = [64, 128]\n", + "LSHSelfAttention.n_chunks_after = 0\n", + "LSHSelfAttention.n_chunks_before = 1\n", + "LSHSelfAttention.n_hashes = 1\n", + "LSHSelfAttention.n_parallel_heads = 1\n", + "LSHSelfAttention.predict_drop_len = 128\n", + "LSHSelfAttention.predict_mem_len = 1024\n", + "\n", + "# Parameters for ReformerLM:\n", + "# ==============================================================================\n", + "ReformerLM.attention_type = %attn_type\n", + "ReformerLM.d_attention_key = %attn_kv\n", + "ReformerLM.d_attention_value = %attn_kv\n", + "ReformerLM.d_model = 256\n", + "ReformerLM.d_ff = 512\n", + "ReformerLM.dropout = %dropout\n", + "ReformerLM.ff_activation = @trax.layers.Relu\n", + "ReformerLM.max_len = %n_tokens\n", + "ReformerLM.mode = 'train'\n", + "ReformerLM.n_heads = %n_heads\n", + "ReformerLM.n_layers = %n_layers\n", + "ReformerLM.vocab_size = 320\n", + "ReformerLM.axial_pos_shape = (512, 1024)\n", + "ReformerLM.d_axial_pos_embs= (64, 192)\n", + "\"\"\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "RGGt0WaT3a-h", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Set up a Trainer.\n", + "output_dir = os.path.expanduser('~/train_dir/')\n", + "!rm -f ~/train_dir/model.pkl.gz # Remove old model\n", + "\n", + "trainer = trax.supervised.Trainer(\n", + " model=trax.models.ReformerLM,\n", + " loss_fn=trax.layers.CrossEntropyLoss(),\n", + " optimizer=trax.optimizers.Adam,\n", + " lr_schedule=trax.lr.multifactor(),\n", + " inputs=data.preprocessing.inputs.Inputs(my_inputs),\n", + " output_dir=output_dir)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "y6VQkmKO3a1L", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255 + }, + "outputId": "3c933bab-b49d-4e18-caf6-3dfc3e220938" + }, + "source": [ + "# Run one training step, to make sure the model fits in memory.\n", + "# The first time trainers.train_epoch is called, it will JIT the entire network\n", + "# architecture, which takes around 2 minutes. The JIT-compiled model is saved\n", + "# so subsequent runs will be much faster than the first.\n", + "trainer.train_epoch(n_steps=1, n_eval_steps=1)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "Step 1: Ran 1 train steps in 155.17 secs\n", + "Step 1: Evaluation\n", + "Step 1: train accuracy | 0.00343633\n", + "Step 1: train loss | 6.36618853\n", + "Step 1: train neg_log_perplexity | -6.36618853\n", + "Step 1: train sequence_accuracy | 0.00000000\n", + "Step 1: train weights_per_batch_per_core | 513812.00000000\n", + "Step 1: eval accuracy | 0.00340154\n", + "Step 1: eval loss | 6.36649418\n", + "Step 1: eval neg_log_perplexity | -6.36649418\n", + "Step 1: eval sequence_accuracy | 0.00000000\n", + "Step 1: eval weights_per_batch_per_core | 513812.00000000\n", + "Step 1: Finished evaluation\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "EFnX4G6z3asD", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Train for 600 steps total\n", + "# The first ~20 steps are slow to run, but after that it reaches steady-state\n", + "# speed. This will take at least 30 minutes to run to completion, but can safely\n", + "# be interrupted by selecting \"Runtime > Interrupt Execution\" from the menu.\n", + "# The language model won't be exceptionally good when trained for just a few\n", + "# steps and with minimal regularization. However, we can still sample from it to\n", + "# see what it learns.\n", + "trainer.train_epoch(n_steps=9, n_eval_steps=1)\n", + "for _ in range(59):\n", + " trainer.train_epoch(n_steps=10, n_eval_steps=1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zY3hpgnI5Rgn", + "colab_type": "text" + }, + "source": [ + "## Sample from the model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ffeLSbJk35pv", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# As we report in the Reformer paper, increasing the number of hashing rounds\n", + "# helps with quality. We can even increase the number of hashing rounds at\n", + "# evaluation time only.\n", + "\n", + "gin.parse_config(\"\"\"LSHSelfAttention.n_hashes = 4\"\"\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-BwIjdl6_2tX", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Load the trained Reformer in 'predict' mode\n", + "model = trax.models.ReformerLM(mode='predict')\n", + "model.init_from_file(os.path.join(output_dir, 'model.pkl.gz'),\n", + " weights_only=True)\n", + "\n", + "# Sample from ReformerLM\n", + "output_token_ids = trax.supervised.decoding.autoregressive_sample(\n", + " model, temperature=0.0)\n", + "\n", + "# Decode token IDs\n", + "# Reformer outputed a batch with one item, we access it using [0]\n", + "# tolist() converts from int64 to int, the type SentencePiece expects\n", + "TOKENIZER.DecodeIds(output_token_ids[0].tolist())\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "s5f5QAmZBgPj", + "colab_type": "code", + "colab": {} + }, + "source": [], + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/trax/models/research/examples/hourglass_downsampled_imagenet.ipynb b/resources/examples/ipynb/models/research/hourglass_downsampled_imagenet.ipynb similarity index 100% rename from trax/models/research/examples/hourglass_downsampled_imagenet.ipynb rename to resources/examples/ipynb/models/research/hourglass_downsampled_imagenet.ipynb diff --git a/trax/models/research/examples/hourglass_enwik8.ipynb b/resources/examples/ipynb/models/research/hourglass_enwik8.ipynb similarity index 100% rename from trax/models/research/examples/hourglass_enwik8.ipynb rename to resources/examples/ipynb/models/research/hourglass_enwik8.ipynb diff --git a/resources/examples/python/base.py b/resources/examples/python/base.py new file mode 100644 index 000000000..e372322e9 --- /dev/null +++ b/resources/examples/python/base.py @@ -0,0 +1,369 @@ +"""Machine learning utilities for dataset handling, model training, and evaluation.""" + +import time + +from enum import Enum +from typing import Dict, Generator, Optional, Tuple, Union + +import datasets +import numpy as np + +from absl import logging +from sklearn.datasets import load_digits, load_iris + +import trax.fastmath as fastmath + +from trax.fastmath import numpy as jnp +from trax.fastmath.jax import jax +from trax.utils import shapes + +# Set global logging verbosity +logging.set_verbosity(logging.INFO) + +# Constants +DEFAULT_BATCH_SIZE = 32 +LOG_INTERVAL = 10 + + +class DeviceType(Enum): + """Supported device types for computation.""" + + CPU = "cpu" + GPU = "gpu" + + +class Dataset(Enum): + """Supported datasets.""" + + IRIS = "iris" + DIGITS = "digits" + MNIST = "mnist" + IMDB = "imdb" + + +class Splits(Enum): + """Supported datasets.""" + + TRAIN = "train" + TEST = "test" + + +def load_mnist(split: str = Splits.TRAIN.value) -> Tuple[np.ndarray, np.ndarray]: + # Load the MNIST dataset using Hugging Face Datasets + # Use 'mnist' for the standard MNIST dataset + dataset = datasets.load_dataset("mnist", split=split) + + # Pre-allocate arrays with the correct shape + num_examples = len(dataset) + X = np.zeros((num_examples, 784), dtype=np.float32) + y = np.zeros(num_examples, dtype=np.int64) + + # Process each example in the dataset + i = 0 + for image, label in zip(dataset["image"], dataset["label"]): + # Flatten image from (28, 28) to (784,) and normalize + X[i] = np.array(image).reshape(-1).astype(np.float32) / 255.0 + y[i] = label + i += 1 + + return X, y + + +def load_imdb(split: str = Splits.TRAIN.value) -> Tuple[np.ndarray, np.ndarray]: + """Load the IMDB sentiment dataset as text and labels.""" + dataset = datasets.load_dataset("imdb", split=split) + texts = np.array(dataset["text"], dtype=object) + labels = np.array(dataset["label"], dtype=np.int64) + return texts, labels + + +def load_dataset( + dataset_name: str = Dataset.IRIS.value, + split: str = Splits.TRAIN.value, +) -> Union[Tuple[np.ndarray, np.ndarray]]: + """ + Load a dataset by name and split. + + Args: + dataset_name: Name of the dataset to load. + split: Which split to load ('train', 'test', or 'validation') + + Returns: + For sklearn datasets: Tuple of (data, labels) arrays. + For TensorFlow datasets: A TensorFlow dataset object. + """ + if dataset_name == Dataset.IRIS.value: + dataset = load_iris() + data, labels = dataset.data, dataset.target + # For sklearn datasets, we'll simulate train/test split + if split == "test": + # Use last 20% as test + test_size = len(data) // 5 + return data[-test_size:], labels[-test_size:] + else: + # Use first 80% as train + train_size = len(data) - (len(data) // 5) + return data[:train_size], labels[:train_size] + + elif dataset_name == Dataset.DIGITS.value: + dataset = load_digits() + data, labels = dataset.data, dataset.target + # For sklearn datasets, we'll simulate train/test split + if split == "test": + # Use last 20% as test + test_size = len(data) // 5 + return data[-test_size:], labels[-test_size:] + else: + # Use first 80% as train + train_size = len(data) - (len(data) // 5) + return data[:train_size], labels[:train_size] + + elif dataset_name == Dataset.MNIST.value: + x, y = load_mnist(split=split) + return x, y + elif dataset_name == Dataset.IMDB.value: + x, y = load_imdb(split=split) + return x, y + else: + raise ValueError(f"Unsupported dataset: {dataset_name}") + + +def create_batch_generator( + data: np.ndarray, + labels: np.ndarray, + weights: Optional[np.ndarray] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + seed: Optional[int] = None, +) -> Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None]: + """ + Create an infinite generator that produces shuffled batches. + + Args: + data: data array, shape [n_examples, n_data]. + labels: Labels array, shape [n_examples]. + weights: Optional sample weights array, shape [n_examples]. If None, uses all ones. + batch_size: Number of samples per batch. + seed: Random seed for reproducibility. + + Yields: + A tuple (data_batch, labels_batch, weights_batch). + """ + n_samples = data.shape[0] + + # Convert inputs to arrays and prepare weights if needed + data = np.asarray(data) + labels = np.asarray(labels) + weights = np.ones_like(labels) if weights is None else np.asarray(weights) + + # Initialize random number generator and shuffle indices + rng = np.random.default_rng(seed) + indices = np.arange(n_samples) + rng.shuffle(indices) + + batch_index = 0 + while True: + # Get batch indices with wraparound handling + end_index = batch_index + batch_size + if end_index <= n_samples: + batch_indices = indices[batch_index:end_index] + else: + overflow = end_index - n_samples + batch_indices = np.concatenate( + [indices[batch_index:], indices[:overflow]], axis=0 + ) + rng.shuffle(indices) + + # Yield batch data converted to jax arrays + yield ( + np.array(data[batch_indices]), + np.array(labels[batch_indices]), + np.array(weights[batch_indices]), + ) + + # Update index for the next batch + batch_index = (batch_index + batch_size) % n_samples + if batch_index == 0: + rng.shuffle(indices) + + +def graph_batch_generator(nodes, adjacency, labels, batch_size=32, seed=0): + rng = np.random.default_rng(seed) + n = nodes.shape[0] + while True: + idx = rng.choice(n, batch_size, replace=False) + yield nodes[idx], adjacency[idx], labels[idx], np.ones(batch_size) + + +def initialize_model(model_with_loss, example_batch) -> Tuple[float, float]: + """ + Initialize and compile a model using an example batch. + + Args: + model_with_loss: Model with loss function to initialize. + example_batch: Example batch for initialization. + + Returns: + Tuple of (initialization_time, compilation_time) in seconds. + """ + logging.info("Initializing model...") + init_start = time.time() + _, _ = model_with_loss.init(shapes.signature(example_batch)) + init_time = time.time() - init_start + logging.info(f"Model initialization time: {init_time:.4f} seconds") + + logging.info("Compiling model with first batch...") + compile_start = time.time() + _ = model_with_loss(example_batch) + compile_time = time.time() - compile_start + logging.info(f"Compilation time: {compile_time:.4f} seconds") + + return init_time, compile_time + + +def _get_target_device(device_type: str): + """Helper function to get the target device.""" + if device_type == DeviceType.CPU.value: + return fastmath.devices(DeviceType.CPU.value)[0] + elif device_type == DeviceType.GPU.value: + return fastmath.devices(DeviceType.GPU.value)[0] + else: + raise ValueError(f"Unsupported device type: {device_type}") + + +def train_model( + trainer, + batch_generator: Generator[Tuple[np.ndarray, np.ndarray, np.ndarray], None, None], + num_steps: int, + base_rng, + device_type: str = DeviceType.CPU.value, +) -> list: + """ + Train a model for a specified number of steps. + + Args: + trainer: The model trainer. + batch_generator: Generator that produces training batches. + num_steps: Number of training steps. + base_rng: Base random number generator. + device_type: Type of device to use for training ("cpu" or "gpu"). + + Returns: + List of loss values for each training step. + """ + logging.info(f"\n\n{'='*20} RUNNING ON {device_type.upper()} {'='*20}") + logging.info( + f"Backend: {fastmath.backend_name()}, Global devices: {fastmath.global_device_count()}" + ) + + losses = [] + training_start = time.time() + + # Set target device via context + target_device = _get_target_device(device_type) + + with jax.default_device(target_device): + for step in range(num_steps): + step_start = time.time() + step_rng, base_rng = fastmath.random.split(base_rng) + batch = next(batch_generator) + loss = trainer.one_step(batch, step_rng, step=step) + step_time = time.time() - step_start + losses.append(loss) + + # Log progress at regular intervals + if step % LOG_INTERVAL == 0 or step == num_steps - 1: + logging.info( + f"Step {step}, Loss: {loss:.4f}, Step time: {step_time:.4f} sec" + ) + + # Print training summary + training_time = time.time() - training_start + avg_step_time = training_time / num_steps + logging.info( + f"Total training time: {training_time:.4f} sec, Average step: {avg_step_time:.4f} sec" + ) + + return losses + + +def compute_accuracy(predicted: jnp.ndarray, true_labels: jnp.ndarray) -> float: + """ + Compute classification accuracy. + + Args: + predicted: 1D array of integer class predictions, shape [N]. + True_labels: 1D array of integer ground-truth labels, shape [N]. + + Returns: + Accuracy as a float between 0 and 1. + """ + return jnp.mean(predicted == true_labels) + + +def evaluate_model( + trainer, + batch_gen: Generator[Tuple[np.ndarray, ...], None, None], + device_type: str = DeviceType.CPU.value, + num_batches: int = 100, +) -> Dict[str, float]: + """ + Evaluate a trained model on test data. + + Args: + trainer: The trained model trainer. + batch_gen: + device_type: Type of device to use for evaluation. + num_batches: Number of batches to evaluate. + + Returns: + Dictionary with evaluation metrics including accuracy and mean loss. + """ + logging.info(f"\n\n{'='*20} EVALUATING MODEL {'='*20}") + + # Set up evaluation environment + target_device = _get_target_device(device_type) + dummy_rng = fastmath.random.get_prng(10) + + # Initialize evaluation metrics + total_loss = 0.0 + total_accuracy = 0.0 + + # Evaluate model on a test set + with fastmath.jax.jax.default_device(target_device): + for i in range(num_batches): + batch = next(batch_gen) + + # Get model predictions + predictions = trainer.model_with_loss.sublayers[0]( + batch, + weights=trainer.model_with_loss.sublayers[0].weights, + state=trainer.model_with_loss.sublayers[0].state, + rng=dummy_rng, + ) + + # Calculate accuracy + predicted = jnp.argmax(predictions[0], axis=1) + labels = predictions[1] + batch_accuracy = compute_accuracy(predicted, labels) + total_accuracy += batch_accuracy + + # Calculate loss + batch_loss = trainer.model_with_loss(batch, rng=dummy_rng) + total_loss += batch_loss + + # Log progress + if i % LOG_INTERVAL == 0 or i == num_batches - 1: + logging.info( + f"Test batch {i}, Accuracy: {batch_accuracy:.4f}, Loss: {batch_loss:.4f}" + ) + + # Calculate final metrics + mean_accuracy = total_accuracy / num_batches + mean_loss = total_loss / num_batches + + # Log summary + logging.info("\nTest results:") + logging.info(f" Mean accuracy: {mean_accuracy:.4f}") + logging.info(f" Mean loss: {mean_loss:.4f}") + + return {"accuracy": float(mean_accuracy), "loss": float(mean_loss)} diff --git a/resources/examples/python/gnn/20ng/train.py b/resources/examples/python/gnn/20ng/train.py new file mode 100644 index 000000000..a08838e58 --- /dev/null +++ b/resources/examples/python/gnn/20ng/train.py @@ -0,0 +1,160 @@ +import re + +from collections import Counter + +import datasets +import numpy as np + +import trax.fastmath as fastmath + +from resources.examples.python.base import ( + DeviceType, + evaluate_model, + graph_batch_generator, + initialize_model, + train_model, +) +from trax import layers as tl +from trax import optimizers +from trax.fastmath import numpy as jnp +from trax.models import gnn +from trax.trainers import jax as trainers + +MAX_LEN = 2000 +VOCAB_SIZE = 200_000 +WINDOW_SIZE = 5 + + +def clean(text): + t = text.lower() + t = re.sub(r"\S+@\S+", " ", t) # emails + t = re.sub(r"http\S+|www\.\S+", " ", t) + t = re.sub(r"[_A-Za-z]:/[^ \n]+", " ", t) # paths/urls + t = re.sub(r"[^a-z0-9\s]", " ", t) + t = re.sub(r"\s+", " ", t).strip() + return t + + +def build_vocab(texts, min_freq=5): + counter = Counter() + for t in texts: + counter.update(clean(t).split()[:MAX_LEN]) + vocab = {"": 0, "": 1} + for w, c in counter.most_common(): + if c < min_freq or len(vocab) >= VOCAB_SIZE: + break + vocab[w] = len(vocab) + return vocab + + +def encode(text, vocab): + tokens = [vocab.get(w, 1) for w in text.lower().split()[:MAX_LEN]] + if len(tokens) < MAX_LEN: + tokens += [0] * (MAX_LEN - len(tokens)) + return np.array(tokens) + + +def window_adjacency(length=MAX_LEN, window=WINDOW_SIZE): + """Create adjacency connecting tokens within a sliding window.""" + adj = np.zeros((length, length), dtype=np.float32) + for i in range(length): + left, right = max(0, i - window), min(length, i + window + 1) + for j in range(left, right): + if i != j: + adj[i, j] = 1.0 + np.fill_diagonal(adj, 1.0) + return adj + + +def load_data(): + train_ds = datasets.load_dataset("SetFit/20_newsgroups", split="train") + test_ds = datasets.load_dataset("SetFit/20_newsgroups", split="test") + # train_ds = datasets.load_dataset("imdb", split="train[:2000]") + # test_ds = datasets.load_dataset("imdb", split="test[:1000]") + + vocab = build_vocab(train_ds["text"]) + x_train = np.stack([encode(t, vocab) for t in train_ds["text"]]) + y_train = np.array(train_ds["label"], dtype=np.int64) + x_test = np.stack([encode(t, vocab) for t in test_ds["text"]]) + y_test = np.array(test_ds["label"], dtype=np.int64) + + adj = window_adjacency() + a_train = np.broadcast_to(adj, (x_train.shape[0], MAX_LEN, MAX_LEN)) + a_test = np.broadcast_to(adj, (x_test.shape[0], MAX_LEN, MAX_LEN)) + + return (x_train, a_train, y_train), (x_test, a_test, y_test), len(vocab) + + +def attention_pool(): + """Compute weighted average of node embeddings.""" + return tl.Serial( + tl.Branch( + None, + tl.Serial( + tl.Dense(1), + tl.Flatten(n_axes_to_keep=2), + tl.Softmax(), + ), + ), + tl.Fn( + "AttnPool", + lambda x, w: jnp.sum(x * w, axis=1), + ), + ) + + +def build_model(vocab_size): + return tl.Serial( + tl.Parallel(tl.Embedding(vocab_size, 512), None), + gnn.GraphAttentionNet(hidden_sizes=(512, 64, 32), num_heads=16), + tl.Select([0]), + attention_pool(), + tl.Dense(20), + tl.Select([0, 2, 3]), + ) + + +def main(): + DEFAULT_BATCH_SIZE = 8 + STEPS_NUMBER = 14_000 + + (x_train, a_train, y_train), (x_test, a_test, y_test), vocab_size = load_data() + batch_gen = graph_batch_generator( + x_train, a_train, y_train, batch_size=DEFAULT_BATCH_SIZE + ) + example_batch = next(batch_gen) + + model_with_loss = tl.Serial( + build_model(vocab_size), tl.CrossEntropyLossWithLogSoftmax() + ) + initialize_model(model_with_loss, example_batch) + + optimizer = optimizers.Adam(0.00001) + trainer = trainers.Trainer(model_with_loss, optimizer) + + base_rng = fastmath.random.get_prng(0) + train_model( + trainer, + batch_gen, + num_steps=STEPS_NUMBER, + base_rng=base_rng, + device_type=DeviceType.GPU.value, + ) + + test_batch_gen = graph_batch_generator( + x_test, a_test, y_test, batch_size=DEFAULT_BATCH_SIZE + ) + + # Evaluate model on a test set + test_results = evaluate_model( + trainer=trainer, + batch_gen=test_batch_gen, + device_type=DeviceType.CPU.value, + num_batches=50, + ) + + print(f"Final test accuracy: {test_results['accuracy']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/resources/examples/python/gnn/imdb/train.py b/resources/examples/python/gnn/imdb/train.py new file mode 100644 index 000000000..6728b3da3 --- /dev/null +++ b/resources/examples/python/gnn/imdb/train.py @@ -0,0 +1,122 @@ +from collections import Counter + +import datasets +import numpy as np + +import trax.fastmath as fastmath + +from resources.examples.python.base import ( + DeviceType, + evaluate_model, + graph_batch_generator, + initialize_model, + train_model, +) +from trax import layers as tl +from trax import optimizers +from trax.models import gnn +from trax.trainers import jax as trainers + +MAX_LEN = 400 +VOCAB_SIZE = 36_000 + + +def build_vocab(texts): + counter = Counter() + for t in texts: + counter.update(t.lower().split()[:MAX_LEN]) + vocab = {"": 0, "": 1} + for i, (w, _) in enumerate(counter.most_common(VOCAB_SIZE - 2), start=2): + vocab[w] = i + return vocab + + +def encode(text, vocab): + tokens = [vocab.get(w, 1) for w in text.lower().split()[:MAX_LEN]] + if len(tokens) < MAX_LEN: + tokens += [0] * (MAX_LEN - len(tokens)) + return np.array(tokens) + + +def chain_adjacency(length=MAX_LEN): + adj = np.zeros((length, length), dtype=np.float32) + for i in range(length - 1): + adj[i, i + 1] = 1 + adj[i + 1, i] = 1 + return adj + + +def load_data(): + train_ds = datasets.load_dataset("imdb", split="train") + test_ds = datasets.load_dataset("imdb", split="test") + # train_ds = datasets.load_dataset("imdb", split="train[:2000]") + # test_ds = datasets.load_dataset("imdb", split="test[:1000]") + + vocab = build_vocab(train_ds["text"]) + x_train = np.stack([encode(t, vocab) for t in train_ds["text"]]) + y_train = np.array(train_ds["label"], dtype=np.int64) + x_test = np.stack([encode(t, vocab) for t in test_ds["text"]]) + y_test = np.array(test_ds["label"], dtype=np.int64) + + adj = chain_adjacency() + a_train = np.broadcast_to(adj, (x_train.shape[0], MAX_LEN, MAX_LEN)) + a_test = np.broadcast_to(adj, (x_test.shape[0], MAX_LEN, MAX_LEN)) + + return (x_train, a_train, y_train), (x_test, a_test, y_test), len(vocab) + + +def build_model(vocab_size): + return tl.Serial( + tl.Parallel(tl.Embedding(vocab_size, 512), None), + gnn.GraphAttentionNet(hidden_sizes=(512, 64, 32), num_heads=2), + tl.Select([0]), + tl.Mean(axis=1), + tl.Dense(2), + tl.Select([0, 2, 3]), + ) + + +def main(): + DEFAULT_BATCH_SIZE = 16 + STEPS_NUMBER = 20_000 + + (x_train, a_train, y_train), (x_test, a_test, y_test), vocab_size = load_data() + batch_gen = graph_batch_generator( + x_train, a_train, y_train, batch_size=DEFAULT_BATCH_SIZE + ) + example_batch = next(batch_gen) + + model_with_loss = tl.Serial( + build_model(vocab_size), tl.CrossEntropyLossWithLogSoftmax() + ) + initialize_model(model_with_loss, example_batch) + + optimizer = optimizers.Adam(0.0001) + trainer = trainers.Trainer(model_with_loss, optimizer) + + base_rng = fastmath.random.get_prng(0) + train_model( + trainer, + batch_gen, + num_steps=STEPS_NUMBER, + base_rng=base_rng, + device_type=DeviceType.GPU.value, + ) + + test_batch_gen = graph_batch_generator( + x_test, a_test, y_test, batch_size=DEFAULT_BATCH_SIZE + ) + + # Evaluate model on a test set + test_results = evaluate_model( + trainer=trainer, + batch_gen=test_batch_gen, + device_type=DeviceType.CPU.value, + num_batches=500, + ) + + print(f"Final test accuracy: {test_results['accuracy']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/resources/examples/python/gnn/mnist/train.py b/resources/examples/python/gnn/mnist/train.py new file mode 100644 index 000000000..0d187029d --- /dev/null +++ b/resources/examples/python/gnn/mnist/train.py @@ -0,0 +1,102 @@ +import numpy as np + +import trax.fastmath as fastmath + +from resources.examples.python.base import ( + Dataset, + DeviceType, + Splits, + evaluate_model, + graph_batch_generator, + initialize_model, + load_dataset, + train_model, +) +from trax import layers as tl +from trax import optimizers +from trax.models import gnn +from trax.trainers import jax as trainers + + +def grid_adjacency(height=28, width=28): + """Returns 4-neighbor adjacency for an image grid.""" + n = height * width + adj = np.zeros((n, n), dtype=np.float32) + for y in range(height): + for x in range(width): + idx = y * width + x + if x > 0: + adj[idx, idx - 1] = 1 + if x < width - 1: + adj[idx, idx + 1] = 1 + if y > 0: + adj[idx, idx - width] = 1 + if y < height - 1: + adj[idx, idx + width] = 1 + return adj + + +def create_graph_data(images): + nodes = images.reshape((images.shape[0], 28 * 28, 1)).astype(np.float32) + adj = grid_adjacency() + adj = np.broadcast_to(adj, (images.shape[0], adj.shape[0], adj.shape[1])) + return nodes, adj + + +def build_model(): + return tl.Serial( + gnn.GraphAttentionNet(hidden_sizes=(128, 64, 32, 16)), + tl.Select([0]), + tl.Mean(axis=1), + tl.Dense(10), + tl.Select([0, 2, 3]), + ) + + +def main(): + DEFAULT_BATCH_SIZE = 8 + STEPS_NUMBER = 20_000 + + images, labels = load_dataset(Dataset.MNIST.value) + nodes, adjacency = create_graph_data(images) + + batch_generator = graph_batch_generator( + nodes, adjacency, labels, batch_size=DEFAULT_BATCH_SIZE + ) + example_batch = next(batch_generator) + + model_with_loss = tl.Serial(build_model(), tl.CrossEntropyLossWithLogSoftmax()) + initialize_model(model_with_loss, example_batch) + + optimizer = optimizers.Adam(0.0001) + trainer = trainers.Trainer(model_with_loss, optimizer) + + base_rng = fastmath.random.get_prng(0) + train_model( + trainer, + batch_generator, + STEPS_NUMBER, + base_rng, + device_type=DeviceType.GPU.value, + ) + + images, labels = load_dataset(Dataset.MNIST.value, Splits.TEST.value) + nodes, adjacency = create_graph_data(images) + + test_batch_gen = graph_batch_generator( + nodes, adjacency, labels, batch_size=DEFAULT_BATCH_SIZE + ) + + # Evaluate model on a test set + test_results = evaluate_model( + trainer=trainer, + batch_gen=test_batch_gen, + device_type=DeviceType.CPU.value, + num_batches=100, + ) + + print(f"Final test accuracy: {test_results['accuracy']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/resources/examples/python/nn/digits/train.py b/resources/examples/python/nn/digits/train.py new file mode 100644 index 000000000..56d989070 --- /dev/null +++ b/resources/examples/python/nn/digits/train.py @@ -0,0 +1,80 @@ +import trax.fastmath as fastmath + +from resources.examples.python.base import ( + Dataset, + DeviceType, + Splits, + create_batch_generator, + evaluate_model, + initialize_model, + load_dataset, + train_model, +) +from trax import layers as tl +from trax import optimizers +from trax.trainers import jax as trainers + + +def build_model(): + # Build your model with loss function + model = tl.Serial( + tl.Dense(16, use_bias=True), tl.Relu(), tl.Dense(10, use_bias=False) + ) + model_with_loss = tl.Serial(model, tl.CrossEntropyLossWithLogSoftmax()) + return model_with_loss + + +def main(): + # Default setup + DEFAULT_BATCH_SIZE = 8 + STEPS_NUMBER = 20_000 + + # Load data + X, y = load_dataset(Dataset.DIGITS.value) + batch_generator = create_batch_generator( + X, y, batch_size=DEFAULT_BATCH_SIZE, seed=42 + ) + example_batch = next(batch_generator) + + # Build and initialize model + model_with_loss = build_model() + initialize_model(model_with_loss, example_batch) + + # Setup optimizer and trainers + optimizer = optimizers.Adam(0.001) + trainer = trainers.Trainer(model_with_loss, optimizer) + + base_rng = fastmath.random.get_prng(0) + + # Run training on CPU and/or GPU + train_model( + trainer, + batch_generator, + STEPS_NUMBER, + base_rng, + device_type=DeviceType.GPU.value, + ) + + # Load test data + test_data, test_labels = load_dataset( + dataset_name=Dataset.DIGITS.value, split=Splits.TEST.value + ) + + # Create batch generator for test data + test_batch_gen = create_batch_generator( + test_data, test_labels, None, DEFAULT_BATCH_SIZE, 0 + ) + + # Evaluate model on a test set + test_results = evaluate_model( + trainer=trainer, + batch_gen=test_batch_gen, + device_type=DeviceType.CPU.value, + num_batches=100, + ) + + print(f"Final test accuracy: {test_results['accuracy']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/resources/examples/python/nn/imdb/train.py b/resources/examples/python/nn/imdb/train.py new file mode 100644 index 000000000..f546ca375 --- /dev/null +++ b/resources/examples/python/nn/imdb/train.py @@ -0,0 +1,336 @@ +import time + +import numpy as np + +from absl import logging +from layers import CrossEntropyLossWithLogSoftmax + +import trax.fastmath as fastmath + +from trax import layers as tl +from trax import optimizers, shapes + +# from trax.data.encoder import encoder as encoder +# from trax.data.loader.tf import base as dataset +# from trax.data.preprocessing import inputs as preprocessing +from trax.fastmath import numpy as jnp + + +def Transpose(): # pylint: disable=invalid-name + layer_name = ( + "Transpose" # don't forget to give your custom layer a name to identify + ) + + # Custom function for the custom layer + def f(x): # pylint: disable=invalid-name + assert len(x.shape) == 3 or len(x.shape) == 2, ( + "Houston we've got a problem: " + "Cannot automatically reshape this " + "stream - input is not a 2d or 3d array " + "you should use trax.data.Batch(n) firsts, where n >= 1." + ) + if len(x.shape) == 2: + return jnp.transpose(x) + + return jnp.transpose(x, (0, 2, 1)) + + return tl.Fn(layer_name, f, n_out=1) + + +def run_training(device_type="cpu", num_steps=100): + """Run training with specified device configuration""" + + # ====== Determine the target device ======= + if device_type == "cpu": + target_device = fastmath.devices("cpu")[0] + elif device_type == "gpu": + target_device = fastmath.devices("gpu")[0] + else: + raise ValueError(f"Unsupported device_type: {device_type}") + + # Set the logging level to INFO or lower + logging.set_verbosity(logging.INFO) + + print(f"\n\n{'='*20} RUNNING ON {device_type.upper()} {'='*20}") + logging.info(f"Backend name: {fastmath.backend_name()}") + logging.info(f"Backend device count: {fastmath.global_device_count()}") + logging.info(f"Backend local device count: {fastmath.local_device_count()}") + logging.info(f"JAX devices: {fastmath.devices(device_type)[0]}") + logging.info(f"JAX target device: {fastmath.devices(device_type)[0]}") + + # ====== Create data pipeline ======= + # VOCAB_TYPE = "subword" + # VOCAB_FILE = "en_8k.subword" + # + # vocab_size = encoder.vocab_size(VOCAB_TYPE, VOCAB_FILE) + # + # train_stream = dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=True)() + # eval_stream = dataset.TFDS('imdb_reviews', keys=('text', 'label'), train=False)() + # + # data_pipeline = preprocessing.Serial( + # preprocessing.ConvertToUnicode(keys=[0]), + # encoder.Tokenize(keys=[0], vocab_type=VOCAB_TYPE, vocab_file=VOCAB_FILE), + # preprocessing.Shuffle(), + # #preprocessing.FilterByLength(max_length=1000_000, length_keys=[0]), + # preprocessing.AddLossWeights(), + # lambda g: map(lambda x: (x[0], np.asarray(x[1]), x[2]), g), + # preprocessing.ClassificationVector(vocab_size=vocab_size), + # preprocessing.Batch(batch_size=32), + # lambda g: map(lambda x: (jnp.asarray(x[0]), jnp.asarray(x[1]), jnp.asarray(x[2])), g), + # ) + + def create_batch_generator( + batch_size=32, feature_dim=10_000, num_classes=20, seed=42 + ): + """ + Creates a generator that yields random example batches. + + Args: + batch_size: Size of each batch + feature_dim: Dimension of feature vectors + num_classes: Number of possible classes + seed: Random seed for reproducibility + + Returns: + A generator that yields (features, labels, weights) tuples + """ + # Initialize the RNG key + key = fastmath.random.get_prng(seed) + + while True: + # Split the key for this iteration to get two independent random keys + key, subkey1, subkey2 = fastmath.random.split(key, 3) + + # Generate features, labels and weights + features = fastmath.random.randint( + subkey1, (batch_size, feature_dim), minval=0, maxval=10_000 + ) + labels = fastmath.random.randint( + subkey2, (batch_size,), minval=0, maxval=num_classes + ) + weights = jnp.ones((batch_size,)) + + # Yield the batch + yield (features, labels, weights) + + train_batches_stream = create_batch_generator() + example_batch = next( + train_batches_stream + ) # Cache first batch to ensure fair comparison + # train_batches_stream = data_pipeline(train_stream) + # example_batch = next(train_batches_stream) # Cache first batch to ensure fair comparison + + # ====== Create and initialize model ======= + mode = "train" + model = tl.Serial( + tl.Embedding(vocab_size=10_000, d_feature=1), + Transpose(), + tl.Dropout(rate=0.1, mode=mode), + tl.LeakyRelu(a=0.1), + tl.Dense(2, use_bias=False), + ) + + model_with_loss = tl.Serial(model, CrossEntropyLossWithLogSoftmax()) + + # Initialize model + print("Initializing model...") + init_start = time.time() + _, _ = model_with_loss.init(shapes.signature(example_batch)) + init_time = time.time() - init_start + print(f"Model initialization time: {init_time:.4f} seconds") + + # First run to compile + print("Compiling model with first batch...") + compile_start = time.time() + y = model_with_loss(example_batch) + compile_time = time.time() - compile_start + print(f"Compilation time: {compile_time:.4f} seconds") + + # Setup optimizer + # optimizer = optimizers.Adafactor(0.001) + optimizer = optimizers.SGD(0.0001) + trainer = optimizers.Trainer(model_with_loss, optimizer) + + base_rng = fastmath.random.get_prng(0) + + # ====== Training loop with timing ======= + print(f"Starting training for {num_steps} steps...") + training_start = time.time() + losses = [] + + with fastmath.jax.jax.default_device(target_device): + for i in range(num_steps): + step_start = time.time() + + # Split the RNG to get a new key for this step + step_rng, base_rng = fastmath.random.split(base_rng) + + # Get batch (use cached first batch for first iteration to ensure fair comparison) + if i == 0: + batch = example_batch + else: + batch = next(train_batches_stream) + + # Training step + loss = trainer.one_step(batch, step_rng, step=i) + step_time = time.time() - step_start + losses.append(loss) + + # Print progress + if i % 10 == 0 or i == num_steps - 1: + print(f"Step {i}, Loss: {loss:.4f}, Step time: {step_time:.4f} seconds") + + training_time = time.time() - training_start + avg_step_time = training_time / num_steps + + print(f"\n{'='*50}") + print(f"Device: {device_type.upper()}") + print(f"Total training time for {num_steps} steps: {training_time:.4f} seconds") + print(f"Average step time: {avg_step_time:.4f} seconds") + print(f"Final loss: {losses[-1]:.4f}") + print(f"{'='*50}\n") + + return { + "device": device_type, + "init_time": init_time, + "compile_time": compile_time, + "total_training_time": training_time, + "avg_step_time": avg_step_time, + "final_loss": losses[-1], + } + + +# Run and compare +NUM_STEPS = 5_000 # Use a smaller number for testing, then increase for full benchmark + +# CPU run +cpu_results = run_training(device_type="cpu", num_steps=NUM_STEPS) + +# GPU run +gpu_results = run_training(device_type="gpu", num_steps=NUM_STEPS) + +# Print comparison +print("\n" + "=" * 50) +print("PERFORMANCE COMPARISON: CPU vs GPU") +print("=" * 50) +print(f"{'Metric':<25} {'CPU':<15} {'GPU':<15} {'Speedup':<10}") +print("-" * 65) + +for metric in ["init_time", "compile_time", "total_training_time", "avg_step_time"]: + cpu_val = cpu_results[metric] + gpu_val = gpu_results[metric] + speedup = cpu_val / gpu_val if gpu_val > 0 else float("inf") + print(f"{metric:<25} {cpu_val:.4f}s{' ':<9} {gpu_val:.4f}s{' ':<9} {speedup:.2f}x") + +print("=" * 65) + + +def create_batch_generator(batch_size=32, feature_dim=10_000, num_classes=20, seed=42): + """ + Creates a generator that yields random example batches. + + Args: + batch_size: Size of each batch + feature_dim: Dimension of feature vectors + num_classes: Number of possible classes + seed: Random seed for reproducibility + + Returns: + A generator that yields (features, labels, weights) tuples + """ + # Initialize the RNG key + key = fastmath.random.get_prng(seed) + + while True: + # Split the key for this iteration to get two independent random keys + key, subkey1, subkey2 = fastmath.random.split(key, 3) + + # Generate features, labels and weights + features = fastmath.random.randint( + subkey1, (batch_size, feature_dim), minval=0, maxval=10_000 + ) + labels = fastmath.random.randint( + subkey2, (batch_size,), minval=0, maxval=num_classes + ) + weights = jnp.ones((batch_size,)) + + # Yield the batch + yield (features, labels, weights) + + +mode = "train" +model = tl.Serial( + tl.Embedding(vocab_size=10_000, d_feature=1), + Transpose(), + tl.Dropout(rate=0.1, mode=mode), + tl.LeakyRelu(a=0.1), + tl.Dense(20, use_bias=False), +) + +# CrossEntropyLossWithLogSoftmax() make overhead ner 6 second in comparison to pure execution sequence of ore.LogSoftmax(), _CrossEntropy(), _WeightedMean(), +# When we use gpu the result is near 40-50 second the cpu near 60-65 difference 10-15 second per operation +# Accelerated version improve computation to cuda:0 time: 11.7531 seconds cpu - TFRT_CPU_0 time: 64.3705 seconds +model_with_loss = tl.Serial(model, CrossEntropyLossWithLogSoftmax()) + +model_with_loss_accelerated = tl.Accelerate(model_with_loss) + +batch_generator = create_batch_generator(batch_size=32) +example_batch = next(batch_generator) + +# Initialize model +print("Initializing model...") +init_start = time.time() +_, _ = model_with_loss_accelerated.init(shapes.signature(example_batch)) +init_time = time.time() - init_start +print(f"Model initialization time: {init_time:.4f} seconds") + +device = fastmath.jax.jax.devices("cpu")[0] +with fastmath.jax.jax.default_device(device): + start_time = time.time() + for _ in range(5_000): + example_batch = next(batch_generator) + y = model_with_loss_accelerated(example_batch) + cpu_time = time.time() - start_time + print(f"{device} time: {cpu_time:.4f} seconds") + print(y) + + +x = np.array([[[1, 2, 3]]]) +transpose_layer = Transpose() +result = transpose_layer(x) +transpose_layer_accelerated = tl.Accelerate(transpose_layer) +result = transpose_layer_accelerated(x) + +# Define a sample computation +def compute(): + # Get a random key + key = fastmath.random.get_prng(0) + + result = jnp.zeros((4_000, 20)) + + for _ in range(100): + key, subkey1, subkey2 = fastmath.random.split(key, 3) + + x = fastmath.random.normal(subkey1, (4_000, 10_000)) + y = fastmath.random.normal(subkey2, (10_000, 20)) + + result = jnp.dot(x, y) + + return result + + +# Run on CPU +cpu_device = fastmath.jax.jax.devices("cpu")[0] +with fastmath.jax.jax.default_device(cpu_device): + start_time = time.time() + compute().block_until_ready() + cpu_time = time.time() - start_time + print(f"CPU time: {cpu_time:.4f} seconds") + +# Run on GPU +gpu_device = fastmath.jax.jax.devices("gpu")[0] +with fastmath.jax.jax.default_device(gpu_device): + start_time = time.time() + compute().block_until_ready() + gpu_time = time.time() - start_time + print(f"GPU time: {gpu_time:.4f} seconds") diff --git a/resources/examples/python/nn/iris/train.py b/resources/examples/python/nn/iris/train.py new file mode 100644 index 000000000..a8290548b --- /dev/null +++ b/resources/examples/python/nn/iris/train.py @@ -0,0 +1,80 @@ +import trax.fastmath as fastmath + +from resources.examples.python.base import ( + Dataset, + DeviceType, + Splits, + create_batch_generator, + evaluate_model, + initialize_model, + load_dataset, + train_model, +) +from trax import layers as tl +from trax import optimizers +from trax.trainers import jax as trainers + + +def build_model(): + # Build your model with loss function + model = tl.Serial( + tl.Dense(16, use_bias=True), tl.Relu(), tl.Dense(3, use_bias=False) + ) + model_with_loss = tl.Serial(model, tl.CrossEntropyLossWithLogSoftmax()) + return model_with_loss + + +def main(): + # Default setup + DEFAULT_BATCH_SIZE = 8 + STEPS_NUMBER = 20_000 + + # Load data + X, y = load_dataset(Dataset.IRIS.value) + batch_generator = create_batch_generator( + X, y, batch_size=DEFAULT_BATCH_SIZE, seed=42 + ) + example_batch = next(batch_generator) + + # Build and initialize model + model_with_loss = build_model() + initialize_model(model_with_loss, example_batch) + + # Setup optimizer and trainers + optimizer = optimizers.SGD(0.1) + trainer = trainers.Trainer(model_with_loss, optimizer) + + base_rng = fastmath.random.get_prng(0) + + # Run training on CPU and/or GPU + train_model( + trainer, + batch_generator, + STEPS_NUMBER, + base_rng, + device_type=DeviceType.GPU.value, + ) + + # Load test data + test_data, test_labels = load_dataset( + dataset_name=Dataset.IRIS.value, split=Splits.TEST.value + ) + + # Create batch generator for test data + test_batch_gen = create_batch_generator( + test_data, test_labels, None, DEFAULT_BATCH_SIZE, 0 + ) + + # Evaluate model on a test set + test_results = evaluate_model( + trainer=trainer, + batch_gen=test_batch_gen, + device_type=DeviceType.CPU.value, + num_batches=100, + ) + + print(f"Final test accuracy: {test_results['accuracy']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/resources/examples/python/nn/mnist/from_scratch/dataset.py b/resources/examples/python/nn/mnist/from_scratch/dataset.py new file mode 100644 index 000000000..b9925106f --- /dev/null +++ b/resources/examples/python/nn/mnist/from_scratch/dataset.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load pickled MNIST data.""" +import gzip +import os +import pickle +import random +import urllib.request + +import numpy as np + + +def load(): + """Loads the dataset. + + Looks for the dataset at /tmp/mnist.pkl.gz and downloads it if it is not there + already. + + Note: The training data is shuffled. + + Returns: + ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)). + Shapes: + train_x: num_training_examples x image_size + train_y: num_training_examples x num_classes + valid_x: num_validation_examples x image_size + valid_y: num_validation_examples x num_classes + test_x: num_test_examples x image_size + test_y: num_test_examples x num_classes + """ + filepath = _maybe_download() + with gzip.open(os.path.join(filepath), "rb") as f: + training_data, validation_data, test_data = pickle.load(f, encoding="bytes") + training_data = (training_data[0], [to_one_hot(x) for x in training_data[1]]) + validation_data = (validation_data[0], [to_one_hot(x) for x in validation_data[1]]) + test_data = (test_data[0], [to_one_hot(x) for x in test_data[1]]) + + def shuffle(data): + zipped = list(zip(*data)) + random.shuffle(zipped) + shuffled = zip(*zipped) + # Convert the zip object to a tuple of lists to make it subscriptable + return tuple(list(x) for x in shuffled) + + return (shuffle(training_data), validation_data, test_data) + + +def to_one_hot(label, num_classes=10): + vec = np.zeros(num_classes, dtype=np.float32) + vec[label] = 1.0 + return vec + + +def _maybe_download(): + """Downloads the MNIST dataset if it is not there already.""" + data_url = "http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz" + filename = data_url.split("/")[-1] + filepath = os.path.join(_get_data_dir(), filename) + if not os.path.exists(filepath): + + def _progress(count, block_size, total_size): + print( + "\r>> Downloading %s %.1f%%" + % (filename, float(count * block_size) / float(total_size) * 100.0) + ) + + filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) + statinfo = os.stat(filepath) + print("Successfully downloaded %s %d bytes." % (filename, statinfo.st_size)) + else: + print("Data already present on disk.") + return filepath + + +def _get_data_dir(): + return "/tmp" diff --git a/resources/examples/python/nn/mnist/from_scratch/model.py b/resources/examples/python/nn/mnist/from_scratch/model.py new file mode 100644 index 000000000..8b1a49a1e --- /dev/null +++ b/resources/examples/python/nn/mnist/from_scratch/model.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model for training on MNIST data.""" +import tensorflow as tf + +from numpy import float32, int32 + +from trax.tf import numpy as np + + +class Model(object): + """A simple neural network with dense layers and sigmoid non-linearity. + + The network consists of `len(hidden_layers) + 1` dense layers. The sizes of + the hidden layers are specified by the user in `hidden_layers` and the + network takes care of adding layers to match the input and output size. + + Attributes: + weights: A list of 2-d float32 arrays containing the layer weights. + biases: A list of 2-d float32 arrays containing the layer biases. + + Methods: + forward: Can be used to perform a forward pass on a batch of + flattened images. Output is returned as a batch of one-hot vectors of the + classes. + train: method performs a forward and backward pass and updates the + weights and biases. + evaluate: method can be used to evaluate the network on a batch of + examples. + """ + + def __init__(self, hidden_layers, input_size=784, num_classes=10): + """Initializes the neural network. + + Args: + hidden_layers: List of ints specifying the sizes of hidden layers. Could + be empty. + input_size: Length of the input array. The network receives the input + image as a flattened 1-d array. Defaults to 784(28*28), the default + image size for MNIST. + num_classes: The number of output classes. Defaults to 10. + """ + hidden_layers = [input_size] + hidden_layers + [num_classes] + self.weights = [] + self.biases = [] + for i in range(len(hidden_layers) - 1): + self.weights.append( + np.array( + np.random.randn(hidden_layers[i + 1], hidden_layers[i]), + copy=False, + dtype=float32, + ) + ) + self.biases.append( + np.array( + np.random.randn(hidden_layers[i + 1]), copy=False, dtype=float32 + ) + ) + + def forward(self, x): + """Performs the forward pass. + + Args: + x: 2-d array of size batch_size x image_size. + + Returns: + A 2-d array of size batch_size x num_classes. + """ + + def sigmoid(x): + return 1.0 / (1.0 + np.exp(-x)) + + for w, b in zip(self.weights, self.biases): + x = sigmoid(np.dot(w, x.T).T + b) + return x + + def train(self, x, y, learning_rate=0.01): + """Runs a single training pass. + + Args: + x: 2-d array of size batch_size x image_size. + y: 2-d array of size batch_size x num_classes in one-hot notation. + learning_rate: The learning rate. + """ + x = np.array(x, copy=False) + y = np.array(y, copy=False) + + def mean_squared_error(x, y): + diff = x - y + return np.sum(diff * diff) / len(x) + + wb_tensors = self.weights + self.biases + with tf.GradientTape() as g: + g.watch(wb_tensors) + loss = mean_squared_error(self.forward(x), y) + gradients = g.gradient(loss, wb_tensors) + gradients = [np.asarray(grad) for grad in gradients] + + new_weights_and_biases = [] + for v, dv in zip(self.weights + self.biases, gradients): + new_weights_and_biases.append(v - learning_rate * dv) + + total_len = len(new_weights_and_biases) + self.weights = new_weights_and_biases[: total_len // 2] + self.biases = new_weights_and_biases[total_len // 2 :] + + def evaluate(self, x, y): + """Returns the number of correct predictions. + + Args: + x: 2-d array of size batch_size x image_size. + y: 2-d array of size batch_size x num_classes. + + Returns: + A scalar, the number of correct predictions. + """ + y_actual = np.argmax(y, axis=1) + y_predicted = np.argmax(self.forward(x), axis=1) + return int(np.sum(np.array(y_actual == y_predicted, copy=False, dtype=int32))) diff --git a/resources/examples/python/nn/mnist/from_scratch/train.py b/resources/examples/python/nn/mnist/from_scratch/train.py new file mode 100644 index 000000000..ccac475af --- /dev/null +++ b/resources/examples/python/nn/mnist/from_scratch/train.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Perform training.""" +from absl import app, flags +from six.moves import range + +from resources.examples.python.nn.mnist.from_scratch import dataset +from resources.examples.python.nn.mnist.from_scratch import model as model_lib + +FLAGS = flags.FLAGS + +flags.DEFINE_integer("batch_size", 50, "Batch size.") +flags.DEFINE_integer("num_training_iters", 10000, "Number of iterations to train for.") +flags.DEFINE_integer( + "validation_steps", 100, "Validation is performed every these many training steps." +) +flags.DEFINE_float("learning_rate", 5.0, "Learning rate.") + + +def train(batch_size, learning_rate, num_training_iters, validation_steps): + """Runs the training.""" + print("Loading data") + training_data, validation_data, test_data = dataset.load() + print( + "Loaded dataset with {} training, {} validation and {} test examples.".format( + len(training_data[0]), len(validation_data[0]), len(test_data[0]) + ) + ) + + assert len(training_data[0]) % batch_size == 0 + assert len(validation_data[0]) % batch_size == 0 + assert len(test_data[0]) % batch_size == 0 + + def build_iterator(data, infinite=True): + """Build the iterator for inputs.""" + index = 0 + size = len(data[0]) + while True: + if index + batch_size > size: + if infinite: + index = 0 + else: + return + yield ( + data[0][index : index + batch_size], + data[1][index : index + batch_size], + ) + index += batch_size + + train_iter = build_iterator(training_data) + model = model_lib.Model([30]) + + for i in range(num_training_iters): + train_x, train_y = next(train_iter) + model.train(train_x, train_y, learning_rate) + if (i + 1) % validation_steps == 0: + validation_iter = build_iterator(validation_data, infinite=False) + correct_predictions = 0 + for valid_x, valid_y in validation_iter: + correct_predictions += model.evaluate(valid_x, valid_y) + print( + "{}/{} correct validation predictions.".format( + correct_predictions, len(validation_data[0]) + ) + ) + + +def main(unused_argv): + train( + FLAGS.batch_size, + FLAGS.learning_rate, + FLAGS.num_training_iters, + FLAGS.validation_steps, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/resources/examples/python/nn/mnist/train.py b/resources/examples/python/nn/mnist/train.py new file mode 100644 index 000000000..96732bc15 --- /dev/null +++ b/resources/examples/python/nn/mnist/train.py @@ -0,0 +1,80 @@ +import trax.fastmath as fastmath + +from resources.examples.python.base import ( + Dataset, + DeviceType, + Splits, + create_batch_generator, + evaluate_model, + initialize_model, + load_dataset, + train_model, +) +from trax import layers as tl +from trax import optimizers +from trax.trainers import jax as trainers + + +def build_model(): + # Build your model with loss function + model = tl.Serial( + tl.Dense(128, use_bias=True), tl.Relu(), tl.Dense(10, use_bias=False) + ) + model_with_loss = tl.Serial(model, tl.CrossEntropyLossWithLogSoftmax()) + return model_with_loss + + +def main(): + # Default setup + DEFAULT_BATCH_SIZE = 8 + STEPS_NUMBER = 20_000 + + # Load data + X, y = load_dataset(Dataset.MNIST.value) + batch_generator = create_batch_generator( + X, y, batch_size=DEFAULT_BATCH_SIZE, seed=42 + ) + example_batch = next(batch_generator) + + # Build and initialize model + model_with_loss = build_model() + initialize_model(model_with_loss, example_batch) + + # Setup optimizer and trainers + optimizer = optimizers.Adam(0.001) + trainer = trainers.Trainer(model_with_loss, optimizer) + + base_rng = fastmath.random.get_prng(0) + + # Run training on CPU and/or GPU + train_model( + trainer, + batch_generator, + STEPS_NUMBER, + base_rng, + device_type=DeviceType.GPU.value, + ) + + # Load test data + test_data, test_labels = load_dataset( + dataset_name=Dataset.MNIST.value, split=Splits.TEST.value + ) + + # Create batch generator for test data + test_batch_gen = create_batch_generator( + test_data, test_labels, None, DEFAULT_BATCH_SIZE, 0 + ) + + # Evaluate model on a test set + test_results = evaluate_model( + trainer=trainer, + batch_gen=test_batch_gen, + device_type=DeviceType.CPU.value, + num_batches=100, + ) + + print(f"Final test accuracy: {test_results['accuracy']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/trax/supervised/configs/bert.gin b/resources/supervised/configs/bert.gin similarity index 95% rename from trax/supervised/configs/bert.gin rename to resources/supervised/configs/bert.gin index 9f46a31d9..97e7827c0 100644 --- a/trax/supervised/configs/bert.gin +++ b/resources/supervised/configs/bert.gin @@ -17,8 +17,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.layers.metrics # Parameters for TFDS data pipeline: diff --git a/trax/supervised/configs/bert_glue_classification.gin b/resources/supervised/configs/bert_glue_classification.gin similarity index 100% rename from trax/supervised/configs/bert_glue_classification.gin rename to resources/supervised/configs/bert_glue_classification.gin diff --git a/trax/supervised/configs/bert_glue_regression.gin b/resources/supervised/configs/bert_glue_regression.gin similarity index 100% rename from trax/supervised/configs/bert_glue_regression.gin rename to resources/supervised/configs/bert_glue_regression.gin diff --git a/trax/supervised/configs/bert_glue_sweep_regression_task.yaml b/resources/supervised/configs/bert_glue_sweep_regression_task.yaml similarity index 100% rename from trax/supervised/configs/bert_glue_sweep_regression_task.yaml rename to resources/supervised/configs/bert_glue_sweep_regression_task.yaml diff --git a/trax/supervised/configs/bert_glue_sweep_single_sentence.yaml b/resources/supervised/configs/bert_glue_sweep_single_sentence.yaml similarity index 100% rename from trax/supervised/configs/bert_glue_sweep_single_sentence.yaml rename to resources/supervised/configs/bert_glue_sweep_single_sentence.yaml diff --git a/trax/supervised/configs/bert_glue_sweep_two_sentences.yaml b/resources/supervised/configs/bert_glue_sweep_two_sentences.yaml similarity index 100% rename from trax/supervised/configs/bert_glue_sweep_two_sentences.yaml rename to resources/supervised/configs/bert_glue_sweep_two_sentences.yaml diff --git a/trax/supervised/configs/bert_pretraining.gin b/resources/supervised/configs/bert_pretraining.gin similarity index 100% rename from trax/supervised/configs/bert_pretraining.gin rename to resources/supervised/configs/bert_pretraining.gin diff --git a/trax/supervised/configs/bert_pretraining_onlymlm.gin b/resources/supervised/configs/bert_pretraining_onlymlm.gin similarity index 100% rename from trax/supervised/configs/bert_pretraining_onlymlm.gin rename to resources/supervised/configs/bert_pretraining_onlymlm.gin diff --git a/trax/supervised/configs/bert_pretraining_onlynsp.gin b/resources/supervised/configs/bert_pretraining_onlynsp.gin similarity index 100% rename from trax/supervised/configs/bert_pretraining_onlynsp.gin rename to resources/supervised/configs/bert_pretraining_onlynsp.gin diff --git a/trax/supervised/configs/c4.gin b/resources/supervised/configs/c4.gin similarity index 100% rename from trax/supervised/configs/c4.gin rename to resources/supervised/configs/c4.gin diff --git a/trax/supervised/configs/c4_pretrain_16gb_adafactor.gin b/resources/supervised/configs/c4_pretrain_16gb_adafactor.gin similarity index 99% rename from trax/supervised/configs/c4_pretrain_16gb_adafactor.gin rename to resources/supervised/configs/c4_pretrain_16gb_adafactor.gin index df43e9f8f..b26c8cb08 100644 --- a/trax/supervised/configs/c4_pretrain_16gb_adafactor.gin +++ b/resources/supervised/configs/c4_pretrain_16gb_adafactor.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib include 'c4_trax_data.gin' diff --git a/trax/supervised/configs/c4_trax_data.gin b/resources/supervised/configs/c4_trax_data.gin similarity index 98% rename from trax/supervised/configs/c4_trax_data.gin rename to resources/supervised/configs/c4_trax_data.gin index d0b5ad3a8..cda964022 100644 --- a/trax/supervised/configs/c4_trax_data.gin +++ b/resources/supervised/configs/c4_trax_data.gin @@ -13,7 +13,7 @@ # limitations under the License. import trax.data -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Macros: # ============================================================================== diff --git a/trax/supervised/configs/cond_skipping_transformer_lm1b.gin b/resources/supervised/configs/cond_skipping_transformer_lm1b.gin similarity index 98% rename from trax/supervised/configs/cond_skipping_transformer_lm1b.gin rename to resources/supervised/configs/cond_skipping_transformer_lm1b.gin index d5dbd7b7c..280c0d0d0 100644 --- a/trax/supervised/configs/cond_skipping_transformer_lm1b.gin +++ b/resources/supervised/configs/cond_skipping_transformer_lm1b.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/gru_copy.gin b/resources/supervised/configs/gru_copy.gin similarity index 97% rename from trax/supervised/configs/gru_copy.gin rename to resources/supervised/configs/gru_copy.gin index e40e13d00..d31d862d7 100644 --- a/trax/supervised/configs/gru_copy.gin +++ b/resources/supervised/configs/gru_copy.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib n_symbols = 32 length = 16 diff --git a/trax/supervised/configs/hourglass_cifar10.gin b/resources/supervised/configs/hourglass_cifar10.gin similarity index 98% rename from trax/supervised/configs/hourglass_cifar10.gin rename to resources/supervised/configs/hourglass_cifar10.gin index aca156772..16b85d86b 100644 --- a/trax/supervised/configs/hourglass_cifar10.gin +++ b/resources/supervised/configs/hourglass_cifar10.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib train_steps = 100000 diff --git a/trax/supervised/configs/hourglass_enwik8.gin b/resources/supervised/configs/hourglass_enwik8.gin similarity index 98% rename from trax/supervised/configs/hourglass_enwik8.gin rename to resources/supervised/configs/hourglass_enwik8.gin index 75914641a..5ba384084 100644 --- a/trax/supervised/configs/hourglass_enwik8.gin +++ b/resources/supervised/configs/hourglass_enwik8.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: diff --git a/trax/supervised/configs/hourglass_imagenet32.gin b/resources/supervised/configs/hourglass_imagenet32.gin similarity index 98% rename from trax/supervised/configs/hourglass_imagenet32.gin rename to resources/supervised/configs/hourglass_imagenet32.gin index 95782ccce..5764a5f7d 100644 --- a/trax/supervised/configs/hourglass_imagenet32.gin +++ b/resources/supervised/configs/hourglass_imagenet32.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/hourglass_imagenet64.gin b/resources/supervised/configs/hourglass_imagenet64.gin similarity index 98% rename from trax/supervised/configs/hourglass_imagenet64.gin rename to resources/supervised/configs/hourglass_imagenet64.gin index f2f3515a4..41147eb60 100644 --- a/trax/supervised/configs/hourglass_imagenet64.gin +++ b/resources/supervised/configs/hourglass_imagenet64.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters that will vary between experiments: # ============================================================================== diff --git a/trax/supervised/configs/hourglass_wiki40b.gin b/resources/supervised/configs/hourglass_wiki40b.gin similarity index 98% rename from trax/supervised/configs/hourglass_wiki40b.gin rename to resources/supervised/configs/hourglass_wiki40b.gin index 9d97ecf88..432035afd 100644 --- a/trax/supervised/configs/hourglass_wiki40b.gin +++ b/resources/supervised/configs/hourglass_wiki40b.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Macros: # ============================================================================== diff --git a/trax/supervised/configs/layerdrop_every_transformer_lm1b.gin b/resources/supervised/configs/layerdrop_every_transformer_lm1b.gin similarity index 98% rename from trax/supervised/configs/layerdrop_every_transformer_lm1b.gin rename to resources/supervised/configs/layerdrop_every_transformer_lm1b.gin index 593219a55..5dfdf0300 100644 --- a/trax/supervised/configs/layerdrop_every_transformer_lm1b.gin +++ b/resources/supervised/configs/layerdrop_every_transformer_lm1b.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/layerdrop_transformer_lm1b.gin b/resources/supervised/configs/layerdrop_transformer_lm1b.gin similarity index 98% rename from trax/supervised/configs/layerdrop_transformer_lm1b.gin rename to resources/supervised/configs/layerdrop_transformer_lm1b.gin index bf7077d44..a1187a90a 100644 --- a/trax/supervised/configs/layerdrop_transformer_lm1b.gin +++ b/resources/supervised/configs/layerdrop_transformer_lm1b.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/layerdrop_ushape_transformer_lm1b.gin b/resources/supervised/configs/layerdrop_ushape_transformer_lm1b.gin similarity index 98% rename from trax/supervised/configs/layerdrop_ushape_transformer_lm1b.gin rename to resources/supervised/configs/layerdrop_ushape_transformer_lm1b.gin index f3c002c2b..b6d849331 100644 --- a/trax/supervised/configs/layerdrop_ushape_transformer_lm1b.gin +++ b/resources/supervised/configs/layerdrop_ushape_transformer_lm1b.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/lstm_lm1b.gin b/resources/supervised/configs/lstm_lm1b.gin similarity index 97% rename from trax/supervised/configs/lstm_lm1b.gin rename to resources/supervised/configs/lstm_lm1b.gin index 95cac448b..457722289 100644 --- a/trax/supervised/configs/lstm_lm1b.gin +++ b/resources/supervised/configs/lstm_lm1b.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin b/resources/supervised/configs/lstm_seq2seq_wmt_ende.gin similarity index 96% rename from trax/supervised/configs/lstm_seq2seq_wmt_ende.gin rename to resources/supervised/configs/lstm_seq2seq_wmt_ende.gin index 0c0b16419..627b3a824 100644 --- a/trax/supervised/configs/lstm_seq2seq_wmt_ende.gin +++ b/resources/supervised/configs/lstm_seq2seq_wmt_ende.gin @@ -15,8 +15,8 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/mlp_mnist.gin b/resources/supervised/configs/mlp_mnist.gin similarity index 94% rename from trax/supervised/configs/mlp_mnist.gin rename to resources/supervised/configs/mlp_mnist.gin index 52e4934a9..fda8d1b3a 100644 --- a/trax/supervised/configs/mlp_mnist.gin +++ b/resources/supervised/configs/mlp_mnist.gin @@ -15,8 +15,8 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/reformer_addition.gin b/resources/supervised/configs/reformer_addition.gin similarity index 98% rename from trax/supervised/configs/reformer_addition.gin rename to resources/supervised/configs/reformer_addition.gin index e96f0edc8..c51b57e6c 100644 --- a/trax/supervised/configs/reformer_addition.gin +++ b/resources/supervised/configs/reformer_addition.gin @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib vocab_size = 13 # For addition, base = vocab_size - 3. diff --git a/trax/supervised/configs/reformer_bair_robot_pushing.gin b/resources/supervised/configs/reformer_bair_robot_pushing.gin similarity index 96% rename from trax/supervised/configs/reformer_bair_robot_pushing.gin rename to resources/supervised/configs/reformer_bair_robot_pushing.gin index 0b7bbfc15..87a82716e 100644 --- a/trax/supervised/configs/reformer_bair_robot_pushing.gin +++ b/resources/supervised/configs/reformer_bair_robot_pushing.gin @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters that will vary between experiments: # ============================================================================== @@ -28,7 +28,7 @@ dropout = 0.0 # Parameters for batcher: # ============================================================================== -batcher.data_streams = @data.data_streams +batcher.data_streams = @data_streams batcher.batch_size_per_device = 1 batcher.eval_batch_size = 8 batcher.max_eval_length = 24576 diff --git a/trax/supervised/configs/reformer_c4.gin b/resources/supervised/configs/reformer_c4.gin similarity index 95% rename from trax/supervised/configs/reformer_c4.gin rename to resources/supervised/configs/reformer_c4.gin index 8a694a609..5d58c3696 100644 --- a/trax/supervised/configs/reformer_c4.gin +++ b/resources/supervised/configs/reformer_c4.gin @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base +import trax.data.preprocessing.tf.c4 import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib MAX_EVAL_LENGTH = 1024 # Get back to 65536 BUCKETS = ([1025], [8, 1]) @@ -47,7 +48,7 @@ LSHSelfAttention.n_hashes = 2 # Parameters for batcher: # ============================================================================== -batcher.data_streams = @data.data_streams +batcher.data_streams = @data_streams batcher.batch_size_per_device = 8 batcher.eval_batch_size = 8 batcher.max_eval_length = %MAX_EVAL_LENGTH @@ -59,7 +60,7 @@ data_streams.data_dir = None data_streams.dataset_name = 'c4' data_streams.input_name = 'targets' data_streams.target_name = 'text' -data_streams.preprocess_fn=@data.c4_preprocess +data_streams.preprocess_fn=@c4_preprocess # Parameters for c4_preprocess: # ============================================================================== diff --git a/trax/supervised/configs/reformer_cifar10.gin b/resources/supervised/configs/reformer_cifar10.gin similarity index 91% rename from trax/supervised/configs/reformer_cifar10.gin rename to resources/supervised/configs/reformer_cifar10.gin index b301040b6..3b4d23726 100644 --- a/trax/supervised/configs/reformer_cifar10.gin +++ b/resources/supervised/configs/reformer_cifar10.gin @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base +import trax.data.preprocessing.tf.c4 import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters that will vary between experiments: # ============================================================================== @@ -25,7 +26,7 @@ n_layers = 9 # Parameters for batcher: # ============================================================================== -batcher.data_streams = @data.data_streams +batcher.data_streams = @data_streams batcher.batch_size_per_device = 1 batcher.eval_batch_size = 8 batcher.max_eval_length = 12288 # 64 * 64 * 3 @@ -34,7 +35,7 @@ batcher.max_eval_length = 12288 # 64 * 64 * 3 # ============================================================================== data_streams.data_dir = None data_streams.dataset_name = 'cifar10' -data_streams.preprocess_fn = @data.cifar10_augmentation_flatten_preprocess +data_streams.preprocess_fn = @cifar10_augmentation_flatten_preprocess # Parameters for multifactor: # ============================================================================== diff --git a/trax/supervised/configs/reformer_copy.gin b/resources/supervised/configs/reformer_copy.gin similarity index 95% rename from trax/supervised/configs/reformer_copy.gin rename to resources/supervised/configs/reformer_copy.gin index 891f24c3b..63b0d103c 100644 --- a/trax/supervised/configs/reformer_copy.gin +++ b/resources/supervised/configs/reformer_copy.gin @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base +import trax.data.preprocessing.inputs import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib vocab_size = 13 # For addition, base = vocab_size - 3. max_len = 32 @@ -99,8 +100,8 @@ TransformerLM.vocab_size = %vocab_size # Parameters for train: # ============================================================================== -train.inputs = @trax.data.sequence_copy_inputs -# train.inputs = @trax.data.addition_inputs +train.inputs = @sequence_copy_inputs +# train.inputs = @addition_inputs train.eval_frequency = 100 train.eval_steps = 10 train.optimizer = @trax.optimizers.Adam diff --git a/trax/supervised/configs/reformer_enwik8.gin b/resources/supervised/configs/reformer_enwik8.gin similarity index 98% rename from trax/supervised/configs/reformer_enwik8.gin rename to resources/supervised/configs/reformer_enwik8.gin index 20938af95..95b7c1f5b 100644 --- a/trax/supervised/configs/reformer_enwik8.gin +++ b/resources/supervised/configs/reformer_enwik8.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters that will vary between experiments: # ============================================================================== diff --git a/trax/supervised/configs/reformer_imagenet64.gin b/resources/supervised/configs/reformer_imagenet64.gin similarity index 98% rename from trax/supervised/configs/reformer_imagenet64.gin rename to resources/supervised/configs/reformer_imagenet64.gin index fbd5eba5c..eae5fe7f8 100644 --- a/trax/supervised/configs/reformer_imagenet64.gin +++ b/resources/supervised/configs/reformer_imagenet64.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters that will vary between experiments: # ============================================================================== diff --git a/trax/supervised/configs/reformer_imagenet64_testing.gin b/resources/supervised/configs/reformer_imagenet64_testing.gin similarity index 98% rename from trax/supervised/configs/reformer_imagenet64_testing.gin rename to resources/supervised/configs/reformer_imagenet64_testing.gin index 878649ee9..9df2a9537 100644 --- a/trax/supervised/configs/reformer_imagenet64_testing.gin +++ b/resources/supervised/configs/reformer_imagenet64_testing.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters that will vary between experiments: # ============================================================================== diff --git a/trax/supervised/configs/reformer_pc_enpl.gin b/resources/supervised/configs/reformer_pc_enpl.gin similarity index 98% rename from trax/supervised/configs/reformer_pc_enpl.gin rename to resources/supervised/configs/reformer_pc_enpl.gin index e306bb12e..7ac599e34 100644 --- a/trax/supervised/configs/reformer_pc_enpl.gin +++ b/resources/supervised/configs/reformer_pc_enpl.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib import t5.data.preprocessors diff --git a/trax/supervised/configs/reformer_wmt_ende.gin b/resources/supervised/configs/reformer_wmt_ende.gin similarity index 90% rename from trax/supervised/configs/reformer_wmt_ende.gin rename to resources/supervised/configs/reformer_wmt_ende.gin index bb0a873f6..475989238 100644 --- a/trax/supervised/configs/reformer_wmt_ende.gin +++ b/resources/supervised/configs/reformer_wmt_ende.gin @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base +import trax.data.preprocessing.tf.wmt import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== -batcher.data_streams = @data.data_streams +batcher.data_streams = @data_streams batcher.batch_size_per_device = 256 batcher.eval_batch_size = 64 batcher.max_eval_length = 512 @@ -30,8 +31,8 @@ batcher.id_to_mask = 0 # Parameters for data_streams: # ============================================================================== data_streams.data_dir = None -data_streams.dataset_name = 't2t_translate_ende_wmt32k' -data_streams.preprocess_fn = @data.wmt_preprocess +data_streams.dataset_name = 't2t_wmt14_translate/de-en' +data_streams.preprocess_fn = @wmt_preprocess # Parameters for multifactor: # ============================================================================== diff --git a/trax/supervised/configs/reformer_wmt_ende_big.gin b/resources/supervised/configs/reformer_wmt_ende_big.gin similarity index 90% rename from trax/supervised/configs/reformer_wmt_ende_big.gin rename to resources/supervised/configs/reformer_wmt_ende_big.gin index b807a6ae7..82bbd7c66 100644 --- a/trax/supervised/configs/reformer_wmt_ende_big.gin +++ b/resources/supervised/configs/reformer_wmt_ende_big.gin @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base +import trax.data.preprocessing.tf.wmt import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== -batcher.data_streams = @data.data_streams +batcher.data_streams = @data_streams batcher.batch_size_per_device = 256 batcher.eval_batch_size = 64 batcher.max_eval_length = 512 @@ -30,8 +31,8 @@ batcher.id_to_mask = 0 # Parameters for data_streams: # ============================================================================== data_streams.data_dir = None -data_streams.dataset_name = 't2t_translate_ende_wmt32k' -data_streams.preprocess_fn = @data.wmt_preprocess +data_streams.dataset_name = 'wmt14_translate/de-en' +data_streams.preprocess_fn = @wmt_preprocess # Parameters for multifactor: # ============================================================================== diff --git a/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin b/resources/supervised/configs/resnet50_frn_imagenet_8gb.gin similarity index 98% rename from trax/supervised/configs/resnet50_frn_imagenet_8gb.gin rename to resources/supervised/configs/resnet50_frn_imagenet_8gb.gin index d08d1b378..720e98c52 100644 --- a/trax/supervised/configs/resnet50_frn_imagenet_8gb.gin +++ b/resources/supervised/configs/resnet50_frn_imagenet_8gb.gin @@ -16,7 +16,7 @@ import trax.data import trax.supervised.lr_schedules import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin b/resources/supervised/configs/resnet50_imagenet_8gb_testing.gin similarity index 97% rename from trax/supervised/configs/resnet50_imagenet_8gb_testing.gin rename to resources/supervised/configs/resnet50_imagenet_8gb_testing.gin index 5ffdc2a9d..ab252c285 100644 --- a/trax/supervised/configs/resnet50_imagenet_8gb_testing.gin +++ b/resources/supervised/configs/resnet50_imagenet_8gb_testing.gin @@ -16,7 +16,7 @@ import trax.data import trax.supervised.lr_schedules import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin b/resources/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin similarity index 98% rename from trax/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin rename to resources/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin index 7e27fa945..2a1eef59e 100644 --- a/trax/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin +++ b/resources/supervised/configs/rezero_wmt_ende_16gb_adafactor_testing.gin @@ -15,7 +15,7 @@ import trax.data import trax.models.research.rezero import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/rse_addition.gin b/resources/supervised/configs/rse_addition.gin similarity index 98% rename from trax/supervised/configs/rse_addition.gin rename to resources/supervised/configs/rse_addition.gin index f1060c713..bd226353a 100644 --- a/trax/supervised/configs/rse_addition.gin +++ b/resources/supervised/configs/rse_addition.gin @@ -16,7 +16,7 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib vocab_size = 5 # For arithmetic operations, base = vocab_size - 3. batch_size = 128 diff --git a/trax/supervised/configs/rse_addition_sweep.yaml b/resources/supervised/configs/rse_addition_sweep.yaml similarity index 100% rename from trax/supervised/configs/rse_addition_sweep.yaml rename to resources/supervised/configs/rse_addition_sweep.yaml diff --git a/trax/supervised/configs/scientific_papers_reformer_lm.gin b/resources/supervised/configs/scientific_papers_reformer_lm.gin similarity index 99% rename from trax/supervised/configs/scientific_papers_reformer_lm.gin rename to resources/supervised/configs/scientific_papers_reformer_lm.gin index b472c90e2..f1868f7f8 100644 --- a/trax/supervised/configs/scientific_papers_reformer_lm.gin +++ b/resources/supervised/configs/scientific_papers_reformer_lm.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Macros: # ============================================================================== diff --git a/trax/supervised/configs/scientific_papers_terraformer.gin b/resources/supervised/configs/scientific_papers_terraformer.gin similarity index 99% rename from trax/supervised/configs/scientific_papers_terraformer.gin rename to resources/supervised/configs/scientific_papers_terraformer.gin index 53fbbd7ce..4bc29177d 100644 --- a/trax/supervised/configs/scientific_papers_terraformer.gin +++ b/resources/supervised/configs/scientific_papers_terraformer.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Macros: # ============================================================================== diff --git a/trax/supervised/configs/scientific_papers_terraformer_favor.gin b/resources/supervised/configs/scientific_papers_terraformer_favor.gin similarity index 99% rename from trax/supervised/configs/scientific_papers_terraformer_favor.gin rename to resources/supervised/configs/scientific_papers_terraformer_favor.gin index 437fd5b9c..4fa0f3719 100644 --- a/trax/supervised/configs/scientific_papers_terraformer_favor.gin +++ b/resources/supervised/configs/scientific_papers_terraformer_favor.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Macros: # ============================================================================== diff --git a/trax/supervised/configs/scientific_papers_terraformer_pretrained.gin b/resources/supervised/configs/scientific_papers_terraformer_pretrained.gin similarity index 99% rename from trax/supervised/configs/scientific_papers_terraformer_pretrained.gin rename to resources/supervised/configs/scientific_papers_terraformer_pretrained.gin index 387f65e73..f823aa09c 100644 --- a/trax/supervised/configs/scientific_papers_terraformer_pretrained.gin +++ b/resources/supervised/configs/scientific_papers_terraformer_pretrained.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Macros: # ============================================================================== diff --git a/trax/supervised/configs/skipping_transformer_lm1b.gin b/resources/supervised/configs/skipping_transformer_lm1b.gin similarity index 98% rename from trax/supervised/configs/skipping_transformer_lm1b.gin rename to resources/supervised/configs/skipping_transformer_lm1b.gin index d5dbd7b7c..280c0d0d0 100644 --- a/trax/supervised/configs/skipping_transformer_lm1b.gin +++ b/resources/supervised/configs/skipping_transformer_lm1b.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/sp_sweep.yaml b/resources/supervised/configs/sp_sweep.yaml similarity index 100% rename from trax/supervised/configs/sp_sweep.yaml rename to resources/supervised/configs/sp_sweep.yaml diff --git a/trax/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin b/resources/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin similarity index 99% rename from trax/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin rename to resources/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin index d62d1cfac..7cf6ba02f 100644 --- a/trax/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin +++ b/resources/supervised/configs/sparse_c4_pretrain_16gb_adafactor.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib include 'c4_trax_data.gin' diff --git a/trax/supervised/configs/sparse_lm1b_pretrain_16gb.gin b/resources/supervised/configs/sparse_lm1b_pretrain_16gb.gin similarity index 98% rename from trax/supervised/configs/sparse_lm1b_pretrain_16gb.gin rename to resources/supervised/configs/sparse_lm1b_pretrain_16gb.gin index cfd546fd8..3740056ac 100644 --- a/trax/supervised/configs/sparse_lm1b_pretrain_16gb.gin +++ b/resources/supervised/configs/sparse_lm1b_pretrain_16gb.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib ff_chunk_size = 0 diff --git a/trax/supervised/configs/t5_aqua_parallel.gin b/resources/supervised/configs/t5_aqua_parallel.gin similarity index 99% rename from trax/supervised/configs/t5_aqua_parallel.gin rename to resources/supervised/configs/t5_aqua_parallel.gin index 90bdd8dc6..08330f1d5 100644 --- a/trax/supervised/configs/t5_aqua_parallel.gin +++ b/resources/supervised/configs/t5_aqua_parallel.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_drop.gin b/resources/supervised/configs/t5_drop.gin similarity index 98% rename from trax/supervised/configs/t5_drop.gin rename to resources/supervised/configs/t5_drop.gin index 514994e45..760164f1f 100644 --- a/trax/supervised/configs/t5_drop.gin +++ b/resources/supervised/configs/t5_drop.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_glue_classification.gin b/resources/supervised/configs/t5_glue_classification.gin similarity index 98% rename from trax/supervised/configs/t5_glue_classification.gin rename to resources/supervised/configs/t5_glue_classification.gin index 374bf1b78..f0edccbe3 100644 --- a/trax/supervised/configs/t5_glue_classification.gin +++ b/resources/supervised/configs/t5_glue_classification.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_glue_classification_mnli.gin b/resources/supervised/configs/t5_glue_classification_mnli.gin similarity index 98% rename from trax/supervised/configs/t5_glue_classification_mnli.gin rename to resources/supervised/configs/t5_glue_classification_mnli.gin index 884543e85..245c91adf 100644 --- a/trax/supervised/configs/t5_glue_classification_mnli.gin +++ b/resources/supervised/configs/t5_glue_classification_mnli.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_glue_classification_parallel.gin b/resources/supervised/configs/t5_glue_classification_parallel.gin similarity index 98% rename from trax/supervised/configs/t5_glue_classification_parallel.gin rename to resources/supervised/configs/t5_glue_classification_parallel.gin index 0920f3f21..2383a49ef 100644 --- a/trax/supervised/configs/t5_glue_classification_parallel.gin +++ b/resources/supervised/configs/t5_glue_classification_parallel.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_glue_classification_two_constants.gin b/resources/supervised/configs/t5_glue_classification_two_constants.gin similarity index 98% rename from trax/supervised/configs/t5_glue_classification_two_constants.gin rename to resources/supervised/configs/t5_glue_classification_two_constants.gin index b93a44945..7a6ccdcd4 100644 --- a/trax/supervised/configs/t5_glue_classification_two_constants.gin +++ b/resources/supervised/configs/t5_glue_classification_two_constants.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa.gin b/resources/supervised/configs/t5_mathqa.gin similarity index 98% rename from trax/supervised/configs/t5_mathqa.gin rename to resources/supervised/configs/t5_mathqa.gin index e0698a877..2b1652d98 100644 --- a/trax/supervised/configs/t5_mathqa.gin +++ b/resources/supervised/configs/t5_mathqa.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa_drop_loop.gin b/resources/supervised/configs/t5_mathqa_drop_loop.gin similarity index 99% rename from trax/supervised/configs/t5_mathqa_drop_loop.gin rename to resources/supervised/configs/t5_mathqa_drop_loop.gin index 7f0797dc8..6506245fa 100644 --- a/trax/supervised/configs/t5_mathqa_drop_loop.gin +++ b/resources/supervised/configs/t5_mathqa_drop_loop.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa_drop_sweep.yaml b/resources/supervised/configs/t5_mathqa_drop_sweep.yaml similarity index 100% rename from trax/supervised/configs/t5_mathqa_drop_sweep.yaml rename to resources/supervised/configs/t5_mathqa_drop_sweep.yaml diff --git a/trax/supervised/configs/t5_mathqa_multi.gin b/resources/supervised/configs/t5_mathqa_multi.gin similarity index 98% rename from trax/supervised/configs/t5_mathqa_multi.gin rename to resources/supervised/configs/t5_mathqa_multi.gin index 05b229b09..af304e1a5 100644 --- a/trax/supervised/configs/t5_mathqa_multi.gin +++ b/resources/supervised/configs/t5_mathqa_multi.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa_parallel.gin b/resources/supervised/configs/t5_mathqa_parallel.gin similarity index 99% rename from trax/supervised/configs/t5_mathqa_parallel.gin rename to resources/supervised/configs/t5_mathqa_parallel.gin index 8b12fee12..c4edc49ac 100644 --- a/trax/supervised/configs/t5_mathqa_parallel.gin +++ b/resources/supervised/configs/t5_mathqa_parallel.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa_parallel_full.gin b/resources/supervised/configs/t5_mathqa_parallel_full.gin similarity index 99% rename from trax/supervised/configs/t5_mathqa_parallel_full.gin rename to resources/supervised/configs/t5_mathqa_parallel_full.gin index 0e2678f21..9cda0fb22 100644 --- a/trax/supervised/configs/t5_mathqa_parallel_full.gin +++ b/resources/supervised/configs/t5_mathqa_parallel_full.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa_parallel_full_correct_order.gin b/resources/supervised/configs/t5_mathqa_parallel_full_correct_order.gin similarity index 99% rename from trax/supervised/configs/t5_mathqa_parallel_full_correct_order.gin rename to resources/supervised/configs/t5_mathqa_parallel_full_correct_order.gin index bea97effe..d78e7321e 100644 --- a/trax/supervised/configs/t5_mathqa_parallel_full_correct_order.gin +++ b/resources/supervised/configs/t5_mathqa_parallel_full_correct_order.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa_parallel_full_order.gin b/resources/supervised/configs/t5_mathqa_parallel_full_order.gin similarity index 99% rename from trax/supervised/configs/t5_mathqa_parallel_full_order.gin rename to resources/supervised/configs/t5_mathqa_parallel_full_order.gin index 1a2ddf25d..7138b1a97 100644 --- a/trax/supervised/configs/t5_mathqa_parallel_full_order.gin +++ b/resources/supervised/configs/t5_mathqa_parallel_full_order.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin b/resources/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin similarity index 99% rename from trax/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin rename to resources/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin index f270e8b99..0d127a99a 100644 --- a/trax/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin +++ b/resources/supervised/configs/t5_mathqa_parallel_with_drop_annot.gin @@ -16,8 +16,8 @@ import trax.data import trax.layers import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib import trax.models.research.bert import trax.layers.metrics diff --git a/trax/supervised/configs/t5_sweep.yaml b/resources/supervised/configs/t5_sweep.yaml similarity index 100% rename from trax/supervised/configs/t5_sweep.yaml rename to resources/supervised/configs/t5_sweep.yaml diff --git a/trax/supervised/configs/t5_sweep_temperature.yaml b/resources/supervised/configs/t5_sweep_temperature.yaml similarity index 100% rename from trax/supervised/configs/t5_sweep_temperature.yaml rename to resources/supervised/configs/t5_sweep_temperature.yaml diff --git a/trax/supervised/configs/terraformer_c4_medium.gin b/resources/supervised/configs/terraformer_c4_medium.gin similarity index 99% rename from trax/supervised/configs/terraformer_c4_medium.gin rename to resources/supervised/configs/terraformer_c4_medium.gin index 6537b35c5..beb09281c 100644 --- a/trax/supervised/configs/terraformer_c4_medium.gin +++ b/resources/supervised/configs/terraformer_c4_medium.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib include 'c4_trax_data.gin' diff --git a/trax/supervised/configs/terraformer_copy.gin b/resources/supervised/configs/terraformer_copy.gin similarity index 97% rename from trax/supervised/configs/terraformer_copy.gin rename to resources/supervised/configs/terraformer_copy.gin index 83a59a6f2..a6379c3c1 100644 --- a/trax/supervised/configs/terraformer_copy.gin +++ b/resources/supervised/configs/terraformer_copy.gin @@ -19,7 +19,7 @@ include 'reformer_copy.gin' import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for ConfigurableTerraformer: # ============================================================================== diff --git a/trax/supervised/configs/terraformer_copy_self_attn.gin b/resources/supervised/configs/terraformer_copy_self_attn.gin similarity index 100% rename from trax/supervised/configs/terraformer_copy_self_attn.gin rename to resources/supervised/configs/terraformer_copy_self_attn.gin diff --git a/trax/supervised/configs/terraformer_purelsh_copy.gin b/resources/supervised/configs/terraformer_purelsh_copy.gin similarity index 98% rename from trax/supervised/configs/terraformer_purelsh_copy.gin rename to resources/supervised/configs/terraformer_purelsh_copy.gin index 75c844adb..401c0cd53 100644 --- a/trax/supervised/configs/terraformer_purelsh_copy.gin +++ b/resources/supervised/configs/terraformer_purelsh_copy.gin @@ -19,7 +19,7 @@ include 'terraformer_copy.gin' import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for PureLSHSelfAttention: diff --git a/trax/supervised/configs/terraformer_wmt_ende.gin b/resources/supervised/configs/terraformer_wmt_ende.gin similarity index 96% rename from trax/supervised/configs/terraformer_wmt_ende.gin rename to resources/supervised/configs/terraformer_wmt_ende.gin index 91566cc65..9540a0277 100644 --- a/trax/supervised/configs/terraformer_wmt_ende.gin +++ b/resources/supervised/configs/terraformer_wmt_ende.gin @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import trax.data +import trax.data.loader.tf.base import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib +import trax.data.preprocessing.tf.wmt MAX_EVAL_LENGTH = 512 # Get back to 65536 BUCKETS = ([513], [64, 1]) @@ -34,7 +35,7 @@ batcher.id_to_mask = 0 # Parameters for data_streams: # ============================================================================== data_streams.data_dir = None -data_streams.dataset_name = 't2t_translate_ende_wmt32k' +data_streams.dataset_name = 't2t_wmt14_translate/de-en' data_streams.preprocess_fn = @data.wmt_preprocess # Parameters for multifactor: diff --git a/trax/supervised/configs/transformer_big_lm1b_8gb.gin b/resources/supervised/configs/transformer_big_lm1b_8gb.gin similarity index 98% rename from trax/supervised/configs/transformer_big_lm1b_8gb.gin rename to resources/supervised/configs/transformer_big_lm1b_8gb.gin index f97a96ef5..25b900d46 100644 --- a/trax/supervised/configs/transformer_big_lm1b_8gb.gin +++ b/resources/supervised/configs/transformer_big_lm1b_8gb.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/transformer_finetune_squad_16gb.gin b/resources/supervised/configs/transformer_finetune_squad_16gb.gin similarity index 100% rename from trax/supervised/configs/transformer_finetune_squad_16gb.gin rename to resources/supervised/configs/transformer_finetune_squad_16gb.gin diff --git a/trax/supervised/configs/transformer_imdb_8gb.gin b/resources/supervised/configs/transformer_imdb_8gb.gin similarity index 97% rename from trax/supervised/configs/transformer_imdb_8gb.gin rename to resources/supervised/configs/transformer_imdb_8gb.gin index 429a3b62e..89e49cf20 100644 --- a/trax/supervised/configs/transformer_imdb_8gb.gin +++ b/resources/supervised/configs/transformer_imdb_8gb.gin @@ -15,8 +15,8 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib # Parameters for the inputs pipeline: # ============================================================================== diff --git a/trax/supervised/configs/transformer_imdb_tfds.gin b/resources/supervised/configs/transformer_imdb_tfds.gin similarity index 98% rename from trax/supervised/configs/transformer_imdb_tfds.gin rename to resources/supervised/configs/transformer_imdb_tfds.gin index df346cf99..310ab88c0 100644 --- a/trax/supervised/configs/transformer_imdb_tfds.gin +++ b/resources/supervised/configs/transformer_imdb_tfds.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib import t5.data.preprocessors diff --git a/trax/supervised/configs/transformer_lm1b_8gb_testing.gin b/resources/supervised/configs/transformer_lm1b_8gb_testing.gin similarity index 97% rename from trax/supervised/configs/transformer_lm1b_8gb_testing.gin rename to resources/supervised/configs/transformer_lm1b_8gb_testing.gin index 992698396..f72035a53 100644 --- a/trax/supervised/configs/transformer_lm1b_8gb_testing.gin +++ b/resources/supervised/configs/transformer_lm1b_8gb_testing.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Module trax.data: # ============================================================================== diff --git a/trax/supervised/configs/transformer_lm_cnndailymail.gin b/resources/supervised/configs/transformer_lm_cnndailymail.gin similarity index 98% rename from trax/supervised/configs/transformer_lm_cnndailymail.gin rename to resources/supervised/configs/transformer_lm_cnndailymail.gin index 13c39db3d..c623d3515 100644 --- a/trax/supervised/configs/transformer_lm_cnndailymail.gin +++ b/resources/supervised/configs/transformer_lm_cnndailymail.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/transformer_lm_wmt_ende_16gb.gin b/resources/supervised/configs/transformer_lm_wmt_ende_16gb.gin similarity index 98% rename from trax/supervised/configs/transformer_lm_wmt_ende_16gb.gin rename to resources/supervised/configs/transformer_lm_wmt_ende_16gb.gin index 0913538a4..c296c7e97 100644 --- a/trax/supervised/configs/transformer_lm_wmt_ende_16gb.gin +++ b/resources/supervised/configs/transformer_lm_wmt_ende_16gb.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/transformer_lm_wmt_ende_8gb.gin b/resources/supervised/configs/transformer_lm_wmt_ende_8gb.gin similarity index 98% rename from trax/supervised/configs/transformer_lm_wmt_ende_8gb.gin rename to resources/supervised/configs/transformer_lm_wmt_ende_8gb.gin index 859f26290..6c17543d8 100644 --- a/trax/supervised/configs/transformer_lm_wmt_ende_8gb.gin +++ b/resources/supervised/configs/transformer_lm_wmt_ende_8gb.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/transformer_ptb_16gb.gin b/resources/supervised/configs/transformer_ptb_16gb.gin similarity index 96% rename from trax/supervised/configs/transformer_ptb_16gb.gin rename to resources/supervised/configs/transformer_ptb_16gb.gin index 9c893f15d..399539f66 100644 --- a/trax/supervised/configs/transformer_ptb_16gb.gin +++ b/resources/supervised/configs/transformer_ptb_16gb.gin @@ -15,8 +15,8 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.lr_schedules -import trax.supervised.trainer_lib +import trax.learning.supervised.lr_schedules +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin b/resources/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin similarity index 98% rename from trax/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin rename to resources/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin index 7ff9cc9b9..6f19a4e59 100644 --- a/trax/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin +++ b/resources/supervised/configs/transformer_wmt_ende_16gb_adafactor_testing.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/transformer_wmt_ende_8gb.gin b/resources/supervised/configs/transformer_wmt_ende_8gb.gin similarity index 98% rename from trax/supervised/configs/transformer_wmt_ende_8gb.gin rename to resources/supervised/configs/transformer_wmt_ende_8gb.gin index 143f8ca43..b909666ac 100644 --- a/trax/supervised/configs/transformer_wmt_ende_8gb.gin +++ b/resources/supervised/configs/transformer_wmt_ende_8gb.gin @@ -15,7 +15,7 @@ import trax.data import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/configs/wide_resnet_cifar10_8gb.gin b/resources/supervised/configs/wide_resnet_cifar10_8gb.gin similarity index 97% rename from trax/supervised/configs/wide_resnet_cifar10_8gb.gin rename to resources/supervised/configs/wide_resnet_cifar10_8gb.gin index 81d775838..5962cddbe 100644 --- a/trax/supervised/configs/wide_resnet_cifar10_8gb.gin +++ b/resources/supervised/configs/wide_resnet_cifar10_8gb.gin @@ -16,7 +16,7 @@ import trax.data import trax.supervised.lr_schedules import trax.models import trax.optimizers -import trax.supervised.trainer_lib +import trax.learning.supervised.trainer_lib # Parameters for batcher: # ============================================================================== diff --git a/trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz b/resources/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz similarity index 100% rename from trax/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz rename to resources/supervised/testdata/reformerlm_copy_lsh_attn.pkl.gz diff --git a/trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz b/resources/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz similarity index 100% rename from trax/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz rename to resources/supervised/testdata/terraformer_copy_lsh_attn.pkl.gz diff --git a/trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz b/resources/supervised/testdata/terraformer_copy_self_attn.pkl.gz similarity index 100% rename from trax/supervised/testdata/terraformer_copy_self_attn.pkl.gz rename to resources/supervised/testdata/terraformer_copy_self_attn.pkl.gz diff --git a/trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz b/resources/supervised/testdata/terraformer_purelsh_copy.pkl.gz similarity index 100% rename from trax/supervised/testdata/terraformer_purelsh_copy.pkl.gz rename to resources/supervised/testdata/terraformer_purelsh_copy.pkl.gz diff --git a/trax/supervised/testdata/transformer_copy.pkl.gz b/resources/supervised/testdata/transformer_copy.pkl.gz similarity index 100% rename from trax/supervised/testdata/transformer_copy.pkl.gz rename to resources/supervised/testdata/transformer_copy.pkl.gz diff --git a/trax/supervised/testdata/transformerlm_copy.pkl.gz b/resources/supervised/testdata/transformerlm_copy.pkl.gz similarity index 100% rename from trax/supervised/testdata/transformerlm_copy.pkl.gz rename to resources/supervised/testdata/transformerlm_copy.pkl.gz diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..14186d7b7 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,88 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py310" + +[lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F", "I", "W"] +ignore = [ + "E302", # We'll handle module-level function spacing manually +] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[lint.isort] +# Enforce two blank lines before top-level classes and functions +required-imports = [] +section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] +# Equivalent to Black's import style +known-first-party = [] +known-third-party = [] +lines-between-types = 1 + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = true + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" diff --git a/setup.py b/setup.py index 0f1f0ec57..565cb4507 100644 --- a/setup.py +++ b/setup.py @@ -16,60 +16,66 @@ # coding=utf-8 """Install trax.""" -from setuptools import find_packages -from setuptools import setup +from setuptools import find_packages, setup setup( - name='trax', - version='1.4.1', - description='Trax', + name="trax", + version="2.0.0", + description="Trax", long_description=( - 'Trax helps you understand deep learning. We start with basic maths and' - ' go through layers, models, supervised and reinforcement learning. We ' - 'get to advanced deep learning results, including recent papers and ' - 'state-of-the-art models.' + "Trax helps you understand deep learning. We start with basic maths and" + " go through layers, models, supervised and reinforcement learning. We " + "get to advanced deep learning results, including recent papers and " + "state-of-the-art models." ), - author='Google Inc.', - author_email='no-reply@google.com', - url='http://github.com/google/trax', - license='Apache 2.0', + author="Google Inc.", + author_email="no-reply@google.com", + url="http://github.com/google/trax", + license="Apache 2.0", packages=find_packages(), install_requires=[ - 'absl-py', - 'funcsigs', - 'gin-config', - 'gym', - 'jax', - 'jaxlib', - 'matplotlib', - 'numpy', - 'psutil', - 'scipy', - 'six', - 'tensorflow-datasets', - 'tensorflow-text', + "absl-py==2.2.0", + "gin-config==0.5.0", + "jax==0.5.3", + "jaxlib==0.5.3", + "numpy==1.26.4", + "psutil==7.0.0", + "scipy==1.15.2", + "tensorflow-datasets==4.9.8", + "tensorflow-text==2.17.0", ], extras_require={ - 'tensorflow': ['tensorflow>=1.15.0'], - 'tensorflow_gpu': ['tensorflow-gpu>=1.15.0'], - 't5': ['t5>=0.4.0'], - 'tests': [ - 'attrs', - 'jupyter', - 'mock', - 'parameterized', - 'pylint', - 'pytest', - 'wrapt==1.11.*', + "tensorflow": ["tensorflow==2.17.0"], + "tensorflow_cuda": ["tensorflow[and-cuda]==2.17.0"], + "t5": [ + "t5==0.9.4", + "seqio==0.0.18", + ], + "rl": [ + "gym==0.26.2", + ], + "viz": [ + "matplotlib==3.10.1", + ], + "examples": [ + "datasets==3.5.0", + ], + "tests": [ + "attrs==25.3.0", + "jupyter", + "mock==5.1.0", + "parameterized==0.9.0", + "pylint==3.3.6", + "pytest==8.3.5", + "wrapt==1.17.2", ], - 't2t': ['tensor2tensor',], }, classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], - keywords='tensorflow machine learning jax', + keywords="tensorflow machine learning jax", ) diff --git a/tests/data/encoder/encoder_test.py b/tests/data/encoder/encoder_test.py new file mode 100644 index 000000000..e7b41d09b --- /dev/null +++ b/tests/data/encoder/encoder_test.py @@ -0,0 +1,683 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.text_encoder.""" + +import collections +import io +import os +import random +import shutil +import string + +import gin +import mock +import numpy as np +import six +import tensorflow.compat.v1 as tf + + +# import tensorflow.compat.v1 as tf +from six.moves import ( + range, # pylint: disable=redefined-builtin # pylint: disable=redefined-builtin +) + +from tests.data.utils import ( # relative import + _TESTDATA, + _spm_path, +) +from trax.data.encoder import encoder as text_encoder + + +class NativeToUnicodeTest(tf.test.TestCase): + def test_native_to_unicode(self): + s = r"foo bar" + s_unicode = text_encoder.native_to_unicode(s) + self.assertEqual(s_unicode, "foo bar") + + +class EscapeUnescapeTokenTest(tf.test.TestCase): + def test_escape_token(self): + escaped = text_encoder._escape_token( + "Foo! Bar.\nunder_score back\\slash", + set("abcdefghijklmnopqrstuvwxyz .\n") | text_encoder._ESCAPE_CHARS, + ) + + self.assertEqual( + "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_", escaped + ) + + def test_unescape_token(self): + unescaped = text_encoder._unescape_token( + "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_" + ) + + self.assertEqual("Foo! Bar.\nunder_score back\\slash", unescaped) + + +class TokenTextEncoderTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + """Make sure the test dir exists and is empty.""" + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") + shutil.rmtree(cls.test_temp_dir, ignore_errors=True) + tf.gfile.MakeDirs(cls.test_temp_dir) + + def test_save_and_reload(self): + """Test that saving and reloading doesn't change the vocab. + + Note that this test reads and writes to the filesystem, which necessitates + that this test size be "large". + """ + + corpus = "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z" + vocab_filename = os.path.join(self.test_temp_dir, "abc.vocab") + + # Make text encoder from a list and store vocab to fake filesystem. + encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) + encoder.store_to_file(vocab_filename) + + # Load back the saved vocab file from the fake_filesystem. + new_encoder = text_encoder.TokenTextEncoder(vocab_filename) + + self.assertEqual(encoder._id_to_token, new_encoder._id_to_token) + self.assertEqual(encoder._token_to_id, new_encoder._token_to_id) + + def test_reserved_tokens_in_corpus(self): + """Test that we handle reserved tokens appearing in the corpus.""" + corpus = "A B {} D E F {} G {}".format( + text_encoder.EOS, text_encoder.EOS, text_encoder.PAD + ) + + encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) + + all_tokens = encoder._id_to_token.values() + + # If reserved tokens are removed correctly, then the set of tokens will + # be unique. + self.assertEqual(len(all_tokens), len(set(all_tokens))) + + +class SubwordTextEncoderTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + """Make sure the test dir exists and is empty.""" + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") + shutil.rmtree(cls.test_temp_dir, ignore_errors=True) + tf.gfile.MakeDirs(cls.test_temp_dir) + + def test_encode_decode(self): + corpus = ( + "This is a corpus of text that provides a bunch of tokens from which " + "to build a vocabulary. It will be used when strings are encoded " + "with a TextEncoder subclass. The encoder was coded by a coder." + ) + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) - {" "} + + original = "This is a coded sentence encoded by the SubwordTextEncoder." + token_counts.update(original.split(" ")) + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + # Encoding should be reversible. + encoded = encoder.encode(original) + decoded = encoder.decode(encoded) + self.assertEqual(original, decoded) + + # The substrings coded and coder are frequent enough in the corpus that + # they should appear in the vocabulary even though they are substrings + # of other included strings. + subtoken_strings = {encoder.all_subtoken_strings[i] for i in encoded} + self.assertIn("encoded_", subtoken_strings) + self.assertIn("coded_", subtoken_strings) + self.assertIn("TextEncoder", encoder.all_subtoken_strings) + self.assertIn("coder", encoder.all_subtoken_strings) + + # Every character in the corpus should be in the encoders alphabet and + # its subtoken vocabulary. + self.assertTrue(alphabet.issubset(encoder._alphabet)) + for a in alphabet: + self.assertIn(a, encoder.all_subtoken_strings) + + def test_unicode(self): + corpus = "Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B" + token_counts = collections.Counter(corpus.split(" ")) + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + self.assertIn("\U0001F638", encoder._alphabet) + self.assertIn("\U0001F63B", encoder.all_subtoken_strings) + + def test_small_vocab(self): + corpus = "The quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) - {" "} + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 10, token_counts, 2, 10 + ) + + # All vocabulary elements are in the alphabet and subtoken strings even + # if we requested a smaller vocabulary to assure all expected strings + # are encodable. + self.assertTrue(alphabet.issubset(encoder._alphabet)) + for a in alphabet: + self.assertIn(a, encoder.all_subtoken_strings) + + def test_long_tokens(self): + """Subword tokenization should still run efficiently with long tokens. + + To make it run efficiently, we need to use the `max_subtoken_length` + argument when calling SubwordTextEncoder.build_to_target_size. + """ + token_length = 4000 + num_tokens = 50 + target_vocab_size = 600 + max_subtoken_length = 10 # Set this to `None` to get problems. + max_count = 500 + + # Generate some long random strings. + random.seed(0) + long_tokens = [] + for _ in range(num_tokens): + long_token = "".join( + [random.choice(string.ascii_uppercase) for _ in range(token_length)] + ) + long_tokens.append(long_token) + + corpus = " ".join(long_tokens) + token_counts = collections.Counter(corpus.split(" ")) + alphabet = set(corpus) - {" "} + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + target_vocab_size, + token_counts, + 1, + max_count, + num_iterations=1, + max_subtoken_length=max_subtoken_length, + ) + + # All vocabulary elements are in the alphabet and subtoken strings even + # if we requested a smaller vocabulary to assure all expected strings + # are encodable. + self.assertTrue(alphabet.issubset(encoder._alphabet)) + for a in alphabet: + self.assertIn(a, encoder.all_subtoken_strings) + + def test_custom_reserved_tokens(self): + """Test that we can pass custom reserved tokens to SubwordTextEncoder.""" + corpus = "The quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + start_symbol = "" + end_symbol = "" + reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, end_symbol] + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 10, token_counts, 2, 10, reserved_tokens=reserved_tokens + ) + + # Make sure that reserved tokens appear in the right places. + self.assertEqual(encoder.decode([2]), start_symbol) + self.assertEqual(encoder.decode([3]), end_symbol) + + # Make sure that we haven't messed up the ability to reconstruct. + reconstructed_corpus = encoder.decode(encoder.encode(corpus)) + self.assertEqual(corpus, reconstructed_corpus) + + def test_encodable_when_not_in_alphabet(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + original = "This has UPPER CASE letters that are out of alphabet" + + # Early versions could have an infinite loop when breaking into subtokens + # if there was any out-of-alphabet characters in the encoded string. + encoded = encoder.encode(original) + decoded = encoder.decode(encoded) + + self.assertEqual(original, decoded) + encoded_str = "".join(encoder.all_subtoken_strings[i] for i in encoded) + self.assertIn("\\84;", encoded_str) + + @mock.patch.object(text_encoder, "_ESCAPE_CHARS", new=set("\\_;13579")) + def test_raises_exception_when_not_encodable(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + # Deliberately exclude some required encoding chars from the alphabet + # and token list, making some strings unencodable. + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + original = "This has UPPER CASE letters that are out of alphabet" + + # Previously there was a bug which produced an infinite loop in this case. + with self.assertRaises(AssertionError): + encoder.encode(original) + + def test_load_from_file(self): + # Test a vocab file with words not wrapped with single quotes + encoder = text_encoder.SubwordTextEncoder() + correct_vocab = ["the", "and", "of"] + vocab = io.StringIO("the\n" "and\n" "of\n") + encoder._load_from_file_object(vocab) + self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) + + # Test a vocab file with words wrapped in single quotes + encoder = text_encoder.SubwordTextEncoder() + vocab = io.StringIO('"the"\n' '"and"\n' '"of"\n') + encoder._load_from_file_object(vocab) + self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) + + def test_reserved_token_chars_not_in_alphabet(self): + corpus = "dog" + token_counts = collections.Counter(corpus.split(" ")) + encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 100 + ) + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder1.store_to_file(filename) + encoder2 = text_encoder.SubwordTextEncoder(filename=filename) + + self.assertEqual(encoder1._alphabet, encoder2._alphabet) + + for t in text_encoder.RESERVED_TOKENS: + for c in t: + # Verify that encoders can encode all reserved token chars. + encoder1.encode(c) + encoder2.encode(c) + + def test_save_and_reload(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + # Deliberately exclude some required encoding chars from the alphabet + # and token list, making some strings unencodable. + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder.store_to_file(filename) + new_encoder = text_encoder.SubwordTextEncoder(filename) + + self.assertEqual(encoder._alphabet, new_encoder._alphabet) + self.assertEqual(encoder.all_subtoken_strings, new_encoder.all_subtoken_strings) + self.assertEqual( + encoder._subtoken_string_to_id, new_encoder._subtoken_string_to_id + ) + self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) + + def test_save_and_reload_no_single_quotes(self): + corpus = "the quick brown fox jumps over the lazy dog" + token_counts = collections.Counter(corpus.split(" ")) + + # Deliberately exclude some required encoding chars from the alphabet + # and token list, making some strings unencodable. + encoder = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 10 + ) + + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder.store_to_file(filename, add_single_quotes=False) + new_encoder = text_encoder.SubwordTextEncoder(filename) + + self.assertEqual(encoder._alphabet, new_encoder._alphabet) + self.assertEqual(encoder.all_subtoken_strings, new_encoder.all_subtoken_strings) + self.assertEqual( + encoder._subtoken_string_to_id, new_encoder._subtoken_string_to_id + ) + self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) + + def test_build_from_generator(self): + corpus = "The quick brown fox jumps over the lazy dog" + + def gen(): + for _ in range(3): + yield corpus + + start_symbol = "" + end_symbol = "" + reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, end_symbol] + encoder = text_encoder.SubwordTextEncoder.build_from_generator( + gen(), 10, reserved_tokens=reserved_tokens + ) + + # Make sure that reserved tokens appear in the right places. + self.assertEqual(encoder.decode([2]), start_symbol) + self.assertEqual(encoder.decode([3]), end_symbol) + + self.assertEqual( + "hi%s" % start_symbol, encoder.decode(encoder.encode("hi") + [2]) + ) + + # Make sure that we haven't messed up the ability to reconstruct. + reconstructed_corpus = encoder.decode(encoder.encode(corpus)) + self.assertEqual(corpus, reconstructed_corpus) + + +class OneHotClassLabelEncoderTest(tf.test.TestCase): + def test_one_hot_encode(self): + encoder = text_encoder.OneHotClassLabelEncoder( + class_labels=["zero", "one", "two"] + ) + self.assertEqual(encoder.encode("zero"), [1, 0, 0]) + self.assertEqual(encoder.encode("one"), [0, 1, 0]) + self.assertEqual(encoder.encode("two"), [0, 0, 1]) + + def test_one_hot_decode(self): + encoder = text_encoder.OneHotClassLabelEncoder( + class_labels=["zero", "one", "two"] + ) + self.assertEqual(encoder.decode([1, 0, 0]), "zero") + self.assertEqual(encoder.decode([0, 1, 0]), "one") + self.assertEqual(encoder.decode([0, 0, 1]), "two") + + +class TokenizerTest(tf.test.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def test_encode(self): + self.assertListEqual( + ["Dude", " - ", "that", "'", "s", "so", "cool", "."], + text_encoder.encode("Dude - that's so cool."), + ) + self.assertListEqual( + ["Łukasz", "est", "nÊ", "en", "1981", "."], + text_encoder.encode("Łukasz est nÊ en 1981."), + ) + self.assertListEqual( + [" ", "Spaces", "at", "the", "ends", " "], + text_encoder.encode(" Spaces at the ends "), + ) + self.assertListEqual(["802", ".", "11b"], text_encoder.encode("802.11b")) + self.assertListEqual( + ["two", ". \n", "lines"], text_encoder.encode("two. \nlines") + ) + + def test_decode(self): + self.assertEqual( + "Dude - that's so cool.", + text_encoder.decode(["Dude", " - ", "that", "'", "s", "so", "cool", "."]), + ) + + def test_invertibility_on_random_strings(self): + for _ in range(1000): + s = "".join(six.unichr(random.randint(0, 65535)) for _ in range(10)) + self.assertEqual(s, text_encoder.decode(text_encoder.encode(s))) + + def test_tokenize_detokenize_character_level(self): + def dataset(): + yield "I have a cat." + + # Character-level. + tok_char = list(text_encoder.tokenize(dataset(), vocab_type="char")) + self.assertAllEqual(tok_char[0], np.array([ord(c) for c in "I have a cat."])) + detok = text_encoder.detokenize(tok_char[0], vocab_type="char") + self.assertEqual(detok, "I have a cat.") + + def test_tokenize_detokenize_sentencepiece(self): + def dataset(): + yield "I have a cat." + + # Sentencepiece. + tok_spc = list( + text_encoder.tokenize( + dataset(), + vocab_type="sentencepiece", + vocab_dir=_TESTDATA, + vocab_file="sentencepiece.model", + ) + ) + + self.assertAllEqual(tok_spc[0], np.array([[27, 43, 3, 9, 1712, 5]])) + + detok = text_encoder.detokenize( + list(tok_spc[0]), + vocab_type="sentencepiece", + vocab_dir=_TESTDATA, + vocab_file="sentencepiece.model", + ) + + self.assertEqual(detok, "I have a cat.") + + def test_tokenize_detokenize_subword(self): + def dataset(): + yield "I have a cat." + + # Subword. + tok_sbw = list( + text_encoder.tokenize( + dataset(), + vocab_type="subword", + vocab_dir=_TESTDATA, + vocab_file="en_8k.subword", + ) + ) + self.assertAllEqual(tok_sbw[0], np.array([139, 96, 12, 2217, 2, 21])) + detok = text_encoder.detokenize( + tok_sbw[0], + vocab_type="subword", + vocab_dir=_TESTDATA, + vocab_file="en_8k.subword", + ) + self.assertEqual(detok, "I have a cat.") + + def test_tokenize_detokenize_bert_lowercase(self): + def dataset(): + yield "I have a cat." + + # bert-lowercase + tok_sbw = list( + text_encoder.tokenize( + dataset(), + vocab_type="bert-lowercase", + vocab_dir=_TESTDATA, + vocab_file="bert_uncased_vocab.txt", + ) + ) + self.assertAllEqual(tok_sbw[0], np.array([1045, 2031, 1037, 4937, 1012])) + + detok = text_encoder.detokenize( + tok_sbw[0], + vocab_type="bert-lowercase", + vocab_dir=_TESTDATA, + vocab_file="bert_uncased_vocab.txt", + ) + self.assertEqual(detok, "i have a cat .") + # note: BERT tokenizer is not reversible, therefore + # difference between original input + + def test_tokenize_keys_reservedids(self): + def dataset(): + yield ("Cat.", "Dog.") + + tok_char1 = list( + text_encoder.tokenize(dataset(), vocab_type="char", n_reserved_ids=5) + ) + self.assertAllEqual(tok_char1[0][0], np.array([ord(c) + 5 for c in "Cat."])) + self.assertAllEqual(tok_char1[0][1], np.array([ord(c) + 5 for c in "Dog."])) + + tok_char2 = list( + text_encoder.tokenize( + dataset(), keys=[0], vocab_type="char", n_reserved_ids=2 + ) + ) + self.assertAllEqual(tok_char2[0][0], np.array([ord(c) + 2 for c in "Cat."])) + self.assertEqual(tok_char2[0][1], "Dog.") + + def test_tokenize_dict(self): + def dataset(): + yield {"a": "Cat.", "b": "Dog."} + + tok_char1 = list(text_encoder.tokenize(dataset(), vocab_type="char")) + self.assertAllEqual(tok_char1[0]["a"], np.array([ord(c) for c in "Cat."])) + self.assertAllEqual(tok_char1[0]["b"], np.array([ord(c) for c in "Dog."])) + + tok_char2 = list( + text_encoder.tokenize(dataset(), keys=["a"], vocab_type="char") + ) + self.assertAllEqual(tok_char2[0]["a"], np.array([ord(c) for c in "Cat."])) + self.assertEqual(tok_char2[0]["b"], "Dog.") + + def test_vocab_size_character_level(self): + # Character-level. + char_size = text_encoder.vocab_size(vocab_type="char", n_reserved_ids=11) + self.assertEqual(char_size, 256 + 11) + + def test_vocab_size_sentencepiece(self): + # Sentencepiece. + spc_size = text_encoder.vocab_size( + vocab_type="sentencepiece", + vocab_dir=_TESTDATA, + vocab_file="sentencepiece.model", + ) + self.assertEqual(spc_size, 32000) + + def test_vocab_size_subword_level(self): + sbw_size = text_encoder.vocab_size( + vocab_type="subword", + vocab_dir=_TESTDATA, + vocab_file="en_8k.subword", + ) + self.assertEqual(sbw_size, 8183) + + def test_vocab_size_bert_uncased(self): + # Bert_uncased. + sbw_size = text_encoder.vocab_size( + vocab_type="bert-lowercase", + vocab_dir=_TESTDATA, + vocab_file="bert_uncased_vocab.txt", + ) + self.assertEqual(sbw_size, 30522) + + def test_sentencepiece_tokenize(self): + def dataset(): + yield "I have a cat." + + # Assume _spm_path() returns the correct path to your SentencePiece model. + # Use the new name: SentencePieceTokenizer. + tokenizer_fn = text_encoder.SentencePieceTokenizer(_spm_path()) + + # tokenizer_fn is now a function that expects a generator (stream) of examples. + tokenized_gen = tokenizer_fn(dataset()) + + # Get the first tokenized example using next() + first_example = next(tokenized_gen) + # Convert to list if needed + toks = list(first_example) + self.assertSequenceEqual([27, 43, 3, 9, 1712, 5], toks) + + +class TestTokenCounts(tf.test.TestCase): + def setUp(self): + super(TestTokenCounts, self).setUp() + self.corpus_path = os.path.join(_TESTDATA, "corpus-*.txt") + self.vocab_path = os.path.join(_TESTDATA, "vocab-*.txt") + + def test_corpus_token_counts_split_on_newlines(self): + token_counts = text_encoder.corpus_token_counts( + self.corpus_path, corpus_max_lines=0, split_on_newlines=True + ) + + expected = { + "'": 2, + ".": 2, + ". ": 1, + "... ": 1, + "Groucho": 1, + "Marx": 1, + "Mitch": 1, + "Hedberg": 1, + "I": 3, + "in": 2, + "my": 2, + "pajamas": 2, + } + self.assertDictContainsSubset(expected, token_counts) + self.assertNotIn(".\n\n", token_counts) + self.assertNotIn("\n", token_counts) + + def test_corpus_token_counts_no_split_on_newlines(self): + token_counts = text_encoder.corpus_token_counts( + self.corpus_path, corpus_max_lines=0, split_on_newlines=False + ) + + if ".\r\n\r\n" in token_counts.keys(): + token_counts.update({"\n\n": token_counts.pop(".\r\n\r\n")}) + + if "\r\n" in token_counts.keys(): + token_counts.update({"\n": token_counts.pop("\r\n")}) + + if ".\n\n" in token_counts.keys(): + token_counts.update({"\n\n": token_counts.pop(".\n\n")}) + + self.assertDictContainsSubset({"\n\n": 2, "\n": 3}, token_counts) + + def test_corpus_token_counts_split_with_max_lines(self): + token_counts = text_encoder.corpus_token_counts( + self.corpus_path, corpus_max_lines=5, split_on_newlines=True + ) + + self.assertIn("slept", token_counts) + self.assertNotIn("Mitch", token_counts) + + def test_corpus_token_counts_no_split_with_max_lines(self): + token_counts = text_encoder.corpus_token_counts( + self.corpus_path, corpus_max_lines=5, split_on_newlines=False + ) + + self.assertIn("slept", token_counts) + self.assertNotIn("Mitch", token_counts) + self.assertDictContainsSubset({".\n\n": 1, "\n": 2, ".\n": 1}, token_counts) + + def test_vocab_token_counts(self): + token_counts = text_encoder.vocab_token_counts(self.vocab_path, 0) + + expected = { + "lollipop": 8, + "reverberated": 12, + "kattywampus": 11, + "balderdash": 10, + "jiggery-pokery": 14, + } + self.assertDictEqual(expected, token_counts) + + def test_vocab_token_counts_with_max_lines(self): + # vocab-1 has 2 lines, vocab-2 has 3 + token_counts = text_encoder.vocab_token_counts(self.vocab_path, 5) + + expected = { + "lollipop": 8, + "reverberated": 12, + "kattywampus": 11, + "balderdash": 10, + } + self.assertDictEqual(expected, token_counts) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/data/loader/tf/base_test.py b/tests/data/loader/tf/base_test.py new file mode 100644 index 000000000..befe7e26b --- /dev/null +++ b/tests/data/loader/tf/base_test.py @@ -0,0 +1,363 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.tf.datasets.""" +from unittest import mock + +import gin +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from tests.data.utils import ( # relative import + _TESTDATA, + _load_dataset, + _spm_path, + _test_dataset_ints, + assert_dataset, +) +from trax.data.loader.tf import base as ds +from trax.data.preprocessing import inputs +from trax.data.preprocessing.inputs import batcher # noqa: F401 + + +class TFDatasetTest(tf.test.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def test_TFDS_single_host_with_eval_holdout(self): + train_ds_gen = ds.TFDS( + "c4/en:2.3.0", + data_dir=_TESTDATA, + train=True, + host_id=0, + keys=("text",), + n_hosts=1, + eval_holdout_size=0.1, + ) + + result = None + + try: + # Just ensure that this doesn't crash. + for d in train_ds_gen(): + break + + result = True + except Exception as e: + self.fail( + f"test_TFDS_single_host_with_eval_holdout() raised an unexpected exception: {e}" + ) + + self.assertIsNotNone( + result, + "test_TFDS_single_host_with_eval_holdout() returned None unexpectedly", + ) + + valid_ds_gen = ds.TFDS( + "c4/en:2.3.0", + data_dir=_TESTDATA, + train=False, + host_id=0, + keys=("text",), + n_hosts=1, + eval_holdout_size=0.1, + ) + + result = None + + try: + # Just ensure that this doesn't crash. + for d in valid_ds_gen(): + break + + result = True + except Exception as e: + self.fail( + f"test_TFDS_single_host_with_eval_holdout() raised an unexpected exception: {e}" + ) + + self.assertIsNotNone( + result, + "test_TFDS_single_host_with_eval_holdout() returned None unexpectedly", + ) + + def test_TFDS_single_host_with_eval_holdout_no_valid_split(self): + train_ds_gen = ds.TFDS( + "para_crawl/ende", + data_dir=_TESTDATA, + train=True, + host_id=0, + keys=("en", "de"), + n_hosts=1, + eval_holdout_size=0.1, + ) + + result = None + + try: + # Just ensure that this doesn't crash. + for d in train_ds_gen(): + break + + result = True + except Exception as e: + self.fail( + f"test_TFDS_single_host_with_eval_holdout() raised an unexpected exception: {e}" + ) + + self.assertIsNotNone( + result, + "test_TFDS_single_host_with_eval_holdout() returned None unexpectedly", + ) + + # para_crawl doesn't have a validation set, see that this still doesn't + # crash because of eval_holdout_set. + valid_ds_gen = ds.TFDS( + "para_crawl/ende", + data_dir=_TESTDATA, + train=False, + host_id=0, + keys=("en", "de"), + n_hosts=1, + eval_holdout_size=0.1, + ) + + result = None + + try: + # Just ensure that this doesn't crash. + for d in valid_ds_gen(): + break + result = True + except Exception as e: + self.fail( + f"test_TFDS_single_host_with_eval_holdout() raised an unexpected exception: {e}" + ) + + self.assertIsNotNone( + result, + "test_TFDS_single_host_with_eval_holdout() returned None unexpectedly", + ) + + def test_TFDS_mnli_split_is_eval(self): + with mock.patch("tensorflow_datasets.load") as tfds_load: + with mock.patch( + "trax.data.loader.tf.base.download_and_prepare", + lambda _, data_dir: data_dir, + ): + _ = ds.TFDS("glue/mnli", keys=("premise", "hypothesis"), train=False) + call_kwargs = tfds_load.call_args[1] + self.assertEqual(call_kwargs["split"], "validation_matched") + + def test_TFDS_mnli_split_is_alt_eval(self): + with mock.patch("tensorflow_datasets.load") as tfds_load: + with mock.patch( + "trax.data.loader.tf.base.download_and_prepare", + lambda _, data_dir: data_dir, + ): + _ = ds.TFDS( + "glue/mnli", + keys=("premise", "hypothesis"), + train=False, + use_alt_eval=True, + ) + call_kwargs = tfds_load.call_args[1] + self.assertEqual(call_kwargs["split"], "validation_mismatched") + + def test_generic_text_dataset_preprocess_fn(self): + # self.skipTest("google.protobuf.json_format.ParseError ...") + dataset = _load_dataset("squad/v1.1:3.0.0") + + (example,) = tfds.as_numpy(dataset.take(1)) + + self.assertNotIn("inputs", example) + self.assertNotIn("targets", example) + + proc_dataset = ds.generic_text_dataset_preprocess_fn( + dataset, + spm_path=_spm_path(), + text_preprocess_fns=[lambda _ds, training: ds.squad_t5(_ds, None)], + copy_pretokenized=True, + debug_print_examples=True, + debug_print_examples_rate=1.0, + ) + + (proc_example,) = tfds.as_numpy(proc_dataset.take(1)) + + self.assertIn("inputs", proc_example) + self.assertIn("targets", proc_example) + + self.assertEqual(proc_example["inputs"].dtype, tf.int64) + self.assertEqual(proc_example["targets"].dtype, tf.int64) + + # TODO(afrozm): Why does this test take so much time? + def test_inputs_using_generic_text_dataset_preprocess_fn(self): + gin.bind_parameter("generic_text_dataset_preprocess_fn.spm_path", _spm_path()) + gin.bind_parameter( + "generic_text_dataset_preprocess_fn.text_preprocess_fns", + [lambda _ds, training: ds.squad_t5(_ds, None)], + ) + + # Just make sure this doesn't throw. + def data_streams(): + return ds.data_streams( + "squad", + data_dir=_TESTDATA, + input_name="inputs", + target_name="targets", + bare_preprocess_fn=ds.generic_text_dataset_preprocess_fn, + shuffle_buffer_size=1, + ) + + n_devices = 3 + + squad_inputs = inputs.batcher( + data_streams=data_streams, + max_eval_length=512, + buckets=( + [ + 513, + ], + [n_devices, n_devices], + ), + ) + + eval_stream = squad_inputs.eval_stream(n_devices) + inps, tgts, _ = next(eval_stream) + + # We can only assert that the batch dim gets divided by n_devices. + self.assertEqual(inps.shape[0] % n_devices, 0) + self.assertEqual(tgts.shape[0] % n_devices, 0) + + def test_filter_dataset_on_len(self): + # {1, 2}, {2, 4}, {3, 6} ... {10, 20} + dataset = _test_dataset_ints(range(1, 11), range(2, 21, 2)) + + ds1 = ds.filter_dataset_on_len( + dataset, True, {"inputs": [4, 8], "targets": [14, 20]} + ) + # Only {7, 14} and {8, 16} satisfy this. + self.assertLen(list(ds1.as_numpy_iterator()), 2) + + ds2 = ds.filter_dataset_on_len( + dataset, + False, + len_map={"inputs": [4, 8], "targets": [14, 20]}, + filter_on_eval=False, + ) + # This is eval and we aren't supposed to filter it. + self.assertLen(list(ds2.as_numpy_iterator()), 10) + + ds3 = ds.filter_dataset_on_len( + dataset, + False, + len_map={"inputs": [4, 8], "targets": [14, 20]}, + filter_on_eval=True, + ) + # This is eval and we are asked to filter it. + self.assertLen(list(ds3.as_numpy_iterator()), 2) + + def test_truncate_dataset_on_len(self): + dataset = _test_dataset_ints([5, 6, 7], [8, 9, 10]) + ds1 = ds.truncate_dataset_on_len( + dataset, True, len_map={"inputs": 6, "targets": 4} + ) + expected_ds = _test_dataset_ints([5, 6, 6], [4, 4, 4]) + + # training, should filter. + assert_dataset(ds1, list(expected_ds.as_numpy_iterator())) + + # not Training, shouldn't filter. + ds2 = ds.truncate_dataset_on_len( + dataset, False, len_map={"inputs": 6, "targets": 4} + ) + assert_dataset(ds2, list(dataset.as_numpy_iterator())) + + # not Training, but asked to filter, should filter. + ds3 = ds.truncate_dataset_on_len( + dataset, False, len_map={"inputs": 6, "targets": 4}, truncate_on_eval=True + ) + assert_dataset(ds3, list(expected_ds.as_numpy_iterator())) + + def test_get_t5_preprocessor_by_name(self): + gin.clear_config() + + gin.parse_config( + """ + get_t5_preprocessor_by_name.name = 'rekey_t5' + get_t5_preprocessor_by_name.fn_kwargs = {'key_map': {'inputs': 'other', 'targets': 'text'}} + """ + ) + + prep_rekey = ds.get_t5_preprocessor_by_name() + og_dataset = tf.data.Dataset.from_tensors( + {"text": "That is good.", "other": "That is bad."} + ) + training = True + dataset = prep_rekey(og_dataset, training) + assert_dataset(dataset, {"inputs": "That is bad.", "targets": "That is good."}) + + def test_pad_dataset_to_length(self): + dataset = _test_dataset_ints([5, 6, 7], [6, 7, 8]) + ds1 = ds.pad_dataset_to_length( + dataset, True, len_map={"inputs": 7, "targets": 10} + ) + + expected_ds = [ + { + "inputs": np.array([1, 1, 1, 1, 1, 0, 0], dtype=np.int64), + "targets": np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 1, 1, 1, 1, 0], dtype=np.int64), + "targets": np.array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 1, 1, 1, 1, 1], dtype=np.int64), + "targets": np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0], dtype=np.int64), + }, + ] + + assert_dataset(ds1, expected_ds) + + def test_lm_token_preprocessing(self): + dataset = _test_dataset_ints([1, 2, 3], [3, 2, 1]) + ds1 = ds.lm_token_preprocessing(dataset, True) + + expected_ds = [ + { + "inputs": np.array([1, 0, 1, 1, 1], dtype=np.int64), + "targets": np.array([1, 0, 1, 1, 1], dtype=np.int64), + "mask": np.array([0, 0, 1, 1, 1], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 0, 1, 1], dtype=np.int64), + "targets": np.array([1, 1, 0, 1, 1], dtype=np.int64), + "mask": np.array([0, 0, 0, 1, 1], dtype=np.int64), + }, + { + "inputs": np.array([1, 1, 1, 0, 1], dtype=np.int64), + "targets": np.array([1, 1, 1, 0, 1], dtype=np.int64), + "mask": np.array([0, 0, 0, 0, 1], dtype=np.int64), + }, + ] + + assert_dataset(ds1, expected_ds) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/data/preprocessing/inputs_test.py b/tests/data/preprocessing/inputs_test.py new file mode 100644 index 000000000..70a54d99b --- /dev/null +++ b/tests/data/preprocessing/inputs_test.py @@ -0,0 +1,1125 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.inputs.""" + +import itertools +import os + +import numpy as np + +from absl.testing import absltest, parameterized + +from trax.data.preprocessing import inputs as data +from trax.data.preprocessing.inputs import ConvertToUnicode + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath(os.path.join(pkg_dir, "../../resources/data/testdata")) + + +def _spm_path(): + return os.path.join(_TESTDATA, "sentencepiece.model") + + +class InputsTest(parameterized.TestCase): + def test_convert_to_unicode(self): + def dataset1(): + yield (b"Audentes fortuna iuvat.", b"Fortune favors the bold.") + + def dataset2(): + yield (b"\x81aabb", b"Value") + + convert_function1 = ConvertToUnicode(keys=[0]) + convert_output1 = next(convert_function1(dataset1())) + self.assertEqual(convert_output1[0], "Audentes fortuna iuvat.") + self.assertEqual(convert_output1[1], b"Fortune favors the bold.") + self.assertIsInstance(convert_output1[0], str) + self.assertIsInstance(convert_output1[1], bytes) + + # Contains an invalid bytes array from the point of view of UTF-8. + try: + convert_function2 = ConvertToUnicode(keys=[0]) + convert_output2 = next(convert_function2(dataset2())) + except UnicodeDecodeError: + self.fail("ConvertToUnicode threw UnicodeDecodeError.") + self.assertEqual(convert_output2[0], "aabb") + self.assertIsInstance(convert_output2[0], str) + + @parameterized.named_parameters( + ("zero", 0), + ("negative", -5), + ) + def test_shuffle_data_raises_error_queue_size(self, queue_size): + samples = iter(range(10)) + with self.assertRaises(ValueError): + _ = list(data.shuffle(samples, queue_size)) + + @parameterized.named_parameters( + ("one", 1), + ("two", 2), + ("twenty", 20), + ) + def test_shuffle_data_queue_size(self, queue_size): + samples = iter(range(100, 200)) + shuffled_stream = data.shuffle(samples, queue_size) + first_ten = [next(shuffled_stream) for _ in range(10)] + + # Queue size limits how far ahead/upstream the current sample can reach. + self.assertLess(first_ten[0], 100 + queue_size) + self.assertLess(first_ten[3], 103 + queue_size) + self.assertLess(first_ten[9], 109 + queue_size) + + unshuffled_first_ten = list(range(100, 110)) + if queue_size == 1: # Degenerate case: no shuffling can happen. + self.assertEqual(first_ten, unshuffled_first_ten) + if queue_size > 1: + self.assertNotEqual(first_ten, unshuffled_first_ten) + + @parameterized.named_parameters( + ("qsize_100_n_001", 100, 1), + ("qsize_100_n_099", 100, 99), + ("qsize_100_n_100", 100, 100), + ("qsize_100_n_101", 100, 101), + ("qsize_100_n_199", 100, 199), + ) + def test_shuffle_data_yields_all_samples(self, queue_size, n_samples): + samples = iter(range(n_samples)) + shuffled_stream = data.shuffle(samples, queue_size) + self.assertLen(list(shuffled_stream), n_samples) + + def test_batch_data(self): + dataset = ((i, i + 1) for i in range(10)) + batches = data.batch(dataset, 10) + batch = next(batches) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (10,)) + + def test_batch_data_padding(self): + dataset = (([1] * (10 - i), i + 1) for i in range(10)) + batches = data.batch(dataset, 10) + batch = next(batches) + self.assertEqual(batch[0].shape, (10, 10)) + self.assertTrue(np.array_equal(batch[0][-1], np.asarray([1] + 9 * [0]))) + + def test_batch_exception_size(self): + dataset = ((i, i + 1) for i in range(10)) + with self.assertRaises(ValueError): + batches = data.batch(dataset, 0) + next(batches) + + def test_serial(self): + dataset = lambda _: ((i, i + 1) for i in range(10)) + batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10)) + batch = next(batches()) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (10,)) + + def test_serial_composes(self): + """Check that data.Serial works inside another data.Serial.""" + dataset = lambda _: ((i, i + 1) for i in range(10)) + serial1 = data.Serial(dataset, data.Shuffle(3)) + batches = data.Serial(serial1, data.Batch(10)) + batch = next(batches()) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (10,)) + + def test_count_and_skip(self): + dataset = lambda _: ((i, i + 1) for i in range(10)) + examples = data.Serial(dataset, data.CountAndSkip("toy_data")) + ex_generator = examples() + ex1 = next(ex_generator) + self.assertEqual(ex1, (0, 1)) + self.assertEqual(data.data_counters["toy_data"], 1) + ex2 = next(ex_generator) + self.assertEqual(ex2, (1, 2)) + self.assertEqual(data.data_counters["toy_data"], 2) + ex3 = next(examples()) # new generator, will skip + self.assertEqual(ex3, (2, 3)) + self.assertEqual(data.data_counters["toy_data"], 3) + data.data_counters["toy_data"] = 0 # reset + ex4 = next(examples()) # new generator, was reset + self.assertEqual(ex4, (0, 1)) + self.assertEqual(data.data_counters["toy_data"], 1) + + def test_parallel(self): + """Basic test of the parallel ccmbinator.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + parallel = data.Parallel([dataset1, dataset2]) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 12) + + def test_parallel_with_gen_not_none(self): + """Test of the parallel ccmbinator with a not none generator.""" + dataset1 = lambda _: (i for i in range(10)) + dataset2 = lambda _: (i for i in range(10, 20)) + parallel = data.Parallel([dataset1, dataset2]) + + def test_generator(): + yield 0 + + generator = parallel(gen=test_generator) + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 12) + + def test_parallel_with_weights(self): + """Test of the parallel ccmbinator with weights.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + parallel = data.Parallel([dataset1, dataset2], counters=(2, 1)) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 3) + self.assertEqual(next(generator), 12) + self.assertEqual(next(generator), 4) + self.assertEqual(next(generator), 5) + self.assertEqual(next(generator), 13) + + def test_parallel_with_weights_and_minimum(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 110)) + parallel = data.Parallel( + [dataset1, dataset2], counters=(10, 100), reweight_by_minimum=True + ) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 12) + self.assertEqual(next(generator), 13) + self.assertEqual(next(generator), 14) + self.assertEqual(next(generator), 15) + self.assertEqual(next(generator), 16) + self.assertEqual(next(generator), 17) + self.assertEqual(next(generator), 18) + self.assertEqual(next(generator), 19) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 20) + self.assertEqual(next(generator), 21) + self.assertEqual(next(generator), 22) + self.assertEqual(next(generator), 23) + self.assertEqual(next(generator), 24) + self.assertEqual(next(generator), 25) + self.assertEqual(next(generator), 26) + self.assertEqual(next(generator), 27) + self.assertEqual(next(generator), 28) + self.assertEqual(next(generator), 29) + self.assertEqual(next(generator), 2) + + def test_parallel_with_gradual_reweighting(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in itertools.cycle(range(1))) + dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) + dataset3 = lambda: (i for i in itertools.cycle(range(30, 70))) + parallel = data.Parallel( + [dataset2, dataset1, dataset3], + counters=(20, 1, 40), + gradually_reweight=True, + ) + generator = parallel() + + for _ in range(3): + self.assertEqual(next(generator), 0) + for i in range(20): + self.assertEqual(next(generator), 10 + i) + self.assertEqual(next(generator), 30 + 2 * i) + self.assertEqual(next(generator), 30 + 2 * i + 1) + + def test_parallel_with_gradual_reweighting_remainders(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in itertools.cycle(range(1))) + dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) + dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) + parallel = data.Parallel( + [dataset2, dataset1, dataset3], + counters=(20, 1, 50), + gradually_reweight=True, + use_remainders=True, + ) + generator = parallel() + + for _ in range(3): + self.assertEqual(next(generator), 0) + for i in range(20): + self.assertEqual(next(generator), 10 + i) + self.assertEqual(next(generator), 30 + 2 * i) + self.assertEqual(next(generator), 30 + 2 * i + 1) + # Here we process the remainder from dataset 3: + for i in range(10): + self.assertEqual(next(generator), 70 + i) + + def test_parallel_with_gradual_reweighting_remainders_big(self): + """Test of the parallel ccmbinator with weights and minimum.""" + dataset1 = lambda: (i for i in itertools.cycle(range(1))) + dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) + dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) + dataset4 = lambda: (i for i in itertools.cycle(range(100, 220))) + parallel = data.Parallel( + [dataset2, dataset1, dataset4, dataset3], + counters=(20, 1, 120, 50), + gradually_reweight=True, + use_remainders=True, + ) + generator = parallel() + + for _ in range(3): + self.assertEqual(next(generator), 0) + for i in range(20): + self.assertEqual(next(generator), 10 + i) + for j in range(2): + self.assertEqual(next(generator), 30 + 2 * i + j) + for k in range(2): + self.assertEqual(next(generator), 100 + 2 * 2 * i + 2 * j + k) + # Here we process the remainder from datasets 3 and 4: + for i in range(10): + self.assertEqual(next(generator), 70 + i) + for i in range(40): + self.assertEqual(next(generator), 180 + i) + + def test_parallel_with_weights_three_datasets(self): + """Check that data.Serial works inside another data.Serial.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + dataset3 = lambda: (i for i in range(20, 30)) + parallel = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3)) + generator = parallel() + + self.assertEqual(next(generator), 0) # (1,0,0) + self.assertEqual(next(generator), 10) # (1,1,0) + self.assertEqual(next(generator), 20) # (1,1,1) + self.assertEqual(next(generator), 1) # (2,1,1) + self.assertEqual(next(generator), 21) # (2,1,2) + self.assertEqual(next(generator), 22) # (2,1,3) + self.assertEqual(next(generator), 2) # (1,0,0) + self.assertEqual(next(generator), 11) # (1,1,0) + self.assertEqual(next(generator), 23) # (1,1,1) + self.assertEqual(next(generator), 3) # (2,1,1) + self.assertEqual(next(generator), 24) # (2,1,2) + self.assertEqual(next(generator), 25) # (2,1,3) + self.assertEqual(next(generator), 4) # (1,0,0) + + def test_stack_parallel(self): + """Test of stacked parallel ccmbinators.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + dataset3 = lambda: (i for i in range(20, 30)) + parallel_lev0 = data.Parallel([dataset1, dataset2]) + parallel_lev1 = data.Parallel([parallel_lev0, dataset3]) + generator = parallel_lev1() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 20) + self.assertEqual(next(generator), 10) + self.assertEqual(next(generator), 21) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 22) + self.assertEqual(next(generator), 11) + self.assertEqual(next(generator), 23) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 24) + self.assertEqual(next(generator), 12) + + def test_parallel_with_zero_counters(self): + """Test of stacked parallel ccmbinators.""" + dataset1 = lambda: (i for i in range(10)) + dataset2 = lambda: (i for i in range(10, 20)) + dataset3 = lambda: (i for i in range(20, 30)) + parallel = data.Parallel([dataset1, dataset2, dataset3], counters=[1, 0, 1]) + generator = parallel() + + self.assertEqual(next(generator), 0) + self.assertEqual(next(generator), 20) + self.assertEqual(next(generator), 1) + self.assertEqual(next(generator), 21) + self.assertEqual(next(generator), 2) + self.assertEqual(next(generator), 22) + self.assertEqual(next(generator), 3) + self.assertEqual(next(generator), 23) + + def test_serial_with_python(self): + dataset = lambda _: ((i, i + 1) for i in range(10)) + batches = data.Serial( + dataset, + lambda g: map(lambda x: (x[0], x[1] + 1), g), + lambda g: filter(lambda x: x[0] % 2 == 1, g), + data.Batch(2), + ) + batch = next(batches()) + self.assertLen(batch, 2) + (xs, ys) = batch + # First tuple after filtering is (1, 3) = (1, 2+1). + self.assertEqual(xs[0], 1) + self.assertEqual(ys[0], 3) + # Second tuple after filtering is (3, 5). + self.assertEqual(xs[1], 3) + self.assertEqual(ys[1], 5) + + def test_pad_to_max_dims(self): + tensors1 = [np.zeros((3, 10)), np.ones((3, 10))] + padded1 = data.pad_to_max_dims(tensors1) + self.assertEqual(padded1.shape, (2, 3, 10)) + tensors2 = [np.zeros((2, 10)), np.ones((3, 9))] + padded2 = data.pad_to_max_dims(tensors2) + self.assertEqual(padded2.shape, (2, 3, 10)) + tensors3 = [np.zeros((8, 10)), np.ones((8, 9))] + padded3 = data.pad_to_max_dims(tensors3, 12) + self.assertEqual(padded3.shape, (2, 12, 12)) + tensors4 = [np.zeros((2, 10)), np.ones((3, 9))] + padded4 = data.pad_to_max_dims(tensors4, 12) + self.assertEqual(padded4.shape, (2, 4, 12)) + + def test_pad_to_length(self): + tensors1 = [(np.zeros((5)), np.ones((3)))] + pad_to_length_function1 = data.PadToLength( + len_map={0: 10, 1: 11}, pad_value={0: 0, 1: 1} + ) + padded1 = next(pad_to_length_function1(tensors1)) + self.assertEqual(padded1[0].shape, (10,)) + self.assertEqual(padded1[1].shape, (11,)) + + tensors2 = [(np.zeros((15)), np.ones((20)))] + pad_to_length_function2 = data.PadToLength( + len_map={0: 10, 1: 10}, pad_value={0: 0, 1: 1}, multiple=True + ) + padded2 = next(pad_to_length_function2(tensors2)) + self.assertEqual(padded2[0].shape, (20,)) + self.assertEqual(padded2[1].shape, (20,)) + + def test_concatenate_lm_input(self): + tensors1 = [(np.zeros((5)), np.ones((3)))] + + lm_input_function1 = data.ConcatenateToLMInput(pad_to_length=10) + lm_input_1 = next(lm_input_function1(tensors1)) + self.assertEqual(lm_input_1[0].shape, (10,)) + self.assertEqual(lm_input_1[1].shape, (10,)) + self.assertEqual(lm_input_1[2].shape, (10,)) + self.assertEqual( + lm_input_1[2].all(), + np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0]]).all(), + ) + + tensors2 = [(np.zeros((5)), np.ones((3)))] + lm_input_function2 = data.ConcatenateToLMInput() + lm_input_2 = next(lm_input_function2(tensors2)) + self.assertEqual(lm_input_2[0].shape, (8,)) + self.assertEqual(lm_input_2[1].shape, (8,)) + self.assertEqual(lm_input_2[2].shape, (8,)) + self.assertEqual( + lm_input_2[2].all(), + np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]]).all(), + ) + + def test_truncate_to_length_no_arg(self): + """Tests that a no-arg call leaves shapes unchanged.""" + + def data_stream(): + while True: + yield (np.zeros((1, 5)), np.ones((1, 5))) + + stream_fn = data.TruncateToLength() + y0, y1 = next(stream_fn(data_stream())) + self.assertEqual(y0.shape, (1, 5)) + self.assertEqual(y1.shape, (1, 5)) + + @parameterized.named_parameters( + ("none", None, ((1, 5), (1, 5))), + ("large_values", {0: (1, 77), 1: (1, 88)}, ((1, 5), (1, 5))), + ("small_values", {0: (1, 3), 1: (1, 2)}, ((1, 3), (1, 2))), + ) + def test_truncate_to_length_len_map(self, len_map, out_shapes): + """Tests that truncation occurs when len_map values are small enough.""" + + def data_stream(): + while True: + yield (np.zeros((1, 5)), np.ones((1, 5))) + + stream_fn = data.TruncateToLength(len_map=len_map) + y0, y1 = next(stream_fn(data_stream())) + self.assertEqual(y0.shape, out_shapes[0]) + self.assertEqual(y1.shape, out_shapes[1]) + + def test_truncate_to_length_questionable_behavior(self): + # Use of np.reshape in TruncateToLength allows non-truncation results + # without warning. As long as the target shape (len_map value) is + # lexicographically prior to the data shape, then np.reshape can happen, + # even if it results in *adding* values to the overall array. + # + # This test passes as a marker of the questionable behavior, and should + # *fail* -- and then be removed -- when the function is + # clarified/re-implemented. + # + # TODO(jonni): Determine desired behavior, and fit implementation to it. + x = np.arange(21).reshape((1, 21, 1)) + + def data_stream(): + while True: + yield x + + stream_fn = data.TruncateToLength(len_map={0: (1, 4, 6)}) + (y,) = next(stream_fn(data_stream())) + self.assertEqual(y.shape, (1, 4, 6)) + self.assertEqual(y[0, 3, 1], 19) + self.assertEqual(y[0, 3, 2], 20) # end of original values [0..20] + self.assertEqual(y[0, 3, 3], 0) # added value + self.assertEqual(y[0, 3, 4], 1) # added value + self.assertEqual(y[0, 3, 5], 2) # added value + + def test_filter_empty_examples(self): + tensors1 = [ + (np.zeros((0,)), np.ones((1, 5))), + (np.zeros((1, 5)), np.ones((1, 5))), + ] + + filter_empty_examples_function1 = data.FilterEmptyExamples() + filtered1 = next(filter_empty_examples_function1(tensors1)) + self.assertEqual(filtered1[0].shape, (1, 5)) + self.assertEqual(filtered1[1].shape, (1, 5)) + + filter_empty_examples_function2 = data.FilterEmptyExamples(axes=[1]) + filtered2 = next(filter_empty_examples_function2(tensors1)) + self.assertEqual(filtered2[0].shape, (0,)) + self.assertEqual(filtered2[1].shape, (1, 5)) + + def test_append_value(self): + tensors1 = [(np.zeros((1, 5)), np.ones((1, 5)))] + + append_value_function1 = data.AppendValue() + unmodified = next(append_value_function1(tensors1)) + self.assertEqual(unmodified[0].shape, (1, 5)) + self.assertEqual(unmodified[1].shape, (1, 5)) + + append_value_function2 = data.AppendValue({0: [[5]], 1: [[4]]}) + appended = next(append_value_function2(tensors1)) + self.assertEqual(appended[0].shape, (1, 6)) + self.assertEqual( + appended[0].all(), np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 5.0]]).all() + ) + self.assertEqual(appended[1].shape, (1, 6)) + self.assertEqual( + appended[1].all(), np.array([[1.0, 1.0, 1.0, 1.0, 1.0, 4.0]]).all() + ) + + def test_pad_to_max_dims_boundary_list(self): + tensors = [np.zeros((1, 15, 31)), np.ones((2, 10, 35)), np.ones((4, 2, 3))] + padded_tensors = data.pad_to_max_dims(tensors, boundary=(None, 15, 20)) + # no boundary, only max in the first dim, 15 is already the max len in + # second dim, last dim padded to multiple of 20. + # The outer dim is the batch here. + self.assertEqual(padded_tensors.shape, (3, 4, 15, 40)) + + def test_pad_to_max_dims_strict_pad_on_len(self): + tensors = [np.ones((15,)), np.ones((12,)), np.ones((14,))] + padded_tensors = data.pad_to_max_dims( + tensors, boundary=10, strict_pad_on_len=True + ) + self.assertEqual(padded_tensors.shape, (3, 20)) + + def test_bucket_by_length(self): + def fake_generator(length, num_examples=1): + for _ in range(num_examples): + yield (np.ones((length,)), np.ones((length,))) + + def length_function(example): + return max(example[0].shape[0], example[1].shape[0]) + + batches = list( + data.bucket_by_length( + fake_generator(5, 6), length_function, [20], [2], strict_pad_on_len=True + ) + ) + + # We'll get three batches of 2 examples each. + self.assertLen(batches, 3) + self.assertIsInstance(batches[0], tuple) + self.assertLen(batches[0], 2) + self.assertEqual((2, 20), batches[0][0].shape) + self.assertEqual((2, 20), batches[0][1].shape) + + @parameterized.named_parameters( + ("encdec_on", True), + ("encdec_off", False), + ) + def test_addition_inputs_exceptions(self, encdec): + vocab_size = 5 + batch_size = 256 + seq_length = 64 + # Check if max/min lengths are validated for train stream + with self.assertRaises(ValueError): + inputs = data.addition_inputs( + vocab_size=vocab_size, + batch_size=batch_size, + train_length=2, + eval_min_length=1, + eval_max_length=seq_length, + pad_to_multiple=seq_length, + encdec=encdec, + ) + train_stream = inputs.train_stream(n_devices=1) + for _ in range(10): + next(train_stream) + + # Check if max/min lengths are validated for eval stream + with self.assertRaises(ValueError): + inputs = data.addition_inputs( + vocab_size=vocab_size, + batch_size=batch_size, + train_length=seq_length, + eval_min_length=1, + eval_max_length=seq_length, + pad_to_multiple=seq_length, + encdec=True, + ) + eval_stream = inputs.eval_stream(n_devices=1) + for _ in range(10): + next(eval_stream) + + def test_addition_inputs_constraints(self): + vocab_size = 5 + batch_size = 256 + seq_length = 64 + inputs = data.addition_inputs( + vocab_size=vocab_size, + batch_size=batch_size, + train_length=seq_length, + eval_min_length=seq_length, + eval_max_length=seq_length, + pad_to_multiple=seq_length, + encdec=True, + ) + + # Check if max length is respected for train stream + train_stream = inputs.train_stream(n_devices=1) + for _ in range(10): + x, y, weights = next(train_stream) + self.assertEqual(x.shape[1], seq_length) + self.assertEqual(y.shape[1], seq_length) + self.assertEqual(weights.shape[1], seq_length) + + # Check if max length is respected for eval stream + eval_stream = inputs.eval_stream(n_devices=1) + for _ in range(10): + x, y, weights = next(eval_stream) + self.assertEqual(x.shape[1], seq_length) + self.assertEqual(y.shape[1], seq_length) + self.assertEqual(weights.shape[1], seq_length) + + def _get_span_lengths(self, x): + span_lengths = [] + curr_len = 0 + for i in range(1, len(x)): + # 1 -> 0 + if x[i] == 0 and x[i - 1] == 1: + span_lengths.append(curr_len) + curr_len = 0 + # 1 -> 1 or 0 -> 1 + elif (x[i] == 1 and x[i - 1] == 1) or (x[i] == 1 and x[i - 1] == 0): + curr_len += 1 + if curr_len != 0: + span_lengths.append(curr_len) + return span_lengths + + def test_random_spans_noise_mask(self): + length = 100 + noise_density = 0.15 + mean_noise_span_length = 3.0 + + # Take 5 random seed1, seed2 values. + for seed in np.random.randint(0, 100, (5, 2)): + is_noise = data.random_spans_noise_mask( + length, + noise_density, + mean_noise_span_length, + seed1=seed[0], + seed2=seed[1], + ) + is_noise = is_noise.astype(np.int32) + # noise_density fraction of tokens are produced + self.assertEqual(np.sum(is_noise), noise_density * length) + # Get span lengths and make sure the average is what we expect. + actual_span_lengths = self._get_span_lengths(is_noise) + average_span_length = sum(actual_span_lengths) / len(actual_span_lengths) + self.assertEqual(mean_noise_span_length, average_span_length) + + @absltest.skip("The version of the dataset you are trying is to old") + def test_process_c4_with_span_corruption(self): + def process_c4_with_span_corruption( + spm_path=None, + extra_ids=0, + train=False, + max_length=100, + noise_density=0.15, + mean_noise_span_length=3.0, + seed1=None, + seed2=None, + ): + return data.Serial( + data.TFDS( + "c4/en:2.3.0", data_dir=_TESTDATA, keys=("text",), train=train + ), + data.SentencePieceTokenize(spm_path=spm_path, extra_ids=extra_ids), + data.generate_sequential_chunks(max_length=max_length), + data.generate_random_noise_mask( + noise_density=noise_density, + mean_noise_span_length=mean_noise_span_length, + seed1=seed1, + seed2=seed2, + ), + data.consume_noise_mask(vocab_size=32000 + extra_ids), + data.FilterEmptyExamples(), + data.AppendValue(val={0: [1], 1: [1]}), + data.PadToLength(len_map={0: 100, 1: 30}, pad_value={0: 0, 1: 0}), + data.AddLossWeights(id_to_mask=0), + data.Batch(batch_size=2), + ) + + gen = process_c4_with_span_corruption(spm_path=_spm_path(), seed1=0, seed2=1) + + examples = [] + for i, ex in enumerate(gen()): + if i == 100: + break + examples.append(ex) + + self.assertLen(examples, 100) + example = examples[0] + + batched_input, batched_output, batched_loss_weights = example + + self.assertSequenceEqual( + batched_input.tolist(), + # pylint: disable=bad-continuation,bad-whitespace + [ + [ + 37, + 2335, + 113, + 3977, + 227, + 7306, + 45, + 3, + 9, + 4716, + 147, + 8, + 71, + 2658, + 65, + 118, + 4313, + 38, + 3, + 9, + 13065, + 32, + 31999, + 9, + 5704, + 26, + 109, + 6, + 6862, + 6, + 4728, + 45, + 8, + 3796, + 24093, + 11834, + 4716, + 30, + 8, + 1379, + 13, + 31998, + 130, + 718, + 12, + 8, + 24124, + 1343, + 300, + 4357, + 1714, + 31997, + 1373, + 47, + 16487, + 3168, + 16, + 321, + 7943, + 5, + 3, + 4868, + 3856, + 5700, + 75, + 7, + 200, + 2231, + 6, + 11163, + 9, + 6, + 113, + 47, + 5330, + 45, + 14354, + 6, + 47, + 31996, + 20721, + 3654, + 44, + 8, + 3112, + 5, + 14599, + 11, + 8067, + 31995, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 277, + 828, + 43, + 5899, + 46, + 16, + 10952, + 139, + 160, + 1687, + 56, + 539, + 30, + 2875, + 41, + 31122, + 2307, + 137, + 2702, + 2780, + 15, + 7, + 31999, + 44, + 8, + 3112, + 11, + 30, + 569, + 783, + 5, + 3, + 17701, + 6, + 2194, + 26, + 23, + 1336, + 6321, + 1694, + 30, + 31998, + 196, + 56, + 1852, + 1423, + 25, + 5, + 27, + 183, + 8032, + 31997, + 217, + 149, + 1513, + 11, + 2238, + 25, + 1800, + 5, + 96, + 2703, + 44, + 3065, + 12537, + 11163, + 9, + 535, + 71, + 9363, + 14886, + 646, + 44, + 8, + 3112, + 243, + 23281, + 12, + 8, + 31996, + 346, + 402, + 17, + 99, + 83, + 11, + 773, + 3668, + 1280, + 31995, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ], + # pylint: enable=bad-continuation,bad-whitespace + ) + + self.assertSequenceEqual( + batched_output.tolist(), + # pylint: disable=bad-continuation,bad-whitespace + [ + [ + 31999, + 1639, + 7, + 15480, + 5, + 11163, + 31998, + 2083, + 9997, + 5076, + 31997, + 265, + 11, + 8, + 31996, + 3, + 31995, + 1343, + 2487, + 106, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 31999, + 12, + 8, + 15480, + 130, + 646, + 31998, + 1376, + 10, + 96, + 31997, + 62, + 410, + 59, + 31996, + 96, + 31995, + 94, + 608, + 10, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ], + # pylint: enable=bad-continuation,bad-whitespace + ) + + self.assertSequenceEqual( + batched_loss_weights.tolist(), + # pylint: disable=bad-continuation,bad-whitespace + [ + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + ], + # pylint: enable=bad-continuation,bad-whitespace + ) + + def test_prefix_lm_last_output_batch_is_short(self): + prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) + examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7, 8]])) + self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) + self.assertSequenceEqual(([6, 7], [8]), examples[1]) + self.assertLen(examples, 2) + + def test_prefix_lm_last_input_batch_is_short(self): + prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) + examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6]])) + self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) + self.assertLen(examples, 1) + + def test_prefix_lm_last_input_batch_exists_but_no_output(self): + prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) + examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7]])) + self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) + self.assertLen(examples, 1) + + def test_unbatch(self): + unbatch_fn = data.UnBatch() + batched_inputs = [ + # First batch - 3 examples + ( + np.arange(3 * 2).reshape(3, -1), + np.arange(3 * 3).reshape(3, -1), + np.arange(3 * 4).reshape(3, -1), + ), + # Second batch - 4 examples + ( + np.arange(4 * 2).reshape(4, -1), + np.arange(4 * 3).reshape(4, -1), + np.arange(4 * 4).reshape(4, -1), + ), + ] + examples = list(unbatch_fn(batched_inputs)) + self.assertLen(examples, 3 + 4) + + def test_sine_shape(self): + inputs = data.sine_inputs(batch_size=3, length=5) + train_batch = next(inputs.train_stream(n_devices=1)) + eval_batch = next(inputs.eval_stream(n_devices=1)) + # (observations, actions, observations, mask) + self.assertLen(train_batch, 4) + self.assertLen(eval_batch, 4) + for (x, y) in zip(train_batch, eval_batch): + self.assertEqual(x.shape, (3, 5)) + self.assertEqual(y.shape, (3, 5)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/data/preprocessing/tf/dataset/bert_test.py b/tests/data/preprocessing/tf/dataset/bert_test.py new file mode 100644 index 000000000..e206b9afb --- /dev/null +++ b/tests/data/preprocessing/tf/dataset/bert_test.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.tf_inputs.""" + +import gin +import numpy as np +import tensorflow as tf + +from tests.data.utils import ( # relative import + _TESTDATA, +) +from trax.data.loader.tf.base import next_sentence_prediction_tf +from trax.data.preprocessing.inputs import batcher # noqa: F401 +from trax.data.preprocessing.tf import bert as inputs_bert + + +class InputsBertTest(tf.test.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def test_create_bert_inputs(self): + inputs_sentences_1 = [np.array([100, 150, 200])] + inputs_sentences_2 = [np.array([300, 500])] + labels = [np.array(1)] + + create_inputs_1 = inputs_bert.CreateBertInputs(False) + create_inputs_2 = inputs_bert.CreateBertInputs(True) + for res in create_inputs_1(zip(inputs_sentences_1, labels)): + values, segment_embs, _, label, weight = res + self.assertAllEqual(values, np.array([101, 100, 150, 200, 102])) + self.assertAllEqual(segment_embs, np.zeros(5)) + self.assertEqual(label, np.int64(1)) + self.assertEqual(weight, np.int64(1)) + + for res in create_inputs_2(zip(inputs_sentences_1, inputs_sentences_2, labels)): + values, segment_embs, _, label, weight = res + self.assertAllEqual( + values, np.array([101, 100, 150, 200, 102, 300, 500, 102]) + ) + exp_segment = np.concatenate((np.zeros(5), np.ones(3))) + self.assertAllEqual(segment_embs, exp_segment) + self.assertEqual(label, np.int64(1)) + self.assertEqual(weight, np.int64(1)) + + def test_bert_next_sentence_prediction_inputs(self): + stream = inputs_bert.BertNextSentencePredictionInputs( + "c4/en:2.3.0", data_dir=_TESTDATA, train=False, shuffle_size=1 + ) + exp_sent1 = "The woman who died after falling from" + exp_sent2 = "The woman who died after falling from" + sent1, sent2, label = next(stream()) + print(sent1, sent2, label) + + self.assertIn(exp_sent1, sent1, "exp_sent1 powinien być częścią sent1") + self.assertIn(exp_sent2, sent1, "exp_sent1 powinien być częścią sent1") + self.assertFalse(label) + + def test_mask_random_tokens(self): + """Test only standard tokens. + + This test deals with sentences composed of two parts: [100 CLS tokens, 100 + chosen standard tokens]. CLS is the token that is added at the beginning of + the sentence and there is only one token in standard scenario. It is never + masked because it is not a part of the sentence. + This tests whether mask_random_tokens will: + - mask only standard tokens + - mask expected number of tokens (15 percent candidates for masking) + """ + cls_token = 101 + mask_token = 103 + example_standard_token = 1001 + test_case_row = np.array([cls_token] * 100 + [example_standard_token] * 100) + test_case = [(test_case_row.copy(),)] + + out, original_tokens, token_weights = next( + inputs_bert.mask_random_tokens(test_case) + ) + # test whether original tokens are unchanged + self.assertAllEqual(test_case_row, original_tokens) + + self.assertEqual(1, token_weights.sum()) + self.assertEqual( + 15, (token_weights > 0).sum() + ) # we should have 15 candidates for masking + + # 101 is a special token, so only 1001 should be masked + self.assertAllEqual(out[:100], test_case_row[:100]) + + # Each candidate has 0.8 probability to be masked while others have 0, so + # no more than 15 tokens with MASK + self.assertLessEqual((out == mask_token).sum(), 15) + + def test_next_sentence_prediction_tf(self): + # Create dummy dataset with two examples. + def data_generator(): + yield {"text": "This is the first sentence. This is the second sentence."} + yield {"text": "Another example text. And a follow-up sentence."} + + output_signature = {"text": tf.TensorSpec(shape=(), dtype=tf.string)} + dataset = tf.data.Dataset.from_generator( + data_generator, output_signature=output_signature + ) + + preprocess = next_sentence_prediction_tf() + processed_ds = preprocess(dataset) + + # Collect results for analysis + examples = [] + for example in processed_ds.take(10): + examples.append( + { + "inputs": example["inputs"].numpy().decode("utf-8"), + "targets": example["targets"].numpy().decode("utf-8"), + } + ) + tf.print(example) + + # Check if we have at least some examples + self.assertGreater(len(examples), 0) + + for example in examples: + # Check the output structure + self.assertIn("inputs", example) + self.assertIn("targets", example) + + # Verify that outputs have correct format + inputs = example["inputs"] + self.assertIn("sentence1:", inputs) + self.assertIn("sentence2:", inputs) + + # Check if label is one of the expected values + self.assertIn(example["targets"], ["next", "not_next"]) + + # Extract sentences for further analysis + parts = inputs.split("sentence2:") + sent1_part = parts[0].strip() + sent1 = sent1_part.replace("sentence1:", "").strip() + sent2 = parts[1].strip() + + # Check if sentences are not empty + self.assertTrue(len(sent1) > 0) + self.assertTrue(len(sent2) > 0) + + # Check relationship between label and sentences + if example["targets"] == "next": + # For "next", both sentences should come from the same document + # We can't fully test this due to randomness, but we can check + # if the format matches the expected pattern + exp_sent1 = "This is the first sentence" + exp_sent2 = "This is the second sentence" + self.assertTrue( + (exp_sent1 in sent1 and exp_sent2 in sent2) + or ( + "Another example text" in sent1 + and "And a follow-up sentence" in sent2 + ) + or not (exp_sent1 in sent1 and "And a follow-up sentence" in sent2) + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/data/preprocessing/tf/dataset/c4_test.py b/tests/data/preprocessing/tf/dataset/c4_test.py new file mode 100644 index 000000000..8d751d742 --- /dev/null +++ b/tests/data/preprocessing/tf/dataset/c4_test.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.tf_inputs.""" +import collections + +import gin +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from tests.data.utils import ( # relative import + _TESTDATA, + _c4_dataset, + _spm_path, + _t5_gin_config, +) +from trax.data.loader.tf import base as ds +from trax.data.preprocessing.inputs import batcher # noqa: F401 +from trax.data.preprocessing.tf.c4 import c4_bare_preprocess_fn, c4_preprocess + + +class TFDatasetC4Test(tf.test.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def test_c4_bare_preprocess_fn(self): + dataset = _c4_dataset() + + example = list(tfds.as_numpy(dataset.take(1)))[0] + + # Targets are NOT in the example. + self.assertNotIn("targets", example) + self.assertIn("text", example) + text = example["text"] + + # This should convert the dataset to an inputs/targets that are tokenized. + dataset = c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) + + example = list(tfds.as_numpy(dataset.take(1)))[0] + + # Earlier text is now stored in targets_pretokenized + self.assertIn("targets_pretokenized", example) + self.assertEqual(example["targets_pretokenized"], text) + + # Targets are now tokenized. + self.assertIn("targets", example) + self.assertIsInstance(example["targets"], np.ndarray) + self.assertEqual(example["targets"].dtype, np.int64) + self.assertGreater(len(example["targets"]), 0) + self.assertEqual(example["targets"][-1], 1) # we add EOS at the end. + + self.assertIn("inputs", example) + self.assertEqual(len(example["inputs"]), 171) + + def test_c4_preprocess(self): + def load_c4_dataset(split="train"): + dataset = _c4_dataset(split=split) + return dataset.map(lambda example: (example, example["text"])) + + def examine_processed_dataset(proc_dataset): + count = 0 + lengths = [] + for example in tfds.as_numpy(proc_dataset): + count += 1 + ex = example[0] + # Targets are in the example. + self.assertIn("targets", ex) + self.assertEqual(ex["targets"].dtype, np.int64) + lengths.append(len(ex["targets"])) + return count, lengths + + unfiltered_count = 0 + for example in tfds.as_numpy(load_c4_dataset()): + unfiltered_count += 1 + # Targets are NOT in the example. + self.assertNotIn("targets", example[0]) + + proc_dataset = c4_preprocess(load_c4_dataset(), False, 2048) + + # `examine_processed_dataset` has some asserts in it. + proc_count, char_lengths = examine_processed_dataset(proc_dataset) + + # Both the original and filtered datasets have examples. + self.assertGreater(unfiltered_count, 0) + self.assertGreater(proc_count, 0) + + # Because we filter out some entries on length. + self.assertLess(proc_count, unfiltered_count) + + # Preprocess using the sentencepiece model in testdata. + spc_proc_dataset = c4_preprocess( + load_c4_dataset(), False, 2048, tokenization="spc", spm_path=_spm_path() + ) + + spc_proc_count, spc_lengths = examine_processed_dataset(spc_proc_dataset) + + # spc shortens the target sequence a lot, should be almost equal to + # unfiltered + self.assertLessEqual(proc_count, spc_proc_count) + self.assertEqual(unfiltered_count, spc_proc_count) + + # Assert all spc_lengths are lesser than their char counterparts. + for spc_len, char_len in zip(spc_lengths, char_lengths): + self.assertLessEqual(spc_len, char_len) + + def test_c4(self): + gin.bind_parameter("c4_preprocess.max_target_length", 2048) + gin.bind_parameter("c4_preprocess.tokenization", "spc") + gin.bind_parameter("c4_preprocess.spm_path", _spm_path()) + + result = None + + try: + # Just make sure this doesn't throw. + result = ds.data_streams( + "c4", + data_dir=_TESTDATA, + input_name="targets", + target_name="text", + preprocess_fn=c4_preprocess, + ) + except Exception as e: + self.fail(f"data_streams() raised an unexpected exception: {e}") + + self.assertIsNotNone(result, "data_streams() returned None unexpectedly") + + def test_c4_bare_preprocess_fn_denoising_objective(self): + _t5_gin_config() + + dataset = _c4_dataset() + dataset = c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) + + example = list(tfds.as_numpy(dataset.take(1)))[0] + + # Assertions now. + self.assertIn("targets", example) + targets = example["targets"] + self.assertIsInstance(targets, np.ndarray) + self.assertEqual(targets.dtype, np.int64) + self.assertGreater(len(targets), 0) + + self.assertIn("inputs", example) + _inputs = example["inputs"] # pylint: disable=invalid-name + self.assertIsInstance(_inputs, np.ndarray) + self.assertEqual(_inputs.dtype, np.int64) + self.assertGreater(len(_inputs), 0) + + # WHP inputs will have the bulk of the text. + self.assertGreater(len(targets), len(_inputs)) + + # WHP there will be one sentinel token in the inputs and targets. + # We new tokenizer so there is no sentinel any more + inputs_counter = collections.Counter(_inputs.tolist()) + targets_counter = collections.Counter(targets.tolist()) + self.assertEqual(0, inputs_counter[31999]) + self.assertEqual(0, targets_counter[31999]) + + self.assertEqual(0, inputs_counter[1]) + self.assertEqual(1, targets_counter[1]) + + def test_c4_pretrain(self): + _t5_gin_config() + + gin.bind_parameter("c4_bare_preprocess_fn.spm_path", _spm_path()) + + gin.bind_parameter("batcher.batch_size_per_device", 8) + gin.bind_parameter("batcher.eval_batch_size", 8) + gin.bind_parameter("batcher.max_eval_length", 50) + gin.bind_parameter("batcher.buckets", ([51], [8, 1])) + + result = None + + try: + # Just make sure this doesn't throw. + result = ds.data_streams( + "c4", + data_dir=_TESTDATA, + input_name="inputs", + target_name="targets", + bare_preprocess_fn=c4_bare_preprocess_fn, + ) + except Exception as e: + self.fail(f"data_streams() raised an unexpected exception: {e}") + + self.assertIsNotNone(result, "data_streams() returned None unexpectedly") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/data/preprocessing/tf/dataset/math_test.py b/tests/data/preprocessing/tf/dataset/math_test.py new file mode 100644 index 000000000..b52e32b15 --- /dev/null +++ b/tests/data/preprocessing/tf/dataset/math_test.py @@ -0,0 +1,175 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.data.tf_inputs.""" + + +import gin +import tensorflow as tf + +from trax.data.preprocessing.inputs import batcher # noqa: F401 +from trax.data.preprocessing.tf import math as dataset_math + + +class TFDatasetMathTest(tf.test.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def test_process_single_mathqa_example_0(self): + # This is the first problem in the MathQA dataset. + example = { + "Problem": "the banker ' s gain of a certain sum due 3 years hence at 10 % " + "per annum is rs . 36 . what is the present worth ?", + "Rationale": '"explanation : t = 3 years r = 10 % td = ( bg × 100 ) / tr = ( ' + "36 × 100 ) / ( 3 × 10 ) = 12 × 10 = rs . 120 td = ( pw × tr )" + " / 100 ⇒ 120 = ( pw × 3 × 10 ) / 100 ⇒ 1200 = pw × 3 pw = " + '1200 / 3 = rs . 400 answer : option a"', + "options": "a ) rs . 400 , b ) rs . 300 , c ) rs . 500 , d ) rs . 350 , e ) " + "none of these", + "correct": "a", + "annotated_formula": "divide(multiply(const_100, divide(multiply(36, const_100), " + "multiply(3, 10))), multiply(3, 10))", + "linear_formula": "multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)|multiply(#2,const_100)|divide(#3,#1)|", + "category": "gain", + } + + ( + answer_num, + python_result, + python_program, + list_op, + list_num, + ) = dataset_math.process_single_mathqa_example(example) + + self.assertEqual(answer_num, 400) # we know it, because correct answer is a) + self.assertEqual(python_result, [3600.0, 30.0, 120.0, 12000.0, 400.0]) + + self.assertEqual( + python_program, + [ + "t0 = n2 * 100.0", + "t1 = n0 * n1", + "t2 = t0 / t1", + "t3 = t2 * 100.0", + "t4 = t3 / t1", + ], + ) + self.assertEqual( + list_op, + [ + "multiply(n2,const_100)", + "multiply(n0,n1)", + "divide(#0,#1)", + "multiply(#2,const_100)", + "divide(#3,#1)", + ], + ) + self.assertEqual(list_num, [3.0, 10.0, 36.0]) + + def test_process_single_mathqa_example_1(self): + # This is the third problem in the MathQA dataset. + example = { + "Problem": "sophia finished 2 / 3 of a book . she calculated that she " + "finished 90 more pages than she has yet to read . how long is her" + " book ?", + "Rationale": "let xx be the total number of pages in the book , then she " + "finished 23 ⋅ x 23 ⋅ x pages . then she has x − 23 ⋅ x = " + "13 ⋅ xx − 23 ⋅ x = 13 ⋅ x pages left . 23 ⋅ x − 13 " + "⋅ x = 9023 ⋅ x − 13 ⋅ x = 90 13 ⋅ x = 9013 ⋅ x = 90 x" + " = 270 x = 270 so the book is 270 pages long . answer : b", + "options": "a ) 229 , b ) 270 , c ) 877 , d ) 266 , e ) 281", + "correct": "b", + "annotated_formula": "divide(90, subtract(const_1, divide(2, 3)))", + "linear_formula": "divide(n0,n1)|subtract(const_1,#0)|divide(n2,#1)", + "category": "general", + } + + ( + answer_num, + python_result, + python_program, + list_op, + list_num, + ) = dataset_math.process_single_mathqa_example(example) + + self.assertEqual(answer_num, 270) # we know it, because correct answer is b) + self.assertAllClose( + python_result, [0.6666666666666666, 0.33333333333333337, 269.99999999999994] + ) + self.assertEqual( + python_program, ["t0 = n0 / n1", "t1 = 1.0 - t0", "t2 = n2 / t1"] + ) + self.assertEqual( + list_op, ["divide(n0,n1)", "subtract(const_1,#0)", "divide(n2,#1)"] + ) + self.assertEqual(list_num, [2.0, 3.0, 90.0]) + + def test_process_single_mathqa_example_with_import(self): + # This is a training MathQA problem which involve an import. + example = { + "Problem": "the length of a rectangular garden is three times its width . if " + "the area of the rectangular garden is 588 square meters , then " + "what is the width of the rectangular garden ?", + "Rationale": '"let x be the width of the garden . 3 x ^ 2 = 588 x ^ 2 = 196 x ' + '= 14 the answer is c ."', + "options": "a ) 12 , b ) 13 , c ) 14 , d ) 15 , e ) 16", + "correct": "c", + "annotated_formula": "sqrt(divide(588, const_3))", + "linear_formula": "divide(n0,const_3)|sqrt(#0)|", + "category": "geometry", + } + + ( + answer_num, + python_result, + python_program, + list_op, + list_num, + ) = dataset_math.process_single_mathqa_example(example) + + self.assertEqual(answer_num, 14) # we know it, because correct answer is c) + self.assertAllClose(python_result, [196, 14]) + self.assertEqual( + python_program, ["t0 = n0 / 3.0", "t1 = math.sqrt(max(0, t0))"] + ) + self.assertEqual(list_op, ["divide(n0,const_3)", "sqrt(#0)"]) + self.assertEqual(list_num, [588]) + + # Below we execute twice the Python program and once the DSL program. + target_values = "import math\n" + problem = example["Problem"] + for i in range(len(list_num)): + target_values += "n{} = {}\n".format(i, list_num[i]) + problem += " n{} = {}".format(i, list_num[i]) + target_values += "\n".join(python_program[:-1]) + final_line = python_program[-1].split("=")[1] + target_values += "\nanswer ={}".format(final_line) + var_dict = {} + exec(target_values, globals(), var_dict) # pylint: disable=exec-used + self.assertAllClose(var_dict["answer"], 14) + self.assertAllClose( + dataset_math.execute_mathqa_program(problem, target_values.split("\n")), 14 + ) + self.assertAllClose( + dataset_math.execute_mathqa_dsl_program( + problem, [example["linear_formula"]] + ), + 14, + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/data/utils.py b/tests/data/utils.py new file mode 100644 index 000000000..acd03bf14 --- /dev/null +++ b/tests/data/utils.py @@ -0,0 +1,179 @@ +import os + +from typing import Any, Mapping, Optional, Sequence, Union + +import gin +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from absl.testing import absltest +from t5.data import preprocessors as t5_processors + +from trax.data.loader.tf import base as ds + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath(os.path.join(pkg_dir, "../../resources/data/testdata")) +_CONFIG_DIR = os.path.normpath( + os.path.join(pkg_dir, "../../resources/supervised/configs/") +) +_SUPERVISED_TESTDATA = os.path.normpath( + os.path.join(pkg_dir, "../../resources/supervised/testdata") +) + +# _ProxyTest is required because py2 does not allow instantiating +# absltest.TestCase directly. +class _ProxyTest(absltest.TestCase): + """Instance of TestCase to reuse methods for testing.""" + + maxDiff = None + + def runTest(self): + pass + + +_pyunit_proxy = _ProxyTest() + + +def _test_dataset_ints(inp_lengths, tgt_lengths): + """Create a test dataset of int64 tensors of given shapes.""" + + def generator(): + for inp_len, tgt_len in zip(inp_lengths, tgt_lengths): + inp = np.ones([inp_len], dtype=np.int64) + tgt = np.ones([tgt_len], dtype=np.int64) + yield {"inputs": inp, "targets": tgt} + + types = {"inputs": tf.int64, "targets": tf.int64} + shapes = {"inputs": tf.TensorShape([None]), "targets": tf.TensorShape([None])} + return tf.data.Dataset.from_generator( + generator, output_types=types, output_shapes=shapes + ) + + +def _load_dataset(name, split="train"): + return tfds.load(name=name, split=split, data_dir=_TESTDATA, shuffle_files=False) + + +def _c4_dataset(split="train"): + return _load_dataset("c4:2.3.0", split=split) + + +def _spm_path(): + return os.path.join(_TESTDATA, "sentencepiece.model") + + +def _t5_gin_config(): + # The following pages worth of gin configuration are required because a lot + # of T5 functions have `gin.REQUIRED` in code, i.e. you cannot use these + # functions at all without having configured gin. + + noise_density = 0.15 + max_input_length = 50 + + # What preprocessors to apply - we select a random chunk of the document if + # it exceeds a certain lengths (`select_random_chunk`), then split up long + # examples (`split_tokens`) and finally the denoising objective (`denoise`). + # + # In addition to this T5 concates multiple documents together to reduce + # padding (`reduce_concat_tokens`) after `select_random_chunk`, but we skip + # that since we don't do sequence packing. + gin.bind_parameter( + "unsupervised_preprocessors.preprocessors", + [ + ds._PREPROCESSOR_REGISTRY["select_random_chunk_t5"], + ds._PREPROCESSOR_REGISTRY["split_tokens_t5"], + ds._PREPROCESSOR_REGISTRY["denoise_t5"], + ], + ) + + # select_random_chunk + gin.bind_parameter("select_random_chunk.feature_key", "targets") + gin.bind_parameter("select_random_chunk.max_length", max_input_length) + + # reduce_concat_tokens + gin.bind_parameter("random_spans_helper.extra_tokens_per_span_inputs", 1) + gin.bind_parameter("random_spans_helper.extra_tokens_per_span_targets", 1) + gin.bind_parameter("random_spans_helper.inputs_length", max_input_length) + gin.bind_parameter("random_spans_helper.mean_noise_span_length", 3.0) + gin.bind_parameter("random_spans_helper.noise_density", noise_density) + + # split_tokens + gin.bind_parameter( + "split_tokens.max_tokens_per_segment", + t5_processors.random_spans_tokens_length(), + ) + + # denoise + gin.bind_parameter("denoise.inputs_fn", t5_processors.noise_span_to_unique_sentinel) + gin.bind_parameter("denoise.noise_density", noise_density) + gin.bind_parameter("denoise.noise_mask_fn", t5_processors.random_spans_noise_mask) + gin.bind_parameter( + "denoise.targets_fn", t5_processors.nonnoise_span_to_unique_sentinel + ) + + +def _maybe_as_bytes(v): + if isinstance(v, list): + return [_maybe_as_bytes(x) for x in v] + if isinstance(v, str): + return tf.compat.as_bytes(v) + return v + + +def assert_dataset( + dataset: tf.data.Dataset, + expected: Union[Mapping[str, Any], Sequence[Mapping[str, Any]]], + expected_dtypes: Optional[Mapping[str, tf.DType]] = None, + rtol=1e-7, + atol=0, +): + """Tests whether the entire dataset == expected or [expected]. + + Args: + dataset: a tf.data dataset + expected: either a single example, or a list of examples. Each example is a + dictionary. + expected_dtypes: an optional mapping from feature key to expected dtype. + rtol: the relative tolerance. + atol: the absolute tolerance. + """ + + if not isinstance(expected, list): + expected = [expected] + actual = list(tfds.as_numpy(dataset)) + _pyunit_proxy.assertEqual(len(actual), len(expected)) + + def _compare_dict(actual_dict, expected_dict): + _pyunit_proxy.assertEqual(set(actual_dict.keys()), set(expected_dict.keys())) + for key, actual_value in actual_dict.items(): + if isinstance(actual_value, dict): + _compare_dict(actual_value, expected_dict[key]) + elif isinstance(actual_value, tf.RaggedTensor) or isinstance( + actual_value, tf.compat.v1.ragged.RaggedTensorValue + ): + actual_value = actual_value.to_list() + np.testing.assert_array_equal( + np.array(actual_value, dtype=object), + np.array(_maybe_as_bytes(expected_dict[key]), dtype=object), + key, + ) + elif ( + isinstance(actual_value, np.floating) + or isinstance(actual_value, np.ndarray) + and np.issubdtype(actual_value.dtype, np.floating) + ): + np.testing.assert_allclose( + actual_value, expected_dict[key], err_msg=key, rtol=rtol, atol=atol + ) + else: + np.testing.assert_array_equal( + actual_value, _maybe_as_bytes(expected_dict[key]), key + ) + + for actual_ex, expected_ex in zip(actual, expected): + _compare_dict(actual_ex, expected_ex) + + if expected_dtypes: + actual_dtypes = {k: dataset.element_spec[k].dtype for k in expected_dtypes} + _pyunit_proxy.assertDictEqual(expected_dtypes, actual_dtypes) diff --git a/tests/fastmath/jax/config.py b/tests/fastmath/jax/config.py new file mode 100644 index 000000000..4c68441a5 --- /dev/null +++ b/tests/fastmath/jax/config.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + + +def bool_env(varname: str, default: bool) -> bool: + """Read an environment variable and interpret it as a boolean. + + True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; + false values are 'n', 'no', 'f', 'false', 'off', and '0'. + + Args: + varname: the name of the variable + default: the default boolean value + Raises: ValueError if the environment variable is anything else. + """ + val = os.getenv(varname, str(default)) + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError("invalid truth value %r for environment %r" % (val, varname)) + + +class Config(object): + def __init__(self): + self.values = {} + self.meta = {} + self.FLAGS = NameSpace(self.read) + self.use_absl = False + + def update(self, name, val): + if self.use_absl: + setattr(self.absl_flags.FLAGS, name, val) + else: + self.check_exists(name) + if name not in self.values: + raise Exception("Unrecognized config option: {}".format(name)) + self.values[name] = val + + def read(self, name): + if self.use_absl: + return getattr(self.absl_flags.FLAGS, name) + else: + self.check_exists(name) + return self.values[name] + + def add_option(self, name, default, opt_type, meta_args, meta_kwargs): + if name in self.values: + raise Exception("Config option {} already defined".format(name)) + self.values[name] = default + self.meta[name] = (opt_type, meta_args, meta_kwargs) + + def check_exists(self, name): + if name not in self.values: + raise Exception("Unrecognized config option: {}".format(name)) + + def DEFINE_bool(self, name, default, *args, **kwargs): + self.add_option(name, default, bool, args, kwargs) + + def DEFINE_integer(self, name, default, *args, **kwargs): + self.add_option(name, default, int, args, kwargs) + + def DEFINE_string(self, name, default, *args, **kwargs): + self.add_option(name, default, str, args, kwargs) + + def DEFINE_enum(self, name, default, *args, **kwargs): + self.add_option(name, default, "enum", args, kwargs) + + def config_with_absl(self): + # Run this before calling `app.run(main)` etc + from absl import app, flags as absl_flags + + self.use_absl = True + self.absl_flags = absl_flags + absl_defs = { + bool: absl_flags.DEFINE_bool, + int: absl_flags.DEFINE_integer, + str: absl_flags.DEFINE_string, + "enum": absl_flags.DEFINE_enum, + } + + for name, val in self.values.items(): + flag_type, meta_args, meta_kwargs = self.meta[name] + absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) + + app.call_after_init(lambda: self.complete_absl_config(absl_flags)) + + def complete_absl_config(self, absl_flags): + for name, _ in self.values.items(): + self.update(name, getattr(absl_flags.FLAGS, name)) + + def parse_flags_with_absl(self): + global already_configured_with_absl + if not already_configured_with_absl: + import absl.flags + + self.config_with_absl() + absl.flags.FLAGS(sys.argv, known_only=True) + self.complete_absl_config(absl.flags) + already_configured_with_absl = True + + +class NameSpace(object): + def __init__(self, getter): + self._getter = getter + + def __getattr__(self, name): + return self._getter(name) + + +config = Config() +flags = config +FLAGS = flags.FLAGS + +already_configured_with_absl = False + +flags.DEFINE_bool( + "jax_enable_checks", + bool_env("JAX_ENABLE_CHECKS", False), + help="Turn on invariant checking (core.skip_checks = False)", +) + +flags.DEFINE_bool( + "tf_numpy_additional_tests", True, "Run tests added specifically for TF numpy" +) diff --git a/tests/fastmath/jax/lax_numpy_einsum_test.py b/tests/fastmath/jax/lax_numpy_einsum_test.py new file mode 100644 index 000000000..691d9102d --- /dev/null +++ b/tests/fastmath/jax/lax_numpy_einsum_test.py @@ -0,0 +1,372 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +from collections import defaultdict # pylint: disable=g-importing-member + +import numpy as np +import tensorflow.compat.v2 as tf + +from absl.testing import absltest, parameterized + +import tests.fastmath.jax.utils as jtu +import trax.tf.numpy as jnp + +from tests.fastmath.jax.config import config + +config.parse_flags_with_absl() + + +class EinsumTest(jtu.TestCase): + def _check(self, s, *ops): + a = np.einsum(s, *ops) + b = jnp.einsum(s, *ops) + self.assertAllClose(a, b, check_dtypes=True, atol=1e-4, rtol=1e-4) + + def test_three_operands_1(self): + r = self.rng() + x = r.randn(3) + y = r.randn(4) + z = r.randn(5) + s = "i,j,k->ijk" + self._check(s, x, y, z) + + def test_three_operands_2(self): + r = self.rng() + x = r.randn(3) + y = r.randn(4) + z = r.randn(5) + s = "i,j,k->ijk" + self._check(s, x, y, z) + + def test_two_operands_1(self): + r = self.rng() + x = r.randn(3, 4) + y = r.randn(4) + s = "ij,j->i" + self._check(s, x, y) + + def test_two_operands_2(self): + r = self.rng() + x = r.randn(3, 4, 5) + y = r.randn(4) + s = "ijk,j->i" + self._check(s, x, y) + + def test_two_operands_3(self): + r = self.rng() + x = r.randn(3, 4, 3) + y = r.randn(3) + s = "iji,i->j" + self._check(s, x, y) + + def test_two_operands_4(self): + r = self.rng() + x = r.randn(3, 4) + y = r.randn(3, 4) + s = "ij,ij->" + self._check(s, x, y) + + def test_two_operands_5(self): + r = self.rng() + x = r.randn(10, 2, 3) + y = r.randn(3, 4) + s = "nij,jk->nik" + self._check(s, x, y) + + def test_two_operands_6(self): + # based on https://github.com/google/jax/issues/37#issuecomment-448572187 + r = self.rng() + x = r.randn(2, 1) + y = r.randn(2, 3, 4) + s = "sa,shb->shab" + self._check(s, x, y) + + def test_one_operand_1(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->j" + self._check(s, x) + + def test_one_operand_2(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->kij" + self._check(s, x) + + def test_one_operand_3(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->ki" + self._check(s, x) + + def test_one_operand_4(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "ijk->ki" + self._check(s, x) + + def test_one_operand_5(self): + r = self.rng() + x = r.randn(2, 3, 4, 5) + s = "...ijk->...ki" + self._check(s, x) + + def test_one_operand_6(self): + r = self.rng() + x = r.randn(3, 4, 5) + s = "...ijk->ki" + self._check(s, x) + + def test_one_operand_7(self): + r = self.rng() + x = r.randn(3, 3) + s = "ii->" + self._check(s, x) + + def test_one_operand_8(self): + r = self.rng() + x = r.randn(3, 3) + s = "ij->" + self._check(s, x) + + def test_one_operand_9(self): + r = self.rng() + x = r.randn(3, 3, 3) + s = "iii->" + self._check(s, x) + + def test_one_operand_10(self): + r = self.rng() + x = r.randn(3, 3) + s = "ii->i" + self._check(s, x) + + def test_one_operand_11(self): + r = self.rng() + x = r.randn(3, 3, 4) + s = "iij->i" + self._check(s, x) + + def test_one_operand_12(self): + r = self.rng() + x = r.randn(3, 3, 3) + s = "iii->i" + self._check(s, x) + + def test_one_operand_13(self): + r = self.rng() + x = r.randn(3, 3, 5, 4, 4) + s = "iijkk->i" + self._check(s, x) + + def test_one_operand_14(self): + r = self.rng() + x = r.randn(3, 3, 5, 4, 4) + s = "iijkk->ik" + self._check(s, x) + + def test_one_operand_15(self): + r = self.rng() + x = r.randn(3, 3, 5, 4, 4) + s = "iijkl->il" + self._check(s, x) + + def test_one_operand_16(self): + r = self.rng() + x = r.randn(3, 3) + s = "ij->ij" + self._check(s, x) + + def test_tf_unsupported_1(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = self.rng() + x = r.randn(2, 3, 5, 1) + y = r.randn(3, 4, 5, 1) + s = "ij...,jk...->ik..." + self._check(s, x, y) + + def test_tf_unsupported_2(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = self.rng() + x = r.randn(2, 3, 3) + y = r.randn(4) + s = "ijj,k->ik" + self._check(s, x, y) + + def test_tf_unsupported_3(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = self.rng() + x = r.randn(2, 3) + y = r.randn(2, 3) + z = r.randn(3, 4) + s = "ij,ij,jk->ik" + self._check(s, x, y, z) + + # these tests are based on https://github.com/dask/dask/pull/3412/files + @parameterized.named_parameters( + { + "testcase_name": "_{}_dtype={}".format( + einstr, dtype.__name__ + ), # pylint: disable=g-complex-comprehension + "einstr": einstr, + "dtype": dtype, + } + for einstr in [ + "abc,bad->abcd", + "abcdef,bcdfg->abcdeg", + "ea,fb,abcd,gc,hd->efgh", + "ab,b", + "aa", + "a,a->", + "a,a->a", + "a,a", + "a,b", + "a,b,c", + "a", + "ba,b", + "ba,b->", + "defab,fedbc->defac", + "ab...,bc...->ac...", + "a...a", + "abc...->cba...", + "...ab->...a", + "a...a->a...", + # Following 2 from # https://stackoverflow.com/a/19203475/1611416 + "...abc,...abcd->...d", + "ab...,b->ab...", + # https://github.com/dask/dask/pull/3412#discussion_r182413444 + "aa->a", + "ab,ab,c->c", + "aab,bc->ac", + "aab,bcc->ac", + "fdf,cdd,ccd,afe->ae", + "fff,fae,bef,def->abd", + ] + # TODO(wangpeng): Add jnp.bool_ to dtype list + for dtype in [jnp.float32, jnp.int32, jnp.complex64] + ) + def test_from_dask(self, einstr, dtype): + r = jtu.rand_default() + if "->" in einstr: + input_str, _ = einstr.split("->") + else: + input_str = einstr + input_names = input_str.split(",") + + dims = itertools.cycle([2, 3, 4]) + shapes = defaultdict(lambda: next(dims)) + input_shapes = [ + tuple(shapes[c] for c in names.replace("...", "01")) + for names in input_names + ] + operands = [r(shape, dtype) for shape in input_shapes] + + self._check(einstr, *operands) + + def test_ordered_front_batch_dim_case(self): + x = np.ones((1, 8, 20, 4)) + y = np.ones((1, 8, 20, 4)) + s = "ijkl,ijml->ijkm" + self._check(s, x, y) + + # pylint: disable=invalid-name + def test_einsum_path(self): + # just check examples from np.einsum_path docstring + a = self.rng().rand(2, 2) + b = self.rng().rand(2, 5) + c = self.rng().rand(5, 2) + + path_info = np.einsum_path("ij,jk,kl->il", a, b, c, optimize="greedy") + self.assertEqual(str(path_info[0]), "['einsum_path', (1, 2), (0, 1)]") + self.assertEqual( + path_info[1].split("\n")[0], " Complete contraction: ij,jk,kl->il" + ) + + # check this doesn't crash + I = self.rng().rand(10, 10, 10, 10) + C = self.rng().rand(10, 10) + np.einsum_path("ea,fb,abcd,gc,hd->efgh", C, C, I, C, C, optimize="greedy") + + @jtu.disable + def test_einsum_kpmurphy_example(self): + # code from an email with @murphyk + N = 2 + C = 3 + D = 4 + K = 5 + T = 6 + r = self.rng() + S = r.randn(N, T, K) + W = r.randn(K, D) + V = r.randn(D, C) + L = np.zeros((N, C)) + for n in range(N): + for c in range(C): + s = 0 + for d in range(D): + for k in range(K): + for t in range(T): + s += S[n, t, k] * W[k, d] * V[d, c] + L[n, c] = s + + path = jnp.einsum_path("ntk,kd,dc->nc", S, W, V, optimize="optimal")[0] + rtol = 1e-2 if jtu.device_under_test() == "tpu" else None + self.assertAllClose( + L, + jnp.einsum("ntk,kd,dc->nc", S, W, V, optimize=path), + check_dtypes=False, + rtol=rtol, + ) + + # pylint: enable=invalid-name + + @jtu.disable + def test_contraction_broadcasting(self): + r = self.rng() + x = r.randn(3, 4, 5) + y = r.randn(3, 1, 6) + s = "cij,cjk->cik" + self._check(s, x, y) + + @jtu.disable + def test_batch_broadcasting(self): + r = self.rng() + x = r.randn(1, 4, 5) + y = r.randn(3, 5, 6) + s = "cij,cjk->cik" + self._check(s, x, y) + + @jtu.disable + def test_batch_and_contraction_broadcasting(self): + r = self.rng() + x = r.randn(1, 4, 5) + y = r.randn(3, 1, 6) + s = "cij,cjk->cik" + self._check(s, x, y) + + @jtu.disable + def test_broadcasting_issue_2189(self): + r = self.rng() + x = r.randn(2, 1, 3, 3) + y = r.randn(2, 4, 3) + s = "...ij,...j" + self._check(s, x, y) + + +if __name__ == "__main__": + tf.enable_v2_behavior() + absltest.main() diff --git a/tests/fastmath/jax/lax_numpy_indexing_test.py b/tests/fastmath/jax/lax_numpy_indexing_test.py new file mode 100644 index 000000000..c261fa56a --- /dev/null +++ b/tests/fastmath/jax/lax_numpy_indexing_test.py @@ -0,0 +1,1331 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import enum +import itertools + +from functools import partial + +import numpy as onp +import tensorflow.compat.v2 as tf + +from absl.testing import absltest, parameterized + +import tests.fastmath.jax.utils as jtu +import trax.tf.extensions as npe +import trax.tf.numpy as jnp + +from tests.fastmath.jax.config import config + +config.parse_flags_with_absl() + + +# We disable the whitespace continuation check in this file because otherwise it +# makes the test name formatting unwieldy. +# pylint: disable=bad-continuation +# We also disable undefined-variable till we start enabling tests. +# pylint: disable=undefined-variable + + +def subvals(lst, replace): + lst = list(lst) + for i, v in replace: + lst[i] = v + return tuple(lst) + + +float_dtypes = [onp.float32, onp.float64] +int_dtypes = [onp.int32, onp.int64] +bool_types = [onp.bool_] +default_dtypes = float_dtypes + int_dtypes +all_dtypes = float_dtypes + int_dtypes + bool_types + +IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"]) + + +suppress_deprecated_indexing_warnings = partial( + jtu.ignore_warning, category=FutureWarning, message="Using a non-tuple sequence.*" +) + + +STATIC_INDEXING_TESTS = [ + ( + "OneIntIndex", + [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ], + ), + ( + "TwoIntIndices", + [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ], + ), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ( + "OneSliceIndex", + [ + IndexSpec(shape=(10,), indexer=slice(1, 3)), + IndexSpec(shape=(10,), indexer=slice(1, -1)), + IndexSpec(shape=(10,), indexer=slice(None, -1)), + IndexSpec(shape=(10,), indexer=slice(None, None, None)), + IndexSpec(shape=(10, 8), indexer=slice(1, 3)), + IndexSpec(shape=(10, 8), indexer=slice(1, None)), + IndexSpec(shape=(10, 8), indexer=slice(None, 3)), + IndexSpec(shape=(10, 8), indexer=slice(-3, None)), + ], + ), + ( + "OneSliceIndexNegativeStride", + [ + IndexSpec(shape=(10,), indexer=slice(3, 1, -1)), + IndexSpec(shape=(10,), indexer=slice(1, 8, -1)), # empty result + IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10,), indexer=slice(None, None, -1)), + IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1)), + IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1)), # empty result + IndexSpec(shape=(10, 8), indexer=slice(None, None, -1)), + ], + ), + ( + "OneSliceIndexNonUnitStride", + [ + IndexSpec(shape=(10,), indexer=slice(0, 8, 2)), + IndexSpec(shape=(10,), indexer=slice(0, 8, 3)), + IndexSpec(shape=(10,), indexer=slice(1, 3, 2)), + IndexSpec(shape=(10,), indexer=slice(1, None, 2)), + IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, 2)), + IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, -2)), + ], + ), + ( + "TwoSliceIndices", + [ + IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2))), + IndexSpec(shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None))), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2))), + ], + ), + ( + "OneColonIndex", + [ + IndexSpec(shape=(3,), indexer=slice(None)), + IndexSpec(shape=(3, 4), indexer=slice(None)), + ], + ), + ( + "MultipleColonIndices", + [ + IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), + ], + ), + ( + "MixedSliceIndices", + [ + IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2))), + IndexSpec(shape=(10, 4), indexer=(1, slice(None))), + ], + ), + ( + "EllipsisIndex", + [ + IndexSpec(shape=(3,), indexer=Ellipsis), + IndexSpec(shape=(3, 4), indexer=Ellipsis), + IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), + ], + ), + ( + "NoneIndex", + [ + IndexSpec(shape=(), indexer=None), + IndexSpec(shape=(), indexer=(None, None)), + IndexSpec(shape=(), indexer=(Ellipsis, None)), + IndexSpec(shape=(3,), indexer=None), + IndexSpec(shape=(3, 4), indexer=None), + IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), + IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), + ], + ), + ( + "EmptyIndex", + [ + IndexSpec(shape=(), indexer=()), + IndexSpec(shape=(3,), indexer=()), + IndexSpec(shape=(3, 4), indexer=()), + ], + ), +] + +STATIC_INDEXING_GRAD_TESTS = [ + ( + "OneIntIndex", + [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ], + ), + ( + "TwoIntIndices", + [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ], + ), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ( + "OneSliceIndex", + [ + IndexSpec(shape=(5,), indexer=slice(1, 3)), + IndexSpec(shape=(5,), indexer=slice(1, -1)), + IndexSpec(shape=(5,), indexer=slice(None, -1)), + IndexSpec(shape=(5,), indexer=slice(None, None, None)), + IndexSpec(shape=(5, 4), indexer=slice(1, 3)), + IndexSpec(shape=(5, 4), indexer=slice(1, None)), + IndexSpec(shape=(5, 4), indexer=slice(None, 3)), + IndexSpec(shape=(5, 4), indexer=slice(-3, None)), + ], + ), + ( + "TwoSliceIndices", + [ + IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4), indexer=(slice(1, None), slice(None, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, None), slice(0, 2))), + ], + ), + ( + "OneColonIndex", + [ + IndexSpec(shape=(3,), indexer=slice(None)), + IndexSpec(shape=(3, 4), indexer=slice(None)), + ], + ), + ( + "MultipleColonIndices", + [ + IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), + ], + ), + ( + "MixedSliceIndices", + [ + IndexSpec(shape=(5, 4), indexer=(slice(None), slice(0, 2))), + IndexSpec(shape=(5, 4), indexer=(1, slice(None))), + ], + ), + ( + "EllipsisIndex", + [ + IndexSpec(shape=(3,), indexer=Ellipsis), + IndexSpec(shape=(3, 4), indexer=Ellipsis), + IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), + ], + ), + ( + "NoneIndex", + [ + IndexSpec(shape=(), indexer=None), + IndexSpec(shape=(), indexer=(None, None)), + IndexSpec(shape=(), indexer=(Ellipsis, None)), + IndexSpec(shape=(3,), indexer=None), + IndexSpec(shape=(3, 4), indexer=None), + IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), + IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), + ], + ), + # TODO(mattjj): these fail for uninteresting dtype reasons + # ("EmptyIndex", + # [IndexSpec(shape=(), indexer=()), + # IndexSpec(shape=(3,), indexer=()), + # IndexSpec(shape=(3, 4), indexer=()), + # ]), +] + +ADVANCED_INDEXING_TESTS = [ + ( + "One1DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([0, 1])), + IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), + IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), + IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), + ], + ), + ( + "One2DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])), + IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])), + IndexSpec( + shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]]) + ), + ], + ), + ( + "Two1DIntArrayIndicesNoBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), onp.array([1, 2]))), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2, 0, 1]), onp.array([-1, 0, -1, 2])), + ), + ], + ), + ( + "Two1DIntArrayIndicesWithBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), onp.array([1, 2]))), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([[0, 2, 0, 1]]), onp.array([-1, 0, -1, 2])), + ), + ], + ), + ( + "TupleOfListsOfPythonInts", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]])), + ], + ), + ( + "TupleOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, onp.array([[2, 3, 0, 3]]))), + ], + ), + ( + "TupleOfListsOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), + IndexSpec( + shape=(3, 4, 5), indexer=([[0], [-1]], onp.array([[2, 3, 0, 3]])) + ), + ], + ), +] + +ADVANCED_INDEXING_TESTS_NO_REPEATS = [ + ( + "One1DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([0, 1])), + IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 0])), + IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), + IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), + # Fails with a TF/XLA error. + # IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), + ], + ), + ( + "One2DIntArrayIndex", + [ + IndexSpec(shape=(3,), indexer=onp.array([[0, 1]])), + IndexSpec(shape=(6, 6), indexer=onp.array([[1, 2, 0], [3, 4, -1]])), + ], + ), + ( + "Two1DIntArrayIndicesNoBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), onp.array([1, 2]))), + IndexSpec( + shape=(4, 5, 6), + indexer=(onp.array([0, 2, 1, 3]), onp.array([-1, 0, -2, 1])), + ), + ], + ), + ( + "Two1DIntArrayIndicesWithBroadcasting", + [ + IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), onp.array([1, 2]))), + IndexSpec( + shape=(4, 5, 6), + indexer=(onp.array([[0, 2, -1, 1]]), onp.array([-1, 0, -2, 2])), + ), + ], + ), + ( + "TupleOfListsOfPythonInts", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]])), + ], + ), + ( + "TupleOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, onp.array([[2, 3, 0]]))), + ], + ), + ( + "TupleOfListsOfPythonIntsAndIntArrays", + [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], onp.array([[2, 3, 0]]))), + ], + ), +] + +MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [ + ( + "SlicesAndOneIntArrayIndex", + [ + IndexSpec(shape=(2, 3), indexer=(onp.array([0, 1]), slice(1, 2))), + IndexSpec(shape=(2, 3), indexer=(slice(0, 2), onp.array([0, 2]))), + IndexSpec( + shape=(3, 4, 5), indexer=(Ellipsis, onp.array([0, 2]), slice(None)) + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([[0, 2], [1, 3]]), slice(None)), + ), + ], + ), + ( + "SlicesAndTwoIntArrayIndices", + [ + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([0, 2]), onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), Ellipsis, onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), onp.array([-1, 2]), Ellipsis), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), onp.array([-1, 2]), slice(1, 3)), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), slice(1, 3), onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=( + onp.array([0, 2, -2]), + slice(None, None, 2), + onp.array([-1, 2, 1]), + ), + ), + ], + ), + ( + "NonesAndIntArrayIndices", + [ + IndexSpec( + shape=(3, 4, 5), indexer=(onp.array([0, 2]), None, onp.array([-1, 2])) + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(onp.array([0, 2]), None, None, onp.array([-1, 2])), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([0, 2]), None, None, onp.array([-1, 2])), + ), + ], + ), + ( + "IntArrayWithInt32Type", + [IndexSpec(shape=(3, 4), indexer=(Ellipsis, onp.array(1, dtype=onp.int32)))], + ), +] + +MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [ + ( + "SlicesAndOneIntArrayIndex", + [ + IndexSpec( + shape=(3, 4, 5), + indexer=(Ellipsis, onp.array([[0, 2], [1, 1]]), slice(None)), + ), + ], + ), + ( + "SlicesAndTwoIntArrayIndices", + [ + IndexSpec( + shape=(3, 4, 5), + indexer=( + onp.array([0, 2, -2]), + slice(None, None, 2), + onp.array([-1, 2, -1]), + ), + ), + IndexSpec( + shape=(3, 4, 5), + indexer=( + onp.array([[0, 2], [2, 0]]), + Ellipsis, + onp.array([[1, 0], [1, 0]]), + ), + ), + ], + ), +] + + +def dynamic_slice_reference(operand, start_indices, slice_sizes): + out = onp.zeros(slice_sizes, dtype=operand.dtype) + idx = tuple( + slice(start, start + size) for start, size in zip(start_indices, slice_sizes) + ) + section = operand[idx] + out[tuple(slice(None, stop) for stop in section.shape)] = section + return out + + +def dynamic_update_slice_reference(operand, update, start_indices): + slices = tuple(map(slice, start_indices, onp.add(start_indices, update.shape))) + updated_operand = onp.copy(operand) + updated_operand[slices] = update + return updated_operand + + +class IndexingTest(jtu.TestCase): + """Tests for Numpy indexing translation rules.""" + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_indexer={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testStaticIndexing(self, shape, dtype, rng_factory, indexer): + # TODO(rohanj): Revisit passing in self.rng() to this to customize further. + # This would need updating lax_numpy_test as well. + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype)] + onp_fun = lambda x: x[indexer] + jnp_fun = lambda x: jnp.asarray(x)[indexer] + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def _ReplaceSlicesWithTuples(self, idx): + """Helper method to replace slices with tuples for dynamic indexing args.""" + if isinstance(idx, slice): + triple = idx.start, idx.stop, idx.step + isnone = [i for i, elt in enumerate(triple) if elt is None] + zeros = itertools.repeat(0) + nones = itertools.repeat(None) + out = subvals(triple, zip(isnone, zeros)) + return out, lambda out: slice(*subvals(out, zip(isnone, nones))) + elif isinstance(idx, (tuple, list)) and idx: + t = type(idx) + elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) + return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) + else: + return idx, lambda x: x + + @parameterized.named_parameters( + { + "testcase_name": "{}_inshape={}_indexer={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in [ + ( + "OneSliceIndex", + [ + IndexSpec(shape=(5,), indexer=slice(1, 3)), + IndexSpec(shape=(5, 4), indexer=slice(1, 3)), + ], + ), + ( + "TwoSliceIndices", + [ + IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), + ], + ), + ( + "NonUnitStrides", + [ + IndexSpec(shape=(3,), indexer=slice(None, None, -1)), + IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), + IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)), + ], + ), + ( + "OnlyStartOrStopDynamic", + [ + IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), + ], + ), + ] + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testDynamicIndexingWithSlices(self, shape, dtype, rng_factory, indexer): + rng = rng_factory() + unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) + + def onp_fun(x, unpacked_indexer): + indexer = pack_indexer(unpacked_indexer) + return x[indexer] + + jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) + + args_maker = lambda: [rng(shape, dtype), unpacked_indexer] + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + # TODO(wangpeng): check_xla_forced_compile is turned off because some + # compile-time-constant requirements are violated. Investigate and turn it + # on. + self._CompileAndCheck( + jnp_fun, + args_maker, + check_dtypes=True, + check_eval_on_shapes=False, + check_incomplete_shape=True, + check_xla_forced_compile=False, + ) + + @parameterized.named_parameters( + { + "testcase_name": "{}_inshape={}_indexer={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in [ + ( + "OneIntIndex", + [ + IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ], + ), + ( + "TwoIntIndices", + [ + IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ], + ), + ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ] + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): + # TODO(rohanj): Revisit passing in self.rng() to this to customize further. + # This would need updating lax_numpy_test as well. + rng = rng_factory() + unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) + + def onp_fun(x, unpacked_indexer): + indexer = pack_indexer(unpacked_indexer) + return x[indexer] + + jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) + + args_maker = lambda: [rng(shape, dtype), unpacked_indexer] + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @parameterized.named_parameters( + { + "testcase_name": "_{}_inshape={}_indexer={}".format( # pylint: disable=g-complex-comprehension + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "name": name, + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in ADVANCED_INDEXING_TESTS + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testAdvancedIntegerIndexing(self, name, shape, dtype, rng_factory, indexer): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), indexer] + onp_fun = lambda x, idx: x[idx] + jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) + + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + # TODO(wangpeng): check_xla_forced_compile is turned off for + # ListOfPythonIntsAndIntArrays because it throws "The number of output + # elements has to equal to number of input elements that are sliced when + # input indices are not constant". Investigate and turn it on. + check_xla = name != "ListOfPythonIntsAndIntArrays" + self._CompileAndCheck( + jnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_xla_forced_compile=check_xla, + ) + + @parameterized.named_parameters( + { + "testcase_name": "_{}_inshape={}_indexer={}".format( # pylint: disable=g-complex-comprehension + name, jtu.format_shape_dtype_string(shape, dtype), indexer + ), + "name": name, + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + } + for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS + for shape, indexer in index_specs + for dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + def testMixedAdvancedIntegerIndexing( + self, name, shape, dtype, rng_factory, indexer + ): + rng = rng_factory() + indexer_with_dummies = [ + e if isinstance(e, onp.ndarray) else () for e in indexer + ] + substitutes = [ + (i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray) + ] + args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] + + def np_fun(x, indexer_with_dummies): + idx = type(indexer)(subvals(indexer_with_dummies, substitutes)) + return x[idx] + + jnp_fun = lambda x, idx: np_fun(jnp.asarray(x), idx) + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + # TODO(wangpeng): check_xla_forced_compile is turned off for + # IntArrayWithInt32Type because it throws "The number of output elements has + # to equal to number of input elements that are sliced when input indices + # are not constant". Investigate and turn it on. + check_xla = name != "IntArrayWithInt32Type" + self._CompileAndCheck( + jnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_xla_forced_compile=check_xla, + ) + + def testAdvancedIndexingManually(self): + x = onp.random.RandomState(0).randn(3, 4, 5) + index_array = onp.array([0, 2, -1, 0]) + + op = lambda x, index_array: x[..., index_array, :] + cop = npe.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2, check_dtypes=True) + + op = lambda x, index_array: x[..., index_array, :, index_array, None] + cop = npe.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2, check_dtypes=True) + + op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] + cop = npe.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2, check_dtypes=True) + + # Note that we don't currently allow __iter__ in graph mode. So this test only + # iterates over eager tensor. + def testUnpacking(self): + def foo(x): + a, b, c = x + return a + b + c + + a1 = foo(onp.arange(3)) + a2 = foo(jnp.arange(3)) + + self.assertAllClose(a1, a2, check_dtypes=True) + + def testBooleanIndexingArray1D(self): + idx = onp.array([True, True, False]) + x = jnp.asarray(onp.arange(3)) + ans = x[idx] + expected = onp.arange(3)[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList1D(self): + idx = [True, True, False] + x = jnp.asarray(onp.arange(3)) + ans = x[idx] + expected = onp.arange(3)[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingArray2DBroadcast(self): + idx = onp.array([True, True, False, True]) + x = onp.arange(8).reshape(4, 2) + ans = jnp.asarray(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList2DBroadcast(self): + idx = [True, True, False, True] + x = onp.arange(8).reshape(4, 2) + ans = jnp.asarray(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingArray2D(self): + idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) + x = onp.arange(8).reshape(4, 2) + ans = jnp.asarray(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingDynamicShape(self): + x = onp.zeros(3) + i = onp.array([True, True, False]) + ans = x[i] + expected = jnp.asarray(x)[i] + self.assertAllClose(ans, expected, check_dtypes=True) + + def testIssue187(self): + x = jnp.ones((5, 5)) + x[[0, 2, 4], [0, 2, 4]] # doesn't crash + + x = onp.arange(25).reshape((5, 5)) + ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) + expected = x[[0, 2, 4], [0, 2, 4]] + self.assertAllClose(ans, expected, check_dtypes=False) + + # TODO(agarwal): Fix this use case. + @jtu.disable + def testIndexingEmptyDimension(self): + # Issue 2671: XLA error when indexing into dimension of size 0 + x = jnp.ones((2, 0)) + # The following work, even on axis 1 of size 0 + _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] + + with self.assertRaisesRegex( + IndexError, "index .* is out of bounds for axis .* with size 0" + ): + _ = onp.ones((2, 0))[0, 0] # The numpy error + with self.assertRaisesRegex( + IndexError, "index is out of bounds for axis .* with size 0" + ): + _ = x[0, 0] # JAX indexing + with self.assertRaisesRegex( + IndexError, "index is out of bounds for axis .* with size 0" + ): + npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit + + def testBooleanIndexingWithEmptyResult(self): + # based on a TensorFlow Probability test that started failing after #1623 + x = jnp.array([-1]) + mask = jnp.array([False]) + ans = x[mask] # doesn't crash + + expected = onp.array([-1])[onp.array([False])] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testFloatIndexingError(self): + error_regex = "only integers, slices.*are valid indices" + # Verify onp behavior + with self.assertRaisesRegex(IndexError, error_regex): + _ = onp.zeros((2, 2))[(0, 0.0)] + # Test jnp + with self.assertRaisesRegex(IndexError, error_regex): + jnp.zeros(2)[0.0] + with self.assertRaisesRegex(IndexError, error_regex): + jnp.zeros((2, 2))[(0, 0.0)] + # Test with jit + with self.assertRaisesRegex(IndexError, error_regex): + npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.0)) + + def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 + array = jnp.ones(5) + self.assertAllClose(array, array[:10], check_dtypes=True) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + size_indices, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "size_indices": size_indices, + "rng_factory": rng_factory, + } + for shape, start_indices, size_indices in [ + [(3,), onp.array((1,)), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(5, 3), onp.array((1, 1)), (3, 1)], + [(7, 5, 3), onp.array((4, 1, 0)), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicSlice(self, shape, dtype, start_indices, size_indices, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] + op = lambda x, starts: npe.dynamic_slice(x, starts, size_indices) + self._CompileAndCheck(op, args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + size_indices, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "size_indices": size_indices, + "rng_factory": rng_factory, + } + for shape, start_indices, size_indices in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicSliceAgainstNumpy( + self, shape, dtype, start_indices, size_indices, rng_factory + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] + op = lambda x, s: npe.dynamic_slice(x, s, size_indices) + numpy_op = lambda x, s: dynamic_slice_reference(x, s, size_indices) + self._CheckAgainstNumpy(numpy_op, op, args_maker) + + def testDynamicSliceInDim(self): + rng = jtu.rand_default() + x = rng((6, 7), onp.int32) + self.assertAllClose( + npe.dynamic_slice_in_dim(x, 2, 3), x[2:5], check_dtypes=True + ) + + +def _broadcastable_shapes(shape): + """Returns all shapes that broadcast to `shape`.""" + + def f(rshape): + yield [] + if rshape: + for s in f(rshape[1:]): + yield rshape[0:1] + s + if rshape[0] != 1: + for s in f(rshape[1:]): + yield [1] + s + + for x in f(list(reversed(shape))): + yield list(reversed(x)) + + +def _update_shape(shape, indexer): + return onp.zeros(shape)[indexer].shape + + +class UpdateOps(enum.Enum): + UPDATE = 0 + ADD = 1 + # MUL = 2 + MIN = 3 + MAX = 4 + + def np_fn(op, indexer, x, y): # pylint: disable=no-self-argument + x = x.copy() + x[indexer] = { + UpdateOps.UPDATE: lambda: y, + UpdateOps.ADD: lambda: x[indexer] + y, + # UpdateOps.MUL: lambda: x[indexer] * y, + UpdateOps.MIN: lambda: onp.minimum(x[indexer], y), + UpdateOps.MAX: lambda: onp.maximum(x[indexer], y), + }[op]() + return x + + def tfnp_fn(op, indexer, x, y): # pylint: disable=no-self-argument + return { + UpdateOps.UPDATE: npe.index_update, + UpdateOps.ADD: npe.index_add, + # UpdateOps.MUL: npe.index_mul, + UpdateOps.MIN: npe.index_min, + UpdateOps.MAX: npe.index_max, + }[op](x, indexer, y) + + +# a test to workaround b/123559667 +def has_non_trivial_stride(indexer): + def has(idx): + return isinstance(idx, slice) and idx.step not in (1, -1, None) + + return any(has(idx) for idx in tf.nest.flatten(indexer)) + + +class IndexedUpdateTest(jtu.TestCase): + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for op in UpdateOps + for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testStaticIndexing( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) + # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing + # xla_context 0-th output from". Investigate. + check_xla = not has_non_trivial_stride(indexer) and not ( # b/123559667 + isinstance(indexer, slice) and indexer.stop == 8 and indexer.step == -1 + ) + self._CompileAndCheck( + tfnp_fn, + args_maker, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS + for shape, indexer in index_specs + for op in UpdateOps + for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testAdvancedIndexing( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) + self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) + + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + for shape, indexer in index_specs + for op in UpdateOps + for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in all_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testMixedAdvancedIndexing( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) + check_xla = not has_non_trivial_stride(indexer) # b/123559667 + self._CompileAndCheck( + tfnp_fn, + args_maker, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @parameterized.named_parameters( + jtu.cases_from_list( + { # pylint: disable=g-complex-comprehension + "testcase_name": "_{}_{}_{}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), + indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), + op.name, + ), + "shape": shape, + "dtype": dtype, + "rng_factory": rng_factory, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op, + } + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer in index_specs + for op in [UpdateOps.ADD, UpdateOps.UPDATE] + for dtype in float_dtypes + for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) + for update_dtype in float_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testStaticIndexingGrads( + self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op + ): + rng = rng_factory() + tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) + x = rng(shape, dtype) + y = rng(update_shape, update_dtype) + self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.0) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + update_shape, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "update_shape": update_shape, + "rng_factory": rng_factory, + } + for shape, start_indices, update_shape in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicUpdateSlice( + self, shape, dtype, start_indices, update_shape, rng_factory + ): + rng = rng_factory() + + def args_maker(): + return [ + rng(shape, dtype), + rng(update_shape, dtype), + onp.array(start_indices), + ] + + # update's shape must be fully known. + # TODO(wangpeng): Support turning off check_incomplete_shape for individual + # arguments. + self._CompileAndCheck( + npe.dynamic_update_slice, args_maker, check_incomplete_shape=False + ) + + @parameterized.named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension + jtu.format_shape_dtype_string(shape, dtype), + start_indices, + update_shape, + ), + "shape": shape, + "dtype": dtype, + "start_indices": start_indices, + "update_shape": update_shape, + "rng_factory": rng_factory, + } + for shape, start_indices, update_shape in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(5, 3), (1, -2), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + [(), (), ()], + ] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testDynamicUpdateSliceAgainstNumpy( + self, shape, dtype, start_indices, update_shape, rng_factory + ): + rng = rng_factory() + + def args_maker(): + return [ + rng(shape, dtype), + rng(update_shape, dtype), + onp.array(start_indices), + ] + + self._CheckAgainstNumpy( + dynamic_update_slice_reference, npe.dynamic_update_slice, args_maker + ) + + def testDynamicUpdateSliceInDim(self): + rng = jtu.rand_default() + x = rng((6, 7), onp.int32) + y = rng((3, 7), onp.int32) + z = x.copy() + z[2:5] = y + self.assertAllClose( + npe.dynamic_update_slice_in_dim(x, y, 2, 0), z, check_dtypes=True + ) + + +if __name__ == "__main__": + tf.config.set_soft_device_placement(False) + jnp.enable_numpy_behavior() + absltest.main() diff --git a/tests/fastmath/jax/lax_numpy_test.py b/tests/fastmath/jax/lax_numpy_test.py new file mode 100644 index 000000000..8504eac47 --- /dev/null +++ b/tests/fastmath/jax/lax_numpy_test.py @@ -0,0 +1,4772 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import functools +import itertools +import operator +import unittest +import warnings + +from functools import partial +from unittest import SkipTest + +import numpy as onp +import six +import tensorflow.compat.v2 as tf + +from absl.testing import absltest, parameterized + +import tests.fastmath.jax.utils as jtu +import trax.tf.extensions as npe +import trax.tf.numpy as lnp + +from tests.fastmath.jax.config import FLAGS, config + +config.parse_flags_with_absl() + + +nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] +nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes # type: ignore +empty_array_shapes = [ + (0,), + (0, 4), + (3, 0), +] + +scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] +array_shapes = nonempty_array_shapes + empty_array_shapes +nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes +nonempty_shapes = scalar_shapes + nonempty_array_shapes +all_shapes = scalar_shapes + array_shapes + +# TODO(wangpeng): float_dtypes = [lnp.bfloat16, onp.float16, onp.float32, +# onp.float64] +float_dtypes = [onp.float16, onp.float32, onp.float64] +complex_dtypes = [onp.complex64, onp.complex128] +int_dtypes = [onp.int32, onp.int64] +unsigned_dtypes = [onp.uint32, onp.uint64] +bool_dtypes = [onp.bool_] +default_dtypes = float_dtypes + int_dtypes +inexact_dtypes = float_dtypes + complex_dtypes +number_dtypes = float_dtypes + complex_dtypes + int_dtypes +all_dtypes = number_dtypes + bool_dtypes + + +python_scalar_dtypes = [lnp.bool_, lnp.int_, lnp.float_, lnp.complex_] + + +def _valid_dtypes_for_shape(shape, dtypes): + # Not all (shape, dtype) pairs are valid. In particular, Python scalars only + # have one type in each category (float, bool, etc.) + if shape is jtu.PYTHON_SCALAR_SHAPE: + return [t for t in dtypes if t in python_scalar_dtypes] + return dtypes + + +def _shape_and_dtypes(shapes, dtypes): + for shape in shapes: + for dtype in _valid_dtypes_for_shape(shape, dtypes): + yield (shape, dtype) + + +OpRecord = collections.namedtuple( + "OpRecord", + [ + "name", + "nargs", + "dtypes", + "shapes", + "rng_factory", + "diff_modes", + "test_name", + "check_dtypes", + "tolerance", + "inexact", + "check_incomplete_shape", + ], +) + + +def op_record( + name, + nargs, + dtypes, + shapes, + rng_factory, + diff_modes, + test_name=None, + check_dtypes=True, + tolerance=None, + inexact=False, + check_incomplete_shape=True, +): + test_name = test_name or name + return OpRecord( + name, + nargs, + dtypes, + shapes, + rng_factory, + diff_modes, + test_name, + check_dtypes, + tolerance, + inexact, + check_incomplete_shape, + ) + + +def minus(a, b): + return [x for x in a if x not in b] + + +JAX_ONE_TO_ONE_OP_RECORDS = [ + op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), + op_record( + "exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "float_power", + 2, + inexact_dtypes, + all_shapes, + partial(jtu.rand_default, scale=1), + ["rev"], + tolerance={ + # TODO(wangpeng): lnp.bfloat16: 1e-2, + onp.float32: 1e-3, + onp.float64: 1e-12, + onp.complex64: 2e-4, + onp.complex128: 1e-12, + }, + check_dtypes=False, + ), + op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "greater", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "greater_equal", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "less", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "less_equal", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []), + op_record( + "maximum", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf, + [], + ), + op_record( + "minimum", + 2, + minus(all_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf, + [], + ), + op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "nextafter", + 2, + [f for f in float_dtypes if f not in (lnp.bfloat16, onp.float16)], + all_shapes, + jtu.rand_default, + ["rev"], + inexact=True, + tolerance=0, + ), + op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), + op_record( + "array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"] + ), + op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "signbit", + 1, + default_dtypes + bool_dtypes, + all_shapes, + jtu.rand_some_inf_and_nan, + ["rev"], + ), + op_record( + "sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record( + "cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record( + "tan", + 1, + number_dtypes, + all_shapes, + partial(jtu.rand_uniform, -1.5, 1.5), + ["rev"], + tolerance={onp.complex64: 3e-5, onp.complex128: 4e-14}, + inexact=True, + ), + # TODO(wangpeng): Add float16 support + op_record( + "sinh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_default, + ["rev"], + inexact=True, + ), + op_record( + "cosh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_default, + ["rev"], + inexact=True, + ), + # TODO(b/142975473): on CPU, tanh for complex128 is only accurate to + # ~float32 precision. + # TODO(b/143135720): on GPU, tanh has only ~float32 precision. + op_record( + "tanh", + 1, + number_dtypes, + all_shapes, + jtu.rand_default, + ["rev"], + tolerance={onp.float64: 1e-7, onp.complex128: 1e-7}, + inexact=True, + ), + op_record( + "arcsin", + 1, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arccos", + 1, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arctan", + 1, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arctan2", + 2, + minus(float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), + op_record( + "arcsinh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_positive, + ["rev"], + inexact=True, + ), + op_record( + "arccosh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_positive, + ["rev"], + inexact=True, + ), + op_record( + "arctanh", + 1, + minus(number_dtypes, [onp.float16]), + all_shapes, + jtu.rand_small, + ["rev"], + inexact=True, + ), +] + +JAX_COMPOUND_OP_RECORDS = [ + # angle has inconsistent 32/64-bit return types across numpy versions. + op_record( + "angle", + 1, + number_dtypes, + all_shapes, + jtu.rand_default, + [], + check_dtypes=False, + inexact=True, + ), + op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "cbrt", 1, default_dtypes, all_shapes, jtu.rand_default, ["rev"], inexact=True + ), + op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "divide", + 2, + number_dtypes, + all_shapes, + jtu.rand_nonzero, + ["rev"], + inexact=six.PY3, + ), + op_record( + "divmod", + 2, + minus(int_dtypes + float_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + ), + op_record( + "exp2", + 1, + number_dtypes, + all_shapes, + jtu.rand_default, + ["rev"], + tolerance={ + # TODO(wangpeng): lnp.bfloat16: 2e-2, + onp.float16: 1e-2 + }, + inexact=True, + ), + # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 + # precision. + op_record( + "expm1", + 1, + number_dtypes, + all_shapes, + jtu.rand_positive, + [], + test_name="expm1_large", + tolerance={onp.float64: 1e-8}, + inexact=True, + ), + op_record( + "expm1", + 1, + number_dtypes, + all_shapes, + jtu.rand_small_positive, + [], + tolerance={onp.float64: 1e-8}, + inexact=True, + ), + op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "floor_divide", + 2, + minus(number_dtypes, complex_dtypes), + all_shapes, + jtu.rand_nonzero, + ["rev"], + ), + op_record( + "heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], inexact=True + ), + op_record( + "hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], inexact=True + ), + op_record( + "kron", + 2, + number_dtypes, + nonempty_shapes, + jtu.rand_default, + [], + check_incomplete_shape=False, + ), + op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record( + "isfinite", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf_and_nan, + [], + ), + op_record( + "isinf", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf_and_nan, + [], + ), + op_record( + "isnan", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_inf_and_nan, + [], + ), + op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record( + "log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record( + "log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record( + "log1p", + 1, + number_dtypes, + all_shapes, + jtu.rand_positive, + [], + test_name="log1p_large", + tolerance={onp.float64: 1e-12}, + inexact=True, + ), + op_record( + "log1p", + 1, + number_dtypes, + all_shapes, + jtu.rand_small_positive, + [], + tolerance={onp.float64: 1e-12}, + inexact=True, + ), + op_record( + "logaddexp", + 2, + float_dtypes, + all_shapes, + jtu.rand_some_inf_and_nan, + ["rev"], + tolerance={onp.float64: 1e-12}, + inexact=True, + ), + op_record( + "logaddexp2", + 2, + float_dtypes, + all_shapes, + jtu.rand_some_inf_and_nan, + ["rev"], + tolerance={onp.float16: 1e-2}, + inexact=True, + ), + op_record( + "polyval", + 2, + number_dtypes, + nonempty_nonscalar_array_shapes, + jtu.rand_default, + [], + check_dtypes=False, + tolerance={onp.float16: 1e-2, onp.float64: 1e-12}, + check_incomplete_shape=False, + ), + op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "power", + 2, + number_dtypes, + all_shapes, + jtu.rand_positive, + ["rev"], + tolerance={onp.complex128: 1e-14}, + ), + op_record( + "rad2deg", + 1, + float_dtypes, + all_shapes, + jtu.rand_default, + [], + tolerance={onp.float64: 5e-6}, + ), + op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), + op_record( + "remainder", + 2, + minus(default_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + tolerance={onp.float16: 1e-2}, + ), + op_record( + "mod", 2, minus(default_dtypes, [onp.float16]), all_shapes, jtu.rand_nonzero, [] + ), + op_record( + "sinc", + 1, + [t for t in number_dtypes if t != lnp.bfloat16], + all_shapes, + jtu.rand_default, + ["rev"], + tolerance={onp.complex64: 1e-5}, + inexact=True, + check_dtypes=False, + ), + op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), + op_record( + "sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], inexact=True + ), + op_record( + "transpose", + 1, + all_dtypes, + all_shapes, + jtu.rand_default, + ["rev"], + check_dtypes=False, + ), + op_record( + "true_divide", + 2, + all_dtypes, + all_shapes, + jtu.rand_nonzero, + ["rev"], + inexact=True, + ), + op_record( + "diff", + 1, + number_dtypes, + nonzerodim_shapes, + jtu.rand_default, + ["rev"], + check_incomplete_shape=False, + ), +] + +JAX_BITWISE_OP_RECORDS = [ + op_record( + "bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), + op_record( + "bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), + op_record( + "bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), + op_record( + "bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [] + ), +] + +JAX_REDUCER_RECORDS = [ + op_record( + "mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True + ), + op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "nanmean", + 1, + minus(inexact_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_some_nan, + [], + inexact=True, + ), + op_record( + "nanprod", + 1, + minus(inexact_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_nan, + [], + ), + op_record( + "nansum", + 1, + minus(number_dtypes, complex_dtypes), + all_shapes, + jtu.rand_some_nan, + [], + ), +] + +JAX_REDUCER_NO_DTYPE_RECORDS = [ + op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), + op_record( + "max", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_default, + [], + ), + op_record( + "min", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_default, + [], + ), + op_record( + "var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True + ), + op_record( + "std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], inexact=True + ), +] + +JAX_ARGMINMAX_RECORDS = [ + op_record( + "argmin", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_some_equal, + [], + ), + op_record( + "argmax", + 1, + minus(all_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_some_equal, + [], + ), +] + +JAX_OPERATOR_OVERLOADS = [ + op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []), + op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "__pow__", + 2, + inexact_dtypes, + all_shapes, + jtu.rand_positive, + [], + tolerance={onp.float32: 2e-4, onp.complex64: 2e-4, onp.complex128: 1e-14}, + ), + op_record( + "__mod__", + 2, + minus(default_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + tolerance={onp.float16: 1e-1}, + ), + op_record("__floordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), + op_record( + "__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], inexact=True + ), + op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []), + # TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2 + op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []), + # TODO(mattjj): investigate these failures + # op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + # op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + # TODO(mattjj): lshift, rshift +] + +JAX_RIGHT_OPERATOR_OVERLOADS = [ + op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + op_record( + "__rpow__", + 2, + inexact_dtypes, + all_shapes, + jtu.rand_positive, + [], + tolerance={onp.float32: 2e-4, onp.complex64: 1e-3}, + ), + op_record( + "__rmod__", + 2, + minus(default_dtypes, [onp.float16]), + all_shapes, + jtu.rand_nonzero, + [], + tolerance={onp.float16: 1e-1}, + ), + op_record("__rfloordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), + op_record( + "__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], inexact=True + ), + # op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []), + # op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), + # op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), +] + +numpy_version = tuple(map(int, onp.version.version.split("."))) +if numpy_version >= (1, 15): + JAX_COMPOUND_OP_RECORDS += [ + op_record( + "isclose", + 2, + [t for t in all_dtypes if t != lnp.bfloat16], + all_shapes, + jtu.rand_small_positive, + [], + ), + op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default, []), + op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default, []), + ] + JAX_REDUCER_NO_DTYPE_RECORDS += [ + op_record( + "ptp", + 1, + minus(number_dtypes, complex_dtypes), + nonempty_shapes, + jtu.rand_default, + [], + ), + ] + +if six.PY2: + JAX_OPERATOR_OVERLOADS += [ + op_record("__div__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + ] + JAX_RIGHT_OPERATOR_OVERLOADS += [ + op_record("__rdiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), + ] + + +CombosWithReplacement = itertools.combinations_with_replacement + + +def _dtypes_are_compatible_for_bitwise_ops(args): + if len(args) <= 1: + return True + is_signed = lambda dtype: lnp.issubdtype(dtype, onp.signedinteger) + width = lambda dtype: lnp.iinfo(dtype).bits + x, y = args + # `lnp.iinfo(dtype).bits` can't be called on bools, so we convert bools to + # ints. + if x == lnp.bool_: + x = lnp.int32 + if y == lnp.bool_: + y = lnp.int32 + if width(x) > width(y): + x, y = y, x + if x == lnp.uint32 and y == lnp.uint64: + return False + # The following condition seems a little ad hoc, but seems to capture what + # numpy actually implements. + return ( + is_signed(x) == is_signed(y) + or (width(x) == 32 and width(y) == 32) + or (width(x) == 32 and width(y) == 64 and is_signed(y)) + ) + + +def _shapes_are_broadcast_compatible(shapes): + accumulator = onp.zeros([]) + for shape in shapes: + try: + accumulator = accumulator + onp.zeros(shape) + except ValueError: + return False + return True + + +def _shapes_are_equal_length(shapes): + return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) + + +def _promote_like_lnp(fun, inexact=False): + """Decorator that promotes the arguments of `fun` to `lnp.result_type(*args)`. + + lnp and onp have different type promotion semantics; this decorator allows + tests make an onp reference implementation act more like an lnp + implementation. + """ + + def wrapper(*args, **kw): + flat_args = tf.nest.flatten(args) + if inexact and not any( + lnp.issubdtype(lnp.result_type(x).as_numpy_dtype, lnp.inexact) + for x in flat_args + ): + dtype = lnp.result_type(lnp.float_, *flat_args) + else: + dtype = lnp.result_type(*flat_args) + dtype = dtype.as_numpy_dtype + args = tf.nest.map_structure(lambda a: onp.asarray(a, dtype), args) + return fun(*args, **kw) + + return wrapper + + +def new_test(f): + def wrapper(self, *args, **kwargs): + if not FLAGS.tf_numpy_additional_tests: + self.skipTest("Newly added test is disabled, since flag is False.") + else: + f(self, *args, **kwargs) + + return wrapper + + +def named_parameters(ls): + """A version that allows an empty param list.""" + + def noop(_): + def wrapper(self, *args, **kwargs): + self.skipTest("Empty parameter list") + + return wrapper + + if isinstance(ls, (list, tuple)) and not ls: + return noop + if isinstance(ls, itertools.chain): + try: + first = next(ls) + except StopIteration: + return noop + else: + ls = itertools.chain([first], ls) + return parameterized.named_parameters(ls) + + +# TODO(wangpeng): Enable all disabled tests in this class +class LaxBackedNumpyTests(jtu.TestCase): + """Tests for LAX-backed Numpy implementation.""" + + def _GetArgsMaker(self, rng, shapes, dtypes, onp_arrays=True): + def f(): + out = [ + rng(shape, dtype or lnp.float_) for shape, dtype in zip(shapes, dtypes) + ] + return out if onp_arrays else [lnp.asarray(a) for a in out] + + return f + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "check_dtypes": rec.check_dtypes, + "tolerance": rec.tolerance, + "inexact": rec.inexact, + "check_incomplete_shape": rec.check_incomplete_shape, + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes) + ) + ) + for rec in itertools.chain( + JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS + ) + ) + ) + def testOp( + self, + onp_op, + lnp_op, + rng_factory, + shapes, + dtypes, + check_dtypes, + tolerance, + inexact, + check_incomplete_shape, + ): + # TODO(b/147769803): Remove this skipping + if lnp_op.__name__ == "kron" and shapes == ((2, 3, 4), (2, 3, 4)): + self.skipTest("Case disabled because of b/147769803") + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) + tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) + tol = functools.reduce( + jtu.join_tolerance, [tolerance, tol, jtu.default_tolerance()] + ) + self._CheckAgainstNumpy( + _promote_like_lnp(onp_op, inexact), + lnp_op, + args_maker, + check_dtypes=check_dtypes, + tol=tol, + ) + # tf.math.pow doesn't support int32/int64 on XLA (b/169191476). + check_xla = not ( + lnp_op.__name__ == "power" + and set(dtypes).intersection((onp.int32, onp.int64)) + ) + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=check_dtypes, + atol=tol, + rtol=tol, + check_incomplete_shape=check_incomplete_shape, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "name": rec.name, + "tol": rec.tolerance, + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes) + ) + ) + for rec in JAX_OPERATOR_OVERLOADS + ) + ) + def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): + rng = rng_factory() + # onp and lnp arrays have different type promotion rules; force the use of + # lnp arrays. + args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) + fun = lambda *xs: getattr(operator, name.strip("_"))(*xs) + scalar_arg = ( + jtu.PYTHON_SCALAR_SHAPE in shapes + or jtu.NUMPY_SCALAR_SHAPE in shapes + or () in shapes + ) + empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) + self._CompileAndCheck( + fun, + args_maker, + check_dtypes=True, # not scalar_arg and not empty_shape, + atol=tol, + rtol=tol, + ) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "name": rec.name, + "op_tolerance": rec.tolerance, + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes) + ) + ) + for rec in JAX_RIGHT_OPERATOR_OVERLOADS + ) + ) + def testRightOperatorOverload( + self, name, rng_factory, shapes, dtypes, op_tolerance + ): + if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: + raise SkipTest() # TODO(mattjj): clean up + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) + fun = lambda fst, snd: getattr(snd, name)(fst) + tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) + scalar_arg = ( + jtu.PYTHON_SCALAR_SHAPE in shapes + or jtu.NUMPY_SCALAR_SHAPE in shapes + or () in shapes + ) + empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) + self._CompileAndCheck( + fun, + args_maker, + check_dtypes=True, # not scalar_arg and not empty_shape, + atol=tol, + rtol=tol, + ) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes + ), + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtypes": dtypes, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + } + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(rec.shapes, rec.nargs), + ) + for dtypes in filter( + _dtypes_are_compatible_for_bitwise_ops, + CombosWithReplacement(rec.dtypes, rec.nargs), + ) + ) + for rec in JAX_BITWISE_OP_RECORDS + ) + ) + def testBitwiseOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + has_python_scalar = jtu.PYTHON_SCALAR_SHAPE in shapes + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + if onp_op == onp.bitwise_not and has_python_scalar: + # For bitwise_not with a Python `int`, npe.jit may choose a different + # dtype for the `int` from onp's choice, which may result in a different + # result value, so we skip _CompileAndCheck. + return + # Numpy does value-dependent dtype promotion on Python/numpy/array scalars + # which `jit` can't do (when np.result_type is called inside `jit`, tensor + # values are not available), so we skip dtype check in this case. + check_dtypes = not ( + set(shapes) & set([jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE, ()]) + ) + self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), + axis, + "None" if out_dtype is None else onp.dtype(out_dtype).name, + keepdims, + ), + "rng_factory": rec.rng_factory, + "shape": shape, + "dtype": dtype, + "out_dtype": out_dtype, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "axis": axis, + "keepdims": keepdims, + "inexact": rec.inexact, + } + for shape in rec.shapes + for dtype in rec.dtypes + for out_dtype in [None] + rec.dtypes + for axis in set(range(-len(shape), len(shape))) | set([None]) + for keepdims in [False, True] + ) + for rec in JAX_REDUCER_RECORDS + ) + ) + def testReducer( + self, + onp_op, + lnp_op, + rng_factory, + shape, + dtype, + out_dtype, + axis, + keepdims, + inexact, + ): + rng = rng_factory() + + def onp_fun(x): + x_cast = x if dtype != lnp.bfloat16 else x.astype(onp.float32) + t = out_dtype if out_dtype != lnp.bfloat16 else onp.float32 + return onp_op(x_cast, axis, dtype=t, keepdims=keepdims) + + onp_fun = _promote_like_lnp(onp_fun, inexact) + lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = { + onp.float16: 1e-2, + onp.float32: 1e-3, + onp.complex64: 1e-3, + onp.float64: 1e-5, + onp.complex128: 1e-5, + } + tol = jtu.tolerance(dtype, tol_spec) + tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=lnp.bfloat16 not in (dtype, out_dtype), + tol=tol, + ) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), + axis, + keepdims, + ), + "rng_factory": rec.rng_factory, + "shape": shape, + "dtype": dtype, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "axis": axis, + "keepdims": keepdims, + "inexact": rec.inexact, + } + for rec in JAX_REDUCER_NO_DTYPE_RECORDS + for shape in rec.shapes + for dtype in rec.dtypes + for axis in set(range(-len(shape), len(shape))) | set([None]) + for keepdims in [False, True] + ) + ) + def testReducerNoDtype( + self, onp_op, lnp_op, rng_factory, shape, dtype, axis, keepdims, inexact + ): + rng = rng_factory() + onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) + onp_fun = _promote_like_lnp(onp_fun, inexact) + lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis + ), + "shape": shape, + "dtype": dtype, + "axis": axis, + } + for shape in all_shapes + for dtype in all_dtypes + for axis in set(range(-len(shape), len(shape))) | set([None]) + ) + ) + def testCountNonzero(self, shape, dtype, axis): + rng = jtu.rand_some_zero() + onp_fun = lambda x: onp.count_nonzero(x, axis) + lnp_fun = lambda x: lnp.count_nonzero(x, axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "shape": shape, + "dtype": dtype, + } + for shape in all_shapes + for dtype in all_dtypes + ) + ) + def testNonzero(self, shape, dtype): + rng = jtu.rand_some_zero() + onp_fun = lambda x: onp.nonzero(x) + lnp_fun = lambda x: lnp.nonzero(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) + # The shapes of `nonzero`'s results are value-dependent, so `eval_on_shapes` + # won't return concrete shapes. + # Also, `nonzero` requires a known rank. + # Turns off XLA check because there are no XLA kernels for `Where`, which + # XLA can't support because it's output shape is dynamic. + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_eval_on_shapes=False, + check_incomplete_shape=True, + check_unknown_rank=False, + check_experimental_compile=False, + check_xla_forced_compile=False, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "{}_inshape={}_axis={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), + axis, + ), + "rng_factory": rec.rng_factory, + "shape": shape, + "dtype": dtype, + "onp_op": getattr(onp, rec.name), + "lnp_op": getattr(lnp, rec.name), + "axis": axis, + } + for rec in JAX_ARGMINMAX_RECORDS + for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) + for axis in range(-len(shape), len(shape)) + ) + ) + def testArgMinMax(self, onp_op, lnp_op, rng_factory, shape, dtype, axis): + rng = rng_factory() + if dtype == onp.complex128 and jtu.device_under_test() == "gpu": + raise unittest.SkipTest("complex128 reductions not supported on GPU") + + def onp_fun(array_to_reduce): + return onp_op(array_to_reduce, axis).astype(lnp.int_) + + def lnp_fun(array_to_reduce): + return lnp_op(array_to_reduce, axis) + + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + axes, + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "axes": axes, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for lhs_shape, rhs_shape, axes in [ + [(2,), (2,), (-1, -1, -1, None)], # scalar output + [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors + [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors + [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting + [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes + [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting + [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors + [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting + [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing + [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)], # same as before + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement( + minus(number_dtypes, complex_dtypes), 2 + ) + ) + ) + def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + axisa, axisb, axisc, axis = axes + lnp_fun = lambda a, b: lnp.cross(a, b, axisa, axisb, axisc, axis) + + def onp_fun(a, b): + a = a.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else a + b = b.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else b + out = onp.cross(a, b, axisa, axisb, axisc, axis) + return out.astype(lnp.promote_types(lhs_dtype, rhs_dtype)) + + tol_spec = { + # TODO(wangpeng): dtypes.bfloat16: 3e-1, + onp.float16: 0.15 + } + tol = max( + jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec) + ) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for name, lhs_shape, rhs_shape in [ + ("matrix-scalar", (3, 3), ()), + ("scalar-matrix", (), (3, 3)), + ("matrix-vector", (4, 5), (5,)), + ("vector-matrix", (6,), (6, 4)), + ("matrix-matrix", (3, 4), (4, 5)), + ("tensor-vector", (4, 3, 2), (2,)), + ("vector-tensor", (2,), (3, 2, 4)), + ("tensor-matrix", (4, 3, 2), (2, 5)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-tensor", (2, 3, 4), (5, 4, 1)), + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = { + onp.float16: 1e-2, + onp.float32: 1e-5, + onp.float64: 1e-14, + onp.complex128: 1e-14, + } + if jtu.device_under_test() == "tpu": + tol[onp.float32] = tol[onp.complex64] = 2e-1 + + def onp_dot(x, y): + x = x.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else x + y = y.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else y + # `onp.dot(x, y).dtype` sometimes differs from `onp.result_type(x, y)` + # (e.g. when x is float64[] and y is complex64[3,3], or when x is + # float16[3,3] and y is int64[]). We ignore this corner case and pretend + # that they agree. + return onp.dot(x, y).astype(onp.result_type(x, y)) + + self._CheckAgainstNumpy( + onp_dot, lnp.dot, args_maker, check_dtypes=True, tol=tol + ) + # We disable dtype check in the following cases because `np.dot` does + # value-dependent type promotion in those cases. + check_dtypes = () not in (lhs_shape, rhs_shape) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + self._CompileAndCheck( + lnp.dot, + args_maker, + check_dtypes=check_dtypes, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for name, lhs_shape, rhs_shape in [ + ("vector-vector", (3,), (3,)), + ("matrix-vector", (3, 3), (3,)), + ("vector-matrix", (3,), (3, 3)), + ("matrix-matrix", (3, 3), (3, 3)), + ("vector-tensor", (3,), (5, 3, 2)), + ("tensor-vector", (5, 3, 2), (2,)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-matrix", (5, 2, 3), (3, 2)), + ("tensor-tensor", (5, 3, 4), (5, 4, 1)), + ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1)), + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + + def onp_fun(x, y): + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return ( + onp.matmul(x, y).astype(dtype), + onp.array(x).__matmul__(y).astype(dtype), + onp.array(y).__rmatmul__(x).astype(dtype), + ) + + def lnp_fun(x, y): + return ( + lnp.matmul(x, y), + lnp.array(x).__matmul__(y), + lnp.array(y).__rmatmul__(x), + ) + + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = { + onp.float16: 1e-2, + onp.float32: 2e-2, + onp.float64: 1e-12, + onp.complex128: 1e-12, + } + if jtu.device_under_test() == "tpu": + tol[onp.float32] = tol[onp.complex64] = 4e-2 + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for name, lhs_shape, rhs_shape in [ + ("vector-vector", (3,), (3,)), + ("vector-matrix", (9,), (3, 3)), + ("matrix-matrix", (3, 3), (3, 3)), + ("tensor-vector", (5, 3, 2), (30,)), + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + @new_test + def testVDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = { + onp.float16: 1e-2, + onp.float32: 2e-2, + onp.float64: 1e-12, + onp.complex128: 1e-12, + } + self._CheckAgainstNumpy( + onp.vdot, lnp.vdot, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + self._CompileAndCheck( + lnp.vdot, + args_maker, + check_dtypes=True, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + axes, + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "axes": axes, + "rng_factory": rng_factory, + } + for rng_factory in [jtu.rand_default] + for lhs_shape, rhs_shape, axes in [ + [(2, 3, 4), (5, 6, 7), 0], # from issue #740 + [(2, 3, 4), (3, 4, 5, 6), 2], + [(2, 3, 4), (5, 4, 3, 6), [1, 2]], + [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], + [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], + ] + for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2) + ) + ) + def testTensordot( + self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory + ): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + lnp_fun = lambda a, b: lnp.tensordot(a, b, axes) + + def onp_fun(a, b): + a = a if lhs_dtype != lnp.bfloat16 else a.astype(onp.float32) + b = b if rhs_dtype != lnp.bfloat16 else b.astype(onp.float32) + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return onp.tensordot(a, b, axes).astype(dtype) + + tol = { + onp.float16: 1e-1, + onp.float32: 1e-3, + onp.float64: 1e-12, + onp.complex64: 1e-3, + onp.complex128: 1e-12, + } + if jtu.device_under_test() == "tpu": + tol[onp.float32] = tol[onp.complex64] = 2e-1 + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int32/int64 MatMul kernels (b/168657656). + check_xla = not set((lhs_dtype, rhs_dtype)).intersection((onp.int32, onp.int64)) + + tol = {onp.float64: 1e-14, onp.float16: 0.04, onp.complex128: 6e-15} + tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol)) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + atol=tol, + rtol=tol, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + ), + "lhs_shape": lhs_shape, + "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, + "rhs_dtype": rhs_dtype, + "rng_factory": jtu.rand_default, + } + # TODO(phawkins): support integer dtypes too. + for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + if len(jtu._dims_of_shape(lhs_shape)) == 0 + or len(jtu._dims_of_shape(rhs_shape)) == 0 + or lhs_shape[-1] == rhs_shape[-1] + ) + ) + def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + + def onp_fun(lhs, rhs): + lhs = lhs if lhs_dtype != lnp.bfloat16 else lhs.astype(onp.float32) + rhs = rhs if rhs_dtype != lnp.bfloat16 else rhs.astype(onp.float32) + dtype = lnp.promote_types(lhs_dtype, rhs_dtype) + return onp.inner(lhs, rhs).astype(dtype) + + lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) + tol_spec = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 2e-6} + if jtu.device_under_test() == "tpu": + tol_spec[onp.float32] = tol_spec[onp.complex64] = 2e-1 + tol = max( + jtu.tolerance(lhs_dtype, tol_spec), jtu.tolerance(rhs_dtype, tol_spec) + ) + # TODO(phawkins): there are float32/float64 disagreements for some inputs. + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=False, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_amin={}_amax={}".format( + jtu.format_shape_dtype_string(shape, dtype), a_min, a_max + ), + "shape": shape, + "dtype": dtype, + "a_min": a_min, + "a_max": a_max, + "rng_factory": jtu.rand_default, + } + for shape in all_shapes + for dtype in minus(number_dtypes, complex_dtypes) + for a_min, a_max in [ + (-1, None), + (None, 1), + (-1, 1), + (-onp.ones(1), None), + (None, onp.ones(1)), + (-onp.ones(1), onp.ones(1)), + ] + ) + ) + def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory): + rng = rng_factory() + + # Convert bounds to the correct dtype if they're arrays + if a_min is not None and hasattr(a_min, "astype"): + a_min = a_min.astype(dtype) + if a_max is not None and hasattr(a_max, "astype"): + a_max = a_max.astype(dtype) + + onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) + lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) + + # Define args_maker as a function to ensure proper dtype conversion + def args_maker(): + x = rng(shape, dtype) + # Ensure input has the correct dtype - force conversion to numpy array first + if hasattr(x, "astype"): + x = onp.asarray(x, dtype=dtype) + return [x] + + tol_spec = {onp.float64: 2e-7} + tol = jtu.tolerance(dtype, tol_spec) + is_x32_scalar = dtype in [onp.int32, onp.float32] and shape in [ + jtu.NUMPY_SCALAR_SHAPE, + (), + ] + # Turns check_dtypes off if is_x32_scalar is True because there is + # a weird promotion inconsistency in numpy: + # ``` + # print(np.result_type(np.ones([], np.int32), 1)) + # print(np.result_type(np.ones([1], np.int32), 1)) + # print(np.result_type(np.int32(1), 1)) + # print(np.result_type(np.int32, 1)) + # print(np.result_type(np.ones([], np.float32), 1)) + # print(np.result_type(np.ones([1], np.float32), 1)) + # print(np.result_type(np.float32(1), 1)) + # print(np.result_type(np.float32, 1)) + # ``` + # >>> + # int64 + # int32 + # int64 + # int32 + # float64 + # float32 + # float64 + # float32 + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=not is_x32_scalar, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=not is_x32_scalar, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_amin={}_amax={}".format( + jtu.format_shape_dtype_string(shape, dtype), a_min, a_max + ), + "shape": shape, + "dtype": dtype, + "a_min": a_min, + "a_max": a_max, + "rng_factory": jtu.rand_default, + } + for shape in array_shapes + [jtu.NUMPY_SCALAR_SHAPE] + for dtype in minus(number_dtypes, complex_dtypes) + for a_min, a_max in [ + (-1, None), + (None, 1), + (-1, 1), + (-onp.ones(1), None), + (None, onp.ones(1)), + (-onp.ones(1), onp.ones(1)), + ] + ) + ) + @new_test + def testClipAsMethodStaticBounds(self, shape, dtype, a_min, a_max, rng_factory): + rng = rng_factory() + + # Fix the a_min and a_max types + if a_min is not None and hasattr(a_min, "astype"): + a_min = a_min.astype(dtype) + if a_max is not None and hasattr(a_max, "astype"): + a_max = a_max.astype(dtype) + + onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) + lnp_fun = lambda x: lnp.asarray(x).clip(a_min=a_min, a_max=a_max) + + # Modified args_maker to ensure dtype consistency + def args_maker(): + x = rng(shape, dtype) + # Force conversion to numpy array with the exact required dtype + if hasattr(x, "astype"): + x = onp.asarray(x, dtype=dtype) + return [x] + + tol_spec = {onp.float64: 2e-7} + tol = jtu.tolerance(dtype, tol_spec) + is_x32_scalar = dtype in [onp.int32, onp.float32] and shape in [ + jtu.NUMPY_SCALAR_SHAPE, + (), + ] + + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=not is_x32_scalar, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=not is_x32_scalar, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_decimals={}".format( + jtu.format_shape_dtype_string(shape, dtype), decimals + ), + "shape": shape, + "dtype": dtype, + "decimals": decimals, + "rng_factory": jtu.rand_default, + } + for shape, dtype in _shape_and_dtypes( + all_shapes, minus(number_dtypes, complex_dtypes) + ) + for decimals in [0, 1, -2] + ) + ) + def testRoundStaticDecimals(self, shape, dtype, decimals, rng_factory): + rng = rng_factory() + if lnp.issubdtype(dtype, onp.integer) and decimals < 0: + self.skipTest("Integer rounding with decimals < 0 not implemented") + onp_fun = lambda x: onp.round(x, decimals=decimals) + lnp_fun = lambda x: lnp.round(x, decimals=decimals) + args_maker = lambda: [rng(shape, dtype)] + tol = { + # TODO(b/154768983): lnp.bfloat16: 5e-2, + onp.float16: 1e-2 + } + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=check_dtypes, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + def testOperatorRound(self): + self.assertAllClose( + round(onp.float32(7.532), 1), round(lnp.float32(7.5), 1), check_dtypes=True + ) + self.assertAllClose( + round(onp.float32(1.234), 2), + round(lnp.float32(1.234), 2), + check_dtypes=True, + ) + self.assertAllClose( + round(onp.float32(1.234)), round(lnp.float32(1.234)), check_dtypes=False + ) + self.assertAllClose( + round(onp.float32(7.532), 1), + round(lnp.array(7.5, lnp.float32), 1), + check_dtypes=True, + ) + self.assertAllClose( + round(onp.float32(1.234), 2), + round(lnp.array(1.234, lnp.float32), 2), + check_dtypes=True, + ) + self.assertAllClose( + round(onp.float32(1.234)), + round(lnp.array(1.234, lnp.float32)), + check_dtypes=False, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format( + jtu.format_shape_dtype_string(shape, dtype), + mode, + pad_width_rank, + constant_values_rank, + ), + "shape": shape, + "dtype": dtype, + "mode": mode, + "pad_width_rank": pad_width_rank, + "constant_values_rank": constant_values_rank, + "rng_factory": jtu.rand_default, + "irng_factory": partial(jtu.rand_int, 3), + } + for mode, constant_values_rank, shapes in [ + ("constant", 0, all_shapes), + ("constant", 1, all_shapes), + ("constant", 2, all_shapes), + ("symmetric", None, nonempty_shapes), + ("reflect", None, nonempty_shapes), + ("wrap", None, nonempty_shapes), + ] + for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) + for pad_width_rank in range(3) + ) + ) + @jtu.disable + def testPad( + self, + shape, + dtype, + mode, + pad_width_rank, + constant_values_rank, + rng_factory, + irng_factory, + ): + rng = rng_factory() + irng = irng_factory() + pad_width = irng([len(shape), 2][2 - pad_width_rank :], onp.int32) + + def onp_fun(x, kwargs): + if pad_width.size == 0: + return x + return onp.pad(x, pad_width, mode=mode, **kwargs) + + def lnp_fun(x, kwargs): + return lnp.pad(x, pad_width, mode=mode, **kwargs) + + def args_maker(): + kwargs = {} + if constant_values_rank: + kwargs["constant_values"] = rng( + [len(shape), 2][2 - constant_values_rank :], dtype + ) + return rng(shape, dtype), kwargs + + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, + ) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape=[{}]_reps={}".format( + jtu.format_shape_dtype_string(shape, dtype), reps + ), + "shape": shape, + "dtype": dtype, + "reps": reps, + "rng_factory": jtu.rand_default, + } + for reps in [(), (2,), (3, 4), (2, 3, 4)] + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + ) + ) + def testTile(self, shape, dtype, reps, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.tile(arg, reps) + lnp_fun = lambda arg: lnp.tile(arg, reps) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = {onp.float64: 2e-7} + tol = jtu.tolerance(dtype, tol_spec) + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, + tol=tol, + ) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( + axis, + ",".join(str(d) for d in base_shape), + ",".join(onp.dtype(dtype).name for dtype in arg_dtypes), + ), + "axis": axis, + "base_shape": base_shape, + "arg_dtypes": arg_dtypes, + "rng_factory": jtu.rand_default, + } + for num_arrs in [3] + for arg_dtypes in CombosWithReplacement(default_dtypes, num_arrs) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape) + 1, len(base_shape)) + ) + ) + def testConcatenate(self, axis, base_shape, arg_dtypes, rng_factory): + rng = rng_factory() + wrapped_axis = axis % len(base_shape) + shapes = [ + base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis + 1 :] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes) + ] + + def onp_fun(*args): + # TODO(nareshmodi): enable once bfloat16 has better support + # args = [x if x.dtype != bfloat16 else x.astype(onp.float32) + # for x in args] + dtype = functools.reduce(lnp.promote_types, arg_dtypes) + return onp.concatenate(args, axis=axis).astype(dtype) + + lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( + axis, + ",".join(str(d) for d in base_shape), + ",".join(onp.dtype(dtype).name for dtype in arg_dtypes), + ), + "axis": axis, + "base_shape": base_shape, + "arg_dtypes": arg_dtypes, + "rng_factory": jtu.rand_default, + } + for arg_dtypes in CombosWithReplacement(default_dtypes, 2) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape) + 1, len(base_shape)) + ) + ) + def testAppend(self, axis, base_shape, arg_dtypes, rng_factory): + rng = rng_factory() + wrapped_axis = axis % len(base_shape) + shapes = [ + base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis + 1 :] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes) + ] + + def onp_fun(arr, values): + arr = arr.astype(onp.float32) if lnp.bfloat16 == arr.dtype else arr + values = ( + values.astype(onp.float32) if lnp.bfloat16 == values.dtype else values + ) + out = onp.append(arr, values, axis=axis) + return out.astype(lnp.promote_types(*arg_dtypes)) + + lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape=[{}]_axis={}_repeats={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, repeats + ), + "axis": axis, + "shape": shape, + "dtype": dtype, + "repeats": repeats, + "rng_factory": jtu.rand_default, + } + for repeats in [0, 1, 2] + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + for axis in [None] + list(range(-len(shape), len(shape))) + ) + ) + def testRepeat(self, axis, shape, dtype, repeats, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis) + onp_fun = _promote_like_lnp(onp_fun) + lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) + + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False + ) + + def testIssue1233(self): + """ + Following numpy test suite from `test_repeat` at https://github.com/numpy/numpy/blob/master/numpy/core/tests/test_multiarray.py + """ + tol = 1e-5 + + def test_single(m, args_maker, repeats, axis): + lax_ans = lnp.repeat(m, repeats, axis) + numpy_ans = onp.repeat(m, repeats, axis) + + self.assertAllClose( + lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol + ) + + lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) + # Turns off XLA check because there are no XLA kernels for `Where` used by + # tf.repeat (b/169192730). + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=False, + check_experimental_compile=False, + check_xla_forced_compile=False, + ) + + m = lnp.array([1, 2, 3, 4, 5, 6]) + args_maker = lambda: [m] + + for repeats in [ + 2, + [1, 3, 2, 1, 1, 2], + [1, 3, 0, 1, 1, 2], + [2], + lnp.array([1, 3, 2, 1, 1, 2]), + lnp.array([2]), + ]: + test_single(m, args_maker, repeats, None) + + m_rect = m.reshape((2, 3)) + args_maker = lambda: [m_rect] + + for repeats in [2, [2, 1], [2], lnp.array([2, 1]), lnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=0) + + for repeats in [2, [1, 3, 2], [2], lnp.array([1, 3, 2]), lnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=1) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( + op, jtu.format_shape_dtype_string(shape, dtype), axis, out_dtype + ), + "axis": axis, + "shape": shape, + "dtype": dtype, + "out_dtype": out_dtype, + "rng_factory": jtu.rand_default, + "lnp_op": getattr(lnp, op), + "onp_op": getattr(onp, op), + } + for op in ["cumsum", "cumprod"] + for dtype in default_dtypes + for out_dtype in default_dtypes + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape))) + ) + ) + def testCumSumProd( + self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_factory + ): + rng = rng_factory() + onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype) + lnp_fun = lambda arg: lnp_op(arg, axis=axis, dtype=out_dtype) + + args_maker = lambda: [rng(shape, dtype)] + + tol = max(jtu.tolerance(dtype), jtu.tolerance(out_dtype)) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + # XLA lacks int64 Cumsum/Cumprod kernels (b/168841378). + check_xla = out_dtype != onp.int64 + rtol = None + if out_dtype == onp.float16: + rtol = 2e-3 + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + rtol=rtol, + check_incomplete_shape=True, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_dtype={}_m={}_n={}_k={}".format( + onp.dtype(dtype).name, m, n, k + ), + "m": m, + "n": n, + "k": k, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for n in [0, 4] + for m in [None, 0, 1, 3, 4] + for k in list(range(-4, 4)) + ) + ) + def testTri(self, m, n, k, dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype) + lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_shape={}_k={}".format( + op, jtu.format_shape_dtype_string(shape, dtype), k + ), + "dtype": dtype, + "shape": shape, + "op": op, + "k": k, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for op in ["tril", "triu"] + for k in list(range(-3, 3)) + ) + ) + def testTriLU(self, dtype, shape, op, k, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: getattr(onp, op)(arg, k=k) + lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + # Incomplete shape support is not implemented at the moment. + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False + ) + + @named_parameters( + jtu.cases_from_list( + {"testcase_name": "_ndim={}_n={}".format(ndim, n), "ndim": ndim, "n": n} + for ndim in [0, 1, 4] + for n in [0, 1, 7] + ) + ) + def testDiagIndices(self, ndim, n): + onp.testing.assert_equal(onp.diag_indices(n, ndim), lnp.diag_indices(n, ndim)) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k + ), + "dtype": dtype, + "shape": shape, + "k": k, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] + for k in list(range(-4, 4)) + ) + ) + def testDiag(self, shape, dtype, k, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.diag(arg, k) + lnp_fun = lambda arg: lnp.diag(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( + jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2 + ), + "dtype": dtype, + "shape": shape, + "offset": offset, + "axis1": axis1, + "axis2": axis2, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in [ + a + for a in range(-len(shape), len(shape)) + if a % len(shape) != axis1 % len(shape) + ] + for offset in list(range(-4, 4)) + ) + ) + def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng_factory): + rng = rng_factory() + onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2) + lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n), + "dtype": dtype, + "n": n, + } + for dtype in default_dtypes + for n in list(range(4)) + ) + ) + def testIdentity(self, n, dtype): + onp_fun = lambda: onp.identity(n, dtype) + lnp_fun = lambda: lnp.identity(n, dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( + jtu.format_shape_dtype_string(shape, dtype), + out_dtype, + offset, + axis1, + axis2, + ), + "dtype": dtype, + "out_dtype": out_dtype, + "shape": shape, + "offset": offset, + "axis1": axis1, + "axis2": axis2, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for out_dtype in [None] + number_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in range(-len(shape), len(shape)) + if (axis1 % len(shape)) != (axis2 % len(shape)) + for offset in list(range(-4, 4)) + ) + ) + def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng_factory): + rng = rng_factory() + + def onp_fun(arg): + if out_dtype == lnp.bfloat16: + return onp.trace(arg, offset, axis1, axis2, onp.float32).astype( + lnp.bfloat16 + ) + else: + return onp.trace(arg, offset, axis1, axis2, out_dtype) + + lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}".format( + jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis + ), + "shape": shape, + "axis": axis, + "dtypes": dtypes, + "rng_factory": rng_factory, + } + for dtypes in [ + [onp.float32], + [onp.float32, onp.float32], + [onp.float32, onp.int32, onp.float32], + [onp.float32, onp.int64, onp.float32], + [onp.float32, onp.int32, onp.float64], + ] + for shape in [(), (2,), (3, 4), (1, 100)] + for axis in range(-len(shape), len(shape) + 1) + for rng_factory in [jtu.rand_default] + ) + ) + def testStack(self, shape, axis, dtypes, rng_factory): + rng = rng_factory() + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + onp_fun = _promote_like_lnp(partial(onp.stack, axis=axis)) + lnp_fun = partial(lnp.stack, axis=axis) + self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, True, check_incomplete_shape=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_{}".format( + op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes) + ), + "shape": shape, + "op": op, + "dtypes": dtypes, + "rng_factory": rng_factory, + } + for op in ["hstack", "vstack", "dstack"] + for dtypes in [ + [onp.float32], + [onp.float32, onp.float32], + [onp.float32, onp.int32, onp.float32], + [onp.float32, onp.int64, onp.float32], + [onp.float32, onp.int32, onp.float64], + ] + for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)] + for rng_factory in [jtu.rand_default] + ) + ) + def testHVDStack(self, shape, op, dtypes, rng_factory): + rng = rng_factory() + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + onp_fun = _promote_like_lnp(getattr(onp, op)) + lnp_fun = getattr(lnp, op) + self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(lnp_fun, args_maker, True, check_incomplete_shape=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_outdtype={}".format( + jtu.format_shape_dtype_string(shape, fill_value_dtype), + onp.dtype(out_dtype).name if out_dtype else "None", + ), + "shape": shape, + "fill_value_dtype": fill_value_dtype, + "out_dtype": out_dtype, + "rng_factory": jtu.rand_default, + } + for shape in array_shapes + [3, onp.array(7, dtype=onp.int32)] + for fill_value_dtype in default_dtypes + for out_dtype in [None] + default_dtypes + ) + ) + def testFull(self, shape, fill_value_dtype, out_dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) + lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) + args_maker = lambda: [rng((), fill_value_dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype), + "onp_op": getattr(onp, op), + "lnp_op": getattr(lnp, op), + "shape": shape, + "dtype": dtype, + } + for op in ["zeros", "ones"] + for shape in [ + 2, + (), + (2,), + (3, 0), + onp.array((4, 5, 6), dtype=onp.int32), + onp.array(4, dtype=onp.int32), + ] + for dtype in all_dtypes + ) + ) + def testZerosOnes(self, onp_op, lnp_op, shape, dtype): + rng = jtu.rand_default() + + def args_maker(): + return [] + + onp_op = partial(onp_op, shape, dtype) + lnp_op = partial(lnp_op, shape, dtype) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_filldtype={}_outdtype={}".format( + jtu.format_shape_dtype_string(shape, in_dtype), + onp.dtype(fill_value_dtype).name, + onp.dtype(out_dtype).name, + ), + "shape": shape, + "in_dtype": in_dtype, + "fill_value_dtype": fill_value_dtype, + "out_dtype": out_dtype, + "rng_factory": jtu.rand_default, + } + for shape in array_shapes + for in_dtype in default_dtypes + for fill_value_dtype in default_dtypes + for out_dtype in default_dtypes + ) + ) + def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype) + lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype) + args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}_{}sections".format( + jtu.format_shape_dtype_string(shape, dtype), axis, num_sections + ), + "shape": shape, + "num_sections": num_sections, + "axis": axis, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for shape, axis, num_sections in [ + ((3,), 0, 3), + ((12,), 0, 3), + ((12, 4), 0, 4), + ((12, 4), 1, 2), + ((2, 3, 4), -1, 2), + ((2, 3, 4), -2, 3), + ] + for dtype in default_dtypes + ) + ) + def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.split(x, num_sections, axis=axis) + lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}_{}sections".format( + jtu.format_shape_dtype_string(shape, dtype), axis, num_sections + ), + "shape": shape, + "num_sections": num_sections, + "axis": axis, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for shape, axis, num_sections in [ + ((12, 4), 0, 4), + ((12, 4), 1, 2), + ((2, 3, 4), 2, 2), + ((4, 3, 4), 0, 2), + ] + for dtype in default_dtypes + ) + ) + def testHVDSplit(self, shape, num_sections, axis, dtype, rng_factory): + rng = rng_factory() + + def fn(module, axis): + if axis == 0: + return module.vsplit + elif axis == 1: + return module.hsplit + else: + assert axis == 2 + return module.dsplit + + onp_fun = lambda x: fn(onp, axis)(x, num_sections) + lnp_fun = lambda x: fn(lnp, axis)(x, num_sections) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_outshape={}_order={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype), + order, + ), + "arg_shape": arg_shape, + "out_shape": out_shape, + "dtype": dtype, + "order": order, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for order in ["C", "F"] + for arg_shape, out_shape in [ + (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), + ((), (1, 1, 1)), + ((7, 0), (0, 42, 101)), + ((3, 4), 12), + ((3, 4), (12,)), + ((3, 4), -1), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)), + ] + ) + ) + def testReshape(self, arg_shape, out_shape, dtype, order, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.reshape(x, out_shape, order=order) + lnp_fun = lambda x: lnp.reshape(x, out_shape, order=order) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_outshape={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype), + ), + "arg_shape": arg_shape, + "out_shape": out_shape, + "dtype": dtype, + "rng_factory": jtu.rand_default, + } + for dtype in default_dtypes + for arg_shape, out_shape in [ + ((7, 0), (0, 42, 101)), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)), + ] + ) + ) + def testReshapeMethod(self, arg_shape, out_shape, dtype, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.reshape(x, out_shape) + lnp_fun = lambda x: x.reshape(*out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_expanddim={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), dim + ), + "arg_shape": arg_shape, + "dtype": dtype, + "dim": dim, + "rng_factory": jtu.rand_default, + } + for arg_shape in [(), (3,), (3, 4)] + for dtype in default_dtypes + for dim in range(-len(arg_shape) + 1, len(arg_shape)) + ) + ) + def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.expand_dims(x, dim) + lnp_fun = lambda x: lnp.expand_dims(x, dim) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_axes=({},{})".format( + jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2 + ), + "arg_shape": arg_shape, + "dtype": dtype, + "ax1": ax1, + "ax2": ax2, + "rng_factory": jtu.rand_default, + } + for arg_shape, ax1, ax2 in [ + ((3, 4), 0, 1), + ((3, 4), 1, 0), + ((3, 4, 5), 1, 2), + ((3, 4, 5), -1, -2), + ((3, 4, 5), 0, 1), + ] + for dtype in default_dtypes + ) + ) + def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) + lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_axes=({},{})".format( + jtu.format_shape_dtype_string(arg_shape, dtype), source, destination + ), + "arg_shape": arg_shape, + "dtype": dtype, + "source": source, + "destination": destination, + "rng_factory": jtu.rand_default, + } + for arg_shape, source, destination in [ + (tuple(range(6)), (0, 2), (3, 5)), + (tuple(range(6)), (0, 2), (-1, -3)), + (tuple(range(6)), (-6, -4), (3, 5)), + (tuple(range(6)), (-6, -4), (-1, -3)), + (tuple(range(6)), 0, 4), + (tuple(range(6)), -6, -2), + (tuple(range(6)), tuple(range(6)), tuple(range(6))), + (tuple(range(6)), tuple(range(6)), tuple(reversed(range(6)))), + (tuple(range(6)), (), ()), + ] + for dtype in default_dtypes + ) + ) + @new_test + def testMoveaxisStaticAxes( + self, arg_shape, dtype, source, destination, rng_factory + ): + rng = rng_factory() + onp_fun = lambda x: onp.moveaxis(x, source, destination) + lnp_fun = lambda x: lnp.moveaxis(x, source, destination) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_inshape={}_axis={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), ax + ), + "arg_shape": arg_shape, + "dtype": dtype, + "ax": ax, + "rng_factory": jtu.rand_default, + } + for arg_shape, ax in [ + ((3, 1), None), + ((3, 1), 1), + ((1, 3, 1), (0, 2)), + ((1, 4, 1), (0,)), + ] + for dtype in default_dtypes + ) + ) + def testSqueeze(self, arg_shape, dtype, ax, rng_factory): + rng = rng_factory() + onp_fun = lambda x: onp.squeeze(x, ax) + lnp_fun = lambda x: lnp.squeeze(x, ax) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_axis={}_weights={}_returned={}".format( + jtu.format_shape_dtype_string(shape, dtype), + axis, + ( + None + if weights_shape is None + else jtu.format_shape_dtype_string(weights_shape, dtype) + ), + returned, + ), + "rng_factory": jtu.rand_default, + "shape": shape, + "dtype": dtype, + "axis": axis, + "weights_shape": weights_shape, + "returned": returned, + } + for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes) + for axis in set(range(-len(shape), len(shape))) | set([None]) + # `weights_shape` is either `None`, same as the averaged axis, or same as + # that of the input + for weights_shape in ( + [None, shape] + if axis is None or len(shape) == 1 + else [None, (shape[axis],), shape] + ) + for returned in [False, True] + ) + ) + def testAverage(self, shape, dtype, axis, weights_shape, returned, rng_factory): + rng = rng_factory() + if weights_shape is None: + onp_fun = lambda x: onp.average(x, axis, returned=returned) + lnp_fun = lambda x: lnp.average(x, axis, returned=returned) + args_maker = lambda: [rng(shape, dtype)] + else: + onp_fun = lambda x, weights: onp.average(x, axis, weights, returned) + lnp_fun = lambda x, weights: lnp.average(x, axis, weights, returned) + args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)] + onp_fun = _promote_like_lnp(onp_fun, inexact=True) + tol = { + # TODO(b/154768983): lnp.bfloat16: 1e-1, + onp.float16: 1e-1, + onp.float32: 1e-3, + onp.float64: 2e-7, + onp.complex64: 1e-3, + onp.complex128: 1e-10, + } + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + try: + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol + ) + except ZeroDivisionError: + self.skipTest("don't support checking for ZeroDivisionError") + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=check_dtypes, + rtol=tol, + atol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_arg{}_ndmin={}".format(i, ndmin), + "arg": arg, + "ndmin": ndmin, + "dtype": dtype, + } + for i, (arg, dtype) in enumerate( + [ + ([True, False, True], lnp.bool_), + (3.0, lnp.float_), + ([1, 2, 3], lnp.int_), + ([1.0, 2.0, 3.0], lnp.float_), + ([[1, 2], [3, 4], [5, 6]], lnp.int_), + ([[1, 2.0], [3, 4], [5, 6]], lnp.float_), + ([[1.0, 2j], [3.0, 4.0], [5.0, 6.0]], lnp.complex_), + ( + [ + [3, onp.array(2, dtype=lnp.float_), 1], + onp.arange(3.0, dtype=lnp.float_), + ], + lnp.float_, + ), + ] + ) + for ndmin in [None, onp.ndim(arg), onp.ndim(arg) + 1, onp.ndim(arg) + 2] + ) + ) + def testArray(self, arg, ndmin, dtype): + args_maker = lambda: [arg] + dtype = lnp.canonicalize_dtype(dtype) + if ndmin is not None: + onp_fun = partial(onp.array, ndmin=ndmin, dtype=dtype) + lnp_fun = partial(lnp.array, ndmin=ndmin) + else: + onp_fun = partial(onp.array, dtype=dtype) + lnp_fun = lnp.array + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + static_argnums=[0], + ) + + def testIssue121(self): + assert not onp.isscalar(lnp.array(3)) + + @jtu.disable + def testArrayMethod(self): + class arraylike(object): + dtype = onp.float32 + + def __array__(self, dtype=None): + return 3.0 + + a = arraylike() + ans = lnp.array(a) + assert ans == 3.0 + + @jtu.skip_on_devices("tpu") # TODO(b/32368900): TPUs don't support uint8 yet. + @jtu.disable + def testMemoryView(self): + ans = lnp.array(bytearray(b"\x2a")) + self.assertAllClose(ans, onp.array([0x2A], dtype=onp.uint8), check_dtypes=True) + + def testAllClose(self): + rng = onp.random.RandomState(0) + x = rng.randn(2, 2) + y = rng.randn(2) + + def same(list1, list2): + allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) + elements_close = list(map(allclose, list1, list2)) + return lnp.all(lnp.array(elements_close)) + + csame = npe.jit(same) + + a1 = same((x, y), (x, y)) + a2 = csame((x, y), (x, y)) + a3 = csame((x, y), (x, 2 * y)) + + self.assertTrue(a1) + self.assertTrue(a2) + self.assertFalse(a3) + + @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure + @jtu.disable + def testOnesBroadcastingConstantHandler(self): + # TODO(mattjj): update this test for jax3 + self.skipTest("test needs jax3 update") + + def fun(x): + ones = lnp.ones((3, 4)) + assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) + + # To check that the constant handler generates a Broadcast for stride-zero + # arrays, we monkey-patch the client instance. + # TODO(mattjj): once we have better HLO dumping and inspecting facilities, + # we can check the HLO more directly. + c = x._node.c + Broadcast = c.Broadcast # pylint: disable=invalid-name + was_called = [] + c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) + out = x + ones # the ndarray constant handler should call Broadcast here + assert was_called, "Broadcast was not called." + + return out + + fun = api.jit(fun) + out_val = fun(lnp.ones(4)) + self.assertAllClose(out_val, onp.full((3, 4), 2.0), check_dtypes=False) + + def testZeroStridesConstantHandler(self): + raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) + const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) + + def fun(x): + return x * const + + fun = npe.jit(fun) + out_val = fun(3.0) + self.assertAllClose(out_val, 3.0 * const, check_dtypes=False) + + def testIsInstanceNdarrayDuringTracing(self): + arr = onp.ones(3) + + @npe.jit + def f(x): + self.assertIsInstance(x, lnp.ndarray) + return lnp.sum(x) + + f(arr) + + @jtu.disable + def testNonArrayErrorMessage(self): + x = [1.0, 2.0] + y = onp.array([3.0, 4.0]) + + def g(x, y): + return lnp.add(x, y) + + def f(x, y): + return lnp.dot(x, y) + + self.assertRaises(TypeError, lambda: g(x, y)) + self.assertRaises(TypeError, lambda: f(x, y)) + self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) + self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) + + @jtu.disable + def testAbstractionErrorMessage(self): + @api.jit + def f(x, n): + for _ in range(n): + x = x * x + return x + + self.assertRaises(TypeError, lambda: f(3.0, 3)) + + @api.jit + def g(x): + if x > 0.0: + return x * 2 + else: + return x + 2 + + self.assertRaises(TypeError, lambda: g(3.0)) + + @jtu.disable + def testTracingPrimitiveWithNoTranslationErrorMessage(self): + # TODO(mattjj): update this for jax3 + self.skipTest("test needs jax3 update") + foo = lnp._not_implemented(lambda x: x) + + # No error if there's no tracing. + foo(onp.arange(3)) + + cfoo = api.jit(foo) + self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "axis": axis, + } + for shape in [(3,), (2, 3)] + for dtype in default_dtypes + for axis in list(range(-len(shape), len(shape))) + + [None] # Test negative axes + for rng_factory in [jtu.rand_default] + ) + ) + def testFlip(self, shape, dtype, axis, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.flip(x, axis) + onp_op = lambda x: onp.flip(x, axis) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + } + for shape in [(3,), (2, 3), (3, 2, 4)] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testFlipud(self, shape, dtype, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.flipud(x) + onp_op = lambda x: onp.flipud(x) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + } + for shape in [(3, 2), (2, 3), (3, 2, 4)] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testFliplr(self, shape, dtype, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.fliplr(x) + onp_op = lambda x: onp.fliplr(x) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_k={}_axes={}".format( + jtu.format_shape_dtype_string(shape, dtype), k, axes + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "k": k, + "axes": axes, + } + for shape, axes in [ + [(2, 3), (0, 1)], + [(2, 3), (1, 0)], + [(4, 3, 2), (0, 2)], + [(4, 3, 2), (2, 1)], + ] + for k in range(-3, 4) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testRot90(self, shape, dtype, k, axes, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.rot90(x, k, axes) + onp_op = lambda x: onp.rot90(x, k, axes) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_k={}_axes={}".format( + jtu.format_shape_dtype_string(shape, dtype), k, axes + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "k": k, + "axes": axes, + } + for shape, axes in [ + [(2, 3), (-2, -1)], + [(2, 3), (-2, 1)], + [(4, 3, 2), (-1, -2)], + [(4, 3, 2), (2, -2)], + ] + for k in range(-3, 4) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + @new_test + # These tests are only added as a separate test from testRot90 since we would + # like to measure coverage directly against the existing baseline. Once we + # stop measuring that, we can combine this test with the above. + def testRot90Additional(self, shape, dtype, k, axes, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + lnp_op = lambda x: lnp.rot90(x, k, axes) + onp_op = lambda x: onp.rot90(x, k, axes) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + # TODO(mattjj): test infix operator overrides + + def testRavel(self): + rng = onp.random.RandomState(0) + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + self._CompileAndCheck( + lambda x: x.ravel(), + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + ) + + def testAstype(self): + rng = onp.random.RandomState(0) + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + op = lambda x: x.astype(lnp.int32) + self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True) + self._CompileAndCheck( + op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + # TODO(mattjj): test other ndarray-like method overrides + + def testOnpMean(self): + # from https://github.com/google/jax/issues/125 + x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.0) + ans = onp.mean(x) + self.assertAllClose(ans, onp.array(1.0 / 3), check_dtypes=False) + + @jtu.disable + def testArangeOnFloats(self): + # from https://github.com/google/jax/issues/145 + expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_) + ans = lnp.arange(0.0, 1.0, 0.1) + self.assertAllClose(expected, ans, check_dtypes=True) + + def testSortManually(self): + def _test(*args, **kwargs): + raw_ans = lnp.sort(*args, **kwargs) + fn_ans = npe.jit(lnp.sort, static_argnums=(1,))(*args, **kwargs) + expected = onp.sort(*args, **kwargs) + + self.assertAllClose(expected, raw_ans, check_dtypes=True) + self.assertAllClose(expected, fn_ans, check_dtypes=True) + + # manual tests for sort are nice because we don't have to worry about ties. + # lax.sort is tested combinatorially. + _test(onp.array([16, 15, 23, 42, 8, 4])) + _test(onp.array([[1, 4], [3, 1]]), None) + _test(onp.array([[1, 4], [3, 1]])) + _test(onp.array([[1, 4], [3, 1]]), 0) + + def testArgsortManually(self): + def _test(*args, **kwargs): + raw_ans = lnp.argsort(*args, **kwargs) + fn_ans = npe.jit(lnp.argsort, static_argnums=(1,))(*args, **kwargs) + expected = onp.argsort(*args, **kwargs) + + self.assertAllClose(expected, raw_ans, check_dtypes=True) + self.assertAllClose(expected, fn_ans, check_dtypes=True) + + _test(onp.array([16, 15, 23, 42, 8, 4])) + _test(onp.array([[16, 15, 23], [42, 8, 4]]), 0) + _test(onp.array([[16, 15, 23], [42, 8, 4]]), 1) + _test(onp.array([[16, 15, 23], [42, 8, 4]]), None) + _test(onp.array([[16, 15, 23], [42, 8, 4]])) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_shifts={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), shifts, axis + ), + "rng_factory": rng_factory, + "shape": shape, + "dtype": dtype, + "shifts": shifts, + "axis": axis, + } + for dtype in all_dtypes + for shape in [(3, 4), (3, 4, 5), (7, 4, 0)] + for shifts, axis in [ + (3, None), + (1, 1), + ((3,), (0,)), + ((-2,), (-2,)), + ((1, 2), (0, -1)), + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testRoll(self, shape, dtype, shifts, axis, rng_factory): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype), onp.array(shifts)] + lnp_op = partial(lnp.roll, axis=axis) + onp_op = partial(onp.roll, axis=axis) + self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_index={}_axis={}_mode={}".format( + jtu.format_shape_dtype_string(shape, dtype), + jtu.format_shape_dtype_string(index_shape, index_dtype), + axis, + mode, + ), + "rng_factory": rng_factory, + "rng_indices_factory": rng_indices_factory, + "shape": shape, + "index_shape": index_shape, + "dtype": dtype, + "index_dtype": index_dtype, + "axis": axis, + "mode": mode, + } + for shape in [(3,), (3, 4), (3, 4, 5)] + for index_shape in scalar_shapes + [(3,), (2, 1, 3)] + for axis in itertools.chain(range(-len(shape), len(shape)), [None]) + for dtype in all_dtypes + for index_dtype in int_dtypes + for mode in ["wrap", "clip"] + for rng_factory in [jtu.rand_default] + for rng_indices_factory in [partial(jtu.rand_int, -5, 5)] + ) + ) + def testTake( + self, + shape, + dtype, + index_shape, + index_dtype, + axis, + mode, + rng_factory, + rng_indices_factory, + ): + def args_maker(): + x = rng(shape, dtype) + i = rng_indices(index_shape, index_dtype) + return x, i + + rng = rng_factory() + rng_indices = rng_indices_factory() + lnp_op = lambda x, i: lnp.take(x, i, axis=axis, mode=mode) + onp_op = lambda x, i: onp.take(x, i, axis=axis, mode=mode) + self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_ishape={}_axis={}".format( + jtu.format_shape_dtype_string(x_shape, dtype), i_shape, axis + ), + "rng_factory": rng_factory, + "x_shape": x_shape, + "i_shape": i_shape, + "dtype": dtype, + "axis": axis, + } + for x_shape, i_shape in filter( + _shapes_are_equal_length, + filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(nonempty_nonscalar_array_shapes, 2), + ), + ) + for axis in itertools.chain(range(len(x_shape)), [-1], [None]) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + ) + ) + def testTakeAlongAxis(self, x_shape, i_shape, dtype, axis, rng_factory): + rng = rng_factory() + i_shape = onp.array(i_shape) + if axis is None: + i_shape = [onp.prod(i_shape, dtype=onp.int64)] + else: + # Test the case where the size of the axis doesn't necessarily broadcast. + i_shape[axis] *= 3 + i_shape = list(i_shape) + + def args_maker(): + x = rng(x_shape, dtype) + n = onp.prod(x_shape, dtype=onp.int32) if axis is None else x_shape[axis] + i = rng(i_shape, onp.int32) % (2 * n - 1) - (n - 1) + return x, i + + lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis) + + if hasattr(onp, "take_along_axis"): + onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis) + self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_n={}_increasing={}".format( + jtu.format_shape_dtype_string([shape], dtype), n, increasing + ), + "dtype": dtype, + "shape": shape, + "n": n, + "increasing": increasing, + "rng_factory": jtu.rand_default, + } + for dtype in inexact_dtypes + for shape in [0, 5] + for n in [2, 4] + for increasing in [False, True] + ) + ) + def testVander(self, shape, dtype, n, increasing, rng_factory): + rng = rng_factory() + + def onp_fun(arg): + arg = arg.astype(onp.float32) if dtype == lnp.bfloat16 else arg + return onp.vander(arg, N=n, increasing=increasing) + + lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing) + args_maker = lambda: [rng([shape], dtype)] + # np.vander seems to return float64 for all floating types. We could obey + # those semantics, but they seem like a bug. + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol={onp.float32: 1e-3} + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=False, + check_incomplete_shape=True, + rtol={onp.complex128: 2e-15}, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + "nan_to_num", [shape], [dtype] + ), + "rng_factory": jtu.rand_some_inf_and_nan, + "shape": shape, + "dtype": dtype, + } + for shape in all_shapes + for dtype in inexact_dtypes + ) + ) + @jtu.disable + def testNanToNum(self, rng_factory, shape, dtype): + rng = rng_factory() + dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type + + def onp_fun(x): + if dtype == lnp.bfloat16: + x = onp.where(onp.isnan(x), dtype(0), x) + x = onp.where(onp.isposinf(x), lnp.finfo(dtype).max, x) + x = onp.where(onp.isneginf(x), lnp.finfo(dtype).min, x) + return x + else: + return onp.nan_to_num(x).astype(dtype) + + args_maker = lambda: [rng(shape, dtype)] + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy( + onp_fun, lnp.nan_to_num, args_maker, check_dtypes=check_dtypes + ) + self._CompileAndCheck(lnp.nan_to_num, args_maker, check_dtypes=check_dtypes) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes), + "rng_factory": jtu.rand_default, + "shapes": shapes, + "dtypes": dtypes, + } + for shapes, dtypes in ( + ((), ()), + (((7,),), (onp.int32,)), + (((3,), (4,)), (onp.int32, onp.int32)), + (((3,), (1,), (4,)), (onp.int32, onp.int32, onp.int32)), + ) + ) + ) + def testIx_(self, rng_factory, shapes, dtypes): + rng = rng_factory() + args_maker = lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] + self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp.ix_, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}".format( + op, + jtu.format_shape_dtype_string(a_shape, a_dtype), + jtu.format_shape_dtype_string(q_shape, q_dtype), + axis, + keepdims, + ), + "a_rng": jtu.rand_default(), + "q_rng": q_rng, + "op": op, + "a_shape": a_shape, + "a_dtype": a_dtype, + "q_shape": q_shape, + "q_dtype": q_dtype, + "axis": axis, + "keepdims": keepdims, + } + for (op, q_rng) in ( + ("percentile", jtu.rand_uniform(low=0.0, high=100.0)), + ("quantile", jtu.rand_uniform(low=0.0, high=1.0)), + ("median", jtu.rand_uniform(low=0.0, high=1.0)), + ) + for a_dtype in float_dtypes + for a_shape, axis in ( + ((7,), None), + ((47, 7), 0), + ((4, 101), 1), + ) + for q_dtype in [onp.float32] + for q_shape in scalar_shapes + [(4,)] + for keepdims in [False, True] + ) + ) + @jtu.disable + def testQuantile( + self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype, axis, keepdims + ): + if op == "quantile" and numpy_version < (1, 15): + raise SkipTest("Numpy < 1.15 does not have np.quantile") + if op == "median": + args_maker = lambda: [a_rng(a_shape, a_dtype)] + else: + args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] + + def onp_fun(*args): + args = [ + x if lnp.result_type(x) != lnp.bfloat16 else onp.asarray(x, onp.float32) + for x in args + ] + return getattr(onp, op)(*args, axis=axis, keepdims=keepdims) + + lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims) + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {onp.float32: 2e-4, onp.float64: 5e-6} + tol = max(jtu.tolerance(a_dtype, tol_spec), jtu.tolerance(q_dtype, tol_spec)) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol + ) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "shape": shape, + "dtype": dtype, + } + for shape in all_shapes + for dtype in all_dtypes + ) + ) + def testWhereOneArgument(self, shape, dtype): + rng = jtu.rand_some_zero() + onp_fun = lambda x: onp.where(x) + lnp_fun = lambda x: lnp.where(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) + # Turns off XLA check because there are no XLA kernels for `Where`, which + # XLA can't support because it's output shape is dynamic. + self._CompileAndCheck( + lnp.where, + args_maker, + check_dtypes=True, + check_eval_on_shapes=False, + check_incomplete_shape=True, + check_unknown_rank=False, + check_experimental_compile=False, + check_xla_forced_compile=False, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}".format( + "_".join( + jtu.format_shape_dtype_string(shape, dtype) + for shape, dtype in zip(shapes, dtypes) + ) + ), + "rng_factory": jtu.rand_default, + "shapes": shapes, + "dtypes": dtypes, + } + for shapes in filter( + _shapes_are_broadcast_compatible, CombosWithReplacement(all_shapes, 3) + ) + for dtypes in CombosWithReplacement(all_dtypes, 3) + ) + ) + def testWhereThreeArgument(self, rng_factory, shapes, dtypes): + rng = rng_factory() + + # Create a custom args_maker that forces correct dtypes + def custom_args_maker(): + args = self._GetArgsMaker(rng_factory(), shapes, dtypes)() + # Explicitly cast each argument to its specified dtype + return [onp.asarray(arg, dtype=dtype) for arg, dtype in zip(args, dtypes)] + + def onp_fun(cond, x, y): + return _promote_like_lnp(partial(onp.where, cond))(x, y) + + self._CheckAgainstNumpy( + onp_fun, lnp.where, custom_args_maker, check_dtypes=True + ) + self._CompileAndCheck( + lnp.where, custom_args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def testWhereScalarPromotion(self): + x = lnp.where(lnp.array([True, False]), 3, lnp.ones((2,), dtype=lnp.float32)) + self.assertEqual(x.dtype, onp.dtype(onp.float32)) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + "", shapes, (onp.bool_,) * n + dtypes + ), + "rng_factory": jtu.rand_default, + "shapes": shapes, + "dtypes": dtypes, + } + for n in range(0, 3) + for shapes in filter( + _shapes_are_broadcast_compatible, + CombosWithReplacement(all_shapes, 2 * n + 1), + ) + for dtypes in CombosWithReplacement(all_dtypes, n + 1) + ) + ) + def testSelect(self, rng_factory, shapes, dtypes): + rng = rng_factory() + n = len(dtypes) - 1 + + def args_maker(): + condlist = [rng(shape, onp.bool_) for shape in shapes[:n]] + choicelist = [ + rng(shape, dtype) for shape, dtype in zip(shapes[n:-1], dtypes[:n]) + ] + default = rng(shapes[-1], dtypes[-1]) + return condlist, choicelist, default + + # TODO(phawkins): float32/float64 type mismatches + def onp_fun(condlist, choicelist, default): + choicelist = [ + x if lnp.bfloat16 != lnp.result_type(x) else x.astype(onp.float32) + for x in choicelist + ] + dtype = lnp.result_type(default, *choicelist).as_numpy_dtype + return onp.select( + condlist, + [onp.asarray(x, dtype=dtype) for x in choicelist], + onp.asarray(default, dtype=dtype), + ) + + self._CheckAgainstNumpy(onp_fun, lnp.select, args_maker, check_dtypes=False) + self._CompileAndCheck( + lnp.select, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + rtol={onp.float64: 1e-7, onp.complex128: 1e-7}, + ) + + @jtu.disable + def testIssue330(self): + x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash + self.assertEqual(x[0, 0], 1) + + @jtu.disable + def testScalarDtypePromotion(self): + orig_numpy_result = (1 + onp.eye(1, dtype=onp.float32)).dtype + jax_numpy_result = (1 + lnp.eye(1, dtype=lnp.float32)).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + @jtu.disable + def testSymmetrizeDtypePromotion(self): + x = onp.eye(3, dtype=onp.float32) + orig_numpy_result = ((x + x.T) / 2).dtype + + x = lnp.eye(3, dtype=lnp.float32) + jax_numpy_result = ((x + x.T) / 2).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + @jtu.disable + def testIssue347(self): + # https://github.com/google/jax/issues/347 + def test_fail(x): + x = lnp.sqrt(lnp.sum(x**2, axis=1)) + ones = lnp.ones_like(x) + x = lnp.where(x > 0.5, x, ones) + return lnp.sum(x) + + x = lnp.array([[1, 2], [3, 4], [0, 0]], dtype=lnp.float64) + result = api.grad(test_fail)(x) + assert not onp.any(onp.isnan(result)) + + def testIssue453(self): + # https://github.com/google/jax/issues/453 + a = onp.arange(6) + 1 + ans = lnp.reshape(a, (3, 2), order="F") + expected = onp.reshape(a, (3, 2), order="F") + self.assertAllClose(ans, expected, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_op={}_dtype={}".format(op, pytype.__name__), + "pytype": pytype, + "dtype": dtype, + "op": op, + } + for pytype, dtype in [ + (int, lnp.int_), + (float, lnp.float_), + (bool, lnp.bool_), + (complex, lnp.complex_), + ] + for op in ["atleast_1d", "atleast_2d", "atleast_3d"] + ) + ) + def testAtLeastNdLiterals(self, pytype, dtype, op): + # Fixes: https://github.com/google/jax/issues/634 + onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype) + lnp_fun = lambda arg: getattr(lnp, op)(arg) + args_maker = lambda: [pytype(2)] + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def testLongLong(self): + self.assertAllClose( + onp.int64(7), npe.jit(lambda x: x)(onp.longlong(7)), check_dtypes=True + ) + + def testArange(self): + # test cases inspired by dask tests at + # https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92 + self.assertAllClose( + lnp.arange(77), onp.arange(77, dtype=lnp.int_), check_dtypes=True + ) + self.assertAllClose( + lnp.arange(2, 13), onp.arange(2, 13, dtype=lnp.int_), check_dtypes=True + ) + self.assertAllClose( + lnp.arange(4, 21, 9), + onp.arange(4, 21, 9, dtype=lnp.int_), + check_dtypes=True, + ) + self.assertAllClose( + lnp.arange(53, 5, -3), + onp.arange(53, 5, -3, dtype=lnp.int_), + check_dtypes=True, + ) + # TODO(mattjj): make these tests work when enable_x64=True + self.assertAllClose( + lnp.arange(77, dtype=float), onp.arange(77, dtype=float), check_dtypes=True + ) + self.assertAllClose( + lnp.arange(2, 13, dtype=int), + onp.arange(2, 13, dtype=int), + check_dtypes=True, + ) + self.assertAllClose( + lnp.arange(0, 1, -0.5), + onp.arange(0, 1, -0.5, dtype=lnp.float_), + check_dtypes=True, + ) + + self.assertRaises(TypeError, lambda: lnp.arange()) + + # # The following have been disabled since they test JAX specific behavior + # # test that lnp.arange(N) doesn't instantiate an ndarray + # self.assertFalse(type(lnp.arange(77)) == type(onp.arange(77))) + # self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77))) + + # # test that lnp.arange(N, dtype=int32) doesn't instantiate an ndarray + # self.assertFalse(type(lnp.arange(77, dtype=lnp.int32)) == + # type(onp.arange(77, dtype=onp.int32))) + # self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) == + # type(lax.iota(onp.int32, 77))) + + def testIssue830(self): + a = lnp.arange(4, dtype=lnp.complex64) + self.assertEqual(a.dtype, lnp.complex64) + + def testIssue728(self): + assert lnp.allclose(lnp.eye(5000), onp.eye(5000)) + self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050))) + + def testIssue746(self): + lnp.arange(12).reshape(3, 4) # doesn't crash + + def testIssue764(self): + x = lnp.linspace(190, 200, 4) + f = npe.grad(lambda x: lnp.sum(lnp.tanh(x))) + # Expected values computed with autograd in float64 precision. + expected = onp.array( + [3.71669453e-165, 4.72999108e-168, 6.01954653e-171, 7.66067839e-174], + onp.float64, + ) + self.assertAllClose(f(x), expected, check_dtypes=False) + + @jtu.disable + def testIssue776(self): + """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" + + def f(u): + y = ( + onp.ones( + 10, + ) + .at[[2, 4, 5]] + .add(u) + ) + # The transpose rule for lax.tie_in returns a symbolic zero for its first + # argument. + return lax.tie_in(y, 7.0) + + self.assertAllClose( + onp.zeros( + 3, + ), + api.grad(f)( + onp.ones( + 3, + ) + ), + check_dtypes=True, + ) + + @jtu.disable + def testIssue777(self): + x = lnp.linspace(-200, 0, 4, dtype=onp.float32) + f = npe.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x)))) + self.assertAllClose( + f(x), onp.array([0.0, 0.0, 0.0, 0.25], dtype=onp.float32), check_dtypes=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]), + "dtype": dtype, + "op": op, + } + for dtype in float_dtypes + for op in ( + "sqrt", + "arccos", + "arcsin", + "arctan", + "sin", + "cos", + "tan", + "sinh", + "cosh", + "tanh", + "arccosh", + "arcsinh", + "arctanh", + "exp", + "log", + "expm1", + "log1p", + ) + ) + ) + def testMathSpecialFloatValues(self, op, dtype): + onp_op = getattr(onp, op) + lnp_op = getattr(lnp, op) + dtype = onp.dtype(lnp.canonicalize_dtype(dtype)).type + for x in ( + onp.nan, + -onp.inf, + -100.0, + -2.0, + -1.0, + 0.0, + 1.0, + 2.0, + 100.0, + onp.inf, + lnp.finfo(dtype).max, + onp.sqrt(lnp.finfo(dtype).max), + onp.sqrt(lnp.finfo(dtype).max) * 2.0, + ): + if ( + op in ("sin", "cos", "tan", "arctan") + and jtu.device_under_test() == "tpu" + ): + continue # TODO(b/132196789, b/134175194): fix and reenable. + # TODO(b/158006398): fix and reenable. + if ( + op + in ( + "cosh", + "arccosh", + "arcsinh", + "arcsin", + "sinh", + "arccos", + "arctan", + "arctanh", + ) + and dtype == onp.float16 + ): + continue + x = dtype(x) + expected = onp_op(x) + actual = lnp_op(x) + tol = jtu.tolerance(dtype, {onp.float32: 1e-3, onp.float64: 1e-7}) + self.assertAllClose(expected, actual, check_dtypes=True, atol=tol, rtol=tol) + + def testIssue883(self): + # from https://github.com/google/jax/issues/883 + + @partial(npe.jit, static_argnums=(1,)) + def f(x, v): + return x + + x = lnp.ones((10, 10)) + v = lnp.array([1, 2, 3]) + first_call = f(x, v) + second_call = f(x, v) # doesn't crash + + def testReductionOfOutOfBoundsAxis(self): # Issue 888 + x = lnp.ones((3, 4)) + self.assertRaises(tf.errors.InvalidArgumentError, lambda: lnp.sum(x, axis=2)) + + @jtu.disable + def testIssue956(self): + self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1))) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}".format( + shape, dtype, out_dtype, axis, ddof, keepdims + ), + "shape": shape, + "dtype": dtype, + "out_dtype": out_dtype, + "axis": axis, + "ddof": ddof, + "keepdims": keepdims, + "rng_factory": rng_factory, + } + for shape in [(5,), (10, 5)] + for dtype in all_dtypes + for out_dtype in inexact_dtypes + for axis in [None, 0, -1] + for ddof in [0, 1, 2] + for keepdims in [False, True] + for rng_factory in [jtu.rand_default] + ) + ) + def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + + def onp_fun(x): + out = onp.var( + x.astype(lnp.promote_types(onp.float32, dtype)), + axis=axis, + ddof=ddof, + keepdims=keepdims, + ) + return out.astype(out_dtype) + + lnp_fun = partial( + lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims + ) + tol = jtu.tolerance( + out_dtype, + { + onp.float16: 1e-1, + onp.float32: 1e-3, + onp.float64: 1e-3, + onp.complex128: 1e-6, + }, + ) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=True, tol=tol + ) + self._CompileAndCheck( + lnp_fun, + args_maker, + check_dtypes=True, + rtol=tol, + atol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( + shape, dtype, rowvar, ddof, bias + ), + "shape": shape, + "dtype": dtype, + "rowvar": rowvar, + "ddof": ddof, + "bias": bias, + "rng_factory": rng_factory, + } + for shape in [(5,), (10, 5), (5, 10)] + for dtype in all_dtypes + for rowvar in [True, False] + for bias in [True, False] + for ddof in [None, 2, 3] + for rng_factory in [jtu.rand_default] + ) + ) + @jtu.skip_on_devices("gpu") # TODO(b/138003641): test fails on GPU. + @jtu.disable + def testCov(self, shape, dtype, rowvar, ddof, bias, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + onp_fun = partial(onp.cov, rowvar=rowvar, ddof=ddof, bias=bias) + lnp_fun = partial(lnp.cov, rowvar=rowvar, ddof=ddof, bias=bias) + tol = {onp.float32: 1e-5, onp.float64: 1e-13, onp.complex128: 1e-13} + tol = 7e-2 if jtu.device_under_test() == "tpu" else tol + tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) + self._CheckAgainstNumpy( + onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol + ) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, atol=tol, rtol=tol + ) + + def testIssue967(self): + self.assertRaises(TypeError, lambda: lnp.zeros(1.5)) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( + shape, dtype, rowvar, ddof, bias + ), + "shape": shape, + "dtype": dtype, + "rowvar": rowvar, + "ddof": ddof, + "bias": bias, + "rng_factory": rng_factory, + } + for shape in [(5,), (10, 5), (3, 10)] + for dtype in number_dtypes + for rowvar in [True, False] + for bias in [True, False] + for ddof in [None, 2, 3] + for rng_factory in [jtu.rand_default] + ) + ) + @jtu.disable + def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + mat = onp.asarray([rng(shape, dtype)]) + onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) + lnp_fun = partial(lnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) + if not onp.any(onp.isclose(onp.std(mat), 0.0)): + self._CheckAgainstNumpy( + onp_fun, + lnp_fun, + args_maker, + check_dtypes=False, + tol=1e-2 if jtu.device_under_test() == "tpu" else None, + ) + self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_shapes={}_dtype={}_indexing={}_sparse={}".format( + shapes, jtu.dtype_str(dtype), indexing, sparse + ), + "shapes": shapes, + "dtype": dtype, + "indexing": indexing, + "sparse": sparse, + "rng_factory": rng_factory, + } + for shapes in [(), (5,), (5, 3)] + for dtype in number_dtypes + for indexing in ["xy", "ij"] + for sparse in [False] # TODO(nareshmodi): Make sparse work + for rng_factory in [jtu.rand_default] + ) + ) + def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker( + rng, [(x,) for x in shapes], [dtype] * len(shapes) + ) + onp_fun = partial(onp.meshgrid, indexing=indexing, sparse=sparse) + lnp_fun = partial(lnp.meshgrid, indexing=indexing, sparse=sparse) + self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}" + "_retstep={}_dtype={}" + ).format(start_shape, stop_shape, num, endpoint, retstep, dtype), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "retstep": retstep, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + for retstep in [True, False] + for dtype in number_dtypes + + [ + None, + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testLinspace( + self, start_shape, stop_shape, num, endpoint, retstep, dtype, rng_factory + ): + if not endpoint and onp.issubdtype(dtype, onp.integer): + # TODO(b/157597565): Support all dtypes when the tf op supports endpoint + # Currently, subtracting the step early leads to rounding errors for + # integers. + return + rng = rng_factory() + # relax default tolerances slightly + tol = jtu.tolerance(dtype if dtype else onp.float32) * 10 + args_maker = self._GetArgsMaker(rng, [start_shape, stop_shape], [dtype, dtype]) + start, stop = args_maker() + ndim = len(onp.shape(start + stop)) + for axis in range(-ndim, ndim): + lnp_op = lambda start, stop: lnp.linspace( + start, + stop, + num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + onp_op = lambda start, stop: onp.linspace( + start, + stop, + num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + self._CheckAgainstNumpy( + onp_op, lnp_op, args_maker, check_dtypes=False, tol=tol + ) + # floating-point compute between jitted platforms and non-jit + rounding + # cause unavoidable variation in integer truncation for some inputs. + if dtype in ( + inexact_dtypes + + [ + None, + ] + ): + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=False, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}" + "_base={}_dtype={}" + ).format( + start_shape, + stop_shape, + num, + endpoint, + base, + dtype.__name__ if dtype else "None", + ), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "base": base, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + for base in [10.0, 2, onp.e] + for dtype in inexact_dtypes + + [ + None, + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testLogspace( + self, start_shape, stop_shape, num, endpoint, base, dtype, rng_factory + ): + if ( + dtype in int_dtypes + and jtu.device_under_test() in ("gpu", "tpu") + and not FLAGS.enable_x64 + ): + raise unittest.SkipTest( + "GPUx32 truncated exponentiation" + " doesn't exactly match other platforms." + ) + rng = rng_factory() + # relax default tolerances slightly + tol = { + onp.float16: 2e-2, + onp.float32: 1e-2, + onp.float64: 1e-6, + onp.complex64: 1e-3, + onp.complex128: 1e-6, + } + args_maker = self._GetArgsMaker(rng, [start_shape, stop_shape], [dtype, dtype]) + start, stop = args_maker() + ndim = len(onp.shape(start + stop)) + for axis in range(-ndim, ndim): + lnp_op = lambda start, stop: lnp.logspace( + start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis + ) + onp_op = lambda start, stop: onp.logspace( + start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis + ) + self._CheckAgainstNumpy( + onp_op, lnp_op, args_maker, check_dtypes=False, tol=tol + ) + if dtype in ( + inexact_dtypes + + [ + None, + ] + ): + # Why do compiled and op-by-op float16 np.power numbers differ + # slightly more than expected? + atol = {onp.float16: 1e-2} + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=False, + atol=atol, + rtol=tol, + check_incomplete_shape=True, + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": ( + "_start_shape={}_stop_shape={}_num={}_endpoint={}" "_dtype={}" + ).format(start_shape, stop_shape, num, endpoint, dtype), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, + "endpoint": endpoint, + "dtype": dtype, + "rng_factory": rng_factory, + } + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + # NB: numpy's geomspace gives nonsense results on integer types + for dtype in inexact_dtypes + + [ + None, + ] + for rng_factory in [jtu.rand_default] + ) + ) + def testGeomspace(self, start_shape, stop_shape, num, endpoint, dtype, rng_factory): + rng = rng_factory() + # relax default tolerances slightly + tol = {onp.float16: 4e-3, onp.float32: 2e-3, onp.complex128: 1e-14} + + def args_maker(): + """Test the set of inputs onp.geomspace is well-defined on.""" + start, stop = self._GetArgsMaker( + rng, [start_shape, stop_shape], [dtype, dtype] + )() + # onp.geomspace can't handle differently ranked tensors + # w. negative numbers! + start, stop = lnp.broadcast_arrays(start, stop) + if dtype in complex_dtypes: + return start, stop + # to avoid NaNs, non-complex start and stop cannot + # differ in sign, elementwise + start = start * lnp.sign(start) * lnp.sign(stop) + return start, stop + + start, stop = args_maker() + ndim = len(onp.shape(start + stop)) + for axis in range(-ndim, ndim): + + def lnp_op(start, stop): + return lnp.geomspace( + start, stop, num, endpoint=endpoint, dtype=dtype, axis=axis + ) + + def onp_op(start, stop): + start = start.astype(onp.float32) if dtype == lnp.bfloat16 else start + stop = stop.astype(onp.float32) if dtype == lnp.bfloat16 else stop + return onp.geomspace( + start, + stop, + num, + endpoint=endpoint, + dtype=dtype if dtype != lnp.bfloat16 else onp.float32, + axis=axis, + ).astype(dtype) + + self._CheckAgainstNumpy( + onp_op, lnp_op, args_maker, check_dtypes=False, tol=tol + ) + if dtype in ( + inexact_dtypes + + [ + None, + ] + ): + self._CompileAndCheck( + lnp_op, + args_maker, + check_dtypes=False, + atol=tol, + rtol=tol, + check_incomplete_shape=True, + ) + + @jtu.disable + def testDisableNumpyRankPromotionBroadcasting(self): + try: + prev_flag = FLAGS.jax_numpy_rank_promotion + FLAGS.jax_numpy_rank_promotion = "allow" + lnp.ones(2) + lnp.ones((1, 2)) # works just fine + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + try: + prev_flag = FLAGS.jax_numpy_rank_promotion + FLAGS.jax_numpy_rank_promotion = "raise" + self.assertRaises(ValueError, lambda: lnp.ones(2) + lnp.ones((1, 2))) + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + try: + prev_flag = FLAGS.jax_numpy_rank_promotion + FLAGS.jax_numpy_rank_promotion = "warn" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + lnp.ones(2) + lnp.ones((1, 2)) + assert len(w) > 0 + msg = str(w[-1].message) + expected_msg = ( + "Following NumPy automatic rank promotion for add on " + "shapes (2,) (1, 2)." + ) + self.assertEqual(msg[: len(expected_msg)], expected_msg) + + prev_len = len(w) + lnp.ones(2) + 3 + self.assertEqual(len(w), prev_len) # don't want to warn for scalars + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + def testStackArrayArgument(self): + # tests https://github.com/google/jax/issues/1271 + @npe.jit + def foo(x): + return lnp.stack(x) + + foo(onp.zeros(2)) # doesn't crash + + @npe.jit + def foo(x): + return lnp.concatenate(x) + + foo(onp.zeros((2, 2))) # doesn't crash + + @jtu.disable + def testReluGradientConstants(self): + # This is a regression test that verifies that constants associated with the + # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the + # outermost jaxpr. This was producing some large materialized constants for + # every relu activation in a model. + def body(i, xy): + x, y = xy + y = y + jax.grad(lambda z: lnp.sum(lnp.maximum(z, 0.0)))(x) + return x, y + + f = lambda y: lax.fori_loop(0, 5, body, (y, y)) + wrapped = linear_util.wrap_init(f) + pv = partial_eval.PartialVal( + (jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit) + ) + _, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv]) + self.assertFalse( + any( + onp.array_equal(x, onp.full((3, 4), 2.0, dtype=onp.float32)) + for x in consts + ) + ) + + @named_parameters( + { + "testcase_name": "_from={}_to={}".format(from_shape, to_shape), + "rng_factory": rng_factory, + "from_shape": from_shape, + "to_shape": to_shape, + } + for from_shape, to_shape in [ + [(1, 3), (4, 3)], + [(3,), (2, 1, 3)], + [(3,), (3, 3)], + [(1,), (3,)], + ] + for rng_factory in [jtu.rand_default] + ) + def testBroadcastTo(self, from_shape, to_shape, rng_factory): + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [from_shape], [onp.float32]) + onp_op = lambda x: onp.broadcast_to(x, to_shape) + lnp_op = lambda x: lnp.broadcast_to(x, to_shape) + self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) + self._CompileAndCheck( + lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True + ) + + def testBroadcastToIssue1522(self): + self.assertRaisesRegex( + Exception, + "Unable to broadcast", + lambda: lnp.broadcast_to(onp.ones((2, 3)), (1, 3)), + ) + + def testBroadcastToIntIssue1548(self): + self.assertAllClose( + lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)), check_dtypes=False + ) + + def testBroadcastToOnScalar(self): + self.assertIsInstance(lnp.broadcast_to(10.0, ()), lnp.ndarray) + self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray) + + @jtu.disable + def testPrecision(self): + ones_1d = onp.ones((2,)) + ones_2d = onp.ones((2, 2)) + ones_3d = onp.ones((2, 2, 2)) + HIGHEST = lax.Precision.HIGHEST + + jtu.assert_dot_precision(None, lnp.dot, ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.dot, precision=HIGHEST), ones_1d, ones_1d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.dot, precision=HIGHEST), ones_3d, ones_3d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.matmul, precision=HIGHEST), ones_2d, ones_2d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.vdot, precision=HIGHEST), ones_1d, ones_1d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.tensordot, axes=2, precision=HIGHEST), ones_2d, ones_2d + ) + jtu.assert_dot_precision( + HIGHEST, + partial(lnp.tensordot, axes=(0, 0), precision=HIGHEST), + ones_1d, + ones_1d, + ) + jtu.assert_dot_precision( + HIGHEST, + partial(lnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), + ones_1d, + ones_1d, + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.einsum, "i,i", precision=HIGHEST), ones_1d, ones_1d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.einsum, "ij,ij", precision=HIGHEST), ones_2d, ones_2d + ) + jtu.assert_dot_precision( + HIGHEST, partial(lnp.inner, precision=HIGHEST), ones_1d, ones_1d + ) + + @named_parameters( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}_{}_{}".format( + shape, + jtu.dtype_str(key_dtype), + jtu.dtype_str(value_dtype), + dimension, + ).replace(" ", ""), + "shape": shape, + "key_dtype": key_dtype, + "value_dtype": value_dtype, + "dimension": dimension, + "rng_factory": rng_factory, + } + for shape in all_shapes + for key_dtype in minus(number_dtypes, complex_dtypes) + for value_dtype in all_dtypes + for dimension in range(-len(shape), len(shape)) + for rng_factory in [jtu.rand_default] + ) + ) + @new_test + def testSortKeyValue(self, shape, key_dtype, value_dtype, dimension, rng_factory): + def onp_ref(keys, values): + idxs = list(onp.ix_(*[onp.arange(d) for d in keys.shape])) + idxs[dimension] = onp.argsort(keys, axis=dimension) + return keys[tuple(idxs)], values[tuple(idxs)] + + rng = rng_factory() + args_maker = self._GetArgsMaker(rng, [shape, shape], [key_dtype, value_dtype]) + op = partial(npe.sort_key_val, dimension=dimension) + self._CheckAgainstNumpy(onp_ref, op, args_maker, check_dtypes=True) + # sort_key_val requires known rank. + # XLA only has TopKV2 (used by tf.argsort) kernels on those dtypes + # (b/169194137). + check_xla = key_dtype in (onp.uint32, onp.int32, onp.float32, lnp.bfloat16) + self._CompileAndCheck( + op, + args_maker, + check_dtypes=True, + check_incomplete_shape=True, + check_unknown_rank=False, + check_experimental_compile=check_xla, + check_xla_forced_compile=check_xla, + ) + + +# Most grad tests are at the lax level (see lax_test.py), but we add some here +# as needed for e.g. particular compound ops of interest. + +GradTestSpec = collections.namedtuple( + "GradTestSpec", ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"] +) + + +def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): + return GradTestSpec(op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) + + +GRAD_TEST_RECORDS = [ + grad_test_spec( + lnp.arcsinh, + nargs=1, + order=2, + rng_factory=jtu.rand_positive, + dtypes=[onp.float64, onp.complex64], + tol=1e-4, + ), + grad_test_spec( + lnp.arccosh, + nargs=1, + order=2, + rng_factory=jtu.rand_positive, + dtypes=[onp.float64, onp.complex64], + tol=1e-4, + ), + grad_test_spec( + lnp.arctanh, + nargs=1, + order=2, + rng_factory=partial(jtu.rand_uniform, -0.9, 0.9), + dtypes=[onp.float64, onp.complex64], + tol=1e-4, + ), +] + +GradSpecialValuesTestSpec = collections.namedtuple( + "GradSpecialValuesTestSpec", ["op", "values", "order"] +) + +GRAD_SPECIAL_VALUE_TEST_RECORDS = [ + GradSpecialValuesTestSpec(lnp.arcsinh, [0.0, 1000.0], 2), + GradSpecialValuesTestSpec(lnp.arccosh, [1000.0], 2), + GradSpecialValuesTestSpec(lnp.arctanh, [0.0], 2), + # TODO(wangpeng): Add `GradSpecialValuesTestSpec(lnp.sinc, [0.], 1)` +] + + +def num_float_bits(dtype): + return lnp.finfo(dtypes.canonicalize_dtype(dtype)).bits + + +class NumpyGradTests(jtu.TestCase): + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": jtu.format_test_name_suffix( + rec.name, shapes, itertools.repeat(dtype) + ), + "op": rec.op, + "rng_factory": rec.rng_factory, + "shapes": shapes, + "dtype": dtype, + "order": rec.order, + "tol": rec.tol, + } + for shapes in CombosWithReplacement(nonempty_shapes, rec.nargs) + for dtype in rec.dtypes + ) + for rec in GRAD_TEST_RECORDS + ) + ) + @jtu.disable + def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): + rng = rng_factory() + tol = {onp.float32: 1e-1, onp.complex64: 1e-1} + args = tuple(rng(shape, dtype) for shape in shapes) + check_grads(op, args, order, ["fwd", "rev"], tol, tol) + + @named_parameters( + itertools.chain.from_iterable( + jtu.cases_from_list( + { + "testcase_name": "_{}_{}".format(rec.op.__name__, special_value), + "op": rec.op, + "special_value": special_value, + "order": rec.order, + } + for special_value in rec.values + ) + for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS + ) + ) + @jtu.disable + def testOpGradSpecialValue(self, op, special_value, order): + check_grads( + op, (special_value,), order, ["fwd", "rev"], atol={onp.float32: 3e-3} + ) + + @jtu.disable + def testTakeAlongAxisIssue1521(self): + # https://github.com/google/jax/issues/1521 + idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1)) + + def f(x): + y = x * lnp.arange(3.0).reshape((1, 3)) + return lnp.take_along_axis(y, idx, -1).sum() + + check_grads(f, (1.0,), order=1) + + +if __name__ == "__main__": + tf.enable_v2_behavior() + lnp.enable_numpy_behavior() + absltest.main() diff --git a/tests/fastmath/jax/utils.py b/tests/fastmath/jax/utils.py new file mode 100644 index 000000000..713c4b46f --- /dev/null +++ b/tests/fastmath/jax/utils.py @@ -0,0 +1,995 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import re +import sys +import unittest +import warnings +import zlib + +from contextlib import contextmanager +from distutils.util import strtobool +from functools import partial +from typing import Dict, Sequence, Union + +import numpy as onp +import numpy.random as npr +import scipy +import tensorflow.compat.v2 as tf + +from absl.testing import parameterized + +import trax.tf.extensions as npe +import trax.tf.numpy as tf_np + +from tests.fastmath.jax.config import flags + +tree_map = tf.nest.map_structure +tree_multimap = tf.nest.map_structure + + +FLAGS = flags.FLAGS + + +# TODO(wangpeng): Remove this flag after broken tests are fixed +flags.DEFINE_bool("enable_x64", strtobool("False"), "Enable 64-bit types to be used.") + + +flags.DEFINE_enum( + "test_dut", + "", + enum_values=["", "cpu", "gpu", "tpu"], + help="Describes the device under test in case special consideration is required.", +) + + +flags.DEFINE_integer( + "num_generated_cases", 10, help="Number of generated cases to test" +) + + +EPS = 1e-4 + + +# Default dtypes corresponding to Python scalars. +python_scalar_dtypes = { + bool: onp.dtype(onp.bool_), + int: onp.dtype(onp.int_), + float: onp.dtype(onp.float_), + complex: onp.dtype(onp.complex_), +} + + +def _dtype(x): + if isinstance(x, tf.Tensor): + return x.dtype.as_numpy_dtype + return ( + getattr(x, "dtype", None) + or onp.dtype(python_scalar_dtypes.get(type(x), None)) + or onp.asarray(x).dtype + ) + + +def is_sequence(x): + try: + iter(x) + except TypeError: + return False + else: + return True + + +_default_tolerance = { + onp.dtype(onp.bool_): 0, + onp.dtype(onp.int8): 0, + onp.dtype(onp.int16): 0, + onp.dtype(onp.int32): 0, + onp.dtype(onp.int64): 0, + onp.dtype(onp.uint8): 0, + onp.dtype(onp.uint16): 0, + onp.dtype(onp.uint32): 0, + onp.dtype(onp.uint64): 0, + # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-2, + onp.dtype(onp.float16): 1e-3, + onp.dtype(onp.float32): 1e-6, + onp.dtype(onp.float64): 1e-15, + onp.dtype(onp.complex64): 1e-6, + onp.dtype(onp.complex128): 1e-15, +} + + +def default_tolerance(): + return _default_tolerance + + +default_gradient_tolerance = { + # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-1, + onp.dtype(onp.float16): 1e-2, + onp.dtype(onp.float32): 2e-3, + onp.dtype(onp.float64): 1e-5, + onp.dtype(onp.complex64): 1e-3, + onp.dtype(onp.complex128): 1e-5, +} + + +def _assert_numpy_allclose(a, b, atol=None, rtol=None): + # TODO(b/154768983): + # a = a.astype(onp.float32) if a.dtype == dtypes.bfloat16 else a + # b = b.astype(onp.float32) if b.dtype == dtypes.bfloat16 else b + kw = {} + if atol: + kw["atol"] = atol + if rtol: + kw["rtol"] = rtol + onp.testing.assert_allclose(a, b, **kw) + + +def tolerance(dtype, tol=None): + tol = {} if tol is None else tol + if not isinstance(tol, dict): + return tol + tol = {onp.dtype(key): value for key, value in tol.items()} + dtype = onp.dtype(dtype) + return tol.get(dtype, default_tolerance()[dtype]) + + +def _normalize_tolerance(tol): + tol = tol or 0 + if isinstance(tol, dict): + return {onp.dtype(k): v for k, v in tol.items()} + else: + return {k: tol for k in _default_tolerance} + + +def join_tolerance(tol1, tol2): + tol1 = _normalize_tolerance(tol1) + tol2 = _normalize_tolerance(tol2) + out = tol1 + for k, v in tol2.items(): + out[k] = max(v, tol1.get(k, 0)) + return out + + +def _assert_numpy_close(a, b, atol=None, rtol=None): + assert a.shape == b.shape + atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol)) + rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol)) + _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size) + + +def check_eq(xs, ys): + tree_all(tree_multimap(_assert_numpy_allclose, xs, ys)) + + +def check_close(xs, ys, atol=None, rtol=None): + assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol) + tree_all(tree_multimap(assert_close, xs, ys)) + + +def inner_prod(xs, ys): + def contract(x, y): + return onp.real(onp.dot(onp.conj(x).reshape(-1), y.reshape(-1))) + + return tree_reduce(onp.add, tree_multimap(contract, xs, ys)) + + +add = partial(tree_multimap, lambda x, y: onp.add(x, y, dtype=_dtype(x))) +sub = partial(tree_multimap, lambda x, y: onp.subtract(x, y, dtype=_dtype(x))) +conj = partial(tree_map, lambda x: onp.conj(x, dtype=_dtype(x))) + + +def scalar_mul(xs, a): + return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs) + + +def rand_like(rng, x): + shape = onp.shape(x) + dtype = _dtype(x) + randn = lambda: onp.asarray(rng.randn(*shape), dtype=dtype) + if onp.issubdtype(dtype, onp.complexfloating): + return randn() + dtype.type(1.0j) * randn() + else: + return randn() + + +def numerical_jvp(f, primals, tangents, eps=EPS): + delta = scalar_mul(tangents, eps) + f_pos = f(*add(primals, delta)) + f_neg = f(*sub(primals, delta)) + return scalar_mul(sub(f_pos, f_neg), 0.5 / eps) + + +def _merge_tolerance(tol, default): + if tol is None: + return default + if not isinstance(tol, dict): + return tol + out = default.copy() + for k, v in tol.items(): + out[onp.dtype(k)] = v + return out + + +def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS): + atol = _merge_tolerance(atol, default_gradient_tolerance) + rtol = _merge_tolerance(rtol, default_gradient_tolerance) + rng = onp.random.RandomState(0) + tangent = tree_map(partial(rand_like, rng), args) + v_out, t_out = f_jvp(args, tangent) + v_out_expected = f(*args) + t_out_expected = numerical_jvp(f, args, tangent, eps=eps) + # In principle we should expect exact equality of v_out and v_out_expected, + # but due to nondeterminism especially on GPU (e.g., due to convolution + # autotuning) we only require "close". + check_close(v_out, v_out_expected, atol=atol, rtol=rtol) + check_close(t_out, t_out_expected, atol=atol, rtol=rtol) + + +def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS): + atol = _merge_tolerance(atol, default_gradient_tolerance) + rtol = _merge_tolerance(rtol, default_gradient_tolerance) + _rand_like = partial(rand_like, onp.random.RandomState(0)) + v_out, vjpfun = f_vjp(*args) + v_out_expected = f(*args) + check_close(v_out, v_out_expected, atol=atol, rtol=rtol) + tangent = tree_map(_rand_like, args) + tangent_out = numerical_jvp(f, args, tangent, eps=eps) + cotangent = tree_map(_rand_like, v_out) + cotangent_out = conj(vjpfun(conj(cotangent))) + ip = inner_prod(tangent, cotangent_out) + ip_expected = inner_prod(tangent_out, cotangent) + check_close(ip, ip_expected, atol=atol, rtol=rtol) + + +def device_under_test(): + return FLAGS.test_dut + + +def if_device_under_test(device_type: Union[str, Sequence[str]], if_true, if_false): + """Chooses `if_true` of `if_false` based on device_under_test.""" + if device_under_test() in ( + [device_type] if isinstance(device_type, str) else device_type + ): + return if_true + else: + return if_false + + +def supported_dtypes(): + if device_under_test() == "tpu": + return { + onp.bool_, + onp.int32, + onp.uint32, + dtypes.bfloat16, + onp.float32, + onp.complex64, + } + else: + return { + onp.bool_, + onp.int8, + onp.int16, + onp.int32, + onp.int64, + onp.uint8, + onp.uint16, + onp.uint32, + onp.uint64, + dtypes.bfloat16, + onp.float16, + onp.float32, + onp.float64, + onp.complex64, + onp.complex128, + } + + +def skip_if_unsupported_type(dtype): + if dtype not in supported_dtypes(): + raise unittest.SkipTest(f"Type {dtype} not supported on {device_under_test()}") + + +def skip_on_devices(*disabled_devices): + """A decorator for test methods to skip the test on certain devices.""" + + def skip(test_method): + @functools.wraps(test_method) + def test_method_wrapper(self, *args, **kwargs): + device = device_under_test() + if device in disabled_devices: + test_name = getattr(test_method, "__name__", "[unknown test]") + raise unittest.SkipTest( + f"{test_name} not supported on {device.upper()}." + ) + return test_method(self, *args, **kwargs) + + return test_method_wrapper + + return skip + + +def skip_on_flag(flag_name, skip_value): + """A decorator for test methods to skip the test when flags are set.""" + + def skip(test_method): # pylint: disable=missing-docstring + @functools.wraps(test_method) + def test_method_wrapper(self, *args, **kwargs): + flag_value = getattr(FLAGS, flag_name) + if flag_value == skip_value: + test_name = getattr(test_method, "__name__", "[unknown test]") + raise unittest.SkipTest( + f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}" + ) + return test_method(self, *args, **kwargs) + + return test_method_wrapper + + return skip + + +# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432 +# Delete this code after the minimum jaxlib version is 0.1.46 or greater. +skip_on_mac_linalg_bug = partial( + unittest.skipIf, + ( + sys.platform == "darwin" + and scipy.version.version > "1.1.0" + and lib.version < (0, 1, 46) + ), + "Test fails on Mac with new scipy (issue #432)", +) + + +def format_test_name_suffix(opname, shapes, dtypes): + arg_descriptions = ( + format_shape_dtype_string(shape, dtype) for shape, dtype in zip(shapes, dtypes) + ) + return "{}_{}".format(opname.capitalize(), "_".join(arg_descriptions)) + + +# We use special symbols, represented as singleton objects, to distinguish +# between NumPy scalars, Python scalars, and 0-D arrays. +class ScalarShape: + def __len__(self): + return 0 + + def __getitem__(self, i): + raise IndexError(f"index {i} out of range.") + + +class _NumpyScalar(ScalarShape): + pass + + +class _PythonScalar(ScalarShape): + pass + + +NUMPY_SCALAR_SHAPE = _NumpyScalar() +PYTHON_SCALAR_SHAPE = _PythonScalar() + + +def _dims_of_shape(shape): + """Converts `shape` to a tuple of dimensions.""" + if type(shape) in (list, tuple): + return shape + elif isinstance(shape, ScalarShape): + return () + else: + raise TypeError(type(shape)) + + +def _cast_to_shape(value, shape, dtype): + """Casts `value` to the correct Python type for `shape` and `dtype`.""" + if shape is NUMPY_SCALAR_SHAPE: + # explicitly cast to NumPy scalar in case `value` is a Python scalar. + return onp.dtype(dtype).type(value) + elif shape is PYTHON_SCALAR_SHAPE: + # explicitly cast to Python scalar via https://stackoverflow.com/a/11389998 + return onp.asarray(value).item() + elif type(shape) in (list, tuple): + assert onp.shape(value) == tuple(shape) + return value + else: + raise TypeError(type(shape)) + + +def dtype_str(dtype): + return onp.dtype(dtype).name + + +def format_shape_dtype_string(shape, dtype): + if shape is NUMPY_SCALAR_SHAPE: + return dtype_str(dtype) + elif shape is PYTHON_SCALAR_SHAPE: + return "py" + dtype_str(dtype) + elif type(shape) in (list, tuple): + shapestr = ",".join(str(dim) for dim in shape) + return "{}[{}]".format(dtype_str(dtype), shapestr) + elif type(shape) is int: + return "{}[{},]".format(dtype_str(dtype), shape) + elif isinstance(shape, onp.ndarray): + return "{}[{}]".format(dtype_str(dtype), shape) + else: + raise TypeError(type(shape)) + + +def _rand_dtype(rand, shape, dtype, scale=1.0, post=lambda x: x): + """Produce random values given shape, dtype, scale, and post-processor. + + Args: + rand: a function for producing random values of a given shape, e.g. a + bound version of either onp.RandomState.randn or onp.RandomState.rand. + shape: a shape value as a tuple of positive integers. + dtype: a numpy dtype. + scale: optional, a multiplicative scale for the random values (default 1). + post: optional, a callable for post-processing the random values (default + identity). + + Returns: + An ndarray of the given shape and dtype using random values based on a call + to rand but scaled, converted to the appropriate dtype, and post-processed. + """ + r = lambda: onp.asarray(scale * rand(*_dims_of_shape(shape)), dtype) + if onp.issubdtype(dtype, onp.complexfloating): + vals = r() + 1.0j * r() + else: + vals = r() + return _cast_to_shape(onp.asarray(post(vals), dtype), shape, dtype) + + +def rand_default(scale=3): + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=scale) + + +def rand_nonzero(): + post = lambda x: onp.where(x == 0, onp.array(1, dtype=x.dtype), x) + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=3, post=post) + + +def rand_positive(): + post = lambda x: x + 1 + rand = npr.RandomState(0).rand + return partial(_rand_dtype, rand, scale=2, post=post) + + +def rand_small(): + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=1e-3) + + +def rand_not_small(offset=10.0): + post = lambda x: x + onp.where(x > 0, offset, -offset) + randn = npr.RandomState(0).randn + return partial(_rand_dtype, randn, scale=3.0, post=post) + + +def rand_small_positive(): + rand = npr.RandomState(0).rand + return partial(_rand_dtype, rand, scale=2e-5) + + +def rand_uniform(low=0.0, high=1.0): + assert low < high + rand = npr.RandomState(0).rand + post = lambda x: x * (high - low) + low + return partial(_rand_dtype, rand, post=post) + + +def rand_some_equal(): + randn = npr.RandomState(0).randn + rng = npr.RandomState(0) + + def post(x): + x_ravel = x.ravel() + if len(x_ravel) == 0: + return x + flips = rng.rand(*onp.shape(x)) < 0.5 + return onp.where(flips, x_ravel[0], x) + + return partial(_rand_dtype, randn, scale=100.0, post=post) + + +def rand_some_inf(): + """Return a random sampler that produces infinities in floating types.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + """ + TODO: Complex numbers are not correctly tested + If blocks should be switched in order, and relevant tests should be fixed + """ + + def rand(shape, dtype): + """The random sampler function.""" + if not onp.issubdtype(dtype, onp.floating): + # only float types have inf + return base_rand(shape, dtype) + + if onp.issubdtype(dtype, onp.complexfloating): + base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype + out = rand(shape, base_dtype) + onp.array(1j, dtype) * rand( + shape, base_dtype + ) + return _cast_to_shape(out, shape, dtype) + + dims = _dims_of_shape(shape) + posinf_flips = rng.rand(*dims) < 0.1 + neginf_flips = rng.rand(*dims) < 0.1 + + vals = base_rand(shape, dtype) + vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) + vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +def rand_some_nan(): + """Return a random sampler that produces nans in floating types.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + def rand(shape, dtype): + """The random sampler function.""" + if onp.issubdtype(dtype, onp.complexfloating): + base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype + out = rand(shape, base_dtype) + onp.array(1j, dtype) * rand( + shape, base_dtype + ) + return _cast_to_shape(out, shape, dtype) + + if not onp.issubdtype(dtype, onp.floating): + # only float types have inf + return base_rand(shape, dtype) + + dims = _dims_of_shape(shape) + nan_flips = rng.rand(*dims) < 0.1 + + vals = base_rand(shape, dtype) + vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +def rand_some_inf_and_nan(): + """Return a random sampler that produces infinities in floating types.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + """ + TODO: Complex numbers are not correctly tested + If blocks should be switched in order, and relevant tests should be fixed + """ + + def rand(shape, dtype): + """The random sampler function.""" + if not onp.issubdtype(dtype, onp.floating): + # only float types have inf + return base_rand(shape, dtype) + + if onp.issubdtype(dtype, onp.complexfloating): + base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype + out = rand(shape, base_dtype) + onp.array(1j, dtype) * rand( + shape, base_dtype + ) + return _cast_to_shape(out, shape, dtype) + + dims = _dims_of_shape(shape) + posinf_flips = rng.rand(*dims) < 0.1 + neginf_flips = rng.rand(*dims) < 0.1 + nan_flips = rng.rand(*dims) < 0.1 + + vals = base_rand(shape, dtype) + vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) + vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) + vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +# TODO(mattjj): doesn't handle complex types +def rand_some_zero(): + """Return a random sampler that produces some zeros.""" + rng = npr.RandomState(1) + base_rand = rand_default() + + def rand(shape, dtype): + """The random sampler function.""" + dims = _dims_of_shape(shape) + zeros = rng.rand(*dims) < 0.5 + + vals = base_rand(shape, dtype) + vals = onp.where(zeros, onp.array(0, dtype=dtype), vals) + + return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) + + return rand + + +def rand_int(low, high=None): + randint = npr.RandomState(0).randint + + def fn(shape, dtype): + return randint(low, high=high, size=shape, dtype=dtype) + + return fn + + +def rand_unique_int(): + randchoice = npr.RandomState(0).choice + + def fn(shape, dtype): + return randchoice( + onp.arange(onp.prod(shape), dtype=dtype), size=shape, replace=False + ) + + return fn + + +def rand_bool(): + rng = npr.RandomState(0) + + def generator(shape, dtype): + return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype) + + return generator + + +def check_raises(thunk, err_type, msg): + try: + thunk() + assert False + except err_type as e: + assert str(e).startswith(msg), "\n{}\n\n{}\n".format(e, msg) + + +def check_raises_regexp(thunk, err_type, pattern): + try: + thunk() + assert False + except err_type as e: + assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern) + + +def _iter_eqns(jaxpr): + # TODO(necula): why doesn't this search in params? + for eqn in jaxpr.eqns: + yield eqn + for subjaxpr in core.subjaxprs(jaxpr): + yield from _iter_eqns(subjaxpr) + + +def assert_dot_precision(expected_precision, fun, *args): + jaxpr = api.make_jaxpr(fun)(*args) + precisions = [ + eqn.params["precision"] + for eqn in _iter_eqns(jaxpr.jaxpr) + if eqn.primitive == lax.dot_general_p + ] + for precision in precisions: + msg = "Unexpected precision: {} != {}".format(expected_precision, precision) + assert precision == expected_precision, msg + + +_CACHED_INDICES: Dict[int, Sequence[int]] = {} + + +def cases_from_list(xs): + xs = list(xs) + n = len(xs) + k = min(n, FLAGS.num_generated_cases) + # Random sampling for every parameterized test is expensive. Do it once and + # cache the result. + indices = _CACHED_INDICES.get(n) + if indices is None: + rng = npr.RandomState(42) + _CACHED_INDICES[n] = indices = rng.permutation(n) + return [xs[i] for i in indices[:k]] + + +def cases_from_gens(*gens): + sizes = [1, 3, 10] + cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1 + for size in sizes: + for i in range(cases_per_size): + yield ("_{}_{}".format(size, i),) + tuple(gen(size) for gen in gens) + + +def to_np(a): + return tf.nest.map_structure(tf_np.asarray, a) + + +def to_tf_fn(f): + return lambda *args: f(*to_np(args)) + + +class TestCase(parameterized.TestCase): + """Base class for tests including numerical checks and boilerplate.""" + + # copied from jax.test_util + def setUp(self): + super().setUp() + self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) + + # copied from jax.test_util + def rng(self): + return self._rng + + # TODO(mattjj): this obscures the error messages from failures, figure out how + # to re-enable it + # def tearDown(self) -> None: + # assert core.reset_trace_state() + + def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None): + """Assert that x and y are close (up to numerical tolerances).""" + self.assertEqual(x.shape, y.shape) + atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) + rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) + + _assert_numpy_allclose(x, y, atol=atol, rtol=rtol) + + if check_dtypes: + self.assertDtypesMatch(x, y) + + def assertDtypesMatch(self, x, y): + if FLAGS.enable_x64: + self.assertEqual(_dtype(x), _dtype(y)) + + def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None): + """Assert that x and y, either arrays or nested tuples/lists, are close.""" + if isinstance(x, dict): + self.assertIsInstance(y, dict) + self.assertEqual(set(x.keys()), set(y.keys())) + for k in x: + self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol) + elif is_sequence(x) and not hasattr(x, "__array__"): + self.assertTrue(is_sequence(y) and not hasattr(y, "__array__")) + self.assertEqual(len(x), len(y)) + for x_elt, y_elt in zip(x, y): + self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol) + elif hasattr(x, "__array__") or onp.isscalar(x): + self.assertTrue(hasattr(y, "__array__") or onp.isscalar(y)) + if check_dtypes: + self.assertDtypesMatch(x, y) + x = onp.asarray(x) + y = onp.asarray(y) + self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol) + elif x == y: + return + else: + raise TypeError((type(x), type(y))) + + def assertMultiLineStrippedEqual(self, expected, what): + """Asserts two strings are equal, after stripping each line.""" + ignore_space_re = re.compile(r"\s*\n\s*") + expected_clean = re.sub(ignore_space_re, "\n", expected.strip()) + what_clean = re.sub(ignore_space_re, "\n", what.strip()) + self.assertMultiLineEqual( + expected_clean, + what_clean, + msg="Found\n{}\nExpecting\n{}".format(what, expected), + ) + + def _CheckAgainstNumpy( + self, numpy_reference_op, lax_op, args_maker, check_dtypes=True, tol=None + ): + args = args_maker() + lax_ans = lax_op(*args) + numpy_ans = numpy_reference_op(*args) + self.assertAllClose( + numpy_ans, lax_ans, check_dtypes=check_dtypes, atol=tol, rtol=tol + ) + + def _CompileAndCheck( + self, + fun, + args_maker, + check_dtypes=True, + rtol=None, + atol=None, + check_eval_on_shapes=True, + check_incomplete_shape=True, + check_unknown_rank=True, + static_argnums=(), + check_experimental_compile=True, + check_xla_forced_compile=True, + ): + """Compiles the function and checks the results. + + Args: + fun: the function to be checked. + args_maker: a callable that returns a tuple which will be used as the + positional arguments. + check_dtypes: whether to check that the result dtypes from non-compiled + and compiled runs agree. + rtol: relative tolerance for allclose assertions. + atol: absolute tolerance for allclose assertions. + check_eval_on_shapes: whether to run `eval_on_shapes` on the function and + check that the result shapes and dtypes are correct. + check_incomplete_shape: whether to check that the function can handle + incomplete shapes (including those with and without a known rank). + check_unknown_rank: (only has effect when check_incomplete_shape is True) + whether to check that the function can handle unknown ranks. + static_argnums: indices of arguments to be treated as static arguments for + `jit` and `eval_on_shapes`. + check_experimental_compile: whether to check compilation with + experimental_compile=True (in addition to compilation without the flag). + check_xla_forced_compile: whether to check compilation with + forced_compile=True (in addition to compilation without the flag). This + flag is different from experimental_compile because it enforces + whole-function compilation while the latter doesn't. TPU requires + whole-function compilation. + """ + args = args_maker() + + for x in args: + if not hasattr(x, "dtype"): + # If there is a input that doesn't have dtype info, jit and + # eval_on_shapes may pick a different dtype for it than numpy, so we + # skip the dtype check. + check_dtypes = False + + python_ans = fun(*args) + + python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans) + onp_shapes = tf.nest.map_structure( + lambda x: onp.shape(onp.asarray(x)), python_ans + ) + self.assertEqual(python_shapes, onp_shapes) + + def check_compile(**kwargs): + # `wrapped_fun` and `python_should_be_executing` are used to check that + # when the jitted function is called the second time, the original Python + # function won't be executed. + def wrapped_fun(*args): + self.assertTrue(python_should_be_executing) + return fun(*args) + + cfun = npe.jit(wrapped_fun, static_argnums=static_argnums, **kwargs) + python_should_be_executing = True + monitored_ans = cfun(*args) + + python_should_be_executing = False + compiled_ans = cfun(*args) + + self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) + self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + + # Run `cfun` with a different set of arguments to check that changing + # arguments won't cause recompilation. + + new_args = args_maker() + + skip_retracing_test = False + for old, new in zip(tf.nest.flatten(args), tf.nest.flatten(new_args)): + if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new): + # If the old and new arguments result in different dtypes (because + # they fall into different value ranges), tf-numpy will retrace, so we + # skip the no-retrace test. + skip_retracing_test = True + + if not skip_retracing_test: + python_should_be_executing = True + new_python_ans = fun(*new_args) + python_should_be_executing = False + compiled_ans = cfun(*new_args) + self.assertAllClose( + new_python_ans, compiled_ans, check_dtypes, atol, rtol + ) + + check_compile() + if check_experimental_compile: + check_compile(experimental_compile=True) + if check_xla_forced_compile: + check_compile(xla_forced_compile=True) + + if check_eval_on_shapes: + # Check that npe.eval_on_shapes can get complete output shapes given + # complete input shapes. + cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums) + compiled_ans = cfun(*args) + flat_python_ans = tf.nest.flatten(python_ans) + flat_compiled_ans = tf.nest.flatten(compiled_ans) + self.assertEqual(len(flat_python_ans), len(flat_compiled_ans)) + for a, b in zip(flat_python_ans, flat_compiled_ans): + if hasattr(a, "shape"): + self.assertEqual(a.shape, b.shape) + if check_dtypes and hasattr(a, "dtype"): + self.assertEqual(tf.as_dtype(a.dtype), b.dtype) + + # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we + # skip incomplete-shape checks, since shape specs need dtype. It's OK to + # skip since the same incomplete-shape checks will run for []-shaped arrays. + if check_incomplete_shape and all(hasattr(x, "dtype") for x in args): + # Check partial shapes with known ranks. + # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not + # `shape`. + if all(hasattr(x, "shape") for x in args): + specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args] + cfun = npe.jit( + fun, static_argnums=static_argnums, input_signature=specs + ) + compiled_ans = cfun(*args) + self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + + if check_unknown_rank: + # Check unknown ranks. + specs = [tf.TensorSpec(None, x.dtype) for x in args] + cfun = npe.jit( + fun, static_argnums=static_argnums, input_signature=specs + ) + compiled_ans = cfun(*args) + self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) + + def check_grads(self, f, args, atol=None, rtol=None, delta=None): + """Check gradients against finite differences. + + Args: + f: function to check at ``f(*args)``. + args: a list or tuple of argument values. + atol: absolute tolerance for gradient equality. + rtol: relative tolerance for gradient equality. + delta: step size used for finite differences. + """ + if delta is None: + # Optimal stepsize for central difference is O(epsilon^{1/3}). + dtype = tf_np.result_type(*args) + epsilon = onp.finfo(dtype).eps + delta = epsilon ** (1.0 / 3.0) + theoretical, numerical = tf.test.compute_gradient( + to_tf_fn(f), args, delta=delta + ) + self.assertAllClose( + theoretical, numerical, check_dtypes=False, atol=atol, rtol=rtol + ) + + +@contextmanager +def ignore_warning(**kw): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", **kw) + yield + + +def disable(_): + def wrapper(self, *args, **kwargs): + self.skipTest("Test is disabled") + + return wrapper diff --git a/tests/fastmath/jax/vmap_test.py b/tests/fastmath/jax/vmap_test.py new file mode 100644 index 000000000..f870da19f --- /dev/null +++ b/tests/fastmath/jax/vmap_test.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +import numpy as np +import tensorflow.compat.v2 as tf + +from absl.testing import parameterized +from tensorflow.python.ops.numpy_ops import ( + np_math_ops, +) + +import trax.tf.numpy as tf_np + +from trax.tf import extensions + + +class VmapTest(tf.test.TestCase, parameterized.TestCase): + def test_vmap_in_axes_list(self): + # https://github.com/google/jax/issues/2367 + dictionary = {"a": 5.0, "b": tf_np.ones(2)} + x = tf_np.zeros(3) + y = tf_np.arange(3.0) + + def f(dct, x, y): + return dct["a"] + dct["b"] + x + y + + out1 = extensions.vmap(f, (None, 0, 0))(dictionary, x, y) + out2 = extensions.vmap(f, [None, 0, 0])(dictionary, x, y) + self.assertAllClose(out1, out2) + + def test_vmap_in_axes_tree_prefix_error(self): + # https://github.com/google/jax/issues/795 + self.assertRaisesRegex( + ValueError, + "vmap in_axes specification must be a tree prefix of the corresponding " + r"value, got specification \(0, 0\) for value tree ", + lambda: extensions.vmap(lambda x: x, in_axes=(0, 0))(tf_np.ones(3)), + ) + + def test_vmap_in_axes_leaf_types(self): + with self.assertRaisesRegex( + TypeError, r"vmap in_axes must be an int, None, or .*" + ): + extensions.vmap(lambda x: x, in_axes=(tf_np.array([1.0, 2.0]),))( + tf_np.array([1.0, 2.0]) + ) + + def test_vmap_out_axes_leaf_types(self): + with self.assertRaisesRegex( + TypeError, r"vmap out_axes must be an int, None, or .*" + ): + extensions.vmap(lambda x: x, out_axes=(tf_np.array([1.0, 2.0]),))( + tf_np.array([1.0, 2.0]) + ) + + def test_vmap_unbatched_object_passthrough_issue_183(self): + # https://github.com/google/jax/issues/183 + fun = lambda f, x: f(x) + vfun = extensions.vmap(fun, (None, 0)) + ans = vfun(lambda x: x + 1, tf_np.arange(3)) + self.assertAllClose(ans, np.arange(1, 4)) + + def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): + # https://github.com/google/jax/issues/705 + with self.assertRaisesRegex( + ValueError, "vmap must have at least one non-None value in in_axes" + ): + # If the output is mapped, there must be a non-None in_axes + extensions.vmap(lambda x: x, in_axes=None)(tf_np.array([1.0, 2.0])) + + # Error is: TypeError: only integer scalar arrays can be converted to a + # scalar index + with self.assertRaisesRegex( + ValueError, + "vmap out_axes specification must be a tree prefix of the " + "corresponding value.*", + ): + extensions.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))( + tf_np.array([1.0, 2.0]) + ) + + def test_vmap_structured_in_axes(self): + a, b, c, d = 2, 3, 4, 5 + k = 6 # batch size + x = np.ones((k, a, b)) # batch axis in different locations + y = np.ones((b, k, c)) + z = np.ones((c, d, k)) + + def foo(tree_arg): + x, (y, z) = tree_arg + return tf_np.dot(x, tf_np.dot(y, z)) + + tree = (x, (y, z)) + vfoo = extensions.vmap(foo, in_axes=((0, (1, 2)),)) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + Point = collections.namedtuple("Point", ["x", "y"]) + tree = (x, Point(y, z)) + vfoo = extensions.vmap(foo, in_axes=((0, Point(1, 2)),)) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + def foo2(tree_arg): + x, dct = tree_arg + y, z = dct["a"], dct["b"] + return tf_np.dot(x, tf_np.dot(y, z)) + + tree = (x, {"a": y, "b": z}) + vfoo = extensions.vmap(foo2, in_axes=((0, {"a": 1, "b": 2}),)) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + tree = (x, collections.OrderedDict([("a", y), ("b", z)])) + vfoo = extensions.vmap( + foo2, in_axes=((0, collections.OrderedDict([("a", 1), ("b", 2)])),) + ) + self.assertEqual(vfoo(tree).shape, (6, 2, 5)) + + def test_vmap_out_axes(self): + f = extensions.vmap(lambda x: x, out_axes=0) + inp = tf_np.arange(6).reshape([2, 3]) + self.assertAllClose(inp, f(inp)) + self.assertAllClose([inp, inp], f((inp, inp))) + + f = extensions.vmap(lambda x: x, out_axes=-1) + self.assertAllClose(inp.T, f(inp)) + + f = extensions.vmap(lambda x: x, out_axes=None) + self.assertAllClose(inp[0], f(inp)) + + f = extensions.vmap(lambda x: x, out_axes=([0], (-1, None), {"a": 1})) + a, b, c = f(([inp], (inp, inp), {"a": inp})) + self.assertAllClose([inp], a) + self.assertAllClose((inp.T, inp[0]), b) + self.assertAllClose(inp.T, c["a"]) + + def test_negative_axes(self): + x = np.arange(3 * 4 * 5).reshape(3, 4, 5) + self.assertAllClose( + extensions.vmap(tf_np.sum, in_axes=-3)(x), tf_np.sum(x, axis=(1, 2)) + ) + self.assertAllClose( + extensions.vmap(tf_np.sum, in_axes=-2)(x), tf_np.sum(x, axis=(0, 2)) + ) + self.assertAllClose( + extensions.vmap(tf_np.sum, in_axes=-1)(x), tf_np.sum(x, axis=(0, 1)) + ) + + identity = lambda y: y + self.assertAllClose(x, extensions.vmap(identity, in_axes=0, out_axes=-3)(x)) + self.assertAllClose( + x.transpose(1, 0, 2), extensions.vmap(identity, in_axes=0, out_axes=-2)(x) + ) + self.assertAllClose( + x.transpose(1, 2, 0), extensions.vmap(identity, in_axes=0, out_axes=-1)(x) + ) + + self.assertAllClose( + np.full((5,), 7), + extensions.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -1))( + np.arange(5), 7 + )[1], + ) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + np_math_ops.enable_numpy_methods_on_tensor() + tf.test.main() diff --git a/tests/fastmath/ops_test.py b/tests/fastmath/ops_test.py new file mode 100644 index 000000000..8643e3a64 --- /dev/null +++ b/tests/fastmath/ops_test.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.fastmath.ops.""" + +import collections + +import gin +import jax.numpy as jnp +import numpy as onp + +from absl.testing import parameterized +from tensorflow import test + +from trax import fastmath + +_TestNamedtuple = collections.namedtuple("_TestNamedtuple", ["x"]) + + +class BackendTest(test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def override_gin(self, bindings): + gin.parse_config_files_and_bindings(None, bindings) + + def test_backend_imports_correctly(self): + backend = fastmath.backend() + self.assertEqual(jnp, backend["np"]) + self.assertNotEqual(onp, backend["np"]) + + self.override_gin("backend.name = 'numpy'") + + backend = fastmath.backend() + self.assertNotEqual(jnp, backend["np"]) + self.assertEqual(onp, backend["np"]) + + def test_backend_can_be_set(self): + self.assertEqual(fastmath.backend_name(), "jax") + fastmath.set_backend("tensorflow-numpy") + self.assertEqual(fastmath.backend_name(), "tensorflow-numpy") + fastmath.set_backend(None) + self.assertEqual(fastmath.backend_name(), "jax") + + def test_numpy_backend_delegation(self): + # Assert that we are getting JAX's numpy backend. + backend = fastmath.backend() + numpy = fastmath.numpy + self.assertEqual(jnp, backend["np"]) + + # Assert that `numpy` calls the appropriate gin configured functions and + # properties. + self.assertTrue(numpy.isinf(numpy.inf)) + self.assertEqual(jnp.isinf, numpy.isinf) + self.assertEqual(jnp.inf, numpy.inf) + + # Assert that we will now get the pure numpy backend. + + self.override_gin("backend.name = 'numpy'") + + backend = fastmath.backend() + numpy = fastmath.numpy + self.assertEqual(onp, backend["np"]) + + # Assert that `numpy` calls the appropriate gin configured functions and + # properties. + self.assertTrue(numpy.isinf(numpy.inf)) + self.assertEqual(onp.isinf, numpy.isinf) + self.assertEqual(onp.inf, numpy.inf) + + @parameterized.named_parameters( + ("_" + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP) + ) + def test_fori_loop(self, backend): + with fastmath.use_backend(backend): + res = fastmath.fori_loop(2, 5, lambda i, x: x + i, 1) + self.assertEqual(res, 1 + 2 + 3 + 4) + + def test_nested_map(self): + inp = {"a": ([0, 1], 2), "b": _TestNamedtuple(3)} + out = {"a": ([1, 2], 3), "b": _TestNamedtuple(4)} + self.assertEqual(fastmath.nested_map(lambda x: x + 1, inp), out) + + def test_nested_stack(self): + inp = [ + {"a": ([0, 1], 2), "b": _TestNamedtuple(3)}, + {"a": ([1, 2], 3), "b": _TestNamedtuple(4)}, + ] + out = {"a": ([[0, 1], [1, 2]], [2, 3]), "b": _TestNamedtuple([3, 4])} + onp.testing.assert_equal(fastmath.nested_stack(inp), out) + + def test_names_match(self): + # Names match up. + for backend_enum, backend_obj in fastmath.ops._backend_dict.items(): + self.assertEqual(backend_enum.value, backend_obj["name"]) + + # Every backend appears in the dictionary. + for backend_enum in fastmath.ops.Backend: + self.assertIn(backend_enum, fastmath.ops._backend_dict) + + def test_use_backend_str(self): + with fastmath.use_backend("tensorflow-numpy"): + self.assertEqual(fastmath.backend_name(), "tensorflow-numpy") + + def test_use_backend_enum(self): + with fastmath.use_backend(fastmath.Backend.NUMPY): + self.assertEqual(fastmath.backend_name(), "numpy") + + +if __name__ == "__main__": + test.main() diff --git a/tests/layers/acceleration_test.py b/tests/layers/acceleration_test.py new file mode 100644 index 000000000..fcd135c54 --- /dev/null +++ b/tests/layers/acceleration_test.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for acceleration.""" + +import numpy as np + +from absl.testing import absltest + +from tests.fastmath.jax import config +from trax import fastmath +from trax import layers as tl +from trax.utils import shapes + + +class AccelerationTest(absltest.TestCase): + def test_accelerated_same_result(self): + layer = tl.Dense(2) + x = np.random.uniform(size=(8, 7)) + layer.init(shapes.signature(x)) + y = layer(x) + z = tl.Accelerate(layer)(x) + for i in range(8): + self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) + self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) + + def test_accelerated_pad(self): + layer = tl.Dense(2) + x = np.random.uniform(size=(3, 7)) + layer.init(shapes.signature(x)) + y = layer(x) + z = tl.Accelerate(layer)(x) + self.assertEqual(z.shape, y.shape) + for i in range(3): + self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) + self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) + + def test_accelerated_weighted_category_accuracy(self): + """Test multi-device aggregation of weights.""" + layer = tl.Accelerate(tl.WeightedCategoryAccuracy()) + weights = np.array([1.0, 1.0, 1.0, 0.0]) + targets = np.array([0, 1, 2, 3]) + + model_outputs = np.array( + [ + [0.2, 0.1, 0.7, 0.0], + [0.2, 0.1, 0.7, 0.0], + [0.2, 0.1, 0.7, 0.0], + [0.2, 0.1, 0.7, 0.0], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(np.mean(accuracy), 1 / 3) + + def test_chunk_memory(self): + """Test chunking here to exercise accelerator memory usage.""" + layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(128)) + chunked = tl.Chunk(layer, 256) + x = np.random.uniform(size=(16 * 1024, 16)) + chunked.init(shapes.signature(x)) + y = chunked(x) + z = tl.Accelerate(chunked)(x) + self.assertEqual(y.shape, (16 * 1024, 128)) + self.assertEqual(z.shape, (16 * 1024, 128)) + + def test_chunk_grad_memory(self): + """Test chunking gradient here to exercise accelerator memory usage.""" + layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(24)) + chunked = tl.Chunk(layer, 256) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits, new_state = chunked.pure_fn(x, weights, state, rng) + loss = fastmath.numpy.mean(logits) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + x = np.random.uniform(size=(32 * 1024, 16)) + chunked.init(shapes.signature(x)) + weights, _, logits = mock_training_step( + x, chunked.weights, chunked.state, fastmath.random.get_prng(0) + ) + self.assertEqual(logits.shape, (32 * 1024, 24)) + self.assertEqual(weights[1][0][0][0].shape, (16, 1024 * 1024)) + + +if __name__ == "__main__": + config.config_with_absl() + absltest.main() diff --git a/tests/layers/activation_fns_test.py b/tests/layers/activation_fns_test.py new file mode 100644 index 000000000..e5f4e1970 --- /dev/null +++ b/tests/layers/activation_fns_test.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for activation function layers.""" + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl + + +class ActivationFnsTest(absltest.TestCase): + def test_relu(self): + layer = tl.Relu() + x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) + y = layer(x) + self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 2.0, 3.0, 5.0]) + + def test_parametric_relu(self): + layer = tl.ParametricRelu(a=0.25) + x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) + y = layer(x) + self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 0.5, 0.75, 1.25]) + + def test_leaky_relu(self): + layer = tl.LeakyRelu(a=0.125) + x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) + y = layer(x) + self.assertEqual(tl.to_list(y), [-0.25, -0.125, 0.0, 2.0, 3.0, 5.0]) + + def test_hard_sigmoid(self): + layer = tl.HardSigmoid() + x = np.array([-1.5, -0.5, -0.25, 0.0, 0.25, 0.5, 1.5]) + y = layer(x) + self.assertEqual(tl.to_list(y), [0.0, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0]) + + def test_hard_tanh(self): + layer = tl.HardTanh() + x = np.array([-1.5, -0.5, -0.25, 0.0, 0.25, 0.5, 1.5]) + y = layer(x) + self.assertEqual(tl.to_list(y), [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/assert_shape_test.py b/tests/layers/assert_shape_test.py new file mode 100644 index 000000000..7c79c06b5 --- /dev/null +++ b/tests/layers/assert_shape_test.py @@ -0,0 +1,281 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for assert shape layers.""" + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl + + +class AssertFunctionTest(absltest.TestCase): + """Test AssertFunction layer.""" + + def test_simple_pass(self): + layer = tl.AssertFunction("abc->abc", tl.Dropout(rate=0.1)) + x = np.ones((2, 5, 20)) + layer(x) + + def test_simple_fail(self): + layer = tl.AssertFunction("abc->cba", tl.Dropout(rate=0.1)) + x = np.ones((2, 5, 20)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_reduce_rank_ellipsis_pass(self): + layer = tl.AssertFunction("...ab->...c", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + layer(x) + + def test_reduce_rank_explicit_pass(self): + layer = tl.AssertFunction("xyzab->xyzc", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + layer(x) + + def test_reduce_rank_to_one_pass(self): + layer = tl.AssertFunction("abcde->x", tl.Flatten(n_axes_to_keep=0)) + x = np.ones((1, 2, 3, 4, 5)) + layer(x) + + def test_reduce_rank_explicit_fail1(self): + layer = tl.AssertFunction("abcde->abcde", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_reduce_rank_explicit_fail2(self): + layer = tl.AssertFunction("abcde->abcd", tl.Flatten(n_axes_to_keep=3)) + x = np.ones((1, 2, 3, 4, 5)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_two_outputs_pass(self): + layer = tl.AssertFunction( + "...cd->...x,...cd", + tl.Branch( + tl.Flatten(n_axes_to_keep=2), + tl.Dropout(rate=0.1), + ), + ) + x = np.ones((1, 2, 3, 4)) + layer(x) + + def test_numeric_dimensions_pass(self): + layer = tl.AssertFunction( + "...34->1234,...34", + tl.Branch( + tl.Dropout(rate=0.1), + tl.Select([0]), + ), + ) + x = np.ones((1, 2, 3, 4)) + layer(x) + + def test_too_many_outputs_fail(self): + layer = tl.AssertFunction( + "...cd->...x,...cd,...cd,...cd", + tl.Branch( + tl.Flatten(n_axes_to_keep=2), + tl.Dropout(rate=0.1), + tl.Serial(), + ), + ) + x = np.ones((1, 2, 3, 4)) + with self.assertRaises(tl.LayerError): + layer(x) + + def test_multi_output_rank_fail(self): + layer = tl.AssertFunction( + "...34->...x,...y", + tl.Branch( + tl.Flatten(n_axes_to_keep=3), + tl.Serial(), + ), + ) + x = np.ones((1, 2, 3, 4)) + with self.assertRaises(tl.LayerError): + layer(x) + + +class AssertShapeTest(absltest.TestCase): + """Test AssertShape layer.""" + + def test_simple_pass(self): + layer = tl.AssertShape("aba,ba") + x = [np.ones((10, 5, 10)), np.zeros((5, 10))] + y = layer(x) + self.assertEqual(y, x) + + def test_same_shapes_pass(self): + layer = tl.AssertShape("aba,ba") + x = [np.ones((5, 5, 5)), np.zeros((5, 5))] + y = layer(x) + self.assertEqual(y, x) + + def test_single_arg_pass(self): + layer = tl.AssertShape("a") + x = np.ones((5,)) + y = layer(x) + self.assertEqual(y.tolist(), x.tolist()) + + def test_scalar_pass(self): + layer = tl.AssertShape("") + x = np.ones(()) + y = layer(x) + self.assertEqual(y.tolist(), x.tolist()) + + def test_square_matrix_pass(self): + layer = tl.AssertShape("aa") + x = np.ones((3, 3)) + y = layer(x) + self.assertEqual(y.tolist(), x.tolist()) + + def test_vector_scalar_pass(self): + layer = tl.AssertShape("a,") + x = [np.ones((5,)), np.zeros(())] + y = layer(x) + self.assertEqual(y, x) + + def test_three_args_pass(self): + layer = tl.AssertShape("a,b,a") + x = [np.ones((5,)), np.zeros((2)), np.zeros((5))] + y = layer(x) + self.assertEqual(y, x) + + def test_multiple_matching_dims_pass(self): + layer = tl.AssertShape("a,b,a,ab") + x = [np.ones((5,)), np.zeros((2)), np.zeros((5)), np.zeros((5, 2))] + y = layer(x) + self.assertEqual(y, x) + + def test_numeric_dims_pass(self): + layer = tl.AssertShape("23,1,93") + x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_numeric_dims_fail(self): + layer = tl.AssertShape("24,1,93") + x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_middle_pass(self): + layer = tl.AssertShape("a...bc,abc") + x = [np.ones((1, 5, 5, 2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_ellipsis_prefix_pass(self): + layer = tl.AssertShape("...bc,abc") + x = [np.ones((5, 5, 2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_ellipsis_matching_zero_dims_pass(self): + layer = tl.AssertShape("...bc,abc") + x = [np.ones((2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_ellipsis_matching_ellipsis_pass(self): + layer = tl.AssertShape("...bc,...bc") + x = [np.ones((1, 2, 3)), np.zeros((1, 2, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_prefix_ellipsis_matching_sufix_ellipsis_pass(self): + layer = tl.AssertShape("bb...,...bb") + x = [np.ones((2, 2, 5, 6)), np.zeros((5, 6, 2, 2))] + y = layer(x) + self.assertEqual(y, x) + + def test_middle_ellipsis_fail(self): + layer = tl.AssertShape("ab...cde,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_short_middle_ellipsis_fail(self): + layer = tl.AssertShape("b...c,2") + x = [np.ones((2)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_double_ellipsis_fail(self): + layer = tl.AssertShape("b......c,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_typo_ellipsis_fail(self): + layer = tl.AssertShape("b..c,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_matching_ellipsis_fail(self): + layer = tl.AssertShape("...a,...b") + x = [np.ones((1, 2, 3, 7)), np.zeros((1, 2, 8))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_numeric_pass(self): + layer = tl.AssertShape("...22,...3") + x = [np.ones((1, 2, 3, 2, 2)), np.zeros((1, 2, 3, 3))] + y = layer(x) + self.assertEqual(y, x) + + def test_prefix_and_sufix_ellipsis_fail(self): + layer = tl.AssertShape("...c...,2") + x = [np.ones((2, 3, 4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipsis_too_few_dims_fail(self): + layer = tl.AssertShape("...abc,2") + x = [np.ones((4, 5)), np.zeros((2))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_ellipses_matching_dims_fail(self): + layer = tl.AssertShape("...2,...8") + x = [np.ones((1, 2, 3, 9)), np.zeros((1, 3, 3, 8))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_dims_matching_fail(self): + layer = tl.AssertShape("aba,ab") + x = [np.ones((10, 5, 10)), np.ones((5, 8))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_rank_fail(self): + layer = tl.AssertShape("aba,ab") + x = [np.ones((10, 5, 10)), np.ones((5, 10, 4))] + with self.assertRaises(tl.LayerError): + layer(x) + + def test_square_matrix_fail(self): + layer = tl.AssertShape("aa") + x = np.ones((10, 5)) + with self.assertRaises(tl.LayerError): + layer(x) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/attention_test.py b/tests/layers/attention_test.py new file mode 100644 index 000000000..a843a7c60 --- /dev/null +++ b/tests/layers/attention_test.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.attention.""" + +import functools + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl + +from tests.layers import test_utils +from trax.utils import shapes + + +class AttentionTest(absltest.TestCase): + def test_simple_call(self): + layer = tl.CausalAttention(d_feature=4, n_heads=2) + x = [ + np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ), + np.array([[[[1, 0, 1]]]]), + ] + _, _ = layer.init(shapes.signature(x)) + + y, mask = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + self.assertEqual(mask.shape, (1, 1, 1, 3)) + + def test_shift_right(self): + # Test shifts right on axis=1 + layer = tl.ShiftRight() + x = np.array( + [ + [[9, 9, 9], [8, 8, 8], [7, 7, 7], [6, 6, 6]], + [[99, 98, 97], [96, 95, 94], [93, 92, 91], [90, 89, 88]], + ] + ) + y = layer(x) + self.assertEqual(x.shape, y.shape) + self.assertEqual( + tl.to_list(y), + [ + [[0, 0, 0], [9, 9, 9], [8, 8, 8], [7, 7, 7]], + [[0, 0, 0], [99, 98, 97], [96, 95, 94], [93, 92, 91]], + ], + ) + + def test_shift_right_float(self): + layer = tl.ShiftRight() + x = np.array( + [ + [[9, 9, 9], [8, 8, 8], [7, 7, 7], [6, 6, 6]], + [[99, 98, 97], [96, 95, 94], [93, 92, 91], [90, 89, 88]], + ] + ).astype(np.float32) + x /= 2.0 + self.assertEqual(x.dtype, np.float32) + + y = layer(x) + self.assertEqual(y.dtype, np.float32) + self.assertEqual( + tl.to_list(y), + [ + [[0.0, 0.0, 0.0], [4.5, 4.5, 4.5], [4.0, 4.0, 4.0], [3.5, 3.5, 3.5]], + [ + [0.0, 0.0, 0.0], + [49.5, 49.0, 48.5], + [48.0, 47.5, 47.0], + [46.5, 46.0, 45.5], + ], + ], + ) + + def test_padding_mask(self): + layer = tl.PaddingMask() + x = np.array( + [ + [1.0, 2.0, 3.0, 4.0, 0.0], + [1.0, 2.0, 3.0, 0.0, 0.0], + [1.0, 2.0, 0.0, 0.0, 0.0], + ] + ) + y = layer(x) + self.assertEqual(x.shape, (3, 5)) + self.assertEqual(y.shape, (3, 1, 1, 5)) + np.testing.assert_equal( + y, + [ + [[[True, True, True, True, False]]], + [[[True, True, True, False, False]]], + [[[True, True, False, False, False]]], + ], + ) + + +class CausalAttentionTest(absltest.TestCase): + def test_simple_call(self): + layer = tl.CausalAttention(d_feature=4, n_heads=2) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_deterministic_eval(self): + d_model = 32 + seq_len = 3 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + tl.CausalAttention, + d_feature=d_model, + n_heads=4, + ) + + test_utils.test_eval_is_deterministic(inp, model_fn) + + def test_predict_equals_eval(self): + d_model = 32 + seq_len = 10 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + tl.CausalAttention, + d_feature=d_model, + n_heads=4, + ) + + test_utils.test_eval_equals_predict(inp, model_fn) + + +class PositionalEncodingTest(absltest.TestCase): + def test_simple_call(self): + layer = tl.PositionalEncoding(max_len=8) + x = np.array([[[2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0]]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (1, 2, 4)) + + def test_predict(self): + layer = tl.PositionalEncoding(max_len=8) + x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) + self.assertEqual(x.shape, (1, 4, 2)) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (1, 4, 2)) + layer = tl.PositionalEncoding(max_len=8, mode="predict") + layer.init(shapes.signature(x[:, :1, :])) + y0 = layer(x[:, :1, :]) # just the first token + self.assertEqual(y0.shape, (1, 1, 2)) + self.assertTrue(np.array_equal(y0, y[:, :1, :])) + y1 = layer(x[:, 1:3, :]) # now the next 2 tokens + self.assertEqual(y1.shape, (1, 2, 2)) + self.assertTrue(np.array_equal(y1, y[:, 1:3, :])) + y2 = layer(x[:, 3:4, :]) # final one token + self.assertEqual(y2.shape, (1, 1, 2)) + self.assertTrue(np.array_equal(y2, y[:, 3:4, :])) + + def test_predict_equals_eval(self): + x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) + self.assertEqual(x.shape, (1, 4, 2)) + + layer_eval = tl.PositionalEncoding(max_len=8, d_feature=4, mode="eval") + layer_eval.init(shapes.signature(x)) + + output_eval = layer_eval(x) + + layer_predict = tl.PositionalEncoding(max_len=8, d_feature=4, mode="predict") + layer_predict.init(shapes.signature(x)) + layer_predict.weights = layer_eval.weights + + output_predict = layer_predict(x) + self.assertTrue(np.array_equal(output_eval, output_predict)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/base_test.py b/tests/layers/base_test.py new file mode 100644 index 000000000..45102b78e --- /dev/null +++ b/tests/layers/base_test.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Trax base layer classes and generic layer-creating functions.""" + +import numpy as np + +from absl.testing import absltest, parameterized + +import trax.layers as tl + +from trax import fastmath +from trax.fastmath import numpy as jnp +from trax.utils import shapes + +BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP] +CUSTOM_GRAD_BACKENDS = [fastmath.Backend.JAX] + + +class BaseLayerTest(parameterized.TestCase): + def test_call_raises_error(self): + layer = tl.Layer() + x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) + with self.assertRaisesRegex(tl.LayerError, "NotImplementedError"): + _ = layer(x) + + def test_set_weighs_raises_error(self): + layer = tl.Layer() + layer.weights = 1.0 # can assign weights + with self.assertRaisesRegex(ValueError, "weighs"): + layer.weighs = 1.0 # cannot assign weighs + + def test_forward_raises_error(self): + layer = tl.Layer() + x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) + with self.assertRaises(NotImplementedError): + _ = layer.forward(x) + + def test_init_returns_empty_weights_and_state(self): + layer = tl.Layer() + input_signature = shapes.ShapeDtype((2, 5)) + weights, state = layer.init(input_signature) + self.assertEmpty(weights) + self.assertEmpty(state) + + def test_output_signature_no_weights(self): + shape_2_3_5 = shapes.ShapeDtype((2, 3, 5)) + input_signature = (shape_2_3_5, shape_2_3_5) + layer = tl.Fn("2in1out", lambda x, y: x + y) + output_signature = layer.output_signature(input_signature) + self.assertEqual(output_signature, shape_2_3_5) + + shape_5_7 = shapes.ShapeDtype((5, 7)) + input_signature = shape_5_7 + layer = tl.Fn("1in3out", lambda x: (x, 2 * x, 3 * x), n_out=3) + output_signature = layer.output_signature(input_signature) + self.assertEqual(output_signature, (shape_5_7, shape_5_7, shape_5_7)) + + # TODO(jonni): Define/test behavior of output signature for layers w/weights. + + @parameterized.named_parameters([("_" + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) + def test_custom_zero_grad(self, backend): + class IdWithZeroGrad(tl.Layer): + def forward(self, x): + return x + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + return (jnp.zeros_like(grad), ()) + + with fastmath.use_backend(backend): + layer = IdWithZeroGrad() + rng = fastmath.random.get_prng(0) + input_signature = shapes.ShapeDtype((9, 17)) + random_input = fastmath.random.uniform( + rng, input_signature.shape, minval=-1.0, maxval=1.0 + ) + layer.init(input_signature) + f = lambda x: jnp.mean(layer(x)) + grad = fastmath.grad(f)(random_input) + self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. + self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0. + + @parameterized.named_parameters([("_" + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) + def test_custom_id_grad(self, backend): + class IdWithIdGrad(tl.Layer): + def forward(self, x): + return x + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + return (inputs, ()) + + with fastmath.use_backend(backend): + layer = IdWithIdGrad() + rng = fastmath.random.get_prng(0) + input_signature = shapes.ShapeDtype((9, 17)) + random_input = fastmath.random.uniform( + rng, input_signature.shape, minval=-1.0, maxval=1.0 + ) + layer.init(input_signature) + f = lambda x: jnp.mean(layer(x)) + grad = fastmath.grad(f)(random_input) + self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. + self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input. + + def test_weights_and_state_signature(self): + class MyLayer(tl.Layer): + def init_weights_and_state(self, input_signature): + self.weights = jnp.zeros((2, 3)) + self.state = jnp.ones(input_signature.shape) + + def forward(self, inputs): + return self.weights + self.state + + layer = MyLayer() + w, s = layer.weights_and_state_signature(jnp.zeros((3, 4))) + self.assertEqual(w.shape, (2, 3)) + self.assertEqual(s.shape, (3, 4)) + + def test_custom_name(self): + layer = tl.Layer() + self.assertIn("Layer", str(layer)) + self.assertNotIn("CustomLayer", str(layer)) + + layer = tl.Layer(name="CustomLayer") + self.assertIn("CustomLayer", str(layer)) + + +class PureLayerTest(absltest.TestCase): + def test_forward(self): + layer = tl.PureLayer( + lambda x: 2 * x[0] + ) # Pure layer cast input to tuple (input,) so x is a tuple + + # Use Layer.__call__. + in_0 = np.array([1, 2]) + out_0 = layer(in_0, weights=jnp.zeros((2, 3))) + self.assertEqual(out_0.tolist(), [2, 4]) + self.assertEmpty(layer.weights) + + # Use PureLayer.forward. + in_1 = np.array([3, 4]) + out_1 = layer.forward(in_1) + self.assertEqual(out_1.tolist(), [6, 8]) + + # Use Layer.pure_fn + in_2 = np.array([5, 6]) + out_2, _ = layer.pure_fn(in_2, tl.EMPTY_WEIGHTS, tl.EMPTY_WEIGHTS, None) + self.assertEqual(out_2.tolist(), [10, 12]) + + +class FnTest(absltest.TestCase): + def test_bad_f_has_default_arg(self): + with self.assertRaisesRegex(ValueError, "default arg"): + _ = tl.Fn("", lambda x, sth=None: x) + + def test_bad_f_has_keyword_arg(self): + with self.assertRaisesRegex(ValueError, "keyword arg"): + _ = tl.Fn("", lambda x, **kwargs: x) + + def test_bad_f_has_variable_arg(self): + with self.assertRaisesRegex(ValueError, "variable arg"): + _ = tl.Fn("", lambda *args: args[0]) + + def test_forward(self): + layer = tl.Fn( + "SumAndMax", lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2 + ) + + x0 = np.array([1, 2, 3, 4, 5]) + x1 = np.array([10, 20, 30, 40, 50]) + + y0, y1 = layer((x0, x1)) + self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55]) + self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50]) + + y2, y3 = layer.forward((x0, x1)) + self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55]) + self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50]) + + (y4, y5), state = layer.pure_fn( + (x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE, None + ) + self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55]) + self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50]) + self.assertEqual(state, tl.EMPTY_STATE) + + def test_weights_state(self): + layer = tl.Fn( + "2in2out", lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), n_out=2 + ) + layer.init_weights_and_state(None) + self.assertEmpty(layer.weights) + self.assertEmpty(layer.state) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/combinators_test.py b/tests/layers/combinators_test.py new file mode 100644 index 000000000..f5f8c8baf --- /dev/null +++ b/tests/layers/combinators_test.py @@ -0,0 +1,748 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for combinator layers.""" + +import numpy as np + +from absl.testing import absltest, parameterized + +import trax.layers as tl + +from trax import fastmath +from trax.utils import shapes + + +def DivideBy(val): # pylint: disable=invalid-name + """Returns a simple division layer with n_in == 1 and n_out == 1.""" + return tl.Fn("DivideBy", lambda x: x / val) + + +def ReturnConst(val): # pylint: disable=invalid-name + """Returns a simple const layer with n_in == 0 and n_out == 1.""" + return tl.Fn("ReturnConst", lambda: val) + + +def SmallerThan(val): # pylint: disable=invalid-name + """Checks if the input is smaller than certain value.""" + return tl.Fn("SmallerThan", lambda x: x < val) + + +# TODO(jonni): Consider a more generic home for this utiliity function. +def as_list(outputs): + """Converts layer outputs to a nested list, for easier equality testing. + + Args: + outputs: A tensor or tuple/list of tensors coming from the forward + application of a layer. Each tensor is NumPy ndarray-like, which + complicates simple equality testing (e.g., via `assertEquals`): + such tensors require equality testing to use either `all` (all + elements match) or `any` (at least one element matches), which is not + directly supported in absltest. + + Returns: + A nested list structure containing all the output values, but now directly + testable using `assertEquals`. + """ + if isinstance(outputs, (list, tuple)): + return [as_list(y) for y in outputs] + else: + return outputs.tolist() + + +class SerialTest(absltest.TestCase): + def test_none_is_no_op(self): + layer = tl.Serial(None) + xs = [np.array([1, 2, 3, 4]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3, 4], [10, 20, 30]]) + + def test_empty_list_is_no_op(self): + layer = tl.Serial([]) + xs = [np.array([1, 2, 3, 4]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3, 4], [10, 20, 30]]) + + def test_one_in_one_out(self): + layer = tl.Serial(DivideBy(3)) + x = np.array([3, 6, 9, 12]) + y = layer(x) + self.assertEqual(as_list(y), [1, 2, 3, 4]) + + def test_zero_in_one_out(self): + layer = tl.Serial(ReturnConst(np.array([3, 4, 5, 6]))) + y = layer(()) + self.assertEqual(as_list(y), [3, 4, 5, 6]) + + def test_one_in_two_out(self): + layer = tl.Serial(DivideBy(3), ReturnConst(np.array([3, 4, 5, 6]))) + x = np.array([3, 6, 9, 12]) + y = layer(x) + self.assertEqual(as_list(y), [[3, 4, 5, 6], [1, 2, 3, 4]]) + + def test_const_div(self): + layer = tl.Serial(ReturnConst(np.array([3, 6, 9, 12])), DivideBy(3)) + y = layer(()) + self.assertEqual(as_list(y), [1, 2, 3, 4]) + + def test_div_div(self): + layer = tl.Serial(DivideBy(2.0), DivideBy(5.0)) + x = np.array([10, 20, 30]) + y = layer(x) + self.assertEqual(as_list(y), [1, 2, 3]) + + def test_dup_dup(self): + layer = tl.Serial(tl.Dup(), tl.Dup()) + x = np.array([1, 2, 3]) + ys = layer(x) + self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + def test_default_name(self): + layer = tl.Serial(tl.Dup(), tl.Dup()) + self.assertIn("Serial", str(layer)) + + def test_custom_name(self): + layer = tl.Serial(tl.Dup(), tl.Dup(), name="Branch") + self.assertIn("Branch", str(layer)) + + def test_weights(self): + model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) + self.assertIsInstance(model.weights, tuple) + self.assertLen(model.weights, 3) + + def test_flat_weights_and_state(self): + model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup())) + sample_input_signature = shapes.signature(np.zeros((2, 3))) + model.init(sample_input_signature) + flat_weights, flat_state = tl.flatten_weights_and_state( + model.weights, model.state + ) + # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. + # So after making them flat, there are 4 trainable weights. + self.assertLen(flat_weights, 4) + self.assertEmpty(flat_state) + model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7)) + sig = model2.weights_and_state_signature(sample_input_signature) + weights2, state2 = tl.unflatten_weights_and_state(flat_weights, flat_state, sig) + model2.weights = weights2 + model2.state = state2 + self.assertLen(model2.weights, 3) + self.assertEqual(model.weights[1], model2.weights[0]) + self.assertEqual(model.weights[2][0], model2.weights[2]) + + def test_flat_weights_and_state_shared(self): + shared = tl.Dense(5) + model = tl.Serial(tl.Dense(5), shared, tl.Serial(shared, tl.Dup())) + sample_input_signature = shapes.signature(np.zeros((2, 3))) + model.init(sample_input_signature) + flat_weights, flat_state = tl.flatten_weights_and_state( + model.weights, model.state + ) + # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. + # So after making them flat, there are 4 trainable weights. + self.assertLen(flat_weights, 4) + self.assertEmpty(flat_state) + model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(5)) + sig = model2.weights_and_state_signature(sample_input_signature) + weights2, state2 = tl.unflatten_weights_and_state(flat_weights, flat_state, sig) + model2.weights = weights2 + model2.state = state2 + self.assertLen(model2.weights, 3) + self.assertEqual(model.weights[0], model2.weights[0]) + self.assertEqual(model.weights[1], model2.weights[2]) + + def test_assign_sublayer_weights(self): + layer = tl.Dense(5, use_bias=False) + model = tl.Serial(tl.Serial(layer, tl.Dense(6)), tl.Dense(7)) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + new_layer_weights = np.random.uniform(weights[0][0].shape) + layer.weights = new_layer_weights + self.assertIs(model.weights[0][0], new_layer_weights) + + def test_shared_weights(self): + layer = tl.Dense(5) + model = tl.Serial(layer, layer) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_nested(self): + layer = tl.Dense(5) + model = tl.Serial(layer, tl.Serial(layer)) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_double_nested(self): + layer = tl.Dense(5) + model = tl.Serial(tl.Serial(layer), tl.Serial(layer)) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_for_shared_serial(self): + layer = tl.Serial(tl.Dense(5), tl.Dense(5)) + model = tl.Serial(layer, layer) + sample_input = np.array([1, 2, 3, 4, 5]) + # Init gives weights reflecting weight sharing. + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + # Forward pass runs successfully. + y = model(sample_input) + self.assertEqual(y.shape, (5,)) + + def test_state(self): + model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) + self.assertIsInstance(model.state, tuple) + self.assertLen(model.state, 3) + + def test_set_rng_recurse_two_levels(self): + dense_00 = tl.Dense(2) + dense_01 = tl.Dense(2) + dense_10 = tl.Dense(2) + dense_11 = tl.Dense(2) + layer = tl.Serial( + tl.Serial(dense_00, dense_01), + tl.Serial(dense_10, dense_11), + ) + input_signature = shapes.ShapeDtype((1, 2)) + + _, _ = layer.init(input_signature) + weights = layer.weights + dense_00_w, dense_00_b = weights[0][0] + dense_01_w, dense_01_b = weights[0][1] + dense_10_w, dense_10_b = weights[1][0] + dense_11_w, dense_11_b = weights[1][1] + + # Setting rng's recursively during init should yield differing weights. + self.assertFalse(np.array_equal(dense_00_w, dense_01_w)) + self.assertFalse(np.array_equal(dense_00_b, dense_01_b)) + self.assertFalse(np.array_equal(dense_10_w, dense_11_w)) + self.assertFalse(np.array_equal(dense_10_b, dense_11_b)) + + +class ParallelTest(absltest.TestCase): + def test_dup_dup(self): + layer = tl.Parallel(tl.Dup(), tl.Dup()) + xs = [np.array([1, 2, 3]), np.array([10, 20])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [10, 20], [10, 20]]) + + def test_div_div(self): + layer = tl.Parallel(DivideBy(0.5), DivideBy(3.0)) + xs = [np.array([1, 2, 3]), np.array([30, 60])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[2, 4, 6], [10, 20]]) + + def test_two_no_ops(self): + layer = tl.Parallel(tl.Select([0]), tl.Select([0])) + xs = (np.array([1, 2, 3]), np.array([10, 20])) + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3], [10, 20]]) + + def test_default_name(self): + layer = tl.Parallel(tl.Dup(), tl.Dup()) + self.assertIn("Parallel", str(layer)) + + def test_custom_name(self): + layer = tl.Parallel(tl.Dup(), tl.Dup(), name="DupDup") + self.assertIn("DupDup", str(layer)) + + def test_weights(self): + model = tl.Parallel(tl.Dense(3), tl.Dense(5)) + self.assertIsInstance(model.weights, tuple) + self.assertLen(model.weights, 2) + + def test_shared_weights(self): + layer = tl.Dense(5) + model = tl.Parallel(layer, layer) + sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_nested(self): + layer = tl.Dense(5) + model = tl.Parallel([layer, tl.Dense(2)], [layer, tl.Dense(2)]) + sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) + + def test_shared_weights_for_shared_parallel(self): + layer = tl.Parallel(tl.Dense(5), tl.Dense(7)) + model = tl.Parallel(layer, layer) + sample_input = [ + np.array([1, 2, 3]), + np.array([10, 20, 30]), + np.array([100, 200, 300]), + np.array([1000, 2000, 3000]), + ] + # Init gives weights reflecting weight sharing. + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + # Forward pass runs successfully. + y0, y1, y2, y3 = model(sample_input) + self.assertEqual(y0.shape, (5,)) + self.assertEqual(y1.shape, (7,)) + self.assertEqual(y2.shape, (5,)) + self.assertEqual(y3.shape, (7,)) + + def test_state(self): + model = tl.Parallel(tl.Dense(3), tl.Dense(5)) + self.assertIsInstance(model.state, tuple) + self.assertLen(model.state, 2) + + +class ConcatenateTest(absltest.TestCase): + def test_n_in_n_out(self): + layer = tl.Concatenate() + self.assertEqual(layer.n_in, 2) + self.assertEqual(layer.n_out, 1) + + def test_with_defaults(self): + layer = tl.Concatenate() # Default n_items=2, axis=-1 + xs = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[10, 20, 30], [40, 50, 60]])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[1, 2, 3, 10, 20, 30], [4, 5, 6, 40, 50, 60]]) + + def test_axis_0(self): + layer = tl.Concatenate(axis=0) + xs = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[10, 20, 30], [40, 50, 60]])] + y = layer(xs) + self.assertEqual(as_list(y), [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]]) + + def test_axis_1(self): + layer = tl.Concatenate(axis=1) + xs = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[10, 20, 30], [40, 50, 60]])] + y = layer(xs) + self.assertEqual(as_list(y), [[1, 2, 3, 10, 20, 30], [4, 5, 6, 40, 50, 60]]) + + def test_n_items_is_not_default(self): + layer = tl.Concatenate(n_items=3) + xs = [ + np.array([[1, 2, 3], [4, 5, 6]]), + np.array([[10, 20, 30], [40, 50, 60]]), + np.array([[100, 200, 300], [400, 500, 600]]), + ] + y = layer(xs) + self.assertEqual(y.shape, (2, 9)) + self.assertEqual( + as_list(y), + [ + [1, 2, 3, 10, 20, 30, 100, 200, 300], + [4, 5, 6, 40, 50, 60, 400, 500, 600], + ], + ) + + def test_repr(self): + layer = tl.Concatenate() + self.assertEqual(repr(layer), "Concatenate_in2") + + layer = tl.Concatenate(axis=0) + self.assertEqual(repr(layer), "Concatenate_axis0_in2") + + layer = tl.Concatenate(axis=1) + self.assertEqual(repr(layer), "Concatenate_axis1_in2") + + layer = tl.Concatenate(n_items=3) + self.assertEqual(repr(layer), "Concatenate_in3") + + +class BranchTest(absltest.TestCase): + def test_noop_dup(self): + layer = tl.Branch(tl.Select([0]), tl.Dup()) + x = np.array([1, 2, 3]) + ys = layer(x) + self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + def test_add_div(self): + layer = tl.Branch(tl.Add(), DivideBy(0.5)) + xs = [np.array([1, 2, 3]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[11, 22, 33], [2, 4, 6]]) + + def test_one_sublayer(self): + layer = tl.Branch(DivideBy(0.5)) + x = np.array([1, 2, 3]) + ys = layer(x) + self.assertEqual(as_list(ys), [2, 4, 6]) + + def test_default_name(self): + layer = tl.Branch(tl.Add(), DivideBy(0.5)) + self.assertIn("Branch", str(layer)) + + def test_printing_sublayers(self): + layer = tl.Branch(tl.Add(), tl.Add()) + expected_result = "Branch_in2_out2[\n Add_in2\n Add_in2\n]" + self.assertEqual(expected_result, str(layer)) + + +class SelectTest(absltest.TestCase): + def test_computes_n_in(self): + layer = tl.Select([0, 0]) + self.assertEqual(layer.n_in, 1) + + layer = tl.Select([1, 0]) + self.assertEqual(layer.n_in, 2) + + layer = tl.Select([2]) + self.assertEqual(layer.n_in, 3) + + def test_given_n_in(self): + layer = tl.Select([0], n_in=2) + self.assertEqual(layer.n_in, 2) + + layer = tl.Select([0], n_in=3) + self.assertEqual(layer.n_in, 3) + + def test_first_of_3(self): + layer = tl.Select([0], n_in=3) + xs = [np.array([1, 2, 3]), np.array([10, 20]), np.array([100])] + y = layer(xs) + self.assertEqual(as_list(y), [1, 2, 3]) + + def test_second_of_3(self): + layer = tl.Select([1], n_in=3) + xs = [np.array([1, 2, 3]), np.array([10, 20]), np.array([100])] + y = layer(xs) + self.assertEqual(as_list(y), [10, 20]) + + +class DropTest(absltest.TestCase): + def test_drop(self): + layer = tl.Drop() + x = np.array([1, 2, 3]) + y = layer(x) + self.assertEqual(as_list(y), []) + + +class SwapTest(absltest.TestCase): + def test_swap(self): + layer = tl.Swap() + xs = [np.array([1, 2, 3]), np.array([10, 20, 30])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[10, 20, 30], [1, 2, 3]]) + + +class ChunkTest(absltest.TestCase): + def test_chunk(self): + layer = tl.Dense(4) + x = np.array([[1, 2, 3], [4, 5, 6]]) + layer.init(x) + y = layer(x) + z = tl.Chunk(layer, 1)(x) + self.assertLess(np.sum((y - z) ** 2), 1e-5) # y == z upto numerics + + def test_chunk_uneven_numbers(self): + layer = tl.Dense(4) + x = np.array([[1, 2, 3], [4, 5, 6]]) + layer.init(x) + y = layer(x) + z = tl.Chunk(layer, 3)(x) # By default it should just pass + self.assertLess(np.sum((y - z) ** 2), 1e-5) # y == z upto numerics + chunk_with_test = tl.Chunk(layer, 3, pass_unchunkable=False) + self.assertRaises(tl.LayerError, lambda: chunk_with_test(x)) + + +class SerialWithSideOutputsTest(absltest.TestCase): + def test_serial_with_side_outputs_div_div(self): + def some_layer(): + return tl.Parallel(DivideBy(2.0), DivideBy(5.0)) + + layer = tl.SerialWithSideOutputs([some_layer(), some_layer()]) + xs = (np.array([1, 2, 3]), np.array([10, 20, 30, 40, 50]), np.array([100, 200])) + ys = layer(xs) + output_shapes = [y.shape for y in ys] + self.assertEqual(output_shapes, [(3,), (5,), (2,)]) + + +BACKENDS = [fastmath.Backend.JAX] + + +@parameterized.named_parameters(("_" + b.value, b) for b in BACKENDS) +class ScanTest(parameterized.TestCase): + def _AddWithCarry(self): # pylint: disable=invalid-name + del self + + def f(x, carry): + res = x + carry + return res, res # output and carry are the same + + return tl.Fn("AddWithCarry", f, n_out=2) + + def test_default_axis(self, backend): + with fastmath.use_backend(backend): + layer = tl.Scan(self._AddWithCarry()) + xs = [ + np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]), + np.array([9000, 8000, 7000, 6000]), + ] + ys = layer(xs) + self.assertEqual( + as_list(ys), + [ + [ + [9000, 8001, 7002, 6003], + [9000, 8011, 7022, 6033], + [9000, 8111, 7222, 6333], + ], + [9000, 8111, 7222, 6333], + ], + ) + + def test_axis_1(self, backend): + with fastmath.use_backend(backend): + layer = tl.Scan(self._AddWithCarry(), axis=1) + xs = [ + np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]), + np.array([9000, 8000, 7000]), + ] + ys = layer(xs) + self.assertEqual( + as_list(ys), + [ + [ + [9000, 9001, 9003, 9006], + [8000, 8010, 8030, 8060], + [7000, 7100, 7300, 7600], + ], + [9006, 8060, 7600], + ], + ) + + def test_predict(self, backend): + with fastmath.use_backend(backend): + layer = tl.Scan(self._AddWithCarry(), axis=1, mode="predict") + xs = [np.array([[0, 1, 2]]), np.array([90])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[[90, 91, 93]], [93]]) + xs = [np.array([[3, 4]]), np.array([90])] + ys = layer(xs) + self.assertEqual(as_list(ys), [[[96, 100]], [100]]) + + def test_multi_input(self, backend): + def _MultiInputFn(): # pylint: disable=invalid-name + def f(a, b, carry): + return a + b, b, carry + 1 + + return tl.Fn("MultiInputFn", f, n_out=2) + + with fastmath.use_backend(backend): + layer = tl.Scan(_MultiInputFn(), axis=1) + xs = [ + np.array([[0, 1, 2], [0, 10, 20]]), + np.array([[4, 5, 6], [40, 50, 60]]), + np.array([9000, 8000]), + ] + ys = layer(xs) + self.assertEqual( + as_list(ys), + [[[4, 6, 8], [40, 60, 80]], [[4, 5, 6], [40, 50, 60]], [9003, 8003]], + ) + + def test_no_carry(self, backend): + def _AddOne(): # pylint: disable=invalid-name + return tl.Fn("AddOne", lambda x: x + 1) + + with fastmath.use_backend(backend): + layer = tl.Scan(_AddOne(), n_carry=0) + x = np.array([[1, 3, 7], [10, 30, 70]]) + y = layer(x) + self.assertEqual(as_list(y), [[2, 4, 8], [11, 31, 71]]) + + +class CondTest(absltest.TestCase): + def test_basic_true(self): + cond = ReturnConst(True) + true = ReturnConst([2]) + false = ReturnConst([5]) + layer = tl.Cond(cond, true, false) + layer.init(()) + xs = tuple() + ys = layer(xs) + self.assertEqual(as_list(ys), 2) + + def test_basic_false(self): + cond = ReturnConst(False) + true = ReturnConst([2]) + false = ReturnConst([5]) + layer = tl.Cond(cond, true, false) + layer.init(()) + xs = tuple() + ys = layer(xs) + self.assertEqual(as_list(ys), 5) + + def test_complex_blocks(self): + cond = ReturnConst(True) + true = DivideBy(2.0) + false = DivideBy(4.0) + layer = tl.Cond(cond, true, false) + xs = [np.arange(5).astype(np.float32)] + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [0.0, 0.5, 1.0, 1.5, 2.0]) + + def test_condition_func_true(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = DivideBy(4.0) + layer = tl.Cond(cond, true, false) + xs = (np.array(2.0), np.array([4.0, 12.0])) + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [2.0, 6.0]) + + def test_condition_func_false(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = DivideBy(4.0) + layer = tl.Cond(cond, true, false) + xs = (np.array(4.0), np.array([4.0, 12.0])) + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [1.0, 3.0]) + + def test_condition_func_default_false(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + layer = tl.Cond(cond, true) + xs = (np.array(4.0), np.array([4.0, 12.0])) + layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual(as_list(ys), [4.0, 12.0]) + + def test_exception_n_out(self): + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = tl.Dup() + self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) + + def test_exception_n_in(self): + cond = SmallerThan(3.0) + true = ReturnConst(2.0) + false = DivideBy(2.0) + self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) + + def test_exception_run1(self): + # We expect exactly one input. + cond = SmallerThan(3.0) + true = ReturnConst(2.0) + false = ReturnConst(5.0) + + def init_and_run(layer, xs): + layer.init(shapes.signature(xs)) + layer(xs) + + # It will pass with one input. + xs = np.array(4.0) + layer = tl.Cond(cond, true, false) + init_and_run(layer, xs) + # It will fail with zero or two inputs. + for xs in ((), (np.array(4.0), np.array([4.0, 12.0]))): + layer = tl.Cond(cond, true, false) + # pylint: disable=cell-var-from-loop + self.assertRaises(Exception, lambda: init_and_run(layer, xs)) + + def test_exception_run2(self): + # We expect exactly two inputs. + cond = SmallerThan(3.0) + true = DivideBy(2.0) + false = DivideBy(5.0) + + def init_and_run(layer, xs): + layer.init(shapes.signature(xs)) + layer(xs) + + # It will pass with two inputs. + xs = (np.array(4.0), np.array([4.0, 12.0])) + layer = tl.Cond(cond, true, false) + init_and_run(layer, xs) + # It will fail with zero or one input. + for xs in ((), (np.array(4.0))): + # pylint: disable=cell-var-from-loop + self.assertRaises(Exception, lambda: init_and_run(layer, xs)) + + def test_weights_and_state(self): + cond = SmallerThan(3.0) + true = tl.Dense(5) + false = tl.Dense(5) + different = tl.Dense(5) + layer = tl.Cond(cond, true, false) + xs = (np.array(2.0), np.array([0.0, 1.0, 2.0])) + layer.init(shapes.signature(xs)) + + # weights + self.assertEqual( + as_list(layer.weights), as_list((cond.weights, true.weights, false.weights)) + ) + self.assertNotEqual(as_list(true.weights), as_list(false.weights)) + self.assertNotEqual(as_list(true.weights), as_list(different.weights)) + + false.weights = true.weights + self.assertEqual( + as_list(layer.weights), as_list((cond.weights, true.weights, true.weights)) + ) + + layer.weights = (cond.weights, true.weights, different.weights) + self.assertEqual( + as_list(layer.weights), + as_list((cond.weights, true.weights, different.weights)), + ) + # state + self.assertEqual( + as_list(layer.state), as_list((cond.state, true.state, false.state)) + ) + # just check if simple assignments (setter from base.Layer) work correctly + # with Cond.init_weights_and_state ; all states are empty so there is no + # point in checking equality + false.state = true.state + layer.state = (cond.state, true.state, different.state) + + +class BatchLeadingAxesTest(absltest.TestCase): + def _Id3Dim(self): # pylint: disable=invalid-name + del self + + def f(x): + assert len(x.shape) == 3 + return x + + return tl.Fn("Id3Dim", f, n_out=1) + + def test_2axes(self): + layer = tl.BatchLeadingAxes(self._Id3Dim(), n_last_axes_to_keep=2) + ys = layer(np.zeros((3, 4, 5))) + self.assertEqual(ys.shape, (3, 4, 5)) + ys = layer(np.zeros((2, 3, 4, 5))) + self.assertEqual(ys.shape, (2, 3, 4, 5)) + ys = layer(np.zeros((1, 2, 3, 4, 5))) + self.assertEqual(ys.shape, (1, 2, 3, 4, 5)) + + +class BidirectionalTest(absltest.TestCase): + def test_dimensionality(self): + x = np.ones((2, 3, 8)) + layer = tl.Bidirectional(tl.GRU(n_units=8)) + input_signature = shapes.signature(x) + _, _ = layer.init(input_signature) + yhat = layer(x) + + self.assertEqual(yhat.shape, (2, 3, 8 + 8)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/convolution_test.py b/tests/layers/convolution_test.py new file mode 100644 index 000000000..3feb89807 --- /dev/null +++ b/tests/layers/convolution_test.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for convolution layers.""" + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl + +from trax.utils import shapes + + +class ConvolutionTest(absltest.TestCase): + def test_call(self): + layer = tl.Conv(30, (3, 3)) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 3, 3, 30)) + + def test_use_bias_true(self): + layer = tl.Conv(30, (3, 3), use_bias=True) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 3, 3, 30)) + + self.assertIsInstance(layer.weights, tuple) + self.assertLen(layer.weights, 2) + self.assertEqual(layer.weights[0].shape, (3, 3, 20, 30)) + self.assertEqual(layer.weights[1].shape, (30,)) + + def test_use_bias_false(self): + layer = tl.Conv(30, (3, 3), use_bias=False) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 3, 3, 30)) + # With use_bias=False, layer.weights is just 'w' and there is no 'b'. + self.assertEqual(layer.weights.shape, (3, 3, 20, 30)) + + def test_call_rebatch(self): + layer = tl.Conv(30, (3, 3)) + x = np.ones((2, 9, 5, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (2, 9, 3, 3, 30)) + + +class CausalConvolutionTest(absltest.TestCase): + def test_causal_conv(self): + layer = tl.CausalConv(filters=30, kernel_width=3) + x = np.ones((9, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 5, 30)) + + # TODO(ddohan): How to test for causality? Gradient check between positions? + + def test_causal_conv_use_bias_false(self): + layer = tl.CausalConv(filters=30, kernel_width=3, use_bias=False) + x = np.ones((9, 5, 20)) + layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (9, 5, 30)) + + self.assertEqual(layer.weights.shape, (3, 20, 30)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/core_test.py b/tests/layers/core_test.py new file mode 100644 index 000000000..b3528c3eb --- /dev/null +++ b/tests/layers/core_test.py @@ -0,0 +1,478 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for core layers.""" + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl +import trax.layers.initializers as init + +from trax.fastmath import numpy as jnp +from trax.utils import shapes + + +class DenseTest(absltest.TestCase): + """Test Dense layer per se and as a key example of trainable layers.""" + + def test_call_before_init_raises_error(self): + layer = tl.Dense(5) + x = np.array([1, 2, 3]) + + # Without init, layer lacks the weights it needs for forward computation. + with self.assertRaises(tl.LayerError): + _ = layer(x) + + def test_call_uses_and_caches_supplied_weights(self): + layer = tl.Dense(4) + x = np.array([2, 3]) + + # Weights from random initialization are cached in the layer. + _, _ = layer.init(shapes.signature(x)) + w_init, b_init = layer.weights + + # Call the layer with externally specified weights. + w = np.array([[10000, 20000, 30000, 40000], [100, 200, 100, 200]]) + b = np.array([9, 8, 7, 6]) + y = layer(x, weights=(w, b)) + + # Using weights keyword arg overrides any previous cached weights ... + self.assertEqual(y.tolist(), [20309, 40608, 60307, 80606]) + self.assertNotEqual(w.tolist(), w_init.tolist()) + self.assertNotEqual(b.tolist(), b_init.tolist()) + + # ... and do not over-write the old weights. + w_cached, b_cached = layer.weights + self.assertNotEqual(w.tolist(), w_cached.tolist()) + self.assertNotEqual(b.tolist(), b_cached.tolist()) + + def test_separate_instances_have_separate_weights(self): + # Two dense layer instances: each will get its own initial weights (w, b). + model = tl.Serial(tl.Dense(5), tl.Dense(5)) + + sample_input = np.array([1, 2, 3, 4, 5]) + _, _ = model.init(shapes.signature(sample_input)) + weights_0 = model.sublayers[0].weights + weights_1 = model.sublayers[1].weights + + w0, b0 = weights_0 + w1, b1 = weights_1 + self.assertNotEqual(w0.tolist(), w1.tolist()) + self.assertNotEqual(b0.tolist(), b1.tolist()) + + def test_shared_instance_means_shared_weights(self): + # Same dense layer instance in two places --> shared weights. + layer = tl.Dense(5) + model = tl.Serial(layer, layer) + sample_input = np.array([1, 2, 3, 4, 5]) + weights, _ = model.init(shapes.signature(sample_input)) + self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) + + def test_call_no_bias(self): + layer = tl.Dense(4, use_bias=False) + x = np.array([2, 5, 3]) + _, _ = layer.init(shapes.signature(x)) + + w = np.array([[100, 200, 300, 400], [10, 10, 10, 10], [1, 2, 1, 2]]) + y = layer(x, weights=w) + self.assertEqual(y.tolist(), [253, 456, 653, 856]) + + def test_new_weights_use_bias(self): + layer = tl.Dense(4) + x = np.array([1, 2]) + _, _ = layer.init(shapes.signature(x)) + self.assertLen(layer.weights, 2) + self.assertEqual(layer.weights[0].shape, (2, 4)) + self.assertEqual(layer.weights[1].shape, (4,)) + + def test_new_weights_no_bias(self): + layer = tl.Dense(4, use_bias=False) + x = np.array([1, 2]) + _, _ = layer.init(shapes.signature(x)) + self.assertEqual(layer.weights.shape, (2, 4)) + + def test_init_twice_weights_same_shape(self): + layer = tl.Dense(4, use_bias=False) + x = np.array([1, 2]) + w1, _ = layer.init(shapes.signature(x)) + w2, _ = layer.init(shapes.signature(x)) + self.assertEqual(w1.shape, (2, 4)) + self.assertEqual(w2.shape, (2, 4)) + + def test_save_to_file_and_init_to_file(self): + layer1 = tl.Dense(4, use_bias=False) + layer2 = tl.Dense(4, use_bias=False) + x = np.array([1, 2]) + w1, _ = layer1.init(shapes.signature(x)) + layer1.save_to_file("/tmp/dense_weights", input_signature=shapes.signature(x)) + w2, _ = layer2.init_from_file("/tmp/dense_weights") + self.assertEqual(w1.shape, (2, 4)) + self.assertEqual(w2.shape, (2, 4)) + self.assertEqual(w1.tolist(), w2.tolist()) + + +class EmbeddingTest(absltest.TestCase): + def test_forward(self): + layer = tl.Embedding(10, 3) # vocab_size=10, d_feature=3 + _, _ = layer.init(None) # Embedding init doesn't use input signature. + x = np.array([2, 3, 5, 3, 2]) + y = layer(x) + self.assertEqual(y.shape, (5, 3)) + + # For distinct in-domain token IDs, resulting vectors should be distinct. + self.assertNotEqual(y[0].tolist(), y[1].tolist()) + self.assertNotEqual(y[0].tolist(), y[2].tolist()) + self.assertNotEqual(y[1].tolist(), y[2].tolist()) + + # For repeats of a token id, resulting vectors should match. + self.assertEqual(y[0].tolist(), y[4].tolist()) + self.assertEqual(y[1].tolist(), y[3].tolist()) + + def test_negative_inputs_clip_to_zero(self): + layer = tl.Embedding(10, 3) + _, _ = layer.init(None) + x = np.array([0, 2, 3, -2, -3]) + y = layer(x) + self.assertNotEqual(y[0].tolist(), y[1].tolist()) + self.assertNotEqual(y[0].tolist(), y[2].tolist()) + self.assertEqual(y[0].tolist(), y[3].tolist()) + self.assertEqual(y[0].tolist(), y[4].tolist()) + + def test_large_inputs_clip_to_upper_bound(self): + layer = tl.Embedding(10, 3) + _, _ = layer.init(None) + x = np.array([2, 3, 9, 10, 20]) + y = layer(x) + + # vocab_size of 10 means max valid token id is 9. + self.assertNotEqual(y[2].tolist(), y[0].tolist()) + self.assertNotEqual(y[2].tolist(), y[1].tolist()) + self.assertEqual(y[2].tolist(), y[3].tolist()) + self.assertEqual(y[2].tolist(), y[4].tolist()) + + def test_new_weights(self): + layer = tl.Embedding(20, 5) + _, _ = layer.init(None) + + # Default weights sampled from Gaussian, mu = 0, sigma = 1. + w = layer.weights + self.assertEqual(w.shape, (20, 5)) + self.assertLess(np.abs(np.mean(w)), 0.4) # .4 is 4 sigma deviation + + def test_explicit_kernel_initializer(self): + def f(shape, rng): + del rng + n_elements = np.prod(shape) + return np.arange(n_elements).reshape(shape) + + layer = tl.Embedding(5, 2, kernel_initializer=f) + _, _ = layer.init(None) + x = np.array([0, 1, 2, 3, 4]) + y = layer(x) + self.assertEqual(y.tolist(), [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) + + +class DropoutTest(absltest.TestCase): + def test_call_in_train_mode(self): + layer = tl.Dropout(rate=0.1, mode="train") + x = np.ones((2, 5, 1000)) # 10,000 values + y = layer(x) + self.assertEqual(y.shape, (2, 5, 1000)) + + # Dropout is stochastic; test it nonflakily at 4 sigmas (.99994). + n_remaining = np.count_nonzero(y) + mu_of_remaining = 9000 # N * q: 10000 * .9 + sigma_of_remaining = 30 # sqrt(N * p * q): sqrt(10000 * .1 * .9) + self.assertLess(np.abs(n_remaining - mu_of_remaining), 4 * sigma_of_remaining) + + def test_call_in_eval_mode_does_no_dropout(self): + layer = tl.Dropout(rate=0.1, mode="eval") + x = np.ones((2, 5, 1000)) + y = layer(x) + self.assertEqual(np.count_nonzero(y), 10_000) + + def test_new_weights(self): + layer = tl.Dropout(rate=0.1, mode="train") + layer.init(None) + self.assertEmpty(layer.weights) + + +class WeightsTest(absltest.TestCase): + """Test Weights layer.""" + + def test_simple(self): + layer = tl.Weights(lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)) + layer.init(()) + y = layer(()) + self.assertEqual(y.tolist(), 0.0) + + def test_shape(self): + layer = tl.Weights(init.RandomNormalInitializer(), (5, 10, 3)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (5, 10, 3)) + + def test_simple_custom_initializer(self): + layer = tl.Weights(init.RandomNormalInitializer()) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, ()) + self.assertNotEqual(y.tolist(), 0.0) + + def test_custom_initializer_shape(self): + layer = tl.Weights( + lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), (2, 2) + ) + layer.init(()) + y = layer(()) + self.assertEqual(y.tolist(), [[0.0, 0.0], [0.0, 0.0]]) + + layer = tl.Weights(init.RandomNormalInitializer(), (2, 2)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (2, 2)) + self.assertNotEqual(y.tolist(), [[0.0, 0.0], [0.0, 0.0]]) + + +class SummaryScalarTest(absltest.TestCase): + def test_passes(self): + layer = tl.SummaryScalar("test") + x = np.array([[3.0, 5.0], [2.0, 6.0]]) # 10,000 values + y = layer(x) + self.assertEqual(y.tolist(), [[3.0, 5.0], [2.0, 6.0]]) + self.assertEqual(layer.state["summary_test"].tolist(), 4.0) + + +class RandomUniformTest(absltest.TestCase): + """Test Weights layer.""" + + def test_simple(self): + layer = tl.RandomUniform() + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, ()) + self.assertBetween(y, 0.0, 1.0) + + def test_shape(self): + layer = tl.RandomUniform(shape=(5, 10, 3)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (5, 10, 3)) + + def test_simple_range(self): + layer = tl.RandomUniform(1.0, 2.0, shape=(1000,)) + layer.init(()) + y = layer(()) + self.assertEqual(y.shape, (1000,)) + self.assertBetween(min(y.tolist()), 1.0, 2.0) + self.assertBetween(max(y.tolist()), 1.0, 2.0) + self.assertBetween(1.5, min(y.tolist()), max(y.tolist())) + + +class LocallyConnected1dTest(absltest.TestCase): + def test_shape_kernel1(self): + for padding in ["WRAP", "SAME", "VALID"]: + layer = tl.LocallyConnected1d(6, 1, padding=padding) + x = np.array([[0, 1], [2, 3], [4, 5]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (3, 6)) + + def test_shape_kernel3(self): + for padding in ["WRAP", "SAME"]: + layer = tl.LocallyConnected1d(6, 3, padding=padding) + x = np.array([[0, 1], [2, 3], [4, 5]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (3, 6)) + + for padding in ["VALID"]: + layer = tl.LocallyConnected1d(6, 3, padding=padding) + x = np.array([[0, 1], [2, 3], [4, 5]]) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, (1, 6)) + + +class FlattenTest(absltest.TestCase): + def test_keep_default(self): + layer = tl.Flatten() + x = np.ones((1, 2, 3, 4, 5)) + y = layer(x) + # Default is leave first axis untouched, flatten the rest. + self.assertEqual(y.shape, (1, 2 * 3 * 4 * 5)) + + def test_keep_3(self): + layer = tl.Flatten(n_axes_to_keep=3) + x = np.ones((1, 2, 3, 4, 5)) + y = layer(x) + self.assertEqual(y.shape, (1, 2, 3, 4 * 5)) + + def test_keep_max_number(self): + layer = tl.Flatten(n_axes_to_keep=4) + x = np.ones((1, 2, 3, 4, 5)) + y = layer(x) + self.assertEqual(y.shape, (1, 2, 3, 4, 5)) + + def test_keep_too_many_raises_error(self): + layer = tl.Flatten(n_axes_to_keep=5) + with self.assertRaises(tl.LayerError): + x = np.ones((1, 2, 3, 4, 5)) + _ = layer(x) + + +class LogSoftmaxTest(absltest.TestCase): + def test_call(self): + layer = tl.LogSoftmax() + x = np.array([[2.0, 1.0, -10.0], [1.0, 1.0, -10.0]]) + y = layer(x) + np.testing.assert_allclose( + y, [[-0.313, -1.313, -12.313], [-0.693, -0.693, -11.693]], atol=0.001 + ) + + +class SoftmaxTest(absltest.TestCase): + def test_call(self): + layer = tl.Softmax() + x = np.array([[2.0, 1.0, -10.0], [1.0, 1.0, -10.0]]) + y = layer(x) + np.testing.assert_allclose( + y, [[0.731, 0.269, 0.00000449], [0.500, 0.500, 0.00000835]], atol=0.001 + ) + + +class CoreFunctionsTest(absltest.TestCase): + def test_one_hot(self): + targets = np.array([2, 0, 1]) + n_categories = 5 + target_distributions = tl.one_hot(targets, n_categories) + self.assertEqual( + tl.to_list(target_distributions), + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ], + ) + + def test_log_softmax(self): + activations = np.array([[2.0, 1.0, -10.0], [1.0, 1.0, -10.0]]) + log_probabilities = tl.log_softmax(activations) + np.testing.assert_allclose( + log_probabilities, + [[-0.313, -1.313, -12.313], [-0.693, -0.693, -11.693]], + atol=0.001, + ) + + def test_log_gaussian_pdf(self): + x = np.zeros((2, 5), dtype=np.float32) + mu = x + dsigma = np.eye(5)[None, :, :] + sigma = np.concatenate([dsigma, 2 * dsigma], axis=0) + prob = tl.log_gaussian_pdf(x, mu, sigma) + self.assertEqual(prob.shape, (2,)) + self.assertEqual(int(prob[0]), -4) + self.assertEqual(int(prob[1]), -6) + + def test_log_gaussian_diag_pdf(self): + x = np.zeros((2, 5), dtype=np.float32) + mu = x + sigma = np.ones((5,))[None, :] + sigma = np.concatenate([sigma, 2 * sigma], axis=0) + prob = tl.log_gaussian_diag_pdf(x, mu, sigma) + self.assertEqual(prob.shape, (2,)) + self.assertEqual(int(prob[0]), -4) + self.assertEqual(int(prob[1]), -6) + + +class StopGradientTest(absltest.TestCase): + def test_passes(self): + layer = tl.StopGradient() + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2, 2)) + self.assertEqual(y.tolist(), [[3.0, 5.0], [2.0, 6.0]]) + + +class MinMaxTest(absltest.TestCase): + def test_min(self): + layer = tl.Min() + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [3.0, 2.0]) + + layer = tl.Min(axis=0) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [2.0, 5.0]) + + layer = tl.Min(axis=None) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, ()) + self.assertEqual(y.tolist(), 2.0) + + layer = tl.Min(keepdims=True) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2, 1)) + self.assertEqual(y.tolist(), [[3.0], [2.0]]) + + def test_max(self): + layer = tl.Max() + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [5.0, 6.0]) + + layer = tl.Max(axis=0) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (2,)) + self.assertEqual(y.tolist(), [3.0, 6.0]) + + layer = tl.Max(axis=None) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, ()) + self.assertEqual(y.tolist(), 6.0) + + layer = tl.Max(axis=0, keepdims=True) + x = np.array([[3.0, 5.0], [2.0, 6.0]]) + y = layer(x) + self.assertEqual(y.shape, (1, 2)) + self.assertEqual(y.tolist(), [[3.0, 6.0]]) + + +class ClassifierLayersTest(absltest.TestCase): + def test_threshold_to_binary(self): + layer = tl.ThresholdToBinary() + x = np.array([0.30, 0.49, 0.50, 0.51, 0.70]) + y = layer(x) + self.assertEqual(y.tolist(), [0, 0, 0, 1, 1]) + + def test_arg_max(self): + layer = tl.ArgMax() + x = np.array([[0.10, 0.90, 0.20, 0.80], [0.22, 0.88, 0.11, 0.99]]) + y = layer(x) + self.assertEqual(y.tolist(), [1, 3]) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/deconvolution_test.py b/tests/layers/deconvolution_test.py similarity index 72% rename from trax/layers/deconvolution_test.py rename to tests/layers/deconvolution_test.py index f1111f21e..ff6ddb590 100644 --- a/trax/layers/deconvolution_test.py +++ b/tests/layers/deconvolution_test.py @@ -15,23 +15,24 @@ """Tests for Deconvolution layers.""" -from absl.testing import absltest import numpy as np -from trax import shapes +from absl.testing import absltest + import trax.layers as tl +from trax.utils import shapes -class ConvTransposeTest(absltest.TestCase): - def test_call(self): - layer = tl.ConvTranspose(30, (3, 3)) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) +class ConvTransposeTest(absltest.TestCase): + def test_call(self): + layer = tl.ConvTranspose(30, (3, 3)) + x = np.ones((9, 5, 5, 20)) + layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (9, 7, 7, 30)) + y = layer(x) + self.assertEqual(y.shape, (9, 7, 7, 30)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/initializers_test.py b/tests/layers/initializers_test.py new file mode 100644 index 000000000..34ab5808d --- /dev/null +++ b/tests/layers/initializers_test.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for initializers.""" + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl + +from trax import fastmath +from trax.utils import test_utils + +INPUT_SHAPE = (5, 7, 20) + + +def rng(): # Can't be a constant, because JAX has to init itself in main first. + return fastmath.random.get_prng(0) + + +class InitializersTest(absltest.TestCase): + def test_random_normal(self): + f = tl.RandomNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_lecun_uniform(self): + f = tl.LeCunUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_random_uniform(self): + f = tl.RandomUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_glorot_normal(self): + f = tl.GlorotNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_glorot_uniform(self): + f = tl.GlorotUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_lecun_normal(self): + f = tl.LeCunNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_kaiming_normal(self): + f = tl.KaimingNormalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_kaiming_uniform(self): + f = tl.KaimingUniformInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_orthogonal(self): + f = tl.OrthogonalInitializer() + init_value = f(INPUT_SHAPE, rng()) + self.assertEqual(init_value.shape, INPUT_SHAPE) + + def test_from_file(self): + params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) + # `create_tempfile` needs access to --test_tmpdir, however in the OSS world + # pytest doesn't run `absltest.main`, so we need to manually parse the flags + test_utils.ensure_flag("test_tmpdir") + filename = self.create_tempfile("params.npy").full_path + with open(filename, "wb") as f: + np.save(f, params) + f = tl.InitializerFromFile(filename) + init_value = f(params.shape, rng()) + np.testing.assert_almost_equal( + tl.to_list(init_value), tl.to_list(params), decimal=4 + ) + # self.assertEqual('%s' % init_value, '%s' % params) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/metrics_test.py b/tests/layers/metrics_test.py new file mode 100644 index 000000000..36b41d622 --- /dev/null +++ b/tests/layers/metrics_test.py @@ -0,0 +1,443 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for metrics layers.""" + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl + + +class MetricsTest(absltest.TestCase): + def test_category_accuracy(self): + layer = tl.CategoryAccuracy() + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets]) + self.assertEqual(accuracy, 1 / 3) + + def test_weighted_category_accuracy_even_weights(self): + layer = tl.WeightedCategoryAccuracy() + weights = np.array([1.0, 1.0, 1.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1 / 3) + + def test_weighted_category_accuracy_uneven_weights(self): + layer = tl.WeightedCategoryAccuracy() + weights = np.array([1.0, 5.0, 2.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.625) + + def test_category_cross_entropy(self): + layer = tl.CategoryCrossEntropy() + targets = np.array([0, 1]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.001, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.665, places=3) + + # First item near perfect, second item more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.333, places=3) + + def test_category_cross_entropy_with_label_smoothing(self): + epsilon = 0.01 + layer = tl.CategoryCrossEntropy(label_smoothing=epsilon) + targets = np.array([0, 1]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.069, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.682, places=3) + + # First item near perfect, second item more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.375, places=3) + + def test_weighted_category_cross_entropy(self): + layer = tl.WeightedCategoryCrossEntropy() + targets = np.array([0, 1]) + weights = np.array([30, 10]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.001, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.665, places=3) + + # First item (with 75% weight) near perfect, second more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.167, places=3) + + def test_weighted_category_cross_entropy_with_label_smoothing(self): + epsilon = 0.01 + layer = tl.WeightedCategoryCrossEntropy(label_smoothing=epsilon) + targets = np.array([0, 1]) + weights = np.array([30, 10]) + + # Near-perfect prediction (for both items in batch). + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 9.0, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.069, places=3) + + # More right than wrong (for both items in batch). + model_outputs = np.array([[2.2, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.682, places=3) + + # First item (with 75% weight) near perfect, second more right than wrong. + model_outputs = np.array([[9.0, 2.0, 0.0, -2.0], [2.0, 2.2, 0.0, -2.0]]) + loss = layer([model_outputs, targets, weights]) + self.assertAlmostEqual(loss, 0.222, places=3) + + def test_masked_sequence_accuracy(self): + layer = tl.MaskedSequenceAccuracy() + targets = np.array([[0, 1, 0, 0], [1, 0, 1, 0]]) + weights = np.array([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0]]) + + # Model gets both sequences right; output in final position would give + # wrong category but is ignored. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.35, 0.65]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.35, 0.65]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + # Model gets the first element of the first sequence barely wrong. + model_outputs = np.array( + [ + [[0.45, 0.55], [0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.5) + + # Model gets second-to-last element of each sequence barely wrong. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.48, 0.52], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.51, 0.49], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.0) + + def test_binary_cross_entropy(self): + layer = tl.BinaryCrossEntropy() + targets = np.array([1, 1, 0, 0, 0]) + + # Near-perfect prediction for all five items in batch. + model_outputs = np.array([9.0, 9.0, -9.0, -9.0, -9.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 0.000123, places=6) + + # More right than wrong for all five items in batch. + model_outputs = np.array([1.0, 1.0, -1.0, -1.0, -1.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 0.313, places=3) + + # Near-perfect for 2, more right than wrong for 3. + model_outputs = np.array([9.0, 1.0, -1.0, -1.0, -9.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 0.188, places=3) + + # More wrong than right for all five. + model_outputs = np.array([-1.0, -1.0, 1.0, 1.0, 1.0]) + metric_output = layer([model_outputs, targets]) + self.assertAlmostEqual(metric_output, 1.313, places=3) + + def test_accuracy_even_weights(self): + layer = tl.Accuracy() + weights = np.array([1.0, 1.0, 1.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1 / 3) + + def test_accuracy_uneven_weights(self): + layer = tl.Accuracy() + weights = np.array([1.0, 5.0, 2.0]) + targets = np.array([0, 1, 2]) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.1, 0.7, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0], [0.2, 0.7, 0.1, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.625) + + model_outputs = np.array( + [[0.7, 0.2, 0.1, 0.0], [0.7, 0.2, 0.1, 0.0], [0.7, 0.2, 0.1, 0.0]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.125) + + def test_accuracy_binary_classifier(self): + layer = tl.Accuracy(classifier=tl.ThresholdToBinary()) + targets = np.array([[0, 0, 1, 1], [1, 1, 1, 0]]) + weights = np.ones_like(targets) + + model_outputs = np.array( + [[0.499, 0.500, 0.501, 0.502], [0.503, 0.502, 0.501, 0.500]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + model_outputs = np.array( + [[0.498, 0.499, 0.500, 0.501], [0.502, 0.501, 0.500, 0.499]] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.75) + + def test_sequence_accuracy_weights_all_ones(self): + layer = tl.SequenceAccuracy() + targets = np.array([[0, 1, 0, 1], [1, 0, 1, 1]]) + weights = np.ones_like(targets) + + # Model gets both sequences right; for each position in each sequence, the + # category (integer ID) selected by argmax matches the target category. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.4, 0.6]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.4, 0.6]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + # Model gets the first element of the first sequence barely wrong. + model_outputs = np.array( + [ + [[0.45, 0.55], [0.2, 0.8], [0.7, 0.3], [0.4, 0.6]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.4, 0.6]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.5) + + # Model gets the last element of each sequence barely wrong. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.55, 0.45]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.52, 0.48]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.0) + + def test_sequence_accuracy_last_position_zero_weight(self): + layer = tl.SequenceAccuracy() + targets = np.array([[0, 1, 0, 0], [1, 0, 1, 0]]) + weights = np.array([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0]]) + + # Model gets both sequences right; output in final position would give + # wrong category but is ignored. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.35, 0.65]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.35, 0.65]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 1.0) + + # Model gets the first element of the first sequence barely wrong. + model_outputs = np.array( + [ + [[0.45, 0.55], [0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.1, 0.9], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.5) + + # Model gets second-to-last element of each sequence barely wrong. + model_outputs = np.array( + [ + [[0.9, 0.1], [0.2, 0.8], [0.48, 0.52], [0.6, 0.4]], + [[0.3, 0.7], [0.8, 0.2], [0.51, 0.49], [0.6, 0.4]], + ] + ) + accuracy = layer([model_outputs, targets, weights]) + self.assertEqual(accuracy, 0.0) + + def test_binary_cross_entropy_loss(self): + # TODO(jonni): Clarify desired semantics/naming, then test it. + layer = tl.BinaryCrossEntropyLoss() + xs = [np.ones((9, 1)), np.ones((9, 1)), np.ones((9, 1))] + y = layer(xs) + self.assertEqual(y.shape, ()) + + def test_cross_entropy_loss(self): + # TODO(jonni): Clarify desired semantics/naming, then test it. + layer = tl.CrossEntropyLoss() + xs = [np.ones((9, 4, 4, 20)), np.ones((9, 4, 4)), np.ones((9, 4, 4))] + y = layer(xs) + self.assertEqual(y.shape, ()) + + def test_l2_loss(self): + layer = tl.L2Loss() + + model_outputs = np.array([[1.0, 1.0], [1.0, 1.0]]) + targets = np.array([[1.0, 1.0], [1.0, 0.0]]) + weights = np.array([[1.0, 1.0], [1.0, 0.0]]) + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.0) + + weights = np.array([[1.0, 0.0], [0.0, 1.0]]) + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.5) + + def test_smooth_l1_loss(self): + layer = tl.SmoothL1Loss() + + model_outputs = np.array([[1.0, 1.0], [1.0, 2.0]]) + targets = np.array([[1.0, 1.0], [1.0, 0.0]]) + l1_dist = 2 + + weights = np.array([[1.0, 1.0], [1.0, 0.0]]) + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.0) + + weights = np.array([[1.0, 0.0], [0.0, 1.0]]) + sum_weights = 2 + + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, (l1_dist - 0.5) / sum_weights) + + model_outputs = np.array([[1.0, 1.0], [1.0, 1.5]]) + targets = np.array([[1.0, 1.0], [1.0, 1.0]]) + l1_dist = 0.5 + loss = layer([model_outputs, targets, weights]) + np.testing.assert_allclose(loss, 0.5 * l1_dist**2 / sum_weights) + + def test_macro_averaged_f_score(self): + # predictions = [1, 1, 2, 1, 1]. + model_outputs = np.array( + [[0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]] + ) + targets = np.array([1, 2, 2, 3, 1]) + # Category indices starting with `0`. + layer = tl.MacroAveragedFScore() + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.333, places=3) + # Excluding the padding index `0`. + layer = tl.MacroAveragedFScore(initial_category_index=1) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.444, places=3) + + def test_weighted_f_score(self): + # predictions = [1, 1, 2, 1, 1]. + model_outputs = np.array( + [[0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]] + ) + targets = np.array([1, 2, 2, 3, 1]) + # Category indices starting with `0`. + layer = tl.WeightedFScore() + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.533, places=3) + # Excluding the padding index `0`. + layer = tl.WeightedFScore(initial_category_index=1) + loss = layer([model_outputs, targets]) + self.assertAlmostEqual(loss, 0.533, places=3) + + def test_names(self): + layer = tl.L2Loss() + self.assertEqual("L2Loss_in3", str(layer)) + layer = tl.Accuracy() + self.assertEqual("Accuracy_in3", str(layer)) + layer = tl.SequenceAccuracy() + self.assertEqual("SequenceAccuracy_in3", str(layer)) + layer = tl.BinaryCrossEntropyLoss() + self.assertEqual("BinaryCrossEntropyLoss_in3", str(layer)) + layer = tl.CrossEntropyLoss() + self.assertEqual("CrossEntropyLoss_in3", str(layer)) + layer = tl.BinaryCrossEntropySum() + self.assertEqual("BinaryCrossEntropySum_in3", str(layer)) + layer = tl.CrossEntropySum() + self.assertEqual("CrossEntropySum_in3", str(layer)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/normalization_test.py b/tests/layers/normalization_test.py new file mode 100644 index 000000000..cfa9561e7 --- /dev/null +++ b/tests/layers/normalization_test.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for normalization layers.""" + +import numpy as np + +from absl.testing import absltest, parameterized + +import trax.layers as tl + +from trax import fastmath +from trax.utils import shapes + + +class BatchNormTest(parameterized.TestCase): + def test_forward_shape(self): + layer = tl.BatchNorm() + x = np.ones((30, 20, 70)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + @parameterized.named_parameters( + ("jax32", fastmath.Backend.JAX, np.float32), + ("tf32", fastmath.Backend.TFNP, np.float32), + ("tf64", fastmath.Backend.TFNP, np.float64), + ) + def test_forward_dtype(self, backend, dtype): + with fastmath.use_backend(backend): + layer = tl.BatchNorm() + x = np.ones((3, 2, 7)).astype(dtype) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.dtype, dtype) + + @parameterized.named_parameters( + ("momentum_999", 0.999), + ("momentum_900", 0.900), + ("momentum_800", 0.800), + ) + def test_forward(self, momentum): + layer = tl.BatchNorm(momentum=momentum) + x = np.array( + [ + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], + ] + ).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + running_mean, running_var, n_batches = layer.state + + fraction_old = momentum + fraction_new = 1.0 - momentum + mean_of_x = 11.5 # mean of range(24) + var_of_x = 47.9167 # variance of range(24) + np.testing.assert_allclose( + running_mean, 0.0 * fraction_old + mean_of_x * fraction_new + ) + np.testing.assert_allclose( + running_var, 1.0 * fraction_old + var_of_x * fraction_new, rtol=1e-6 + ) + self.assertEqual(n_batches, 1) + eps = 1e-5 + np.testing.assert_allclose( + y, (x - mean_of_x) / np.sqrt(var_of_x + eps), rtol=1e-6 + ) + + def test_new_weights_and_state(self): + layer = tl.BatchNorm() + x = np.ones((3, 2, 7)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + + running_mean, running_var, n_batches = layer.state + np.testing.assert_allclose(running_mean, 0.0) + np.testing.assert_allclose(running_var, 1.0) + self.assertEqual(n_batches, 0) + + +class LayerNormTest(parameterized.TestCase): + def test_forward_shape(self): + layer = tl.LayerNorm() + x = np.ones((3, 2, 7)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + @parameterized.named_parameters( + ("jax32", fastmath.Backend.JAX, np.float32), + ("tf32", fastmath.Backend.TFNP, np.float32), + ("tf64", fastmath.Backend.TFNP, np.float64), + ) + def test_forward_dtype(self, backend, dtype): + with fastmath.use_backend(backend): + layer = tl.LayerNorm() + x = np.ones((3, 2, 7)).astype(dtype) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.dtype, dtype) + + +class FilterResponseNormTest(parameterized.TestCase): + @parameterized.named_parameters( + ("learn_epsilon_false", False), + ("learn_epsilon_true", True), + ) + def test_forward_shape(self, learn_epsilon): + layer = tl.FilterResponseNorm(learn_epsilon=learn_epsilon) + + B, H, W, C = 64, 5, 7, 3 # pylint: disable=invalid-name + x = np.ones((B, H, W, C)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/pooling_test.py b/tests/layers/pooling_test.py new file mode 100644 index 000000000..262d8ec79 --- /dev/null +++ b/tests/layers/pooling_test.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for conv layers.""" + +import numpy as np + +from absl.testing import absltest + +import trax.layers as tl + + +class MaxPoolTest(absltest.TestCase): + def test_forward_shape(self): + layer = tl.MaxPool(pool_size=(2, 2), strides=(1, 2)) + x = np.ones((11, 6, 4, 17)) + y = layer(x) + self.assertEqual(y.shape, (11, 5, 2, 17)) + + def test_forward(self): + layer = tl.MaxPool(pool_size=(2, 2), strides=(2, 2)) + x = np.array( + [ + [ + [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], + [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], + ] + ] + ) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[[7, 5, 6], [70, 50, 60]]]]) + + def test_padding_default(self): + layer = tl.MaxPool(pool_size=(3,), strides=(3,)) + + # Discard incomplete window at end: [[3, 6], [4, 5]]. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[2, 9]]]) + + def test_padding_same(self): + layer = tl.MaxPool(pool_size=(3,), strides=(3,), padding="SAME") + + # One padding position needed; add at end. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[2, 9], [4, 6]]]) + + # Two padding positions needed; add one at end and one at start. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[1, 9], [3, 7]]]) + + +class SumPoolTest(absltest.TestCase): + def test_forward_shape(self): + layer = tl.SumPool(pool_size=(2, 2), strides=(1, 2)) + x = np.ones((11, 6, 4, 17)) + y = layer(x) + self.assertEqual(y.shape, (11, 5, 2, 17)) + + def test_forward(self): + layer = tl.SumPool(pool_size=(2, 2), strides=(2, 2)) + x = np.array( + [ + [ + [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], + [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], + ] + ] + ) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[[16, 10, 14], [160, 100, 140]]]]) + + def test_padding_same(self): + layer = tl.SumPool(pool_size=(3,), strides=(3,), padding="SAME") + + # One padding position needed; add at end. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[3, 24], [7, 11]]]) + + # Two padding positions needed; add one at end and one at start. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[1, 17], [5, 13]]]) + + +class AvgPoolTest(absltest.TestCase): + def test_forward_shape(self): + layer = tl.AvgPool(pool_size=(2, 2), strides=(1, 2)) + x = np.ones((11, 6, 4, 17)) + y = layer(x) + self.assertEqual(y.shape, (11, 5, 2, 17)) + + def test_forward(self): + layer = tl.AvgPool(pool_size=(2, 2), strides=(2, 2)) + x = np.array( + [ + [ + [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], + [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], + ] + ] + ) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[[4.0, 2.5, 3.5], [40, 25, 35]]]]) + + def test_padding_same(self): + layer = tl.AvgPool(pool_size=(3,), strides=(3,), padding="SAME") + + # One padding position needed; add at end. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6], [4, 5]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[1, 8], [3.5, 5.5]]]) + + # Two padding positions needed; add one at end and one at start. + x = np.array([[[0, 9], [1, 8], [2, 7], [3, 6]]]) + y = layer(x) + self.assertEqual(tl.to_list(y), [[[0.5, 8.5], [2.5, 6.5]]]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/research/efficient_attention_test.py b/tests/layers/research/efficient_attention_test.py new file mode 100644 index 000000000..ef64d5b20 --- /dev/null +++ b/tests/layers/research/efficient_attention_test.py @@ -0,0 +1,597 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.research.efficient_attention.""" + +import jax +import numpy as np + +from absl.testing import parameterized +from tensorflow import test + +from trax import fastmath +from trax.fastmath import numpy as jnp +from trax.layers.research import efficient_attention +from trax.utils import shapes + + +class EfficientAttentionTest(test.TestCase, parameterized.TestCase): + def test_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.SelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_lsh_ff(self): + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.LSHFF(d_ff=1024 * 8, n_buckets=[16, 8]) + x = np.ones((3, 7, 1024)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_self_attention_tf(self): + with fastmath.use_backend(fastmath.Backend.TFNP): + layer = efficient_attention.SelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_lsh_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.LSHSelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + causal=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def _run_forward_and_backward(self, model, inp, weights, state): + def forward(inp, weights): + return model.pure_fn(inp, weights, state, rng=jax.random.PRNGKey(0)) + + out, vjpfun, new_state = jax.vjp(forward, inp, weights, has_aux=True) + inp_grad, weights_grad = vjpfun(fastmath.numpy.ones_like(inp)) + return out, new_state, inp_grad, weights_grad + + def _test_equivalence_to_reference_code( + self, model_cls, inp, input_signature, common_kwargs, *test_kwargs + ): + ref_model = model_cls(use_reference_code=True, **common_kwargs) + rng = fastmath.random.get_prng(123) + weights, state = ref_model.init(input_signature, rng) + + ref_all = self._run_forward_and_backward(ref_model, inp, weights, state) + ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all + + for kwargs in test_kwargs: + test_model = model_cls(**common_kwargs, **kwargs) + state = test_model.init(input_signature, rng)[1] + test_all = self._run_forward_and_backward(test_model, inp, weights, state) + test_out, test_state, test_inp_grad, test_weights_grad = test_all + + self.assertEqual(jax.tree_structure(ref_out), jax.tree_structure(test_out)) + self.assertEqual( + jax.tree_structure(ref_state), jax.tree_structure(test_state) + ) + self.assertEqual( + jax.tree_structure(ref_inp_grad), jax.tree_structure(test_inp_grad) + ) + self.assertEqual( + jax.tree_structure(ref_weights_grad), + jax.tree_structure(test_weights_grad), + ) + + check_close = lambda x, y: self.assertAllClose(x, y, rtol=2e-3, atol=2e-3) + fastmath.nested_map_multiarg(check_close, ref_out, test_out) + fastmath.nested_map_multiarg(check_close, ref_state, test_state) + fastmath.nested_map_multiarg(check_close, ref_inp_grad, test_inp_grad) + fastmath.nested_map_multiarg( + check_close, ref_weights_grad, test_weights_grad + ) + + def test_batching_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + common_kwargs = dict( + n_heads=6, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=5, + n_chunks_before=1, + n_chunks_after=0, + attention_dropout=0.2, + output_dropout=0.1, + mode="train", + ) + test_kwargs = [] + for n_parallel_heads in [1, 3, 6, 12]: + for use_python_loop in [True, False]: + test_kwargs.append( + dict( + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + ) + ) + + x = jax.random.uniform( + jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32 + ) + input_signature = shapes.signature(x) + self._test_equivalence_to_reference_code( + efficient_attention.SelfAttention, + x, + input_signature, + common_kwargs, + *test_kwargs, + ) + + def test_batching_lsh_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + common_kwargs = dict( + n_heads=6, + d_qk=7, + d_v=17, + causal=True, + chunk_len=5, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + attention_dropout=0.2, + output_dropout=0.1, + mode="train", + ) + test_kwargs = [] + for n_parallel_heads in [1, 3, 6, 12]: + for use_python_loop in [True, False]: + test_kwargs.append( + dict( + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + ) + ) + + x = jax.random.uniform( + jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32 + ) + input_signature = shapes.signature(x) + self._test_equivalence_to_reference_code( + efficient_attention.LSHSelfAttention, + x, + input_signature, + common_kwargs, + *test_kwargs, + ) + + def _test_fast_inference( + self, model_cls, x, input_signature, common_kwargs, *test_kwargs + ): + ref_model = model_cls(use_reference_code=True, mode="eval", **common_kwargs) + weights, state = ref_model.init(input_signature) + + ref_out, _ = ref_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + + def get_slice(pytree, i): + def get_slice_for_val(x): + if isinstance(x, shapes.ShapeDtype): + return shapes.ShapeDtype( + shape=x.shape[:1] + (1,) + x.shape[2:], dtype=x.dtype + ) + else: + return x[:, i : i + 1] + + return jax.tree_map(get_slice_for_val, pytree) + + seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1] + + for kwargs in test_kwargs: + test_model = model_cls(mode="predict", **common_kwargs, **kwargs) + cur_state = test_model.init(get_slice(input_signature, 0))[1] + out = [] + for i in range(seqlen): + cur_out, cur_state = test_model.pure_fn( + get_slice(x, i), weights, cur_state, jax.random.PRNGKey(0) + ) + out.append(cur_out) + out = jnp.concatenate(out, axis=1) + + self.assertAllClose(out, ref_out, rtol=1e-3, atol=1e-3) + + def test_fast_inference_self_attention(self): + with fastmath.use_backend(fastmath.Backend.JAX): + common_kwargs = dict( + n_heads=6, + d_qk=7, + d_v=17, + share_qk=False, + causal=True, + chunk_len=5, + n_chunks_before=1, + n_chunks_after=0, + attention_dropout=0.0, + output_dropout=0.0, + ) + test_kwargs = [] + for n_parallel_heads in [1, 3, 6, 12]: + for use_python_loop in [True, False]: + test_kwargs.append( + dict( + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + ) + ) + + x = jax.random.uniform( + jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32 + ) + input_signature = shapes.signature(x) + self._test_fast_inference( + efficient_attention.SelfAttention, + x, + input_signature, + common_kwargs, + *test_kwargs, + ) + + def _test_lsh_self_attention_deterministic_given_seed(self, causal=False): + # Once the initialization and the call seeds are pinned down we have + # deterministic output. + with fastmath.use_backend(fastmath.Backend.JAX): + layer = efficient_attention.LSHSelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + causal=causal, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + x = np.ones((3, 32, 8)).astype(np.float32) + + def get_output(): + _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0)) + return layer(x, rng=jax.random.PRNGKey(1)) + + ys = [get_output() for _ in range(10)] + + self.assertEqual(ys[0].shape, x.shape) + + for y in ys[1:]: + np.testing.assert_array_almost_equal(ys[0], y, decimal=6) + + def test_lsh_determinism_causal(self): + self._test_lsh_self_attention_deterministic_given_seed(causal=True) + + def test_lsh_determinism_non_causal(self): + self._test_lsh_self_attention_deterministic_given_seed(causal=False) + + def test_lsh_self_attention_masked_non_causal(self): + # Test that when the input that is in the masked area changes the attention + # for the un-masked outputs doesn't change, but the masked region does + # change. + with fastmath.use_backend(fastmath.Backend.JAX): + # Set a fixed seed for deterministic hashing + rng_seed = 42 + + layer = efficient_attention.LSHSelfAttention( + n_heads=5, + d_qk=7, + d_v=17, + causal=False, + masked=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=2, + n_buckets=4, + use_reference_code=True, + attention_dropout=0.0, + mode="train", + ) + + batch = 5 + max_len = 32 + hidden = 8 + + # Use fixed seed for reproducible input + input_rng = jax.random.PRNGKey(rng_seed) + x = jax.random.uniform(input_rng, shape=(batch, max_len, hidden)) + + mask = np.ones((batch, max_len)).astype(bool) + mask_rng = jax.random.PRNGKey(rng_seed + 1) + rngs = jax.random.randint(mask_rng, (batch,), minval=1, maxval=max_len - 1) + + # Set some suffix of each mask[b] to 0. + for i in range(batch): + mask[i, rngs[i] :] = 0 + + # Fix rngs and get the output for the LSH layer. + def get_output(x, mask): + xs = [x, mask] + params_rng = jax.random.PRNGKey(rng_seed + 2) + _, _ = layer.init(shapes.signature(xs), params_rng) + # Use the same RNG for both forward passes to ensure deterministic hashing + forward_rng = jax.random.PRNGKey(rng_seed + 3) + return layer(xs, rng=forward_rng) + + # Get the attention output for masked x. + y = get_output(x, mask) + + # Create a modified input with a different seed, but only for masked regions + mod_input_rng = jax.random.PRNGKey(rng_seed + 10) # Different seed + x_modified = np.copy(x) # Create a copy to modify + + # Change x, but only in the masked regions. + for i in range(batch): + # Generate modifications with fixed seed for reproducibility + modifications = jax.random.uniform( + mod_input_rng, shape=(max_len - rngs[i], hidden) + ) + x_modified[i, rngs[i] :] = modifications + + y2 = get_output(x_modified, mask) + + for i in range(batch): + # y and y2 should be identical in the non-masked part. + np.testing.assert_array_almost_equal( + y[i, : rngs[i]], y2[i, : rngs[i]], decimal=6 + ) + + @parameterized.named_parameters(("_weights_2", 2), ("_weights_3", 3)) + def test_pure_lsh_wrapper_causal_non_masked(self, num_weights): + with fastmath.use_backend(fastmath.Backend.JAX): + n_heads = 5 + batch, seqlen, d_head = 3, 32, 8 + n_hashes = 2 + d_model = n_heads * d_head + layer = efficient_attention.PureLSHSelfAttentionWrapper( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=True, + masked=False, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=n_hashes, + n_buckets=4, + bias=False, + pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, + mode="train", + num_weights=num_weights, + ) + + rng = jax.random.PRNGKey(0) + rng, x_rng = jax.random.split(rng) + + input_shape = (batch, seqlen, d_model) + x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) + + inp = x + w, s = layer.init(shapes.signature(inp)) + o = layer(inp) + + # Get the actual weights. + weights = fastmath.tree_leaves(w) + # Assert number of weights is as expected, the extra 1 is for output. + self.assertLen(weights, num_weights + 1) + + # Assert each weight is of the expected shape. + for i in range(num_weights + 1): + self.assertEqual(weights[i].shape, (d_model, d_model)) + + # Test that the output and the input shape match. + self.assertEqual(inp.shape, o.shape) + + # Assert state is the shape expected. + state = fastmath.tree_leaves(s) + self.assertLen(state, 2) + # buckets + self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) + # rngs + self.assertEqual(state[1].shape, (batch * n_heads, 2)) + + @parameterized.named_parameters(("_weights_2", 2), ("_weights_3", 3)) + def test_pure_lsh_wrapper_non_causal_masked(self, num_weights): + with fastmath.use_backend(fastmath.Backend.JAX): + n_heads = 5 + batch, seqlen, d_head = 3, 32, 8 + num_weights = 2 + n_hashes = 2 + d_model = n_heads * d_head + layer = efficient_attention.PureLSHSelfAttentionWrapper( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=False, + masked=True, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=n_hashes, + n_buckets=4, + bias=False, + pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, + mode="train", + num_weights=num_weights, + ) + + rng = jax.random.PRNGKey(0) + rng, x_rng = jax.random.split(rng) + + input_shape = (batch, seqlen, d_model) + x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) + mask = jnp.ones((batch, seqlen), dtype=jnp.int32) + + inp = (x, mask) + w, s = layer.init(shapes.signature(inp)) + o = layer(inp) + + # Get the actual weights. + weights = fastmath.tree_leaves(w) + # Assert number of weights is as expected, the extra 1 is for output. + self.assertLen(weights, num_weights + 1) + + # Assert each weight is of the expected shape. + for i in range(num_weights + 1): + self.assertEqual(weights[i].shape, (d_model, d_model)) + + # Test that the output and the x's shape match. + self.assertEqual(x.shape, o.shape) + + # Assert state is the shape expected. + state = fastmath.tree_leaves(s) + self.assertLen(state, 2) + # buckets + self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) + # rngs + self.assertEqual(state[1].shape, (batch * n_heads, 2)) + + def test_lsh_and_pure_lsh_self_attention_equivalence(self): + # Given the same weight matrices and random numbers, do these produce the + # same output. + with fastmath.use_backend(fastmath.Backend.JAX): + n_heads = 4 + d_head = 4 + d_model = n_heads * d_head + pure_lsh_layer = efficient_attention.PureLSHSelfAttention( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=True, + masked=False, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=4, + n_buckets=8, + use_reference_code=False, + attention_dropout=0.0, + use_python_loop=True, + bias=False, + mode="train", + ) + lsh_layer = efficient_attention.LSHSelfAttention( + n_heads=n_heads, + d_qk=d_head, + d_v=d_head, + causal=True, + masked=False, + chunk_len=8, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=4, + n_buckets=8, + use_reference_code=False, + attention_dropout=0.0, + use_python_loop=True, + mode="train", + ) + + batch, seqlen = 3, 32 + input_shape = (batch, seqlen, d_model) + + x = jax.random.uniform( + jax.random.PRNGKey(0), input_shape, dtype=jnp.float32 + ) + lsh_layer_input = x + + call_rng = jax.random.PRNGKey(42) + + lsh_layer_weights, lsh_layer_state = lsh_layer.init( + shapes.signature(lsh_layer_input) + ) + lsh_layer.rng = call_rng + lsh_layer_output = lsh_layer(lsh_layer_input) + + # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head), + # (n_heads, d_head, d_model) + # Abbreviated as - hmn, hmn, hnm + w_qk, w_v, w_o = lsh_layer_weights + + qk = jnp.einsum("blm,hmn->bhln", x, w_qk) + qk = qk.reshape((-1, qk.shape[2], qk.shape[3])) + + v = jnp.einsum("blm,hmn->bhln", x, w_v) + v = v.reshape((-1, v.shape[2], v.shape[3])) + + pure_lsh_layer_input = (qk, v) + _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input)) + pure_lsh_layer.rng = call_rng + pure_lsh_layer.state = lsh_layer_state + pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input) + + # b*h,l,n + pure_lsh_layer_output = pure_lsh_layer_output.reshape( + (batch, -1) + pure_lsh_layer_output.shape[1:] + ) + pure_lsh_layer_output_projected = jnp.einsum( + "bhld,hdm->blm", pure_lsh_layer_output, w_o + ) + + diff = pure_lsh_layer_output_projected - lsh_layer_output + avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff)) + + self.assertLess(avg_diff, 1e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tests/layers/research/flash_attention_test.py b/tests/layers/research/flash_attention_test.py new file mode 100644 index 000000000..63f3a4f6f --- /dev/null +++ b/tests/layers/research/flash_attention_test.py @@ -0,0 +1,38 @@ +# coding=utf-8 +"""Tests for flash_attention.""" + +import numpy as np +from absl.testing import absltest + +import jax +from trax import fastmath +from trax.fastmath import numpy as jnp +from trax.layers.research import flash_attention + + +def _naive_attention(q, k, v, mask=None): + logits = jnp.einsum("bqd,bkd->bqk", q, k) + if mask is not None: + logits = jnp.where(mask[:, None, :], -1e9, logits) + weights = jax.nn.softmax(logits, axis=-1) + return jnp.einsum("bqk,bkd->bqd", weights, v) + + +class FlashAttentionTest(absltest.TestCase): + def test_matches_naive(self): + with fastmath.use_backend(fastmath.Backend.JAX): + batch, seqlen, d = 2, 7, 4 + q = jnp.arange(batch * seqlen * d).reshape((batch, seqlen, d)) / 100.0 + k = jnp.arange(batch * seqlen * d).reshape((batch, seqlen, d)) / 50.0 + v = jnp.arange(batch * seqlen * d).reshape((batch, seqlen, d)) / 25.0 + mask = jnp.arange(seqlen)[None, :] >= 5 + out_ref = _naive_attention(q, k, v, mask) + out_flash = flash_attention.flash_attention( + q, k, v, block_size=4, mask=mask + ) + self.assertEqual(out_ref.shape, out_flash.shape) + np.testing.assert_allclose(out_ref, out_flash, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/research/position_encodings_test.py b/tests/layers/research/position_encodings_test.py new file mode 100644 index 000000000..e8d3ad98f --- /dev/null +++ b/tests/layers/research/position_encodings_test.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.research.position_encodings.""" + +import functools + +import absl.testing.absltest as unittest +import numpy as np +import parameterized + +import trax.layers.research.position_encodings as pe + +from trax import fastmath + + +@parameterized.parameterized_class( + [ + # {'Encoding': pe.FixedBasePositionalEncoding}, + {"Encoding": pe.InfinitePositionalEncoding}, + {"Encoding": functools.partial(pe.InfinitePositionalEncoding, affine=False)}, + { + "Encoding": functools.partial( + pe.TimeBinPositionalEncoding, time_bin_length=5 + ) + }, + ] +) +class PositionEncodingsTest(unittest.TestCase): + """Position encodings conform to the position encodings protocol.""" + + @parameterized.parameterized.expand( + [ + (1, 100, 8), # typical + (1, 1, 8), # short + (1, 100, 1), # narrow + (2, 100, 8), # batched + ] + ) + def test_training(self, n, t, c): + encoding = self.Encoding() + input_ntc = np.random.randn(n, t, c) + encoding.init(input_ntc) + output_ntc = encoding(input_ntc) + self.assertEqual(output_ntc.shape, input_ntc.shape) + self.assertTrue(np.not_equal(output_ntc, input_ntc).any()) + + @parameterized.parameterized.expand( + [ + (1, 100, 8), # typical + (1, 100, 1), # narrow + (2, 100, 8), # batched + ] + ) + def test_inference(self, n, t, c): + # Get the eval mode outputs: + encoding = self.Encoding(mode="eval") + input_ntc = np.random.randn(n, t, c) + rng = fastmath.random.get_prng(1234) + encoding.init(input_ntc, rng=rng) + output_ntc = encoding(input_ntc) + + is_random = self.Encoding == pe.InfinitePositionalEncoding + + # Get the predict mode outputs: + encoding_pred = self.Encoding(mode="predict") + encoding_pred.init(input_ntc[:, 0:1, :], rng=rng) + output_ntc0 = encoding_pred(input_ntc[:, 0:1, :]) + if not is_random: + np.testing.assert_allclose(output_ntc0, output_ntc[:, 0:1, :], atol=1e-4) + + output_ntc1 = encoding_pred(input_ntc[:, 1:2, :]) + if not is_random: + np.testing.assert_allclose(output_ntc1, output_ntc[:, 1:2, :], atol=1e-4) + + output_ntc2 = encoding_pred(input_ntc[:, 2:3, :]) + if not is_random: + np.testing.assert_allclose(output_ntc2, output_ntc[:, 2:3, :], atol=1e-4) + + +class SinCosEncodingsTest(unittest.TestCase): + """Position encodings conform to the position encodings protocol.""" + + @parameterized.parameterized.expand( + [ + (1, 100, 8), # typical + (1, 1, 8), # short + (2, 100, 8), # batched + ] + ) + def test_training(self, n, t, c): + encoding = pe.SinCosPositionalEncoding() + input_ntc = np.random.randn(n, t, c) + encoding.init(input_ntc) + output_ntc = encoding(input_ntc) + self.assertEqual(output_ntc.shape, input_ntc.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/trax/layers/research/rel_attention_test.py b/tests/layers/research/rel_attention_test.py similarity index 60% rename from trax/layers/research/rel_attention_test.py rename to tests/layers/research/rel_attention_test.py index 50918ff78..3e79cbb5a 100644 --- a/trax/layers/research/rel_attention_test.py +++ b/tests/layers/research/rel_attention_test.py @@ -29,26 +29,46 @@ """Tests for trax.layers.relattention.""" -from absl.testing import absltest import numpy as np +from absl.testing import absltest + import trax.layers as tl import trax.layers.research.rel_attention as ra class RelAttentionTest(absltest.TestCase): + def test_fast_shift_matrix(self): + layer = ra._fast_matrix_shift + x = np.array( + [ + [ + [ + [-3.0, -2.0, -1.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + ] + ] + ] + ).astype(np.float32) - def test_fast_shift_matrix(self): - layer = ra._fast_matrix_shift - x = np.array([[[[-3., -2., -1., 0.], [-3., -2., -1., - 0.], [-3., -2., -1., 0.], - [-3., -2., -1., 0.]]]]).astype(np.float32) + y = layer(x) + self.assertEqual(y.dtype, np.float32) + self.assertEqual( + tl.to_list(y), + [ + [ + [ + [0.0, 0.0, -3.0, -2.0], + [-1.0, 0.0, 0.0, -3.0], + [-2.0, -1.0, 0.0, 0.0], + [-3.0, -2.0, -1.0, 0.0], + ] + ] + ], + ) - y = layer(x) - self.assertEqual(y.dtype, np.float32) - self.assertEqual( - tl.to_list(y), [[[[0., 0., -3., -2.], [-1., 0., 0., -3.], - [-2., -1., 0., 0.], [-3., -2., -1., 0.]]]]) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/research/rotary_positional_embedding_test.py b/tests/layers/research/rotary_positional_embedding_test.py similarity index 52% rename from trax/layers/research/rotary_positional_embedding_test.py rename to tests/layers/research/rotary_positional_embedding_test.py index 8e049d11e..3a110c9c3 100644 --- a/trax/layers/research/rotary_positional_embedding_test.py +++ b/tests/layers/research/rotary_positional_embedding_test.py @@ -15,32 +15,33 @@ """Tests for trax.layers.research.rotary_positional_embedding.""" -from absl.testing import absltest import numpy as np + +from absl.testing import absltest + from trax.layers.research import rotary_positional_embedding as rotary_pe class RelAttentionTest(absltest.TestCase): + def test_rotary_monotonicity(self): + layer = rotary_pe.Rotate() + batch_size = 1 + seq_len = 32 + d_model = 512 + shape = (batch_size, seq_len, d_model) + q, k = np.ones(shape).astype(np.float32), np.ones(shape).astype(np.float32) + q, k = layer(q), layer(k) - def test_rotary_monotonicity(self): - layer = rotary_pe.Rotate() - batch_size = 1 - seq_len = 32 - d_model = 512 - shape = (batch_size, seq_len, d_model) - q, k = np.ones(shape).astype(np.float32), np.ones(shape).astype(np.float32) - q, k = layer(q), layer(k) - - self.assertEqual(q.dtype, np.float32) - self.assertEqual(q.shape, shape) + self.assertEqual(q.dtype, np.float32) + self.assertEqual(q.shape, shape) - # Test monotonicity of the resulting dot_product for the two first tokens - # in close proximity - dot_product = np.einsum('bnd, bmd -> bnm', q, k) + # Test monotonicity of the resulting dot_product for the two first tokens + # in close proximity + dot_product = np.einsum("bnd, bmd -> bnm", q, k) - self.assertTrue((dot_product[0, 0, :9] > dot_product[0, 0, 1:10]).all()) - self.assertTrue((dot_product[0, 1, 1:10] > dot_product[0, 1, 2:11]).all()) + self.assertTrue((dot_product[0, 0, :9] > dot_product[0, 0, 1:10]).all()) + self.assertTrue((dot_product[0, 1, 1:10] > dot_product[0, 1, 2:11]).all()) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/research/sparsity_test.py b/tests/layers/research/sparsity_test.py new file mode 100644 index 000000000..2de4a9adc --- /dev/null +++ b/tests/layers/research/sparsity_test.py @@ -0,0 +1,516 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.research.efficient_attention.""" + +import functools + +import jax +import numpy as np + +from absl.testing import parameterized +from tensorflow import test + +import trax.layers as tl + +from tests.layers import test_utils +from trax import fastmath +from trax.layers.research import sparsity +from trax.utils import shapes + + +class EfficientFeedForwardTest(test.TestCase, parameterized.TestCase): + def test_blocksparse_ff_train(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (3, 7, d_model) + with fastmath.use_backend(fastmath.Backend.JAX): + layer = sparsity.BlockSparseFF( + d_ff=d_ff, n_experts=n_experts, temperature=0.7, mode="train" + ) + x = np.ones(x_shape).astype(np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_blocksparse_ff_predict_equals_eval(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (1, 1, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + common_kwargs = dict( + d_ff=d_ff, + n_experts=n_experts, + temperature=temperature, + ) + eval_model = sparsity.BlockSparseFF(mode="eval", **common_kwargs) + weights, state = eval_model.init(input_signature) + eval_out, _ = eval_model.pure_fn( + x, weights, state, rng=jax.random.PRNGKey(0) + ) + pred_model = sparsity.BlockSparseFF(mode="predict", **common_kwargs) + _, _ = pred_model.init(input_signature) + pred_out, _ = pred_model.pure_fn( + x, weights, state, rng=jax.random.PRNGKey(0) + ) + self.assertEqual(eval_out.shape, x.shape) + # eval_out and pred_out should be identical. + np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) + + def test_sparse_ff_predict_equals_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + d_model = 64 + seq_len = 6 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + sparsity.SparseFF, + d_ff=256, + temperature=0.7, + n_elements_in_block=8, + ) + + configs = [ + {"multiply_by_controller_output": True}, + {"multiply_by_controller_output": False}, + {"ff_chunk_size": 2}, + ] + + test_utils.test_eval_equals_predict_configs(inp, model_fn, configs) + + @parameterized.named_parameters( + ("_mode_train", "train"), ("_mode_eval", "eval"), ("_mode_predict", "predict") + ) + def test_sparse_ff_with_chunking(self, mode): + d_model = 8 + n_elements_in_block = 2 + d_ff = 16 + x_shape = (2, 8, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + model = sparsity.SparseFF( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + temperature=temperature, + ff_chunk_size=4, + mode=mode, + ) + weights, state = model.init(input_signature) + out, _ = model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(out.shape, x.shape) + + @parameterized.named_parameters( + ("_mode_train", "train"), ("_mode_eval", "eval"), ("_mode_predict", "predict") + ) + def test_sparse_ff_multiply(self, mode): + d_model = 8 + n_elements_in_block = 2 + d_ff = 16 + x_shape = (2, 8, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + model = sparsity.SparseFF( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + temperature=temperature, + ff_chunk_size=4, + mode=mode, + multiply_by_controller_output=True, + ) + weights, state = model.init(input_signature) + out, _ = model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(out.shape, x.shape) + + def test_sparse_ff_kernel_scaling(self): + d_model = 8 + n_elements_in_block = 2 + d_ff = 16 + x_shape = (2, 8, d_model) + temperature = 0.7 + with fastmath.use_backend(fastmath.Backend.JAX): + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + model = sparsity.SparseFF( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + temperature=temperature, + ff_chunk_size=4, + mode="train", + kernel_scaling=True, + ) + weights, state = model.init(input_signature) + out, _ = model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(out.shape, x.shape) + + def test_switchsparse_ff_train(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (3, 7, d_model) + layer = sparsity.SwitchSparseFF(d_ff=d_ff, n_experts=n_experts, mode="train") + x = np.ones(x_shape).astype(np.float32) + layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_switchsparse_ff_predict_equals_eval(self): + d_model = 1024 + n_experts = 64 + d_ff = d_model * 8 + x_shape = (1, 1, d_model) + x = np.ones(x_shape).astype(np.float32) + input_signature = shapes.signature(x) + eval_model = sparsity.SwitchSparseFF( + mode="eval", d_ff=d_ff, n_experts=n_experts + ) + weights, state = eval_model.init(input_signature) + eval_out, _ = eval_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + pred_model = sparsity.SwitchSparseFF( + mode="predict", d_ff=d_ff, n_experts=n_experts + ) + pred_model.init(input_signature) + pred_out, _ = pred_model.pure_fn(x, weights, state, rng=jax.random.PRNGKey(0)) + self.assertEqual(eval_out.shape, x.shape) + # eval_out and pred_out should be identical. + np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) + + +class ReversibleReshapePermuteTest(test.TestCase): + def test_reversible_permute(self): + layer = sparsity.ReversibleReshapePermute() + x = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7]]) + layer.init(shapes.signature(x)) + ys = layer(x) + self.assertEqual( + tl.to_list(ys), [[1, 3, 5, 7, 2, 4, 6, 8], [0, 2, 4, 6, 1, 3, 5, 7]] + ) + rev_x = layer.reverse(ys, weights=layer.weights) + self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) + + +class ReversibleRandomPermuteTest(test.TestCase): + def test_reversible_permute(self): + layer = sparsity.ReversibleRandomPermute() + x = np.array( + [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], + [0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 11, 12, 13], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], + ] + ) + layer.init(shapes.signature(x)) + ys = layer(x) + # this assert will fail once per ~87B runs, but it's okay + self.assertNotEqual(tl.to_list(ys), tl.to_list(x)) + + self.assertEqual(tl.to_list(ys[0]), tl.to_list(ys[2])) + self.assertNotEqual(tl.to_list(ys[0]), tl.to_list(ys[1])) + rev_x = layer.reverse(ys, weights=layer.weights) + self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) + + +class LocallyConnectedDenseTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.LocallyConnectedDense(2, 8) + x = np.array([[2, 5, 3, 4], [0, 1, 2, 3]]) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (2, 16)) + + +class SparseDenseWithOptionsTest(test.TestCase): + def test_simple_call(self): + d_input, d_output = 16, 32 + settings = [ + (None, 0, 0, False), + (None, 0, 0, True), + ("einsum", 0, 0, False), + ("lowrank", 0, 8, False), + ("mult", 2, 0, False), + ("mult", 2, 0, True), + ("local", 2, 0, False), + ("local3", 2, 0, False), + ] + for stype, sparsity_level, d_lowrank, use_bfloat16 in settings: + layer = sparsity.SparseDenseWithOptions( + d_output, + d_input=d_input, + sparsity_type=stype, + sparsity=sparsity_level, + d_lowrank=d_lowrank, + use_bfloat16=use_bfloat16, + ) + x = np.ones((1, 1, d_input)) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual( + y.shape, + (1, 1, d_output), + msg="[{}->{}] {} - {} - {} - {}".format( + d_input, d_output, stype, sparsity_level, d_lowrank, use_bfloat16 + ), + ) + + +class ModularCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.ModularCausalAttention(d_feature=4, n_heads=2, sparsity=2) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class LowRankCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.LowRankCausalAttention(d_feature=4, n_heads=2, lowrank=2) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class MultiplicativeCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.MultiplicativeCausalAttention( + d_feature=4, n_heads=2, sparsity=2 + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class MultiplicativeModularCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.MultiplicativeModularCausalAttention( + d_feature=4, n_heads=2, sparsity=2 + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + +class MultiplicativeConvCausalAttentionTest(test.TestCase): + def test_simple_call(self): + layer = sparsity.MultiplicativeConvCausalAttention( + d_feature=4, n_heads=2, sparsity=2 + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_various_calls(self): + list_kwargs = [] + for share_qk in [True, False]: + for output in ["none", "mult", "conv", "multconv"]: + for concat in ["original", "fixed", "none"]: + kwargs = { + "share_qk": share_qk, + "output_layer_type": output, + "v_concat_type": concat, + } + list_kwargs.append(kwargs) + for kwargs in list_kwargs: + layer = sparsity.MultiplicativeConvCausalAttention( + d_feature=4, n_heads=2, sparsity=2, **kwargs + ) + x = np.array( + [ + [ + [2, 5, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ] + ) + _, _ = layer.init(shapes.signature(x)) + + y = layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_predict_equals_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + d_model = 32 + seq_len = 5 + x_shape = (1, seq_len, d_model) + inp = np.ones(x_shape).astype(np.float32) + + model_fn = functools.partial( + sparsity.MultiplicativeConvCausalAttention, + d_feature=d_model, + n_heads=4, + sparsity=4, + ) + + list_kwargs = [] + for share_qk in [True, False]: + for output in ["none", "mult", "conv", "multconv"]: + for concat in ["original", "fixed", "none"]: + kwargs = { + "share_qk": share_qk, + "output_layer_type": output, + "v_concat_type": concat, + } + list_kwargs.append(kwargs) + + test_utils.test_eval_equals_predict_configs(inp, model_fn, list_kwargs) + + +class FavorTest(test.TestCase): + def test_call_and_grad(self): + layer_partial = tl.Serial( + tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), + sparsity.Favor(d_feature=4, n_heads=2), + tl.Select([0], n_in=2), + ) + layer = tl.Serial( + tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), + sparsity.Favor(d_feature=4, n_heads=2), + tl.Select([0], n_in=2), + tl.WeightedCategoryCrossEntropy(), + ) + x = np.ones((1, 2), dtype=np.int32) + w = np.ones_like(x).astype(np.float32) + x_sig = shapes.signature(x) + w_sig = shapes.signature(w) + layer_partial.init(x_sig) + y = layer_partial(x) + self.assertEqual(y.shape, (1, 2, 4)) + layer.init((x_sig, x_sig, w_sig)) + y = layer((x, x, w)) + self.assertEqual(y.shape, ()) + state = layer.state + rng = fastmath.random.get_prng(0) + fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] + g = fastmath.grad(fwd)(layer.weights, (x, x, w)) + self.assertEqual(g[0][1][0].shape, (3, 4)) + + def test_call_and_grad_approximate_softmax(self): + layer_partial = tl.Serial( + tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), + sparsity.Favor( + d_feature=12, + n_heads=3, + n_random_features=128, + use_approximate_softmax=True, + ), + tl.Select([0], n_in=2), + ) + layer = tl.Serial( + tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), + sparsity.Favor( + d_feature=12, + n_heads=3, + n_random_features=128, + use_approximate_softmax=True, + ), + tl.Select([0], n_in=2), + tl.WeightedCategoryCrossEntropy(), + ) + x = np.ones((3, 5), dtype=np.int32) + w = np.ones_like(x).astype(np.float32) + x_sig = shapes.signature(x) + w_sig = shapes.signature(w) + layer_partial.init(x_sig) + y = layer_partial(x) + self.assertEqual(y.shape, (3, 5, 12)) + layer.init((x_sig, x_sig, w_sig)) + y = layer((x, x, w)) + self.assertEqual(y.shape, ()) + state = layer.state + rng = fastmath.random.get_prng(0) + fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] + g = fastmath.grad(fwd)(layer.weights, (x, x, w)) + self.assertEqual(g[0][1][0].shape, (11, 12)) + + def test_causal_call_and_grad(self): + layer = tl.Serial( + tl.Dense(4), sparsity.CausalFavor(d_feature=4, n_heads=2), tl.L2Loss() + ) + x = np.random.uniform(size=(1, 2, 4)).astype(np.float32) + w = np.ones_like(x) + x_sig = shapes.signature(x) + w_sig = shapes.signature(w) + layer.init((x_sig, x_sig, w_sig)) + y = layer((x, x, w)) + self.assertEqual(y.shape, ()) + state = layer.state + rng = fastmath.random.get_prng(0) + fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] + g = fastmath.grad(fwd)(layer.weights, (x, x, w)) + self.assertEqual(g[0][0].shape, (4, 4)) + + +if __name__ == "__main__": + test.main() diff --git a/trax/layers/reversible_test.py b/tests/layers/reversible_test.py similarity index 63% rename from trax/layers/reversible_test.py rename to tests/layers/reversible_test.py index 14fb67eaf..cf8159cb5 100644 --- a/trax/layers/reversible_test.py +++ b/tests/layers/reversible_test.py @@ -15,27 +15,26 @@ """Tests for reversible layers.""" -from absl.testing import absltest -from absl.testing import parameterized import numpy as np -from trax import fastmath +from absl.testing import absltest, parameterized + import trax.layers as tl +from trax import fastmath BACKENDS = [fastmath.Backend.JAX] class ReversibleLayerTest(parameterized.TestCase): - - @parameterized.named_parameters([('_' + b.value, b) for b in BACKENDS]) - def test_reversible_swap(self, backend): - with fastmath.use_backend(backend): - layer = tl.ReversibleSwap() - xs = [np.array([1, 2]), np.array([10, 20])] - ys = layer(xs) - self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]]) + @parameterized.named_parameters([("_" + b.value, b) for b in BACKENDS]) + def test_reversible_swap(self, backend): + with fastmath.use_backend(backend): + layer = tl.ReversibleSwap() + xs = [np.array([1, 2]), np.array([10, 20])] + ys = layer(xs) + self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]]) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/rnn_test.py b/tests/layers/rnn_test.py new file mode 100644 index 000000000..991a54fbc --- /dev/null +++ b/tests/layers/rnn_test.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for rnn layers.""" + +import numpy as np + +from absl.testing import absltest, parameterized + +import trax.layers as tl + +from trax import fastmath +from trax.utils import shapes + +BACKENDS = [fastmath.Backend.JAX] + + +@parameterized.named_parameters(("_" + b.value, b) for b in BACKENDS) +class RnnTest(parameterized.TestCase): + def test_conv_gru_cell(self, backend): + with fastmath.use_backend(backend): + layer = tl.ConvGRUCell(9, kernel_size=(3, 3)) + x = np.ones((8, 1, 7, 9)) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_gru_cell(self, backend): + with fastmath.use_backend(backend): + layer = tl.GRUCell(9) + xs = [np.ones((8, 7, 9)), np.ones((8, 7, 9))] + _, _ = layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual([y.shape for y in ys], [(8, 7, 9), (8, 7, 9)]) + + def test_lstm_cell(self, backend): + with fastmath.use_backend(backend): + layer = tl.LSTMCell(9) + xs = [np.ones((8, 9)), np.ones((8, 18))] + _, _ = layer.init(shapes.signature(xs)) + ys = layer(xs) + self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)]) + + def test_sru(self, backend): + with fastmath.use_backend(backend): + layer = tl.SRU(7) + x = np.ones((8, 9, 7), np.float32) + _, _ = layer.init(shapes.signature(x)) + y = layer(x) + self.assertEqual(y.shape, x.shape) + + def test_names(self, backend): + with fastmath.use_backend(backend): + layer = tl.LSTM(3) + self.assertEqual("LSTM_3", str(layer)) + layer = tl.GRU(5) + self.assertEqual("GRU_5", str(layer)) + layer = tl.SRU(7) + self.assertEqual("SRU_7", str(layer)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/layers/test_utils.py b/tests/layers/test_utils.py new file mode 100644 index 000000000..c4d0cad37 --- /dev/null +++ b/tests/layers/test_utils.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for testing.""" + +import copy +import functools + +import numpy as np +import pytest + +from absl.testing import absltest + +from trax import fastmath +from trax import layers as tl +from trax.utils import shapes + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_is_deterministic(inp, model_fn, message=""): + """Utility method for testing if eval mode is deterministic. + + Args: + inp: input fed to the model. It can be a tensor, or a tuple of tensors. + model_fn: function creating a model after calling with `mode` argument. + message: Optional message to show when outputs of eval/predict mode don't + match. + """ + with fastmath.use_backend(fastmath.Backend.JAX): + model_eval1 = model_fn(mode="eval") + model_eval2 = model_fn(mode="eval") + + input_signature = shapes.signature(inp) + model_eval1.init(input_signature) + model_eval2.init(input_signature) + model_eval1.save_to_file("/tmp/unique_weights") + model_eval2.init_from_file( + "/tmp/unique_weights", weights_only=True, input_signature=input_signature + ) + + rng = fastmath.random.get_prng(0) + output_eval1 = model_eval1(inp, rng=rng) + if not isinstance(output_eval1, (tuple, list)): + # We will automatically check each and every tensor returned. + output_eval1 = [output_eval1] + + output_eval2 = model_eval2(inp, rng=rng) + if not isinstance(output_eval2, (tuple, list)): + # We will automatically check each and every tensor returned. + output_eval2 = [output_eval2] + + np.testing.assert_equal(len(output_eval1), len(output_eval2)) + for out1, out2 in zip(output_eval1, output_eval2): + np.testing.assert_array_almost_equal( + out1, out2, decimal=5, err_msg="Non-deterministic.{}".format(message) + ) + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_equals_predict( + inp, model_fn, seq_axis=1, seq_tensor=None, init_tokens=3, message="" +): + """Utility method for testing equivalence of predict and eval modes. + + Args: + inp: input fed to the model. It can be a tensor, or a tuple of tensors. + model_fn: function creating a model after calling with `mode` argument. + seq_axis: axis of sequence_length. In predict mode we iterate over this + axis. By default `1`, which is 2nd dimension. + seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor + in this tuple on which we iterate the sequence. + init_tokens: how many tokens should be passed to the first `predict` call. + message: Optional message to show when outputs of eval/predict mode don't + match. + """ + with fastmath.use_backend(fastmath.Backend.JAX): + model_eval = model_fn(mode="eval") + model_predict = model_fn(mode="predict") + + input_signature = shapes.signature(inp) + model_eval.init(input_signature) + model_predict.init(input_signature) + model_eval.save_to_file("/tmp/unique_weights") + model_predict.init_from_file( + "/tmp/unique_weights", weights_only=True, input_signature=input_signature + ) + + rng = fastmath.random.get_prng(0) + output_eval = model_eval(inp, rng=rng) + if not isinstance(output_eval, (tuple, list)): + # We will automatically check each and every tensor returned. + output_eval = [output_eval] + + if seq_tensor is None: + length = inp.shape[seq_axis] + else: + length = inp[seq_tensor].shape[seq_axis] + + assert length >= init_tokens + 2 # Required to properly test predict mode. + indices_list = [(0, init_tokens)] + [ + (i, i + 1) for i in range(init_tokens, length) + ] + + for indices in indices_list: + start, end = indices + if seq_tensor is None: + new_inp = inp.take(indices=np.arange(start, end), axis=seq_axis) + else: + new_inp = list(inp) + new_inp[seq_tensor] = new_inp[seq_tensor].take( + indices=np.arange(start, end), axis=seq_axis + ) + + output_predict = model_predict(new_inp, rng=rng) + if not isinstance(output_predict, (tuple, list)): + # We will automatically check each and every tensor returned. + output_predict = [output_predict] + + np.testing.assert_equal(len(output_predict), len(output_eval)) + for outp, oute in zip(output_predict, output_eval): + np.testing.assert_array_almost_equal( + oute.take(indices=np.arange(start, end), axis=seq_axis), + outp.take(indices=np.arange(0, end - start), axis=seq_axis), + decimal=5, + err_msg="Error on element {} out of {}.{}".format( + indices, length, message + ), + ) + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_equals_predict_configs( + inp, model_fn, configs, seq_axis=1, seq_tensor=None, message="" +): + """Utility method for testing equivalence of predict and eval modes. + + This function iterates over a list of dictionaries `confis`, and runs the test + on models with each configuration. + + Args: + inp: input fed to the model. It can be a tensor, or a tuple of tensors. + model_fn: function creating a model after calling with `mode` argument. + configs: List of dictionaries, which contain configs to be fed into + `model_fn`. + seq_axis: axis of sequence_length. In predict mode we iterate over this + axis. By default `1`, which is 2nd dimension. + seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor + in this tuple on which we iterate the sequence. + message: Optional message to show when outputs of eval/predict mode don't + match. + """ + for config in configs: + model_fn_configured = functools.partial(model_fn, **config) + test_eval_equals_predict( + inp, + model_fn_configured, + seq_axis=seq_axis, + seq_tensor=seq_tensor, + message=" Config: {}.{}".format(config, message), + ) + + +@absltest.skip +@pytest.mark.skip(reason="This is helper method not direct test") +def test_eval_equals_predict_discrete(model_fn, vocab_size=10, length=5, batch_size=3): + """Tests the equivalence of eval and predict modes for discrete models.""" + with fastmath.use_backend(fastmath.Backend.JAX): + model_slow = model_fn(mode="eval", vocab_size=vocab_size) + model_fast = model_fn(mode="predict", vocab_size=vocab_size) + rng = fastmath.random.get_prng(0) + input_signature = shapes.ShapeDtype((batch_size, 1), np.int32) + # Given the same rng, both models initialize with the same parameters. + model_slow.init(input_signature, rng) + model_fast.init(input_signature, rng) + + buf = np.zeros((batch_size, length), dtype=np.int32) + next_sym = np.zeros((batch_size, 1), dtype=np.int32) + + for index in range(length): + logits_slow = model_slow(buf, rng=rng) + logits_fast = model_fast(next_sym, rng=rng) + np.testing.assert_array_almost_equal( + logits_slow[:, index, :], + logits_fast[:, 0, :], + decimal=5, + ) + next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) + buf[:, index] = next_sym[:, 0] + + +class MockTransformerLM(tl.Layer): + r"""Mock TransformerLM for testing autoregressive sampling routines. + + Mimics the behavior of a perfectly-trained, deterministic TransformerLM. + Allows to specify the \sigma^* -> \sigma function implemented by the model + and to make assertions about the input sequence passed to the model. + + Supports two modes: stateful "predict" for fast inference, and stateless + non-"predict" ("train", "eval" etc). + + Useful for testing any logic that relies on autoregressive sampling, as it + removes the additional layer of complexity related to training a model or + maintaining a pretrained one. Makes the tests run MUCH faster. + + Does not support acceleration. Do not wrap in tl.Accelerate(). + """ + + def __init__(self, sequence_fn, mode, vocab_size): + super().__init__() + + self._sequence_fn = sequence_fn + self._mode = mode + self._vocab_size = vocab_size + + self._prediction_buffers = None + + @property + def state(self): + return copy.deepcopy(self._prediction_buffers) + + @state.setter + def state(self, state): + self._prediction_buffers = copy.deepcopy(state) + + def _output_symbol_predict(self, input_symbols, prediction_buffer): + prediction_buffer.extend(input_symbols) + output_symbol = self._sequence_fn(np.array(prediction_buffer)) + return np.array([output_symbol]) + + def _output_symbols_eval(self, input_symbols, prediction_buffer): + del prediction_buffer + + # Add a leading 0 token to imitate ShiftRight. + input_symbols = np.concatenate(([0], input_symbols)) + + # Call sequence_fn repeatedly along the input sequence. + return np.array( + [ + self._sequence_fn(input_symbols[:end]) + for end in range(1, len(input_symbols)) + ] + ) + + def _symbols_to_logits(self, symbols): + # Assert that symbols are discrete. + assert np.issubdtype(symbols.dtype, np.integer) + # Assert that 0 <= symbols < vocab_size. + np.testing.assert_array_less(-1, symbols) + np.testing.assert_array_less(symbols, self._vocab_size) + + # Return almost-determinisitc logits: + # e^1000 / (e^1000 + vocab_size) ~= 1 + return tl.one_hot(symbols, n_categories=self._vocab_size) * 1000.0 + + def __call__(self, inputs, rng=None): + del rng + + assert inputs.ndim == 2, "The input sequences should have exactly two axes." + + if self._prediction_buffers is None: + # Initialize the buffer. + batch_size = inputs.shape[0] + # [[]] * batch_size would create multiple references to the same + # list, and we want separate lists. + self._prediction_buffers = [[] for _ in range(batch_size)] + + if self._mode == "predict": + output_fn = self._output_symbol_predict + else: + output_fn = self._output_symbols_eval + + # Calculate the output separately for each sequence in the batch. + output_symbols = np.array( + [ + output_fn(input_seq, pred_buffer) + for (input_seq, pred_buffer) in zip(inputs, self._prediction_buffers) + ] + ) + return self._symbols_to_logits(output_symbols) + + def assert_prediction_buffers_equal(self, expected_buffers): + if self._prediction_buffers is None: + batch_size = expected_buffers.shape[0] + actual_buffers = np.empty((batch_size, 0)) + else: + actual_buffers = np.array(self._prediction_buffers) + + np.testing.assert_array_equal(actual_buffers, expected_buffers) diff --git a/tests/layers/test_utils_test.py b/tests/layers/test_utils_test.py new file mode 100644 index 000000000..8c7cc4cc8 --- /dev/null +++ b/tests/layers/test_utils_test.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.test_utils.""" + +import functools + +import numpy as np + +from absl.testing import absltest + +from tests.layers import test_utils +from trax.learning.supervised import decoding + + +def arithmetic_sequence(input_seq, limit=10): + # Increment the last symbol. Wrap to [0, 10). + return (input_seq[-1] + 1) % limit + + +class TestUtilsTest(absltest.TestCase): + def test_mock_transformer_lm_eval_equals_predict(self): + model_fn = functools.partial( + test_utils.MockTransformerLM, + sequence_fn=arithmetic_sequence, + vocab_size=10, + ) + test_utils.test_eval_equals_predict_discrete(model_fn, vocab_size=10) + + def test_mock_transformer_lm_decodes_arithmetic_sequence(self): + model = test_utils.MockTransformerLM( + sequence_fn=arithmetic_sequence, + vocab_size=10, + mode="predict", + ) + output = decoding.autoregressive_sample( + model, max_length=5, start_id=0, eos_id=-1, accelerate=False + ) + + # Sequence including the leading 0 and the last predicted symbol. + full_seq = list(range(6)) + # decoding.autoregressive_sample doesn't return the leading 0. + np.testing.assert_array_equal(output, [full_seq[1:]]) + # The prediction buffers don't include the last predicted symbol. + model.assert_prediction_buffers_equal([full_seq[:-1]]) + + def test_mock_transformer_lm_rewinds(self): + model = test_utils.MockTransformerLM( + sequence_fn=arithmetic_sequence, + vocab_size=10, + mode="predict", + ) + sample_3 = functools.partial( + decoding.autoregressive_sample, + max_length=3, + eos_id=-1, + accelerate=False, + ) + + # Generate the 3 initial symbols. + init_output = sample_3(model, start_id=0) + np.testing.assert_array_equal(init_output, [[1, 2, 3]]) + state = model.state + + # Generate the next 3 symbols. + next_output = sample_3(model, start_id=init_output[0, -1]) + np.testing.assert_array_equal(next_output, [[4, 5, 6]]) + + # Rewind and generate the last 3 symbols again. + model.state = state + next_output = sample_3(model, start_id=init_output[0, -1]) + np.testing.assert_array_equal(next_output, [[4, 5, 6]]) + + # Check the buffers. + model.assert_prediction_buffers_equal([[0, 1, 2, 3, 4, 5]]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/actor_critic_joint_test.py b/tests/learning/reinforcement/actor_critic_joint_test.py new file mode 100644 index 000000000..2a0850dfc --- /dev/null +++ b/tests/learning/reinforcement/actor_critic_joint_test.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RL training.""" + +import functools + +from absl.testing import absltest + +from trax import layers as tl +from trax import models +from trax import optimizers as opt +from trax.learning.reinforcement import actor_critic_joint +from trax.learning.reinforcement import task as rl_task +from trax.learning.supervised import lr_schedules +from trax.utils import test_utils + + +class ActorCriticJointTest(absltest.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + + def test_awrjoint_save_restore(self): + """Check save and restore of joint AWR trainers.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=2, max_steps=2) + joint_model = functools.partial( + models.PolicyAndValue, + body=lambda mode: tl.Serial(tl.Dense(4), tl.Relu()), + ) + tmp_dir = self.create_tempdir().full_path + trainer1 = actor_critic_joint.AWRJoint( + task, + joint_model=joint_model, + optimizer=opt.Adam, + batch_size=4, + train_steps_per_epoch=1, + n_trajectories_per_epoch=2, + output_dir=tmp_dir, + ) + trainer1.run(2) + self.assertEqual(trainer1.current_epoch, 2) + self.assertEqual(trainer1._trainer.step, 2) + # Agent 2 starts where agent 1 stopped. + trainer2 = actor_critic_joint.AWRJoint( + task, + joint_model=joint_model, + optimizer=opt.Adam, + batch_size=4, + train_steps_per_epoch=1, + n_trajectories_per_epoch=2, + output_dir=tmp_dir, + ) + trainer2.run(1) + self.assertEqual(trainer2.current_epoch, 3) + self.assertEqual(trainer2._trainer.step, 3) + trainer1.close() + trainer2.close() + + def test_jointppotrainer_cartpole(self): + """Test-runs joint PPO on CartPole.""" + + task = rl_task.RLTask("CartPole-v0", initial_trajectories=0, max_steps=2) + joint_model = functools.partial( + models.PolicyAndValue, + body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()), + ) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + + trainer = actor_critic_joint.PPOJoint( + task, + joint_model=joint_model, + optimizer=opt.Adam, + lr_schedule=lr, + batch_size=4, + train_steps_per_epoch=2, + n_trajectories_per_epoch=5, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + def test_jointawrtrainer_cartpole(self): + """Test-runs joint AWR on cartpole.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=1, max_steps=2) + joint_model = functools.partial( + models.PolicyAndValue, + body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), + ) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + trainer = actor_critic_joint.AWRJoint( + task, + joint_model=joint_model, + optimizer=opt.Adam, + lr_schedule=lr, + batch_size=4, + train_steps_per_epoch=2, + n_trajectories_per_epoch=5, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + def test_jointa2ctrainer_cartpole(self): + """Test-runs joint A2C on cartpole.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=1, max_steps=2) + joint_model = functools.partial( + models.PolicyAndValue, + body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), + ) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + trainer = actor_critic_joint.A2CJoint( + task, + joint_model=joint_model, + optimizer=opt.RMSProp, + lr_schedule=lr, + batch_size=2, + train_steps_per_epoch=1, + n_trajectories_per_epoch=1, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + def test_jointawrtrainer_cartpole_transformer(self): + """Test-runs joint AWR on cartpole with Transformer.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=1, max_steps=2) + body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda + d_model=4, d_ff=4, n_layers=1, n_heads=1, mode=mode + ) + joint_model = functools.partial(models.PolicyAndValue, body=body) + trainer = actor_critic_joint.AWRJoint( + task, + joint_model=joint_model, + optimizer=opt.Adam, + batch_size=4, + train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + max_slice_length=2, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + def test_jointa2ctrainer_cartpole_transformer(self): + """Test-runs joint A2C on cartpole with Transformer.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=1, max_steps=2) + body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda + d_model=4, d_ff=4, n_layers=1, n_heads=1, mode=mode + ) + joint_model = functools.partial(models.PolicyAndValue, body=body) + trainer = actor_critic_joint.A2CJoint( + task, + joint_model=joint_model, + optimizer=opt.RMSProp, + batch_size=4, + train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/actor_critic_test.py b/tests/learning/reinforcement/actor_critic_test.py new file mode 100644 index 000000000..431a90030 --- /dev/null +++ b/tests/learning/reinforcement/actor_critic_test.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RL training.""" + +import functools + +from absl.testing import absltest, parameterized + +from trax import layers as tl +from trax import models +from trax import optimizers as opt +from trax.learning.reinforcement import actor_critic, advantages +from trax.learning.reinforcement import task as rl_task +from trax.learning.supervised import lr_schedules +from trax.utils import test_utils + + +class ActorCriticTest(parameterized.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + + def test_a2ctrainer_save_restore(self): + """Check save and restore of A2C trainers.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=0, max_steps=20) + body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) + policy_model = functools.partial(models.Policy, body=body) + value_model = functools.partial(models.Value, body=body) + tmp_dir = self.create_tempdir().full_path + trainer1 = actor_critic.A2C( + task, + value_model=value_model, + value_optimizer=opt.Adam, + value_batch_size=2, + value_train_steps_per_epoch=1, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_batch_size=2, + policy_train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + n_shared_layers=1, + output_dir=tmp_dir, + ) + trainer1.run(2) + self.assertEqual(trainer1.current_epoch, 2) + self.assertEqual(trainer1._value_trainer.step, 2) + self.assertEqual(trainer1._policy_trainer.step, 4) + # Trainer 2 starts where trainers 1 stopped. + trainer2 = actor_critic.A2C( + task, + value_model=value_model, + value_optimizer=opt.Adam, + value_batch_size=2, + value_train_steps_per_epoch=1, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_batch_size=2, + policy_train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + n_shared_layers=1, + output_dir=tmp_dir, + ) + trainer2.run(1) + self.assertEqual(trainer2.current_epoch, 3) + self.assertEqual(trainer2._value_trainer.step, 3) + self.assertEqual(trainer2._policy_trainer.step, 6) + trainer1.close() + trainer2.close() + + def test_sanity_a2ctrainer_cartpole(self): + """Test-runs a2c on cartpole.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=0, max_steps=2) + body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) + policy_model = functools.partial(models.Policy, body=body) + value_model = functools.partial(models.Value, body=body) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-4, warmup_steps=100, factors="constant * linear_warmup" + ) + trainer = actor_critic.A2C( + task, + n_shared_layers=1, + value_model=value_model, + value_optimizer=opt.Adam, + value_lr_schedule=lr, + value_batch_size=2, + value_train_steps_per_epoch=2, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_lr_schedule=lr, + policy_batch_size=2, + policy_train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + def test_sanity_ppo_cartpole(self): + """Run PPO and check whether it correctly runs for 2 epochs.s.""" + task = rl_task.RLTask("CartPole-v1", initial_trajectories=0, max_steps=200) + + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-3, warmup_steps=100, factors="constant * linear_warmup" + ) + + body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) + policy_model = functools.partial(models.Policy, body=body) + value_model = functools.partial(models.Value, body=body) + trainer = actor_critic.PPO( + task, + n_shared_layers=1, + value_model=value_model, + value_optimizer=opt.Adam, + value_lr_schedule=lr, + value_batch_size=128, + value_train_steps_per_epoch=10, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_lr_schedule=lr, + policy_batch_size=128, + policy_train_steps_per_epoch=10, + n_trajectories_per_epoch=10, + ) + + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + def test_sanity_loopawr(self): + """Test-runs LoopAWR.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=0, max_steps=2) + body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) + model_fn = functools.partial(models.PolicyAndValue, body=body) + trainer = actor_critic.LoopAWR( + task, + model_fn, + batch_size=2, + network_eval_at=(lambda _: True), + policy_n_steps_per_epoch=2, + value_n_steps_per_epoch=2, + n_trajectories_per_epoch=1, + n_eval_episodes=1, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + @parameterized.named_parameters( + ("default", None), ("thresholds", ((70, 1.0, 0), (90, 4.0, 0))) + ) + def test_sanity_awrtrainer_transformer_cartpole(self, thresholds): + """Test-runs AWR on cartpole with Transformer.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=2, max_steps=2) + body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda + d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode + ) + policy_model = functools.partial(models.Policy, body=body) + value_model = functools.partial(models.Value, body=body) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + trainer = actor_critic.AWR( + task, + thresholds=thresholds, + n_shared_layers=0, + max_slice_length=2, + added_policy_slice_length=1, + value_model=value_model, + value_optimizer=opt.Adam, + value_lr_schedule=lr, + value_batch_size=2, + value_train_steps_per_epoch=2, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_lr_schedule=lr, + policy_batch_size=2, + policy_train_steps_per_epoch=2, + n_trajectories_per_epoch=1, + n_eval_episodes=1, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + def test_sampling_awrtrainer_cartpole(self): + """Test-runs AWR on cartpole with Transformer.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=0, max_steps=20) + body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) + policy_model = functools.partial(models.Policy, body=body) + value_model = functools.partial(models.Value, body=body) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + trainer = actor_critic.SamplingAWR( + task, + n_shared_layers=0, + added_policy_slice_length=1, + value_model=value_model, + value_optimizer=opt.Adam, + value_lr_schedule=lr, + value_batch_size=2, + value_train_steps_per_epoch=2, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_lr_schedule=lr, + policy_batch_size=2, + policy_train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + advantage_estimator=advantages.monte_carlo, + advantage_normalization=False, + q_value_n_samples=3, + q_value_aggregate="max", + reweight=False, + ) + trainer.run(1) + self.assertEqual(1, trainer.current_epoch) + + def test_sampling_awrtrainer_cartpole_sample_all_discrete(self): + """Test-runs AWR on cartpole with Transformer, n_actions = n_samples.""" + task = rl_task.RLTask("CartPole-v0", initial_trajectories=0, max_steps=20) + body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) + policy_model = functools.partial(models.Policy, body=body) + value_model = functools.partial(models.Value, body=body) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + trainer = actor_critic.SamplingAWR( + task, + n_shared_layers=0, + added_policy_slice_length=1, + value_model=value_model, + value_optimizer=opt.Adam, + value_lr_schedule=lr, + value_batch_size=2, + value_train_steps_per_epoch=2, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_lr_schedule=lr, + policy_batch_size=2, + policy_train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + advantage_estimator=advantages.monte_carlo, + advantage_normalization=False, + q_value_n_samples=2, + q_value_aggregate="max", + reweight=False, + ) + trainer.run(1) + self.assertEqual(1, trainer.current_epoch) + + def test_sampling_awrtrainer_mountain_acr(self): + """Test-runs Sampling AWR on MountainCarContinuous.""" + task = rl_task.RLTask( + "MountainCarContinuous-v0", initial_trajectories=0, max_steps=2 + ) + body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) + policy_model = functools.partial(models.Policy, body=body) + value_model = functools.partial(models.Value, body=body) + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + trainer = actor_critic.SamplingAWR( + task, + n_shared_layers=0, + added_policy_slice_length=1, + value_model=value_model, + value_optimizer=opt.Adam, + value_lr_schedule=lr, + value_batch_size=2, + value_train_steps_per_epoch=2, + policy_model=policy_model, + policy_optimizer=opt.Adam, + policy_lr_schedule=lr, + policy_batch_size=2, + policy_train_steps_per_epoch=2, + n_trajectories_per_epoch=2, + advantage_estimator=advantages.monte_carlo, + advantage_normalization=False, + q_value_n_samples=3, + ) + trainer.run(1) + self.assertEqual(1, trainer.current_epoch) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/advantages_test.py b/tests/learning/reinforcement/advantages_test.py new file mode 100644 index 000000000..e38e54fd7 --- /dev/null +++ b/tests/learning/reinforcement/advantages_test.py @@ -0,0 +1,230 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.reinforcement.advantages.""" + +import functools + +import numpy as np + +from absl.testing import absltest, parameterized + +from trax.learning.reinforcement import advantages + + +def calc_bias_and_variance(x, true_mean): + sample_mean = np.mean(x) + bias = np.mean(np.abs(sample_mean - true_mean)) + variance = np.mean((x - sample_mean) ** 2) + return (bias, variance) + + +def estimate_advantage_bias_and_variance( + advantage_fn, + mean_reward=1.23, + reward_noise=0.45, + discount_mask=None, + discount_true_return=True, + true_value=False, + n_samples=10000, + length=5, + gamma=0.9, + margin=1, + **advantage_kwargs, +): + advantage_fn = advantage_fn(gamma, margin, **advantage_kwargs) + rewards = np.random.normal( + loc=mean_reward, scale=reward_noise, size=(n_samples, length) + ) + if discount_mask is None: + discount_mask = np.ones_like(rewards) + gammas = advantages.mask_discount(gamma, discount_mask) + returns = advantages.discounted_returns(rewards, gammas) + + true_returns = advantages.discounted_returns( + np.full(returns.shape, fill_value=mean_reward), gammas=gammas + ) + if true_value: + values = true_returns + else: + values = np.zeros_like(returns) + + dones = np.zeros_like(returns, dtype=bool) + adv = advantage_fn(rewards, returns, values, dones, discount_mask) + if discount_true_return: + mean_return = true_returns[0, 0] + else: + mean_return = mean_reward * length + return calc_bias_and_variance(adv[:, 0], mean_return - values[:, 0]) + + +class AdvantagesTest(parameterized.TestCase): + @parameterized.named_parameters( + ("monte_carlo", advantages.monte_carlo), + ("td_k", advantages.td_k), + ("td_lambda", advantages.td_lambda), + ("gae", advantages.gae), + ) + def test_shapes(self, advantage_fn): + rewards = np.array([[1, 1, 1]], dtype=np.float32) + returns = np.array([[3, 2, 1]], dtype=np.float32) + values = np.array([[2, 2, 2]], dtype=np.float32) + dones = np.array([[False, False, True]]) + discount_mask = np.ones_like(rewards) + adv1 = advantage_fn(gamma=1, margin=1)( + rewards, returns, values, dones, discount_mask + ) + self.assertEqual(adv1.shape, (1, 2)) + adv2 = advantage_fn(gamma=1, margin=2)( + rewards, returns, values, dones, discount_mask + ) + self.assertEqual(adv2.shape, (1, 1)) + + def test_monte_carlo_bias_is_zero(self): + (bias, _) = estimate_advantage_bias_and_variance( + advantages.monte_carlo, margin=3 + ) + np.testing.assert_allclose(bias, 0, atol=0.1) + + def test_td_k_variance_lower_than_monte_carlo(self): + (_, var_td_3) = estimate_advantage_bias_and_variance(advantages.td_k, margin=3) + (_, var_mc) = estimate_advantage_bias_and_variance(advantages.monte_carlo) + self.assertLess(var_td_3, var_mc) + + @parameterized.named_parameters(("1_2", 1, 2), ("2_3", 2, 3)) + def test_td_k_bias_decreases_with_k(self, k1, k2): + (bias1, _) = estimate_advantage_bias_and_variance(advantages.td_k, margin=k1) + (bias2, _) = estimate_advantage_bias_and_variance(advantages.td_k, margin=k2) + self.assertGreater(bias1, bias2) + + @parameterized.named_parameters(("1_2", 1, 2), ("2_3", 2, 3)) + def test_td_k_variance_increases_with_k(self, k1, k2): + (_, var1) = estimate_advantage_bias_and_variance(advantages.td_k, margin=k1) + (_, var2) = estimate_advantage_bias_and_variance(advantages.td_k, margin=k2) + self.assertLess(var1, var2) + + def test_td_lambda_variance_lower_than_monte_carlo(self): + (_, var_td_095) = estimate_advantage_bias_and_variance( + advantages.td_lambda, lambda_=0.95 + ) + (_, var_mc) = estimate_advantage_bias_and_variance(advantages.monte_carlo) + self.assertLess(var_td_095, var_mc) + + @parameterized.named_parameters( + ("td_lambda_0.5_0.7", advantages.td_lambda, 0.5, 0.7), + ("td_lambda_0.7_0.9", advantages.td_lambda, 0.7, 0.9), + ("gae_0.5_0.7", advantages.gae, 0.5, 0.7), + ("gae_0.7_0.9", advantages.gae, 0.7, 0.9), + ) + def test_bias_decreases_with_lambda(self, advantage_fn, lambda1, lambda2): + (bias1, _) = estimate_advantage_bias_and_variance(advantage_fn, lambda_=lambda1) + (bias2, _) = estimate_advantage_bias_and_variance(advantage_fn, lambda_=lambda2) + self.assertGreater(bias1, bias2) + + @parameterized.named_parameters(("0.5_0.7", 0.5, 0.7), ("0.7_0.9", 0.7, 0.9)) + def test_variance_increases_with_lambda(self, lambda1, lambda2): + (_, var1) = estimate_advantage_bias_and_variance( + advantages.td_lambda, lambda_=lambda1 + ) + (_, var2) = estimate_advantage_bias_and_variance( + advantages.td_lambda, lambda_=lambda2 + ) + self.assertLess(var1, var2) + + @parameterized.named_parameters( + ("monte_carlo", advantages.monte_carlo), + ("td_k", advantages.td_k), + ("td_lambda", advantages.td_lambda), + ("gae", advantages.gae), + ) + def test_advantage_future_return_is_zero_at_done(self, advantage_fn): + rewards = np.array([[1, 1, 1]], dtype=np.float32) + returns = np.array([[3, 2, 1]], dtype=np.float32) + values = np.array([[2, 2, 2]], dtype=np.float32) + dones = np.array([[False, True, False]]) + discount_mask = np.ones_like(rewards) + adv = advantage_fn(gamma=0.9, margin=1)( + rewards, returns, values, dones, discount_mask + ) + target_returns = values[:, :-1] + adv + # Assert that in the "done" state the future return in the advantage is + # zero, i.e. the return is equal to the reward. + np.testing.assert_almost_equal(target_returns[0, 1], rewards[0, 1]) + + @parameterized.named_parameters( + ("monte_carlo", advantages.monte_carlo), + # Disabled for TD-k because the differences are too small. + # ('td_k', advantages.td_k), + ("td_lambda", advantages.td_lambda), + ("gae", advantages.gae), + ) + def test_bias_and_variance_with_non_const_discount_mask(self, advantage_fn): + non_const_discount_mask = np.array([[1, 0, 1, 0, 1]]) + const_discount_mask = np.ones_like(non_const_discount_mask) + est_bias_and_variance = functools.partial( + estimate_advantage_bias_and_variance, + advantage_fn, + length=const_discount_mask.shape[1], + # Set gamma to a small value to accentuate the differences. + gamma=0.5, + # We want to measure error due to the discount, so compare with the + # undiscounted return. + discount_true_return=False, + # Use true values to remove the value estimation error. + true_value=True, + ) + (bias_non_const, var_non_const) = est_bias_and_variance( + discount_mask=non_const_discount_mask + ) + (bias_const, var_const) = est_bias_and_variance( + discount_mask=const_discount_mask + ) + self.assertLess(bias_non_const, bias_const) + self.assertGreater(var_non_const, var_const) + + @parameterized.named_parameters( + ("monte_carlo", advantages.monte_carlo), + ("td_k", advantages.td_k), + ("td_lambda", advantages.td_lambda), + ("gae", advantages.gae), + ) + def test_future_return_is_zero_iff_discount_mask_is_on(self, advantage_fn): + # (... when gamma=0) + rewards = np.array([[1, 2, 3, 4]], dtype=np.float32) + values = np.array([[5, 6, 7, 8]], dtype=np.float32) + dones = np.zeros_like(rewards, dtype=bool) + discount_mask = np.array([[1, 0, 1, 0]], dtype=bool) + gammas = advantages.mask_discount(0.0, discount_mask) + returns = advantages.discounted_returns(rewards, gammas) + adv = advantage_fn(gamma=0.0, margin=1)( + rewards, returns, values, dones, discount_mask + ) + target_returns = values[:, :-1] + adv + # Assert that in the states with discount_mask on the future return in the + # advantage is zero, i.e. the return is equal to the reward. + rewards = rewards[:, :-1] + discount_mask = discount_mask[:, :-1] + np.testing.assert_almost_equal( + target_returns[discount_mask], rewards[discount_mask] + ) + # Assert the converse. + with np.testing.assert_raises(AssertionError): + np.testing.assert_almost_equal( + target_returns[~discount_mask], rewards[~discount_mask] + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/distributions_test.py b/tests/learning/reinforcement/distributions_test.py new file mode 100644 index 000000000..7538fbe6f --- /dev/null +++ b/tests/learning/reinforcement/distributions_test.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.reinforcement.distributions.""" + +import gin +import gym +import numpy as np + +from absl.testing import absltest, parameterized + +from trax.learning.reinforcement import distributions + + +class DistributionsTest(parameterized.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + @parameterized.named_parameters( + ("discrete", gym.spaces.Discrete(n=4), ""), + ("multi_discrete", gym.spaces.MultiDiscrete(nvec=[5, 5]), ""), + ( + "gaussian_const_std", + gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)), + "Gaussian.learn_std = None", + ), + ( + "gaussian_shared_std", + gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)), + 'Gaussian.learn_std = "shared"', + ), + ( + "gaussian_separate_std", + gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)), + 'Gaussian.learn_std = "separate"', + ), + ) + def test_shapes(self, space, gin_config): + gin.parse_config(gin_config) + + batch_shape = (2, 3) + distribution = distributions.create_distribution(space) + inputs = np.random.random(batch_shape + (distribution.n_inputs,)) + point = distribution.sample(inputs) + self.assertEqual(point.shape, batch_shape + space.shape) + # Check if the datatypes are compatible, i.e. either both floating or both + # integral. + self.assertEqual(isinstance(point.dtype, float), isinstance(space.dtype, float)) + log_prob = distribution.log_prob(inputs, point) + self.assertEqual(log_prob.shape, batch_shape) + + @parameterized.named_parameters(("1d", 1), ("2d", 2)) + def test_gaussian_probability_sums_to_one(self, n_dims): + std = 1.0 + n_samples = 10000 + + distribution = distributions.Gaussian(shape=(n_dims,), std=std) + means = np.random.random((3, n_dims)) + # Monte carlo integration over [mean - 3 * std, mean + 3 * std] across + # all dimensions. + means = np.broadcast_to(means, (n_samples,) + means.shape) + probs = (6 * std) ** n_dims * np.mean( + np.exp( + distribution.log_prob( + means, np.random.uniform(means - 3 * std, means + 3 * std) + ) + ), + axis=0, + ) + # Should sum to one. High tolerance because of variance and cutting off the + # tails. + np.testing.assert_allclose(probs, np.ones_like(probs), atol=0.05) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/normalization_test.py b/tests/learning/reinforcement/normalization_test.py new file mode 100644 index 000000000..0a6d76139 --- /dev/null +++ b/tests/learning/reinforcement/normalization_test.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.reinforcement.normalization.""" + +import numpy as np + +from absl.testing import absltest + +from trax.learning.reinforcement import normalization +from trax.utils import shapes + + +class NormalizationTest(absltest.TestCase): + def test_running_mean(self): + x = np.random.uniform(size=10) + state = normalization.running_mean_init(shape=()) + for i in range(len(x)): + state = normalization.running_mean_update(x[i], state) + np.testing.assert_almost_equal( + normalization.running_mean_get_mean(state), np.mean(x[: i + 1]) + ) + + def test_running_variance(self): + x = np.random.uniform(size=10) + state = normalization.running_mean_and_variance_init(shape=()) + for i in range(len(x)): + state = normalization.running_mean_and_variance_update(x[i], state) + np.testing.assert_almost_equal( + normalization.running_mean_and_variance_get_variance(state), + np.var(x[: i + 1]), + ) + + def test_normalize_collect(self): + x = np.random.uniform(size=(2, 3, 4, 5)) + normalize = normalization.Normalize(mode="collect") + normalize.init(shapes.signature(x)) + old_state = normalize.state + y = normalize(x) + with self.assertRaises(AssertionError): + np.testing.assert_equal(normalize.state, old_state) + with self.assertRaises(AssertionError): + np.testing.assert_almost_equal(x, y) + + def test_normalize_train(self): + x = np.random.uniform(size=(2, 3, 4, 5)) + normalize = normalization.Normalize(mode="train", epsilon=0.0) + normalize.init(shapes.signature(x)) + old_state = normalize.state + y = normalize(x) + np.testing.assert_equal(normalize.state, old_state) + np.testing.assert_almost_equal(x, y) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/space_serializer_test.py b/tests/learning/reinforcement/space_serializer_test.py new file mode 100644 index 000000000..212f97f4f --- /dev/null +++ b/tests/learning/reinforcement/space_serializer_test.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.reinforcement.space_serializer.""" +import gin +import gym +import numpy as np + +from tensorflow import test + +from trax.learning.reinforcement import space_serializer + + +class BoxSpaceSerializerTest(test.TestCase): + def _make_space_and_serializer( + self, + low=-10, + high=10, + shape=(2,), + # Weird vocab_size to test that it doesn't only work with powers of 2. + vocab_size=257, + # Enough precision to represent float32s accurately. + precision=4, + ): + gin.bind_parameter("BoxSpaceSerializer.precision", precision) + space = gym.spaces.Box(low=low, high=high, shape=shape) + serializer = space_serializer.create(space, vocab_size=vocab_size) + return (space, serializer) + + def _sample_batch(self, space): + return np.reshape(space.sample(), (1,) + space.shape) + + def test_representation_length(self): + (space, serializer) = self._make_space_and_serializer() + input_array = self._sample_batch(space) + representation = serializer.serialize(input_array) + self.assertEqual(representation.shape, (1, serializer.representation_length)) + + def test_commutes(self): + (space, serializer) = self._make_space_and_serializer() + input_array = self._sample_batch(space) + representation = serializer.serialize(input_array) + output_array = serializer.deserialize(representation) + # Testing till 5 decimals to reduce flakyness. + np.testing.assert_array_almost_equal(input_array, output_array, decimal=5) + + def test_representation_changes(self): + (space, serializer) = self._make_space_and_serializer() + array1 = self._sample_batch(space) + array2 = -array1 + (repr1, repr2) = tuple(map(serializer.serialize, (array1, array2))) + self.assertFalse(np.array_equal(repr1, repr2)) + + def test_bounds_space(self): + gin.bind_parameter("BoxSpaceSerializer.max_range", (-10.0, 10.0)) + (_, serializer) = self._make_space_and_serializer( + # Too wide range to represent, need to clip. + low=-1e18, + high=1e18, + shape=(1,), + ) + input_array = np.array([[1.2345]]) + representation = serializer.serialize(input_array) + output_array = serializer.deserialize(representation) + np.testing.assert_array_almost_equal(input_array, output_array) + + def test_significance_map(self): + (_, serializer) = self._make_space_and_serializer(shape=(2,)) + np.testing.assert_array_equal( + serializer.significance_map, [0, 1, 2, 3, 0, 1, 2, 3] + ) + + def test_serializes_boundaries(self): + vocab_size = 256 + precision = 4 + (_, serializer) = self._make_space_and_serializer( + low=-1, + high=1, + shape=(1,), + vocab_size=vocab_size, + precision=precision, + ) + input_array = np.array([[-1, 1]]) + representation = serializer.serialize(input_array) + np.testing.assert_array_equal( + representation, [[0] * precision + [vocab_size - 1] * precision] + ) + + +class DiscreteSpaceSerializerTest(test.TestCase): + def setUp(self): + super().setUp() + self._space = gym.spaces.Discrete(n=2) + self._serializer = space_serializer.create(self._space, vocab_size=2) + + def _sample_batch(self): + return np.reshape(self._space.sample(), (1,) + self._space.shape) + + def test_representation_length(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + self.assertEqual( + representation.shape, (1, self._serializer.representation_length) + ) + + def test_commutes(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + output_array = self._serializer.deserialize(representation) + np.testing.assert_array_almost_equal(input_array, output_array) + + def test_representation_changes(self): + array1 = self._sample_batch() + array2 = 1 - array1 + (repr1, repr2) = tuple(map(self._serializer.serialize, (array1, array2))) + self.assertFalse(np.array_equal(repr1, repr2)) + + def test_significance_map(self): + np.testing.assert_array_equal(self._serializer.significance_map, [0]) + + +class MultiDiscreteSpaceSerializerTest(test.TestCase): + def setUp(self): + super().setUp() + self._space = gym.spaces.MultiDiscrete(nvec=[2, 2]) + self._serializer = space_serializer.create(self._space, vocab_size=2) + + def _sample_batch(self): + return np.reshape(self._space.sample(), (1,) + self._space.shape) + + def test_representation_length(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + self.assertEqual( + representation.shape, (1, self._serializer.representation_length) + ) + + def test_commutes(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + output_array = self._serializer.deserialize(representation) + np.testing.assert_array_almost_equal(input_array, output_array) + + def test_representation_changes(self): + array1 = self._sample_batch() + array2 = 1 - array1 + (repr1, repr2) = tuple(map(self._serializer.serialize, (array1, array2))) + self.assertFalse(np.array_equal(repr1, repr2)) + + def test_significance_map(self): + np.testing.assert_array_equal(self._serializer.significance_map, [0, 0]) + + +if __name__ == "__main__": + test.main() diff --git a/tests/learning/reinforcement/task_test.py b/tests/learning/reinforcement/task_test.py new file mode 100644 index 000000000..13054ae8e --- /dev/null +++ b/tests/learning/reinforcement/task_test.py @@ -0,0 +1,370 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RL training.""" + +import os + +import gym +import numpy as np + +from absl.testing import absltest + +from trax.learning.reinforcement import task as rl_task +from trax.utils import test_utils + + +class DummyEnv: + """Dummy Env class for testing.""" + + observation_space = gym.spaces.Box(-2, 2, shape=(2,)) + action_space = gym.spaces.Discrete(2) + + def reset(self): + self._step = 0 + return np.ones((2,)) + + def step(self, action): + del action + info = { + "control_mask": self._step % 2 == 0, + "discount_mask": self._step % 3 == 0, + } + self._step += 1 + return np.ones((2,)), 0.0, False, info + + +class TaskTest(absltest.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + + def _extend( + self, + trajectory, + action=0, + dist_inputs=0, + reward=0, + done=False, + new_observation=0, + ): + trajectory.extend( + action=action, + dist_inputs=dist_inputs, + reward=reward, + done=done, + new_observation=new_observation, + ) + + def test_trajectory_len(self): + """Test that trajectory length is equal to the number of observations.""" + tr = rl_task.Trajectory(observation=0) + for _ in range(5): + self._extend(tr) + self.assertLen(tr, 6) + + def test_empty_trajectory_last_observation(self): + """Test that last_observation is the one passed in __init__.""" + tr = rl_task.Trajectory(observation=123) + self.assertEqual(tr.last_observation, 123) + + def test_nonempty_trajectory_last_observation(self): + """Test that last_observation is the one passed in the last extend().""" + tr = rl_task.Trajectory(observation=123) + for _ in range(5): + self._extend(tr) + self._extend(tr, new_observation=321) + self.assertEqual(tr.last_observation, 321) + + def test_trajectory_done_get_and_set(self): + """Test that we can get and set the `done` flag of a trajectory.""" + tr = rl_task.Trajectory(observation=123) + self._extend(tr) + self.assertFalse(tr.done) + tr.done = True + self.assertTrue(tr.done) + + def test_trajectory_suffix_len(self): + """Test that a trajectory suffix has the correct length.""" + tr = rl_task.Trajectory(observation=0) + for _ in range(5): + self._extend(tr) + tr_suffix = tr.suffix(length=3) + self.assertLen(tr_suffix, 3) + + def test_trajectory_suffix_observations(self): + """Test that a trajectory suffix has the correct observations.""" + tr = rl_task.Trajectory(observation=0) + for obs in range(1, 6): + self._extend(tr, new_observation=obs) + tr_suffix = tr.suffix(length=4) + self.assertEqual([ts.observation for ts in tr_suffix.timesteps], [2, 3, 4]) + self.assertEqual(tr_suffix.last_observation, 5) + + def test_trajectory_to_np_shape(self): + """Test that the shape of a to_np result matches the trajectory length.""" + tr = rl_task.Trajectory(observation=np.zeros((2, 3))) + for _ in range(5): + self._extend(tr, new_observation=np.zeros((2, 3))) + tr_np = tr.to_np() + self.assertEqual(tr_np.observation.shape, (len(tr), 2, 3)) + self.assertEqual(tr_np.action.shape, (len(tr),)) + + def test_trajectory_to_np_shape_after_extend(self): + """Test that the shape of a to_np result grows after calling extend().""" + tr = rl_task.Trajectory(observation=0) + for _ in range(5): + self._extend(tr) + len_before = tr.to_np().observation.shape[0] + self._extend(tr) + len_after = tr.to_np().observation.shape[0] + self.assertEqual(len_after, len_before + 1) + + def test_trajectory_to_np_observations(self): + """Test that to_np returns correct observations.""" + tr = rl_task.Trajectory(observation=0) + for obs in range(1, 3): + self._extend(tr, new_observation=obs) + tr_np = tr.to_np() + np.testing.assert_array_equal(tr_np.observation, [0, 1, 2]) + + def test_trajectory_to_np_adds_margin(self): + """Test that to_np adds a specified margin.""" + tr = rl_task.Trajectory(observation=2) + for _ in range(2): + self._extend(tr, new_observation=2) + tr_np = tr.to_np(margin=2) + np.testing.assert_array_equal(tr_np.observation, [2, 2, 2, 0]) + np.testing.assert_array_equal(tr_np.mask, [1, 1, 0, 0]) + + def test_trajectory_to_np_without_margin_cuts_last_observation(self): + """Test that to_np with margin=0 cuts the last observation.""" + tr = rl_task.Trajectory(observation=0) + for obs in range(1, 4): + self._extend(tr, new_observation=obs) + tr_np = tr.to_np(margin=0) + np.testing.assert_array_equal(tr_np.observation, [0, 1, 2]) + + def test_task_random_initial_trajectories_and_max_steps(self): + """Test generating initial random trajectories, stop at max steps.""" + task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=9) + stream = task.trajectory_slice_stream(max_slice_length=1) + next_slice = next(stream) + self.assertEqual(next_slice.observation.shape, (1, 2)) + + def test_time_limit_terminates_epsiodes(self): + """Test that episodes are terminated upon reaching `time_limit` steps.""" + task = rl_task.RLTask( + DummyEnv(), initial_trajectories=3, max_steps=10, time_limit=10 + ) + trajectories = task.trajectories[0] # Get trajectories from epoch 0. + self.assertLen(trajectories, 3) + for trajectory in trajectories: + self.assertTrue(trajectory.done) + # max_steps + 1 (the initial observation doesn't count). + self.assertLen(trajectory, 11) + + def test_max_steps_doesnt_terminate_epsiodes(self): + """Test that episodes are not terminated upon reaching `max_steps` steps.""" + task = rl_task.RLTask( + DummyEnv(), initial_trajectories=2, max_steps=5, time_limit=10 + ) + trajectories = task.trajectories[0] # Get trajectories from epoch 0. + self.assertLen(trajectories, 2) + # The trajectory should be cut in half. The first half should not be "done". + self.assertFalse(trajectories[0].done) + self.assertLen(trajectories[0], 6) # max_steps + 1 + # The second half should be "done". + self.assertTrue(trajectories[1].done) + self.assertLen(trajectories[1], 6) # max_steps + 1 + + def test_collects_specified_number_of_interactions(self): + """Test that the specified number of interactions are collected.""" + task = rl_task.RLTask( + DummyEnv(), initial_trajectories=0, max_steps=3, time_limit=20 + ) + task.collect_trajectories(policy=(lambda _: (0, 0)), n_interactions=10) + trajectories = task.trajectories[1] # Get trajectories from epoch 1. + n_interactions = 0 + for trajectory in trajectories: + n_interactions += len(trajectory) - 1 + self.assertEqual(n_interactions, 10) + + def test_collects_specified_number_of_trajectories(self): + """Test that the specified number of interactions are collected.""" + task = rl_task.RLTask( + DummyEnv(), initial_trajectories=0, max_steps=3, time_limit=20 + ) + task.collect_trajectories(policy=(lambda _: (0, 0)), n_trajectories=3) + trajectories = task.trajectories[1] # Get trajectories from epoch 1. + self.assertLen(trajectories, 3) + + def test_task_save_init(self): + """Test saving and re-initialization.""" + task1 = rl_task.RLTask( + DummyEnv(), initial_trajectories=13, max_steps=9, gamma=0.9 + ) + self.assertLen(task1.trajectories[0], 13) + self.assertEqual(task1.max_steps, 9) + self.assertEqual(task1.gamma, 0.9) + temp_file = os.path.join(self.create_tempdir().full_path, "task.pkl") + task1.save_to_file(temp_file) + task2 = rl_task.RLTask( + DummyEnv(), initial_trajectories=3, max_steps=19, gamma=1.0 + ) + self.assertLen(task2.trajectories[0], 3) + self.assertEqual(task2.max_steps, 19) + self.assertEqual(task2.gamma, 1.0) + task2.init_from_file(temp_file) + self.assertLen(task2.trajectories[0], 13) + self.assertEqual(task2.max_steps, 9) + self.assertEqual(task2.gamma, 0.9) + + def test_task_epochs_index_minusone(self): + """Test that the epoch index -1 means last epoch and updates to it.""" + obs = np.zeros((2,)) + tr1 = rl_task.Trajectory(obs) + self._extend(tr1, new_observation=obs, done=True) + task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) + stream = task.trajectory_slice_stream(epochs=[-1], max_slice_length=1) + next_slice = next(stream) + np.testing.assert_equal(next_slice.observation, np.zeros((1, 2))) + task.collect_trajectories(policy=(lambda _: (0, 0)), n_trajectories=1) + next_slice = next(stream) + np.testing.assert_equal(next_slice.observation, np.ones((1, 2))) + + def test_trajectory_slice_stream_shape(self): + """Test the shape yielded by trajectory stream.""" + obs = np.zeros((12, 13)) + tr1 = rl_task.Trajectory(obs) + self._extend(tr1, new_observation=obs, done=True) + task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) + stream = task.trajectory_slice_stream(max_slice_length=1) + next_slice = next(stream) + self.assertEqual(next_slice.observation.shape, (1, 12, 13)) + + def test_trajectory_slice_stream_long_slice(self): + """Test trajectory stream with slices of longer length.""" + obs = np.zeros((12, 13)) + tr1 = rl_task.Trajectory(obs) + self._extend(tr1, new_observation=obs) + self._extend(tr1, new_observation=obs, done=True) + task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) + stream = task.trajectory_slice_stream(max_slice_length=2) + next_slice = next(stream) + self.assertEqual(next_slice.observation.shape, (2, 12, 13)) + + def test_trajectory_slice_stream_sampling_uniform(self): + """Test if the trajectory stream samples uniformly.""" + # Long trajectory of 0s. + tr1 = rl_task.Trajectory(0) + for _ in range(100): + self._extend(tr1) + self._extend(tr1, new_observation=200, done=True) + # Short trajectory of 101. + tr2 = rl_task.Trajectory(101) + self._extend(tr2, new_observation=200, done=True) + task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1, tr2], max_steps=9) + + # Stream of both. Check that we're sampling by slice, not by trajectory. + stream = task.trajectory_slice_stream(max_slice_length=1) + slices = [] + for _ in range(10): + next_slice = next(stream) + assert next_slice.observation.shape[0] == 1 + slices.append(next_slice.observation[-1]) + mean_obs = sum(slices) / float(len(slices)) + # Average should be around 1 sampling from 0x100, 101 uniformly. + self.assertLess(mean_obs, 31) # Sampling 101 even 3 times is unlikely. + self.assertLen(slices, 10) + + def test_trajectory_slice_stream_sampling_by_trajectory(self): + """Test if the trajectory stream samples by trajectory.""" + # Long trajectory of 0s. + tr1 = rl_task.Trajectory(0) + for _ in range(100): + self._extend(tr1) + self._extend(tr1, new_observation=200, done=True) + # Short trajectory of 101. + tr2 = rl_task.Trajectory(101) + self._extend(tr2, new_observation=200, done=True) + task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1, tr2], max_steps=9) + + # Stream of both. Check that we're sampling by trajectory. + stream = task.trajectory_slice_stream( + max_slice_length=1, sample_trajectories_uniformly=True + ) + slices = [] + for _ in range(10): + next_slice = next(stream) + assert next_slice.observation.shape[0] == 1 + slices.append(next_slice.observation[-1]) + mean_obs = sum(slices) / float(len(slices)) + # Average should be around 50, sampling from {0, 101} uniformly. + # Sampling 101 < 2 times has low probability (but it possible, flaky test). + self.assertGreater(mean_obs, 20) + self.assertLen(slices, 10) + + def test_trajectory_slice_stream_margin(self): + """Test trajectory stream with an added margin.""" + tr1 = rl_task.Trajectory(0) + self._extend(tr1, new_observation=1) + self._extend(tr1, new_observation=1) + self._extend( + tr1, new_observation=1, action=1, dist_inputs=2, reward=3, done=True + ) + task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) + + # Stream of slices without the final state. + stream1 = task.trajectory_slice_stream(max_slice_length=4, margin=3) + got_done = False + for _ in range(20): + next_slice = next(stream1) + self.assertEqual(next_slice.observation.shape, (4,)) + if next_slice.done[0]: + # In the slice, first we have the last timestep in the actual + # trajectory, so observation = 1. + # Then comes the first timestep in the margin, which has the final + # observation from the trajectory: observation = 1. + # The remaining timesteps have 0 observations. + np.testing.assert_array_equal(next_slice.observation, [1, 1, 0, 0]) + # In the margin, done = True and mask = 0. + for i in range(1, next_slice.observation.shape[0]): + self.assertTrue(next_slice.done[i]) + self.assertFalse(next_slice.mask[i]) + got_done = True + # Assert that we got a done somewhere, otherwise the test is not triggered. + # Not getting done has low probability (1/2^20) but is possible, flaky test. + self.assertTrue(got_done) + + def test_trajectory_batch_stream_propagates_env_info(self): + task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=4) + stream = task.trajectory_batch_stream(batch_size=1, max_slice_length=4) + tr_slice = next(stream) + # control_mask = step % 2 == 0, discount_mask = step % 3 == 0. + np.testing.assert_array_equal(tr_slice.env_info.control_mask, [[1, 0, 1, 0]]) + np.testing.assert_array_equal(tr_slice.env_info.discount_mask, [[1, 0, 0, 1]]) + + def test_trajectory_batch_stream_shape(self): + task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=10) + batch_stream = task.trajectory_batch_stream( + batch_size=3, min_slice_length=4, max_slice_length=4 + ) + batch = next(batch_stream) + self.assertEqual(batch.observation.shape, (3, 4, 2)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/training_test.py b/tests/learning/reinforcement/training_test.py new file mode 100644 index 000000000..0c58be2f6 --- /dev/null +++ b/tests/learning/reinforcement/training_test.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RL training.""" + +import functools +import math +import os +import pickle + +import tensorflow as tf + +from absl.testing import absltest + +from trax import layers as tl +from trax import models +from trax import optimizers as opt +from trax.learning.reinforcement import task as rl_task +from trax.learning.reinforcement import training +from trax.learning.supervised import lr_schedules +from trax.utils import test_utils + + +class TrainingTest(absltest.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + self._model_fn = functools.partial( + models.Policy, + body=lambda mode: tl.Serial( # pylint: disable=g-long-lambda + tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu() + ), + ) + + def test_policy_gradient_smoke(self): + """Smoke test of PolicyGradient.""" + task = rl_task.RLTask("CartPole-v0", max_steps=2) + tmp_dir = self.create_tempdir().full_path + agent = training.PolicyGradient( + task, + model_fn=self._model_fn, + optimizer=opt.Adam, + batch_size=2, + n_trajectories_per_epoch=2, + n_eval_episodes=1, + output_dir=tmp_dir, + ) + agent.run(1) + self.assertEqual(agent.current_epoch, 1) + + def test_expert_iteration_smoke(self): + """Smoke test of ExpertIteration.""" + task = rl_task.RLTask("CartPole-v0", max_steps=2) + tmp_dir = self.create_tempdir().full_path + agent = training.ExpertIteration( + task, + model_fn=self._model_fn, + optimizer=opt.Adam, + batch_size=2, + n_trajectories_per_epoch=2, + n_train_steps_per_epoch=2, + n_eval_episodes=1, + output_dir=tmp_dir, + ) + agent.run(1) + self.assertEqual(agent.current_epoch, 1) + + def test_policy_gradient_save_restore(self): + """Check save and restore of policy agent.""" + task = rl_task.RLTask("CartPole-v0", max_steps=2) + tmp_dir = self.create_tempdir().full_path + agent1 = training.PolicyGradient( + task, + model_fn=self._model_fn, + optimizer=opt.Adam, + batch_size=2, + n_trajectories_per_epoch=2, + n_eval_episodes=1, + output_dir=tmp_dir, + ) + agent1.run(1) + agent1.run(1) + self.assertEqual(agent1.current_epoch, 2) + self.assertEqual(agent1.loop.step, 2) + # Trainer 2 starts where agent 1 stopped. + agent2 = training.PolicyGradient( + task, + model_fn=self._model_fn, + optimizer=opt.Adam, + batch_size=2, + n_trajectories_per_epoch=2, + n_eval_episodes=1, + output_dir=tmp_dir, + ) + agent2.run(1) + self.assertEqual(agent2.current_epoch, 3) + self.assertEqual(agent2.loop.step, 3) + # Manually set saved epoch to 1. + dictionary = { + "epoch": 1, + "avg_returns": [0.0], + "avg_returns_temperature_0.0": {200: [0.0]}, + } + with tf.io.gfile.GFile(os.path.join(tmp_dir, "reinforcement.pkl"), "wb") as f: + pickle.dump(dictionary, f) + + # Trainer 3 restores from a checkpoint with Agent/Loop step mistmatch, + # should fail. + def agent3_fn(): + return training.PolicyGradient( + task, + model_fn=self._model_fn, + optimizer=opt.Adam, + batch_size=2, + n_trajectories_per_epoch=2, + n_eval_episodes=1, + output_dir=tmp_dir, + ) + + self.assertRaises(ValueError, agent3_fn) + agent1.close() + agent2.close() + + def test_policy_gradient_cartpole(self): + """Trains a policy on cartpole.""" + task = rl_task.RLTask("CartPole-v0", max_steps=200) + lr = lambda: lr_schedules.multifactor(constant=1e-2, factors="constant") + max_avg_returns = -math.inf + for _ in range(2): + agent = training.PolicyGradient( + task, + model_fn=self._model_fn, + optimizer=opt.Adam, + lr_schedule=lr, + batch_size=128, + eval_temperatures=[0.0, 0.5], + n_eval_episodes=1, + n_trajectories_per_epoch=2, + ) + # Assert that we get to 200 at some point and then exit so the test is as + # fast as possible. + for ep in range(200): + agent.run(1) + self.assertEqual(agent.current_epoch, ep + 1) + if agent.avg_returns[-1] == 200.0: + for eval_t in agent._eval_temperatures: + self.assertEqual( + len(agent._avg_returns_temperatures[eval_t][200]), + len(agent.avg_returns), + ) + return + max_avg_returns = max(max_avg_returns, agent.avg_returns[-1]) + self.fail( + "The expected score of 200 has not been reached. " + "Maximum at end was {}.".format(max_avg_returns) + ) + + def test_dqntrainer_cartpole(self): + """Test-runs joint PPO on CartPole.""" + + task = rl_task.RLTask("CartPole-v0", initial_trajectories=0, max_steps=2) + value_body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) + + lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda + constant=1e-2, warmup_steps=100, factors="constant * linear_warmup" + ) + + trainer = training.DQN( + task, + value_body=value_body, + value_optimizer=opt.Adam, + value_lr_schedule=lr, + value_batch_size=4, + value_train_steps_per_epoch=2, + n_trajectories_per_epoch=5, + ) + trainer.run(2) + self.assertEqual(2, trainer.current_epoch) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/reinforcement/value_tasks_test.py b/tests/learning/reinforcement/value_tasks_test.py new file mode 100644 index 000000000..c0e0b0c87 --- /dev/null +++ b/tests/learning/reinforcement/value_tasks_test.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.reinforcement.value_tasks.""" + +import numpy as np + +from absl.testing import absltest + +from trax import layers as tl +from trax import models +from trax import optimizers as opt +from trax.learning.reinforcement import ( + advantages, + distributions, + policy_tasks, + value_tasks, +) +from trax.learning.reinforcement import task as rl_task +from trax.learning.supervised import lr_schedules, training + + +class ValueTasksTest(absltest.TestCase): + def setUp(self): + super().setUp() + self._model_fn = lambda mode: tl.Serial( # pylint: disable=g-long-lambda + tl.Dense(64), tl.Relu(), tl.Dense(1) + ) + self._task = rl_task.RLTask( + "CartPole-v0", gamma=0.5, max_steps=10, initial_trajectories=100 + ) + self._trajectory_batch_stream = self._task.trajectory_batch_stream( + batch_size=256, epochs=[-1], max_slice_length=2 + ) + + def _value_error(self, value_fn): + errors = [] + for _ in range(10): + batch = next(self._trajectory_batch_stream) + values = value_fn(batch) + errors.append(np.mean((values - batch.return_) ** 2)) + return np.mean(errors) + + def test_value_tasks_smoke(self): + # Smoke test for train + eval. + train_model = self._model_fn(mode="train") + eval_model = self._model_fn(mode="eval") + train_task = value_tasks.ValueTrainTask( + self._trajectory_batch_stream, + optimizer=opt.Adam(), + lr_schedule=lr_schedules.constant(1e-3), + advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), + model=train_model, + target_model=eval_model, + ) + eval_task = value_tasks.ValueEvalTask(train_task) + loop = training.Loop( + model=train_model, + eval_model=eval_model, + tasks=[train_task], + eval_tasks=[eval_task], + eval_at=(lambda _: True), + ) + loop.run(n_steps=1) + + def test_value_error_high_without_syncs(self): + train_model = self._model_fn(mode="train") + eval_model = self._model_fn(mode="eval") + train_task = value_tasks.ValueTrainTask( + self._trajectory_batch_stream, + optimizer=opt.Adam(), + lr_schedule=lr_schedules.constant(1e-3), + advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), + model=train_model, + target_model=eval_model, + # Synchronize just once, at the end of training. + sync_at=(lambda step: step == 100), + ) + loop = training.Loop( + model=train_model, + eval_model=eval_model, + tasks=[train_task], + ) + + # Assert that before training, the error is high. + error_before = self._value_error(train_task.value) + self.assertGreater(error_before, 2.0) + + loop.run(n_steps=100) + + # Assert that after training, the error is smaller, but still high. + error_after = self._value_error(train_task.value) + + self.assertLess(error_after, 2.0) + self.assertGreater(error_after, 0.8) + + def test_value_error_low_with_syncs(self): + min_error = np.inf + for _ in range(5): + train_model = self._model_fn(mode="train") + eval_model = self._model_fn(mode="eval") + train_task = value_tasks.ValueTrainTask( + self._trajectory_batch_stream, + optimizer=opt.Adam(), + lr_schedule=lr_schedules.constant(1e-3), + advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), + model=train_model, + target_model=eval_model, + # Synchronize often throughout training. + sync_at=(lambda step: step % 10 == 0), + ) + loop = training.Loop( + model=train_model, + eval_model=eval_model, + tasks=[train_task], + ) + + # Assert that before training, the error is high. + error_before = self._value_error(train_task.value) + self.assertGreater(error_before, 2.0) + + loop.run(n_steps=100) + + # Assert that after training, the error is small. + error_after = self._value_error(train_task.value) + + if error_after < 0.8: + return + + min_error = min(min_error, error_after) + + self.fail(f"Even after 5 trials, min error_after({min_error}) is not < 0.8") + + def test_integration_with_policy_tasks(self): + # Integration test for policy + value training and eval. + optimizer = opt.Adam() + lr_schedule = lr_schedules.constant(1e-3) + advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1) + policy_dist = distributions.create_distribution(self._task.action_space) + body = lambda mode: tl.Dense(64) + train_model = models.PolicyAndValue(policy_dist, body=body) + eval_model = models.PolicyAndValue(policy_dist, body=body) + + head_selector = tl.Select([1]) + value_train_task = value_tasks.ValueTrainTask( + self._trajectory_batch_stream, + optimizer, + lr_schedule, + advantage_estimator, + model=train_model, + target_model=eval_model, + head_selector=head_selector, + ) + value_eval_task = value_tasks.ValueEvalTask( + value_train_task, head_selector=head_selector + ) + + # Drop the value head - just tl.Select([0]) would pass it, and it would + # override the targets. + head_selector = tl.Select([0], n_in=2) + policy_train_task = policy_tasks.PolicyTrainTask( + self._trajectory_batch_stream, + optimizer, + lr_schedule, + policy_dist, + advantage_estimator, + # Plug a trained critic as our value estimate. + value_fn=value_train_task.value, + head_selector=head_selector, + ) + policy_eval_task = policy_tasks.PolicyEvalTask( + policy_train_task, head_selector=head_selector + ) + + loop = training.Loop( + model=train_model, + eval_model=eval_model, + tasks=[policy_train_task, value_train_task], + eval_tasks=[policy_eval_task, value_eval_task], + eval_at=(lambda _: True), + # Switch the task every step. + which_task=(lambda step: step % 2), + ) + # Run for a couple of steps to make sure there are a few task switches. + loop.run(n_steps=10) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/supervised/callbacks_test.py b/tests/learning/supervised/callbacks_test.py new file mode 100644 index 000000000..d4efb3ca4 --- /dev/null +++ b/tests/learning/supervised/callbacks_test.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.supervised.callbacks.""" + +import functools +import io + +from unittest import mock + +import gym +import numpy as np + +from absl.testing import absltest, parameterized + +from tests.layers import test_utils as tl_test_utils +from trax import models +from trax.data.preprocessing import inputs +from trax.learning.reinforcement import serialization_utils, space_serializer +from trax.learning.supervised import callbacks, lr_schedules, trainer_lib, training +from trax.utils import test_utils + + +def random_inputs(seq_len, batch_size): + def stream_fn(num_devices): + del num_devices + while True: + x = np.random.uniform(size=(batch_size, seq_len)) + y = np.random.uniform(size=(batch_size, seq_len)) + mask = np.ones_like(x).astype(np.float32) + yield (x, y, x, mask) + + return inputs.Inputs( + train_stream=stream_fn, + eval_stream=stream_fn, + ) + + +def make_multibonacci_modulo(history_length, limit): + """Creates a function that generates the Multibonacci sequence modulo n.""" + + def sequence_fn(seq): + return np.sum(seq[-history_length:]) % limit + + return sequence_fn + + +def generate_trajectory(sequence_fn, space, n_steps): + """Generates random actions and observations that follow sequence_fn.""" + act = [space.sample() for _ in range(n_steps)] + obs = [space.sample()] + + for o, a in zip( + obs, + act[:-1], # Don't generate the last observation. + ): + context = list(np.array([o, a]).flatten()) + symbols = [] + for _ in range(np.array(o).size): + symbol = sequence_fn(context + symbols) + symbols.append(symbol) + obs.append(np.reshape(symbols, space.shape)) + + obs = np.array([obs]) + act = np.array([act]) + return (obs, act) + + +def make_singleton_eval_task(observations, actions): + """Creates an EvalTask with just one example.""" + mask = np.ones(observations.shape[:2]) + + def data(): + while True: + yield (observations, actions, observations, mask) + + return training.EvalTask( + labeled_data=data(), + metrics=[], + ) + + +def make_serialized_model(seq_model, space, vocab_size): + srl = space_serializer.create(space, vocab_size) + return serialization_utils.SerializedModel( + functools.partial(seq_model, vocab_size=vocab_size), + observation_serializer=srl, + action_serializer=srl, + significance_decay=0.7, + ) + + +class CallbacksTest(parameterized.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + + @mock.patch("sys.stdout", new_callable=io.StringIO) + def test_serialized_model_evaluation(self, mock_stdout): + precision = 1 + vocab_size = 2 + srl = space_serializer.BoxSpaceSerializer( + space=gym.spaces.Box(shape=(), low=0.0, high=1.0), + vocab_size=vocab_size, + precision=precision, + ) + + def inner_model(mode): + return models.TransformerLM( + mode=mode, + vocab_size=vocab_size, + d_model=2, + d_ff=4, + n_layers=1, + n_heads=1, + ) + + serialized_model_fn = functools.partial( + serialization_utils.SerializedModel, + inner_model, + observation_serializer=srl, + action_serializer=srl, + significance_decay=0.7, + ) + eval_callback = functools.partial( + callbacks.SerializedModelEvaluation, eval_at=5 + ) + + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir=output_dir, + model=serialized_model_fn, + inputs=functools.partial(random_inputs, seq_len=4, batch_size=64), + lr_schedule_fn=functools.partial(lr_schedules.constant, 0.01), + callbacks=[eval_callback], + steps=10, + ) + self.assertTrue(_has_metric("pred_error", mock_stdout)) + + @parameterized.product( + context_lengths=((2,), (1, 3)), + horizon_lengths=((1,), (1, 2)), + ) + def test_srl_eval_feeds_correct_sequence(self, context_lengths, horizon_lengths): + vocab_size = 10 + n_steps = 5 + + multibonacci_modulo = make_multibonacci_modulo(2, vocab_size) + space = gym.spaces.Discrete(n=vocab_size) + (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) + eval_task = make_singleton_eval_task(obs, act) + seq_model = functools.partial( + tl_test_utils.MockTransformerLM, + sequence_fn=multibonacci_modulo, + ) + serialized_model = make_serialized_model(seq_model, space, vocab_size) + callback = callbacks.SerializedModelEvaluation( + loop=None, + eval_task=eval_task, + model=serialized_model, + context_lengths=context_lengths, + horizon_lengths=horizon_lengths, + accelerate_model=False, + ) + callback.evaluate(weights=None) + + expected_seq = np.zeros(2 * n_steps + 1) + expected_seq[1::2] = obs + expected_seq[2::2] = act + seen_len = (context_lengths[-1] + horizon_lengths[-1]) * 2 + callback.predict_model.assert_prediction_buffers_equal( + [expected_seq[:seen_len]] + ) + + @parameterized.named_parameters(("one_symbol", 1), ("two_symbols", 2)) + def test_srl_eval_reports_zero_error_for_perfect_model(self, precision): + vocab_size = 100 + n_steps = 5 + + multibonacci_modulo = make_multibonacci_modulo(2 * precision, vocab_size) + space = gym.spaces.MultiDiscrete(nvec=([vocab_size] * precision)) + (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) + eval_task = make_singleton_eval_task(obs, act) + seq_model = functools.partial( + tl_test_utils.MockTransformerLM, + sequence_fn=multibonacci_modulo, + ) + serialized_model = make_serialized_model(seq_model, space, vocab_size) + callback = callbacks.SerializedModelEvaluation( + loop=None, + eval_task=eval_task, + model=serialized_model, + context_lengths=(1,), + horizon_lengths=(4,), + accelerate_model=False, + ) + metrics = callback.evaluate(weights=None) + error = next(value for (name, value) in metrics.items() if "pred_error" in name) + assert error == 0 + + +def _has_metric(metric_name, stdout): + log = stdout.getvalue() + metric_logs = [line for line in log.split("\n") if metric_name in line] + return bool(metric_logs) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/supervised/decoding_test.py b/tests/learning/supervised/decoding_test.py new file mode 100644 index 000000000..abd5c0408 --- /dev/null +++ b/tests/learning/supervised/decoding_test.py @@ -0,0 +1,541 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for decoding.""" + +import functools +import os + +import gin +import numpy as np + +from tensorflow.compat.v2 import test + +from tests.data.utils import ( # relative import + _CONFIG_DIR, + _SUPERVISED_TESTDATA, +) +from tests.fastmath.jax.config import config +from trax import fastmath, models +from trax import layers as tl +from trax.learning.supervised import decoding +from trax.utils import shapes + + +class DecodingTest(test.TestCase): + def test_autoregressive_sample_transformerlm(self): + model = models.TransformerLM( + 10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode="predict" + ) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + s1 = decoding.autoregressive_sample( + model, batch_size=1, eos_id=-1, max_length=10 + ) + self.assertEqual(s1.shape[0], 1) + self.assertEqual(s1.shape[1], 10) + batch_per_device = 2 // fastmath.local_device_count() + model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) + s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) + self.assertEqual(s2.shape[0], 2) + self.assertLess(s2.shape[1], 11) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + prefix = np.array([[1, 2, 3]]) + s3 = decoding.autoregressive_sample( + model, prefix, eos_id=-1, max_length=10, batch_size=1 + ) + self.assertEqual(s3.shape[0], 1) + self.assertEqual(s3.shape[1], 10) + + def test_autoregressive_sample_transformerlm_tfnp(self): + with fastmath.use_backend(fastmath.Backend.TFNP): + model = models.TransformerLM( + 10, d_model=32, d_ff=64, n_layers=1, n_heads=2, mode="predict" + ) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + s1 = decoding.autoregressive_sample( + model, batch_size=1, eos_id=-1, max_length=10 + ) + self.assertEqual(s1.shape[0], 1) + self.assertEqual(s1.shape[1], 10) + batch_per_device = 2 // fastmath.local_device_count() + model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) + s2 = decoding.autoregressive_sample(model, batch_size=2, max_length=10) + self.assertEqual(s2.shape[0], 2) + self.assertLess(s2.shape[1], 11) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + prefix = np.array([[1, 2, 3]]) + s3 = decoding.autoregressive_sample( + model, prefix, eos_id=-1, max_length=10, batch_size=1 + ) + self.assertEqual(s3.shape[0], 1) + self.assertEqual(s3.shape[1], 10) + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def _pure_lsh_self_attention_fn(self, n_chunks_after=0): + return functools.partial( + tl.PureLSHSelfAttentionWrapper, + attention_dropout=0.0, + chunk_len=16, + n_buckets=[32, 32], + n_chunks_after=n_chunks_after, + n_chunks_before=1, + n_hashes=2, + n_parallel_heads=1, + max_length_for_buckets=1024, + predict_drop_len=128, + predict_mem_len=1024, + num_weights=2, + bias=False, + pure_lsh_implementation=tl.PureLSHSelfAttention, + ) + + def _timebin_self_attention_fn(self, use_reference_code=False, chunk_len=64): + return functools.partial( + tl.SelfAttention, + attention_dropout=0.05, + chunk_len=chunk_len, + n_chunks_before=1, + n_parallel_heads=1, + use_reference_code=use_reference_code, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def test_autoregressive_sample_reformerlm(self): + lsh_self_attention = self._lsh_self_attention_fn() + timebin_self_attention = self._timebin_self_attention_fn() + + model = models.ReformerLM( + vocab_size=256, + d_model=256, + d_ff=512, + d_attention_key=128, + d_attention_value=128, + n_layers=2, + n_heads=2, + dropout=0.05, + max_len=65536, + attention_type=[timebin_self_attention, lsh_self_attention], + pos_axial_shape=(256, 256), + pos_d_axial_embs=(128, 128), + ff_activation=tl.Relu, + ff_use_sru=0, + mode="predict", + ) + model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) + s1 = decoding.autoregressive_sample( + model, batch_size=1, eos_id=-1, max_length=10 + ) + self.assertEqual(s1.shape[0], 1) + self.assertEqual(s1.shape[1], 10) + + def test_autoregressive_sample_transformer(self): + model = models.Transformer( + 10, + d_model=32, + d_ff=64, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=2, + mode="predict", + ) + inputs = np.ones((1, 3), dtype=np.int32) + model.init( + (shapes.signature(inputs), shapes.ShapeDtype((1, 1), dtype=np.int32)) + ) + s = decoding.autoregressive_sample( + model, inputs=inputs, eos_id=-1, max_length=10 + ) + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_transformerlm_quality(self): + pred_model = models.TransformerLM( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_layers=2, + vocab_size=13, + mode="predict", + ) + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model_path = os.path.join(_SUPERVISED_TESTDATA, "transformerlm_copy.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape11, shape11) + ) + inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) + s = decoding.autoregressive_sample( + pred_model, inputs, max_length=6, temperature=0.0 + ) + self.assertEqual(str(s[0]), "[3 7 5 3 2 4]") + + def test_autoregressive_sample_transformerlm_quality_eval(self): + eval_model = models.TransformerLM( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_layers=2, + vocab_size=13, + mode="eval", + ) + model_path = os.path.join(_SUPERVISED_TESTDATA, "transformerlm_copy.pkl.gz") + eval_model.init_from_file(model_path) + inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) + s = decoding.autoregressive_sample( + eval_model, inputs, eval_mode=True, max_length=6, temperature=0.0 + ) + self.assertEqual(str(s[0]), "[3 7 5 3 2 4]") + + def test_autoregressive_sample_transformerlm_quality_beam(self): + pred_model = models.TransformerLM( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_layers=2, + vocab_size=13, + mode="predict", + ) + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model_path = os.path.join(_SUPERVISED_TESTDATA, "transformerlm_copy.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape11, shape11) + ) + inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) + s = decoding.beam_search(pred_model, inputs, n_beams=3, max_length=6) + self.assertEqual(len(s), 3) # 3 beams + self.assertEqual(str(s[0][0][0]), "[3 7 5 3 2 4]") + self.assertEqual(str(s[1][0][0]), "[3 7 5 3 2 2]") # different from above + self.assertEqual(str(s[2][0][0]), "[3 7 5 3 3 2]") # different from above + + def test_autoregressive_sample_transformer_quality(self): + pred_model = models.Transformer( + d_model=64, + d_ff=128, + dropout=0.05, + max_len=256, + n_heads=2, + n_encoder_layers=2, + n_decoder_layers=2, + input_vocab_size=13, + mode="predict", + ) + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model_path = os.path.join(_SUPERVISED_TESTDATA, "transformer_copy.pkl.gz") + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape11, shape11) + ) + inputs = np.array([[3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=1, max_length=10, temperature=0.0 + ) + self.assertEqual(str(s[0]), "[3 7 5 3 2 4 1]") + + def test_autoregressive_sample_terraformer_lsh(self): + max_len = 128 + + pred_model = models.ConfigurableTerraformer( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=max_len, + n_heads=4, + n_encoder_layers=1, + n_decoder_layers=1, + ff_use_sru=1, + d_attention_key=64, + d_attention_value=64, + encoder_attention_type=self._lsh_self_attention_fn(), + encoder_decoder_attention_type=self._lsh_self_attention_fn(), + input_vocab_size=256, + pos_axial_shape=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + pred_model.init(input_signature=(shape1l, shape11)) + + # 0w0w + inputs = np.array( + [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32 + ) + inputs = np.pad( + inputs, + [(0, 0), (0, max_len - inputs.shape[1])], + mode="constant", + constant_values=0, + ) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0 + ) + + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_terraformer_lsh_attn_quality(self): + gin.add_config_file_search_path(_CONFIG_DIR) + max_len = 32 # 32 is the max length we trained the checkpoint for. + test_lengths = [8, 16, 32] + vocab_size = 13 + # The checkpoint is correct on ~90% sequences, set random seed to deflake. + np.random.seed(0) + for test_len in test_lengths: + gin.clear_config() + gin.parse_config_file("terraformer_copy.gin") + gin.bind_parameter("LSHSelfAttention.predict_mem_len", 2 * max_len) + gin.bind_parameter("LSHSelfAttention.predict_drop_len", 2 * max_len) + + pred_model = models.ConfigurableTerraformer(mode="predict") + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + + model_path = os.path.join( + _SUPERVISED_TESTDATA, "terraformer_copy_lsh_attn.pkl.gz" + ) + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape1l, shape11) + ) + initial_state = pred_model.state + + for _ in range(2): # Set low to make the test run reasonably fast. + # Pick a length in [1, test_len] at random. + inp_len = np.random.randint(low=1, high=test_len + 1) + inputs = np.random.randint( + low=1, high=vocab_size - 1, size=(1, max_len) + ) + # TODO(jaszczur): properly fix padding in terraformer predict mode, + # and add a test here. + s = decoding.autoregressive_sample( + pred_model, + inputs=inputs, + eos_id=-1, + max_length=inp_len, + temperature=0.0, + ) + np.testing.assert_equal(s[0], inputs[0, :inp_len]) + pred_model.state = initial_state + gin.clear_config() # Make sure to not affect other tests. + + def test_autoregressive_sample_reformerlm_lsh(self): + max_len = 32 + + pred_model = models.ReformerLM( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=2 * max_len, + n_heads=4, + n_layers=3, + ff_use_sru=0, + d_attention_key=64, + d_attention_value=64, + attention_type=functools.partial( + tl.LSHSelfAttention, + chunk_len=16, + n_hashes=2, + n_buckets=[32, 32], + predict_drop_len=max_len, + predict_mem_len=max_len, + max_length_for_buckets=1024, + ), + vocab_size=13, + pos_type="fixed-base", + pos_d_axial_embs=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + pred_model.init(shape11) + + # 0w0 + inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) + inputs = np.pad( + inputs, + [(0, 0), (0, max_len - inputs.shape[1])], + mode="constant", + constant_values=0, + ) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0 + ) + + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_reformerlm_lsh_quality(self): + max_len = 32 + + pred_model = models.ReformerLM( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=2 * max_len, + n_heads=4, + n_layers=3, + ff_use_sru=0, + d_attention_key=64, + d_attention_value=64, + attention_type=functools.partial( + tl.LSHSelfAttention, + chunk_len=16, + n_hashes=2, + n_buckets=[32, 32], + predict_drop_len=max_len, + predict_mem_len=max_len, + max_length_for_buckets=1024, + ), + vocab_size=13, + pos_type="fixed-base", + pos_d_axial_embs=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + + model_path = os.path.join( + _SUPERVISED_TESTDATA, "reformerlm_copy_lsh_attn.pkl.gz" + ) + pred_model.init_from_file( + model_path, weights_only=True, input_signature=shape11 + ) + + # 0w0 + inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) + inp_len = inputs.shape[1] + s = decoding.autoregressive_sample( + pred_model, + inputs=inputs, + eos_id=-1, + max_length=inp_len - 2, + temperature=0.0, + ) + + np.testing.assert_equal(s[0], inputs[0, 1 : inp_len - 1]) + # pylint: enable=unreachable + + def test_autoregressive_sample_terraformer_pure_lsh(self): + max_len = 128 + + pred_model = models.ConfigurableTerraformer( + mode="predict", + d_model=256, + d_ff=512, + dropout=0.05, + max_len=max_len, + n_heads=4, + n_encoder_layers=1, + n_decoder_layers=1, + ff_use_sru=1, + d_attention_key=64, + d_attention_value=64, + encoder_attention_type=self._pure_lsh_self_attention_fn(n_chunks_after=1), + encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(), + input_vocab_size=256, + pos_axial_shape=None, + ) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + pred_model.init(input_signature=(shape1l, shape11)) + + # 0w0w + inputs = np.array( + [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32 + ) + inputs = np.pad( + inputs, + [(0, 0), (0, max_len - inputs.shape[1])], + mode="constant", + constant_values=0, + ) + s = decoding.autoregressive_sample( + pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0 + ) + + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 10) + + def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self): + gin.add_config_file_search_path(_CONFIG_DIR) + max_len = 32 # 32 is the max length we trained the checkpoint for. + test_lengths = [8, 16, 32] + vocab_size = 13 + # The checkpoint is correct on ~90% sequences, set random seed to deflake. + np.random.seed(0) + for test_len in test_lengths: + gin.clear_config() + gin.parse_config_file("terraformer_purelsh_copy.gin") + gin.bind_parameter("PureLSHSelfAttention.predict_mem_len", 2 * max_len) + gin.bind_parameter("PureLSHSelfAttention.predict_drop_len", 2 * max_len) + gin.bind_parameter("PureLSHSelfAttentionWrapper.bias", False) + gin.bind_parameter("PureLSHSelfAttentionWrapper.num_weights", 2) + + pred_model = models.ConfigurableTerraformer(mode="predict") + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + + model_path = os.path.join( + _SUPERVISED_TESTDATA, "terraformer_purelsh_copy.pkl.gz" + ) + pred_model.init_from_file( + model_path, weights_only=True, input_signature=(shape1l, shape11) + ) + initial_state = pred_model.state + + for _ in range(2): # Set low to make the test run reasonably fast. + # Pick a length in [1, test_len] at random. + inp_len = np.random.randint(low=1, high=test_len + 1) + inputs = np.random.randint( + low=1, high=vocab_size - 1, size=(1, max_len) + ) + # TODO(jaszczur): properly fix padding in terraformer predict mode, + # and add a test here. + s = decoding.autoregressive_sample( + pred_model, + inputs=inputs, + eos_id=-1, + max_length=inp_len, + temperature=0.0, + ) + + np.testing.assert_equal(s[0], inputs[0, :inp_len]) + pred_model.state = initial_state + gin.clear_config() # Make sure to not affect other tests. + + +if __name__ == "__main__": + config.config_with_absl() + test.main() diff --git a/tests/learning/supervised/decoding_timing_test.py b/tests/learning/supervised/decoding_timing_test.py new file mode 100644 index 000000000..75b055287 --- /dev/null +++ b/tests/learning/supervised/decoding_timing_test.py @@ -0,0 +1,501 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Timing tests for decoding.""" + +import copy +import functools +import gc +import os +import time + +import numpy as np +import psutil + +from tensorflow.compat.v2 import test + +from tests.fastmath.jax.config import config +from trax import fastmath, models +from trax import layers as tl +from trax.learning.supervised import decoding +from trax.utils import shapes + + +def _size_of_model(model): + def _size(x): + try: + return x.size + except Exception: # pylint: disable=broad-except + return 0 + + sizes = fastmath.nested_map(_size, model.weights) + total_size = sum(fastmath.tree_flatten(sizes)) + return total_size + + +def _recurrent_delete(w): + if "delete" in dir(w): + # Object has a 'delete' method, so it is a DeviceArray or something similar, + # so we want to delete it. + w.delete() + elif isinstance(w, (list, tuple)): + for x in w: + _recurrent_delete(x) + elif isinstance(w, dict): + for x in w.values(): + _recurrent_delete(x) + else: + raise ValueError("Unknown type encountered in weights: {}".format(type(w))) + + +def _memory_usage(): + gc.collect() + return psutil.Process(os.getpid()).memory_info().rss + + +class DecodingTimingTest(test.TestCase): + def _terraformer_decoding_time(self, settings): + # Garbage collection influences the timing, so we turn it off. + gc.disable() + max_len = 16 + + def _self_attention_fn(): + return functools.partial( + tl.SelfAttention, + predict_drop_len=2 * max_len, + predict_mem_len=2 * max_len, + ) + + def _causal_attention_fn(): + attn_layer, attn_kwargs = settings["attn"] + return functools.partial( + attn_layer, max_inference_length=2 * max_len, **attn_kwargs + ) + + if settings["model"] == "terraformer": + pred_model = models.ConfigurableTerraformer( + mode="predict", + d_model=settings["d_model"], + d_ff=settings["d_ff"], + dropout=0.1, + max_len=max_len, + n_heads=settings["n_heads"], + n_encoder_layers=settings["encoder_layers"], + n_decoder_layers=settings["decoder_layers"], + encoder_attention_type=_self_attention_fn(), + encoder_decoder_attention_type=_causal_attention_fn(), + input_vocab_size=settings["vocab"], + ff_sparsity=settings["ff_sparsity"], + ff_use_sru=settings["ff_use_sru"], + ff_dropout=0.1, + # ff_chunk_size=1024, + # attention_chunk_size=1, + n_decoder_attention_layers=settings["attention_layers"], + loss_sparsity=settings["loss_sparsity"], + pos_axial_shape=None, + use_bfloat16=True, + ) + elif settings["model"] == "transformer": + pred_model = models.ConfigurableTransformer( + mode="predict", + d_model=settings["d_model"], + d_ff=settings["d_ff"], + dropout=0.1, + max_len=max_len, + n_heads=settings["n_heads"], + n_encoder_layers=settings["encoder_layers"], + n_decoder_layers=settings["decoder_layers"], + # encoder_attention_type=_self_attention_fn(), + encoder_decoder_attention_type=_causal_attention_fn(), + input_vocab_size=settings["vocab"], + ff_sparsity=settings["ff_sparsity"], + ff_use_sru=settings["ff_use_sru"], + # ff_dropout=0.1, + # ff_chunk_size=1024, + # attention_chunk_size=1, + # n_decoder_attention_layers=settings['attention_layers'], + loss_sparsity=settings["loss_sparsity"], + pos_axial_shape=None, + # enc_dec_attention_sparsity=settings['enc_dec_sparsity'], + # use_bfloat16=True, + ) + else: + assert False + # We put acceleration outside of autoregressive_sample_stream, because + # we want to have a separate run (separate input) for model compilation. + pred_model = tl.Accelerate(pred_model) + + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) + pred_model.init(input_signature=(shape1l, shape11)) + original_state = copy.deepcopy(pred_model.state) + + inputs_warmup = np.zeros((1, max_len), dtype=np.int32) + inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) + + # This is a warm-up run, for compilation. + result, current_time = [], time.time() + elapsed_warmup_times = [] + for index, sample in zip( + range(0, 4), + decoding.autoregressive_sample_stream( + pred_model, inputs_warmup, temperature=0.0, accelerate=False + ), + ): + del index # unused + result.append(sample[:, None]) # to be sure that the result is computed + + current_time, start_time = time.time(), current_time + elapsed_warmup_times.append(current_time - start_time) + + # This is a real decoding timing run that we measure. + pred_model.state = original_state + result, current_time = [], time.time() + elapsed_times = [] + for index, sample in zip( + range(12), + decoding.autoregressive_sample_stream( + pred_model, inputs, temperature=0.0, accelerate=False + ), + ): + del index # unused + result.append(sample[:, None]) # to be sure that the result is computed + + current_time, start_time = time.time(), current_time + elapsed_times.append(current_time - start_time) + peak_memory = _memory_usage() + + if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]): + print( + "WARNING! High variance found in elapsed times! Settings: {} ; " + "elapsed times: {} ; Probably more warm-up steps should be used, " + "or model size should be increased.".format(settings, elapsed_times) + ) + # Check resulting shapes. + s = np.concatenate(result, axis=1) + self.assertEqual(s.shape[0], 1) + self.assertEqual(s.shape[1], 12) + model_size = int(_size_of_model(pred_model)) + + # We delete the model weights, because in some situations they won't be + # deleted automatically. + _recurrent_delete(pred_model.weights) + gc.enable() + return model_size, elapsed_times, peak_memory + + def test_autoregressive_sample_terraformer_timing(self): + template_to_use = "medium_transformer" + + settings_templates = { + # full model + # # 54B params + # 'full_model': { + # 'encoder_layers': 6, 'decoder_layers': 36, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': 64*1024, 'd_model': 96*96, 'n_heads': 96, + # 'ff_use_sru': (1, 64), 'ff_sparsity': (256, 32), + # 'loss_sparsity': 8, + # 'attn': (tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 64})}, + # 1/18 of model (1/6 of encoder, 1/18 of decoder, full vocab) + # 4B params + # 'big_terraformer': { + # 'model': 'terraformer', + # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, + # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, + # 'attn': (tl.CausalAttention, {})}, + # 'big_transformer': { + # 'model': 'transformer', + # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, + # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, + # 'attn': (tl.CausalAttention, {})}, + # medium model + # 275M params (only decoder) + "medium_transformer": { + "model": "transformer", + "encoder_layers": 2, + "decoder_layers": 24, + "vocab": 32000, + "attention_layers": 2, + "d_ff": 4 * 1024, + "d_model": 1024, + "n_heads": 16, + "ff_use_sru": 0, + "ff_sparsity": 0, + "loss_sparsity": 0, + "attn": (tl.CausalAttention, {}), + }, + # 'medium_terraformer': { + # 'model': 'terraformer', + # 'encoder_layers': 2, 'decoder_layers': 24, 'vocab': 32000, + # 'attention_layers': 2, + # 'd_ff': 4*1024, 'd_model': 1024, 'n_heads': 16, + # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, + # 'attn': (tl.CausalAttention, {})}, + } + + sweep_settings = { + # 'big_transformer': [ # for big + # dict(), # baseline + # {'ff_sparsity': (256, 32)}, # + Sparse FF + # {'attn': ( # + Sparse QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 64}), + # 'd_ff': 64*1024, + # }, + # {'ff_sparsity': (256, 32), + # 'attn': ( # + Sparse FF+QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 64}), + # 'd_ff': 64*1024, + # }, + # ], + "medium_transformer": [ # for medium + dict(), # baseline + { + "ff_sparsity": 64, + "attn": ( # Sparse FF+QKV + tl.MultiplicativeConvCausalAttention, + {"length_kernel_size": 3, "sparsity": 16}, + ), + "d_ff": 6 * 1024, + }, + # {'ff_sparsity': 64, # Sparse FF+QKV + Loss + # 'attn': ( + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # 'loss_sparsity': 4, + # }, + # {'attn': ( # Sparse QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # }, + # {'loss_sparsity': 4}, # Sparse Loss + # {'ff_sparsity': 64}, # Sparse FF + # {'ff_sparsity': 128}, # + Sparse FF 128 + # APPENDIX below + # different loss layers + # {'loss_sparsity': 8}, + # {'loss_sparsity': 2}, + # {'loss_sparsity': 0}, + ], + # 'big_terraformer': [ # for big terraformer + # dict(), # baseline + # {'ff_sparsity': 64}, # + Sparse FF / Sparse FF 64 + # {'ff_sparsity': 64, + # 'attn': ( # + Sparse FF+QKV + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # }, + # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss + # 'attn': ( + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # 'loss_sparsity': 4, + # }, + # ], + # 'medium_terraformer': [ # for medium terraformer + # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss + # 'attn': ( + # tl.MultiplicativeConvCausalAttention, + # {'length_kernel_size': 3, 'sparsity': 16}), + # 'd_ff': 6*1024, + # 'loss_sparsity': 4, + # }, + # ], + } + + encoding_times = [] + decoding_times = [] + sizes = [] + memories = [] + messages = [] + for override_settings in sweep_settings[template_to_use]: + settings = copy.deepcopy(settings_templates[template_to_use]) + settings.update(override_settings) + + init_memory = _memory_usage() + size, elapsed_times, peak_memory = self._terraformer_decoding_time(settings) + + # TODO(jaszczur): Why is elapsed_times[0] always small? + encoding_time = elapsed_times[1] + decoding_time_10 = sum(elapsed_times[2:]) + + after_memory = _memory_usage() + model_memory_gigabytes = (peak_memory - init_memory) / 1024**3 + decoding_time_diff = (max(elapsed_times[2:]) - min(elapsed_times[2:])) / 2 + decoding_time_diff_percent = int( + decoding_time_diff / np.mean(elapsed_times) * 100 + ) + message = ( + "\n\n" + "Params: {}\n" + "Settings: {}\n" + "Override: {}\n" + "Init memory: {:.1f} GiB\n" + "Peak memory: {:.1f} GiB\n" + "After memory: {:.1f} GiB\n" + "Estimated model memory: {:.1f} GiB\n" + "Times for each step: {}\n" + "Time for encoding: {:.4f} s\n" + "Time for decoding 10 tokens: {:.4f} s +/- {} %\n" + "\n\n".format( + size, + settings, + override_settings, + init_memory / 1024**3, + peak_memory / 1024**3, + after_memory / 1024**3, + model_memory_gigabytes, + elapsed_times, + encoding_time, + decoding_time_10, + decoding_time_diff_percent, + ) + ) + print(message) + messages.append(message) + encoding_times.append(encoding_time) + decoding_times.append(decoding_time_10) + sizes.append(size) + memories.append(model_memory_gigabytes) + + print("Final results (recap):") + for message in messages: + print(message) + + # This is useful for copying results into a spreadsheet etc. + # for i in range(len(sweep_settings)): + # print('{}\t{}\t{}\t{:.1f}'.format( + # sizes[i], encoding_times[i], decoding_times[i], memories[i])) + + def test_loss_layer_timing(self): + all_settings = [ + # The first run is sometimes slower, less reliable. + { + "output": 32000, + "input": 2048, + "prob": None, + "type": None, + "sparsity": 0, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": None, + "sparsity": 0, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "einsum", + "sparsity": 0, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "mult", + "sparsity": 2, + "lowrank": 0, + "use_bias": False, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": None, + "sparsity": 0, + "lowrank": 0, + "use_bias": True, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "einsum", + "sparsity": 0, + "lowrank": 0, + "use_bias": True, + }, + { + "output": 32000, + "input": 2048, + "prob": None, + "type": "mult", + "sparsity": 2, + "lowrank": 0, + "use_bias": True, + }, + ] + + messages = [] + for settings in all_settings: + pred_model = tl.SparseDenseWithOptions( + n_units=settings["output"], + d_input=settings["input"], + sparsity_type=settings["type"], + sparsity=settings["sparsity"], + d_lowrank=settings["lowrank"], + prob_sparse=settings["prob"], + use_bias=settings["use_bias"], + mode="predict", + ) + pred_model = tl.Accelerate(pred_model) + + shape1l = shapes.ShapeDtype((1, settings["input"])) + pred_model.init(input_signature=shape1l) + inputs = np.ones((1, settings["input"])) + + total_time = 0.0 + for counter in range(-50, 100): + start_time = time.time() + y = pred_model(inputs) + self.assertEqual(y.shape, (1, settings["output"])) + elapsed_time = time.time() - start_time + if counter >= 0: + total_time += elapsed_time + + message = ( + "\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n" + % (_size_of_model(pred_model), settings, total_time) + ) + messages.append(message) + print(message) + + print("Final results (recap):") + for message in messages: + print(message) + + +if __name__ == "__main__": + config.config_with_absl() + test.main() diff --git a/tests/learning/supervised/history_test.py b/tests/learning/supervised/history_test.py new file mode 100644 index 000000000..1b28ff89e --- /dev/null +++ b/tests/learning/supervised/history_test.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.supervised.history.""" + +from absl.testing import absltest + +from trax.learning.supervised import history as trax_history + + +class HistoryTest(absltest.TestCase): + def test_unknown_mode(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + self.assertEqual(history.get("unknown_mode", "metric1"), []) + + def test_unknown_metric(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + self.assertEqual(history.get("train", "unknown_metric"), []) + + def test_serializer_and_deserializer(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + json_object = history.to_dict() + history2 = trax_history.History.from_dict(json_object) + self.assertEqual(history2.get("train", "metric1"), [(1, 0.1)]) + + def test_modes(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + history.append("test", "metric2", 2, 0.2) + self.assertEqual(history.modes, ["test", "train"]) + + def test_metrics_for_mode(self): + history = trax_history.History() + history.append("train", "metric1", 1, 0.1) + history.append("train", "metric2", 2, 0.2) + self.assertEqual(history.metrics_for_mode("train"), ["metric1", "metric2"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/supervised/lr_schedules_test.py b/tests/learning/supervised/lr_schedules_test.py new file mode 100644 index 000000000..7b33556f2 --- /dev/null +++ b/tests/learning/supervised/lr_schedules_test.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests of learning rate schedules.""" + +import math + +from absl.testing import absltest + +from trax.learning.supervised import lr_schedules + + +class LRFunctionsTest(absltest.TestCase): + def test_warmup(self): + lr_fn = lr_schedules.warmup(9, 0.01) + + # Linear warm-up. + self.assertAlmostEqual(0.001, lr_fn(1)) + self.assertAlmostEqual(0.002, lr_fn(2)) + self.assertAlmostEqual(0.005, lr_fn(5)) + self.assertAlmostEqual(0.009, lr_fn(9)) + + # Constant thereafter. + self.assertAlmostEqual(0.01, lr_fn(10)) + self.assertAlmostEqual(0.01, lr_fn(11)) + self.assertAlmostEqual(0.01, lr_fn(20)) + self.assertAlmostEqual(0.01, lr_fn(300)) + self.assertAlmostEqual(0.01, lr_fn(4000)) + + def test_constant(self): + lr_fn = lr_schedules.constant(0.02) + self.assertEqual(0.02, lr_fn(1)) + self.assertEqual(0.02, lr_fn(20)) + self.assertEqual(0.02, lr_fn(300)) + self.assertEqual(0.02, lr_fn(4000)) + self.assertEqual(0.02, lr_fn(50000)) + self.assertEqual(0.02, lr_fn(600000)) + self.assertEqual(0.02, lr_fn(7000000)) + self.assertEqual(0.02, lr_fn(80000000)) + self.assertEqual(0.02, lr_fn(900000000)) + + def test_warmup_and_rsqrt_decay(self): + lr_fn = lr_schedules.warmup_and_rsqrt_decay(24, 0.25) + + # Warm-up. + self.assertAlmostEqual(0.01, lr_fn(1)) + self.assertAlmostEqual(0.02, lr_fn(2)) + self.assertAlmostEqual(0.23, lr_fn(23)) + self.assertAlmostEqual(0.24, lr_fn(24)) + + # Reciprocal square-root decay. + self.assertAlmostEqual(0.25 * (5 / math.sqrt(25)), lr_fn(25)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(26)), lr_fn(26)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(27)), lr_fn(27)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(300)), lr_fn(300)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(4000)), lr_fn(4000)) + self.assertAlmostEqual(0.25 * (5 / math.sqrt(50000)), lr_fn(50000)) + + def test_cosine_sawtooth(self): + tail_fn = lr_schedules._CosineSawtoothTail(180, min_value=0.1) + lr_fn = lr_schedules._BodyAndTail(0.3, tail_start=0, tail_fn=tail_fn) + + # First cycle + self.assertAlmostEqual(0.29998477, lr_fn(1)) + self.assertAlmostEqual(0.28660254, lr_fn(30)) + self.assertAlmostEqual(0.25, lr_fn(60)) + self.assertAlmostEqual(0.20, lr_fn(90)) + self.assertAlmostEqual(0.15, lr_fn(120)) + self.assertAlmostEqual(0.10001523, lr_fn(179)) + + # Second cycle + self.assertEqual(0.3, lr_fn(180)) + self.assertAlmostEqual(0.29998477, lr_fn(181)) + self.assertAlmostEqual(0.28660254, lr_fn(210)) + self.assertAlmostEqual(0.25, lr_fn(240)) + self.assertAlmostEqual(0.20, lr_fn(270)) + self.assertAlmostEqual(0.15, lr_fn(300)) + self.assertAlmostEqual(0.10001523, lr_fn(359)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/supervised/mnist_test.py b/tests/learning/supervised/mnist_test.py new file mode 100644 index 000000000..0bd8c7464 --- /dev/null +++ b/tests/learning/supervised/mnist_test.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test training an MNIST model 100 steps (saves time vs. 2000 steps).""" + +import io + +from unittest import mock + +from absl.testing import absltest + +from trax import layers as tl +from trax.data.loader.tf import base as dataset +from trax.data.preprocessing import inputs as preprocessing +from trax.learning.supervised import training +from trax.optimizers import adam + + +class MnistTest(absltest.TestCase): + @mock.patch("sys.stdout", new_callable=io.StringIO) + def test_train_mnist_single_task(self, mock_stdout): + """Train MNIST model a bit, to compare to other implementations.""" + mnist_model = _build_model(two_heads=False) + (task, eval_task) = _mnist_tasks() + training_session = training.Loop( + mnist_model, + tasks=[task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 20 == 0, + ) + + training_session.run(n_steps=100) + self.assertEqual(training_session.step, 100) + + # Assert that we reach at least 80% eval accuracy. + self.assertGreater(_read_metric("WeightedCategoryAccuracy", mock_stdout), 0.8) + + @mock.patch("sys.stdout", new_callable=io.StringIO) + def test_train_mnist_multitask(self, mock_stdout): + """Train two-head MNIST model a bit, to compare to other implementations.""" + mnist_model = _build_model(two_heads=True) + # MNIST classification task. + (cls_task, cls_eval_task) = _mnist_tasks(head=tl.Select([0], n_in=2)) + (train_batches_stream, eval_batches_stream) = _mnist_brightness_dataset() + # Auxiliary brightness prediction task. + reg_task = training.TrainTask( + train_batches_stream, + tl.Serial(tl.Select([1]), tl.L2Loss()), + adam.Adam(0.001), + ) + reg_eval_task = training.EvalTask( + eval_batches_stream, + [tl.Serial(tl.Select([1]), tl.L2Loss())], + n_eval_batches=1, + metric_names=["L2"], + ) + training_session = training.Loop( + mnist_model, + tasks=[cls_task, reg_task], + eval_tasks=[cls_eval_task, reg_eval_task], + eval_at=lambda step_n: step_n % 20 == 0, + which_task=lambda step_n: step_n % 2, + ) + + training_session.run(n_steps=1_000) + self.assertEqual(training_session.step, 1_000) + + # Assert that we reach at least 80% eval accuracy on MNIST. + self.assertGreater(_read_metric("WeightedCategoryAccuracy", mock_stdout), 0.8) + # Assert that we get below 0.03 brightness prediction error. + self.assertLess(_read_metric("L2", mock_stdout), 0.03) + + +def _build_model(two_heads): + cls_head = tl.Dense(10) + if two_heads: + reg_head = tl.Dense(1) + heads = tl.Branch(cls_head, reg_head) + else: + heads = cls_head + return tl.Serial( + tl.Fn("ScaleInput", lambda x: x / 255), + tl.Flatten(), + tl.Dense(512), + tl.Relu(), + tl.Dense(512), + tl.Relu(), + heads, + ) + + +def _mnist_brightness_dataset(): + """Loads (and caches) a MNIST mean brightness data set.""" + train_stream = dataset.TFDS("mnist", keys=("image", "label"), train=True)() + eval_stream = dataset.TFDS("mnist", keys=("image", "label"), train=False)() + + train_data_pipeline = preprocessing.Serial( + lambda g: map( + lambda item: (lambda x, y: (x, (x / 255).mean().flatten()))(*item), g + ), + preprocessing.Batch(8), + preprocessing.AddLossWeights(), + ) + train_batches_stream = train_data_pipeline(train_stream) + + eval_data_pipeline = preprocessing.Serial( + lambda g: map( + lambda item: (lambda x, y: (x, (x / 255).mean().flatten()))(*item), g + ), + preprocessing.Batch(8), + preprocessing.AddLossWeights(), + ) + eval_batches_stream = eval_data_pipeline(eval_stream) + + return train_batches_stream, eval_batches_stream + + +def _mnist_tasks(head=None): + """Creates MNIST training and evaluation tasks. + + Args: + head: Adaptor layer to put before loss and accuracy layers in the tasks. + + Returns: + A pair (train_task, eval_task) consisting of the MNIST training task and the + MNIST evaluation task using cross-entropy as loss and accuracy as metric. + """ + train_stream = dataset.TFDS("mnist", keys=("image", "label"), train=True)() + eval_stream = dataset.TFDS("mnist", keys=("image", "label"), train=False)() + + train_data_pipeline = preprocessing.Serial( + preprocessing.Batch(8), + preprocessing.AddLossWeights(), + ) + train_batches_stream = train_data_pipeline(train_stream) + + eval_data_pipeline = preprocessing.Serial( + preprocessing.Batch(8), + preprocessing.AddLossWeights(), + ) + eval_batches_stream = eval_data_pipeline(eval_stream) + + loss = tl.WeightedCategoryCrossEntropy() + accuracy = tl.WeightedCategoryAccuracy() + if head is not None: + loss = tl.Serial(head, loss) + accuracy = tl.Serial(head, accuracy) + task = training.TrainTask( + train_batches_stream, + loss, + adam.Adam(0.001), + ) + eval_task = training.EvalTask( + eval_batches_stream, + [loss, accuracy], + n_eval_batches=10, + metric_names=["CrossEntropy", "WeightedCategoryAccuracy"], + ) + return (task, eval_task) + + +def _read_metric(metric_name, stdout): + log = stdout.getvalue() + metric_log = [line for line in log.split("\n") if metric_name in line][-1] + return float(metric_log.strip().split(" ")[-1]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/learning/supervised/trainer_lib_test.py b/tests/learning/supervised/trainer_lib_test.py new file mode 100644 index 000000000..a888fa9d8 --- /dev/null +++ b/tests/learning/supervised/trainer_lib_test.py @@ -0,0 +1,581 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.supervised.trainer_lib.""" + +import functools +import os + +import jax +import tensorflow.compat.v2 as tf + +from absl.testing import absltest, parameterized + +from tests.fastmath.jax.config import config +from trax import fastmath, models +from trax import layers as tl +from trax import optimizers as trax_opt +from trax.data.preprocessing import inputs as inputs_lib +from trax.fastmath import numpy as jnp +from trax.learning.supervised import lr_schedules as lr +from trax.learning.supervised import trainer_lib +from trax.tf import extensions as npe +from trax.tf import numpy as tf_np +from trax.utils import shapes as trax_shapes +from trax.utils import test_utils + + +def _test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)): + """Make trainer_lib.inputs.Inputs.""" + batch_size = 2 * jax.device_count() + + def input_stream(n_devices): + del n_devices + key = fastmath.random.get_prng(0) + while True: + keys = fastmath.random.split(key, 4) + key = keys[0] + inputs = fastmath.random.uniform(keys[1], [batch_size] + list(input_shape)) + targets = fastmath.random.randint( + keys[2], [batch_size], dtype=jnp.int32, minval=0, maxval=n_classes + ) + weights = fastmath.random.uniform(keys[3], [batch_size]) + if with_weights: + yield inputs, targets, weights + else: + yield inputs, targets + + def input_stream_masked(n_devices): + return inputs_lib.add_loss_weights(input_stream(n_devices)) + + return inputs_lib.Inputs(input_stream_masked) + + +def _test_inputs_lm(vocab_size, seq_len, per_device_batch_size=2): + """Make trainer_lib.inputs.Inputs for language model.""" + batch_size = per_device_batch_size * jax.device_count() + + def input_stream(_): + def make_batch(key): + return fastmath.random.randint( + key, [batch_size, seq_len], dtype=jnp.int32, minval=0, maxval=vocab_size + ) + + key = fastmath.random.get_prng(0) + while True: + keys = fastmath.random.split(key, 3) + key = keys[0] + inputs = make_batch(keys[1]) + targets = make_batch(keys[2]) + yield inputs, targets + + def input_stream_masked(n_devices): + return inputs_lib.add_loss_weights(input_stream(n_devices)) + + return inputs_lib.Inputs(input_stream_masked) + + +BACKENDS = [fastmath.Backend.JAX] +BACKENDS_AND_CONFIGS = [(fastmath.Backend.JAX, [("Simple", None)])] + + +def short_name(b): + if b == fastmath.Backend.JAX: + return "jax" + else: + return "tf" + + +def opt_name(opt): + if opt is None: + return "None" + return opt.__name__ + + +def _pure_lsh_self_attention_fn(n_chunks_after=0): + return functools.partial( + tl.PureLSHSelfAttentionWrapper, + attention_dropout=0.1, + chunk_len=16, + n_buckets=[32, 32], + n_chunks_after=n_chunks_after, + n_chunks_before=1, + n_hashes=2, + n_parallel_heads=1, + max_length_for_buckets=1024, + predict_drop_len=128, + predict_mem_len=1024, + num_weights=2, + bias=False, + pure_lsh_implementation=tl.PureLSHSelfAttention, + ) + + +def _mixed_lsh_self_attention_fn(n_chunks_after=0): + return functools.partial( + tl.PureLSHSelfAttentionWrapper, + attention_dropout=0.1, + chunk_len=16, + n_buckets=[32, 32], + n_chunks_after=n_chunks_after, + n_chunks_before=1, + n_hashes=2, + n_parallel_heads=1, + max_length_for_buckets=1024, + predict_drop_len=128, + predict_mem_len=1024, + num_weights=2, + bias=False, + pure_lsh_implementation=tl.MixedLSHSelfAttention, + ) + + +class TraxTest(parameterized.TestCase): + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super().__init__(methodName) + if npe.tpu_devices(): + # Initialize TPU for TF + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") + tf.tpu.experimental.initialize_tpu_system(resolver) + + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + self._old_is_allow_float64 = tf_np.is_allow_float64() + tf_np.set_allow_float64(False) + + def tearDown(self): + tf_np.set_allow_float64(self._old_is_allow_float64) + super().tearDown() + + def _test_train_eval_predict(self, backend, model_name="Simple", optimizer=None): + with fastmath.use_backend(backend): + # Prepare model and inputs + steps = 2 + eval_steps = 2 + + if model_name == "Simple": + n_classes = 4 + + # Adds Dropout and BatchNorm to test state handling. + def model_fn(mode="train"): + return tl.Serial( + tl.Dropout(mode=mode, rate=0.1), + tl.BatchNorm(mode=mode), + models.MLP(layer_widths=(16, 16, n_classes), mode=mode), + ) + + inputs = _test_inputs(n_classes) + n_in = 1 + elif model_name == "Resnet50": + n_classes = 4 + model_fn = models.Resnet50 + inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) + n_in = 1 + elif model_name == "Transformer": + vocab_size = 32 + seq_len = 16 + inputs = _test_inputs_lm(vocab_size, seq_len) + model_fn = functools.partial( + models.Transformer, input_vocab_size=vocab_size + ) + n_in = 2 + else: + raise ValueError("Unrecognized model name: " + model_name) + + kwargs = {} + if optimizer is not None: + kwargs["optimizer"] = optimizer + + # Train and evaluate + output_dir = self.create_tempdir().full_path + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, # eval at every step. + **kwargs, + ) + + # Assert total train steps + self.assertEqual(steps, loop.step) + + inputs = inputs.train_stream(1) + + # Predict with final weights + model = model_fn() + weights = loop.model.weights + state = loop.model.state + model(next(inputs)[:n_in], weights=weights, state=state) + + # Predict with weights loaded from file. + model = model_fn() + model.init_from_file(os.path.join(output_dir, "model.pkl.gz")) + model(next(inputs)[:n_in]) + + @parameterized.named_parameters( + ( + "_%s_%s_%s" + % ( + short_name(backend), + model_name, + opt_name(opt), + ), # pylint: disable=g-complex-comprehension + backend, + model_name, + opt, + ) + for backend, configs in BACKENDS_AND_CONFIGS + for model_name, opt in configs + ) + def test_train_eval_predict(self, backend, model_name, opt): + self._test_train_eval_predict(backend, model_name, opt) + + @parameterized.parameters(BACKENDS) + def test_train_eval_predict_sm3(self, backend): + self._test_train_eval_predict(backend, "Simple", trax_opt.SM3) + + @parameterized.parameters(BACKENDS) + def test_train_restart(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Restart training + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=(2 * steps), + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Assert total train steps + self.assertEqual(loop.step, 2 * steps) + + @parameterized.parameters(BACKENDS) + def test_train_permanent_checkpoints(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 5 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + + # Steps 1 -> 5 + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + permanent_checkpoint_frequency=2, + ) + + # Steps 6 -> 10 + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=(2 * steps), + eval_steps=eval_steps, + eval_frequency=1, + permanent_checkpoints_at=[7, 8, 10], + ) + + path = os.path.join(output_dir, "model.pkl.gz") + self.assertTrue(tf.io.gfile.exists(path)) + + for step in range(11): + filename = "model_{}.pkl.gz".format(step) + path = os.path.join(output_dir, filename) + if step in [1, 2, 4, 7, 8, 10]: + self.assertTrue( + tf.io.gfile.exists(path), + msg="No model for step: {} in dir {}.".format( + step, tf.io.gfile.listdir(output_dir) + ), + ) + else: + self.assertFalse( + tf.io.gfile.exists(path), + msg="Model for step: {} in dir {}.".format( + step, tf.io.gfile.listdir(output_dir) + ), + ) + + # Assert total train steps + self.assertEqual(loop.step, 10) + + @parameterized.parameters(BACKENDS) + def test_train_restart_with_same_steps(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Restart training + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Assert total train steps + self.assertEqual(loop.step, steps) + + def test_train_with_pure_lsh_attention(self, backend=fastmath.Backend.JAX): + with fastmath.use_backend(backend): + # Prepare model and inputs + def model(mode="train"): + return models.ConfigurableTerraformer( + mode=mode, + d_model=16, + d_ff=16, + n_heads=2, + dropout=0.05, + n_decoder_layers=1, + n_encoder_layers=1, + input_vocab_size=256, + encoder_attention_type=_pure_lsh_self_attention_fn(), + encoder_decoder_attention_type=_pure_lsh_self_attention_fn(), + ) + + max_len = 128 + inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) + + steps = 1 + eval_steps = 1 + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Read checkpoint + model_file = os.path.join(output_dir, "model.pkl.gz") + + shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) + shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) + + model_predict = model(mode="predict") + model_predict.init_from_file( + model_file, weights_only=True, input_signature=(shape1l, shape11) + ) + + def test_train_with_mixed_lsh_attention(self, backend=fastmath.Backend.JAX): + with fastmath.use_backend(backend): + # Prepare model and inputs + + def model(mode="train"): + return models.ConfigurableTerraformer( + mode=mode, + d_model=16, + d_ff=16, + n_heads=2, + dropout=0.05, + n_decoder_layers=1, + n_encoder_layers=1, + input_vocab_size=256, + encoder_attention_type=_mixed_lsh_self_attention_fn(), + encoder_decoder_attention_type=_mixed_lsh_self_attention_fn(), + ) + + max_len = 128 + inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) + + steps = 1 + eval_steps = 1 + + # Train and evaluate + output_dir = self.create_tempdir().full_path + trainer_lib.train( + output_dir, + model=model, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + ) + + # Read checkpoint + model_file = os.path.join(output_dir, "model.pkl.gz") + + shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) + shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) + + model_predict = model(mode="predict") + model_predict.init_from_file( + model_file, weights_only=True, input_signature=(shape1l, shape11) + ) + + @parameterized.parameters(BACKENDS) + def test_train_fills_in_missing_eval_metrics(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + additional_eval_stream = trainer_lib.NamedStream( + # deliberately duplicating eval data + stream=inputs.eval_stream(1), + name="additional_eval_task", + ) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + loop = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + eval_frequency=1, + additional_eval_streams=[additional_eval_stream], + ) + + self.assertLen(loop.eval_tasks, 2) + eval_task_1, eval_task_2 = loop.eval_tasks + self.assertCountEqual(eval_task_1.metrics, eval_task_2.metrics) + self.assertCountEqual(eval_task_1.metric_names, eval_task_2.metric_names) + + @parameterized.named_parameters( + ("_%s" % short_name(backend), backend) for backend in BACKENDS + ) + def test_train_with_weights(self, backend): + with fastmath.use_backend(backend): + # Prepare model and inputs + n_classes = 4 + steps = 2 + eval_steps = 2 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes, with_weights=True) + + # Train and evaluate + output_dir = self.create_tempdir().full_path + state = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + steps=steps, + eval_steps=eval_steps, + ) + + # Assert total train steps + self.assertEqual(state.step, steps) + + @parameterized.parameters(BACKENDS) + def test_reset_twice(self, backend): + with fastmath.use_backend(backend): + n_classes = 4 + model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) + inputs = _test_inputs(n_classes) + + trainer = trainer_lib.Trainer( + model=model_fn, + loss_fn=tl.WeightedCategoryCrossEntropy(), + optimizer=trax_opt.SM3, + lr_schedule=lr.multifactor(), + inputs=inputs, + ) + + output_dir1 = self.create_tempdir(name="output_dir1").full_path + trainer.reset(output_dir1) + trainer.evaluate(1) + output_dir2 = self.create_tempdir(name="output_dir2").full_path + trainer.reset(output_dir2) + trainer.evaluate(1) + + def test_tf_xla_forced_compile(self): + # TODO(wangpeng): re-enable this test + self.skipTest("Needs --config=cuda to pass this test") + old_flag = fastmath.tf.tf_xla_forced_compile_enabled() + fastmath.tf.set_tf_xla_forced_compile(True) + self._test_train_eval_predict("tf") + fastmath.tf.set_tf_xla_forced_compile(old_flag) + + +class EpochsTest(absltest.TestCase): + def test_cuts_epoch_when_total_steps_reached(self): + epoch_steps = trainer_lib.epochs( + total_steps=5, steps_to_skip=0, epoch_steps=[1, 2, 3] + ) + self.assertEqual(list(epoch_steps), [1, 2, 2]) + + def test_skips_full_epoch(self): + epoch_steps = trainer_lib.epochs( + total_steps=4, steps_to_skip=2, epoch_steps=[2, 2] + ) + self.assertEqual(list(epoch_steps), [2]) + + def test_skips_part_of_epoch(self): + epoch_steps = trainer_lib.epochs( + total_steps=4, steps_to_skip=1, epoch_steps=[2, 2] + ) + self.assertEqual(list(epoch_steps), [1, 2]) + + +if __name__ == "__main__": + config.config_with_absl() + tf.compat.v1.enable_eager_execution() + absltest.main() diff --git a/tests/learning/supervised/training_test.py b/tests/learning/supervised/training_test.py new file mode 100644 index 000000000..b8f98b8fe --- /dev/null +++ b/tests/learning/supervised/training_test.py @@ -0,0 +1,760 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for supervised training: core classes and flows.""" + +import collections +import os +import time + +import numpy as np + +from absl.testing import absltest + +from tests.fastmath.jax.config import config +from trax import data, fastmath, optimizers +from trax import layers as tl +from trax.layers import base +from trax.learning.supervised import callbacks, training +from trax.models import transformer +from trax.utils import shapes, test_utils + + +class TrainingTest(absltest.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + + def test_loop_no_eval_task(self): + """Runs a training loop with no eval task(s).""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + training_session = training.Loop(model, [task]) + # Loop should initialize and run successfully, even with no eval task. + training_session.run(n_steps=5) + + def test_loop_checkpoint_low_metric(self): + """Runs a training loop that saves checkpoints for low metric values.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_metric = tl.L2Loss() + eval_task = training.EvalTask( + _very_simple_data(), [eval_metric], metric_names=["l2_loss"] + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + output_dir=tmp_dir, + eval_at=lambda step_n: step_n % 2 == 0, + checkpoint_at=lambda step_n: step_n % 2 == 0, + checkpoint_low_metric="l2_loss", + ) + call_counter = collections.Counter() + loop.save_checkpoint = lambda name: call_counter.update([name]) + loop.run(n_steps=10) + + # Eval metric steadily descends, so low checkpoint triggered all 5 times. + # High checkpoint not defined, so never triggered. + self.assertEqual(call_counter["model"], 5) + self.assertEqual(call_counter["lowest_l2_loss"], 5) + self.assertEqual(call_counter["highest_l2_loss"], 0) + + def test_loop_checkpoint_high_metric(self): + """Runs a training loop that saves checkpoints for high metric values.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_metric = tl.L2Loss() + eval_task = training.EvalTask( + _very_simple_data(), [eval_metric], metric_names=["l2_loss"] + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + output_dir=tmp_dir, + eval_at=lambda step_n: step_n % 2 == 0, + checkpoint_at=lambda step_n: step_n % 2 == 0, + checkpoint_high_metric="l2_loss", + ) + call_counter = collections.Counter() + loop.save_checkpoint = lambda name: call_counter.update([name]) + loop.run(n_steps=10) + + # Eval metric steadily descends, so high checkpoint triggered only once. + # Low checkpoint not defined, so never triggered. + self.assertEqual(call_counter["model"], 5) + self.assertEqual(call_counter["lowest_l2_loss"], 0) + self.assertEqual(call_counter["highest_l2_loss"], 1) + + def test_train_dense_layer(self): + """Trains a very simple network on a very simple task.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + training_session = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=15) + self.assertEqual(15, training_session.step) + training_session.run(n_steps=5) + self.assertEqual(20, training_session.step) + + def test_loop_with_initialized_model(self): + """Check that loop does not re-initialize an already initialized model.""" + model = tl.Serial(tl.Dense(1)) + example_data = next(_very_simple_data()) + model.init(example_data) + w = model.weights[0][0] + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + ) + self.assertEqual(0, loop.step) + self.assertEqual(loop.model.weights[0][0], w) + + def test_train_save_restore_dense(self): + """Saves and restores a checkpoint to check for equivalence.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + train_data = data.Serial( + lambda _: _very_simple_data(), data.CountAndSkip("simple_data") + ) + task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(0.0001)) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + + def _make_model_and_session(): + m = tl.Serial(tl.Dense(1)) + ts = training.Loop( + m, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + return m, ts + + model, training_session = _make_model_and_session() + self.assertEqual(0, training_session.step) + training_session.run(n_steps=1) + training_session.save_checkpoint("model") + self.assertEqual(data.inputs.data_counters["simple_data"], 2) + data.inputs.data_counters["simple_data"] = 0 # reset manually + self.assertEqual(data.inputs.data_counters["simple_data"], 0) # check + model2, training_session2 = _make_model_and_session() + self.assertEqual(data.inputs.data_counters["simple_data"], 2) # restored + + x = np.ones((8, 1)) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertEqual(str(y1), str(y2)) + + training_session2.run(n_steps=1) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertNotEqual(str(y1), str(y2)) + + slots1 = training_session._trainer_per_task[0].slots + slots2 = training_session2._trainer_per_task[0].slots + np.testing.assert_array_equal(slots1, slots2) + + def test_train_save_restore_sharded(self): + """Saves and restores a sharded checkpoint to check for equivalence.""" + if fastmath.local_device_count() < 2: + return # multi-accelerator only + base.N_WEIGHTS_SHARDS = fastmath.local_device_count() + train_data = data.Serial( + lambda _: _very_simple_data(2, 2), data.CountAndSkip("simple_data") + ) + task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(0.0001)) + eval_task = training.EvalTask( + _very_simple_data(2, 2), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + + def _make_model_and_session(): + m = tl.Serial(tl.Dense(2)) + ts = training.Loop( + m, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + return m, ts + + _, training_session = _make_model_and_session() + self.assertEqual(0, training_session.step) + training_session.run(n_steps=1) + training_session.save_checkpoint("model") + _, training_session2 = _make_model_and_session() + training_session2.run(n_steps=1) + base.N_WEIGHTS_SHARDS = 1 + + def test_train_save_restore_transformer(self): + """Saves and restores a checkpoint to check for equivalence.""" + vocab_size = 8 + task = training.TrainTask( + _very_simple_transformer_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_transformer_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + + def _make_model_and_session(): + m = transformer.TransformerLM( + vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.0 + ) + ts = training.Loop( + m, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + return m, ts + + model, training_session = _make_model_and_session() + self.assertEqual(0, training_session.step) + training_session.run(n_steps=1) + training_session.save_checkpoint("model") + model2, training_session2 = _make_model_and_session() + + x = np.ones((2, 2)).astype(np.int32) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertEqual(str(y1), str(y2)) + + training_session2.run(n_steps=1) + y1 = model(x, rng=fastmath.random.get_prng(0)) + y2 = model2(x, rng=fastmath.random.get_prng(0)) + self.assertNotEqual(str(y1), str(y2)) + + def test_train_dense_layer_with_momentum(self): + """Trains with an optimizer that has slots / requires initialization.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Momentum(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["Momentum.L2Loss"], + ) + training_session = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=20) + self.assertEqual(20, training_session.step) + + def test_train_dense_layer_evals(self): + """Trains a very simple network on a very simple task, 2 epochs.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), + [tl.L2Loss()], # deliberately re-using training data + ) + training_session = training.Loop( + model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: False + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=10) + self.assertEqual(10, training_session.step) + training_session.run_evals() + self.assertEqual(10, training_session.step) # Unchanged + + def test_summaries_are_written(self): + """Training writes down metrics when writing is turned on.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + metric_names=["SGD.L2Loss"], + ) + tmp_dir = self.create_tempdir().full_path + training_session = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + expected_train_metric_dir = os.path.join(tmp_dir, "train") + expected_eval_metric_dir = os.path.join(tmp_dir, "eval") + for directory in [expected_train_metric_dir, expected_eval_metric_dir]: + self.assertFalse( + os.path.isdir(directory), "Failed for directory %s." % directory + ) + training_session.run(n_steps=15) + time.sleep(1) # wait for the files to be closed + for directory in [expected_train_metric_dir, expected_eval_metric_dir]: + self.assertTrue( + os.path.isdir(directory), "Failed for directory %s." % directory + ) + self.assertEqual( + 1, _count_files(directory), "Failed for directory %s." % directory + ) + training_session.run(n_steps=5) + time.sleep(1) # wait for the files to be closed + for directory in [expected_train_metric_dir, expected_eval_metric_dir]: + self.assertEqual( + 2, _count_files(directory), "Failed for directory %s." % directory + ) + + def test_restores_step(self): + """Training restores step from directory where it saved it.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + loop2 = training.Loop(model, [task], output_dir=tmp_dir) + self.assertEqual(4, loop2.step) + + def test_restores_memory_efficient_from_standard(self): + """Training restores step from directory where it saved it.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + model = tl.Serial(tl.Dense(4), tl.Dense(1)) + task_std = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Adam(0.0001) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task_std], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + task_memeff = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Adam + ) + loop2 = training.Loop( + model, [task_memeff], output_dir=tmp_dir, use_memory_efficient_trainer=True + ) + loop2.run(2) + self.assertEqual(6, loop2.step) + + def test_restores_from_smaller_model(self): + """Training restores from a checkpoint created with smaller model.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + model1 = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.Adam(0.01) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model1, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(2) + model2 = tl.Serial(tl.Dense(1), tl.Dense(1)) + loop2 = training.Loop(model2, [task], output_dir=tmp_dir) + self.assertEqual(2, loop2.step) + + def test_restore_fails_different_model(self): + """Training restores from a checkpoint created with smaller model.""" + model1 = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model1, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(2) + model2 = tl.Serial(tl.Dense(2)) + with self.assertRaises(IndexError): + training.Loop(model2, [task], output_dir=tmp_dir) + + def test_restores_step_bfloat16(self): + """Training restores step from directory where it saved it, w/ bfloat16.""" + self.skipTest("Broken by https://github.com/google/jax/pull/11234") + model = tl.Serial(tl.Dense(1, use_bfloat16=True)) + # We'll also use Adafactor with bfloat16 to check restoring bfloat slots. + opt = optimizers.Adafactor(0.01, do_momentum=True, momentum_in_bfloat16=True) + task = training.TrainTask(_very_simple_data(), tl.L2Loss(), opt) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + loop2 = training.Loop(model, [task], output_dir=tmp_dir) + self.assertEqual(4, loop2.step) + loop2.run(2) # check that continued training works + self.assertEqual(6, loop2.step) + + def test_restores_step_sharded(self): + """Training restores step from directory where it saved it, sharded.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + use_memory_efficient_trainer=True, + ) + loop.run(4) + loop2 = training.Loop( + model, [task], output_dir=tmp_dir, use_memory_efficient_trainer=True + ) + self.assertEqual(4, loop2.step) + + def test_restores_step_sharded_bfloat16(self): + """Training restores step from where it saved it, sharded and bfloat16.""" + model = tl.Serial(tl.Dense(1, use_bfloat16=True)) + task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + use_memory_efficient_trainer=True, + ) + loop.run(4) + loop2 = training.Loop( + model, [task], output_dir=tmp_dir, use_memory_efficient_trainer=True + ) + self.assertEqual(4, loop2.step) + loop2.run(2) # check that continued training works + self.assertEqual(6, loop2.step) + + def test_restores_history(self): + """Training restores history from directory where it saved it.""" + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), + [tl.L2Loss()], # deliberately re-using training data + ) + tmp_dir = self.create_tempdir().full_path + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n % 2 == 0, + checkpoint_at=lambda step_n: step_n % 2 == 0, + output_dir=tmp_dir, + ) + loop.run(4) + loop2 = training.Loop(model, [task], output_dir=tmp_dir) + self.assertLen(loop2.history.modes, 2) + self.assertLen(loop2.history.metrics_for_mode("train"), 6) + self.assertLen(loop2.history.metrics_for_mode("eval"), 1) + for mode, metric in [ + ("train", "metrics/L2Loss"), + ("train", "training/learning_rate"), + ("train", "training/steps per second"), + ("train", "training/gradients_l2"), + ("train", "training/loss"), + ("train", "training/weights_l2"), + ("eval", "metrics/L2Loss"), + ]: + self.assertLen(loop2.history.get(mode, metric), 1) + self.assertEqual(2, loop2.history.get(mode, metric)[0][0]) + + def test_trains_on_two_tasks(self): + """Trains a very simple network on two very simple tasks.""" + model = tl.Serial(tl.Dense(3), tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + eval_task = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + ) + training_session = training.Loop( + model, + tasks=(task, task), + eval_tasks=(eval_task, eval_task), + which_task=lambda step_n: step_n % 2, + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=15) + self.assertEqual(15, training_session.step) + training_session.run(n_steps=5) + self.assertEqual(20, training_session.step) + + def test_train_one_task_eval_two_tasks(self): + """Trains a very simple network on one task and evaluates on two tasks.""" + model = tl.Serial(tl.Dense(3), tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + export_prefix_1 = "eval_1" + eval_task_1 = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + export_prefix=export_prefix_1, + ) + export_prefix_2 = "eval_2" + eval_task_2 = training.EvalTask( + _very_simple_data(), # deliberately re-using training data + [tl.L2Loss()], + export_prefix=export_prefix_2, + ) + training_session = training.Loop( + model, + tasks=(task,), + eval_tasks=(eval_task_1, eval_task_2), + ) + self.assertEqual(0, training_session.step) + training_session.run(n_steps=5) + self.assertEqual(5, training_session.step) + export_prefixes = [task.export_prefix for task in training_session.eval_tasks] + self.assertCountEqual([export_prefix_1, export_prefix_2], export_prefixes) + + def test_can_predict_with_trained_model(self): + model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) + train_tasks, eval_tasks = [], [] + for output_dim in [1, 2]: + # The head we select from the model: 0 for output_dim 1 and 1 for 2. + head_index = output_dim - 1 + train_tasks.append( + training.TrainTask( + _very_simple_data(output_dim), + tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()), + optimizers.SGD(0.01), + ) + ) + eval_tasks.append( + training.EvalTask( + _very_simple_data(output_dim), # deliberately re-use training data + [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())], + ) + ) + tmp_dir = self.create_tempdir().full_path + training_session = training.Loop( + model, + tasks=train_tasks, + eval_tasks=eval_tasks, + checkpoint_at=lambda step_n: step_n == 1, + output_dir=tmp_dir, + which_task=lambda step_n: step_n % 2, + ) + training_session.run(n_steps=2) + + trained_model = training_session.eval_model + inp = next(_very_simple_data())[0] + out = trained_model(inp) + self.assertEqual( + shapes.signature(out), + (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), + ) + + def test_train_memory_efficient(self): + """Trains a large network in a memory-efficient way.""" + # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU + # and CPU when you run it locally, but it's too big for unit-testing. + ram_limited = True # Set to False to run this test locally. + if fastmath.global_device_count() == 1 and ram_limited: + return + + # Create the model. + n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram + model = tl.Serial( + tl.Embedding(9, 16 * 1024), + tl.Dup(), + [ + [tl.ReversibleHalfResidual(tl.Dense(16 * 1024)), tl.ReversibleSwap()] + for _ in range(n_layers) + ], + tl.Concatenate(), + tl.Dense(9), + ) + + # Create inputs. + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + + def _data_gen(): + while True: + yield labeled_batch + + # Run training. + loss_layer = tl.WeightedCategoryCrossEntropy() + task = training.TrainTask(_data_gen(), loss_layer, optimizers.Adafactor) + eval_task = training.EvalTask(_data_gen(), [tl.WeightedCategoryCrossEntropy()]) + loop = training.Loop( + model, + [task], + eval_tasks=[eval_task], + eval_at=lambda step_n: step_n == 2, + use_memory_efficient_trainer=True, + ) + self.assertEqual(0, loop.step) + loop.run(n_steps=2) + self.assertEqual(2, loop.step) + + def test_initializes_step_callbacks_with_loop_instance(self): + """Runs a training loop, asserting that callbacks are initialized.""" + + class ActualLoop: + # Wrapper object to make the Loop reference mutable. + loop = None + + class TestCallback(callbacks.TrainingStepCallback): + def __init__(self, loop): + super().__init__(loop) + ActualLoop.loop = loop + + def call_at(self, step): + return False + + def on_step_begin(self, step): + del step + + def on_step_end(self, step): + del step + + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + expected_loop = training.Loop(model, [task], callbacks=[TestCallback]) + self.assertIs(ActualLoop.loop, expected_loop) + + def test_calls_step_callbacks(self): + """Runs a training loop, asserting that callbacks are called.""" + call_at_steps = [1, 3, 4] + begin_steps = [] + end_steps = [] + test_case = self + + class TestCallback(callbacks.TrainingStepCallback): + def call_at(self, step): + return step in call_at_steps + + def on_step_begin(self, step): + begin_steps.append(step) + + def on_step_end(self, step): + # Assert that on_step_begin() was called before. + test_case.assertIn(step, begin_steps) + end_steps.append(step) + + model = tl.Serial(tl.Dense(1)) + task = training.TrainTask( + _very_simple_data(), tl.L2Loss(), optimizers.SGD(0.01) + ) + loop = training.Loop(model, [task], callbacks=[TestCallback]) + loop.run(n_steps=5) + + # Assert that the callback has been called at the appropriate steps. + self.assertEqual(begin_steps, call_at_steps) + self.assertEqual(end_steps, call_at_steps) + + +def _very_simple_data(output_dim=1, input_dim=1): + """Returns stream of labeled data that maps small integers to constant pi.""" + inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch + inputs_batch = np.concatenate([inputs_batch] * input_dim, axis=1) + targets_batch = np.pi * np.ones((8, output_dim)) + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + while True: + yield labeled_batch + + +def _very_simple_transformer_data(): + """ "Returns stream of labeled data that maps small integers to constant pi.""" + inputs_batch = np.ones((2, 2)).astype(np.int32) + targets_batch = np.ones((2, 2, 8)).astype(np.int32) + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + while True: + yield labeled_batch + + +def _count_files(path): + """Returns number of files in a given directory.""" + return len( + [ + filename + for filename in os.listdir(path) + if os.path.isfile(os.path.join(path, filename)) + ] + ) + + +if __name__ == "__main__": + config.config_with_absl() + absltest.main() diff --git a/tests/models/atari_cnn_test.py b/tests/models/atari_cnn_test.py new file mode 100644 index 000000000..704dc1f83 --- /dev/null +++ b/tests/models/atari_cnn_test.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.models.atari_cnn.""" + +import functools +import operator as op + +import numpy as np + +from tensorflow import test + +from trax.models import atari_cnn +from trax.utils.shapes import ShapeDtype + + +class AtariCnnTest(test.TestCase): + def test_computes(self): + hidden_size = (4, 4) + output_size = 6 + + model = atari_cnn.AtariCnn(hidden_sizes=hidden_size, output_size=output_size) + + B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name + input_signature = ShapeDtype((1, 1) + OBS) + + _, _ = model.init(input_signature) + x = np.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( + B, T + 1, *OBS + ) + y = model(x) + self.assertEqual((B, T + 1, output_size), y.shape) + + +class FrameStackMLPTest(test.TestCase): + def test_computes(self): + hidden_size = (4, 4) + output_size = 6 + model = atari_cnn.FrameStackMLP( + hidden_sizes=hidden_size, output_size=output_size + ) + B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name + input_signature = ShapeDtype((1, 1, OBS)) + + _, _ = model.init(input_signature) + x = np.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) + y = model(x) + + self.assertEqual((B, T + 1, output_size), y.shape) + + +if __name__ == "__main__": + test.main() diff --git a/tests/models/gnn_test.py b/tests/models/gnn_test.py new file mode 100644 index 000000000..6645e1dd3 --- /dev/null +++ b/tests/models/gnn_test.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Graph Neural Network models.""" + +import numpy as np + +from absl.testing import absltest + +from trax.models import gnn +from trax.utils import shapes + + +class GNNTest(absltest.TestCase): + def setUp(self): + super().setUp() + base_adj = np.array( + [ + [0, 1, 0, 0], + [1, 0, 1, 0], + [0, 1, 0, 1], + [0, 0, 1, 0], + ], + dtype=np.float32, + ) + self.adj = np.stack([base_adj, base_adj]) + self.features = np.ones((2, 4, 3), dtype=np.float32) + self.edge_features = np.ones((2, 4, 4, 1), dtype=np.float32) + + def test_graph_conv_net_forward_shape(self): + model = gnn.GraphConvNet(hidden_sizes=(5, 2)) + _, _ = model.init([shapes.signature(self.features), shapes.signature(self.adj)]) + out_features, out_adj = model([self.features, self.adj]) + self.assertEqual(out_features.shape, (2, 4, 2)) + self.assertEqual(out_adj.shape, (2, 4, 4)) + + def test_graph_attention_net_forward_shape(self): + model = gnn.GraphAttentionNet(hidden_sizes=(5, 2), num_heads=2) + _, _ = model.init([shapes.signature(self.features), shapes.signature(self.adj)]) + out_features, out_adj = model([self.features, self.adj]) + self.assertEqual(out_features.shape, (2, 4, 2)) + self.assertEqual(out_adj.shape, (2, 4, 4)) + + def test_graph_edge_net_forward_shape(self): + model = gnn.GraphEdgeNet(node_sizes=(5, 2), edge_sizes=(3, 2)) + model.init( + [ + shapes.signature(self.features), + shapes.signature(self.edge_features), + shapes.signature(self.adj), + ] + ) + out_features, out_edges, out_adj = model( + [self.features, self.edge_features, self.adj] + ) + self.assertEqual(out_features.shape, (2, 4, 2)) + self.assertEqual(out_edges.shape, (2, 4, 4, 2)) + self.assertEqual(out_adj.shape, (2, 4, 4)) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/models/mlp_test.py b/tests/models/mlp_test.py similarity index 69% rename from trax/models/mlp_test.py rename to tests/models/mlp_test.py index 40d335610..f2497835c 100644 --- a/trax/models/mlp_test.py +++ b/tests/models/mlp_test.py @@ -15,24 +15,22 @@ """Tests for MLP.""" -from absl.testing import absltest import numpy as np -from trax import fastmath -from trax import shapes +from absl.testing import absltest + from trax.models import mlp +from trax.utils import shapes class MLPTest(absltest.TestCase): - - def test_mlp_forward_shape(self): - model = mlp.MLP(layer_widths=(32, 16, 8)) - x = np.ones((7, 28, 28, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (7, 8)) - + def test_mlp_forward_shape(self): + model = mlp.MLP(layer_widths=(32, 16, 8)) + x = np.ones((7, 28, 28, 3)).astype(np.float32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (7, 8)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/trax/models/neural_gpu_test.py b/tests/models/neural_gpu_test.py similarity index 69% rename from trax/models/neural_gpu_test.py rename to tests/models/neural_gpu_test.py index 0eaa77dbf..488b1365b 100644 --- a/trax/models/neural_gpu_test.py +++ b/tests/models/neural_gpu_test.py @@ -15,22 +15,22 @@ """Tests for trax.models.neural_gpu.""" -from absl.testing import absltest import numpy as np -from trax import shapes +from absl.testing import absltest + from trax.models import neural_gpu +from trax.utils import shapes class NeuralGPUTest(absltest.TestCase): - - def test_ngpu(self): - model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=22) - x = np.ones((3, 5, 7)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, 7, 22)) + def test_ngpu(self): + model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=22) + x = np.ones((3, 5, 7)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 5, 7, 22)) -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/reformer/reformer_e2e_test.py b/tests/models/reformer/reformer_e2e_test.py new file mode 100644 index 000000000..05e37139e --- /dev/null +++ b/tests/models/reformer/reformer_e2e_test.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End to end test for Reformer.""" +import os + +import gin + +from absl.testing import absltest + +from trax.data.encoder import encoder as encoder +from trax.learning.supervised import trainer_lib +from trax.utils import test_utils + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath(os.path.join(pkg_dir, "../../../resources/data/testdata")) +_CONFIG_DIR = os.path.normpath( + os.path.join(pkg_dir, "../../../resources/supervised/configs") +) + + +class ReformerE2ETest(absltest.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + gin.add_config_file_search_path(_CONFIG_DIR) + test_utils.ensure_flag("test_tmpdir") + + def test_reformer_wmt_ende(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + tokenizer = encoder.SubwordTextEncoder( + filename=os.path.join( + _TESTDATA, "vocab.translate_ende_wmt32k.32768.subwords" + ) + ) + + gin.parse_config_file("reformer_wmt_ende.gin") + + gin.bind_parameter("data_streams.data_dir", _TESTDATA) + gin.bind_parameter("wmt_preprocess.tokenizer", tokenizer) + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("Reformer.n_encoder_layers", n_layers) + gin.bind_parameter("Reformer.n_decoder_layers", n_layers) + gin.bind_parameter("Reformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + def test_reformer_copy(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + d_model = 32 + + gin.parse_config_file("reformer_copy.gin") + + gin.bind_parameter("data_streams.data_dir", _TESTDATA) + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ReformerLM.n_layers", n_layers) + gin.bind_parameter("ReformerLM.d_ff", d_ff) + gin.bind_parameter("ReformerLM.d_model", d_model) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/reformer/reformer_test.py b/tests/models/reformer/reformer_test.py new file mode 100644 index 000000000..bda3f0c18 --- /dev/null +++ b/tests/models/reformer/reformer_test.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Reformer models.""" + +import functools + +import gin +import numpy as np + +from absl.testing import absltest, parameterized + +from trax import fastmath +from trax import layers as tl +from trax.models.reformer import reformer +from trax.utils import shapes + +BACKENDS = [fastmath.Backend.JAX] + + +def short_name(b): + if b == fastmath.Backend.JAX: + return "jax" + else: + return "tf" + + +class ReformerTest(parameterized.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def _timebin_self_attention_fn(self, use_reference_code=False): + return functools.partial( + tl.SelfAttention, + attention_dropout=0.05, + chunk_len=64, + n_chunks_before=1, + n_parallel_heads=1, + use_reference_code=use_reference_code, + ) + + def test_reformer_lm_forward_shape(self): + vocab_size = 16 + model = reformer.ReformerLM( + vocab_size, + d_model=32, + d_ff=64, + d_attention_key=16, + d_attention_value=16, + n_layers=1, + n_heads=2, + max_len=16, + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + @absltest.skip + def test_reformer_lm_lsh(self): + """ + Problems with: + - res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)), + probably dropout_shared_axes should be [] + - Scan in Chunk res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) shape assertion is wrong + """ + lsh_self_attention = self._lsh_self_attention_fn() + timebin_self_attention = self._timebin_self_attention_fn() + + model = reformer.ReformerLM( + vocab_size=256, + d_model=256, + d_ff=512, + d_attention_key=64, + d_attention_value=64, + n_layers=2, + n_heads=2, + dropout=0.05, + max_len=65536, + attention_type=[timebin_self_attention, lsh_self_attention], + pos_axial_shape=(256, 256), + pos_d_axial_embs=(64, 192), + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=8192, + mode="train", + ) + x = (np.ones((1, 65536)).astype(np.int32), np.ones((1, 65536)).astype(np.int32)) + weights, state = model.init(shapes.signature(x)) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits, new_state = model.pure_fn(x, weights, state, rng) + loss = fastmath.numpy.mean(logits[..., 0]) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + weights, state, logits = mock_training_step( + x, weights, state, fastmath.random.get_prng(0) + ) + self.assertEqual(logits.shape, (1, 65536, 256)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/configurable_transformer_test.py b/tests/models/research/configurable_transformer_test.py new file mode 100644 index 000000000..e31fdafef --- /dev/null +++ b/tests/models/research/configurable_transformer_test.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer models.""" + +import functools + +import numpy as np + +from absl.testing import absltest, parameterized + +from tests.layers import test_utils +from trax import fastmath +from trax import layers as tl +from trax.models.research import configurable_transformer as ct +from trax.utils import shapes + + +class ConfigurableTransformerTest(parameterized.TestCase): + def test_transformer_lm_forward_shape(self): + vocab_size = 16 + model = ct.ConfigurableTransformerLM( + vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2 + ) + x = np.ones((3, 5)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 5, vocab_size)) + + def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + model = ct.ConfigurableTransformer( + input_vocab_size, + output_vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + ) + xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + y, _ = model(xs) + + vocab_size = output_vocab_size or input_vocab_size + self.assertEqual(y.shape, (3, 5, vocab_size)) + + @parameterized.named_parameters( + ("same_vocab", 16, None), ("same_size", 16, 16), ("different_size", 16, 50) + ) + def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + """Run the Transformer forward and check output shape.""" + self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) + + def test_dot_product_causal_attention_fast_inference(self): + self._test_fast_inference(length=5) + + def _test_fast_inference(self, length): + with fastmath.use_backend(fastmath.Backend.JAX): + model_fn = functools.partial( + ct.ConfigurableTransformerLM, + vocab_size=16, + d_model=4, + d_ff=8, + n_layers=2, + n_heads=2, + ) + batch_size = 2 + inp = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_equals_predict(inp, model_fn) + + def test_sparse_configurable_transformer_fast_inference(self): + self._test_sparse_fast_inference(length=5) + + def _test_sparse_fast_inference(self, length): + with fastmath.use_backend(fastmath.Backend.JAX): + vocab_size = 16 + d_model = 4 + batch_size = 2 + + encoder_decoder_attention_type = functools.partial( + tl.MultiplicativeConvCausalAttention, + sparsity=2, + length_kernel_size=1, + ) + + model_fn = functools.partial( + ct.ConfigurableTransformer, + input_vocab_size=vocab_size, + d_model=d_model, + d_ff=8, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + loss_sparsity=2, + ff_sparsity=2, + encoder_decoder_attention_type=encoder_decoder_attention_type, + ff_use_sru=(1, 4), + ) + + inp = np.random.randint(vocab_size, size=(batch_size, length)) + out = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_equals_predict((inp, out), model_fn, seq_tensor=1) + + @parameterized.named_parameters( + ("positional_encoding", None), + ("fixed_base_positional_encoding", "fixed-base"), + ("infinite_positional_encoding", "infinite"), + ("infinite_affine_positional_encoding", "infinite-affine"), + ("axial_positional_encoding", (2, 16)), + ) + def test_positional_encoder(self, pos_axial_shape): + # dim should divide FixedBasePositionalEncoding.n_digits + batch, length, dim = 2, 32, 8 + input_shape = (batch, length, dim) + vocab_size = 32 + x = np.random.randint(0, vocab_size - 1, input_shape) + # should sum to dim + pos_d_axial_embs = (4, 4) + + positional_encoding = ct.PositionalEncoder( + "train", + dropout=0.1, + max_len=length, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + ) + _, _ = positional_encoding.init(shapes.signature(x)) + y = positional_encoding(x) + self.assertEqual(y.shape, input_shape) + + @parameterized.named_parameters( + ("input_vocab_size_only", 32, None), + ("output_vocab_size_only", None, 32), + ("same_input_output_vocab_size", 32, 32), + ("different_input_output_vocab_size", 32, 16), + ) + def test_embedding_and_positional_encodings( + self, input_vocab_size, output_vocab_size + ): + d_model = 16 + max_len = 32 + batch = 2 + input_shape = (batch, max_len) + output_vocab_size_expected = output_vocab_size or input_vocab_size + x_out = np.random.randint(0, output_vocab_size_expected - 1, input_shape) + if input_vocab_size is None: + x_in = np.random.uniform(size=list(input_shape) + [2]) + else: + x_in = np.random.randint(0, input_vocab_size - 1, input_shape) + + ( + in_encoder, + out_encoder, + output_vocab_size_result, + ) = ct.EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + "train", + 0.1, + [-2], + max_len, + output_vocab_size=output_vocab_size, + pos_axial_shape=None, + pos_d_axial_embs=None, + ) + + self.assertEqual(output_vocab_size_result, output_vocab_size_expected) + + model_in = tl.Serial(in_encoder) + model_out = tl.Serial(out_encoder) + + model_in.init(shapes.signature(x_in)) + model_out.init(shapes.signature(x_out)) + + y = model_in(x_in) + self.assertEqual(y.shape, input_shape + (d_model,)) + + y = model_out(x_out) + self.assertEqual(y.shape, input_shape + (d_model,)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/hourglass_test.py b/tests/models/research/hourglass_test.py new file mode 100644 index 000000000..b307e2e69 --- /dev/null +++ b/tests/models/research/hourglass_test.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Hourglass model.""" + +import gin +import jax +import numpy as np + +from absl.testing import absltest, parameterized + +import trax.layers.research.resampling as resampling +import trax.models.research.hourglass as hourglass + +from trax import fastmath +from trax import layers as tl +from trax.utils import shapes + + +class HourglassTest(parameterized.TestCase): + def _check_forward_shape(self, model, input_shape, output_vocab_size): + x = np.ones(input_shape).astype(np.int32) + model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (*input_shape, output_vocab_size)) + + def test_hourglass_lm_forward_shape(self): + d_model = 16 + vocab_size = 7 + model = hourglass.HourglassLM( + vocab_size, + hierarchy="2@3 2@6 2@3", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + batch_size, seq_len = 3, 24 + self._check_forward_shape( + model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size + ) + + def test_lsh_attention_in_vanilla(self): + d_model = 16 + vocab_size = 7 + + gin.bind_parameter( + "PureLSHSelfAttentionWrapper.pure_lsh_implementation", + tl.PureLSHSelfAttention, + ) + gin.bind_parameter("PureLSHSelfAttention.chunk_len", 2) + + model = hourglass.HourglassLM( + vocab_size, + hierarchy="2@3", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + vanilla_attn_type=tl.PureLSHSelfAttentionWrapper, + downsampling_fn=resampling.LinearPooling, + upsampling_fn=resampling.LinearUpsampling, + ) + + batch_size, seq_len = 3, 12 + self._check_forward_shape( + model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size + ) + + def _test_autoregressive_property(self, model, input_shape, output_vocab_size): + rng_1 = jax.random.PRNGKey(0) + rng_2 = jax.random.PRNGKey(1) + + def _get_output_logits(unitialized_eval_model: tl.Layer, x): + input_signature = shapes.signature(x) + unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False) + + output_logits, *_ = unitialized_eval_model(x, rng=rng_1) + return output_logits + + def check_autoregressive_property(model): + with fastmath.use_backend(fastmath.Backend.JAX): + x_1 = jax.random.randint(rng_1, input_shape, 0, output_vocab_size) + y_1 = _get_output_logits(model, x_1) + + x_2 = jax.random.randint(rng_2, input_shape, 0, output_vocab_size) + + for i in range(input_shape[1]): + masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1) + + y_2 = _get_output_logits(model, masked_x_2) + self.assertEqual(y_2.shape[0], input_shape[1]) + np.testing.assert_array_almost_equal(y_1[: i + 1], y_2[: i + 1]) + + check_autoregressive_property(model) + + def test_hourglass_lm_autoregressive_property(self): + d_model = 8 + vocab_size = 26 + + model_single_stage = hourglass.HourglassLM( + vocab_size, + hierarchy="2@4", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + model_multi_stage = hourglass.HourglassLM( + vocab_size, + hierarchy="2@3 2@6 2@3", + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + input_shape = (1, 12) + self._test_autoregressive_property( + model_single_stage, input_shape, output_vocab_size=vocab_size + ) + self._test_autoregressive_property( + model_multi_stage, input_shape, output_vocab_size=vocab_size + ) + + def test_parse_hourglass_hierarchy(self): + self.assertEqual(hourglass._parse_hierarchy("6@3"), ([6], [3])) + self.assertEqual( + hourglass._parse_hierarchy("3@2 2@6 5@24 2@6 3@2"), ([3, 2, 5], [2, 3, 4]) + ) + self.assertRaises(ValueError, hourglass._parse_hierarchy, "1@2 2@3 1@2") + self.assertRaises(ValueError, hourglass._parse_hierarchy, "1@2 2@3") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/layerdrop_transformer_test.py b/tests/models/research/layerdrop_transformer_test.py new file mode 100644 index 000000000..044a668fc --- /dev/null +++ b/tests/models/research/layerdrop_transformer_test.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Reformer models.""" + +import numpy as np + +from absl.testing import absltest + +from trax.models.research import layerdrop_transformer +from trax.utils import shapes + + +class SkippingTransformerTest(absltest.TestCase): + def test_skipping_transformer_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.SkippingTransformerLM( + vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16 + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +class LayerDropTransformerTest(absltest.TestCase): + def test_layerdrop_transformer_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.LayerDropTransformerLM( + vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16 + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + def test_layerdrop_layerwise_skip_fraction(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.LayerDropTransformerLM( + vocab_size, + d_model=16, + d_ff=32, + n_layers=2, + n_heads=2, + max_len=16, + skip_fraction=[0.2, 0.8], + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +class EveryOtherLayerDropTransformerTest(absltest.TestCase): + def test_everyother_layerdrop_transformer_forward(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = layerdrop_transformer.EveryOtherLayerDropTransformerLM( + vocab_size, + d_model=16, + d_ff=32, + n_layers=2, + n_heads=2, + max_len=16, + skip_mode="1half", + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/rezero_test.py b/tests/models/research/rezero_test.py new file mode 100644 index 000000000..81609d72b --- /dev/null +++ b/tests/models/research/rezero_test.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ReZero models.""" + +import numpy as np + +from absl.testing import absltest + +from trax import layers as tl +from trax.models.research import rezero +from trax.utils import shapes + + +class ResidualZeroTest(absltest.TestCase): + def test_residual_layer_forward(self): + """Tests that the forward pass runs and returns the expected shape.""" + model = rezero.ResidualZero(tl.Dense(5)) + x = [np.arange(5).astype(np.float32)] + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0, 3.0, 4.0]) + + +class ReZeroTransformerLMTest(absltest.TestCase): + def test_rezero_lm_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = rezero.ReZeroTransformerLM( + vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2, max_len=16 + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +class ReZeroTransformerTest(absltest.TestCase): + def test_rezero_forward_shape(self): + """Tests that the forward pass runs and returns the expected shape.""" + vocab_size = 16 + model = rezero.ReZeroTransformer( + vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + max_len=16, + ) + xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + ys = model(xs) + self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/rse_test.py b/tests/models/research/rse_test.py new file mode 100644 index 000000000..965ab58b2 --- /dev/null +++ b/tests/models/research/rse_test.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Residual Shuffle-Exchange Networks.""" + +import numpy as np + +from absl.testing import absltest + +from trax.models.research import rse +from trax.utils import shapes + + +class RSETest(absltest.TestCase): + def test_rsu_forward_shape(self): + batch_size = 3 + seq_len = 32 + d_model = 17 + model = rse.ResidualSwitchUnit(d_model=d_model, dropout=0.1, mode="train") + x = np.ones((batch_size, seq_len, d_model)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (batch_size, seq_len, d_model)) + + def test_shuffle_layer(self): + shuffle_layer = rse.ShuffleLayer() + x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) + print(x.shape) + _, _ = shuffle_layer.init(shapes.signature(x)) + y = shuffle_layer(x) + expected_output = np.array([[[0], [2], [4], [6], [1], [3], [5], [7]]]) + self._assert_equal_tensors(y, expected_output) + + def test_shuffle_layer_log_times_is_identity(self): + seq_len = 8 + d_model = 17 + shuffle_layer = rse.ShuffleLayer() + x = _input_with_indice_as_values(seq_len, d_model) + _, _ = shuffle_layer.init(shapes.signature(x)) + y = x + for _ in range(int(np.log2(seq_len))): + y = shuffle_layer(y) + self._assert_equal_tensors(x, y) + + def test_reverse_shuffle_layer(self): + reverse_shuffle_layer = rse.ReverseShuffleLayer() + x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) + print(x.shape) + _, _ = reverse_shuffle_layer.init(shapes.signature(x)) + y = reverse_shuffle_layer(x) + expected_output = np.array([[[0], [4], [1], [5], [2], [6], [3], [7]]]) + self._assert_equal_tensors(y, expected_output) + + def test_reverse_shuffle_layer_log_times_is_identity(self): + seq_len = 8 + d_model = 17 + reverse_shuffle_layer = rse.ReverseShuffleLayer() + x = _input_with_indice_as_values(seq_len, d_model) + _, _ = reverse_shuffle_layer.init(shapes.signature(x)) + y = x + for _ in range(int(np.log2(seq_len))): + y = reverse_shuffle_layer(y) + self._assert_equal_tensors(x, y) + + def test_rse_forward_shape(self): + vocab_size = 12 + seq_len = 32 + model = rse.ResidualShuffleExchange( + vocab_size=vocab_size, + d_model=17, + dropout=0.1, + input_dropout=0.05, + mode="train", + ) + x = np.ones((3, seq_len)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, seq_len, vocab_size)) + + def _assert_equal_tensors(self, x, y): + self.assertEqual(y.shape, x.shape) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + for k in range(x.shape[2]): + self.assertEqual( + x[i][j][k], + y[i][j][k], + f"Tensors differ on index [{i}][{j}][{k}].", + ) + + +def _input_with_indice_as_values(length, dim): + """Retuns np.array of size (1, length, dim) where x[0, a, b] = a.""" + positions = [] + for i in range(length): + positions.append([i] * dim) + positions_input = np.array(positions) + positions_input = np.expand_dims(positions_input, axis=0) + return positions_input + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/terraformer_e2e_test.py b/tests/models/research/terraformer_e2e_test.py new file mode 100644 index 000000000..334b0656b --- /dev/null +++ b/tests/models/research/terraformer_e2e_test.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End to end test for Reformer.""" + +import os + +import gin + +from absl.testing import absltest + +from trax.data.encoder import encoder as encoder +from trax.learning.supervised import trainer_lib +from trax.utils import test_utils + +pkg_dir, _ = os.path.split(__file__) +_TESTDATA = os.path.normpath(os.path.join(pkg_dir, "../../../resources/data/testdata")) +_CONFIG_DIR = os.path.normpath( + os.path.join(pkg_dir, "../../../resources/supervised/configs") +) + + +class TerraformerE2ETest(absltest.TestCase): + def setUp(self): + super().setUp() + test_utils.ensure_flag("test_tmpdir") + gin.clear_config() + gin.add_config_file_search_path(_CONFIG_DIR) + + def test_terraformer_wmt_ende(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + tokenizer = encoder.SubwordTextEncoder( + filename=os.path.join( + _TESTDATA, "vocab.translate_ende_wmt32k.32768.subwords" + ) + ) + + gin.parse_config_file("terraformer_wmt_ende.gin") + + gin.bind_parameter("data_streams.data_dir", _TESTDATA) + gin.bind_parameter("wmt_preprocess.tokenizer", tokenizer) + gin.bind_parameter("wmt_preprocess.max_length", 20) + gin.bind_parameter("wmt_preprocess.max_eval_length", 25) + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter( + "batcher.buckets", ([512], [batch_size_per_device, batch_size_per_device]) + ) + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ConfigurableTerraformer.n_encoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.n_decoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + def test_terraformer_copy(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + gin.parse_config_file("terraformer_copy.gin") + + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("batcher.buckets", ([64], [1, 1])) # batch size 1. + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ConfigurableTerraformer.n_encoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.n_decoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + def test_terraformer_purelsh_copy(self): + batch_size_per_device = 2 + steps = 1 + n_layers = 2 + d_ff = 32 + + gin.parse_config_file("terraformer_purelsh_copy.gin") + + gin.bind_parameter("batcher.batch_size_per_device", batch_size_per_device) + gin.bind_parameter("batcher.buckets", ([64], [1, 1])) # batch size 1. + gin.bind_parameter("train.steps", steps) + gin.bind_parameter("ConfigurableTerraformer.n_encoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.n_decoder_layers", n_layers) + gin.bind_parameter("ConfigurableTerraformer.d_ff", d_ff) + + output_dir = self.create_tempdir().full_path + _ = trainer_lib.train(output_dir=output_dir) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/models/research/terraformer_memory_test.py b/tests/models/research/terraformer_memory_test.py similarity index 83% rename from trax/models/research/terraformer_memory_test.py rename to tests/models/research/terraformer_memory_test.py index 8c4a78601..cf7750c2b 100644 --- a/trax/models/research/terraformer_memory_test.py +++ b/tests/models/research/terraformer_memory_test.py @@ -23,14 +23,10 @@ from absl.testing import absltest - class TerraformerMemoryTest(absltest.TestCase): + def test_terraformer_memory(self): + pass # TODO(jonni): Figure out an OSS-compatible memory test. - def test_terraformer_memory(self): - pass # TODO(jonni): Figure out an OSS-compatible memory test. - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/terraformer_oom_test.py b/tests/models/research/terraformer_oom_test.py new file mode 100644 index 000000000..bf68a9650 --- /dev/null +++ b/tests/models/research/terraformer_oom_test.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OOM for Terraformer .""" + +import functools +import operator + +import gin +import numpy as np + +from absl.testing import absltest + +from trax import fastmath +from trax import layers as tl +from trax.models.research import terraformer +from trax.utils import shapes + + +class TerraformerOOMTest(absltest.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def test_terraformer_one_step(self): + d_model = 1024 + vocab_size = 14041 + max_len = 16384 + pos_axial = (128, 128) # should multiply to max_len + pos_d_axial_embs = (512, 512) # sum to d model + + assert operator.mul(*pos_axial) == max_len + assert sum(pos_d_axial_embs) == d_model + + d_ff = 4096 + n_heads = 8 + d_attn = d_model // n_heads + + n_buckets = 128 + encoder_chunk_len = (2 * max_len) // n_buckets # 256 + decoder_chunk_len = 2 * encoder_chunk_len # 512 + encoder_n_chunks_after = 1 # since its not causal. + + lsh_self_attention = functools.partial( + self._lsh_self_attention_fn(), n_buckets=n_buckets + ) + + encoder_lsh_self_attention = functools.partial( + lsh_self_attention, + n_chunks_after=encoder_n_chunks_after, + chunk_len=encoder_chunk_len, + ) + + decoder_lsh_self_attention = functools.partial( + lsh_self_attention, n_chunks_after=0, chunk_len=decoder_chunk_len + ) + + model = terraformer.ConfigurableTerraformer( + vocab_size, + d_model=d_model, + d_ff=d_ff, + d_attention_key=d_attn, + d_attention_value=d_attn, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=n_heads, + dropout=0.05, + max_len=max_len, + encoder_attention_type=encoder_lsh_self_attention, + encoder_decoder_attention_type=decoder_lsh_self_attention, + pos_axial_shape=pos_axial, + pos_d_axial_embs=pos_d_axial_embs, + ff_activation=tl.Relu, + ff_use_sru=0, + mode="train", + ) + + def random_sentence(): + return np.random.randint( + low=1, high=vocab_size - 1, size=(1, max_len), dtype=np.int32 + ) + + x = [random_sentence(), random_sentence()] + weights, state = model.init(shapes.signature(x)) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) + # This returns [logits, decoder tokens] + logits = logits_and_dec_toks[0] + loss = fastmath.numpy.mean(logits[..., 0]) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + weights, state, logits = mock_training_step( + x, weights, state, fastmath.random.get_prng(0) + ) + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/terraformer_test.py b/tests/models/research/terraformer_test.py new file mode 100644 index 000000000..1c4510216 --- /dev/null +++ b/tests/models/research/terraformer_test.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Terraformer models.""" + +import functools + +import gin +import numpy as np + +from absl.testing import absltest, parameterized + +from tests.layers import test_utils +from trax import fastmath +from trax import layers as tl +from trax.models.research import terraformer +from trax.utils import shapes + +BACKENDS = [fastmath.Backend.JAX] + + +def short_name(b): + if b == fastmath.Backend.JAX: + return "jax" + else: + return "tf" + + +class TerraformerTest(parameterized.TestCase): + def setUp(self): + super().setUp() + gin.clear_config() + + def _lsh_self_attention_fn(self): + return functools.partial( + tl.LSHSelfAttention, + attention_dropout=0.0, + chunk_len=64, + n_buckets=[32, 32], + n_chunks_after=0, + n_chunks_before=1, + n_hashes=1, + n_parallel_heads=1, + predict_drop_len=128, + predict_mem_len=1024, + ) + + def _timebin_self_attention_fn(self, use_reference_code=False): + return functools.partial( + tl.SelfAttention, + attention_dropout=0.05, + chunk_len=64, + n_chunks_before=1, + n_parallel_heads=1, + use_reference_code=use_reference_code, + ) + + @parameterized.named_parameters( + [ + ("_%s_efficient" % short_name(backend), backend, tl.SelfAttention, False) + for backend in BACKENDS + ] + + [ + ("_%s_causal" % short_name(backend), backend, tl.CausalAttention, False) + for backend in BACKENDS + ] + + + # NOTE: tl.SelfAttention is not currently working for this case. + [ + ("_%s_preembed" % short_name(backend), backend, tl.CausalAttention, True) + for backend in BACKENDS + ] + ) + def test_terraformer_quick(self, backend, encoder_attention_type, preembed): + with fastmath.use_backend(backend): + vocab_size = 2 + input_vocab_size = None if preembed else vocab_size + output_vocab_size = vocab_size if preembed else None + max_len = 2 + + model = terraformer.ConfigurableTerraformer( + input_vocab_size, + d_model=4, + d_ff=4, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=2, + dropout=0.05, + max_len=max_len, + pos_type=None, + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=2, + mode="train", + output_vocab_size=output_vocab_size, + encoder_attention_type=encoder_attention_type, + ) + + if preembed: + model_inputs = [ + np.ones((1, max_len, 3)).astype(np.float32), + np.ones((1, max_len)).astype(bool), + ] + else: + model_inputs = [np.ones((1, max_len)).astype(np.int32)] + x = model_inputs + [np.ones((1, max_len)).astype(np.int32)] + model.init(shapes.signature(x)) + + logits, dec_toks = model(x) + del dec_toks + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + def test_terraformer_deterministic_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + vocab_size = 16 + d_model = 4 + batch_size = 2 + length = 5 + + model_fn = functools.partial( + terraformer.ConfigurableTerraformer, + vocab_size, + d_model=d_model, + d_ff=16, + n_encoder_layers=0, + n_decoder_layers=1, + n_heads=2, + dropout=0.0, + max_len=length * 2, + pos_type=None, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + ) + + inp = np.random.randint(vocab_size, size=(batch_size, length)) + out = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_is_deterministic((inp, out), model_fn) + + def test_terraformer_predict_equals_eval(self): + with fastmath.use_backend(fastmath.Backend.JAX): + vocab_size = 16 + d_model = 8 + batch_size = 1 + length = 5 + + model_fn = functools.partial( + terraformer.ConfigurableTerraformer, + vocab_size, + d_model=d_model, + d_ff=16, + n_encoder_layers=1, + n_decoder_layers=1, + n_heads=2, + ff_use_sru=(1, 8), # ? is SRU working? + dropout=0.0, + max_len=(length + 7) * 2, + pos_type=None, + reversible_encoder=True, + n_decoder_attention_layers=1, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + ) + + # Token id of 0 indicates padding; and predict mode doesn't support it. + inp = np.random.randint(1, vocab_size, size=(batch_size, length)) + inp[:, -2:] = 0 + out = np.zeros((batch_size, length), dtype=np.int32) + + test_utils.test_eval_equals_predict( + (inp, out), model_fn, seq_axis=1, seq_tensor=-1, init_tokens=1 + ) + + def test_terraformer_doubling(self): + vocab_size = 2 + max_len = 2 + + model = terraformer.ConfigurableTerraformer( + vocab_size, + d_model=8, + d_ff=16, + n_encoder_layers=1, + n_decoder_layers=6, + n_heads=2, + dropout=0.05, + max_len=max_len, + pos_type=None, + half_before_layer=2, + double_after_layer=2, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + mode="train", + ) + + x = [ + np.ones((1, max_len)).astype(np.int32), + np.ones((1, max_len)).astype(np.int32), + ] + model.init(shapes.signature(x)) + + logits, dec_toks = model(x) + del dec_toks + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + def test_terraformer_one_step(self): + vocab_size = 32 + max_len = 256 + pos_axial = 16 + assert pos_axial * pos_axial == max_len + + chunk_len = 32 + + # Since 2 * chunk_len * n_buckets should be max_len. + n_buckets = max_len // (2 * chunk_len) + + lsh_self_attention = functools.partial( + self._lsh_self_attention_fn(), chunk_len=chunk_len, n_buckets=n_buckets + ) + + timebin_self_attention = self._timebin_self_attention_fn() + + model = terraformer.ConfigurableTerraformer( + vocab_size, + d_model=32, + d_ff=64, + d_attention_key=64, + d_attention_value=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + dropout=0.05, + max_len=max_len, + encoder_attention_type=lsh_self_attention, + encoder_decoder_attention_type=[timebin_self_attention, lsh_self_attention], + pos_axial_shape=(pos_axial, pos_axial), + pos_d_axial_embs=(64, 192), + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=64, + ff_sparsity=8, + mode="train", + ) + + x = [ + np.ones((1, max_len)).astype(np.int32), + np.ones((1, max_len)).astype(np.int32), + ] + weights, state = model.init(shapes.signature(x)) + + @fastmath.jit + def mock_training_step(x, weights, state, rng): + def compute_mock_loss(weights): + logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) + # This returns [logits, decoder tokens] + logits = logits_and_dec_toks[0] + loss = fastmath.numpy.mean(logits[..., 0]) + return loss, (new_state, logits) + + gradients, (new_state, logits) = fastmath.grad( + compute_mock_loss, has_aux=True + )(weights) + new_weights = fastmath.nested_map_multiarg( + lambda w, g: w - 1e-4 * g, weights, gradients + ) + return new_weights, new_state, logits + + weights, state, logits = mock_training_step( + x, weights, state, fastmath.random.get_prng(0) + ) + + self.assertEqual(logits.shape, (1, max_len, vocab_size)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/research/transformer2_test.py b/tests/models/research/transformer2_test.py new file mode 100644 index 000000000..be037074a --- /dev/null +++ b/tests/models/research/transformer2_test.py @@ -0,0 +1,468 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer models.""" + +import numpy as np + +from absl.testing import absltest + +from trax.models.research import transformer2 +from trax.utils import shapes + + +class Transformer2Test(absltest.TestCase): + def test_concat_with_padding(self): + vec_e = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # vec_e[:,:,0] != 0 + mask_e = np.array( + [ + [True, True, False, False, False, False], + [True, True, True, False, False, False], + ] + ) + + vec_d = np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + layer = transformer2.ConcatWithPadding(mode="train") + inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d + layer.init(shapes.signature(inp)) + y, _, _ = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + def test_concat_with_padding_predict(self): + vec_e = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # vec_e[:,:,0] != 0 + mask_e = np.array( + [ + [True, True, False, False, False, False], + [True, True, True, False, False, False], + ] + ) + + vec_d = np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + layer = transformer2.ConcatWithPadding(mode="predict") + inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d + _, _ = layer.init(shapes.signature(inp)) + y, _, _ = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + # On subsequent runs however, we should get vec_d only. + for _ in range(2): + y, _, _ = layer(inp) + np.testing.assert_equal(y, vec_d) + + def test_concat_with_padding2(self): + vec_e = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + # vec_e[:,:,0] != 0 + mask_e = np.array( + [ + [True, True, False, False, False, False], + [True, True, True, False, False, False], + ] + ) + + vec_d = np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + layer = transformer2.ConcatWithPadding2(mode="train") + inp = (vec_e, vec_e, vec_d, mask_e, vec_e, vec_d) + layer.init(shapes.signature(inp)) + y1, y2, _, _ = layer(inp) + + np.testing.assert_equal( + y1, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + np.testing.assert_equal( + y2, + np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + def test_strip_from_concatenate_with_padding(self): + enc_dec = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) + tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) + + layer = transformer2.StripFromConcatenateWithPadding(mode="train") + inp = (enc_dec, tok_e, tok_d) + _, _ = layer.init(shapes.signature(inp)) + y = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + def test_strip_from_concatenate_with_padding_predict(self): + enc_dec = np.array( + [ + [ + [7, 5, 2, 8, 8, 8, 6, 7], + [8, 2, 6, 2, 1, 1, 4, 2], + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [4, 3, 1, 7, 5, 6, 2, 1], + [6, 9, 9, 4, 1, 3, 2, 1], + [3, 8, 2, 4, 7, 9, 4, 1], + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ) + + tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) + tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) + + layer = transformer2.StripFromConcatenateWithPadding(mode="predict") + inp = (enc_dec, tok_e, tok_d) + _, _ = layer.init(shapes.signature(inp)) + y = layer(inp) + + np.testing.assert_equal( + y, + np.array( + [ + [ + [4, 7, 7, 4, 8, 9, 9, 9], + [6, 8, 2, 9, 3, 6, 6, 8], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + [ + [3, 7, 5, 6, 2, 9, 3, 1], + [4, 7, 3, 2, 1, 1, 1, 6], + [4, 7, 3, 2, 1, 1, 1, 6], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + ] + ), + ) + + # On subsequent runs however, we should get enc_dec only. + for _ in range(2): + y = layer(inp) + np.testing.assert_equal(y, enc_dec) + + def test_transformer_noencdec_forward_shape(self): + input_vocab_size = 16 + output_vocab_size = 16 + + model = transformer2.Transformer2( + input_vocab_size, + output_vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + ) + + enc_toks = np.array([[6, 2, 0, 0, 0, 0], [6, 3, 7, 0, 0, 0]]) + dec_toks = np.array([[4, 2, 0, 0], [8, 5, 0, 0]]) + + xs = [enc_toks, dec_toks] + _, _ = model.init(shapes.signature(xs)) + + # decoder output, decoder mask + ys = model(xs) + + # (B, L2, H) + self.assertEqual( + ys[0].shape, (dec_toks.shape[0], dec_toks.shape[1], output_vocab_size) + ) + + self.assertEqual(ys[1].shape, dec_toks.shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/models/resnet_test.py b/tests/models/resnet_test.py similarity index 55% rename from trax/models/resnet_test.py rename to tests/models/resnet_test.py index 3742d67ae..1e9491475 100644 --- a/trax/models/resnet_test.py +++ b/tests/models/resnet_test.py @@ -15,31 +15,29 @@ """Tests for Resnet models.""" -from absl.testing import absltest import numpy as np -from trax import fastmath -from trax import shapes +from absl.testing import absltest + from trax.models import resnet +from trax.utils import shapes class ResnetTest(absltest.TestCase): - - def test_resnet(self): - model = resnet.Resnet50(d_hidden=8, n_output_classes=10) - x = np.ones((3, 256, 256, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 10)) - - def test_wide_resnet(self): - model = resnet.WideResnet(n_blocks=1, n_output_classes=10) - x = np.ones((3, 32, 32, 3)).astype(np.float32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 10)) - - - -if __name__ == '__main__': - absltest.main() + def test_resnet(self): + model = resnet.Resnet50(d_hidden=8, n_output_classes=10) + x = np.ones((3, 256, 256, 3)).astype(np.float32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 10)) + + def test_wide_resnet(self): + model = resnet.WideResnet(n_blocks=1, n_output_classes=10) + x = np.ones((3, 32, 32, 3)).astype(np.float32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 10)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/rl_test.py b/tests/models/rl_test.py new file mode 100644 index 000000000..139687723 --- /dev/null +++ b/tests/models/rl_test.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RL.""" + +from unittest import mock + +import numpy as np + +from absl.testing import absltest + +from trax.models import rl +from trax.utils import shapes + + +class RLTest(absltest.TestCase): + def test_policy_forward_shape(self): + mock_dist = mock.MagicMock() + mock_dist.n_inputs = 4 + model = rl.Policy(policy_distribution=mock_dist) + x = np.ones((2, 3)) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (2, 4)) + + def test_value_forward_shape(self): + model = rl.Value() + x = np.ones((2, 3)) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (2, 1)) + + def test_policy_and_value_forward_shape(self): + mock_dist = mock.MagicMock() + mock_dist.n_inputs = 4 + model = rl.PolicyAndValue(policy_distribution=mock_dist) + x = np.ones((2, 3)) + _, _ = model.init(shapes.signature(x)) + ys = model(x) + self.assertEqual([y.shape for y in ys], [(2, 4), (2, 1)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/rnn_test.py b/tests/models/rnn_test.py new file mode 100644 index 000000000..86bf8d689 --- /dev/null +++ b/tests/models/rnn_test.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RNNs.""" + +from absl.testing import absltest, parameterized + +from trax import fastmath +from trax.fastmath import numpy as jnp +from trax.models import rnn +from trax.utils import shapes + +BACKENDS = [fastmath.Backend.JAX] + + +@parameterized.named_parameters(("_" + b.value, b) for b in BACKENDS) +class RNNTest(parameterized.TestCase): + def test_rnnlm_forward_shape(self, backend): + with fastmath.use_backend(backend): + model = rnn.RNNLM(vocab_size=20, d_model=16) + x = (jnp.ones((3, 28)).astype(jnp.int32),) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 28, 20)) + + def test_grulm_forward_shape(self, backend): + with fastmath.use_backend(backend): + model = rnn.GRULM(vocab_size=20, d_model=16) + x = jnp.ones((3, 28)).astype(jnp.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 28, 20)) + + def test_lstmseq2seqattn_forward_shape(self, backend): + with fastmath.use_backend(backend): + model = rnn.LSTMSeq2SeqAttn( + input_vocab_size=20, target_vocab_size=20, d_model=16 + ) + x = jnp.ones((3, 28)).astype(jnp.int32) + _, _ = model.init([shapes.signature(x), shapes.signature(x)]) + ys = model([x, x]) + self.assertEqual([y.shape for y in ys], [(3, 28, 20), (3, 28)]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/models/transformer_test.py b/tests/models/transformer_test.py new file mode 100644 index 000000000..fce7a343f --- /dev/null +++ b/tests/models/transformer_test.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer models.""" + +import functools + +import numpy as np + +from absl.testing import absltest, parameterized + +from tests.layers import test_utils +from trax.models import transformer +from trax.utils import shapes + + +class TransformerTest(parameterized.TestCase): + def test_transformer_lm_forward_shape(self): + vocab_size = 16 + model = transformer.TransformerLM( + vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2 + ) + x = np.ones((3, 5)).astype(np.int32) + _, _ = model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (3, 5, vocab_size)) + + def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + model = transformer.Transformer( + input_vocab_size, + output_vocab_size, + d_model=32, + d_ff=64, + n_encoder_layers=2, + n_decoder_layers=2, + n_heads=2, + ) + xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] + _, _ = model.init(shapes.signature(xs)) + y, _ = model(xs) + + vocab_size = output_vocab_size or input_vocab_size + self.assertEqual(y.shape, (3, 5, vocab_size)) + + @parameterized.named_parameters( + ("same_vocab", 16, None), ("same_size", 16, 16), ("different_size", 16, 50) + ) + def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + """Run the Transformer forward and check output shape.""" + self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) + + def test_dot_product_causal_attention_fast_inference(self): + model_fn = functools.partial( + transformer.TransformerLM, d_model=4, d_ff=8, n_layers=2, n_heads=2 + ) + test_utils.test_eval_equals_predict_discrete(model_fn) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/optimizers/optimizers_test.py b/tests/optimizers/optimizers_test.py similarity index 51% rename from trax/optimizers/optimizers_test.py rename to tests/optimizers/optimizers_test.py index 583f73655..43d37b435 100644 --- a/trax/optimizers/optimizers_test.py +++ b/tests/optimizers/optimizers_test.py @@ -15,36 +15,35 @@ """Tests for supervised training optimizers.""" -from absl.testing import absltest - import numpy as np +from absl.testing import absltest + from trax import optimizers from trax.optimizers import momentum class OptimizersTest(absltest.TestCase): - - def test_slots(self): - weights_shape = (3, 5) - weight_tree = np.arange(15).reshape(weights_shape) - - # SGD - an optimizer that doesn't use slots. - opt_1 = optimizers.SGD(.01) - self.assertIsNone(opt_1.slots) - opt_1.tree_init(weight_tree) - self.assertIsInstance(opt_1.slots, tuple) - self.assertLen(opt_1.slots, 1) - self.assertIsNone(opt_1.slots[0]) - - # Momentum - an optimizer with slots - opt_2 = momentum.Momentum(.01) - self.assertIsNone(opt_2.slots) - opt_2.tree_init(weight_tree) - self.assertIsInstance(opt_2.slots, tuple) - self.assertLen(opt_2.slots, 1) - self.assertEqual(weights_shape, opt_2.slots[0].shape) - - -if __name__ == '__main__': - absltest.main() + def test_slots(self): + weights_shape = (3, 5) + weight_tree = np.arange(15).reshape(weights_shape) + + # SGD - an optimizer that doesn't use slots. + opt_1 = optimizers.SGD(0.01) + self.assertIsNone(opt_1.slots) + opt_1.tree_init(weight_tree) + self.assertIsInstance(opt_1.slots, tuple) + self.assertLen(opt_1.slots, 1) + self.assertIsNone(opt_1.slots[0]) + + # Momentum - an optimizer with slots + opt_2 = momentum.Momentum(0.01) + self.assertIsNone(opt_2.slots) + opt_2.tree_init(weight_tree) + self.assertIsInstance(opt_2.slots, tuple) + self.assertLen(opt_2.slots, 1) + self.assertEqual(weights_shape, opt_2.slots[0].shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/resources/examples/python/mnist/train_test.py b/tests/resources/examples/python/mnist/train_test.py new file mode 100644 index 000000000..65c4d7412 --- /dev/null +++ b/tests/resources/examples/python/mnist/train_test.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that the example training script works on fake data.""" +import mock +import numpy as np +import tensorflow as tf + +from resources.examples.python.nn.mnist.from_scratch import dataset, train + + +class TFNumpyMnistExampleTest(tf.test.TestCase): + def testRuns(self): + with mock.patch.object(dataset, "load", new=fake_mnist_data): + train.train( + batch_size=1, + learning_rate=0.1, + num_training_iters=10, + validation_steps=5, + ) + train.train( + batch_size=2, + learning_rate=0.1, + num_training_iters=5, + validation_steps=2, + ) + train.train( + batch_size=10, + learning_rate=0.1, + num_training_iters=1, + validation_steps=1, + ) + + +def fake_mnist_data(): + def gen_examples(num_examples): + x = np.array(np.random.randn(num_examples, 784), copy=False, dtype=np.float32) + y = np.zeros((num_examples, 10), dtype=np.float32) + y[:][0] = 1.0 + return (x, y) + + return (gen_examples(100), gen_examples(10), gen_examples(10)) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/tf/extensions/extensions_test.py b/tests/tf/extensions/extensions_test.py new file mode 100644 index 000000000..309ba6a81 --- /dev/null +++ b/tests/tf/extensions/extensions_test.py @@ -0,0 +1,1170 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tf numpy mathematical methods.""" +import functools +import itertools + +import jax +import numpy as np +import tensorflow as tf + +from absl.testing import parameterized + +import trax.tf.numpy as tf_np + +from tests.fastmath.jax.config import flags +from trax.tf import extensions + +FLAGS = flags.FLAGS + +flags.DEFINE_bool("requires_tpu", False, "Requires TPU.") + + +def generate_params_inputs_targets(num_examples=1000): + params = (tf_np.asarray(tf.constant(5.0)), tf_np.asarray(tf.constant(0.0))) + + params_true = (tf_np.asarray(tf.constant(3.0)), tf_np.asarray(tf.constant(2.0))) + + inputs = tf_np.asarray(tf.random.normal(shape=[num_examples])) + noise = tf_np.asarray(tf.random.normal(shape=[num_examples])) + targets = inputs * params_true[0] + params_true[1] + noise + + return params, params_true, inputs, targets + + +def loss_fn(params, inputs, targets): + predicted = params[0] * inputs + params[1] + loss = tf.reduce_mean(input_tensor=tf.square(predicted - targets)) + return tf_np.asarray(loss) + + +def train_step(params, inputs, targets, learning_rate=0.1): + grad_fn = extensions.grad(loss_fn) + grads = grad_fn(params, inputs, targets) + new_w = params[0] - (grads[0] * learning_rate) + new_b = params[1] - (grads[1] * learning_rate) + + return new_w, new_b + + +def uniform(rng, shape, dtype): + if np.issubdtype(dtype, np.integer): + minval = None + else: + minval = 0 + return tf_np.asarray(rng.uniform(shape=shape, dtype=dtype, minval=minval)) + + +def to_np(a): + return tf.nest.map_structure(tf_np.asarray, a) + + +def to_tf_fn(f): + return lambda *args: f(*to_np(args)) + + +def scan_reference(f, init, xs): + carry = init + ys = [] + for x in xs: + (carry, y) = f(carry, x) + ys.append(tf_np.reshape(y, (1,) + y.shape)) + ys = tf_np.concatenate(ys, 0) + return carry, ys + + +def spec(*args): + return tf.TensorSpec(args, tf.float32) + + +class ExtensionsTest(tf.test.TestCase, parameterized.TestCase): + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super().__init__(methodName) + physical_devices = tf.config.experimental.list_physical_devices("CPU") + tf.config.experimental.set_virtual_device_configuration( + physical_devices[0], + [ + tf.config.experimental.VirtualDeviceConfiguration(), + tf.config.experimental.VirtualDeviceConfiguration(), + ], + ) + if extensions.tpu_devices(): + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") + tf.tpu.experimental.initialize_tpu_system(resolver) + + def _hasGPU(self): + physical_devices = tf.config.experimental.list_physical_devices("GPU") + return physical_devices + + def testCustomGrad(self): + """Test for custom_grad.""" + x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) + y_shape = tf.TensorShape([]) + dtype = np.float32 + scale1 = 5.0 + scale2 = 6.0 + + def fwd(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + @extensions.custom_grad + def f(a, b): + y = fwd(a, b) + + def vjp(dy): + return dy * scale1 * a, dy * scale2 * b + + return y, vjp + + rng = tf.random.Generator.from_seed(1234) + x, dy = tf.nest.map_structure( + lambda shape: uniform(rng, shape, dtype), [x_shape, y_shape] + ) + expected_y = fwd(*x) + expected_dx = (dy * scale1 * x[0], dy * scale2 * x[1]) + y, vjp = extensions.vjp(f, *x) + dx = vjp(dy) + self.assertAllClose(expected_y, y) + self.assertAllClose(expected_dx, dx) + + @parameterized.named_parameters( + [ + ( # pylint: disable=g-complex-comprehension + ("_%s_%s_%s" % (decorator_id, x_struct, y_struct)) + .replace(" ", "") + .replace("None", ""), + decorator, + x_struct, + y_struct, + ) + for y_struct in [[None, ()], (None, (), [], (None, ((), None)))] + for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))] + for decorator_id, decorator in enumerate([lambda f: f, extensions.jit]) + ] + ) + def testCustomGradStructure(self, decorator, x_struct, y_struct): + """Tests that custom_grad can handle structured inputs/outputs.""" + + def zeros(x): + return tf.nest.map_structure(lambda _: tf_np.zeros([], np.float32), x) + + def get_struct(x): + return tf.nest.map_structure(lambda _: None, x) + + @extensions.custom_grad + def f(*x): + del x + + def vjp(dy): + self.assertEqual(y_struct, get_struct(dy)) + return zeros(x_struct) + + return zeros(y_struct), vjp + + x, dy = zeros([x_struct, y_struct]) + + @decorator + def run(x, dy): + y, vjp = extensions.vjp(f, *x) + dx = vjp(dy) + return dx, y + + dx, y = run(x, dy) + self.assertEqual(x_struct, get_struct(dx)) + self.assertEqual(y_struct, get_struct(y)) + + @parameterized.named_parameters( + [("_%s" % has_aux, has_aux) for has_aux in [True, False]] + ) + def testVjp(self, has_aux): + x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) + y_shape = tf.TensorShape([]) + dtype = np.float32 + + def f(a, b): + y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + if has_aux: + return y, tf_np.asarray(1) + else: + return y + + rng = tf.random.Generator.from_seed(1234) + x, dy_list = tf.nest.map_structure( + lambda shape: uniform(rng, shape, dtype), [x_shape, [y_shape] * 2] + ) + tf_x = x + outputs = extensions.vjp(f, *x, has_aux=has_aux) + if has_aux: + y, vjp, aux = outputs + else: + y, vjp = outputs + with tf.GradientTape(persistent=True) as tape: + tape.watch(tf_x) + outputs = f(*x) + if has_aux: + expected_y, expected_aux = outputs + self.assertAllClose(expected_aux, aux) + else: + expected_y = outputs + self.assertAllClose(expected_y, y) + for dy in dy_list: + expected_dx = tape.gradient(expected_y, tf_x, output_gradients=dy) + self.assertAllClose(expected_dx, vjp(dy)) + + def testGrad(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + g = extensions.grad(f) + + def compare(a, b): + with tf.GradientTape() as tape: + tape.watch(a) + r = f(a, b) + expected = tape.gradient(r, a) + self.assertAllEqual(expected, g(a, b)) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + compare(a, b) + + def testGradNonArrayOutput(self): + def f(_): + return 1.0 + + g = extensions.grad(f) + with self.assertRaisesWithPredicateMatch( + ValueError, r"result .* must be an ndarray" + ): + g(tf_np.asarray(1.0)) + + def testGradNonScalarOutput(self): + def f(a): + return a + + g = extensions.grad(f) + with self.assertRaisesWithPredicateMatch( + ValueError, r"result .* must be a scalar" + ): + g(tf_np.asarray([1.0, 2.0])) + + @extensions.jit + def g_jitted(a): + return extensions.grad(f)(a) + + g_jitted(tf_np.asarray(1.0)) + with self.assertRaisesWithPredicateMatch( + ValueError, r"result .* must be a scalar" + ): + g_jitted(tf_np.asarray([1.0, 2.0])) + + def testJit(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_jitted = extensions.jit(f) + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + self.assertAllClose(f(a, b), f_jitted(a, b)) + # Call again since the code path is different on second call + self.assertAllClose(f(a, b), f_jitted(a, b)) + + def testJitNoUnnecessaryTracing(self): + def num_traces(f): + return len(f.tf_function._list_all_concrete_functions_for_serialization()) + + def check_trace_only_once(arg1, arg2): + @extensions.jit + def f(a): + return a + 1 + + self.assertAllEqual(0, num_traces(f)) + f(arg1) + self.assertAllEqual(1, num_traces(f)) + f(arg2) + self.assertAllEqual(1, num_traces(f)) + + check_trace_only_once(1, 2) + check_trace_only_once(1.1, 2.1) + check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) + check_trace_only_once( + tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2) + ) + + def _testEvalOnShapes(self, transformer, allow_static_outputs): + # A class that's not convertable to tensor + class Thing: + def __init__(self, value): + self.value = value + + def f(a, b, reverse=False): + res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + res = (res, 10) + if allow_static_outputs: + res = res + (Thing(20),) + if reverse: + res = tuple(reversed(res)) + return res + + f_prime = transformer( + f, static_argnums=(2,), allow_static_outputs=allow_static_outputs + ) + shape = [10] + dtype = np.float16 + a = tf_np.zeros(shape=shape, dtype=dtype) + b = tf_np.zeros(shape=shape, dtype=dtype) + expected, *_ = f(a, b) + got = f_prime(a, b) + + def check(got): + self.assertIsInstance(got[0], (tf.TensorSpec, tf_np.ndarray)) + self.assertAllEqual(expected.shape, got[0].shape) + self.assertAllEqual(expected.dtype, got[0].dtype) + if allow_static_outputs: + self.assertIsInstance(got[1], int) + self.assertEqual(10, got[1]) + self.assertIsInstance(got[2], Thing) + self.assertEqual(20, got[2].value) + else: + self.assertIsInstance(got[1], (tf.TensorSpec, tf_np.ndarray)) + self.assertAllEqual((), got[1].shape) + + check(got) + # Call again since the code path is different on second call + got = f_prime(a, b) + check(got) + # Retrace and check again + got = f_prime(a, b, True) + check(tuple(reversed(got))) + got = f_prime(a, b, True) + check(tuple(reversed(got))) + + @parameterized.named_parameters(("_%s" % b, b) for b in [False, True]) + def testEvalOnShapes(self, allow_static_outputs): + self._testEvalOnShapes(extensions.eval_on_shapes, allow_static_outputs) + + def testEvalOnShapesNested(self): + transformer = functools.partial( + extensions.eval_on_shapes, allow_static_outputs=True + ) + + @transformer + def outer(): + @transformer + def inner(): + return 1 + + return inner() + 2 + + r = outer() + self.assertIsInstance(r, int) + self.assertEqual(3, r) + + def testJitOfEvalOnShapes(self): + """Tests that eval_on_shapes can be called within jit.""" + + def transformer(f, **kwargs): + def f_prime(*args): + res = extensions.eval_on_shapes(f, **kwargs)(*args) + return tf.nest.map_structure( + lambda x: tf_np.zeros(x.shape, x.dtype), res + ) + + return extensions.jit(f_prime, kwargs.get("static_argnums", ())) + + self._testEvalOnShapes(transformer, False) + + def testEvalOnShapesNoUnnecessaryTracing(self): + def num_traces(f): + return len(f._tf_function._list_all_concrete_functions_for_serialization()) + + def check_trace_only_once(arg1, arg2): + @extensions.eval_on_shapes + def f(a): + return a + 1 + + self.assertAllEqual(0, num_traces(f)) + f(arg1) + self.assertAllEqual(1, num_traces(f)) + f(arg2) + self.assertAllEqual(1, num_traces(f)) + + check_trace_only_once(1, 2) + check_trace_only_once(1.1, 2.1) + check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) + check_trace_only_once( + tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2) + ) + + @parameterized.parameters( + { + "lhs_np": np.ones((5, 3)), + "rhs_np": np.ones((3, 2)), + "dims": (((1,), (0,)), ((), ())), + }, + { + "lhs_np": np.ones((5, 3)), + "rhs_np": np.ones((5, 3)), + "dims": (((0, 1), (0, 1)), ((), ())), + }, + { + "lhs_np": np.ones((5, 3, 2)), + "rhs_np": np.ones((2, 3, 2)), + "dims": (((1, 2), (1, 0)), ((), ())), + }, + { + "lhs_np": np.ones((6, 5, 3)), + "rhs_np": np.ones((6, 3, 2)), + "dims": (((2,), (1,)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((6, 3, 5)), + "rhs_np": np.ones((6, 3, 2)), + "dims": (((1,), (1,)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((5, 3, 2, 2)), + "rhs_np": np.ones((5, 2, 2, 6)), + "dims": (((2, 3), (1, 2)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((2, 2, 5, 3)), + "rhs_np": np.ones((2, 2, 3, 2)), + "dims": (((3,), (2,)), ((0, 1), (0, 1))), + }, + { + "lhs_np": np.ones((2, 2, 5, 2)), + "rhs_np": np.ones((2, 2, 3, 2)), + "dims": (((3,), (1,)), ((0,), (0,))), + }, + { + "lhs_np": np.ones((2, 2, 5, 3, 3)), + "rhs_np": np.ones((2, 3, 2, 3, 2)), + "dims": (((4,), (1,)), ((0,), (0,))), + }, + ) + def test_tf_dot_general(self, lhs_np, rhs_np, dims): + ans = jax.lax.dot_general(lhs_np, rhs_np, dims) + result = extensions.tf_dot_general(lhs_np, rhs_np, dims) + self.assertAllClose(result, np.array(ans)) + + @parameterized.named_parameters( + [ + ( + "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" # pylint: disable=g-complex-comprehension + "_lhs_dilation={}_rhs_dilation={}" + "_feature_group_count={}_batch_group_count={}_dims={}" + "_perms={}".format( + lhs_shape, + rhs_shape, + strides, + padding, + lhs_dilation, + rhs_dilation, + feature_group_count, + batch_group_count, + ",".join(dimension_numbers), + perms, + ), + lhs_shape, + rhs_shape, + strides, + padding, + lhs_dilation, + rhs_dilation, + feature_group_count, + batch_group_count, + dimension_numbers, + perms, + ) + for batch_group_count, feature_group_count in [(1, 1)] + for lhs_shape, rhs_shape in [ + ( + (b * batch_group_count, i * feature_group_count, 9, w), + (j * feature_group_count * batch_group_count, i, 4, 5), + ) + for w in [1, 10] + for b, i, j in itertools.product([2, 3], repeat=3) + ] + for strides in [(1, 1), (2, 1)] + for padding in ["SAME"] + for lhs_dilation, rhs_dilation in [(None, (1, 1))] + for dimension_numbers, perms in [ + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) + ] + ] + ) + def testConvGeneralDilated( + self, + lhs_shape, + rhs_shape, + strides, + padding, + lhs_dilation, + rhs_dilation, + feature_group_count, + batch_group_count, + dimension_numbers, + perms, + ): + lhs_perm, rhs_perm = perms # permute to compatible shapes + + lhs = np.transpose(np.ones(lhs_shape), lhs_perm) + rhs = np.transpose(np.ones(rhs_shape), rhs_perm) + + jax_conv = jax.lax.conv_general_dilated( + lhs, + rhs, + strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count, + ) + + tf_conv = extensions.tf_conv_general_dilated( + lhs, + rhs, + strides, + padding, + None, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count, + batch_group_count, + ) + + self.assertAllClose(tf_conv, tf_np.asarray(jax_conv)) + + def testConv(self): + y = extensions.conv( + np.ones([5, 320, 480, 3], dtype=np.float32), + np.ones([3, 4, 3, 11], dtype=np.float32), + [1, 1], + "SAME", + ("NHWC", "HWIO", "NHWC"), + ) + self.assertAllClose(y.shape, [5, 320, 480, 11]) + self.assertAllClose( + y, + tf.nn.conv2d( + input=tf.ones([5, 320, 480, 3], dtype=tf.float32), + filters=tf.ones([3, 4, 3, 11], dtype=tf.float32), + strides=1, + padding="SAME", + ), + ) + + def testAvgPool(self): + y = extensions.avg_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") + self.assertAllEqual( + y, + tf.nn.pool( + input=tf.ones([5, 320, 480, 3]), + window_shape=[3, 5], + pooling_type="AVG", + padding="VALID", + strides=[2, 3], + ), + ) + + def testMaxPool(self): + y = extensions.max_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") + self.assertAllEqual( + y, + tf.nn.pool( + input=tf.ones([5, 320, 480, 3]), + window_shape=[3, 5], + pooling_type="MAX", + padding="VALID", + strides=[2, 3], + ), + ) + + def assertDTypesEqual(self, a, b): + get_dtype = lambda t: t.dtype + self.assertEqual( + tf.nest.map_structure(get_dtype, a), tf.nest.map_structure(get_dtype, b) + ) + + @parameterized.named_parameters( + ( + f"_{jit_scan}_{jit_f}", + jit_scan, + jit_f, + ) # pylint: disable=g-complex-comprehension + for jit_f in [False, True] + for jit_scan in ["no", "no_xla", "xla_forced_compile"] + ) + def testScanImpl(self, jit_scan, jit_f): + rng = np.random.RandomState(0) + + d = rng.randn(2) + + def f(c, a): + assert a.shape == (3,) + assert c.shape == (4,) + b = tf_np.cos( + tf_np.sum(tf_np.sin(a)) + + tf_np.sum(tf_np.cos(c)) + + tf_np.sum(tf_np.tan(d)) + ) + c = tf_np.sin(c * b) + assert b.shape == () # pylint: disable=g-explicit-bool-comparison + return c, b + + if jit_f: + f = extensions.jit(f) + + if jit_scan == "no_xla": + scan = extensions.jit(extensions.scan, static_argnums=(0,)) + elif jit_scan == "xla_forced_compile": + scan = extensions.jit( + extensions.scan, static_argnums=(0,), xla_forced_compile=True + ) + else: + scan = extensions.scan + + xs = rng.randn(5, 3) + c = rng.randn(4) + + ans = scan(f, c, xs) + expected = scan_reference(f, c, xs) + if jit_scan == "xla_forced_compile": + # xla.compile doesn't preserve list-vs-tuple properly for the outputs, so + # we canonicalize them to lists here. + expected = list(expected) + ans = list(ans) + self.assertDTypesEqual(expected, ans) + self.assertAllClose(expected, ans) + + def testScanStruct(self): + rng = np.random.RandomState(0) + + d = rng.randn(2) + + def f(c_g_i, a_e_h): + c_g, i = c_g_i + c, g = c_g + a, e_h = a_e_h + e, h = e_h + assert a.shape == (3,) + assert e.shape == () # pylint: disable=g-explicit-bool-comparison + assert c.shape == (4,) + assert g.shape == (2,) + assert i is None + assert h is None + b = tf_np.cos( + tf_np.sum(tf_np.sin(a)) + + tf_np.sum(tf_np.cos(c)) + + tf_np.sum(tf_np.tan(d)) + ) + f = tf_np.cos(a) + c = tf_np.sin(c * b) + g = tf_np.sin(g * b) + assert b.shape == () # pylint: disable=g-explicit-bool-comparison + assert f.shape == (3,) + return [(c, g), i], (b, [f, h]) + + xs = (rng.randn(5, 3), [rng.randn(5), None]) + init = [(rng.randn(4), rng.randn(2)), None] + + c_g_i, b_f_h = extensions.scan(f, init, xs) + self.assertIsInstance(c_g_i, list) + self.assertIsInstance(b_f_h, tuple) + c_g, i = c_g_i + c, g = c_g + self.assertIsInstance(c_g, tuple) + self.assertEqual((4,), c.shape) + self.assertEqual((2,), g.shape) + self.assertIsNone(i) + b, f_h = b_f_h + f, h = f_h + self.assertIsInstance(f_h, list) + self.assertEqual((5,), b.shape) + self.assertEqual((5, 3), f.shape) + self.assertIsNone(h) + + @parameterized.named_parameters( + ( + f"_{jit_grad}_{jit_scan}_{jit_f}", + jit_grad, + jit_scan, + jit_f, + ) # pylint: disable=g-complex-comprehension + for jit_f in [False, True] + for jit_scan in ["no", "no_xla", "xla_forced_compile"] + for jit_grad in ["no", "no_xla", "xla_forced_compile"] + ) + def testScanGrad(self, jit_grad, jit_scan, jit_f): + rng = np.random.RandomState(0) + + d = rng.randn(2) + + def f(c, a): + assert a.shape == (3,) + assert c.shape == (4,) + b = ( + tf_np.sum(tf_np.sin(a)) + + tf_np.sum(tf_np.sin(c)) + + tf_np.sum(tf_np.sin(d)) + ) + c = tf_np.sin(c * b) + assert b.shape == () # pylint: disable=g-explicit-bool-comparison + return c, b + + if jit_f: + f = extensions.jit(f) + + if jit_scan == "no_xla": + scan = extensions.jit(extensions.scan, static_argnums=(0,)) + elif jit_scan == "xla_forced_compile": + # TODO(b/187107596): Remove `skipTest` + self.skipTest( + "Taking gradients of `jit(scan, experimental_compile=True)` triggers " + "'Support for TensorList crossing the XLA/TF boundary is not " + "implemented' error" + ) + # `xla_forced_compile=True` doesn't support gradients, so we use + # `experimental_compile=True`. + scan = extensions.jit( + extensions.scan, static_argnums=(0,), experimental_compile=True + ) + else: + scan = extensions.scan + + xs = tf_np.asarray(rng.randn(5, 3)) + c = tf_np.asarray(rng.randn(4)) + + def losses(scan, c, xs): + c, ys = scan(f, c, xs) + return tf_np.concatenate( + tf.nest.flatten( + tf.nest.map_structure(lambda a: tf_np.reshape(a, [-1]), (c, ys)) + ) + ) + + def loss(scan, c, xs): + return tf_np.sum(losses(scan, c, xs)) + + def grad_origin(c, xs): + return extensions.grad(functools.partial(loss, scan))(c, xs) + + if jit_grad == "no_xla": + grad_jit = extensions.jit(grad_origin) + elif jit_grad == "xla_forced_compile": + grad_jit = extensions.jit(grad_origin, xla_forced_compile=True) + else: + grad_jit = grad_origin + + ans = grad_jit(c, xs) + expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) + self.assertDTypesEqual(expected, ans) + self.assertAllClose(expected, ans) + + theoretical, numerical = tf.test.compute_gradient( + to_tf_fn(functools.partial(losses, scan)), (c, xs) + ) + self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4) + + @parameterized.named_parameters( + (f"_{i}", *args) # pylint: disable=g-complex-comprehension + for i, args in enumerate( + [ + ( + lambda c, x: (c + 1, tf_np.sum(c + x, 0)), + [spec(2), spec(4, 3, 2)], + [spec(2), spec(4, 2)], + ), + ( + lambda c, x: (c + 1, tf_np.sum(c + x, 0)), + [spec(2), spec(0, 3, 2), 0], + [spec(2), spec(0, 2)], + ), + ] + ) + ) + def testScanShape(self, f, inputs, expected_outputs): + outputs = extensions.eval_on_shapes( + functools.partial(extensions.scan, f), static_argnums=(2,) + )(*inputs) + self.assertAllEqual(expected_outputs, outputs) + + def testMap(self): + shape = [2, 3] + dtype = tf_np.int32 + xs1 = tf_np.zeros(shape, dtype) + xs2 = tf_np.ones(shape, dtype) + ys_expected = [xs2 + 10, xs1 + 20] + + def f(x): + self.assertIsInstance(x, tuple) + for a in x: + self.assertEqual(a.shape, shape[1:]) + x1, x2 = x + return [x2 + 10, x1 + 20] + + ys = extensions.tf_map(f, (xs1, xs2)) + self.assertIsInstance(ys, list) + self.assertAllClose(ys, ys_expected) + + def testPrng(self): + self.assertAllEqual(tf_np.asarray(123, np.int64), extensions.prng(123)) + + def testUniform(self): + minval = 0.43 + maxval = 3.10 + shape = [13, 34, 29] + atol = 0.1 + outputs = extensions.uniform(123, shape, minval=minval, maxval=maxval) + self.assertAllClose((minval + maxval) / 2.0, np.mean(outputs), atol=atol) + + def testNormal(self): + shape = [13, 34, 29] + atol = 0.1 + outputs = extensions.normal(123, shape) + self.assertAllClose(0, np.mean(outputs), atol=atol) + self.assertAllClose(1, np.std(outputs), atol=atol) + + def testBernoulli(self): + mean = 0.23 + shape = [13, 34, 29] + atol = 0.1 + outputs = extensions.bernoulli(123, mean, shape) + self.assertAllClose(mean, np.mean(outputs), atol=atol) + + def testBernoulliWrongShape(self): + mean = [0.1, 0.2] + shape = [3] + with self.assertRaisesIncompatibleShapesError(): + extensions.bernoulli(123, mean, shape) + + def testDatasetAsNumpy(self): + arrs = extensions.dataset_as_numpy([tf.constant([1, 2]), tf.constant([3, 4])]) + for a in arrs: + self.assertIsInstance(a, tf_np.ndarray) + with self.assertRaisesWithPredicateMatch( + ValueError, + r"dataset_as_numpy must be run in eager mode outside tf.function", + ): + + @tf.function + def f(): + return extensions.dataset_as_numpy([tf.constant([1, 2])]) + + f() + + def _get_two_devices(self, require_same_type=False): + tpus = extensions.tpu_devices() + if FLAGS.requires_tpu: + if len(tpus) == 2: + res = tpus + else: + raise ValueError( + "This test requires 2 TPU cores but %s are found" % len(tpus) + ) + else: + if len(tpus) == 2: + res = tpus + elif self._hasGPU() and not require_same_type: + res = ("CPU:0", "GPU:0") + else: + res = ("CPU:0", "CPU:1") + return res + + def testPmap(self): + devices = self._get_two_devices() + + @functools.partial(extensions.pmap, devices=devices) + def return_three(f): + return f, f + 1.0, f + 2.0 + + result = return_three(tf.ones((2, 20))) + # The function returned 3 items, so we got 3 items back. + self.assertLen(result, 3) + + # Each of the items should be a ShardedNdarray that when converted to tensor + # should produce a tensor of shape (2, 20) + converted = tf.nest.map_structure(tf.convert_to_tensor, result) + + self.assertLen(result, 3) + + self.assertAllEqual(converted[0].shape, converted[1].shape) + self.assertAllEqual(converted[0].shape, converted[2].shape) + + self.assertAllEqual(converted[0], tf.ones((2, 20))) + self.assertAllEqual(converted[1], 1 + tf.ones((2, 20))) + self.assertAllEqual(converted[2], 2 + tf.ones((2, 20))) + + @functools.partial(extensions.pmap, devices=devices) + def return_one(f): + return f + 2.0 + + result = return_one(tf.ones((2, 20))) + + # Only a single item is returned, so we can convert it directly. + converted = tf.convert_to_tensor(value=result) + self.assertAllEqual(converted, 2 + tf.ones((2, 20))) + + @functools.partial(extensions.pmap, devices=devices) + def return_list(f): + return [f + 2.0] + + result = return_list(tf.ones((2, 20))) + + # A singleton list is returned. + self.assertLen(result, 1) + converted = tf.convert_to_tensor(value=result[0]) + self.assertAllEqual(converted, 2 + tf.ones((2, 20))) + + def testGradSimpleModel(self): + params, params_true, inputs, targets = generate_params_inputs_targets() + + for _ in range(50): + params = train_step(params, inputs, targets) + + # This is not trained super well, but it usually gets "close". + self.assertAllClose(params[0], params_true[0], atol=1e-1) + self.assertAllClose(params[1], params_true[1], atol=1e-1) + + # NOTE: Compare to testGradSimpleModel to see the differences when pmapping. + def testPmapSimpleModel(self): + devices = self._get_two_devices(require_same_type=True) + n_devices = len(devices) + + params, params_true, inputs, targets = generate_params_inputs_targets() + + def _train_and_reduce(params, inputs, targets, learning_rate=0.1): + new_w, new_b = train_step(params, inputs, targets, learning_rate) + + return ( + extensions.psum(new_w) / n_devices, + extensions.psum(new_b) / n_devices, + ) + + train_step_pmapped = extensions.pmap(_train_and_reduce, devices=devices) + + def replicate(x, num_devices=2): + return tf_np.broadcast_to(x, (num_devices,) + x.shape) + + params = tf.nest.map_structure(replicate, params) + + def reshape(x, num_devices=2): + x_shape = list(x.shape) + batch_size = x_shape[0] + batch_size_per_device = batch_size // num_devices + + # New shape. + new_shape_prefix = [num_devices, batch_size_per_device] + return tf_np.reshape(x, new_shape_prefix + x_shape[1:]) + + inputs = tf.nest.map_structure(reshape, inputs) + targets = tf.nest.map_structure(reshape, targets) + + for _ in range(50): + params = train_step_pmapped(params, inputs, targets) + + # PMAP returns sharded tensors. + + # Since the inputs are identical, the returned tensors should be identical + self.assertAllClose(params[0][0], params[0][1]) + self.assertAllClose(params[1][0], params[1][1]) + + # This is not trained super well, but it usually gets "close". + self.assertAllClose(params[0][0], params_true[0], atol=1e-1) + self.assertAllClose(params[1][0], params_true[1], atol=1e-1) + + def testPsum(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(f): + return extensions.psum(f) + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + pmapped = extensions.pmap(reduce_sum, devices=devices) + result = pmapped(data) + + self.assertAllClose(result[0], 4) + self.assertAllClose(result[1], 4) + + def testPsumStruct(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(a): + a = extensions.psum(a) + tf.nest.map_structure(lambda x: self.assertIsInstance(x, tf_np.ndarray), a) + return a + + data = [tf_np.asarray([1, 3]), tf_np.asarray([2, 4], np.int64)] + pmapped = extensions.pmap(reduce_sum, devices=devices) + result = pmapped(data) + + self.assertIsInstance(result[0][0], tf_np.ndarray) + self.assertIsInstance(result[0][1], tf_np.ndarray) + self.assertIsInstance(result[1][0], tf_np.ndarray) + self.assertIsInstance(result[1][1], tf_np.ndarray) + self.assertAllClose(result[0][0], 4) + self.assertAllClose(result[0][1], 4) + self.assertAllClose(result[1][0], 6) + self.assertAllClose(result[1][1], 6) + + def testPmean(self): + if extensions.tpu_devices(): + self.skipTest("pmean for TPU is not supported yet") + devices = self._get_two_devices(require_same_type=True) + + def reduce_mean(f): + return extensions.pmean(f) + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + pmapped = extensions.pmap(reduce_mean, devices=devices) + result = pmapped(data) + + self.assertAllClose(result[0], 2) + self.assertAllClose(result[1], 2) + + def testAxisName(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(f): + return extensions.psum(f, axis_name="foo") + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) + pmapped(data) + + def testWrongAxisName(self): + devices = self._get_two_devices(require_same_type=True) + + def reduce_sum(f): + return extensions.psum(f, axis_name="bar") + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + with self.assertRaisesWithPredicateMatch( + ValueError, r"axis_name (.*) is not equal to that of the surrounding" + ): + pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) + pmapped(data) + + def testNoNestedPmap(self): + devices = self._get_two_devices(require_same_type=True) + + def f(x): + return x + 1.0 + + data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) + with self.assertRaisesWithPredicateMatch( + ValueError, r"Nested pmap is not supported" + ): + f = extensions.pmap(f, devices=devices) + f = extensions.pmap(f, devices=devices) + f(data) + + def testVmap(self): + fn1 = extensions.vmap(lambda z: z * z) + + x = tf_np.arange(10) + self.assertAllClose(x * x, fn1(x)) + + y = tf.range(10) + np_y = tf_np.asarray(y) + output = fn1(y) + self.assertIsInstance(output, tf_np.ndarray) + self.assertAllClose(np_y * np_y, output) + + fn2 = extensions.vmap(lambda x, y: x + y) + x = tf_np.random.randn(10, 3) + y = tf_np.random.randn(10, 2, 3) + self.assertAllClose(tf_np.expand_dims(x, 1) + y, fn2(x, y)) + + def testRemat(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.grad(f_remat)(a, b) + expected = extensions.grad(f)(a, b) + self.assertAllClose(actual, expected) + + def testRematLambdaFunction(self): + f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.grad(f_remat)(a, b) + expected = extensions.grad(f)(a, b) + self.assertAllClose(actual, expected) + + def testRematJit(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.jit(extensions.grad(f_remat))(a, b) + expected = extensions.jit(extensions.grad(f))(a, b) + self.assertAllClose(actual, expected) + + def testRematJitXla(self): + def f(a, b): + return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) + + f_remat = extensions.remat(f) + + shape = [10] + a = tf_np.random.randn(*shape) + b = tf_np.random.randn(*shape) + + actual = extensions.jit(extensions.grad(f_remat), xla_forced_compile=True)(a, b) + expected = extensions.jit(extensions.grad(f), xla_forced_compile=True)(a, b) + self.assertAllClose(actual, expected) + + actual = extensions.jit(extensions.grad(f_remat), experimental_compile=True)( + a, b + ) + expected = extensions.jit(extensions.grad(f), experimental_compile=True)(a, b) + self.assertAllClose(actual, expected) + + def testStaticStopGradient(self): + self.assertEqual(extensions.stop_gradient(5.0), 5.0) + self.assertEqual(type(extensions.stop_gradient(5.0)), type(5.0)) + + self.assertEqual(extensions.stop_gradient(tf_np.asarray(5.0)), 5.0) + self.assertNotEqual( + type(extensions.stop_gradient(tf_np.asarray(5.0))), type(5.0) + ) + + +if __name__ == "__main__": + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/trax/tf_numpy/public_symbol_test.py b/tests/tf/public_symbol_test.py similarity index 80% rename from trax/tf_numpy/public_symbol_test.py rename to tests/tf/public_symbol_test.py index 23f3ebe4e..fd848ac41 100644 --- a/trax/tf_numpy/public_symbol_test.py +++ b/tests/tf/public_symbol_test.py @@ -15,24 +15,22 @@ """Tests different ways to use the public tf-numpy module.""" import numpy as onp - import tensorflow as tf import tensorflow.experimental.numpy as np1 -from tensorflow.experimental import numpy as np2 # pylint: disable=reimported +from tensorflow.experimental import numpy as np2 # pylint: disable=reimported np3 = tf.experimental.numpy class PublicSymbolTest(tf.test.TestCase): - - def testSimple(self): - a = 0.1 - b = 0.2 - for op in [np1.add, np2.add, np3.add]: - self.assertAllClose(onp.add(a, b), op(a, b)) + def testSimple(self): + a = 0.1 + b = 0.2 + for op in [np1.add, np2.add, np3.add]: + self.assertAllClose(onp.add(a, b), op(a, b)) if __name__ == "__main__": - tf.compat.v1.enable_eager_execution() - tf.test.main() + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tests/trainers/base_test.py b/tests/trainers/base_test.py new file mode 100644 index 000000000..2ea038722 --- /dev/null +++ b/tests/trainers/base_test.py @@ -0,0 +1,383 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for accelerated optimization of loss layers.""" + +import time + +import numpy as np + +from absl.testing import absltest +from trainers.base import ReversibleSerialTrainer, Trainer, extract_reversible_blocks + +from tests.fastmath.jax.config import config +from trax import fastmath, optimizers +from trax import layers as tl +from trax.layers import base +from trax.models.research import terraformer +from trax.utils import shapes + + +class TrainerTest(absltest.TestCase): + def _assert_all_equal(self, t1, t2, tol=1e-5): + def eq(x1, x2): + diff = np.maximum(np.abs(x1 - x2) - tol, 0.0) + self.assertLessEqual( + np.sum(diff), 0.0, msg=f"\n{x1}\n !=\n{x2}\n diff:\n{x1-x2}" + ) + + fastmath.nested_map_multiarg(eq, t1, t2) + + def test_run_simple_task(self): + """Runs an accelerated optimizer on a simple task.""" + inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch + targets_batch = np.pi * np.ones_like(inputs_batch) + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss()) + loss_layer.init(labeled_batch) + optimizer = optimizers.SGD(0.01) + optimizer.tree_init(loss_layer.weights) + trainer = Trainer(loss_layer, optimizer) + rng = fastmath.random.get_prng(0) + trainer.one_step(labeled_batch, rng) + + def test_run_sharded_terraformer(self): + """Runs Terraformer with sharded weights (only on 2+-device systems).""" + if fastmath.local_device_count() == 1: + return + base.N_WEIGHTS_SHARDS = fastmath.local_device_count() + inputs_batch = np.arange(8).reshape((2, 4)) + 1 + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) + input_sig = (int_sig, int_sig, int_sig) + # We want to test rng propagation too, so adding some dropout layers. + model = terraformer.ConfigurableTerraformer( + 20, + d_model=8, + d_ff=32, + n_heads=1, + dropout=0.0, + n_encoder_layers=2, + n_decoder_layers=2, + ff_sparsity=(4, 8, 0.0, 1.0), + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + pos_type=None, + reversible_encoder=True, + ) + loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) + model_with_loss = tl.Serial(model, loss) + rng_init = fastmath.random.get_prng(12) + model_with_loss.init(input_sig, rng=rng_init) + + # Make a step with the trainers. + optimizer = optimizers.Adafactor(0.01) + split_w = fastmath.nested_map( + lambda x: x[0], tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS) + ) + optimizer.tree_init(split_w) + trainer = Trainer(model_with_loss, optimizer) + rng_step1 = fastmath.random.get_prng(7) + trainer.one_step(labeled_batch, rng_step1) + # Reset shards back to default. + base.N_WEIGHTS_SHARDS = 1 + + def test_run_reversible_slots(self): + """Tests that slots can be read and assigned in reversible trainers.""" + layers = [tl.Dense(4), tl.Dup()] + rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)), tl.ReversibleSwap()] + loss_layer = tl.Serial( + tl.Concatenate(), tl.Dense(4), tl.LogSoftmax(), tl.CrossEntropyLoss() + ) + trainer = ReversibleSerialTrainer( + [(layers, rev_layers)], loss_layer, optimizers.Adam + ) + slots = trainer.slots + trainer.slots = slots + self.assertEqual(slots, trainer.slots) + + def test_run_reversible_same_as_default_basic(self): + """Runs the reversible trainers, check results are the same as default.""" + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + # We want to test rng propagation too, so adding some dropout layers. + first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) + rev_layers = [ + tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), + tl.ReversibleSwap(), + tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), + tl.ReversibleSwap(), + ] + loss_layer = tl.Serial( + tl.Concatenate(), + tl.Dense(19), + tl.Dropout(0.3), + tl.LogSoftmax(), + tl.CrossEntropyLoss(), + ) + model = tl.Serial([first_layer] + rev_layers + [loss_layer]) + rng_init = fastmath.random.get_prng(12) + model.init(labeled_batch, rng=rng_init) + optimizer_fn = optimizers.Adam # to test slots + + # Make 2 steps with the original trainers. + optimizer = optimizer_fn() + optimizer.tree_init(model.weights) + trainer = Trainer(model, optimizer) + rng_step1 = fastmath.random.get_prng(7) + rng_step2 = fastmath.random.get_prng(8) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + first_layer_weights1 = first_layer.weights + rev_layer0_weights1 = rev_layers[0].weights + rev_layer2_weights1 = rev_layers[2].weights + loss_layer_weights1 = loss_layer.weights + + # Now make 2 steps with reversible trainers. + model.init(labeled_batch, rng=rng_init) + trainer = ReversibleSerialTrainer( + [(first_layer.sublayers, rev_layers)], loss_layer, optimizer_fn + ) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + + # Check that weights end up the same. + self._assert_all_equal(loss_layer_weights1, loss_layer.weights) + self._assert_all_equal(rev_layer2_weights1, rev_layers[2].weights) + self._assert_all_equal(rev_layer0_weights1, rev_layers[0].weights) + self._assert_all_equal(first_layer_weights1, first_layer.weights) + + def test_run_reversible_same_as_default_extended(self): + """Runs the reversible trainers, check results are the same as default.""" + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + # We want to test rng propagation too, so adding some dropout layers. + first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) + rev_layers1 = [ + tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), + tl.ReversibleSwap(), + tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), + tl.ReversibleSwap(), + ] + mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) + rev_layers2 = [ + tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), + tl.ReversibleSwap(), + ] + loss_layer = tl.Serial( + tl.Concatenate(), + tl.Dense(19), + tl.Dropout(0.3), + tl.LogSoftmax(), + tl.CrossEntropyLoss(), + ) + model = tl.Serial( + [first_layer] + rev_layers1 + [mid_layer] + rev_layers2 + [loss_layer] + ) + rng_init = fastmath.random.get_prng(12) + model.init(labeled_batch, rng=rng_init) + optimizer_fn = optimizers.Adam # to test slots + + # Make 3 steps with the original trainers. + optimizer = optimizer_fn() + optimizer.tree_init(model.weights) + trainer = Trainer(model, optimizer) + rng_step1 = fastmath.random.get_prng(7) + rng_step2 = fastmath.random.get_prng(8) + rng_step3 = fastmath.random.get_prng(9) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + first_layer_weights1 = first_layer.weights + rev_layer12_weights1 = rev_layers1[2].weights + mid_layer_weights1 = mid_layer.weights + rev_layer20_weights1 = rev_layers2[0].weights + loss_layer_weights1 = loss_layer.weights + + # Now make 3 steps with reversible trainers. + model.init(labeled_batch, rng=rng_init) + # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why? + trainer = ReversibleSerialTrainer( + [(first_layer.sublayers, rev_layers1), (mid_layer.sublayers, rev_layers2)], + loss_layer, + optimizer_fn, + memoize_jit=False, + ) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + + # Check that weights end up the same. + self._assert_all_equal(loss_layer_weights1, loss_layer.weights) + self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) + self._assert_all_equal(mid_layer_weights1, mid_layer.weights) + self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) + self._assert_all_equal(first_layer_weights1, first_layer.weights) + + def test_run_reversible_same_as_default_terraformer(self): + """Runs the reversible trainers, check results are the same as default.""" + inputs_batch = np.arange(8).reshape((2, 4)) + 1 + targets_batch = 2 * inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) + input_sig = (int_sig, int_sig, int_sig) + # We want to test rng propagation too, so adding some dropout layers. + model = terraformer.ConfigurableTerraformer( + 20, + d_model=8, + d_ff=32, + n_heads=1, + dropout=0.0, + n_encoder_layers=2, + n_decoder_layers=2, + ff_sparsity=(4, 8, 0.0, 1.0), + pos_type=None, + reversible_encoder=True, + ) + loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) + optimizer_fn = optimizers.Adafactor + blocks, loss_layer = extract_reversible_blocks([model, loss], loss_chunk_size=4) + blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks] + model_with_loss = tl.Serial(model, loss) + rng_init = fastmath.random.get_prng(12) + model_with_loss.init(input_sig, rng=rng_init) + + # Make 3 steps with the original trainers. + optimizer = optimizer_fn() + optimizer.tree_init(model_with_loss.weights) + trainer = Trainer(model_with_loss, optimizer) + rng_step1 = fastmath.random.get_prng(7) + rng_step2 = fastmath.random.get_prng(8) + rng_step3 = fastmath.random.get_prng(9) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + first_weights = blocks_serial[0][0].weights + first_rev_weights = blocks[0][1][0].weights + loss_weights = loss_layer.weights + + # Now make 3 steps with reversible trainers. + model_with_loss.init(input_sig, rng=rng_init) + trainer = ReversibleSerialTrainer(blocks, loss_layer, optimizer_fn) + trainer.one_step(labeled_batch, rng_step1) + trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) + trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) + + # Check that weights end up the same. + self._assert_all_equal(loss_weights, loss_layer.weights) + self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights) + self._assert_all_equal(first_weights, blocks_serial[0][0].weights) + + def test_run_reversible_large_weights(self): + """Runs the reversible trainers with a lot of weights to test memory use.""" + # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU + # and CPU when you run it locally, but it's too big for unit-testing. + ram_limited = True # Set to False to run this test locally. + if fastmath.global_device_count() == 1 and ram_limited: + return + + # Create inputs and rngs. + inputs_batch = np.arange(8).reshape((2, 4)) + targets_batch = inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup()) + rng_init = fastmath.random.get_prng(12) + rng_step = fastmath.random.get_prng(13) + + # Initialize layers. + first_layer.init(labeled_batch, rng=rng_init) + n_layers = 18 # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram + rev_layers = [] + int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) + shape = shapes.ShapeDtype((2, 4, 16 * 1024)) + sig = (shape, shape) + for _ in range(n_layers): + layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024)) + layer.init(sig, rng=rng_init) + layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory + rev_layers.append(layer) + rev_layers.append(tl.ReversibleSwap()) + loss_layer = tl.Serial( + tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss() + ) + loss_layer.init((shape, shape, int_shape, int_shape)) + optimizer_fn = optimizers.Adafactor + + # Make a step with reversible trainers. + trainer = ReversibleSerialTrainer( + [(first_layer, rev_layers)], loss_layer, optimizer_fn + ) + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + # Set to true to run again, e.g., for profiling. + run_twice = False + if run_twice: + t = time.time() + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + print("Took %.3f seconds to run, loss %s" % (time.time() - t, loss)) + + def test_run_reversible_weights_trainsfer_xprof(self): + """Runs the reversible trainers and profiles weight transfer stats.""" + run_this_test = False # We only run this test manually. + if not run_this_test or fastmath.global_device_count() == 1: # TPU only + return + + # Create inputs and rngs. + inputs_batch = np.ones((1024, 128), dtype=np.int32) + targets_batch = inputs_batch + labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) + first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) + rng_init = fastmath.random.get_prng(12) + rng_step = fastmath.random.get_prng(13) + + # Initialize layers. + first_layer.init(labeled_batch, rng=rng_init) + n_layers = 6 + rev_layers = [] + int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) + shape = shapes.ShapeDtype((1024, 128, 1024)) + sig = (shape, shape) + for _ in range(n_layers): + layer = tl.ReversibleHalfResidual(tl.Dense(1024)) + layer.init(sig, rng=rng_init) + layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory + rev_layers.append(layer) + rev_layers.append(tl.ReversibleSwap()) + loss_layer = tl.Serial( + tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss() + ) + loss_layer.init((shape, shape, int_shape, int_shape)) + optimizer_fn = optimizers.SGD + + # Make a step with reversible trainers. + trainer = ReversibleSerialTrainer( + [(first_layer, rev_layers)], loss_layer, optimizer_fn + ) + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + # We profile here. + t = time.time() + loss, _ = trainer.one_step(labeled_batch, rng_step) + self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. + print("Took %.3f seconds to run, loss %s" % (time.time() - t, loss)) + + +if __name__ == "__main__": + config.config_with_absl() + absltest.main() diff --git a/trax/import_test.py b/tests/utils/import_test.py similarity index 69% rename from trax/import_test.py rename to tests/utils/import_test.py index 00051b8fb..07920e91d 100644 --- a/trax/import_test.py +++ b/tests/utils/import_test.py @@ -19,18 +19,15 @@ class ImportTest(absltest.TestCase): + def test_import_trax(self): + try: + # Import trax + import trax # pylint: disable=g-import-not-at-top - def test_import_trax(self): - try: - # Import trax - import trax # pylint: disable=g-import-not-at-top - # Access a few symbols. - dir(trax.fastmath) - dir(trax.layers) - dir(trax.models) - except ImportError as e: - raise e + # Access a few symbols. + except ImportError as e: + raise e -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/tests/utils/shapes_test.py b/tests/utils/shapes_test.py new file mode 100644 index 000000000..d333fb64e --- /dev/null +++ b/tests/utils/shapes_test.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.shapes.""" +import numpy as np + +from absl.testing import absltest + +from trax.utils import shapes +from trax.utils.shapes import ShapeDtype + + +class ShapesTest(absltest.TestCase): + def test_constructor_and_read_properties(self): + sd = ShapeDtype((2, 3), np.int32) + self.assertEqual(sd.shape, (2, 3)) + self.assertEqual(sd.dtype, np.int32) + + def test_default_dtype_is_float32(self): + sd = ShapeDtype((2, 3)) + self.assertEqual(sd.shape, (2, 3)) + self.assertEqual(sd.dtype, np.float32) + + def test_signature_on_ndarray(self): + array = np.array([[2, 3, 5, 7], [11, 13, 17, 19]], dtype=np.int16) + sd = shapes.signature(array) + self.assertEqual(sd.shape, (2, 4)) + self.assertEqual(sd.dtype, np.int16) + + def test_shape_dtype_repr(self): + sd = ShapeDtype((2, 3)) + repr_string = "{}".format(sd) + self.assertEqual( + repr_string, "ShapeDtype{shape:(2, 3), dtype:}" + ) + + def test_splice_signatures(self): + sd1 = ShapeDtype((1,)) + sd2 = ShapeDtype((2,)) + sd3 = ShapeDtype((3,)) + sd4 = ShapeDtype((4,)) + sd5 = ShapeDtype((5,)) + + # Signatures can be ShapeDtype instances, tuples of 2+ ShapeDtype instances, + # or empty tuples. + sig1 = sd1 + sig2 = (sd2, sd3, sd4) + sig3 = () + sig4 = sd5 + spliced = shapes.splice_signatures(sig1, sig2, sig3, sig4) + self.assertEqual(spliced, (sd1, sd2, sd3, sd4, sd5)) + + def test_len_signature(self): + """Signatures of all sizes should give correct length when asked.""" + x1 = np.array([1, 2, 3]) + x2 = np.array([10, 20, 30]) + inputs0 = () + inputs1 = x1 # NOT in a tuple + inputs2 = (x1, x2) + + sig0 = shapes.signature(inputs0) + sig1 = shapes.signature(inputs1) + sig2 = shapes.signature(inputs2) + + # pylint: disable=g-generic-assert + self.assertEqual(len(sig0), 0) + self.assertEqual(len(sig1), 1) + self.assertEqual(len(sig2), 2) + # pylint: enable=g-generic-assert + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/utils/trax2keras_test.py b/tests/utils/trax2keras_test.py new file mode 100644 index 000000000..027bd2505 --- /dev/null +++ b/tests/utils/trax2keras_test.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax2keras.""" + +import os + +import numpy as onp +import tensorflow.compat.v2 as tf + +from absl.testing import absltest, parameterized + +import trax + +from trax import fastmath as math_lib +from trax import layers +from trax.fastmath import numpy as jnp +from trax.models import mlp, transformer +from trax.utils import trax2keras +from trax.utils.trax2keras import read_values, to_arrays, to_tensors + +tf.enable_v2_behavior() + + +def has_gpu(): + return bool(tf.config.list_physical_devices("GPU")) + + +def dummy_inputs(rng, input_sig): + def f(sig): + shape = sig.shape + if shape and shape[0] is None: + shape = (2,) + tuple(shape[1:]) + if onp.issubdtype(sig.dtype, onp.integer): + minval = 1 + # Must specify maxval for integer dtype. + # TODO(afrozm): Revisit after TF 2.3 + maxval = 10000 + else: + minval = 0 + maxval = 1 + return rng.uniform(shape=shape, dtype=sig.dtype, minval=minval, maxval=maxval) + + return math_lib.nested_map(f, input_sig) + + +def Mod(n): # pylint: disable=invalid-name + return layers.Fn("Mod", lambda x: x % n) + + +# Format: +# (trax-layer maker, input shapes, input dtype, can handle None batch size?) +_LAYERS = [ + (lambda: layers.Dense(3), tf.TensorShape([4]), onp.float32, True), + (mlp.MLP, tf.TensorShape([4]), onp.float32, False), + ( + lambda: layers.Serial(Mod(8), transformer.TransformerLM(8)), + tf.TensorShape([4]), + onp.int32, + False, + ), +] + +_RNG_UPDATERS = [ + lambda x: x, + lambda rng: math_lib.random.split(rng, 1)[0], +] + + +# Needs tf.test.TestCase for `assertAllClose` and `get_temp_dir` +class Trax2KerasTest(tf.test.TestCase, parameterized.TestCase): + def _assert_nested_close(self, expected, actual, **kwargs): + """Assert that two pytrees match element-wise.""" + + tf.nest.assert_same_structure(expected, actual) + for exp_leaf, act_leaf in zip( + tf.nest.flatten(expected), tf.nest.flatten(actual) + ): + exp_tensor = tf.convert_to_tensor(exp_leaf) + act_tensor = tf.convert_to_tensor(act_leaf) + super().assertAllClose(exp_tensor, act_tensor, **kwargs) + + @parameterized.named_parameters( + [ + { + "testcase_name": "_%s_%s_%s_%s_%s_%s" + % ( # pylint: disable=g-complex-comprehension + layer_id, + rng_updater_id, + batch_size, + trax_has_weights, + explicit_build, + use_model, + ), + "layer_id": layer_id, + "rng_updater_id": rng_updater_id, + "batch_size": batch_size, + "trax_has_weights": trax_has_weights, + "explicit_build": explicit_build, + "use_model": use_model, + } + for use_model in [True, False] + for explicit_build in [True, False] + for trax_has_weights in [True, False] + for batch_size in [2, None] + for rng_updater_id in [1] + for layer_id in range(len(_LAYERS)) + ] + ) + def testTrain( + self, + layer_id, + rng_updater_id, + batch_size, + trax_has_weights, + explicit_build, + use_model, + ): + """Tests training (forward and backward pass) for AsKeras. + + Args: + layer_id: an integer, the index into `_LAYERS`. + rng_updater_id: an integer, the index into `_RNG_UPDATERS`. + batch_size: an integer or `None`, the value for the `batch_size` argument + in `AsKeras.__init__`. + trax_has_weights: bool, whether to make the trax layer contain weights at + the time when `AsKeras.build` is called. + explicit_build: bool, whether to explicitly call `AsKeras.build`. + use_model: bool, whether to build a `tf.keras.Model` out of the + `AsKeras` layer and use the model to do the training instead of + the bare layer. If `True`, we will also test checkpointing and restoring + using the model. + """ + with trax.fastmath.use_backend("tensorflow-numpy"): + make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = _LAYERS[ + layer_id + ] + # We make a fresh trax layer for each test case, so that different test + # cases won't interfere with each other. + trax_layer = make_trax_layer() + if not allow_none_batch and batch_size is None: + self.skipTest("This Trax layer can't handle None batch size.") + rng_updater = _RNG_UPDATERS[rng_updater_id] + input_shapes = math_lib.nested_map( + lambda s: [batch_size] + s, input_shapes_no_batch + ) + input_sig = trax2keras.tensor_shapes_to_shape_dtypes(input_shapes, dtype) + initializer_rng = math_lib.random.get_prng(765) + weights, state = trax_layer.init(input_sig, rng=initializer_rng) + generator = tf.random.Generator.from_seed(567) + + def get_inputs(): + return dummy_inputs(generator, input_sig) + + if trax_has_weights: + trax_layer(to_arrays(get_inputs()), weights=weights, state=state) + rng = math_lib.random.get_prng(1234) + keras_layer = trax2keras.AsKeras( + trax_layer, + batch_size=batch_size, + initializer_rng=initializer_rng, + rng=rng, + rng_updater=rng_updater, + ) + if explicit_build: + keras_layer.build(input_shapes) + if use_model: + x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype) + y = keras_layer(x) + keras_model = tf.keras.Model(inputs=x, outputs=y) + lr = 0.1 # learning rate + for _ in range(3): + inputs = get_inputs() + with tf.GradientTape() as trax_tape: + trax_tape.watch(tf.nest.flatten(weights)) + trax_outputs, state = trax_layer.pure_fn( + to_arrays(inputs), weights=weights, state=state, rng=rng + ) + trax_grads = trax_tape.gradient(*to_tensors([trax_outputs, weights])) + # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor` + # before multiplication. + weights = tf.nest.map_structure( + lambda w, g: w + jnp.asarray(lr * tf.convert_to_tensor(g), w.dtype), + weights, + trax_grads, + ) + rng = rng_updater(rng) + with tf.GradientTape() as keras_tape: + if use_model: + keras_outputs = keras_model(inputs) + else: + keras_outputs = keras_layer(inputs) + if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1: + keras_outputs = keras_outputs[0] + self.assertAllClose(to_tensors(trax_outputs), keras_outputs, atol=1e-5) + keras_grads = keras_tape.gradient( + keras_outputs, keras_layer.trainable_variables + ) + tf.nest.map_structure( + lambda v, g: v.assign_add( # pylint: disable=g-long-lambda + tf.cast(lr * tf.convert_to_tensor(g), v.dtype) + ), + keras_layer.trainable_variables, + keras_grads, + ) + self._assert_nested_close( + to_tensors(weights), + read_values(keras_layer._weights), + rtol=2e-6, + atol=4.5e-4 if has_gpu() else 1e-6, + ) + self._assert_nested_close( + to_tensors(state), read_values(keras_layer._state) + ) + self._assert_nested_close( + to_tensors(rng), read_values(keras_layer._rng) + ) + if use_model: + fname = os.path.join(self.get_temp_dir(), "checkpoint") + keras_model.save(fname) + loaded_model = tf.keras.models.load_model(fname) + for _ in range(2): + inputs = get_inputs() + self.assertAllClose(keras_model(inputs), loaded_model(inputs)) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/__init__.py b/trax/__init__.py index 8747ec520..48a3348d2 100644 --- a/trax/__init__.py +++ b/trax/__init__.py @@ -14,13 +14,3 @@ # limitations under the License. """Trax top level import.""" - -from trax import data -from trax import fastmath -from trax import layers -from trax import models -from trax import optimizers -from trax import shapes -from trax import supervised -from trax.supervised import lr_schedules as lr -from trax.trax2keras import AsKeras diff --git a/trax/data/__init__.py b/trax/data/__init__.py index 9f1ed919b..614272206 100644 --- a/trax/data/__init__.py +++ b/trax/data/__init__.py @@ -31,74 +31,10 @@ """ -from trax.data.debug_data_pipeline import debug_pipeline -from trax.data.inputs import add_loss_weights -from trax.data.inputs import addition_inputs -from trax.data.inputs import AddLossWeights -from trax.data.inputs import AppendValue -from trax.data.inputs import batch -from trax.data.inputs import Batch -from trax.data.inputs import bucket_by_length -from trax.data.inputs import BucketByLength -from trax.data.inputs import CastTo -from trax.data.inputs import ConcatenateToLMInput -from trax.data.inputs import consume_noise_mask -from trax.data.inputs import CountAndSkip -from trax.data.inputs import Dup -from trax.data.inputs import FilterByLength -from trax.data.inputs import FilterEmptyExamples -from trax.data.inputs import generate_random_noise_mask -from trax.data.inputs import generate_sequential_chunks -from trax.data.inputs import Log -from trax.data.inputs import MLM -from trax.data.inputs import PadToLength -from trax.data.inputs import Parallel -from trax.data.inputs import Prefetch -from trax.data.inputs import PrefixLM -from trax.data.inputs import random_spans_noise_mask -from trax.data.inputs import sequence_copy_inputs -from trax.data.inputs import Serial -from trax.data.inputs import shuffle -from trax.data.inputs import Shuffle -from trax.data.inputs import simple_sequence_copy_inputs -from trax.data.inputs import sine_inputs -from trax.data.inputs import TruncateToLength -from trax.data.inputs import UnBatch -from trax.data.inputs import UniformlySeek -from trax.data.tf_inputs import add_eos_to_output_features -from trax.data.tf_inputs import BertGlueEvalStream -from trax.data.tf_inputs import BertGlueTrainStream -from trax.data.tf_inputs import BertNextSentencePredictionInputs -from trax.data.tf_inputs import cifar10_augmentation_flatten_preprocess -from trax.data.tf_inputs import cifar10_augmentation_preprocess -from trax.data.tf_inputs import ConvertToUnicode -from trax.data.tf_inputs import CorpusToRandomChunks -from trax.data.tf_inputs import CreateAnnotatedDropInputs -from trax.data.tf_inputs import CreateAquaInputs -from trax.data.tf_inputs import CreateBertInputs -from trax.data.tf_inputs import CreateDropInputs -from trax.data.tf_inputs import CreateMathQAInputs -from trax.data.tf_inputs import data_streams -from trax.data.tf_inputs import detokenize -from trax.data.tf_inputs import downsampled_imagenet_flatten_bare_preprocess -from trax.data.tf_inputs import filter_dataset_on_len -from trax.data.tf_inputs import lm1b_preprocess -from trax.data.tf_inputs import mask_random_tokens -from trax.data.tf_inputs import sentencepiece_tokenize -from trax.data.tf_inputs import SentencePieceTokenize -from trax.data.tf_inputs import squeeze_targets_preprocess -from trax.data.tf_inputs import T5GlueEvalStream -from trax.data.tf_inputs import T5GlueEvalStreamsParallel -from trax.data.tf_inputs import T5GlueEvalTasks -from trax.data.tf_inputs import T5GlueTrainStream -from trax.data.tf_inputs import T5GlueTrainStreamsParallel -from trax.data.tf_inputs import TFDS -from trax.data.tf_inputs import tokenize -from trax.data.tf_inputs import Tokenize -from trax.data.tf_inputs import truncate_dataset_on_len -from trax.data.tf_inputs import vocab_size -from trax.data.tf_inputs import wmt_concat_preprocess -from trax.data.tf_inputs import wmt_preprocess + + + + diff --git a/trax/data/benchamrking/__init__.py b/trax/data/benchamrking/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/benchamrking/benchmark.py b/trax/data/benchamrking/benchmark.py new file mode 100644 index 000000000..af4add86c --- /dev/null +++ b/trax/data/benchamrking/benchmark.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow data sources and associated prepocessing functions.""" + + +import gin +import tensorflow_datasets as tfds + +from trax import data +from trax import layers as tl +from trax.data.encoder import encoder as tokenizer +from trax.data.loader.tf.base import TFDS, generic_text_dataset_preprocess_fn, t5_data +from trax.data.preprocessing.tf import bert as bert +from trax.fastmath import numpy as jnp +from trax.learning import supervised + +_GLUE_KEYS = { + "cola": ("sentence",), + "sst2": ("sentence",), + "mrpc": ("sentence1", "sentence2"), + "qqp": ("question1", "question2"), + "stsb": ("sentence1", "sentence2"), + "mnli": ("premise", "hypothesis"), + "qnli": ("question", "sentence"), + "rte": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + + +# Labels inferred from the T5 paper: https://arxiv.org/pdf/1910.10683.pdf +_GLUE_LABELS = { + "cola": ("unacceptable", "acceptable"), + "sst2": ("negative", "positive"), + "mrpc": ("not_equivalent", "equivalent"), + "qqp": ("not_duplicate", "duplicate"), + "stsb": ("sentence1", "sentence2"), + "mnli": ("entailment", "neutral", "contradiction"), + "qnli": ("entailment", "not_entailment"), + "rte": ("entailment", "not_entailment"), + "wnli": ("sentence1", "sentence2"), +} + +# Defining separate TrainStream and EvalStream functions (below) +# makes gin configuration expressions more direct. A single gin line can +# configure each; for example: +# +# BertGlueTrainStream.benchmark= 'mnli' +# BertGlueEvalStream.benchmark = 'mnli' + + +# pylint: disable=invalid-name +@gin.configurable(module="trax.data") +def BertGlueTrainStream(benchmark=gin.REQUIRED): + """Returns a Bert-preprocessed training stream for ``benchmark``. + + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. + """ + return _BertGlueDataStream(benchmark + "_t") + + +# GLUE evals need special handling because one eval in particular, MNLI, has +# two different eval sets: "matched" and "mismatched". The code in this module +# distinguishes between the two using the suffixes '_e' versus '_e2', +# respectively. +def _ensure_eval_suffix(benchmark): + """Returns a string ending in an eval suffix; adds ``'_e'`` suffix if needed. + + Args: + benchmark: Name of a benchmark or task, that might already include an + eval-indicating suffix (``'_e'`` or ``'_e2'``). + """ + if benchmark.endswith("_e") or benchmark.endswith("_e2"): + return benchmark + else: + return benchmark + "_e" + + +@gin.configurable(module="trax.data") +def BertGlueEvalStream(benchmark=gin.REQUIRED): + """Returns a Bert-preprocessed eval data stream for ``benchmark``. + + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. If the benchmark includes an alternate + eval (e.g., MNLI's "mismatched" eval/validation split), you can + specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. + """ + return _BertGlueDataStream(_ensure_eval_suffix(benchmark)) + + +def _BertGlueDataStream(benchmark_id): + """Returns a Bert-preprocessed data stream for ``benchmark_id``. + + Args: + benchmark_id: String that indicates the name and data split of a GLUE + benchmark. Data splits are indicated as underscore suffixes, e.g., + ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE + benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, + alternate "mismatched" eval/validation split). + """ + benchmark_id = _ensure_eval_suffix(benchmark_id) + benchmark, split = benchmark_id.rsplit("_", 1) + glue_data = TFDS( + f"glue/{benchmark}", + keys=_GLUE_KEYS[benchmark], + train=(split == "t"), + use_alt_eval=(split == "e2"), + ) + return data.Serial( + glue_data, + tokenizer.Tokenize(), + bert.CreateBertInputs(), + data.Shuffle(), + data.PadToLength(), + data.TruncateToLength(), + data.Batch(), + ) + + +@gin.configurable(module="trax.data") +def T5GlueTrainStream(benchmark=gin.REQUIRED): + """Returns a T5-preprocessed training data stream for ``benchmark``. + + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. + """ + return _T5GlueDataStream(benchmark + "_t") + + +@gin.configurable(module="trax.data") +def T5GlueTrainStreamsParallel( + benchmark_list=gin.REQUIRED, + counters=None, + reweight_by_minimum=False, + gradually_reweight=False, +): + """Returns a parallel set of training streams, based on ``benchmark_list``. + + Args: + benchmark_list: List of simple lower-case names of GLUE benchmarks, e.g., + ``'cola'``, ``'mnli'``, ``'rte'``. + counters: a list of counters to be passed to data.Parallel, e.g., + [8551, 392702, 2490] would be a reasonable counterpart to + benchmark_list = ["cola", "mnli", "rte"], see + https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/glue_utils.py#L42 + for more details on counters. + reweight_by_minimum: divide by the minimal counter. + gradually_reweight: a more refined reweighting policy, see inputs.py + for more details. + """ + stream_list = list(map(T5GlueTrainStream, benchmark_list)) + return data.Parallel( + stream_list, + counters=counters, + reweight_by_minimum=reweight_by_minimum, + gradually_reweight=gradually_reweight, + )() + + +@gin.configurable(module="trax.data") +def T5GlueEvalStream(benchmark=gin.REQUIRED): + """Returns a T5-preprocessed eval data stream for ``benchmark``. + + Args: + benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, + ``'mnli'``, ``'rte'``. If the benchmark includes an alternate + eval (e.g., MNLI's "mismatched" eval/validation split), you can + specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. + """ + return _T5GlueDataStream(_ensure_eval_suffix(benchmark)) + + +@gin.configurable(module="trax.data") +def T5GlueEvalStreamsParallel(benchmark_list=gin.REQUIRED): + """Returns a parallel set of T5 eval streams, based on ``benchmark_list``. + + Args: + benchmark_list: List of strings, each of which is a simple lower-case name + of a GLUE benchmark, e.g., ``'cola'``, ``'mnli'``, ``'rte'``. If a + benchmark includes an alternate eval (e.g., MNLI's "mismatched" + eval/validation split), you can specify it with an ``'_e2'`` suffix, + e.g., ``'mnli_e2'``. + """ + stream_list = list(map(T5GlueEvalStream, benchmark_list)) + return data.Parallel(stream_list)() + + +def _T5GlueDataStream(benchmark_id, t5_tokenization=False): + """Returns a T5-preprocessed data stream for ``benchmark_id``. + + Args: + benchmark_id: String that indicates the name and data split of a GLUE + benchmark. Data splits are indicated as underscore suffixes, e.g., + ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE + benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, + alternate "mismatched" eval/validation split). + t5_tokenization: if true, then use t5_tokenization. + """ + return data.Serial( + _t5_glue_data_split(benchmark_id) + if t5_tokenization + else _t5_glue_data_split_no_token(benchmark_id), + tokenizer.Tokenize(), + data.Shuffle(), + data.PadToLength(), + data.TruncateToLength(), + data.Batch(), + ) + + +@gin.configurable(module="trax.data") +def T5GlueEvalTasks(benchmark_list=gin.REQUIRED): + """Returns a list of T5 GLUE eval tasks, based on ``benchmark_list``. + + Args: + benchmark_list: List of strings, each of which indicates the name and + data split of a GLUE benchmark. Data splits are indicated as underscore + suffixes, e.g., ``'cola_t'`` (Cola benchmark, training split), + ``'rte_e'`` (RTE benchmark, eval/validation split), and ``'mnli_e2'`` + (MNLI alternate "mismatched" eval/validation split). + """ + task_list = list(map(_T5GlueEvalTask, benchmark_list)) + return task_list + + +def _T5GlueEvalTask(benchmark_id): + """Returns a T5 GLUE eval task, based on ``benchmark_id``.""" + eval_data = T5GlueEvalStream(benchmark_id) + benchmark_id = _ensure_eval_suffix(benchmark_id) + metrics = [tl.WeightedCategoryAccuracy(), tl.SequenceAccuracy()] + benchmark, split = benchmark_id.rsplit("_", 1) + if benchmark == "cola": + name_upper = "Cola" + elif benchmark == "mnli": + name_upper = "MNLI_matched" if split == "e" else "MNLI_mismatched" + else: + name_upper = benchmark.upper() + return supervised.training.EvalTask( + eval_data(), + metrics, + metric_names=[f"{name_upper} accuracy", f"{name_upper} sequence accuracy"], + ) + + +def _t5_glue_data_split_no_token(benchmark_id): + """Returns a GLUE data split prepared with the standard T5 preprocessor.""" + benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) + dataset = tfds.load(name=f"glue/{benchmark}", split=split) + processed_dataset = t5_data().preprocessors.glue( # pylint: disable=g-long-lambda + dataset, benchmark_name=benchmark, label_names=_GLUE_LABELS[benchmark] + ) + + def stream_of_inputs_targets_weights(generator=None): + del generator + while True: + for example in processed_dataset: + input_values = example["inputs"].numpy() + target_values = example["targets"].numpy() + yield (input_values, target_values, jnp.array([1] * len(target_values))) + + return stream_of_inputs_targets_weights + + +def _t5_glue_data_split(benchmark_id): + """Returns a GLUE data split prepared with the standard T5 preprocessor.""" + benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) + dataset = tfds.load(name=f"glue/{benchmark}", split=split) + processed_dataset = generic_text_dataset_preprocess_fn( + dataset, + spm_path=t5_data().DEFAULT_SPM_PATH, + text_preprocess_fns=[ + lambda ds, training: t5_data().preprocessors.glue( # pylint: disable=g-long-lambda + ds, benchmark_name=benchmark, label_names=_GLUE_LABELS[benchmark] + ) + ], + copy_pretokenized=True, + debug_print_examples=True, + debug_print_examples_rate=0.05, + ) + dataset_as_numpy = tfds.as_numpy(processed_dataset) + + def stream_of_inputs_targets_weights(generator=None): + del generator + while True: + for example in dataset_as_numpy: + input_values = example["inputs"] + target_values = example["targets"] + yield ( + jnp.array(input_values), + jnp.array(target_values), + jnp.array([1] * len(target_values)), + ) + + return stream_of_inputs_targets_weights + + +def _t5_glue_benchmark_and_split(benchmark_id): + benchmark, mode = benchmark_id.rsplit("_", 1) + if mode == "t": + split = "train" + elif benchmark == "mnli": + split = "validation_mismatched" if mode == "e2" else "validation_matched" + else: + split = "validation" + return benchmark, split diff --git a/trax/data/builder/__init__.py b/trax/data/builder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/builder/subword.py b/trax/data/builder/subword.py new file mode 100644 index 000000000..654c0f836 --- /dev/null +++ b/trax/data/builder/subword.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Program to build a SubwordTextEncoder. + +The flags --min_count and --corpus_max_lines will affect the size of the +vocabulary. Try changing these flags until you get a vocabulary +of the size you want. + +Example usage: + +python trax/data/subword.py \ + --corpus_filepattern=$DATA_DIR/my_problem-train-* \ + --corpus_max_lines=12345 \ + --output_filename=$DATA_DIR/my_problem.subword_text_encoder \ + --logtostderr + +""" + +from absl import app, flags + +from trax.data.encoder import encoder as text_encoder +from trax.data.encoder import encoder as tokenizer + +flags.DEFINE_string( + "output_filename", + "/tmp/my.subword_text_encoder", + "where to store the SubwordTextEncoder", +) +flags.DEFINE_string("corpus_filepattern", "", "Corpus of one or more text files") +flags.DEFINE_string( + "vocab_filepattern", + "", + "One or more vocabulary files " '(one word per line as "word,count")', +) +flags.DEFINE_integer("min_count", 5, "Minimum subtoken count in corpus") +flags.DEFINE_integer("corpus_max_lines", 10000, "How many lines of corpus to read") +flags.DEFINE_integer("num_iterations", 4, "Number of iterations") +flags.DEFINE_bool("split_on_newlines", True, "Break corpus into lines.") + +FLAGS = flags.FLAGS + + +def main(unused_argv): + if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern: + raise ValueError( + "Must only provide one of --corpus_filepattern or --vocab_filepattern" + ) + + elif FLAGS.corpus_filepattern: + token_counts = tokenizer.corpus_token_counts( + FLAGS.corpus_filepattern, + FLAGS.corpus_max_lines, + split_on_newlines=FLAGS.split_on_newlines, + ) + + elif FLAGS.vocab_filepattern: + token_counts = tokenizer.vocab_token_counts( + FLAGS.vocab_filepattern, FLAGS.corpus_max_lines + ) + + else: + raise ValueError( + "Must provide one of --corpus_filepattern or --vocab_filepattern" + ) + + encoder = text_encoder.SubwordTextEncoder() + encoder.build_from_token_counts(token_counts, FLAGS.min_count, FLAGS.num_iterations) + encoder.store_to_file(FLAGS.output_filename) + + +if __name__ == "__main__": + app.run(main) diff --git a/trax/data/debugger/__init__.py b/trax/data/debugger/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/debug_data_pipeline.py b/trax/data/debugger/data_pipeline.py similarity index 52% rename from trax/data/debug_data_pipeline.py rename to trax/data/debugger/data_pipeline.py index 7506149cc..15d0b0ece 100644 --- a/trax/data/debug_data_pipeline.py +++ b/trax/data/debugger/data_pipeline.py @@ -17,25 +17,27 @@ import functools -from absl import logging import gin +from absl import logging + + +@gin.configurable(denylist=["f"]) +def debug_pipeline(f, debug=False, method="pow", log_prefix=None): + """Decorator for input pipeline generators that logs examples at intervals.""" + if not debug: + return f + + assert method in ("pow", "every") + + @functools.wraps(f) + def wrapper(*args, **kwargs): + count = 0 + prefix = log_prefix or f.__name__ + for example in f(*args, **kwargs): + count += 1 + if method == "every" or (method == "pow" and (count & count - 1 == 0)): + logging.info("%s example[%d] = %r", prefix, count, example) + yield example -@gin.configurable(denylist=['f']) -def debug_pipeline(f, debug=False, method='pow', log_prefix=None): - """Decorator for input pipeline generators that logs examples at intervals.""" - if not debug: - return f - - assert method in ('pow', 'every') - @functools.wraps(f) - def wrapper(*args, **kwargs): - count = 0 - prefix = log_prefix or f.__name__ - for example in f(*args, **kwargs): - count += 1 - if method == 'every' or (method == 'pow' and (count & count - 1 == 0)): - logging.info('%s example[%d] = %r', prefix, count, example) - yield example - - return wrapper + return wrapper diff --git a/trax/data/encoder/__init__.py b/trax/data/encoder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/encoder/encoder.py b/trax/data/encoder/encoder.py new file mode 100644 index 000000000..add5e73f9 --- /dev/null +++ b/trax/data/encoder/encoder.py @@ -0,0 +1,1767 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Encoders for text data. + +* TextEncoder: base class +* ByteTextEncoder: for ascii text +* TokenTextEncoder: with user-supplied vocabulary file +* SubwordTextEncoder: invertible +* BertEncoder: for compatible tokenizers with original bert +""" +import collections +import itertools +import math +import os +import re +import sys +import tempfile +import time +import unicodedata + +import gin +import numpy as np +import six +import tensorflow as tf +import tensorflow_text as tft + +from absl import logging + +from trax.data.debugger import data_pipeline as debug_data_pipeline + +# Reserved tokens for things like padding and EOS symbols. +PAD = "" +EOS = "" +RESERVED_TOKENS = [PAD, EOS] +NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) +PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 +EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 +RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] + +# Regular expression for unescaping token strings. +# '\u' is converted to '_' +# '\\' is converted to '\' +# '\213;' is converted to unichr(213) +_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") +_ESCAPE_CHARS = set("\\_u;0123456789") + +# This set contains all letter and number characters. +_ALPHANUMERIC_CHAR_SET = set( + six.unichr(i) + for i in range(sys.maxunicode) + if ( + unicodedata.category(six.unichr(i)).startswith("L") + or unicodedata.category(six.unichr(i)).startswith("N") + ) +) + + +# Unicode utility functions that work with Python 2 and 3 +def native_to_unicode(s): + if is_unicode(s): + return s + try: + return to_unicode(s) + except UnicodeDecodeError: + res = to_unicode(s, ignore_errors=True) + logging.info("Ignoring Unicode error, outputting: %s", res) + return res + + +def is_unicode(s): + return isinstance(s, six.text_type) + + +def to_unicode(s, ignore_errors=False): + if is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) + + +def to_unicode_ignore_errors(s): + return to_unicode(s, ignore_errors=True) + + +def to_unicode_utf8(s): + return s.decode("utf-8") + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +def strip_ids(ids, ids_to_strip): + """Strip ids_to_strip from the end IDs.""" + ids = list(ids) + while ids and ids[-1] in ids_to_strip: + ids.pop() + return ids + + +def _escape_token(token, alphabet): + """Escape away underscores and OOV characters and append '_'. + + This allows the token to be expressed as the concatenation of a list + of subtokens from the vocabulary. The underscore acts as a sentinel + which allows us to invertibly concatenate multiple such lists. + + Args: + token: A unicode string to be escaped. + alphabet: A set of all characters in the vocabulary's alphabet. + + Returns: + escaped_token: An escaped unicode string. + + Raises: + ValueError: If the provided token is not unicode. + """ + if not isinstance(token, six.text_type): + raise ValueError("Expected string type for token, got %s" % type(token)) + + token = token.replace("\\", "\\\\").replace("_", "\\u") + ret = [c if c in alphabet and c != "\n" else r"\%d;" % ord(c) for c in token] + return "".join(ret) + "_" + + +def _unescape_token(escaped_token): + """Inverse of _escape_token(). + + Args: + escaped_token: a unicode string + + Returns: + token: a unicode string + """ + + def match(m): + if m.group(1) is None: + return "_" if m.group(0) == "\\u" else "\\" + + try: + return six.unichr(int(m.group(1))) + except (ValueError, OverflowError) as _: + return "\u3013" # Unicode for undefined character. + + trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token + return _UNESCAPE_REGEX.sub(match, trimmed) + + +def _bert_is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _bert_is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _bert_is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class, but we treat them as punctuation anyway, for + # consistency. + if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True): + """Reads files matching a wildcard pattern, yielding the contents. + + Args: + filepattern: A wildcard pattern matching one or more files. + max_lines: If set, stop reading after reading this many lines. + split_on_newlines: A boolean. If true, then split files by lines and strip + leading and trailing whitespace from each line. Otherwise, treat each + file as a single string. + + Yields: + The contents of the files as lines, if split_on_newlines is True, or + the entire contents of each file if False. + """ + filenames = sorted(tf.io.gfile.glob(filepattern)) + lines_read = 0 + for filename in filenames: + with tf.io.gfile.GFile(filename) as f: + if split_on_newlines: + for line in f: + yield line.strip() + lines_read += 1 + if max_lines and lines_read >= max_lines: + return + + else: + if max_lines: + doc = [] + for line in f: + doc.append(line) + lines_read += 1 + if max_lines and lines_read >= max_lines: + yield "".join(doc) + return + yield "".join(doc) + + else: + yield f.read() + + +def corpus_token_counts(text_filepattern, corpus_max_lines, split_on_newlines=True): + """Read the corpus and compute a dictionary of token counts. + + Args: + text_filepattern: A pattern matching one or more files. + corpus_max_lines: An integer; maximum total lines to read. + split_on_newlines: A boolean. If true, then split files by lines and strip + leading and trailing whitespace from each line. Otherwise, treat each + file as a single string. + + Returns: + a dictionary mapping token to count. + """ + counts = collections.Counter() + for doc in _read_filepattern( + text_filepattern, + max_lines=corpus_max_lines, + split_on_newlines=split_on_newlines, + ): + counts.update(encode(doc)) + + return counts + + +def vocab_token_counts(text_filepattern, max_lines): + """Read a vocab file and return a dictionary of token counts. + + Reads a two-column CSV file of tokens and their frequency in a dataset. The + tokens are presumed to be generated by encode() or the equivalent. + + Args: + text_filepattern: A pattern matching one or more files. + max_lines: An integer; maximum total lines to read. + + Returns: + a dictionary mapping token to count. + """ + ret = {} + for i, line in enumerate(_read_filepattern(text_filepattern, max_lines=max_lines)): + if "," not in line: + logging.warning("Malformed vocab line #%d '%s'", i, line) + continue + + token, count = line.rsplit(",", 1) + ret[token] = int(count) + + return ret + + +def _get_vocab(vocab_type="subword", vocab_file=None, vocab_dir=None, extra_ids=0): + """Gets the vocabulary object for tokenization; see tokenize for details.""" + if vocab_type not in ["char", "subword", "sentencepiece", "bert", "bert-lowercase"]: + raise ValueError( + 'vocab_type must be "subword", "char", "sentencepiece", "bert" or "bert-lowercase" ' + f"but got {vocab_type}" + ) + + if vocab_type == "char": + # Note that we set num_reserved_ids=0 below. We could instead pass + # the value n_reserved_ids from tokenize here -- ByteTextEncoder does + # exactly the same thing as tokenize above, ie., adds num_reserved_ids. + return ByteTextEncoder(num_reserved_ids=0) + + vocab_dir = vocab_dir or "gs://trax-ml/vocabs/" + path = os.path.join(vocab_dir, vocab_file) + + if vocab_type == "subword": + return SubwordTextEncoder(path) + + if vocab_type == "bert": + return BertEncoder(path, do_lower_case=False) + + if vocab_type == "bert-lowercase": + return BertEncoder(path, do_lower_case=True) + + if vocab_type == "sentencepiece": + return SentencePieceEncoder(path, extra_ids=extra_ids) + + return None + + +def vocab_size(vocab_type="subword", vocab_file=None, vocab_dir=None, n_reserved_ids=0): + """Returns the size of the vocabulary (number of symbols used). + + This function can be used to set the size of the final layers of a model that + needs to predict symbols from a given vocabulary. More precisely, if this + function returns N then the last layer size should be set to at least N (it + can be more). Note that this function does take reserved IDs into account. + + Args: + vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. + vocab_file: Name of the vocabulary file. + vocab_dir: Directory which contains the vocabulary file. + n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused. + + Returns: + An integer, the number of symbols used (including reserved IDs). + """ + vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) + return vocab.vocab_size + n_reserved_ids + + +""" + Encoders and decoders +""" + + +def encode(text): + """Encode a unicode string as a list of tokens. + + Args: + text: a unicode string + Returns: + a list of tokens as Unicode strings + """ + if not text: + return [] + ret = [] + token_start = 0 + # Classify each character in the input string + is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] + for pos in range(1, len(text)): + if is_alnum[pos] != is_alnum[pos - 1]: + token = text[token_start:pos] + if token != " " or token_start == 0: + ret.append(token) + token_start = pos + final_token = text[token_start:] + ret.append(final_token) + return ret + + +def decode(tokens): + """Decode a list of tokens to a unicode string. + + Args: + tokens: a list of Unicode strings + Returns: + a unicode string + """ + token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] + ret = [] + for i, token in enumerate(tokens): + if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: + ret.append(" ") + ret.append(token) + return "".join(ret) + + +class TextEncoder: + """Base class for converting from ints to/from human readable strings.""" + + def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): + self._num_reserved_ids = num_reserved_ids + + @property + def num_reserved_ids(self): + return self._num_reserved_ids + + def encode(self, s): + """Transform a human-readable string into a sequence of int IDs. + + The IDs should be in the range [num_reserved_ids, vocab_size). IDs [0, + num_reserved_ids) are reserved. + + EOS is not appended. + + Args: + s: human-readable string to be converted. + + Returns: + ids: list of integers + """ + return [int(w) + self._num_reserved_ids for w in s.split()] + + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int IDs into a human-readable string. + + EOS is not expected in IDs. + + Args: + ids: list of integers to be converted. + strip_extraneous: bool, whether to strip off extraneous tokens (EOS and + PAD). + + Returns: + s: human-readable string. + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + """Transform a sequence of int IDs into a their string versions. + + This method supports transforming individual input/output IDs to their + string versions so that sequence to/from text conversions can be visualized + in a human readable format. + + Args: + ids: list of integers to be converted. + + Returns: + strs: list of human-readable string. + """ + decoded_ids = [] + for id_ in ids: + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) + else: + decoded_ids.append(id_ - self._num_reserved_ids) + return [str(d) for d in decoded_ids] + + @property + def vocab_size(self): + raise NotImplementedError() + + +class ByteTextEncoder(TextEncoder): + """Encodes each byte to an id. For 8-bit strings only.""" + + def encode(self, s): + numres = self._num_reserved_ids + # Python3: explicitly convert to UTF-8 + return [c + numres for c in s.encode("utf-8")] + + def decode(self, ids, strip_extraneous=False): + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + # Python3: join byte arrays and then decode string + return b"".join(decoded_ids).decode("utf-8", "replace") + + def decode_list(self, ids): + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + # Python3: join byte arrays and then decode string + return decoded_ids + + @property + def vocab_size(self): + return 2**8 + self._num_reserved_ids + + +class ClassLabelEncoder(TextEncoder): + """Encoder for class labels.""" + + def __init__(self, class_labels=None, class_labels_fname=None): + super(ClassLabelEncoder, self).__init__(num_reserved_ids=0) + + if class_labels_fname: + with tf.io.gfile.GFile(class_labels_fname) as f: + class_labels = [label.strip() for label in f.readlines()] + + assert class_labels + self._class_labels = class_labels + + def encode(self, s): + label_str = s + return self._class_labels.index(label_str) + + def decode(self, ids, strip_extraneous=False): + del strip_extraneous + label_id = ids + if isinstance(label_id, list): + assert len(label_id) == 1 + (label_id,) = label_id + if isinstance(label_id, np.ndarray): + label_id = np.squeeze(label_id) + return self._class_labels[label_id] + + def decode_list(self, ids): + return [self._class_labels[i] for i in ids] + + @property + def vocab_size(self): + return len(self._class_labels) + + +class OneHotClassLabelEncoder(ClassLabelEncoder): + """One-hot encoder for class labels.""" + + def encode( + self, label_str, on_value=1, off_value=0 + ): # pylint: disable=arguments-differ + e = np.full(self.vocab_size, off_value, dtype=np.int32) + e[self._class_labels.index(label_str)] = on_value + return e.tolist() + + def decode(self, ids, strip_extraneous=False): + del strip_extraneous + label_id = ids + if isinstance(label_id, np.ndarray): + label_id = np.squeeze(label_id).astype(np.int8).tolist() + assert isinstance(label_id, list) + assert len(label_id) == self.vocab_size + return self._class_labels[label_id.index(1)] + + @property + def vocab_size(self): + return len(self._class_labels) + + +class TokenTextEncoder(TextEncoder): + """Encoder based on a user-supplied vocabulary (file or list).""" + + def __init__( + self, + vocab_filename, + reverse=False, + vocab_list=None, + replace_oov=None, + num_reserved_ids=NUM_RESERVED_TOKENS, + ): + """Initialize from a file or list, one token per line. + + Handling of reserved tokens works as follows: + - When initializing from a list, we add reserved tokens to the vocab. + - When initializing from a file, we do not add reserved tokens to the vocab. + - When saving vocab files, we save reserved tokens to the file. + + Args: + vocab_filename: If not None, the full filename to read vocab from. If this + is not None, then vocab_list should be None. + reverse: Boolean indicating if tokens should be reversed during encoding + and decoding. + vocab_list: If not None, a list of elements of the vocabulary. If this is + not None, then vocab_filename should be None. + replace_oov: If not None, every out-of-vocabulary token seen when encoding + will be replaced by this string (which must be in vocab). + num_reserved_ids: Number of IDs to save for reserved tokens like . + """ + super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) + self._reverse = reverse + self._replace_oov = replace_oov + if vocab_filename: + self._init_vocab_from_file(vocab_filename) + else: + assert vocab_list is not None + self._init_vocab_from_list(vocab_list) + + def encode(self, s): + """Converts a space-separated string of tokens to a list of ids.""" + sentence = s + tokens = sentence.strip().split() + if self._replace_oov is not None: + tokens = [ + t if t in self._token_to_id else self._replace_oov for t in tokens + ] + ret = [self._token_to_id[tok] for tok in tokens] + return ret[::-1] if self._reverse else ret + + def decode(self, ids, strip_extraneous=False): + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + seq = reversed(ids) if self._reverse else ids + return [self._safe_id_to_token(i) for i in seq] + + @property + def vocab_size(self): + return len(self._id_to_token) + + def _safe_id_to_token(self, idx): + return self._id_to_token.get(idx, "ID_%d" % idx) + + def _init_vocab_from_file(self, filename): + """Load vocab from a file. + + Args: + filename: The file to load vocabulary from. + """ + with tf.io.gfile.GFile(filename) as f: + tokens = [token.strip() for token in f.readlines()] + + def token_gen(): + for token in tokens: + yield token + + self._init_vocab(token_gen(), add_reserved_tokens=False) + + def _init_vocab_from_list(self, vocab_list): + """Initialize tokens from a list of tokens. + + It is ok if reserved tokens appear in the vocab list. They will be + removed. The set of tokens in vocab_list should be unique. + + Args: + vocab_list: A list of tokens. + """ + + def token_gen(): + for token in vocab_list: + if token not in RESERVED_TOKENS: + yield token + + self._init_vocab(token_gen()) + + def _init_vocab(self, token_generator, add_reserved_tokens=True): + """Initialize vocabulary with tokens from token_generator.""" + + self._id_to_token = {} + non_reserved_start_index = 0 + + if add_reserved_tokens: + self._id_to_token.update(enumerate(RESERVED_TOKENS)) + non_reserved_start_index = len(RESERVED_TOKENS) + + self._id_to_token.update( + enumerate(token_generator, start=non_reserved_start_index) + ) + + # _token_to_id is the reverse of _id_to_token + self._token_to_id = dict((v, k) for k, v in six.iteritems(self._id_to_token)) + + def store_to_file(self, filename): + """Write vocab file to disk. + + Vocab files have one token per line. The file ends in a newline. Reserved + tokens are written to the vocab file as well. + + Args: + filename: Full path of the file to store the vocab to. + """ + with tf.io.gfile.GFile(filename, "w") as f: + for i in range(len(self._id_to_token)): + f.write(self._id_to_token[i] + "\n") + + +class SubwordTextEncoder(TextEncoder): + """Class for invertibly encoding text using a limited vocabulary. + + Invertibly encodes a native string as a sequence of subtokens from a limited + vocabulary. + + A SubwordTextEncoder is built from a corpus (so it is tailored to the text in + the corpus), and stored to a file. See subword.py. + + It can then be loaded and used to encode/decode any text. + + Encoding has four phases: + + 1. Tokenize into a list of tokens. Each token is a unicode string of either + all alphanumeric characters or all non-alphanumeric characters. We drop + tokens consisting of a single space that are between two alphanumeric + tokens. + + 2. Escape each token. This escapes away special and out-of-vocabulary + characters, and makes sure that each token ends with an underscore, and + has no other underscores. + + 3. Represent each escaped token as a the concatenation of a list of subtokens + from the limited vocabulary. Subtoken selection is done greedily from + beginning to end. That is, we construct the list in order, always picking + the longest subtoken in our vocabulary that matches a prefix of the + remaining portion of the encoded token. + + 4. Concatenate these lists. This concatenation is invertible due to the + fact that the trailing underscores indicate when one list is finished. + + """ + + def __init__(self, filename=None): + """Initialize and read from a file, if provided. + + Args: + filename: filename from which to read vocab. If None, do not load a vocab + """ + self._alphabet = set() + self.filename = filename + if filename is not None: + self._load_from_file(filename) + super(SubwordTextEncoder, self).__init__() + + def encode(self, s): + """Converts a native string to a list of subtoken IDs. + + Args: + s: a native string. + + Returns: + a list of integers in the range [0, vocab_size) + """ + return self._tokens_to_subtoken_ids(encode(native_to_unicode(s))) + + def encode_without_tokenizing(self, token_text): + """Converts string to list of subtoken IDs without calling tokenizer. + + This treats `token_text` as a single token and directly converts it + to subtoken IDs. This may be useful when the default tokenizer doesn't + do what we want (e.g., when encoding text with tokens composed of lots of + nonalphanumeric characters). It is then up to the caller to make sure that + raw text is consistently converted into tokens. Only use this if you are + sure that `encode` doesn't suit your needs. + + Args: + token_text: A native string representation of a single token. + + Returns: + A list of subword token IDs; i.e., integers in the range [0, vocab_size). + """ + return self._tokens_to_subtoken_ids([native_to_unicode(token_text)]) + + def decode(self, ids, strip_extraneous=False): + """Converts a sequence of subtoken IDs to a native string. + + Args: + ids: a list of integers in the range [0, vocab_size) + strip_extraneous: bool, whether to strip off extraneous tokens (EOS and + PAD). + + Returns: + a native string + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return decode(self._subtoken_ids_to_tokens(ids)) + + def decode_list(self, ids): + return [self._subtoken_id_to_subtoken_string(s) for s in ids] + + @property + def vocab_size(self): + """The subtoken vocabulary size.""" + return len(self._all_subtoken_strings) + + def _tokens_to_subtoken_ids(self, tokens): + """Converts a list of tokens to a list of subtoken IDs. + + Args: + tokens: a list of strings. + + Returns: + a list of integers in the range [0, vocab_size) + """ + ret = [] + for token in tokens: + ret.extend(self._token_to_subtoken_ids(token)) + return ret + + def _token_to_subtoken_ids(self, token): + """Converts token to a list of subtoken IDs. + + Args: + token: a string. + + Returns: + a list of integers in the range [0, vocab_size) + """ + cache_location = hash(token) % self._cache_size + cache_key, cache_value = self._cache[cache_location] + if cache_key == token: + return cache_value + ret = self._escaped_token_to_subtoken_ids(_escape_token(token, self._alphabet)) + self._cache[cache_location] = (token, ret) + return ret + + def _subtoken_ids_to_tokens(self, subtokens): + """Converts a list of subtoken IDs to a list of tokens. + + Args: + subtokens: a list of integers in the range [0, vocab_size) + + Returns: + a list of strings. + """ + concatenated = "".join( + [self._subtoken_id_to_subtoken_string(s) for s in subtokens] + ) + split = concatenated.split("_") + ret = [] + for t in split: + if t: + unescaped = _unescape_token(t + "_") + if unescaped: + ret.append(unescaped) + return ret + + def _subtoken_id_to_subtoken_string(self, subtoken): + """Converts a subtoken integer ID to a subtoken string.""" + if 0 <= subtoken < self.vocab_size: + return self._all_subtoken_strings[subtoken] + return "" + + def _escaped_token_to_subtoken_strings(self, escaped_token): + """Converts an escaped token string to a list of subtoken strings. + + Args: + escaped_token: An escaped token as a unicode string. + + Returns: + A list of subtokens as unicode strings. + """ + # NOTE: This algorithm is greedy; it won't necessarily produce the "best" + # list of subtokens. + ret = [] + start = 0 + token_len = len(escaped_token) + while start < token_len: + for end in range(min(token_len, start + self._max_subtoken_len), start, -1): + subtoken = escaped_token[start:end] + if subtoken in self._subtoken_string_to_id: + ret.append(subtoken) + start = end + break + + else: # Did not break + # If there is no possible encoding of the escaped token then one of the + # characters in the token is not in the alphabet. This should be + # impossible and would be indicative of a bug. + assert False, "Token substring not found in subtoken vocabulary." + + return ret + + def _escaped_token_to_subtoken_ids(self, escaped_token): + """Converts an escaped token string to a list of subtoken IDs. + + Args: + escaped_token: An escaped token as a unicode string. + + Returns: + A list of subtoken IDs as integers. + """ + return [ + self._subtoken_string_to_id[subtoken] + for subtoken in self._escaped_token_to_subtoken_strings(escaped_token) + ] + + @classmethod + def build_from_generator( + cls, generator, target_size, max_subtoken_length=None, reserved_tokens=None + ): + """Builds a SubwordTextEncoder from the generated text. + + Args: + generator: yields text. + target_size: int, approximate vocabulary size to create. + max_subtoken_length: Maximum length of a subtoken. If this is not set, + then the runtime and memory use of creating the vocab is quadratic in + the length of the longest token. If this is set, then it is instead + O(max_subtoken_length * length of longest token). + reserved_tokens: List of reserved tokens. The global variable + `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this + argument is `None`, it will use `RESERVED_TOKENS`. + + Returns: + SubwordTextEncoder with `vocab_size` approximately `target_size`. + """ + token_counts = collections.defaultdict(int) + for item in generator: + for tok in encode(native_to_unicode(item)): + token_counts[tok] += 1 + encoder = cls.build_to_target_size( + target_size, + token_counts, + 1, + 1e3, + max_subtoken_length=max_subtoken_length, + reserved_tokens=reserved_tokens, + ) + return encoder + + @classmethod + def build_to_target_size( + cls, + target_size, + token_counts, + min_val, + max_val, + max_subtoken_length=None, + reserved_tokens=None, + num_iterations=4, + ): + """Builds a SubwordTextEncoder that has `vocab_size` near `target_size`. + + Uses simple recursive binary search to find a minimum token count that most + closely matches the `target_size`. + + Args: + target_size: Desired vocab_size to approximate. + token_counts: A dictionary of token counts, mapping string to int. + min_val: An integer; lower bound for the minimum token count. + max_val: An integer; upper bound for the minimum token count. + max_subtoken_length: Maximum length of a subtoken. If this is not set, + then the runtime and memory use of creating the vocab is quadratic in + the length of the longest token. If this is set, then it is instead + O(max_subtoken_length * length of longest token). + reserved_tokens: List of reserved tokens. The global variable + `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this + argument is `None`, it will use `RESERVED_TOKENS`. + num_iterations: An integer; how many iterations of refinement. + + Returns: + A SubwordTextEncoder instance. + + Raises: + ValueError: If `min_val` is greater than `max_val`. + """ + if min_val > max_val: + raise ValueError( + "Lower bound for the minimum token count " + "is greater than the upper bound." + ) + if target_size < 1: + raise ValueError("Target size must be positive.") + + if reserved_tokens is None: + reserved_tokens = RESERVED_TOKENS + + def bisect(min_val, max_val): + """Bisection to find the right size.""" + present_count = (max_val + min_val) // 2 + logging.info("Trying min_count %d", present_count) + subtokenizer = cls() + subtokenizer.build_from_token_counts( + token_counts, + present_count, + num_iterations, + max_subtoken_length=max_subtoken_length, + reserved_tokens=reserved_tokens, + ) + + # Being within 1% of the target size is ok. + is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size + # If min_val == max_val, we can't do any better than this. + if is_ok or min_val >= max_val or present_count < 2: + return subtokenizer + + if subtokenizer.vocab_size > target_size: + other_subtokenizer = bisect(present_count + 1, max_val) + else: + other_subtokenizer = bisect(min_val, present_count - 1) + + if other_subtokenizer is None: + return subtokenizer + + if abs(other_subtokenizer.vocab_size - target_size) < abs( + subtokenizer.vocab_size - target_size + ): + return other_subtokenizer + return subtokenizer + + return bisect(min_val, max_val) + + def build_from_token_counts( + self, + token_counts, + min_count, + num_iterations=4, + reserved_tokens=None, + max_subtoken_length=None, + ): + """Train a SubwordTextEncoder based on a dictionary of word counts. + + Args: + token_counts: a dictionary of Unicode strings to int. + min_count: an integer - discard subtokens with lower counts. + num_iterations: an integer. how many iterations of refinement. + reserved_tokens: List of reserved tokens. The global variable + `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this + argument is `None`, it will use `RESERVED_TOKENS`. + max_subtoken_length: Maximum length of a subtoken. If this is not set, + then the runtime and memory use of creating the vocab is quadratic in + the length of the longest token. If this is set, then it is instead + O(max_subtoken_length * length of longest token). + + Raises: + ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it + is not clear what the space is being reserved for, or when it will be + filled in. + """ + if reserved_tokens is None: + reserved_tokens = RESERVED_TOKENS + else: + # There is not complete freedom in replacing RESERVED_TOKENS. + for default, proposed in zip(RESERVED_TOKENS, reserved_tokens): + if default != proposed: + raise ValueError( + "RESERVED_TOKENS must be a prefix of " "reserved_tokens." + ) + + # Initialize the alphabet. Note, this must include reserved tokens or it can + # result in encoding failures. + alphabet_tokens = itertools.chain( + six.iterkeys(token_counts), [native_to_unicode(t) for t in reserved_tokens] + ) + + self._init_alphabet_from_tokens(alphabet_tokens) + + # Bootstrap the initial list of subtokens with the characters from the + # alphabet plus the escaping characters. + self._init_subtokens_from_list( + list(self._alphabet), reserved_tokens=reserved_tokens + ) + + # We build iteratively. On each iteration, we segment all the words, + # then count the resulting potential subtokens, keeping the ones + # with high enough counts for our new vocabulary. + if min_count < 1: + min_count = 1 + for i in range(num_iterations): + logging.info("Iteration %d", i) + + # Collect all substrings of the encoded token that break along current + # subtoken boundaries. + subtoken_counts = collections.defaultdict(int) + for token, count in six.iteritems(token_counts): + iter_start_time = time.time() + escaped_token = _escape_token(token, self._alphabet) + subtokens = self._escaped_token_to_subtoken_strings(escaped_token) + start = 0 + for subtoken in subtokens: + last_position = len(escaped_token) + 1 + if max_subtoken_length is not None: + last_position = min(last_position, start + max_subtoken_length) + + for end in range(start + 1, last_position): + new_subtoken = escaped_token[start:end] + subtoken_counts[new_subtoken] += count + start += len(subtoken) + iter_time_secs = time.time() - iter_start_time + if iter_time_secs > 0.1: + logging.info( + "Processing token [%s] took {%d} seconds, consider " + "setting Text2TextProblem.max_subtoken_length to a " + "smaller value.", + token, + iter_time_secs, + ) + + # Array of sets of candidate subtoken strings, by length. + len_to_subtoken_strings = [] + for subtoken_string, count in six.iteritems(subtoken_counts): + lsub = len(subtoken_string) + if count >= min_count: + while len(len_to_subtoken_strings) <= lsub: + len_to_subtoken_strings.append(set()) + len_to_subtoken_strings[lsub].add(subtoken_string) + + # Consider the candidates longest to shortest, so that if we accept + # a longer subtoken string, we can decrement the counts of its prefixes. + new_subtoken_strings = [] + for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1): + subtoken_strings = len_to_subtoken_strings[lsub] + for subtoken_string in subtoken_strings: + count = subtoken_counts[subtoken_string] + if count >= min_count: + # Exclude alphabet tokens here, as they must be included later, + # explicitly, regardless of count. + if subtoken_string not in self._alphabet: + new_subtoken_strings.append((count, subtoken_string)) + for l in range(1, lsub): + subtoken_counts[subtoken_string[:l]] -= count + + # Include the alphabet explicitly to guarantee all strings are encodable. + new_subtoken_strings.extend( + (subtoken_counts.get(a, 0), a) for a in self._alphabet + ) + new_subtoken_strings.sort(reverse=True) + + # Reinitialize to the candidate vocabulary. + new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings] + if reserved_tokens: + escaped_reserved_tokens = [ + _escape_token(native_to_unicode(t), self._alphabet) + for t in reserved_tokens + ] + new_subtoken_strings = escaped_reserved_tokens + new_subtoken_strings + + self._init_subtokens_from_list(new_subtoken_strings) + logging.info("vocab_size = %d", self.vocab_size) + + @property + def all_subtoken_strings(self): + return tuple(self._all_subtoken_strings) + + def dump(self): + """Debugging dump of the current subtoken vocabulary.""" + subtoken_strings = [ + (i, s) for s, i in six.iteritems(self._subtoken_string_to_id) + ] + print( + ", ".join("{0} : '{1}'".format(i, s) for i, s in sorted(subtoken_strings)) + ) + + def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None): + """Initialize token information from a list of subtoken strings. + + Args: + subtoken_strings: a list of subtokens + reserved_tokens: List of reserved tokens. We must have `reserved_tokens` + as None or the empty list, or else the global variable `RESERVED_TOKENS` + must be a prefix of `reserved_tokens`. + + Raises: + ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it + is not clear what the space is being reserved for, or when it will be + filled in. + """ + if reserved_tokens is None: + reserved_tokens = [] + + if reserved_tokens: + self._all_subtoken_strings = reserved_tokens + subtoken_strings + else: + self._all_subtoken_strings = subtoken_strings + + # we remember the maximum length of any subtoken to avoid having to + # check arbitrarily long strings. + self._max_subtoken_len = max([len(s) for s in subtoken_strings]) + self._subtoken_string_to_id = { + s: i + len(reserved_tokens) for i, s in enumerate(subtoken_strings) if s + } + # Initialize the cache to empty. + self._cache_size = 2**20 + self._cache = [(None, None)] * self._cache_size + + def _init_alphabet_from_tokens(self, tokens): + """Initialize alphabet from an iterable of token or subtoken strings.""" + # Include all characters from all tokens in the alphabet to guarantee that + # any token can be encoded. Additionally, include all escaping characters. + self._alphabet = { + c for token in tokens for c in token + } # pylint: disable=g-complex-comprehension + self._alphabet |= _ESCAPE_CHARS + + def _load_from_file_object(self, f): + """Load from a file object. + + Args: + f: File object to load vocabulary from + """ + subtoken_strings = [] + for line in f: + s = line.rstrip() + # Some vocab files wrap words in single quotes, but others don't + if (s.startswith("'") and s.endswith("'")) or ( + s.startswith('"') and s.endswith('"') + ): + s = s[1:-1] + subtoken_strings.append(native_to_unicode(s)) + self._init_subtokens_from_list(subtoken_strings) + self._init_alphabet_from_tokens(subtoken_strings) + + def _load_from_file(self, filename): + """Load from a vocab file.""" + if not tf.io.gfile.exists(filename): + raise ValueError("File %s not found" % filename) + with tf.io.gfile.GFile(filename) as f: + self._load_from_file_object(f) + + def store_to_file(self, filename, add_single_quotes=True): + with tf.io.gfile.GFile(filename, "w") as f: + for subtoken_string in self._all_subtoken_strings: + if add_single_quotes: + f.write("'" + subtoken_string + "'\n") + else: + f.write(subtoken_string + "\n") + + +class ImageEncoder: + """Encoder class for saving and loading images.""" + + def __init__(self, num_reserved_ids=0, height=None, width=None, channels=3): + assert num_reserved_ids == 0 + self._height = height + self._width = width + self._channels = channels + + @property + def num_reserved_ids(self): + return 0 + + def encode(self, s): + """Transform a string with a filename into a list of RGB integers. + + Args: + s: path to the file with an image. + + Returns: + ids: list of integers + """ + try: + import matplotlib.image as im # pylint: disable=g-import-not-at-top + except ImportError as e: + logging.warning( + "Reading an image requires matplotlib to be installed: %s", e + ) + raise NotImplementedError("Image reading not implemented.") + return im.imread(s) + + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int IDs into an image file. + + Args: + ids: list of integers to be converted. + strip_extraneous: unused + + Returns: + Path to the temporary file where the image was saved. + + Raises: + ValueError: if the IDs are not of the appropriate size. + """ + del strip_extraneous + _, tmp_file_path = tempfile.mkstemp("_decode.png") + if self._height is None or self._width is None: + size = int(math.sqrt(len(ids) / self._channels)) + length = size * size * self._channels + else: + size = None + length = self._height * self._width * self._channels + if len(ids) != length: + raise ValueError( + "Length of ids (%d) must be height (%d) x width (%d) x " + "channels (%d); %d != %d.\n Ids: %s" + % ( + len(ids), + self._height, + self._width, + self._channels, + len(ids), + length, + " ".join([str(i) for i in ids]), + ) + ) + # TF2 eager implementation: build image tensor and write PNG directly. + raw = tf.convert_to_tensor(ids, dtype=tf.uint8) + if size is None: + img = tf.reshape(raw, [self._height, self._width, self._channels]) + else: + img = tf.reshape(raw, [size, size, self._channels]) + png = tf.image.encode_png(img) + # Use TF2 IO API to write the encoded PNG to file. + tf.io.write_file(tmp_file_path, png) + return tmp_file_path + + def decode_list(self, ids): + """Transform a sequence of int IDs into an image file. + + Args: + ids: list of integers to be converted. + + Returns: + Singleton list: path to the temporary file where the image was saved. + """ + return [self.decode(ids)] + + @property + def vocab_size(self): + return 256 + + +class RealEncoder: + """Encoder class for saving and loading float values.""" + + def encode(self, s): + """Transform a string (space separated float values) into a float array. + + Args: + s: space separated float values. + + Returns: + Array of float values. + """ + return [float(w) for w in s.split()] + + def decode(self, ids, strip_extraneous=False): + """Transform sequence of float values into string (float values). + + Args: + ids: array of floats to be converted. + strip_extraneous: unused + + Returns: + String having space separated float values. + + Raises: + ValueError: if the IDs are not of the appropriate size. + """ + del strip_extraneous + return " ".join([str(i) for i in ids]) + + +class BertEncoder: + """Encoder Class that is compatible with models trained in original BERT library.""" + + def __init__(self, vocab_file, do_lower_case=True): + self._vocab = self.load_vocab(vocab_file) + self._inv_vocab = {v: k for k, v in self._vocab.items()} + self._basic_tokenizer = BertBasicEncoder(do_lower_case=do_lower_case) + self._wordpiece_tokenizer = BertWordpieceTokenizer(vocab=self._vocab) + + def load_vocab(self, vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with tf.io.gfile.GFile(vocab_file, "r") as reader: + while True: + token = native_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + def encode(self, text): + return self._convert_tokens_to_ids(self.tokenize(text)) + + # Note: Because encoding by BertEncoder is not unique text decoded + # from token ids is not unique. + def decode(self, ids): + """Returns a text that encoded would yield provided ids.""" + tokens = self._convert_ids_to_tokens(ids) + if not tokens: + return "" + retarr = [tokens[0]] + for token in tokens[1:]: + if token.startswith("##"): + retarr.append(token.lstrip("#")) + else: + retarr.append(" ") + retarr.append(token) + return "".join(retarr) + + @property + def vocab_size(self): + return len(self._vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self._basic_tokenizer.tokenize(text): + for sub_token in self._wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def _convert_tokens_to_ids(self, tokens): + return [self._vocab[token] for token in tokens] + + def _convert_ids_to_tokens(self, ids): + return [self._inv_vocab[token_id] for token_id in ids] + + +class BertBasicEncoder: + """Part of BertEncoder; tokenization (punctuation splitting, lower casing).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = native_to_unicode(text) + text = self._clean_text(text) + + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _bert_is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _bert_is_control(char): + continue + if _bert_is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class SentencePieceEncoder: + """SentencePiece tokenizer with support for extra_ids like in T5.""" + + def __init__(self, spm_path, extra_ids=0): + with tf.io.gfile.GFile(spm_path, "rb") as f: + sp_model = f.read() + self.tokenizer = tft.SentencepieceTokenizer(model=sp_model) + self.extra_ids = extra_ids + # Note: We assume vocab size includes EOS, PAD, etc. + self.vocab_size = self.tokenizer.vocab_size().numpy() + self.total_vocab_size = self.vocab_size + self.extra_ids + + def encode(self, text): + # Tokenize the text into base SentencePiece IDs + tokens = self.tokenizer.tokenize([text]) + # Back to numpy from tf + print(text) + return tokens.flat_values + + def decode(self, ids): + # Convert IDs back to text, handling extra_ids if needed + ids = [i for i in ids if i < self.vocab_size] + ids_tensor = tf.constant([ids], dtype=tf.int32) + text = self.tokenizer.detokenize(ids_tensor) + # Back to numpy from tf + print(text) + return text.numpy()[0].decode("utf-8") + + +"""A simple invertible tokenizer. + +Converts from a unicode string to a list of tokens +(represented as Unicode strings). + +This tokenizer has the following desirable properties: + - It is invertible. + - Alphanumeric characters are broken away from non-alphanumeric characters. + - A single space between words does not produce an extra token. + - The full Unicode punctuation and separator set is recognized. + +The tokenization algorithm is as follows: + +1. Split the text into a list of tokens, splitting at every boundary of an + alphanumeric character and a non-alphanumeric character. This produces + a list which alternates between "alphanumeric tokens" + (strings of alphanumeric characters) and "non-alphanumeric tokens" + (strings of non-alphanumeric characters). + +2. Remove every token consisting of a single space, unless it is + the very first or very last token in the list. These tokens are now + implied by the fact that there are two adjacent alphanumeric tokens. + +e.g. u"Dude - that's so cool." + -> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."] +""" + + +# Tokenization. +@debug_data_pipeline.debug_pipeline +def tokenize( + stream, + keys=None, + vocab_type="subword", + vocab_file=None, + vocab_dir=None, + n_reserved_ids=0, +): + """Tokenize examples from the stream. + + This function assumes that `stream` generates either strings or tuples/dicts + containing strings at some `keys`. This function maps these strings to + numpy arrays of integers -- the tokenized version of each string. + + Args: + stream: A python generator yielding strings, tuples or dicts. + keys: which keys of the tuple/dict to tokenize (by default: all) + vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. + vocab_file: Name of the vocabulary file. + vocab_dir: Directory which contains the vocabulary file. + n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; + This is common for example when reserving the 0 for padding and 1 for EOS, + but it's only needed if these symbols are not already included (and thus + reserved) in the vocab_file. + + Yields: + Examples from stream with strings at `keys` replaced by np.arrays of + integers -- the tokenized version of these strings. + """ + vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) + for example in stream: + if isinstance(example, (list, tuple)): + new_example = [] + for i, x in enumerate(example): + if keys is None or i in keys: + new_example.append(np.array(vocab.encode(x)) + n_reserved_ids) + else: + new_example.append(x) + output = tuple(new_example) + yield output + elif isinstance(example, dict): + new_example = {} + for k in example: + if keys is None or k in keys: + new_example[k] = np.array(vocab.encode(example[k])) + n_reserved_ids + else: + new_example[k] = example[k] + yield new_example + else: + output = np.array(vocab.encode(example)) + n_reserved_ids + yield output + + +@gin.configurable(module="trax.data") +def Tokenize( # pylint: disable=invalid-name + keys=None, + vocab_type="subword", # pylint: disable=invalid-name + vocab_file=None, + vocab_dir=None, + n_reserved_ids=0, +): + """Returns a function that maps text to integer arrays; see `tokenize`.""" + return lambda g: tokenize( # pylint: disable=g-long-lambda + g, + keys=keys, + vocab_type=vocab_type, + vocab_file=vocab_file, + vocab_dir=vocab_dir, + n_reserved_ids=n_reserved_ids, + ) + + +def detokenize( + x, vocab_type="subword", vocab_file=None, vocab_dir=None, n_reserved_ids=0 +): + """Maps integer arrays to text; the opposite of `tokenize`. + + In many cases (all char- and subword-type vocabularies and most sentencepiece + ones) the tokenization is invertible, so detokenize(tokenize(x)) = x. In some + more rare cases this can remove some spacing, but it is still often useful + to run detokenize to get a readable version for a tokenized string. + + Args: + x: a list or numpy array of integers. + vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. + vocab_file: Name of the vocabulary file. + vocab_dir: Directory which contains the vocabulary file. + n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; + This is common for example when reserving the 0 for padding and 1 for EOS, + but it's only needed if these symbols are not already included (and thus + reserved) in the vocab_file. + + Returns: + A string corresponding to the de-tokenized version of x. + """ + vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) + x_unreserved = np.array(x) - n_reserved_ids + return str(vocab.decode(x_unreserved.tolist())) + + +@gin.configurable(module="trax.data") +def SentencePieceTokenizer(spm_path=None, extra_ids=0): + """ + Returns a generator function that tokenizes a stream of text using + SentencePiece and supports extra IDs. + + Args: + spm_path: Path to the SentencePiece model file. Must be provided. + extra_ids: Number of extra IDs to reserve. + + Returns: + A function that takes a generator of text examples and yields tokenized + numpy arrays. + """ + if spm_path is None: + raise ValueError("spm_path must be provided.") + + def tokenize(stream, spm_path, extra_ids): + vocab_file = os.path.basename(spm_path) + vocab_dir = os.path.dirname(spm_path) + vocab = _get_vocab( + vocab_type="sentencepiece", + vocab_file=vocab_file, + vocab_dir=vocab_dir, + extra_ids=extra_ids, + ) + for example in stream: + # Optionally replace print with logging.debugger + # logging.debugger("Tokenizing example: %s", example) + if isinstance(example, tuple): + example = example[0] + yield np.array(vocab.encode(example), dtype=np.int64) + + return lambda g: tokenize(g, spm_path=spm_path, extra_ids=extra_ids) + + +class BertWordpieceTokenizer: + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = native_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens diff --git a/trax/data/inputs.py b/trax/data/inputs.py deleted file mode 100644 index de15497f4..000000000 --- a/trax/data/inputs.py +++ /dev/null @@ -1,1590 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Data sources and input processing. - -Trax authors recommend constructing input pipelines using layer-like functions -and combinators. For example, following is an input pipeline for training -sentiment analysis tasks on the IMDB dataset:: - - from trax import data - - inputs = data.Serial( - data.TFDS('imdb_reviews', keys=('text', 'label'), train=True), - data.Tokenize(vocab_file='en_8k.subword', keys=[0]), - data.Shuffle(), - data.FilterByLength(max_length=2048, length_keys=[0]), - data.BucketByLength(boundaries=[ 32, 128, 512, 2048], - batch_sizes=[128, 32, 8, 2, 1], - length_keys=[0]), - data.AddLossWeights() - ) - -Each of these functions creates a Python generator of tuples of data arrays. -For example:: - - data.TFDS('imdb_reviews', keys=('text', 'label'), train=True), - -creates a generator of examples (tuples of NumPy :py:class:`ndarray` objects) -from the TFDS imdb_reviews dataset, see here: -https://www.tensorflow.org/datasets/catalog/imdb_reviews - -As you can see on the website above, this dataset has 'text' and 'label' fields -and we create tuples containing the text and the label from the training split -by specifying keys=('text', 'label'), train=True. - -Other functions, like ``Tokenize`` and ``Shuffle``, take a generator and output -another generator, in this way converting tuples into other tuples or mixing -the training stream. For example, ``Tokenize(..., keys=[0])`` tokenizes the -first element of a tuple -- converting it from text to a NumPy integer array. -And ``Shuffle`` randomizes the order of examples. - -Note that all elements in the data pipeline are just functions on generators, -so you can use Python's `map` and `filter` and other native functions too. -For example, you can create an input pipeline for a language model reading -lines from `my_file.txt` as follows:: - - inputs = data.Serial( - lambda _: open('my_file.txt'), - lambda g: map(lambda line: line.strip(), g), - data.Tokenize(vocab_file='en_8k.subword'), - lambda g: filter(lambda x: x.shape[0] < 513, g), # At most 512 tokens. - data.Shuffle(), - lambda g: map(lambda x: (x, x)), # Language models have inputs = targets. - data.BucketByLength(boundaries=[ 32, 64, 128, 256, 512], - batch_sizes=[ 32, 16, 8, 4, 2, 1]), - data.AddLossWeights(id_to_mask=0) - ) - -""" - -import math -import multiprocessing.dummy as mp # using threads for now -import os -import pickle -import random -import time - -from absl import logging -import gin -import jax -import numpy as np -import tensorflow as tf - -from trax import fastmath -from trax import shapes -from trax.data import debug_data_pipeline -from trax.fastmath import numpy as jnp - - -def Serial(*fns): # pylint: disable=invalid-name - """Combines generator functions into one that runs them serially.""" - def composed_fns(generator=None): - for f in fastmath.tree_flatten(fns): - generator = f(generator) - return generator - return composed_fns - - -# TODO(jonni): Rename to Blend/Merge/Mix/Interleave/...? -def Parallel( # pylint: disable=invalid-name - fns=None, - counters=None, - reweight_by_minimum=False, - gradually_reweight=False, - use_remainders=False): - """Combines generator functions into one that runs them in parallel. - - Args: - fns: a sequence of datasets which are combined in parallel. - counters: a sequence of ints with same length as fns, please see comments on - its use below. - reweight_by_minimum: if set to True, then we re-weight every counter by the - minimal counter. E.g. counters (10000, 100000) are translated to (1, 10) - and hence for every 10 examples from the second dataset we are getting - 1 example from the first dataset. Without reweighting first we would see - 20 examples from the first and second dataset and then 90 thousand eamples - only from the first dataset. - gradually_reweight: if set to True, then we loop through the generators - using a recursive rule defined in emit_examples. First we sort generators - by the counters. If we have datasets with counters 1, 20, 40 - (after sorting) then we yield examples (a(b c^2)^20)^*, where examples of - type a come from the first dataset, of type b from the second and of type - c from the third. The exponents are obtained through divisions of - subsequent counters. - use_remainders: if set to True as weell as gradually_reweight is set to - True and counters are 1, 20, 45 then after dealing with all examples in - the format (a(b c^2)^20)^*, the generator yields the remaining 5 examples - from the dataset with counter 45. - Returns: - parallel_generator: the generator yields samples according to given; - if counters are not given then samples are genereted uniformly. - - Example 1: - - gen = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3)) - - defines a generator that yields 33% examples from dataset1, 16% examples from - dataset2 and 50% examples from dataset3. - - Example 2: - - gen = data.Parallel([dataset1, dataset2, dataset3], counters=(20, 50, 30)) - - defines a generator that yields 20% examples from dataset1, 50% examples from - dataset2 and 30% examples from dataset3. - """ - - if counters: - assert len(counters) == len(fns) - # Remove generators with zero counters - counters = list(counters) - fns = list(fns) - non_zeros = [j for j in range(len(counters)) if counters[j] != 0] - counters = [counters[j] for j in non_zeros] - fns = [fns[j] for j in non_zeros] - else: - counters = [1] * len(fns) - - if reweight_by_minimum: - counters = [math.floor(counter / min(counters)) for counter in counters] - - def emit_examples(sorted_counters_with_gens, prev_counter): - if sorted_counters_with_gens: - _, counter, generator = sorted_counters_with_gens[0] - repeats = math.floor(counter / prev_counter) - for _ in range(repeats): - yield next(generator) - yield from emit_examples(sorted_counters_with_gens[1:], counter) - - def parallel_generator(gen=None): - # If gradually_reweight is set to False then - # current_counters are increased step by step; they are reset to 0s when - # current_counters[idx] == counters[idx] for all idx. See - # test_parallel_with_weights_three_datasets for an example of how - # current_counters are changed during computation. - # If gradually_reweight is set to False then we loop using a - # recursive rule defined in emit_examples. - - generators = [] - for f in fns: - if gen: - generators.append(f(gen)) - else: - # This handles the case when the function f cannot be - # called on None. - generators.append(f()) - - if gradually_reweight: - counters_with_gens = zip(range(len(generators)), counters, generators) - sorted_counters_with_gens = sorted(counters_with_gens, key=lambda x: x[1]) - while True: - yield from emit_examples(sorted_counters_with_gens, min(counters)) - if use_remainders: - # Below we are dealing with remainders. - fractions = [] - for i in range(len(sorted_counters_with_gens)): - _, counter, generator = sorted_counters_with_gens[i] - processed = 1 - for fraction in fractions: - processed *= fraction - remainder = counter - processed - for _ in range(remainder): - yield next(generator) - if i < len(sorted_counters_with_gens) - 1: - _, next_counter, _ = sorted_counters_with_gens[i + 1] - fractions.append(math.floor(next_counter / counter)) - else: - current_counters = [0] * len(generators) - while True: - for idx, generator in enumerate(generators): - if current_counters[idx] < counters[idx]: - current_counters[idx] += 1 - # instead of checking current_counters[idx] == counters[idx] for - # all idx, we check the equivalent condition: - if sum(current_counters) == sum(counters): - current_counters = [0] * len(generators) - yield next(generator) - - return parallel_generator - - -@gin.configurable(module='trax.data') -def Shuffle(queue_size=1024): # pylint: disable=invalid-name - """Returns a shuffle function with the given queue size.""" - return lambda g: shuffle(g, queue_size) - - -@gin.configurable(module='trax.data') -def Batch(batch_size): # pylint: disable=invalid-name - """Returns a batching function with given batch size.""" - return lambda g: batch(g, batch_size) - - -@gin.configurable(module='trax.data') -def Dup(): # pylint: disable=invalid-name - """Duplicates (copies) the top element (inputs). - - The generator stream is augmented in the following way: - - - If the stream consists of a single element `(inputs, )`, - the inputs simply get copied to `(inputs, inputs)`. - - If the stream consists of multiple elements, for example - `(inputs, weights)`, the rest of elements get moved toward - the right side `(inputs, inputs, weights)`. - - Returns: - the duplicating function. - """ - def _copy(xs): - x, *rest = xs - return (x, x, *rest) - return lambda g: map(lambda x: _copy(x), g) # pylint: disable=unnecessary-lambda - - -@gin.configurable(module='trax.data') -def FilterEmptyExamples(axes=None, debug=False): # pylint: disable=invalid-name - """Filters empty examples. - - Filters any example that has an array of size (0,) (if axes=None). - Alternatively, checks only axes provided in `axes' list. Contrary to - FilterByLength used with several elements with length_axis, here the example - would be filtered if ANY of the dimensions listed in `axes' contains an empty - array. - - Args: - axes: list of indices to check, if None, all of them. - debug: If true, emits a log everytime we filter out an empty example. - - Returns: - Function filtering empty examples. - """ - def _filter_examples(generator): - for example in generator: - correct = True - for i, unused_tuple_element in enumerate(example): - if axes is None or i in axes: - if example[i].shape == (0,): - correct = False - break - if correct: - yield example - elif debug: - logging.info('Filtered example: %r', example) - return _filter_examples - - -@gin.configurable(module='trax.data') -def FilterByLength(max_length, min_length=0, # pylint: disable=invalid-name - length_keys=None, length_axis=0): - """Returns a function that filters out examples by length. - - Args: - max_length: int. If not None, indicates maximum length. - min_length: int. If not None, indicates minimum length. - length_keys: (list) which example keys to take into account. - length_axis: which shape axis to take into account. - Returns: - a function that filters out examples by length. - """ - - assert max_length is not None or min_length is not None - length_keys = length_keys or [0, 1] - length_fn = lambda x: _length_fn(x, length_axis, length_keys) - def filtered(gen): - for example in gen: - example_len = length_fn(example) - - # Checking max length boundary. - if max_length is not None: - if example_len > max_length: - continue - # Checking min length boundary. - if min_length is not None: - if example_len < min_length: - continue - # Within bounds. - yield example - return filtered - - -@gin.configurable(module='trax.data') -def TruncateToLength(len_map=None): # pylint: disable=invalid-name - """Returns a stream function that resizes items as specified by ``len_map``. - - Args: - len_map: Dictionary that specifies maximum shapes for potentially multiple - features per stream item. For example, given a stream of tokenized - string pairs, one could enforce a maximum length of 256 tokens for each - string by using ``len_map={0: (256,), 1: (256,)}``. - """ - @debug_data_pipeline.debug_pipeline - def _truncate_to_length(generator): - for example in generator: - if isinstance(example, np.ndarray): - example = (example,) - if isinstance(example, (list, tuple)): - example = list(example) - if len_map is not None: - for key, max_len in len_map.items(): - example_len = example[key].shape - if example_len > max_len: - example[key] = np.resize(example[key], max_len) - output = tuple(example) - else: - output = None - raise ValueError(f'Unknown example type: {example}') - yield output - - return _truncate_to_length - - -@gin.configurable(module='trax.data') -def PadToLength( # pylint: disable=invalid-name - len_map=None, pad_value=0, multiple=False): - """Pads the values to lengths given in `len_map'. - - len_map contains a dictionary of example keys to dimension sizes. - - Args: - len_map: dict of int to int, we pad examples to lengths - given by the values of the dict. If multiple is True, the dimensions are - padded to multiple of this value. - pad_value: dict of int to int. The value gets applied to - constant_values on numpy.pad per given dimension. - multiple: boolean. If False, pads to the value of len_map. If True, pads to - closest multiple of value of len_map. - Returns: - Function to pad examples to given lengths. - """ - @debug_data_pipeline.debug_pipeline - def _pad_to_length(generator): - for example in generator: - if isinstance(example, (list, tuple)): - example = list(example) - for key, value in len_map.items(): - array_length = example[key].shape[0] - if multiple: - padding_len = array_length - ((array_length // value) * value) - else: - padding_len = max([0, value-example[key].shape[0]]) - example[key] = np.pad(example[key], - pad_width=(0, padding_len), - mode='constant', - constant_values=pad_value[key]) - output = tuple(example) - else: - if not isinstance(example, np.ndarray): - raise ValueError(f'example isn\'t nparray, but should be: {example}') - array_length = example.shape[0] - if multiple: - padding_len = ( - array_length - ((array_length // len_map[0]) * len_map[0])) - else: - padding_len = max(0, len_map[0] - array_length) - output = np.pad(example, - pad_width=(0, padding_len), - mode='constant', - constant_values=pad_value[0]) - yield output - if len_map is None: - raise ValueError('len_map parameter should be provided.') - return _pad_to_length - - -@gin.configurable(module='trax.data') -def BucketByLength(boundaries, batch_sizes, # pylint: disable=invalid-name - length_keys=None, length_axis=0, strict_pad_on_len=False): - """Returns a function for bucketing inputs, see `bucket_by_length`.""" - length_keys = length_keys or [0, 1] - # In all cases so far, we use a length function of the following form. - length_fn = lambda x: _length_fn(x, length_axis, length_keys) - return lambda g: bucket_by_length( # pylint: disable=g-long-lambda - g, length_fn, boundaries, batch_sizes, strict_pad_on_len) - - -@gin.configurable(module='trax.data') -def MLM(vocab_size=None, # pylint:disable=invalid-name - max_length=None, - noise_density=0.15, - mean_noise_span_length=3.0): - """Pipeline that just does MLM.""" - return Serial( - # Generate sequential chunks. - generate_sequential_chunks(max_length=max_length), - # Generate mask and chunk. - generate_random_noise_mask( - noise_density=noise_density, - mean_noise_span_length=mean_noise_span_length), - # Consume mask and chunk to give (input, targets). - consume_noise_mask(vocab_size=vocab_size), - ) - - -@gin.configurable(module='trax.data') -def PrefixLM(input_length=128, output_length=512): # pylint:disable=invalid-name - """Chunks examples so as to make inputs/outputs of specified lenghts.""" - def _f(generator): - for example in generator: - n_tokens = len(example) - # Iterate: - # |--------|<---- input_length ---->|<- output_length ->|--------------| - # ^ ^ ^ ^ - # | | | | - # 0 input_begin_idx input_end_idx output_end_idx - input_begin_idx = 0 - # While you can make an input batch, keep going. - while input_begin_idx + input_length < n_tokens: - input_end_idx = input_begin_idx + input_length - output_end_idx = min(input_end_idx + output_length, n_tokens) - yield (example[input_begin_idx:input_end_idx], - example[input_end_idx:output_end_idx]) - # Update the indices. - input_begin_idx = output_end_idx - return _f - - -@gin.configurable(module='trax.data') -def ConcatenateToLMInput(pad_to_length=None): # pylint: disable=invalid-name - """Prepares the input needed for training of Language Models. - - Each example needs to contain two elements (input and target). - Input is concatenated to target and, if pad_to_length is given, padded to - length provided. - The loss_weights indicates only the target, without input nor padding. - - Args: - pad_to_length: int, total length of padding of input and target arrays. - Returns: - Function to return input for a LM. - """ - @debug_data_pipeline.debug_pipeline - def _concatenate_to_lm_input(generator): - for example in generator: - if isinstance(example, (list, tuple)) and (len(example) == 2): - concatenated = np.concatenate((example[0], example[1]), axis=-1) - loss_weights = np.concatenate((np.zeros_like(example[0]), - np.ones_like(example[1]))) - if pad_to_length is not None: - padding_len = pad_to_length - ( - example[0].shape[0] + example[1].shape[0]) - if padding_len < 0: - raise ValueError( - 'Example lengths ' - f'({example[0].shape[0]}, {example[1].shape[0]}) ' - f'longer than pad_to_length ({pad_to_length}).') - loss_weights = np.pad(loss_weights, (0, padding_len), 'constant') - concatenated = np.pad(concatenated, (0, padding_len), 'constant') - output = (concatenated, concatenated, loss_weights) - elif isinstance(example, (list, tuple)) and (len(example) == 1): - # Make x into (x, x) - output = (example[0], example[0]) - elif isinstance(example, np.ndarray): - # Make x into (x, x) - output = (example, example) - else: - output = None - raise ValueError(f'Unknown input to ConcatenateToLMInput: {example}') - yield output - - return _concatenate_to_lm_input - - -@gin.configurable(module='trax.data') -def CastTo(dtype=np.int32, indices=(0, 1,), debug=False): # pylint: disable=invalid-name - """Casts the given indices to the given dtype.""" - def _cast_fn(generator): - debug_count = 0 - for example in generator: - debug_count += 1 - assert isinstance(example, tuple) - example = list(example) - dtype_mismatch = False - original_index_and_dtype = [] - for i in range(len(example)): - if i not in indices: - continue - original_type = example[i].dtype - if original_type != dtype: - if not (original_type == np.int64 and dtype == np.int32): - # Downcasting from np.int64 to np.int32 is OK - original_index_and_dtype.append((i, original_type)) - example[i] = example[i].astype(dtype) - dtype_mismatch = True - if debug and dtype_mismatch and original_index_and_dtype: - logging.info('dtype mismatch in example[%d] = %r was earlier: %r', - debug_count, example, original_index_and_dtype) - yield tuple(example) - return _cast_fn - - -@gin.configurable(module='trax.data') -def AppendValue(val=None): # pylint: disable=invalid-name - """Appends values provided in 'val` to inputs. - - val are keyed by example keys, its values contain appended tensors. - - Args: - val: dict of int to tensors. Specific keys get the tensors specified in - values appended. - Returns: - Funtion to append tensors to examples. - """ - @debug_data_pipeline.debug_pipeline - def _append_value(generator): - for example in generator: - if isinstance(example, tuple): - example = list(example) - if val is not None: - for key, value in val.items(): - example[key] = np.append(example[key], value, -1) - output = tuple(example) - else: - if not isinstance(example, np.ndarray): - raise ValueError(f'example isn\'t nparray, but should be: {example}') - output = np.append(example, val[0]) - yield output - - return _append_value - - -@gin.configurable(module='trax.data') -def AddLossWeights(id_to_mask=None): # pylint: disable=invalid-name - """Returns a function to add loss weights; see `add_loss_weights`.""" - return lambda g: add_loss_weights(g, id_to_mask=id_to_mask) - - -@gin.configurable(module='trax.data') -def UnBatch(): # pylint: disable=invalid-name - """Returns a function which unbatches.""" - def _unbatch(generator): - for batched_example in generator: - # batched_example is usually like: - # (batched_inputs, batched_outputs) or - # (batched_inputs, batched_outputs, batched_weights) - assert isinstance(batched_example, tuple) - # assert all lengths are the same. - batch_sizes = list(set(map(lambda example: example.shape[0], - batched_example))) - assert len(batch_sizes) == 1 - # Now unbatch examples. - for example_idx in range(batch_sizes[0]): - yield tuple(map(lambda x: x[example_idx], batched_example)) # pylint: disable=cell-var-from-loop - return _unbatch - - -@gin.configurable(module='trax.data') -def Prefetch(n_prefetch=2): # pylint: disable=invalid-name - """Pre-fetches a number of examples from generator in a separate process.""" - def prefetch(generator): - in_q, out_q = mp.Queue(), mp.Queue() - p = mp.Process(target=_generator_process, args=(generator, in_q, out_q)) - for _ in range(n_prefetch): - in_q.put(None) - p.start() - while True: - yield out_q.get() - in_q.put(None) - return prefetch - - -@gin.configurable(module='trax.data') -def UniformlySeek(name=None, host_id=None, n_hosts=None, dataset_size=None): # pylint: disable=invalid-name - """Sets each host at (dataset_size/n_hosts)-th of the dataset.""" - if not dataset_size: - dataset_size = 2 ** 18 # 512 * 512 - logging.error( - 'No dataset size given to Uniformly seek, assuming: %d', dataset_size) - assert name - host_id = jax.process_index() if host_id is None else host_id - n_hosts = n_hosts or jax.host_count() - each_host = int(dataset_size / n_hosts) - def _f(generator): - # Each host seeks to the appropriate point in the dataset. - num_to_seek = int(host_id * each_host) - start_time = time.time() - logging.info('Dataset[%s] host_id[%d] is seeking to position[%d]', - name, host_id, num_to_seek) - for _ in range(num_to_seek): - next(generator) - logging.info('Dataset[%s] host_id[%d] reached position[%d]. ' - 'Time taken [%s] seconds', - name, host_id, num_to_seek, time.time() - start_time) - for example in generator: - yield example - return _f - - -@gin.configurable(module='trax.data') -def CountAndSkip(name): # pylint: disable=invalid-name - """Returns a function that counts and skips examples (see above).""" - return lambda g: count_and_skip(g, name) - - -@gin.configurable(module='trax.data') -def Log(n_steps_per_example=1, only_shapes=True): # pylint: disable=invalid-name - """Creates a logging component of the input pipeline.""" - def log(stream): - counter = 0 - for example in stream: - item_to_log = example - if only_shapes: - item_to_log = fastmath.nested_map(shapes.signature, example) - if counter % n_steps_per_example == 0: - logging.info(str(item_to_log)) - print(item_to_log) - counter += 1 - yield example - return log - - -def shuffle(samples, queue_size): - """Shuffles a sample stream using a random-out next-in queue of given size. - - Args: - samples: Stream of samples for eventual use as training data or eval data. - queue_size: Minimum number of samples within which the streamed shuffling - takes place. - - Yields: - Shuffled stream of samples, ready for further processing, e.g., grouping - into batches. - """ - if queue_size < 1: - raise ValueError(f'Arg queue_size ({queue_size}) is less than 1.') - if queue_size == 1: - logging.warning('Queue size of 1 results in no shuffling.') - queue = [] - try: - # Prep: fill the queue. - for _ in range(queue_size): - queue.append(next(samples)) - - # Core streaming shuffle: yield sample from random location in queue, then - # fill that location with new sample from input stream. - for sample in samples: - i = np.random.randint(queue_size) - yield queue[i] - queue[i] = sample - except StopIteration: - # Only get here if the initial queue fill fails. - logging.warning( - 'Not enough samples (%d) to fill initial queue (size %d).', - len(queue), queue_size) - - # No new samples coming in; shuffle and drain the queue. - np.random.shuffle(queue) - for sample in queue: - yield sample - - -def batch(generator, batch_size): - """Batch and pad generator as in tf.data.Dataset.padded_batch.""" - if batch_size <= 0: - raise ValueError(f'Batch size must be positive, but is {batch_size}.') - buf = [] - i = 0 - for example in generator: - buf.append(example) # Examples are tuples of tensors. - if len(buf) == batch_size: - # buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)] - # batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3]) - try: - batched_example = tuple( - pad_to_max_dims([np.asarray(tensor) for tensor in x]) - for x in zip(*buf)) - except ValueError as e: - for j in range(len(buf)): - logging.error('Batch[%d][%d] input shape: %r output shape: %r', - i, j, buf[j][0].shape, buf[j][1].shape) - for j in range(len(buf)): - logging.error('Batch[%d][%d] input: %r', i, j, buf[j][0]) - logging.error('Batch[%d][%d] output: %r', i, j, buf[j][1]) - raise e - i += 1 - yield batched_example - buf = [] - - -def pad_to_max_dims(tensors, boundary=None, strict_pad_on_len=False): - """Pad a tuple of tensors to a joint dimension and return their batch. - - For example, a pair of tensors of shape (2, 10) and (3, 9) will be padded - to (3, 10) both and the returned tensor will have shape (2, 3, 10). - - When boundary is specified, we try to pad all unknown dimensions to boundary - if possible, which can help reduce the number of different shapes occurring - in the tensors and speed up XLA compilation. So, for example, a pair of - tensors of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12). - - One special case occurs when boundary is much higher than the padding length - that we'd use without boundary. For example, tensors (2, 10) and (3, 9) with - boundary=12 could end up padded to (12, 12), but this is very wasteful in - the first dimension. In that case, we will use the closest power-of-2 instead - of the boundary, so the we will end up padding to (4, 12) instead of (12, 12). - - Args: - tensors: a tuple or list of tensors to pad - boundary: int or None; if given, expand the padded dimensions to this size - strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] - strictly as a multiple of boundary. - - Returns: - a tensor, the tensors padded together - """ - # TODO(afrozm): Unify this later. - if ((boundary is not None) and - (strict_pad_on_len or isinstance(boundary, (list, tuple)))): - ndim = tensors[0].ndim - if not isinstance(boundary, (list, tuple)): - boundary = [boundary] * ndim - - if ndim != len(boundary): - raise ValueError(f'ndim != len(boundary) - ' - f'ndim({ndim}) vs boundary({boundary}) ' - f'len(boundary) = {len(boundary)}.') - - max_len_per_dim = [0] * ndim - for tensor in tensors: - max_len_per_dim = [ - max(e, s) for e, s in zip(tensor.shape, max_len_per_dim)] - - # Round everything up to a multiple of boundary in the respective dimension. - len_per_dim = [ - max_len_per_dim[i] if not b else b * math.ceil(max_len_per_dim[i] / b) - for i, b in enumerate(boundary)] - - padded_tensors = [ - np.pad(t, [(0, len_per_dim[i] - t.shape[i]) for i in range(ndim)], - mode='constant', constant_values=t.dtype.type(0)) - for t in tensors] - - return np.stack(padded_tensors) - - max_len_to_pad = [] - padding_needed = False - dim = len(tensors[0].shape) - for i in range(dim): - max_len = max([t.shape[i] for t in tensors]) - min_len = min([t.shape[i] for t in tensors]) - if max_len == min_len and max_len == boundary: # No padding needed. - max_len_to_pad.append(max_len) - elif boundary is None: - max_len_to_pad.append(max_len) - padding_needed = True - else: - padding_needed = True - cur_boundary = max(max_len, boundary) - if 2 * max_len < cur_boundary: - cur_boundary = 2**int(np.ceil(np.log2(max_len))) - max_len_to_pad.append(cur_boundary) - if not padding_needed: - return np.stack(tensors) - padded_tensors = [] - for t in tensors: - pad_widths = [(0, max_len_to_pad[i] - t.shape[i]) for i in range(dim)] - padded_t = np.pad(t, pad_widths, mode='constant', - constant_values=t.dtype.type(0)) - padded_tensors.append(padded_t) - return np.stack(padded_tensors) - - -def bucket_by_length(generator, length_fn, boundaries, batch_sizes, - strict_pad_on_len=False): - """Bucket by length, like tf.data.experimental.bucket_by_sequence_length. - - This function draws examples from the provided `generator` and puts an - example into a bucket depending on `l = length_fn(example)`. Which bucket - is used depends on between which `boundaries` is l. When a bucket reaches - its batch size, as specified by `batch_sizes`, generates a batch of - padded examples from this bucket. - - Args: - generator: python generator to draw data from. - length_fn: a function taking the example and returning the length. - boundaries: a list of bucket boundaries. - batch_sizes: a list of batch sizes. - strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] - strictly as a multiple of boundary. - - Yields: - An input batch, which comes from one of the buckets. - """ - buckets = [[] for _ in range(len(batch_sizes))] - boundaries = boundaries + [math.inf] # Max boundary is unlimited. - for example in generator: - length = length_fn(example) - # `bucket_idx` will always be < len(boundaries), since boundaries is right - # padded by `math.inf`. - bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b]) - buckets[bucket_idx].append(example) - if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]: - batched = zip(*buckets[bucket_idx]) - boundary = boundaries[bucket_idx] - boundary = None if boundary == math.inf else boundary - padded_batch = tuple( - pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched) - yield padded_batch - buckets[bucket_idx] = [] - - -@debug_data_pipeline.debug_pipeline -def add_loss_weights(generator, id_to_mask=None): - """Add weights to inputs without weights and masks by id if requested. - - The generator stream is augmented in the following way: - - - If the stream consists of pairs `(inputs, targets)`, a loss mask is added - that is creates as a tensor of ones of the same shape as targets. - - If `id_to_mask` is not `None`, and the stream (after the previous point) - has triples `(inputs, targets, weights)`, the weights are multiplied by a - 0/1 mask that is 0 iff targets is equal to `id_to_mask` (1 otherwise). - - Args: - generator: Stream of tuples. - id_to_mask: If not None, int-valued id that represents padding, as opposed - to true target IDs. - - Yields: - Examples from the augmented stream. - """ - for example in generator: - if len(example) > 3 or len(example) < 2: - assert id_to_mask is None, 'Cannot automatically mask this stream.' - yield example - else: - if len(example) == 2: - weights = np.ones_like(example[1]).astype(np.float32) - else: - weights = example[2].astype(np.float32) - mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32) - weights *= mask - output = (example[0], example[1], weights) - yield output - - -@gin.configurable(module='trax.data') -def generate_random_noise_mask(noise_density=0.15, - mean_noise_span_length=3.0, - seed1=None, - seed2=None): - """Returns a function that generates a random noise mask.""" - def _f(generator): - for example in generator: - length = len(example) - noise_mask = random_spans_noise_mask( - length, noise_density=noise_density, - mean_noise_span_length=mean_noise_span_length, - seed1=seed1, seed2=seed2, example=example) - yield (example, noise_mask) - return _f - - -@gin.configurable(module='trax.data') -def consume_noise_mask(vocab_size=32100): - """Consumes (tokens, noise mask) and returns (inputs, targets).""" - def _noise_span_to_unique_sentinel(tokens, noise_mask): - prev_token_is_noise = np.pad( - noise_mask[:-1], [1, 0], mode='constant', constant_values=False) - first_noise_tokens = np.logical_and(noise_mask, - np.logical_not(prev_token_is_noise)) - subsequent_noise_tokens = np.logical_and(noise_mask, prev_token_is_noise) - sentinel = vocab_size - np.cumsum(first_noise_tokens) - tokens = np.where(first_noise_tokens, sentinel, tokens) - return tokens[np.logical_not(subsequent_noise_tokens)] - - def _f(generator): - for tokens, noise_mask in generator: - # Returns inputs and targets. - yield (_noise_span_to_unique_sentinel(tokens, noise_mask), - _noise_span_to_unique_sentinel(tokens, np.logical_not(noise_mask))) - return _f - - -@gin.configurable(module='trax.data') -def generate_sequential_chunks(max_length=None): - """Returns a function that generates chunks of atmost max_length length.""" - def _f(generator): - for example in generator: - n_tokens = len(example) - if n_tokens <= max_length: - yield example - else: - n_segments = int(math.ceil(float(n_tokens) / float(max_length))) - for i in range(n_segments): - start = max_length * i - end = min(start + max_length, n_tokens) - yield example[start:end] - return _f - - -@gin.configurable(module='trax.data') -def addition_input_stream( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, min_length=gin.REQUIRED, - max_length=gin.REQUIRED, pad_to_multiple=32, encdec=False): - """Data stream for the add problem: x+y(x+y). - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - min_length: minimal length of w. - max_length: maximal length of w. - pad_to_multiple: int, pad length to be multiple of this number. - encdec: bool, if True return encoder-decoder style inputs (default: False) - - Returns: - python generator of tuples of data examples - """ - base = vocab_size - 3 # We use 0 to pad, base+1 as "+" and base+2 as "". - def single_example(max_length, min_length): - """Generate a stream of random mini-batches.""" - add_len = (min_length - 1) // 2 - l1 = np.random.randint((max_length - add_len + 1) // 2) + add_len - l2 = np.random.randint(max_length - l1 - 1) + 1 - n1 = random_number_lower_endian(l1, base) - n2 = random_number_lower_endian(l2, base) - result = lower_endian_to_number(n1, base) + lower_endian_to_number( - n2, base) - inp = n1 + [base] + n2 - tgt = number_to_lower_endian(result, base) - if encdec: - x = [i + 1 for i in inp] - y = [i + 1 for i in tgt] - weights = [1] * len(tgt) - candidate_example = (np.array(x), np.array(y), np.array(weights)) - if any(len(sample) > max_length for sample in candidate_example): - # sample too long, try again - return single_example(max_length, min_length) - return (np.array(x), np.array(y), np.array(weights)) - else: - x = [base+2] + [i+1 for i in inp] + [base+2] + [i+1 for i in tgt] - weights = ([0] * (len(inp) + 2)) + ([1] * len(tgt)) - return (np.array(x), np.array(x), np.array(weights)) - - def batches(max_length, min_length): - """Batches of examples.""" - if max_length < 3 or min_length < 3: - raise ValueError('Maximum/minimum length must be at least 3.') - while True: - ex = [single_example(max_length, min_length) for _ in range(batch_size)] - padded_batch = [pad_to_max_dims(x, boundary=pad_to_multiple, - strict_pad_on_len=True) - for x in zip(*ex)] - yield tuple(padded_batch) - - return batches(max_length, min_length) - - -# This is a straightforward translation of T5's random_spans_noise_mask. -def random_spans_noise_mask(length, - noise_density=0.15, - mean_noise_span_length=3.0, - seed1=None, - seed2=None, - example=None): - """Computes span corruption masks given input parameters.""" - # Passing this in case if we want to use for debugging/logging - del example - orig_length = length - # increase length to avoid degeneracy - length = max(length, 2) - num_noise_tokens = int(round(length * noise_density)) - # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. - num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) - num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) - # avoid degeneracy by ensuring positive number of noise spans - num_noise_spans = max(num_noise_spans, 1) - num_nonnoise_tokens = length - num_noise_tokens - - # Pick the lengths of the noise spans and the non-noise spans - def randomly_segment(num_items, num_segments, seed): - x = np.arange(num_items - 1) < num_segments - 1 - # Set random seed if passed (only in tests for now). - if seed is not None: - np.random.seed(seed) - np.random.shuffle(x) - first_in_segment = np.pad(x, (1, 0), mode='constant') - segment_id = np.cumsum(first_in_segment) - - y = np.roll(segment_id, 1) - y[0] = 0 - idxs = np.pad(np.squeeze(np.argwhere(segment_id - y), axis=1), - (1, 0), - mode='constant') - segment_lengths = np.add.reduceat(np.ones_like(segment_id), idxs, axis=0) - return segment_lengths - - noise_span_lengths = randomly_segment( - num_noise_tokens, num_noise_spans, seed1) - nonnoise_span_lengths = randomly_segment( - num_nonnoise_tokens, num_noise_spans, seed2) - interleaved_span_lengths = np.reshape( - np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), - [num_noise_spans * 2]) - span_starts = np.cumsum(interleaved_span_lengths)[:-1] - span_start_indicator = np.zeros(length) # all 0s to begin with - span_start_indicator[span_starts] = 1 - span_num = np.cumsum(span_start_indicator) - is_noise = np.equal(span_num % 2, 1) - return is_noise[:orig_length] - - -def lower_endian_to_number(l, base): - """Helper function: convert a list of digits in the given base to a number.""" - return sum([d * (base**i) for i, d in enumerate(l)]) - - -def number_to_lower_endian(n, base): - """Helper function: convert a number to a list of digits in the given base.""" - if n < base: - return [n] - return [n % base] + number_to_lower_endian(n // base, base) - - -def random_number_lower_endian(length, base): - """Helper function: generate a random number as a lower-endian digits list.""" - if length == 1: # Last digit can be 0 only if length is 1. - return [np.random.randint(base)] - prefix = [np.random.randint(base) for _ in range(length - 1)] - return prefix + [np.random.randint(base - 1) + 1] # Last digit is not 0. - - -data_counters = {} # Used by {load,save}_data_counters and count_and_skip - - -def count_and_skip(generator, name): - """Count the number of items in the generator, skip already counted ones. - - This function counts the number of processed examples and puts it into - the global variable `counters`. This variable can be saved and restored, - and if restored, this function will skip examples until the restored counter - is reached. When the data generator is deterministic, this allows to restore - the data reading process from a checkpoint. - - Args: - generator: generator for examples in the dataset. - name: string, a unique id that we use to count the examples - - Yields: - The examples from generator but first skip the number specified in the - global variable counters[name] and next increment this variable every - time a new example appears. - """ - global data_counters - local_counter = 0 - for example in generator: - local_counter += 1 - # This check must be inside the loop due to asynchronous initializations. - if name not in data_counters: - data_counters[name] = 0 - if local_counter > data_counters[name]: - data_counters[name] += 1 - yield example - - -def save_data_counters(output_dir, host_id=None): - """Checkpoint data counters.""" - global data_counters - host_id = jax.process_index() if host_id is None else host_id - fname = os.path.join(output_dir, 'data_counters%d.pkl' % host_id) - with tf.io.gfile.GFile(fname, 'wb') as f: - pickle.dump(data_counters, f) - - -def load_data_counters(output_dir, host_id=None): - """Checkpoint data counters.""" - global data_counters - host_id = jax.process_index() if host_id is None else host_id - fname = os.path.join(output_dir, 'data_counters%d.pkl' % host_id) - if not tf.io.gfile.exists(fname): - logging.info('Did not load data counters as %s does not exist.', fname) - return - with tf.io.gfile.GFile(fname, 'rb') as f: - obj = pickle.load(f) - data_counters = obj - - -def _generator_process(generator, in_q, out_q): - for example in generator: - in_q.get() - out_q.put(example) - - -def _buckets_for_length(bucket_length, batch_size, max_eval_length, n_devices, - training): - """Creates heuristically a set of bucket boundaries and sizes. - - The middle boundary is set to `bucket_length` and the corresponding batch - size is set to `batch_size`. We also create buckets of 1/2 and 1/4 length - with 2x and 4x batch size, and buckets of 2x and 4x and larger length with - 1/2 and 1/4 batch size respectively, and batch size 1 for the final one. - - Args: - bucket_length: the length of the middle bucket. - batch_size: the batch size for the middle bucket. - max_eval_length: the longest bucket length if training=False. - n_devices: number of devices, batch sizes are divisible by that. - training: bool, whether we are training or evaluating. - - Returns: - a pair of lists of integers, (bucket_boundaries, bucket_batch_sizes). - """ - bucket_boundaries = [bucket_length // 4, bucket_length // 2, - bucket_length, bucket_length * 2, - bucket_length * 4, bucket_length * 8, - bucket_length * 16] - if not training: - max_eval_length = max_eval_length or bucket_length * 32 - # Set last bucket boundary to be max_eval_length, cut off boundaries - # that are larger than this. - bucket_boundaries = ( - [b for b in bucket_boundaries if b < max_eval_length] + - [max_eval_length] - ) - bucket_boundaries.append(max_eval_length) - bucket_batch_sizes = [batch_size * 4, batch_size * 2, - batch_size, batch_size // 2, - batch_size // 4, batch_size // 8, - batch_size // 16, 1] - if not training: - # The last bucket batch size is always 1, but the one-but-last is - # sized to accommodate the final length = bucket_boundaries[-1], which - # we changed for eval above -- so adjusting here too. - - # Resize if needed, since bucket_batch_sizes may not be the same size - # anymore. - bucket_batch_sizes = bucket_batch_sizes[:len(bucket_boundaries)] + [1] - bucket_batch_sizes[-2] = batch_size // max_eval_length - # Make batch sizes divisible by n_devices. - bucket_batch_sizes = [max(b // n_devices, 1) * n_devices - for b in bucket_batch_sizes] - return (bucket_boundaries, bucket_batch_sizes) - - -def _length_fn(example, length_axis, length_keys): - """Length is the maximum of shape on length_axis over length_keys.""" - if isinstance(example, (list, tuple)): - return max([example[i].shape[length_axis] for i in length_keys]) - return example.shape[length_axis] - - -# ######################################################################## -# Inputs class used by Trainer, and associated helper functions. -# -# Note: In the planned move from Trainer to Loop, the Inputs class should be -# deprecated and finally removed. - - -class Inputs: - """Inputs bundle. - - Inputs bundle holds input streams and shapes for a training run. - It contains stream-creating functions that return python generators - of (input_batch, target_batch) tuples. - - * train_stream: training data that will be used for training - may include all the augmentation or selection the training wants - the shape of examples is [batch_fn.batch_size, ...] - * train_eval_stream: training data used for evaluation - examples from training data but usually without augmentation - the shape of examples is [batch_fn.eval_batch_size, ...] - * eval_stream: evaluation data stream - examples from evaluation data, usually without augmentation - the shape of examples is [batch_fn.eval_batch_size, ...] - * input_shape: the shape of inputs - the [...] above, without batch size - * input_dtype: the data type of inputs - * target_shape: the shape of targets - the [...] above, without batch size - * target_dtype: the data type of targets - """ - - def __init__(self, train_stream, eval_stream=None, train_eval_stream=None): - """Initialize a new set of inputs. - - Args: - train_stream: a function taking n_devices (an int) and returning - a python generator of training batches. - eval_stream: a function taking n_devices (an int) and returning - a python generator of validation batches; - if None, then the training generator will be used for evaluation. - train_eval_stream: a function taking n_devices (an int) and returning - a python generator of batches from - the training set used for evaluation (if None, use train_stream). - """ - if not callable(train_stream): - raise ValueError('Trax Inputs should be initialized with a function. ' - 'Did you forget the n_devices argument? If your inputs ' - 'do not use it, try lambda _: [your-inputs].') - - self._train_stream = train_stream - self._eval_stream = eval_stream or self._train_stream - - # TODO(lukaszkaiser): should we get rid of this one day? - self._train_eval_stream = train_eval_stream or self._train_stream - - # Peek into the train stream to get an example shape. - example_train_batch = next(train_stream(1)) - self._input_shape = tuple(example_train_batch[0].shape)[1:] - self._input_dtype = example_train_batch[0].dtype - self._target_shape = tuple(example_train_batch[-1].shape)[1:] - self._target_dtype = example_train_batch[-1].dtype - self._example_shape = [x.shape for x in example_train_batch] - self._example_dtype = [x.dtype for x in example_train_batch] - - def train_stream(self, n_devices): - return self._train_stream(n_devices) - - def eval_stream(self, n_devices): - return self._eval_stream(n_devices) - - def train_eval_stream(self, n_devices): - return self._train_stream(n_devices) - - @property - def input_shape(self): - """Example input shape, without batch dimension.""" - return self._input_shape - - @property - def target_shape(self): - """Example target shape, without batch dimension.""" - return self._target_shape - - @property - def input_dtype(self): - """Dtype of the input.""" - return self._input_dtype - - @property - def target_dtype(self): - """Dtype of the target.""" - return self._target_dtype - - @property - def example_shape_dtype(self): - """Shape and Dtype of an example batch.""" - return self._example_shape, self._example_dtype - - -# Batching and Inputs creation helpers. - - -@gin.configurable(module='trax.data') -def make_inputs(train_stream=gin.REQUIRED, eval_stream=None): - """Create Inputs from two streams; mostly for use in gin configs.""" - if isinstance(train_stream, (list, tuple)): - train_stream = Serial(train_stream)() - if isinstance(eval_stream, (list, tuple)): - eval_stream = Serial(eval_stream)() - eval_stream_fn = None if eval_stream is None else lambda _: eval_stream - return Inputs(train_stream=lambda _: train_stream, - eval_stream=eval_stream_fn) - - -@gin.configurable(module='trax.data') -def make_additional_stream(stream=gin.REQUIRED): - """Create a stream mostly for use in gin configs for additional tasks.""" - return Serial(stream)() - - -@gin.configurable(module='trax.data') -def make_parallel_stream(streams=gin.REQUIRED, counters=None): - """Create a parallel stream for use in gin configs for additional tasks.""" - return Parallel(streams, counters=counters)() - - -@gin.configurable(module='trax.data') -def batcher(data_streams=gin.REQUIRED, variable_shapes=True, - batch_size_per_device=32, batch_size=None, eval_batch_size=32, - bucket_length=32, buckets=None, - buckets_include_inputs_in_length=False, - batch_shuffle_size=None, max_eval_length=None, - # TODO(afrozm): Unify padding logic. - id_to_mask=None, strict_pad_on_len=False): - """Batcher: create trax Inputs from single-example data-streams.""" - # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. - # For now leaving the arguments as in batch_fn to reduce gin config changes. - if callable(data_streams): # If we pass a function, e.g., through gin, call. - train_stream, eval_stream = data_streams() - else: - train_stream, eval_stream = data_streams - # pylint: disable=g-long-lambda - batch_train_stream = lambda n_devices: batch_fn( - train_stream(), True, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - batch_eval_stream = lambda n_devices: batch_fn( - eval_stream(), False, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - batch_train_eval_stream = lambda n_devices: batch_fn( - train_stream(), False, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - # pylint: enable=g-long-lambda - return Inputs(train_stream=batch_train_stream, - eval_stream=batch_eval_stream, - train_eval_stream=batch_train_eval_stream) - - -def batch_fn(dataset, training, n_devices, variable_shapes, - batch_size_per_device=32, batch_size=None, eval_batch_size=32, - bucket_length=32, buckets=None, - buckets_include_inputs_in_length=False, - batch_shuffle_size=None, max_eval_length=None, - id_to_mask=None, strict_pad_on_len=False): - """Batching function.""" - # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. - # After that, create a proper doc-string; we may also not need to pass both - # training and eval arguments here, as batcher calls the function separately - # now and it's not under gin-config any more -- consider reducing args. - batch_size = batch_size or batch_size_per_device * n_devices - # If bucketing is not specified, check if target shapes are variable. - cur_batch_size = batch_size if training else eval_batch_size - # Make cur_batch_size divisible by n_devices. - cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices - # Create heuristic buckets if none are specified. - if buckets is None: - logging.info('Heuristically setting bucketing to %s based on shapes ' - 'of target tensors.', variable_shapes) - if variable_shapes: - buckets = _buckets_for_length( - bucket_length, cur_batch_size, max_eval_length, n_devices, training) - - if buckets: - logging.info('Bucketing with buckets %s.', str(buckets)) - def example_length(x): - """The length function used by bucket_by_sequence_length to bucket.""" - # The input x is a tuple to go on the stack, typically either - # (input, target) or (input, target, mask). - example_inputs, target = x[0], x[1] - # Length is the shape of axis 0 here (no batch yet). - other_length = 0 # We include input length only if asked. - if buckets_include_inputs_in_length: - other_length = example_inputs.shape[0] - return max(target.shape[0], other_length) - boundaries, batch_sizes = buckets - dataset = bucket_by_length( - dataset, example_length, boundaries, batch_sizes, strict_pad_on_len) - else: - logging.info('Not Bucketing cur_batch_size %d.', cur_batch_size) - dataset = batch(dataset, cur_batch_size) - if training and batch_shuffle_size is not None: - dataset = shuffle(dataset, batch_shuffle_size) - return add_loss_weights(dataset, id_to_mask) - - -# Example input functions. - - -@gin.configurable(module='trax.data') -def random_inputs( - input_shape=gin.REQUIRED, input_dtype=jnp.int32, input_range=(0, 255), - output_shape=gin.REQUIRED, output_dtype=jnp.int32, output_range=(0, 9)): - """Make random Inputs for debugging. - - Args: - input_shape: the shape of inputs (including batch dimension). - input_dtype: the type of the inputs (int32 by default). - input_range: the range of inputs (defaults to (0, 255)). - output_shape: the shape of outputs (including batch dimension). - output_dtype: the type of the outputs (int32 by default). - output_range: the range of outputs (defaults to (0, 9)). - - Returns: - trax.inputs.Inputs - """ - def random_minibatches(n_devices): - """Generate a stream of random mini-batches.""" - assert input_range[0] % n_devices == 0 - if input_dtype in [jnp.float16, jnp.float32, jnp.float64]: - rand = np.random.uniform - else: - rand = np.random.random_integers - while True: - inp = rand(input_range[0], input_range[1], input_shape) - inp = inp.astype(input_dtype) - out = rand(output_range[0], output_range[1], output_shape) - out = out.astype(output_dtype) - yield inp, out - - return Inputs(random_minibatches) - - -@gin.configurable(module='trax.data') -def sequence_copy_inputs( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, - eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, reverse=False, - pad_to_multiple=32): - """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*. - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - train_length: maximum length of w for training. - eval_min_length: minimum length of w for eval. - eval_max_length : maximum length of w for eval. - reverse: bool (optional, false by default): reverse the second sequence. - pad_to_multiple: int, pad length to be multiple of this number. - - Returns: - trax.inputs.Inputs - """ - def random_minibatches(length_list): - """Generate a stream of random mini-batches.""" - while True: - length = random.choice(length_list) - assert length % 2 == 0 - w_length = (length // 2) - 1 - w = np.random.randint(low=1, high=vocab_size-1, - size=(batch_size, w_length)) - zero = np.zeros([batch_size, 1], np.int32) - loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)), - np.ones((batch_size, w_length))], axis=1) - if reverse: - x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1) - else: - x = np.concatenate([zero, w, zero, w], axis=1) - x = _pad_to_multiple_of(x, pad_to_multiple, 1) - loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) - yield (x, x, loss_weights) # Here inputs and targets are the same. - - train_lengths = [2*(i+2) for i in range(train_length - 1)] - eval_lengths = [2*(i+1) for i in range(eval_min_length, eval_max_length)] - return Inputs( - train_stream=lambda _: random_minibatches(train_lengths), - eval_stream=lambda _: random_minibatches(eval_lengths) - ) - - -@gin.configurable(module='trax.data') -def simple_sequence_copy_inputs( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, - eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, - pad_to_multiple=32): - """Inputs for the sequence copy problem: w for w in [1..vocab_size-1]*. - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - train_length: maximum length of w for training. - eval_min_length: minimum length of w for eval. - eval_max_length : maximum length of w for eval. - pad_to_multiple: int, pad length to be multiple of this number. - - Returns: - trax.inputs.Inputs - """ - def random_minibatches(length_list): - """Generate a stream of random mini-batches.""" - while True: - length = random.choice(length_list) - x = np.random.randint(low=1, high=vocab_size-1, - size=(batch_size, length)) - loss_weights = np.ones((batch_size, length)) - x = _pad_to_multiple_of(x, pad_to_multiple, 1) - loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) - yield (x, x, loss_weights) # Here inputs and targets are the same. - - train_lengths = list(range(1, train_length + 1)) - eval_lengths = list(range(eval_min_length, eval_max_length + 1)) - return Inputs( - train_stream=lambda _: random_minibatches(train_lengths), - eval_stream=lambda _: random_minibatches(eval_lengths) - ) - - -@gin.configurable(module='trax.data') -def addition_inputs( - vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, - eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, - pad_to_multiple=32, encdec=False): - """Inputs for the add problem: x+y(x+y). - - Args: - vocab_size: how many symbols to use. - batch_size: how large are the batches. - train_length: maximal length of w for training. - eval_min_length: minimal length of w for eval. - eval_max_length: maximal length of w for eval. - pad_to_multiple: int, pad length to be multiple of this number. - encdec: bool, if True return encoder-decoder style inputs (default: False) - - Returns: - trax.inputs.Inputs - """ - train_stream = addition_input_stream( - vocab_size, batch_size, 3, train_length, pad_to_multiple, encdec) - eval_stream = addition_input_stream( - vocab_size, batch_size, eval_min_length, eval_max_length, pad_to_multiple, - encdec) - return Inputs( - train_stream=lambda _: train_stream, - eval_stream=lambda _: eval_stream - ) - - -@gin.configurable(module='trax.data') -def sine_inputs( - batch_size=gin.REQUIRED, - length=gin.REQUIRED, - max_phase=(2 * math.pi), - min_period=0.1, - max_period=10.0, -): - """Sinusoids of random period and phase. - - Args: - batch_size (int): Number of examples in a batch. - length (int): Length of each sequence. - max_phase (float): Maximum phase of the sinusoids. - min_period (float): Minimum period of the sinusoids. - max_period (float): Maximum period of the sinusoids. - - Returns: - trax.inputs.Inputs - """ - def random_series(): - while True: - phase = np.random.uniform(0, max_phase) - period = np.exp(np.random.uniform(np.log(min_period), np.log(max_period))) - x = np.arange(length) - yield np.sin((x - phase) / period) - - def random_minibatches(_): - minibatch = [] - for series in random_series(): - minibatch.append(series) - if len(minibatch) == batch_size: - obs = np.stack(minibatch) - minibatch.clear() - act = np.zeros_like(obs, dtype=np.int32) - mask = np.ones_like(obs) - yield (obs, act, obs, mask) - - return Inputs(train_stream=random_minibatches, eval_stream=random_minibatches) - - -def _pad_to_multiple_of(x, y, axis): - """Pads x to multiple of y on the given axis.""" - pad_len = np.ceil(x.shape[axis] / float(y)) * y - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = (0, int(pad_len - x.shape[axis])) - return np.pad(x, pad_widths, mode='constant', - constant_values=x.dtype.type(0)) diff --git a/trax/data/inputs_test.py b/trax/data/inputs_test.py deleted file mode 100644 index e4cf5c0bd..000000000 --- a/trax/data/inputs_test.py +++ /dev/null @@ -1,774 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data.inputs.""" - -import itertools -import os - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -from trax import data - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') - - -def _spm_path(): - return os.path.join(_TESTDATA, 'sentencepiece.model') - - -class InputsTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('zero', 0), - ('negative', -5), - ) - def test_shuffle_data_raises_error_queue_size(self, queue_size): - samples = iter(range(10)) - with self.assertRaises(ValueError): - _ = list(data.shuffle(samples, queue_size)) - - @parameterized.named_parameters( - ('one', 1), - ('two', 2), - ('twenty', 20), - ) - def test_shuffle_data_queue_size(self, queue_size): - samples = iter(range(100, 200)) - shuffled_stream = data.shuffle(samples, queue_size) - first_ten = [next(shuffled_stream) for _ in range(10)] - - # Queue size limits how far ahead/upstream the current sample can reach. - self.assertLess(first_ten[0], 100 + queue_size) - self.assertLess(first_ten[3], 103 + queue_size) - self.assertLess(first_ten[9], 109 + queue_size) - - unshuffled_first_ten = list(range(100, 110)) - if queue_size == 1: # Degenerate case: no shuffling can happen. - self.assertEqual(first_ten, unshuffled_first_ten) - if queue_size > 1: - self.assertNotEqual(first_ten, unshuffled_first_ten) - - @parameterized.named_parameters( - ('qsize_100_n_001', 100, 1), - ('qsize_100_n_099', 100, 99), - ('qsize_100_n_100', 100, 100), - ('qsize_100_n_101', 100, 101), - ('qsize_100_n_199', 100, 199), - ) - def test_shuffle_data_yields_all_samples(self, queue_size, n_samples): - samples = iter(range(n_samples)) - shuffled_stream = data.shuffle(samples, queue_size) - self.assertLen(list(shuffled_stream), n_samples) - - def test_batch_data(self): - dataset = ((i, i+1) for i in range(10)) - batches = data.batch(dataset, 10) - batch = next(batches) - self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (10,)) - - def test_batch_data_padding(self): - dataset = (([1] * (10 - i), i+1) for i in range(10)) - batches = data.batch(dataset, 10) - batch = next(batches) - self.assertEqual(batch[0].shape, (10, 10)) - self.assertTrue(np.array_equal(batch[0][-1], np.asarray([1] + 9 * [0]))) - - def test_batch_exception_size(self): - dataset = ((i, i + 1) for i in range(10)) - with self.assertRaises(ValueError): - batches = data.batch(dataset, 0) - next(batches) - - def test_serial(self): - dataset = lambda _: ((i, i+1) for i in range(10)) - batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10)) - batch = next(batches()) - self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (10,)) - - def test_serial_composes(self): - """Check that data.Serial works inside another data.Serial.""" - dataset = lambda _: ((i, i+1) for i in range(10)) - serial1 = data.Serial(dataset, data.Shuffle(3)) - batches = data.Serial(serial1, data.Batch(10)) - batch = next(batches()) - self.assertLen(batch, 2) - self.assertEqual(batch[0].shape, (10,)) - - def test_count_and_skip(self): - dataset = lambda _: ((i, i+1) for i in range(10)) - examples = data.Serial(dataset, data.CountAndSkip('toy_data')) - ex_generator = examples() - ex1 = next(ex_generator) - self.assertEqual(ex1, (0, 1)) - self.assertEqual(data.inputs.data_counters['toy_data'], 1) - ex2 = next(ex_generator) - self.assertEqual(ex2, (1, 2)) - self.assertEqual(data.inputs.data_counters['toy_data'], 2) - ex3 = next(examples()) # new generator, will skip - self.assertEqual(ex3, (2, 3)) - self.assertEqual(data.inputs.data_counters['toy_data'], 3) - data.inputs.data_counters['toy_data'] = 0 # reset - ex4 = next(examples()) # new generator, was reset - self.assertEqual(ex4, (0, 1)) - self.assertEqual(data.inputs.data_counters['toy_data'], 1) - - def test_parallel(self): - """Basic test of the parallel ccmbinator.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - parallel = data.Parallel([dataset1, dataset2]) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 12) - - def test_parallel_with_gen_not_none(self): - """Test of the parallel ccmbinator with a not none generator.""" - dataset1 = lambda _: (i for i in range(10)) - dataset2 = lambda _: (i for i in range(10, 20)) - parallel = data.Parallel([dataset1, dataset2]) - - def test_generator(): - yield 0 - - generator = parallel(gen=test_generator) - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 12) - - def test_parallel_with_weights(self): - """Test of the parallel ccmbinator with weights.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - parallel = data.Parallel([dataset1, dataset2], counters=(2, 1)) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 3) - self.assertEqual(next(generator), 12) - self.assertEqual(next(generator), 4) - self.assertEqual(next(generator), 5) - self.assertEqual(next(generator), 13) - - def test_parallel_with_weights_and_minimum(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 110)) - parallel = data.Parallel([dataset1, dataset2], - counters=(10, 100), - reweight_by_minimum=True) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 12) - self.assertEqual(next(generator), 13) - self.assertEqual(next(generator), 14) - self.assertEqual(next(generator), 15) - self.assertEqual(next(generator), 16) - self.assertEqual(next(generator), 17) - self.assertEqual(next(generator), 18) - self.assertEqual(next(generator), 19) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 20) - self.assertEqual(next(generator), 21) - self.assertEqual(next(generator), 22) - self.assertEqual(next(generator), 23) - self.assertEqual(next(generator), 24) - self.assertEqual(next(generator), 25) - self.assertEqual(next(generator), 26) - self.assertEqual(next(generator), 27) - self.assertEqual(next(generator), 28) - self.assertEqual(next(generator), 29) - self.assertEqual(next(generator), 2) - - def test_parallel_with_gradual_reweighting(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in itertools.cycle(range(1))) - dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) - dataset3 = lambda: (i for i in itertools.cycle(range(30, 70))) - parallel = data.Parallel([dataset2, dataset1, dataset3], - counters=(20, 1, 40), - gradually_reweight=True) - generator = parallel() - - for _ in range(3): - self.assertEqual(next(generator), 0) - for i in range(20): - self.assertEqual(next(generator), 10 + i) - self.assertEqual(next(generator), 30 + 2 * i) - self.assertEqual(next(generator), 30 + 2 * i + 1) - - def test_parallel_with_gradual_reweighting_remainders(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in itertools.cycle(range(1))) - dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) - dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) - parallel = data.Parallel([dataset2, dataset1, dataset3], - counters=(20, 1, 50), - gradually_reweight=True, - use_remainders=True) - generator = parallel() - - for _ in range(3): - self.assertEqual(next(generator), 0) - for i in range(20): - self.assertEqual(next(generator), 10 + i) - self.assertEqual(next(generator), 30 + 2 * i) - self.assertEqual(next(generator), 30 + 2 * i + 1) - # Here we process the remainder from dataset 3: - for i in range(10): - self.assertEqual(next(generator), 70 + i) - - def test_parallel_with_gradual_reweighting_remainders_big(self): - """Test of the parallel ccmbinator with weights and minimum.""" - dataset1 = lambda: (i for i in itertools.cycle(range(1))) - dataset2 = lambda: (i for i in itertools.cycle(range(10, 30))) - dataset3 = lambda: (i for i in itertools.cycle(range(30, 80))) - dataset4 = lambda: (i for i in itertools.cycle(range(100, 220))) - parallel = data.Parallel([dataset2, dataset1, dataset4, dataset3], - counters=(20, 1, 120, 50), - gradually_reweight=True, - use_remainders=True) - generator = parallel() - - for _ in range(3): - self.assertEqual(next(generator), 0) - for i in range(20): - self.assertEqual(next(generator), 10 + i) - for j in range(2): - self.assertEqual(next(generator), 30 + 2 * i + j) - for k in range(2): - self.assertEqual(next(generator), 100 + 2 * 2 * i + 2 * j + k) - # Here we process the remainder from datasets 3 and 4: - for i in range(10): - self.assertEqual(next(generator), 70 + i) - for i in range(40): - self.assertEqual(next(generator), 180 + i) - - def test_parallel_with_weights_three_datasets(self): - """Check that data.Serial works inside another data.Serial.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - dataset3 = lambda: (i for i in range(20, 30)) - parallel = data.Parallel( - [dataset1, dataset2, dataset3], counters=(2, 1, 3)) - generator = parallel() - - self.assertEqual(next(generator), 0) # (1,0,0) - self.assertEqual(next(generator), 10) # (1,1,0) - self.assertEqual(next(generator), 20) # (1,1,1) - self.assertEqual(next(generator), 1) # (2,1,1) - self.assertEqual(next(generator), 21) # (2,1,2) - self.assertEqual(next(generator), 22) # (2,1,3) - self.assertEqual(next(generator), 2) # (1,0,0) - self.assertEqual(next(generator), 11) # (1,1,0) - self.assertEqual(next(generator), 23) # (1,1,1) - self.assertEqual(next(generator), 3) # (2,1,1) - self.assertEqual(next(generator), 24) # (2,1,2) - self.assertEqual(next(generator), 25) # (2,1,3) - self.assertEqual(next(generator), 4) # (1,0,0) - - def test_stack_parallel(self): - """Test of stacked parallel ccmbinators.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - dataset3 = lambda: (i for i in range(20, 30)) - parallel_lev0 = data.Parallel([dataset1, dataset2]) - parallel_lev1 = data.Parallel([parallel_lev0, dataset3]) - generator = parallel_lev1() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 20) - self.assertEqual(next(generator), 10) - self.assertEqual(next(generator), 21) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 22) - self.assertEqual(next(generator), 11) - self.assertEqual(next(generator), 23) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 24) - self.assertEqual(next(generator), 12) - - def test_parallel_with_zero_counters(self): - """Test of stacked parallel ccmbinators.""" - dataset1 = lambda: (i for i in range(10)) - dataset2 = lambda: (i for i in range(10, 20)) - dataset3 = lambda: (i for i in range(20, 30)) - parallel = data.Parallel([dataset1, dataset2, dataset3], counters=[1, 0, 1]) - generator = parallel() - - self.assertEqual(next(generator), 0) - self.assertEqual(next(generator), 20) - self.assertEqual(next(generator), 1) - self.assertEqual(next(generator), 21) - self.assertEqual(next(generator), 2) - self.assertEqual(next(generator), 22) - self.assertEqual(next(generator), 3) - self.assertEqual(next(generator), 23) - - def test_serial_with_python(self): - dataset = lambda _: ((i, i+1) for i in range(10)) - batches = data.Serial( - dataset, - lambda g: map(lambda x: (x[0], x[1] + 1), g), - lambda g: filter(lambda x: x[0] % 2 == 1, g), - data.Batch(2) - ) - batch = next(batches()) - self.assertLen(batch, 2) - (xs, ys) = batch - # First tuple after filtering is (1, 3) = (1, 2+1). - self.assertEqual(xs[0], 1) - self.assertEqual(ys[0], 3) - # Second tuple after filtering is (3, 5). - self.assertEqual(xs[1], 3) - self.assertEqual(ys[1], 5) - - def test_pad_to_max_dims(self): - tensors1 = [np.zeros((3, 10)), np.ones((3, 10))] - padded1 = data.inputs.pad_to_max_dims(tensors1) - self.assertEqual(padded1.shape, (2, 3, 10)) - tensors2 = [np.zeros((2, 10)), np.ones((3, 9))] - padded2 = data.inputs.pad_to_max_dims(tensors2) - self.assertEqual(padded2.shape, (2, 3, 10)) - tensors3 = [np.zeros((8, 10)), np.ones((8, 9))] - padded3 = data.inputs.pad_to_max_dims(tensors3, 12) - self.assertEqual(padded3.shape, (2, 12, 12)) - tensors4 = [np.zeros((2, 10)), np.ones((3, 9))] - padded4 = data.inputs.pad_to_max_dims(tensors4, 12) - self.assertEqual(padded4.shape, (2, 4, 12)) - - def test_pad_to_length(self): - tensors1 = [(np.zeros((5)), np.ones((3)))] - pad_to_length_function1 = data.inputs.PadToLength(len_map={0: 10, - 1: 11}, - pad_value={0: 0, - 1: 1}) - padded1 = next(pad_to_length_function1(tensors1)) - self.assertEqual(padded1[0].shape, (10,)) - self.assertEqual(padded1[1].shape, (11,)) - - tensors2 = [(np.zeros((15)), np.ones((20)))] - pad_to_length_function2 = data.inputs.PadToLength(len_map={0: 10, - 1: 10}, - pad_value={0: 0, - 1: 1}, - multiple=True) - padded2 = next(pad_to_length_function2(tensors2)) - self.assertEqual(padded2[0].shape, (20,)) - self.assertEqual(padded2[1].shape, (20,)) - - def test_concatenate_lm_input(self): - tensors1 = [(np.zeros((5)), np.ones((3)))] - - lm_input_function1 = data.inputs.ConcatenateToLMInput(pad_to_length=10) - lm_input_1 = next(lm_input_function1(tensors1)) - self.assertEqual(lm_input_1[0].shape, (10,)) - self.assertEqual(lm_input_1[1].shape, (10,)) - self.assertEqual(lm_input_1[2].shape, (10,)) - self.assertEqual(lm_input_1[2].all(), - np.array([[0., 0., 0., 0., 0., - 1., 1., 1., 0., 0.]]).all()) - - tensors2 = [(np.zeros((5)), np.ones((3)))] - lm_input_function2 = data.inputs.ConcatenateToLMInput() - lm_input_2 = next(lm_input_function2(tensors2)) - self.assertEqual(lm_input_2[0].shape, (8,)) - self.assertEqual(lm_input_2[1].shape, (8,)) - self.assertEqual(lm_input_2[2].shape, (8,)) - self.assertEqual(lm_input_2[2].all(), - np.array([[0., 0., 0., 0., 0., - 1., 1., 1.]]).all()) - - def test_truncate_to_length_no_arg(self): - """Tests that a no-arg call leaves shapes unchanged.""" - def data_stream(): - while True: - yield (np.zeros((1, 5)), np.ones((1, 5))) - stream_fn = data.inputs.TruncateToLength() - y0, y1 = next(stream_fn(data_stream())) - self.assertEqual(y0.shape, (1, 5)) - self.assertEqual(y1.shape, (1, 5)) - - @parameterized.named_parameters( - ('none', None, ((1, 5), (1, 5))), - ('large_values', {0: (1, 77), 1: (1, 88)}, ((1, 5), (1, 5))), - ('small_values', {0: (1, 3), 1: (1, 2)}, ((1, 3), (1, 2))), - ) - def test_truncate_to_length_len_map(self, len_map, out_shapes): - """Tests that truncation occurs when len_map values are small enough.""" - def data_stream(): - while True: - yield (np.zeros((1, 5)), np.ones((1, 5))) - stream_fn = data.inputs.TruncateToLength(len_map=len_map) - y0, y1 = next(stream_fn(data_stream())) - self.assertEqual(y0.shape, out_shapes[0]) - self.assertEqual(y1.shape, out_shapes[1]) - - def test_truncate_to_length_questionable_behavior(self): - # Use of np.reshape in TruncateToLength allows non-truncation results - # without warning. As long as the target shape (len_map value) is - # lexicographically prior to the data shape, then np.reshape can happen, - # even if it results in *adding* values to the overall array. - # - # This test passes as a marker of the questionable behavior, and should - # *fail* -- and then be removed -- when the function is - # clarified/re-implemented. - # - # TODO(jonni): Determine desired behavior, and fit implementation to it. - x = np.arange(21).reshape((1, 21, 1)) - def data_stream(): - while True: - yield x - stream_fn = data.inputs.TruncateToLength(len_map={0: (1, 4, 6)}) - (y,) = next(stream_fn(data_stream())) - self.assertEqual(y.shape, (1, 4, 6)) - self.assertEqual(y[0, 3, 1], 19) - self.assertEqual(y[0, 3, 2], 20) # end of original values [0..20] - self.assertEqual(y[0, 3, 3], 0) # added value - self.assertEqual(y[0, 3, 4], 1) # added value - self.assertEqual(y[0, 3, 5], 2) # added value - - def test_filter_empty_examples(self): - tensors1 = [(np.zeros((0,)), np.ones((1, 5))), - (np.zeros((1, 5)), np.ones((1, 5)))] - - filter_empty_examples_function1 = data.inputs.FilterEmptyExamples() - filtered1 = next(filter_empty_examples_function1(tensors1)) - self.assertEqual(filtered1[0].shape, (1, 5)) - self.assertEqual(filtered1[1].shape, (1, 5)) - - filter_empty_examples_function2 = data.inputs.FilterEmptyExamples(axes=[1]) - filtered2 = next(filter_empty_examples_function2(tensors1)) - self.assertEqual(filtered2[0].shape, (0,)) - self.assertEqual(filtered2[1].shape, (1, 5)) - - def test_append_value(self): - tensors1 = [(np.zeros((1, 5)), np.ones((1, 5)))] - - append_value_function1 = data.inputs.AppendValue() - unmodified = next(append_value_function1(tensors1)) - self.assertEqual(unmodified[0].shape, (1, 5)) - self.assertEqual(unmodified[1].shape, (1, 5)) - - append_value_function2 = data.inputs.AppendValue({0: [[5]], - 1: [[4]]}) - appended = next(append_value_function2(tensors1)) - self.assertEqual(appended[0].shape, (1, 6)) - self.assertEqual(appended[0].all(), - np.array([[0., 0., 0., 0., 0., 5.]]).all()) - self.assertEqual(appended[1].shape, (1, 6)) - self.assertEqual(appended[1].all(), - np.array([[1., 1., 1., 1., 1., 4.]]).all()) - - def test_pad_to_max_dims_boundary_list(self): - tensors = [np.zeros((1, 15, 31)), np.ones((2, 10, 35)), np.ones((4, 2, 3))] - padded_tensors = data.inputs.pad_to_max_dims( - tensors, boundary=(None, 15, 20)) - # no boundary, only max in the first dim, 15 is already the max len in - # second dim, last dim padded to multiple of 20. - # The outer dim is the batch here. - self.assertEqual(padded_tensors.shape, (3, 4, 15, 40)) - - def test_pad_to_max_dims_strict_pad_on_len(self): - tensors = [np.ones((15,)), np.ones((12,)), np.ones((14,))] - padded_tensors = data.inputs.pad_to_max_dims( - tensors, boundary=10, strict_pad_on_len=True) - self.assertEqual(padded_tensors.shape, (3, 20)) - - def test_bucket_by_length(self): - def fake_generator(length, num_examples=1): - for _ in range(num_examples): - yield (np.ones((length,)), np.ones((length,))) - - def length_function(example): - return max(example[0].shape[0], example[1].shape[0]) - - batches = list(data.bucket_by_length(fake_generator(5, 6), - length_function, - [20], - [2], - strict_pad_on_len=True)) - - # We'll get three batches of 2 examples each. - self.assertLen(batches, 3) - self.assertIsInstance(batches[0], tuple) - self.assertLen(batches[0], 2) - self.assertEqual((2, 20), batches[0][0].shape) - self.assertEqual((2, 20), batches[0][1].shape) - - @parameterized.named_parameters( - ('encdec_on', True), - ('encdec_off', False), - ) - def test_addition_inputs_exceptions(self, encdec): - vocab_size = 5 - batch_size = 256 - seq_length = 64 - # Check if max/min lengths are validated for train stream - with self.assertRaises(ValueError): - inputs = data.inputs.addition_inputs( - vocab_size=vocab_size, - batch_size=batch_size, - train_length=2, - eval_min_length=1, - eval_max_length=seq_length, - pad_to_multiple=seq_length, - encdec=encdec) - train_stream = inputs.train_stream(n_devices=1) - for _ in range(10): - next(train_stream) - - # Check if max/min lengths are validated for eval stream - with self.assertRaises(ValueError): - inputs = data.inputs.addition_inputs( - vocab_size=vocab_size, - batch_size=batch_size, - train_length=seq_length, - eval_min_length=1, - eval_max_length=seq_length, - pad_to_multiple=seq_length, - encdec=True) - eval_stream = inputs.eval_stream(n_devices=1) - for _ in range(10): - next(eval_stream) - - def test_addition_inputs_constraints(self): - vocab_size = 5 - batch_size = 256 - seq_length = 64 - inputs = data.inputs.addition_inputs( - vocab_size=vocab_size, - batch_size=batch_size, - train_length=seq_length, - eval_min_length=seq_length, - eval_max_length=seq_length, - pad_to_multiple=seq_length, - encdec=True) - - # Check if max length is respected for train stream - train_stream = inputs.train_stream(n_devices=1) - for _ in range(10): - x, y, weights = next(train_stream) - self.assertEqual(x.shape[1], seq_length) - self.assertEqual(y.shape[1], seq_length) - self.assertEqual(weights.shape[1], seq_length) - - # Check if max length is respected for eval stream - eval_stream = inputs.eval_stream(n_devices=1) - for _ in range(10): - x, y, weights = next(eval_stream) - self.assertEqual(x.shape[1], seq_length) - self.assertEqual(y.shape[1], seq_length) - self.assertEqual(weights.shape[1], seq_length) - - def _get_span_lengths(self, x): - span_lengths = [] - curr_len = 0 - for i in range(1, len(x)): - # 1 -> 0 - if x[i] == 0 and x[i - 1] == 1: - span_lengths.append(curr_len) - curr_len = 0 - # 1 -> 1 or 0 -> 1 - elif ((x[i] == 1 and x[i - 1] == 1) or - (x[i] == 1 and x[i - 1] == 0)): - curr_len += 1 - if curr_len != 0: - span_lengths.append(curr_len) - return span_lengths - - def test_random_spans_noise_mask(self): - length = 100 - noise_density = 0.15 - mean_noise_span_length = 3.0 - - # Take 5 random seed1, seed2 values. - for seed in np.random.randint(0, 100, (5, 2)): - is_noise = data.random_spans_noise_mask(length, - noise_density, - mean_noise_span_length, - seed1=seed[0], - seed2=seed[1]) - is_noise = is_noise.astype(np.int32) - # noise_density fraction of tokens are produced - self.assertEqual(np.sum(is_noise), noise_density * length) - # Get span lengths and make sure the average is what we expect. - actual_span_lengths = self._get_span_lengths(is_noise) - average_span_length = ( - sum(actual_span_lengths) / len(actual_span_lengths)) - self.assertEqual(mean_noise_span_length, average_span_length) - - def test_process_c4_with_span_corruption(self): - def process_c4_with_span_corruption(spm_path=None, - extra_ids=0, - train=False, - max_length=100, - noise_density=0.15, - mean_noise_span_length=3.0, - seed1=None, - seed2=None): - return data.Serial( - data.TFDS( - 'c4/en:2.3.0', data_dir=_TESTDATA, keys=('text',), train=train), - data.SentencePieceTokenize(spm_path=spm_path, extra_ids=extra_ids), - data.generate_sequential_chunks(max_length=max_length), - data.generate_random_noise_mask( - noise_density=noise_density, - mean_noise_span_length=mean_noise_span_length, - seed1=seed1, seed2=seed2), - data.consume_noise_mask(vocab_size=32000 + extra_ids), - data.FilterEmptyExamples(), - data.AppendValue(val={0: [1], 1: [1]}), - data.PadToLength(len_map={0: 100, 1: 30}, pad_value={0: 0, 1: 0}), - data.AddLossWeights(id_to_mask=0), - data.Batch(batch_size=2) - ) - - gen = process_c4_with_span_corruption( - spm_path=_spm_path(), seed1=0, seed2=1) - - examples = [] - for i, ex in enumerate(gen()): - if i == 100: - break - examples.append(ex) - - self.assertLen(examples, 100) - example = examples[0] - - batched_input, batched_output, batched_loss_weights = example - - self.assertSequenceEqual( - batched_input.tolist(), - # pylint: disable=bad-continuation,bad-whitespace - [[ 37, 2335, 113, 3977, 227, 7306, 45, 3, 9, - 4716, 147, 8, 71, 2658, 65, 118, 4313, 38, - 3, 9, 13065, 32, 31999, 9, 5704, 26, 109, - 6, 6862, 6, 4728, 45, 8, 3796, 24093, 11834, - 4716, 30, 8, 1379, 13, 31998, 130, 718, 12, - 8, 24124, 1343, 300, 4357, 1714, 31997, 1373, 47, - 16487, 3168, 16, 321, 7943, 5, 3, 4868, 3856, - 5700, 75, 7, 200, 2231, 6, 11163, 9, 6, - 113, 47, 5330, 45, 14354, 6, 47, 31996, 20721, - 3654, 44, 8, 3112, 5, 14599, 11, 8067, 31995, - 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0], - [ 277, 828, 43, 5899, 46, 16, 10952, 139, 160, - 1687, 56, 539, 30, 2875, 41, 31122, 2307, 137, - 2702, 2780, 15, 7, 31999, 44, 8, 3112, 11, - 30, 569, 783, 5, 3, 17701, 6, 2194, 26, - 23, 1336, 6321, 1694, 30, 31998, 196, 56, 1852, - 1423, 25, 5, 27, 183, 8032, 31997, 217, 149, - 1513, 11, 2238, 25, 1800, 5, 96, 2703, 44, - 3065, 12537, 11163, 9, 535, 71, 9363, 14886, 646, - 44, 8, 3112, 243, 23281, 12, 8, 31996, 346, - 402, 17, 99, 83, 11, 773, 3668, 1280, 31995, - 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0]] - # pylint: enable=bad-continuation,bad-whitespace - ) - - self.assertSequenceEqual( - batched_output.tolist(), - # pylint: disable=bad-continuation,bad-whitespace - [[31999, 1639, 7, 15480, 5, 11163, 31998, 2083, 9997, - 5076, 31997, 265, 11, 8, 31996, 3, 31995, 1343, - 2487, 106, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0], - [31999, 12, 8, 15480, 130, 646, 31998, 1376, 10, - 96, 31997, 62, 410, 59, 31996, 96, 31995, 94, - 608, 10, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0]] - # pylint: enable=bad-continuation,bad-whitespace - ) - - self.assertSequenceEqual( - batched_loss_weights.tolist(), - # pylint: disable=bad-continuation,bad-whitespace - [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], - [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]] - # pylint: enable=bad-continuation,bad-whitespace - ) - - def test_prefix_lm_last_output_batch_is_short(self): - prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) - examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7, 8]])) - self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) - self.assertSequenceEqual(([6, 7], [8]), examples[1]) - self.assertLen(examples, 2) - - def test_prefix_lm_last_input_batch_is_short(self): - prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) - examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6]])) - self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) - self.assertLen(examples, 1) - - def test_prefix_lm_last_input_batch_exists_but_no_output(self): - prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) - examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7]])) - self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) - self.assertLen(examples, 1) - - def test_unbatch(self): - unbatch_fn = data.UnBatch() - batched_inputs = [ - # First batch - 3 examples - (np.arange(3*2).reshape(3, -1), - np.arange(3*3).reshape(3, -1), - np.arange(3*4).reshape(3, -1)), - # Second batch - 4 examples - (np.arange(4*2).reshape(4, -1), - np.arange(4*3).reshape(4, -1), - np.arange(4*4).reshape(4, -1)), - ] - examples = list(unbatch_fn(batched_inputs)) - self.assertLen(examples, 3 + 4) - - def test_sine_shape(self): - inputs = data.sine_inputs(batch_size=3, length=5) - train_batch = next(inputs.train_stream(n_devices=1)) - eval_batch = next(inputs.eval_stream(n_devices=1)) - # (observations, actions, observations, mask) - self.assertLen(train_batch, 4) - self.assertLen(eval_batch, 4) - for (x, y) in zip(train_batch, eval_batch): - self.assertEqual(x.shape, (3, 5)) - self.assertEqual(y.shape, (3, 5)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/data/loader/__init__.py b/trax/data/loader/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/loader/tf/__init__.py b/trax/data/loader/tf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/loader/tf/base.py b/trax/data/loader/tf/base.py new file mode 100644 index 000000000..91904a86c --- /dev/null +++ b/trax/data/loader/tf/base.py @@ -0,0 +1,1560 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow data sources and associated prepocessing functions.""" + +import functools +import itertools +import json +import os +import random +import re + +import gin +import jax +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from absl import logging + +from trax import fastmath +from trax.data.encoder.encoder import SentencePieceEncoder +from trax.data.preprocessing.tf.math import ( + convert_float_to_mathqa, + convert_to_subtract, +) + +# How many examples from the stream to skip at random during training. +# For now, we skip at most 100K examples for efficiency. +_MAX_SKIP_EXAMPLES = 1e5 + +_T2T_TO_TFDS_MAP = { + # Translation + "t2t_translate_ende_wmt32k": "wmt14_translate/de-en", + "t2t_wmt14_translate/de-en": "wmt14_translate/de-en", + + # Language modeling + "t2t_languagemodel_lm1b32k": "lm1b", + "t2t_languagemodel_ptb10k": "ptb_text_only", + + # Byte/text corpora + "t2t_enwik8_l2k": "enwik8", + "t2t_enwik8_l65k": "enwik8", + + # Sentiment/classification + "t2t_sentiment_imdb": "imdb_reviews", + + # Summarization + "t2t_summarize_cnn_dailymail32k": "cnn_dailymail", + + # Vision + "t2t_image_imagenet224": "imagenet2012", + "t2t_image_imagenet64_gen_flat_rev": "downsampled_imagenet/64x64", + + # Video + "t2t_video_bair_robot_pushing": "bair_robot_pushing_small", +} + + +def t5_data(): + """Get the T5 data module if available.""" + module = None + try: + import t5.data # pylint: disable=g-import-not-at-top + + module = t5.data + except AttributeError as e: + logging.error("pip install t5") + raise e + return module + + +def random_split_text_tf(max_words_per_segment=512, text_key="text"): + """ + Returns a TFDS preprocessing function that chunks long text randomly. + """ + + def preprocess_fn(dataset): + def random_chunk(example): + text = example[text_key] + # Basic whitespace tokenizer (can be replaced with SentencePiece) + tokens = tf.strings.split([text]).values + length = tf.size(tokens) + + max_len = tf.minimum(length, max_words_per_segment) + start = tf.random.uniform( + shape=[], maxval=length - max_len + 1, dtype=tf.int32 + ) + chunk = tokens[start : start + max_len] + + # Rejoin into string or keep as tokens depending on downstream + example[text_key] = tf.strings.reduce_join(chunk, separator=" ") + return example + + return dataset.map(random_chunk, num_parallel_calls=tf.data.AUTOTUNE) + + return preprocess_fn + + +def _select_features(example, feature_list=None): + """Select a subset of features from the example dict.""" + feature_list = feature_list or ["inputs", "targets"] + return {f: example[f] for f in feature_list if f in example} + + +def next_sentence_prediction_tf( + text_key="text", label_sentences=True, buffer_size=50000 +): + """ + Returns a TFDS preprocessing function for NSP. + Each example must contain a text_key (e.g., 'text') with paragraph(s). + """ + + def preprocess_fn(dataset): + # First, buffer examples into memory + dataset = dataset.shuffle(buffer_size, reshuffle_each_iteration=True) + + # Create a second shuffled dataset for random next sentences + other_dataset = dataset.shuffle(buffer_size, reshuffle_each_iteration=True) + + # Zip datasets together + combined = tf.data.Dataset.zip((dataset, other_dataset)) + + def create_nsp_example(a, b): + # Get the raw text tensors + text_a = a[text_key] + text_b = b[text_key] + + # Helper to obtain first and second sentences robustly in TF2. + def first_two_sentences(text): + # Split on '. ' into a RaggedTensor, then densify a single row. + rt = tf.strings.split([text], sep=". ") + dense = rt.to_tensor(default_value="") # shape [1, N] + n = tf.shape(dense)[1] + first = tf.cond( + tf.greater(n, 0), + lambda: dense[0, 0], + lambda: tf.constant("", dtype=tf.string), + ) + second = tf.cond( + tf.greater(n, 1), + lambda: dense[0, 1], + lambda: first, + ) + return first, second + + first_sentence, a_second = first_two_sentences(text_a) + b_first, _ = first_two_sentences(text_b) + + # Random decision: use text from B or a subsequent sentence from A + use_random = tf.random.uniform(()) < 0.5 + second_sentence = tf.cond(use_random, lambda: b_first, lambda: a_second) + + # Format output + input_text = tf.strings.join( + ["sentence1: ", first_sentence, " sentence2: ", second_sentence] + ) + label = tf.where(use_random, "not_next", "next") + + return {"inputs": input_text, "targets": label} + + return combined.map(create_nsp_example) + + return preprocess_fn + + +def no_preprocess(dataset, training): + del training + return dataset + + +def download_and_prepare(dataset_name, data_dir): + """Downloads and prepares TFDS dataset, mapping from T2T if needed. + + Args: + dataset_name: tfds dataset or t2t problem name prefixed by 't2t_'. + data_dir: location of existing dataset or None. + + Returns: + data_dir: path string of downloaded data. + """ + # Translate legacy T2T dataset names to TFDS equivalents + if dataset_name in _T2T_TO_TFDS_MAP: + dataset_name = _T2T_TO_TFDS_MAP[dataset_name] + + if not data_dir: + data_dir = os.path.expanduser("~/tensorflow_datasets/") + dl_dir = os.path.join(data_dir, "download") + logging.info( + "No dataset directory provided. " + "Downloading and generating dataset for %s inside data directory %s " + "For large datasets it is better to prepare datasets manually!", + dataset_name, + data_dir, + ) + + tf.io.gfile.makedirs(data_dir) + tf.io.gfile.makedirs(dl_dir) + # Download and prepare TFDS dataset. + tfds_builder = tfds.builder(dataset_name) + tfds_builder.download_and_prepare(download_dir=dl_dir) + else: + data_dir = os.path.expanduser(data_dir) + return data_dir + + +def dataset_to_stream(dataset, input_name): + """Takes a tf.Dataset and creates a numpy stream of ready batches.""" + # All input-pipeline processing should be on CPU. + for example in fastmath.dataset_as_numpy(dataset): + features = example[0] + + if not isinstance(features[input_name], np.ndarray): + input = np.array(features[input_name]).reshape(1, -1) + else: + input = features[input_name] + + if not isinstance(example[1], np.ndarray): + output = np.array(example[1]).reshape(1, -1) + else: + output = example[1] + + inp, out = input, output + mask = features["mask"] if "mask" in features else None + # Some accelerators don't handle uint8 well, cast to int. + if isinstance(inp, np.uint8): + inp = inp.astype(np.int32) + if isinstance(out, np.uint8): + out = out.astype(np.int32) + yield (inp, out) if mask is None else (inp, out, mask) + + +@gin.configurable(module="trax.data") +def data_streams( + dataset_name, + data_dir=None, + preprocess_fn=no_preprocess, + bare_preprocess_fn=None, + shuffle_buffer_size=1024, + eval_holdout_size=0, + input_name=None, + target_name=None, +): + """Creates `(train, eval)` data sources from ``dataset_name``. + + Args: + dataset_name: Name of dataset belonging to TFDS or T2T. T2T dataset names + must start with ``'t2t_'``. + data_dir: Directory where the data is located. + preprocess_fn: Function to use for pre-processing after appending targets to + inputs. + bare_preprocess_fn: Function to use for pre-processing before appending + targets to inputs. + shuffle_buffer_size: Size of the shuffle buffer. + eval_holdout_size: If greater than 0, specifies a fraction of training data + to siphon off and use as eval data, in place of an separate eval split. + input_name: Name of the inputs from the dictionary. + target_name: Name of the outputs either from the dictionary or as a result + of post-processing. + + Returns: + A pair of functions, `(f, g)` for use as data sources; call `f()` to get an + iterator of training data samples, and call `g()` to get an iterator of eval + data samples. + """ + data_dir = download_and_prepare(dataset_name, data_dir) + + cache = [] + + def stream(which): + """Create the stream, cache TF streams if needed.""" + if not cache: + cache.append( + _train_and_eval_streams( + dataset_name, + data_dir, + preprocess_fn, + bare_preprocess_fn, + shuffle_buffer_size, + eval_holdout_size, + input_name, + target_name, + ) + ) + + (train_ds, eval_ds, input_name_c) = cache[0] + dataset = eval_ds if which == "eval" else train_ds + return dataset_to_stream(dataset, input_name_c) + + train_stream = lambda: stream("train") + eval_stream = lambda: stream("eval") + return train_stream, eval_stream + + +def load_translation_dataset( + dataset_name="wmt14_translate/de-en", + data_dir=None, + train_shuffle_files=True, + eval_shuffle_files=False, + input_key="en", + target_key="de", +): + """ + Loads translation dataset and prepares train/eval tf.data.Datasets with mapped (inputs, targets). + """ + data_dir = os.path.expanduser(data_dir or "~/tensorflow_datasets") + builder = tfds.builder(dataset_name, data_dir=data_dir) + builder.download_and_prepare() + + def _map_example(example): + return {"inputs": example[input_key], "targets": example[target_key]} + + # Load and preprocess splits + train_ds = tfds.load( + dataset_name, + split="train", + shuffle_files=train_shuffle_files, + data_dir=data_dir, + ).map(_map_example) + + eval_ds = tfds.load( + dataset_name, + split="validation", + shuffle_files=eval_shuffle_files, + data_dir=data_dir, + ).map(_map_example) + + supervised_keys = (["inputs"], ["targets"]) + + return train_ds, eval_ds, supervised_keys + + +def _train_and_eval_streams( + dataset, + data_dir, + preprocess_fn, + bare_preprocess_fn, + shuffle_buffer_size, + eval_holdout_size, + input_name, + target_name, +): + """Return train and eval batches with input name and shape.""" + (train_data, eval_data, keys) = _train_and_eval_dataset( + dataset, data_dir, eval_holdout_size + ) + # If provided select input_name/target_name else fall back to keys if that is + # available, else [None]. + input_names = ( + [input_name] + if input_name is not None + else keys[0] + if keys is not None + else [None] + ) + target_names = ( + [target_name] + if target_name is not None + else keys[1] + if keys is not None + else [None] + ) + + train_batches = _shuffle_data( + train_data, + target_names, + True, + shuffle_buffer_size, + preprocess_fn, + bare_preprocess_fn, + ) + eval_batches = _shuffle_data( + eval_data, + target_names, + False, + shuffle_buffer_size, + preprocess_fn, + bare_preprocess_fn, + ) + return (train_batches, eval_batches, input_names[0]) + + +def _train_and_eval_dataset( + dataset_name, + data_dir, + eval_holdout_size, + train_shuffle_files=True, + eval_shuffle_files=False, + use_alt_eval=False, + subsplit=None, + require_train_split=True, +): + """Return train and evaluation datasets, feature info and supervised keys. + + Args: + dataset_name: a string, the name of the dataset; if it starts with 't2t_' + then we'll search T2T Problem registry for it, otherwise we assume it is a + dataset from TFDS and load it from there. + data_dir: directory where the data is located. + eval_holdout_size: float from 0 to <1; if >0 use this much of training data + for evaluation (instead of looking for a pre-specified VALIDATION split). + train_shuffle_files: Boolean determining whether or not to shuffle the train + files at startup. Set to False if you want data determinism. + eval_shuffle_files: Boolean determining whether or not to shuffle the test + files at startup. Set to False if you want data determinism. + use_alt_eval: If True, use the dataset's alternate/secondary eval split; + else use the dataset's default/only eval split. Currently, only the + `glue/mnli` dataset provides an alternate eval split, and this arg is + ignored for other datasets. + subsplit: a pair of floats (x, y), both in [0, 1], saying which part of the + full training dataset we should return (default: all of it, [0, 1]). + + Returns: + a 4-tuple consisting of: + * the train tf.Dataset + * the eval tf.Dataset + * information about features: a python dictionary with feature names + as keys and an object as value that provides .shape and .n_classes. + * supervised_keys: information what's the input and what's the target, + ie., a pair of lists with input and target feature names. + """ + # Translate legacy T2T dataset names to TFDS equivalents early. + if dataset_name in _T2T_TO_TFDS_MAP: + dataset_name = _T2T_TO_TFDS_MAP[dataset_name] + logging.info("Building TF data pipeline for %s", dataset_name) + if dataset_name.startswith("t2t_"): + return _train_and_eval_dataset_v1( + dataset_name[4:], data_dir, train_shuffle_files, eval_shuffle_files + ) + dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) + info = dataset_builder.info + splits = dataset_builder.info.splits + has_train_split = tfds.Split.TRAIN in splits + + train_split = None + eval_split = None + + if dataset_name == "c4/multilingual": + train_split = "en" + has_train_split = True + elif has_train_split: + train_split = tfds.Split.TRAIN + elif require_train_split: + raise ValueError("To train we require a train split in the dataset.") + + if train_split is not None: + train_examples = info.splits[train_split].num_examples + eval_holdout_examples = int(train_examples * eval_holdout_size) + if eval_holdout_examples > 0 or subsplit is not None: + if subsplit is None: + subsplit = (0, 1) + n_train = train_examples - eval_holdout_examples + train_start = int(n_train * subsplit[0]) + train_end = int(n_train * subsplit[1]) + if train_end - train_start < 1: + raise ValueError( + "Requested train subsplit has no examples: " + "n_train %d subsplit %s" % (n_train, subsplit) + ) + # Eval holdout examples from the end of the training set. + if eval_holdout_examples > 0: + eval_split = f"{train_split}[-{eval_holdout_examples}:]" + # Shard the training set for this host. + train_split = f"{train_split}[{train_start}:{train_end}]" + + if dataset_name == "glue/mnli": + eval_split = "validation_mismatched" if use_alt_eval else "validation_matched" + elif dataset_name == "c4/multilingual": + eval_split = "en-validation" + elif eval_split is None: + if tfds.Split.VALIDATION not in splits and "test" not in splits: + raise ValueError("We require a validation or test split in the dataset.") + eval_split = tfds.Split.VALIDATION + if tfds.Split.VALIDATION not in splits: + eval_split = tfds.Split.TEST + + train = None + if train_split is not None: + train = tfds.load( + name=dataset_name, + split=train_split, + data_dir=data_dir, + shuffle_files=train_shuffle_files, + ) + valid = tfds.load( + name=dataset_name, + split=eval_split, + data_dir=data_dir, + shuffle_files=eval_shuffle_files, + ) + keys = None + if info.supervised_keys: + keys = ([info.supervised_keys[0]], [info.supervised_keys[1]]) + return train, valid, keys + + +def _train_and_eval_dataset_v1( + dataset_name="wmt14_translate/de-en", + data_dir=None, + train_shuffle_files=True, + eval_shuffle_files=False, +): + """Return train and evaluation datasets, feature info and supervised keys.""" + train_ds, eval_ds, supervised_keys = load_translation_dataset( + dataset_name=dataset_name, + data_dir=data_dir, + train_shuffle_files=train_shuffle_files, + eval_shuffle_files=eval_shuffle_files, + input_key="en", + target_key="de", + ) + + # You can take an example to determine input key if needed + examples = list(tfds.as_numpy(train_ds.take(1))) + input_key = "inputs" if "inputs" in examples[0] else "targets" + return train_ds, eval_ds, ([input_key], ["targets"]) + + +def _shuffle_data( + dataset, + target_names, + training, + shuffle_buffer_size, + preprocess_fn, + bare_preprocess_fn, +): + """Shuffle the given dataset and run pre-processing.""" + + def append_targets(example): + """Append targets to the example dictionary. Needed for Keras.""" + if len(target_names) == 1: + return (example, example[target_names[0]]) + targets = {} + for name in target_names: + targets[name] = example[name] + return (example, targets) + + # `bare_preprocess_fn` is called before appending targets etc. + if bare_preprocess_fn is not None: + dataset = bare_preprocess_fn(dataset, training) + dataset = dataset.map(append_targets) + # TODO(pkozakowski): Repeat both the training and evaluation set, so we don't + # have incomplete batches during evaluation. This will be a problem when we + # add an option to evaluate on the whole dataset, then we'll need to think of + # a different solution. + dataset = dataset.repeat() + if training: + # Skip a random fraction at the beginning of the stream. The skip is + # essential for synchronous highly-parallel training to avoid multiple + # replicas reading the same data in lock-step. + dataset = dataset.skip(random.randint(0, _MAX_SKIP_EXAMPLES)) + dataset = preprocess_fn(dataset, training) + dataset = dataset.shuffle(shuffle_buffer_size) + return dataset.prefetch(8) + + +@gin.configurable(module="trax.data") +def TFDS( # pylint: disable=invalid-name + dataset_name, + data_dir=None, + tfds_preprocess_fn=None, + keys=None, + train=True, + use_alt_eval=False, + shuffle_train=True, + host_id=None, + n_hosts=None, + eval_holdout_size=0, +): + """Creates a data source from TensorFlow dataset ``dataset_name``. + + Args: + dataset_name: Name of the dataset, as registered in TensorFlow datasets + (e.g., ``'glue/mnli'``). + data_dir: Directory where the data is located. + tfds_preprocess_fn: If specified, function that applies to items in raw + dataset (before selecting specific features). + keys: Tuple of dataset-specific strings that select features from the + dataset. + train: If True, select the training split from the dataset; else select an + eval split. + use_alt_eval: If True, and if ``train`` is False, select the dataset's + alternate eval split if it has one (or fall back to the dataset's only + eval split). This currently affects only the `glue/mnli` dataset. + shuffle_train: If True, have TensorFlow pre-shuffle the training data; else + receive training data in deterministic sequence. + host_id: Integer id used for tracking data subsplits, in cases where + ``n_hosts`` > 1. + n_hosts: If greater than 1, prepare data subsplits for the given number of + hosts. + eval_holdout_size: If greater than 0, specifies a fraction of training data + to siphon off and use as eval data, in place of an separate eval split. + + Returns: + A function `f` for use as a training or eval data source; call `f()` to get + an iterator of data samples. + """ + data_dir = download_and_prepare(dataset_name, data_dir) + + # Try to query JAX multi-host info; fall back to single-host CPU if JAX + # backends (e.g., CUDA) are unavailable in the current environment. + try: + host_id = jax.process_index() if host_id is None else host_id + n_hosts = n_hosts or jax.host_count() + except Exception: + host_id = 0 if host_id is None else host_id + n_hosts = n_hosts or 1 + if n_hosts > 1: + subsplit = (host_id / n_hosts, (host_id + 1) / n_hosts) + else: + subsplit = None + train_data, eval_data, _ = _train_and_eval_dataset( + dataset_name, + data_dir, + eval_holdout_size, + train_shuffle_files=shuffle_train, + use_alt_eval=use_alt_eval, + subsplit=subsplit, + require_train_split=train, + ) + if train and train_data is None: + raise ValueError( + f"Dataset {dataset_name} does not provide a train split for training." + ) + dataset = train_data if train else eval_data + dataset = dataset if tfds_preprocess_fn is None else tfds_preprocess_fn(dataset) + + def select_from(example): + return tuple(example[k] for k in keys) + + dataset = dataset.map(select_from) + dataset = dataset.repeat() + + def gen(generator=None): + del generator + for example in fastmath.dataset_as_numpy(dataset): + yield example + + return gen + + +@gin.configurable(module="trax.data") +def CorpusToRandomChunks( + dataset_name, num_tokens=512, train=True +): # pylint: disable=invalid-name + return TFDS( + dataset_name, + tfds_preprocess_fn=random_split_text_tf( + max_words_per_segment=num_tokens, + text_key="text", + ), + train=train, + keys=["text"], + ) + + +@gin.configurable(module="trax.data") +def CreateAquaInputs( # pylint: disable=invalid-name + dataset_path=None, + train=True, + cumulative=False, + rationale=False, + correct_answer=False, + correct_answer_given_reasoning=False, + partial_reasoning=True, + order_prediction=False, +): + """Prepares Aqua inputs. + + Args: + dataset_path: a path with the Aqua dataset. + train: if True, then generate training examples, otherwhise generate + validation examples (the dataset has also a test set). + cumulative: if set to True, then generate examples in the format input - + problem + step1 + step3 + step3 target - step4 If set to False, then + examples are in the format input - problem, target - all operations. + rationale: if set to True, then input is the problem and the target is the + rationale. + correct_answer: if set to True, then input is the problem plus all possible + answers and the target is the correct answer. + correct_answer_given_reasoning: if set to True, then input is the problem + plus reasoning (aka rationale) plus all possible answers and the target is + the correct answer. + partial_reasoning: an additional option related to + correct_answer_given_reasoning; if set to True, then we take a random + prefix of the reasoning. + order_prediction: if set to True, then input is the problem and a list of + all operations; with probability 0.5 two operations are swapped; the task + consists in detecting whether the operations were swapped. A similar + additional task was considered in https://arxiv.org/pdf/1909.11942.pdf and + in a recent work of Piotr Piękos, henrykm@ and mateuszm@. + + Returns: + aqua_yield_examples: a generator of Aqua examples; the generator yields + non-tokenized examples - they can be further processed using for example + the tokenize function from this module + """ + if train: + dataset_path = os.path.join(dataset_path, "train.json") + else: + dataset_path = os.path.join(dataset_path, "dev.json") + # Opening with GFile allows to use remotely stored files, e.g. + # in a gs bucket. + dataset_handle = tf.io.gfile.GFile(dataset_path, "r") + dataset = [] + for line in dataset_handle: + dataset.append(json.loads(line)) + + def aqua_yield_examples(generator=None): + del generator + while True: + for example in itertools.cycle(dataset): + input_prefix = example["question"] + steps = example["rationale"].split("\n") + if cumulative: + for i in range(len(steps)): + input_values = "infer cumulative rationale: " + input_prefix + target_values = steps[i] + input_prefix += " " + steps[i] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif rationale: + input_values = "infer full rationale: " + input_prefix + target_values = example["rationale"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif correct_answer: + input_values = "infer correct answer: " + input_prefix + input_values += " " + " ".join(example["options"]) + target_values = example["correct"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif correct_answer_given_reasoning: + input_values = ( + "infer correct answer given reasoning: " + input_prefix + ) + if partial_reasoning: + reasoning_list = example["rationale"].split("\n") + reasoning_list = reasoning_list[ + 0 : np.random.randint(0, len(reasoning_list)) + ] + reasoning = "\n".join(reasoning_list) + else: + reasoning = example["rationale"] + input_values += ( + " " + example["rationale"] + " " + " ".join(example["options"]) + ) + target_values = example["correct"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif order_prediction: + if np.random.uniform() < 0.5 and len(steps) >= 2: + idx = range(len(steps)) + i1, i2 = random.sample(idx, 2) + steps[i1], steps[i2] = steps[i2], steps[i1] + target_values = "not_ordered" + else: + target_values = "ordered" + input_values = ( + "order prediction: " + input_prefix + " " + "\n".join(steps) + ) + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + else: + raise ValueError( + "One of the boolean parameters of the Aqua generator must be set to True." + ) + + return aqua_yield_examples + + +@gin.configurable(module="trax.data") +def CreateAnnotatedDropInputs( # pylint: disable=invalid-name + dataset_path=None, + train=True, + single_file=True, + unique=False, + total_number_of_samples=None, + percentile=1.0, +): + r"""Prepares annotated Drop inputs. + + Example of an annotated input which can be used with this interface: + + { + 'passage': 'The Armenian Prelature of Cyprus was established in 973 by + Catholicos Khatchig I. Historically, the Prelature has been under the + jurisdiction of the Catholicosate of the Great House of Cilicia, while today + it is the oldest theme that falls under its jurisdiction. Since 2014 the + Prelate, a Catholicosal Vicar General, has been Archbishop Nareg Alemezian. + The parish priest in Nicosia is Fr. Momik Habeshian, while the parish priest + in Larnaca and Limassol is Fr. Mashdots Ashkarian. For centuries, the + Prelature building was located within the Armenian compound in Victoria + street in walled Nicosia; when that area was taken over by Turkish-Cypriot + extremists in 1963-1964, the Prelature was temporarily housed in Aram + Ouzounian street and, later on, in Kyriakos Matsis street in Ayios + Dhometios. Thanks to the efforts of Bishop Zareh Aznavorian and with + financial aid from the Evangelical Church of Westphalia, the new Prelature + building was erected in 1983, next to the Virgin Mary church and the Nareg + school in Nicosia, by architects Athos Dikaios & Alkis Dikaios; it was + officially inaugurated on 4 March 1984, during the pastoral visit of + Catholicos Karekin II. By initiative of Archbishop Varoujan Hergelian, in + 1998 the basement of the building was renovated and the "Vahram Utidjian" + Hall was formed; previously a store room, it became a reality from the + proceeds of the auction in 1994 of the art collection that Vahram Utidjian + had donated to the Prelature in 1954. It was inaugurated on 3 February 1999 + by Catholicos Aram I; numerous charity, communal and cultural events take + place there. The Prelature\'s consistory houses a collection of + ecclesiastical relics, some of which were previously in the old Virgin Mary + church or the Magaravank.', + 'question': 'How many years after the Vahram Utidjian was donated to the + Prelature was it sold at an auction?', + 'answer': 40, + 'calculation': 'subtract(n8,n9)' + } + + In this example the calculation is formulated using the notation from the + MathQA dataset, but this is not required. subtract(n8,n9) means that the + answer 40 can be obtained through the substraction of the 9th and and the 10th + number in the input. The input consists of the passage concatened with the + question. The annotations can be generated using, for example, a method + from the paper https://arxiv.org/abs/1909.00109. + + Args: + dataset_path: a path with the Aqua dataset. + train: if True, then generate training examples, otherwhise generate + validation examples (the dataset has also a test set). + single_file: if True, then look just for one file. If False, read all + json files in a given directory and assume that each file contains one + example. Applied only to training data. + unique: if set to True, then the generator will provide at most one question + per passage. + total_number_of_samples: if set to a positive integer, then the total number + of unique samples will be bounded total_number_of_samples. + percentile: the percentile of the train dataset used for training; default + set to 1., though setting to a lower value can be interesting when + combined train is combined with another source of data. + + Returns: + drop_annotated_yield_examples: a generator of annotated Drop examples; + the generator yields non-tokenized examples - they can be further processed + using for example the tokenize function from this module. + """ + if train: + if single_file: + dataset_path = os.path.join(dataset_path, "train_annotated.json") + else: + dataset_path = os.path.join(dataset_path, "dev_annotated.json") + + def load_dataset(): + dataset = [] + if single_file: + # Opening with GFile allows to use remotely stored files, e.g. + # in a gs bucket. + dataset_handle = tf.io.gfile.GFile(dataset_path, "r") + for line in dataset_handle: + dataset.append(json.loads(line)) + else: + all_files = tf.io.gfile.listdir(dataset_path) + for filename in all_files: + if "json" in filename: + print("Loading data from file {}".format(filename)) + with tf.io.gfile.GFile(os.path.join(dataset_path, filename)) as f: + for line in f: + dataset.append(json.loads(line)) + print("The total size of the dataset {}".format(len(dataset))) + return dataset[: int(len(dataset) * percentile)] + + def drop_annotated_yield_examples(generator=None): + del generator + while True: + passages = set() + unique_examples = set() + # Notice that below we enable a poor man RL loop + # aka the DAgger algorithm: https://arxiv.org/pdf/1011.0686.pdf + # tl;dr: after parsing all examples we re-load the dataset - this + # may become handy if a prediction service generates new examples. + dataset = load_dataset() + for example in dataset: + # If total_number_of_samples is not None and we have reached this + # number of samples, then we re-load the dataset. + if total_number_of_samples: + if len(unique_examples) >= total_number_of_samples: + break + # Do we have a pre-calculated input in the example? + if "input" in example.keys(): + question = example["input"] + # Remove the old prompt + question = question[question.find(":") + 2 :] + else: + # If input is not present, then we expect that this is an + # original drop example. + if unique and example["passage"] in passages: + continue + passages.add(example["passage"]) + question = example["passage"] + " " + example["question"] + list_num = [ + float( + num.replace(",", "").rstrip(".").lstrip(".") + ) # pylint: disable=g-complex-comprehension + for num in re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", + question, + ) + ] + for i in range(len(list_num)): + question += " n{} = {}".format(i, list_num[i]) + input_values = "drop annotated question: " + question + target_values = example["calculation"] + unique_examples.add((input_values, target_values)) + yield ( + input_values, + target_values, + np.array([1] * len(target_values), dtype=np.int32), + ) + + return drop_annotated_yield_examples + + +@gin.configurable(module="trax.data") +def CreateDropInputs(train=True, mathqa_format=False): # pylint: disable=invalid-name + """Prepares Drop inputs. + + Args: + train: if True, then generate training examples, otherwhise generate + validation examples (the dataset has also a test set). + mathqa_format: if True, then floats in targets are converted to the + the MathQA convention and wrapped in the subtract operation. + E.g. "3.13" is converted to "subtract(const_3_13,const_0)". + + Returns: + drop_yield_examples: a generator of Drop examples; the generator yields + non-tokenized examples - they can be further processed using for example + the tokenize function from this module + """ + if train: + dataset = tfds.load(name="drop", split="train") + else: + dataset = tfds.load(name="drop", split="dev") + dataset = tfds.as_numpy(dataset) + + def drop_yield_examples(generator=None): + del generator + while True: + for example in itertools.cycle(dataset): + input_values = ( + "drop question: " + + example["passage"].decode("utf-8") + + " " + + example["question"].decode("utf-8") + ) + target_values = example["answer"].decode("utf-8") + # Apparently the dataset has some empty "target values" - + # when such a value is encountered, the Tokenizer decides to assign + # to it a float32 tensor and the training fails. + if not target_values: + continue + if mathqa_format: + if target_values.replace(".", "", 1).isdigit(): + target_values = convert_to_subtract( + convert_float_to_mathqa(target_values) + ) + yield input_values, target_values, np.array( + [1] * len(target_values), dtype=np.int32 + ) + + return drop_yield_examples + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def downsampled_imagenet_flatten_bare_preprocess(dataset, training): + """Preprocessing for downsampled_imagenet. + + Args: + dataset: the dataset. + training: unused option. + + Returns: + Flattened dataset. + + Preprocessing for downsampled_imagenet 32x32 and 64x64 generation from + http://arxiv.org/abs/1601.06759 (page 8). + """ + del training + + def flatten_image(features): + img = features["image"] + flat = tf.cast(tf.reshape(img, [-1]), tf.int64) + + new_features = {"image": flat} + return new_features + + return dataset.map(flatten_image) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def concat_preprocess(dataset, training, pad_symbol=0): + """Pre-processing function that concatenates input and target for LM.""" + del training + + def concat(features, targets): + inp = features["inputs"] + pad = tf.expand_dims(tf.zeros_like(inp[0]) + pad_symbol, axis=0) + concat = tf.concat([pad, inp, pad, targets], axis=0) + # Note: we're updating existing features dictionary here, so make sure + # it is not re-used in some other ways outside of this function. + features["inputs"] = concat + return features, concat + + dataset = dataset.map(concat) + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def squeeze_targets_preprocess(dataset, training): + """Pre-processing function that squeezes last axis of targets.""" + del training + + def squeeze(features, targets): + if targets.shape[-1] == 1: + targets = tf.squeeze(targets, axis=-1) + return features, targets + + dataset = dataset.map(squeeze) + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def lm1b_preprocess(dataset, training, max_target_length=-1, max_eval_target_length=-1): + """Preprocessing for LM1B: filter out targets exceeding maximum length.""" + + def target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_target_length + 1) + + def eval_target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_eval_target_length + 1) + + if max_target_length > 0 and training: + dataset = dataset.filter(target_right_length) + + if max_eval_target_length > 0 and not training: + dataset = dataset.filter(eval_target_right_length) + + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def lm_token_preprocessing(dataset, training): + """Concatenates inputs, 0, targets, with masking only for targets.""" + del training + + def concat_and_add_mask(x): + inp = x["inputs"] + targets = x["targets"] + pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) + concat = tf.concat([inp, pad, targets], axis=0) + mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) + x["inputs"] = concat + x["targets"] = concat + x["mask"] = mask + return x + + dataset = dataset.map(concat_and_add_mask) + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def bair_robot_pushing_preprocess(dataset, training): + """Pre-processing function that concatenates input and target frames.""" + del training + + def concat_and_add_mask(features, targets): + """Concatenate input and output frames to form a language modeling setup.""" + inp = features["inputs"] + concat = tf.concat([inp, targets], axis=0) + mask = tf.concat([tf.zeros_like(inp), tf.ones_like(targets)], axis=0) + concat = tf.reshape(concat, (-1,)) + mask = tf.reshape(mask, (-1,)) + concat = tf.cast(concat, tf.int32) + mask = tf.cast(mask, tf.float32) + features["inputs"] = features["targets"] = concat + features["mask"] = mask + return features, concat + + dataset = dataset.map(concat_and_add_mask) + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def filter_dataset_on_len(dataset, training, len_map=None, filter_on_eval=False): + """Filters a dataset of lengths given in `len_map`. + + Args: + dataset: `tf.data.Dataset` the dataset to filter. + training: bool, true if we are in training mode. + len_map: optional dict of str to (int, int). We filter examples where a + feature's size is beyond the specified bounds. Ex: + {'inputs': (1, 512), 'targets': (64, 128)} will keep only those examples + where 1 <= len(inputs) <= 512 and 64 <= len(targets) <= 128. + filter_on_eval: bool if true, we will filter in eval mode also. + + Returns: + a filtered `tf.data.Dataset`. + """ + if (len_map is None) or (not training and not filter_on_eval): + return dataset + + assert isinstance(len_map, dict) + for k, bounds in len_map.items(): + # pylint: disable=cell-var-from-loop + # TODO(afrozm): Investigate `cell-var-from-loop` - since this is WAI and + # there is a test too. + def within_bounds(x, key, len_bounds): + size = tf.shape(x[key])[0] + min_len, max_len = len_bounds + return (min_len <= size) and (size <= max_len) + + dataset = dataset.filter(lambda x: within_bounds(x, k, bounds)) + # pylint: enable=cell-var-from-loop + + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def truncate_dataset_on_len(dataset, training, len_map=None, truncate_on_eval=False): + """Truncates features in an example to lengths given in `len_map`. + + Args: + dataset: `tf.data.Dataset` the dataset to filter. + training: bool, true if we are in training mode. + len_map: optional dict of str to int, we truncate examples where a feature's + size is beyond the max. Ex: {'inputs': 512, 'targets': 64} will truncate + examples to be within those bounds. + truncate_on_eval: bool if true, we will truncate in eval mode also. + + Returns: + a filtered `tf.data.Dataset`. + """ + if (len_map is None) or (not training and not truncate_on_eval): + return dataset + + assert isinstance(len_map, dict) + + def truncate_example(x): + for key, max_len in len_map.items(): + x_len = tf.shape(x[key])[0] + if x_len > max_len: + x[key] = x[key][:max_len, ...] + return x + + return dataset.map(truncate_example) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def pad_dataset_to_length(dataset, training, len_map=None): + """Pad features less than specified length to specified length.""" + del training + if len_map is None: + return dataset + + def pad_to_len(x): + for key, max_len in len_map.items(): + x_shape = tf.shape(x[key]) + x_len = x_shape[0] + if x_len < max_len: + pad_shape = [ + max_len - x_len, + ] + zeros = tf.zeros(pad_shape, dtype=x[key].dtype) + x[key] = tf.concat([x[key], zeros], 0) + return x + + return dataset.map(pad_to_len) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def add_eos_to_output_features(dataset, training, output_features="targets", eos=1): + """Adds `EOS` to all features in `output_features`.""" + del training + if not isinstance(output_features, (list, tuple)): + output_features = [output_features] + + def add_eos(x): + for output_feature in output_features: + x[output_feature] = tf.concat([x[output_feature], [eos]], axis=0) + return x + + return dataset.map(add_eos) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def select_random_chunk_t5( + dataset, training, sequence_length=None, output_features=None +): + """Select a random chunk from the input tokens.""" + del training + + def select_chunk(features): + if sequence_length is None: + return features + + tokens = features["inputs"] + seq_len = tf.shape(tokens)[0] + + max_start = tf.maximum(seq_len - sequence_length, 0) + start_index = tf.random.uniform( + [], minval=0, maxval=max_start + 1, dtype=tf.int32 + ) + + chunk = tokens[start_index : start_index + sequence_length] + + features["inputs"] = chunk + features["targets"] = chunk + + return features + + return dataset.map(select_chunk, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def split_tokens_t5(dataset, training, sequence_length=None, output_features=None): + """Split tokens into two parts.""" + del training + + def split(features): + if sequence_length is None: + return features + + tokens = features["inputs"] + seq_len = tf.shape(tokens)[0] + + split_point = seq_len // 2 + + features["inputs"] = tokens[:split_point] + features["targets"] = tokens[split_point:] + + return features + + return dataset.map(split, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def denoise_t5( + dataset, training, sequence_length=None, output_features=None, noise_density=0.15 +): + """Apply denoising to the tokens.""" + del training + + def apply_noise(features): + if sequence_length is None: + return features + + tokens = features["inputs"] + + mask = tf.random.uniform(tf.shape(tokens), minval=0, maxval=1) < noise_density + noisy_tokens = tf.where(mask, tf.zeros_like(tokens), tokens) + + features["inputs"] = noisy_tokens + features["targets"] = tokens + + return features + + return dataset.map(apply_noise, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +def _pad_punctuation(text): + """Adds spaces around punctuation.""" + # Add space around punctuation. + text = tf.strings.regex_replace(text, r"([[:punct:]])", r" \1 ") + # Collapse consecutive whitespace into one space. + text = tf.strings.regex_replace(text, r"\s+", " ") + return text + + +def _string_join(lst): + # Join on space, but collapse consecutive spaces. + out = tf.strings.join(lst, separator=" ") + return tf.strings.regex_replace(out, r"\s+", " ") + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def squad_t5(dataset, training, include_context=True): + """Convert SQuAD examples to a text2text pair. + + SQuAD produces examples with this form: + {'id': , context':
, 'question': , + 'answers': { 'text': [] }} + This function will return examples of the format: + {'inputs': 'question: context:
', + 'targets': '', + 'id': , 'question': , 'context': , + 'answers': []}, + + Args: + x: an example to process. + include_context: a boolean + Returns: + A preprocessed example with the format listed above. + """ + + """Apply squad to the tokens.""" + del training + + def squad(x): + a = _pad_punctuation(x["answers"]["text"]) + q = _pad_punctuation(x["question"]) + c = _pad_punctuation(x["context"]) + if include_context: + inputs = _string_join(["question:", q, "context:", c]) + else: + inputs = _string_join(["squad trivia question:", q]) + return { + "inputs": inputs, + "targets": a[0], + "id": x["id"], + "context": c, + "question": q, + "answers": a, + } + + return dataset.map(squad, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def rekey_t5(dataset, training, key_map=None): + """Replace the feature keys according to the mapping in `key_map`. + + For example, if the dataset returns examples of the format: + {'foo': 'something', 'bar': 'something else'} + and key_map = {'boo': 'foo', 'spar': 'bar'} then this function will return + examples with the format + {'boo': 'something', 'spar': 'something else'} + + If a mapping is to an empty key name or None, the new value is set to an empty + string. + + Args: + x: an example to process. + key_map: dictionary mapping new keys to original keys + + Returns: + A preprocessed example with the format listed above. + """ + + del training + + def rekey(x): + if key_map: + return { + new_key: x[old_key] if old_key else "" + for new_key, old_key in key_map.items() + } + return x + + return dataset.map(rekey, num_parallel_calls=tf.data.experimental.AUTOTUNE) + + +_PREPROCESSOR_REGISTRY = { + "next_sentence_prediction_tf": next_sentence_prediction_tf, + "random_split_text_tf": random_split_text_tf, + "select_random_chunk_t5": select_random_chunk_t5, + "split_tokens_t5": split_tokens_t5, + "denoise_t5": denoise_t5, + "squad_t5": squad_t5, + "rekey_t5": rekey_t5, +} + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def unsupervised_preprocessors( + dataset, training, sequence_length=None, output_features=None, preprocessors=None +): + """ + Apply a series of unsupervised preprocessors. + + Args: + dataset: Input TensorFlow dataset + sequence_length: Maximum sequence length + output_features: Optional output features dictionary + preprocessors: List of preprocessing functions to apply + + Returns: + Preprocessed dataset + """ + del training + + if preprocessors is None: + return dataset + + for preprocessor in preprocessors: + dataset = preprocessor( + dataset, + None, + sequence_length=sequence_length, + output_features=output_features, + ) + + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def generic_text_dataset_preprocess_fn( + dataset, + training=True, + text_preprocess_fns=None, + token_preprocess_fns=None, + spm_path=None, + copy_pretokenized=False, + debug_print_examples=False, + debug_print_examples_rate=0.01, +): + """Pre-processes, tokenizes and post-processes a `tf.data.Dataset`. + + Args: + dataset: `tf.data.Dataset` to process. + training: boolean, set to True if training, False otherwise. + text_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> + `tf.data.Dataset` this operates before tokenization. Typically used to + select which fields we want to learn over or change something into "text + to text" form. + token_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> + `tf.data.Dataset`, this operates after tokenization. Since this can view + the tokenized fields, this can be used to filter on length etc. + spm_path: None or str, path to a sentencepiece model to use for tokenization + by default uses the 32k vocabulary from T5. + copy_pretokenized: bool, if True retains the original fields after + tokenization. + debug_print_examples: bool, if True this prints examples to the logging + stream for inspection, both before and after tokenization. + debug_print_examples_rate: float, [0, 1.0], on average this fraction of + dataset examples will be printed out in each phase i.e. pre and post + tokenization. + + Returns: + a `tf.data.Dataset` with all the preprocessing and tokenization performed. + """ + + # The assumption is that `text_preprocess_fns` finally gives us a dataset + # which has `inputs` and `targets`. + if text_preprocess_fns is not None: + for text_preprocess_fn in text_preprocess_fns: + dataset = text_preprocess_fn(dataset, training) + + # Print debugging examples if needed before tokenization. + if debug_print_examples: + + def print_examples(x): + if np.random.uniform() < debug_print_examples_rate: + tf.print(x, output_stream=logging.info) + return x + + dataset = dataset.map(print_examples) + + # Vocabulary for tokenization. + tokenizer = SentencePieceEncoder(spm_path) + + # Tokenize the inputs and targets. + def tokenize_fields(example): + inputs = example.get("inputs", example["targets"]) + targets = example["targets"] + + tokenized_inputs = tf.cast(tokenizer.encode(inputs), tf.int64) + tokenized_targets = tf.cast(tokenizer.encode(targets), tf.int64) + + new_example = { + "inputs": tokenized_inputs, + "targets": tokenized_targets, + } + if copy_pretokenized: + new_example["inputs_pretokenized"] = inputs + new_example["targets_pretokenized"] = targets + + return new_example + + dataset = dataset.map(tokenize_fields) + + # Apply the token-preprocessors. + if token_preprocess_fns is not None: + for token_preprocess_fn in token_preprocess_fns: + dataset = token_preprocess_fn(dataset, training) + + if debug_print_examples: + + def print_examples_and_shapes(x): + if np.random.uniform() < debug_print_examples_rate: + tf.print( + "inputs_shape:", + tf.size(x["inputs"]), + "targets_shape:", + tf.size(x["targets"]), + "inputs:", + x["inputs"], + "targets:", + x["targets"], + output_stream=logging.info, # or use a custom stream that writes to logging.info + ) + return x + + dataset = dataset.map(print_examples_and_shapes) + + return dataset + + +@gin.configurable(module="trax.data") +def get_t5_preprocessor_by_name(name=None, fn_kwargs=None): + """Returns a closure of any T5 preprocessor function with its arguments. + + The main use-case is to use this (with gin scopes) to make any preprocessor + function available in a gin file to configure and use. + + See: `TFInputs.test_gin_configurable_preprocessors` + + Args: + name: str, name of the preprocessor function to configure. + fn_kwargs: optional dictionary, the arguments to configure, these will be + partially applied to the function given by `name`. + + Returns: + a closure of the preprocessor function along with its arguments, this + function takes two arguments only, dataset and boolean training and ignores + the training and calls the t5 processor with the dataset (and closed over + arguments only). + """ + + if name is None or name not in _PREPROCESSOR_REGISTRY: + raise ValueError(f"Unknown or missing preprocessor name: '{name}'.") + + fn = _PREPROCESSOR_REGISTRY[name] + if fn_kwargs: + fn = functools.partial(fn, **fn_kwargs) + + # Ensure compatibility with trax data preprocessing signature + return lambda ds, training: fn(ds, training) diff --git a/trax/data/preprocessing/__init__.py b/trax/data/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/preprocessing/inputs.py b/trax/data/preprocessing/inputs.py new file mode 100644 index 000000000..f98df5586 --- /dev/null +++ b/trax/data/preprocessing/inputs.py @@ -0,0 +1,2154 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data sources and input processing. + +Trax authors recommend constructing input pipelines using layer-like functions +and combinators. For example, following is an input pipeline for training +sentiment analysis tasks on the IMDB dataset:: + + from trax import data + + inputs = data.Serial( + data.TFDS('imdb_reviews', keys=('text', 'label'), train=True), + data.Tokenize(vocab_file='en_8k.subword', keys=[0]), + data.Shuffle(), + data.FilterByLength(max_length=2048, length_keys=[0]), + data.BucketByLength(boundaries=[ 32, 128, 512, 2048], + batch_sizes=[128, 32, 8, 2, 1], + length_keys=[0]), + data.AddLossWeights() + ) + +Each of these functions creates a Python generator of tuples of data arrays. +For example:: + + data.TFDS('imdb_reviews', keys=('text', 'label'), train=True), + +creates a generator of examples (tuples of NumPy :py:class:`ndarray` objects) +from the TFDS imdb_reviews dataset, see here: +https://www.tensorflow.org/datasets/catalog/imdb_reviews + +As you can see on the website above, this dataset has 'text' and 'label' fields +and we create tuples containing the text and the label from the training split +by specifying keys=('text', 'label'), train=True. + +Other functions, like ``Tokenize`` and ``Shuffle``, take a generator and output +another generator, in this way converting tuples into other tuples or mixing +the training stream. For example, ``Tokenize(..., keys=[0])`` tokenizes the +first element of a tuple -- converting it from text to a NumPy integer array. +And ``Shuffle`` randomizes the order of examples. + +Note that all elements in the data pipeline are just functions on generators, +so you can use Python's `map` and `filter` and other native functions too. +For example, you can create an input pipeline for a language model reading +lines from `my_file.txt` as follows:: + + inputs = data.Serial( + lambda _: open('my_file.txt'), + lambda g: map(lambda line: line.strip(), g), + data.Tokenize(vocab_file='en_8k.subword'), + lambda g: filter(lambda x: x.shape[0] < 513, g), # At most 512 tokens. + data.Shuffle(), + lambda g: map(lambda x: (x, x)), # Language models have inputs = targets. + data.BucketByLength(boundaries=[ 32, 64, 128, 256, 512], + batch_sizes=[ 32, 16, 8, 4, 2, 1]), + data.AddLossWeights(id_to_mask=0) + ) + +""" + +import math +import multiprocessing.dummy as mp # using threads for now +import os +import pickle +import random +import time + +from typing import Optional, Sequence, Union + +import gin +import jax +import numpy as np +import tensorflow as tf + +from absl import logging + +from trax import fastmath +from trax.data.debugger import data_pipeline as debug_data_pipeline +from trax.fastmath import numpy as jnp +from trax.utils import shapes + + +def Serial(*fns): # pylint: disable=invalid-name + """Combines generator functions into one that runs them serially.""" + + def composed_fns(generator=None): + for f in fastmath.tree_flatten(fns): + generator = f(generator) + return generator + + return composed_fns + + +# TODO(jonni): Rename to Blend/Merge/Mix/Interleave/...? +def Parallel( # pylint: disable=invalid-name + fns=None, + counters=None, + reweight_by_minimum=False, + gradually_reweight=False, + use_remainders=False, +): + """Combines generator functions into one that runs them in parallel. + + Args: + fns: a sequence of datasets which are combined in parallel. + counters: a sequence of ints with same length as fns, please see comments on + its use below. + reweight_by_minimum: if set to True, then we re-weight every counter by the + minimal counter. E.g. counters (10000, 100000) are translated to (1, 10) + and hence for every 10 examples from the second dataset we are getting + 1 example from the first dataset. Without reweighting first we would see + 20 examples from the first and second dataset and then 90 thousand eamples + only from the first dataset. + gradually_reweight: if set to True, then we loop through the generators + using a recursive rule defined in emit_examples. First we sort generators + by the counters. If we have datasets with counters 1, 20, 40 + (after sorting) then we yield examples (a(b c^2)^20)^*, where examples of + type a come from the first dataset, of type b from the second and of type + c from the third. The exponents are obtained through divisions of + subsequent counters. + use_remainders: if set to True as weell as gradually_reweight is set to + True and counters are 1, 20, 45 then after dealing with all examples in + the format (a(b c^2)^20)^*, the generator yields the remaining 5 examples + from the dataset with counter 45. + Returns: + parallel_generator: the generator yields samples according to given; + if counters are not given then samples are genereted uniformly. + + Example 1: + + gen = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3)) + + defines a generator that yields 33% examples from dataset1, 16% examples from + dataset2 and 50% examples from dataset3. + + Example 2: + + gen = data.Parallel([dataset1, dataset2, dataset3], counters=(20, 50, 30)) + + defines a generator that yields 20% examples from dataset1, 50% examples from + dataset2 and 30% examples from dataset3. + """ + + if counters: + assert len(counters) == len(fns) + # Remove generators with zero counters + counters = list(counters) + fns = list(fns) + non_zeros = [j for j in range(len(counters)) if counters[j] != 0] + counters = [counters[j] for j in non_zeros] + fns = [fns[j] for j in non_zeros] + else: + counters = [1] * len(fns) + + if reweight_by_minimum: + counters = [math.floor(counter / min(counters)) for counter in counters] + + def emit_examples(sorted_counters_with_gens, prev_counter): + if sorted_counters_with_gens: + _, counter, generator = sorted_counters_with_gens[0] + repeats = math.floor(counter / prev_counter) + for _ in range(repeats): + yield next(generator) + yield from emit_examples(sorted_counters_with_gens[1:], counter) + + def parallel_generator(gen=None): + # If gradually_reweight is set to False then + # current_counters are increased step by step; they are reset to 0s when + # current_counters[idx] == counters[idx] for all idx. See + # test_parallel_with_weights_three_datasets for an example of how + # current_counters are changed during computation. + # If gradually_reweight is set to False then we loop using a + # recursive rule defined in emit_examples. + + generators = [] + for f in fns: + if gen: + generators.append(f(gen)) + else: + # This handles the case when the function f cannot be + # called on None. + generators.append(f()) + + if gradually_reweight: + counters_with_gens = zip(range(len(generators)), counters, generators) + sorted_counters_with_gens = sorted(counters_with_gens, key=lambda x: x[1]) + while True: + yield from emit_examples(sorted_counters_with_gens, min(counters)) + if use_remainders: + # Below we are dealing with remainders. + fractions = [] + for i in range(len(sorted_counters_with_gens)): + _, counter, generator = sorted_counters_with_gens[i] + processed = 1 + for fraction in fractions: + processed *= fraction + remainder = counter - processed + for _ in range(remainder): + yield next(generator) + if i < len(sorted_counters_with_gens) - 1: + _, next_counter, _ = sorted_counters_with_gens[i + 1] + fractions.append(math.floor(next_counter / counter)) + else: + current_counters = [0] * len(generators) + while True: + for idx, generator in enumerate(generators): + if current_counters[idx] < counters[idx]: + current_counters[idx] += 1 + # instead of checking current_counters[idx] == counters[idx] for + # all idx, we check the equivalent condition: + if sum(current_counters) == sum(counters): + current_counters = [0] * len(generators) + yield next(generator) + + return parallel_generator + + +@gin.configurable(module="trax.data") +def Shuffle(queue_size=1024): # pylint: disable=invalid-name + """Returns a shuffle function with the given queue size.""" + return lambda g: shuffle(g, queue_size) + + +@gin.configurable(module="trax.data") +def Batch(batch_size): # pylint: disable=invalid-name + """Returns a batching function with given batch size.""" + return lambda g: batch(g, batch_size) + + +@gin.configurable(module="trax.data") +def Dup(): # pylint: disable=invalid-name + """Duplicates (copies) the top element (inputs). + + The generator stream is augmented in the following way: + + - If the stream consists of a single element `(inputs, )`, + the inputs simply get copied to `(inputs, inputs)`. + - If the stream consists of multiple elements, for example + `(inputs, weights)`, the rest of elements get moved toward + the right side `(inputs, inputs, weights)`. + + Returns: + the duplicating function. + """ + + def _copy(xs): + x, *rest = xs + return (x, x, *rest) + + return lambda g: map(lambda x: _copy(x), g) # pylint: disable=unnecessary-lambda + + +@gin.configurable(module="trax.data") +def FilterEmptyExamples(axes=None, debug=False): # pylint: disable=invalid-name + """Filters empty examples. + + Filters any example that has an array of size (0,) (if axes=None). + Alternatively, checks only axes provided in `axes' list. Contrary to + FilterByLength used with several elements with length_axis, here the example + would be filtered if ANY of the dimensions listed in `axes' contains an empty + array. + + Args: + axes: list of indices to check, if None, all of them. + debug: If true, emits a log everytime we filter out an empty example. + + Returns: + Function filtering empty examples. + """ + + def _filter_examples(generator): + for example in generator: + correct = True + for i, unused_tuple_element in enumerate(example): + if axes is None or i in axes: + if example[i].shape == (0,): + correct = False + break + if correct: + yield example + elif debug: + logging.info("Filtered example: %r", example) + + return _filter_examples + + +@gin.configurable(module="trax.data") +def FilterByLength( + max_length, + min_length=0, # pylint: disable=invalid-name + length_keys=None, + length_axis=0, +): + """Returns a function that filters out examples by length. + + Args: + max_length: int. If not None, indicates maximum length. + min_length: int. If not None, indicates minimum length. + length_keys: (list) which example keys to take into account. + length_axis: which shape axis to take into account. + Returns: + a function that filters out examples by length. + """ + + assert max_length is not None or min_length is not None + length_keys = length_keys or [0, 1] + length_fn = lambda x: _length_fn(x, length_axis, length_keys) + + def filtered(gen): + for example in gen: + example_len = length_fn(example) + + # Checking max length boundary. + if max_length is not None: + if example_len > max_length: + continue + # Checking min length boundary. + if min_length is not None: + if example_len < min_length: + continue + # Within bounds. + yield example + + return filtered + + +@gin.configurable(module="trax.data") +def TruncateToLength(len_map=None): # pylint: disable=invalid-name + """Returns a stream function that resizes items as specified by ``len_map``. + + Args: + len_map: Dictionary that specifies maximum shapes for potentially multiple + features per stream item. For example, given a stream of tokenized + string pairs, one could enforce a maximum length of 256 tokens for each + string by using ``len_map={0: (256,), 1: (256,)}``. + """ + + @debug_data_pipeline.debug_pipeline + def _truncate_to_length(generator): + for example in generator: + if isinstance(example, np.ndarray): + example = (example,) + if isinstance(example, (list, tuple)): + example = list(example) + if len_map is not None: + for key, max_len in len_map.items(): + example_len = example[key].shape + if example_len > max_len: + example[key] = np.resize(example[key], max_len) + output = tuple(example) + else: + output = None + raise ValueError(f"Unknown example type: {example}") + yield output + + return _truncate_to_length + + +@gin.configurable(module="trax.data") +def PadToLength( # pylint: disable=invalid-name + len_map=None, pad_value=0, multiple=False +): + """Pads the values to lengths given in `len_map'. + + len_map contains a dictionary of example keys to dimension sizes. + + Args: + len_map: dict of int to int, we pad examples to lengths + given by the values of the dict. If multiple is True, the dimensions are + padded to multiple of this value. + pad_value: dict of int to int. The value gets applied to + constant_values on numpy.pad per given dimension. + multiple: boolean. If False, pads to the value of len_map. If True, pads to + closest multiple of value of len_map. + Returns: + Function to pad examples to given lengths. + """ + + @debug_data_pipeline.debug_pipeline + def _pad_to_length(generator): + for example in generator: + if isinstance(example, (list, tuple)): + example = list(example) + for key, value in len_map.items(): + array_length = example[key].shape[0] + if multiple: + padding_len = array_length - ((array_length // value) * value) + else: + padding_len = max([0, value - example[key].shape[0]]) + example[key] = np.pad( + example[key], + pad_width=(0, padding_len), + mode="constant", + constant_values=pad_value[key], + ) + output = tuple(example) + else: + if not isinstance(example, np.ndarray): + raise ValueError(f"example isn't nparray, but should be: {example}") + array_length = example.shape[0] + if multiple: + padding_len = array_length - ( + (array_length // len_map[0]) * len_map[0] + ) + else: + padding_len = max(0, len_map[0] - array_length) + output = np.pad( + example, + pad_width=(0, padding_len), + mode="constant", + constant_values=pad_value[0], + ) + yield output + + if len_map is None: + raise ValueError("len_map parameter should be provided.") + return _pad_to_length + + +@gin.configurable(module="trax.data") +def BucketByLength( + boundaries, + batch_sizes, # pylint: disable=invalid-name + length_keys=None, + length_axis=0, + strict_pad_on_len=False, +): + """Returns a function for bucketing inputs, see `bucket_by_length`.""" + length_keys = length_keys or [0, 1] + # In all cases so far, we use a length function of the following form. + length_fn = lambda x: _length_fn(x, length_axis, length_keys) + return lambda g: bucket_by_length( # pylint: disable=g-long-lambda + g, length_fn, boundaries, batch_sizes, strict_pad_on_len + ) + + +@gin.configurable(module="trax.data") +def MLM( + vocab_size=None, # pylint:disable=invalid-name + max_length=None, + noise_density=0.15, + mean_noise_span_length=3.0, +): + """Pipeline that just does MLM.""" + return Serial( + # Generate sequential chunks. + generate_sequential_chunks(max_length=max_length), + # Generate mask and chunk. + generate_random_noise_mask( + noise_density=noise_density, mean_noise_span_length=mean_noise_span_length + ), + # Consume mask and chunk to give (input, targets). + consume_noise_mask(vocab_size=vocab_size), + ) + + +@gin.configurable(module="trax.data") +def PrefixLM(input_length=128, output_length=512): # pylint:disable=invalid-name + """Chunks examples so as to make inputs/outputs of specified lenghts.""" + + def _f(generator): + for example in generator: + n_tokens = len(example) + # Iterate: + # |--------|<---- input_length ---->|<- output_length ->|--------------| + # ^ ^ ^ ^ + # | | | | + # 0 input_begin_idx input_end_idx output_end_idx + input_begin_idx = 0 + # While you can make an input batch, keep going. + while input_begin_idx + input_length < n_tokens: + input_end_idx = input_begin_idx + input_length + output_end_idx = min(input_end_idx + output_length, n_tokens) + yield ( + example[input_begin_idx:input_end_idx], + example[input_end_idx:output_end_idx], + ) + # Update the indices. + input_begin_idx = output_end_idx + + return _f + + +@gin.configurable(module="trax.data") +def ConcatenateToLMInput(pad_to_length=None): # pylint: disable=invalid-name + """Prepares the input needed for training of Language Models. + + Each example needs to contain two elements (input and target). + Input is concatenated to target and, if pad_to_length is given, padded to + length provided. + The loss_weights indicates only the target, without input nor padding. + + Args: + pad_to_length: int, total length of padding of input and target arrays. + Returns: + Function to return input for a LM. + """ + + @debug_data_pipeline.debug_pipeline + def _concatenate_to_lm_input(generator): + for example in generator: + if isinstance(example, (list, tuple)) and (len(example) == 2): + concatenated = np.concatenate((example[0], example[1]), axis=-1) + loss_weights = np.concatenate( + (np.zeros_like(example[0]), np.ones_like(example[1])) + ) + if pad_to_length is not None: + padding_len = pad_to_length - ( + example[0].shape[0] + example[1].shape[0] + ) + if padding_len < 0: + raise ValueError( + "Example lengths " + f"({example[0].shape[0]}, {example[1].shape[0]}) " + f"longer than pad_to_length ({pad_to_length})." + ) + loss_weights = np.pad(loss_weights, (0, padding_len), "constant") + concatenated = np.pad(concatenated, (0, padding_len), "constant") + output = (concatenated, concatenated, loss_weights) + elif isinstance(example, (list, tuple)) and (len(example) == 1): + # Make x into (x, x) + output = (example[0], example[0]) + elif isinstance(example, np.ndarray): + # Make x into (x, x) + output = (example, example) + else: + output = None + raise ValueError(f"Unknown input to ConcatenateToLMInput: {example}") + yield output + + return _concatenate_to_lm_input + + +@gin.configurable(module="trax.data") +def CastTo( + dtype=np.int32, + indices=( + 0, + 1, + ), + debug=False, +): # pylint: disable=invalid-name + """Casts the given indices to the given dtype.""" + + def _cast_fn(generator): + debug_count = 0 + for example in generator: + debug_count += 1 + assert isinstance(example, tuple) + example = list(example) + dtype_mismatch = False + original_index_and_dtype = [] + for i in range(len(example)): + if i not in indices: + continue + original_type = example[i].dtype + if original_type != dtype: + if not (original_type == np.int64 and dtype == np.int32): + # Downcasting from np.int64 to np.int32 is OK + original_index_and_dtype.append((i, original_type)) + example[i] = example[i].astype(dtype) + dtype_mismatch = True + if debug and dtype_mismatch and original_index_and_dtype: + logging.info( + "dtype mismatch in example[%d] = %r was earlier: %r", + debug_count, + example, + original_index_and_dtype, + ) + yield tuple(example) + + return _cast_fn + + +@gin.configurable(module="trax.data") +def AppendValue(val=None): # pylint: disable=invalid-name + """Appends values provided in 'val` to inputs. + + val are keyed by example keys, its values contain appended tensors. + + Args: + val: dict of int to tensors. Specific keys get the tensors specified in + values appended. + Returns: + Funtion to append tensors to examples. + """ + + @debug_data_pipeline.debug_pipeline + def _append_value(generator): + for example in generator: + if isinstance(example, tuple): + example = list(example) + if val is not None: + for key, value in val.items(): + example[key] = np.append(example[key], value, -1) + output = tuple(example) + else: + if not isinstance(example, np.ndarray): + raise ValueError(f"example isn't nparray, but should be: {example}") + output = np.append(example, val[0]) + yield output + + return _append_value + + +@gin.configurable(module="trax.data") +def AddLossWeights(id_to_mask=None): # pylint: disable=invalid-name + """Returns a function to add loss weights; see `add_loss_weights`.""" + return lambda g: add_loss_weights(g, id_to_mask=id_to_mask) + + +@gin.configurable(module="trax.data") +def UnBatch(): # pylint: disable=invalid-name + """Returns a function which unbatches.""" + + def _unbatch(generator): + for batched_example in generator: + # batched_example is usually like: + # (batched_inputs, batched_outputs) or + # (batched_inputs, batched_outputs, batched_weights) + assert isinstance(batched_example, tuple) + # assert all lengths are the same. + batch_sizes = list( + set(map(lambda example: example.shape[0], batched_example)) + ) + assert len(batch_sizes) == 1 + # Now unbatch examples. + for example_idx in range(batch_sizes[0]): + yield tuple( + map(lambda x: x[example_idx], batched_example) + ) # pylint: disable=cell-var-from-loop + + return _unbatch + + +@gin.configurable(module="trax.data") +def Prefetch(n_prefetch=2): # pylint: disable=invalid-name + """Pre-fetches a number of examples from generator in a separate process.""" + + def prefetch(generator): + in_q, out_q = mp.Queue(), mp.Queue() + p = mp.Process(target=_generator_process, args=(generator, in_q, out_q)) + for _ in range(n_prefetch): + in_q.put(None) + p.start() + while True: + yield out_q.get() + in_q.put(None) + + return prefetch + + +@gin.configurable(module="trax.data") +def UniformlySeek( + name=None, host_id=None, n_hosts=None, dataset_size=None +): # pylint: disable=invalid-name + """Sets each host at (dataset_size/n_hosts)-th of the dataset.""" + if not dataset_size: + dataset_size = 2**18 # 512 * 512 + logging.error( + "No dataset size given to Uniformly seek, assuming: %d", dataset_size + ) + assert name + host_id = jax.process_index() if host_id is None else host_id + n_hosts = n_hosts or jax.host_count() + each_host = int(dataset_size / n_hosts) + + def _f(generator): + # Each host seeks to the appropriate point in the dataset. + num_to_seek = int(host_id * each_host) + start_time = time.time() + logging.info( + "Dataset[%s] host_id[%d] is seeking to position[%d]", + name, + host_id, + num_to_seek, + ) + for _ in range(num_to_seek): + next(generator) + logging.info( + "Dataset[%s] host_id[%d] reached position[%d]. " "Time taken [%s] seconds", + name, + host_id, + num_to_seek, + time.time() - start_time, + ) + for example in generator: + yield example + + return _f + + +@gin.configurable(module="trax.data") +def CountAndSkip(name): # pylint: disable=invalid-name + """Returns a function that counts and skips examples (see above).""" + return lambda g: count_and_skip(g, name) + + +@gin.configurable(module="trax.data") +def Log(n_steps_per_example=1, only_shapes=True): # pylint: disable=invalid-name + """Creates a logging component of the input pipeline.""" + + def log(stream): + counter = 0 + for example in stream: + item_to_log = example + if only_shapes: + item_to_log = fastmath.nested_map(shapes.signature, example) + if counter % n_steps_per_example == 0: + logging.info(str(item_to_log)) + print(item_to_log) + counter += 1 + yield example + + return log + + +def shuffle(samples, queue_size): + """Shuffles a sample stream using a random-out next-in queue of given size. + + Args: + samples: Stream of samples for eventual use as training data or eval data. + queue_size: Minimum number of samples within which the streamed shuffling + takes place. + + Yields: + Shuffled stream of samples, ready for further processing, e.g., grouping + into batches. + """ + if queue_size < 1: + raise ValueError(f"Arg queue_size ({queue_size}) is less than 1.") + if queue_size == 1: + logging.warning("Queue size of 1 results in no shuffling.") + queue = [] + try: + # Prep: fill the queue. + for _ in range(queue_size): + queue.append(next(samples)) + + # Core streaming shuffle: yield sample from random location in queue, then + # fill that location with new sample from input stream. + for sample in samples: + i = np.random.randint(queue_size) + yield queue[i] + queue[i] = sample + except StopIteration: + # Only get here if the initial queue fill fails. + logging.warning( + "Not enough samples (%d) to fill initial queue (size %d).", + len(queue), + queue_size, + ) + + # No new samples coming in; shuffle and drain the queue. + np.random.shuffle(queue) + for sample in queue: + yield sample + + +def batch(generator, batch_size): + """Batch and pad generator as in tf.data.Dataset.padded_batch.""" + if batch_size <= 0: + raise ValueError(f"Batch size must be positive, but is {batch_size}.") + buf = [] + i = 0 + for example in generator: + buf.append(example) # Examples are tuples of tensors. + if len(buf) == batch_size: + # buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)] + # batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3]) + try: + batched_example = tuple( + pad_to_max_dims([np.asarray(tensor) for tensor in x]) + for x in zip(*buf) + ) + except ValueError as e: + for j in range(len(buf)): + logging.error( + "Batch[%d][%d] input shape: %r output shape: %r", + i, + j, + buf[j][0].shape, + buf[j][1].shape, + ) + for j in range(len(buf)): + logging.error("Batch[%d][%d] input: %r", i, j, buf[j][0]) + logging.error("Batch[%d][%d] output: %r", i, j, buf[j][1]) + raise e + i += 1 + yield batched_example + buf = [] + + +def pad_tf_tensors(tensors, boundary=None, strict_pad_on_len=False): + """ + Pad RaggedTensors to a consistent size with advanced padding options. + + Args: + tensors: A list of TensorFlow RaggedTensors to pad + boundary: Optional boundary for padding + strict_pad_on_len: If True, pad strictly to boundary multiples + + Returns: + A padded batch of tensors + """ + # Ensure inputs are RaggedTensors or Tensor + if not all(isinstance(a, (tf.RaggedTensor, tf.Tensor)) for a in tensors): + raise ValueError("All input tensors must be RaggedTensors or Tensor") + + # Get the number of dimensions + dim = tensors[0].shape.rank + + # Handle boundary input + if boundary is not None: + if not isinstance(boundary, (list, tuple)): + boundary = [boundary] * dim + + if len(boundary) != dim: + raise ValueError( + f"Length of boundary ({len(boundary)}) must match tensor dimensions ({dim})" + ) + else: + boundary = [None] * dim + + # Extract lengths for each dimension + def get_tensor_lengths(tensors, dim_index): + """Safely extract lengths for a given dimension.""" + lengths = [] + for t in tensors: + # For the first dimension (row lengths) + if dim_index == 0: + lengths.append(t.nrows()) + # For subsequent dimensions, get the row lengths + else: + # Flatten and get max length of inner dimension + flat_values = t.flat_values + # Handle multi-dimensional ragged tensors + if dim_index < flat_values.shape.ndims: + flat_length = flat_values.shape[dim_index - 1] + lengths.append(flat_length) + else: + lengths.append(0) + return lengths + + # Compute padding lengths + max_len_to_pad = [] + padding_needed = False + + for i in range(dim): + lengths = get_tensor_lengths(tensors, i) + + # Determine max length + max_len = max(lengths) + min_len = min(lengths) + + # Handle boundary and strict padding + cur_boundary = boundary[i] + + if cur_boundary is None: + # No boundary specified, use max length + max_len_pad = max_len + elif strict_pad_on_len: + # Strictly pad to boundary multiples + max_len_pad = math.ceil(max_len / cur_boundary) * cur_boundary + else: + # Use boundary with intelligent power-of-2 adjustment + if max_len <= 0: + max_len_pad = 0 + else: + cur_boundary = max(max_len, cur_boundary) + if 2 * max_len < cur_boundary: + max_len_pad = 2 ** int(np.ceil(np.log2(max_len))) + else: + max_len_pad = cur_boundary + + max_len_to_pad.append(max_len_pad) + + # Check if padding is needed + if max_len_pad != max_len: + padding_needed = True + + # If no padding is needed, stack the tensors + if not padding_needed: + return tf.stack(tensors) + + # Pad each tensor + padded_tensors = [] + for t in tensors: + # Determine padding for each dimension + padding_spec = [] + for i, max_pad in enumerate(max_len_to_pad): + if i == 0: + # Pad rows + row_padding = max_pad - t.nrows() + padding_spec.append([0, row_padding]) + else: + # Pad inner dimensions + try: + flat_values = t.flat_values + if i < flat_values.shape.ndims: + dim_len = flat_values.shape[i - 1] + padding_to_add = max_pad - dim_len + padding_spec.append([0, padding_to_add]) + else: + padding_spec.append([0, 0]) + except Exception: + padding_spec.append([0, 0]) + + # Apply padding + padded_t = tf.pad_to_max_length(t, max_len_to_pad[0], constant_values=0) + padded_tensors.append(padded_t) + + return tf.stack(padded_tensors) + + +def pad_np_tensors(tensors, boundary=None, strict_pad_on_len=False): + """Pad a tuple of tensors to a joint dimension and return their batch. + + For example, a pair of tensors of shape (2, 10) and (3, 9) will be padded + to (3, 10) both and the returned tensor will have shape (2, 3, 10). + + When boundary is specified, we try to pad all unknown dimensions to boundary + if possible, which can help reduce the number of different shapes occurring + in the tensors and speed up XLA compilation. So, for example, a pair of + tensors of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12). + + One special case occurs when boundary is much higher than the padding length + that we'd use without boundary. For example, tensors (2, 10) and (3, 9) with + boundary=12 could end up padded to (12, 12), but this is very wasteful in + the first dimension. In that case, we will use the closest power-of-2 instead + of the boundary, so the we will end up padding to (4, 12) instead of (12, 12). + + Args: + tensors: a tuple or list of tensors to pad + boundary: int or None; if given, expand the padded dimensions to this size + strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] + strictly as a multiple of boundary. + + Returns: + a tensor, the tensors padded together + """ + # TODO(afrozm): Unify this later. + if not all(isinstance(a, np.ndarray) for a in tensors): + raise ValueError("All input tensors must be numpuy array") + + if (boundary is not None) and ( + strict_pad_on_len or isinstance(boundary, (list, tuple)) + ): + ndim = tensors[0].ndim + if not isinstance(boundary, (list, tuple)): + boundary = [boundary] * ndim + + if ndim != len(boundary): + raise ValueError( + f"ndim != len(boundary) - " + f"ndim({ndim}) vs boundary({boundary}) " + f"len(boundary) = {len(boundary)}." + ) + + max_len_per_dim = [0] * ndim + for tensor in tensors: + max_len_per_dim = [max(e, s) for e, s in zip(tensor.shape, max_len_per_dim)] + + # Round everything up to a multiple of boundary in the respective dimension. + len_per_dim = [ + max_len_per_dim[i] if not b else b * math.ceil(max_len_per_dim[i] / b) + for i, b in enumerate(boundary) + ] + + padded_tensors = [ + np.pad( + t, + [(0, len_per_dim[i] - t.shape[i]) for i in range(ndim)], + mode="constant", + constant_values=t.dtype.type(0), + ) + for t in tensors + ] + + return np.stack(padded_tensors) + + max_len_to_pad = [] + padding_needed = False + dim = len(tensors[0].shape) + for i in range(dim): + max_len = max([t.shape[i] for t in tensors]) + min_len = min([t.shape[i] for t in tensors]) + if max_len == min_len and max_len == boundary: # No padding needed. + max_len_to_pad.append(max_len) + elif boundary is None: + max_len_to_pad.append(max_len) + padding_needed = True + else: + padding_needed = True + cur_boundary = max(max_len, boundary) + if 2 * max_len < cur_boundary: + cur_boundary = 2 ** int(np.ceil(np.log2(max_len))) + max_len_to_pad.append(cur_boundary) + if not padding_needed: + return np.stack(tensors) + padded_tensors = [] + for t in tensors: + pad_widths = [(0, max_len_to_pad[i] - t.shape[i]) for i in range(dim)] + padded_t = np.pad( + t, pad_widths, mode="constant", constant_values=t.dtype.type(0) + ) + padded_tensors.append(padded_t) + return np.stack(padded_tensors) + + +def pad_jax_arrays( + arrays: Sequence[jax.Array], + boundary: Optional[Union[int, Sequence[int]]] = None, + strict_pad_on_len: bool = False, +) -> jax.Array: + """Pad a sequence of JAX Arrays to a joint dimension and return their batch. + + For example, a pair of arrays of shape (2, 10) and (3, 9) will be padded + to (3, 10) both and the returned array will have shape (2, 3, 10). + + When boundary is specified, we try to pad all unknown dimensions to boundary + if possible, which can help reduce the number of different shapes occurring + in the arrays and speed up XLA compilation. So, for example, a pair of + arrays of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12). + + One special case occurs when boundary is much higher than the padding length + that we'd use without boundary. For example, arrays (2, 10) and (3, 9) with + boundary=12 could end up padded to (12, 12), but this is very wasteful in + the first dimension. In that case, we will use the closest power-of-2 instead + of the boundary, so we will end up padding to (4, 12) instead of (12, 12). + + Args: + arrays: a sequence of JAX Arrays to pad + boundary: int or None; if given, expand the padded dimensions to this size + or can be a sequence matching the number of dimensions + strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] + strictly as a multiple of boundary. + + Returns: + a JAX Array, the arrays padded together + """ + # Ensure inputs are JAX Arrays + if not all(isinstance(a, jax.Array) for a in arrays): + raise ValueError("All inputs must be JAX Arrays") + + # Handle case with list/tuple boundary or strict padding + if (boundary is not None) and ( + strict_pad_on_len or isinstance(boundary, (list, tuple)) + ): + ndim = arrays[0].ndim + if not isinstance(boundary, (list, tuple)): + boundary = [boundary] * ndim + + if ndim != len(boundary): + raise ValueError( + f"ndim != len(boundary) - " + f"ndim({ndim}) vs boundary({boundary}) " + f"len(boundary) = {len(boundary)}." + ) + + # Find maximum length per dimension + max_len_per_dim = [0] * ndim + for array in arrays: + max_len_per_dim = [max(e, s) for e, s in zip(array.shape, max_len_per_dim)] + + # Round everything up to a multiple of boundary in the respective dimension + len_per_dim = [ + max_len_per_dim[i] if not b else b * math.ceil(max_len_per_dim[i] / b) + for i, b in enumerate(boundary) + ] + + # Pad each array to the target dimensions + padded_arrays = [ + jnp.pad( + a, + [(0, len_per_dim[i] - a.shape[i]) for i in range(ndim)], + mode="constant", + constant_values=a.dtype.type(0), + ) + for a in arrays + ] + + return jnp.stack(padded_arrays) + + # Handle the simpler case (similar to pad_np_tensors second part) + max_len_to_pad = [] + padding_needed = False + dim = arrays[0].ndim + + for i in range(dim): + max_len = max([a.shape[i] for a in arrays]) + min_len = min([a.shape[i] for a in arrays]) + + if max_len == min_len and max_len == boundary: # No padding needed + max_len_to_pad.append(max_len) + elif boundary is None: + max_len_to_pad.append(max_len) + padding_needed = True + else: + padding_needed = True + cur_boundary = max(max_len, boundary) + if 2 * max_len < cur_boundary: + cur_boundary = 2 ** int(jnp.ceil(jnp.log2(max_len))) + max_len_to_pad.append(cur_boundary) + + if not padding_needed: + return jnp.stack(arrays) + + padded_arrays = [] + for a in arrays: + pad_widths = [(0, max_len_to_pad[i] - a.shape[i]) for i in range(dim)] + padded_a = jnp.pad( + a, pad_widths, mode="constant", constant_values=a.dtype.type(0) + ) + padded_arrays.append(padded_a) + + return jnp.stack(padded_arrays) + + +def pad_to_max_dims(tensors, boundary=None, strict_pad_on_len=False): + """ + Unified padding function. Depending on the type of input tensors, it either applies + dense padding (using NumPy) or uses TensorFlow operations for RaggedTensors. + + Args: + tensors: A list or tuple of tensors to pad. They must be either all np.ndarray or all tf.RaggedTensor. + boundary: Optional boundary for padding. + strict_pad_on_len: If True, pad strictly to boundary multiples. + + Returns: + A batched tensor with consistent dimensions. + """ + if all(isinstance(t, tf.RaggedTensor) or isinstance(t, tf.Tensor) for t in tensors): + return pad_tf_tensors(tensors, boundary, strict_pad_on_len) + elif all(isinstance(t, np.ndarray) for t in tensors): + return pad_np_tensors(tensors, boundary, strict_pad_on_len) + elif all(isinstance(t, jax.Array) for t in tensors): + return pad_jax_arrays(tensors, boundary, strict_pad_on_len) + else: + raise ValueError( + "Mixed tensor types not supported. All tensors must be either tf.RaggedTensor, tf.Tensor, jax Array or np.ndarray." + ) + + +def bucket_by_length( + generator, length_fn, boundaries, batch_sizes, strict_pad_on_len=False +): + """Bucket by length, like tf.data.experimental.bucket_by_sequence_length. + + This function draws examples from the provided `generator` and puts an + example into a bucket depending on `l = length_fn(example)`. Which bucket + is used depends on between which `boundaries` is l. When a bucket reaches + its batch size, as specified by `batch_sizes`, generates a batch of + padded examples from this bucket. + + Args: + generator: python generator to draw data from. + length_fn: a function taking the example and returning the length. + boundaries: a list of bucket boundaries. + batch_sizes: a list of batch sizes. + strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] + strictly as a multiple of boundary. + + Yields: + An input batch, which comes from one of the buckets. + """ + buckets = [[] for _ in range(len(batch_sizes))] + boundaries = boundaries + [math.inf] # Max boundary is unlimited. + for example in generator: + length = length_fn(example) + # `bucket_idx` will always be < len(boundaries), since boundaries is right + # padded by `math.inf`. + bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b]) + buckets[bucket_idx].append(example) + if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]: + batched = zip(*buckets[bucket_idx]) + boundary = boundaries[bucket_idx] + boundary = None if boundary == math.inf else boundary + padded_batch = tuple( + pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched + ) + yield padded_batch + buckets[bucket_idx] = [] + + +@debug_data_pipeline.debug_pipeline +def add_loss_weights(generator, id_to_mask=None): + """Add weights to inputs without weights and masks by id if requested. + + The generator stream is augmented in the following way: + + - If the stream consists of pairs `(inputs, targets)`, a loss mask is added + that is creates as a tensor of ones of the same shape as targets. + - If `id_to_mask` is not `None`, and the stream (after the previous point) + has triples `(inputs, targets, weights)`, the weights are multiplied by a + 0/1 mask that is 0 iff targets is equal to `id_to_mask` (1 otherwise). + + Args: + generator: Stream of tuples. + id_to_mask: If not None, int-valued id that represents padding, as opposed + to true target IDs. + + Yields: + Examples from the augmented stream. + """ + for example in generator: + if len(example) > 3 or len(example) < 2: + assert id_to_mask is None, "Cannot automatically mask this stream." + yield example + else: + if len(example) == 2: + weights = np.ones_like(example[1]).astype(np.float32) + else: + weights = example[2].astype(np.float32) + mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32) + weights *= mask + output = (example[0], example[1], weights) + yield output + + +@gin.configurable(module="trax.data") +def generate_random_noise_mask( + noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None +): + """Returns a function that generates a random noise mask.""" + + def _f(generator): + for example in generator: + length = len(example) + noise_mask = random_spans_noise_mask( + length, + noise_density=noise_density, + mean_noise_span_length=mean_noise_span_length, + seed1=seed1, + seed2=seed2, + example=example, + ) + yield (example, noise_mask) + + return _f + + +@gin.configurable(module="trax.data") +def consume_noise_mask(vocab_size=32100): + """Consumes (tokens, noise mask) and returns (inputs, targets).""" + + def _noise_span_to_unique_sentinel(tokens, noise_mask): + prev_token_is_noise = np.pad( + noise_mask[:-1], [1, 0], mode="constant", constant_values=False + ) + first_noise_tokens = np.logical_and( + noise_mask, np.logical_not(prev_token_is_noise) + ) + subsequent_noise_tokens = np.logical_and(noise_mask, prev_token_is_noise) + sentinel = vocab_size - np.cumsum(first_noise_tokens) + tokens = np.where(first_noise_tokens, sentinel, tokens) + return tokens[np.logical_not(subsequent_noise_tokens)] + + def _f(generator): + for tokens, noise_mask in generator: + # Returns inputs and targets. + yield ( + _noise_span_to_unique_sentinel(tokens, noise_mask), + _noise_span_to_unique_sentinel(tokens, np.logical_not(noise_mask)), + ) + + return _f + + +@gin.configurable(module="trax.data") +def generate_sequential_chunks(max_length=None): + """Returns a function that generates chunks of atmost max_length length.""" + + def _f(generator): + for example in generator: + n_tokens = len(example) + if n_tokens <= max_length: + yield example + else: + n_segments = int(math.ceil(float(n_tokens) / float(max_length))) + for i in range(n_segments): + start = max_length * i + end = min(start + max_length, n_tokens) + yield example[start:end] + + return _f + + +@gin.configurable(module="trax.data") +def addition_input_stream( + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + min_length=gin.REQUIRED, + max_length=gin.REQUIRED, + pad_to_multiple=32, + encdec=False, +): + """Data stream for the add problem: x+y(x+y). + + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + min_length: minimal length of w. + max_length: maximal length of w. + pad_to_multiple: int, pad length to be multiple of this number. + encdec: bool, if True return encoder-decoder style inputs (default: False) + + Returns: + python generator of tuples of data examples + """ + base = vocab_size - 3 # We use 0 to pad, base+1 as "+" and base+2 as "". + + def single_example(max_length, min_length): + """Generate a stream of random mini-batches.""" + add_len = (min_length - 1) // 2 + l1 = np.random.randint((max_length - add_len + 1) // 2) + add_len + l2 = np.random.randint(max_length - l1 - 1) + 1 + n1 = random_number_lower_endian(l1, base) + n2 = random_number_lower_endian(l2, base) + result = lower_endian_to_number(n1, base) + lower_endian_to_number(n2, base) + inp = n1 + [base] + n2 + tgt = number_to_lower_endian(result, base) + if encdec: + x = [i + 1 for i in inp] + y = [i + 1 for i in tgt] + weights = [1] * len(tgt) + candidate_example = (np.array(x), np.array(y), np.array(weights)) + if any(len(sample) > max_length for sample in candidate_example): + # sample too long, try again + return single_example(max_length, min_length) + return (np.array(x), np.array(y), np.array(weights)) + else: + x = [base + 2] + [i + 1 for i in inp] + [base + 2] + [i + 1 for i in tgt] + weights = ([0] * (len(inp) + 2)) + ([1] * len(tgt)) + return (np.array(x), np.array(x), np.array(weights)) + + def batches(max_length, min_length): + """Batches of examples.""" + if max_length < 3 or min_length < 3: + raise ValueError("Maximum/minimum length must be at least 3.") + while True: + ex = [single_example(max_length, min_length) for _ in range(batch_size)] + padded_batch = [ + pad_to_max_dims(x, boundary=pad_to_multiple, strict_pad_on_len=True) + for x in zip(*ex) + ] + yield tuple(padded_batch) + + return batches(max_length, min_length) + + +# This is a straightforward translation of T5's random_spans_noise_mask. +def random_spans_noise_mask( + length, + noise_density=0.15, + mean_noise_span_length=3.0, + seed1=None, + seed2=None, + example=None, +): + """Computes span corruption masks given input parameters.""" + # Passing this in case if we want to use for debugging/logging + del example + orig_length = length + # increase length to avoid degeneracy + length = max(length, 2) + num_noise_tokens = int(round(length * noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # Pick the lengths of the noise spans and the non-noise spans + def randomly_segment(num_items, num_segments, seed): + x = np.arange(num_items - 1) < num_segments - 1 + # Set random seed if passed (only in tests for now). + if seed is not None: + np.random.seed(seed) + np.random.shuffle(x) + first_in_segment = np.pad(x, (1, 0), mode="constant") + segment_id = np.cumsum(first_in_segment) + + y = np.roll(segment_id, 1) + y[0] = 0 + idxs = np.pad( + np.squeeze(np.argwhere(segment_id - y), axis=1), (1, 0), mode="constant" + ) + segment_lengths = np.add.reduceat(np.ones_like(segment_id), idxs, axis=0) + return segment_lengths + + noise_span_lengths = randomly_segment(num_noise_tokens, num_noise_spans, seed1) + nonnoise_span_lengths = randomly_segment( + num_nonnoise_tokens, num_noise_spans, seed2 + ) + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), + [num_noise_spans * 2], + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros(length) # all 0s to begin with + span_start_indicator[span_starts] = 1 + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + return is_noise[:orig_length] + + +def lower_endian_to_number(l, base): + """Helper function: convert a list of digits in the given base to a number.""" + return sum([d * (base**i) for i, d in enumerate(l)]) + + +def number_to_lower_endian(n, base): + """Helper function: convert a number to a list of digits in the given base.""" + if n < base: + return [n] + return [n % base] + number_to_lower_endian(n // base, base) + + +def random_number_lower_endian(length, base): + """Helper function: generate a random number as a lower-endian digits list.""" + if length == 1: # Last digit can be 0 only if length is 1. + return [np.random.randint(base)] + prefix = [np.random.randint(base) for _ in range(length - 1)] + return prefix + [np.random.randint(base - 1) + 1] # Last digit is not 0. + + +data_counters = {} # Used by {load,save}_data_counters and count_and_skip + + +def count_and_skip(generator, name): + """Count the number of items in the generator, skip already counted ones. + + This function counts the number of processed examples and puts it into + the global variable `counters`. This variable can be saved and restored, + and if restored, this function will skip examples until the restored counter + is reached. When the data generator is deterministic, this allows to restore + the data reading process from a checkpoint. + + Args: + generator: generator for examples in the dataset. + name: string, a unique id that we use to count the examples + + Yields: + The examples from generator but first skip the number specified in the + global variable counters[name] and next increment this variable every + time a new example appears. + """ + global data_counters + local_counter = 0 + for example in generator: + local_counter += 1 + # This check must be inside the loop due to asynchronous initializations. + if name not in data_counters: + data_counters[name] = 0 + if local_counter > data_counters[name]: + data_counters[name] += 1 + yield example + + +def save_data_counters(output_dir, host_id=None): + """Checkpoint data counters.""" + global data_counters + host_id = jax.process_index() if host_id is None else host_id + fname = os.path.join(output_dir, "data_counters%d.pkl" % host_id) + with tf.io.gfile.GFile(fname, "wb") as f: + pickle.dump(data_counters, f) + + +def load_data_counters(output_dir, host_id=None): + """Checkpoint data counters.""" + global data_counters + host_id = jax.process_index() if host_id is None else host_id + fname = os.path.join(output_dir, "data_counters%d.pkl" % host_id) + if not tf.io.gfile.exists(fname): + logging.info("Did not load data counters as %s does not exist.", fname) + return + with tf.io.gfile.GFile(fname, "rb") as f: + obj = pickle.load(f) + data_counters = obj + + +def _generator_process(generator, in_q, out_q): + for example in generator: + in_q.get() + out_q.put(example) + + +def _buckets_for_length( + bucket_length, batch_size, max_eval_length, n_devices, training +): + """Creates heuristically a set of bucket boundaries and sizes. + + The middle boundary is set to `bucket_length` and the corresponding batch + size is set to `batch_size`. We also create buckets of 1/2 and 1/4 length + with 2x and 4x batch size, and buckets of 2x and 4x and larger length with + 1/2 and 1/4 batch size respectively, and batch size 1 for the final one. + + Args: + bucket_length: the length of the middle bucket. + batch_size: the batch size for the middle bucket. + max_eval_length: the longest bucket length if training=False. + n_devices: number of devices, batch sizes are divisible by that. + training: bool, whether we are training or evaluating. + + Returns: + a pair of lists of integers, (bucket_boundaries, bucket_batch_sizes). + """ + bucket_boundaries = [ + bucket_length // 4, + bucket_length // 2, + bucket_length, + bucket_length * 2, + bucket_length * 4, + bucket_length * 8, + bucket_length * 16, + ] + if not training: + max_eval_length = max_eval_length or bucket_length * 32 + # Set last bucket boundary to be max_eval_length, cut off boundaries + # that are larger than this. + bucket_boundaries = [b for b in bucket_boundaries if b < max_eval_length] + [ + max_eval_length + ] + bucket_boundaries.append(max_eval_length) + bucket_batch_sizes = [ + batch_size * 4, + batch_size * 2, + batch_size, + batch_size // 2, + batch_size // 4, + batch_size // 8, + batch_size // 16, + 1, + ] + if not training: + # The last bucket batch size is always 1, but the one-but-last is + # sized to accommodate the final length = bucket_boundaries[-1], which + # we changed for eval above -- so adjusting here too. + + # Resize if needed, since bucket_batch_sizes may not be the same size + # anymore. + bucket_batch_sizes = bucket_batch_sizes[: len(bucket_boundaries)] + [1] + bucket_batch_sizes[-2] = batch_size // max_eval_length + # Make batch sizes divisible by n_devices. + bucket_batch_sizes = [ + max(b // n_devices, 1) * n_devices for b in bucket_batch_sizes + ] + return (bucket_boundaries, bucket_batch_sizes) + + +def _length_fn(example, length_axis, length_keys): + """Length is the maximum of shape on length_axis over length_keys.""" + if isinstance(example, (list, tuple)): + return max([example[i].shape[length_axis] for i in length_keys]) + return example.shape[length_axis] + + +# ######################################################################## +# Inputs class used by Trainer, and associated helper functions. +# +# Note: In the planned move from Trainer to Loop, the Inputs class should be +# deprecated and finally removed. + + +class Inputs: + """Inputs bundle. + + Inputs bundle holds input streams and shapes for a training run. + It contains stream-creating functions that return python generators + of (input_batch, target_batch) tuples. + + * train_stream: training data that will be used for training + may include all the augmentation or selection the training wants + the shape of examples is [batch_fn.batch_size, ...] + * train_eval_stream: training data used for evaluation + examples from training data but usually without augmentation + the shape of examples is [batch_fn.eval_batch_size, ...] + * eval_stream: evaluation data stream + examples from evaluation data, usually without augmentation + the shape of examples is [batch_fn.eval_batch_size, ...] + * input_shape: the shape of inputs + the [...] above, without batch size + * input_dtype: the data type of inputs + * target_shape: the shape of targets + the [...] above, without batch size + * target_dtype: the data type of targets + """ + + def __init__(self, train_stream, eval_stream=None, train_eval_stream=None): + """Initialize a new set of inputs. + + Args: + train_stream: a function taking n_devices (an int) and returning + a python generator of training batches. + eval_stream: a function taking n_devices (an int) and returning + a python generator of validation batches; + if None, then the training generator will be used for evaluation. + train_eval_stream: a function taking n_devices (an int) and returning + a python generator of batches from + the training set used for evaluation (if None, use train_stream). + """ + if not callable(train_stream): + raise ValueError( + "Trax Inputs should be initialized with a function. " + "Did you forget the n_devices argument? If your inputs " + "do not use it, try lambda _: [your-inputs]." + ) + + self._train_stream = train_stream + self._eval_stream = eval_stream or self._train_stream + + # TODO(lukaszkaiser): should we get rid of this one day? + self._train_eval_stream = train_eval_stream or self._train_stream + + # Peek into the train stream to get an example shape. + example_train_batch = next(train_stream(1)) + + self._input_shape = tuple(example_train_batch[0].shape)[1:] + self._input_dtype = example_train_batch[0].dtype + self._target_shape = tuple(example_train_batch[-1].shape)[1:] + self._target_dtype = example_train_batch[-1].dtype + self._example_shape = [x.shape for x in example_train_batch] + self._example_dtype = [x.dtype for x in example_train_batch] + + def train_stream(self, n_devices): + return self._train_stream(n_devices) + + def eval_stream(self, n_devices): + return self._eval_stream(n_devices) + + def train_eval_stream(self, n_devices): + return self._train_stream(n_devices) + + @property + def input_shape(self): + """Example input shape, without batch dimension.""" + return self._input_shape + + @property + def target_shape(self): + """Example target shape, without batch dimension.""" + return self._target_shape + + @property + def input_dtype(self): + """Dtype of the input.""" + return self._input_dtype + + @property + def target_dtype(self): + """Dtype of the target.""" + return self._target_dtype + + @property + def example_shape_dtype(self): + """Shape and Dtype of an example batch.""" + return self._example_shape, self._example_dtype + + +# Batching and Inputs creation helpers. + + +@gin.configurable(module="trax.data") +def make_inputs(train_stream=gin.REQUIRED, eval_stream=None): + """Create Inputs from two streams; mostly for use in gin configs.""" + if isinstance(train_stream, (list, tuple)): + train_stream = Serial(train_stream)() + if isinstance(eval_stream, (list, tuple)): + eval_stream = Serial(eval_stream)() + eval_stream_fn = None if eval_stream is None else lambda _: eval_stream + return Inputs(train_stream=lambda _: train_stream, eval_stream=eval_stream_fn) + + +@gin.configurable(module="trax.data") +def make_additional_stream(stream=gin.REQUIRED): + """Create a stream mostly for use in gin configs for additional tasks.""" + return Serial(stream)() + + +@gin.configurable(module="trax.data") +def make_parallel_stream(streams=gin.REQUIRED, counters=None): + """Create a parallel stream for use in gin configs for additional tasks.""" + return Parallel(streams, counters=counters)() + + +@gin.configurable(module="trax.data") +def batcher( + data_streams=gin.REQUIRED, + variable_shapes=True, + batch_size_per_device=32, + batch_size=None, + eval_batch_size=32, + bucket_length=32, + buckets=None, + buckets_include_inputs_in_length=False, + batch_shuffle_size=None, + max_eval_length=None, + # TODO(afrozm): Unify padding logic. + id_to_mask=None, + strict_pad_on_len=False, +): + """Batcher: create trax Inputs from single-example data-streams.""" + # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. + # For now leaving the arguments as in batch_fn to reduce gin config changes. + if callable(data_streams): # If we pass a function, e.g., through gin, call. + train_stream, eval_stream = data_streams() + else: + train_stream, eval_stream = data_streams + # pylint: disable=g-long-lambda + batch_train_stream = lambda n_devices: batch_fn( + train_stream(), + True, + n_devices, + variable_shapes, + batch_size_per_device, + batch_size, + eval_batch_size, + bucket_length, + buckets, + buckets_include_inputs_in_length, + batch_shuffle_size, + max_eval_length, + id_to_mask, + strict_pad_on_len, + ) + batch_eval_stream = lambda n_devices: batch_fn( + eval_stream(), + False, + n_devices, + variable_shapes, + batch_size_per_device, + batch_size, + eval_batch_size, + bucket_length, + buckets, + buckets_include_inputs_in_length, + batch_shuffle_size, + max_eval_length, + id_to_mask, + strict_pad_on_len, + ) + batch_train_eval_stream = lambda n_devices: batch_fn( + train_stream(), + False, + n_devices, + variable_shapes, + batch_size_per_device, + batch_size, + eval_batch_size, + bucket_length, + buckets, + buckets_include_inputs_in_length, + batch_shuffle_size, + max_eval_length, + id_to_mask, + strict_pad_on_len, + ) + # pylint: enable=g-long-lambda + return Inputs( + train_stream=batch_train_stream, + eval_stream=batch_eval_stream, + train_eval_stream=batch_train_eval_stream, + ) + + +def batch_fn( + dataset, + training, + n_devices, + variable_shapes, + batch_size_per_device=32, + batch_size=None, + eval_batch_size=32, + bucket_length=32, + buckets=None, + buckets_include_inputs_in_length=False, + batch_shuffle_size=None, + max_eval_length=None, + id_to_mask=None, + strict_pad_on_len=False, +): + """Batching function.""" + # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. + # After that, create a proper doc-string; we may also not need to pass both + # training and eval arguments here, as batcher calls the function separately + # now and it's not under gin-config any more -- consider reducing args. + batch_size = batch_size or batch_size_per_device * n_devices + # If bucketing is not specified, check if target shapes are variable. + cur_batch_size = batch_size if training else eval_batch_size + # Make cur_batch_size divisible by n_devices. + cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices + # Create heuristic buckets if none are specified. + if buckets is None: + logging.info( + "Heuristically setting bucketing to %s based on shapes " + "of target tensors.", + variable_shapes, + ) + if variable_shapes: + buckets = _buckets_for_length( + bucket_length, cur_batch_size, max_eval_length, n_devices, training + ) + + if buckets: + logging.info("Bucketing with buckets %s.", str(buckets)) + + def example_length(x): + """The length function used by bucket_by_sequence_length to bucket.""" + # The input x is a tuple to go on the stack, typically either + # (input, target) or (input, target, mask). + example_inputs, target = x[0], x[1] + # Length is the shape of axis 0 here (no batch yet). + other_length = 0 # We include input length only if asked. + if buckets_include_inputs_in_length: + other_length = example_inputs.shape[0] + return max(target.shape[0], other_length) + + boundaries, batch_sizes = buckets + dataset = bucket_by_length( + dataset, example_length, boundaries, batch_sizes, strict_pad_on_len + ) + else: + logging.info("Not Bucketing cur_batch_size %d.", cur_batch_size) + dataset = batch(dataset, cur_batch_size) + if training and batch_shuffle_size is not None: + dataset = shuffle(dataset, batch_shuffle_size) + return add_loss_weights(dataset, id_to_mask) + + +# Example input functions. + + +@gin.configurable(module="trax.data") +def random_inputs( + input_shape=gin.REQUIRED, + input_dtype=jnp.int32, + input_range=(0, 255), + output_shape=gin.REQUIRED, + output_dtype=jnp.int32, + output_range=(0, 9), +): + """Make random Inputs for debugging. + + Args: + input_shape: the shape of inputs (including batch dimension). + input_dtype: the type of the inputs (int32 by default). + input_range: the range of inputs (defaults to (0, 255)). + output_shape: the shape of outputs (including batch dimension). + output_dtype: the type of the outputs (int32 by default). + output_range: the range of outputs (defaults to (0, 9)). + + Returns: + trax.inputs.Inputs + """ + + def random_minibatches(n_devices): + """Generate a stream of random mini-batches.""" + assert input_range[0] % n_devices == 0 + if input_dtype in [jnp.float16, jnp.float32, jnp.float64]: + rand = np.random.uniform + else: + rand = np.random.random_integers + while True: + inp = rand(input_range[0], input_range[1], input_shape) + inp = inp.astype(input_dtype) + out = rand(output_range[0], output_range[1], output_shape) + out = out.astype(output_dtype) + yield inp, out + + return Inputs(random_minibatches) + + +@gin.configurable(module="trax.data") +def sequence_copy_inputs( + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + train_length=gin.REQUIRED, + eval_min_length=gin.REQUIRED, + eval_max_length=gin.REQUIRED, + reverse=False, + pad_to_multiple=32, +): + """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*. + + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + train_length: maximum length of w for training. + eval_min_length: minimum length of w for eval. + eval_max_length : maximum length of w for eval. + reverse: bool (optional, false by default): reverse the second sequence. + pad_to_multiple: int, pad length to be multiple of this number. + + Returns: + trax.inputs.Inputs + """ + + def random_minibatches(length_list): + """Generate a stream of random mini-batches.""" + while True: + length = random.choice(length_list) + assert length % 2 == 0 + w_length = (length // 2) - 1 + w = np.random.randint( + low=1, high=vocab_size - 1, size=(batch_size, w_length) + ) + zero = np.zeros([batch_size, 1], np.int32) + loss_weights = np.concatenate( + [np.zeros((batch_size, w_length + 2)), np.ones((batch_size, w_length))], + axis=1, + ) + if reverse: + x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1) + else: + x = np.concatenate([zero, w, zero, w], axis=1) + x = _pad_to_multiple_of(x, pad_to_multiple, 1) + loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) + yield (x, x, loss_weights) # Here inputs and targets are the same. + + train_lengths = [2 * (i + 2) for i in range(train_length - 1)] + eval_lengths = [2 * (i + 1) for i in range(eval_min_length, eval_max_length)] + return Inputs( + train_stream=lambda _: random_minibatches(train_lengths), + eval_stream=lambda _: random_minibatches(eval_lengths), + ) + + +@gin.configurable(module="trax.data") +def simple_sequence_copy_inputs( + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + train_length=gin.REQUIRED, + eval_min_length=gin.REQUIRED, + eval_max_length=gin.REQUIRED, + pad_to_multiple=32, +): + """Inputs for the sequence copy problem: w for w in [1..vocab_size-1]*. + + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + train_length: maximum length of w for training. + eval_min_length: minimum length of w for eval. + eval_max_length : maximum length of w for eval. + pad_to_multiple: int, pad length to be multiple of this number. + + Returns: + trax.inputs.Inputs + """ + + def random_minibatches(length_list): + """Generate a stream of random mini-batches.""" + while True: + length = random.choice(length_list) + x = np.random.randint(low=1, high=vocab_size - 1, size=(batch_size, length)) + loss_weights = np.ones((batch_size, length)) + x = _pad_to_multiple_of(x, pad_to_multiple, 1) + loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) + yield (x, x, loss_weights) # Here inputs and targets are the same. + + train_lengths = list(range(1, train_length + 1)) + eval_lengths = list(range(eval_min_length, eval_max_length + 1)) + return Inputs( + train_stream=lambda _: random_minibatches(train_lengths), + eval_stream=lambda _: random_minibatches(eval_lengths), + ) + + +@gin.configurable(module="trax.data") +def addition_inputs( + vocab_size=gin.REQUIRED, + batch_size=gin.REQUIRED, + train_length=gin.REQUIRED, + eval_min_length=gin.REQUIRED, + eval_max_length=gin.REQUIRED, + pad_to_multiple=32, + encdec=False, +): + """Inputs for the add problem: x+y(x+y). + + Args: + vocab_size: how many symbols to use. + batch_size: how large are the batches. + train_length: maximal length of w for training. + eval_min_length: minimal length of w for eval. + eval_max_length: maximal length of w for eval. + pad_to_multiple: int, pad length to be multiple of this number. + encdec: bool, if True return encoder-decoder style inputs (default: False) + + Returns: + trax.inputs.Inputs + """ + train_stream = addition_input_stream( + vocab_size, batch_size, 3, train_length, pad_to_multiple, encdec + ) + eval_stream = addition_input_stream( + vocab_size, + batch_size, + eval_min_length, + eval_max_length, + pad_to_multiple, + encdec, + ) + return Inputs( + train_stream=lambda _: train_stream, eval_stream=lambda _: eval_stream + ) + + +@gin.configurable(module="trax.data") +def sine_inputs( + batch_size=gin.REQUIRED, + length=gin.REQUIRED, + max_phase=(2 * math.pi), + min_period=0.1, + max_period=10.0, +): + """Sinusoids of random period and phase. + + Args: + batch_size (int): Number of examples in a batch. + length (int): Length of each sequence. + max_phase (float): Maximum phase of the sinusoids. + min_period (float): Minimum period of the sinusoids. + max_period (float): Maximum period of the sinusoids. + + Returns: + trax.inputs.Inputs + """ + + def random_series(): + while True: + phase = np.random.uniform(0, max_phase) + period = np.exp(np.random.uniform(np.log(min_period), np.log(max_period))) + x = np.arange(length) + yield np.sin((x - phase) / period) + + def random_minibatches(_): + minibatch = [] + for series in random_series(): + minibatch.append(series) + if len(minibatch) == batch_size: + obs = np.stack(minibatch) + minibatch.clear() + act = np.zeros_like(obs, dtype=np.int32) + mask = np.ones_like(obs) + yield (obs, act, obs, mask) + + return Inputs(train_stream=random_minibatches, eval_stream=random_minibatches) + + +def _pad_to_multiple_of(x, y, axis): + """Pads x to multiple of y on the given axis.""" + pad_len = np.ceil(x.shape[axis] / float(y)) * y + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (0, int(pad_len - x.shape[axis])) + return np.pad(x, pad_widths, mode="constant", constant_values=x.dtype.type(0)) + + +@gin.configurable(module="trax.data") +def ConvertToUnicode(keys=None): # pylint: disable=invalid-name + """Converts to Unicode UTF-8 elements of an example. + + Useful for when TFDS outputs byte arrays. All of the errors of the conversion + are ignored. + + Args: + keys: tuple/list of example dimensions to convert. + + Returns: + Function converting chosen elements of an example to UTF-8. + """ + + @debug_data_pipeline.debug_pipeline + def _convert_to_unicode_str(stream): + for example in stream: + if isinstance(example, (list, tuple)): + new_example = [] + for i, x in enumerate(example): + if keys is None or i in keys: + new_example.append(_to_unicode(x)) + else: + new_example.append(x) + output = tuple(new_example) + yield output + elif isinstance(example, dict): + new_example = {} + for k in example: + if keys is None or k in keys: + new_example[k] = _to_unicode(example[k]) + else: + new_example[k] = example[k] + yield new_example + else: + output = _to_unicode(example) + yield output + + return _convert_to_unicode_str + + +def _to_unicode(s): + # Errors of the casting are ignored (e.g. sequences not allowed by UTF-8), + # in order not to stay with incomplete examples (with empty values). + return str(s, encoding="utf-8", errors="ignore") + + +@gin.configurable(module="trax.data") +def ClassificationVector(vocab_size=8192): # pylint: disable=invalid-name + """Returns a function to convert token sequences to one-hot vectors.""" + return lambda g: classification_vector(g, vocab_size=vocab_size) + + +@debug_data_pipeline.debug_pipeline +def classification_vector(generator, vocab_size=8192): + """Convert token sequences to classification vectors. + + The generator stream is transformed by replacing token sequences with + vectors where each position contains the token ID if that token appears + in the text, otherwise 0. + + Args: + generator: Stream of tuples where the first element is a token sequence. + vocab_size: Size of the vocabulary (defines length of the vector). + + Yields: + Examples with token sequences converted to classification vectors. + """ + for example in generator: + tokens = example[0] + + # Create a zero vector of vocab_size length + class_vector = np.zeros(vocab_size, dtype=np.int32) + + # Set token ID at positions corresponding to tokens + for token_id in tokens: + if 0 <= token_id < vocab_size: # Ensure token_id is in valid range + class_vector[token_id] = token_id + + # Create output tuple with the classification vector replacing tokens + output = (class_vector,) + example[1:] + yield output diff --git a/trax/data/preprocessing/tf/__init__.py b/trax/data/preprocessing/tf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/data/preprocessing/tf/bert.py b/trax/data/preprocessing/tf/bert.py new file mode 100644 index 000000000..1ed17ae54 --- /dev/null +++ b/trax/data/preprocessing/tf/bert.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow data sources and associated prepocessing functions.""" + +import functools + +import gin +import numpy as np + +from trax.data.loader.tf.base import TFDS, next_sentence_prediction_tf + + +@gin.configurable(module="trax.data") +def BertNextSentencePredictionInputs( + dataset_name, # pylint: disable=invalid-name + data_dir=None, + text_key="text", + train=True, + shuffle_size=50000, +): + """Defines a stream for the next sentence prediction task.""" + stream = TFDS( + dataset_name, + data_dir=data_dir, + tfds_preprocess_fn=next_sentence_prediction_tf( + text_key=text_key, + label_sentences=True, + buffer_size=shuffle_size, + ), + keys=["inputs", "targets"], + train=train, + ) + + def split_stream(generator=None): + # split string with 'sentence1:' and 'sentence2:' into two separate strings + for inputs, targets in stream(generator): + # Extract inputs and targets from the dictionary + + text_str = str(inputs)[:-1] # removes last '"' which is always at the end + print(text_str) + sentences = text_str.split("sentence1: ")[1].split(" sentence2: ") + if len(sentences) != 2: + # 'sentence2:' appeared in the text and got mixed up with the label + continue + sent1, sent2 = sentences + yield sent1, sent2, targets == "next" + + return split_stream + + +def BertSingleSentenceInputs( + batch, + labeled=True, + cls_id=101, + sep_id=102, # pylint: disable=invalid-name +): + """Prepares inputs for BERT: add [SEP], [CLS] and create embeddings.""" + if labeled: + for sent1, label in batch: + value_vector = np.concatenate(([cls_id], sent1, [sep_id])) + segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) + yield value_vector, segment_embs, segment_embs, label, np.int32(1) + else: + for (sent1,) in batch: # row is a tuple with 1 element + value_vector = np.concatenate(([cls_id], sent1, [sep_id])) + segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) + yield value_vector, segment_embs, segment_embs + + +def BertDoubleSentenceInputs( + batch, labeled=True, cls_id=101, sep_id=102 # pylint: disable=invalid-name +): + """Prepares inputs for BERT models by adding [SEP] and [CLS] tokens and creating segment embeddings.""" + if labeled: + for sent1, sent2, label in batch: + value_vector = np.concatenate(([cls_id], sent1, [sep_id], sent2, [sep_id])) + + segment_embs = np.zeros(sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) + second_sent_start = sent1.shape[0] + 2 + segment_embs[second_sent_start:] = 1 + yield value_vector, segment_embs, segment_embs, label, np.int32(1) + else: + for sent1, sent2 in batch: + value_vector = np.concatenate(([cls_id], sent1, [sep_id], sent2, [sep_id])) + + segment_embs = np.zeros(sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) + second_sent_start = sent1.shape[0] + 2 + segment_embs[second_sent_start:] = 1 + yield value_vector, segment_embs, segment_embs + + +@gin.configurable(module="trax.data") +def CreateBertInputs( + double_sentence=True, # pylint: disable=invalid-name + labeled=True, + cls_id=101, + sep_id=102, +): + bert_inputs_fn = ( + BertDoubleSentenceInputs if double_sentence else BertSingleSentenceInputs + ) + return functools.partial( + bert_inputs_fn, labeled=labeled, cls_id=cls_id, sep_id=sep_id + ) + + +@gin.configurable(module="trax.data") +def mask_random_tokens( + batch, + explicit_vocab_size=30522, + masking_prob=0.15, + cls_id=101, + sep_id=102, + mask_id=103, + vocab_start_id=999, +): + """Prepares input for the masking task. + + Preparation consist in masking masking_prob percentage of non-special tokens + at each input row; round(masking_prob * num_nonspecial_tokens) random tokens + are selected out of which each token is either + - replaced with [MASK] token with 80% probability, + - replaced with random token with 10% probability, + - or unchanged with 10%. + The implentation is based on + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L342 + + Examples: + - batch is a stream with each row having tuple (token_ids,). Function yields + rows of form (modified_token_ids, original_tokens, token_weights), where + modified_token_ids have [MASK] tokens or random tokens according to the + procedure described above. + - batch is a stream with each row having tuple (token_ids, segment_embeddings, + nsp_label, nsp_weight).Function yields rows of form (modified_token_ids, + segment_embeddings, nsp_label, nsp_weight, original_tokens, token_weights). + + Args: + batch: stream of inputs. Each row in the stream is a tuple which first + element is an array of tokens + explicit_vocab_size: the total size of the vocabulary. + masking_prob: Determines percent of non-special tokens to be selected for + masking. + cls_id: id of the special CLS token. + sep_id: id of the special SEP token. + mask_id: id of the special MASK token. + vocab_start_id: id of first non-special token in the vocabulary. + + Yields: + a stream with tokens masked for MLM training and 2 appended arrays: + - original tokens: a copy of original tokens used as a label for mlm + training + - token_weights: weights distributed uniformly over selected tokens (sum + is 1). Other tokens have 0 weight. + """ + for token_ids, *row_rest in batch: + original_tokens = token_ids.copy() + + # choose tokens for prediction. Chooses 0.15 of + # all non-special tokens + is_special_token = np.logical_or( + token_ids == cls_id, token_ids == sep_id + ) # CLS and SEP tokens + is_special_token = np.logical_or(is_special_token, token_ids == 0) # padding + viable_ids = np.arange(token_ids.shape[0])[~is_special_token] + num_to_sample = round(masking_prob * viable_ids.shape[0]) + if num_to_sample == 0: + # sentence is too short to select given percentage of tokens to mask + continue + candidate_ids = np.random.choice(viable_ids, num_to_sample, replace=False) + + # create weights + token_weights = np.zeros(token_ids.shape) + token_weights[candidate_ids] = 1 / candidate_ids.shape[0] + + prob_scores = np.random.random(candidate_ids.shape) + + # change 80 % of tokens to [MASK] + mask_token_ids = candidate_ids[prob_scores < 0.8] + token_ids[mask_token_ids] = mask_id + + # change 10% of tokens to random token + random_token_ids = candidate_ids[(0.8 <= prob_scores) & (prob_scores < 0.9)] + token_ids[random_token_ids] = np.random.randint( + vocab_start_id, explicit_vocab_size, random_token_ids.shape[0] + ) + + # rest (10%) is left unchaged + yield (token_ids, *row_rest, original_tokens, token_weights) diff --git a/trax/data/preprocessing/tf/c4.py b/trax/data/preprocessing/tf/c4.py new file mode 100644 index 000000000..ddea6e17a --- /dev/null +++ b/trax/data/preprocessing/tf/c4.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow data sources and associated prepocessing functions.""" + +import functools + +import gin +import tensorflow as tf +import tensorflow_text as tf_text + +from trax.data.loader.tf.base import ( + add_eos_to_output_features, + pad_dataset_to_length, + truncate_dataset_on_len, + unsupervised_preprocessors, +) + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def c4_preprocess( + dataset, training, max_target_length=-1, tokenization=None, spm_path=None +): + """Pre-processing function for C4 dataset.""" + del training + + def unicode_decode_chars(features, targets): + targets = tf.strings.unicode_decode(features["text"], "UTF-8") + targets = tf.cast(targets, tf.int64) + features["targets"] = targets + features["inputs"] = targets + return (features, targets) + + def spc_tokenize(tokenizer, features, targets): + del targets + tokenized_text = tokenizer.tokenize(features["text"]) + features["targets"] = tf.cast(tokenized_text, tf.int64) + features["inputs"] = features["targets"] + return features, features["targets"] + + if tokenization == "spc": + if not spm_path: + raise ValueError( + "A valid SentencePiece model path (`spm_path`) must be provided." + ) + + with tf.io.gfile.GFile(spm_path, "rb") as f: + spc_model = f.read() + tokenizer = tf_text.SentencepieceTokenizer(model=spc_model) + dataset = dataset.map(functools.partial(spc_tokenize, tokenizer)) + else: + dataset = dataset.map(unicode_decode_chars) + + def target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_target_length + 1) + + if max_target_length > 0: + dataset = dataset.filter(target_right_length) + + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def c4_bare_preprocess_fn( + dataset, + training=True, + spm_path=None, + copy_pretokenized=True, + sequence_length=None, +): + """ + Preprocess C4 dataset to generate 'inputs' and 'targets' using SentencePiece. + This version is T5-free and uses tensorflow_text for tokenization. + """ + + # Load SentencePiece model + with tf.io.gfile.GFile(spm_path, "rb") as f: + sp_model = f.read() + tokenizer = tf_text.SentencepieceTokenizer(model=sp_model) + + # Rekey: move "text" to "targets", and optionally to "inputs" + def rekey(example): + ret = {"targets": example["text"]} + if copy_pretokenized: + ret["targets_pretokenized"] = example["text"] + return ret + + dataset = dataset.map(rekey, num_parallel_calls=tf.data.AUTOTUNE) + + # Tokenize using SentencePiece + def tokenize(example): + tokens = tokenizer.tokenize(example["targets"]) + tokens = tf.cast(tokens, tf.int64) + example["inputs"] = tokens + example["targets"] = tokens + return example + + dataset = dataset.map(tokenize, num_parallel_calls=tf.data.AUTOTUNE) + + # Preprocess the tokens - the exact preprocessors are set via gin. + dataset = unsupervised_preprocessors( + dataset, training, sequence_length=sequence_length + ) + + # Add EOS. + dataset = add_eos_to_output_features(dataset, training) + + # Truncate and then pad the examples -- all examples have the same shape. + dataset = truncate_dataset_on_len(dataset, training, sequence_length, True) + dataset = pad_dataset_to_length(dataset, training, sequence_length) + + return dataset diff --git a/trax/data/preprocessing/tf/cifar.py b/trax/data/preprocessing/tf/cifar.py new file mode 100644 index 000000000..f9b9a70d0 --- /dev/null +++ b/trax/data/preprocessing/tf/cifar.py @@ -0,0 +1,83 @@ +import gin +import tensorflow as tf + + +# Makes the function accessible in gin configs, even with all args denylisted. +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def cifar10_no_augmentation_preprocess(dataset, training): + del training + + def cast_image(features, targets): + features["image"] = tf.cast(features["image"], tf.float32) / 255.0 + return features, targets + + dataset = dataset.map(cast_image) + return dataset + + +def _cifar_augment_image(image): + """Image augmentation suitable for CIFAR-10/100. + + As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5). + + Args: + image: a Tensor. + + Returns: + Tensor of the same shape as image. + """ + image = tf.image.resize_with_crop_or_pad(image, 40, 40) + image = tf.image.random_crop(image, [32, 32, 3]) + image = tf.image.random_flip_left_right(image) + return image + + +# Makes the function accessible in gin configs, even with all args denylisted. +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def cifar10_augmentation_preprocess(dataset, training): + """Preprocessing for cifar10 with augmentation (see below).""" + + def augment(features, targets): + features["image"] = _cifar_augment_image(features["image"]) + return features, targets + + def cast_image(features, targets): + features["image"] = tf.cast(features["image"], tf.float32) / 255.0 + return features, targets + + if training: + dataset = dataset.map(augment) + dataset = dataset.map(cast_image) + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def cifar10_augmentation_flatten_preprocess( + dataset, training, predict_image_train_weight=0.01 +): + """Preprocessing for cifar10 that flattens it and appends targets.""" + + def augment(features, targets): + features["image"] = _cifar_augment_image(features["image"]) + return features, targets + + def flatten_image(features, targets): + """Flatten the image.""" + img = features["image"] + flat = tf.cast(tf.reshape(img, [-1]), tf.int64) + tgt = tf.expand_dims(targets, axis=0) + flat_with_target = tf.concat([flat, tgt], axis=0) + new_features = {} + new_features["image"] = flat_with_target + predict_image_weight = predict_image_train_weight if training else 0.0 + mask_begin = tf.ones_like(flat) + mask_begin = tf.cast(mask_begin, tf.float32) * predict_image_weight + mask_end = tf.cast(tf.ones_like(tgt), tf.float32) + new_features["mask"] = tf.concat([mask_begin, mask_end], axis=0) + return new_features, flat_with_target + + if training: + dataset = dataset.map(augment) + dataset = dataset.map(flatten_image) + + return dataset diff --git a/trax/data/preprocessing/tf/math.py b/trax/data/preprocessing/tf/math.py new file mode 100644 index 000000000..8913c6322 --- /dev/null +++ b/trax/data/preprocessing/tf/math.py @@ -0,0 +1,859 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow data sources and associated prepocessing functions.""" + +import itertools +import json +import math +import os +import random +import re + +import gin +import numpy as np +import scipy +import scipy.special +import tensorflow as tf + +# pylint: enable=invalid-name + + +def compute_single_result(op_name, num_args): + """An implementation of the most popular ops from the MathQA dataset.""" + # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ + # and specfically line 142 and following in new_DataStructure.py + # for an implementation which covers more details. + if op_name == "add": + return num_args[0] + num_args[1] + elif op_name == "circle_arc": + return num_args[0] / 360 * math.pi * 2 * num_args[1] + elif op_name == "circle_area": + return math.pi * num_args[0] ** 2 + elif op_name == "circle_sector_area": + return num_args[1] / 360 * math.pi * (num_args[0] ** 2) + elif op_name == "circumface": + return 2 * math.pi * num_args[0] + elif op_name == "choose": + return scipy.special.comb(num_args[0], num_args[1]) + elif op_name == "cosine": + return math.cos(num_args[0]) + elif op_name == "cube_edge_by_volume": + return num_args[0] ** (1 / 3) + elif op_name == "combined_work": + return 1 / ( + min(num_args[0], 1 / num_args[0]) + min(num_args[1], 1 / num_args[1]) + ) + elif op_name == "count_interval": + return num_args[0] - num_args[1] + 1 + elif op_name == "diagonal": + return math.sqrt(num_args[0] ** 2 + num_args[1] ** 2) + elif op_name == "divide" or op_name == "speed": + if num_args[1] != 0: + return num_args[0] / num_args[1] + else: + return 0 + elif op_name == "factorial": + return math.factorial(min(15, int(num_args[0]))) + elif op_name == "floor": + return math.floor(num_args[0]) + elif op_name == "find_work": + return 1 / ( + max(min(num_args[0], 1 / num_args[0]), min(num_args[1], 1 / num_args[1])) + - min(min(num_args[0], 1 / num_args[0]), min(num_args[1], 1 / num_args[1])) + ) + elif op_name == "from_percent": + return num_args[0] / 100 + elif op_name == "gain_percent": + return 100 + num_args[0] + elif op_name == "gcd": + return scipy.gcd(int(num_args[0]), int(num_args[1])) + elif op_name == "inverse": + if num_args[0] != 0: + return 1 / num_args[0] + else: + return 0 + elif op_name == "lcm": + return scipy.lcm(int(num_args[0]), int(num_args[1])) + elif op_name == "log": + return math.log(max(1e-5, num_args[0]), 2) + elif op_name == "loss_percent": + return 100 - num_args[0] + elif op_name == "max": + return max(num_args[0], num_args[1]) + elif op_name == "multiply": + return num_args[0] * num_args[1] + elif op_name == "negate_percent": + return 100 - num_args[0] + elif op_name == "negate": + return -num_args[0] + elif op_name == "original_price_before_loss": + return num_args[1] * 100 / (100 + 1e-5 - num_args[0]) + elif op_name == "original_price_before_gain": + return num_args[1] * 100 / (100 + num_args[0]) + elif op_name == "permutation": + n, m = min(num_args[0], num_args[1]), max(num_args[0], num_args[1]) + return math.factorial(int(m)) / math.factorial(int(m - n)) + elif op_name == "power": + return num_args[0] ** min(num_args[1], 5) + elif op_name == "percent": + return num_args[0] / 100 * num_args[1] + elif op_name == "price_after_gain" or op_name == "p_after_gain": + return (1 + num_args[0] / 100) * num_args[1] + elif op_name == "price_after_loss" or op_name == "price_after_loss": + return (1 - num_args[0] / 100) * num_args[1] + elif op_name == "quadrilateral_area": + return num_args[0] * (num_args[1] + num_args[2]) / 2 + elif op_name == "reminder": + return num_args[0] % num_args[1] + elif op_name == "rectangle_area": + return num_args[0] * num_args[1] + elif op_name == "rectangle_perimeter": + return 2 * (num_args[0] + num_args[1]) + elif op_name == "rhombus_area": + return num_args[0] * num_args[1] / 2 + elif op_name == "sine": + return math.sin(num_args[0]) + elif op_name == "sqrt": + return math.sqrt(max(0, num_args[0])) + elif op_name == "subtract": + return num_args[0] - num_args[1] + elif op_name == "square_edge_by_perimeter": + return num_args[0] / 4 + elif op_name == "square_edge_by_area": + return math.sqrt(num_args[0]) + elif op_name == "square_area": + return num_args[0] ** 2 + elif op_name == "surface_cube": + return 6 * num_args[0] ** 2 + elif op_name == "surface_rectangular_prism": + return 2 * ( + num_args[0] * num_args[1] + + num_args[0] * num_args[2] + + num_args[1] * num_args[2] + ) + elif op_name == "semi_circle_perimiter": + return math.pi * num_args[0] + 2 * num_args[0] + elif op_name == "square_perimeter" or op_name == "rhombus_perimeter": + return 4 * num_args[0] + elif op_name == "surface_sphere": + return 4 * math.pi * num_args[0] ** 2 + elif op_name == "speed_ratio_steel_to_stream": + return (num_args[0] + num_args[1]) / (num_args[0] - num_args[1]) + elif op_name == "speed_in_still_water": + return (num_args[0] + num_args[1]) / 2 + elif op_name == "stream_speed": + return (num_args[0] - num_args[1]) / 2 + elif op_name == "trapezium_area": + return num_args[0] * (num_args[1] + num_args[2]) / 2 + elif op_name == "triangle_area": + return num_args[0] * num_args[1] / 2 + elif op_name == "triangle_perimeter": + return num_args[0] + num_args[1] + num_args[2] + elif op_name == "triangle_area_three_edges": + # Heron's formula + s = (num_args[0] + num_args[1] + num_args[2]) / 2 + return math.sqrt( + max(0, s * (s - num_args[0]) * (s - num_args[1]) * (s - num_args[2])) + ) + elif op_name == "union_prob": + return num_args[0] + num_args[1] - num_args[2] + elif op_name == "negate_prob": + return 1 - num_args[0] + elif op_name == "volume_cube": + return num_args[0] ** 3 + elif op_name == "volume_cone": + return math.pi * num_args[0] ** 2 * num_args[1] / 3 + elif op_name == "volume_cylinder": + return math.pi * num_args[0] ** 2 * num_args[1] + elif op_name == "volume_rectangular_prism": + return num_args[0] * num_args[1] * num_args[2] + elif op_name == "volume_sphere": + return 4 / 3 * math.pi * num_args[0] ** 3 + + +def compute_result(list_op, list_num): + """Python execution of MathQA ops.""" + # The last of temporary results is the final answer. + temporary_results = [] + for op in list_op: + op_name = op.split("(")[0] + start_bracket = op.find("(") + end_bracket = op.find(")") + op_args = op[start_bracket + 1 : end_bracket].split(",") + num_args = [] + for arg in op_args: + # The hash stands for a number stored in temporary_results. + # For example #2 refers to the third temporary result. + if arg[0] == "#": + temp_index = int( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + )[0] + ) + num_args.append(temporary_results[temp_index]) + # The n prefix stands for numbers which listed in list_num - + # originally they were contained in the text. + elif arg[0] == "n": + n_index = int( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + )[0] + ) + num_args.append(list_num[n_index]) + elif arg[0] == "c": + if arg == "const_pi": + constant = math.pi + elif arg == "const_deg_to_rad": + constant = math.pi / 180 + else: + consts = re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + ) + if len(consts) == 1: + constant = float(consts[0]) + else: + constant1 = float(consts[0]) + constant2 = float("0." + consts[1]) + constant = constant1 + constant2 + num_args.append(constant) + temporary_results.append(compute_single_result(op_name, num_args)) + return temporary_results + + +def single_op_to_python_command(op_name, num_args): + """An implementation of the most popular ops from the MathQA dataset.""" + # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ + # and specfically line 142 and following in new_DataStructure.py + # for an implementation which covers more details. + if op_name == "add": + return "{} + {}".format(num_args[0], num_args[1]) + elif op_name == "circle_arc": + return "{} / 360 * math.pi * 2 * {}".format(num_args[0], num_args[1]) + elif op_name == "circle_area": + return "math.pi * {}**2".format(num_args[0]) + elif op_name == "circle_sector_area": + return "{} / 360 * math.pi * ({}**2)".format(num_args[1], num_args[0]) + elif op_name == "circumface": + return "2 * math.pi * {}".format(num_args[0]) + elif op_name == "choose": + return "scipy.special.comb({}, {})".format(num_args[0], num_args[1]) + elif op_name == "cosine": + return "math.cos({})".format(num_args[0]) + elif op_name == "cube_edge_by_volume": + return "{}**(1 / 3)".format(num_args[0]) + elif op_name == "combined_work": + return "1 / (min({}, 1 / {}) + min({}, 1 / {}))".format( + num_args[0], num_args[0], num_args[1], num_args[1] + ) + elif op_name == "count_interval": + return "{} - {} + 1".format(num_args[0], num_args[1]) + elif op_name == "diagonal": + return "math.sqrt({}**2 + {}**2)".format(num_args[0], num_args[1]) + elif op_name == "divide" or op_name == "speed": + # safe divide + if num_args[1] != 0: + return "{} / {}".format(num_args[0], num_args[1]) + else: + return "0" + elif op_name == "factorial": + return "math.factorial(min(15, int({})))".format(num_args[0]) + elif op_name == "floor": + return "math.floor({})".format(num_args[0]) + elif op_name == "find_work": + return ( + "1 / (max(min({}, 1 / {}), min({}, 1 / {})) - min(min({}, 1 / {}), " + "min({}, 1 / {})))" + ).format( + num_args[0], + num_args[0], + num_args[1], + num_args[1], + num_args[0], + num_args[0], + num_args[1], + num_args[1], + ) + elif op_name == "from_percent": + return "{} / 100".format(num_args[0]) + elif op_name == "gain_percent": + return "100 + {}".format(num_args[0]) + elif op_name == "gcd": + return "scipy.gcd(int({}), int({}))".format(num_args[0], num_args[1]) + elif op_name == "inverse": + # safe inverse + if num_args[0] != 0: + return "1 / {}".format(num_args[0]) + else: + return "0" + elif op_name == "lcm": + return "scipy.lcm(int({}), int({}))".format(num_args[0], num_args[1]) + elif op_name == "log": + return "math.log(max(1e-5, {}), 2)".format(num_args[0]) + elif op_name == "loss_percent": + return "100 - {}".format(num_args[0]) + elif op_name == "max": + return "max({},{})".format(num_args[0], num_args[1]) + elif op_name == "multiply": + return "{} * {}".format(num_args[0], num_args[1]) + elif op_name == "negate_percent": + return "100 - {}".format(num_args[0]) + elif op_name == "negate": + return "-{}".format(num_args[0]) + elif op_name == "original_price_before_loss": + return "{} * 100 / (100 + 1e-5 - {}) # original price before loss".format( + num_args[1], num_args[0] + ) + elif op_name == "original_price_before_gain": + return "{} * 100 / (100 + {}) # original_price_before gain".format( + num_args[1], num_args[0] + ) + elif op_name == "permutation": + return ( + "math.factorial(int(max({}, {}))) / math.factorial(int(max({}, {}) " + "- min({}, {}))) # find all permutations" + ).format( + num_args[0], num_args[1], num_args[0], num_args[1], num_args[0], num_args[1] + ) + elif op_name == "power": + return "{}**min({}, 5)".format(num_args[0], num_args[1]) + elif op_name == "percent": + return "{} / 100 * {}".format(num_args[0], num_args[1]) + elif op_name == "price_after_gain" or op_name == "p_after_gain": + return "(1 + {} / 100) * {}".format(num_args[0], num_args[1]) + elif op_name == "price_after_loss" or op_name == "price_after_loss": + return "(1 - {} / 100) * {}".format(num_args[0], num_args[1]) + elif op_name == "quadrilateral_area": + return "{} * ({} + {}) / 2 # quadrilateral area".format( + num_args[0], num_args[1], num_args[2] + ) + elif op_name == "reminder": + return "{} % {}".format(num_args[0], num_args[1]) + elif op_name == "rectangle_area": + return "{} * {} # area of rectangle".format(num_args[0], num_args[1]) + elif op_name == "rectangle_perimeter": + return "2 * ({} + {}) # perimetere of rectangle".format( + num_args[0], num_args[1] + ) + elif op_name == "rhombus_area": + return "{} * {} / 2".format(num_args[0], num_args[1]) + elif op_name == "sine": + return "math.sin({})".format(num_args[0]) + elif op_name == "sqrt": + return "math.sqrt(max(0, {}))".format(num_args[0]) + elif op_name == "subtract": + return "{} - {}".format(num_args[0], num_args[1]) + elif op_name == "square_edge_by_perimeter": + return "{} / 4. # square edge given perimeter".format(num_args[0]) + elif op_name == "square_edge_by_area": + return "math.sqrt({}) # square edge given area".format(num_args[0]) + elif op_name == "square_area": + return "{}**2".format(num_args[0]) + elif op_name == "surface_cube": + return "6 * {}**2 # surface of a cube".format(num_args[0]) + elif op_name == "surface_rectangular_prism": + return "2 * ({} * {} + {} * {} + {} * {}) # surface of a rectangular prism".format( + num_args[0], num_args[1], num_args[0], num_args[2], num_args[1], num_args[2] + ) + elif op_name == "semi_circle_perimiter": + return "math.pi * {} + 2 * {} # perimeter of a semi-circle".format( + num_args[0], num_args[0] + ) + elif op_name == "square_perimeter" or op_name == "rhombus_perimeter": + return "4 * {}".format(num_args[0]) + elif op_name == "surface_sphere": + return "4 * math.pi * {}**2".format(num_args[0]) + elif op_name == "speed_ratio_steel_to_stream": + return "({} + {}) / ({} - {})".format( + num_args[0], num_args[1], num_args[0], num_args[1] + ) + elif op_name == "speed_in_still_water": + return "{} + {} / 2".format(num_args[0], num_args[1]) + elif op_name == "stream_speed": + return "{} - {} / 2".format(num_args[0], num_args[1]) + elif op_name == "trapezium_area": + return "{} * ({} + {}) / 2".format(num_args[0], num_args[1], num_args[2]) + elif op_name == "triangle_area": + return "{} * {} / 2".format(num_args[0], num_args[1]) + elif op_name == "triangle_perimeter": + return "{} + {} + {} # perimeter of a triangle".format( + num_args[0], num_args[1], num_args[2] + ) + elif op_name == "triangle_area_three_edges": + return ( + "(lambda s, a, b, c: math.sqrt(max(0, s * (s - a) * (s - b) * (s - " + "c))))(({} + {} + {}) / 2, {}, {}, {}) # Heron's formula" + ).format( + num_args[0], num_args[1], num_args[2], num_args[0], num_args[1], num_args[2] + ) + elif op_name == "union_prob": + return "{} + {} - {}".format(num_args[0], num_args[1], num_args[2]) + elif op_name == "negate_prob": + return "1 - {}".format(num_args[0]) + elif op_name == "volume_cube": + return "{}**3".format(num_args[0]) + elif op_name == "volume_cone": + return "math.pi * {}**2 * {} / 3".format(num_args[0], num_args[1]) + elif op_name == "volume_cylinder": + return "math.pi * {}**2 * {}".format(num_args[0], num_args[1]) + elif op_name == "volume_rectangular_prism": + return "{} * {} * {}".format(num_args[0], num_args[1], num_args[2]) + elif op_name == "volume_sphere": + return "4 / 3 * math.pi * {}**3".format(num_args[0]) + + +def compute_program(list_op): + """Python execution of MathQA ops.""" + # The last of temporary results is the final answer. + temporary_results = [] + num_op = 0 + for op in list_op: + op_name = op.split("(")[0] + start_bracket = op.find("(") + end_bracket = op.find(")") + op_args = op[start_bracket + 1 : end_bracket].split(",") + num_args = [] + for arg in op_args: + # The hash stands for a number stored in temporary_results. + # For example #2 refers to the third temporary result. + if arg[0] == "#": + temp_index = int( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + )[0] + ) + num_args.append("t{}".format(temp_index)) + # The n prefix stands for numbers which listed in list_num - + # originally they were contained in the text. + elif arg[0] == "n": + # n_index = int( + # re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', + # arg)[0]) + num_args.append(arg) + elif arg[0] == "c": + if arg == "const_pi": + constant = math.pi + elif arg == "const_deg_to_rad": + constant = math.pi / 180 + else: + consts = re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", arg + ) + if len(consts) == 1: + constant = float(consts[0]) + else: + constant1 = float(consts[0]) + constant2 = float("0." + consts[1]) + constant = constant1 + constant2 + num_args.append(str(constant)) + temporary_result = "t{} = {}".format( + num_op, single_op_to_python_command(op_name, num_args) + ) + temporary_results.append(temporary_result) + num_op += 1 + return temporary_results + + +def compute_nums(question): + """Finds numbers in a string and convert them to floats.""" + # The funny looking replace is needed to deal with numbers such as 4,000 + # TODO(henrykm) deal with numbers written as words "one", "two", ... + return [ + float(num.replace(",", "")) + for num in re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", question + ) + ] + + +def compute_ops(linear_formula): + list_op = linear_formula.split("|") + # In some cases the list of operations contains a superflous last element, + # namely an empty string. + if not list_op[-1]: + list_op = list_op[:-1] + return list_op + + +def process_single_mathqa_example(example): + """Execute a single example and verify coherence of a MathQA problem. + + Args: + example: a dictionary with the following fields: Problem - a natural + language formulation of the problem Rationale - a natural language + solution of the problem options - five possible answers ( a) b) c) d) and + e) ) correct - the letter representing the correct answer + annotated_formula - formula representing the full solution linear_formula + - a string of operations separated by the | character, e.g. + multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)| + multiply(#2,const_100)|divide(#3,#1)| category - a natural language + description of the category to which a given problem belongs. + + Returns: + answer_num: numerical answer contained in the example + python_result: numerical answers computed in Python, including intermediate + results. The answer_num should be close python_result[-1] + list_op: list of arithmetic operations + list_num: list of identified numbers in the text + """ + question = example["Problem"] + list_num = compute_nums(question) + list_op = compute_ops(example["linear_formula"]) + answers = example["options"] + correct_answer = example["correct"] + index = answers.find("{} )".format(correct_answer)) + answer_string = re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", answers[index:] + ) + # The if statement deals with empty lists - they are needed to treat + # a correct non-numerical answer e) None of the above. Here we do not want + # non-numerical answers, hence we return None. + if answer_string: + answer_num = float( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", answers[index:] + )[0].replace(",", "") + ) + else: + return None + # The if statements below deals with answers written as fractions e.g. + # a ) 1 / 2 , b ) 1 / 3 , c ) 1 / 5 , d ) 10 / 30 , e ) 2 / 5 ? + index_end_of_answer = index + len(str(answer_num)) + 3 + if index_end_of_answer < len(answers) and answers[index_end_of_answer] == "/": + answer_denom = float( + re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", + answers[index_end_of_answer:], + )[0].replace(",", "") + ) + answer_num /= answer_denom + python_result = compute_result(list_op, list_num) + python_program = compute_program(list_op) + return answer_num, python_result, python_program, list_op, list_num + + +def convert_float_to_mathqa(number): + floor = int(float(number)) + if floor == number: + return "const_" + str(floor) + else: + return "const_" + str(floor) + "_" + str(number)[len(str(floor)) + 1 :] + + +def convert_to_subtract(const_string): + return "subtract({},const_0)".format(const_string) + + +def execute_mathqa_dsl_program(problem, dsl_code): + """Executes the DSL code for a given problem. + + Args: + problem: problem formulation (needed to get parameters). + dsl_code: DSL code. + + Returns: + the result of executing of the DSL code. + """ + n0_loc = problem.find("n0") + list_num = compute_nums(problem[n0_loc:]) + # The list contains _all_ numbers in the string, hence in particular + # for n0 = 2.0 n1 = 3.0 we are getting list_num = [0.0, 2.0, 1.0, 3.0], + # so that below we are filtering the odd occurrences. + assert len(list_num) % 2 == 0 + list_num = [list_num[2 * i + 1] for i in range(int(len(list_num) / 2))] + + # dsl_code is a list of strings; since all DSL programs are single liners, + # we need to guess the correct line. For now we use the same location as in + # in the ground truth examples, that is the first line. + list_op = compute_ops(dsl_code[0]) + + try: + results = compute_result(list_op, list_num)[-1] + except: # pylint: disable=bare-except + results = None + return results + + +def is_number(s): + try: + float(s) + return True + except: # pylint: disable=bare-except + return False + + +def execute_mathqa_program(problem, program): + """Executes the DSL code for a given problem. + + Args: + problem: problem formulation (not needed, but we want the same API as + in the DSL case). + program: Python code. + + Returns: + the result of executing of the Python code. + """ + del problem # problem only needed in the DSL version. + # Programs are lists of strings. We need to concatenate them in order to exec. + program = "\n".join(program) + var_dict = {} + try: + # The logic of this is the following: if exec with timeout is working + # without exceptions, then we can call exec again and gather the variables. + exec(program, globals(), var_dict) # pylint: disable=exec-used + if "answer" in var_dict and is_number(var_dict["answer"]): + return float(var_dict["answer"]) + else: + return None + except: # pylint: disable=bare-except + return None + + +@gin.configurable(module="trax.data") +def CreateMathQAInputs( # pylint: disable=invalid-name + dataset_path=None, + train=True, + test=False, + challenge=False, + tolerance=0.01, + cumulative=True, + python_code=False, + full_dict=False, + partial_results=True, + nlp_rationale=False, + correct_answer=False, + answer_in_mathqa_format=True, + correct_answer_given_reasoning=False, + category=False, + order_prediction=False, + reduced_operation_name=True, + qed=False, +): + """Prepares MathQA inputs. + + The generation procedure leaves a lot parameters to be set by the user. + Currently we support only correct examples in the following sense: + python execution agrees with the declared answer up to 1%. + + According to this criterion wrong examples such as + problem: calculate 85184 Ãˇ ? = 352 + operations ['multiply(n0,n1)'] + are ignored (this should be divide(n0,n1) in this case). + + Args: + dataset_path: a path with the MathQA dataset. + train: if True, then generate training examples; if train, test and + challenge are set to False generate validation examples. + test: if train is set to False and test is set to True, + then generate test examples. + challenge: if train and test are set to False and challenge is set to True, + then generate challenge examples. + tolerance: if for a given example relative difference between Python result + and the result declared in the dataset exceeds the level, then the example + is dropped; tolerances ranging from 0.1 to 0.001 yield from 18K to 21K + examples. + cumulative: if set to True, then generate examples in the format input - + problem + numbers + op1 + op2 + op3 target - op4 If set to False, then + examples are in the format input - problem + numbers target - all + operations. + python_code: if set to True, then generates python code instead of + MathQA commands. + full_dict: if set to True, then Python examples are returned together with + the DSL code and the NLP rationale. + partial_results: if set to True, then partial results will be reported as + part of the input, e.g. input - problem + numbers + op1 + #1 + op2 + #2 + + op3 + #3, target - op4, where #k is the partial results from operation + opk. Activated only in cumulative set to True. + nlp_rationale: if set to True, then input is the problem and the target is + the nlp rationale. + correct_answer: if set to True, then input is the problem plus all possible + answers and the target is the correct answer. + answer_in_mathqa_format: if set to True, then convert numerical answer to + the MathQA format and wrap it in the subtract operation. + E.g. "3.13" is converted to "subtract(const_3_13,const_0)". + correct_answer_given_reasoning: if set to True, then input is the problem + plus linear formula plus all possible answers and the target is the + correct answer. + category: if set to True, then input is the problem and the target is its + category. + order_prediction: if set to True, then input is the problem and a list of + all operations; with probability 0.5 two operations are swapped; the task + consists in detecting whether the operations were swapped. See the + order prediction task in CreateAquaInputs in this file. + reduced_operation_name: If set to True, then in order prediction consider + only the operation token without parameterers. + qed: if set to True, then the reasoning is finished with an additional + operation qed. + + Returns: + mathqa_yield_examples: a generator of MathQA examples; the generator yields + non-tokenized examples - they can be further processed using for example + the tokenize function from this module + """ + if train: + dataset_path = os.path.join(dataset_path, "train.json") + elif test: + dataset_path = os.path.join(dataset_path, "test.json") + elif challenge: + dataset_path = os.path.join(dataset_path, "challenge_test.json") + else: + dataset_path = os.path.join(dataset_path, "dev.json") + # Opening with GFile allows to use remotely stored files, e.g. + # in a gs bucket. + dataset_handle = tf.io.gfile.GFile(dataset_path, "r") + dataset = json.load(dataset_handle) + + def mathqa_yield_examples(generator=None): + del generator + while True: + for example in itertools.cycle(dataset): + result = process_single_mathqa_example(example) + # TODO(henrykm): Remove the first two ifs. + if not result: + continue + answer_num, python_result, python_program, list_op, list_num = result + if not answer_num or not python_result[-1]: + continue + if qed: + list_op.append("qed") + if math.isclose(answer_num, python_result[-1], rel_tol=tolerance): + input_prefix = example["Problem"] + for i in range(len(list_num)): + input_prefix += " n{} = {}".format(i, list_num[i]) + if cumulative: + for i in range(len(list_op)): + input_values = input_prefix + target_values = list_op[i] + input_prefix += " " + list_op[i] + if partial_results: + input_prefix += " #{} = {}".format(i, answer_num) + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif python_code: + input_values = "# " + input_prefix + target_values = "" + for command in python_program: + if "math" in command: + target_values += "import math\n" + break + for command in python_program: + if "scipy" in command: + target_values += "import scipy\n" + break + for i in range(len(list_num)): + target_values += "n{} = {}\n".format(i, list_num[i]) + target_values += "\n".join(python_program[:-1]) + final_line = python_program[-1].split("=")[1] + target_values += "\nanswer ={}".format(final_line) + var_dict = {} + # We generate a python code and want to check whether the answer + # is coorect. + exec( + target_values, globals(), var_dict + ) # pylint: disable=exec-used + if math.isclose( + answer_num, var_dict["answer"], rel_tol=tolerance + ): + if full_dict: + yield ( + input_values, + target_values, + example["linear_formula"], + example["Rationale"], + ) + else: + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif nlp_rationale: + input_values = "infer full rationale: " + input_prefix + target_values = example["Rationale"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif correct_answer: + input_values = "infer correct answer: " + input_prefix + input_values += " " + example["options"] + if answer_in_mathqa_format: + target_values = str(answer_num) + target_values = convert_to_subtract( + convert_float_to_mathqa(target_values) + ) + else: + target_values = example["correct"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif correct_answer_given_reasoning: + input_values = ( + "infer correct answer given reasoning: " + input_prefix + ) + input_values += ( + " " + " ".join(list_op) + " " + example["options"] + ) + target_values = example["correct"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif category: + input_values = "infer category: " + input_prefix + target_values = example["category"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + elif order_prediction: + if np.random.uniform() < 0.5 and len(list_op) >= 2: + idx = range(len(list_op)) + i1, i2 = random.sample(idx, 2) + list_op[i1], list_op[i2] = list_op[i2], list_op[i1] + target_values = "not_ordered" + else: + target_values = "ordered" + if reduced_operation_name: + list_op = [op.split("(")[0] for op in list_op] + input_values = ( + "order prediction: " + + input_prefix + + " " + + " ".join(list_op) + ) + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + else: + input_values = "infer full calculation: " + input_prefix + target_values = example["linear_formula"] + yield ( + input_values, + target_values, + np.array([1] * len(target_values)), + ) + + return mathqa_yield_examples diff --git a/trax/data/preprocessing/tf/video.py b/trax/data/preprocessing/tf/video.py new file mode 100644 index 000000000..417f8c255 --- /dev/null +++ b/trax/data/preprocessing/tf/video.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow data sources and associated prepocessing functions.""" + +import gin + + +@gin.configurable(module="trax.data", denylist=["hparams"]) +def bair_robot_pushing_hparams( + hparams=None, video_num_input_frames=1, video_num_target_frames=15 +): + if hparams is not None: + hparams.video_num_input_frames = video_num_input_frames + hparams.video_num_target_frames = video_num_target_frames + else: + return video_num_input_frames, video_num_target_frames diff --git a/trax/data/preprocessing/tf/wmt.py b/trax/data/preprocessing/tf/wmt.py new file mode 100644 index 000000000..9d12137e1 --- /dev/null +++ b/trax/data/preprocessing/tf/wmt.py @@ -0,0 +1,92 @@ +import gin +import numpy as np +import tensorflow as tf + + +# TODO(lukaszkaiser): find a single more abstract way of text pre-processing. +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def wmt_preprocess( + dataset, training, max_length=-1, max_eval_length=-1, tokenizer=None +): + """Preprocessing for LM1B: filter out targets exceeding maximum length.""" + + def train_right_length(example): + input_length = tf.strings.length(example["inputs"]) + target_length = tf.strings.length(example["targets"]) + max_tensor_length = tf.maximum(input_length, target_length) + return tf.less(max_tensor_length, max_length + 1) + + def eval_right_length(example): + input_length = tf.strings.length(example["inputs"]) + target_length = tf.strings.length(example["targets"]) + max_tensor_length = tf.maximum(input_length, target_length) + return tf.less(max_tensor_length, max_eval_length + 1) + + dataset = dataset.map(lambda x, y: x) + + if max_length > 0 and training: + dataset = dataset.filter(train_right_length) + + if max_eval_length > 0 and not training: + dataset = dataset.filter(eval_right_length) + + def tokenize_example(encoder, example): + """Tokenize examples using a SubwordTextEncoder. + + Args: + encoder: A trax.data.encoder.encoder.SubwordTextEncoder instance + example: A dictionary with 'inputs' and 'targets' keys containing text tensors + + Returns: + A dictionary with tokenized 'inputs' and 'targets' + """ + + def _encode_text(text_tensor): + # Convert tensor to string + if hasattr(text_tensor, "numpy"): + # Handle TensorFlow tensor + text = text_tensor.numpy() + if isinstance(text, bytes): + text = text.decode("utf-8") + else: + # Already string or bytes + text = text_tensor + if isinstance(text, bytes): + text = text.decode("utf-8") + + # Use the encoder's encode method directly + return np.array(encoder.encode(text), dtype=np.int64) + + # Use tf.py_function to handle the Python code within TensorFlow graph + encoded_inputs = tf.py_function(_encode_text, [example["inputs"]], tf.int64) + + encoded_targets = tf.py_function(_encode_text, [example["targets"]], tf.int64) + + # Update the example with encoded data + return {"inputs": encoded_inputs, "targets": encoded_targets}, encoded_targets + + # Apply to your dataset + dataset = dataset.map( + lambda example: tokenize_example(tokenizer, example), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + return dataset + + +@gin.configurable(module="trax.data", denylist=["dataset", "training"]) +def wmt_concat_preprocess(dataset, training, max_length=-1, max_eval_length=-1): + """Preprocessing for WMT: filter exceeding maximum length and concatenate.""" + dataset = wmt_preprocess(dataset, training, max_length, max_eval_length) + + def concat_and_add_mask(features, targets): + inp = features["inputs"] + pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) + concat = tf.concat([inp, pad, targets], axis=0) + mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) + features["inputs"] = concat + features["mask"] = mask + return features, concat + + dataset = dataset.map(concat_and_add_mask) + return dataset diff --git a/trax/data/text_encoder.py b/trax/data/text_encoder.py deleted file mode 100644 index 245d9f312..000000000 --- a/trax/data/text_encoder.py +++ /dev/null @@ -1,1338 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Encoders for text data. - -* TextEncoder: base class -* ByteTextEncoder: for ascii text -* TokenTextEncoder: with user-supplied vocabulary file -* SubwordTextEncoder: invertible -* BertEncoder: for compatible tokenizers with original bert -""" - -import collections -import itertools -import math -import re -import tempfile -import time -import unicodedata - -from absl import logging -import numpy as np -import six -import tensorflow as tf -from trax.data import tokenizer - -# Reserved tokens for things like padding and EOS symbols. -PAD = "" -EOS = "" -RESERVED_TOKENS = [PAD, EOS] -NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) -PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 -EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 -RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] - -# Regular expression for unescaping token strings. -# '\u' is converted to '_' -# '\\' is converted to '\' -# '\213;' is converted to unichr(213) -_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") -_ESCAPE_CHARS = set(u"\\_u;0123456789") - - -# Unicode utility functions that work with Python 2 and 3 -def native_to_unicode(s): - if is_unicode(s): - return s - try: - return to_unicode(s) - except UnicodeDecodeError: - res = to_unicode(s, ignore_errors=True) - logging.info("Ignoring Unicode error, outputting: %s", res) - return res - - -def is_unicode(s): - return isinstance(s, six.text_type) - - -def to_unicode(s, ignore_errors=False): - if is_unicode(s): - return s - error_mode = "ignore" if ignore_errors else "strict" - return s.decode("utf-8", errors=error_mode) - - -def to_unicode_ignore_errors(s): - return to_unicode(s, ignore_errors=True) - - -def to_unicode_utf8(s): - return s.decode("utf-8") - - -def strip_ids(ids, ids_to_strip): - """Strip ids_to_strip from the end IDs.""" - ids = list(ids) - while ids and ids[-1] in ids_to_strip: - ids.pop() - return ids - - -class TextEncoder: - """Base class for converting from ints to/from human readable strings.""" - - def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): - self._num_reserved_ids = num_reserved_ids - - @property - def num_reserved_ids(self): - return self._num_reserved_ids - - def encode(self, s): - """Transform a human-readable string into a sequence of int IDs. - - The IDs should be in the range [num_reserved_ids, vocab_size). IDs [0, - num_reserved_ids) are reserved. - - EOS is not appended. - - Args: - s: human-readable string to be converted. - - Returns: - ids: list of integers - """ - return [int(w) + self._num_reserved_ids for w in s.split()] - - def decode(self, ids, strip_extraneous=False): - """Transform a sequence of int IDs into a human-readable string. - - EOS is not expected in IDs. - - Args: - ids: list of integers to be converted. - strip_extraneous: bool, whether to strip off extraneous tokens (EOS and - PAD). - - Returns: - s: human-readable string. - """ - if strip_extraneous: - ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) - return " ".join(self.decode_list(ids)) - - def decode_list(self, ids): - """Transform a sequence of int IDs into a their string versions. - - This method supports transforming individual input/output IDs to their - string versions so that sequence to/from text conversions can be visualized - in a human readable format. - - Args: - ids: list of integers to be converted. - - Returns: - strs: list of human-readable string. - """ - decoded_ids = [] - for id_ in ids: - if 0 <= id_ < self._num_reserved_ids: - decoded_ids.append(RESERVED_TOKENS[int(id_)]) - else: - decoded_ids.append(id_ - self._num_reserved_ids) - return [str(d) for d in decoded_ids] - - @property - def vocab_size(self): - raise NotImplementedError() - - -class ByteTextEncoder(TextEncoder): - """Encodes each byte to an id. For 8-bit strings only.""" - - def encode(self, s): - numres = self._num_reserved_ids - # Python3: explicitly convert to UTF-8 - return [c + numres for c in s.encode("utf-8")] - - def decode(self, ids, strip_extraneous=False): - if strip_extraneous: - ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) - numres = self._num_reserved_ids - decoded_ids = [] - int2byte = six.int2byte - for id_ in ids: - if 0 <= id_ < numres: - decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) - else: - decoded_ids.append(int2byte(id_ - numres)) - # Python3: join byte arrays and then decode string - return b"".join(decoded_ids).decode("utf-8", "replace") - - def decode_list(self, ids): - numres = self._num_reserved_ids - decoded_ids = [] - int2byte = six.int2byte - for id_ in ids: - if 0 <= id_ < numres: - decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) - else: - decoded_ids.append(int2byte(id_ - numres)) - # Python3: join byte arrays and then decode string - return decoded_ids - - @property - def vocab_size(self): - return 2**8 + self._num_reserved_ids - - -class ClassLabelEncoder(TextEncoder): - """Encoder for class labels.""" - - def __init__(self, class_labels=None, class_labels_fname=None): - super(ClassLabelEncoder, self).__init__(num_reserved_ids=0) - - if class_labels_fname: - with tf.io.gfile.GFile(class_labels_fname) as f: - class_labels = [label.strip() for label in f.readlines()] - - assert class_labels - self._class_labels = class_labels - - def encode(self, s): - label_str = s - return self._class_labels.index(label_str) - - def decode(self, ids, strip_extraneous=False): - del strip_extraneous - label_id = ids - if isinstance(label_id, list): - assert len(label_id) == 1 - label_id, = label_id - if isinstance(label_id, np.ndarray): - label_id = np.squeeze(label_id) - return self._class_labels[label_id] - - def decode_list(self, ids): - return [self._class_labels[i] for i in ids] - - @property - def vocab_size(self): - return len(self._class_labels) - - -class OneHotClassLabelEncoder(ClassLabelEncoder): - """One-hot encoder for class labels.""" - - def encode(self, label_str, on_value=1, off_value=0): # pylint: disable=arguments-differ - e = np.full(self.vocab_size, off_value, dtype=np.int32) - e[self._class_labels.index(label_str)] = on_value - return e.tolist() - - def decode(self, ids, strip_extraneous=False): - del strip_extraneous - label_id = ids - if isinstance(label_id, np.ndarray): - label_id = np.squeeze(label_id).astype(np.int8).tolist() - assert isinstance(label_id, list) - assert len(label_id) == self.vocab_size - return self._class_labels[label_id.index(1)] - - @property - def vocab_size(self): - return len(self._class_labels) - - -class TokenTextEncoder(TextEncoder): - """Encoder based on a user-supplied vocabulary (file or list).""" - - def __init__(self, - vocab_filename, - reverse=False, - vocab_list=None, - replace_oov=None, - num_reserved_ids=NUM_RESERVED_TOKENS): - """Initialize from a file or list, one token per line. - - Handling of reserved tokens works as follows: - - When initializing from a list, we add reserved tokens to the vocab. - - When initializing from a file, we do not add reserved tokens to the vocab. - - When saving vocab files, we save reserved tokens to the file. - - Args: - vocab_filename: If not None, the full filename to read vocab from. If this - is not None, then vocab_list should be None. - reverse: Boolean indicating if tokens should be reversed during encoding - and decoding. - vocab_list: If not None, a list of elements of the vocabulary. If this is - not None, then vocab_filename should be None. - replace_oov: If not None, every out-of-vocabulary token seen when encoding - will be replaced by this string (which must be in vocab). - num_reserved_ids: Number of IDs to save for reserved tokens like . - """ - super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) - self._reverse = reverse - self._replace_oov = replace_oov - if vocab_filename: - self._init_vocab_from_file(vocab_filename) - else: - assert vocab_list is not None - self._init_vocab_from_list(vocab_list) - - def encode(self, s): - """Converts a space-separated string of tokens to a list of ids.""" - sentence = s - tokens = sentence.strip().split() - if self._replace_oov is not None: - tokens = [ - t if t in self._token_to_id else self._replace_oov for t in tokens - ] - ret = [self._token_to_id[tok] for tok in tokens] - return ret[::-1] if self._reverse else ret - - def decode(self, ids, strip_extraneous=False): - return " ".join(self.decode_list(ids)) - - def decode_list(self, ids): - seq = reversed(ids) if self._reverse else ids - return [self._safe_id_to_token(i) for i in seq] - - @property - def vocab_size(self): - return len(self._id_to_token) - - def _safe_id_to_token(self, idx): - return self._id_to_token.get(idx, "ID_%d" % idx) - - def _init_vocab_from_file(self, filename): - """Load vocab from a file. - - Args: - filename: The file to load vocabulary from. - """ - with tf.io.gfile.GFile(filename) as f: - tokens = [token.strip() for token in f.readlines()] - - def token_gen(): - for token in tokens: - yield token - - self._init_vocab(token_gen(), add_reserved_tokens=False) - - def _init_vocab_from_list(self, vocab_list): - """Initialize tokens from a list of tokens. - - It is ok if reserved tokens appear in the vocab list. They will be - removed. The set of tokens in vocab_list should be unique. - - Args: - vocab_list: A list of tokens. - """ - - def token_gen(): - for token in vocab_list: - if token not in RESERVED_TOKENS: - yield token - - self._init_vocab(token_gen()) - - def _init_vocab(self, token_generator, add_reserved_tokens=True): - """Initialize vocabulary with tokens from token_generator.""" - - self._id_to_token = {} - non_reserved_start_index = 0 - - if add_reserved_tokens: - self._id_to_token.update(enumerate(RESERVED_TOKENS)) - non_reserved_start_index = len(RESERVED_TOKENS) - - self._id_to_token.update( - enumerate(token_generator, start=non_reserved_start_index)) - - # _token_to_id is the reverse of _id_to_token - self._token_to_id = dict( - (v, k) for k, v in six.iteritems(self._id_to_token)) - - def store_to_file(self, filename): - """Write vocab file to disk. - - Vocab files have one token per line. The file ends in a newline. Reserved - tokens are written to the vocab file as well. - - Args: - filename: Full path of the file to store the vocab to. - """ - with tf.io.gfile.GFile(filename, "w") as f: - for i in range(len(self._id_to_token)): - f.write(self._id_to_token[i] + "\n") - - -def _escape_token(token, alphabet): - """Escape away underscores and OOV characters and append '_'. - - This allows the token to be expressed as the concatenation of a list - of subtokens from the vocabulary. The underscore acts as a sentinel - which allows us to invertibly concatenate multiple such lists. - - Args: - token: A unicode string to be escaped. - alphabet: A set of all characters in the vocabulary's alphabet. - - Returns: - escaped_token: An escaped unicode string. - - Raises: - ValueError: If the provided token is not unicode. - """ - if not isinstance(token, six.text_type): - raise ValueError("Expected string type for token, got %s" % type(token)) - - token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") - ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] - return u"".join(ret) + "_" - - -def _unescape_token(escaped_token): - """Inverse of _escape_token(). - - Args: - escaped_token: a unicode string - - Returns: - token: a unicode string - """ - - def match(m): - if m.group(1) is None: - return u"_" if m.group(0) == u"\\u" else u"\\" - - try: - return six.unichr(int(m.group(1))) - except (ValueError, OverflowError) as _: - return u"\u3013" # Unicode for undefined character. - - trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token - return _UNESCAPE_REGEX.sub(match, trimmed) - - -class SubwordTextEncoder(TextEncoder): - """Class for invertibly encoding text using a limited vocabulary. - - Invertibly encodes a native string as a sequence of subtokens from a limited - vocabulary. - - A SubwordTextEncoder is built from a corpus (so it is tailored to the text in - the corpus), and stored to a file. See text_encoder_build_subword.py. - - It can then be loaded and used to encode/decode any text. - - Encoding has four phases: - - 1. Tokenize into a list of tokens. Each token is a unicode string of either - all alphanumeric characters or all non-alphanumeric characters. We drop - tokens consisting of a single space that are between two alphanumeric - tokens. - - 2. Escape each token. This escapes away special and out-of-vocabulary - characters, and makes sure that each token ends with an underscore, and - has no other underscores. - - 3. Represent each escaped token as a the concatenation of a list of subtokens - from the limited vocabulary. Subtoken selection is done greedily from - beginning to end. That is, we construct the list in order, always picking - the longest subtoken in our vocabulary that matches a prefix of the - remaining portion of the encoded token. - - 4. Concatenate these lists. This concatenation is invertible due to the - fact that the trailing underscores indicate when one list is finished. - - """ - - def __init__(self, filename=None): - """Initialize and read from a file, if provided. - - Args: - filename: filename from which to read vocab. If None, do not load a vocab - """ - self._alphabet = set() - self.filename = filename - if filename is not None: - self._load_from_file(filename) - super(SubwordTextEncoder, self).__init__() - - def encode(self, s): - """Converts a native string to a list of subtoken IDs. - - Args: - s: a native string. - - Returns: - a list of integers in the range [0, vocab_size) - """ - return self._tokens_to_subtoken_ids(tokenizer.encode(native_to_unicode(s))) - - def encode_without_tokenizing(self, token_text): - """Converts string to list of subtoken IDs without calling tokenizer. - - This treats `token_text` as a single token and directly converts it - to subtoken IDs. This may be useful when the default tokenizer doesn't - do what we want (e.g., when encoding text with tokens composed of lots of - nonalphanumeric characters). It is then up to the caller to make sure that - raw text is consistently converted into tokens. Only use this if you are - sure that `encode` doesn't suit your needs. - - Args: - token_text: A native string representation of a single token. - - Returns: - A list of subword token IDs; i.e., integers in the range [0, vocab_size). - """ - return self._tokens_to_subtoken_ids([native_to_unicode(token_text)]) - - def decode(self, ids, strip_extraneous=False): - """Converts a sequence of subtoken IDs to a native string. - - Args: - ids: a list of integers in the range [0, vocab_size) - strip_extraneous: bool, whether to strip off extraneous tokens (EOS and - PAD). - - Returns: - a native string - """ - if strip_extraneous: - ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) - return tokenizer.decode(self._subtoken_ids_to_tokens(ids)) - - def decode_list(self, ids): - return [self._subtoken_id_to_subtoken_string(s) for s in ids] - - @property - def vocab_size(self): - """The subtoken vocabulary size.""" - return len(self._all_subtoken_strings) - - def _tokens_to_subtoken_ids(self, tokens): - """Converts a list of tokens to a list of subtoken IDs. - - Args: - tokens: a list of strings. - - Returns: - a list of integers in the range [0, vocab_size) - """ - ret = [] - for token in tokens: - ret.extend(self._token_to_subtoken_ids(token)) - return ret - - def _token_to_subtoken_ids(self, token): - """Converts token to a list of subtoken IDs. - - Args: - token: a string. - - Returns: - a list of integers in the range [0, vocab_size) - """ - cache_location = hash(token) % self._cache_size - cache_key, cache_value = self._cache[cache_location] - if cache_key == token: - return cache_value - ret = self._escaped_token_to_subtoken_ids( - _escape_token(token, self._alphabet)) - self._cache[cache_location] = (token, ret) - return ret - - def _subtoken_ids_to_tokens(self, subtokens): - """Converts a list of subtoken IDs to a list of tokens. - - Args: - subtokens: a list of integers in the range [0, vocab_size) - - Returns: - a list of strings. - """ - concatenated = "".join( - [self._subtoken_id_to_subtoken_string(s) for s in subtokens]) - split = concatenated.split("_") - ret = [] - for t in split: - if t: - unescaped = _unescape_token(t + "_") - if unescaped: - ret.append(unescaped) - return ret - - def _subtoken_id_to_subtoken_string(self, subtoken): - """Converts a subtoken integer ID to a subtoken string.""" - if 0 <= subtoken < self.vocab_size: - return self._all_subtoken_strings[subtoken] - return u"" - - def _escaped_token_to_subtoken_strings(self, escaped_token): - """Converts an escaped token string to a list of subtoken strings. - - Args: - escaped_token: An escaped token as a unicode string. - - Returns: - A list of subtokens as unicode strings. - """ - # NOTE: This algorithm is greedy; it won't necessarily produce the "best" - # list of subtokens. - ret = [] - start = 0 - token_len = len(escaped_token) - while start < token_len: - for end in range( - min(token_len, start + self._max_subtoken_len), start, -1): - subtoken = escaped_token[start:end] - if subtoken in self._subtoken_string_to_id: - ret.append(subtoken) - start = end - break - - else: # Did not break - # If there is no possible encoding of the escaped token then one of the - # characters in the token is not in the alphabet. This should be - # impossible and would be indicative of a bug. - assert False, "Token substring not found in subtoken vocabulary." - - return ret - - def _escaped_token_to_subtoken_ids(self, escaped_token): - """Converts an escaped token string to a list of subtoken IDs. - - Args: - escaped_token: An escaped token as a unicode string. - - Returns: - A list of subtoken IDs as integers. - """ - return [ - self._subtoken_string_to_id[subtoken] - for subtoken in self._escaped_token_to_subtoken_strings(escaped_token) - ] - - @classmethod - def build_from_generator(cls, - generator, - target_size, - max_subtoken_length=None, - reserved_tokens=None): - """Builds a SubwordTextEncoder from the generated text. - - Args: - generator: yields text. - target_size: int, approximate vocabulary size to create. - max_subtoken_length: Maximum length of a subtoken. If this is not set, - then the runtime and memory use of creating the vocab is quadratic in - the length of the longest token. If this is set, then it is instead - O(max_subtoken_length * length of longest token). - reserved_tokens: List of reserved tokens. The global variable - `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this - argument is `None`, it will use `RESERVED_TOKENS`. - - Returns: - SubwordTextEncoder with `vocab_size` approximately `target_size`. - """ - token_counts = collections.defaultdict(int) - for item in generator: - for tok in tokenizer.encode(native_to_unicode(item)): - token_counts[tok] += 1 - encoder = cls.build_to_target_size( - target_size, - token_counts, - 1, - 1e3, - max_subtoken_length=max_subtoken_length, - reserved_tokens=reserved_tokens) - return encoder - - @classmethod - def build_to_target_size(cls, - target_size, - token_counts, - min_val, - max_val, - max_subtoken_length=None, - reserved_tokens=None, - num_iterations=4): - """Builds a SubwordTextEncoder that has `vocab_size` near `target_size`. - - Uses simple recursive binary search to find a minimum token count that most - closely matches the `target_size`. - - Args: - target_size: Desired vocab_size to approximate. - token_counts: A dictionary of token counts, mapping string to int. - min_val: An integer; lower bound for the minimum token count. - max_val: An integer; upper bound for the minimum token count. - max_subtoken_length: Maximum length of a subtoken. If this is not set, - then the runtime and memory use of creating the vocab is quadratic in - the length of the longest token. If this is set, then it is instead - O(max_subtoken_length * length of longest token). - reserved_tokens: List of reserved tokens. The global variable - `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this - argument is `None`, it will use `RESERVED_TOKENS`. - num_iterations: An integer; how many iterations of refinement. - - Returns: - A SubwordTextEncoder instance. - - Raises: - ValueError: If `min_val` is greater than `max_val`. - """ - if min_val > max_val: - raise ValueError("Lower bound for the minimum token count " - "is greater than the upper bound.") - if target_size < 1: - raise ValueError("Target size must be positive.") - - if reserved_tokens is None: - reserved_tokens = RESERVED_TOKENS - - def bisect(min_val, max_val): - """Bisection to find the right size.""" - present_count = (max_val + min_val) // 2 - logging.info("Trying min_count %d", present_count) - subtokenizer = cls() - subtokenizer.build_from_token_counts( - token_counts, - present_count, - num_iterations, - max_subtoken_length=max_subtoken_length, - reserved_tokens=reserved_tokens) - - # Being within 1% of the target size is ok. - is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size - # If min_val == max_val, we can't do any better than this. - if is_ok or min_val >= max_val or present_count < 2: - return subtokenizer - - if subtokenizer.vocab_size > target_size: - other_subtokenizer = bisect(present_count + 1, max_val) - else: - other_subtokenizer = bisect(min_val, present_count - 1) - - if other_subtokenizer is None: - return subtokenizer - - if (abs(other_subtokenizer.vocab_size - target_size) < - abs(subtokenizer.vocab_size - target_size)): - return other_subtokenizer - return subtokenizer - - return bisect(min_val, max_val) - - def build_from_token_counts(self, - token_counts, - min_count, - num_iterations=4, - reserved_tokens=None, - max_subtoken_length=None): - """Train a SubwordTextEncoder based on a dictionary of word counts. - - Args: - token_counts: a dictionary of Unicode strings to int. - min_count: an integer - discard subtokens with lower counts. - num_iterations: an integer. how many iterations of refinement. - reserved_tokens: List of reserved tokens. The global variable - `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this - argument is `None`, it will use `RESERVED_TOKENS`. - max_subtoken_length: Maximum length of a subtoken. If this is not set, - then the runtime and memory use of creating the vocab is quadratic in - the length of the longest token. If this is set, then it is instead - O(max_subtoken_length * length of longest token). - - Raises: - ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it - is not clear what the space is being reserved for, or when it will be - filled in. - """ - if reserved_tokens is None: - reserved_tokens = RESERVED_TOKENS - else: - # There is not complete freedom in replacing RESERVED_TOKENS. - for default, proposed in zip(RESERVED_TOKENS, reserved_tokens): - if default != proposed: - raise ValueError("RESERVED_TOKENS must be a prefix of " - "reserved_tokens.") - - # Initialize the alphabet. Note, this must include reserved tokens or it can - # result in encoding failures. - alphabet_tokens = itertools.chain( - six.iterkeys(token_counts), - [native_to_unicode(t) for t in reserved_tokens]) - - self._init_alphabet_from_tokens(alphabet_tokens) - - # Bootstrap the initial list of subtokens with the characters from the - # alphabet plus the escaping characters. - self._init_subtokens_from_list( - list(self._alphabet), reserved_tokens=reserved_tokens) - - # We build iteratively. On each iteration, we segment all the words, - # then count the resulting potential subtokens, keeping the ones - # with high enough counts for our new vocabulary. - if min_count < 1: - min_count = 1 - for i in range(num_iterations): - logging.info("Iteration %d", i) - - # Collect all substrings of the encoded token that break along current - # subtoken boundaries. - subtoken_counts = collections.defaultdict(int) - for token, count in six.iteritems(token_counts): - iter_start_time = time.time() - escaped_token = _escape_token(token, self._alphabet) - subtokens = self._escaped_token_to_subtoken_strings(escaped_token) - start = 0 - for subtoken in subtokens: - last_position = len(escaped_token) + 1 - if max_subtoken_length is not None: - last_position = min(last_position, start + max_subtoken_length) - - for end in range(start + 1, last_position): - new_subtoken = escaped_token[start:end] - subtoken_counts[new_subtoken] += count - start += len(subtoken) - iter_time_secs = time.time() - iter_start_time - if iter_time_secs > 0.1: - logging.info( - "Processing token [%s] took {%d} seconds, consider " - "setting Text2TextProblem.max_subtoken_length to a " - "smaller value.", token, iter_time_secs) - - # Array of sets of candidate subtoken strings, by length. - len_to_subtoken_strings = [] - for subtoken_string, count in six.iteritems(subtoken_counts): - lsub = len(subtoken_string) - if count >= min_count: - while len(len_to_subtoken_strings) <= lsub: - len_to_subtoken_strings.append(set()) - len_to_subtoken_strings[lsub].add(subtoken_string) - - # Consider the candidates longest to shortest, so that if we accept - # a longer subtoken string, we can decrement the counts of its prefixes. - new_subtoken_strings = [] - for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1): - subtoken_strings = len_to_subtoken_strings[lsub] - for subtoken_string in subtoken_strings: - count = subtoken_counts[subtoken_string] - if count >= min_count: - # Exclude alphabet tokens here, as they must be included later, - # explicitly, regardless of count. - if subtoken_string not in self._alphabet: - new_subtoken_strings.append((count, subtoken_string)) - for l in range(1, lsub): - subtoken_counts[subtoken_string[:l]] -= count - - # Include the alphabet explicitly to guarantee all strings are encodable. - new_subtoken_strings.extend( - (subtoken_counts.get(a, 0), a) for a in self._alphabet) - new_subtoken_strings.sort(reverse=True) - - # Reinitialize to the candidate vocabulary. - new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings] - if reserved_tokens: - escaped_reserved_tokens = [ - _escape_token(native_to_unicode(t), self._alphabet) - for t in reserved_tokens - ] - new_subtoken_strings = escaped_reserved_tokens + new_subtoken_strings - - self._init_subtokens_from_list(new_subtoken_strings) - logging.info("vocab_size = %d", self.vocab_size) - - @property - def all_subtoken_strings(self): - return tuple(self._all_subtoken_strings) - - def dump(self): - """Debugging dump of the current subtoken vocabulary.""" - subtoken_strings = [ - (i, s) for s, i in six.iteritems(self._subtoken_string_to_id) - ] - print(u", ".join( - u"{0} : '{1}'".format(i, s) for i, s in sorted(subtoken_strings))) - - def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None): - """Initialize token information from a list of subtoken strings. - - Args: - subtoken_strings: a list of subtokens - reserved_tokens: List of reserved tokens. We must have `reserved_tokens` - as None or the empty list, or else the global variable `RESERVED_TOKENS` - must be a prefix of `reserved_tokens`. - - Raises: - ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it - is not clear what the space is being reserved for, or when it will be - filled in. - """ - if reserved_tokens is None: - reserved_tokens = [] - - if reserved_tokens: - self._all_subtoken_strings = reserved_tokens + subtoken_strings - else: - self._all_subtoken_strings = subtoken_strings - - # we remember the maximum length of any subtoken to avoid having to - # check arbitrarily long strings. - self._max_subtoken_len = max([len(s) for s in subtoken_strings]) - self._subtoken_string_to_id = { - s: i + len(reserved_tokens) for i, s in enumerate(subtoken_strings) if s - } - # Initialize the cache to empty. - self._cache_size = 2**20 - self._cache = [(None, None)] * self._cache_size - - def _init_alphabet_from_tokens(self, tokens): - """Initialize alphabet from an iterable of token or subtoken strings.""" - # Include all characters from all tokens in the alphabet to guarantee that - # any token can be encoded. Additionally, include all escaping characters. - self._alphabet = {c for token in tokens for c in token} # pylint: disable=g-complex-comprehension - self._alphabet |= _ESCAPE_CHARS - - def _load_from_file_object(self, f): - """Load from a file object. - - Args: - f: File object to load vocabulary from - """ - subtoken_strings = [] - for line in f: - s = line.rstrip() - # Some vocab files wrap words in single quotes, but others don't - if ((s.startswith("'") and s.endswith("'")) or - (s.startswith("\"") and s.endswith("\""))): - s = s[1:-1] - subtoken_strings.append(native_to_unicode(s)) - self._init_subtokens_from_list(subtoken_strings) - self._init_alphabet_from_tokens(subtoken_strings) - - def _load_from_file(self, filename): - """Load from a vocab file.""" - if not tf.io.gfile.exists(filename): - raise ValueError("File %s not found" % filename) - with tf.io.gfile.GFile(filename) as f: - self._load_from_file_object(f) - - def store_to_file(self, filename, add_single_quotes=True): - with tf.io.gfile.GFile(filename, "w") as f: - for subtoken_string in self._all_subtoken_strings: - if add_single_quotes: - f.write("'" + subtoken_string + "'\n") - else: - f.write(subtoken_string + "\n") - - -class ImageEncoder: - """Encoder class for saving and loading images.""" - - def __init__(self, num_reserved_ids=0, height=None, width=None, channels=3): - assert num_reserved_ids == 0 - self._height = height - self._width = width - self._channels = channels - - @property - def num_reserved_ids(self): - return 0 - - def encode(self, s): - """Transform a string with a filename into a list of RGB integers. - - Args: - s: path to the file with an image. - - Returns: - ids: list of integers - """ - try: - import matplotlib.image as im # pylint: disable=g-import-not-at-top - except ImportError as e: - logging.warning( - "Reading an image requires matplotlib to be installed: %s", e) - raise NotImplementedError("Image reading not implemented.") - return im.imread(s) - - def decode(self, ids, strip_extraneous=False): - """Transform a sequence of int IDs into an image file. - - Args: - ids: list of integers to be converted. - strip_extraneous: unused - - Returns: - Path to the temporary file where the image was saved. - - Raises: - ValueError: if the IDs are not of the appropriate size. - """ - del strip_extraneous - _, tmp_file_path = tempfile.mkstemp("_decode.png") - if self._height is None or self._width is None: - size = int(math.sqrt(len(ids) / self._channels)) - length = size * size * self._channels - else: - size = None - length = self._height * self._width * self._channels - if len(ids) != length: - raise ValueError("Length of ids (%d) must be height (%d) x width (%d) x " - "channels (%d); %d != %d.\n Ids: %s" % - (len(ids), self._height, self._width, self._channels, - len(ids), length, " ".join([str(i) for i in ids]))) - with tf.Graph().as_default(): - raw = tf.constant(ids, dtype=tf.uint8) - if size is None: - img = tf.reshape(raw, [self._height, self._width, self._channels]) - else: - img = tf.reshape(raw, [size, size, self._channels]) - png = tf.image.encode_png(img) - op = tf.write_file(tmp_file_path, png) - with tf.Session() as sess: - sess.run(op) - return tmp_file_path - - def decode_list(self, ids): - """Transform a sequence of int IDs into an image file. - - Args: - ids: list of integers to be converted. - - Returns: - Singleton list: path to the temporary file where the image was saved. - """ - return [self.decode(ids)] - - @property - def vocab_size(self): - return 256 - - -class RealEncoder: - """Encoder class for saving and loading float values.""" - - def encode(self, s): - """Transform a string (space separated float values) into a float array. - - Args: - s: space separated float values. - - Returns: - Array of float values. - """ - return [float(w) for w in s.split()] - - def decode(self, ids, strip_extraneous=False): - """Transform sequence of float values into string (float values). - - Args: - ids: array of floats to be converted. - strip_extraneous: unused - - Returns: - String having space separated float values. - - Raises: - ValueError: if the IDs are not of the appropriate size. - """ - del strip_extraneous - return " ".join([str(i) for i in ids]) - - -class BertEncoder: - """Encoder Class that is compatible with models trained in original BERT library.""" - - def __init__(self, vocab_file, do_lower_case=True): - self._vocab = self.load_vocab(vocab_file) - self._inv_vocab = {v: k for k, v in self._vocab.items()} - self._basic_tokenizer = BertBasicEncoder(do_lower_case=do_lower_case) - self._wordpiece_tokenizer = BertWordpieceTokenizer(vocab=self._vocab) - - def load_vocab(self, vocab_file): - """Loads a vocabulary file into a dictionary.""" - vocab = collections.OrderedDict() - index = 0 - with tf.io.gfile.GFile(vocab_file, "r") as reader: - while True: - token = native_to_unicode(reader.readline()) - if not token: - break - token = token.strip() - vocab[token] = index - index += 1 - return vocab - - def encode(self, text): - return self._convert_tokens_to_ids(self.tokenize(text)) - - # Note: Because encoding by BertEncoder is not unique text decoded - # from token ids is not unique. - def decode(self, ids): - """Returns a text that encoded would yield provided ids.""" - tokens = self._convert_ids_to_tokens(ids) - if not tokens: - return "" - retarr = [tokens[0]] - for token in tokens[1:]: - if token.startswith("##"): - retarr.append(token.lstrip("#")) - else: - retarr.append(" ") - retarr.append(token) - return "".join(retarr) - - @property - def vocab_size(self): - return len(self._vocab) - - def tokenize(self, text): - split_tokens = [] - for token in self._basic_tokenizer.tokenize(text): - for sub_token in self._wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) - - return split_tokens - - def _convert_tokens_to_ids(self, tokens): - return [self._vocab[token] for token in tokens] - - def _convert_ids_to_tokens(self, ids): - return [self._inv_vocab[token_id] for token_id in ids] - - -class BertBasicEncoder: - """Part of BertEncoder; tokenization (punctuation splitting, lower casing).""" - - def __init__(self, do_lower_case=True): - """Constructs a BasicTokenizer. - - Args: - do_lower_case: Whether to lower case the input. - """ - self.do_lower_case = do_lower_case - - def tokenize(self, text): - """Tokenizes a piece of text.""" - text = native_to_unicode(text) - text = self._clean_text(text) - - text = self._tokenize_chinese_chars(text) - - orig_tokens = whitespace_tokenize(text) - split_tokens = [] - for token in orig_tokens: - if self.do_lower_case: - token = token.lower() - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _bert_is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" - output = [] - for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xfffd or _bert_is_control(char): - continue - if _bert_is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - -class BertWordpieceTokenizer: - """Runs WordPiece tokenziation.""" - - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): - self.vocab = vocab - self.unk_token = unk_token - self.max_input_chars_per_word = max_input_chars_per_word - - def tokenize(self, text): - """Tokenizes a piece of text into its word pieces. - - This uses a greedy longest-match-first algorithm to perform tokenization - using the given vocabulary. - For example: - input = "unaffable" - output = ["un", "##aff", "##able"] - Args: - text: A single token or whitespace separated tokens. This should have - already been passed through `BasicTokenizer. - - Returns: - A list of wordpiece tokens. - """ - - text = native_to_unicode(text) - - output_tokens = [] - for token in whitespace_tokenize(text): - chars = list(token) - if len(chars) > self.max_input_chars_per_word: - output_tokens.append(self.unk_token) - continue - - is_bad = False - start = 0 - sub_tokens = [] - while start < len(chars): - end = len(chars) - cur_substr = None - while start < end: - substr = "".join(chars[start:end]) - if start > 0: - substr = "##" + substr - if substr in self.vocab: - cur_substr = substr - break - end -= 1 - if cur_substr is None: - is_bad = True - break - sub_tokens.append(cur_substr) - start = end - - if is_bad: - output_tokens.append(self.unk_token) - else: - output_tokens.extend(sub_tokens) - return output_tokens - - -def _bert_is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _bert_is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat in ("Cc", "Cf"): - return True - return False - - -def _bert_is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False - - -def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens diff --git a/trax/data/text_encoder_build_subword.py b/trax/data/text_encoder_build_subword.py deleted file mode 100644 index 1df9d85cc..000000000 --- a/trax/data/text_encoder_build_subword.py +++ /dev/null @@ -1,80 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Program to build a SubwordTextEncoder. - -The flags --min_count and --corpus_max_lines will affect the size of the -vocabulary. Try changing these flags until you get a vocabulary -of the size you want. - -Example usage: - -python trax/data/text_encoder_build_subword.py \ - --corpus_filepattern=$DATA_DIR/my_problem-train-* \ - --corpus_max_lines=12345 \ - --output_filename=$DATA_DIR/my_problem.subword_text_encoder \ - --logtostderr - -""" - -from absl import app -from absl import flags - -from trax.data import text_encoder -from trax.data import tokenizer - -flags.DEFINE_string('output_filename', '/tmp/my.subword_text_encoder', - 'where to store the SubwordTextEncoder') -flags.DEFINE_string('corpus_filepattern', '', - 'Corpus of one or more text files') -flags.DEFINE_string( - 'vocab_filepattern', '', 'One or more vocabulary files ' - '(one word per line as "word,count")') -flags.DEFINE_integer('min_count', 5, 'Minimum subtoken count in corpus') -flags.DEFINE_integer('corpus_max_lines', 10000, - 'How many lines of corpus to read') -flags.DEFINE_integer('num_iterations', 4, 'Number of iterations') -flags.DEFINE_bool('split_on_newlines', True, 'Break corpus into lines.') - -FLAGS = flags.FLAGS - - -def main(unused_argv): - if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern: - raise ValueError( - 'Must only provide one of --corpus_filepattern or --vocab_filepattern') - - elif FLAGS.corpus_filepattern: - token_counts = tokenizer.corpus_token_counts( - FLAGS.corpus_filepattern, - FLAGS.corpus_max_lines, - split_on_newlines=FLAGS.split_on_newlines) - - elif FLAGS.vocab_filepattern: - token_counts = tokenizer.vocab_token_counts(FLAGS.vocab_filepattern, - FLAGS.corpus_max_lines) - - else: - raise ValueError( - 'Must provide one of --corpus_filepattern or --vocab_filepattern') - - encoder = text_encoder.SubwordTextEncoder() - encoder.build_from_token_counts(token_counts, FLAGS.min_count, - FLAGS.num_iterations) - encoder.store_to_file(FLAGS.output_filename) - - -if __name__ == '__main__': - app.run(main) diff --git a/trax/data/text_encoder_test.py b/trax/data/text_encoder_test.py deleted file mode 100644 index 791f13e9b..000000000 --- a/trax/data/text_encoder_test.py +++ /dev/null @@ -1,376 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data.text_encoder.""" - -import collections -import io -import os -import random -import shutil -import string - -import mock -from six.moves import range # pylint: disable=redefined-builtin -import tensorflow.compat.v1 as tf -from trax.data import text_encoder - - -class NativeToUnicodeTest(tf.test.TestCase): - - def test_native_to_unicode(self): - s = r"foo bar" - s_unicode = text_encoder.native_to_unicode(s) - self.assertEqual(s_unicode, u"foo bar") - - -class EscapeUnescapeTokenTest(tf.test.TestCase): - - def test_escape_token(self): - escaped = text_encoder._escape_token( - "Foo! Bar.\nunder_score back\\slash", - set("abcdefghijklmnopqrstuvwxyz .\n") | text_encoder._ESCAPE_CHARS) - - self.assertEqual( - "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_", escaped) - - def test_unescape_token(self): - unescaped = text_encoder._unescape_token( - "\\70;oo\\33; \\66;ar.\\10;under\\uscore back\\\\slash_") - - self.assertEqual( - "Foo! Bar.\nunder_score back\\slash", unescaped) - - -class TokenTextEncoderTest(tf.test.TestCase): - - @classmethod - def setUpClass(cls): - """Make sure the test dir exists and is empty.""" - cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") - shutil.rmtree(cls.test_temp_dir, ignore_errors=True) - tf.gfile.MakeDirs(cls.test_temp_dir) - - def test_save_and_reload(self): - """Test that saving and reloading doesn't change the vocab. - - Note that this test reads and writes to the filesystem, which necessitates - that this test size be "large". - """ - - corpus = "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z" - vocab_filename = os.path.join(self.test_temp_dir, "abc.vocab") - - # Make text encoder from a list and store vocab to fake filesystem. - encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) - encoder.store_to_file(vocab_filename) - - # Load back the saved vocab file from the fake_filesystem. - new_encoder = text_encoder.TokenTextEncoder(vocab_filename) - - self.assertEqual(encoder._id_to_token, new_encoder._id_to_token) - self.assertEqual(encoder._token_to_id, new_encoder._token_to_id) - - def test_reserved_tokens_in_corpus(self): - """Test that we handle reserved tokens appearing in the corpus.""" - corpus = "A B {} D E F {} G {}".format(text_encoder.EOS, - text_encoder.EOS, - text_encoder.PAD) - - encoder = text_encoder.TokenTextEncoder(None, vocab_list=corpus.split()) - - all_tokens = encoder._id_to_token.values() - - # If reserved tokens are removed correctly, then the set of tokens will - # be unique. - self.assertEqual(len(all_tokens), len(set(all_tokens))) - - -class SubwordTextEncoderTest(tf.test.TestCase): - - @classmethod - def setUpClass(cls): - """Make sure the test dir exists and is empty.""" - cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") - shutil.rmtree(cls.test_temp_dir, ignore_errors=True) - tf.gfile.MakeDirs(cls.test_temp_dir) - - def test_encode_decode(self): - corpus = ( - "This is a corpus of text that provides a bunch of tokens from which " - "to build a vocabulary. It will be used when strings are encoded " - "with a TextEncoder subclass. The encoder was coded by a coder.") - token_counts = collections.Counter(corpus.split(" ")) - alphabet = set(corpus) - {" "} - - original = "This is a coded sentence encoded by the SubwordTextEncoder." - token_counts.update(original.split(" ")) - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - # Encoding should be reversible. - encoded = encoder.encode(original) - decoded = encoder.decode(encoded) - self.assertEqual(original, decoded) - - # The substrings coded and coder are frequent enough in the corpus that - # they should appear in the vocabulary even though they are substrings - # of other included strings. - subtoken_strings = {encoder.all_subtoken_strings[i] for i in encoded} - self.assertIn("encoded_", subtoken_strings) - self.assertIn("coded_", subtoken_strings) - self.assertIn("TextEncoder", encoder.all_subtoken_strings) - self.assertIn("coder", encoder.all_subtoken_strings) - - # Every character in the corpus should be in the encoders alphabet and - # its subtoken vocabulary. - self.assertTrue(alphabet.issubset(encoder._alphabet)) - for a in alphabet: - self.assertIn(a, encoder.all_subtoken_strings) - - def test_unicode(self): - corpus = "Cat emoticons. \U0001F638 \U0001F639 \U0001F63A \U0001F63B" - token_counts = collections.Counter(corpus.split(" ")) - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - self.assertIn("\U0001F638", encoder._alphabet) - self.assertIn("\U0001F63B", encoder.all_subtoken_strings) - - def test_small_vocab(self): - corpus = "The quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - alphabet = set(corpus) - {" "} - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 10, token_counts, 2, 10) - - # All vocabulary elements are in the alphabet and subtoken strings even - # if we requested a smaller vocabulary to assure all expected strings - # are encodable. - self.assertTrue(alphabet.issubset(encoder._alphabet)) - for a in alphabet: - self.assertIn(a, encoder.all_subtoken_strings) - - def test_long_tokens(self): - """Subword tokenization should still run efficiently with long tokens. - - To make it run efficiently, we need to use the `max_subtoken_length` - argument when calling SubwordTextEncoder.build_to_target_size. - """ - token_length = 4000 - num_tokens = 50 - target_vocab_size = 600 - max_subtoken_length = 10 # Set this to `None` to get problems. - max_count = 500 - - # Generate some long random strings. - random.seed(0) - long_tokens = [] - for _ in range(num_tokens): - long_token = "".join([random.choice(string.ascii_uppercase) - for _ in range(token_length)]) - long_tokens.append(long_token) - - corpus = " ".join(long_tokens) - token_counts = collections.Counter(corpus.split(" ")) - alphabet = set(corpus) - {" "} - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - target_vocab_size, token_counts, 1, max_count, num_iterations=1, - max_subtoken_length=max_subtoken_length) - - # All vocabulary elements are in the alphabet and subtoken strings even - # if we requested a smaller vocabulary to assure all expected strings - # are encodable. - self.assertTrue(alphabet.issubset(encoder._alphabet)) - for a in alphabet: - self.assertIn(a, encoder.all_subtoken_strings) - - def test_custom_reserved_tokens(self): - """Test that we can pass custom reserved tokens to SubwordTextEncoder.""" - corpus = "The quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - start_symbol = "" - end_symbol = "" - reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, - end_symbol] - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 10, token_counts, 2, 10, reserved_tokens=reserved_tokens) - - # Make sure that reserved tokens appear in the right places. - self.assertEqual(encoder.decode([2]), start_symbol) - self.assertEqual(encoder.decode([3]), end_symbol) - - # Make sure that we haven't messed up the ability to reconstruct. - reconstructed_corpus = encoder.decode(encoder.encode(corpus)) - self.assertEqual(corpus, reconstructed_corpus) - - def test_encodable_when_not_in_alphabet(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - original = "This has UPPER CASE letters that are out of alphabet" - - # Early versions could have an infinite loop when breaking into subtokens - # if there was any out-of-alphabet characters in the encoded string. - encoded = encoder.encode(original) - decoded = encoder.decode(encoded) - - self.assertEqual(original, decoded) - encoded_str = "".join(encoder.all_subtoken_strings[i] for i in encoded) - self.assertIn("\\84;", encoded_str) - - @mock.patch.object(text_encoder, "_ESCAPE_CHARS", new=set("\\_;13579")) - def test_raises_exception_when_not_encodable(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - # Deliberately exclude some required encoding chars from the alphabet - # and token list, making some strings unencodable. - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - original = "This has UPPER CASE letters that are out of alphabet" - - # Previously there was a bug which produced an infinite loop in this case. - with self.assertRaises(AssertionError): - encoder.encode(original) - - def test_load_from_file(self): - # Test a vocab file with words not wrapped with single quotes - encoder = text_encoder.SubwordTextEncoder() - correct_vocab = ["the", "and", "of"] - vocab = io.StringIO("the\n" - "and\n" - "of\n") - encoder._load_from_file_object(vocab) - self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) - - # Test a vocab file with words wrapped in single quotes - encoder = text_encoder.SubwordTextEncoder() - vocab = io.StringIO("\"the\"\n" - "\"and\"\n" - "\"of\"\n") - encoder._load_from_file_object(vocab) - self.assertAllEqual(encoder.all_subtoken_strings, correct_vocab) - - def test_reserved_token_chars_not_in_alphabet(self): - corpus = "dog" - token_counts = collections.Counter(corpus.split(" ")) - encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 100) - filename = os.path.join(self.test_temp_dir, "out.voc") - encoder1.store_to_file(filename) - encoder2 = text_encoder.SubwordTextEncoder(filename=filename) - - self.assertEqual(encoder1._alphabet, encoder2._alphabet) - - for t in text_encoder.RESERVED_TOKENS: - for c in t: - # Verify that encoders can encode all reserved token chars. - encoder1.encode(c) - encoder2.encode(c) - - def test_save_and_reload(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - # Deliberately exclude some required encoding chars from the alphabet - # and token list, making some strings unencodable. - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - filename = os.path.join(self.test_temp_dir, "out.voc") - encoder.store_to_file(filename) - new_encoder = text_encoder.SubwordTextEncoder(filename) - - self.assertEqual(encoder._alphabet, new_encoder._alphabet) - self.assertEqual(encoder.all_subtoken_strings, - new_encoder.all_subtoken_strings) - self.assertEqual(encoder._subtoken_string_to_id, - new_encoder._subtoken_string_to_id) - self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) - - def test_save_and_reload_no_single_quotes(self): - corpus = "the quick brown fox jumps over the lazy dog" - token_counts = collections.Counter(corpus.split(" ")) - - # Deliberately exclude some required encoding chars from the alphabet - # and token list, making some strings unencodable. - encoder = text_encoder.SubwordTextEncoder.build_to_target_size( - 100, token_counts, 2, 10) - - filename = os.path.join(self.test_temp_dir, "out.voc") - encoder.store_to_file(filename, add_single_quotes=False) - new_encoder = text_encoder.SubwordTextEncoder(filename) - - self.assertEqual(encoder._alphabet, new_encoder._alphabet) - self.assertEqual(encoder.all_subtoken_strings, - new_encoder.all_subtoken_strings) - self.assertEqual(encoder._subtoken_string_to_id, - new_encoder._subtoken_string_to_id) - self.assertEqual(encoder._max_subtoken_len, new_encoder._max_subtoken_len) - - def test_build_from_generator(self): - - corpus = "The quick brown fox jumps over the lazy dog" - - def gen(): - for _ in range(3): - yield corpus - - start_symbol = "" - end_symbol = "" - reserved_tokens = text_encoder.RESERVED_TOKENS + [start_symbol, - end_symbol] - encoder = text_encoder.SubwordTextEncoder.build_from_generator( - gen(), 10, reserved_tokens=reserved_tokens) - - # Make sure that reserved tokens appear in the right places. - self.assertEqual(encoder.decode([2]), start_symbol) - self.assertEqual(encoder.decode([3]), end_symbol) - - self.assertEqual("hi%s" % start_symbol, - encoder.decode(encoder.encode("hi") + [2])) - - # Make sure that we haven't messed up the ability to reconstruct. - reconstructed_corpus = encoder.decode(encoder.encode(corpus)) - self.assertEqual(corpus, reconstructed_corpus) - - -class OneHotClassLabelEncoderTest(tf.test.TestCase): - - def test_one_hot_encode(self): - encoder = text_encoder.OneHotClassLabelEncoder( - class_labels=["zero", "one", "two"]) - self.assertEqual(encoder.encode("zero"), [1, 0, 0]) - self.assertEqual(encoder.encode("one"), [0, 1, 0]) - self.assertEqual(encoder.encode("two"), [0, 0, 1]) - - def test_one_hot_decode(self): - encoder = text_encoder.OneHotClassLabelEncoder( - class_labels=["zero", "one", "two"]) - self.assertEqual(encoder.decode([1, 0, 0]), "zero") - self.assertEqual(encoder.decode([0, 1, 0]), "one") - self.assertEqual(encoder.decode([0, 0, 1]), "two") - - -if __name__ == "__main__": - tf.test.main() diff --git a/trax/data/tf_inputs.py b/trax/data/tf_inputs.py deleted file mode 100644 index 239b4d3a1..000000000 --- a/trax/data/tf_inputs.py +++ /dev/null @@ -1,2755 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""TensorFlow data sources and associated prepocessing functions.""" - -import functools -import itertools -import json -import math -import os -import random -import re - -from absl import logging -import gin -import jax -import numpy as np -import scipy -import scipy.special -import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_datasets as tfds -import tensorflow_text as tf_text -from trax import data -from trax import fastmath -from trax import layers as tl -from trax import supervised -from trax.data import debug_data_pipeline -from trax.data import text_encoder -from trax.fastmath import numpy as jnp - -# How many examples from the stream to skip at random during training. -# For now, we skip at most 100K examples for efficiency. -# TODO(lukaszkaiser): can we improve efficiency, should that be changed? -_MAX_SKIP_EXAMPLES = 1e5 - - -def t5_data(): - """Get the T5 data module if available.""" - module = None - try: - import t5.data # pylint: disable=g-import-not-at-top - module = t5.data - except AttributeError as e: - logging.error('pip install t5') - raise e - return module - - -def no_preprocess(dataset, training): - del training - return dataset - - -def t2t_problems(): - # Load t2t problems on request only, this should save some import time. - from tensor2tensor import problems_colab as t2tp # pylint: disable=g-import-not-at-top - return t2tp - - -# TODO(jonni): Rename function to better match its return values. -@gin.configurable(module='trax.data') -def data_streams(dataset_name, - data_dir=None, - preprocess_fn=no_preprocess, - bare_preprocess_fn=None, - shuffle_buffer_size=1024, - eval_holdout_size=0, - input_name=None, - target_name=None): - """Creates `(train, eval)` data sources from ``dataset_name``. - - Args: - dataset_name: Name of dataset belonging to TFDS or T2T. T2T dataset names - must start with ``'t2t_'``. - data_dir: Directory where the data is located. - preprocess_fn: Function to use for pre-processing after appending targets to - inputs. - bare_preprocess_fn: Function to use for pre-processing before appending - targets to inputs. - shuffle_buffer_size: Size of the shuffle buffer. - eval_holdout_size: If greater than 0, specifies a fraction of training data - to siphon off and use as eval data, in place of an separate eval split. - input_name: Name of the inputs from the dictionary. - target_name: Name of the outputs either from the dictionary or as a result - of post-processing. - - Returns: - A pair of functions, `(f, g)` for use as data sources; call `f()` to get an - iterator of training data samples, and call `g()` to get an iterator of eval - data samples. - """ - data_dir = download_and_prepare(dataset_name, data_dir) - - cache = [] - - def stream(which): - """Create the stream, cache TF streams if needed.""" - if not cache: - cache.append( - _train_and_eval_streams(dataset_name, data_dir, preprocess_fn, - bare_preprocess_fn, shuffle_buffer_size, - eval_holdout_size, input_name, target_name)) - - (train_ds, eval_ds, input_name_c) = cache[0] - dataset = eval_ds if which == 'eval' else train_ds - return dataset_to_stream(dataset, input_name_c) - - train_stream = lambda: stream('train') - eval_stream = lambda: stream('eval') - return train_stream, eval_stream - - -def dataset_to_stream(dataset, input_name): - """Takes a tf.Dataset and creates a numpy stream of ready batches.""" - # All input-pipeline processing should be on CPU. - for example in fastmath.dataset_as_numpy(dataset): - features = example[0] - inp, out = features[input_name], example[1] - mask = features['mask'] if 'mask' in features else None - # Some accelerators don't handle uint8 well, cast to int. - if isinstance(inp, np.uint8): - inp = inp.astype(np.int32) - if isinstance(out, np.uint8): - out = out.astype(np.int32) - yield (inp, out) if mask is None else (inp, out, mask) - - -def _train_and_eval_streams(dataset, data_dir, preprocess_fn, - bare_preprocess_fn, shuffle_buffer_size, - eval_holdout_size, input_name, target_name): - """Return train and eval batches with input name and shape.""" - (train_data, eval_data, - keys) = _train_and_eval_dataset(dataset, data_dir, eval_holdout_size) - # If provided select input_name/target_name else fall back to keys if that is - # available, else [None]. - input_names = ([input_name] if input_name is not None else - keys[0] if keys is not None else [None]) - target_names = ([target_name] if target_name is not None else - keys[1] if keys is not None else [None]) - - train_batches = _shuffle_data(train_data, target_names, True, - shuffle_buffer_size, preprocess_fn, - bare_preprocess_fn) - eval_batches = _shuffle_data(eval_data, target_names, False, - shuffle_buffer_size, preprocess_fn, - bare_preprocess_fn) - return (train_batches, eval_batches, input_names[0]) - - -def _shuffle_data(dataset, target_names, training, shuffle_buffer_size, - preprocess_fn, bare_preprocess_fn): - """Shuffle the given dataset and run pre-processing.""" - - def append_targets(example): - """Append targets to the example dictionary. Needed for Keras.""" - if len(target_names) == 1: - return (example, example[target_names[0]]) - targets = {} - for name in target_names: - targets[name] = example[name] - return (example, targets) - - # `bare_preprocess_fn` is called before appending targets etc. - if bare_preprocess_fn is not None: - dataset = bare_preprocess_fn(dataset, training) - dataset = dataset.map(append_targets) - # TODO(pkozakowski): Repeat both the training and evaluation set, so we don't - # have incomplete batches during evaluation. This will be a problem when we - # add an option to evaluate on the whole dataset, then we'll need to think of - # a different solution. - dataset = dataset.repeat() - if training: - # Skip a random fraction at the beginning of the stream. The skip is - # essential for synchronous highly-parallel training to avoid multiple - # replicas reading the same data in lock-step. - dataset = dataset.skip(random.randint(0, _MAX_SKIP_EXAMPLES)) - dataset = preprocess_fn(dataset, training) - dataset = dataset.shuffle(shuffle_buffer_size) - return dataset.prefetch(8) - - -def _train_and_eval_dataset(dataset_name, - data_dir, - eval_holdout_size, - train_shuffle_files=True, - eval_shuffle_files=False, - use_alt_eval=False, - subsplit=None): - """Return train and evaluation datasets, feature info and supervised keys. - - Args: - dataset_name: a string, the name of the dataset; if it starts with 't2t_' - then we'll search T2T Problem registry for it, otherwise we assume it is a - dataset from TFDS and load it from there. - data_dir: directory where the data is located. - eval_holdout_size: float from 0 to <1; if >0 use this much of training data - for evaluation (instead of looking for a pre-specified VALIDATION split). - train_shuffle_files: Boolean determining whether or not to shuffle the train - files at startup. Set to False if you want data determinism. - eval_shuffle_files: Boolean determining whether or not to shuffle the test - files at startup. Set to False if you want data determinism. - use_alt_eval: If True, use the dataset's alternate/secondary eval split; - else use the dataset's default/only eval split. Currently, only the - `glue/mnli` dataset provides an alternate eval split, and this arg is - ignored for other datasets. - subsplit: a pair of floats (x, y), both in [0, 1], saying which part of the - full training dataset we should return (default: all of it, [0, 1]). - - Returns: - a 4-tuple consisting of: - * the train tf.Dataset - * the eval tf.Dataset - * information about features: a python dictionary with feature names - as keys and an object as value that provides .shape and .n_classes. - * supervised_keys: information what's the input and what's the target, - ie., a pair of lists with input and target feature names. - """ - logging.info('Building TF data pipeline for %s', dataset_name) - if dataset_name.startswith('t2t_'): - return _train_and_eval_dataset_v1(dataset_name[4:], data_dir, - train_shuffle_files, eval_shuffle_files) - dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) - info = dataset_builder.info - splits = dataset_builder.info.splits - if dataset_name != 'c4/multilingual' and tfds.Split.TRAIN not in splits: - raise ValueError('To train we require a train split in the dataset.') - train_split = tfds.Split.TRAIN if dataset_name != 'c4/multilingual' else 'en' - eval_split = None - train_examples = info.splits[train_split].num_examples - eval_holdout_examples = int(train_examples * eval_holdout_size) - if eval_holdout_examples > 0 or subsplit is not None: - if subsplit is None: - subsplit = (0, 1) - n_train = train_examples - eval_holdout_examples - train_start = int(n_train * subsplit[0]) - train_end = int(n_train * subsplit[1]) - if train_end - train_start < 1: - raise ValueError('Requested train subsplit has no examples: ' - 'n_train %d subsplit %s' % (n_train, subsplit)) - # Eval holdout examples from the end of the training set. - if eval_holdout_examples > 0: - eval_split = f'{train_split}[-{eval_holdout_examples}:]' - # Shard the training set for this host. - train_split = f'{train_split}[{train_start}:{train_end}]' - - if dataset_name == 'glue/mnli': - eval_split = ( - 'validation_mismatched' if use_alt_eval else 'validation_matched') - elif dataset_name == 'c4/multilingual': - eval_split = 'en-validation' - elif eval_split is None: - if tfds.Split.VALIDATION not in splits and 'test' not in splits: - raise ValueError('We require a validation or test split in the dataset.') - eval_split = tfds.Split.VALIDATION - if tfds.Split.VALIDATION not in splits: - eval_split = tfds.Split.TEST - - train = tfds.load( - name=dataset_name, - split=train_split, - data_dir=data_dir, - shuffle_files=train_shuffle_files) - valid = tfds.load( - name=dataset_name, - split=eval_split, - data_dir=data_dir, - shuffle_files=eval_shuffle_files) - keys = None - if info.supervised_keys: - keys = ([info.supervised_keys[0]], [info.supervised_keys[1]]) - return train, valid, keys - - -# TODO(jonni): Consider renaming this function. -@gin.configurable(module='trax.data') -def TFDS( # pylint: disable=invalid-name - dataset_name, - data_dir=None, - tfds_preprocess_fn=None, - keys=None, - train=True, - use_alt_eval=False, - shuffle_train=True, - host_id=None, - n_hosts=None, - eval_holdout_size=0): - """Creates a data source from TensorFlow dataset ``dataset_name``. - - Args: - dataset_name: Name of the dataset, as registered in TensorFlow datasets - (e.g., ``'glue/mnli'``). - data_dir: Directory where the data is located. - tfds_preprocess_fn: If specified, function that applies to items in raw - dataset (before selecting specific features). - keys: Tuple of dataset-specific strings that select features from the - dataset. - train: If True, select the training split from the dataset; else select an - eval split. - use_alt_eval: If True, and if ``train`` is False, select the dataset's - alternate eval split if it has one (or fall back to the dataset's only - eval split). This currently affects only the `glue/mnli` dataset. - shuffle_train: If True, have TensorFlow pre-shuffle the training data; else - receive training data in deterministic sequence. - host_id: Integer id used for tracking data subsplits, in cases where - ``n_hosts`` > 1. - n_hosts: If greater than 1, prepare data subsplits for the given number of - hosts. - eval_holdout_size: If greater than 0, specifies a fraction of training data - to siphon off and use as eval data, in place of an separate eval split. - - Returns: - A function `f` for use as a training or eval data source; call `f()` to get - an iterator of data samples. - """ - data_dir = download_and_prepare(dataset_name, data_dir) - - host_id = jax.process_index() if host_id is None else host_id - n_hosts = n_hosts or jax.host_count() - if n_hosts > 1: - subsplit = (host_id / n_hosts, (host_id + 1) / n_hosts) - else: - subsplit = None - train_data, eval_data, _ = ( - _train_and_eval_dataset(dataset_name, - data_dir, - eval_holdout_size, - train_shuffle_files=shuffle_train, - use_alt_eval=use_alt_eval, - subsplit=subsplit)) - dataset = train_data if train else eval_data - dataset = dataset if tfds_preprocess_fn is None else tfds_preprocess_fn( - dataset) - - def select_from(example): - return tuple(example[k] for k in keys) - - dataset = dataset.map(select_from) - dataset = dataset.repeat() - - def gen(generator=None): - del generator - for example in fastmath.dataset_as_numpy(dataset): - yield example - - return gen - - -def _select_features(example, feature_list=None): - """Select a subset of features from the example dict.""" - feature_list = feature_list or ['inputs', 'targets'] - return {f: example[f] for f in feature_list if f in example} - - -def _eager_dataset_iterator(dataset): - for item in dataset: - flat = tf.nest.flatten(item) - flat = [el.numpy() for el in flat] - yield tf.nest.pack_sequence_as(item, flat) - - -def _train_and_eval_dataset_v1(problem_name, data_dir, train_shuffle_files, - eval_shuffle_files): - """Return train and evaluation datasets, feature info and supervised keys.""" - with tf.device('cpu:0'): - problem = t2t_problems().problem(problem_name) - hparams = None - if problem_name == 'video_bair_robot_pushing': - hparams = problem.get_hparams() - bair_robot_pushing_hparams(hparams) - train_dataset = problem.dataset( - tf_estimator.ModeKeys.TRAIN, - data_dir, - shuffle_files=train_shuffle_files, - hparams=hparams) - train_dataset = train_dataset.map(_select_features) - eval_dataset = problem.dataset( - tf_estimator.ModeKeys.EVAL, - data_dir, - shuffle_files=eval_shuffle_files, - hparams=hparams) - eval_dataset = eval_dataset.map(_select_features) - # TODO(lukaszkaiser): remove this need for one example, just input_key. - examples = list(tfds.as_numpy(train_dataset.take(1))) - # We use 'inputs' as input except for purely auto-regressive tasks like - # language models where 'targets' are used as input_key. - input_key = 'inputs' if 'inputs' in examples[0] else 'targets' - supervised_keys = ([input_key], ['targets']) - return train_dataset, eval_dataset, supervised_keys - - -# Tokenization. -@debug_data_pipeline.debug_pipeline -def tokenize(stream, - keys=None, - vocab_type='subword', - vocab_file=None, - vocab_dir=None, - n_reserved_ids=0): - """Tokenize examples from the stream. - - This function assumes that `stream` generates either strings or tuples/dicts - containing strings at some `keys`. This function maps these strings to - numpy arrays of integers -- the tokenized version of each string. - - Args: - stream: A python generator yielding strings, tuples or dicts. - keys: which keys of the tuple/dict to tokenize (by default: all) - vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. - vocab_file: Name of the vocabulary file. - vocab_dir: Directory which contains the vocabulary file. - n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; - This is common for example when reserving the 0 for padding and 1 for EOS, - but it's only needed if these symbols are not already included (and thus - reserved) in the vocab_file. - - Yields: - Examples from stream with strings at `keys` replaced by np.arrays of - integers -- the tokenized version of these strings. - """ - vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) - for example in stream: - if isinstance(example, (list, tuple)): - new_example = [] - for i, x in enumerate(example): - if keys is None or i in keys: - new_example.append(np.array(vocab.encode(x)) + n_reserved_ids) - else: - new_example.append(x) - output = tuple(new_example) - yield output - elif isinstance(example, dict): - new_example = {} - for k in example: - if keys is None or k in keys: - new_example[k] = np.array(vocab.encode(example[k])) + n_reserved_ids - else: - new_example[k] = example[k] - yield new_example - else: - output = np.array(vocab.encode(example)) + n_reserved_ids - yield output - - -@gin.configurable(module='trax.data') -def Tokenize( # pylint: disable=invalid-name - keys=None, - vocab_type='subword', # pylint: disable=invalid-name - vocab_file=None, - vocab_dir=None, - n_reserved_ids=0): - """Returns a function that maps text to integer arrays; see `tokenize`.""" - return lambda g: tokenize( # pylint: disable=g-long-lambda - g, - keys=keys, - vocab_type=vocab_type, - vocab_file=vocab_file, - vocab_dir=vocab_dir, - n_reserved_ids=n_reserved_ids) - - -def detokenize(x, - vocab_type='subword', - vocab_file=None, - vocab_dir=None, - n_reserved_ids=0): - """Maps integer arrays to text; the opposite of `tokenize`. - - In many cases (all char- and subword-type vocabularies and most sentencepiece - ones) the tokenization is invertible, so detokenize(tokenize(x)) = x. In some - more rare cases this can remove some spacing, but it is still often useful - to run detokenize to get a readable version for a tokenized string. - - Args: - x: a list or numpy array of integers. - vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. - vocab_file: Name of the vocabulary file. - vocab_dir: Directory which contains the vocabulary file. - n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused; - This is common for example when reserving the 0 for padding and 1 for EOS, - but it's only needed if these symbols are not already included (and thus - reserved) in the vocab_file. - - Returns: - A string corresponding to the de-tokenized version of x. - """ - vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) - x_unreserved = np.array(x) - n_reserved_ids - return str(vocab.decode(x_unreserved.tolist())) - - -def _to_unicode(s): - # Errors of the casting are ignored (e.g. sequences not allowed by UTF-8), - # in order not to stay with incomplete examples (with empty values). - return str(s, encoding='utf-8', errors='ignore') - - -@gin.configurable(module='trax.data') -def ConvertToUnicode(keys=None): # pylint: disable=invalid-name - """Converts to Unicode UTF-8 elements of an example. - - Useful for when TFDS outputs byte arrays. All of the errors of the conversion - are ignored. - - Args: - keys: tuple/list of example dimensions to convert. - - Returns: - Function converting chosen elements of an example to UTF-8. - """ - - @debug_data_pipeline.debug_pipeline - def _convert_to_unicode_str(stream): - for example in stream: - if isinstance(example, (list, tuple)): - new_example = [] - for i, x in enumerate(example): - if keys is None or i in keys: - new_example.append(_to_unicode(x)) - else: - new_example.append(x) - output = tuple(new_example) - yield output - elif isinstance(example, dict): - new_example = {} - for k in example: - if keys is None or k in keys: - new_example[k] = _to_unicode(example[k]) - else: - new_example[k] = example[k] - yield new_example - else: - output = _to_unicode(example) - yield output - - return _convert_to_unicode_str - - -def vocab_size(vocab_type='subword', - vocab_file=None, - vocab_dir=None, - n_reserved_ids=0): - """Returns the size of the vocabulary (number of symbols used). - - This function can be used to set the size of the final layers of a model that - needs to predict symbols from a given vocabulary. More precisely, if this - function returns N then the last layer size should be set to at least N (it - can be more). Note that this function does take reserved IDs into account. - - Args: - vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. - vocab_file: Name of the vocabulary file. - vocab_dir: Directory which contains the vocabulary file. - n_reserved_ids: An int, offset added so 0, ..., n_reserved_ids-1 are unused. - - Returns: - An integer, the number of symbols used (including reserved IDs). - """ - vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) - return vocab.vocab_size + n_reserved_ids - - -def _get_vocab(vocab_type='subword', vocab_file=None, vocab_dir=None, - extra_ids=0): - """Gets the vocabulary object for tokenization; see tokenize for details.""" - if vocab_type not in [ - 'char', 'subword', 'sentencepiece', 'bert', 'bert-lowercase' - ]: - raise ValueError( - 'vocab_type must be "subword", "char", "sentencepiece", "bert" or "bert-lowercase" ' - f'but got {vocab_type}') - - if vocab_type == 'char': - # Note that we set num_reserved_ids=0 below. We could instead pass - # the value n_reserved_ids from tokenize here -- ByteTextEncoder does - # exactly the same thing as tokenize above, ie., adds num_reserved_ids. - return text_encoder.ByteTextEncoder(num_reserved_ids=0) - - vocab_dir = vocab_dir or 'gs://trax-ml/vocabs/' - path = os.path.join(vocab_dir, vocab_file) - - if vocab_type == 'subword': - return text_encoder.SubwordTextEncoder(path) - - if vocab_type == 'bert': - return text_encoder.BertEncoder(path, do_lower_case=False) - - if vocab_type == 'bert-lowercase': - return text_encoder.BertEncoder(path, do_lower_case=True) - - assert vocab_type == 'sentencepiece' - return t5_data().SentencePieceVocabulary(sentencepiece_model_file=path, - extra_ids=extra_ids) - - -# Makes the function accessible in gin configs, even with all args denylisted. -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def cifar10_no_augmentation_preprocess(dataset, training): - del training - - def cast_image(features, targets): - features['image'] = tf.cast(features['image'], tf.float32) / 255.0 - return features, targets - - dataset = dataset.map(cast_image) - return dataset - - -def _cifar_augment_image(image): - """Image augmentation suitable for CIFAR-10/100. - - As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5). - - Args: - image: a Tensor. - - Returns: - Tensor of the same shape as image. - """ - image = tf.image.resize_with_crop_or_pad(image, 40, 40) - image = tf.image.random_crop(image, [32, 32, 3]) - image = tf.image.random_flip_left_right(image) - return image - - -# Makes the function accessible in gin configs, even with all args denylisted. -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def cifar10_augmentation_preprocess(dataset, training): - """Preprocessing for cifar10 with augmentation (see below).""" - - def augment(features, targets): - features['image'] = _cifar_augment_image(features['image']) - return features, targets - - def cast_image(features, targets): - features['image'] = tf.cast(features['image'], tf.float32) / 255.0 - return features, targets - - if training: - dataset = dataset.map(augment) - dataset = dataset.map(cast_image) - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def cifar10_augmentation_flatten_preprocess(dataset, - training, - predict_image_train_weight=0.01): - """Preprocessing for cifar10 that flattens it and appends targets.""" - - def augment(features, targets): - features['image'] = _cifar_augment_image(features['image']) - return features, targets - - def flatten_image(features, targets): - """Flatten the image.""" - img = features['image'] - flat = tf.cast(tf.reshape(img, [-1]), tf.int64) - tgt = tf.expand_dims(targets, axis=0) - flat_with_target = tf.concat([flat, tgt], axis=0) - new_features = {} - new_features['image'] = flat_with_target - predict_image_weight = predict_image_train_weight if training else 0.0 - mask_begin = tf.ones_like(flat) - mask_begin = tf.cast(mask_begin, tf.float32) * predict_image_weight - mask_end = tf.cast(tf.ones_like(tgt), tf.float32) - new_features['mask'] = tf.concat([mask_begin, mask_end], axis=0) - return new_features, flat_with_target - - if training: - dataset = dataset.map(augment) - dataset = dataset.map(flatten_image) - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def downsampled_imagenet_flatten_bare_preprocess(dataset, training): - """Preprocessing for downsampled_imagenet. - - Args: - dataset: the dataset. - training: unused option. - - Returns: - Flattened dataset. - - Preprocessing for downsampled_imagenet 32x32 and 64x64 generation from - http://arxiv.org/abs/1601.06759 (page 8). - """ - del training - - def flatten_image(features): - img = features['image'] - flat = tf.cast(tf.reshape(img, [-1]), tf.int64) - - new_features = {'image': flat} - return new_features - - return dataset.map(flatten_image) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def concat_preprocess(dataset, training, pad_symbol=0): - """Pre-processing function that concatenates input and target for LM.""" - del training - - def concat(features, targets): - inp = features['inputs'] - pad = tf.expand_dims(tf.zeros_like(inp[0]) + pad_symbol, axis=0) - concat = tf.concat([pad, inp, pad, targets], axis=0) - # Note: we're updating existing features dictionary here, so make sure - # it is not re-used in some other ways outside of this function. - features['inputs'] = concat - return features, concat - - dataset = dataset.map(concat) - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def squeeze_targets_preprocess(dataset, training): - """Pre-processing function that squeezes last axis of targets.""" - del training - - def squeeze(features, targets): - if targets.shape[-1] == 1: - targets = tf.squeeze(targets, axis=-1) - return features, targets - - dataset = dataset.map(squeeze) - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def lm1b_preprocess(dataset, - training, - max_target_length=-1, - max_eval_target_length=-1): - """Preprocessing for LM1B: filter out targets exceeding maximum length.""" - - def target_right_length(_, target): - return tf.less(tf.shape(target)[0], max_target_length + 1) - - def eval_target_right_length(_, target): - return tf.less(tf.shape(target)[0], max_eval_target_length + 1) - - if max_target_length > 0 and training: - dataset = dataset.filter(target_right_length) - - if max_eval_target_length > 0 and not training: - dataset = dataset.filter(eval_target_right_length) - - return dataset - - -# TODO(lukaszkaiser): find a single more abstract way of text pre-processing. -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def wmt_preprocess(dataset, training, max_length=-1, max_eval_length=-1): - """Preprocessing for LM1B: filter out targets exceeding maximum length.""" - - def train_right_length(example, target): - l = tf.maximum(tf.shape(example['inputs'])[0], tf.shape(target)[0]) - return tf.less(l, max_length + 1) - - def eval_right_length(example, target): - l = tf.maximum(tf.shape(example['inputs'])[0], tf.shape(target)[0]) - return tf.less(l, max_eval_length + 1) - - if max_length > 0 and training: - dataset = dataset.filter(train_right_length) - - if max_eval_length > 0 and not training: - dataset = dataset.filter(eval_right_length) - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def wmt_concat_preprocess(dataset, training, max_length=-1, max_eval_length=-1): - """Preprocessing for WMT: filter exceeding maximum length and concatenate.""" - dataset = wmt_preprocess(dataset, training, max_length, max_eval_length) - - def concat_and_add_mask(features, targets): - inp = features['inputs'] - pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) - concat = tf.concat([inp, pad, targets], axis=0) - mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) - features['inputs'] = concat - features['mask'] = mask - return features, concat - - dataset = dataset.map(concat_and_add_mask) - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def lm_token_preprocessing(dataset, training): - """Concatenates inputs, 0, targets, with masking only for targets.""" - del training - - def concat_and_add_mask(x): - inp = x['inputs'] - targets = x['targets'] - pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) - concat = tf.concat([inp, pad, targets], axis=0) - mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) - x['inputs'] = concat - x['targets'] = concat - x['mask'] = mask - return x - - dataset = dataset.map(concat_and_add_mask) - return dataset - - -@gin.configurable(module='trax.data', denylist=['hparams']) -def bair_robot_pushing_hparams(hparams=None, - video_num_input_frames=1, - video_num_target_frames=15): - if hparams is not None: - hparams.video_num_input_frames = video_num_input_frames - hparams.video_num_target_frames = video_num_target_frames - else: - return video_num_input_frames, video_num_target_frames - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def bair_robot_pushing_preprocess(dataset, training): - """Pre-processing function that concatenates input and target frames.""" - del training - - def concat_and_add_mask(features, targets): - """Concatenate input and output frames to form a language modeling setup.""" - inp = features['inputs'] - concat = tf.concat([inp, targets], axis=0) - mask = tf.concat([tf.zeros_like(inp), tf.ones_like(targets)], axis=0) - concat = tf.reshape(concat, (-1,)) - mask = tf.reshape(mask, (-1,)) - concat = tf.cast(concat, tf.int32) - mask = tf.cast(mask, tf.float32) - features['inputs'] = features['targets'] = concat - features['mask'] = mask - return features, concat - - dataset = dataset.map(concat_and_add_mask) - return dataset - - -def sentencepiece_tokenize(stream, spm_path=None, extra_ids=0): - """Sentencepiece tokenization.""" - spm_path = spm_path or t5_data().DEFAULT_SPM_PATH - vocab_file = os.path.basename(spm_path) - vocab_dir = os.path.dirname(spm_path) - vocab = _get_vocab(vocab_type='sentencepiece', - vocab_file=vocab_file, - vocab_dir=vocab_dir, - extra_ids=extra_ids) - for example in stream: - # example could either be str or (str,) - if isinstance(example, tuple): - example = example[0] - yield np.array(vocab.encode(example)) - - -@gin.configurable(module='trax.data') -def SentencePieceTokenize( # pylint: disable=invalid-name - spm_path=None, - extra_ids=0): - """Returns a function that maps text to integer arrays.""" - return lambda g: sentencepiece_tokenize( # pylint: disable=g-long-lambda - g, - spm_path=spm_path, - extra_ids=extra_ids) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def c4_preprocess(dataset, - training, - max_target_length=-1, - tokenization=None, - spm_path=None): - """Pre-processing function for C4 dataset.""" - del training - - def unicode_decode_chars(features, targets): - targets = tf.strings.unicode_decode(features['text'], 'UTF-8') - targets = tf.cast(targets, tf.int64) - features['targets'] = targets - features['inputs'] = targets - return (features, targets) - - def spc_tokenize(tokenizer, features, targets): - del targets - tokenized_text = tokenizer.tokenize(features['text']) - features['targets'] = tf.cast(tokenized_text, tf.int64) - features['inputs'] = features['targets'] - return features, features['targets'] - - if tokenization == 'spc': - spm_path = spm_path or t5_data().DEFAULT_SPM_PATH - with tf.compat.v1.gfile.GFile(spm_path, 'rb') as f: - spc_model = f.read() - tokenizer = tf_text.SentencepieceTokenizer(model=spc_model) - dataset = dataset.map(functools.partial(spc_tokenize, tokenizer)) - else: - dataset = dataset.map(unicode_decode_chars) - - def target_right_length(_, target): - return tf.less(tf.shape(target)[0], max_target_length + 1) - - if max_target_length > 0: - dataset = dataset.filter(target_right_length) - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def c4_bare_preprocess_fn(dataset, - training=True, - spm_path=None, - copy_pretokenized=True, - sequence_length=None): - """Returns a dataset that contains 'inputs' and 'targets' from C4.""" - # Set target key to be equal to the text content. - dataset = t5_data().preprocessors.rekey( - dataset, key_map={ - 'targets': 'text', - 'inputs': None - }) - - # Vocabulary for tokenization. - extra_ids = 0 - vocab = t5_data().SentencePieceVocabulary( - sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH, - extra_ids=extra_ids) - feature = t5_data().Feature(vocab) - output_features = {'targets': feature, 'inputs': feature} - - # Tokenize the targets. - keys = output_features - - def encode_string_features_fn(features): - """Encode all specified feature that are strings and return a dictionary. - - Args: - features: a dictionary - - Returns: - a dictionary - """ - ret = {} - for k, v in features.items(): - if k in keys and v.dtype == tf.string: - if copy_pretokenized: - ret['%s_pretokenized' % k] = v - v = tf.cast(output_features[k].vocabulary.encode_tf(v), tf.int64) - ret[k] = v - return ret - - dataset = dataset.map( - encode_string_features_fn, - num_parallel_calls=tf.data.experimental.AUTOTUNE) - - # Preprocess the tokens - the exact preprocessors are set via gin. - dataset = t5_data().preprocessors.unsupervised( - dataset, sequence_length=sequence_length, output_features=output_features) - - # Add EOS. - dataset = add_eos_to_output_features(dataset, training) - - # Truncate and then pad the examples -- all examples have the same shape. - dataset = truncate_dataset_on_len(dataset, training, sequence_length, True) - dataset = pad_dataset_to_length(dataset, training, sequence_length) - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def filter_dataset_on_len(dataset, - training, - len_map=None, - filter_on_eval=False): - """Filters a dataset of lengths given in `len_map`. - - Args: - dataset: `tf.data.Dataset` the dataset to filter. - training: bool, true if we are in training mode. - len_map: optional dict of str to (int, int). We filter examples where a - feature's size is beyond the specified bounds. Ex: - {'inputs': (1, 512), 'targets': (64, 128)} will keep only those examples - where 1 <= len(inputs) <= 512 and 64 <= len(targets) <= 128. - filter_on_eval: bool if true, we will filter in eval mode also. - - Returns: - a filtered `tf.data.Dataset`. - """ - if (len_map is None) or (not training and not filter_on_eval): - return dataset - - assert isinstance(len_map, dict) - for k, bounds in len_map.items(): - # pylint: disable=cell-var-from-loop - # TODO(afrozm): Investigate `cell-var-from-loop` - since this is WAI and - # there is a test too. - def within_bounds(x, key, len_bounds): - size = tf.shape(x[key])[0] - min_len, max_len = len_bounds - return (min_len <= size) and (size <= max_len) - - dataset = dataset.filter(lambda x: within_bounds(x, k, bounds)) - # pylint: enable=cell-var-from-loop - - return dataset - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def truncate_dataset_on_len(dataset, - training, - len_map=None, - truncate_on_eval=False): - """Truncates features in an example to lengths given in `len_map`. - - Args: - dataset: `tf.data.Dataset` the dataset to filter. - training: bool, true if we are in training mode. - len_map: optional dict of str to int, we truncate examples where a feature's - size is beyond the max. Ex: {'inputs': 512, 'targets': 64} will truncate - examples to be within those bounds. - truncate_on_eval: bool if true, we will truncate in eval mode also. - - Returns: - a filtered `tf.data.Dataset`. - """ - if (len_map is None) or (not training and not truncate_on_eval): - return dataset - - assert isinstance(len_map, dict) - - def truncate_example(x): - for key, max_len in len_map.items(): - x_len = tf.shape(x[key])[0] - if x_len > max_len: - x[key] = x[key][:max_len, ...] - return x - - return dataset.map(truncate_example) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def pad_dataset_to_length(dataset, training, len_map=None): - """Pad features less than specified length to specified length.""" - del training - if len_map is None: - return dataset - - def pad_to_len(x): - for key, max_len in len_map.items(): - x_shape = tf.shape(x[key]) - x_len = x_shape[0] - if x_len < max_len: - pad_shape = [ - max_len - x_len, - ] - zeros = tf.zeros(pad_shape, dtype=x[key].dtype) - x[key] = tf.concat([x[key], zeros], 0) - return x - - return dataset.map(pad_to_len) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def add_eos_to_output_features(dataset, - training, - output_features='targets', - eos=1): - """Adds `EOS` to all features in `output_features`.""" - del training - if not isinstance(output_features, (list, tuple)): - output_features = [output_features] - - def add_eos(x): - for output_feature in output_features: - x[output_feature] = tf.concat([x[output_feature], [eos]], axis=0) - return x - - return dataset.map(add_eos) - - -@gin.configurable(module='trax.data', denylist=['dataset', 'training']) -def generic_text_dataset_preprocess_fn(dataset, - training=True, - text_preprocess_fns=None, - token_preprocess_fns=None, - spm_path=None, - copy_pretokenized=False, - debug_print_examples=False, - debug_print_examples_rate=0.01): - """Pre-processes, tokenizes and post-processes a `tf.data.Dataset`. - - Args: - dataset: `tf.data.Dataset` to process. - training: boolean, set to True if training, False otherwise. - text_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> - `tf.data.Dataset` this operates before tokenization. Typically used to - select which fields we want to learn over or change something into "text - to text" form. - token_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool -> - `tf.data.Dataset`, this operates after tokenization. Since this can view - the tokenized fields, this can be used to filter on length etc. - spm_path: None or str, path to a sentencepiece model to use for tokenization - by default uses the 32k vocabulary from T5. - copy_pretokenized: bool, if True retains the original fields after - tokenization. - debug_print_examples: bool, if True this prints examples to the logging - stream for inspection, both before and after tokenization. - debug_print_examples_rate: float, [0, 1.0], on average this fraction of - dataset examples will be printed out in each phase i.e. pre and post - tokenization. - - Returns: - a `tf.data.Dataset` with all the preprocessing and tokenization performed. - """ - - # The assumption is that `text_preprocess_fns` finally gives us a dataset - # which has `inputs` and `targets`. - if text_preprocess_fns is not None: - for text_preprocess_fn in text_preprocess_fns: - dataset = text_preprocess_fn(dataset, training) - - # Print debugging examples if needed before tokenization. - if debug_print_examples: - - def print_examples(x): - if np.random.uniform() < debug_print_examples_rate: - tf.print(x, output_stream=logging.info) - return x - - dataset = dataset.map(print_examples) - - # Vocabulary for tokenization. - extra_ids = 0 - vocab = t5_data().SentencePieceVocabulary( - sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH, - extra_ids=extra_ids) - feature = t5_data().Feature(vocab) - output_features = {'targets': feature, 'inputs': feature} - - # Tokenize the inputs and targets. - dataset = t5_data().preprocessors.tokenize( - dataset, output_features, copy_pretokenized=copy_pretokenized) - - # Apply the token-preprocessors. - if token_preprocess_fns is not None: - for token_preprocess_fn in token_preprocess_fns: - dataset = token_preprocess_fn(dataset, training) - - if debug_print_examples: - - def print_examples_and_shapes(x): - if np.random.uniform() < debug_print_examples_rate: - tf.print( - { - 'inputs_shape': tf.size(x['inputs']), - 'targets_shape': tf.size(x['targets']), - 'inputs': x['inputs'], - 'targets': x['targets'], - }, - output_stream=logging.info) - return x - - dataset = dataset.map(print_examples_and_shapes) - - return dataset - - -@gin.configurable(module='trax.data') -def get_t5_preprocessor_by_name(name=None, fn_kwargs=None): - """Returns a closure of any T5 preprocessor function with its arguments. - - The main use-case is to use this (with gin scopes) to make any preprocessor - function available in a gin file to configure and use. - - See: `TFInputs.test_gin_configurable_preprocessors` - - Args: - name: str, name of the preprocessor function to configure. - fn_kwargs: optional dictionary, the arguments to configure, these will be - partially applied to the function given by `name`. - - Returns: - a closure of the preprocessor function along with its arguments, this - function takes two arguments only, dataset and boolean training and ignores - the training and calls the t5 processor with the dataset (and closed over - arguments only). - """ - - assert name is not None - f = getattr(t5_data().preprocessors, name) - if fn_kwargs is not None: - f = functools.partial(f, **fn_kwargs) - return lambda ds, unused_training: f(ds) - - -def download_and_prepare(dataset_name, data_dir): - """Downloads and prepares T2T or TFDS dataset. - - Args: - dataset_name: tfds dataset or t2t problem name prefixed by 't2t_'. - data_dir: location of existing dataset or None. - - Returns: - data_dir: path string of downloaded data. - """ - if not data_dir: - data_dir = os.path.expanduser('~/tensorflow_datasets/') - dl_dir = os.path.join(data_dir, 'download') - logging.info( - 'No dataset directory provided. ' - 'Downloading and generating dataset for %s inside data directory %s ' - 'For large datasets it is better to prepare datasets manually!', - dataset_name, data_dir) - if dataset_name.startswith('t2t_'): - # Download and run dataset generator for T2T problem. - data_dir = os.path.join(data_dir, dataset_name) - tf.io.gfile.makedirs(data_dir) - tf.io.gfile.makedirs(dl_dir) - t2t_problems().problem(dataset_name[len('t2t_'):]).generate_data( - data_dir, dl_dir) - else: - # Download and prepare TFDS dataset. - tfds_builder = tfds.builder(dataset_name) - tfds_builder.download_and_prepare(download_dir=dl_dir) - else: - data_dir = os.path.expanduser(data_dir) - return data_dir - - -def BertSingleSentenceInputs(batch, # pylint: disable=invalid-name - labeled=True, - cls_id=101, - sep_id=102): - """Prepares inputs for BERT: add [SEP], [CLS] and create embeddings.""" - if labeled: - for sent1, label in batch: - value_vector = np.concatenate(([cls_id], sent1, [sep_id])) - segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) - yield value_vector, segment_embs, segment_embs, label, np.int32(1) - else: - for (sent1,) in batch: # row is a tuple with 1 element - value_vector = np.concatenate(([cls_id], sent1, [sep_id])) - segment_embs = np.zeros(sent1.shape[0] + 2, dtype=np.int32) - yield value_vector, segment_embs, segment_embs - - -def BertDoubleSentenceInputs(batch, # pylint: disable=invalid-name - labeled=True, - cls_id=101, - sep_id=102): - """Prepares inputs for BERT models by adding [SEP] and [CLS] tokens and creating segment embeddings.""" - if labeled: - for sent1, sent2, label in batch: - value_vector = np.concatenate( - ([cls_id], sent1, [sep_id], sent2, [sep_id])) - - segment_embs = np.zeros( - sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) - second_sent_start = sent1.shape[0] + 2 - segment_embs[second_sent_start:] = 1 - yield value_vector, segment_embs, segment_embs, label, np.int32(1) - else: - for sent1, sent2 in batch: - value_vector = np.concatenate( - ([cls_id], sent1, [sep_id], sent2, [sep_id])) - - segment_embs = np.zeros( - sent1.shape[0] + sent2.shape[0] + 3, dtype=np.int32) - second_sent_start = sent1.shape[0] + 2 - segment_embs[second_sent_start:] = 1 - yield value_vector, segment_embs, segment_embs - - -@gin.configurable(module='trax.data') -def CreateBertInputs(double_sentence=True, # pylint: disable=invalid-name - labeled=True, - cls_id=101, - sep_id=102): - bert_inputs_fn = BertDoubleSentenceInputs if double_sentence else BertSingleSentenceInputs - return functools.partial( - bert_inputs_fn, labeled=labeled, cls_id=cls_id, sep_id=sep_id) - - -@gin.configurable(module='trax.data') -def mask_random_tokens(batch, - explicit_vocab_size=30522, - masking_prob=0.15, - cls_id=101, - sep_id=102, - mask_id=103, - vocab_start_id=999): - """Prepares input for the masking task. - - Preparation consist in masking masking_prob percentage of non-special tokens - at each input row; round(masking_prob * num_nonspecial_tokens) random tokens - are selected out of which each token is either - - replaced with [MASK] token with 80% probability, - - replaced with random token with 10% probability, - - or unchanged with 10%. - The implentation is based on - https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L342 - - Examples: - - batch is a stream with each row having tuple (token_ids,). Function yields - rows of form (modified_token_ids, original_tokens, token_weights), where - modified_token_ids have [MASK] tokens or random tokens according to the - procedure described above. - - batch is a stream with each row having tuple (token_ids, segment_embeddings, - nsp_label, nsp_weight).Function yields rows of form (modified_token_ids, - segment_embeddings, nsp_label, nsp_weight, original_tokens, token_weights). - - Args: - batch: stream of inputs. Each row in the stream is a tuple which first - element is an array of tokens - explicit_vocab_size: the total size of the vocabulary. - masking_prob: Determines percent of non-special tokens to be selected for - masking. - cls_id: id of the special CLS token. - sep_id: id of the special SEP token. - mask_id: id of the special MASK token. - vocab_start_id: id of first non-special token in the vocabulary. - - Yields: - a stream with tokens masked for MLM training and 2 appended arrays: - - original tokens: a copy of original tokens used as a label for mlm - training - - token_weights: weights distributed uniformly over selected tokens (sum - is 1). Other tokens have 0 weight. - """ - for token_ids, *row_rest in batch: - original_tokens = token_ids.copy() - - # choose tokens for prediction. Chooses 0.15 of - # all non-special tokens - is_special_token = np.logical_or(token_ids == cls_id, - token_ids == sep_id) # CLS and SEP tokens - is_special_token = np.logical_or(is_special_token, - token_ids == 0) # padding - viable_ids = np.arange(token_ids.shape[0])[~is_special_token] - num_to_sample = round(masking_prob * viable_ids.shape[0]) - if num_to_sample == 0: - # sentence is too short to select given percentage of tokens to mask - continue - candidate_ids = np.random.choice(viable_ids, num_to_sample, replace=False) - - # create weights - token_weights = np.zeros(token_ids.shape) - token_weights[candidate_ids] = 1 / candidate_ids.shape[0] - - prob_scores = np.random.random(candidate_ids.shape) - - # change 80 % of tokens to [MASK] - mask_token_ids = candidate_ids[prob_scores < 0.8] - token_ids[mask_token_ids] = mask_id - - # change 10% of tokens to random token - random_token_ids = candidate_ids[(0.8 <= prob_scores) & (prob_scores < 0.9)] - token_ids[random_token_ids] = np.random.randint(vocab_start_id, - explicit_vocab_size, - random_token_ids.shape[0]) - - # rest (10%) is left unchaged - yield (token_ids, *row_rest, original_tokens, token_weights) - - -@gin.configurable(module='trax.data') -def BertNextSentencePredictionInputs(dataset_name, # pylint: disable=invalid-name - data_dir=None, - text_key='text', - train=True, - shuffle_size=50000): - """Defines a stream for the next sentence prediction task.""" - stream = TFDS( - dataset_name, - data_dir=data_dir, - tfds_preprocess_fn=functools.partial( - t5_data().preprocessors.next_sentence_prediction, - text_key=text_key, - label_sentences=True, - buffer_size=shuffle_size), - keys=['inputs', 'targets'], - train=train) - - def split_stream(generator=None): - # split string with 'sentence1:' and 'sentence2:' into two separate strings - for text, target in stream(generator): - text_str = str(text)[:-1] # removes last '"' which is always at the end - sentences = text_str.split('sentence1: ')[1].split(' sentence2: ') - if len(sentences) != 2: - # 'sentence2:' appeared in the text and got mixed up with the label - continue - sent1, sent2 = sentences - yield sent1, sent2, target == 'next' - - return split_stream - - -@gin.configurable(module='trax.data') -def CorpusToRandomChunks(dataset_name, num_tokens=512, train=True): # pylint: disable=invalid-name - return TFDS( - dataset_name, - tfds_preprocess_fn=functools.partial( - t5_data().preprocessors.random_split_text, - max_words_per_segment=num_tokens), - train=train, - keys=['text']) - - -_GLUE_KEYS = { - 'cola': ('sentence',), - 'sst2': ('sentence',), - 'mrpc': ('sentence1', 'sentence2'), - 'qqp': ('question1', 'question2'), - 'stsb': ('sentence1', 'sentence2'), - 'mnli': ('premise', 'hypothesis'), - 'qnli': ('question', 'sentence'), - 'rte': ('sentence1', 'sentence2'), - 'wnli': ('sentence1', 'sentence2'), -} - - -# Labels inferred from the T5 paper: https://arxiv.org/pdf/1910.10683.pdf -_GLUE_LABELS = { - 'cola': ('unacceptable', 'acceptable'), - 'sst2': ('negative', 'positive'), - 'mrpc': ('not_equivalent', 'equivalent'), - 'qqp': ('not_duplicate', 'duplicate'), - 'stsb': ('sentence1', 'sentence2'), - 'mnli': ('entailment', 'neutral', 'contradiction'), - 'qnli': ('entailment', 'not_entailment'), - 'rte': ('entailment', 'not_entailment'), - 'wnli': ('sentence1', 'sentence2'), -} - -# Defining separate TrainStream and EvalStream functions (below) -# makes gin configuration expressions more direct. A single gin line can -# configure each; for example: -# -# BertGlueTrainStream.benchmark= 'mnli' -# BertGlueEvalStream.benchmark = 'mnli' - - -# pylint: disable=invalid-name -@gin.configurable(module='trax.data') -def BertGlueTrainStream(benchmark=gin.REQUIRED): - """Returns a Bert-preprocessed training stream for ``benchmark``. - - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. - """ - return _BertGlueDataStream(benchmark + '_t') - - -# GLUE evals need special handling because one eval in particular, MNLI, has -# two different eval sets: "matched" and "mismatched". The code in this module -# distinguishes between the two using the suffixes '_e' versus '_e2', -# respectively. -def _ensure_eval_suffix(benchmark): - """Returns a string ending in an eval suffix; adds ``'_e'`` suffix if needed. - - Args: - benchmark: Name of a benchmark or task, that might already include an - eval-indicating suffix (``'_e'`` or ``'_e2'``). - """ - if benchmark.endswith('_e') or benchmark.endswith('_e2'): - return benchmark - else: - return benchmark + '_e' - - -@gin.configurable(module='trax.data') -def BertGlueEvalStream(benchmark=gin.REQUIRED): - """Returns a Bert-preprocessed eval data stream for ``benchmark``. - - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. If the benchmark includes an alternate - eval (e.g., MNLI's "mismatched" eval/validation split), you can - specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. - """ - return _BertGlueDataStream(_ensure_eval_suffix(benchmark)) - - -def _BertGlueDataStream(benchmark_id): - """Returns a Bert-preprocessed data stream for ``benchmark_id``. - - Args: - benchmark_id: String that indicates the name and data split of a GLUE - benchmark. Data splits are indicated as underscore suffixes, e.g., - ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE - benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, - alternate "mismatched" eval/validation split). - """ - benchmark_id = _ensure_eval_suffix(benchmark_id) - benchmark, split = benchmark_id.rsplit('_', 1) - glue_data = TFDS(f'glue/{benchmark}', - keys=_GLUE_KEYS[benchmark], - train=(split == 't'), - use_alt_eval=(split == 'e2')) - return data.Serial( - glue_data, - data.Tokenize(), - data.CreateBertInputs(), - data.Shuffle(), - data.PadToLength(), - data.TruncateToLength(), - data.Batch(), - ) - - -@gin.configurable(module='trax.data') -def T5GlueTrainStream(benchmark=gin.REQUIRED): - """Returns a T5-preprocessed training data stream for ``benchmark``. - - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. - """ - return _T5GlueDataStream(benchmark + '_t') - - -@gin.configurable(module='trax.data') -def T5GlueTrainStreamsParallel(benchmark_list=gin.REQUIRED, - counters=None, - reweight_by_minimum=False, - gradually_reweight=False): - """Returns a parallel set of training streams, based on ``benchmark_list``. - - Args: - benchmark_list: List of simple lower-case names of GLUE benchmarks, e.g., - ``'cola'``, ``'mnli'``, ``'rte'``. - counters: a list of counters to be passed to data.Parallel, e.g., - [8551, 392702, 2490] would be a reasonable counterpart to - benchmark_list = ["cola", "mnli", "rte"], see - https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/glue_utils.py#L42 - for more details on counters. - reweight_by_minimum: divide by the minimal counter. - gradually_reweight: a more refined reweighting policy, see inputs.py - for more details. - """ - stream_list = list(map(T5GlueTrainStream, benchmark_list)) - return data.Parallel( - stream_list, - counters=counters, - reweight_by_minimum=reweight_by_minimum, - gradually_reweight=gradually_reweight)() - - -@gin.configurable(module='trax.data') -def T5GlueEvalStream(benchmark=gin.REQUIRED): - """Returns a T5-preprocessed eval data stream for ``benchmark``. - - Args: - benchmark: Simple lower-case name of a GLUE benchmark, e.g., ``'cola'``, - ``'mnli'``, ``'rte'``. If the benchmark includes an alternate - eval (e.g., MNLI's "mismatched" eval/validation split), you can - specify it with an ``'_e2'`` suffix, e.g., ``'mnli_e2'``. - """ - return _T5GlueDataStream(_ensure_eval_suffix(benchmark)) - - -@gin.configurable(module='trax.data') -def T5GlueEvalStreamsParallel(benchmark_list=gin.REQUIRED): - """Returns a parallel set of T5 eval streams, based on ``benchmark_list``. - - Args: - benchmark_list: List of strings, each of which is a simple lower-case name - of a GLUE benchmark, e.g., ``'cola'``, ``'mnli'``, ``'rte'``. If a - benchmark includes an alternate eval (e.g., MNLI's "mismatched" - eval/validation split), you can specify it with an ``'_e2'`` suffix, - e.g., ``'mnli_e2'``. - """ - stream_list = list(map(T5GlueEvalStream, benchmark_list)) - return data.Parallel(stream_list)() - - -def _T5GlueDataStream(benchmark_id, t5_tokenization=False): - """Returns a T5-preprocessed data stream for ``benchmark_id``. - - Args: - benchmark_id: String that indicates the name and data split of a GLUE - benchmark. Data splits are indicated as underscore suffixes, e.g., - ``'cola_t'`` (Cola benchmark, training split), ``'rte_e'`` (RTE - benchmark, eval/validation split), and ``'mnli_e2'`` (MNLI benchmark, - alternate "mismatched" eval/validation split). - t5_tokenization: if true, then use t5_tokenization. - """ - return data.Serial( - _t5_glue_data_split(benchmark_id) - if t5_tokenization else _t5_glue_data_split_no_token(benchmark_id), - data.Tokenize(), - data.Shuffle(), - data.PadToLength(), - data.TruncateToLength(), - data.Batch(), - ) - - -@gin.configurable(module='trax.data') -def T5GlueEvalTasks(benchmark_list=gin.REQUIRED): - """Returns a list of T5 GLUE eval tasks, based on ``benchmark_list``. - - Args: - benchmark_list: List of strings, each of which indicates the name and - data split of a GLUE benchmark. Data splits are indicated as underscore - suffixes, e.g., ``'cola_t'`` (Cola benchmark, training split), - ``'rte_e'`` (RTE benchmark, eval/validation split), and ``'mnli_e2'`` - (MNLI alternate "mismatched" eval/validation split). - """ - task_list = list(map(_T5GlueEvalTask, benchmark_list)) - return task_list - - -def _T5GlueEvalTask(benchmark_id): - """Returns a T5 GLUE eval task, based on ``benchmark_id``.""" - eval_data = T5GlueEvalStream(benchmark_id) - benchmark_id = _ensure_eval_suffix(benchmark_id) - metrics = [tl.WeightedCategoryAccuracy(), tl.SequenceAccuracy()] - benchmark, split = benchmark_id.rsplit('_', 1) - if benchmark == 'cola': - name_upper = 'Cola' - elif benchmark == 'mnli': - name_upper = 'MNLI_matched' if split == 'e' else 'MNLI_mismatched' - else: - name_upper = benchmark.upper() - return supervised.training.EvalTask( - eval_data(), - metrics, - metric_names=[f'{name_upper} accuracy', - f'{name_upper} sequence accuracy']) - - -def _t5_glue_data_split_no_token(benchmark_id): - """Returns a GLUE data split prepared with the standard T5 preprocessor.""" - benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) - dataset = tfds.load(name=f'glue/{benchmark}', split=split) - processed_dataset = t5_data().preprocessors.glue( # pylint: disable=g-long-lambda - dataset, - benchmark_name=benchmark, - label_names=_GLUE_LABELS[benchmark]) - - def stream_of_inputs_targets_weights(generator=None): - del generator - while True: - for example in processed_dataset: - input_values = example['inputs'].numpy() - target_values = example['targets'].numpy() - yield (input_values, - target_values, - jnp.array([1] * len(target_values))) - - return stream_of_inputs_targets_weights - - -def _t5_glue_data_split(benchmark_id): - """Returns a GLUE data split prepared with the standard T5 preprocessor.""" - benchmark, split = _t5_glue_benchmark_and_split(benchmark_id) - dataset = tfds.load(name=f'glue/{benchmark}', split=split) - processed_dataset = generic_text_dataset_preprocess_fn( - dataset, - spm_path=t5_data().DEFAULT_SPM_PATH, - text_preprocess_fns=[ - lambda ds, training: t5_data().preprocessors.glue( # pylint: disable=g-long-lambda - ds, - benchmark_name=benchmark, - label_names=_GLUE_LABELS[benchmark]) - ], - copy_pretokenized=True, - debug_print_examples=True, - debug_print_examples_rate=0.05) - dataset_as_numpy = tfds.as_numpy(processed_dataset) - - def stream_of_inputs_targets_weights(generator=None): - del generator - while True: - for example in dataset_as_numpy: - input_values = example['inputs'] - target_values = example['targets'] - yield (jnp.array(input_values), - jnp.array(target_values), - jnp.array([1] * len(target_values))) - - return stream_of_inputs_targets_weights - - -def _t5_glue_benchmark_and_split(benchmark_id): - benchmark, mode = benchmark_id.rsplit('_', 1) - if mode == 't': - split = 'train' - elif benchmark == 'mnli': - split = 'validation_mismatched' if mode == 'e2' else 'validation_matched' - else: - split = 'validation' - return benchmark, split -# pylint: enable=invalid-name - - -def compute_single_result(op_name, num_args): - """An implementation of the most popular ops from the MathQA dataset.""" - # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ - # and specfically line 142 and following in new_DataStructure.py - # for an implementation which covers more details. - if op_name == 'add': - return num_args[0] + num_args[1] - elif op_name == 'circle_arc': - return num_args[0] / 360 * math.pi * 2 * num_args[1] - elif op_name == 'circle_area': - return math.pi * num_args[0]**2 - elif op_name == 'circle_sector_area': - return num_args[1] / 360 * math.pi * (num_args[0]**2) - elif op_name == 'circumface': - return 2 * math.pi * num_args[0] - elif op_name == 'choose': - return scipy.special.comb(num_args[0], num_args[1]) - elif op_name == 'cosine': - return math.cos(num_args[0]) - elif op_name == 'cube_edge_by_volume': - return num_args[0]**(1 / 3) - elif op_name == 'combined_work': - return 1 / ( - min(num_args[0], 1 / num_args[0]) + min(num_args[1], 1 / num_args[1])) - elif op_name == 'count_interval': - return num_args[0] - num_args[1] + 1 - elif op_name == 'diagonal': - return math.sqrt(num_args[0]**2 + num_args[1]**2) - elif op_name == 'divide' or op_name == 'speed': - if num_args[1] != 0: - return num_args[0] / num_args[1] - else: - return 0 - elif op_name == 'factorial': - return math.factorial(min(15, int(num_args[0]))) - elif op_name == 'floor': - return math.floor(num_args[0]) - elif op_name == 'find_work': - return 1 / ( - max( - min(num_args[0], 1 / num_args[0]), min( - num_args[1], 1 / num_args[1])) - min( - min(num_args[0], 1 / num_args[0]), - min(num_args[1], 1 / num_args[1]))) - elif op_name == 'from_percent': - return num_args[0] / 100 - elif op_name == 'gain_percent': - return 100 + num_args[0] - elif op_name == 'gcd': - return scipy.gcd(int(num_args[0]), int(num_args[1])) - elif op_name == 'inverse': - if num_args[0] != 0: - return 1 / num_args[0] - else: - return 0 - elif op_name == 'lcm': - return scipy.lcm(int(num_args[0]), int(num_args[1])) - elif op_name == 'log': - return math.log(max(1e-5, num_args[0]), 2) - elif op_name == 'loss_percent': - return 100 - num_args[0] - elif op_name == 'max': - return max(num_args[0], num_args[1]) - elif op_name == 'multiply': - return num_args[0] * num_args[1] - elif op_name == 'negate_percent': - return 100 - num_args[0] - elif op_name == 'negate': - return -num_args[0] - elif op_name == 'original_price_before_loss': - return num_args[1] * 100 / (100 + 1e-5 - num_args[0]) - elif op_name == 'original_price_before_gain': - return num_args[1] * 100 / (100 + num_args[0]) - elif op_name == 'permutation': - n, m = min(num_args[0], num_args[1]), max(num_args[0], num_args[1]) - return math.factorial(int(m)) / math.factorial(int(m - n)) - elif op_name == 'power': - return num_args[0]**min(num_args[1], 5) - elif op_name == 'percent': - return num_args[0] / 100 * num_args[1] - elif op_name == 'price_after_gain' or op_name == 'p_after_gain': - return (1 + num_args[0] / 100) * num_args[1] - elif op_name == 'price_after_loss' or op_name == 'price_after_loss': - return (1 - num_args[0] / 100) * num_args[1] - elif op_name == 'quadrilateral_area': - return num_args[0] * (num_args[1] + num_args[2]) / 2 - elif op_name == 'reminder': - return num_args[0] % num_args[1] - elif op_name == 'rectangle_area': - return num_args[0] * num_args[1] - elif op_name == 'rectangle_perimeter': - return 2 * (num_args[0] + num_args[1]) - elif op_name == 'rhombus_area': - return num_args[0] * num_args[1] / 2 - elif op_name == 'sine': - return math.sin(num_args[0]) - elif op_name == 'sqrt': - return math.sqrt(max(0, num_args[0])) - elif op_name == 'subtract': - return num_args[0] - num_args[1] - elif op_name == 'square_edge_by_perimeter': - return num_args[0] / 4 - elif op_name == 'square_edge_by_area': - return math.sqrt(num_args[0]) - elif op_name == 'square_area': - return num_args[0]**2 - elif op_name == 'surface_cube': - return 6 * num_args[0]**2 - elif op_name == 'surface_rectangular_prism': - return 2 * ( - num_args[0] * num_args[1] + num_args[0] * num_args[2] + - num_args[1] * num_args[2]) - elif op_name == 'semi_circle_perimiter': - return math.pi * num_args[0] + 2 * num_args[0] - elif op_name == 'square_perimeter' or op_name == 'rhombus_perimeter': - return 4 * num_args[0] - elif op_name == 'surface_sphere': - return 4 * math.pi * num_args[0]**2 - elif op_name == 'speed_ratio_steel_to_stream': - return (num_args[0] + num_args[1]) / (num_args[0] - num_args[1]) - elif op_name == 'speed_in_still_water': - return (num_args[0] + num_args[1]) / 2 - elif op_name == 'stream_speed': - return (num_args[0] - num_args[1]) / 2 - elif op_name == 'trapezium_area': - return num_args[0] * (num_args[1] + num_args[2]) / 2 - elif op_name == 'triangle_area': - return num_args[0] * num_args[1] / 2 - elif op_name == 'triangle_perimeter': - return num_args[0] + num_args[1] + num_args[2] - elif op_name == 'triangle_area_three_edges': - # Heron's formula - s = (num_args[0] + num_args[1] + num_args[2]) / 2 - return math.sqrt( - max(0, - s * (s - num_args[0]) * (s - num_args[1]) * (s - num_args[2]))) - elif op_name == 'union_prob': - return num_args[0] + num_args[1] - num_args[2] - elif op_name == 'negate_prob': - return 1 - num_args[0] - elif op_name == 'volume_cube': - return num_args[0]**3 - elif op_name == 'volume_cone': - return math.pi * num_args[0]**2 * num_args[1] / 3 - elif op_name == 'volume_cylinder': - return math.pi * num_args[0]**2 * num_args[1] - elif op_name == 'volume_rectangular_prism': - return num_args[0] * num_args[1] * num_args[2] - elif op_name == 'volume_sphere': - return 4 / 3 * math.pi * num_args[0]**3 - - -def compute_result(list_op, list_num): - """Python execution of MathQA ops.""" - # The last of temporary results is the final answer. - temporary_results = [] - for op in list_op: - op_name = op.split('(')[0] - start_bracket = op.find('(') - end_bracket = op.find(')') - op_args = op[start_bracket + 1:end_bracket].split(',') - num_args = [] - for arg in op_args: - # The hash stands for a number stored in temporary_results. - # For example #2 refers to the third temporary result. - if arg[0] == '#': - temp_index = int( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - arg)[0]) - num_args.append(temporary_results[temp_index]) - # The n prefix stands for numbers which listed in list_num - - # originally they were contained in the text. - elif arg[0] == 'n': - n_index = int( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - arg)[0]) - num_args.append(list_num[n_index]) - elif arg[0] == 'c': - if arg == 'const_pi': - constant = math.pi - elif arg == 'const_deg_to_rad': - constant = math.pi / 180 - else: - consts = re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', arg) - if len(consts) == 1: - constant = float(consts[0]) - else: - constant1 = float(consts[0]) - constant2 = float('0.' + consts[1]) - constant = constant1 + constant2 - num_args.append(constant) - temporary_results.append(compute_single_result(op_name, num_args)) - return temporary_results - - -def single_op_to_python_command(op_name, num_args): - """An implementation of the most popular ops from the MathQA dataset.""" - # See https://gitlab.cs.washington.edu/amini91/mathqa-categorization/ - # and specfically line 142 and following in new_DataStructure.py - # for an implementation which covers more details. - if op_name == 'add': - return '{} + {}'.format(num_args[0], num_args[1]) - elif op_name == 'circle_arc': - return '{} / 360 * math.pi * 2 * {}'.format(num_args[0], num_args[1]) - elif op_name == 'circle_area': - return 'math.pi * {}**2'.format(num_args[0]) - elif op_name == 'circle_sector_area': - return '{} / 360 * math.pi * ({}**2)'.format(num_args[1], num_args[0]) - elif op_name == 'circumface': - return '2 * math.pi * {}'.format(num_args[0]) - elif op_name == 'choose': - return 'scipy.special.comb({}, {})'.format(num_args[0], num_args[1]) - elif op_name == 'cosine': - return 'math.cos({})'.format(num_args[0]) - elif op_name == 'cube_edge_by_volume': - return '{}**(1 / 3)'.format(num_args[0]) - elif op_name == 'combined_work': - return '1 / (min({}, 1 / {}) + min({}, 1 / {}))'.format( - num_args[0], num_args[0], num_args[1], num_args[1]) - elif op_name == 'count_interval': - return '{} - {} + 1'.format(num_args[0], num_args[1]) - elif op_name == 'diagonal': - return 'math.sqrt({}**2 + {}**2)'.format(num_args[0], num_args[1]) - elif op_name == 'divide' or op_name == 'speed': - # safe divide - if num_args[1] != 0: - return '{} / {}'.format(num_args[0], num_args[1]) - else: - return '0' - elif op_name == 'factorial': - return 'math.factorial(min(15, int({})))'.format(num_args[0]) - elif op_name == 'floor': - return 'math.floor({})'.format(num_args[0]) - elif op_name == 'find_work': - return ('1 / (max(min({}, 1 / {}), min({}, 1 / {})) - min(min({}, 1 / {}), ' - 'min({}, 1 / {})))').format(num_args[0], num_args[0], num_args[1], - num_args[1], num_args[0], num_args[0], - num_args[1], num_args[1]) - elif op_name == 'from_percent': - return '{} / 100'.format(num_args[0]) - elif op_name == 'gain_percent': - return '100 + {}'.format(num_args[0]) - elif op_name == 'gcd': - return 'scipy.gcd(int({}), int({}))'.format(num_args[0], num_args[1]) - elif op_name == 'inverse': - # safe inverse - if num_args[0] != 0: - return '1 / {}'.format(num_args[0]) - else: - return '0' - elif op_name == 'lcm': - return 'scipy.lcm(int({}), int({}))'.format(num_args[0], num_args[1]) - elif op_name == 'log': - return 'math.log(max(1e-5, {}), 2)'.format(num_args[0]) - elif op_name == 'loss_percent': - return '100 - {}'.format(num_args[0]) - elif op_name == 'max': - return 'max({},{})'.format(num_args[0], num_args[1]) - elif op_name == 'multiply': - return '{} * {}'.format(num_args[0], num_args[1]) - elif op_name == 'negate_percent': - return '100 - {}'.format(num_args[0]) - elif op_name == 'negate': - return '-{}'.format(num_args[0]) - elif op_name == 'original_price_before_loss': - return '{} * 100 / (100 + 1e-5 - {}) # original price before loss'.format( - num_args[1], num_args[0]) - elif op_name == 'original_price_before_gain': - return '{} * 100 / (100 + {}) # original_price_before gain'.format( - num_args[1], num_args[0]) - elif op_name == 'permutation': - return ('math.factorial(int(max({}, {}))) / math.factorial(int(max({}, {}) ' - '- min({}, {}))) # find all permutations').format( - num_args[0], num_args[1], num_args[0], num_args[1], num_args[0], - num_args[1]) - elif op_name == 'power': - return '{}**min({}, 5)'.format(num_args[0], num_args[1]) - elif op_name == 'percent': - return '{} / 100 * {}'.format(num_args[0], num_args[1]) - elif op_name == 'price_after_gain' or op_name == 'p_after_gain': - return '(1 + {} / 100) * {}'.format(num_args[0], num_args[1]) - elif op_name == 'price_after_loss' or op_name == 'price_after_loss': - return '(1 - {} / 100) * {}'.format(num_args[0], num_args[1]) - elif op_name == 'quadrilateral_area': - return '{} * ({} + {}) / 2 # quadrilateral area'.format( - num_args[0], num_args[1], num_args[2]) - elif op_name == 'reminder': - return '{} % {}'.format(num_args[0], num_args[1]) - elif op_name == 'rectangle_area': - return '{} * {} # area of rectangle'.format(num_args[0], num_args[1]) - elif op_name == 'rectangle_perimeter': - return '2 * ({} + {}) # perimetere of rectangle'.format( - num_args[0], num_args[1]) - elif op_name == 'rhombus_area': - return '{} * {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'sine': - return 'math.sin({})'.format(num_args[0]) - elif op_name == 'sqrt': - return 'math.sqrt(max(0, {}))'.format(num_args[0]) - elif op_name == 'subtract': - return '{} - {}'.format(num_args[0], num_args[1]) - elif op_name == 'square_edge_by_perimeter': - return '{} / 4. # square edge given perimeter'.format(num_args[0]) - elif op_name == 'square_edge_by_area': - return 'math.sqrt({}) # square edge given area'.format(num_args[0]) - elif op_name == 'square_area': - return '{}**2'.format(num_args[0]) - elif op_name == 'surface_cube': - return '6 * {}**2 # surface of a cube'.format(num_args[0]) - elif op_name == 'surface_rectangular_prism': - return '2 * ({} * {} + {} * {} + {} * {}) # surface of a rectangular prism'.format( - num_args[0], num_args[1], num_args[0], num_args[2], num_args[1], - num_args[2]) - elif op_name == 'semi_circle_perimiter': - return 'math.pi * {} + 2 * {} # perimeter of a semi-circle'.format( - num_args[0], num_args[0]) - elif op_name == 'square_perimeter' or op_name == 'rhombus_perimeter': - return '4 * {}'.format(num_args[0]) - elif op_name == 'surface_sphere': - return '4 * math.pi * {}**2'.format(num_args[0]) - elif op_name == 'speed_ratio_steel_to_stream': - return '({} + {}) / ({} - {})'.format(num_args[0], num_args[1], num_args[0], - num_args[1]) - elif op_name == 'speed_in_still_water': - return '{} + {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'stream_speed': - return '{} - {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'trapezium_area': - return '{} * ({} + {}) / 2'.format(num_args[0], num_args[1], num_args[2]) - elif op_name == 'triangle_area': - return '{} * {} / 2'.format(num_args[0], num_args[1]) - elif op_name == 'triangle_perimeter': - return '{} + {} + {} # perimeter of a triangle'.format( - num_args[0], num_args[1], num_args[2]) - elif op_name == 'triangle_area_three_edges': - return ("(lambda s, a, b, c: math.sqrt(max(0, s * (s - a) * (s - b) * (s - " - "c))))(({} + {} + {}) / 2, {}, {}, {}) # Heron's formula").format( - num_args[0], num_args[1], num_args[2], num_args[0], num_args[1], - num_args[2]) - elif op_name == 'union_prob': - return '{} + {} - {}'.format(num_args[0], num_args[1], num_args[2]) - elif op_name == 'negate_prob': - return '1 - {}'.format(num_args[0]) - elif op_name == 'volume_cube': - return '{}**3'.format(num_args[0]) - elif op_name == 'volume_cone': - return 'math.pi * {}**2 * {} / 3'.format(num_args[0], num_args[1]) - elif op_name == 'volume_cylinder': - return 'math.pi * {}**2 * {}'.format(num_args[0], num_args[1]) - elif op_name == 'volume_rectangular_prism': - return '{} * {} * {}'.format(num_args[0], num_args[1], num_args[2]) - elif op_name == 'volume_sphere': - return '4 / 3 * math.pi * {}**3'.format(num_args[0]) - - -def compute_program(list_op): - """Python execution of MathQA ops.""" - # The last of temporary results is the final answer. - temporary_results = [] - num_op = 0 - for op in list_op: - op_name = op.split('(')[0] - start_bracket = op.find('(') - end_bracket = op.find(')') - op_args = op[start_bracket + 1:end_bracket].split(',') - num_args = [] - for arg in op_args: - # The hash stands for a number stored in temporary_results. - # For example #2 refers to the third temporary result. - if arg[0] == '#': - temp_index = int( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - arg)[0]) - num_args.append('t{}'.format(temp_index)) - # The n prefix stands for numbers which listed in list_num - - # originally they were contained in the text. - elif arg[0] == 'n': - # n_index = int( - # re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - # arg)[0]) - num_args.append(arg) - elif arg[0] == 'c': - if arg == 'const_pi': - constant = math.pi - elif arg == 'const_deg_to_rad': - constant = math.pi / 180 - else: - consts = re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', arg) - if len(consts) == 1: - constant = float(consts[0]) - else: - constant1 = float(consts[0]) - constant2 = float('0.' + consts[1]) - constant = constant1 + constant2 - num_args.append(str(constant)) - temporary_result = 't{} = {}'.format( - num_op, single_op_to_python_command(op_name, num_args)) - temporary_results.append(temporary_result) - num_op += 1 - return temporary_results - - -def compute_nums(question): - """Finds numbers in a string and convert them to floats.""" - # The funny looking replace is needed to deal with numbers such as 4,000 - # TODO(henrykm) deal with numbers written as words "one", "two", ... - return [ - float(num.replace(',', '')) for num in re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', question) - ] - - -def compute_ops(linear_formula): - list_op = linear_formula.split('|') - # In some cases the list of operations contains a superflous last element, - # namely an empty string. - if not list_op[-1]: - list_op = list_op[:-1] - return list_op - - -def process_single_mathqa_example(example): - """Execute a single example and verify coherence of a MathQA problem. - - Args: - example: a dictionary with the following fields: Problem - a natural - language formulation of the problem Rationale - a natural language - solution of the problem options - five possible answers ( a) b) c) d) and - e) ) correct - the letter representing the correct answer - annotated_formula - formula representing the full solution linear_formula - - a string of operations separated by the | character, e.g. - multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)| - multiply(#2,const_100)|divide(#3,#1)| category - a natural language - description of the category to which a given problem belongs. - - Returns: - answer_num: numerical answer contained in the example - python_result: numerical answers computed in Python, including intermediate - results. The answer_num should be close python_result[-1] - list_op: list of arithmetic operations - list_num: list of identified numbers in the text - """ - question = example['Problem'] - list_num = compute_nums(question) - list_op = compute_ops(example['linear_formula']) - answers = example['options'] - correct_answer = example['correct'] - index = answers.find('{} )'.format(correct_answer)) - answer_string = re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', answers[index:]) - # The if statement deals with empty lists - they are needed to treat - # a correct non-numerical answer e) None of the above. Here we do not want - # non-numerical answers, hence we return None. - if answer_string: - answer_num = float( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - answers[index:])[0].replace(',', '')) - else: - return None - # The if statements below deals with answers written as fractions e.g. - # a ) 1 / 2 , b ) 1 / 3 , c ) 1 / 5 , d ) 10 / 30 , e ) 2 / 5 ? - index_end_of_answer = index + len(str(answer_num)) + 3 - if index_end_of_answer < len(answers) and answers[index_end_of_answer] == '/': - answer_denom = float( - re.findall(r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - answers[index_end_of_answer:])[0].replace(',', '')) - answer_num /= answer_denom - python_result = compute_result(list_op, list_num) - python_program = compute_program(list_op) - return answer_num, python_result, python_program, list_op, list_num - - -def convert_float_to_mathqa(number): - floor = int(float(number)) - if floor == number: - return 'const_' + str(floor) - else: - return 'const_' + str(floor) + '_' + str(number)[len(str(floor)) + 1:] - - -def convert_to_subtract(const_string): - return 'subtract({},const_0)'.format(const_string) - - -def execute_mathqa_dsl_program(problem, dsl_code): - """Executes the DSL code for a given problem. - - Args: - problem: problem formulation (needed to get parameters). - dsl_code: DSL code. - - Returns: - the result of executing of the DSL code. - """ - n0_loc = problem.find('n0') - list_num = compute_nums(problem[n0_loc:]) - # The list contains _all_ numbers in the string, hence in particular - # for n0 = 2.0 n1 = 3.0 we are getting list_num = [0.0, 2.0, 1.0, 3.0], - # so that below we are filtering the odd occurrences. - assert len(list_num) % 2 == 0 - list_num = [list_num[2 * i + 1] for i in range(int(len(list_num) / 2))] - - # dsl_code is a list of strings; since all DSL programs are single liners, - # we need to guess the correct line. For now we use the same location as in - # in the ground truth examples, that is the first line. - list_op = compute_ops(dsl_code[0]) - - try: - results = compute_result(list_op, list_num)[-1] - except: # pylint: disable=bare-except - results = None - return results - - -def is_number(s): - try: - float(s) - return True - except: # pylint: disable=bare-except - return False - - -def execute_mathqa_program(problem, program): - """Executes the DSL code for a given problem. - - Args: - problem: problem formulation (not needed, but we want the same API as - in the DSL case). - program: Python code. - - Returns: - the result of executing of the Python code. - """ - del problem # problem only needed in the DSL version. - # Programs are lists of strings. We need to concatenate them in order to exec. - program = '\n'.join(program) - var_dict = {} - try: - # The logic of this is the following: if exec with timeout is working - # without exceptions, then we can call exec again and gather the variables. - exec(program, globals(), var_dict) # pylint: disable=exec-used - if 'answer' in var_dict and is_number(var_dict['answer']): - return float(var_dict['answer']) - else: - return None - except: # pylint: disable=bare-except - return None - - -@gin.configurable(module='trax.data') -def CreateMathQAInputs( # pylint: disable=invalid-name - dataset_path=None, - train=True, - test=False, - challenge=False, - tolerance=0.01, - cumulative=True, - python_code=False, - full_dict=False, - partial_results=True, - nlp_rationale=False, - correct_answer=False, - answer_in_mathqa_format=True, - correct_answer_given_reasoning=False, - category=False, - order_prediction=False, - reduced_operation_name=True, - qed=False): - """Prepares MathQA inputs. - - The generation procedure leaves a lot parameters to be set by the user. - Currently we support only correct examples in the following sense: - python execution agrees with the declared answer up to 1%. - - According to this criterion wrong examples such as - problem: calculate 85184 Ãˇ ? = 352 - operations ['multiply(n0,n1)'] - are ignored (this should be divide(n0,n1) in this case). - - Args: - dataset_path: a path with the MathQA dataset. - train: if True, then generate training examples; if train, test and - challenge are set to False generate validation examples. - test: if train is set to False and test is set to True, - then generate test examples. - challenge: if train and test are set to False and challenge is set to True, - then generate challenge examples. - tolerance: if for a given example relative difference between Python result - and the result declared in the dataset exceeds the level, then the example - is dropped; tolerances ranging from 0.1 to 0.001 yield from 18K to 21K - examples. - cumulative: if set to True, then generate examples in the format input - - problem + numbers + op1 + op2 + op3 target - op4 If set to False, then - examples are in the format input - problem + numbers target - all - operations. - python_code: if set to True, then generates python code instead of - MathQA commands. - full_dict: if set to True, then Python examples are returned together with - the DSL code and the NLP rationale. - partial_results: if set to True, then partial results will be reported as - part of the input, e.g. input - problem + numbers + op1 + #1 + op2 + #2 + - op3 + #3, target - op4, where #k is the partial results from operation - opk. Activated only in cumulative set to True. - nlp_rationale: if set to True, then input is the problem and the target is - the nlp rationale. - correct_answer: if set to True, then input is the problem plus all possible - answers and the target is the correct answer. - answer_in_mathqa_format: if set to True, then convert numerical answer to - the MathQA format and wrap it in the subtract operation. - E.g. "3.13" is converted to "subtract(const_3_13,const_0)". - correct_answer_given_reasoning: if set to True, then input is the problem - plus linear formula plus all possible answers and the target is the - correct answer. - category: if set to True, then input is the problem and the target is its - category. - order_prediction: if set to True, then input is the problem and a list of - all operations; with probability 0.5 two operations are swapped; the task - consists in detecting whether the operations were swapped. See the - order prediction task in CreateAquaInputs in this file. - reduced_operation_name: If set to True, then in order prediction consider - only the operation token without parameterers. - qed: if set to True, then the reasoning is finished with an additional - operation qed. - - Returns: - mathqa_yield_examples: a generator of MathQA examples; the generator yields - non-tokenized examples - they can be further processed using for example - the tokenize function from this module - """ - if train: - dataset_path = os.path.join(dataset_path, 'train.json') - elif test: - dataset_path = os.path.join(dataset_path, 'test.json') - elif challenge: - dataset_path = os.path.join(dataset_path, 'challenge_test.json') - else: - dataset_path = os.path.join(dataset_path, 'dev.json') - # Opening with GFile allows to use remotely stored files, e.g. - # in a gs bucket. - dataset_handle = tf.io.gfile.GFile(dataset_path, 'r') - dataset = json.load(dataset_handle) - - def mathqa_yield_examples(generator=None): - del generator - while True: - for example in itertools.cycle(dataset): - result = process_single_mathqa_example(example) - # TODO(henrykm): Remove the first two ifs. - if not result: - continue - answer_num, python_result, python_program, list_op, list_num = result - if not answer_num or not python_result[-1]: - continue - if qed: - list_op.append('qed') - if math.isclose(answer_num, python_result[-1], rel_tol=tolerance): - input_prefix = example['Problem'] - for i in range(len(list_num)): - input_prefix += ' n{} = {}'.format(i, list_num[i]) - if cumulative: - for i in range(len(list_op)): - input_values = input_prefix - target_values = list_op[i] - input_prefix += ' ' + list_op[i] - if partial_results: - input_prefix += ' #{} = {}'.format(i, answer_num) - yield input_values, target_values, np.array([1] * - len(target_values)) - elif python_code: - input_values = '# ' + input_prefix - target_values = '' - for command in python_program: - if 'math' in command: - target_values += 'import math\n' - break - for command in python_program: - if 'scipy' in command: - target_values += 'import scipy\n' - break - for i in range(len(list_num)): - target_values += 'n{} = {}\n'.format(i, list_num[i]) - target_values += '\n'.join(python_program[:-1]) - final_line = python_program[-1].split('=')[1] - target_values += '\nanswer ={}'.format(final_line) - var_dict = {} - # We generate a python code and want to check whether the answer - # is coorect. - exec(target_values, globals(), var_dict) # pylint: disable=exec-used - if math.isclose(answer_num, var_dict['answer'], rel_tol=tolerance): - if full_dict: - yield input_values, target_values, example[ - 'linear_formula'], example['Rationale'] - else: - yield input_values, target_values, np.array([1] * - len(target_values)) - elif nlp_rationale: - input_values = 'infer full rationale: ' + input_prefix - target_values = example['Rationale'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif correct_answer: - input_values = 'infer correct answer: ' + input_prefix - input_values += ' ' + example['options'] - if answer_in_mathqa_format: - target_values = str(answer_num) - target_values = convert_to_subtract( - convert_float_to_mathqa(target_values)) - else: - target_values = example['correct'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif correct_answer_given_reasoning: - input_values = 'infer correct answer given reasoning: ' + input_prefix - input_values += ' ' + ' '.join(list_op) + ' ' + example['options'] - target_values = example['correct'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif category: - input_values = 'infer category: ' + input_prefix - target_values = example['category'] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif order_prediction: - if np.random.uniform() < 0.5 and len(list_op) >= 2: - idx = range(len(list_op)) - i1, i2 = random.sample(idx, 2) - list_op[i1], list_op[i2] = list_op[i2], list_op[i1] - target_values = 'not_ordered' - else: - target_values = 'ordered' - if reduced_operation_name: - list_op = [op.split('(')[0] for op in list_op] - input_values = 'order prediction: ' + input_prefix + ' ' + ' '.join( - list_op) - yield input_values, target_values, np.array([1] * - len(target_values)) - else: - input_values = 'infer full calculation: ' + input_prefix - target_values = example['linear_formula'] - yield input_values, target_values, np.array([1] * - len(target_values)) - - return mathqa_yield_examples - - -@gin.configurable(module='trax.data') -def CreateAquaInputs( # pylint: disable=invalid-name - dataset_path=None, - train=True, - cumulative=False, - rationale=False, - correct_answer=False, - correct_answer_given_reasoning=False, - partial_reasoning=True, - order_prediction=False): - """Prepares Aqua inputs. - - Args: - dataset_path: a path with the Aqua dataset. - train: if True, then generate training examples, otherwhise generate - validation examples (the dataset has also a test set). - cumulative: if set to True, then generate examples in the format input - - problem + step1 + step3 + step3 target - step4 If set to False, then - examples are in the format input - problem, target - all operations. - rationale: if set to True, then input is the problem and the target is the - rationale. - correct_answer: if set to True, then input is the problem plus all possible - answers and the target is the correct answer. - correct_answer_given_reasoning: if set to True, then input is the problem - plus reasoning (aka rationale) plus all possible answers and the target is - the correct answer. - partial_reasoning: an additional option related to - correct_answer_given_reasoning; if set to True, then we take a random - prefix of the reasoning. - order_prediction: if set to True, then input is the problem and a list of - all operations; with probability 0.5 two operations are swapped; the task - consists in detecting whether the operations were swapped. A similar - additional task was considered in https://arxiv.org/pdf/1909.11942.pdf and - in a recent work of Piotr Piękos, henrykm@ and mateuszm@. - - Returns: - aqua_yield_examples: a generator of Aqua examples; the generator yields - non-tokenized examples - they can be further processed using for example - the tokenize function from this module - """ - if train: - dataset_path = os.path.join(dataset_path, 'train.json') - else: - dataset_path = os.path.join(dataset_path, 'dev.json') - # Opening with GFile allows to use remotely stored files, e.g. - # in a gs bucket. - dataset_handle = tf.io.gfile.GFile(dataset_path, 'r') - dataset = [] - for line in dataset_handle: - dataset.append(json.loads(line)) - - def aqua_yield_examples(generator=None): - del generator - while True: - for example in itertools.cycle(dataset): - input_prefix = example['question'] - steps = example['rationale'].split('\n') - if cumulative: - for i in range(len(steps)): - input_values = 'infer cumulative rationale: ' + input_prefix - target_values = steps[i] - input_prefix += ' ' + steps[i] - yield input_values, target_values, np.array([1] * - len(target_values)) - elif rationale: - input_values = 'infer full rationale: ' + input_prefix - target_values = example['rationale'] - yield input_values, target_values, np.array([1] * len(target_values)) - elif correct_answer: - input_values = 'infer correct answer: ' + input_prefix - input_values += ' ' + ' '.join(example['options']) - target_values = example['correct'] - yield input_values, target_values, np.array([1] * len(target_values)) - elif correct_answer_given_reasoning: - input_values = 'infer correct answer given reasoning: ' + input_prefix - if partial_reasoning: - reasoning_list = example['rationale'].split('\n') - reasoning_list = reasoning_list[0:np.random - .randint(0, len(reasoning_list))] - reasoning = '\n'.join(reasoning_list) - else: - reasoning = example['rationale'] - input_values += ' ' + example['rationale'] + ' ' + ' '.join( - example['options']) - target_values = example['correct'] - yield input_values, target_values, np.array([1] * len(target_values)) - elif order_prediction: - if np.random.uniform() < 0.5 and len(steps) >= 2: - idx = range(len(steps)) - i1, i2 = random.sample(idx, 2) - steps[i1], steps[i2] = steps[i2], steps[i1] - target_values = 'not_ordered' - else: - target_values = 'ordered' - input_values = 'order prediction: ' + input_prefix + ' ' + '\n'.join( - steps) - yield input_values, target_values, np.array([1] * len(target_values)) - else: - raise ValueError( - 'One of the boolean parameters of the Aqua generator must be set to True.' - ) - - return aqua_yield_examples - - -@gin.configurable(module='trax.data') -def CreateDropInputs( # pylint: disable=invalid-name - train=True, mathqa_format=False): - """Prepares Drop inputs. - - Args: - train: if True, then generate training examples, otherwhise generate - validation examples (the dataset has also a test set). - mathqa_format: if True, then floats in targets are converted to the - the MathQA convention and wrapped in the subtract operation. - E.g. "3.13" is converted to "subtract(const_3_13,const_0)". - - Returns: - drop_yield_examples: a generator of Drop examples; the generator yields - non-tokenized examples - they can be further processed using for example - the tokenize function from this module - """ - if train: - dataset = tfds.load(name='drop', split='train') - else: - dataset = tfds.load(name='drop', split='dev') - dataset = tfds.as_numpy(dataset) - - def drop_yield_examples(generator=None): - del generator - while True: - for example in itertools.cycle(dataset): - input_values = 'drop question: ' + example['passage'].decode( - 'utf-8') + ' ' + example['question'].decode('utf-8') - target_values = example['answer'].decode('utf-8') - # Apparently the dataset has some empty "target values" - - # when such a value is encountered, the Tokenizer decides to assign - # to it a float32 tensor and the training fails. - if not target_values: - continue - if mathqa_format: - if target_values.replace('.', '', 1).isdigit(): - target_values = convert_to_subtract( - convert_float_to_mathqa(target_values)) - yield input_values, target_values, np.array( - [1] * len(target_values), dtype=np.int32) - - return drop_yield_examples - - -@gin.configurable(module='trax.data') -def CreateAnnotatedDropInputs( # pylint: disable=invalid-name - dataset_path=None, - train=True, - single_file=True, - unique=False, - total_number_of_samples=None, - percentile=1.): - r"""Prepares annotated Drop inputs. - - Example of an annotated input which can be used with this interface: - - { - 'passage': 'The Armenian Prelature of Cyprus was established in 973 by - Catholicos Khatchig I. Historically, the Prelature has been under the - jurisdiction of the Catholicosate of the Great House of Cilicia, while today - it is the oldest theme that falls under its jurisdiction. Since 2014 the - Prelate, a Catholicosal Vicar General, has been Archbishop Nareg Alemezian. - The parish priest in Nicosia is Fr. Momik Habeshian, while the parish priest - in Larnaca and Limassol is Fr. Mashdots Ashkarian. For centuries, the - Prelature building was located within the Armenian compound in Victoria - street in walled Nicosia; when that area was taken over by Turkish-Cypriot - extremists in 1963-1964, the Prelature was temporarily housed in Aram - Ouzounian street and, later on, in Kyriakos Matsis street in Ayios - Dhometios. Thanks to the efforts of Bishop Zareh Aznavorian and with - financial aid from the Evangelical Church of Westphalia, the new Prelature - building was erected in 1983, next to the Virgin Mary church and the Nareg - school in Nicosia, by architects Athos Dikaios & Alkis Dikaios; it was - officially inaugurated on 4 March 1984, during the pastoral visit of - Catholicos Karekin II. By initiative of Archbishop Varoujan Hergelian, in - 1998 the basement of the building was renovated and the "Vahram Utidjian" - Hall was formed; previously a store room, it became a reality from the - proceeds of the auction in 1994 of the art collection that Vahram Utidjian - had donated to the Prelature in 1954. It was inaugurated on 3 February 1999 - by Catholicos Aram I; numerous charity, communal and cultural events take - place there. The Prelature\'s consistory houses a collection of - ecclesiastical relics, some of which were previously in the old Virgin Mary - church or the Magaravank.', - 'question': 'How many years after the Vahram Utidjian was donated to the - Prelature was it sold at an auction?', - 'answer': 40, - 'calculation': 'subtract(n8,n9)' - } - - In this example the calculation is formulated using the notation from the - MathQA dataset, but this is not required. subtract(n8,n9) means that the - answer 40 can be obtained through the substraction of the 9th and and the 10th - number in the input. The input consists of the passage concatened with the - question. The annotations can be generated using, for example, a method - from the paper https://arxiv.org/abs/1909.00109. - - Args: - dataset_path: a path with the Aqua dataset. - train: if True, then generate training examples, otherwhise generate - validation examples (the dataset has also a test set). - single_file: if True, then look just for one file. If False, read all - json files in a given directory and assume that each file contains one - example. Applied only to training data. - unique: if set to True, then the generator will provide at most one question - per passage. - total_number_of_samples: if set to a positive integer, then the total number - of unique samples will be bounded total_number_of_samples. - percentile: the percentile of the train dataset used for training; default - set to 1., though setting to a lower value can be interesting when - combined train is combined with another source of data. - - Returns: - drop_annotated_yield_examples: a generator of annotated Drop examples; - the generator yields non-tokenized examples - they can be further processed - using for example the tokenize function from this module. - """ - if train: - if single_file: - dataset_path = os.path.join(dataset_path, 'train_annotated.json') - else: - dataset_path = os.path.join(dataset_path, 'dev_annotated.json') - - def load_dataset(): - dataset = [] - if single_file: - # Opening with GFile allows to use remotely stored files, e.g. - # in a gs bucket. - dataset_handle = tf.io.gfile.GFile(dataset_path, 'r') - for line in dataset_handle: - dataset.append(json.loads(line)) - else: - all_files = tf.io.gfile.listdir(dataset_path) - for filename in all_files: - if 'json' in filename: - print('Loading data from file {}'.format(filename)) - with tf.io.gfile.GFile(os.path.join(dataset_path, filename)) as f: - for line in f: - dataset.append(json.loads(line)) - print('The total size of the dataset {}'.format(len(dataset))) - return dataset[:int(len(dataset) * percentile)] - - def drop_annotated_yield_examples(generator=None): - del generator - while True: - passages = set() - unique_examples = set() - # Notice that below we enable a poor man RL loop - # aka the DAgger algorithm: https://arxiv.org/pdf/1011.0686.pdf - # tl;dr: after parsing all examples we re-load the dataset - this - # may become handy if a prediction service generates new examples. - dataset = load_dataset() - for example in dataset: - # If total_number_of_samples is not None and we have reached this - # number of samples, then we re-load the dataset. - if total_number_of_samples: - if len(unique_examples) >= total_number_of_samples: - break - # Do we have a pre-calculated input in the example? - if 'input' in example.keys(): - question = example['input'] - # Remove the old prompt - question = question[question.find(':') + 2:] - else: - # If input is not present, then we expect that this is an - # original drop example. - if unique and example['passage'] in passages: - continue - passages.add(example['passage']) - question = example['passage'] + ' ' + example['question'] - list_num = [ - float(num.replace(',', '').rstrip('.').lstrip('.')) # pylint: disable=g-complex-comprehension - for num in re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', - question) - ] - for i in range(len(list_num)): - question += ' n{} = {}'.format(i, list_num[i]) - input_values = 'drop annotated question: ' + question - target_values = example['calculation'] - unique_examples.add((input_values, target_values)) - yield input_values, target_values, np.array( - [1] * len(target_values), dtype=np.int32) - - return drop_annotated_yield_examples diff --git a/trax/data/tf_inputs_test.py b/trax/data/tf_inputs_test.py deleted file mode 100644 index 376f59c2b..000000000 --- a/trax/data/tf_inputs_test.py +++ /dev/null @@ -1,873 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data.tf_inputs.""" - -import collections -import os -from unittest import mock - -import gin -import numpy as np -from t5.data import assert_dataset -from t5.data import preprocessors as t5_processors -import tensorflow as tf -import tensorflow_datasets as tfds -from trax.data import inputs # pylint: disable=unused-import -from trax.data import tf_inputs - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') - - -def _test_dataset_ints(inp_lengths, tgt_lengths): - """Create a test dataset of int64 tensors of given shapes.""" - - def generator(): - for inp_len, tgt_len in zip(inp_lengths, tgt_lengths): - inp = np.ones([inp_len], dtype=np.int64) - tgt = np.ones([tgt_len], dtype=np.int64) - yield {'inputs': inp, 'targets': tgt} - - types = {'inputs': tf.int64, 'targets': tf.int64} - shapes = {'inputs': tf.TensorShape([None]), 'targets': tf.TensorShape([None])} - return tf.data.Dataset.from_generator( - generator, output_types=types, output_shapes=shapes) - - -def _load_dataset(name, split='train'): - return tfds.load( - name=name, split=split, data_dir=_TESTDATA, shuffle_files=False) - - -def _c4_dataset(split='train'): - return _load_dataset('c4:2.3.0', split=split) - - -def _spm_path(): - return os.path.join(_TESTDATA, 'sentencepiece.model') - - -def _t5_gin_config(): - # The following pages worth of gin configuration are required because a lot - # of T5 functions have `gin.REQUIRED` in code, i.e. you cannot use these - # functions at all without having configured gin. - - noise_density = 0.15 - max_input_length = 50 - - # What preprocessors to apply - we select a random chunk of the document if - # it exceeds a certain lengths (`select_random_chunk`), then split up long - # examples (`split_tokens`) and finally the denoising objective (`denoise`). - # - # In addition to this T5 concates multiple documents together to reduce - # padding (`reduce_concat_tokens`) after `select_random_chunk`, but we skip - # that since we don't do sequence packing. - gin.bind_parameter('unsupervised.preprocessors', [ - t5_processors.select_random_chunk, - t5_processors.split_tokens, - t5_processors.denoise, - ]) - - # select_random_chunk - gin.bind_parameter('select_random_chunk.feature_key', 'targets') - gin.bind_parameter('select_random_chunk.max_length', max_input_length) - - # reduce_concat_tokens - gin.bind_parameter('random_spans_helper.extra_tokens_per_span_inputs', 1) - gin.bind_parameter('random_spans_helper.extra_tokens_per_span_targets', 1) - gin.bind_parameter('random_spans_helper.inputs_length', max_input_length) - gin.bind_parameter('random_spans_helper.mean_noise_span_length', 3.0) - gin.bind_parameter('random_spans_helper.noise_density', noise_density) - - # split_tokens - gin.bind_parameter('split_tokens.max_tokens_per_segment', - t5_processors.random_spans_tokens_length()) - - # denoise - gin.bind_parameter('denoise.inputs_fn', - t5_processors.noise_span_to_unique_sentinel) - gin.bind_parameter('denoise.noise_density', noise_density) - gin.bind_parameter('denoise.noise_mask_fn', - t5_processors.random_spans_noise_mask) - gin.bind_parameter('denoise.targets_fn', - t5_processors.nonnoise_span_to_unique_sentinel) - - -class TFInputsTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - - def test_TFDS_single_host_with_eval_holdout(self): - train_ds_gen = tf_inputs.TFDS( - 'c4/en:2.3.0', - data_dir=_TESTDATA, - train=True, - host_id=0, - keys=('text',), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in train_ds_gen(): - print(f'Train: {d}') - break - - valid_ds_gen = tf_inputs.TFDS( - 'c4/en:2.3.0', - data_dir=_TESTDATA, - train=False, - host_id=0, - keys=('text',), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in valid_ds_gen(): - print(f'Eval: {d}') - break - - def test_TFDS_single_host_with_eval_holdout_no_valid_split(self): - train_ds_gen = tf_inputs.TFDS( - 'para_crawl/ende', - data_dir=_TESTDATA, - train=True, - host_id=0, - keys=('en', 'de'), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in train_ds_gen(): - print(f'Train: {d}') - break - - # para_crawl doesn't have a validation set, see that this still doesn't - # crash because of eval_holdout_set. - valid_ds_gen = tf_inputs.TFDS( - 'para_crawl/ende', - data_dir=_TESTDATA, - train=False, - host_id=0, - keys=('en', 'de'), - n_hosts=1, - eval_holdout_size=0.1) - - # Just ensure that this doesn't crash. - for d in valid_ds_gen(): - print(f'Eval: {d}') - break - - def test_TFDS_mnli_split_is_eval(self): - with mock.patch('tensorflow_datasets.load') as tfds_load: - with mock.patch('trax.data.tf_inputs.download_and_prepare', - lambda _, data_dir: data_dir): - _ = tf_inputs.TFDS('glue/mnli', - keys=('premise', 'hypothesis'), - train=False) - call_kwargs = tfds_load.call_args[1] - self.assertEqual(call_kwargs['split'], 'validation_matched') - - def test_TFDS_mnli_split_is_alt_eval(self): - with mock.patch('tensorflow_datasets.load') as tfds_load: - with mock.patch('trax.data.tf_inputs.download_and_prepare', - lambda _, data_dir: data_dir): - _ = tf_inputs.TFDS('glue/mnli', - keys=('premise', 'hypothesis'), - train=False, - use_alt_eval=True) - call_kwargs = tfds_load.call_args[1] - self.assertEqual(call_kwargs['split'], 'validation_mismatched') - - def test_convert_to_unicode(self): - - def dataset1(): - yield (b'Audentes fortuna iuvat.', b'Fortune favors the bold.') - - def dataset2(): - yield (b'\x81aabb', b'Value') - - convert_function1 = tf_inputs.ConvertToUnicode(keys=[0]) - convert_output1 = next(convert_function1(dataset1())) - self.assertEqual(convert_output1[0], 'Audentes fortuna iuvat.') - self.assertEqual(convert_output1[1], b'Fortune favors the bold.') - self.assertIsInstance(convert_output1[0], str) - self.assertIsInstance(convert_output1[1], bytes) - - # Contains an invalid bytes array from the point of view of UTF-8. - try: - convert_function2 = tf_inputs.ConvertToUnicode(keys=[0]) - convert_output2 = next(convert_function2(dataset2())) - except UnicodeDecodeError: - self.fail('ConvertToUnicode threw UnicodeDecodeError.') - self.assertEqual(convert_output2[0], 'aabb') - self.assertIsInstance(convert_output2[0], str) - - def test_tokenize_detokenize(self): - - def dataset(): - yield 'I have a cat.' - - # Character-level. - tok_char = list(tf_inputs.tokenize(dataset(), vocab_type='char')) - self.assertAllEqual(tok_char[0], - np.array([ord(c) for c in 'I have a cat.'])) - detok = tf_inputs.detokenize(tok_char[0], vocab_type='char') - self.assertEqual(detok, 'I have a cat.') - - # Sentencepiece. - tok_spc = list( - tf_inputs.tokenize( - dataset(), - vocab_type='sentencepiece', - vocab_dir=_TESTDATA, - vocab_file='sentencepiece.model')) - self.assertAllEqual(tok_spc[0], np.array([27, 43, 3, 9, 1712, 5])) - detok = tf_inputs.detokenize( - list(tok_spc[0]), - vocab_type='sentencepiece', - vocab_dir=_TESTDATA, - vocab_file='sentencepiece.model') - self.assertEqual(detok, 'I have a cat.') - - # Subword. - tok_sbw = list( - tf_inputs.tokenize( - dataset(), - vocab_type='subword', - vocab_dir=_TESTDATA, - vocab_file='en_8k.subword')) - self.assertAllEqual(tok_sbw[0], np.array([139, 96, 12, 2217, 2, 21])) - detok = tf_inputs.detokenize( - tok_sbw[0], - vocab_type='subword', - vocab_dir=_TESTDATA, - vocab_file='en_8k.subword') - self.assertEqual(detok, 'I have a cat.') - - # bert-lowercase - tok_sbw = list( - tf_inputs.tokenize( - dataset(), - vocab_type='bert-lowercase', - vocab_dir=_TESTDATA, - vocab_file='bert_uncased_vocab.txt')) - self.assertAllEqual(tok_sbw[0], np.array([1045, 2031, 1037, 4937, 1012])) - detok = tf_inputs.detokenize( - tok_sbw[0], - vocab_type='bert-lowercase', - vocab_dir=_TESTDATA, - vocab_file='bert_uncased_vocab.txt') - self.assertEqual(detok, 'i have a cat .') - # note: BERT tokenizer is not reversible, therefore - # difference between original input - - def test_tokenize_keys_reservedids(self): - - def dataset(): - yield ('Cat.', 'Dog.') - - tok_char1 = list( - tf_inputs.tokenize(dataset(), vocab_type='char', n_reserved_ids=5)) - self.assertAllEqual(tok_char1[0][0], np.array([ord(c) + 5 for c in 'Cat.'])) - self.assertAllEqual(tok_char1[0][1], np.array([ord(c) + 5 for c in 'Dog.'])) - - tok_char2 = list( - tf_inputs.tokenize( - dataset(), keys=[0], vocab_type='char', n_reserved_ids=2)) - self.assertAllEqual(tok_char2[0][0], np.array([ord(c) + 2 for c in 'Cat.'])) - self.assertEqual(tok_char2[0][1], 'Dog.') - - def test_tokenize_dict(self): - - def dataset(): - yield {'a': 'Cat.', 'b': 'Dog.'} - - tok_char1 = list(tf_inputs.tokenize(dataset(), vocab_type='char')) - self.assertAllEqual(tok_char1[0]['a'], np.array([ord(c) for c in 'Cat.'])) - self.assertAllEqual(tok_char1[0]['b'], np.array([ord(c) for c in 'Dog.'])) - - tok_char2 = list( - tf_inputs.tokenize(dataset(), keys=['a'], vocab_type='char')) - self.assertAllEqual(tok_char2[0]['a'], np.array([ord(c) for c in 'Cat.'])) - self.assertEqual(tok_char2[0]['b'], 'Dog.') - - def test_vocab_size(self): - # Character-level. - char_size = tf_inputs.vocab_size(vocab_type='char', n_reserved_ids=11) - self.assertEqual(char_size, 256 + 11) - # Sentencepiece. - spc_size = tf_inputs.vocab_size( - vocab_type='sentencepiece', - vocab_dir=_TESTDATA, - vocab_file='sentencepiece.model') - self.assertEqual(spc_size, 32000) - # Subword. - sbw_size = tf_inputs.vocab_size( - vocab_type='subword', vocab_dir=_TESTDATA, vocab_file='en_8k.subword') - self.assertEqual(sbw_size, 8183) - # Bert_uncased. - sbw_size = tf_inputs.vocab_size( - vocab_type='bert-lowercase', - vocab_dir=_TESTDATA, - vocab_file='bert_uncased_vocab.txt') - self.assertEqual(sbw_size, 30522) - - def test_c4_bare_preprocess_fn(self): - dataset = _c4_dataset() - - example = list(tfds.as_numpy(dataset.take(1)))[0] - - # Targets are NOT in the example. - self.assertNotIn('targets', example) - self.assertIn('text', example) - text = example['text'] - - # This should convert the dataset to an inputs/targets that are tokenized. - dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) - - example = list(tfds.as_numpy(dataset.take(1)))[0] - - # Earlier text is now stored in targets_pretokenized - self.assertIn('targets_pretokenized', example) - self.assertEqual(example['targets_pretokenized'], text) - - # Targets are now tokenized. - self.assertIn('targets', example) - self.assertIsInstance(example['targets'], np.ndarray) - self.assertEqual(example['targets'].dtype, np.int64) - self.assertGreater(len(example['targets']), 0) - self.assertEqual(example['targets'][-1], 1) # we add EOS at the end. - - # Inputs exist but is empty because t5 preprocessors' unsupervised wasn't - # gin configured with any. - self.assertIn('inputs', example) - self.assertEqual(len(example['inputs']), 0) - - def test_c4_preprocess(self): - - def load_c4_dataset(split='train'): - dataset = _c4_dataset(split=split) - return dataset.map(lambda example: (example, example['text'])) - - def examine_processed_dataset(proc_dataset): - count = 0 - lengths = [] - for example in tfds.as_numpy(proc_dataset): - count += 1 - ex = example[0] - # Targets are in the example. - self.assertIn('targets', ex) - self.assertEqual(ex['targets'].dtype, np.int64) - lengths.append(len(ex['targets'])) - return count, lengths - - unfiltered_count = 0 - for example in tfds.as_numpy(load_c4_dataset()): - unfiltered_count += 1 - # Targets are NOT in the example. - self.assertNotIn('targets', example[0]) - - proc_dataset = tf_inputs.c4_preprocess(load_c4_dataset(), False, 2048) - - # `examine_processed_dataset` has some asserts in it. - proc_count, char_lengths = examine_processed_dataset(proc_dataset) - - # Both the original and filtered datasets have examples. - self.assertGreater(unfiltered_count, 0) - self.assertGreater(proc_count, 0) - - # Because we filter out some entries on length. - self.assertLess(proc_count, unfiltered_count) - - # Preprocess using the sentencepiece model in testdata. - spc_proc_dataset = tf_inputs.c4_preprocess( - load_c4_dataset(), - False, - 2048, - tokenization='spc', - spm_path=_spm_path()) - - spc_proc_count, spc_lengths = examine_processed_dataset(spc_proc_dataset) - - # spc shortens the target sequence a lot, should be almost equal to - # unfiltered - self.assertLessEqual(proc_count, spc_proc_count) - self.assertEqual(unfiltered_count, spc_proc_count) - - # Assert all spc_lengths are lesser than their char counterparts. - for spc_len, char_len in zip(spc_lengths, char_lengths): - self.assertLessEqual(spc_len, char_len) - - def test_c4(self): - gin.bind_parameter('c4_preprocess.max_target_length', 2048) - gin.bind_parameter('c4_preprocess.tokenization', 'spc') - gin.bind_parameter('c4_preprocess.spm_path', _spm_path()) - - # Just make sure this doesn't throw. - _ = tf_inputs.data_streams( - 'c4', - data_dir=_TESTDATA, - input_name='targets', - target_name='text', - preprocess_fn=tf_inputs.c4_preprocess) - - def test_c4_bare_preprocess_fn_denoising_objective(self): - _t5_gin_config() - - dataset = _c4_dataset() - dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) - - example = list(tfds.as_numpy(dataset.take(1)))[0] - - # Assertions now. - - self.assertIn('targets', example) - targets = example['targets'] - self.assertIsInstance(targets, np.ndarray) - self.assertEqual(targets.dtype, np.int64) - self.assertGreater(len(targets), 0) - - self.assertIn('inputs', example) - _inputs = example['inputs'] # pylint: disable=invalid-name - self.assertIsInstance(_inputs, np.ndarray) - self.assertEqual(_inputs.dtype, np.int64) - self.assertGreater(len(_inputs), 0) - - # WHP inputs will have the bulk of the text. - self.assertGreater(len(_inputs), len(targets)) - - # WHP there will be one sentinel token in the inputs and targets. - inputs_counter = collections.Counter(_inputs.tolist()) - targets_counter = collections.Counter(targets.tolist()) - self.assertEqual(1, inputs_counter[31999]) - self.assertEqual(1, targets_counter[31999]) - - def test_c4_pretrain(self): - _t5_gin_config() - - gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path()) - - gin.bind_parameter('batcher.batch_size_per_device', 8) - gin.bind_parameter('batcher.eval_batch_size', 8) - gin.bind_parameter('batcher.max_eval_length', 50) - gin.bind_parameter('batcher.buckets', ([51], [8, 1])) - - # Just make sure this doesn't throw. - _ = tf_inputs.data_streams( - 'c4', - data_dir=_TESTDATA, - input_name='inputs', - target_name='targets', - bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn) - - def test_generic_text_dataset_preprocess_fn(self): - dataset = _load_dataset('squad/v1.1:3.0.0') - - example, = tfds.as_numpy(dataset.take(1)) - - self.assertNotIn('inputs', example) - self.assertNotIn('targets', example) - - proc_dataset = tf_inputs.generic_text_dataset_preprocess_fn( - dataset, - spm_path=_spm_path(), - text_preprocess_fns=[lambda ds, training: t5_processors.squad(ds)], - copy_pretokenized=True, - debug_print_examples=True, - debug_print_examples_rate=1.0) - - proc_example, = tfds.as_numpy(proc_dataset.take(1)) - - self.assertIn('inputs', proc_example) - self.assertIn('targets', proc_example) - - self.assertEqual(proc_example['inputs'].dtype, np.int32) - self.assertEqual(proc_example['targets'].dtype, np.int32) - - # TODO(afrozm): Why does this test take so much time? - def test_inputs_using_generic_text_dataset_preprocess_fn(self): - gin.bind_parameter('generic_text_dataset_preprocess_fn.spm_path', - _spm_path()) - gin.bind_parameter('generic_text_dataset_preprocess_fn.text_preprocess_fns', - [lambda ds, training: t5_processors.squad(ds)]) - - # Just make sure this doesn't throw. - def data_streams(): - return tf_inputs.data_streams( - 'squad', - data_dir=_TESTDATA, - input_name='inputs', - target_name='targets', - bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, - shuffle_buffer_size=1) - - n_devices = 3 - - squad_inputs = inputs.batcher( - data_streams=data_streams, - max_eval_length=512, - buckets=([ - 513, - ], [n_devices, n_devices])) - - eval_stream = squad_inputs.eval_stream(n_devices) - inps, tgts, _ = next(eval_stream) - - # We can only assert that the batch dim gets divided by n_devices. - self.assertEqual(inps.shape[0] % n_devices, 0) - self.assertEqual(tgts.shape[0] % n_devices, 0) - - def test_filter_dataset_on_len(self): - # {1, 2}, {2, 4}, {3, 6} ... {10, 20} - ds = _test_dataset_ints(range(1, 11), range(2, 21, 2)) - - ds1 = tf_inputs.filter_dataset_on_len(ds, True, { - 'inputs': [4, 8], - 'targets': [14, 20] - }) - # Only {7, 14} and {8, 16} satisfy this. - self.assertLen(list(ds1.as_numpy_iterator()), 2) - - ds2 = tf_inputs.filter_dataset_on_len( - ds, - False, - len_map={ - 'inputs': [4, 8], - 'targets': [14, 20] - }, - filter_on_eval=False) - # This is eval and we aren't supposed to filter it. - self.assertLen(list(ds2.as_numpy_iterator()), 10) - - ds3 = tf_inputs.filter_dataset_on_len( - ds, - False, - len_map={ - 'inputs': [4, 8], - 'targets': [14, 20] - }, - filter_on_eval=True) - # This is eval and we are asked to filter it. - self.assertLen(list(ds3.as_numpy_iterator()), 2) - - def test_truncate_dataset_on_len(self): - ds = _test_dataset_ints([5, 6, 7], [8, 9, 10]) - ds1 = tf_inputs.truncate_dataset_on_len( - ds, True, len_map={ - 'inputs': 6, - 'targets': 4 - }) - expected_ds = _test_dataset_ints([5, 6, 6], [4, 4, 4]) - - # training, should filter. - assert_dataset(ds1, list(expected_ds.as_numpy_iterator())) - - # not Training, shouldn't filter. - ds2 = tf_inputs.truncate_dataset_on_len( - ds, False, len_map={ - 'inputs': 6, - 'targets': 4 - }) - assert_dataset(ds2, list(ds.as_numpy_iterator())) - - # not Training, but asked to filter, should filter. - ds3 = tf_inputs.truncate_dataset_on_len( - ds, False, len_map={ - 'inputs': 6, - 'targets': 4 - }, truncate_on_eval=True) - assert_dataset(ds3, list(expected_ds.as_numpy_iterator())) - - def test_get_t5_preprocessor_by_name(self): - gin.clear_config() - - gin.parse_config(""" - get_t5_preprocessor_by_name.name = 'rekey' - get_t5_preprocessor_by_name.fn_kwargs = {'key_map': {'inputs': 'other', 'targets': 'text'}} - """) - prep_rekey = tf_inputs.get_t5_preprocessor_by_name() - og_dataset = tf.data.Dataset.from_tensors({ - 'text': 'That is good.', - 'other': 'That is bad.' - }) - training = True - dataset = prep_rekey(og_dataset, training) - assert_dataset(dataset, { - 'inputs': 'That is bad.', - 'targets': 'That is good.' - }) - - def test_pad_dataset_to_length(self): - ds = _test_dataset_ints([5, 6, 7], [6, 7, 8]) - ds1 = tf_inputs.pad_dataset_to_length( - ds, True, len_map={ - 'inputs': 7, - 'targets': 10 - }) - - expected_ds = [ - { - 'inputs': np.array([1, 1, 1, 1, 1, 0, 0], dtype=np.int64), - 'targets': np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 1, 1, 1, 1, 0], dtype=np.int64), - 'targets': np.array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 1, 1, 1, 1, 1], dtype=np.int64), - 'targets': np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0], dtype=np.int64), - }, - ] - - assert_dataset(ds1, expected_ds) - - def test_lm_token_preprocessing(self): - ds = _test_dataset_ints([1, 2, 3], [3, 2, 1]) - ds1 = tf_inputs.lm_token_preprocessing(ds, True) - - # pylint: disable=bad-whitespace - expected_ds = [ - { - 'inputs': np.array([1, 0, 1, 1, 1], dtype=np.int64), - 'targets': np.array([1, 0, 1, 1, 1], dtype=np.int64), - 'mask': np.array([0, 0, 1, 1, 1], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 0, 1, 1], dtype=np.int64), - 'targets': np.array([1, 1, 0, 1, 1], dtype=np.int64), - 'mask': np.array([0, 0, 0, 1, 1], dtype=np.int64), - }, - { - 'inputs': np.array([1, 1, 1, 0, 1], dtype=np.int64), - 'targets': np.array([1, 1, 1, 0, 1], dtype=np.int64), - 'mask': np.array([0, 0, 0, 0, 1], dtype=np.int64), - }, - ] - # pylint: enable=bad-whitespace - - assert_dataset(ds1, expected_ds) - - def test_create_bert_inputs(self): - inputs_sentences_1 = [np.array([100, 150, 200])] - inputs_sentences_2 = [np.array([300, 500])] - labels = [np.array(1)] - - create_inputs_1 = tf_inputs.CreateBertInputs(False) - create_inputs_2 = tf_inputs.CreateBertInputs(True) - for res in create_inputs_1(zip(inputs_sentences_1, labels)): - values, segment_embs, _, label, weight = res - self.assertAllEqual(values, np.array([101, 100, 150, 200, 102])) - self.assertAllEqual(segment_embs, np.zeros(5)) - self.assertEqual(label, np.int64(1)) - self.assertEqual(weight, np.int64(1)) - - for res in create_inputs_2( - zip(inputs_sentences_1, inputs_sentences_2, labels)): - values, segment_embs, _, label, weight = res - self.assertAllEqual(values, - np.array([101, 100, 150, 200, 102, 300, 500, 102])) - exp_segment = np.concatenate((np.zeros(5), np.ones(3))) - self.assertAllEqual(segment_embs, exp_segment) - self.assertEqual(label, np.int64(1)) - self.assertEqual(weight, np.int64(1)) - - def test_mask_random_tokens(self): - """Test only standard tokens. - - This test deals with sentences composed of two parts: [100 CLS tokens, 100 - chosen standard tokens]. CLS is the token that is added at the beginning of - the sentence and there is only one token in standard scenario. It is never - masked because it is not a part of the sentence. - This tests whether mask_random_tokens will: - - mask only standard tokens - - mask expected number of tokens (15 percent candidates for masking) - """ - cls_token = 101 - mask_token = 103 - example_standard_token = 1001 - test_case_row = np.array([cls_token] * 100 + [example_standard_token] * 100) - test_case = [(test_case_row.copy(),)] - - out, original_tokens, token_weights = next( - tf_inputs.mask_random_tokens(test_case)) - # test whether original tokens are unchanged - self.assertAllEqual(test_case_row, original_tokens) - - self.assertEqual(1, token_weights.sum()) - self.assertEqual( - 15, - (token_weights > 0).sum()) # we should have 15 candidates for masking - - # 101 is a special token, so only 1001 should be masked - self.assertAllEqual(out[:100], test_case_row[:100]) - - # Each candidate has 0.8 probability to be masked while others have 0, so - # no more than 15 tokens with MASK - self.assertLessEqual((out == mask_token).sum(), 15) - - def test_bert_next_sentence_prediction_inputs(self): - stream = tf_inputs.BertNextSentencePredictionInputs( - 'c4/en:2.3.0', data_dir=_TESTDATA, train=False, shuffle_size=1) - exp_sent1 = 'Police were called to the carriageway around 6.' - exp_sent2 = 'I am sorry we did not see how lost and alone you felt.' - sent1, sent2, label = next(stream()) - self.assertEqual(exp_sent1, sent1) - self.assertEqual(exp_sent2, sent2) - self.assertFalse(label) - - def test_process_single_mathqa_example_0(self): - # This is the first problem in the MathQA dataset. - example = { - 'Problem': - "the banker ' s gain of a certain sum due 3 years hence at 10 % " - 'per annum is rs . 36 . what is the present worth ?', - 'Rationale': - '"explanation : t = 3 years r = 10 % td = ( bg × 100 ) / tr = ( ' - '36 × 100 ) / ( 3 × 10 ) = 12 × 10 = rs . 120 td = ( pw × tr )' - ' / 100 ⇒ 120 = ( pw × 3 × 10 ) / 100 ⇒ 1200 = pw × 3 pw = ' - '1200 / 3 = rs . 400 answer : option a"', - 'options': - 'a ) rs . 400 , b ) rs . 300 , c ) rs . 500 , d ) rs . 350 , e ) ' - 'none of these', - 'correct': - 'a', - 'annotated_formula': - 'divide(multiply(const_100, divide(multiply(36, const_100), ' - 'multiply(3, 10))), multiply(3, 10))', - 'linear_formula': - 'multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)|multiply(#2,const_100)|divide(#3,#1)|', - 'category': - 'gain' - } - - answer_num, python_result, python_program, list_op, list_num = tf_inputs.process_single_mathqa_example( - example) - self.assertEqual(answer_num, - 400) # we know it, because correct answer is a) - self.assertEqual(python_result, [3600.0, 30.0, 120.0, 12000.0, 400.0]) - - self.assertEqual(python_program, [ - 't0 = n2 * 100.0', 't1 = n0 * n1', 't2 = t0 / t1', 't3 = t2 * 100.0', - 't4 = t3 / t1' - ]) - self.assertEqual(list_op, [ - 'multiply(n2,const_100)', 'multiply(n0,n1)', 'divide(#0,#1)', - 'multiply(#2,const_100)', 'divide(#3,#1)' - ]) - self.assertEqual(list_num, [3.0, 10.0, 36.0]) - - def test_process_single_mathqa_example_1(self): - # This is the third problem in the MathQA dataset. - example = { - 'Problem': - 'sophia finished 2 / 3 of a book . she calculated that she ' - 'finished 90 more pages than she has yet to read . how long is her' - ' book ?', - 'Rationale': - 'let xx be the total number of pages in the book , then she ' - 'finished 23 ⋅ x 23 ⋅ x pages . then she has x − 23 ⋅ x = ' - '13 ⋅ xx − 23 ⋅ x = 13 ⋅ x pages left . 23 ⋅ x − 13 ' - '⋅ x = 9023 ⋅ x − 13 ⋅ x = 90 13 ⋅ x = 9013 ⋅ x = 90 x' - ' = 270 x = 270 so the book is 270 pages long . answer : b', - 'options': 'a ) 229 , b ) 270 , c ) 877 , d ) 266 , e ) 281', - 'correct': 'b', - 'annotated_formula': 'divide(90, subtract(const_1, divide(2, 3)))', - 'linear_formula': 'divide(n0,n1)|subtract(const_1,#0)|divide(n2,#1)', - 'category': 'general' - } - - answer_num, python_result, python_program, list_op, list_num = tf_inputs.process_single_mathqa_example( - example) - self.assertEqual(answer_num, - 270) # we know it, because correct answer is b) - self.assertAllClose( - python_result, - [0.6666666666666666, 0.33333333333333337, 269.99999999999994]) - self.assertEqual(python_program, - ['t0 = n0 / n1', 't1 = 1.0 - t0', 't2 = n2 / t1']) - self.assertEqual(list_op, - ['divide(n0,n1)', 'subtract(const_1,#0)', 'divide(n2,#1)']) - self.assertEqual(list_num, [2.0, 3.0, 90.0]) - - def test_process_single_mathqa_example_with_import(self): - # This is a training MathQA problem which involve an import. - example = { - 'Problem': - 'the length of a rectangular garden is three times its width . if ' - 'the area of the rectangular garden is 588 square meters , then ' - 'what is the width of the rectangular garden ?', - 'Rationale': - '\"let x be the width of the garden . 3 x ^ 2 = 588 x ^ 2 = 196 x ' - '= 14 the answer is c .\"', - 'options': - 'a ) 12 , b ) 13 , c ) 14 , d ) 15 , e ) 16', - 'correct': - 'c', - 'annotated_formula': - 'sqrt(divide(588, const_3))', - 'linear_formula': - 'divide(n0,const_3)|sqrt(#0)|', - 'category': - 'geometry' - } - - answer_num, python_result, python_program, list_op, list_num = tf_inputs.process_single_mathqa_example( - example) - self.assertEqual(answer_num, 14) # we know it, because correct answer is c) - self.assertAllClose(python_result, [196, 14]) - self.assertEqual( - python_program, - ['t0 = n0 / 3.0', 't1 = math.sqrt(max(0, t0))']) - self.assertEqual(list_op, ['divide(n0,const_3)', 'sqrt(#0)']) - self.assertEqual(list_num, [588]) - - # Below we execute twice the Python program and once the DSL program. - target_values = 'import math\n' - problem = example['Problem'] - for i in range(len(list_num)): - target_values += 'n{} = {}\n'.format(i, list_num[i]) - problem += ' n{} = {}'.format(i, list_num[i]) - target_values += '\n'.join(python_program[:-1]) - final_line = python_program[-1].split('=')[1] - target_values += '\nanswer ={}'.format(final_line) - var_dict = {} - exec(target_values, globals(), var_dict) # pylint: disable=exec-used - self.assertAllClose(var_dict['answer'], 14) - self.assertAllClose( - tf_inputs.execute_mathqa_program(problem, target_values.split('\n')), - 14) - self.assertAllClose( - tf_inputs.execute_mathqa_dsl_program(problem, - [example['linear_formula']]), 14) - - - def test_sentencepiece_tokenize(self): - def dataset(): - yield 'I have a cat.' - - examples = [] - for example in tf_inputs.sentencepiece_tokenize(dataset(), _spm_path()): - examples.append(example) - toks = list(examples[0]) - self.assertSequenceEqual([27, 43, 3, 9, 1712, 5], toks) - - -if __name__ == '__main__': - tf.test.main() diff --git a/trax/data/tokenizer.py b/trax/data/tokenizer.py deleted file mode 100644 index 64081f4da..000000000 --- a/trax/data/tokenizer.py +++ /dev/null @@ -1,188 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A simple invertible tokenizer. - -Converts from a unicode string to a list of tokens -(represented as Unicode strings). - -This tokenizer has the following desirable properties: - - It is invertible. - - Alphanumeric characters are broken away from non-alphanumeric characters. - - A single space between words does not produce an extra token. - - The full Unicode punctuation and separator set is recognized. - -The tokenization algorithm is as follows: - -1. Split the text into a list of tokens, splitting at every boundary of an - alphanumeric character and a non-alphanumeric character. This produces - a list which alternates between "alphanumeric tokens" - (strings of alphanumeric characters) and "non-alphanumeric tokens" - (strings of non-alphanumeric characters). - -2. Remove every token consisting of a single space, unless it is - the very first or very last token in the list. These tokens are now - implied by the fact that there are two adjacent alphanumeric tokens. - -e.g. u"Dude - that's so cool." - -> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."] -""" - -import collections -import sys -import unicodedata - -from absl import logging -import six -import tensorflow as tf - - -# This set contains all letter and number characters. -_ALPHANUMERIC_CHAR_SET = set( - six.unichr(i) for i in range(sys.maxunicode) - if (unicodedata.category(six.unichr(i)).startswith("L") or - unicodedata.category(six.unichr(i)).startswith("N"))) - - -def encode(text): - """Encode a unicode string as a list of tokens. - - Args: - text: a unicode string - Returns: - a list of tokens as Unicode strings - """ - if not text: - return [] - ret = [] - token_start = 0 - # Classify each character in the input string - is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] - for pos in range(1, len(text)): - if is_alnum[pos] != is_alnum[pos - 1]: - token = text[token_start:pos] - if token != u" " or token_start == 0: - ret.append(token) - token_start = pos - final_token = text[token_start:] - ret.append(final_token) - return ret - - -def decode(tokens): - """Decode a list of tokens to a unicode string. - - Args: - tokens: a list of Unicode strings - Returns: - a unicode string - """ - token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] - ret = [] - for i, token in enumerate(tokens): - if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: - ret.append(u" ") - ret.append(token) - return "".join(ret) - - -def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True): - """Reads files matching a wildcard pattern, yielding the contents. - - Args: - filepattern: A wildcard pattern matching one or more files. - max_lines: If set, stop reading after reading this many lines. - split_on_newlines: A boolean. If true, then split files by lines and strip - leading and trailing whitespace from each line. Otherwise, treat each - file as a single string. - - Yields: - The contents of the files as lines, if split_on_newlines is True, or - the entire contents of each file if False. - """ - filenames = sorted(tf.io.gfile.glob(filepattern)) - lines_read = 0 - for filename in filenames: - with tf.io.gfile.GFile(filename) as f: - if split_on_newlines: - for line in f: - yield line.strip() - lines_read += 1 - if max_lines and lines_read >= max_lines: - return - - else: - if max_lines: - doc = [] - for line in f: - doc.append(line) - lines_read += 1 - if max_lines and lines_read >= max_lines: - yield "".join(doc) - return - yield "".join(doc) - - else: - yield f.read() - - -def corpus_token_counts( - text_filepattern, corpus_max_lines, split_on_newlines=True): - """Read the corpus and compute a dictionary of token counts. - - Args: - text_filepattern: A pattern matching one or more files. - corpus_max_lines: An integer; maximum total lines to read. - split_on_newlines: A boolean. If true, then split files by lines and strip - leading and trailing whitespace from each line. Otherwise, treat each - file as a single string. - - Returns: - a dictionary mapping token to count. - """ - counts = collections.Counter() - for doc in _read_filepattern( - text_filepattern, - max_lines=corpus_max_lines, - split_on_newlines=split_on_newlines): - counts.update(encode(doc)) - - return counts - - -def vocab_token_counts(text_filepattern, max_lines): - """Read a vocab file and return a dictionary of token counts. - - Reads a two-column CSV file of tokens and their frequency in a dataset. The - tokens are presumed to be generated by encode() or the equivalent. - - Args: - text_filepattern: A pattern matching one or more files. - max_lines: An integer; maximum total lines to read. - - Returns: - a dictionary mapping token to count. - """ - ret = {} - for i, line in enumerate( - _read_filepattern(text_filepattern, max_lines=max_lines)): - if "," not in line: - logging.warning("Malformed vocab line #%d '%s'", i, line) - continue - - token, count = line.rsplit(",", 1) - ret[token] = int(count) - - return ret diff --git a/trax/data/tokenizer_test.py b/trax/data/tokenizer_test.py deleted file mode 100644 index 593ebe83d..000000000 --- a/trax/data/tokenizer_test.py +++ /dev/null @@ -1,136 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.data..tokenizer.""" -import os -import random - -import six -from six.moves import range # pylint: disable=redefined-builtin -import tensorflow.compat.v1 as tf -from trax.data import tokenizer - - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, "testdata") - - -class TokenizerTest(tf.test.TestCase): - - def test_encode(self): - self.assertListEqual( - [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."], - tokenizer.encode(u"Dude - that's so cool.")) - self.assertListEqual([u"Łukasz", u"est", u"nÊ", u"en", u"1981", u"."], - tokenizer.encode(u"Łukasz est nÊ en 1981.")) - self.assertListEqual([u" ", u"Spaces", u"at", u"the", u"ends", u" "], - tokenizer.encode(u" Spaces at the ends ")) - self.assertListEqual([u"802", u".", u"11b"], tokenizer.encode(u"802.11b")) - self.assertListEqual([u"two", u". \n", u"lines"], - tokenizer.encode(u"two. \nlines")) - - def test_decode(self): - self.assertEqual( - u"Dude - that's so cool.", - tokenizer.decode( - [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."])) - - def test_invertibility_on_random_strings(self): - for _ in range(1000): - s = u"".join(six.unichr(random.randint(0, 65535)) for _ in range(10)) - self.assertEqual(s, tokenizer.decode(tokenizer.encode(s))) - - -class TestTokenCounts(tf.test.TestCase): - - def setUp(self): - super(TestTokenCounts, self).setUp() - self.corpus_path = os.path.join(_TESTDATA, "corpus-*.txt") - self.vocab_path = os.path.join(_TESTDATA, "vocab-*.txt") - - def test_corpus_token_counts_split_on_newlines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=0, split_on_newlines=True) - - expected = { - u"'": 2, - u".": 2, - u". ": 1, - u"... ": 1, - u"Groucho": 1, - u"Marx": 1, - u"Mitch": 1, - u"Hedberg": 1, - u"I": 3, - u"in": 2, - u"my": 2, - u"pajamas": 2, - } - self.assertDictContainsSubset(expected, token_counts) - self.assertNotIn(u".\n\n", token_counts) - self.assertNotIn(u"\n", token_counts) - - def test_corpus_token_counts_no_split_on_newlines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=0, split_on_newlines=False) - - self.assertDictContainsSubset({u".\n\n": 2, u"\n": 3}, token_counts) - - def test_corpus_token_counts_split_with_max_lines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=5, split_on_newlines=True) - - self.assertIn(u"slept", token_counts) - self.assertNotIn(u"Mitch", token_counts) - - def test_corpus_token_counts_no_split_with_max_lines(self): - token_counts = tokenizer.corpus_token_counts( - self.corpus_path, corpus_max_lines=5, split_on_newlines=False) - - self.assertIn(u"slept", token_counts) - self.assertNotIn(u"Mitch", token_counts) - self.assertDictContainsSubset({ - u".\n\n": 1, - u"\n": 2, - u".\n": 1 - }, token_counts) - - def test_vocab_token_counts(self): - token_counts = tokenizer.vocab_token_counts(self.vocab_path, 0) - - expected = { - u"lollipop": 8, - u"reverberated": 12, - u"kattywampus": 11, - u"balderdash": 10, - u"jiggery-pokery": 14, - } - self.assertDictEqual(expected, token_counts) - - def test_vocab_token_counts_with_max_lines(self): - # vocab-1 has 2 lines, vocab-2 has 3 - token_counts = tokenizer.vocab_token_counts(self.vocab_path, 5) - - expected = { - u"lollipop": 8, - u"reverberated": 12, - u"kattywampus": 11, - u"balderdash": 10, - } - self.assertDictEqual(expected, token_counts) - - -if __name__ == "__main__": - tf.test.main() diff --git a/trax/examples/Deep_N_Gram_Models.ipynb b/trax/examples/Deep_N_Gram_Models.ipynb deleted file mode 100644 index 1972d2df5..000000000 --- a/trax/examples/Deep_N_Gram_Models.ipynb +++ /dev/null @@ -1,1040 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - }, - "papermill": { - "duration": 297.094983, - "end_time": "2020-10-19T05:28:36.576660", - "environment_variables": {}, - "exception": null, - "input_path": "__notebook__.ipynb", - "output_path": "__notebook__.ipynb", - "parameters": {}, - "start_time": "2020-10-19T05:23:39.481677", - "version": "2.1.0" - }, - "colab": { - "name": "Deep N-Gram Models", - "provenance": [], - "include_colab_link": true - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lAAzPCP8n05S" - }, - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CcV2B-3LnvBk" - }, - "source": [ - "Author - [@SauravMaheshkar](https://github.com/SauravMaheshkar)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.024472, - "end_time": "2020-10-19T05:23:45.163806", - "exception": false, - "start_time": "2020-10-19T05:23:45.139334", - "status": "completed" - }, - "tags": [], - "id": "uEg7rw6fnr0q" - }, - "source": [ - "# Downloading the Trax Package" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.024546, - "end_time": "2020-10-19T05:23:45.211638", - "exception": false, - "start_time": "2020-10-19T05:23:45.187092", - "status": "completed" - }, - "tags": [], - "id": "7iVotT-qnr0q" - }, - "source": [ - "[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information." - ] - }, - { - "cell_type": "code", - "metadata": { - "_kg_hide-input": false, - "_kg_hide-output": true, - "execution": { - "iopub.execute_input": "2020-10-19T05:23:45.265606Z", - "iopub.status.busy": "2020-10-19T05:23:45.264326Z", - "iopub.status.idle": "2020-10-19T05:24:40.876515Z", - "shell.execute_reply": "2020-10-19T05:24:40.877287Z" - }, - "papermill": { - "duration": 55.642763, - "end_time": "2020-10-19T05:24:40.877583", - "exception": false, - "start_time": "2020-10-19T05:23:45.234820", - "status": "completed" - }, - "tags": [], - "id": "LTV7nHkWnr0q" - }, - "source": [ - "%%capture\n", - "!pip install trax" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.121469, - "end_time": "2020-10-19T05:24:41.120599", - "exception": false, - "start_time": "2020-10-19T05:24:40.999130", - "status": "completed" - }, - "tags": [], - "id": "s4e-X6Ranr0s" - }, - "source": [ - "# Importing Packages" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.117117, - "end_time": "2020-10-19T05:24:41.355694", - "exception": false, - "start_time": "2020-10-19T05:24:41.238577", - "status": "completed" - }, - "tags": [], - "id": "zaoHVZj0nr0s" - }, - "source": [ - "In this notebook we will use the following packages:\n", - "\n", - "* [**Pandas**](https://pandas.pydata.org/) is a fast, powerful, flexible and easy to use open-source data analysis and manipulation tool, built on top of the Python programming language. It offers a fast and efficient DataFrame object for data manipulation with integrated indexing.\n", - "* [**os**](https://docs.python.org/3/library/os.html) module provides a portable way of using operating system dependent functionality.\n", - "* [**trax**](https://trax-ml.readthedocs.io/en/latest/trax.html) is an end-to-end library for deep learning that focuses on clear code and speed.\n", - "* [**random**](https://docs.python.org/3/library/random.html) module implements pseudo-random number generators for various distributions.\n", - "* [**itertools**](https://docs.python.org/3/library/itertools.html) module implements a number of iterator building blocks inspired by constructs from APL, Haskell, and SML. Each has been recast in a form suitable for Python." - ] - }, - { - "cell_type": "code", - "metadata": { - "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", - "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", - "execution": { - "iopub.execute_input": "2020-10-19T05:24:41.598509Z", - "iopub.status.busy": "2020-10-19T05:24:41.597670Z", - "iopub.status.idle": "2020-10-19T05:24:54.656423Z", - "shell.execute_reply": "2020-10-19T05:24:54.655287Z" - }, - "papermill": { - "duration": 13.181434, - "end_time": "2020-10-19T05:24:54.656623", - "exception": false, - "start_time": "2020-10-19T05:24:41.475189", - "status": "completed" - }, - "tags": [], - "id": "h8vjYA-8nr0s" - }, - "source": [ - "import pandas as pd \n", - "import os\n", - "import trax\n", - "import trax.fastmath.numpy as np\n", - "import random as rnd\n", - "from trax import fastmath\n", - "from trax import layers as tl" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.118759, - "end_time": "2020-10-19T05:24:54.899617", - "exception": false, - "start_time": "2020-10-19T05:24:54.780858", - "status": "completed" - }, - "tags": [], - "id": "ZZaUGa2Lnr0s" - }, - "source": [ - "# Loading the Data" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.122704, - "end_time": "2020-10-19T05:24:55.144895", - "exception": false, - "start_time": "2020-10-19T05:24:55.022191", - "status": "completed" - }, - "tags": [], - "id": "WbwaTxIFnr0s" - }, - "source": [ - "For this project, I've used the [gothic-literature](https://www.kaggle.com/charlesaverill/gothic-literature), [shakespeare-plays](https://www.kaggle.com/kingburrito666/shakespeare-plays) and [shakespeareonline](https://www.kaggle.com/kewagbln/shakespeareonline) datasets from the Kaggle library. \n", - "\n", - "We perform the following steps for loading in the data:\n", - "\n", - "* Iterate over all the directories in the `/kaggle/input/` directory\n", - "* Filter out `.txt` files\n", - "* Make a `lines` list containing the individual lines from all the datasets combined" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:55.385118Z", - "iopub.status.busy": "2020-10-19T05:24:55.384122Z", - "iopub.status.idle": "2020-10-19T05:24:55.716407Z", - "shell.execute_reply": "2020-10-19T05:24:55.715479Z" - }, - "papermill": { - "duration": 0.456359, - "end_time": "2020-10-19T05:24:55.716572", - "exception": false, - "start_time": "2020-10-19T05:24:55.260213", - "status": "completed" - }, - "tags": [], - "id": "a7mudKI5nr0s" - }, - "source": [ - "directories = os.listdir('/kaggle/input/')\n", - "lines = []\n", - "for directory in directories:\n", - " for filename in os.listdir(os.path.join('/kaggle/input',directory)):\n", - " if filename.endswith(\".txt\"):\n", - " with open(os.path.join(os.path.join('/kaggle/input',directory), filename)) as files:\n", - " for line in files: \n", - " processed_line = line.strip()\n", - " if processed_line:\n", - " lines.append(processed_line)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.113664, - "end_time": "2020-10-19T05:24:55.951966", - "exception": false, - "start_time": "2020-10-19T05:24:55.838302", - "status": "completed" - }, - "tags": [], - "id": "EPifypFdnr0s" - }, - "source": [ - "## Pre-Processing" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.119888, - "end_time": "2020-10-19T05:24:56.194726", - "exception": false, - "start_time": "2020-10-19T05:24:56.074838", - "status": "completed" - }, - "tags": [], - "id": "eU58tWP3nr0s" - }, - "source": [ - "### Converting to Lowercase\n", - "\n", - "Converting all the characters in the `lines` list to **lowercase**." - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:56.496346Z", - "iopub.status.busy": "2020-10-19T05:24:56.470575Z", - "iopub.status.idle": "2020-10-19T05:24:56.569027Z", - "shell.execute_reply": "2020-10-19T05:24:56.569637Z" - }, - "papermill": { - "duration": 0.253923, - "end_time": "2020-10-19T05:24:56.569875", - "exception": false, - "start_time": "2020-10-19T05:24:56.315952", - "status": "completed" - }, - "tags": [], - "id": "QAxU3uzunr0s" - }, - "source": [ - "for i, line in enumerate(lines):\n", - " lines[i] = line.lower()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.11122, - "end_time": "2020-10-19T05:24:56.795120", - "exception": false, - "start_time": "2020-10-19T05:24:56.683900", - "status": "completed" - }, - "tags": [], - "id": "voNUJBrRnr0s" - }, - "source": [ - "### Converting into Tensors\n", - "\n", - "Creating a function to convert each line into a tensor by converting each character into it's ASCII value. And adding a optional `EOS`(**End of statement**) character." - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:57.032580Z", - "iopub.status.busy": "2020-10-19T05:24:57.029673Z", - "iopub.status.idle": "2020-10-19T05:24:57.037237Z", - "shell.execute_reply": "2020-10-19T05:24:57.036444Z" - }, - "papermill": { - "duration": 0.131432, - "end_time": "2020-10-19T05:24:57.037392", - "exception": false, - "start_time": "2020-10-19T05:24:56.905960", - "status": "completed" - }, - "tags": [], - "id": "J0F2sUJfnr0s" - }, - "source": [ - "def line_to_tensor(line, EOS_int=1):\n", - " \n", - " tensor = []\n", - " for c in line:\n", - " c_int = ord(c)\n", - " tensor.append(c_int)\n", - " \n", - " tensor.append(EOS_int)\n", - "\n", - " return tensor" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.109763, - "end_time": "2020-10-19T05:24:57.259043", - "exception": false, - "start_time": "2020-10-19T05:24:57.149280", - "status": "completed" - }, - "tags": [], - "id": "zYT5__Danr0s" - }, - "source": [ - "### Creating a Batch Generator\n", - "\n", - "Here, we create a `batch_generator()` function to yield a batch and mask generator. We perform the following steps:\n", - "\n", - "* Shuffle the lines if not shuffled\n", - "* Convert the lines into a Tensor\n", - "* Pad the lines if it's less than the maximum length\n", - "* Generate a mask " - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:57.491159Z", - "iopub.status.busy": "2020-10-19T05:24:57.490293Z", - "iopub.status.idle": "2020-10-19T05:24:57.503719Z", - "shell.execute_reply": "2020-10-19T05:24:57.502899Z" - }, - "papermill": { - "duration": 0.134497, - "end_time": "2020-10-19T05:24:57.503870", - "exception": false, - "start_time": "2020-10-19T05:24:57.369373", - "status": "completed" - }, - "tags": [], - "id": "V-D_5L_snr0s" - }, - "source": [ - "def data_generator(batch_size, max_length, data_lines, line_to_tensor=line_to_tensor, shuffle=True):\n", - " \n", - " index = 0 \n", - " cur_batch = [] \n", - " num_lines = len(data_lines) \n", - " lines_index = [*range(num_lines)] \n", - "\n", - " if shuffle:\n", - " rnd.shuffle(lines_index)\n", - " \n", - " while True:\n", - " \n", - " if index >= num_lines:\n", - " index = 0\n", - " if shuffle:\n", - " rnd.shuffle(lines_index)\n", - " \n", - " line = data_lines[lines_index[index]] \n", - " \n", - " if len(line) < max_length:\n", - " cur_batch.append(line)\n", - " \n", - " index += 1\n", - " \n", - " if len(cur_batch) == batch_size:\n", - " \n", - " batch = []\n", - " mask = []\n", - " \n", - " for li in cur_batch:\n", - "\n", - " tensor = line_to_tensor(li)\n", - "\n", - " pad = [0] * (max_length - len(tensor))\n", - " tensor_pad = tensor + pad\n", - " batch.append(tensor_pad)\n", - "\n", - " example_mask = [0 if t == 0 else 1 for t in tensor_pad]\n", - " mask.append(example_mask)\n", - " \n", - " batch_np_arr = np.array(batch)\n", - " mask_np_arr = np.array(mask)\n", - " \n", - " \n", - " yield batch_np_arr, batch_np_arr, mask_np_arr\n", - " \n", - " cur_batch = []\n", - " " - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.113922, - "end_time": "2020-10-19T05:24:57.728762", - "exception": false, - "start_time": "2020-10-19T05:24:57.614840", - "status": "completed" - }, - "tags": [], - "id": "biglhqPjnr0s" - }, - "source": [ - "# Defining the Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.110544, - "end_time": "2020-10-19T05:24:57.950897", - "exception": false, - "start_time": "2020-10-19T05:24:57.840353", - "status": "completed" - }, - "tags": [], - "id": "6JgMdnTonr0s" - }, - "source": [ - "## Gated Recurrent Unit\n", - "\n", - "This function generates a GRU Language Model, consisting of the following layers:\n", - "\n", - "* ShiftRight()\n", - "* Embedding()\n", - "* GRU Units(Number specified by the `n_layers` parameter)\n", - "* Dense() Layer\n", - "* LogSoftmax() Activation" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:58.183193Z", - "iopub.status.busy": "2020-10-19T05:24:58.182383Z", - "iopub.status.idle": "2020-10-19T05:24:58.186370Z", - "shell.execute_reply": "2020-10-19T05:24:58.185685Z" - }, - "papermill": { - "duration": 0.124594, - "end_time": "2020-10-19T05:24:58.186525", - "exception": false, - "start_time": "2020-10-19T05:24:58.061931", - "status": "completed" - }, - "tags": [], - "id": "MSA3bpCHnr0s" - }, - "source": [ - "def GRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", - " model = tl.Serial(\n", - " tl.ShiftRight(mode=mode), \n", - " tl.Embedding( vocab_size = vocab_size, d_feature = d_model), \n", - " [tl.GRU(n_units=d_model) for _ in range(n_layers)], \n", - " tl.Dense(n_units = vocab_size), \n", - " tl.LogSoftmax() \n", - " )\n", - " return model" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.150132, - "end_time": "2020-10-19T05:24:58.463252", - "exception": false, - "start_time": "2020-10-19T05:24:58.313120", - "status": "completed" - }, - "tags": [], - "id": "9A0JtfgCnr0s" - }, - "source": [ - "## Long Short Term Memory\n", - "\n", - "This function generates a LSTM Language Model, consisting of the following layers:\n", - "\n", - "* ShiftRight()\n", - "* Embedding()\n", - "* LSTM Units(Number specified by the `n_layers` parameter)\n", - "* Dense() Layer\n", - "* LogSoftmax() Activation" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:58.713423Z", - "iopub.status.busy": "2020-10-19T05:24:58.712488Z", - "iopub.status.idle": "2020-10-19T05:24:58.717162Z", - "shell.execute_reply": "2020-10-19T05:24:58.716096Z" - }, - "papermill": { - "duration": 0.129976, - "end_time": "2020-10-19T05:24:58.717410", - "exception": false, - "start_time": "2020-10-19T05:24:58.587434", - "status": "completed" - }, - "tags": [], - "id": "ScuXPmvLnr0s" - }, - "source": [ - "def LSTMLM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", - " model = tl.Serial(\n", - " tl.ShiftRight(mode=mode), \n", - " tl.Embedding( vocab_size = vocab_size, d_feature = d_model), \n", - " [tl.LSTM(n_units=d_model) for _ in range(n_layers)], \n", - " tl.Dense(n_units = vocab_size), \n", - " tl.LogSoftmax() \n", - " )\n", - " return model" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.130305, - "end_time": "2020-10-19T05:24:58.971978", - "exception": false, - "start_time": "2020-10-19T05:24:58.841673", - "status": "completed" - }, - "tags": [], - "id": "zWVaUwG1nr0s" - }, - "source": [ - "## Simple Recurrent Unit\n", - "\n", - "This function generates a SRU Language Model, consisting of the following layers:\n", - "\n", - "* ShiftRight()\n", - "* Embedding()\n", - "* SRU Units(Number specified by the `n_layers` parameter)\n", - "* Dense() Layer\n", - "* LogSoftmax() Activation" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:59.219038Z", - "iopub.status.busy": "2020-10-19T05:24:59.218146Z", - "iopub.status.idle": "2020-10-19T05:24:59.221200Z", - "shell.execute_reply": "2020-10-19T05:24:59.221764Z" - }, - "papermill": { - "duration": 0.12795, - "end_time": "2020-10-19T05:24:59.221979", - "exception": false, - "start_time": "2020-10-19T05:24:59.094029", - "status": "completed" - }, - "tags": [], - "id": "ECzZRknPnr0s" - }, - "source": [ - "def SRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'):\n", - " model = tl.Serial(\n", - " tl.ShiftRight(mode=mode), \n", - " tl.Embedding( vocab_size = vocab_size, d_feature = d_model), \n", - " [tl.SRU(n_units=d_model) for _ in range(n_layers)], \n", - " tl.Dense(n_units = vocab_size), \n", - " tl.LogSoftmax() \n", - " )\n", - " return model" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:24:59.461999Z", - "iopub.status.busy": "2020-10-19T05:24:59.460669Z", - "iopub.status.idle": "2020-10-19T05:24:59.465622Z", - "shell.execute_reply": "2020-10-19T05:24:59.466443Z" - }, - "papermill": { - "duration": 0.132413, - "end_time": "2020-10-19T05:24:59.466681", - "exception": false, - "start_time": "2020-10-19T05:24:59.334268", - "status": "completed" - }, - "tags": [], - "id": "1i8UlSvhnr0s", - "outputId": "f4894449-5399-48c8-e22d-a8fa05be3615" - }, - "source": [ - "GRUmodel = GRULM(n_layers = 5)\n", - "LSTMmodel = LSTMLM(n_layers = 5)\n", - "SRUmodel = SRULM(n_layers = 5)\n", - "print(GRUmodel)\n", - "print(LSTMmodel)\n", - "print(SRUmodel)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Serial[\n", - " ShiftRight(1)\n", - " Embedding_256_512\n", - " GRU_512\n", - " GRU_512\n", - " GRU_512\n", - " GRU_512\n", - " GRU_512\n", - " Dense_256\n", - " LogSoftmax\n", - "]\n", - "Serial[\n", - " ShiftRight(1)\n", - " Embedding_256_512\n", - " LSTM_512\n", - " LSTM_512\n", - " LSTM_512\n", - " LSTM_512\n", - " LSTM_512\n", - " Dense_256\n", - " LogSoftmax\n", - "]\n", - "Serial[\n", - " ShiftRight(1)\n", - " Embedding_256_512\n", - " SRU_512\n", - " SRU_512\n", - " SRU_512\n", - " SRU_512\n", - " SRU_512\n", - " Dense_256\n", - " LogSoftmax\n", - "]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.117255, - "end_time": "2020-10-19T05:24:59.712882", - "exception": false, - "start_time": "2020-10-19T05:24:59.595627", - "status": "completed" - }, - "tags": [], - "id": "As2O2Zj8nr0t" - }, - "source": [ - "## Hyperparameters" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.113458, - "end_time": "2020-10-19T05:24:59.939569", - "exception": false, - "start_time": "2020-10-19T05:24:59.826111", - "status": "completed" - }, - "tags": [], - "id": "cxIs1y_Gnr0t" - }, - "source": [ - "Here, we declare `the batch_size` and the `max_length` hyperparameters for the model." - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:25:00.173212Z", - "iopub.status.busy": "2020-10-19T05:25:00.172118Z", - "iopub.status.idle": "2020-10-19T05:25:00.176348Z", - "shell.execute_reply": "2020-10-19T05:25:00.175587Z" - }, - "papermill": { - "duration": 0.121757, - "end_time": "2020-10-19T05:25:00.176474", - "exception": false, - "start_time": "2020-10-19T05:25:00.054717", - "status": "completed" - }, - "tags": [], - "id": "BLKz_gfKnr0t" - }, - "source": [ - "batch_size = 32\n", - "max_length = 64" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.111425, - "end_time": "2020-10-19T05:25:00.399880", - "exception": false, - "start_time": "2020-10-19T05:25:00.288455", - "status": "completed" - }, - "tags": [], - "id": "zUKNlXAmnr0t" - }, - "source": [ - "# Creating Evaluation and Training Dataset" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:25:00.637648Z", - "iopub.status.busy": "2020-10-19T05:25:00.634400Z", - "iopub.status.idle": "2020-10-19T05:25:00.641032Z", - "shell.execute_reply": "2020-10-19T05:25:00.641698Z" - }, - "papermill": { - "duration": 0.130539, - "end_time": "2020-10-19T05:25:00.641885", - "exception": false, - "start_time": "2020-10-19T05:25:00.511346", - "status": "completed" - }, - "tags": [], - "id": "TYJepc9Knr0t" - }, - "source": [ - "eval_lines = lines[-1000:] # Create a holdout validation set\n", - "lines = lines[:-1000] # Leave the rest for training" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.112994, - "end_time": "2020-10-19T05:25:00.871007", - "exception": false, - "start_time": "2020-10-19T05:25:00.758013", - "status": "completed" - }, - "tags": [], - "id": "1DbI1fFSnr0t" - }, - "source": [ - "# Training the Models" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.112218, - "end_time": "2020-10-19T05:25:01.096544", - "exception": false, - "start_time": "2020-10-19T05:25:00.984326", - "status": "completed" - }, - "tags": [], - "id": "8LKJoIzenr0t" - }, - "source": [ - "Here, we create a function to train the models. This function does the following:\n", - "\n", - "* Creating a Train and Evaluation Generator that cycles infinetely using the `itertools` module\n", - "* Train the Model using Adam Optimizer\n", - "* Use the Accuracy Metric for Evaluation" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:25:01.335062Z", - "iopub.status.busy": "2020-10-19T05:25:01.330866Z", - "iopub.status.idle": "2020-10-19T05:25:01.339390Z", - "shell.execute_reply": "2020-10-19T05:25:01.338695Z" - }, - "papermill": { - "duration": 0.130503, - "end_time": "2020-10-19T05:25:01.339549", - "exception": false, - "start_time": "2020-10-19T05:25:01.209046", - "status": "completed" - }, - "tags": [], - "id": "i4-fSW3Tnr0t" - }, - "source": [ - "from trax.supervised import training\n", - "import itertools\n", - "\n", - "def train_model(model, data_generator, batch_size=32, max_length=64, lines=lines, eval_lines=eval_lines, n_steps=10, output_dir = 'model/'): \n", - "\n", - " \n", - " bare_train_generator = data_generator(batch_size, max_length, data_lines=lines)\n", - " infinite_train_generator = itertools.cycle(bare_train_generator)\n", - " \n", - " bare_eval_generator = data_generator(batch_size, max_length, data_lines=eval_lines)\n", - " infinite_eval_generator = itertools.cycle(bare_eval_generator)\n", - " \n", - " train_task = training.TrainTask(\n", - " labeled_data=infinite_train_generator, \n", - " loss_layer=tl.CrossEntropyLoss(), \n", - " optimizer=trax.optimizers.Adam(0.0005),\n", - " n_steps_per_checkpoint=1 \n", - " )\n", - "\n", - " eval_task = training.EvalTask(\n", - " labeled_data=infinite_eval_generator, \n", - " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],\n", - " n_eval_batches=1 \n", - " )\n", - " \n", - " training_loop = training.Loop(model,\n", - " train_task,\n", - " eval_tasks=[eval_task],\n", - " output_dir = output_dir\n", - " )\n", - "\n", - " training_loop.run(n_steps=n_steps)\n", - " \n", - " return training_loop\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:25:01.602437Z", - "iopub.status.busy": "2020-10-19T05:25:01.601617Z", - "iopub.status.idle": "2020-10-19T05:26:21.063884Z", - "shell.execute_reply": "2020-10-19T05:26:21.062700Z" - }, - "papermill": { - "duration": 79.597768, - "end_time": "2020-10-19T05:26:21.064134", - "exception": false, - "start_time": "2020-10-19T05:25:01.466366", - "status": "completed" - }, - "tags": [], - "id": "dykzx2t1nr0t" - }, - "source": [ - "GRU_training_loop = train_model(GRUmodel, data_generator,n_steps=10, output_dir = 'model/GRU')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:26:21.431594Z", - "iopub.status.busy": "2020-10-19T05:26:21.430465Z", - "iopub.status.idle": "2020-10-19T05:27:55.049767Z", - "shell.execute_reply": "2020-10-19T05:27:55.049034Z" - }, - "papermill": { - "duration": 93.801876, - "end_time": "2020-10-19T05:27:55.049974", - "exception": false, - "start_time": "2020-10-19T05:26:21.248098", - "status": "completed" - }, - "tags": [], - "id": "4w9jvGYDnr0t" - }, - "source": [ - "LSTM_training_loop = train_model(LSTMmodel, data_generator, n_steps = 10, output_dir = 'model/LSTM')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-19T05:27:55.406482Z", - "iopub.status.busy": "2020-10-19T05:27:55.405074Z", - "iopub.status.idle": "2020-10-19T05:28:36.239692Z", - "shell.execute_reply": "2020-10-19T05:28:36.238806Z" - }, - "papermill": { - "duration": 41.004194, - "end_time": "2020-10-19T05:28:36.239938", - "exception": false, - "start_time": "2020-10-19T05:27:55.235744", - "status": "completed" - }, - "tags": [], - "id": "PWePFGVKnr0t" - }, - "source": [ - "SRU_training_loop = train_model(SRUmodel, data_generator, n_steps = 10, output_dir = 'model/SRU')" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/trax/examples/Fashion_MNIST_with_Trax.ipynb b/trax/examples/Fashion_MNIST_with_Trax.ipynb deleted file mode 100644 index aa56e7311..000000000 --- a/trax/examples/Fashion_MNIST_with_Trax.ipynb +++ /dev/null @@ -1,399 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "executionInfo": { - "elapsed": 436, - "status": "ok", - "timestamp": 1607381103381, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "1ecEWLK0nsyg" - }, - "outputs": [], - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "executionInfo": { - "elapsed": 447, - "status": "ok", - "timestamp": 1607381103836, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "vxLvhYV5XrvS", - "outputId": "f399419a-f30c-462d-b66e-61fa55c1a466" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "!pip install -q -U trax" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "executionInfo": { - "elapsed": 34658, - "status": "ok", - "timestamp": 1607381138504, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "ssFKSDd3X9Xj", - "outputId": "9eba95c4-ba52-461f-ea42-6a7b1d671a3f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensorflow-numpy\n" - ] - } - ], - "source": [ - "import trax\n", - "# Use the tensorflow-numpy backend.\n", - "trax.fastmath.set_backend('tensorflow-numpy')\n", - "print(trax.fastmath.backend_name())" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "executionInfo": { - "elapsed": 18987, - "status": "ok", - "timestamp": 1607381157508, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "OHKt1_SaYGZW" - }, - "outputs": [], - "source": [ - "# https://www.tensorflow.org/datasets/catalog/fashion_mnist\n", - "train_stream = trax.data.TFDS('fashion_mnist', keys=('image', 'label'), train=True)()\n", - "eval_stream = trax.data.TFDS('fashion_mnist', keys=('image', 'label'), train=False)()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "executionInfo": { - "elapsed": 470, - "status": "ok", - "timestamp": 1607381157985, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "AfGtZHo4YYf6" - }, - "outputs": [], - "source": [ - "train_data_pipeline = trax.data.Serial(\n", - " trax.data.Shuffle(),\n", - " trax.data.Batch(8),\n", - ")\n", - "\n", - "train_batches_stream = train_data_pipeline(train_stream)\n", - "\n", - "eval_data_pipeline = trax.data.Batch(8)\n", - "\n", - "eval_batches_stream = eval_data_pipeline(eval_stream)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "executionInfo": { - "elapsed": 907, - "status": "ok", - "timestamp": 1607381158899, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "T75v8i91ZKcp", - "outputId": "5711f41d-2bf6-498d-fe44-247e16fadb07" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "batch shape (image, label) = [(8, 28, 28, 1), (8,)]\n" - ] - } - ], - "source": [ - "example_batch = next(train_batches_stream)\n", - "print(f'batch shape (image, label) = {[x.shape for x in example_batch]}')" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "executionInfo": { - "elapsed": 430, - "status": "ok", - "timestamp": 1607381159334, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "SbRlJX9_ZRLj" - }, - "outputs": [], - "source": [ - "from trax import layers as tl\n", - "from trax.models.resnet import Resnet50\n", - "\n", - "def get_model(n_output_classes=10):\n", - " model = tl.Serial(\n", - " tl.ToFloat(),\n", - "\n", - " tl.Conv(32, (3, 3), (1, 1), 'SAME'),\n", - " tl.LayerNorm(),\n", - " tl.Relu(),\n", - " tl.MaxPool(),\n", - "\n", - " tl.Conv(64, (3, 3), (1, 1), 'SAME'),\n", - " tl.LayerNorm(),\n", - " tl.Relu(),\n", - " tl.MaxPool(),\n", - "\n", - " tl.Flatten(),\n", - " tl.Dense(n_output_classes),\n", - " )\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "executionInfo": { - "elapsed": 944, - "status": "ok", - "timestamp": 1607381160283, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "zv6LSQZdaV6z" - }, - "outputs": [], - "source": [ - "from trax.supervised import training\n", - "\n", - "train_task = training.TrainTask(\n", - " labeled_data=train_batches_stream,\n", - " loss_layer=tl.CategoryCrossEntropy(),\n", - " optimizer=trax.optimizers.Adam(0.01),\n", - " n_steps_per_checkpoint=100,\n", - ")\n", - "\n", - "eval_task = training.EvalTask(\n", - " labeled_data=eval_batches_stream,\n", - " metrics=[tl.CategoryCrossEntropy(), tl.CategoryAccuracy()],\n", - " n_eval_batches=20,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "executionInfo": { - "elapsed": 14526, - "status": "ok", - "timestamp": 1607381174829, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "Rcz3ngZCa_9i", - "outputId": "3ece3594-8835-416d-d968-205e804f4bcc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Step 1: Total number of trainable weights: 451658\n", - "Step 1: Ran 1 train steps in 1.20 secs\n", - "Step 1: train CategoryCrossEntropy | 2.94750214\n", - "Step 1: eval CategoryCrossEntropy | 211.32588081\n", - "Step 1: eval CategoryAccuracy | 0.12500000\n", - "\n", - "Step 100: Ran 99 train steps in 1.60 secs\n", - "Step 100: train CategoryCrossEntropy | 33.01021576\n", - "Step 100: eval CategoryCrossEntropy | 4.50655540\n", - "Step 100: eval CategoryAccuracy | 0.61250000\n", - "\n", - "Step 200: Ran 100 train steps in 1.53 secs\n", - "Step 200: train CategoryCrossEntropy | 1.78586197\n", - "Step 200: eval CategoryCrossEntropy | 0.89368055\n", - "Step 200: eval CategoryAccuracy | 0.76250000\n", - "\n", - "Step 300: Ran 100 train steps in 0.98 secs\n", - "Step 300: train CategoryCrossEntropy | 0.81385994\n", - "Step 300: eval CategoryCrossEntropy | 0.64747319\n", - "Step 300: eval CategoryAccuracy | 0.77500000\n", - "\n", - "Step 400: Ran 100 train steps in 0.95 secs\n", - "Step 400: train CategoryCrossEntropy | 0.59235722\n", - "Step 400: eval CategoryCrossEntropy | 0.61784569\n", - "Step 400: eval CategoryAccuracy | 0.78750000\n", - "\n", - "Step 500: Ran 100 train steps in 1.01 secs\n", - "Step 500: train CategoryCrossEntropy | 0.52771598\n", - "Step 500: eval CategoryCrossEntropy | 0.41176467\n", - "Step 500: eval CategoryAccuracy | 0.85000000\n", - "\n", - "Step 600: Ran 100 train steps in 1.03 secs\n", - "Step 600: train CategoryCrossEntropy | 0.54706430\n", - "Step 600: eval CategoryCrossEntropy | 0.61605544\n", - "Step 600: eval CategoryAccuracy | 0.77500000\n", - "\n", - "Step 700: Ran 100 train steps in 1.02 secs\n", - "Step 700: train CategoryCrossEntropy | 0.60464281\n", - "Step 700: eval CategoryCrossEntropy | 0.40039212\n", - "Step 700: eval CategoryAccuracy | 0.86250000\n", - "\n", - "Step 800: Ran 100 train steps in 1.01 secs\n", - "Step 800: train CategoryCrossEntropy | 0.49882782\n", - "Step 800: eval CategoryCrossEntropy | 0.69752997\n", - "Step 800: eval CategoryAccuracy | 0.72500000\n", - "\n", - "Step 900: Ran 100 train steps in 1.03 secs\n", - "Step 900: train CategoryCrossEntropy | 0.47269714\n", - "Step 900: eval CategoryCrossEntropy | 0.57425045\n", - "Step 900: eval CategoryAccuracy | 0.80625000\n", - "\n", - "Step 1000: Ran 100 train steps in 1.06 secs\n", - "Step 1000: train CategoryCrossEntropy | 0.53420645\n", - "Step 1000: eval CategoryCrossEntropy | 0.58350748\n", - "Step 1000: eval CategoryAccuracy | 0.79375000\n" - ] - } - ], - "source": [ - "import os\n", - "\n", - "model = get_model()\n", - "\n", - "training_loop = training.Loop(model, \n", - " train_task, \n", - " eval_tasks=[eval_task], \n", - " output_dir='./cnn_model')\n", - "\n", - "training_loop.run(1000)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "executionInfo": { - "elapsed": 530, - "status": "ok", - "timestamp": 1607381175378, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "AMhqFx6HbOs_" - }, - "outputs": [], - "source": [ - "" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "Fashion MNIST with Trax.ipynb", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/examples/Knowledge_Tracing_Transformer.ipynb b/trax/examples/Knowledge_Tracing_Transformer.ipynb deleted file mode 100644 index 3e431f757..000000000 --- a/trax/examples/Knowledge_Tracing_Transformer.ipynb +++ /dev/null @@ -1,2126 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Knowledge Tracing Transformer", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "cells": [ - { - "cell_type": "code", - "metadata": { - "id": "eGCe1pjznIQS" - }, - "source": [ - "#@title\n", - "# Copyright 2021 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lAula_PU9jqB" - }, - "source": [ - "## Intro\r\n", - "\r\n", - "This notebook trains a transformer model on the [EdNet dataset](https://github.com/riiid/ednet) using the [google/trax library](https://github.com/google/trax). The EdNet dataset is large set of student responses to multiple choice questions related to English language learning. A recent Kaggle competition, [Riiid! Answer Correctness Prediction](https://www.kaggle.com/c/riiid-test-answer-prediction), provided as subset of this data, consisting of 100 million responses to 13 thousand questions from 300 thousand students.\r\n", - "\r\n", - "The state of the art result, detailed in [SAINT+: Integrating Temporal Features for EdNet Correctness Prediction](https://arxiv.org/abs/2010.12042), achieves an AUC ROC of 0.7914. The winning solution in the [Riiid! Answer Correctness Prediction](https://www.kaggle.com/c/riiid-test-answer-prediction) competition achieved an AUC ROC of 0.820. This notebook achieves an AUC ROC of 0.776 implementing an approach similar to the state of the art approach, training for 25,000 steps. It demonstrates several techniques that may be useful to those getting started with the [google/trax library](https://github.com/google/trax) or deep learning in general. This notebook demonstrates how to:\r\n", - "\r\n", - "* Use BigQuery to perform feature engineering\r\n", - "* Create TFRecords with multiple sequences per record\r\n", - "* Modify the trax Transformer model to accommodate a knowledge tracing dataset:\r\n", - " * Utilize multiple encoder and decoder embeddings - aggregated either by concatenation or sum\r\n", - " * Include a custom metric - AUC ROC\r\n", - " * Utilize a combined padding and future mask\r\n", - "* Use trax's [gin-config](https://github.com/google/gin-config) integration to specify training parameters\r\n", - "* Display training progress using trax's tensorboard integration\r\n", - "\r\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/CalebEverett/riiid_transformer/blob/master/riiid-trax-transformer.ipynb)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "tuG_-VFcpxLc" - }, - "source": [ - "# Choose a location for your storage bucket and BigQuery dataset to minimize data egress charges. Once you have\r\n", - "# created them, if you restart your notebook you can run this to see where your colab is running\r\n", - "# and factory reset until you get a location that is near your data.\r\n", - "!curl ipinfo.io" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_SQN6SX89XNq" - }, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vo5bzc9z7nw_" - }, - "source": [ - "# \r\n", - "!git clone https://github.com/google/trax.git\r\n", - "!pip install ./trax\r\n", - "!pip install -U pyarrow\r\n", - "!pip install -U google-cloud-bigquery google-cloud-bigquery-storage" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "0W7kto2g7Sfa" - }, - "source": [ - "from functools import partial\r\n", - "import json\r\n", - "import math\r\n", - "import os\r\n", - "from pathlib import Path\r\n", - "import subprocess\r\n", - "import sys\r\n", - "import time\r\n", - "\r\n", - "import gin\r\n", - "from google.cloud import storage, bigquery\r\n", - "from google.cloud.bigquery import LoadJobConfig, QueryJobConfig, \\\r\n", - " SchemaField, SourceFormat\r\n", - "import jax\r\n", - "from jax.config import config\r\n", - "import pandas as pd\r\n", - "import numpy as np\r\n", - "import requests\r\n", - "import sqlite3\r\n", - "import trax\r\n", - "from trax import fastmath\r\n", - "from trax import layers as tl\r\n", - "from trax.fastmath import numpy as tnp\r\n", - "import tensorflow as tf\r\n", - "from tqdm.notebook import tqdm\r\n", - "import zipfile\r\n", - "\r\n", - "# Create google credentials and store in drive\r\n", - "# https://colab.research.google.com/drive/1LWhrqE2zLXqz30T0a0JqXnDPKweqd8ET\r\n", - "# \r\n", - "# Create a config.json file with variables for:\r\n", - "# \"BUCKET\": \"\",\r\n", - "# \"BQ_DATASET\": \"\",\r\n", - "# \"KAGGLE_USERNAME\": \"\",\r\n", - "# \"KAGGLE_KEY\": \"\",\r\n", - "# \"PROJECT\": \"\",\r\n", - "# \"LOCATION\": \"\"\r\n", - "from google.colab import drive\r\n", - "\r\n", - "DRIVE = Path('/content/drive/My Drive')\r\n", - "PATH = 'riiid-transformer'\r\n", - "\r\n", - "if not DRIVE.exists():\r\n", - " drive.mount(str(DRIVE.parent))\r\n", - "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = str(DRIVE/PATH/'google.json')\r\n", - "\r\n", - "with open(str(DRIVE/PATH/'config.json')) as f:\r\n", - " CONFIG = json.load(f)\r\n", - " os.environ = {**os.environ, **CONFIG}\r\n", - "\r\n", - "from kaggle.api.kaggle_api_extended import KaggleApi\r\n", - "kaggle_api = KaggleApi()\r\n", - "kaggle_api.authenticate()\r\n", - "\r\n", - "AUTO = tf.data.experimental.AUTOTUNE\r\n", - "BUCKET = os.getenv('BUCKET', 'riiid-transformer')\r\n", - "BQ_DATASET = os.getenv('BQ_DATASET', 'my_data')\r\n", - "LOCATION = os.getenv('LOCATION', 'us-central1')\r\n", - "PROJECT = os.getenv('PROJECT', 'fastai-caleb')\r\n", - "\r\n", - "bucket = storage.Client(project=PROJECT).get_bucket(BUCKET)\r\n", - "dataset = bigquery.Dataset(f'{PROJECT}.{BQ_DATASET}')\r\n", - "bq_client = bigquery.Client(project=PROJECT, location=LOCATION)\r\n", - "\r\n", - "%matplotlib inline\r\n", - "from matplotlib import pyplot as plt\r\n", - "\r\n", - "%load_ext tensorboard\r\n", - "\r\n", - "gin.enter_interactive_mode()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vL0eRGAnyK9x" - }, - "source": [ - "## Control Panel" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YaAhPw-zv1la" - }, - "source": [ - "These variables can be set to True to run the code in the sections described or False to skip over them after they have been run for the first time." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "MNrBIpVPyPGX" - }, - "source": [ - "USE_TPU = False\r\n", - "DOWNLOAD_DATASET = False\r\n", - "LOAD_DATA_TO_BQ = False\r\n", - "PERFORM_FEATURE_ENGINEERING = False\r\n", - "TEST_FEATURE_ENGNEERING = False\r\n", - "CREATE_TFRECORDS = False\r\n", - "TEST_TFRECORDS = False\r\n", - "TRAIN_MODEL = False" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "t8Jvva6lBRyI" - }, - "source": [ - "## Initialize TPU" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PsczFYbe80ei" - }, - "source": [ - "if USE_TPU:\r\n", - " if 'TPU_DRIVER_MODE' not in globals():\r\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'\r\n", - " resp = requests.post(url)\r\n", - " TPU_DRIVER_MODE = 1\r\n", - "\r\n", - " config.FLAGS.jax_xla_backend = \"tpu_driver\"\r\n", - " config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\r\n", - " print(config.FLAGS.jax_backend_target)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GXP1CnQXBtzd" - }, - "source": [ - "## Download Dataset" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "YSAnW-bzBzCE" - }, - "source": [ - "if DOWNLOAD_DATASET:\r\n", - " kaggle_api.competition_download_cli('riiid-test-answer-prediction')\r\n", - " with zipfile.ZipFile('riiid-test-answer-prediction.zip', 'r') as zip_ref:\r\n", - " zip_ref.extractall()\r\n", - " for f in ['train.csv', 'questions.csv', 'lectures.csv']:\r\n", - " bucket.blob(f).upload_from_filename(f)\r\n", - "\r\n", - "if False:\r\n", - " for f in tqdm(['train.csv', 'questions.csv', 'lectures.csv']):\r\n", - " bucket.blob(f).download_to_filename(f)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wM1_VVnm-61P" - }, - "source": [ - "## Create BigQuery Dataset" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "_Eo0iR8Y-5Sv" - }, - "source": [ - "if False:\r\n", - " delete_contents=False\r\n", - " bq_client.delete_dataset(BQ_DATASET, delete_contents=delete_contents)\r\n", - " print(f'Dataset {dataset.dataset_id} deleted from project {dataset.project}.')\r\n", - "\r\n", - "try:\r\n", - " dataset = bq_client.get_dataset(dataset.dataset_id)\r\n", - " print(f'Dataset {dataset.dataset_id} already exists '\r\n", - " f'in location {dataset.location} in project {dataset.project}.')\r\n", - "except:\r\n", - " dataset = bq_client.create_dataset(dataset)\r\n", - " print(f'Dataset {dataset.dataset_id} created '\r\n", - " f'in location {dataset.location} in project {dataset.project}.')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i7tZZN449eH-" - }, - "source": [ - "## Dtypes" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qt70hdhk_j6p" - }, - "source": [ - "dtypes_orig = {\r\n", - " 'lectures': {\r\n", - " 'lecture_id': 'uint16',\r\n", - " 'tag': 'uint8',\r\n", - " 'part': 'uint8',\r\n", - " 'type_of': 'str',\r\n", - " },\r\n", - " 'questions': {\r\n", - " 'question_id': 'uint16',\r\n", - " 'bundle_id': 'uint16',\r\n", - " 'correct_answer': 'uint8',\r\n", - " 'part': 'uint8',\r\n", - " 'tags': 'str',\r\n", - " \r\n", - " },\r\n", - " 'train': {\r\n", - " 'row_id': 'int64',\r\n", - " 'timestamp': 'int64',\r\n", - " 'user_id': 'int32',\r\n", - " 'content_id': 'int16',\r\n", - " 'content_type_id': 'int8',\r\n", - " 'task_container_id': 'int16',\r\n", - " 'user_answer': 'int8',\r\n", - " 'answered_correctly': 'int8',\r\n", - " 'prior_question_elapsed_time': 'float32', \r\n", - " 'prior_question_had_explanation': 'bool'\r\n", - " }\r\n", - " \r\n", - "}\r\n", - "\r\n", - "dtypes_new = {\r\n", - " 'lectures': {},\r\n", - " 'questions': {\r\n", - " 'tags_array': 'str'\r\n", - " },\r\n", - " 'train': {\r\n", - " 'task_container_id_q': 'int16',\r\n", - " 'pqet_current': 'int32',\r\n", - " 'ts_delta': 'int32'\r\n", - " }\r\n", - "}\r\n", - "\r\n", - "dtypes = {}\r\n", - "for table_id in dtypes_orig:\r\n", - " dtypes[table_id] = {\r\n", - " **dtypes_orig[table_id],\r\n", - " **dtypes_new[table_id]\r\n", - " }" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zYIOhHEoDw-v" - }, - "source": [ - "### Big Query Table Schemas" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "q1LEgqZfDulc" - }, - "source": [ - "# \r\n", - "type_map = {\r\n", - " 'int64': 'INT64',\r\n", - " 'int32': 'INT64',\r\n", - " 'int16': 'INT64',\r\n", - " 'int8': 'INT64',\r\n", - " 'uint8': 'INT64',\r\n", - " 'uint16': 'INT64',\r\n", - " 'str': 'STRING',\r\n", - " 'bool': 'BOOL',\r\n", - " 'float32': 'FLOAT64'\r\n", - "}\r\n", - "\r\n", - "schemas_orig = {table: [SchemaField(f, type_map[t]) for f, t in\r\n", - " fields.items()] for table, fields in dtypes_orig.items()}\r\n", - "\r\n", - "schemas = {}\r\n", - "for table_id, fields in dtypes_new.items():\r\n", - " new_fields = [SchemaField(f, type_map[t]) for\r\n", - " f, t in fields.items() if 'array' not in f]\r\n", - " \r\n", - " new_array_feilds = [SchemaField(f, 'INT64', 'REPEATED') for\r\n", - " f, t in fields.items() if 'array' in f]\r\n", - "\r\n", - " new_fields += new_array_feilds\r\n", - "\r\n", - " schemas[table_id] = schemas_orig[table_id] + new_fields" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sv7wPwp2EJpH" - }, - "source": [ - "### Load Tables" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "EtBgHrBvC_H3" - }, - "source": [ - "def load_job_cb(future):\r\n", - " \"\"\"Prints update upon completion to output of last run cell.\"\"\"\r\n", - " \r\n", - " seconds = (future.ended - future.created).total_seconds()\r\n", - " print(f'Loaded {future.output_rows:,d} rows to table {future.job_id.split(\"_\")[0]} in '\r\n", - " f'{seconds:>4,.1f} sec, {int(future.output_rows / seconds):,d} per sec.')\r\n", - "\r\n", - "def load_csv_from_uri(table_id, schemas_orig):\r\n", - " full_table_id = f'{BQ_DATASET}.{table_id}'\r\n", - "\r\n", - " job_config = LoadJobConfig(\r\n", - " schema=schemas_orig[table_id],\r\n", - " source_format=SourceFormat.CSV,\r\n", - " skip_leading_rows=1\r\n", - " )\r\n", - "\r\n", - " uri = f'gs://{BUCKET}/{table_id}.csv'\r\n", - " load_job = bq_client.load_table_from_uri(uri, full_table_id,\r\n", - " job_config=job_config,\r\n", - " job_id_prefix=f'{table_id}_')\r\n", - " print(f'job {load_job.job_id} started')\r\n", - " load_job.add_done_callback(load_job_cb)\r\n", - " \r\n", - " return load_job" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "L44_o0NYEOcC" - }, - "source": [ - "if LOAD_DATA_TO_BQ:\r\n", - " for table_id in dtypes_orig:\r\n", - " lj = load_csv_from_uri(table_id, schemas_orig).result()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JUAg3Pz5ImSx" - }, - "source": [ - "### Update BiqQuery Schemas" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ym62FdoNgU8t" - }, - "source": [ - "Before performing feature engineering, we have to update the table schemas in Big Query to create columns for the new features." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qByuVM7MIr8b" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " for table_id, schema in schemas.items():\r\n", - " table = bq_client.get_table(f'{BQ_DATASET}.{table_id}')\r\n", - " table.schema = schema\r\n", - " table = bq_client.update_table(table, ['schema'])" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oCHq9dJiFOPh" - }, - "source": [ - "## Feature Engineering" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d6tN2qREdc9A" - }, - "source": [ - "Using BigQuery for a dataset of 100 million rows is much faster than using local dataframes. In addition, you get to use the full power of SQL, including [window functions](https://cloud.google.com/bigquery/docs/reference/standard-sql/analytic-function-concepts), which are especially useful for time series feature engineering.\r\n", - "\r\n", - "Feature engineering for this problem is fairly minimal and includes:\r\n", - "* Replacing missing null values for `prior_question_elapsed_time` and `prior_question_had_explanation` in the train table\r\n", - "* Replacing one missing tag value in the questions table\r\n", - "* Recalcuating the `task_container_id` as `task_container_id_q` so that it excludes lecture records and increases monotonically with `timetamp` so that the calucations for elapsed time and time delta, which depend on values from the immediately prior and immediately succeeding records, are calculated correctly.\r\n", - "* Calculating `pqet_current`, the time it took on average to answer the questions in the current `task_container_id_q`.\r\n", - "* Calculating `ts_delta`, the elapsed time between the last `task_container_id_q` and the current one.\r\n", - "* Creating `folds` table, in which users are assigned to one of 20 folds.\r\n", - "* Creating a `tags_array` field in the questions table, that returns an array of six elements populated with the tags assigned to each questions, padded with zeros if there are less than six." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "X2ynWZqPFnSj" - }, - "source": [ - "def done_cb(future):\r\n", - " seconds = (future.ended - future.started).total_seconds()\r\n", - " print(f'Job {future.job_id} finished in {seconds} seconds.')\r\n", - "\r\n", - "def run_query(query, job_id_prefix=None, wait=True,\r\n", - " use_query_cache=True):\r\n", - "\r\n", - " job_config = QueryJobConfig(\r\n", - " use_query_cache=use_query_cache)\r\n", - "\r\n", - " query_job = bq_client.query(query, job_id_prefix=job_id_prefix,\r\n", - " job_config=job_config)\r\n", - " print(f'Job {query_job.job_id} started.')\r\n", - " query_job.add_done_callback(done_cb)\r\n", - " if wait:\r\n", - " query_job.result()\r\n", - " \r\n", - " return query_job" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "8Qo21D1ITicH" - }, - "source": [ - "def get_df_query_bqs(query, dtypes=None, fillna=None):\r\n", - " qj = bq_client.query(query)\r\n", - " df = qj.to_dataframe(create_bqstorage_client=True, progress_bar_type='tqdm_notebook')\r\n", - " if fillna is not None:\r\n", - " df = df.fillna(fillna)\r\n", - " try:\r\n", - " df = df.astype({c: dtypes.get(c, 'int32') for c in df.columns}) \r\n", - " except:\r\n", - " print('dtypes not applied.')\r\n", - " finally: \r\n", - " return df" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N71-o9uQFSzV" - }, - "source": [ - "### Replace Missing Values" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "rnBL1HXxFWKX" - }, - "source": [ - "def update_missing_values(table_id='train', column_id=None, value=None):\r\n", - " return f\"\"\"\r\n", - " UPDATE {BQ_DATASET}.{table_id}\r\n", - " SET {column_id} = {value}\r\n", - " WHERE {column_id} is NULL;\r\n", - " \"\"\", sys._getframe().f_code.co_name + '_'" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "e0qBG2XrGIMB" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " qj = run_query(*update_missing_values('train', 'prior_question_elapsed_time', '0'))\r\n", - " qj = run_query(*update_missing_values('train', 'prior_question_had_explanation', 'false'))\r\n", - " qj = run_query(*update_missing_values('questions', 'tags', '\"188\"'))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "elZXRogqL-pr" - }, - "source": [ - "### Recalculate Task Container Ids for Questions Only" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Li0UdfY2MeOm" - }, - "source": [ - "def update_task_container_id(table_id='train',\r\n", - " column_id='task_container_id',\r\n", - " excl_lectures=True):\r\n", - " excl_lec = 'WHERE content_type_id = 0' if excl_lectures else ''\r\n", - " \r\n", - " return f\"\"\"\r\n", - " UPDATE {BQ_DATASET}.{table_id} t\r\n", - " SET {column_id} = target.calc\r\n", - " FROM (\r\n", - " SELECT row_id, DENSE_RANK()\r\n", - " OVER (\r\n", - " PARTITION BY user_id\r\n", - " ORDER BY timestamp\r\n", - " ) calc\r\n", - " FROM {BQ_DATASET}.{table_id}\r\n", - " {excl_lec}\r\n", - " ) target\r\n", - " WHERE target.row_id = t.row_id\r\n", - " \"\"\", sys._getframe().f_code.co_name + '_'" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "FGFisFdpMtGy" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " q = update_task_container_id(table_id='train',\r\n", - " column_id='task_container_id_q ',\r\n", - " excl_lectures=True)\r\n", - " qj = run_query(*q)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2HblfPhCG618" - }, - "source": [ - "### Calculate Current Question Elapsed Time and Timestamp Delta" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "29-ajMgjHEUl" - }, - "source": [ - "def update_pqet_current(table_id='train'):\r\n", - " return f\"\"\"\r\n", - " UPDATE {BQ_DATASET}.{table_id} t\r\n", - " SET t.pqet_current = CAST(p.pqet_current AS INT64)\r\n", - " FROM (\r\n", - " SELECT\r\n", - " row_id, LAST_VALUE(prior_question_elapsed_time) OVER (\r\n", - " PARTITION BY user_id ORDER BY task_container_id_q\r\n", - " RANGE BETWEEN 1 FOLLOWING AND 1 FOLLOWING) pqet_current\r\n", - " FROM {BQ_DATASET}.train \r\n", - " WHERE content_type_id = 0\r\n", - " ) p\r\n", - " WHERE t.row_id = p.row_id;\r\n", - " \r\n", - " UPDATE {BQ_DATASET}.{table_id}\r\n", - " SET pqet_current = 0\r\n", - " WHERE pqet_current IS NULL;\r\n", - " \r\n", - " \"\"\", sys._getframe().f_code.co_name + '_'" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "juf9vDrzIF2W" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " qj = run_query(*update_pqet_current())" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "E9LnKsjVLgRk" - }, - "source": [ - "def update_ts_delta(table_id='train'):\r\n", - " return f\"\"\"\r\n", - " UPDATE {BQ_DATASET}.{table_id} t\r\n", - " SET t.ts_delta = timestamp - p.ts_prior\r\n", - " FROM (\r\n", - " SELECT\r\n", - " row_id, LAST_VALUE(timestamp) OVER (\r\n", - " PARTITION BY user_id ORDER BY task_container_id_q\r\n", - " RANGE BETWEEN 1 PRECEDING AND 1 PRECEDING) ts_prior\r\n", - " FROM {BQ_DATASET}.train \r\n", - " WHERE content_type_id = 0\r\n", - " ) p\r\n", - " WHERE t.row_id = p.row_id;\r\n", - " \r\n", - " UPDATE {BQ_DATASET}.{table_id}\r\n", - " SET ts_delta = 0\r\n", - " WHERE ts_delta IS NULL;\r\n", - " \"\"\", sys._getframe().f_code.co_name + '_'" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "0-CEUJsoL1dC" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " qj = run_query(*update_ts_delta())" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "99qz0H8Xb3i1" - }, - "source": [ - "### Create Folds Table\r\n", - "Assign users randomly to one of 20 folds. Store total records to facilitate filtering based on record count." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "N7UnAHF8cesC" - }, - "source": [ - "def create_table_folds(table_id='folds', n_folds=20):\r\n", - " return f\"\"\"\r\n", - " DECLARE f INT64;\r\n", - "\r\n", - " CREATE OR REPLACE TABLE {BQ_DATASET}.{table_id} (\r\n", - " user_id INT64,\r\n", - " fold INT64,\r\n", - " record_count INT64\r\n", - " );\r\n", - "\r\n", - " INSERT {BQ_DATASET}.{table_id} (user_id, fold, record_count)\r\n", - " SELECT f.user_id, CAST(FLOOR(RAND() * {n_folds}) AS INT64) fold, f.record_count\r\n", - " FROM (\r\n", - " SELECT user_id,\r\n", - " COUNT(row_id) record_count\r\n", - " FROM {BQ_DATASET}.train\r\n", - " WHERE content_type_id = 0\r\n", - " GROUP BY user_id\r\n", - " ) f\r\n", - " ORDER BY user_id;\r\n", - " \"\"\", sys._getframe().f_code.co_name + '_'" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "SVPio880dPSe" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " qj = run_query(*create_table_folds())" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "14dQwOnzdolg" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " df_folds = get_df_query_bqs(f\"\"\"\r\n", - " SELECT *\r\n", - " FROM {BQ_DATASET}.folds\r\n", - " \"\"\",\r\n", - " dtypes=dtypes)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "R9Y9Xhwee6f7" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " df_folds.groupby('fold').count().user_id.plot(kind='bar', title='Count of Users by Fold');" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "qciROEyIoowx" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " df_folds.groupby('fold').mean().record_count.plot(kind='bar', title='Average Records per User by Fold');" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "q5zS5bWenJaj" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " df_fold_ac = get_df_query_bqs(f\"\"\"\r\n", - " SELECT fold, SUM(answered_correctly) ac_sum, COUNT(answered_correctly) rec_count\r\n", - " FROM {BQ_DATASET}.train\r\n", - " JOIN {BQ_DATASET}.folds\r\n", - " ON train.user_id = folds.user_id\r\n", - " GROUP BY fold\r\n", - " \"\"\",\r\n", - " dtypes=dtypes)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "3TGelsfEn7xf" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " df_fold_ac.rec_count.plot(kind='bar', title='Count of Records by Fold');" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "a1kcjfOkoGU_" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " (df_fold_ac.ac_sum / df_fold_ac.rec_count).plot(kind='bar', title='Percent Answered Correctly by Fold');" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Nqg9tqEnOa7l" - }, - "source": [ - "### Create Tags Array on Questions Table\r\n", - "We need the tags as an array later when we create TFRecords. We also increment by one and pad with zeros to a fixed length of 6 so that they can be concatentated as a feature for modeling." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "wr4peSSWPHiW" - }, - "source": [ - "def update_tags_array(table_id='questions', column_id='tags_array'):\r\n", - " \r\n", - " return f\"\"\"\r\n", - " UPDATE {BQ_DATASET}.{table_id} q\r\n", - " SET {column_id} = tp.tags_fixed_len\r\n", - " FROM (\r\n", - " WITH tags_padded AS (\r\n", - " WITH tags_table AS (SELECT question_id, tags FROM {BQ_DATASET}.{table_id})\r\n", - " SELECT question_id, ARRAY_CONCAT(ARRAY_AGG(CAST(tag AS INT64) + 1), [0,0,0,0,0]) tags_array\r\n", - " FROM tags_table, UNNEST(SPLIT(tags, ' ')) as tag\r\n", - " GROUP BY question_id\r\n", - " )\r\n", - " SELECT question_id,\r\n", - " ARRAY(SELECT x FROM UNNEST(tags_array) AS x WITH OFFSET off WHERE off < 6 ORDER BY off) tags_fixed_len\r\n", - " FROM tags_padded\r\n", - " ) tp\r\n", - " WHERE tp.question_id = q.question_id\r\n", - " \"\"\", sys._getframe().f_code.co_name + '_'" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "exf_kIXuRagG" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " qj = run_query(*update_tags_array())" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "pB2YEqhISnRz" - }, - "source": [ - "if PERFORM_FEATURE_ENGINEERING:\r\n", - " df_q = get_df_query_bqs('select * from my_data.questions', dtypes=dtypes)\r\n", - " print(df_q.head())" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BrxKtvosXTfw" - }, - "source": [ - "## Feature Engineering Tests\r\n", - "* Features come back out of Biq Query with the same values they went in with\r\n", - "* `ts_delta` is equal to difference between timestamps on consecutive records\r\n", - "* `pqet_current` is equal to `prior_question_elapsed_time` from next record\r\n", - "* visually inspect distributions of `ts_delta` and `pqet_current`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ynTomTnKY7F3" - }, - "source": [ - "### Load Sample from train.csv" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "c9aNhgV7Yfzw" - }, - "source": [ - "if TEST_FEATURE_ENGNEERING:\r\n", - " df_train_samp = pd.read_csv('train.csv', nrows=100000)\r\n", - " df_train_samp.prior_question_had_explanation = df_train_samp.prior_question_had_explanation.fillna(False).astype(bool)\r\n", - " df_train_samp.prior_question_elapsed_time = df_train_samp.prior_question_elapsed_time.fillna(0)\r\n", - " user_ids_samp = df_train_samp.user_id.unique()[:-1]\r\n", - " print(len(user_ids_samp))\r\n", - " df_train_samp = df_train_samp[df_train_samp.user_id.isin(user_ids_samp) & (df_train_samp.content_type_id == 0)].reset_index(drop=True)\r\n", - " print(len(df_train_samp))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "elJgSDpmY0q0" - }, - "source": [ - "### Pull sample of corresponding user_ids from BigQuery" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "GvQF1yIzZPHE" - }, - "source": [ - "if TEST_FEATURE_ENGNEERING:\r\n", - " df_bq_samp = get_df_query_bqs(f\"\"\"\r\n", - " SELECT *\r\n", - " FROM {BQ_DATASET}.train\r\n", - " WHERE user_id IN ({(',').join(map(str, user_ids_samp))})\r\n", - " AND content_type_id = 0\r\n", - " ORDER BY user_id, timestamp, row_id\r\n", - " \"\"\",\r\n", - " dtypes=None)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gP3mDb-phsSt" - }, - "source": [ - "### Tests" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "TukjFI4YatpD" - }, - "source": [ - "if TEST_FEATURE_ENGNEERING:\r\n", - " # values in columns are the same between train.csv and bq\r\n", - " for c in df_train_samp.columns:\r\n", - " assert all(df_train_samp[c] == df_bq_samp[c]), f'{c} is not the same'\r\n", - "\r\n", - " # pqet_current pulls prior_question_elapsed_time back one task_container_id for each user\r\n", - " df_bq_samp_tst = df_bq_samp[['user_id', 'task_container_id_q', 'prior_question_elapsed_time', 'pqet_current']].groupby(['user_id', 'task_container_id_q']).max()\r\n", - "\r\n", - " for user_id in user_ids_samp:\r\n", - " assert all(df_bq_samp_tst.loc[user_id].pqet_current.shift(1).iloc[1:] == df_bq_samp_tst.loc[user_id].prior_question_elapsed_time.iloc[1:])\r\n", - "\r\n", - " # ts_delta equal to timestamp from current task_container_id_q minus timestamp from prior task_container_id_q\r\n", - " df_bq_samp_tst = df_bq_samp[['user_id', 'task_container_id_q', 'timestamp', 'ts_delta']].groupby(['user_id', 'task_container_id_q']).max()\r\n", - "\r\n", - " for user_id in user_ids_samp:\r\n", - " assert all((df_bq_samp_tst.loc[user_id].timestamp - df_bq_samp_tst.loc[user_id].timestamp.shift(1)).iloc[1:] == df_bq_samp_tst.loc[user_id].ts_delta.iloc[1:])" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "jZhQStC4bIa_" - }, - "source": [ - "if TEST_FEATURE_ENGNEERING:\r\n", - " df_bq_samp.pqet_current.hist();" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "4WrlU_qZbisC" - }, - "source": [ - "if TEST_FEATURE_ENGNEERING:\r\n", - " df_bq_samp.ts_delta.hist();" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qR33_34dpD5R" - }, - "source": [ - "## Create TFRecords" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NjZunnJNqG7U" - }, - "source": [ - "We are going to create a set of TFRecords with one user per record and one fold per file. We are going to include the following columns as features:\r\n", - "* `user_id` - this won't get used as a feature, but is included to able to tie back to original data\r\n", - "* `content_id` - incremented by one to reserve 0 for padding character\r\n", - "* `answered_correctly` - incremented by one to reserve 0 for padding character\r\n", - "* `part`\r\n", - "* `pqet_curret`\r\n", - "* `ts_delta`\r\n", - "* `tags` - already incremented by one with zeros as padding\r\n", - "* `task_container_id` - excluding lectures and already indexed to one\r\n", - "* `timestamp`" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "e3IhNw5C3epG" - }, - "source": [ - "def _int64_feature(value):\r\n", - " \r\n", - " if type(value) != type(list()):\r\n", - " value = [value]\r\n", - "\r\n", - " return tf.train.Feature(int64_list=tf.train.Int64List(value=value))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "jBv7Qtjb3gr-" - }, - "source": [ - "def serialize_example(user_id, features):\r\n", - " \r\n", - " feature_names = ['content_id', 'answered_correctly', 'part', 'pqet_current', 'ts_delta', 'tags',\r\n", - " 'task_container_id', 'timestamp']\r\n", - " \r\n", - " feature = {'user_id': _int64_feature(user_id)}\r\n", - " \r\n", - " for i, n in enumerate(feature_names):\r\n", - " feature[n] = _int64_feature(features[i])\r\n", - "\r\n", - " return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "8sG6HK_p3imX" - }, - "source": [ - "def parse_example(example):\r\n", - " \r\n", - " feature_names = {'content_id': tf.int32, 'answered_correctly': tf.int32, 'part': tf.int32,\r\n", - " 'pqet_current': tf.int32, 'ts_delta': tf.int64, 'tags': tf.int32,\r\n", - " 'task_container_id': tf.int32, 'timestamp': tf.int64}\r\n", - " \r\n", - " features = {'user_id': tf.io.FixedLenFeature([1], tf.int64)}\r\n", - " \r\n", - " for k, v in feature_names.items():\r\n", - " features[k] = tf.io.VarLenFeature(tf.int64)\r\n", - "\r\n", - " example = tf.io.parse_single_example(example, features)\r\n", - "\r\n", - " for k, v in feature_names.items():\r\n", - " example[k] = tf.cast(example[k].values, v)\r\n", - " \r\n", - " example['tags'] = tf.reshape(example['tags'], (tf.size(example['answered_correctly']), 6))\r\n", - "\r\n", - " return example" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Nh-44zO158H8" - }, - "source": [ - "def get_ds_tfrec_raw(folds=[0]):\r\n", - " file_pat = 'gs://{BUCKET}/tfrec/{f:02d}-*.tfrec'\r\n", - " file_pats = [file_pat.format(BUCKET=BUCKET, f=f) for f in folds]\r\n", - " options = tf.data.Options()\r\n", - "\r\n", - " ds = (tf.data.Dataset.list_files(file_pats)\r\n", - " .with_options(options)\r\n", - " .interleave(tf.data.TFRecordDataset, num_parallel_calls=AUTO)\r\n", - " .map(parse_example, num_parallel_calls=AUTO)\r\n", - " )\r\n", - " \r\n", - " return ds" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "NEfOS7m2pKZp" - }, - "source": [ - "def get_df_tfrec(folds):\r\n", - " df_tfrec = get_df_query_bqs(f\"\"\"\r\n", - " SELECT fold, train.user_id, content_id + 1 content_id,\r\n", - " answered_correctly + 1 answered_correctly, part, pqet_current, ts_delta,\r\n", - " tags_array tags, task_container_id_q task_container_id, timestamp\r\n", - " FROM {BQ_DATASET}.train\r\n", - " JOIN {BQ_DATASET}.folds\r\n", - " ON train.user_id = folds.user_id\r\n", - " JOIN {BQ_DATASET}.questions\r\n", - " ON train.content_id = questions.question_id\r\n", - " WHERE fold IN ({(', ').join(map(str, folds))})\r\n", - " AND content_type_id = 0\r\n", - " ORDER BY user_id, timestamp, row_id\r\n", - " \"\"\",\r\n", - " dtypes=None)\r\n", - "\r\n", - " return df_tfrec" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "HoYPoSKR3V5-" - }, - "source": [ - "def write_tfrecords(folds):\r\n", - " \r\n", - " df_tfrec = get_df_tfrec(folds)\r\n", - " \r\n", - " for f in folds:\r\n", - " groups_dict = (df_tfrec[df_tfrec.fold == f]\r\n", - " .groupby('user_id')\r\n", - " .apply(lambda r: (list(r['content_id'].values),\r\n", - " list(r['answered_correctly'].values),\r\n", - " list(r['part'].values),\r\n", - " list(r['pqet_current'].values.astype(np.int64)),\r\n", - " list(r['ts_delta'].values.astype(np.int64)),\r\n", - " list(np.concatenate(r['tags'].values)),\r\n", - " list(r['task_container_id'].values.astype(np.int64)),\r\n", - " list(r['timestamp'].values.astype(np.int64)),\r\n", - " ))).to_dict() \r\n", - " \r\n", - " out_path = f'gs://{BUCKET}/tfrec'\r\n", - " filename = f'{f:02d}-{len(groups_dict.keys())}.tfrec'\r\n", - " record_file = f'{out_path}/{filename}'\r\n", - "\r\n", - " with tf.io.TFRecordWriter(record_file) as writer:\r\n", - " for user_id, features in tqdm(groups_dict.items(), desc=f'Fold {f:02d}'):\r\n", - " writer.write(serialize_example(user_id, features))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "48jjQ7L9g1_M" - }, - "source": [ - "## Write TFRecords\r\n", - "\r\n", - "* Process in chunks to avoid running out of memory." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Mpw-Nb7Dg8xL" - }, - "source": [ - "if CREATE_TFRECORDS:\r\n", - " fold_splits = np.array_split(np.arange(20), 10)\r\n", - " for folds in tqdm(fold_splits):\r\n", - " write_tfrecords(folds)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Un9OMMVC4QQQ" - }, - "source": [ - "## Test TFRecords\r\n", - "\r\n", - "* Same number of users and records as in `df_folds`\r\n", - "* Values in tfrecords are the same as in original data" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eIGmrsU28TJ4" - }, - "source": [ - "def test_tfrecord_folds(folds_test, n_sample=100):\r\n", - " pbar = tqdm(total=n_sample)\r\n", - " ds = get_ds_tfrec_raw(folds_test)\r\n", - " df = get_df_tfrec(folds_test)\r\n", - "\r\n", - " for b in ds.shuffle(10000).take(n_sample):\r\n", - " try:\r\n", - " for c in [c for c in df.columns if c not in ['tags', 'fold', 'user_id']]:\r\n", - " try:\r\n", - " assert all(df[df.user_id == b['user_id'].numpy()[0]][c] == b[c].numpy())\r\n", - " except:\r\n", - " print(f\"Error for user {b['user_id'].numpy()[0]}\")\r\n", - " user_tags = np.concatenate(df[df.user_id == b['user_id'].numpy()[0]].tags.values)\r\n", - " assert all(user_tags == (b['tags'].numpy().flatten()))\r\n", - " except:\r\n", - " print(f\"Error for user {b['user_id'].numpy()[0]}\")\r\n", - " finally:\r\n", - " pbar.update()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "ByaOx2_23BJJ" - }, - "source": [ - "if TEST_TFRECORDS:\r\n", - " folds_test = list(range(20))\r\n", - " ds = get_ds_tfrec_raw(folds=folds_test)\r\n", - "\r\n", - " df_folds = get_df_query_bqs(f\"\"\"\r\n", - " SELECT *\r\n", - " FROM {BQ_DATASET}.folds\r\n", - " \"\"\",\r\n", - " dtypes=dtypes)\r\n", - "\r\n", - " user_ids = []\r\n", - " count = 0\r\n", - " for b in ds:\r\n", - " user_ids.append(b['user_id'].numpy()[0])\r\n", - " count += len(b['content_id'].numpy())\r\n", - "\r\n", - " assert len(set(user_ids)) == len(df_folds)\r\n", - " assert df_folds.record_count.sum() == count\r\n", - "\r\n", - " test_tfrecord_folds([10])\r\n", - "\r\n", - " b = next(iter(ds))\r\n", - " print(b)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OMfgCXo159d2" - }, - "source": [ - "## Dataset Functions" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "TcQ-dd1b6EKN" - }, - "source": [ - "@gin.configurable\r\n", - "def get_ds_tfrec(folds=None, max_len=None, min_len=None):\r\n", - " file_pat = 'gs://{BUCKET}/tfrec/{f:02d}-*.tfrec'\r\n", - " file_pats = [file_pat.format(BUCKET=BUCKET, f=f) for f in folds]\r\n", - " options = tf.data.Options()\r\n", - "\r\n", - " ds = (tf.data.Dataset.list_files(file_pats, shuffle=True)\r\n", - " .with_options(options)\r\n", - " .interleave(tf.data.TFRecordDataset, num_parallel_calls=AUTO)\r\n", - " .shuffle(10000)\r\n", - " .map(parse_example, num_parallel_calls=AUTO)\r\n", - " .filter(partial(filter_min_len, min_len=min_len))\r\n", - " .map(example_to_tuple, num_parallel_calls=AUTO)\r\n", - " .map(partial(trunc_seq, max_len=max_len), num_parallel_calls=AUTO)\r\n", - " .map(con_to_cat, num_parallel_calls=AUTO)\r\n", - " )\r\n", - "\r\n", - " ds = ds.repeat().prefetch(AUTO)\r\n", - " \r\n", - " def gen(generator=None):\r\n", - " del generator\r\n", - " for example in fastmath.dataset_as_numpy(ds):\r\n", - " yield example\r\n", - " \r\n", - " return gen" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "l20hHtYADfGp" - }, - "source": [ - "def filter_min_len(e, min_len):\r\n", - " return tf.size(e['content_id']) >= min_len" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "2TcKGF8t70Gl" - }, - "source": [ - "def example_to_tuple(example):\r\n", - " return (example['content_id'], example['part'], example['tags'], example['task_container_id'],\r\n", - " example['answered_correctly'], example['pqet_current'], example['ts_delta'])" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "5PjKM5Q57uS5" - }, - "source": [ - "def trunc_seq(*b, max_len=None):\r\n", - " \"\"\"Returns a sequence drawn randomly from available tokens with a max length\r\n", - " of max_len.\r\n", - " \"\"\"\r\n", - " \r\n", - " max_len = tf.constant(max_len)\r\n", - " seq_len = tf.size(b[0])\r\n", - " seq_end_min = tf.minimum(seq_len - 1, max_len)\r\n", - " seq_end = tf.maximum(max_len, tf.random.uniform((), seq_end_min, seq_len, dtype=tf.int32))\r\n", - " \r\n", - " def get_seq(m):\r\n", - " return m[seq_end-max_len:seq_end]\r\n", - " \r\n", - " return tuple(map(get_seq, b))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "9aC7H12W7mRs" - }, - "source": [ - "# SAINT+ Elapsed Time = prior_question_elapsed_time and Lag Time = time_stamp_1 - timestamp_0\r\n", - "# Elapsed Time categorical - capped at 300 seconds, discrete value for each second\r\n", - "# Lag Time - discretized to minutes 0, 1, 2, 3, 4, 5, 10, 20, 30 ... 1440. 150 discrete values.\r\n", - "\r\n", - "ts_delta_lookup = tf.concat([tf.range(6, dtype=tf.int32), tf.repeat(5, 5)], axis=0)\r\n", - "\r\n", - "cat = 10\r\n", - "while cat < 1440:\r\n", - " ts_delta_lookup = tf.concat([ts_delta_lookup, tf.repeat(cat, 10)], axis=0)\r\n", - " cat += 10\r\n", - " \r\n", - "ts_delta_lookup = tf.concat([ts_delta_lookup, [1440]], axis=0)\r\n", - "\r\n", - "def con_to_cat(*b):\r\n", - " \r\n", - " def pqet_cat(e, vocab_size=None, val_min=None, val_max=None):\r\n", - " e = tf.clip_by_value(e, val_min, val_max)\r\n", - " val_range = val_max - val_min\r\n", - " e = tf.cast((e - val_min) * (vocab_size - 1) / val_range, tf.int32)\r\n", - " return e\r\n", - " \r\n", - " def ts_delta_cat(e):\r\n", - " val_max = tf.cast(tf.reduce_max(ts_delta_lookup) * 60000, tf.float64)\r\n", - " e = tf.clip_by_value(tf.cast(e, tf.float64), 0, val_max)\r\n", - " e = tf.cast(e / 60000, tf.int32)\r\n", - " e = tf.gather(ts_delta_lookup, e)\r\n", - " return e\r\n", - " \r\n", - " pqet = pqet_cat(b[-2], vocab_size=300, val_min=0, val_max=300000)\r\n", - " ts_delta = ts_delta_cat(b[-1])\r\n", - " \r\n", - " return tuple((*b[:-2], pqet, ts_delta))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cWKX9-WHNIdJ" - }, - "source": [ - "## Metrics Functions" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BUC43hA69diL" - }, - "source": [ - "def RocAucScore(num_thresholds=100, pos_label=2):\r\n", - " def f(y_score, y_true, weight): \r\n", - " weight = tnp.expand_dims(tnp.ravel(weight), -1)\r\n", - " \r\n", - " softmax=tl.Softmax(axis=-1)\r\n", - " y_score = tnp.ravel(softmax(y_score)[:, :, -1])\r\n", - " y_score = tnp.expand_dims(y_score, -1)\r\n", - " y_true = tnp.expand_dims(tnp.ravel(y_true) == pos_label, -1).astype(tnp.float32)\r\n", - " \r\n", - " thresholds = tnp.expand_dims(tnp.linspace(1, 0, num_thresholds), 0)\r\n", - " \r\n", - " threshold_counts = y_score > thresholds\r\n", - " \r\n", - " tps = tnp.logical_and(threshold_counts, y_true)\r\n", - " fps = tnp.logical_and(threshold_counts, tnp.logical_not(y_true))\r\n", - " \r\n", - " tps = tnp.sum(tps * weight, axis=0)\r\n", - " fps = tnp.sum(fps * weight, axis=0)\r\n", - " \r\n", - " tpr = tps / tps[-1]\r\n", - " fpr = fps / fps[-1]\r\n", - " \r\n", - " return tnp.trapz(tpr, fpr)\r\n", - " \r\n", - " return tl.Fn('RocAucScore', f)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "LtbnwInF9fMf" - }, - "source": [ - "metrics = {\r\n", - " 'loss': tl.WeightedCategoryCrossEntropy(),\r\n", - " 'accuracy': tl.WeightedCategoryAccuracy(),\r\n", - " 'sequence_accuracy': tl.MaskedSequenceAccuracy(),\r\n", - " 'auc_all': RocAucScore(),\r\n", - " 'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum())\r\n", - "}" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WA_-d9V_NN-h" - }, - "source": [ - "## Model Functions" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "QK7Kf4AM9u9P" - }, - "source": [ - "@gin.configurable\r\n", - "@tl.assert_shape('bl->b1ll')\r\n", - "def PaddingFutureMask(pad=0, block_self=False, tid=True, pad_end=False):\r\n", - " def f(x):\r\n", - " mask_pad = tnp.logical_not(tnp.equal(x, 0))[:, tnp.newaxis, tnp.newaxis, :]\r\n", - " \r\n", - " x_new = x\r\n", - " if pad_end:\r\n", - " x_new = tnp.where(tnp.equal(x, 0), tnp.max(x), x)\r\n", - " \r\n", - " if tid:\r\n", - " mask_future = x_new[:, :, tnp.newaxis] >= x_new[:, tnp.newaxis, :] + block_self\r\n", - " mask_future = mask_future[:, tnp.newaxis, :, :]\r\n", - " else:\r\n", - " mask_future = tnp.arange(x.shape[-1])[tnp.newaxis, tnp.newaxis, :, tnp.newaxis] \\\r\n", - " >= tnp.arange(x.shape[-1])[tnp.newaxis, :]\r\n", - " \r\n", - " return tnp.logical_and(mask_future, mask_pad)\r\n", - " \r\n", - " return tl.Fn(f'PaddingFutureMask({pad})', f)\r\n", - "\r\n", - "\r\n", - "# the only thing different here is the shape assertions to accomodate the change\r\n", - "# in mask shape from b11l to b1ll\r\n", - "\r\n", - "@tl.assert_shape('bld,b1ll->bld,b1ll')\r\n", - "@gin.configurable\r\n", - "def KTAttention(d_feature, n_heads=1, dropout=0.0, mode='train'):\r\n", - " return tl.Serial(\r\n", - " tl.Select([0, 0, 0]),\r\n", - " tl.AttentionQKV(\r\n", - " d_feature, n_heads=n_heads, dropout=dropout, mode=mode),\r\n", - " )\r\n", - "\r\n", - "def my_add_loss_weights(generator, id_to_mask=None):\r\n", - " for example in generator:\r\n", - " weights = (example[0] != id_to_mask).astype(tnp.float32)\r\n", - " yield (*example, weights)\r\n", - "\r\n", - "@gin.configurable\r\n", - "def KTAddLossWeights(id_to_mask=0): # pylint: disable=invalid-name\r\n", - " return lambda g: my_add_loss_weights(g, id_to_mask=id_to_mask)\r\n", - "\r\n", - "def trim_tags(generator):\r\n", - " for example in generator:\r\n", - " # content_id, part, tags, tid, ac, pqet, ts_delta\r\n", - " yield (example[0], example[1], example[2][:, :, :6], example[3], example[4], example[5], example[6])\r\n", - "\r\n", - "@gin.configurable\r\n", - "def TrimTags():\r\n", - " return lambda g: trim_tags(g)\r\n", - "\r\n", - "@gin.configurable\r\n", - "def KTPositionalEncoder(max_position=10000.0, d_model=512, tid=False): \r\n", - " \"\"\"This is set up to perform standard positional encoding based on the\r\n", - " position in the sequence, but also to calculate position based on the\r\n", - " id of the task container to which the question belongs.\r\n", - " \"\"\"\r\n", - " def f(inputs):\r\n", - " # whether or not to use task_container_id or seq position\r\n", - " if tid:\r\n", - " position = tnp.expand_dims(inputs.astype(tnp.float32), -1)\r\n", - " else:\r\n", - " position = tnp.arange(inputs.shape[1])\r\n", - " \r\n", - " position = position.astype(tnp.float32)[tnp.newaxis, :, tnp.newaxis]\r\n", - "\r\n", - " i = tnp.expand_dims(tnp.arange(d_model, dtype=tnp.float32), 0)\r\n", - "\r\n", - " angles = 1 / tnp.power(max_position, (2 * (i // 2)) /\r\n", - " tnp.array(d_model, dtype=tnp.float32))\r\n", - "\r\n", - " angle_rads = position * angles\r\n", - "\r\n", - " # apply sin to even index in the array\r\n", - " sines = tnp.sin(angle_rads[:, :, 0::2])\r\n", - " # apply cos to odd index in the array\r\n", - " cosines = tnp.cos(angle_rads[:, :, 1::2])\r\n", - "\r\n", - " pos_encoding = tnp.concatenate([sines, cosines], axis=-1)\r\n", - "\r\n", - " return pos_encoding\r\n", - "\r\n", - " return tl.Fn('KTPositionalEncoder', f)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Pmo6yeiXkBAQ" - }, - "source": [ - "@gin.configurable\r\n", - "def KTTransformer(d_model,\r\n", - " d_input,\r\n", - " d_part,\r\n", - " d_tags,\r\n", - " d_out,\r\n", - " d_pqet,\r\n", - " d_ts_delta,\r\n", - " d_tid,\r\n", - " embed_concat=False,\r\n", - " d_ff=2048,\r\n", - " n_encoder_layers=6,\r\n", - " n_decoder_layers=6,\r\n", - " n_heads=8,\r\n", - " max_len=2048,\r\n", - " dropout=0.1,\r\n", - " dropout_shared_axes=None,\r\n", - " mode='train',\r\n", - " ff_activation=tl.Relu):\r\n", - " \r\n", - " def Embedder(vocab_size, d_embed): # tokens --> vectors\r\n", - " return [\r\n", - " tl.Embedding(vocab_size, d_embed),\r\n", - " tl.Dropout(\r\n", - " rate=dropout, shared_axes=dropout_shared_axes, mode=mode),\r\n", - " ]\r\n", - "\r\n", - " # Encoder Embeddings\r\n", - " in_embedder = Embedder(*d_input)\r\n", - " part_embedder = Embedder(*d_part)\r\n", - " # Keeps the tags in the data batch tuple, but drops it if it\r\n", - " # isn't included in the embeddings.\r\n", - " if d_tags is not None:\r\n", - " tags_embedder = tl.Serial(Embedder(*d_tags), tl.Sum(axis=-2))\r\n", - " else:\r\n", - " tags_embedder = tl.Drop()\r\n", - " in_pos_encoder = KTPositionalEncoder(*d_tid)\r\n", - "\r\n", - " # Decoder Embeddings\r\n", - " out_embedder = Embedder(*d_out)\r\n", - " pqet_embedder = Embedder(*d_pqet)\r\n", - " ts_delta_embedder = Embedder(*d_ts_delta)\r\n", - " out_pos_encoder = KTPositionalEncoder(*d_tid)\r\n", - "\r\n", - " encoder_mode = 'eval' if mode == 'predict' else mode\r\n", - "\r\n", - " in_encoder = [tl.Parallel(in_embedder, part_embedder, tags_embedder, in_pos_encoder)]\r\n", - " out_encoder = [tl.Parallel(out_embedder, pqet_embedder, ts_delta_embedder, out_pos_encoder)]\r\n", - " \r\n", - " if embed_concat:\r\n", - " if d_tags is not None:\r\n", - " in_encoder += [tl.Concatenate(n_items=3), tl.Add()]\r\n", - " else:\r\n", - " in_encoder += [tl.Concatenate(n_items=2), tl.Add()]\r\n", - " out_encoder += [tl.Concatenate(n_items=3), tl.Add()]\r\n", - " else:\r\n", - " if d_tags is not None:\r\n", - " in_encoder += [tl.Add(), tl.Add(), tl.Add()]\r\n", - " else:\r\n", - " in_encoder += [tl.Add(), tl.Add()]\r\n", - " out_encoder += [tl.Add(), tl.Add(), tl.Add()]\r\n", - "\r\n", - " encoder_blocks = [\r\n", - " _KTEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,\r\n", - " mode, ff_activation)\r\n", - " for i in range(n_encoder_layers)]\r\n", - "\r\n", - " encoder = tl.Serial(\r\n", - " in_encoder,\r\n", - " encoder_blocks,\r\n", - " tl.LayerNorm()\r\n", - " )\r\n", - "\r\n", - " encoder_decoder_blocks = [\r\n", - " _KTEncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,\r\n", - " mode, ff_activation)\r\n", - " for i in range(n_decoder_layers)]\r\n", - "\r\n", - " # output tuple - leading number is max index \r\n", - " return tl.Serial( # 7: 0:tok_e 1:tok_p 2:tok_t 3:tok_tid 4:tok_d 5:tok_pq, 6:tok_tsd 7:wts_l \r\n", - " tl.Select([0, 1, 2, 3, 3, 3, # 10: 0:tok_e 1:tok_p 2:tok_t 3:tok_tid 4:tok_tid 5: tok_tid\r\n", - " 4, 5, 6, 4]), # 6:tok_d 7:tok`_pq, 8:tok_tsd 9:tok_d 10:wts_l\r\n", - "\r\n", - " # Encode.\r\n", - " tl.Parallel(\r\n", - " tl.Select([0, 1, 2, 3]),\r\n", - " PaddingFutureMask(tid=True)\r\n", - " ), # 10: tok_e tok_p tok_t tok_tid mask_combined tok_tid tok_d tok_pq tok_tsd tok_d wts_l\r\n", - " encoder, # 7: vec_e mask_combined tok_tid tok_d tok_pq tok_tsd tok_d wts_l\r\n", - " # Decode.\r\n", - " tl.Select([3, 4, 5, 2, 2, 0]), # 7: tok_d tok_pq tok_tsd tok_tid tok_tid vec_e tok_d wts_l\r\n", - " tl.Parallel(\r\n", - " tl.ShiftRight(mode=mode),\r\n", - " tl.ShiftRight(mode=mode), \r\n", - " tl.ShiftRight(mode=mode),\r\n", - " tl.ShiftRight(mode=mode),\r\n", - " tl.Serial(tl.ShiftRight(),\r\n", - " PaddingFutureMask(tid=False)),\r\n", - " ), # 7: tok_d tok_pq tok_tsd tok_tid mask_combined vec_e tok_d wts_l \r\n", - " out_encoder, # 4: vec_d mask_combined vec_e tok_d wts_l\r\n", - " encoder_decoder_blocks, # 4: vec_d mask_combined vec_e tok_d wts_l\r\n", - " tl.LayerNorm(), # 4: vec_d mask_combined vec_e tok_d wts_l\r\n", - "\r\n", - " # Map to output vocab.\r\n", - " tl.Select([0], n_in=3), # 3: vec_d tok_d wts_l\r\n", - " tl.Dense(d_out[0]), # vec_d .....\r\n", - " )\r\n", - "\r\n", - "\r\n", - "def _KTEncoderBlock(d_model, d_ff, n_heads,\r\n", - " dropout, dropout_shared_axes, mode, ff_activation):\r\n", - " \"\"\"Same as the default, but changes attention layer to KTAttention to \r\n", - " accept a combined padding and future mask.\r\n", - " \"\"\"\r\n", - " \r\n", - " attention = KTAttention(\r\n", - " d_model, n_heads=n_heads, dropout=dropout, mode=mode)\r\n", - "\r\n", - " feed_forward = _KTFeedForwardBlock(\r\n", - " d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)\r\n", - "\r\n", - " dropout_ = tl.Dropout(\r\n", - " rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", - "\r\n", - " return [\r\n", - " tl.Residual(\r\n", - " tl.LayerNorm(),\r\n", - " attention,\r\n", - " dropout_,\r\n", - " ),\r\n", - " tl.Residual(\r\n", - " feed_forward\r\n", - " ),\r\n", - " ]\r\n", - "\r\n", - "def _KTEncoderDecoderBlock(d_model, d_ff, n_heads,\r\n", - " dropout, dropout_shared_axes, mode, ff_activation):\r\n", - " \"\"\"Same as the default, but changes the first layer to KTAttention to \r\n", - " accept a combined padding and future mask.\r\n", - " \"\"\"\r\n", - " def _Dropout():\r\n", - " return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", - "\r\n", - " attention = KTAttention(\r\n", - " d_model, n_heads=n_heads, dropout=dropout, mode=mode)\r\n", - "\r\n", - " attention_qkv = tl.AttentionQKV(\r\n", - " d_model, n_heads=n_heads, dropout=dropout, mode=mode)\r\n", - "\r\n", - " feed_forward = _KTFeedForwardBlock(\r\n", - " d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)\r\n", - "\r\n", - " return [ # vec_d masks vec_e\r\n", - " tl.Residual(\r\n", - " tl.LayerNorm(), # vec_d ..... .....\r\n", - " attention, # vec_d ..... .....\r\n", - " _Dropout(), # vec_d ..... .....\r\n", - " ),\r\n", - " tl.Residual(\r\n", - " tl.LayerNorm(), # vec_d ..... .....\r\n", - " tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e\r\n", - " attention_qkv, # vec_d masks vec_e\r\n", - " _Dropout(), # vec_d masks vec_e\r\n", - " ),\r\n", - " tl.Residual(\r\n", - " feed_forward # vec_d masks vec_e\r\n", - " ),\r\n", - " ]\r\n", - "\r\n", - "def _KTFeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes,\r\n", - " mode, activation):\r\n", - " \"\"\"Same as default.\r\n", - " \"\"\"\r\n", - " dropout_middle = tl.Dropout(\r\n", - " rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", - " dropout_final = tl.Dropout(\r\n", - " rate=dropout, shared_axes=dropout_shared_axes, mode=mode)\r\n", - "\r\n", - " return [\r\n", - " tl.LayerNorm(),\r\n", - " tl.Dense(d_ff),\r\n", - " activation(),\r\n", - " dropout_middle,\r\n", - " tl.Dense(d_model),\r\n", - " dropout_final,\r\n", - " ]" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RFlp9RAINR4d" - }, - "source": [ - "## Configuration" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "5nj3MX9D97Nz" - }, - "source": [ - "# Configure hyperparameters.\r\n", - "\r\n", - "total_steps = 10000\r\n", - "\r\n", - "gin.clear_config()\r\n", - "gin.parse_config(f\"\"\"\r\n", - "import trax.layers\r\n", - "import trax.models\r\n", - "import trax.optimizers\r\n", - "import trax.data.inputs\r\n", - "import trax.supervised.trainer_lib\r\n", - "\r\n", - "# Parameters that will vary between experiments:\r\n", - "# ==============================================================================\r\n", - "# min_len = 12\r\n", - "# max_len = 64\r\n", - "# d_model = 512 # need to make sure this works with concat embeddings\r\n", - "# d_ff = 256\r\n", - "# n_encoder_layers = 2\r\n", - "# n_decoder_layers = 2\r\n", - "# n_heads = 2\r\n", - "# dropout = 0.0\r\n", - "\r\n", - "min_len = 12\r\n", - "max_len = 256\r\n", - "d_model = 512 # need to make sure this works with concat embeddings\r\n", - "d_ff = 1024\r\n", - "n_encoder_layers = 6\r\n", - "n_decoder_layers = 6\r\n", - "n_heads = 8\r\n", - "dropout = 0.1\r\n", - "\r\n", - "# Set to True to aggregate embeddings by concatenation. If set\r\n", - "# to False aggregation will be by sum.\r\n", - "embed_concat = True\r\n", - "\r\n", - "# (Vocab, depth) Uncomment to use with aggregation by concatenation.\r\n", - "d_input = (13500, 384)\r\n", - "d_part = (8, 8)\r\n", - "d_tags = (189, 120)\r\n", - "\r\n", - "# (Vocab, depth) Uncomment to use with aggregation by concatenation.\r\n", - "d_out = (3, 384)\r\n", - "d_pqet = (300, 64)\r\n", - "d_ts_delta = (150, 64)\r\n", - "\r\n", - "# Used for positional encodings if not None. Positional encoding based\r\n", - "# on sequence in batch if None.\r\n", - "d_tid = (10000, %d_model)\r\n", - "\r\n", - "# d_input = (13500, %d_model)\r\n", - "# d_part = (8, %d_model)\r\n", - "# d_tags = (189, %d_model)\r\n", - "# # d_tags = None\r\n", - "# d_out = (3, %d_model)\r\n", - "# d_pqet = (300, %d_model)\r\n", - "# d_ts_delta = (150, %d_model)\r\n", - "# d_tid = (10000, %d_model)\r\n", - "\r\n", - "total_steps = {total_steps}\r\n", - "\r\n", - "# Parameters for learning rate schedule:\r\n", - "# ==============================================================================\r\n", - "warmup_and_rsqrt_decay.n_warmup_steps = 3000\r\n", - "warmup_and_rsqrt_decay.max_value = 0.001\r\n", - "\r\n", - "# multifactor.constant = 0.01\r\n", - "# multifactor.factors = 'constant * linear_warmup * cosine_decay'\r\n", - "# multifactor.warmup_steps = 4000\r\n", - "# multifactor.steps_per_cycle = %total_steps\r\n", - "# multifactor.minimum = .0001\r\n", - "\r\n", - "# Parameters for Adam:\r\n", - "# ==============================================================================\r\n", - "# Adam.weight_decay_rate=0.0\r\n", - "Adam.b1 = 0.9\r\n", - "Adam.b2 = 0.999\r\n", - "Adam.eps = 1e-8\r\n", - "\r\n", - "# Parameters for input pipeline:\r\n", - "# ==============================================================================\r\n", - "get_ds_tfrec.min_len = %min_len\r\n", - "get_ds_tfrec.max_len = %max_len\r\n", - "train/get_ds_tfrec.folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]\r\n", - "eval/get_ds_tfrec.folds = [19]\r\n", - "\r\n", - "BucketByLength.boundaries = [32, 64, 128]\r\n", - "BucketByLength.batch_sizes = [512, 256, 128, 64]\r\n", - "# BucketByLength.batch_sizes = [16, 8, 4, 2]\r\n", - "\r\n", - "BucketByLength.strict_pad_on_len = True\r\n", - "\r\n", - "KTAddLossWeights.id_to_mask = 0\r\n", - "\r\n", - "train/make_additional_stream.stream = [\r\n", - " @train/get_ds_tfrec(),\r\n", - " @BucketByLength(),\r\n", - " @TrimTags(),\r\n", - " @KTAddLossWeights()\r\n", - "]\r\n", - "\r\n", - "eval/make_additional_stream.stream = [\r\n", - " @eval/get_ds_tfrec(),\r\n", - " @BucketByLength(),\r\n", - " @TrimTags(),\r\n", - " @KTAddLossWeights()\r\n", - "]\r\n", - "\r\n", - "make_inputs.train_stream = @train/make_additional_stream()\r\n", - "make_inputs.eval_stream = @eval/make_additional_stream()\r\n", - "\r\n", - "# Parameters for KTPositionalEncoder:\r\n", - "# ==============================================================================\r\n", - "KTPositionalEncoder.d_model = %d_model\r\n", - "\r\n", - "# Set to True to calculate positional encodings based on position in orginal\r\n", - "# full length sequence, False to be based on position in batch sequence.\r\n", - "KTPositionalEncoder.tid = False\r\n", - "\r\n", - "# Parameters for PaddingFutureMaske:\r\n", - "# ==============================================================================\r\n", - "PaddingFutureMask.pad_end = False\r\n", - "\r\n", - "# Set to True to calculate future mask based on task container id (questions\r\n", - "# are delivered to users in groups identified by task_container id) or False\r\n", - "# to be based next question only.\r\n", - "PaddingFutureMask.tid = False\r\n", - "\r\n", - "# Parameters for KTTransformer:\r\n", - "# ==============================================================================\r\n", - "KTTransformer.d_model = %d_model\r\n", - "KTTransformer.d_input = %d_input\r\n", - "KTTransformer.d_part = %d_part\r\n", - "KTTransformer.d_tags = %d_tags\r\n", - "KTTransformer.d_out = %d_out\r\n", - "KTTransformer.d_pqet = %d_pqet\r\n", - "KTTransformer.d_ts_delta = %d_ts_delta\r\n", - "KTTransformer.d_tid = %d_tid\r\n", - "KTTransformer.embed_concat = %embed_concat\r\n", - "KTTransformer.d_ff = %d_ff\r\n", - "KTTransformer.n_encoder_layers = %n_encoder_layers\r\n", - "KTTransformer.n_decoder_layers = %n_decoder_layers\r\n", - "KTTransformer.n_heads = %n_heads\r\n", - "KTTransformer.dropout = %dropout\r\n", - "\r\n", - "# Parameters for train:\r\n", - "# ==============================================================================\r\n", - "train.inputs = @make_inputs\r\n", - "train.eval_frequency = 200\r\n", - "train.eval_steps = 20\r\n", - "train.checkpoints_at = {list(range(0,total_steps + 1, 2000))}\r\n", - "train.optimizer = @trax.optimizers.Adam\r\n", - "train.steps = %total_steps\r\n", - "train.model = @KTTransformer\r\n", - "train.lr_schedule_fn = @trax.supervised.lr_schedules.warmup_and_rsqrt_decay\r\n", - "\"\"\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "PP46-cEAB4i_" - }, - "source": [ - "if False:\r\n", - " inputs = trax.data.inputs.make_inputs()\r\n", - " train_stream = inputs.train_stream(trax.fastmath.device_count())\r\n", - " train_eval_stream = inputs.train_eval_stream(trax.fastmath.device_count())\r\n", - " b = next(train_stream)\r\n", - " for i, m in enumerate(b):\r\n", - " print(i, m.shape)\r\n", - " b" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "8OmbSWj5Cvt2" - }, - "source": [ - "if False:\r\n", - " model = KTTransformer()\r\n", - " model.init(trax.shapes.signature(b))\r\n", - " outs = model(b)\r\n", - " for i, m in enumerate(outs):\r\n", - " print(i, m.shape)\r\n", - " outs" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dcHfhEEFNXXJ" - }, - "source": [ - "## Training" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "XGIfCGL9GprB" - }, - "source": [ - "run_no = 0\r\n", - "prefix = f'model_runs/{run_no:02d}'\r\n", - "output_dir = f'gs://{BUCKET}/{prefix}'\r\n", - "log_dir = output_dir[:-3]" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "5rhH1YNVHEPO" - }, - "source": [ - "%tensorboard --logdir $log_dir" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "wQcrZYSTGMhX" - }, - "source": [ - "if TRAIN_MODEL:\r\n", - " if False:\r\n", - " init_checkpoint = f'{output_dir}/model.pkl.gz'\r\n", - " else:\r\n", - " bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix)))\r\n", - "\r\n", - " loop = trax.supervised.trainer_lib.train(output_dir, metrics=metrics)" - ], - "execution_count": null, - "outputs": [] - } - ] -} diff --git a/trax/examples/MathQA_Python_generation_notebook.ipynb b/trax/examples/MathQA_Python_generation_notebook.ipynb deleted file mode 100644 index 09ba56aa3..000000000 --- a/trax/examples/MathQA_Python_generation_notebook.ipynb +++ /dev/null @@ -1,174 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "oILRLCWN_16u" - }, - "outputs": [], - "source": [ - "#@title License\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lpqiZgTy4DqT" - }, - "source": [ - "How to generate the MathQA-Python dataset?\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "1. Download the dataset from the MathQA project webpage: https://math-qa.github.io/\n", - "2. Create the mathqa directory in the local colab drive.\n", - "3. Unpack the json files (train.json, dev.json, test.json, challenge_test.json) and place them in the mathqa directory.\n", - "4. Run the cells below - they will generate the MathQA-Python dataset for the test split. \n", - "5. Repeat the process for other splits if needed.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "B8nqRq0Qhcf8" - }, - "outputs": [], - "source": [ - "!pip install -U git+https://github.com/google/trax.git@220a62303ebf4ad18871aa5607b4dda2f064f2d2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "v4RKdd18hqRH" - }, - "outputs": [], - "source": [ - "from trax import data\n", - "import json\n", - "import numpy as np\n", - "import os\n", - "import tensorflow as tf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TAyU75naIFW5" - }, - "outputs": [], - "source": [ - "dataset_path = '/content/mathqa/'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "L-RZ9MeajaWC" - }, - "outputs": [], - "source": [ - "mathqa_test_gen = data.CreateMathQAInputs(dataset_path=dataset_path, cumulative=False, python_code=True, full_dict=True, train=False, test=True)()\n", - "def read_all_problems(mathqa_gen):\n", - " problems = []\n", - " questions = set()\n", - " index = 0\n", - " while True:\n", - " problem = next(mathqa_gen)\n", - " problem_dict = {}\n", - " if problem[0] in questions:\n", - " break\n", - " else:\n", - " problem_dict['text'] = problem[0]\n", - " problem_dict['code'] = problem[1]\n", - " problem_dict['dsl_code'] = problem[2]\n", - " problem_dict['reasoning'] = problem[3].strip('\\\"').strip(\"\\'\")\n", - " problem_dict['answer'] = data.tf_inputs.execute_mathqa_program(problem[0], problem[1].split('\\n'))\n", - " problem_dict['task_id'] = index\n", - " np.testing.assert_almost_equal(problem_dict['answer'], data.tf_inputs.execute_mathqa_dsl_program(problem[0], [problem[2]]))\n", - " problems.append(problem_dict)\n", - " questions.add(problem[0])\n", - " index += 1\n", - " return problems" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "K96xIQDQjyrS" - }, - "outputs": [], - "source": [ - "test_problems = read_all_problems(mathqa_test_gen)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "K5y7244_j3mB" - }, - "outputs": [], - "source": [ - "len(test_problems)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "emEvo5iAucGl" - }, - "outputs": [], - "source": [ - "test_problems[0]" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "MathQA_Python_generation_notebook.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1pdlfcJ8F4-QhBWe3KRKJW_iSov7zl6Ve", - "timestamp": 1626376876263 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/examples/NER_using_Reformer.ipynb b/trax/examples/NER_using_Reformer.ipynb deleted file mode 100644 index 9e0251ef3..000000000 --- a/trax/examples/NER_using_Reformer.ipynb +++ /dev/null @@ -1,1669 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - }, - "papermill": { - "duration": 29594.274789, - "end_time": "2020-10-20T22:20:01.092204", - "environment_variables": {}, - "exception": null, - "input_path": "__notebook__.ipynb", - "output_path": "__notebook__.ipynb", - "parameters": {}, - "start_time": "2020-10-20T14:06:46.817415", - "version": "2.1.0" - }, - "colab": { - "name": "NER using Reformer", - "provenance": [], - "include_colab_link": true - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eGCe1pjznIQS" - }, - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Zpj2rPQdm8nb" - }, - "source": [ - "Author - [@SauravMaheshkar](https://github.com/SauravMaheshkar)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8LURHZ84v9-i", - "papermill": { - "duration": 0.034262, - "end_time": "2020-10-20T14:06:51.973823", - "exception": false, - "start_time": "2020-10-20T14:06:51.939561", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "# Install Dependencies\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yEuPYcg3BoAb", - "papermill": { - "duration": 0.031347, - "end_time": "2020-10-20T14:06:52.037011", - "exception": false, - "start_time": "2020-10-20T14:06:52.005664", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "Install the latest version of the [Trax](https://github.com/google/trax) Library." - ] - }, - { - "cell_type": "code", - "metadata": { - "_kg_hide-output": true, - "collapsed": true, - "execution": { - "iopub.execute_input": "2020-10-20T14:06:52.106075Z", - "iopub.status.busy": "2020-10-20T14:06:52.105239Z", - "iopub.status.idle": "2020-10-20T14:07:45.817343Z", - "shell.execute_reply": "2020-10-20T14:07:45.816507Z" - }, - "id": "u4GfFPtWv0eb", - "papermill": { - "duration": 53.749037, - "end_time": "2020-10-20T14:07:45.817478", - "exception": false, - "start_time": "2020-10-20T14:06:52.068441", - "status": "completed" - }, - "tags": [], - "outputId": "59aaef48-c9fc-4af9-9043-a1f7c7745749" - }, - "source": [ - "!pip install -q -U trax" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\u001b[31mERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.\r\n", - "\r\n", - "We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.\r\n", - "\r\n", - "pytorch-lightning 0.9.0 requires tensorboard==2.2.0, but you'll have tensorboard 2.3.0 which is incompatible.\r\n", - "kfac 0.2.3 requires tensorflow-probability==0.8, but you'll have tensorflow-probability 0.7.0 which is incompatible.\u001b[0m\r\n", - "\u001b[33mWARNING: You are using pip version 20.2.3; however, version 20.2.4 is available.\r\n", - "You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.\u001b[0m\r\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "25jBohslvAaM", - "papermill": { - "duration": 0.031968, - "end_time": "2020-10-20T14:07:45.882676", - "exception": false, - "start_time": "2020-10-20T14:07:45.850708", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "# Introduction\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "drjA2GYE4g_F", - "papermill": { - "duration": 0.031988, - "end_time": "2020-10-20T14:07:45.947830", - "exception": false, - "start_time": "2020-10-20T14:07:45.915842", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "---\n", - "\n", - "**Named-entity recognition** (NER) is a subtask of *information extraction* that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc.\n", - "\n", - "To evaluate the quality of a NER system's output, several measures have been defined. The usual measures are called **Precision**, **recall**, and **F1 score**. However, several issues remain in just how to calculate those values. State-of-the-art NER systems for English produce near-human performance. For example, the best system entering MUC-7 scored 93.39% of F-measure while human annotators scored 97.60% and 96.95%." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SbzpsLnB6Rt_", - "papermill": { - "duration": 0.031674, - "end_time": "2020-10-20T14:07:46.011670", - "exception": false, - "start_time": "2020-10-20T14:07:45.979996", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "# Importing Packages" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:07:46.082693Z", - "iopub.status.busy": "2020-10-20T14:07:46.081926Z", - "iopub.status.idle": "2020-10-20T14:07:57.865757Z", - "shell.execute_reply": "2020-10-20T14:07:57.865118Z" - }, - "id": "2pGNHjR46RFs", - "papermill": { - "duration": 11.822159, - "end_time": "2020-10-20T14:07:57.865897", - "exception": false, - "start_time": "2020-10-20T14:07:46.043738", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "import trax # Our Main Library\n", - "from trax import layers as tl\n", - "import os # For os dependent functionalities\n", - "import numpy as np # For scientific computing\n", - "import pandas as pd # For basic data analysis\n", - "import random as rnd # For using random functions" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qQFItoSJGeti", - "papermill": { - "duration": 0.032906, - "end_time": "2020-10-20T14:07:57.931601", - "exception": false, - "start_time": "2020-10-20T14:07:57.898695", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "# Pre-Processing" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jdC7V8KspbHb", - "papermill": { - "duration": 0.032062, - "end_time": "2020-10-20T14:07:57.996789", - "exception": false, - "start_time": "2020-10-20T14:07:57.964727", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Loading the Dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CcLS-1P0IePt", - "papermill": { - "duration": 0.032255, - "end_time": "2020-10-20T14:07:58.061951", - "exception": false, - "start_time": "2020-10-20T14:07:58.029696", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "Let's load the `ner_dataset.csv` file into a dataframe and see what it looks like" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:07:58.132552Z", - "iopub.status.busy": "2020-10-20T14:07:58.131692Z", - "iopub.status.idle": "2020-10-20T14:07:59.524742Z", - "shell.execute_reply": "2020-10-20T14:07:59.523989Z" - }, - "id": "q83GgD2JWlTz", - "papermill": { - "duration": 1.430809, - "end_time": "2020-10-20T14:07:59.524871", - "exception": false, - "start_time": "2020-10-20T14:07:58.094062", - "status": "completed" - }, - "tags": [], - "outputId": "3f67e377-6450-41d2-ca19-b1f8f79f4766" - }, - "source": [ - "data = pd.read_csv(\"/kaggle/input/entity-annotated-corpus/ner_dataset.csv\",encoding = 'ISO-8859-1')\n", - "data = data.fillna(method = 'ffill')\n", - "data.head()" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Sentence #WordPOSTag
0Sentence: 1ThousandsNNSO
1Sentence: 1ofINO
2Sentence: 1demonstratorsNNSO
3Sentence: 1haveVBPO
4Sentence: 1marchedVBNO
\n", - "
" - ], - "text/plain": [ - " Sentence # Word POS Tag\n", - "0 Sentence: 1 Thousands NNS O\n", - "1 Sentence: 1 of IN O\n", - "2 Sentence: 1 demonstrators NNS O\n", - "3 Sentence: 1 have VBP O\n", - "4 Sentence: 1 marched VBN O" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 3 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DCss4fg8IQwN", - "papermill": { - "duration": 0.032562, - "end_time": "2020-10-20T14:07:59.590814", - "exception": false, - "start_time": "2020-10-20T14:07:59.558252", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Creating a Vocabulary File" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "etNMRldEImgg", - "papermill": { - "duration": 0.032586, - "end_time": "2020-10-20T14:07:59.656501", - "exception": false, - "start_time": "2020-10-20T14:07:59.623915", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "We can see there's a column for the words in each sentence. Thus, we can extract this column using the [`.loc()`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.loc.html) and store it into a `.txt` file using the [`.savetext()`](https://numpy.org/doc/stable/reference/generated/numpy.savetxt.html) function from numpy." - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:07:59.752532Z", - "iopub.status.busy": "2020-10-20T14:07:59.729721Z", - "iopub.status.idle": "2020-10-20T14:08:02.024882Z", - "shell.execute_reply": "2020-10-20T14:08:02.025503Z" - }, - "id": "tw9ewglyIa_0", - "papermill": { - "duration": 2.336183, - "end_time": "2020-10-20T14:08:02.025687", - "exception": false, - "start_time": "2020-10-20T14:07:59.689504", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Extract the 'Word' column from the dataframe\n", - "words = data.loc[:, \"Word\"]\n", - "\n", - "## Convert into a text file using the .savetxt() function\n", - "np.savetxt(r'words.txt', words.values, fmt=\"%s\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "skW3Wz9YKULq", - "papermill": { - "duration": 0.032752, - "end_time": "2020-10-20T14:08:02.092503", - "exception": false, - "start_time": "2020-10-20T14:08:02.059751", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Creating a Dictionary for Vocabulary" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LWMeXo8LCkkG", - "papermill": { - "duration": 0.032646, - "end_time": "2020-10-20T14:08:02.158153", - "exception": false, - "start_time": "2020-10-20T14:08:02.125507", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "Here, we create a Dictionary for our vocabulary by reading through all the sentences in the dataset." - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:02.234561Z", - "iopub.status.busy": "2020-10-20T14:08:02.233556Z", - "iopub.status.idle": "2020-10-20T14:08:02.865465Z", - "shell.execute_reply": "2020-10-20T14:08:02.866059Z" - }, - "id": "C9TxwknFKStf", - "papermill": { - "duration": 0.675227, - "end_time": "2020-10-20T14:08:02.866282", - "exception": false, - "start_time": "2020-10-20T14:08:02.191055", - "status": "completed" - }, - "tags": [], - "outputId": "b7574311-badd-4623-da2a-5ef101db0b00" - }, - "source": [ - "vocab = {}\n", - "with open('words.txt') as f:\n", - " for i, l in enumerate(f.read().splitlines()):\n", - " vocab[l] = i\n", - " print(\"Number of words:\", len(vocab))\n", - " vocab[''] = len(vocab)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Number of words: 35178\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Zj-hlBvzpl5x", - "papermill": { - "duration": 0.035449, - "end_time": "2020-10-20T14:08:02.936000", - "exception": false, - "start_time": "2020-10-20T14:08:02.900551", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Extracting Sentences from the Dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wYUmK0skDFU7", - "papermill": { - "duration": 0.033405, - "end_time": "2020-10-20T14:08:03.003298", - "exception": false, - "start_time": "2020-10-20T14:08:02.969893", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "For extracting sentences from the dataset and creating (X,y) pairs for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:03.081064Z", - "iopub.status.busy": "2020-10-20T14:08:03.080247Z", - "iopub.status.idle": "2020-10-20T14:08:03.083412Z", - "shell.execute_reply": "2020-10-20T14:08:03.083967Z" - }, - "id": "J_iN8EMIWyNM", - "papermill": { - "duration": 0.047324, - "end_time": "2020-10-20T14:08:03.084165", - "exception": false, - "start_time": "2020-10-20T14:08:03.036841", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "class Get_sentence(object):\n", - " def __init__(self,data):\n", - " self.n_sent=1\n", - " self.data = data\n", - " agg_func = lambda s:[(w,p,t) for w,p,t in zip(s[\"Word\"].values.tolist(),\n", - " s[\"POS\"].values.tolist(),\n", - " s[\"Tag\"].values.tolist())]\n", - " self.grouped = self.data.groupby(\"Sentence #\").apply(agg_func)\n", - " self.sentences = [s for s in self.grouped]" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:03.245176Z", - "iopub.status.busy": "2020-10-20T14:08:03.176225Z", - "iopub.status.idle": "2020-10-20T14:08:10.354304Z", - "shell.execute_reply": "2020-10-20T14:08:10.353652Z" - }, - "id": "OXZjM3UeW3ur", - "papermill": { - "duration": 7.236033, - "end_time": "2020-10-20T14:08:10.354445", - "exception": false, - "start_time": "2020-10-20T14:08:03.118412", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "getter = Get_sentence(data)\n", - "sentence = getter.sentences" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:10.581301Z", - "iopub.status.busy": "2020-10-20T14:08:10.518364Z", - "iopub.status.idle": "2020-10-20T14:08:10.588073Z", - "shell.execute_reply": "2020-10-20T14:08:10.587321Z" - }, - "id": "_ZKrFo7cW5RX", - "papermill": { - "duration": 0.196933, - "end_time": "2020-10-20T14:08:10.588222", - "exception": false, - "start_time": "2020-10-20T14:08:10.391289", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "words = list(set(data[\"Word\"].values))\n", - "words_tag = list(set(data[\"Tag\"].values))\n", - "\n", - "word_idx = {w : i+1 for i ,w in enumerate(words)}\n", - "tag_idx = {t : i for i ,t in enumerate(words_tag)}" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:10.718207Z", - "iopub.status.busy": "2020-10-20T14:08:10.689128Z", - "iopub.status.idle": "2020-10-20T14:08:11.290981Z", - "shell.execute_reply": "2020-10-20T14:08:11.291803Z" - }, - "id": "yxWy9E-gXPRJ", - "papermill": { - "duration": 0.669432, - "end_time": "2020-10-20T14:08:11.292061", - "exception": false, - "start_time": "2020-10-20T14:08:10.622629", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "X = [[word_idx[w[0]] for w in s] for s in sentence]\n", - "y = [[tag_idx[w[2]] for w in s] for s in sentence]" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UM6bvNKkpyYP", - "papermill": { - "duration": 0.034986, - "end_time": "2020-10-20T14:08:11.365543", - "exception": false, - "start_time": "2020-10-20T14:08:11.330557", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Making a Batch Generator" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jO-C08uzDqDf", - "papermill": { - "duration": 0.034216, - "end_time": "2020-10-20T14:08:11.434628", - "exception": false, - "start_time": "2020-10-20T14:08:11.400412", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "Here, we create a batch generator for training. " - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:11.521714Z", - "iopub.status.busy": "2020-10-20T14:08:11.512674Z", - "iopub.status.idle": "2020-10-20T14:08:11.525251Z", - "shell.execute_reply": "2020-10-20T14:08:11.524640Z" - }, - "id": "kLaPXRDtXe6E", - "papermill": { - "duration": 0.056187, - "end_time": "2020-10-20T14:08:11.525386", - "exception": false, - "start_time": "2020-10-20T14:08:11.469199", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "def data_generator(batch_size, x, y,pad, shuffle=False, verbose=False):\n", - "\n", - " num_lines = len(x)\n", - " lines_index = [*range(num_lines)]\n", - " if shuffle:\n", - " rnd.shuffle(lines_index)\n", - " \n", - " index = 0 \n", - " while True:\n", - " buffer_x = [0] * batch_size \n", - " buffer_y = [0] * batch_size \n", - "\n", - " max_len = 0\n", - " for i in range(batch_size):\n", - " if index >= num_lines:\n", - " index = 0\n", - " if shuffle:\n", - " rnd.shuffle(lines_index)\n", - " \n", - " buffer_x[i] = x[lines_index[index]]\n", - " buffer_y[i] = y[lines_index[index]]\n", - " \n", - " lenx = len(x[lines_index[index]]) \n", - " if lenx > max_len:\n", - " max_len = lenx \n", - " \n", - " index += 1\n", - "\n", - " X = np.full((batch_size, max_len), pad)\n", - " Y = np.full((batch_size, max_len), pad)\n", - "\n", - "\n", - " for i in range(batch_size):\n", - " x_i = buffer_x[i]\n", - " y_i = buffer_y[i]\n", - "\n", - " for j in range(len(x_i)):\n", - "\n", - " X[i, j] = x_i[j]\n", - " Y[i, j] = y_i[j]\n", - "\n", - " if verbose: print(\"index=\", index)\n", - " yield((X,Y))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_xtaMdPWp8NW", - "papermill": { - "duration": 0.034404, - "end_time": "2020-10-20T14:08:11.594978", - "exception": false, - "start_time": "2020-10-20T14:08:11.560574", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Splitting into Test and Train " - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:11.681160Z", - "iopub.status.busy": "2020-10-20T14:08:11.670506Z", - "iopub.status.idle": "2020-10-20T14:08:11.718703Z", - "shell.execute_reply": "2020-10-20T14:08:11.717823Z" - }, - "id": "RWYE1ndgX2up", - "papermill": { - "duration": 0.089107, - "end_time": "2020-10-20T14:08:11.718853", - "exception": false, - "start_time": "2020-10-20T14:08:11.629746", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "from sklearn.model_selection import train_test_split\n", - "x_train,x_test,y_train,y_test = train_test_split(X,y,test_size = 0.1,random_state=1)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MAeHnfnjx-Am", - "papermill": { - "duration": 0.034597, - "end_time": "2020-10-20T14:08:11.788761", - "exception": false, - "start_time": "2020-10-20T14:08:11.754164", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "# Building the Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "30w3W3IGzsP-", - "papermill": { - "duration": 0.038502, - "end_time": "2020-10-20T14:08:11.869814", - "exception": false, - "start_time": "2020-10-20T14:08:11.831312", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## The Reformer Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ISjwwZJLx_5j", - "papermill": { - "duration": 0.035572, - "end_time": "2020-10-20T14:08:11.940351", - "exception": false, - "start_time": "2020-10-20T14:08:11.904779", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "In this notebook, we use the Reformer, which is a more efficient of Transformer that uses reversible layers and locality-sensitive hashing. You can read the original paper [here](https://arxiv.org/abs/2001.04451). \n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cLrrFjeuzxVn", - "papermill": { - "duration": 0.034724, - "end_time": "2020-10-20T14:08:12.010232", - "exception": false, - "start_time": "2020-10-20T14:08:11.975508", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "### Locality-Sensitive Hashing\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fjo8QzSw2PbN", - "papermill": { - "duration": 0.034683, - "end_time": "2020-10-20T14:08:12.079753", - "exception": false, - "start_time": "2020-10-20T14:08:12.045070", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "---\n", - "The biggest problem that one might encounter while using Transformers, for huge corpora is the handling of the attention layer. Reformer introduces Locality Sensitive Hashing to solve this problem, by computing a hash function that groups similar vectors together. Thus, a input sequence is rearranged to bring elements with the same hash together and then divide into segments(or *chunks*, *buckets*) to enable parallel processing. Thus, we can apply Attention to these chunks (rather than the whole input sequence) to reduce the computational load." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u0YsTmPq13el", - "papermill": { - "duration": 0.03446, - "end_time": "2020-10-20T14:08:12.150541", - "exception": false, - "start_time": "2020-10-20T14:08:12.116081", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "![Reformer LSH.png]()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WyRZCUtO2Dbm", - "papermill": { - "duration": 0.035247, - "end_time": "2020-10-20T14:08:12.220409", - "exception": false, - "start_time": "2020-10-20T14:08:12.185162", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "### Reversible Layers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xAdYG2122Jt7", - "papermill": { - "duration": 0.03461, - "end_time": "2020-10-20T14:08:12.289666", - "exception": false, - "start_time": "2020-10-20T14:08:12.255056", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "---\n", - "\n", - "Using Locality Sensitive Hashing, we were able to solve the problem of computation but still we have a memory issue. Reformer implements a novel approach to solve this problem, by recomputing the input of each layer on-demand during back-propagation, rather than storing it in memory. This is accomplished by using Reversible Layers (*activations from last layers are used to recover activations from any intermediate layer*). \n", - "\n", - "Reversible layers store two sets of activations for each layer. \n", - "\n", - "- One follows the standard procedure in which the activations are added as they pass through the network\n", - "\n", - "- The other set only captures the changes. Thus, if we run the network in reverse, we simply subtract the activations applied at each layer." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cgA4DL7g30bG", - "papermill": { - "duration": 0.038825, - "end_time": "2020-10-20T14:08:12.363527", - "exception": false, - "start_time": "2020-10-20T14:08:12.324702", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "![Reformer Reversible.png]()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5IGhItKo6kIr", - "papermill": { - "duration": 0.035579, - "end_time": "2020-10-20T14:08:12.433667", - "exception": false, - "start_time": "2020-10-20T14:08:12.398088", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "## Model Architecture" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BTv1SBEn9-Wa", - "papermill": { - "duration": 0.034786, - "end_time": "2020-10-20T14:08:12.503419", - "exception": false, - "start_time": "2020-10-20T14:08:12.468633", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "We will perform the following steps:\n", - "\n", - "* Use input tensors from our data generator\n", - "\n", - "* Produce Semantic entries from an Embedding Layer\n", - "\n", - "* Feed these into our Reformer Language model\n", - "\n", - "* Run the Output through a Linear Layer\n", - "\n", - "* Run these through a log softmax layer to get predicted classes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4s0vDUd--pY4", - "papermill": { - "duration": 0.034523, - "end_time": "2020-10-20T14:08:12.572892", - "exception": false, - "start_time": "2020-10-20T14:08:12.538369", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "We use the:\n", - "\n", - "\n", - "\n", - "1. [`tl.Serial()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.combinators.Serial): Combinator that applies layers serially(by function composition). It's commonly used to construct deep networks. It uses stack semantics to manage data for its sublayers\n", - "2. [`tl.Embedding()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Embedding): Initializes a trainable embedding layer that maps discrete tokens/ids to vectors\n", - "\n", - "3. [`trax.models.reformer.Reformer()`](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.reformer.reformer.Reformer): Creates a Reversible Transformer encoder-decoder model.\n", - "\n", - "4. [`tl.Dense()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Dense): Creates a Dense(*fully-connected, affine*) layer\n", - "\n", - "5. [`tl.LogSoftmax()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.LogSoftmax): Creates a layer that applies log softmax along one tensor axis.\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:12.654167Z", - "iopub.status.busy": "2020-10-20T14:08:12.653125Z", - "iopub.status.idle": "2020-10-20T14:08:12.656426Z", - "shell.execute_reply": "2020-10-20T14:08:12.655814Z" - }, - "id": "gDqWqFKT6a6r", - "papermill": { - "duration": 0.046731, - "end_time": "2020-10-20T14:08:12.656598", - "exception": false, - "start_time": "2020-10-20T14:08:12.609867", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "def NERmodel(tags, vocab_size=35181, d_model = 50):\n", - "\n", - " model = tl.Serial(\n", - " # tl.Embedding(vocab_size, d_model),\n", - " trax.models.reformer.Reformer(vocab_size, d_model, ff_activation=tl.LogSoftmax),\n", - " tl.Dense(tags),\n", - " tl.LogSoftmax()\n", - " )\n", - "\n", - " return model" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "_kg_hide-output": true, - "collapsed": true, - "execution": { - "iopub.execute_input": "2020-10-20T14:08:12.748207Z", - "iopub.status.busy": "2020-10-20T14:08:12.747072Z", - "iopub.status.idle": "2020-10-20T14:08:12.754662Z", - "shell.execute_reply": "2020-10-20T14:08:12.754047Z" - }, - "id": "NsCct_PV8kEi", - "papermill": { - "duration": 0.062424, - "end_time": "2020-10-20T14:08:12.754804", - "exception": false, - "start_time": "2020-10-20T14:08:12.692380", - "status": "completed" - }, - "tags": [], - "outputId": "fc664cfd-87a1-4f98-cadc-8fdcafc9929f" - }, - "source": [ - "model = NERmodel(tags = 17)\n", - "\n", - "print(model)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Serial_in2_out2[\n", - " Serial_in2_out2[\n", - " Select[0,1,1]_in2_out3\n", - " Branch_out2[\n", - " []\n", - " [PaddingMask(0), Squeeze]\n", - " ]\n", - " Serial_in2_out2[\n", - " Embedding_35181_512\n", - " Dropout\n", - " PositionalEncoding\n", - " Dup_out2\n", - " ReversibleSerial_in3_out3[\n", - " ReversibleHalfResidual_in3_out3[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention_in2\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in3_out3[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention_in2\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in3_out3[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention_in2\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in3_out3[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention_in2\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in3_out3[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention_in2\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in3_out3[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention_in2\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ]\n", - " XYAvg_in2\n", - " LayerNorm\n", - " ]\n", - " Select[2,0,1]_in3_out3\n", - " ShiftRight(1)\n", - " Embedding_50_512\n", - " Dropout\n", - " PositionalEncoding\n", - " Dup_out2\n", - " ReversibleSerial_in4_out4[\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in4_out4[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " EncDecAttention_in3\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in4_out4[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " EncDecAttention_in3\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in4_out4[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " EncDecAttention_in3\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in4_out4[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " EncDecAttention_in3\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in4_out4[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " EncDecAttention_in3\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " SelfAttention\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in4_out4[\n", - " Serial[\n", - " LayerNorm\n", - " ]\n", - " EncDecAttention_in3\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ReversibleHalfResidual_in2_out2[\n", - " Serial[\n", - " LayerNorm\n", - " Dense_2048\n", - " Dropout\n", - " LogSoftmax\n", - " Dense_512\n", - " Dropout\n", - " ]\n", - " ]\n", - " ReversibleSwap_in2_out2\n", - " ]\n", - " XYAvg_in2\n", - " LayerNorm\n", - " Select[0]_in3\n", - " Dense_50\n", - " LogSoftmax\n", - " ]\n", - " Dense_17\n", - " LogSoftmax\n", - "]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1GsNxS4JYETt", - "papermill": { - "duration": 0.041676, - "end_time": "2020-10-20T14:08:12.833227", - "exception": false, - "start_time": "2020-10-20T14:08:12.791551", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "# Train the Model" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:12.920850Z", - "iopub.status.busy": "2020-10-20T14:08:12.919773Z", - "iopub.status.idle": "2020-10-20T14:08:12.924442Z", - "shell.execute_reply": "2020-10-20T14:08:12.923657Z" - }, - "id": "9nhKmsUkYFgD", - "papermill": { - "duration": 0.051837, - "end_time": "2020-10-20T14:08:12.924577", - "exception": false, - "start_time": "2020-10-20T14:08:12.872740", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "from trax.supervised import training\n", - "\n", - "rnd.seed(33)\n", - "\n", - "batch_size = 64\n", - "\n", - "train_generator = trax.data.inputs.add_loss_weights(\n", - " data_generator(batch_size, x_train, y_train,vocab[''], True),\n", - " id_to_mask=vocab[''])\n", - "\n", - "eval_generator = trax.data.inputs.add_loss_weights(\n", - " data_generator(batch_size, x_test, y_test,vocab[''] ,True),\n", - " id_to_mask=vocab[''])" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:13.011288Z", - "iopub.status.busy": "2020-10-20T14:08:13.010310Z", - "iopub.status.idle": "2020-10-20T14:08:13.013512Z", - "shell.execute_reply": "2020-10-20T14:08:13.012867Z" - }, - "id": "3CZWK9HgY_lj", - "papermill": { - "duration": 0.05051, - "end_time": "2020-10-20T14:08:13.013644", - "exception": false, - "start_time": "2020-10-20T14:08:12.963134", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "def train_model(model, train_generator, eval_generator, train_steps=1, output_dir='model'):\n", - " train_task = training.TrainTask(\n", - " train_generator, \n", - " loss_layer = tl.CrossEntropyLoss(), \n", - " optimizer = trax.optimizers.Adam(0.01), \n", - " n_steps_per_checkpoint=10\n", - " )\n", - "\n", - " eval_task = training.EvalTask(\n", - " labeled_data = eval_generator, \n", - " metrics = [tl.CrossEntropyLoss(), tl.Accuracy()], \n", - " n_eval_batches = 10 \n", - " )\n", - "\n", - " training_loop = training.Loop(\n", - " model, \n", - " train_task, \n", - " eval_tasks = eval_task, \n", - " output_dir = output_dir) \n", - "\n", - " training_loop.run(n_steps = train_steps)\n", - " return training_loop" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-10-20T14:08:13.124230Z", - "iopub.status.busy": "2020-10-20T14:08:13.103251Z", - "iopub.status.idle": "2020-10-20T22:19:59.586075Z", - "shell.execute_reply": "2020-10-20T22:19:59.584833Z" - }, - "id": "Y8kYOG9xZNF7", - "papermill": { - "duration": 29506.536646, - "end_time": "2020-10-20T22:19:59.586493", - "exception": false, - "start_time": "2020-10-20T14:08:13.049847", - "status": "completed" - }, - "tags": [], - "outputId": "29557238-28fb-4d50-b22d-fd50a203cc52" - }, - "source": [ - "train_steps = 100\n", - "training_loop = train_model(model, train_generator, eval_generator, train_steps)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - "Step 1: Ran 1 train steps in 815.40 secs\n", - "Step 1: train CrossEntropyLoss | 2.97494578\n", - "Step 1: eval CrossEntropyLoss | 5.96823492\n", - "Step 1: eval Accuracy | 0.85458949\n", - "\n", - "Step 10: Ran 9 train steps in 6809.59 secs\n", - "Step 10: train CrossEntropyLoss | 5.27117538\n", - "Step 10: eval CrossEntropyLoss | 5.19212604\n", - "Step 10: eval Accuracy | 0.85005882\n", - "\n", - "Step 20: Ran 10 train steps in 5372.06 secs\n", - "Step 20: train CrossEntropyLoss | 6.68565750\n", - "Step 20: eval CrossEntropyLoss | 4.00950582\n", - "Step 20: eval Accuracy | 0.81635543\n", - "\n", - "Step 30: Ran 10 train steps in 1040.84 secs\n", - "Step 30: train CrossEntropyLoss | 3.92878985\n", - "Step 30: eval CrossEntropyLoss | 3.32506871\n", - "Step 30: eval Accuracy | 0.78096363\n", - "\n", - "Step 40: Ran 10 train steps in 3624.02 secs\n", - "Step 40: train CrossEntropyLoss | 3.41684675\n", - "Step 40: eval CrossEntropyLoss | 3.47973170\n", - "Step 40: eval Accuracy | 0.84054841\n", - "\n", - "Step 50: Ran 10 train steps in 195.43 secs\n", - "Step 50: train CrossEntropyLoss | 2.64065409\n", - "Step 50: eval CrossEntropyLoss | 2.21273057\n", - "Step 50: eval Accuracy | 0.84472065\n", - "\n", - "Step 60: Ran 10 train steps in 1060.08 secs\n", - "Step 60: train CrossEntropyLoss | 2.35068488\n", - "Step 60: eval CrossEntropyLoss | 2.66343498\n", - "Step 60: eval Accuracy | 0.84561690\n", - "\n", - "Step 70: Ran 10 train steps in 1041.36 secs\n", - "Step 70: train CrossEntropyLoss | 2.30295134\n", - "Step 70: eval CrossEntropyLoss | 1.31594980\n", - "Step 70: eval Accuracy | 0.84971260\n", - "\n", - "Step 80: Ran 10 train steps in 1178.78 secs\n", - "Step 80: train CrossEntropyLoss | 1.15712142\n", - "Step 80: eval CrossEntropyLoss | 1.15898243\n", - "Step 80: eval Accuracy | 0.84357584\n", - "\n", - "Step 90: Ran 10 train steps in 2033.67 secs\n", - "Step 90: train CrossEntropyLoss | 1.06345284\n", - "Step 90: eval CrossEntropyLoss | 0.93652567\n", - "Step 90: eval Accuracy | 0.84781972\n", - "\n", - "Step 100: Ran 10 train steps in 2001.96 secs\n", - "Step 100: train CrossEntropyLoss | 1.04488492\n", - "Step 100: eval CrossEntropyLoss | 1.02899926\n", - "Step 100: eval Accuracy | 0.85163420\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dRwN9mp74kZG", - "papermill": { - "duration": 0.058348, - "end_time": "2020-10-20T22:19:59.703317", - "exception": false, - "start_time": "2020-10-20T22:19:59.644969", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "# References" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "--6G7L9w4mNg", - "papermill": { - "duration": 0.058998, - "end_time": "2020-10-20T22:19:59.820862", - "exception": false, - "start_time": "2020-10-20T22:19:59.761864", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "---\n", - "\n", - "* [Google AI Blog- Reformer: The Efficient Transformer](https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html)\n", - "\n", - "* [Google AI Blog- Transformer: A Novel Neural Network Architecture for Language Understanding](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html)\n", - "\n", - "* [Trax: Deep Learning with Clear Code and Speed](https://github.com/google/trax)\n", - "\n", - "* [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)\n", - "\n", - "* [Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n", - "\n", - "* [Illustrating the Reformer](https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0)" - ] - } - ] -} \ No newline at end of file diff --git a/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb b/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb deleted file mode 100644 index 1a4679647..000000000 --- a/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb +++ /dev/null @@ -1,2249 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "NMT with Transformers/Reformers using Trax.ipynb", - "provenance": [], - "collapsed_sections": [ - "_WpKodqa9dmJ", - "VY6_SnLM9dms", - "WD0ZqedYIpr3", - "r_8UOdZ_9dnO", - "v5IDVjXl9dnU", - "4U_V6nNQ_37u" - ], - "toc_visible": true, - "machine_shape": "hm", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lAAzPCP8n05S" - }, - "source": [ - "#@title\n", - "# Copyright 2021 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hqqdEx7xtHuH" - }, - "source": [ - "# **NMT with Transformers/Reformers using Trax**\n", - "\n", - "A guide to Neural Machine Translation using ***Transformers/Reformers***. Includes a detailed tutorial using ***Trax*** in Google Colaboratory.\n", - "\n", - "Machine translation is an important task in natural language processing and could be useful not only for translating one language to another but also for word sense disambiguation. \n", - "\n", - "In this Notebook you will:\n", - "* Learn how to preprocess your training and evaluation data.\n", - "* implement an encoder-decoder system with attention.\n", - "* understand how attention works.\n", - "* build the NMT model from scratch using Trax.\n", - "* learn how to preprocess your training and evaluation data.\n", - "* generate translations using greedy and Minimum Bayes Risk (MBR) decoding.\n", - "\n", - "This notebook contains a lot of cells taken from [Natural Language Processing Specialization](https://www.coursera.org/specializations/natural-language-processing)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R8u7YU2uqOXH" - }, - "source": [ - "# Part (-1): Run on TPU" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pO10zU6I87dc" - }, - "source": [ - "This notebook was designed to run on TPU.\n", - "\n", - "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8QCsYnkLv59s", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "29c114d1-c940-4411-fcf1-984a34b7f9fa" - }, - "source": [ - "# Install JAX/TRAX.\n", - "!pip install --upgrade -q jax\n", - "!pip install --upgrade -q jaxlib\n", - "!pip install --upgrade -q trax" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\u001b[K |████████████████████████████████| 34.7MB 123kB/s \n", - "\u001b[K |████████████████████████████████| 522kB 5.8MB/s \n", - "\u001b[K |████████████████████████████████| 3.4MB 9.7MB/s \n", - "\u001b[K |████████████████████████████████| 215kB 52.9MB/s \n", - "\u001b[K |████████████████████████████████| 3.8MB 35.4MB/s \n", - "\u001b[K |████████████████████████████████| 1.2MB 53.5MB/s \n", - "\u001b[K |████████████████████████████████| 368kB 52.9MB/s \n", - "\u001b[K |████████████████████████████████| 71kB 7.9MB/s \n", - "\u001b[K |████████████████████████████████| 1.9MB 56.5MB/s \n", - "\u001b[K |████████████████████████████████| 3.2MB 53.5MB/s \n", - "\u001b[K |████████████████████████████████| 890kB 59.9MB/s \n", - "\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wsN3Jxi6vquW", - "outputId": "ea9f5590-0f13-43f0-dcfe-f075941ca07f" - }, - "source": [ - "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", - "import requests\n", - "import os\n", - "if 'TPU_DRIVER_MODE' not in globals():\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", - " resp = requests.post(url)\n", - " TPU_DRIVER_MODE = 1\n", - "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "from jax.config import config\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "print(config.FLAGS.jax_backend_target)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "grpc://10.43.185.50:8470\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QVw5457jqlOm" - }, - "source": [ - "# Part (0): Important Imports" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nA7u_MqG9dmQ", - "outputId": "741a1e11-319a-4742-e38e-6217da1295e9" - }, - "source": [ - "import trax\n", - "from trax.data import inputs\n", - "from trax import layers as tl\n", - "from trax.supervised import training\n", - "\n", - "import numpy as np\n", - "\n", - "from termcolor import colored\n", - "import random\n", - "\n", - "!pip list | grep trax" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "trax 1.3.7 \n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aByNRLKr9dmG" - }, - "source": [ - "# Part (1): Data Preparation\n", - "\n", - "**You Can jump directly to Trax Data Pipeline (optional) Section and skip 1.1 to 1.5 sections.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_WpKodqa9dmJ" - }, - "source": [ - "## 1.1 Importing the Data\n", - "We will be using [ParaCrawl](https://paracrawl.eu/), a large multi-lingual translation dataset created by the European Union. All of these datasets are available via [TFDS para_crawl](https://www.tensorflow.org/datasets/catalog/para_crawl). We used English to French dataset. You can try the other avaliable languages by changing the `dataset_name` and `keys`. Or even try another datasets available at TFDS.\n", - "\n", - "Notice: It will take a while in the first time to download the dataset. So, it is prefered to specify `data_dir` on Google Drive not in Colab runtime. Try other than para_crawl dataset. since, the para_crawl is a large dataset." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "S-cIYEHwrhoZ", - "outputId": "8cb90c0f-db5f-4f07-8e1f-a4e4637d9f33" - }, - "source": [ - "# MOUNT DRIVE\n", - "from google.colab import drive\n", - "drive.mount('/content/drive')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Mounted at /content/drive\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "jEJaYJ5C9dmb" - }, - "source": [ - "# This will download the train dataset if no data_dir is specified.\n", - "train_stream_fn = trax.data.TFDS('para_crawl/enfr',\n", - " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", - " keys=('en', 'fr'),\n", - " eval_holdout_size=0.01, # 1% for eval\n", - " train=True)\n", - "\n", - "# Get generator function for the eval set\n", - "eval_stream_fn = trax.data.TFDS('para_crawl/enfr',\n", - " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", - " keys=('en', 'fr'),\n", - " eval_holdout_size=0.01, # 1% for eval\n", - " train=False)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kk-x0gW9-qsD" - }, - "source": [ - "You can work with your own datasets instead of loading your dataset with TFDS. Opening a file as shown above creates that generator for you. dont forget to make another function for eval.\n", - "\n", - "```python\n", - "def train_stream_fn():\n", - " # provide an infinite generator", - " while True:", - " # open the first language file (e.g. English sentences)\n", - " with open('lang1.csv','r') as f1:\n", - " # open the second language file (e.g. French sentences)\n", - " with open('lang2.csv','r') as f2:\n", - " # looping over the two files to combine the two translation toghether and yields them.\n", - " for l1, l2 in zip(f1,f2):\n", - " yield (l1, l2)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tPlcZf3RLNAg" - }, - "source": [ - "Notice that TFDS returns a generator *function*.\n", - "\n", - "Let's print a a sample pair from our train and eval data. Notice that the raw ouput is represented in bytes (denoted by the `b'` prefix) and these will be converted to strings internally in the next steps." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "16UrIf259dml", - "outputId": "9a860216-c9fa-4f29-e28a-e9c535feefd4" - }, - "source": [ - "train_stream = train_stream_fn()\n", - "print(colored('train data (en, fr) tuple:', 'red'), next(train_stream))\n", - "print()\n", - "\n", - "eval_stream = eval_stream_fn()\n", - "print(colored('eval data (en, fr) tuple:', 'red'), next(eval_stream))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\u001b[31mtrain data (en, fr) tuple:\u001b[0m (b'Our soldiers are in a poor state; the Germans want grain, they will take it and go back, making it impossible for Soviet power to continue in existence.', b\"Nos soldats ne valent rien ; les Allemands veulent du bl\\xc3\\xa9, ils le prendront et ils battront en retraite apr\\xc3\\xa8s avoir rendu impossible l'existence du pouvoir des Soviets. Dire que la d\\xc3\\xa9mobilisation cesse, c'est se condamner \\xc3\\xa0 \\xc3\\xaatre balay\\xc3\\xa9. Notes\")\n", - "\n", - "\u001b[31meval data (en, fr) tuple:\u001b[0m (b'These scrumptious brownies can be part of a healthful eating plan.', b\"Ces succulents brownies peuvent faire partie d'un r\\xc3\\xa9gime alimentaire \\xc3\\xa9quilibr\\xc3\\xa9.\")\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kWUH9_PNIe5g" - }, - "source": [ - "Now that we have imported our corpus, we will be preprocessing the sentences into a format that our model can accept. This will be composed of several steps:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VY6_SnLM9dms" - }, - "source": [ - "## 1.2 Tokenization and Formatting\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PWP3GAoXHiwo" - }, - "source": [ - "**Tokenizing the sentences using subword representations:** we want to represent each sentence as an array of integers instead of strings. For our application, we will use *subword* representations to tokenize our sentences. This is a common technique to avoid out-of-vocabulary words by allowing parts of words to be represented separately. For example, instead of having separate entries in your vocabulary for \"fear\", \"fearless\", \"fearsome\", \"some\", and \"less\", you can simply store \"fear\", \"some\", and \"less\" then allow your tokenizer to combine these subwords when needed. This allows it to be more flexible so you won't have to save uncommon words explicitly in your vocabulary (e.g. *stylebender*, *nonce*, etc). Tokenizing is done with the `trax.data.Tokenize()` command. The combined subword vocabulary for English, German and French (i.e. `endefr_32k.subword`) is provided by trax. Feel free to open this file to see how the subwords look like." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Q8R2RxvK9dmt" - }, - "source": [ - "# global variables that state the filename and directory of the vocabulary file\n", - "VOCAB_FILE = 'endefr_32k.subword'\n", - "VOCAB_DIR = 'gs://trax-ml/vocabs/'\n", - "\n", - "# Tokenize the dataset.\n", - "tokenized_train_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)\n", - "tokenized_eval_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yrmCi915HTKA" - }, - "source": [ - "**Append an end-of-sentence token to each sentence:** We will assign a token (i.e. in this case `1`) to mark the end of a sentence. This will be useful in inference/prediction so we'll know that the model has completed the translation." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "RuolzODV9dm0" - }, - "source": [ - "# Append EOS at the end of each sentence.\n", - "\n", - "# Integer assigned as end-of-sentence (EOS)\n", - "EOS = 1\n", - "\n", - "# generator helper function to append EOS to each sentence\n", - "def append_eos(stream):\n", - " for (inputs, targets) in stream:\n", - " inputs_with_eos = list(inputs) + [EOS]\n", - " targets_with_eos = list(targets) + [EOS]\n", - " yield np.array(inputs_with_eos), np.array(targets_with_eos)\n", - "\n", - "# append EOS to the train data\n", - "tokenized_train_stream = append_eos(tokenized_train_stream)\n", - "\n", - "# append EOS to the eval data\n", - "tokenized_eval_stream = append_eos(tokenized_eval_stream)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rbaYhKr99dm8" - }, - "source": [ - "**Filter long sentences:** We will place a limit on the number of tokens per sentence to ensure we won't run out of memory. This is done with the `trax.data.FilterByLength()` method and you can see its syntax below." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Miw7Uu849dm9", - "outputId": "3a6b35a5-f257-42fb-b914-59534d3f2a76" - }, - "source": [ - "# Filter too long sentences to not run out of memory.\n", - "# length_keys=[0, 1] means we filter both English and French sentences, so\n", - "# both much be not longer that 512 tokens for training / 1024 for eval.\n", - "filtered_train_stream = trax.data.FilterByLength(\n", - " max_length=512, length_keys=[0, 1])(tokenized_train_stream)\n", - "filtered_eval_stream = trax.data.FilterByLength(\n", - " max_length=1024, length_keys=[0, 1])(tokenized_eval_stream)\n", - "\n", - "# print a sample input-target pair of tokenized sentences\n", - "train_input, train_target = next(filtered_train_stream)\n", - "print(colored(f'Single tokenized example input:', 'red' ), train_input)\n", - "print(colored(f'Single tokenized example target:', 'red'), train_target)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\u001b[31mSingle tokenized example input:\u001b[0m [ 107 4 1624 1039 4 6211 34 2544 533 1272 19535 757\n", - " 15 694 3252 371 8538 3 1]\n", - "\u001b[31mSingle tokenized example target:\u001b[0m [ 812 578 28 485 1791 2 11044 49 18 31 9859 5\n", - " 10 3965 2994 3077 26 285 12502 5005 49 21 7275 11759\n", - " 3 1]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WD0ZqedYIpr3" - }, - "source": [ - "## 1.3 tokenize & detokenize helper functions\n", - "\n", - "- tokenize(): converts a text sentence to its corresponding token list (i.e. list of indices). Also converts words to subwords (parts of words).\n", - "- detokenize(): converts a token list to its corresponding sentence (i.e. string)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "OyO5I2e_9dnD" - }, - "source": [ - "# Setup helper functions for tokenizing and detokenizing sentences\n", - "def tokenize(input_str, vocab_file=None, vocab_dir=None):\n", - " \"\"\"Encodes a string to an array of integers\n", - " Args:\n", - " input_str (str): human-readable string to encode\n", - " vocab_file (str): filename of the vocabulary text file\n", - " vocab_dir (str): path to the vocabulary file\n", - " Returns:\n", - " numpy.ndarray: tokenized version of the input string\n", - " \"\"\"\n", - " # Set the encoding of the \"end of sentence\" as 1\n", - " EOS = 1\n", - " # Use the trax.data.tokenize method. It takes streams and returns streams,\n", - " # we get around it by making a 1-element stream with `iter`.\n", - " inputs = next(trax.data.tokenize(iter([input_str]),\n", - " vocab_file=vocab_file, vocab_dir=vocab_dir))\n", - " # Mark the end of the sentence with EOS\n", - " inputs = list(inputs) + [EOS]\n", - " # Adding the batch dimension to the front of the shape\n", - " batch_inputs = np.reshape(np.array(inputs), [1, -1])\n", - " return batch_inputs\n", - "\n", - "def detokenize(integers, vocab_file=None, vocab_dir=None):\n", - " \"\"\"Decodes an array of integers to a human readable string\n", - " Args:\n", - " integers (numpy.ndarray): array of integers to decode\n", - " vocab_file (str): filename of the vocabulary text file\n", - " vocab_dir (str): path to the vocabulary file \n", - " Returns:\n", - " str: the decoded sentence.\n", - " \"\"\"\n", - " # Remove the dimensions of size 1\n", - " integers = list(np.squeeze(integers))\n", - " # Set the encoding of the \"end of sentence\" as 1\n", - " EOS = 1\n", - " # Remove the EOS to decode only the original tokens\n", - " if EOS in integers:\n", - " integers = integers[:integers.index(EOS)] \n", - " return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NKfYr4SA9dnH" - }, - "source": [ - "Let's see how we might use these functions:" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Xb7UEVAS9dnI", - "outputId": "dc1cc233-77ef-4ee2-93cc-34f7ddc586c2" - }, - "source": [ - "# Detokenize an input-target pair of tokenized sentences\n", - "print(colored(f'Single detokenized example input:', 'red'), detokenize(train_input, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))\n", - "print(colored(f'Single detokenized example target:', 'red'), detokenize(train_target, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))\n", - "print()\n", - "\n", - "# Tokenize and detokenize a word that is not explicitly saved in the vocabulary file.\n", - "# See how it combines the subwords 'hell' and 'o' to form the word 'hello'.\n", - "print(colored(f\"tokenize('hello'): \", 'green'), tokenize('hello', vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\u001b[31mSingle detokenized example input:\u001b[0m In the longer term the emphasis on increasing road capacity encourages car-based urban development patterns.\n", - "\u001b[31mSingle detokenized example target:\u001b[0m Au niveau du long terme, insister sur l’accroissement de la capacitÊ routière encourage le dÊveloppement urbain basÊ sur les vÊhicules personnels.\n", - "\n", - "\u001b[32mtokenize('hello'): \u001b[0m [[11068 5505 1]]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "r_8UOdZ_9dnO" - }, - "source": [ - "## 1.4 Bucketing\n", - "\n", - "Bucketing the tokenized sentences is an important technique used to speed up training in NLP.\n", - "Here is a \n", - "[nice article describing it in detail](https://medium.com/@rashmi.margani/how-to-speed-up-the-training-of-the-sequence-model-using-bucketing-techniques-9e302b0fd976)\n", - "but the gist is very simple. Our inputs have variable lengths and you want to make these the same when batching groups of sentences together. One way to do that is to pad each sentence to the length of the longest sentence in the dataset. This might lead to some wasted computation though. For example, if there are multiple short sentences with just two tokens, do we want to pad these when the longest sentence is composed of a 100 tokens? Instead of padding with 0s to the maximum length of a sentence each time, we can group our tokenized sentences by length and bucket, as on this image (from the article above):\n", - "\n", - "![alt text](https://miro.medium.com/max/700/1*hcGuja_d5Z_rFcgwe9dPow.png)\n", - "\n", - "We batch the sentences with similar length together (e.g. the blue sentences in the image above) and only add minimal padding to make them have equal length (usually up to the nearest power of two). This allows to waste less computation when processing padded sequences.\n", - "In Trax, it is implemented in the [bucket_by_length](https://github.com/google/trax/blob/5fb8aa8c5cb86dabb2338938c745996d5d87d996/trax/supervised/inputs.py#L378) function." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "MUlfg9kX9dnP" - }, - "source": [ - "# Bucketing to create streams of batches.\n", - "\n", - "# Buckets are defined in terms of boundaries and batch sizes.\n", - "# Batch_sizes[i] determines the batch size for items with length < boundaries[i]\n", - "# So below, we'll take a batch of 128 sentences of length < 8, 128 if length is\n", - "# between 8 and 16, and so on. 128 batch is also taken if length is over 256.\n", - "boundaries = [ 8, 16, 32, 64, 128, 256]\n", - "batch_sizes = [128, 128, 128, 128, 128, 128, 128]\n", - "# Notice all is 128. As we are using TPUs, We need the same batch_size to run in parallel.\n", - "# You can make diffrent batch_sizes if you are using GPU or CPU.\n", - "\n", - "# Create the generators.\n", - "train_batch_stream = trax.data.BucketByLength(\n", - " boundaries, batch_sizes,\n", - " length_keys=[0, 1] # As before: count inputs and targets to length.\n", - ")(filtered_train_stream)\n", - "\n", - "eval_batch_stream = trax.data.BucketByLength(\n", - " boundaries, batch_sizes,\n", - " length_keys=[0, 1] # As before: count inputs and targets to length.\n", - ")(filtered_eval_stream)\n", - "\n", - "# Add masking for the padding (0s).\n", - "train_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)\n", - "eval_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(eval_batch_stream)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "v5IDVjXl9dnU" - }, - "source": [ - "## 1.5 Exploring the data" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vX-ukU52No8Q" - }, - "source": [ - "We will now be displaying some of our data. You will see that the functions defined above (i.e. `tokenize()` and `detokenize()`)" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zI_Rea2Q9dnV", - "outputId": "2db581ca-3b3a-450f-9b91-1d924420fa51" - }, - "source": [ - "input_batch, target_batch, mask_batch = next(train_batch_stream)\n", - "\n", - "# let's see the data type of a batch\n", - "print(\"input_batch data type: \", type(input_batch))\n", - "print(\"target_batch data type: \", type(target_batch))\n", - "\n", - "# let's see the shape of this particular batch (batch length, sentence length)\n", - "print(\"input_batch shape: \", input_batch.shape)\n", - "print(\"target_batch shape: \", target_batch.shape)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "input_batch data type: \n", - "target_batch data type: \n", - "input_batch shape: (128, 64)\n", - "target_batch shape: (128, 64)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wE_ilByVN8zT" - }, - "source": [ - "The `input_batch` and `target_batch` are Numpy arrays consisting of tokenized English sentences and French sentences respectively. These tokens will later be used to produce embedding vectors for each word in the sentence (so the embedding for a sentence will be a matrix).\n", - "\n", - "We can now visually inspect some of the data. You can run the cell below several times to shuffle through the sentences." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vd_71uRi9dnb", - "outputId": "c25d1a92-8953-4d1b-c916-9bfde6882414" - }, - "source": [ - "# pick a random index less than the batch size.\n", - "index = random.randrange(len(input_batch))\n", - "\n", - "# use the index to grab an entry from the input and target batch\n", - "print(colored('THIS IS THE ENGLISH SENTENCE: \\n', 'red'), detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", - "print(colored('THIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \\n ', 'red'), input_batch[index], '\\n')\n", - "print(colored('THIS IS THE FRENCH TRANSLATION: \\n', 'red'), detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", - "print(colored('THIS IS THE TOKENIZED VERSION OF THE FRENCH TRANSLATION: \\n', 'red'), target_batch[index], '\\n')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\u001b[31mTHIS IS THE ENGLISH SENTENCE: \n", - "\u001b[0m If we do not give the question the same attention here as other alternative propositions, it is because this is an overarching debate that tends to be divisive within social movements and left-wing political parties. It is indeed a major question that would require a lengthy discussion to be dealt with adequately. \n", - "\n", - "\u001b[31mTHIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \n", - " \u001b[0m [ 600 100 271 83 993 4 329 4 470 914 859 60\n", - " 221 3577 1819 2 62 27 382 64 27 50 9052 7500\n", - " 260 876 29 8480 16 14 53 17380 1965 439 392 16958\n", - " 11 2220 15 6347 798 563 39 186 27 2780 17 858\n", - " 329 29 150 3630 17 29834 88 1682 14 53 9780 58\n", - " 20015 3 1 0] \n", - "\n", - "\u001b[31mTHIS IS THE FRENCH TRANSLATION: \n", - "\u001b[0m Si nous n’y accordons pas ici la mÃĒme attention qu’aux autres propositions d’alternatives, c’est que le dÊbat traverse et divise tant les mouvements sociaux que les partis de gauche et qu’il est nÊcessaire d’y consacrer de nombreuses pages pour faire le tour de la question. \n", - "\n", - "\u001b[31mTHIS IS THE TOKENIZED VERSION OF THE FRENCH TRANSLATION: \n", - "\u001b[0m [ 983 108 30 31 88 14331 281 102 1257 10 280 914\n", - " 103 31 89 265 1819 24 31 22892 2 162 31 37\n", - " 36 26 1124 28501 12 17380 32 854 21 16959 4369 36\n", - " 21 10943 16 5 19387 12 103 31 72 37 1101 24\n", - " 31 88 17458 5 1397 4763 40 287 26 3619 5 10\n", - " 329 3 1 0] \n", - "\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UDSPHBZaeRAW" - }, - "source": [ - "## Trax Data Pipeline (optional)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WP2RACXYeTse" - }, - "source": [ - "Those were the steps needed to prepare the data (steps from 1.1 to 1.5) But you could simply use [Trax data pipeline](https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html#Data) `trax.data.Serial` in the next cell. **if you run this cell you should skip (steps from 1.1 to 1.5).**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BUPhstH70Xzu" - }, - "source": [ - "You can work with your own datasets instead of loading your dataset with TFDS you can simply replace the TFDS call with an `lambda _: train_stream_fn()`\n", - "Everything in tf.Serial is a generator. Opening a file as shown above creates that generator for you.\n", - "\n", - "```python\n", - "def train_stream_fn():\n", - " # open the first language file (e.g. English sentences)\n", - " with open('lang1.csv','r') as f1:\n", - " # open the second language file (e.g. French sentences)\n", - " with open('lang2.csv','r') as f2:\n", - " # looping over the two files to combine the two translation toghether and yields them.\n", - " for l1, l2 in zip(f1,f2):\n", - " yield (l1, l2)\n", - "```\n", - "\n", - "and then add\n", - "```python\n", - "lambda _: train_stream_fn()\n", - "```\n", - "to `trax.data.Serial()` instead of \n", - "```python\n", - "trax.data.TFDS('para_crawl/enfr',\n", - " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", - " keys=('en', 'fr'),\n", - " eval_holdout_size=0.01, # 1% for eval\n", - " train=True)\n", - "```\n", - "for both the training and eval streams." - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "iwwICFzpCTMa", - "outputId": "ed0b59ac-88d8-423f-b958-c683b6964f37" - }, - "source": [ - "# MOUNT DRIVE\n", - "from google.colab import drive\n", - "drive.mount('/content/drive')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Mounted at /content/drive\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "52LAeuAETG9o" - }, - "source": [ - "# if you run this cell you should skip (steps from 1.1 to 1.5).\n", - "\n", - "# global variables that state the filename and directory of the vocabulary file\n", - "VOCAB_FILE = 'endefr_32k.subword'\n", - "VOCAB_DIR = 'gs://trax-ml/vocabs/'\n", - "\n", - "EOS = 1\n", - "\n", - "# generator helper function to append EOS to each sentence\n", - "def append_eos(stream):\n", - " for (inputs, targets) in stream:\n", - " inputs_with_eos = list(inputs) + [EOS]\n", - " targets_with_eos = list(targets) + [EOS]\n", - " yield np.array(inputs_with_eos), np.array(targets_with_eos)\n", - "\n", - "train_batches_stream = trax.data.Serial(\n", - " trax.data.TFDS('para_crawl/enfr',\n", - " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", - " keys=('en', 'fr'),\n", - " eval_holdout_size=0.01, # 1% for eval\n", - " train=True), # replace TFDS with lambda _: train_stream_fn() if you want to run with your own data\n", - " trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR),\n", - " lambda _: append_eos(_),\n", - " trax.data.Shuffle(),\n", - " trax.data.FilterByLength(max_length=512, length_keys=[0, 1]),\n", - " trax.data.BucketByLength(boundaries = [ 8, 16, 32, 64, 128, 256],\n", - " batch_sizes = [128, 128, 128, 128, 128, 128, 128],\n", - " length_keys=[0, 1]),\n", - " trax.data.AddLossWeights(id_to_mask=0)\n", - " )\n", - "\n", - "eval_batches_stream = trax.data.Serial(\n", - " trax.data.TFDS('para_crawl/enfr',\n", - " data_dir='/content/drive/MyDrive/Colab Notebooks/data/',\n", - " keys=('en', 'fr'),\n", - " eval_holdout_size=0.01, # 1% for eval\n", - " train=False),\n", - " trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR),\n", - " lambda _: append_eos(_),\n", - " trax.data.Shuffle(),\n", - " trax.data.FilterByLength(max_length=1024, length_keys=[0, 1]),\n", - " trax.data.BucketByLength(boundaries = [ 8, 16, 32, 64, 128, 256],\n", - " batch_sizes = [128, 128, 128, 128, 128, 128, 128],\n", - " length_keys=[0, 1]),\n", - " trax.data.AddLossWeights(id_to_mask=0)\n", - " )" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "wWvu5PraqBQx", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "34480b60-6a40-4be6-f51c-4e481c2e1bc3" - }, - "source": [ - "# Exploring the data\n", - "train_batch_stream = train_batches_stream()\n", - "eval_batch_stream = eval_batches_stream()\n", - "input_batch, target_batch, mask_batch = next(train_batch_stream)\n", - "# let's see the data type of a batch\n", - "print(\"input_batch data type: \", type(input_batch))\n", - "print(\"target_batch data type: \", type(target_batch))\n", - "# let's see the shape of this particular batch (batch length, sentence length)\n", - "print(\"input_batch shape: \", input_batch.shape)\n", - "print(\"target_batch shape: \", target_batch.shape)\n", - "\n", - "# pick a random index less than the batch size.\n", - "index = random.randrange(len(input_batch))\n", - "# use the index to grab an entry from the input and target batch\n", - "print(colored('ENGLISH SENTENCE: \\n', 'red'), trax.data.detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", - "print(colored('THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \\n ', 'red'), input_batch[index], '\\n')\n", - "print(colored('THE FRENCH TRANSLATION: \\n', 'red'), trax.data.detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\\n')\n", - "print(colored('THE TOKENIZED VERSION OF THE FRENCH TRANSLATION: \\n', 'red'), target_batch[index], '\\n')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "input_batch data type: \n", - "target_batch data type: \n", - "input_batch shape: (128, 32)\n", - "target_batch shape: (128, 32)\n", - "\u001b[31mENGLISH SENTENCE: \n", - "\u001b[0m Instead they can submit section 2 of the application to Health Canada and have it added to their original application. \n", - "\n", - "\u001b[31mTHE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \n", - " \u001b[0m [13284 182 136 9547 1098 112 8 4 487 14 6051 141\n", - " 11 82 62 4183 14 148 1957 487 3 1 0 0\n", - " 0 0 0 0 0 0 0 0] \n", - "\n", - "\u001b[31mTHE FRENCH TRANSLATION: \n", - "\u001b[0m Ils peuvent plutôt soumettre la section 2 du formulaire à SantÊ Canada pour qu'elle soit ajoutÊe à leur demande originale. \n", - "\n", - "\u001b[31mTHE TOKENIZED VERSION OF THE FRENCH TRANSLATION: \n", - "\u001b[0m [ 2621 589 2726 13657 10 1098 112 28 19418 6 23 7163\n", - " 141 40 103 7 252 419 11275 23 267 775 4948 3\n", - " 1 0 0 0 0 0 0 0] \n", - "\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M-69Mr-9_VEk" - }, - "source": [ - "# Part (2): Model\n", - "\n", - "Now that we’ve seen preprocessing, it’s time to move into Modeling itself. Trax allows the use of Predefined Models, such as:\n", - " - Seq2Seq with Attention\n", - " - BERT\n", - " - Transformer\n", - " - Reformer\n", - "\n", - "We will be using Transformer in this Notebook As Trax provided a pretrained Transformer NMT Model which is traind on English to German dataset and We now are going to train it on English to French dataset and get a very close results to the one provide by Google Brain Team.\n", - "\n", - "You can simply change `trax.models.Transformer` in the next cell to `trax.models.Reformer` to use the Reformer model.\n", - "\n", - "```python\n", - "# you could check the available pretrained models and vocab files provided by trax by running:\n", - "!gsutil ls gs://trax-ml/\n", - "```" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "SkAuvdOErAlP" - }, - "source": [ - "# Create a Transformer model.\n", - "model = trax.models.Transformer(\n", - " input_vocab_size=33600,\n", - " d_model=512, d_ff=2048, dropout = 0.1,\n", - " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", - " max_len=2048, mode='train')\n", - "\n", - "# Pre-trained Transformer model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", - "# Initialize Transformer using pre-trained weights.\n", - "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", - " weights_only=True)\n", - "\n", - "# You also, could intiate the model from an output checpoint.\n", - "# simply change 'gs://trax-ml/models/translation/ende_wmt32k.pkl.gz' to 'output_dir/ + last_checkpoint'\n", - "# for example:\n", - "# model.init_from_file('/content/drive/MyDrive/Colab Notebooks/Transformer_FR_pretrained_336/model.pkl.gz',\n", - "# weights_only=True)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2p0AGzlKQusn" - }, - "source": [ - "You could have a peek at the model layers." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "CvmdtOfeZ9Ff" - }, - "source": [ - "# model" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E8FfOMp59doX" - }, - "source": [ - "# Part (3): Training\n", - "We will now be training our model in this section. Doing supervised training in Trax is pretty straightforward (short example [here](https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html#Supervised-training)). We will be instantiating three classes for this: `TrainTask`, `EvalTask`, and `Loop`. Let's take a closer look at each of these in the sections below." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "re1ZHUac9doY" - }, - "source": [ - "## 3.1 TrainTask\n", - "\n", - "The [TrainTask](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.training.TrainTask) class allows us to define the labeled data to use for training and the feedback mechanisms to compute the loss and update the weights. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "gFP83q7S9doZ" - }, - "source": [ - "train_task = training.TrainTask(\n", - " # use the train batch stream as labeled data\n", - " labeled_data= train_batch_stream,\n", - " # use the cross entropy loss with LogSoftmax\n", - " loss_layer= tl.CrossEntropyLossWithLogSoftmax(),\n", - " # use the Adafactor optimizer with learning rate of 0.001\n", - " optimizer= trax.optimizers.Adafactor(learning_rate=0.001, epsilon1=1e-30),\n", - " # have 500 warmup steps\n", - " lr_schedule= trax.lr.multifactor(constant=1.0, warmup_steps=500),\n", - " # have a checkpoint every 100 steps\n", - " n_steps_per_checkpoint= 100,\n", - " # saving a checkpoint every 1000 steps on the output_dir\n", - " n_steps_per_permanent_checkpoint = 1000\n", - ")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7EQI-c999doi" - }, - "source": [ - "## 3.2 EvalTask\n", - "\n", - "The [EvalTask](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.training.EvalTask) on the other hand allows us to see how the model is doing while training. For our application, we want it to report the cross entropy loss with LogSoftmax and accuracy." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "u5hVQ0Qd9doj" - }, - "source": [ - "eval_task = training.EvalTask(\n", - " # use the eval batch stream as labeled data\n", - " labeled_data=eval_batch_stream,\n", - " # use the cross entropy loss with LogSoftmax and accuracy as metrics\n", - " metrics=[tl.CrossEntropyLossWithLogSoftmax(), tl.WeightedCategoryAccuracy()],\n", - " # you could specify the number of eval batch by n_eval_batches = 64 or any other number\n", - " # but it not specified here as we want to evaluate the whole eval data\n", - " # n_eval_batches = 64\n", - ")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "14pSLHEw9dol" - }, - "source": [ - "## 3.3 Loop\n", - "\n", - "The [Loop](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.training.Loop) class defines the model we will train as well as the train and eval tasks to execute. Its `run()` method allows us to execute the training for a specified number of steps." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "QdnRbEAz9dom" - }, - "source": [ - "# define the output directory\n", - "output_dir = '/content/drive/MyDrive/Colab Notebooks/Transformer_FR_pretrained_336'\n", - "\n", - "# # remove old model if it exists. restarts training.\n", - "# !rm -rf output_dir\n", - "\n", - "# define the training loop\n", - "training_loop = training.Loop(model,\n", - " train_task,\n", - " eval_tasks=[eval_task],\n", - " output_dir=output_dir)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "bRk-1Wsu9doo", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "1459c595-b218-4be8-d6ea-805147ca20c5" - }, - "source": [ - "# Start Training!\n", - "training_loop.run(5000)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - "Step 1: Total number of trainable weights: 80370196\n", - "Step 1: Ran 1 train steps in 130.79 secs\n", - "Step 1: train CrossEntropyLossWithLogSoftmax | 9.85205269\n", - "Step 1: eval CrossEntropyLossWithLogSoftmax | 9.71523285\n", - "Step 1: eval WeightedCategoryAccuracy | 0.08454708\n", - "\n", - "Step 100: Ran 99 train steps in 487.01 secs\n", - "Step 100: train CrossEntropyLossWithLogSoftmax | 7.02561331\n", - "Step 100: eval CrossEntropyLossWithLogSoftmax | 5.81694698\n", - "Step 100: eval WeightedCategoryAccuracy | 0.22895759\n", - "\n", - "Step 200: Ran 100 train steps in 177.36 secs\n", - "Step 200: train CrossEntropyLossWithLogSoftmax | 5.24865103\n", - "Step 200: eval CrossEntropyLossWithLogSoftmax | 4.46555328\n", - "Step 200: eval WeightedCategoryAccuracy | 0.34939855\n", - "\n", - "Step 300: Ran 100 train steps in 70.54 secs\n", - "Step 300: train CrossEntropyLossWithLogSoftmax | 4.33149672\n", - "Step 300: eval CrossEntropyLossWithLogSoftmax | 3.83377051\n", - "Step 300: eval WeightedCategoryAccuracy | 0.42087770\n", - "\n", - "Step 400: Ran 100 train steps in 67.50 secs\n", - "Step 400: train CrossEntropyLossWithLogSoftmax | 3.96885633\n", - "Step 400: eval CrossEntropyLossWithLogSoftmax | 4.21221638\n", - "Step 400: eval WeightedCategoryAccuracy | 0.34850734\n", - "\n", - "Step 500: Ran 100 train steps in 71.36 secs\n", - "Step 500: train CrossEntropyLossWithLogSoftmax | 3.83042574\n", - "Step 500: eval CrossEntropyLossWithLogSoftmax | 3.64122152\n", - "Step 500: eval WeightedCategoryAccuracy | 0.41977778\n", - "\n", - "Step 600: Ran 100 train steps in 76.19 secs\n", - "Step 600: train CrossEntropyLossWithLogSoftmax | 3.35245323\n", - "Step 600: eval CrossEntropyLossWithLogSoftmax | 3.23352051\n", - "Step 600: eval WeightedCategoryAccuracy | 0.47931033\n", - "\n", - "Step 700: Ran 100 train steps in 79.46 secs\n", - "Step 700: train CrossEntropyLossWithLogSoftmax | 3.10961103\n", - "Step 700: eval CrossEntropyLossWithLogSoftmax | 3.15166879\n", - "Step 700: eval WeightedCategoryAccuracy | 0.46864897\n", - "\n", - "Step 800: Ran 100 train steps in 81.76 secs\n", - "Step 800: train CrossEntropyLossWithLogSoftmax | 2.93149567\n", - "Step 800: eval CrossEntropyLossWithLogSoftmax | 3.35616207\n", - "Step 800: eval WeightedCategoryAccuracy | 0.44215840\n", - "\n", - "Step 900: Ran 100 train steps in 79.12 secs\n", - "Step 900: train CrossEntropyLossWithLogSoftmax | 2.82020950\n", - "Step 900: eval CrossEntropyLossWithLogSoftmax | 2.36338472\n", - "Step 900: eval WeightedCategoryAccuracy | 0.57924318\n", - "\n", - "Step 1000: Ran 100 train steps in 104.90 secs\n", - "Step 1000: train CrossEntropyLossWithLogSoftmax | 2.68358994\n", - "Step 1000: eval CrossEntropyLossWithLogSoftmax | 2.56412911\n", - "Step 1000: eval WeightedCategoryAccuracy | 0.54872346\n", - "\n", - "Step 1100: Ran 100 train steps in 83.85 secs\n", - "Step 1100: train CrossEntropyLossWithLogSoftmax | 2.58823633\n", - "Step 1100: eval CrossEntropyLossWithLogSoftmax | 2.62969518\n", - "Step 1100: eval WeightedCategoryAccuracy | 0.52579981\n", - "\n", - "Step 1200: Ran 100 train steps in 88.05 secs\n", - "Step 1200: train CrossEntropyLossWithLogSoftmax | 2.51080465\n", - "Step 1200: eval CrossEntropyLossWithLogSoftmax | 2.52758622\n", - "Step 1200: eval WeightedCategoryAccuracy | 0.53838688\n", - "\n", - "Step 1300: Ran 100 train steps in 90.28 secs\n", - "Step 1300: train CrossEntropyLossWithLogSoftmax | 2.46406817\n", - "Step 1300: eval CrossEntropyLossWithLogSoftmax | 2.31228042\n", - "Step 1300: eval WeightedCategoryAccuracy | 0.57349908\n", - "\n", - "Step 1400: Ran 100 train steps in 90.89 secs\n", - "Step 1400: train CrossEntropyLossWithLogSoftmax | 2.39312744\n", - "Step 1400: eval CrossEntropyLossWithLogSoftmax | 2.07446051\n", - "Step 1400: eval WeightedCategoryAccuracy | 0.63776493\n", - "\n", - "Step 1500: Ran 100 train steps in 94.41 secs\n", - "Step 1500: train CrossEntropyLossWithLogSoftmax | 2.35005140\n", - "Step 1500: eval CrossEntropyLossWithLogSoftmax | 2.32324076\n", - "Step 1500: eval WeightedCategoryAccuracy | 0.57490349\n", - "\n", - "Step 1600: Ran 100 train steps in 93.66 secs\n", - "Step 1600: train CrossEntropyLossWithLogSoftmax | 2.31463027\n", - "Step 1600: eval CrossEntropyLossWithLogSoftmax | 2.30394077\n", - "Step 1600: eval WeightedCategoryAccuracy | 0.57527685\n", - "\n", - "Step 1700: Ran 100 train steps in 94.34 secs\n", - "Step 1700: train CrossEntropyLossWithLogSoftmax | 2.23612332\n", - "Step 1700: eval CrossEntropyLossWithLogSoftmax | 2.14128780\n", - "Step 1700: eval WeightedCategoryAccuracy | 0.60737085\n", - "\n", - "Step 1800: Ran 100 train steps in 91.89 secs\n", - "Step 1800: train CrossEntropyLossWithLogSoftmax | 2.23225784\n", - "Step 1800: eval CrossEntropyLossWithLogSoftmax | 2.52646112\n", - "Step 1800: eval WeightedCategoryAccuracy | 0.52996337\n", - "\n", - "Step 1900: Ran 100 train steps in 92.47 secs\n", - "Step 1900: train CrossEntropyLossWithLogSoftmax | 2.17324281\n", - "Step 1900: eval CrossEntropyLossWithLogSoftmax | 2.50346041\n", - "Step 1900: eval WeightedCategoryAccuracy | 0.53843790\n", - "\n", - "Step 2000: Ran 100 train steps in 113.97 secs\n", - "Step 2000: train CrossEntropyLossWithLogSoftmax | 2.14945030\n", - "Step 2000: eval CrossEntropyLossWithLogSoftmax | 2.17006946\n", - "Step 2000: eval WeightedCategoryAccuracy | 0.59461838\n", - "\n", - "Step 2100: Ran 100 train steps in 95.56 secs\n", - "Step 2100: train CrossEntropyLossWithLogSoftmax | 2.13827372\n", - "Step 2100: eval CrossEntropyLossWithLogSoftmax | 2.08724308\n", - "Step 2100: eval WeightedCategoryAccuracy | 0.60836774\n", - "\n", - "Step 2200: Ran 100 train steps in 95.67 secs\n", - "Step 2200: train CrossEntropyLossWithLogSoftmax | 2.10173893\n", - "Step 2200: eval CrossEntropyLossWithLogSoftmax | 1.84778070\n", - "Step 2200: eval WeightedCategoryAccuracy | 0.66895270\n", - "\n", - "Step 2300: Ran 100 train steps in 99.50 secs\n", - "Step 2300: train CrossEntropyLossWithLogSoftmax | 2.07918763\n", - "Step 2300: eval CrossEntropyLossWithLogSoftmax | 2.18345213\n", - "Step 2300: eval WeightedCategoryAccuracy | 0.58705992\n", - "\n", - "Step 2400: Ran 100 train steps in 100.64 secs\n", - "Step 2400: train CrossEntropyLossWithLogSoftmax | 2.06117034\n", - "Step 2400: eval CrossEntropyLossWithLogSoftmax | 2.10406661\n", - "Step 2400: eval WeightedCategoryAccuracy | 0.61676151\n", - "\n", - "Step 2500: Ran 100 train steps in 99.83 secs\n", - "Step 2500: train CrossEntropyLossWithLogSoftmax | 2.02957368\n", - "Step 2500: eval CrossEntropyLossWithLogSoftmax | 2.14476347\n", - "Step 2500: eval WeightedCategoryAccuracy | 0.58916318\n", - "\n", - "Step 2600: Ran 100 train steps in 99.18 secs\n", - "Step 2600: train CrossEntropyLossWithLogSoftmax | 2.01416183\n", - "Step 2600: eval CrossEntropyLossWithLogSoftmax | 1.93275166\n", - "Step 2600: eval WeightedCategoryAccuracy | 0.64312029\n", - "\n", - "Step 2700: Ran 100 train steps in 105.54 secs\n", - "Step 2700: train CrossEntropyLossWithLogSoftmax | 1.98193300\n", - "Step 2700: eval CrossEntropyLossWithLogSoftmax | 1.70399988\n", - "Step 2700: eval WeightedCategoryAccuracy | 0.67673290\n", - "\n", - "Step 2800: Ran 100 train steps in 103.56 secs\n", - "Step 2800: train CrossEntropyLossWithLogSoftmax | 1.98967147\n", - "Step 2800: eval CrossEntropyLossWithLogSoftmax | 2.26600218\n", - "Step 2800: eval WeightedCategoryAccuracy | 0.57345927\n", - "\n", - "Step 2900: Ran 100 train steps in 106.24 secs\n", - "Step 2900: train CrossEntropyLossWithLogSoftmax | 1.96915519\n", - "Step 2900: eval CrossEntropyLossWithLogSoftmax | 1.83608222\n", - "Step 2900: eval WeightedCategoryAccuracy | 0.64486253\n", - "\n", - "Step 3000: Ran 100 train steps in 123.97 secs\n", - "Step 3000: train CrossEntropyLossWithLogSoftmax | 1.95634198\n", - "Step 3000: eval CrossEntropyLossWithLogSoftmax | 1.96534336\n", - "Step 3000: eval WeightedCategoryAccuracy | 0.61572355\n", - "\n", - "Step 3100: Ran 100 train steps in 101.74 secs\n", - "Step 3100: train CrossEntropyLossWithLogSoftmax | 1.93480480\n", - "Step 3100: eval CrossEntropyLossWithLogSoftmax | 2.07498431\n", - "Step 3100: eval WeightedCategoryAccuracy | 0.58850276\n", - "\n", - "Step 3200: Ran 100 train steps in 106.99 secs\n", - "Step 3200: train CrossEntropyLossWithLogSoftmax | 1.92409098\n", - "Step 3200: eval CrossEntropyLossWithLogSoftmax | 1.75855708\n", - "Step 3200: eval WeightedCategoryAccuracy | 0.65730476\n", - "\n", - "Step 3300: Ran 100 train steps in 104.48 secs\n", - "Step 3300: train CrossEntropyLossWithLogSoftmax | 1.89491403\n", - "Step 3300: eval CrossEntropyLossWithLogSoftmax | 1.59045112\n", - "Step 3300: eval WeightedCategoryAccuracy | 0.68677223\n", - "\n", - "Step 3400: Ran 100 train steps in 106.04 secs\n", - "Step 3400: train CrossEntropyLossWithLogSoftmax | 1.89865410\n", - "Step 3400: eval CrossEntropyLossWithLogSoftmax | 1.97923064\n", - "Step 3400: eval WeightedCategoryAccuracy | 0.61284077\n", - "\n", - "Step 3500: Ran 100 train steps in 103.33 secs\n", - "Step 3500: train CrossEntropyLossWithLogSoftmax | 1.88124871\n", - "Step 3500: eval CrossEntropyLossWithLogSoftmax | 1.87313676\n", - "Step 3500: eval WeightedCategoryAccuracy | 0.64622855\n", - "\n", - "Step 3600: Ran 100 train steps in 103.66 secs\n", - "Step 3600: train CrossEntropyLossWithLogSoftmax | 1.86413860\n", - "Step 3600: eval CrossEntropyLossWithLogSoftmax | 1.93156767\n", - "Step 3600: eval WeightedCategoryAccuracy | 0.61913443\n", - "\n", - "Step 3700: Ran 100 train steps in 104.92 secs\n", - "Step 3700: train CrossEntropyLossWithLogSoftmax | 1.85850441\n", - "Step 3700: eval CrossEntropyLossWithLogSoftmax | 1.82424903\n", - "Step 3700: eval WeightedCategoryAccuracy | 0.64653146\n", - "\n", - "Step 3800: Ran 100 train steps in 108.30 secs\n", - "Step 3800: train CrossEntropyLossWithLogSoftmax | 1.84594476\n", - "Step 3800: eval CrossEntropyLossWithLogSoftmax | 1.89609993\n", - "Step 3800: eval WeightedCategoryAccuracy | 0.63271540\n", - "\n", - "Step 3900: Ran 100 train steps in 106.10 secs\n", - "Step 3900: train CrossEntropyLossWithLogSoftmax | 1.81989634\n", - "Step 3900: eval CrossEntropyLossWithLogSoftmax | 2.22522569\n", - "Step 3900: eval WeightedCategoryAccuracy | 0.57733458\n", - "\n", - "Step 4000: Ran 100 train steps in 122.86 secs\n", - "Step 4000: train CrossEntropyLossWithLogSoftmax | 1.82069206\n", - "Step 4000: eval CrossEntropyLossWithLogSoftmax | 1.69654596\n", - "Step 4000: eval WeightedCategoryAccuracy | 0.66414380\n", - "\n", - "Step 4100: Ran 100 train steps in 105.58 secs\n", - "Step 4100: train CrossEntropyLossWithLogSoftmax | 1.81504095\n", - "Step 4100: eval CrossEntropyLossWithLogSoftmax | 1.82490277\n", - "Step 4100: eval WeightedCategoryAccuracy | 0.64850724\n", - "\n", - "Step 4200: Ran 100 train steps in 103.48 secs\n", - "Step 4200: train CrossEntropyLossWithLogSoftmax | 1.79890764\n", - "Step 4200: eval CrossEntropyLossWithLogSoftmax | 1.94082224\n", - "Step 4200: eval WeightedCategoryAccuracy | 0.61675012\n", - "\n", - "Step 4300: Ran 100 train steps in 105.42 secs\n", - "Step 4300: train CrossEntropyLossWithLogSoftmax | 1.78839958\n", - "Step 4300: eval CrossEntropyLossWithLogSoftmax | 1.64301860\n", - "Step 4300: eval WeightedCategoryAccuracy | 0.68001771\n", - "\n", - "Step 4400: Ran 100 train steps in 107.00 secs\n", - "Step 4400: train CrossEntropyLossWithLogSoftmax | 1.77166772\n", - "Step 4400: eval CrossEntropyLossWithLogSoftmax | 1.98841286\n", - "Step 4400: eval WeightedCategoryAccuracy | 0.61240149\n", - "\n", - "Step 4500: Ran 100 train steps in 107.03 secs\n", - "Step 4500: train CrossEntropyLossWithLogSoftmax | 1.77967501\n", - "Step 4500: eval CrossEntropyLossWithLogSoftmax | 1.55927932\n", - "Step 4500: eval WeightedCategoryAccuracy | 0.69401485\n", - "\n", - "Step 4600: Ran 100 train steps in 104.64 secs\n", - "Step 4600: train CrossEntropyLossWithLogSoftmax | 1.77094245\n", - "Step 4600: eval CrossEntropyLossWithLogSoftmax | 1.78488588\n", - "Step 4600: eval WeightedCategoryAccuracy | 0.64639080\n", - "\n", - "Step 4700: Ran 100 train steps in 107.92 secs\n", - "Step 4700: train CrossEntropyLossWithLogSoftmax | 1.77308905\n", - "Step 4700: eval CrossEntropyLossWithLogSoftmax | 1.85960603\n", - "Step 4700: eval WeightedCategoryAccuracy | 0.63444734\n", - "\n", - "Step 4800: Ran 100 train steps in 103.30 secs\n", - "Step 4800: train CrossEntropyLossWithLogSoftmax | 1.76320994\n", - "Step 4800: eval CrossEntropyLossWithLogSoftmax | 1.62576365\n", - "Step 4800: eval WeightedCategoryAccuracy | 0.67749906\n", - "\n", - "Step 4900: Ran 100 train steps in 104.57 secs\n", - "Step 4900: train CrossEntropyLossWithLogSoftmax | 1.75034785\n", - "Step 4900: eval CrossEntropyLossWithLogSoftmax | 2.15475702\n", - "Step 4900: eval WeightedCategoryAccuracy | 0.59079874\n", - "\n", - "Step 5000: Ran 100 train steps in 132.27 secs\n", - "Step 5000: train CrossEntropyLossWithLogSoftmax | 1.74079084\n", - "Step 5000: eval CrossEntropyLossWithLogSoftmax | 1.76899207\n", - "Step 5000: eval WeightedCategoryAccuracy | 0.64984393\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SftZvkIl_4ko" - }, - "source": [ - "## More Steps (optional)\n", - "\n", - "As we have specified the `n_steps_per_permanent_checkpoint` in `training.TrainTask` it saves checkpoint in `output_dir` after the specified number of steps. So, if you have face runtime disconnection or you want to train the model for more number of steps to improve the result, you could load last checkpoint saved and load it using `training_loop.load_checkpoint`. \n", - "\n", - "This is an optional way. you could have used `model.init_from_file` as in (Part (2): Model) cells. change 'gs://trax-ml/models/translation/ende_wmt32k.pkl.gz' to 'output_dir/ + last_checkpoint'" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "LBq6EZy6_4Lo" - }, - "source": [ - "output_dir = '/content/drive/MyDrive/Colab Notebooks/Transformer_FR_pretrained_336/'\n", - "\n", - "# This loads a checkpoint:\n", - "training_loop.load_checkpoint(directory=output_dir, filename=\"model.pkl.gz\")\n", - "# Continue training:\n", - "training_loop.run(5000)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8JHimNjsmHUR" - }, - "source": [ - "## Tensorboard (optional)\n", - "The Trax training loop optimizes training, creates TensorBoard logs and model checkpoints for you. you could simply visualize them using the following:\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "sfCy3oZAuron" - }, - "source": [ - "# Load the TensorBoard notebook extension\n", - "%load_ext tensorboard" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "kdcy6NP1uxOI" - }, - "source": [ - "%tensorboard --logdir output_dir" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bIy5wc90m0ZW" - }, - "source": [ - "if it is not loading properly, and for example your `output_dir` is:\n", - "\n", - "```python\n", - "output_dir = '/content/drive/MyDrive/Colab Notebooks/Transformer_FR_pretrained_336'\n", - "```\n", - "add:\n", - "```\n", - "%cd '/content/drive/MyDrive/Colab Notebooks/'\n", - "```\n", - "before:\n", - "```\n", - "%tensorboard --logdir output_dir\n", - "```\n", - "and change it to:\n", - "```\n", - "%tensorboard --logdir Transformer_FR_pretrained_336\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0WXTnjBJ9dov" - }, - "source": [ - "# Part (4): Testing\n", - "\n", - "We will now be using the model you just trained to translate English sentences to French. We will implement this with two functions: The first allows you to identify the next symbol (i.e. output token). The second one takes care of combining the entire translated string.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5g9O_h-R9do0" - }, - "source": [ - "## 4.1 Decoding" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "xH-imC6U-jBn" - }, - "source": [ - "# Setup helper functions for tokenizing and detokenizing sentences\n", - "def tokenize(input_str, vocab_file=None, vocab_dir=None):\n", - " \"\"\"Encodes a string to an array of integers\n", - " Args:\n", - " input_str (str): human-readable string to encode\n", - " vocab_file (str): filename of the vocabulary text file\n", - " vocab_dir (str): path to the vocabulary file\n", - " Returns:\n", - " numpy.ndarray: tokenized version of the input string\n", - " \"\"\"\n", - " # Set the encoding of the \"end of sentence\" as 1\n", - " EOS = 1\n", - " # Use the trax.data.tokenize method. It takes streams and returns streams,\n", - " # we get around it by making a 1-element stream with `iter`.\n", - " inputs = next(trax.data.tokenize(iter([input_str]),\n", - " vocab_file=vocab_file, vocab_dir=vocab_dir))\n", - " # Mark the end of the sentence with EOS\n", - " inputs = list(inputs) + [EOS]\n", - " # Adding the batch dimension to the front of the shape\n", - " batch_inputs = np.reshape(np.array(inputs), [1, -1])\n", - " return batch_inputs\n", - "\n", - "def detokenize(integers, vocab_file=None, vocab_dir=None):\n", - " \"\"\"Decodes an array of integers to a human readable string\n", - " Args:\n", - " integers (numpy.ndarray): array of integers to decode\n", - " vocab_file (str): filename of the vocabulary text file\n", - " vocab_dir (str): path to the vocabulary file \n", - " Returns:\n", - " str: the decoded sentence.\n", - " \"\"\"\n", - " # Remove the dimensions of size 1\n", - " integers = list(np.squeeze(integers))\n", - " # Set the encoding of the \"end of sentence\" as 1\n", - " EOS = 1\n", - " # Remove the EOS to decode only the original tokens\n", - " if EOS in integers:\n", - " integers = integers[:integers.index(EOS)] \n", - " return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R3ud8xnDGL-5" - }, - "source": [ - "There are several ways to get the next token when translating a sentence. For instance, we can just get the most probable token at each step (i.e. greedy decoding) or get a sample from a distribution. We can generalize the implementation of these two approaches by using the `tl.logsoftmax_sample()` method." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "cD8F14b49do1" - }, - "source": [ - "def next_symbol(model, input_tokens, cur_output_tokens, temperature):\n", - " \"\"\"Returns the index of the next token.\n", - " Args:\n", - " model: the NMT model.\n", - " input_tokens (np.ndarray 1 x n_tokens): tokenized representation of the input sentence\n", - " cur_output_tokens (list): tokenized representation of previously translated words\n", - " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", - " 0.0: same as argmax, always pick the most probable token\n", - " 1.0: sampling from the distribution (can sometimes say random things)\n", - " Returns:\n", - " int: index of the next token in the translated sentence\n", - " float: log probability of the next symbol\n", - " \"\"\"\n", - " # set the length of the current output tokens\n", - " token_length = len(cur_output_tokens)\n", - " # calculate next power of 2 for padding length \n", - " padded_length = np.power(2, int(np.ceil(np.log2(token_length + 1))))\n", - " # pad cur_output_tokens up to the padded_length\n", - " padded = cur_output_tokens + [0] * (padded_length - token_length) \n", - " # model expects the output to have an axis for the batch size in front so\n", - " # convert `padded` list to a numpy array with shape (x, ) where the\n", - " # x position is the batch axis.\n", - " padded_with_batch = np.expand_dims(padded, axis=0)\n", - " # the model prediction.\n", - " output, _ = model((input_tokens, padded_with_batch)) \n", - " # get log probabilities from the last token output\n", - " log_probs = output[0, token_length, :]\n", - " # get the next symbol by getting a logsoftmax sample\n", - " symbol = int(tl.logsoftmax_sample(log_probs, temperature))\n", - " return symbol, float(log_probs[symbol])" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R0KlObsa9dpE" - }, - "source": [ - "The `sampling_decode()` function will call the `next_symbol()` function above several times until the next output is the end-of-sentence token (i.e. `EOS`). It takes in an input string and returns the translated version of that string.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "bwIB-MQl9dpF" - }, - "source": [ - "def sampling_decode(input_sentence, model = None, temperature=0.0, vocab_file=None, vocab_dir=None):\n", - " \"\"\"Returns the translated sentence.\n", - " Args:\n", - " input_sentence (str): sentence to translate.\n", - " model: the NMT model.\n", - " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", - " 0.0: same as argmax, always pick the most probable token\n", - " 1.0: sampling from the distribution (can sometimes say random things)\n", - " vocab_file (str): filename of the vocabulary\n", - " vocab_dir (str): path to the vocabulary file\n", - " Returns:\n", - " tuple: (list, str, float)\n", - " list of int: tokenized version of the translated sentence\n", - " float: log probability of the translated sentence\n", - " str: the translated sentence\n", - " \"\"\" \n", - " # encode the input sentence\n", - " input_tokens = tokenize(input_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", - " # initialize the list of output tokens\n", - " cur_output_tokens = []\n", - " # initialize an integer that represents the current output index\n", - " cur_output = 0 \n", - " # Set the encoding of the \"end of sentence\" as 1\n", - " EOS = 1\n", - " # check that the current output is not the end of sentence token\n", - " while cur_output != EOS: \n", - " # update the current output token by getting the index of the next word\n", - " cur_output, log_prob = next_symbol(model, input_tokens, cur_output_tokens, temperature)\n", - " # append the current output token to the list of output tokens\n", - " cur_output_tokens.append(cur_output) \n", - " # detokenize the output tokens\n", - " sentence = detokenize(cur_output_tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", - " return cur_output_tokens, log_prob, sentence" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "diQYEDgF9dpG", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "40d8f201-3aa6-42fc-ee96-fa09f2c6959f" - }, - "source": [ - "# Test the function above. Try varying the temperature setting with values from 0 to 1.\n", - "# Run it several times with each setting and see how often the output changes.\n", - "sampling_decode(\"Hello.\", model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "([9431, 489, 3, 1], 15.834756851196289, 'Bonjour.')" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 16 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uRbgTBWt9dpO" - }, - "source": [ - "We have set a default value of `0` to the temperature setting in our implementation of `sampling_decode()` above. As you may have noticed in the `logsoftmax_sample()` method, this setting will ultimately result in greedy decoding. This algorithm generates the translation by getting the most probable word at each step. It gets the argmax of the output array of your model and then returns that index. See the testing function and sample inputs below. You'll notice that the output will remain the same each time you run it." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "g1txjY-x9dpP" - }, - "source": [ - "def greedy_decode_test(sentence, model=None, vocab_file=None, vocab_dir=None):\n", - " \"\"\"Prints the input and output of our NMT model using greedy decode\n", - " Args:\n", - " sentence (str): a custom string.\n", - " model: the NMT model.\n", - " vocab_file (str): filename of the vocabulary\n", - " vocab_dir (str): path to the vocabulary file\n", - " Returns:\n", - " str: the translated sentence\n", - " \"\"\" \n", - " _,_, translated_sentence = sampling_decode(sentence, model, vocab_file=vocab_file, vocab_dir=vocab_dir) \n", - " print(\"English: \", sentence)\n", - " print(\"French: \", translated_sentence)\n", - " return translated_sentence" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "i7XKz-9I9dpS", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "89ebe68b-1522-40e7-bc40-ee525c509235" - }, - "source": [ - "# put a custom string here\n", - "your_sentence = 'I love languages.'\n", - "greedy_decode_test(your_sentence, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR);" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "English: I love languages.\n", - "French: J'aime les langues.\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "M8UlR7LS9dpU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "40223d3f-77b5-43c5-cdef-42192673211f" - }, - "source": [ - "greedy_decode_test('You are almost done with the assignment!', model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR);" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "English: You are almost done with the assignment!\n", - "French: Vous ÃĒtes presque terminÊ avec le contrat !\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sf80_9T29dpX" - }, - "source": [ - "## 4.2 Minimum Bayes-Risk Decoding\n", - "\n", - "Getting the most probable token at each step may not necessarily produce the best results. Another approach is to do Minimum Bayes Risk Decoding or MBR. The general steps to implement this are:\n", - "\n", - "1. take several random samples\n", - "2. score each sample against all other samples\n", - "3. select the one with the highest score" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hp_qzJ8u9dpX" - }, - "source": [ - "\n", - "### 4.2.1 Generating samples\n", - "\n", - "First, let's build a function to generate several samples. You can use the `sampling_decode()` function you developed earlier to do this easily. We want to record the token list and log probability for each sample as these will be needed in the next step." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4iSRPOrI9dpX" - }, - "source": [ - "def generate_samples(sentence, n_samples, model=None, temperature=0.6, vocab_file=None, vocab_dir=None):\n", - " \"\"\"Generates samples using sampling_decode()\n", - " Args:\n", - " sentence (str): sentence to translate.\n", - " n_samples (int): number of samples to generate\n", - " model: the NMT model.\n", - " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", - " 0.0: same as argmax, always pick the most probable token\n", - " 1.0: sampling from the distribution (can sometimes say random things)\n", - " vocab_file (str): filename of the vocabulary\n", - " vocab_dir (str): path to the vocabulary file \n", - " Returns:\n", - " tuple: (list, list)\n", - " list of lists: token list per sample\n", - " list of floats: log probability per sample\n", - " \"\"\"\n", - " # define lists to contain samples and probabilities\n", - " samples, log_probs = [], []\n", - " # run a for loop to generate n samples\n", - " for _ in range(n_samples):\n", - " # get a sample using the sampling_decode() function\n", - " sample, logp, _ = sampling_decode(sentence, model, temperature, vocab_file=vocab_file, vocab_dir=vocab_dir)\n", - " # append the token list to the samples list\n", - " samples.append(sample)\n", - " # append the log probability to the log_probs list\n", - " log_probs.append(logp) \n", - " return samples, log_probs" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "LlYC8y8H9dpZ", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "241edf8f-3921-46be-930b-00058cd6efb5" - }, - "source": [ - "# generate 4 samples with the default temperature (0.6)\n", - "generate_samples('I love languages.', 4, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "([[769, 7, 31720, 21, 15267, 3, 1],\n", - " [769, 7, 31720, 13, 15267, 3, 1],\n", - " [254, 31720, 21, 15267, 3, 1],\n", - " [769, 7, 31720, 13, 15267, 3, 1]],\n", - " [18.705636978149414, 18.2911319732666, 19.461563110351562, 18.2911319732666])" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 21 - } - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 36 - }, - "id": "VR6FLNdcILll", - "outputId": "b228a403-306f-4c74-b5ea-8c0f129906a4" - }, - "source": [ - "detokenize([769, 31, 31720, 21, 15267, 3, 1], VOCAB_FILE, VOCAB_DIR)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'J’aime les langues.'" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 22 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HonzLcOP9dpb" - }, - "source": [ - "### 4.2.2 Comparing overlaps\n", - "\n", - "Let us now build our functions to compare a sample against another. There are several metrics available and you can try experimenting with any one of these. We will be calculating scores for unigram overlaps. One of the more simple metrics is the [Jaccard similarity](https://en.wikipedia.org/wiki/Jaccard_index) which gets the intersection over union of two sets." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IB7ipzoZ9dpc" - }, - "source": [ - "def jaccard_similarity(candidate, reference):\n", - " \"\"\"Returns the Jaccard similarity between two token lists\n", - " Args:\n", - " candidate (list of int): tokenized version of the candidate translation\n", - " reference (list of int): tokenized version of the reference translation\n", - " Returns:\n", - " float: overlap between the two token lists\n", - " \"\"\" \n", - " # convert the lists to a set to get the unique tokens\n", - " can_unigram_set, ref_unigram_set = set(candidate), set(reference) \n", - " # get the set of tokens common to both candidate and reference\n", - " joint_elems = can_unigram_set.intersection(ref_unigram_set)\n", - " # get the set of all tokens found in either candidate or reference\n", - " all_elems = can_unigram_set.union(ref_unigram_set)\n", - " # divide the number of joint elements by the number of all elements\n", - " overlap = len(joint_elems) / len(all_elems)\n", - " return overlap" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CZRis5hp9dph" - }, - "source": [ - "One of the more commonly used metrics in machine translation is the ROUGE score. For unigrams, this is called ROUGE-1 and you can output the scores for both precision and recall when comparing two samples. To get the final score, you will want to compute the F1-score as given by:\n", - "\n", - "$$score = 2* \\frac{(precision * recall)}{(precision + recall)}$$\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WRhPTgv09dpi" - }, - "source": [ - "# for making a frequency table easily\n", - "from collections import Counter\n", - "\n", - "def rouge1_similarity(system, reference):\n", - " \"\"\"Returns the ROUGE-1 score between two token lists\n", - " Args:\n", - " system (list of int): tokenized version of the system translation\n", - " reference (list of int): tokenized version of the reference translation\n", - " Returns:\n", - " float: overlap between the two token lists\n", - " \"\"\" \n", - " # make a frequency table of the system tokens\n", - " sys_counter = Counter(system) \n", - " # make a frequency table of the reference tokens\n", - " ref_counter = Counter(reference)\n", - " # initialize overlap to 0\n", - " overlap = 0\n", - " # run a for loop over the sys_counter object\n", - " for token in sys_counter: \n", - " # lookup the value of the token in the sys_counter dictionary \n", - " token_count_sys = sys_counter.get(token,0)\n", - " # lookup the value of the token in the ref_counter dictionary \n", - " token_count_ref = ref_counter.get(token,0)\n", - " # update the overlap by getting the smaller number between the two token counts above\n", - " overlap += min(token_count_sys, token_count_ref) \n", - " # get the precision (i.e. number of overlapping tokens / number of system tokens)\n", - " precision = overlap / sum(sys_counter.values()) \n", - " # get the recall (i.e. number of overlapping tokens / number of reference tokens)\n", - " recall = overlap / sum(ref_counter.values()) \n", - " if precision + recall != 0:\n", - " # compute the f1-score\n", - " rouge1_score = 2 * ((precision * recall)/(precision + recall))\n", - " else:\n", - " rouge1_score = 0 \n", - " return rouge1_score" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qn3wLqSb9dpp" - }, - "source": [ - "### 4.2.3 Overall score\n", - "\n", - "We will now build a function to generate the overall score for a particular sample. As mentioned earlier, we need to compare each sample with all other samples. For instance, if we generated 30 sentences, we will need to compare sentence 1 to sentences 2 to 30. Then, we compare sentence 2 to sentences 1 and 3 to 30, and so forth. At each step, we get the average score of all comparisons to get the overall score for a particular sample. To illustrate, these will be the steps to generate the scores of a 4-sample list.\n", - "\n", - "1. Get similarity score between sample 1 and sample 2\n", - "2. Get similarity score between sample 1 and sample 3\n", - "3. Get similarity score between sample 1 and sample 4\n", - "4. Get average score of the first 3 steps. This will be the overall score of sample 1.\n", - "5. Iterate and repeat until samples 1 to 4 have overall scores.\n", - "\n", - "We will be storing the results in a dictionary for easy lookups." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Umtj0NLX9dpp" - }, - "source": [ - "def average_overlap(similarity_fn, samples, *ignore_params):\n", - " \"\"\"Returns the arithmetic mean of each candidate sentence in the samples\n", - " Args:\n", - " similarity_fn (function): similarity function used to compute the overlap\n", - " samples (list of lists): tokenized version of the translated sentences\n", - " *ignore_params: additional parameters will be ignored\n", - " Returns:\n", - " dict: scores of each sample\n", - " key: index of the sample\n", - " value: score of the sample\n", - " \"\"\" \n", - " # initialize dictionary\n", - " scores = {}\n", - " # run a for loop for each sample\n", - " for index_candidate, candidate in enumerate(samples): \n", - " # initialize overlap to 0.0\n", - " overlap = 0.0\n", - " # run a for loop for each sample\n", - " for index_sample, sample in enumerate(samples): \n", - " # skip if the candidate index is the same as the sample index\n", - " if index_candidate == index_sample:\n", - " continue \n", - " # get the overlap between candidate and sample using the similarity function\n", - " sample_overlap = similarity_fn(candidate,sample) \n", - " # add the sample overlap to the total overlap\n", - " overlap += sample_overlap \n", - " # get the score for the candidate by computing the average\n", - " score = overlap/index_sample \n", - " # save the score in the dictionary. use index as the key.\n", - " scores[index_candidate] = score \n", - " return scores" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-w7LL7lm9dpx" - }, - "source": [ - "It is also common to see the weighted mean being used to calculate the overall score instead of just the arithmetic mean." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "o70TS8PG9dpy" - }, - "source": [ - "def weighted_avg_overlap(similarity_fn, samples, log_probs):\n", - " \"\"\"Returns the weighted mean of each candidate sentence in the samples\n", - " Args:\n", - " samples (list of lists): tokenized version of the translated sentences\n", - " log_probs (list of float): log probability of the translated sentences\n", - " Returns:\n", - " dict: scores of each sample\n", - " key: index of the sample\n", - " value: score of the sample\n", - " \"\"\"\n", - " # initialize dictionary\n", - " scores = {} \n", - " # run a for loop for each sample\n", - " for index_candidate, candidate in enumerate(samples): \n", - " # initialize overlap and weighted sum\n", - " overlap, weight_sum = 0.0, 0.0 \n", - " # run a for loop for each sample\n", - " for index_sample, (sample, logp) in enumerate(zip(samples, log_probs)):\n", - " # skip if the candidate index is the same as the sample index \n", - " if index_candidate == index_sample:\n", - " continue \n", - " # convert log probability to linear scale\n", - " sample_p = float(np.exp(logp))\n", - " # update the weighted sum\n", - " weight_sum += sample_p\n", - " # get the unigram overlap between candidate and sample\n", - " sample_overlap = similarity_fn(candidate, sample) \n", - " # update the overlap\n", - " overlap += sample_p * sample_overlap \n", - " # get the score for the candidate\n", - " score = overlap / weight_sum\n", - " # save the score in the dictionary. use index as the key.\n", - " scores[index_candidate] = score\n", - " return scores" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l5jgBrPu9dp4" - }, - "source": [ - "### 4.2.4 Putting it all together\n", - "\n", - "We will now put everything together and develop the `mbr_decode()` function." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "S58nPXgY9dp5" - }, - "source": [ - "def mbr_decode(sentence, n_samples=4, score_fn=weighted_avg_overlap, similarity_fn=rouge1_similarity, model=model,\n", - " temperature=0.6, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR):\n", - " \"\"\"Returns the translated sentence using Minimum Bayes Risk decoding\n", - " Args:\n", - " sentence (str): sentence to translate.\n", - " n_samples (int): number of samples to generate\n", - " score_fn (function): function that generates the score for each sample\n", - " similarity_fn (function): function used to compute the overlap between a\n", - " pair of samples\n", - " model: the NMT model.\n", - " temperature (float): parameter for sampling ranging from 0.0 to 1.0.\n", - " 0.0: same as argmax, always pick the most probable token\n", - " 1.0: sampling from the distribution (can sometimes say random things)\n", - " vocab_file (str): filename of the vocabulary\n", - " vocab_dir (str): path to the vocabulary file\n", - " Returns:\n", - " str: the translated sentence\n", - " \"\"\"\n", - " # generate samples\n", - " samples, log_probs = generate_samples(sentence, n_samples,\n", - " model, temperature,\n", - " vocab_file, vocab_dir) \n", - " # use the scoring function to get a dictionary of scores\n", - " scores = score_fn(similarity_fn, samples, log_probs)\n", - " # find the key with the highest score\n", - " max_index = max(scores, key=scores.get) \n", - " # detokenize the token list associated with the max_index\n", - " translated_sentence = detokenize(samples[max_index], vocab_file, vocab_dir)\n", - " return (translated_sentence, max_index, scores)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ab1LHo-59dp8" - }, - "source": [ - "# put a custom string here\n", - "your_sentence = 'She speaks English, French and German.'" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "BhgGWv7c9dp_", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ae5e00cb-0935-45f7-9fbe-c96f0e12dfc1" - }, - "source": [ - "mbr_decode(your_sentence)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "('Elle parle anglais, français et Allemand.',\n", - " 1,\n", - " {0: 0.909090909090909,\n", - " 1: 0.9730044973480663,\n", - " 2: 0.9730044973480663,\n", - " 3: 0.9730044973480663})" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 29 - } - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 36 - }, - "id": "QqyR1Ym6A_Ah", - "outputId": "eb7db397-28a6-41c7-c44e-0c314574d147" - }, - "source": [ - "mbr_decode('You have completed the tutorial.')[0]" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'Vous avez terminÊ le tutorial.'" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 30 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RPbqDGUY8Vp_" - }, - "source": [ - "# **Resources**\n", - "\n", - "- [Natural Language Processing Specialization](https://www.coursera.org/specializations/natural-language-processing)\n", - "\n", - "- [Trax documentation](https://trax-ml.readthedocs.io/en/latest/index.html)\n", - "\n", - "- [Trax community](https://gitter.im/trax-ml/community)" - ] - } - ] -} diff --git a/trax/examples/Terraformer_from_scratch.ipynb b/trax/examples/Terraformer_from_scratch.ipynb deleted file mode 100644 index 9e3eaea6c..000000000 --- a/trax/examples/Terraformer_from_scratch.ipynb +++ /dev/null @@ -1,2587 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "Vzsxj2EV3lfL" - }, - "source": [ - "# Scaling Transformers - Sparse Is Enough\n", - "\n", - "Licensed under the Apache License, Version 2.0", - "\n", - "This colab contains all relevant code for the paper \"Sparse is Enough in Scaling Transformers\". We depend on the Trax library and the experiments in the paper were not run with the colab but in a distributed setup with the attached config files -- but with the code below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "SMmztiOqenFD" - }, - "outputs": [], - "source": [ - "# Imports.\n", - "!pip install --upgrade -q trax==1.3.9\n", - "\n", - "import functools\n", - "import os\n", - "import random\n", - "import time\n", - "import numpy as np\n", - "\n", - "import jax\n", - "import trax\n", - "from trax import layers as tl\n", - "from trax import fastmath\n", - "from trax.fastmath import numpy as jnp\n", - "from trax.supervised import training" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fi6zzlt15l-d" - }, - "source": [ - "## Main sparse layers\n", - "\n", - "This cell contains the implementation of our main sparse layers:\n", - "* sparse QKV layers\n", - "* sparse feed-forward blocks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kbTJBQ_fBz8d" - }, - "outputs": [], - "source": [ - "def SplitLastAxis(num_splits):\n", - " return tl.Fn(f'SplitLastAxis_{num_splits}',\n", - " lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1)))\n", - "\n", - "\n", - "def MergeLastTwoAxes():\n", - " return tl.Fn('MergeLastTwoAxes',\n", - " lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,)))\n", - "\n", - "\n", - "def LocallyConnectedDense(n_modules, n_units, kernel_size=1,\n", - " kernel_initializer=tl.GlorotUniformInitializer(),\n", - " bias_initializer=tl.RandomNormalInitializer(1e-6),\n", - " use_bias=True):\n", - " \"\"\"Layer using LocallyConnected1d for approximation of Dense layer.\n", - "\n", - " The layer splits the last axis of a tensor into `n_modules`, then runs\n", - " LocallyConnected1d (grouped convolution) on all those modules, and\n", - " concatenates their results. It is essentially a locally-sensitive\n", - " approximation of Dense layer, with number of parameters smaller by the factor\n", - " of `n_modules / kernel_size`.\n", - "\n", - " Args:\n", - " n_modules: Indicates how many modules (pixels) should be input and output\n", - " split into for processing.\n", - " n_units: how many outputs (filters) should each module generate.\n", - " kernel_size: The size of the kernel to be used.\n", - " kernel_initializer: Function that creates a matrix of (random) initial\n", - " connection weights `W` for the layer.\n", - " bias_initializer: Function that creates a vector of (random) initial\n", - " bias weights `b` for the layer.\n", - " use_bias: If `True`, compute an affine map `y = Wx + b`; else compute\n", - " a linear map `y = Wx`.\n", - "\n", - " Returns:\n", - " LocallyConnectedDense tl.Layer.\n", - " \"\"\"\n", - " if n_modules == 1:\n", - " return tl.Dense(n_units, kernel_initializer=kernel_initializer,\n", - " bias_initializer=bias_initializer, use_bias=use_bias)\n", - " return tl.Serial(\n", - " SplitLastAxis(n_modules),\n", - " tl.LocallyConnected1d(\n", - " n_units, kernel_size, kernel_initializer=kernel_initializer,\n", - " bias_initializer=bias_initializer, use_bias=use_bias, padding='WRAP'),\n", - " MergeLastTwoAxes())\n", - "\n", - "\n", - "class _RememberPad(tl.Layer):\n", - " \"\"\"Layer which remembers last N elements in predict mode.\"\"\"\n", - "\n", - " def __init__(self, n_items_to_remember, mode):\n", - " \"\"\"Returns a layer which remembers last N elements in predict mode.\n", - "\n", - " For predict mode, the layer remembers last N elements and pads with them.\n", - " For other modes, it pads with zeros. The layer pads/remembers elements from\n", - " the second axis.\n", - "\n", - " Args:\n", - " n_items_to_remember: Number of items to remember/pad with.\n", - " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", - " \"\"\"\n", - " super().__init__(name='_RememberPad')\n", - " self._n_items_to_remember = n_items_to_remember\n", - " self._mode = mode\n", - " self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none\n", - "\n", - " def monkey_patched_mask(self):\n", - " # This is necessary for Terraformer model. See comments there.\n", - " # The mask will only be used in Terraformer in predict mode.\n", - " return None\n", - "\n", - " def forward(self, x):\n", - " if self._n_items_to_remember == 0:\n", - " return x\n", - " if self._mode == 'predict':\n", - " x = jnp.concatenate([self.state[0], x], axis=1)\n", - " if self._portal_mask is not None and 'init' in self.state[1]:\n", - " assert x.shape[0] == 1\n", - " mask = self._portal_mask.get_value()\n", - " count_padding = jnp.sum(mask == 0, dtype=jnp.int32)\n", - " self.state = (fastmath.dynamic_slice_in_dim(\n", - " x, x.shape[1] - (self._n_items_to_remember + count_padding),\n", - " self._n_items_to_remember, axis=1), {'forward': ()})\n", - " else:\n", - " self.state = (x[:, -self._n_items_to_remember:, ...], {'forward': ()})\n", - " else:\n", - " pad_widths = [[0, 0] for _ in range(len(x.shape))]\n", - " pad_widths[1][0] = self._n_items_to_remember\n", - " x = jnp.pad(x, pad_width=pad_widths, mode='constant')\n", - " return x\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Initializes this layer's weights.\"\"\"\n", - " if isinstance(input_signature, (list, tuple)):\n", - " input_signature = input_signature[0]\n", - " self.weights = ()\n", - " if self._mode == 'predict':\n", - " shape = list(input_signature.shape)\n", - " shape[1] = self._n_items_to_remember\n", - " self.state = (jnp.zeros(shape, dtype=jnp.float32), {'init': ()})\n", - " else:\n", - " self.state = ()\n", - "\n", - "\n", - "def LocallyConvDense(n_modules, n_units, mode, kernel_size=1,\n", - " length_kernel_size=1):\n", - " \"\"\"Layer using local convolutions for approximation of Dense layer.\n", - "\n", - " The layer splits the last axis of a tensor into `n_modules`, then runs\n", - " a convolution on all those modules, and concatenates their results.\n", - " It is similar to LocallyConnectedDense above, but shares weights.\n", - "\n", - " Args:\n", - " n_modules: Indicates how many modules (pixels) should be input and output\n", - " split into for processing.\n", - " n_units: how many outputs (filters) should each module generate.\n", - " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", - " kernel_size: The size of the kernel to be used.\n", - " length_kernel_size: If \u003e 1, also do causal convolution on the previous axis,\n", - " which is often the sentence length in sequence models.\n", - "\n", - " Returns:\n", - " LocallyConvDense tl.Layer.\n", - " \"\"\"\n", - " if n_modules == 1:\n", - " return tl.Dense(n_units)\n", - " if kernel_size % 2 != 1:\n", - " raise ValueError('Currently we only handle odd kernel sizes.')\n", - " half = (kernel_size - 1) // 2\n", - " pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]]\n", - " return tl.Serial(\n", - " SplitLastAxis(n_modules),\n", - " tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths, mode='constant')),\n", - " _RememberPad(length_kernel_size-1, mode=mode),\n", - " tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)),\n", - " MergeLastTwoAxes()\n", - " )\n", - "\n", - "\n", - "def RandomLayer(layer_a, layer_b, prob_a):\n", - " \"\"\"Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.\"\"\"\n", - " condition = tl.Serial(\n", - " tl.RandomUniform(),\n", - " tl.Fn('SmallerThan', lambda x: x \u003c prob_a)\n", - " )\n", - " return tl.Cond(condition, layer_a, layer_b)\n", - "\n", - "\n", - "def SparseDenseWithOptions(n_units, d_input=None, sparsity_type=None,\n", - " sparsity=0, d_lowrank=None, prob_sparse=None,\n", - " mode=None, use_bias=True, use_bfloat16=False):\n", - " \"\"\"Configurable sparse version of Dense layer.\"\"\"\n", - " if prob_sparse is not None:\n", - " if mode is not None and mode != 'train':\n", - " # For non-training modes, we want to use a sparse variant.\n", - " # This is different than simply prob_sparse being None, as the weights of\n", - " # the model are different.\n", - " prob_sparse = 1.0\n", - " return RandomLayer(\n", - " SparseDenseWithOptions(n_units, d_input, sparsity_type, sparsity,\n", - " d_lowrank, use_bias=use_bias,\n", - " use_bfloat16=use_bfloat16),\n", - " tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16),\n", - " prob_sparse)\n", - "\n", - " if sparsity_type is None or sparsity_type == 'None' or sparsity == 0:\n", - " return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16)\n", - " if sparsity_type == 'mult':\n", - " return FactoredDense(sparsity, d_input, n_units, use_bias=use_bias,\n", - " use_bfloat16=use_bfloat16)\n", - "\n", - " assert not use_bfloat16 # use_bfloat16 is unsupported for other variants\n", - " if sparsity_type == 'local':\n", - " assert use_bias # use_bias = False is unsupported\n", - " assert n_units % sparsity == 0\n", - " return LocallyConnectedDense(sparsity, n_units/sparsity)\n", - " if sparsity_type == 'local3':\n", - " assert use_bias # use_bias = False is unsupported\n", - " assert n_units % sparsity == 0\n", - " return LocallyConnectedDense(sparsity, n_units/sparsity, kernel_size=3)\n", - "\n", - " raise ValueError('Unknown sparsity type: {}'.format(sparsity_type))\n", - "\n", - "\n", - "def FactoredDense(n_modules, d_in, d_out, use_bias=True, use_bfloat16=False):\n", - " r\"\"\"Returns a Dense-like layer, internally factored to use fewer parameters.\n", - "\n", - " This layer treats an activation vector as if divided into :math:`M`\n", - " subvectors (``n_modules`` 'modules'). It uses this factored view to compute\n", - " a :py:class:`Dense`-like mapping with high mixing/connectivity, but using\n", - " approximately :math:`1/M` the number of weights of a similarly dimensioned\n", - " :py:class:`Dense` layer.\n", - "\n", - " More specifically, each activation vector of dimensionality ``n_in`` is\n", - " multiplied element-wise (a generalized form of gating) with ``n_modules``\n", - " vectors also of dimensionality ``n_in``. The resulting vectors are projected\n", - " to the subvector/module dimensionality ``d_out / n_modules`` via a matrix\n", - " multiply, and finally reshaped back to a single vector of dimensionality\n", - " ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at\n", - " the end. All the above-mentioned non-input objects -- gating vectors,\n", - " projection matrix, and optional bias -- are trainable weights.\n", - "\n", - " Args:\n", - " n_modules: Number by which an activation vector is divided into subvectors\n", - " (modules) for the factored computation.\n", - " d_in: Last/innermost dimension of input array.\n", - " d_out: Last/innermost dimension of output array.\n", - " use_bias: If True, add bias vectors at the end of the layer; else end the\n", - " layer with the matrix multiply.\n", - " use_bfloat16: If True, use bfloat16 weights; else use float32 weights.\n", - " \"\"\"\n", - " if d_out % n_modules != 0:\n", - " raise ValueError(f'Value d_out ({d_out}) must be a multiple of arg '\n", - " f'n_modules ({n_modules}).')\n", - " d_module = d_out // n_modules\n", - "\n", - " def GatingVectors():\n", - " return tl.Weights(tl.RandomNormalInitializer(stddev=0.5),\n", - " shape=[n_modules, d_in],\n", - " use_bfloat16=use_bfloat16)\n", - "\n", - " def ProjectionMatrix():\n", - " return tl.Weights(tl.GlorotUniformInitializer(),\n", - " shape=[d_in, d_module],\n", - " use_bfloat16=use_bfloat16),\n", - "\n", - " def Bias():\n", - " return tl.Weights(tl.RandomNormalInitializer(1e-6),\n", - " shape=[d_out],\n", - " use_bfloat16=use_bfloat16),\n", - "\n", - " layers = [\n", - " GatingVectors(),\n", - " ProjectionMatrix(),\n", - " _GateAndProject(),\n", - " MergeLastTwoAxes(),\n", - " ]\n", - " if use_bias:\n", - " layers += [Bias(), tl.Add()]\n", - "\n", - " return tl.Serial(layers)\n", - "\n", - "\n", - "def _GateAndProject():\n", - " \"\"\"Returns a combined gating+projection layer that saves on memory.\"\"\"\n", - "\n", - " def f(projection, gating, x):\n", - " # Args arrive in reverse order because of how they were put on the stack.\n", - " # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules)\n", - " return jnp.einsum('...d,nd,dm-\u003e...nm', x, gating, projection)\n", - "\n", - " return tl.Fn('_GateAndProject', f)\n", - "\n", - "\n", - "def MultiplicativeConvCausalAttention(\n", - " d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0,\n", - " force_no_dropout=False, max_inference_length=2048, share_qk=False,\n", - " output_layer_type='none', v_concat_type='none', mode='train'):\n", - " \"\"\"Returns a layer that maps activations to activations, with causal masking.\n", - "\n", - " Like `CausalAttention`, this layer type represents one pass of multi-head\n", - " self-attention with causal masking rather than padding-based masking. However,\n", - " for computing Q/K/V instead of a Dense layer it combines\n", - " FactoredDense layer with LocallyConvLayer.\n", - "\n", - " Args:\n", - " d_feature: Depth/dimensionality of feature embedding.\n", - " n_heads: Number of attention heads.\n", - " sparsity: The sparsity of the layer; usually it should be equal to n_heads.\n", - " length_kernel_size: Size of convolution kernel on the length dimension.\n", - " dropout: Probababilistic rate for internal dropout applied to attention\n", - " activations (based on query-key pairs) before dotting them with values.\n", - " force_no_dropout: If True, force dropout to be 0.0 independent of the above\n", - " value; used to override some configurations.\n", - " max_inference_length: maximum length for inference.\n", - " share_qk: if True, average Q and K embeddings and share for both Q and K.\n", - " output_layer_type: Which sparse layers to use for processing output from the\n", - " attention mechanism. One of `'none'`, `'mult'`, `'conv'`,\n", - " or `'multconv'`.\n", - " v_concat_type: What kind of concatenation to use when computing V tensor.\n", - " One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just\n", - " output from mutliplicative layer shared by Q, K, V. `'fixed'` means\n", - " using output from multiplicative layer concatenated, for each module,\n", - " with the layer input. `'original'` means using concatenation without\n", - " properly taking modules into account; this method was used in\n", - " experiments previously, so it is included for backwards-compatibility.\n", - " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", - " \"\"\"\n", - " assert output_layer_type in ['none', 'mult', 'conv', 'multconv']\n", - " assert v_concat_type in ['original', 'fixed', 'none']\n", - "\n", - " dropout = 0.0 if force_no_dropout else dropout\n", - " sparsity = n_heads if sparsity is None else sparsity\n", - " d_module = d_feature // sparsity\n", - "\n", - " output_layers = []\n", - " if 'mult' in output_layer_type:\n", - " output_layers.append(FactoredDense(\n", - " sparsity, d_feature, d_feature))\n", - " if 'conv' in output_layer_type:\n", - " output_layers.append(LocallyConvDense(\n", - " sparsity, d_module, mode=mode, kernel_size=3,\n", - " length_kernel_size=length_kernel_size))\n", - "\n", - " if v_concat_type == 'original':\n", - " # 'original'` uses concatenation without properly taking modules into\n", - " # account; this method was used in experiments previously, so it is included\n", - " # for backwards-compatibility.\n", - " concat_layers = [tl.Concatenate()] # use permuted and original for v\n", - " elif v_concat_type == 'fixed':\n", - " # `'fixed'` uses the output from multiplicative layer concatenated, for each\n", - " # module, with the layer input. This means that every module in Conv layer\n", - " # has access both to parts of embeddings which were used to compute Q/K of\n", - " # this particular module, and it ha access to parts of the embedding which\n", - " # will be modified by this module.\n", - " concat_layers = [\n", - " tl.Parallel(\n", - " tl.Fn('Reshape1', lambda x: jnp.reshape( # pylint: disable=g-long-lambda\n", - " x, (x.shape[0], x.shape[1], sparsity, d_module))),\n", - " tl.Fn('Reshape2', lambda x: jnp.reshape( # pylint: disable=g-long-lambda\n", - " x, (x.shape[0], x.shape[1], sparsity, d_module)))),\n", - " tl.Concatenate(),\n", - " tl.Fn('Reshape3',\n", - " lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2*d_feature))),\n", - " ]\n", - " elif v_concat_type == 'none':\n", - " # `'none'` doesn't use concatenation: we throw away the original layer\n", - " # input and pass to Conv only output of shared Multiplicative layer.\n", - " concat_layers = [tl.Select([0], n_in=2)]\n", - "\n", - " if share_qk:\n", - " return tl.Serial(\n", - " tl.Select([0, 0]), # pre-qkv, pre-v-for-concat\n", - " FactoredDense(sparsity, d_feature, d_feature), # shared q k\n", - " tl.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat\n", - " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", - " length_kernel_size=length_kernel_size),\n", - " tl.SplitIntoHeads(n_heads),\n", - " tl.Select([0, 0]), # use for q and k\n", - " tl.Parallel(\n", - " [],\n", - " [],\n", - " [concat_layers,\n", - " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,\n", - " length_kernel_size=length_kernel_size),\n", - " tl.SplitIntoHeads(n_heads)],\n", - " ),\n", - " tl.DotProductCausalAttention(\n", - " dropout=dropout, max_inference_length=max_inference_length,\n", - " mode=mode),\n", - " tl.MergeHeads(n_heads),\n", - " output_layers,\n", - " )\n", - " return tl.Serial(\n", - " tl.Select([0, 0]), # duplicate activations\n", - " FactoredDense(sparsity, d_feature, d_feature), # shared q, k\n", - " tl.Select([0, 0, 0]), # use for q, k, v\n", - " tl.Parallel(\n", - " [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", - " length_kernel_size=length_kernel_size),\n", - " tl.SplitIntoHeads(n_heads)],\n", - " [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3,\n", - " length_kernel_size=length_kernel_size),\n", - " tl.SplitIntoHeads(n_heads)],\n", - " [concat_layers,\n", - " LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1,\n", - " length_kernel_size=length_kernel_size),\n", - " tl.SplitIntoHeads(n_heads)],\n", - " ),\n", - " tl.DotProductCausalAttention(\n", - " dropout=dropout, max_inference_length=max_inference_length,\n", - " mode=mode),\n", - " tl.MergeHeads(n_heads),\n", - " output_layers,\n", - " )\n", - "\n", - "\n", - "class DotProductCausalAttention(tl.Layer):\n", - " \"\"\"Layer that computes attention strengths by masking out the \"future\".\n", - "\n", - " Causal attention uses masking to prevent a given sequence position from\n", - " attending to positions greater than / following it. This is used, for\n", - " example, when training autoregressive sequence models, or when decoding a\n", - " sequence symbol by symbol.\n", - "\n", - " This layer performs the core per-head attention calculation. The layer\n", - " assumes that any splitting into attention heads precedes it, and that any\n", - " merging of attention heads will follow it.\n", - " \"\"\"\n", - "\n", - " def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'):\n", - " \"\"\"Creates a :py:class:`DotProductCausalAttention` instance.\n", - "\n", - " Args:\n", - " dropout: Probababilistic rate for attention dropout, which overrides\n", - " (sets to zero) some attention strengths derived from query-key\n", - " matching. As a result, on a given forward pass, some value vectors\n", - " don't contribute to the output, analogous to how regular dropout can\n", - " cause some node activations to be ignored. Applies only if layer is\n", - " created in ``'train'`` mode.\n", - " max_inference_length: Maximum sequence length allowed in non-training\n", - " modes.\n", - " mode: One of ``'train'``, ``'eval'``, or ``'predict'``.\n", - " \"\"\"\n", - " super().__init__(n_in=3, n_out=1)\n", - " self._dropout = dropout\n", - " self._mode = mode\n", - " self._max_len = max_inference_length\n", - " self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none\n", - "\n", - " def monkey_patched_mask(self):\n", - " # This is necessary for Terraformer model. See comments there.\n", - " # The mask will only be used in Terraformer in predict mode.\n", - " return None\n", - "\n", - " def forward(self, inputs):\n", - " \"\"\"Returns attention-computed activations.\n", - "\n", - " Args:\n", - " inputs: A (queries, keys, values) tuple.\n", - " \"\"\"\n", - " q, k, v = inputs\n", - "\n", - " if self._portal_mask is not None:\n", - " mask_for_predict = self._portal_mask.get_value()\n", - " else:\n", - " mask_for_predict = None\n", - "\n", - " if self._mode == 'predict':\n", - " self.state, mask = _fast_inference_update_state(\n", - " inputs, self.state,\n", - " mask_for_predict=mask_for_predict)\n", - " if self._portal_mask is not None:\n", - " (_, k, v, _) = self.state\n", - " else:\n", - " (k, v, _) = self.state\n", - " else:\n", - " sequence_length = q.shape[-2]\n", - " mask = _causal_mask(sequence_length)\n", - "\n", - " activations, attn_strengths = _per_head_attention(\n", - " q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)\n", - " if self._mode == 'viz':\n", - " self.state = attn_strengths\n", - " return activations\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Initializes this layer for fast inference, if in ``'predict'`` mode.\"\"\"\n", - " if self._mode == 'predict':\n", - " self.state = _fast_inference_init_state(\n", - " input_signature, self._max_len,\n", - " predict_mask=self._portal_mask)\n", - " \n", - "def _fast_inference_init_state(input_signature, buffer_length,\n", - " predict_mask=None):\n", - " \"\"\"Returns an initial state for causal attention layer fast inference.\"\"\"\n", - " def zeros_for(batch_size, shape_dtype):\n", - " shape, dtype = shape_dtype.as_tuple()\n", - " d_feature = shape[-1]\n", - " return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype)\n", - "\n", - " batch_size = input_signature[0].shape[0]\n", - " k = zeros_for(batch_size, input_signature[1])\n", - " v = zeros_for(batch_size, input_signature[2])\n", - " if predict_mask is not None:\n", - " mask_for_predict = jnp.zeros((buffer_length,)) != 0\n", - " return (mask_for_predict, k, v, jnp.array(0))\n", - " else:\n", - " return (k, v, jnp.array(0))\n", - "\n", - "\n", - "def _fast_inference_update_state(inputs, state, mask_for_predict=None):\n", - " \"\"\"Updates state of a causal attention layer for fast inference.\n", - "\n", - " The layer state stores arrays with cached values of keys and values,\n", - " as well as an index. To make shapes static, keys and values in the state are\n", - " long, and the index indicates where the new keys and values from inputs need\n", - " to be appended.\n", - "\n", - " During update, we append new_keys and new_values to keys and values at\n", - " position given by index. And we increment index by length of new keys.\n", - " We also create a mask to be 1 at appropriate positions (causal mask).\n", - "\n", - " Args:\n", - " inputs: a triple (new_queries, new_keys, new_values)\n", - " state: layer state with (keys, values, index)\n", - " mask_for_predict: mask used for predict mode. This is used only in\n", - " Terraformer.\n", - "\n", - " Returns:\n", - " Updated state and mask to be used.\n", - " \"\"\"\n", - " # Fast inference: run step-by-step, storing the sequence\n", - " # of keys and values calculated so far in state.\n", - " (_, new_k, new_v) = inputs\n", - " if mask_for_predict is not None:\n", - " (state_mask_for_predict, ks, vs, idx) = state\n", - " else:\n", - " (ks, vs, idx) = state\n", - " length = new_k.shape[1]\n", - " ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)\n", - " vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)\n", - " k_length = ks.shape[1]\n", - "\n", - " # Mask is of shape [1, q_length, k_length].\n", - " # Mask should be true for every pair of (query_token, key_token) such that\n", - " # index of query_token is equal or larger to index of key_token.\n", - " mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length))\n", - " \u003c= jnp.reshape(jnp.arange(length) + idx, (1, length, 1)))\n", - " if mask_for_predict is None:\n", - " return (ks, vs, idx + length), mask\n", - " else:\n", - " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", - " state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0,\n", - " axis=0)\n", - "\n", - " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", - " state_mask_for_predict != 0, jnp.ones((1,)) != 0,\n", - " jnp.sum(mask_for_predict, dtype=jnp.int32), axis=0)\n", - "\n", - " state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(\n", - " state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0)\n", - " placeholder = jnp.reshape(state_mask_for_predict != 0,\n", - " (1, 1, mask.shape[2],))\n", - " mask = mask * placeholder\n", - "\n", - " return (state_mask_for_predict, ks, vs, idx + length), mask\n", - "\n", - "\n", - "def _causal_mask(length):\n", - " # Not all backends define jnp.tril. However, using np.tril is inefficient\n", - " # in that it creates a large global constant.\n", - " if fastmath.is_backend(fastmath.Backend.JAX):\n", - " return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)\n", - " else:\n", - " return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)\n", - "\n", - "\n", - "def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):\n", - " \"\"\"Computes new per-head activations via scaled dot-product attention.\n", - "\n", - " This function is the core of the attention mechanism. Given per-head\n", - " ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:\n", - "\n", - " - computes the scaled dot product of each Q-K pair;\n", - " - applies ``mask`` to screen out positions that come from padding tokens\n", - " (indicated by 0 value);\n", - " - [in ``'train'`` mode] applies dropout to Q-K dot products;\n", - " - computes Q-K attention strengths using a per-query softmax of the Q-K dot\n", - " products; and\n", - " - for each query position, combines V vectors according to the Q-K\n", - " attention strengths.\n", - "\n", - " Args:\n", - " queries: Per-head activations representing attention queries.\n", - " keys: Per-head activations representing attention keys.\n", - " values: Per-head activations to be combined by computed attention strengths.\n", - " mask: Mask that distinguishes positions with real content vs. padding.\n", - " dropout: Probababilistic rate for attention dropout, which overrides\n", - " (sets to zero) some attention strengths derived from query-key\n", - " matching. As a result, on a given forward pass, some value vectors\n", - " don't contribute to the output, analogous to how regular dropout can\n", - " cause some node activations to be ignored. Applies only in ``'train'``\n", - " mode.\n", - " mode: One of ``'train'``, ``'eval'``, or ``'predict'``.\n", - " rng: Single-use random number generator (JAX PRNG key).\n", - "\n", - " Returns:\n", - " Tuple of (activations, attn_strengths), where activations are new per-head\n", - " activation vectors and attn_strengths is a matrix of per-head attention\n", - " strengths.\n", - " \"\"\"\n", - " if dropout \u003e= 1.0:\n", - " raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')\n", - "\n", - " d_feature = queries.shape[-1]\n", - "\n", - " dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)\n", - " if mask is not None:\n", - " dots = jnp.where(mask,\n", - " dots,\n", - " jnp.full_like(dots, -1e9))\n", - " attn_strengths = (\n", - " jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))\n", - " if dropout is not None and dropout \u003e 0.0 and mode == 'train':\n", - " keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)\n", - " attn_strengths = jnp.where(keep,\n", - " attn_strengths / (1.0 - dropout),\n", - " jnp.zeros_like(attn_strengths))\n", - " activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)\n", - " attn_strengths = attn_strengths.astype(jnp.float32)\n", - " return activations, attn_strengths\n", - "\n", - "\n", - "class _RememberInReverse(tl.Layer):\n", - " \"\"\"Layer remembering the input in forward pass. For reversible models.\"\"\"\n", - "\n", - " def __init__(self, output=True):\n", - " \"\"\"Layer remembering the input in forward pass. For reversible models.\n", - "\n", - " During the first pass through the model this layer saves the input as\n", - " state, and returns the input unmodified. During the second pass through the\n", - " model the layer outputs the input from the first pass. This is used to\n", - " combat numerical stability problems in Terraformer. It doesn't do anything\n", - " in non-reversible models.\n", - "\n", - " Args:\n", - " output: Whether to pass the input or not.\n", - " \"\"\"\n", - " n_out = 1 if output else 0\n", - " self._output = output\n", - " super().__init__(name='_RememberInReverse', n_out=n_out)\n", - "\n", - " def forward(self, x):\n", - " if 'running_second_time_yes' in self.state[1]:\n", - " result = self.state[0]\n", - " else:\n", - " result = x\n", - " self.state = (x, {'running_second_time': ()})\n", - "\n", - " if self._output:\n", - " return result\n", - " else:\n", - " return tuple()\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Initializes this layer's weights.\"\"\"\n", - " if isinstance(input_signature, (list, tuple)):\n", - " input_signature = input_signature[0]\n", - " self.weights = ()\n", - " self.state = (jnp.zeros(input_signature.shape, dtype=jnp.int32),\n", - " {'running_second_time': ()})\n", - "\n", - "\n", - "class _RecallQuantMaskInReverse(tl.Layer):\n", - " \"\"\"Layer recalling quant mask from specific _RememberInReverse.\n", - "\n", - " This layer is needed for memory-efficient training of reversible model with\n", - " ff chunking. During forward pass it simply returns minus ones, which are\n", - " ignored in the controller. During reverse_and_grad it returns a quant_mask\n", - " which was memorized (saved to state) by a RememberInReverse layer.\n", - "\n", - " This enable us to save quant_mask right after chunking, and load it again\n", - " (when reversing) right before chunking.\n", - " \"\"\"\n", - "\n", - " def __init__(self, remember_layer, elements):\n", - " self._remember_layer = remember_layer\n", - " self._elements = elements\n", - " super().__init__(name='_RecallQuantMaskInReverse', n_in=1, n_out=2)\n", - "\n", - " def forward(self, x):\n", - " if (self._remember_layer.state and\n", - " 'running_second_time_yes' in self._remember_layer.state[1]):\n", - " # It's reverse_and_grad, so we pull the quant_mask from remembering layer.\n", - " result = self._remember_layer.state[0]\n", - " else:\n", - " result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32)\n", - " return (x, result)\n", - "\n", - "\n", - "class _SparseFFController(tl.Layer):\n", - " \"\"\"The controller part of Sparse Feed-Forward layer.\"\"\"\n", - "\n", - " def __init__(self, d_ff, n_elements_in_block, d_lowrank, temperature,\n", - " use_bfloat16, mode, kernel_initializer, bias_initializer,\n", - " also_return_nondiscrete_output):\n", - " \"\"\"Returns a sparse feed-forward block.\"\"\"\n", - " n_out = 2 if also_return_nondiscrete_output else 1\n", - " super().__init__(name=f'_SparseFFController_{d_ff}', n_in=2, n_out=n_out)\n", - " self._use_bfloat16 = use_bfloat16\n", - " self._d_ff = d_ff\n", - " self._d_lowrank = d_lowrank\n", - " # Q: what temperature is actually most useful in training?\n", - " self._temperature = temperature if mode == 'train' else 0.0\n", - " self._mode = mode\n", - " self._n_elements_in_block = n_elements_in_block\n", - " self._kernel_initializer = kernel_initializer\n", - " self._bias_initializer = bias_initializer\n", - " # Helper numbers as d_ff will be divided by n_elements_in_block.\n", - " assert self._d_ff % self._n_elements_in_block == 0\n", - " self._d1 = self._d_ff // self._n_elements_in_block\n", - " self._d2 = self._n_elements_in_block\n", - " self._also_return_nondiscrete_output = also_return_nondiscrete_output\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Executes this layer as part of a forward pass through the model.\n", - "\n", - " Args:\n", - " x: Tensor of same shape and dtype as the input signature used to\n", - " initialize this layer.\n", - "\n", - " Returns:\n", - " Tensor of same shape and dtype as the input.\n", - " \"\"\"\n", - " x, recalled_quant_mask = x\n", - " m1, m2, mb = self.weights\n", - "\n", - " x_shape = x.shape\n", - " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", - "\n", - " # Q: should we add bias and/or put relu after the low-rank m1 dot?\n", - " # Replacing multiplication and reshape by this einsum brings training speed\n", - " # improvement (see also reshape in initialization).\n", - " mask_logits = jnp.einsum('bd,dl,lxy-\u003ebxy', x, m1, m2) + mb\n", - "\n", - " if self._also_return_nondiscrete_output:\n", - " # Softmax.\n", - " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", - " log_mask = mask_logits - mask_logsumexp\n", - " mask = jnp.exp(log_mask)\n", - " # Gumbel-softmax with straight-through discretization.\n", - " if self._temperature == 0.0:\n", - " quant_mask = jnp.argmax(log_mask, axis=-1)\n", - " else:\n", - " u = fastmath.random.uniform(self.rng, mask.shape, jnp.float32, 1e-6,\n", - " 1.0 - 1e-6)\n", - " g = -jnp.log(-jnp.log(u))\n", - " quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", - " else:\n", - " quant_mask = jnp.argmax(mask_logits, axis=-1)\n", - "\n", - " if self._mode == 'train':\n", - " # We use recalled_quant_mask if it's different than -1; otherwise\n", - " # we use a quant_mask which we have just computed.\n", - " quant_mask = jnp.where(recalled_quant_mask == -1,\n", - " quant_mask, recalled_quant_mask)\n", - "\n", - " if self._also_return_nondiscrete_output:\n", - " return quant_mask, mask\n", - " else:\n", - " return quant_mask\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", - " x_input_signature = input_signature[0]\n", - " d_model = x_input_signature.shape[-1]\n", - " shape_m1 = (d_model, self._d_lowrank)\n", - " shape_m2 = (self._d_lowrank, self._d_ff)\n", - " shape_mb = (self._d_ff,)\n", - "\n", - " rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3)\n", - " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", - " m2 = self._kernel_initializer(shape_m2, rng_m2)\n", - " mb = self._bias_initializer(shape_mb, rng_mb)\n", - " if self._use_bfloat16:\n", - " m1 = m1.astype(jnp.bfloat16)\n", - " m2 = m2.astype(jnp.bfloat16)\n", - " mb = mb.astype(jnp.bfloat16)\n", - "\n", - " # Reshapes below, with einsum in feedforward, improve the training speed.\n", - " m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2])\n", - " mb = jnp.reshape(mb, [self._d1, self._d2])\n", - "\n", - " self.weights = (m1, m2, mb)\n", - "\n", - "\n", - "class _SparseFFMain(tl.Layer):\n", - " \"\"\"The main (non-controller) part of Sparse Feed-Forward layer.\"\"\"\n", - "\n", - " def __init__(self, d_ff, n_elements_in_block, d_lowrank, quant_prob,\n", - " use_bfloat16, big_weights_in_bfloat16, mode, kernel_initializer,\n", - " bias_initializer, multiply_by_controller_output, kernel_scaling):\n", - " \"\"\"Returns a sparse feed-forward block.\"\"\"\n", - " n_in = 3 if mode == 'train' or multiply_by_controller_output else 2\n", - " super().__init__(name=f'_SparseFFMain_{d_ff}', n_in=n_in, n_out=2)\n", - " self._mode = mode\n", - " self._use_bfloat16 = use_bfloat16\n", - " self._big_weights_in_bfloat16 = big_weights_in_bfloat16\n", - " self._d_ff = d_ff\n", - " self._d_lowrank = d_lowrank\n", - " self._quant_prob = quant_prob\n", - " self._n_elements_in_block = n_elements_in_block\n", - " self._kernel_initializer = kernel_initializer\n", - " self._bias_initializer = bias_initializer\n", - " # Helper numbers as d_ff will be divided by n_elements_in_block.\n", - " assert self._d_ff % self._n_elements_in_block == 0\n", - " self._d1 = self._d_ff // self._n_elements_in_block\n", - " self._d2 = self._n_elements_in_block\n", - " self._multiply_by_controller_output = multiply_by_controller_output\n", - " self._kernel_scaling = kernel_scaling\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Executes this layer as part of a forward pass through the model.\n", - "\n", - " Args:\n", - " x: Tensor of same shape and dtype as the input signature used to\n", - " initialize this layer.\n", - "\n", - " Returns:\n", - " Tensor of same shape and dtype as the input.\n", - " \"\"\"\n", - " if self._mode == 'train' or self._multiply_by_controller_output:\n", - " quant_mask, mask, x = x\n", - " else:\n", - " quant_mask, x = x\n", - " original_quant_mask = quant_mask\n", - "\n", - " w1, w2, b2 = self.weights\n", - "\n", - " if self._mode == 'predict':\n", - " w1 = jnp.transpose(w1, (1, 2, 0)) # dm, d1, d2 -\u003e d1, d2, dm\n", - " w2 = jnp.transpose(w2, (1, 0, 2)) # d2, d1, dm -\u003e d1, d2, dm\n", - " x_shape = x.shape\n", - " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", - "\n", - " if self._mode == 'train':\n", - " # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797\n", - " quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)\n", - " quant_mask = fastmath.stop_gradient(quant_mask)\n", - " quant_mask += mask - fastmath.stop_gradient(mask) # straight-through\n", - " # We will sometimes (quant_prob of the batches) use the soft-mask instead\n", - " # of the quantized mask to improve training stability (see paper above).\n", - " select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0)\n", - " quant_mask = jnp.where(select \u003c self._quant_prob, quant_mask, mask)\n", - "\n", - " # In training, run full matmul to get benefits from the above tricks.\n", - " mid = jnp.einsum('bd,dxy-\u003ebxy', x, w1) * quant_mask\n", - " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", - " if self._multiply_by_controller_output:\n", - " # We multiply only for quantized decisions, since for non-quantized\n", - " # decisions we've already multiplied the output.\n", - " mask_mult = jnp.where(select \u003c self._quant_prob,\n", - " mask, jnp.ones_like(mask))\n", - " # Stop-gradient is here, because we already have a pass-through gradient\n", - " # (for quantized decisions).\n", - " mask_mult = fastmath.stop_gradient(mask_mult)\n", - " relu = relu * mask_mult\n", - " res = jnp.einsum('bxy,yxd-\u003ebd', relu, w2) + b2\n", - " elif self._mode == 'predict':\n", - " # This implementation mimicks inference. It's not efficient for large\n", - " # size of joint_batch, but at inference that will be 1 most of the time.\n", - " # Shapes:\n", - " # quant_mask is [joint_batch, self._d1]\n", - " # w1 is [d_model, self._d1, self._d2]\n", - " # we'll index w1 with advanced numpy indexing, first range over\n", - " # self._d1 times the batch size, second range being quant_mask\n", - " batch_size = quant_mask.shape[0]\n", - " idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)\n", - " # flatten indices and select from w1\n", - " idx1 = jnp.reshape(idx1, [-1])\n", - " idx2 = jnp.reshape(quant_mask, [-1])\n", - " w = w1[idx1, idx2, :] # now we have per-element weights with batch dim\n", - " w = jnp.reshape(w, [batch_size, self._d1, -1])\n", - " mid = jnp.einsum('ai,aji-\u003eaj', x, w)\n", - " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", - " if self._multiply_by_controller_output:\n", - " mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0]\n", - " relu = relu * mask_mult\n", - " # w2 is [self._d1, self._d2, d_model]\n", - " v = w2[idx1, idx2, :]\n", - " v = jnp.reshape(v, [batch_size, self._d1, -1])\n", - " res = jnp.einsum('ai,aij-\u003eaj', relu, v) + b2\n", - " else:\n", - " quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)\n", - " mid = jnp.einsum('bd,dxy-\u003ebxy', x, w1) * quant_mask\n", - " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", - " if self._multiply_by_controller_output:\n", - " relu = relu * mask\n", - " res = jnp.einsum('bxy,yxd-\u003ebd', relu, w2) + b2\n", - "\n", - " return original_quant_mask, jnp.reshape(res, x_shape)\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", - " d_model = input_signature[-1].shape[-1]\n", - " shape_w1 = (d_model, self._d_ff)\n", - " shape_w2 = (self._d_ff, d_model)\n", - " shape_b2 = (d_model,)\n", - "\n", - " rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3)\n", - " if tl.N_WEIGHTS_SHARDS \u003e 1:\n", - " # In sharded-weights mode, put the weights on CPU on init\n", - " # as they will be sharded later.\n", - " w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1))\n", - " w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2))\n", - " else:\n", - " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", - " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", - "\n", - " b2 = self._bias_initializer(shape_b2, rng_b2)\n", - " if self._use_bfloat16:\n", - " b2 = b2.astype(jnp.bfloat16)\n", - " if self._use_bfloat16 or self._big_weights_in_bfloat16:\n", - " w1 = w1.astype(jnp.bfloat16)\n", - " w2 = w2.astype(jnp.bfloat16)\n", - "\n", - " w1 = jnp.reshape(w1, (-1, self._d1, self._d2))\n", - " w2 = jnp.reshape(w2, (self._d2, self._d1, -1))\n", - "\n", - " if self._kernel_scaling:\n", - " # This keeps expected variance of the output regardless of N.\n", - " w2 = w2 * (self._n_elements_in_block ** 0.5)\n", - "\n", - " self.weights = (w1, w2, b2)\n", - "\n", - "\n", - "def SparseFF(\n", - " d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.1, quant_prob=0.3,\n", - " use_bfloat16=False, big_weights_in_bfloat16=False, mode='train',\n", - " kernel_initializer=tl.GlorotUniformInitializer(),\n", - " bias_initializer=tl.RandomNormalInitializer(1e-6),\n", - " dropout_rate=0.0, dropout_shared_axes=None, ff_chunk_size=0,\n", - " multiply_by_controller_output=False, kernel_scaling=False):\n", - " \"\"\"Returns Feed-forward block with sparsity.\n", - "\n", - " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", - " that takes an input, makes it of size d_ff (usually larger than it was) and\n", - " then brings it back to the original size after Relu. It is commonly used in\n", - " Transformer models where it often accounts for most of the trainable weights.\n", - "\n", - " The original block can be slow in decoding due to the need to fetch a lot of\n", - " weights from memory. This sparse block only allows one non-zero element\n", - " in a block of a specified size. This is trained with straight-through Gumbel\n", - " softmax trick.\n", - "\n", - " Args:\n", - " d_ff: Depth/dimensionality of FeedForward layer.\n", - " n_elements_in_block: The sparsity level. The layer is divided into blocks of\n", - " this size, and each block has only a single element active.\n", - " d_lowrank: The dimensionality of low-rank controller.\n", - " temperature: The temperature of the controller during training.\n", - " quant_prob: During training this proportion of blocks will have quantized\n", - " mask (i.e. a single element active). The rest will use a soft mask.\n", - " use_bfloat16: Whether to use bfloat16 for weights.\n", - " big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the\n", - " FeedForward layer.\n", - " mode: One of `'train'`, `'eval'`, or `'predict'`.\n", - " kernel_initializer: Function that creates a matrix of (random) initial\n", - " connection weights `W` for the layer.\n", - " bias_initializer: Function that creates a vector of (random) initial\n", - " bias weights `b` for the layer.\n", - " dropout_rate: Probability for dropping an activation value.\n", - " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", - " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", - " way to save memory and apply consistent masks to activation vectors at\n", - " different sequence positions.\n", - " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks.\n", - " multiply_by_controller_output: whether to multiply the middle activation\n", - " layer of FF by controller output (i.e. softmax).\n", - " kernel_scaling: Whether to scale the kernel matrix (during init) to keep the\n", - " variance of the layer output regardless of n_elements_in_block.\n", - " \"\"\"\n", - "\n", - " if mode == 'train' or multiply_by_controller_output:\n", - " also_return_nondiscrete_output = True\n", - " else:\n", - " also_return_nondiscrete_output = False\n", - " controller = _SparseFFController(\n", - " d_ff=d_ff, n_elements_in_block=n_elements_in_block,\n", - " d_lowrank=d_lowrank, temperature=temperature,\n", - " use_bfloat16=use_bfloat16, mode=mode,\n", - " kernel_initializer=kernel_initializer,\n", - " bias_initializer=bias_initializer,\n", - " also_return_nondiscrete_output=also_return_nondiscrete_output)\n", - "\n", - " main = [\n", - " _SparseFFMain(\n", - " d_ff=d_ff, n_elements_in_block=n_elements_in_block,\n", - " d_lowrank=d_lowrank, quant_prob=quant_prob, use_bfloat16=use_bfloat16,\n", - " big_weights_in_bfloat16=big_weights_in_bfloat16, mode=mode,\n", - " kernel_initializer=kernel_initializer,\n", - " bias_initializer=bias_initializer,\n", - " multiply_by_controller_output=multiply_by_controller_output,\n", - " kernel_scaling=kernel_scaling),\n", - " # quant_mask, emb\n", - " tl.Select([1, 0]),\n", - " # emb, quant_mask\n", - " tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode),\n", - " tl.Select([1, 0]),\n", - " # quant_mask, emb\n", - " ]\n", - "\n", - " # We will \"remember\" quant_mask _after_ chunking, and \"recall\" this same\n", - " # quant_mask during reverse_and_grad _before_ chunking.\n", - " remembering = _RememberInReverse(output=False)\n", - " recalling = _RecallQuantMaskInReverse(\n", - " remember_layer=remembering, elements=d_ff//n_elements_in_block)\n", - "\n", - " return tl.BatchLeadingAxes(tl.Serial(\n", - " recalling, # emb, quant_mask\n", - " tl.Chunk(chunk_size=ff_chunk_size, layer=tl.Serial(\n", - " # emb, quant_mask\n", - " tl.Select((0, 1, 0)), # emb, quant_mask, emb\n", - " controller, # quant_mask, mask, emb\n", - " main, # quant_mask, emb/output\n", - " )),\n", - " remembering, # emb/output\n", - " ))\n", - "\n", - "\n", - "class BlockSparseFF(tl.Layer):\n", - " \"\"\"Feed-forward block with block sparsity.\n", - "\n", - " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", - " that takes an input, makes it of size d_ff (usually larger than it was) and\n", - " then brings it back to the original size after Relu. It is commonly used in\n", - " Transformer models where it often accounts for most of the trainable weights.\n", - "\n", - " This block sparse layer mimics mixture of experts architecture.\n", - " It divides the dimension of d_ff in each weight matrix to # of blocks equal to\n", - " n_experts and activates only one non-zero block from the weights matrix.\n", - " This is trained with straight-through Gumbel softmax trick.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " d_ff,\n", - " n_experts=64,\n", - " temperature=0.7,\n", - " mode='train',\n", - " kernel_initializer=tl.GlorotUniformInitializer(),\n", - " bias_initializer=tl.RandomNormalInitializer(1e-6)):\n", - " \"\"\"Returns a block sparse feed-forward block.\"\"\"\n", - " super().__init__(name=f'BlockSparseFF_{d_ff}')\n", - " self._mode = mode\n", - " self._d_ff = d_ff\n", - " self._n_experts = n_experts\n", - " self._temperature = temperature if mode == 'train' else 0.0\n", - " self._n_elements_in_block = d_ff // n_experts\n", - " self._kernel_initializer = kernel_initializer\n", - " self._bias_initializer = bias_initializer\n", - " assert self._d_ff % self._n_experts == 0\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Executes this layer as part of a forward pass through the model.\n", - "\n", - " Args:\n", - " x: Tensor of same shape and dtype as the input signature used to\n", - " initialize this layer.\n", - "\n", - " Returns:\n", - " Tensor of same shape and dtype as the input.\n", - " \"\"\"\n", - " m1, w1, w2, b2 = self.weights\n", - " x_shape = x.shape\n", - " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", - "\n", - " # Q: check if we need bias and/or put relu after the m1 dot?\n", - " mask_logits = jnp.dot(x, m1)\n", - " # Softmax.\n", - " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", - " log_mask = mask_logits - mask_logsumexp\n", - " mask = jnp.exp(log_mask)\n", - " # Gumbel-softmax with straight-through discretization.\n", - " rng1, rng2 = fastmath.random.split(self.rng, 2)\n", - " u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)\n", - " g = -jnp.log(-jnp.log(u))\n", - " selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", - " if self._mode == 'train':\n", - " # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797\n", - " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", - " quant_mask = fastmath.stop_gradient(quant_mask)\n", - " quant_mask += mask - fastmath.stop_gradient(mask) # straight-through\n", - " # We will sometimes (50% of the batches) use the soft-mask instead of\n", - " # the quantized mask to improve training stability (see the paper above).\n", - " # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?\n", - " select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)\n", - " quant_mask = jnp.where(select \u003e 0.0, quant_mask, mask)\n", - " else:\n", - " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", - " quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])\n", - " batch_size = quant_mask.shape[0]\n", - "\n", - " if self._mode == 'predict' and batch_size == 1:\n", - " # This implementation mimicks inference for batch_size 1.\n", - " start_idx = selected_experts[0] * self._n_elements_in_block\n", - " # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]\n", - " w = fastmath.dynamic_slice(w1, [0, start_idx],\n", - " [w1.shape[0], self._n_elements_in_block])\n", - " mid = jnp.dot(x, w)\n", - " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", - " # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]\n", - " v = fastmath.dynamic_slice(w2, [start_idx, 0],\n", - " [self._n_elements_in_block, w2.shape[-1]])\n", - " v = jnp.reshape(v, [self._n_elements_in_block, -1])\n", - " res = jnp.dot(relu, v) + b2\n", - " else:\n", - " expanded_mask = jnp.broadcast_to(\n", - " quant_mask,\n", - " (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))\n", - " expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))\n", - " mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff]\n", - " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", - " res = jnp.dot(relu, w2) + b2\n", - "\n", - " return jnp.reshape(res, x_shape) # un-flatten if needed\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", - " d_model = input_signature.shape[-1]\n", - " shape_m1 = (d_model, self._n_experts)\n", - " shape_w1 = (d_model, self._d_ff)\n", - " shape_w2 = (self._d_ff, d_model)\n", - " shape_b2 = (d_model,)\n", - "\n", - " rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)\n", - " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", - " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", - " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", - " b2 = self._bias_initializer(shape_b2, rng_b2)\n", - "\n", - " self.weights = (m1, w1, w2, b2)\n", - "\n", - "\n", - "class SwitchSparseFF(tl.Layer):\n", - " \"\"\"Feed-forward block with switch-style block sparsity.\n", - "\n", - " The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense\n", - " that takes an input, makes it of size d_ff (usually larger than it was) and\n", - " then brings it back to the original size after Relu. It is commonly used in\n", - " Transformer models where it often accounts for most of the trainable weights.\n", - "\n", - " This block sparse layer mimics mixture of experts architecture.\n", - " It divides the dimension of d_ff in each weight matrix to # of blocks equal to\n", - " n_experts and activates only one non-zero block from the weights matrix.\n", - " This is trained with methods following the Switch Transformer.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " d_ff,\n", - " n_experts=64,\n", - " temperature=0.1,\n", - " mode='train',\n", - " kernel_initializer=tl.GlorotUniformInitializer(),\n", - " bias_initializer=tl.RandomNormalInitializer(1e-6)):\n", - " \"\"\"Returns a switch-style training block sparse feed-forward block.\"\"\"\n", - " super().__init__(name=f'SwitchSparseFF_{d_ff}')\n", - " self._mode = mode\n", - " self._d_ff = d_ff\n", - " self._n_experts = n_experts\n", - " self._temperature = temperature if mode == 'train' else 0.0\n", - " self._n_elements_in_block = d_ff // n_experts\n", - " self._kernel_initializer = kernel_initializer\n", - " self._bias_initializer = bias_initializer\n", - " assert self._d_ff % self._n_experts == 0\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Executes this layer as part of a forward pass through the model.\n", - "\n", - " Args:\n", - " x: Tensor of same shape and dtype as the input signature used to\n", - " initialize this layer.\n", - "\n", - " Returns:\n", - " Tensor of same shape and dtype as the input.\n", - " \"\"\"\n", - " m1, w1, w2, b2 = self.weights\n", - " x_shape = x.shape\n", - " x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x.\n", - "\n", - " # Q: check if we need bias and/or put relu after the m1 dot?\n", - " mask_logits = jnp.dot(x, m1)\n", - " # Softmax.\n", - " mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True)\n", - " log_mask = mask_logits - mask_logsumexp\n", - " mask = jnp.exp(log_mask)\n", - " # Gumbel noise to allow sampling from the softmax.\n", - " rng1, _ = fastmath.random.split(self.rng, 2)\n", - " u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6)\n", - " g = -jnp.log(-jnp.log(u))\n", - " selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1)\n", - " quant_mask = tl.one_hot(selected_experts, self._n_experts)\n", - " quant_mask = fastmath.stop_gradient(quant_mask)\n", - " quant_mask *= mask # go to just the selected expert\n", - " quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1])\n", - " batch_size = quant_mask.shape[0]\n", - "\n", - " if self._mode == 'predict' and batch_size == 1:\n", - " mask_flat = jnp.reshape(mask, [-1, self._n_experts])\n", - " selected_flat = jnp.reshape(selected_experts, [-1])\n", - " selected_mask_flat = mask_flat[np.arange(selected_flat.size),\n", - " selected_flat]\n", - " # This implementation mimicks inference for batch_size 1.\n", - " start_idx = selected_experts[0] * self._n_elements_in_block\n", - " # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]\n", - " w = fastmath.dynamic_slice(w1, [0, start_idx],\n", - " [w1.shape[0], self._n_elements_in_block])\n", - " mid = jnp.dot(x, w)\n", - " mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None]\n", - " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", - " # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]\n", - " v = fastmath.dynamic_slice(w2, [start_idx, 0],\n", - " [self._n_elements_in_block, w2.shape[-1]])\n", - " v = jnp.reshape(v, [self._n_elements_in_block, -1])\n", - " res = jnp.dot(relu, v) + b2\n", - " else:\n", - " expanded_mask = jnp.broadcast_to(\n", - " quant_mask,\n", - " (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block))\n", - " expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))\n", - " mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff]\n", - " relu = jnp.where(mid \u003c= 0, jnp.zeros_like(mid), mid)\n", - " res = jnp.dot(relu, w2) + b2\n", - "\n", - " return jnp.reshape(res, x_shape) # un-flatten if needed\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", - " d_model = input_signature.shape[-1]\n", - " shape_m1 = (d_model, self._n_experts)\n", - " shape_w1 = (d_model, self._d_ff)\n", - " shape_w2 = (self._d_ff, d_model)\n", - " shape_b2 = (d_model,)\n", - "\n", - " rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4)\n", - " m1 = self._kernel_initializer(shape_m1, rng_m1)\n", - " w1 = self._kernel_initializer(shape_w1, rng_w1)\n", - " w2 = self._kernel_initializer(shape_w2, rng_w2)\n", - " b2 = self._bias_initializer(shape_b2, rng_b2)\n", - "\n", - " self.weights = (m1, w1, w2, b2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4-3_EPyP4c7K" - }, - "outputs": [], - "source": [ - "# SRU needs to be changed in order for concatenated encoder-decoder attention\n", - "# to work in predict mode.\n", - "\n", - "def MakeZeroState(depth_multiplier=1):\n", - " \"\"\"Makes zeros of shape like x but removing the length (axis 1).\"\"\"\n", - " def f(x): # pylint: disable=invalid-name\n", - " if len(x.shape) != 3:\n", - " raise ValueError(f'Layer input should be a rank 3 tensor representing'\n", - " f' (batch_size, sequence_length, feature_depth); '\n", - " f'instead got shape {x.shape}.')\n", - " return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]),\n", - " dtype=jnp.float32)\n", - " return tl.Fn('MakeZeroState', f)\n", - "\n", - "def InnerSRUCell():\n", - " \"\"\"The inner (non-parallel) computation of an SRU.\"\"\"\n", - " def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name\n", - " res = cur_f * cur_state + cur_x_times_one_minus_f\n", - " return res, res\n", - " return tl.Fn('InnerSRUCell', f, n_out=2)\n", - "\n", - "\n", - "def ScanSRUCell(mode, monkey_patched_mask=None):\n", - " \"\"\"The inner (non-parallel) computation of an SRU.\"\"\"\n", - " if monkey_patched_mask is None:\n", - " return tl.Scan(InnerSRUCell(), axis=1, mode=mode)\n", - "\n", - " # This is necessary for Terraformer model. See comments there.\n", - " # The mask will only be used in Terraformer in predict mode.\n", - " assert mode == 'predict'\n", - "\n", - " def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name\n", - " initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32)\n", - " if initial.shape[1] \u003e 1:\n", - " updated_mask = fastmath.dynamic_update_slice_in_dim(\n", - " initial != 0, mask != 0, 1, axis=1)\n", - " else:\n", - " updated_mask = initial\n", - " return updated_mask, x_times_one_minus_f\n", - "\n", - " def masked_inner_sru_cell(cur_mask, cur_x_times_one_minus_f, cur_f, # pylint: disable=invalid-name\n", - " cur_state):\n", - " res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask\n", - " + (1 - cur_mask) * cur_state)\n", - " return res, res\n", - "\n", - " return tl.Serial(\n", - " monkey_patched_mask.get_layer(),\n", - " tl.Fn('update_mask', update_mask, n_out=2),\n", - " tl.Scan(tl.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2),\n", - " axis=1, mode=mode),\n", - " )\n", - "\n", - "\n", - "def SRU(n_units, activation=None, mode='train'):\n", - " r\"\"\"SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.\n", - "\n", - " As defined in the paper:\n", - "\n", - " .. math::\n", - " y_t \u0026= W x_t + B \\quad \\hbox{(include $B$ optionally)} \\\\\n", - " f_t \u0026= \\sigma(Wf x_t + bf) \\\\\n", - " r_t \u0026= \\sigma(Wr x_t + br) \\\\\n", - " c_t \u0026= f_t \\times c_{t-1} + (1 - f_t) \\times y_t \\\\\n", - " h_t \u0026= r_t \\times \\hbox{activation}(c_t) + (1 - r_t) \\times x_t\n", - "\n", - " We assume the input is of shape [batch, length, depth] and recurrence\n", - " happens on the length dimension. This returns a single layer. It's best\n", - " to use at least 2, they say in the paper, except inside a Transformer.\n", - "\n", - " Args:\n", - " n_units: output depth of the SRU layer.\n", - " activation: Optional activation function.\n", - " mode: if 'predict' then we save the previous state for one-by-one inference\n", - "\n", - " Returns:\n", - " The SRU layer.\n", - " \"\"\"\n", - " sigmoid_activation = tl.Sigmoid()\n", - " return tl.Serial( # x\n", - " tl.Branch(tl.Dense(3 * n_units), []), # r_f_y, x\n", - " tl.Split(n_items=3), # r, f, y, x\n", - " tl.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x\n", - " tl.Fn('',\n", - " lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x\n", - " n_out=3),\n", - " tl.Parallel([], [], tl.Branch(MakeZeroState(), [])),\n", - " ScanSRUCell(mode=mode),\n", - " tl.Select([0], n_in=2), # act(c), r, x\n", - " activation if activation is not None else [],\n", - " tl.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)),\n", - " # Set the name to SRU and don't print sublayers.\n", - " name=f'SRU_{n_units}', sublayers_to_print=[]\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cyf_7nTr55gU" - }, - "source": [ - "## Terraformer\n", - "\n", - "The cells below contain the implementation of the Terraformer architecture:\n", - "* feed-forward and positional encoding blocks\n", - "* encoder and decoder blocks\n", - "* concatenation and stripping to combine the encoder and decoder\n", - "* the final Terraformer model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3eEe0xnOvG_X" - }, - "outputs": [], - "source": [ - "def _FeedForward(d_model, d_ff, dropout, activation, act_dropout,\n", - " use_bfloat16, mode):\n", - " \"\"\"Feed-forward block with layer normalization at start.\"\"\"\n", - " if act_dropout is None:\n", - " act_dropout = dropout\n", - " return [\n", - " tl.Dense(d_ff, use_bfloat16=use_bfloat16),\n", - " tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode),\n", - " activation(),\n", - " tl.Dense(d_model, use_bfloat16=use_bfloat16),\n", - " ]\n", - "\n", - "\n", - "def FeedForwardWithOptions(d_model,\n", - " d_ff,\n", - " dropout,\n", - " dropout_shared_axes,\n", - " ff_activation,\n", - " ff_dropout,\n", - " ff_chunk_size,\n", - " ff_use_sru,\n", - " ff_sparsity,\n", - " center_layernorm,\n", - " mode,\n", - " use_bfloat16=False,\n", - " ff_sparsity_type='1inN'):\n", - " \"\"\"Feed-Forward block with all the options.\n", - "\n", - " Args:\n", - " d_model: Final dimension of tensors at most points in the model, including\n", - " the initial embedding output.\n", - " d_ff: Size of special dense layer in the feed-forward part of each block.\n", - " dropout: Stochastic rate (probability) for dropping an activation value when\n", - " applying dropout within a block.\n", - " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", - " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", - " way to save memory and apply consistent masks to activation vectors at\n", - " different sequence positions.\n", - " ff_activation: Type of activation function at the end of each block; must be\n", - " an activation-type subclass of `Layer`.\n", - " ff_dropout: Stochastic rate (probability) for dropping an activation value\n", - " when applying dropout after the FF dense layer.\n", - " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks\n", - " ff_use_sru: int or pair of ints; if \u003e 0, we use this many SRU layers\n", - " in addition to the feed-forward block (second int specifies sru size)\n", - " ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block\n", - " with this sparsity\n", - " center_layernorm: whether to use centering in LayerNorm (default) or if\n", - " to skip it, which is known as RMS normalization.\n", - " mode: If `'train'`, each block will include dropout; else, it will pass all\n", - " values through unaltered.\n", - " use_bfloat16: whether to use bfloat16 for weights (default: False).\n", - " ff_sparsity_type: string, if ff_sparsity \u003e0,\n", - " use SparseFF if ff_sparsity_type=`'1inN'` and\n", - " use BlockSparseFF if ff_sparsity_type=`'Block'`\n", - " use SwitchSparseFF if ff_sparsity_type=`'Switch'`\n", - "\n", - " Returns:\n", - " A list of layers which maps vectors to vectors.\n", - " \"\"\"\n", - " if ff_sparsity and ff_sparsity_type == '1inN':\n", - " temperature, quant_prob = 0.1, 0.3\n", - " if isinstance(ff_sparsity, str):\n", - " # This is hacky but used to pass ff_sparsity in yaml sweep files.\n", - " ff_sparsity = [(float(x) if '.' in x else int(x))\n", - " for x in ff_sparsity.split()]\n", - " if isinstance(ff_sparsity, (list, tuple)):\n", - " if len(ff_sparsity) == 2:\n", - " n_elements_in_block, d_lowrank = ff_sparsity\n", - " else:\n", - " n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity\n", - " else:\n", - " assert isinstance(ff_sparsity, int)\n", - " n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity\n", - " ff = SparseFF(\n", - " d_ff,\n", - " n_elements_in_block=n_elements_in_block,\n", - " d_lowrank=d_lowrank,\n", - " temperature=temperature,\n", - " quant_prob=quant_prob,\n", - " use_bfloat16=use_bfloat16,\n", - " mode=mode,\n", - " dropout_rate=dropout,\n", - " dropout_shared_axes=dropout_shared_axes,\n", - " ff_chunk_size=ff_chunk_size)\n", - " elif ff_sparsity and ff_sparsity_type == 'Block':\n", - " ff = BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)\n", - " elif ff_sparsity and ff_sparsity_type == 'Switch':\n", - " ff = SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)\n", - " else:\n", - " ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout,\n", - " use_bfloat16, mode)\n", - " res = [tl.LayerNorm(center=center_layernorm), ff]\n", - " if ff_sparsity_type != '1inN' or ff_sparsity == 0:\n", - " # SparseFF has Dropout and BatchLeadingAxes built-in.\n", - " res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes,\n", - " mode=mode))\n", - " if ff_chunk_size \u003e 0:\n", - " res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size))\n", - " if ff_use_sru:\n", - " if isinstance(ff_use_sru, (list, tuple)):\n", - " sru_n_layers, sru_n_units = ff_use_sru\n", - " else:\n", - " sru_n_layers, sru_n_units = ff_use_sru, 32\n", - " sru = [SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)]\n", - " block = [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units)\n", - " ] + sru + [tl.Dense(d_model)]\n", - " res = tl.Residual(block, shortcut=res)\n", - " return [res]\n", - "\n", - "\n", - "def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal,\n", - " masked, attention_dropout, output_dropout,\n", - " attention_chunk_size, mode):\n", - " \"\"\"Runs the supplied attention layer.\"\"\"\n", - " try:\n", - " attention = attention_type(\n", - " n_heads=n_heads,\n", - " d_qk=d_qk,\n", - " d_v=d_v,\n", - " causal=causal,\n", - " masked=masked,\n", - " output_dropout=output_dropout,\n", - " attention_dropout=attention_dropout,\n", - " mode=mode)\n", - " except TypeError: # No d_qk arguments in less advanced layers.\n", - " attention = attention_type(\n", - " d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode)\n", - " return tl.Chunk(attention, attention_chunk_size)\n", - "\n", - "\n", - "def PositionalEncoder(mode,\n", - " dropout=None,\n", - " max_len=None,\n", - " pos_type=None,\n", - " pos_axial_shape=None,\n", - " pos_d_axial_embs=None,\n", - " pos_start_from_zero_prob=1.0,\n", - " pos_max_offset_to_add=0,\n", - " use_bfloat16=False):\n", - " \"\"\"Returns the positional encoding layer depending on the arguments.\n", - "\n", - " Args:\n", - " mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder\n", - " block will include dropout; else, it will pass all values through\n", - " unaltered.\n", - " dropout: Stochastic rate (probability) for dropping an activation\n", - " value when applying dropout after the embedding block.\n", - " max_len: Maximum symbol length for positional encoding.\n", - " pos_type: string, the type of positional embeddings to use.\n", - " pos_axial_shape: tuple of ints: input shape to use for the axial position\n", - " encoding. If unset, axial position encoding is disabled.\n", - " pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.\n", - " Tuple length must match pos_axial_shape, and values must sum to d_model.\n", - " pos_start_from_zero_prob: how often to start from 0 during training,\n", - " (if 1.0, we always start from position 0, if less, we randomize).\n", - " pos_max_offset_to_add: maximum offset to add to positions during training\n", - " when randomizing; this offset plus input length must still be less than\n", - " max_len for all training examples.\n", - " use_bfloat16: If `True`, use bfloat16 weights instead of the default\n", - " float32; this can save memory but may (rarely) lead to numerical issues.\n", - "\n", - " Returns:\n", - " A layer that will do the positional encoding.\n", - " \"\"\"\n", - " if not pos_type:\n", - " positional_encoding = tl.PositionalEncoding(\n", - " max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16,\n", - " start_from_zero_prob=pos_start_from_zero_prob,\n", - " max_offset_to_add=pos_max_offset_to_add, mode=mode)\n", - " elif pos_type == 'sin-cos':\n", - " positional_encoding = tl.SinCosPositionalEncoding(mode=mode)\n", - " elif pos_type == 'fixed-base':\n", - " positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)\n", - " elif pos_type == 'infinite':\n", - " positional_encoding = tl.InfinitePositionalEncoding(affine=False)\n", - " elif pos_type == 'infinite-affine':\n", - " positional_encoding = tl.InfinitePositionalEncoding()\n", - " elif pos_type == 'time-bin':\n", - " positional_encoding = tl.TimeBinPositionalEncoding()\n", - " else:\n", - " assert pos_d_axial_embs is not None\n", - " positional_encoding = tl.AxialPositionalEncoding(\n", - " shape=pos_axial_shape, d_embs=pos_d_axial_embs,\n", - " dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)),\n", - " dropout=dropout, mode=mode)\n", - "\n", - " return positional_encoding\n", - "\n", - "\n", - "def EmbeddingAndPositionalEncodings(input_vocab_size,\n", - " d_model,\n", - " mode,\n", - " embedding_dropout,\n", - " dropout_shared_axes,\n", - " max_len,\n", - " output_vocab_size=None,\n", - " pos_type=None,\n", - " pos_axial_shape=None,\n", - " pos_d_axial_embs=None,\n", - " pos_start_from_zero_prob=1.0,\n", - " pos_max_offset_to_add=0,\n", - " use_bfloat16=False):\n", - " \"\"\"Returns the embedder and positional encoder.\n", - "\n", - " Args:\n", - " input_vocab_size: Input vocabulary size -- each element of the input tensor\n", - " should be an integer in `range(vocab_size)`. These integers typically\n", - " represent token IDs from a vocabulary-based tokenizer.\n", - " d_model: Final dimension of tensors at most points in the model, including\n", - " the initial embedding output.\n", - " mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder\n", - " block will include dropout; else, it will pass all values through\n", - " unaltered.\n", - " embedding_dropout: Stochastic rate (probability) for dropping an activation\n", - " value when applying dropout after the embedding block.\n", - " dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing\n", - " along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful\n", - " way to save memory and apply consistent masks to activation vectors at\n", - " different sequence positions.\n", - " max_len: Maximum symbol length for positional encoding.\n", - " output_vocab_size: If specified, gives the vocabulary size for the targets;\n", - " if None, then input and target integers (token IDs) are assumed to come\n", - " from the same vocabulary.\n", - " pos_type: string, the type of positional embeddings to use.\n", - " pos_axial_shape: tuple of ints: input shape to use for the axial position\n", - " encoding. If unset, axial position encoding is disabled.\n", - " pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.\n", - " Tuple length must match pos_axial_shape, and values must sum to d_model.\n", - " pos_start_from_zero_prob: how often to start from 0 during training,\n", - " (if 1.0, we always start from position 0, if less, we randomize).\n", - " pos_max_offset_to_add: maximum offset to add to positions during training\n", - " when randomizing; this offset plus input length must still be less than\n", - " max_len for all training examples.\n", - " use_bfloat16: If `True`, use bfloat16 weights instead of the default\n", - " float32; this can save memory but may (rarely) lead to numerical issues.\n", - "\n", - " Returns:\n", - " A tuple of (input encoder, output encoder, output vocab size used).\n", - " \"\"\"\n", - " # tokens --\u003e vectors\n", - " def Embedder(vocab_size, embedding_mode):\n", - " if vocab_size is not None:\n", - " embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16)\n", - " else:\n", - " embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16)\n", - " return [\n", - " embedding,\n", - " tl.Dropout(rate=embedding_dropout,\n", - " shared_axes=dropout_shared_axes,\n", - " mode=embedding_mode),\n", - " ]\n", - "\n", - " # NOTE: Positional encodings are not shared between encoder and decoder.\n", - "\n", - " # Since encoder doesn't run stepwise, we do not use predict mode there.\n", - " encoder_mode = 'eval' if mode == 'predict' else mode\n", - " in_embedder = Embedder(input_vocab_size, encoder_mode)\n", - " in_encoder = in_embedder + [\n", - " PositionalEncoder(encoder_mode,\n", - " dropout=embedding_dropout,\n", - " max_len=max_len,\n", - " pos_type=pos_type,\n", - " pos_axial_shape=pos_axial_shape,\n", - " pos_d_axial_embs=pos_d_axial_embs,\n", - " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", - " pos_max_offset_to_add=pos_max_offset_to_add,\n", - " use_bfloat16=use_bfloat16)\n", - " ]\n", - "\n", - " # If output_vocab_size is None, we reuse the same embedding matrix, otherwise\n", - " # we initialize one.\n", - " assert input_vocab_size or output_vocab_size\n", - " if output_vocab_size is None:\n", - " out_embedder = in_embedder\n", - " else:\n", - " out_embedder = Embedder(output_vocab_size, mode)\n", - "\n", - " out_encoder = out_embedder + [\n", - " PositionalEncoder(mode,\n", - " dropout=embedding_dropout,\n", - " max_len=max_len,\n", - " pos_type=pos_type,\n", - " pos_axial_shape=pos_axial_shape,\n", - " pos_d_axial_embs=pos_d_axial_embs,\n", - " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", - " pos_max_offset_to_add=pos_max_offset_to_add,\n", - " use_bfloat16=use_bfloat16)\n", - " ]\n", - "\n", - " # Set this to the value actually used.\n", - " if output_vocab_size is None:\n", - " output_vocab_size = input_vocab_size\n", - "\n", - " if input_vocab_size is None:\n", - " in_encoder = tl.AssertFunction('...a-\u003e...b', in_encoder)\n", - " else:\n", - " in_encoder = tl.AssertFunction('...-\u003e...d', in_encoder)\n", - " out_encoder = tl.AssertFunction('...-\u003e...d', out_encoder)\n", - "\n", - " return in_encoder, out_encoder, output_vocab_size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2D3dQi9Q2bO7" - }, - "outputs": [], - "source": [ - "def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,\n", - " n_heads, attention_type, dropout, ff_activation,\n", - " ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity,\n", - " attention_chunk_size, n_attention_layers=1,\n", - " n_feedforward_layers=1, center_layernorm=True,\n", - " use_bfloat16=False, mode='train'):\n", - " \"\"\"Reversible transformer decoder layer.\n", - "\n", - " Args:\n", - " d_model: int: depth of embedding\n", - " d_ff: int: depth of feed-forward layer\n", - " d_attention_key: int: depth of key vector for each attention head\n", - " d_attention_value: int: depth of value vector for each attention head\n", - " n_heads: int: number of attention heads\n", - " attention_type: subclass of tl.BaseCausalAttention: attention class to use\n", - " dropout: float: dropout rate (how much to drop out)\n", - " ff_activation: the non-linearity in feed-forward layer\n", - " ff_dropout: the dropout rate in feed-forward layer\n", - " ff_use_sru: int; if \u003e 0, we use this many SRU layers instead of feed-forward\n", - " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks\n", - " ff_sparsity: int, if \u003e 0 use sparse feed-forward block with this sparsity\n", - " attention_chunk_size: int, if \u003e 0 run attention chunked at this size\n", - " n_attention_layers: how many residual causal attention layers should we\n", - " have before the feed-forward block (default: 1, the standard block)\n", - " n_feedforward_layers: how many FFNN layers should we have (default 1).\n", - " center_layernorm: whether to use centering in LayerNorm (default) or if\n", - " to skip it, which is known as RMS normalization.\n", - " use_bfloat16: whether to use bfloat16 for weights (default: False).\n", - " mode: str: 'train' or 'eval'\n", - "\n", - "\n", - " Returns:\n", - " the layer.\n", - " \"\"\"\n", - " # pylint: disable=g-complex-comprehension\n", - " def _Attn():\n", - " return ApplyAttentionLayer(\n", - " attention_type, d_model, n_heads, d_attention_key,\n", - " d_attention_value, True, False, dropout, dropout,\n", - " attention_chunk_size, mode)\n", - "\n", - " def _FF():\n", - " return FeedForwardWithOptions(\n", - " d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,\n", - " ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,\n", - " mode, use_bfloat16)\n", - "\n", - " def _attention_half_residual():\n", - " return [\n", - " tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),\n", - " attention_layer=_Attn(),\n", - " name='ReversibleHalfResidualDecoderAttn'),\n", - " tl.ReversibleSwap()\n", - " ]\n", - "\n", - " def _feed_forward():\n", - " return [\n", - " tl.ReversibleHalfResidual(_FF(),\n", - " name='ReversibleHalfResidualDecoderFF'),\n", - " tl.ReversibleSwap()\n", - " ]\n", - "\n", - " return ([_attention_half_residual() for _ in range(n_attention_layers)]\n", - " + [_feed_forward() for _ in range(n_feedforward_layers)])\n", - "\n", - "\n", - "def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,\n", - " ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0,\n", - " attention_chunk_size=0, center_layernorm=True,\n", - " use_bfloat16=False, use_two_swaps_per_block=True,\n", - " mode='train'):\n", - " \"\"\"Returns a list of layers that implements a Terraformer encoder block.\n", - "\n", - " The input to the layer is a pair, (activations, mask), where the mask was\n", - " created from the original source tokens to prevent attending to the padding\n", - " part of the input.\n", - "\n", - " Args:\n", - " d_model: int: depth of embedding\n", - " d_ff: int: depth of feed-forward layer\n", - " n_heads: int: number of attention heads\n", - " attention_type: subclass of tl.BaseCausalAttention: attention class to use\n", - " dropout: float: dropout rate (how much to drop out)\n", - " ff_activation: the non-linearity in feed-forward layer\n", - " ff_dropout: the dropout rate in feed-forward layer\n", - " ff_use_sru: int; if \u003e 0, we use this many SRU layers instead of feed-forward\n", - " ff_chunk_size: int; if \u003e 0, chunk feed-forward into this-sized chunks\n", - " ff_sparsity: int, if \u003e 0 use sparse feed-forward block with this sparsity\n", - " attention_chunk_size: int, if \u003e 0 run attention chunked at this size\n", - " center_layernorm: whether to use centering in LayerNorm (default) or if\n", - " to skip it, which is known as RMS normalization.\n", - " use_bfloat16: whether to use bfloat16 for weights (default: False)\n", - " use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder\n", - " block, otherwise use only one swap.\n", - " mode: str: 'train' or 'eval'\n", - "\n", - " Returns:\n", - " A list of layers that maps (activations, mask) to (activations, mask).\n", - " \"\"\"\n", - " if mode == 'predict':\n", - " # Mode 'predict' means that the decoder should be run one token at a time.\n", - " # The encoder only ever runs over full sequences, which is why it's switched\n", - " # to 'eval' mode instead.\n", - " mode = 'eval'\n", - "\n", - " def _Attn():\n", - " return ApplyAttentionLayer(\n", - " attention_type=attention_type, d_model=d_model, n_heads=n_heads,\n", - " d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False,\n", - " attention_dropout=dropout, output_dropout=dropout,\n", - " attention_chunk_size=attention_chunk_size, mode=mode)\n", - "\n", - " def _FF():\n", - " return FeedForwardWithOptions(\n", - " d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,\n", - " ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,\n", - " mode, use_bfloat16)\n", - "\n", - " attention = _Attn()\n", - " if attention.n_out == 2:\n", - " attention = tl.Serial(\n", - " tl.Parallel([], _InsertAxes12()),\n", - " attention,\n", - " tl.Select([0], n_in=2)\n", - " )\n", - "\n", - " def _attention_half_residual():\n", - " return [\n", - " tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),\n", - " attention_layer=attention,\n", - " name='ReversibleHalfResidualEncoderAttn'),\n", - " tl.ReversibleSwap()\n", - " ]\n", - "\n", - " def _feed_forward():\n", - " layers = [\n", - " tl.ReversibleHalfResidual(_FF(),\n", - " name='ReversibleHalfResidualEncoderFF')\n", - " ]\n", - " if use_two_swaps_per_block:\n", - " layers.append(tl.ReversibleSwap())\n", - " return layers\n", - "\n", - " return _attention_half_residual() + _feed_forward()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ITiWrbEnAZyb" - }, - "outputs": [], - "source": [ - "# Arg shapes: (B, L1, H), (B, L2, H), (B, L1).\n", - "def _ConcatWithPadding(vec_e, vec_d, mask_e):\n", - " \"\"\"Concatenate with padding: see the ConcatWithPadding layer for details.\"\"\"\n", - " # pylint: disable=invalid-name\n", - " B, L1, H = vec_e.shape\n", - " L2 = vec_d.shape[1]\n", - " # pylint: enable=invalid-name\n", - "\n", - " if vec_d.shape != (B, L2, H):\n", - " raise ValueError(f'Shape of decoder vector, {vec_d.shape}, does not'\n", - " f' equal {(B, L2, H)}.')\n", - " if mask_e.shape != (B, L1):\n", - " raise ValueError(f'Shape of encoder mask, {mask_e.shape}, does not'\n", - " f' equal {(B, L1)}.')\n", - "\n", - " def _UpdateRow(x):\n", - " # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)\n", - " row_e, row_d, row_mask_e = x\n", - " # final_row - (L1+L2, H)\n", - " final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)\n", - " # Find the last real token/vector of the encoder.\n", - " e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)\n", - " # Starting after that index, update with the decoder row.\n", - " zero = jnp.array(0, dtype=e_idx.dtype) # avoid int32/int64 mismatch\n", - " return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))\n", - "\n", - " return fastmath.map(_UpdateRow, [vec_e, vec_d, mask_e])\n", - "\n", - "\n", - "def _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d):\n", - " \"\"\"Strip concatenate with padding: see the layer below for details.\"\"\"\n", - " # pylint: disable=invalid-name\n", - " B, L, H = vec_ed.shape\n", - " L1 = tok_e.shape[1]\n", - " L2 = tok_d.shape[1]\n", - " # pylint: enable=invalid-name\n", - " if L != L1 + L2:\n", - " raise ValueError(f'Length from encoder-decoder vectors ({L}) does not'\n", - " f' equal sum of lengths from encoder ({L1}) and decoder'\n", - " f' ({L2}).')\n", - " if tok_e.shape != (B, L1):\n", - " raise ValueError(f'Shape of encoder tokens, {tok_e.shape}, does not'\n", - " f' equal {(B, L1)}.')\n", - " if tok_d.shape != (B, L2):\n", - " raise ValueError(f'Shape of decoder tokens, {tok_d.shape}, does not'\n", - " f' equal {(B, L2)}.')\n", - "\n", - " def _UpdateRow(x):\n", - " # (L, H), (L1, H) \u0026 (L2, H)\n", - " row_ed, row_e, _ = x\n", - " mask_e = row_e != 0\n", - " len_e = jnp.sum(mask_e, dtype=jnp.int32)\n", - " # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`\n", - " # and pick up (L2, H) tensor slice from there.\n", - " zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch\n", - " return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))\n", - "\n", - " return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])\n", - "\n", - "\n", - "class StripFromConcatenateWithPadding(tl.Layer):\n", - " \"\"\"Strips out the leading encoder tokens from the concatenated array.\"\"\"\n", - "\n", - " def __init__(self, mode='train'):\n", - " super().__init__(n_in=3, n_out=1)\n", - " self._mode = mode\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Sets layer-specific internal state.\"\"\"\n", - " del input_signature\n", - " self.state = jnp.array(0, dtype=jnp.int32)\n", - "\n", - " def forward(self, inputs):\n", - " vec_ed, tok_e, tok_d = inputs\n", - "\n", - " # In training/eval mode or at the first step predict mode i.e. when\n", - " # state.shape is (), i.e. at first step, we do the actual compuration\n", - " if self._mode != 'predict' or not self.state.shape:\n", - " # Now state.shape will not evaluate to false.\n", - " self.state = self.state.reshape((1,))\n", - " return _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d)\n", - "\n", - " # In predict mode and on subsequent steps (i.e. after the first step) vec_ed\n", - " # is actually vec_d, since no concatenation happened at all.\n", - " return vec_ed\n", - "\n", - "\n", - "class ConcatWithPadding(tl.ReversibleLayer):\n", - " \"\"\"Concatenates two length padded (B, L, H) arrays (of different lenghts).\"\"\"\n", - "\n", - " def __init__(self, mode='train'):\n", - " super().__init__(n_in=5, n_out=3)\n", - " self._mode = mode\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Sets layer-specific internal state.\"\"\"\n", - " del input_signature\n", - " self.state = jnp.array(0, dtype=jnp.int32)\n", - "\n", - " def forward(self, inputs):\n", - " vec_e, vec_d, mask_e, tok_e, tok_d = inputs\n", - "\n", - " # In training/eval mode or at the first step predict mode i.e. when\n", - " # state.shape is (), i.e. at first step, we return the concatenated output.\n", - " if self._mode != 'predict' or not self.state.shape:\n", - " # Now state.shape will not evaluate to false.\n", - " self.state = self.state.reshape((1,))\n", - " return _ConcatWithPadding(vec_e, vec_d, mask_e), tok_e, tok_d\n", - "\n", - " # In predict mode and on subsequent steps (i.e. after the first step) we\n", - " # don't concatenate anymore, but just return the decoder vector.\n", - " return vec_d, tok_e, tok_d\n", - "\n", - " def reverse(self, output, weights=(), state=(), new_state=(), rng=None):\n", - " del state, new_state, rng, weights\n", - " assert self._mode != 'predict', 'cannot reverse in predict mode'\n", - " vecs_ed, toks_e, toks_d = output\n", - " vecs_d = _StripFromConcatenateWithPadding(vecs_ed, toks_e, toks_d)\n", - " mask_e = (toks_e != 0)\n", - " mask_e_float = mask_e.astype(jnp.float32)\n", - " vecs_e = vecs_ed[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", - " return vecs_e, vecs_d, mask_e, toks_e, toks_d\n", - "\n", - "\n", - "class ConcatWithPadding2(tl.ReversibleLayer):\n", - " \"\"\"Concatenate with padding operating on pairs to combine with rev-nets.\"\"\"\n", - "\n", - " def __init__(self, mode='train'):\n", - " super().__init__(n_in=6, n_out=4)\n", - " self._mode = mode\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Sets layer-specific internal state.\"\"\"\n", - " del input_signature\n", - " self.state = jnp.array(0, dtype=jnp.int32)\n", - "\n", - " def forward(self, inputs):\n", - " vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d = inputs\n", - "\n", - " # In training/eval mode or at the first step predict mode i.e. when\n", - " # state.shape is (), i.e. at first step, we return the concatenated output.\n", - " if self._mode != 'predict' or not self.state.shape:\n", - " # Now state.shape will not evaluate to false.\n", - " self.state = self.state.reshape((1,))\n", - " # Calculate mask and concat_with_padding on the pairs.\n", - " vecs_ed1 = _ConcatWithPadding(vecs_e1, vecs_d, mask_e)\n", - " vecs_ed2 = _ConcatWithPadding(vecs_e2, vecs_d, mask_e)\n", - " return vecs_ed1, vecs_ed2, toks_e, toks_d\n", - "\n", - " # In predict mode and on subsequent steps (i.e. after the first step) we\n", - " # don't concatenate anymore, but just return the decoder vector.\n", - " return vecs_d, vecs_d, toks_e, toks_d\n", - "\n", - " def reverse(self, output, weights=(), state=(), new_state=(), rng=None):\n", - " del state, new_state, rng, weights\n", - " assert self._mode != 'predict', 'cannot reverse in predict mode'\n", - " vecs_ed1, vecs_ed2, toks_e, toks_d = output\n", - " vecs_d = _StripFromConcatenateWithPadding(vecs_ed1, toks_e, toks_d)\n", - " mask_e = (toks_e != 0)\n", - " mask_e_float = mask_e.astype(jnp.float32)\n", - " vecs_e1 = vecs_ed1[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", - " vecs_e2 = vecs_ed2[:, :toks_e.shape[1], :] * mask_e_float[:, :, None]\n", - " return vecs_e1, vecs_e2, vecs_d, mask_e, toks_e, toks_d" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4FPVnsq8Ersd" - }, - "outputs": [], - "source": [ - "def Terraformer(input_vocab_size,\n", - " output_vocab_size=None,\n", - " d_model=512,\n", - " d_ff=2048,\n", - " d_attention_key=None,\n", - " d_attention_value=None,\n", - " n_encoder_layers=6,\n", - " n_decoder_layers=6,\n", - " n_heads=8,\n", - " dropout=0.1,\n", - " max_len=2048,\n", - " encoder_attention_type=tl.SelfAttention,\n", - " encoder_decoder_attention_type=tl.SelfAttention,\n", - " pos_type='fixed-base',\n", - " pos_axial_shape=(),\n", - " pos_d_axial_embs=None,\n", - " pos_start_from_zero_prob=1.0,\n", - " pos_max_offset_to_add=0,\n", - " ff_activation=tl.Relu,\n", - " ff_use_sru=(1, 32),\n", - " ff_chunk_size=0,\n", - " ff_dropout=None,\n", - " ff_sparsity=32,\n", - " loss_sparsity_type='mult',\n", - " loss_sparsity=0,\n", - " loss_d_lowrank=0,\n", - " loss_sparsity_prob=None,\n", - " attention_chunk_size=0,\n", - " n_layers_forget=0,\n", - " forget_dense=True,\n", - " n_decoder_attention_layers=2,\n", - " use_bfloat16=False,\n", - " reversible_encoder=False,\n", - " use_two_swaps_per_encoder_block=True,\n", - " center_layernorm=True,\n", - " half_before_layer=None,\n", - " double_after_layer=None,\n", - " mode='train'):\n", - " \"\"\"Returns a highly configurable Terraformer encoder-decoder model.\n", - "\n", - " This model maps paired text sequences (source and target) to float-valued\n", - " losses. If ``input_vocab_size`` is not ``None``, the layer takes\n", - " two input sequences:\n", - "\n", - " - inputs (2):\n", - "\n", - " - source: 2-D int array representing a batch of text strings via token\n", - " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", - " where sequence_length \u003c= ``max_len``. Array elements are in\n", - " ``range(input_vocab_size)``, and 0 values mark padding positions.\n", - "\n", - " - target: 2-D int array representing a batch of text strings via token\n", - " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", - " where sequence_length \u003c= ``max_len``. Array elements are in\n", - " ``range(output_vocab_size)``, and 0 values mark padding positions.\n", - "\n", - " - output: 1-D float array of losses; shape is `(batch_size)`.\n", - "\n", - " If ``input_vocab_size`` is ``None``, the layer takes three input sequences:\n", - "\n", - " - inputs (3):\n", - "\n", - " - source: 3-D float array representing a batch of already-embedded text\n", - " strings; shape is `(batch_size, sequence_length, d_model)`, where\n", - " sequence_length \u003c= ``max_len``.\n", - "\n", - " - mask: 2-D int array representing active versus masked positions; 0\n", - " values mark masked (padding) positions.\n", - "\n", - " - target: 2-D int array representing a batch of text strings via token\n", - " IDs plus padding markers; shape is `(batch_size, sequence_length)`,\n", - " where sequence_length \u003c= ``max_len``. Array elements are in\n", - " ``range(output_vocab_size)``, and 0 values mark padding positions.\n", - "\n", - " - output: 1-D float array of losses; shape is `(batch_size)`.\n", - "\n", - " Args:\n", - " input_vocab_size: Input vocabulary size -- each element of the input tensor\n", - " should be an integer in ``range(vocab_size)``. These integers typically\n", - " represent token IDs from a vocabulary-based tokenizer.\n", - " output_vocab_size: If specified, gives the vocabulary size for the targets;\n", - " if ``None``, then input and target integers (token IDs) are assumed to\n", - " come from the same vocabulary.\n", - " d_model: Last/innermost dimension of activation arrays at most points in\n", - " the model, including the initial embedding output.\n", - " d_ff: Last/innermost dimension of special (typically wider)\n", - " :py:class:`Dense` layer in the feedforward part of each encoder block.\n", - " d_attention_key: Depth of key vectors in each attention head.\n", - " d_attention_value: Depth of value vectors in each attention head.\n", - " n_encoder_layers: Number of encoder blocks.\n", - " n_decoder_layers: Number of decoder blocks.\n", - " n_heads: Number of attention heads.\n", - " dropout: Stochastic rate (probability) for dropping an activation value\n", - " when applying dropout within encoder/decoder blocks. The same rate is\n", - " also used for attention dropout in encoder/decoder blocks.\n", - " max_len: Maximum symbol length for positional encoding.\n", - " encoder_attention_type: Type of attention to use in the encoder; must be\n", - " an attention-type subclass of :py:class:`trax.layers.Layer`.\n", - " encoder_decoder_attention_type: Type of attention to use in the decoder;\n", - " must be an attention-type subclass of :py:class:`trax.layers.Layer`.\n", - " pos_type: String indicating the type of positional embeddings to use.\n", - " pos_axial_shape: Shape (tuple of ints) to use for the axial position\n", - " encoding. If unset, axial position encoding is disabled.\n", - " pos_d_axial_embs: Tuple of ints specifying the depth of position embedding\n", - " for each axis. Tuple length must match ``pos_axial_shape``, and values\n", - " must sum to ``d_model``.\n", - " pos_start_from_zero_prob: Stochastic rate (probability) for starting\n", - " positional encoding at position 0 during training. If 1.0, always start\n", - " from position 0; if \u003c 1.0, the non-zero starts will be uniformly\n", - " distributed up to ``pos_max_offset_to_add``.\n", - " pos_max_offset_to_add: Maximum offset to add to positions during training\n", - " when randomizing. This offset plus input length must be less than\n", - " ``max_len`` for all training examples.\n", - " ff_activation: Type of activation function at the end of each block; must\n", - " be an activation-type subclass of :py:class:`trax.layers.Layer`.\n", - " ff_use_sru: If \u003e 0, use this number of SRU layers in place of feedforward\n", - " layers.\n", - " ff_chunk_size: If \u003e 0, chunk each feedforward layer into chunks of this\n", - " size.\n", - " ff_dropout: Stochastic rate (probability) for dropping an activation value\n", - " at feedforward nonlinearities.\n", - " ff_sparsity: If \u003e 0, use sparse feedforward blocks with this level of\n", - " sparsity.\n", - " loss_sparsity_type: String indicating the type of sparsity to used in loss\n", - " layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``,\n", - " use no sparsity.\n", - " loss_sparsity: If \u003e 0, use this level of sparsity in the loss layer.\n", - " loss_d_lowrank: If \u003e 0, use a (low-rank) intermediate layer, with this\n", - " dimension, in the loss.\n", - " loss_sparsity_prob: Stochastic rate (probability) for using the sparse\n", - " version of the loss. If ``None``, use the sparse version exclusively.\n", - " attention_chunk_size: If \u003e 0, compute attention using chunks of this size.\n", - " n_layers_forget: How often to have a forgetting block between layers.\n", - " forget_dense: If True, use :py:class:`Dense` instances as forget layers;\n", - " else use no-ops.\n", - " n_decoder_attention_layers: Number of attention layers in a decoder block.\n", - " use_bfloat16: If True, use bfloat16 for weights; else use float32.\n", - " reversible_encoder: If True, make the encoder be reversible.\n", - " use_two_swaps_per_encoder_block: If True, ensure that there is a an even\n", - " number of swaps across the encoder.\n", - " center_layernorm: If True, use centering in :py:class:`LayerNorm` (the\n", - " default); else omit centering (which is known as RMS normalization).\n", - " half_before_layer: If not None, specifies an n'th layer such that all\n", - " layers before the n'th use half the normal values for ``d_model`` and\n", - " ``d_ff``.\n", - " double_after_layer: If not None, specifies an n'th layer such that all\n", - " layers after the n'th use double the normal values for ``d_model`` and\n", - " ``d_ff``.\n", - " mode: If ``'train'``, include dropout in each encoder/decoder block; else\n", - " dropout layers have no effect.\n", - "\n", - " Returns:\n", - " A Terraformer encoder-decoder as a layer that maps from target and source\n", - " text sequences to a scalar loss.\n", - " \"\"\"\n", - " if mode == 'predict':\n", - " portal_mask = _PortalInput()\n", - " else:\n", - " portal_mask = None\n", - "\n", - " # Set default dimensions for attention head key and value sizes.\n", - " if (d_model / 2) % n_heads != 0:\n", - " raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')\n", - " if d_attention_key is None:\n", - " d_attention_key = d_model // n_heads\n", - " if d_attention_value is None:\n", - " d_attention_value = d_model // n_heads\n", - "\n", - " # Set values of d_model, d_ff and d_qkv for the first stage.\n", - " d_model1, d_ff1 = d_model, d_ff\n", - " d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value\n", - " if half_before_layer:\n", - " d_model1, d_ff1 = d_model / 2, d_ff / 2\n", - " d_attention_key1 = d_attention_key / 2\n", - " d_attention_value1 = d_attention_value / 2\n", - "\n", - " # Set values of d_model, d_ff and d_qkv for the final stage.\n", - " d_model2, d_ff2 = d_model, d_ff\n", - " d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value\n", - " if double_after_layer:\n", - " d_model2, d_ff2 = d_model * 2, d_ff * 2\n", - " d_attention_key2 = d_attention_key * 2\n", - " d_attention_value2 = d_attention_value * 2\n", - "\n", - " # Vector embeddings.\n", - " in_encoder, out_encoder, output_vocab_size = (\n", - " EmbeddingAndPositionalEncodings(\n", - " input_vocab_size,\n", - " d_model1,\n", - " mode,\n", - " dropout,\n", - " [-2], # dropout_shared_axes\n", - " max_len,\n", - " output_vocab_size=output_vocab_size,\n", - " pos_type=pos_type,\n", - " pos_axial_shape=pos_axial_shape,\n", - " pos_d_axial_embs=pos_d_axial_embs,\n", - " pos_start_from_zero_prob=pos_start_from_zero_prob,\n", - " pos_max_offset_to_add=pos_max_offset_to_add,\n", - " use_bfloat16=use_bfloat16)\n", - " )\n", - "\n", - " def _EncoderBlock():\n", - " return EncoderBlock(\n", - " d_model1,\n", - " d_ff1,\n", - " n_heads,\n", - " encoder_attention_type,\n", - " dropout=dropout,\n", - " ff_activation=ff_activation,\n", - " ff_dropout=ff_dropout,\n", - " ff_use_sru=ff_use_sru,\n", - " ff_chunk_size=ff_chunk_size,\n", - " ff_sparsity=ff_sparsity,\n", - " attention_chunk_size=attention_chunk_size,\n", - " center_layernorm=center_layernorm,\n", - " use_bfloat16=use_bfloat16,\n", - " use_two_swaps_per_block=use_two_swaps_per_encoder_block,\n", - " mode=mode)\n", - "\n", - " def _Encoder(): # vec_e mask_e tok_e tok_d tok_d\n", - " layers = [\n", - " tl.ReversibleSelect([0, 0]),\n", - " _ReversibleSerialForget(\n", - " [_EncoderBlock() for _ in range(n_encoder_layers)],\n", - " d_model1,\n", - " n_layers_forget,\n", - " forget_dense)\n", - " ]\n", - " if not reversible_encoder:\n", - " layers += [\n", - " _XYAvg(),\n", - " tl.Dense(d_model1, use_bfloat16=use_bfloat16),\n", - " tl.LayerNorm(),\n", - " ]\n", - " if mode == 'predict':\n", - " return tl.Cache(tl.Serial(layers))\n", - " else:\n", - " return tl.Serial(layers)\n", - "\n", - " if mode == 'predict':\n", - " global DotProductCausalAttention\n", - " DotProductCausalAttention.monkey_patched_mask = (\n", - " lambda x: portal_mask)\n", - " global _RememberPad\n", - " _RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access\n", - " lambda x: portal_mask)\n", - " global ScanSRUCell\n", - " originalScanSRUCell = ScanSRUCell\n", - " ScanSRUCell = functools.partial(ScanSRUCell,\n", - " monkey_patched_mask=portal_mask)\n", - "\n", - " decoder_blocks = []\n", - "\n", - " if isinstance(encoder_decoder_attention_type, (tuple, list)):\n", - " assert n_decoder_layers % len(encoder_decoder_attention_type) == 0\n", - " else:\n", - " encoder_decoder_attention_type = [encoder_decoder_attention_type]\n", - " for layer_idx in range(n_decoder_layers):\n", - " layer_attention_type = encoder_decoder_attention_type[\n", - " layer_idx % len(encoder_decoder_attention_type)]\n", - " # Grow d_model, d_ff, and d_qkv if requested.\n", - " d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1\n", - " if half_before_layer and layer_idx \u003e= half_before_layer:\n", - " d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value\n", - " if double_after_layer and layer_idx \u003e double_after_layer:\n", - " d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2\n", - " decoder_block = DecoderBlock(\n", - " d_m, d_f, d_k, d_v, n_heads,\n", - " attention_type=layer_attention_type,\n", - " dropout=dropout,\n", - " ff_activation=ff_activation,\n", - " ff_dropout=ff_dropout,\n", - " ff_use_sru=ff_use_sru,\n", - " ff_chunk_size=ff_chunk_size,\n", - " ff_sparsity=ff_sparsity,\n", - " attention_chunk_size=attention_chunk_size,\n", - " n_attention_layers=n_decoder_attention_layers,\n", - " center_layernorm=center_layernorm,\n", - " use_bfloat16=use_bfloat16,\n", - " mode=mode)\n", - " decoder_blocks.append(decoder_block)\n", - " if half_before_layer and layer_idx == half_before_layer - 1:\n", - " decoder_blocks.append(tl.ReversibleConcatenatePair())\n", - " if double_after_layer and layer_idx == double_after_layer:\n", - " decoder_blocks.append(tl.ReversibleConcatenatePair())\n", - "\n", - " if mode == 'predict':\n", - " # After initializing the decoder we can revert to original state of\n", - " # previously monkey-patched classes/functions.\n", - " DotProductCausalAttention.monkey_patched_mask = (\n", - " lambda x: None)\n", - " _RememberPad.monkey_patched_mask = (lambda x: None) # pylint: disable=protected-access\n", - " ScanSRUCell = originalScanSRUCell\n", - "\n", - " def _Loss():\n", - " return SparseDenseWithOptions(\n", - " output_vocab_size,\n", - " d_input=d_model2,\n", - " sparsity_type=loss_sparsity_type,\n", - " sparsity=loss_sparsity,\n", - " d_lowrank=loss_d_lowrank,\n", - " prob_sparse=loss_sparsity_prob,\n", - " use_bfloat16=use_bfloat16,\n", - " mode=mode)\n", - "\n", - " def _enc_dec_concat():\n", - " \"\"\"Layers to merge encoder and decoder.\"\"\"\n", - " if reversible_encoder:\n", - " return [\n", - " tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d\n", - " ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d\n", - " ]\n", - " else:\n", - " return [\n", - " tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d\n", - " ConcatWithPadding(mode=mode), # v_ed tok_e tok_d\n", - " tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d\n", - " ]\n", - "\n", - " def _inp_layers():\n", - " if input_vocab_size is not None:\n", - " return tl.AssertFunction(\n", - " 'bl,br-\u003ebld,bl,bl,br', # b: batch, l/r: enc/dec length, d: vec depth\n", - " tl.Serial( # tok_e tok_d\n", - " tl.Select([0, 0, 0, 1]),\n", - " tl.Parallel(in_encoder, [tl.PaddingMask(),\n", - " _RemoveAxes12()])\n", - " )) # vec_e mask_e tok_e tok_d\n", - " else:\n", - " # Input in this case is vec_e, mask_e, tok_d. Where all downstream\n", - " # operations expect tok_e, we give it instead mask_e, expecting that\n", - " # downstream ops only are looking for padding/not padding.\n", - " return tl.AssertFunction(\n", - " 'blf,bl,br-\u003ebld,bl,bl,br', # f: in-feature depth, d: out-vector depth\n", - " tl.Serial( # vec_e mask_e tok_d\n", - " tl.Select([0, 1, 1, 2]),\n", - " tl.Parallel(in_encoder, [], _AsTokenIDs())\n", - " )) # vec_e mask_e tok_e tok_d\n", - "\n", - " # Assemble and return the model.\n", - " return tl.Serial(\n", - " _inp_layers(), # vec_e mask_e tok_e tok_d\n", - " tl.Parallel([], portal_mask),\n", - "\n", - " tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss.\n", - "\n", - " # Embed in and out tokens; done together as weights may be shared.\n", - " tl.Parallel([], [], [], [tl.ShiftRight(mode=mode),\n", - " out_encoder]), # vec_e mask_e tok_e vec_d tok_d\n", - "\n", - " # Encode; then concat encoder and decoder, given encoder mask.\n", - " _Encoder(), # vec_e mask_e tok_e vec_d tok_d\n", - " _enc_dec_concat(),\n", - "\n", - " # Run decoder blocks.\n", - " _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,\n", - " forget_dense), # vec_ed1 vec_ed2 tok_e tok_d\n", - " _XYAvg(), # vec_ed tok_e tok_d\n", - " tl.LayerNorm(), # vec_ed tok_e tok_d\n", - "\n", - " # Separate out the encoder part from the concatenated vector,\n", - " # then compute loss.\n", - " tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d\n", - " StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d\n", - " _Loss(), # vec_d tok_d\n", - " )\n", - "\n", - "\n", - "def _InsertAxes12():\n", - " \"\"\"Returns a layer that inserts two internal size-1 axes into an array.\"\"\"\n", - " return tl.Fn('InsertAxes12',\n", - " lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])))\n", - "\n", - "\n", - "def _RemoveAxes12():\n", - " \"\"\"Returns a layer that removes two internal size-1 axes from an array.\"\"\"\n", - " return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2)))\n", - "\n", - "\n", - "def _AsTokenIDs():\n", - " \"\"\"Returns a layer that makes mask values look like token ID ints.\"\"\"\n", - " return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32))\n", - "\n", - "\n", - "def _XYAvg():\n", - " \"\"\"Returns a layer that computes the element-wise average of two arrays.\"\"\"\n", - " return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0)\n", - "\n", - "\n", - "def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True):\n", - " \"\"\"ReversibleSerial but with a forgetting block every n_layers.\"\"\"\n", - " if not n_layers or len(layers) \u003c= n_layers + 1:\n", - " return tl.ReversibleSerial(layers)\n", - " layers1, layers2 = layers[:n_layers], layers[n_layers:]\n", - "\n", - " if forget_dense:\n", - " forgetting_layer = tl.Serial(\n", - " _XYAvg(),\n", - " tl.Dense(d_model),\n", - " tl.Dup(),\n", - " )\n", - " else:\n", - " forgetting_layer = tl.Select([0, 1])\n", - "\n", - " return tl.Serial(\n", - " tl.ReversibleSerial(layers1),\n", - " forgetting_layer,\n", - " _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense)\n", - " )\n", - "\n", - "\n", - "def _ConvertToNaNsOnAnyZero():\n", - " def _convert_to_nans(x, y):\n", - " # if all values in y are non-zeros, return x; otherwise return 0s\n", - " return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y\n", - " return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2)\n", - "\n", - "\n", - "class _PortalInput(tl.Layer):\n", - " \"\"\"Portal input for monkey-patching of mask in predict mode.\"\"\"\n", - "\n", - " def __init__(self):\n", - " super().__init__(name='_PortalInput', n_out=1, n_in=1)\n", - " self._portal_output = _PortalOutput(self)\n", - "\n", - " def forward(self, x):\n", - " if isinstance(x, (list, tuple)):\n", - " x = x[0]\n", - " self.state = (x,)\n", - " return x\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Initializes this layer's weights.\"\"\"\n", - " if isinstance(input_signature, (list, tuple)):\n", - " input_signature = input_signature[0]\n", - " self.state = (jnp.zeros(input_signature.shape),)\n", - "\n", - " def get_value(self):\n", - " return self.state[0]\n", - "\n", - " def get_layer(self):\n", - " return self._portal_output\n", - "\n", - "\n", - "class _PortalOutput(tl.Layer):\n", - " \"\"\"Portal input for monkey-patching of mask in predict mode.\"\"\"\n", - "\n", - " def __init__(self, portal_input):\n", - " super().__init__(name='_PortalOutput', n_out=1, n_in=0)\n", - " self._portal_input = portal_input\n", - "\n", - " def forward(self, x):\n", - " return self._portal_input.get_value()\n", - "\n", - " def get_value(self):\n", - " return self._portal_input.get_value()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E0Rq71ML6XZu" - }, - "source": [ - "## Example training\n", - "\n", - "Here we show how the Terraformer can be trained on example inputs. The results for the paper were obtained with identical training but for different configurations of inputs and models, which are specified in the attached config files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oI5XQcltJmeE" - }, - "outputs": [], - "source": [ - "model = Terraformer(\n", - " input_vocab_size=12,\n", - " # small model for testing\n", - " d_model=128,\n", - " d_ff=512,\n", - " n_encoder_layers=2,\n", - " n_decoder_layers=2,\n", - " # setting sparsity\n", - " ff_use_sru=(1, 32),\n", - " ff_sparsity=32,\n", - " loss_sparsity=4,\n", - " encoder_decoder_attention_type=functools.partial(\n", - " MultiplicativeConvCausalAttention, sparsity=16, length_kernel_size=3),\n", - " )\n", - "\n", - "copy_inputs = trax.data.inputs.simple_sequence_copy_inputs(\n", - " vocab_size=10, batch_size=32, train_length=32,\n", - " eval_min_length=16, eval_max_length=32)\n", - "\n", - "# Training task.\n", - "train_task = training.TrainTask(\n", - " labeled_data=copy_inputs.train_stream(1),\n", - " loss_layer=tl.WeightedCategoryCrossEntropy(),\n", - " optimizer=trax.optimizers.Adam(0.0001),\n", - " n_steps_per_checkpoint=5,\n", - ")\n", - "\n", - "# Evaluaton task.\n", - "eval_task = training.EvalTask(\n", - " labeled_data=copy_inputs.eval_stream(1),\n", - " metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],\n", - " n_eval_batches=2 # For less variance in eval numbers.\n", - ")\n", - "\n", - "# Training loop saves checkpoints to output_dir.\n", - "output_dir = os.path.expanduser('~/output_dir/')\n", - "!rm -rf {output_dir}\n", - "training_loop = training.Loop(model,\n", - " train_task,\n", - " eval_tasks=[eval_task],\n", - " output_dir=output_dir)\n", - "\n", - "# Run 2000 steps (batches).\n", - "training_loop.run(20)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "Terraformer from scratch.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1mdBTceBJGE_yff5FvRAByrisUsc88Nw7", - "timestamp": 1635190861529 - } - ], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/examples/earlystopping.ipynb b/trax/examples/earlystopping.ipynb deleted file mode 100644 index 86f614eaf..000000000 --- a/trax/examples/earlystopping.ipynb +++ /dev/null @@ -1,752 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "display_name": "trax", - "language": "python", - "name": "trax" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "colab": { - "name": "earlystopping.ipynb", - "provenance": [], - "collapsed_sections": [] - } - }, - "cells": [ - { - "cell_type": "code", - "metadata": { - "id": "6NWA5uxOmBVz" - }, - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:20:49.409509Z", - "start_time": "2021-10-21T19:20:49.407066Z" - }, - "id": "r9WfLoXBP6Hc" - }, - "source": [ - "import warnings\n", - "warnings.filterwarnings('ignore')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.763601Z", - "start_time": "2021-10-21T19:20:49.410970Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:06.758715Z", - "iopub.status.busy": "2020-12-04T15:35:06.758464Z", - "iopub.status.idle": "2020-12-04T15:35:37.278247Z", - "shell.execute_reply": "2020-12-04T15:35:37.277568Z", - "shell.execute_reply.started": "2020-12-04T15:35:06.758651Z" - }, - "id": "OLUMD0tPP6Hd" - }, - "source": [ - "import collections\n", - "import functools\n", - "import os\n", - "import sys\n", - "import time\n", - "\n", - "import numpy as np\n", - "import psutil\n", - "import trax\n", - "from absl import logging\n", - "from trax import fastmath\n", - "from trax import layers as tl\n", - "from trax.fastmath import numpy as jnp\n", - "from trax.supervised import training" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.778216Z", - "start_time": "2021-10-21T19:21:15.765216Z" - }, - "id": "nG4CK5NsP6He" - }, - "source": [ - "class MyLoop(training.Loop):\n", - " def __init__(\n", - " self,\n", - " *args, **kwargs\n", - " ):\n", - " super().__init__(\n", - " *args, **kwargs\n", - " )\n", - " self._stop_training = False\n", - "\n", - " def run(self, n_steps=1):\n", - " \"\"\"Just add a logic to break the loop to ``training.Loop.run`` when \n", - " the early stopping condition is satisfied. \n", - " \"\"\"\n", - " \n", - " with self._open_summary_writers() as (\n", - " train_summary_writers,\n", - " eval_summary_writers,\n", - " ):\n", - " process = psutil.Process(os.getpid())\n", - " loss_acc, step_acc = 0.0, 0\n", - " start_time = time.time()\n", - " optimizer_metrics_acc = collections.defaultdict(float)\n", - " for i in range(n_steps):\n", - " prev_task_index = self._which_task(self._step)\n", - " self._step += 1\n", - " task_index = self._which_task(self._step)\n", - " task_changed = task_index != prev_task_index\n", - "\n", - " if task_changed:\n", - " loss_acc, step_acc = 0.0, 0\n", - " \n", - " loss, optimizer_metrics = self._run_one_step(task_index, task_changed)\n", - "\n", - " optimizer_metrics, loss = fastmath.nested_map(\n", - " functools.partial(tl.mean_or_pmean, self._n_devices),\n", - " (optimizer_metrics, loss),\n", - " )\n", - "\n", - " loss_acc += loss\n", - " # Log loss every 50 steps, every step in memory-efficient trainer.\n", - " if self._step % 50 == 0 or self._use_memory_efficient_trainer:\n", - " self._log_step(\"Loss: %.4f\" % loss, stdout=False)\n", - " step_acc += 1\n", - " for metric_name, value in optimizer_metrics.items():\n", - " optimizer_metrics_acc[metric_name] += value\n", - "\n", - "\n", - " if self._checkpoint_at(self.step):\n", - " self.save_checkpoint(\"model\")\n", - " if self._permanent_checkpoint_at(self.step):\n", - " self.save_checkpoint(f\"model_{self.step}\")\n", - " if self._eval_at(self.step):\n", - " logging.info(\n", - " \"cpu memory use (MB): %.2f\",\n", - " process.memory_info().rss / float(1024 * 1024),\n", - " )\n", - " elapsed_time = time.time() - start_time\n", - " self._log_training_progress(\n", - " task=self._tasks[task_index],\n", - " total_loss=loss_acc,\n", - " n_steps=step_acc,\n", - " elapsed_time=elapsed_time,\n", - " optimizer_metrics=optimizer_metrics_acc,\n", - " summary_writer=train_summary_writers[task_index],\n", - " )\n", - " self.run_evals(eval_summary_writers)\n", - " loss_acc, step_acc = 0.0, 0\n", - " start_time = time.time()\n", - " optimizer_metrics_acc = collections.defaultdict(float)\n", - "\n", - " if self._checkpoint_at(self.step):\n", - " if self._checkpoint_low_metric is not None and self._at_lowest():\n", - " self.save_checkpoint(f\"lowest_{self._checkpoint_low_metric}\")\n", - " if self._checkpoint_high_metric is not None and self._at_highest():\n", - " self.save_checkpoint(f\"highest_{self._checkpoint_high_metric}\")\n", - " \n", - " \n", - " for callback in self._callbacks:\n", - " if callback.call_at(self.step):\n", - " if callback.__class__.__name__ == 'EarlyStopping':\n", - " #added to check for earlystopping callback after \n", - " # history was updated.\n", - " #callback.on_step_end execute before history was \n", - " #updated. \n", - " best_step = callback.on_step_begin_with_history(self.step)\n", - " \n", - " if not self._stop_training and self.step == n_steps:\n", - " self._log_step(\"Did not meet early stopping condition.\")\n", - " \n", - " \n", - " if self._stop_training:\n", - " # added to stop the training.\n", - " self._log_step(f\"Early stopping... \"\n", - " f\" the best step at {best_step}\")\n", - " break\n", - " \n", - " self._eval_model.weights = self._model.weights" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.800616Z", - "start_time": "2021-10-21T19:21:15.780224Z" - }, - "id": "rfncVhM7P6Hg" - }, - "source": [ - "def callback_earlystopper(\n", - " monitor=None, \n", - " min_delta=0, \n", - " patience=0, \n", - " mode=\"auto\", \n", - " restore_best_checkpoint=True\n", - "):\n", - " \"\"\"Wrap the EarlyStopping class into a callable.\n", - "\n", - " Returns an early stopping.\n", - "\n", - " Args:\n", - " monitor: Quantity to be monitored.\n", - "\n", - " min_delta: Minimum change in the monitored quantity\n", - " to qualify as an improvement, i.e. an absolute\n", - " change of less than min_delta, will count as no\n", - " improvement.\n", - "\n", - " patience: ``patience`` times ``n_steps_per_checkpoint`` will be\n", - " the total number of steps without improvement\n", - " after which training will be stopped.\n", - "\n", - " mode: One of ``{\"auto\", \"min\", \"max\"}``. In ``min``(``max``) mode,\n", - " training will stop when the quantity monitored has stopped\n", - " decreasing(increasing) during the number of steps assigned\n", - " in ``patience``; in ``\"auto\"``\n", - " mode, the direction is automatically inferred\n", - " from the name of the monitored quantity.\n", - "\n", - " restore_best_checkpoint: Whether to restore model from\n", - " the checkpoint with the best value of the monitored quantity.\n", - " If False, the model weights obtained at the last step of\n", - " training are used. If True and there is an early stopping,\n", - " the best checkpoint will be restored.\n", - " \"\"\"\n", - "\n", - " if mode not in [\"auto\", \"max\", \"min\"]:\n", - " self._loop._log_step(\n", - " f\"Early stopping mode='{mode}' is unknown, \" \"fallback to 'auto' mode\"\n", - " )\n", - " mode = \"auto\"\n", - "\n", - " class EarlyStopping:\n", - " \"\"\"Create a call back taht activates early stopping.\n", - "\n", - " Activate early stopping.\n", - " \"\"\"\n", - "\n", - " def __init__(self, loop):\n", - " \"\"\"Configures an early stopping.\n", - " This is inspired by keras.callbacks.EarlyStopping.\n", - "\n", - " Args:\n", - " loop: training ``Loop`` from the current training.\n", - "\n", - " \"\"\"\n", - "\n", - " self._loop = loop\n", - " self.monitor = monitor\n", - " self.min_delta = jnp.abs(min_delta)\n", - " self.patience = jnp.maximum(patience, 1)\n", - "\n", - " self.restore_best_checkpoint = restore_best_checkpoint\n", - "\n", - " if mode == \"min\":\n", - " self.monitor_op = jnp.less\n", - " elif mode == \"max\":\n", - " self.monitor_op = jnp.greater\n", - " else:\n", - " if self.monitor.endswith(\"Accuracy\"):\n", - " self.monitor_op = jnp.greater\n", - " else:\n", - " self.monitor_op = jnp.less\n", - "\n", - " if self.monitor_op == np.greater:\n", - " self.min_delta *= 1\n", - " else:\n", - " self.min_delta *= -1\n", - "\n", - " self.wait = 0\n", - " self.stopped_step = 1\n", - " self.best = jnp.inf if self.monitor_op == jnp.less else -jnp.inf\n", - " self.best_step = 1\n", - " self.best_checkpoint_path = None\n", - "\n", - " def _is_metric_exist(self):\n", - " metric_names = [\n", - " name\n", - " for eval_task in self._loop._eval_tasks\n", - " for name in eval_task.metric_names\n", - " ]\n", - " return self.monitor in metric_names\n", - "\n", - " def call_at(self, step):\n", - " return self._loop._eval_at(step)\n", - "\n", - " def on_step_begin(self, step):\n", - " if not self._is_metric_exist():\n", - " # Raise error if the monitor name is not in evaluation task.\n", - " self._loop._log_step(\n", - " f\"Early Stopping metric '{self.monitor}' \" \"is not in eval_tasks.\"\n", - " )\n", - " self._loop._log_step(\n", - " \"Select one of \" f\"them from here {self.metric_names}.\"\n", - " )\n", - "\n", - " raise SystemExit(\"Monitoring metric not found.\")\n", - "\n", - " def on_step_end(self, step):\n", - " pass\n", - "\n", - " def on_step_begin_with_history(self, step):\n", - " if self.restore_best_checkpoint and self.best_checkpoint_path is None:\n", - " self._loop.save_checkpoint(\"best_checkpoint\")\n", - " self.best_checkpoint_path = os.path.join(\n", - " self._loop._output_dir, \"best_checkpoint.pkl.gz\"\n", - " )\n", - "\n", - " self.wait += 1\n", - " current_step, current = self._get_monitor_value()\n", - "\n", - " if current is None:\n", - " return\n", - "\n", - " if self._is_improvement(current, self.best):\n", - " self.best = current\n", - " self.best_step = current_step\n", - " self._loop.save_checkpoint(\"best_checkpoint\")\n", - "\n", - " # reset wait\n", - " self.wait = 0\n", - "\n", - " if self.wait >= self.patience and step > 1:\n", - " self.stopped_step = current_step\n", - " self._loop._stop_training = True\n", - "\n", - " if (\n", - " self.restore_best_checkpoint\n", - " and self.best_checkpoint_path is not None\n", - " ):\n", - " self._loop.load_checkpoint(self.best_checkpoint_path)\n", - " self._loop._log_step(\n", - " f\"Best checkpoint was restored from Step {self.best_step}.\"\n", - " )\n", - "\n", - " return self.best_step\n", - "\n", - " def _is_improvement(self, monitor_value, reference_value):\n", - " return self.monitor_op(monitor_value - self.min_delta, reference_value)\n", - "\n", - " def _get_monitor_value(self):\n", - " step, monitor_value = self._loop.history.get(\n", - " \"eval\", \"metrics/\" + self.monitor\n", - " )[-1]\n", - " return step, monitor_value\n", - "\n", - " return EarlyStopping" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sJHUx_nSP6Hh" - }, - "source": [ - "## Linear Regression\n", - "## Generate data for linear model" - ] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.805691Z", - "start_time": "2021-10-21T19:21:15.802488Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:37.279761Z", - "iopub.status.busy": "2020-12-04T15:35:37.279529Z", - "iopub.status.idle": "2020-12-04T15:35:37.283375Z", - "shell.execute_reply": "2020-12-04T15:35:37.282592Z", - "shell.execute_reply.started": "2020-12-04T15:35:37.279738Z" - }, - "id": "dKYZQY-pP6Hi" - }, - "source": [ - "def get_data_linear():\n", - " while True:\n", - " x=np.random.randint(low=1, high=10) * 1.0\n", - " y=x * 2.0 - 1\n", - " yield (np.array([x]), np.array([y]))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.811465Z", - "start_time": "2021-10-21T19:21:15.807568Z" - }, - "id": "SCTZW1pBP6Hj" - }, - "source": [ - "data_linear = get_data_linear()\n", - "print(next(data_linear))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.815839Z", - "start_time": "2021-10-21T19:21:15.813113Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:37.292101Z", - "iopub.status.busy": "2020-12-04T15:35:37.291815Z", - "iopub.status.idle": "2020-12-04T15:35:37.296048Z", - "shell.execute_reply": "2020-12-04T15:35:37.295266Z", - "shell.execute_reply.started": "2020-12-04T15:35:37.292054Z" - }, - "id": "4pcAhWJMP6Hk" - }, - "source": [ - "data_pipeline = trax.data.Serial(trax.data.Batch(50), trax.data.AddLossWeights(),)\n", - "data_stream = data_pipeline(data_linear)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2vK15-1oP6Hl" - }, - "source": [ - "## Build a simple linear model" - ] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.821323Z", - "start_time": "2021-10-21T19:21:15.817944Z" - }, - "id": "xzN0oZBCP6Hl" - }, - "source": [ - "model_linear = tl.Serial(tl.Dense(1))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qi0bM41PP6Hl" - }, - "source": [ - "## Train a linear model" - ] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.849200Z", - "start_time": "2021-10-21T19:21:15.822986Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:37.302605Z", - "iopub.status.busy": "2020-12-04T15:35:37.302292Z", - "iopub.status.idle": "2020-12-04T15:35:37.311629Z", - "shell.execute_reply": "2020-12-04T15:35:37.311016Z", - "shell.execute_reply.started": "2020-12-04T15:35:37.302575Z" - }, - "id": "d0_9qZHVP6Hm" - }, - "source": [ - "# Use the same data_stream for both training and evaluation\n", - "train_task = training.TrainTask(\n", - " labeled_data=data_stream,\n", - " loss_layer=tl.L2Loss(),\n", - " optimizer=trax.optimizers.SGD(0.01),\n", - " n_steps_per_checkpoint=10,\n", - ")\n", - "\n", - "eval_task = training.EvalTask(\n", - " labeled_data=data_stream, metrics=[tl.L2Loss()], n_eval_batches=15,\n", - ")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R5ngyoYSP6Hm" - }, - "source": [ - "## Add early stopping function" - ] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.857420Z", - "start_time": "2021-10-21T19:21:15.854694Z" - }, - "id": "SKetNF4LP6Hm" - }, - "source": [ - "earlystopping=callback_earlystopper(monitor='L2Loss',min_delta=1e-4)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:15.993079Z", - "start_time": "2021-10-21T19:21:15.859169Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:37.313247Z", - "iopub.status.busy": "2020-12-04T15:35:37.313032Z", - "iopub.status.idle": "2020-12-04T15:35:37.442811Z", - "shell.execute_reply": "2020-12-04T15:35:37.442187Z", - "shell.execute_reply.started": "2020-12-04T15:35:37.313221Z" - }, - "id": "D2XjQO80P6Hn" - }, - "source": [ - "# Delete the training folder\n", - "!rm -r linear_model" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:16.491382Z", - "start_time": "2021-10-21T19:21:15.996076Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:37.444083Z", - "iopub.status.busy": "2020-12-04T15:35:37.443918Z", - "iopub.status.idle": "2020-12-04T15:35:39.043136Z", - "shell.execute_reply": "2020-12-04T15:35:39.042484Z", - "shell.execute_reply.started": "2020-12-04T15:35:37.444063Z" - }, - "id": "mCrc_bXZP6Hn" - }, - "source": [ - "model_linear = tl.Serial(tl.Dense(1))\n", - "training_loop = MyLoop(\n", - " model=model_linear, tasks=train_task, eval_tasks=[eval_task], output_dir=\"./linear_model\",\n", - " callbacks=[earlystopping]\n", - ")\n", - "# training_loop.save_checkpoint(f'step_{training_loop.step}')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:17.643821Z", - "start_time": "2021-10-21T19:21:16.492560Z" - }, - "id": "kFURD6T4P6Hn" - }, - "source": [ - "training_loop.run(1500)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lg_ONworP6Hn" - }, - "source": [ - "## Change patience \n", - "patience = 10 means it will wait for 10 x 10 = 100 steps (patience * n_steps_per_checkpoint ) to before making a decision to stop." - ] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:17.648180Z", - "start_time": "2021-10-21T19:21:17.645555Z" - }, - "id": "IStFKG7GP6Hn" - }, - "source": [ - "earlystopping=callback_earlystopper(monitor='L2Loss',patience=10,min_delta=1e-4)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:17.781483Z", - "start_time": "2021-10-21T19:21:17.650003Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:37.313247Z", - "iopub.status.busy": "2020-12-04T15:35:37.313032Z", - "iopub.status.idle": "2020-12-04T15:35:37.442811Z", - "shell.execute_reply": "2020-12-04T15:35:37.442187Z", - "shell.execute_reply.started": "2020-12-04T15:35:37.313221Z" - }, - "id": "pihrcvTtP6Ho" - }, - "source": [ - "# Delete the training folder\n", - "!rm -r linear_model" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:17.814600Z", - "start_time": "2021-10-21T19:21:17.783984Z" - }, - "execution": { - "iopub.execute_input": "2020-12-04T15:35:37.444083Z", - "iopub.status.busy": "2020-12-04T15:35:37.443918Z", - "iopub.status.idle": "2020-12-04T15:35:39.043136Z", - "shell.execute_reply": "2020-12-04T15:35:39.042484Z", - "shell.execute_reply.started": "2020-12-04T15:35:37.444063Z" - }, - "id": "UvjDLZd3P6Ho" - }, - "source": [ - "model_linear = tl.Serial(tl.Dense(1))\n", - "training_loop = MyLoop(\n", - " model=model_linear, tasks=train_task, eval_tasks=[eval_task], output_dir=\"./linear_model\",\n", - " callbacks=[earlystopping]\n", - ")\n", - "# training_loop.save_checkpoint(f'step_{training_loop.step}')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:23.010609Z", - "start_time": "2021-10-21T19:21:17.816246Z" - }, - "scrolled": false, - "id": "bAsft27BP6Ho" - }, - "source": [ - "training_loop.run(1500)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6HyIjZWBP6Ho" - }, - "source": [ - "## Make a prediction " - ] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2021-10-21T19:21:23.047842Z", - "start_time": "2021-10-21T19:21:23.012208Z" - }, - "execution": { - "iopub.execute_input": "2020-12-05T04:36:10.040691Z", - "iopub.status.busy": "2020-12-05T04:36:10.040407Z", - "iopub.status.idle": "2020-12-05T04:36:10.114322Z", - "shell.execute_reply": "2020-12-05T04:36:10.113606Z", - "shell.execute_reply.started": "2020-12-05T04:36:10.040657Z" - }, - "id": "d7bVzat7P6Ho" - }, - "source": [ - "test_data=np.array([[2.0],[3.0],[10.0],[44.0]])\n", - "model_linear(test_data)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "lxl9PhnKP6Hp" - }, - "source": [ - "" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/trax/examples/illustrated_wideresnet.ipynb b/trax/examples/illustrated_wideresnet.ipynb deleted file mode 100644 index b627d1296..000000000 --- a/trax/examples/illustrated_wideresnet.ipynb +++ /dev/null @@ -1,896 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - }, - "papermill": { - "duration": 10194.991178, - "end_time": "2020-12-01T04:22:46.481666", - "environment_variables": {}, - "exception": null, - "input_path": "__notebook__.ipynb", - "output_path": "__notebook__.ipynb", - "parameters": {}, - "start_time": "2020-12-01T01:32:51.490488", - "version": "2.1.0" - }, - "colab": { - "name": "illustrated-wideresnet.ipynb", - "provenance": [] - } - }, - "cells": [ - { - "cell_type": "code", - "metadata": { - "id": "A00Q5PP0j8ZH" - }, - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zTfrmAx5kBwR" - }, - "source": [ - "# Author\n", - "\n", - "SauravMaheshkar- [@MaheshkarSaurav](https://twitter.com/MaheshkarSaurav)" - ] - }, - { - "cell_type": "code", - "metadata": { - "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", - "_kg_hide-input": true, - "_kg_hide-output": true, - "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", - "execution": { - "iopub.execute_input": "2020-12-01T01:32:57.405689Z", - "iopub.status.busy": "2020-12-01T01:32:57.404445Z", - "iopub.status.idle": "2020-12-01T01:33:26.950549Z", - "shell.execute_reply": "2020-12-01T01:33:26.949786Z" - }, - "papermill": { - "duration": 29.585875, - "end_time": "2020-12-01T01:33:26.950713", - "exception": false, - "start_time": "2020-12-01T01:32:57.364838", - "status": "completed" - }, - "tags": [], - "id": "pgp28DB-j6ev" - }, - "source": [ - "%%capture\n", - "!pip install --upgrade trax" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.035652, - "end_time": "2020-12-01T01:33:27.018778", - "exception": false, - "start_time": "2020-12-01T01:33:26.983126", - "status": "completed" - }, - "tags": [], - "id": "uUPujeMDj6ew" - }, - "source": [ - "# Introduction" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.035743, - "end_time": "2020-12-01T01:33:27.097892", - "exception": false, - "start_time": "2020-12-01T01:33:27.062149", - "status": "completed" - }, - "tags": [], - "id": "x8qmALfWj6ew" - }, - "source": [ - "Prior to the introduction of [Wide Residual Networks](https://arxiv.org/pdf/1605.07146.pdf) (WRNs) by Sergey Zagoruyko and Nikos Komodakis, deep residual networks were shown to have a fractional increase in performance but at the cost of **doubling** the number of layers. This led to the problem of diminishing feature reuse and overall made the models slow to train. WRNs showed that having a wider residual network leads to better performance and increased the then SOTA results on CIFAR, SVHN and COCO. \n", - "\n", - "In this notebook we run through a simple demonstration of training a WideResnet on the `cifar10` dataset using the [Trax](https://github.com/google/trax) framework. Trax is an end-to-end library for deep learning that focuses on **clear code and speed**. It is actively used and maintained in the *Google Brain team*." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.031704, - "end_time": "2020-12-01T01:33:27.164144", - "exception": false, - "start_time": "2020-12-01T01:33:27.132440", - "status": "completed" - }, - "tags": [], - "id": "SAGHOSLHj6ew" - }, - "source": [ - "# Issues with Traditional Residual Networks" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l644xzpolDH6" - }, - "source": [ - "![Screenshot 2020-12-01 at 10.04.11 AM.png]()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.031242, - "end_time": "2020-12-01T01:33:27.288600", - "exception": false, - "start_time": "2020-12-01T01:33:27.257358", - "status": "completed" - }, - "tags": [], - "id": "Br6sY3Skj6ew" - }, - "source": [ - "Figure 1: *Various ResNet Blocks*" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.030469, - "end_time": "2020-12-01T01:33:27.350166", - "exception": false, - "start_time": "2020-12-01T01:33:27.319697", - "status": "completed" - }, - "tags": [], - "id": "1Aa3qJM8j6ew" - }, - "source": [ - "## Diminishing Feature Reuse" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029283, - "end_time": "2020-12-01T01:33:27.409595", - "exception": false, - "start_time": "2020-12-01T01:33:27.380312", - "status": "completed" - }, - "tags": [], - "id": "aByhtJHsj6ew" - }, - "source": [ - "A **Residual block with a identity mapping**, which allows us to train very deep networks is a **weakness**. As the gradient flows through the network there is nothing to force it to go through the residual block weights and thus it can avoid learning during training. This only a few blocks can run valuable representations or many blocks could share very little information with small contributions to the final goal. This problem was tried to be addressed using a special case of dropout applied to residual blocks in which an identity scalar weight is added to each residual block on which dropout is applied.\n", - "\n", - "As we are widening our residual blocks, this results in an increase in the number of parameters, and the authors decided to study the effects of dropout to regularize training and prevent overfitting. They argued that the dropout should be inserted between convolutional layers instead of being inserted in the identity part of the block and showed that this results in consistent gains, yielding new SOTA results." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029729, - "end_time": "2020-12-01T01:33:27.469022", - "exception": false, - "start_time": "2020-12-01T01:33:27.439293", - "status": "completed" - }, - "tags": [], - "id": "reNP-uCgj6ew" - }, - "source": [ - "The paper [Wide Residual Networks](https://arxiv.org/pdf/1605.07146.pdf) attemptsto answer the question of how wide deep residual networks should be and address the problem of training." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029851, - "end_time": "2020-12-01T01:33:27.529228", - "exception": false, - "start_time": "2020-12-01T01:33:27.499377", - "status": "completed" - }, - "tags": [], - "id": "RKfIYWqoj6ew" - }, - "source": [ - "# Residual Networks" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.030099, - "end_time": "2020-12-01T01:33:27.588818", - "exception": false, - "start_time": "2020-12-01T01:33:27.558719", - "status": "completed" - }, - "tags": [], - "id": "c7uSNYYuj6ew" - }, - "source": [ - "$\\large \n", - "x_{l+1} = x_l + \\mathbb{F}(x_l, W_l)\n", - "$\n", - "\n", - "\n", - "This is the representation of a Residual block with an identity mapping. \n", - "\n", - "* $x_{l+1}$ and $x_l$ represent the input and output of the $l$-th unit in the network\n", - "\n", - "* $\\mathbb{F}$ is a residual function\n", - "\n", - "* $W_l$ are the parameters" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029364, - "end_time": "2020-12-01T01:33:27.647782", - "exception": false, - "start_time": "2020-12-01T01:33:27.618418", - "status": "completed" - }, - "tags": [], - "id": "E-3X8obdj6ew" - }, - "source": [ - "Figure 1(a) and 1(c) represent the fundamental difference between the *basic* and the *basic-wide* blocks used." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.032155, - "end_time": "2020-12-01T01:33:27.709374", - "exception": false, - "start_time": "2020-12-01T01:33:27.677219", - "status": "completed" - }, - "tags": [], - "id": "mgf_2paVj6ew" - }, - "source": [ - "# Architecture" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Rc5ECr36lQSU" - }, - "source": [ - "![Screenshot 2020-12-01 at 10.04.48 AM.png]()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029146, - "end_time": "2020-12-01T01:33:27.827298", - "exception": false, - "start_time": "2020-12-01T01:33:27.798152", - "status": "completed" - }, - "tags": [], - "id": "CQvqUPVWj6ew" - }, - "source": [ - "This is the basic structure of Wide Residual Networks. In the papers the size of `conv1` was fixed in all the experiments, while the \"widening\" factor `k` was experimented with in the next three groups. Here `k` is the. widening factor which multiplies the number of features in convolutional layers\n", - "\n", - "Let B(M) denote various residual block structures, where M is a list with the kernel sizes of the convoutional layers in a block.\n", - "The following architectures were used in experimentation:-\n", - "\n", - "* B(3,3) - The Original \"basic\" block. (Figure 1(a))\n", - "* B(3,1,3) - Same as basic but with a extra 1x1 layer in between\n", - "* B(1,3,1) - For Bottleneck (Figure 1(b))\n", - "* B(1,3) - Having Alternative 1x1-3x3 convolutions\n", - "* B(3,1) - Having Alternative 3x3-1x1 convolutions\n", - "* B(3,1,1) - A Network-in-Network style block" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029591, - "end_time": "2020-12-01T01:33:27.886716", - "exception": false, - "start_time": "2020-12-01T01:33:27.857125", - "status": "completed" - }, - "tags": [], - "id": "DyOIi4wmj6ew" - }, - "source": [ - "# Experimental Results" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RA8b4sDzlbJk" - }, - "source": [ - "![Screenshot 2020-12-01 at 10.05.33 AM.png]()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.032181, - "end_time": "2020-12-01T01:33:28.007010", - "exception": false, - "start_time": "2020-12-01T01:33:27.974829", - "status": "completed" - }, - "tags": [], - "id": "5YyR95dEj6ex" - }, - "source": [ - "*Test error (%, median over 5 runs) on CIFAR-10 of residual networks with k = 1 and different block types. Time represents one training epoch*" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029054, - "end_time": "2020-12-01T01:33:28.070359", - "exception": false, - "start_time": "2020-12-01T01:33:28.041305", - "status": "completed" - }, - "tags": [], - "id": "VsUwjYPsj6ex" - }, - "source": [ - "The paper highlights that the block structure B(3,3) beats B(3,1) and B(3,1,3) by a little margin. " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.028867, - "end_time": "2020-12-01T01:33:28.128575", - "exception": false, - "start_time": "2020-12-01T01:33:28.099708", - "status": "completed" - }, - "tags": [], - "id": "rlc-MxfZj6ex" - }, - "source": [ - "# Key Takeaways" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.030514, - "end_time": "2020-12-01T01:33:28.188471", - "exception": false, - "start_time": "2020-12-01T01:33:28.157957", - "status": "completed" - }, - "tags": [], - "id": "G0Ix8SDaj6ex" - }, - "source": [ - "The paper highlights a method, giving a total improvement of 4.4% over ResNet-1001 and showing that:-\n", - "\n", - "* widening consistently improves performance across residual networks of different depth\n", - "\n", - "* incresing both depth and width helps until the number of parameters becomes too high and stronger regularization is required\n", - "\n", - "* there doesn't seem to be a regularization effect from very high depth in residual networks as wide networks with the same number of parameters as thin ones can learn same or better representations. Furthermore, wide networks can successfully learn with a 2 or more times larger number of parameters than thin ones, which would require doubling the depth of thin networks, making them infeasibly expensive to train." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.028903, - "end_time": "2020-12-01T01:33:28.247192", - "exception": false, - "start_time": "2020-12-01T01:33:28.218289", - "status": "completed" - }, - "tags": [], - "id": "dXm6AEpdj6ex" - }, - "source": [ - "# Importing Libraries" - ] - }, - { - "cell_type": "code", - "metadata": { - "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", - "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", - "execution": { - "iopub.execute_input": "2020-12-01T01:33:28.314745Z", - "iopub.status.busy": "2020-12-01T01:33:28.313774Z", - "iopub.status.idle": "2020-12-01T01:34:08.826788Z", - "shell.execute_reply": "2020-12-01T01:34:08.825868Z" - }, - "papermill": { - "duration": 40.550443, - "end_time": "2020-12-01T01:34:08.826937", - "exception": false, - "start_time": "2020-12-01T01:33:28.276494", - "status": "completed" - }, - "tags": [], - "id": "lJ_OiL_wj6ex" - }, - "source": [ - "import trax\n", - "from trax import layers as tl\n", - "from trax.supervised import training\n", - "\n", - "# Trax offers the WideResnet architecture in it's models module\n", - "from trax.models.resnet import WideResnet\n", - "\n", - "trax.fastmath.set_backend('tensorflow-numpy')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029394, - "end_time": "2020-12-01T01:34:08.888184", - "exception": false, - "start_time": "2020-12-01T01:34:08.858790", - "status": "completed" - }, - "tags": [], - "id": "P9PPQOMOj6ex" - }, - "source": [ - "# Downloading Dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.02981, - "end_time": "2020-12-01T01:34:08.947487", - "exception": false, - "start_time": "2020-12-01T01:34:08.917677", - "status": "completed" - }, - "tags": [], - "id": "9Uto6Pgej6ex" - }, - "source": [ - "Trax offers a rich collection of [.data](https://trax-ml.readthedocs.io/en/latest/trax.data.html) API's to create input pipelines. One of which is the [`trax.data.TFDS()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.TFDS) which returns an iterator of numpy arrays representing the dataset. \n", - "\n", - "If you'd like to learn more about the trax.data API's please checkout the notebook [here](https://www.kaggle.com/sauravmaheshkar/trax-data-explained) where I explain the most common API's in a in-depth manner" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-01T01:34:09.017896Z", - "iopub.status.busy": "2020-12-01T01:34:09.017050Z", - "iopub.status.idle": "2020-12-01T01:35:05.274951Z", - "shell.execute_reply": "2020-12-01T01:35:05.275644Z" - }, - "papermill": { - "duration": 56.298163, - "end_time": "2020-12-01T01:35:05.275849", - "exception": false, - "start_time": "2020-12-01T01:34:08.977686", - "status": "completed" - }, - "tags": [], - "id": "ihYJyhJoj6ex" - }, - "source": [ - "%%capture\n", - "train_stream = trax.data.TFDS('cifar10', keys=('image', 'label'), train=True)()\n", - "eval_stream = trax.data.TFDS('cifar10', keys=('image', 'label'), train=False)()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.031813, - "end_time": "2020-12-01T01:35:05.346382", - "exception": false, - "start_time": "2020-12-01T01:35:05.314569", - "status": "completed" - }, - "tags": [], - "id": "tqEE8bLXj6ex" - }, - "source": [ - "# Batch Generator" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029693, - "end_time": "2020-12-01T01:35:05.405910", - "exception": false, - "start_time": "2020-12-01T01:35:05.376217", - "status": "completed" - }, - "tags": [], - "id": "3X4Yy6P9j6ex" - }, - "source": [ - "Here, we create pre-processing pipelines, by using the [`Shuffle()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.inputs.Shuffle), [`Batch()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.inputs.Batch) and [`AddLossWeights()`](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.inputs.AddLossWeights) functions from the trax.data API" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-01T01:35:05.475968Z", - "iopub.status.busy": "2020-12-01T01:35:05.474997Z", - "iopub.status.idle": "2020-12-01T01:35:05.477690Z", - "shell.execute_reply": "2020-12-01T01:35:05.478353Z" - }, - "papermill": { - "duration": 0.042864, - "end_time": "2020-12-01T01:35:05.478534", - "exception": false, - "start_time": "2020-12-01T01:35:05.435670", - "status": "completed" - }, - "tags": [], - "id": "BvR6FwLxj6ex" - }, - "source": [ - "train_data_pipeline = trax.data.Serial(\n", - " trax.data.Shuffle(),\n", - " trax.data.Batch(64),\n", - " trax.data.AddLossWeights(),\n", - ")\n", - "\n", - "train_batches_stream = train_data_pipeline(train_stream)\n", - "\n", - "eval_data_pipeline = trax.data.Serial(\n", - " trax.data.Batch(64),\n", - " trax.data.AddLossWeights(),\n", - ")\n", - "\n", - "eval_batches_stream = eval_data_pipeline(eval_stream)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.030008, - "end_time": "2020-12-01T01:35:05.539520", - "exception": false, - "start_time": "2020-12-01T01:35:05.509512", - "status": "completed" - }, - "tags": [], - "id": "ZFSkOQIGj6ex" - }, - "source": [ - "# Model Architecture" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.030691, - "end_time": "2020-12-01T01:35:05.601093", - "exception": false, - "start_time": "2020-12-01T01:35:05.570402", - "status": "completed" - }, - "tags": [], - "id": "m3GvLNa1j6ex" - }, - "source": [ - "We use the `WideResnet` architecture defined in `trax.models.resnet` module. By Default the \"widening factor\" is set to 1, thus we experiment with two values, 1 and 2. The Architecture doesn't contain a [`tl.LogSoftmax()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.LogSoftmax) function so we add it to our model using the [`tl.Serial()`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.combinators.Serial) combinator" - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-01T01:35:05.678498Z", - "iopub.status.busy": "2020-12-01T01:35:05.670632Z", - "iopub.status.idle": "2020-12-01T01:35:05.682035Z", - "shell.execute_reply": "2020-12-01T01:35:05.681344Z" - }, - "papermill": { - "duration": 0.050465, - "end_time": "2020-12-01T01:35:05.682174", - "exception": false, - "start_time": "2020-12-01T01:35:05.631709", - "status": "completed" - }, - "tags": [], - "id": "ZYMPoH0yj6ex" - }, - "source": [ - "thin_model = tl.Serial(\n", - " WideResnet(widen_factor = 1),\n", - " tl.LogSoftmax()\n", - ")\n", - "\n", - "wide_model = tl.Serial(\n", - " WideResnet(widen_factor = 2),\n", - " tl.LogSoftmax()\n", - ")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.030998, - "end_time": "2020-12-01T01:35:05.744169", - "exception": false, - "start_time": "2020-12-01T01:35:05.713171", - "status": "completed" - }, - "tags": [], - "id": "7_6akNEVj6ex" - }, - "source": [ - "When we have our model and the data, we use [`trax.supervised.training`](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#module-trax.supervised.training) to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you." - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-01T01:35:05.819134Z", - "iopub.status.busy": "2020-12-01T01:35:05.818314Z", - "iopub.status.idle": "2020-12-01T01:35:06.454470Z", - "shell.execute_reply": "2020-12-01T01:35:06.453640Z" - }, - "papermill": { - "duration": 0.678771, - "end_time": "2020-12-01T01:35:06.454617", - "exception": false, - "start_time": "2020-12-01T01:35:05.775846", - "status": "completed" - }, - "tags": [], - "id": "HPzQ5xJHj6ex" - }, - "source": [ - "train_task = training.TrainTask(\n", - " labeled_data=train_batches_stream,\n", - " loss_layer=tl.CrossEntropyLoss(),\n", - " optimizer=trax.optimizers.Adam(0.01),\n", - " n_steps_per_checkpoint=1000,\n", - ")\n", - "\n", - "eval_task = training.EvalTask(\n", - " labeled_data=eval_batches_stream,\n", - " metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],\n", - " n_eval_batches=20,\n", - ")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-01T01:35:06.524108Z", - "iopub.status.busy": "2020-12-01T01:35:06.522607Z", - "iopub.status.idle": "2020-12-01T02:27:48.982071Z", - "shell.execute_reply": "2020-12-01T02:27:48.981188Z" - }, - "papermill": { - "duration": 3162.496721, - "end_time": "2020-12-01T02:27:48.982225", - "exception": false, - "start_time": "2020-12-01T01:35:06.485504", - "status": "completed" - }, - "tags": [], - "id": "eaj_Y4FPj6ex", - "outputId": "55396574-ad00-4112-f560-06268d7efe21" - }, - "source": [ - "training_loop = training.Loop(thin_model, \n", - " train_task, \n", - " eval_tasks=[eval_task], \n", - " output_dir='./thin_model')\n", - "\n", - "training_loop.run(5000)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - "Step 1: Total number of trainable weights: 295866\n", - "Step 1: Ran 1 train steps in 18.14 secs\n", - "Step 1: train CrossEntropyLoss | 2.49250388\n", - "Step 1: eval CrossEntropyLoss | 2.38415594\n", - "Step 1: eval Accuracy | 0.14687500\n", - "\n", - "Step 1000: Ran 999 train steps in 648.60 secs\n", - "Step 1000: train CrossEntropyLoss | 1.56840193\n", - "Step 1000: eval CrossEntropyLoss | 1.32664271\n", - "Step 1000: eval Accuracy | 0.51484375\n", - "\n", - "Step 2000: Ran 1000 train steps in 616.66 secs\n", - "Step 2000: train CrossEntropyLoss | 1.17271507\n", - "Step 2000: eval CrossEntropyLoss | 1.11862110\n", - "Step 2000: eval Accuracy | 0.59843750\n", - "\n", - "Step 3000: Ran 1000 train steps in 612.61 secs\n", - "Step 3000: train CrossEntropyLoss | 1.00170410\n", - "Step 3000: eval CrossEntropyLoss | 0.99056525\n", - "Step 3000: eval Accuracy | 0.63593750\n", - "\n", - "Step 4000: Ran 1000 train steps in 606.28 secs\n", - "Step 4000: train CrossEntropyLoss | 0.89905792\n", - "Step 4000: eval CrossEntropyLoss | 0.88028392\n", - "Step 4000: eval Accuracy | 0.69140625\n", - "\n", - "Step 5000: Ran 1000 train steps in 608.31 secs\n", - "Step 5000: train CrossEntropyLoss | 0.82710254\n", - "Step 5000: eval CrossEntropyLoss | 0.94539436\n", - "Step 5000: eval Accuracy | 0.66640625\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-01T02:27:49.199831Z", - "iopub.status.busy": "2020-12-01T02:27:49.199007Z", - "iopub.status.idle": "2020-12-01T04:22:46.209990Z", - "shell.execute_reply": "2020-12-01T04:22:46.209229Z" - }, - "papermill": { - "duration": 6897.182439, - "end_time": "2020-12-01T04:22:46.210173", - "exception": false, - "start_time": "2020-12-01T02:27:49.027734", - "status": "completed" - }, - "tags": [], - "id": "3NvZ7a1Kj6ez", - "outputId": "84ea1d39-0fb6-4892-85fc-d340f562de3c" - }, - "source": [ - "training_loop = training.Loop(wide_model, \n", - " train_task, \n", - " eval_tasks=[eval_task], \n", - " output_dir='./wide_model')\n", - "\n", - "training_loop.run(5000)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - "Step 1: Total number of trainable weights: 1167242\n", - "Step 1: Ran 1 train steps in 16.27 secs\n", - "Step 1: train CrossEntropyLoss | 2.34743023\n", - "Step 1: eval CrossEntropyLoss | 2.30782272\n", - "Step 1: eval Accuracy | 0.15703125\n", - "\n", - "Step 1000: Ran 999 train steps in 1385.02 secs\n", - "Step 1000: train CrossEntropyLoss | 1.55927908\n", - "Step 1000: eval CrossEntropyLoss | 1.37345315\n", - "Step 1000: eval Accuracy | 0.48046875\n", - "\n", - "Step 2000: Ran 1000 train steps in 1348.40 secs\n", - "Step 2000: train CrossEntropyLoss | 1.20171380\n", - "Step 2000: eval CrossEntropyLoss | 1.00589269\n", - "Step 2000: eval Accuracy | 0.64375000\n", - "\n", - "Step 3000: Ran 1000 train steps in 1361.62 secs\n", - "Step 3000: train CrossEntropyLoss | 0.98751819\n", - "Step 3000: eval CrossEntropyLoss | 0.92049764\n", - "Step 3000: eval Accuracy | 0.66640625\n", - "\n", - "Step 4000: Ran 1000 train steps in 1355.99 secs\n", - "Step 4000: train CrossEntropyLoss | 0.86016709\n", - "Step 4000: eval CrossEntropyLoss | 0.88372944\n", - "Step 4000: eval Accuracy | 0.69062500\n", - "\n", - "Step 5000: Ran 1000 train steps in 1356.59 secs\n", - "Step 5000: train CrossEntropyLoss | 0.76069421\n", - "Step 5000: eval CrossEntropyLoss | 0.76499336\n", - "Step 5000: eval Accuracy | 0.72968750\n" - ], - "name": "stdout" - } - ] - } - ] -} \ No newline at end of file diff --git a/trax/examples/semantic_segmentation.ipynb b/trax/examples/semantic_segmentation.ipynb deleted file mode 100644 index fdf2d9066..000000000 --- a/trax/examples/semantic_segmentation.ipynb +++ /dev/null @@ -1,836 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Author- [@yashkhasbage25](https://github.com/yashkhasbage25 \"Yash Khasbage\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AZWS_qfMw1B3" - }, - "source": [ - "# Semantic Segmentation \n", - "Semantic Segmentation is a computer vision task that divides an image into segments, identifying what parts of image belong to what object. \n", - "\n", - "In this tutorial, we will train a Convolutional neural network to segment images. \n", - "\n", - "Briefly, we will discuss\n", - "1. downloading an image segmentation dataset from kaggle\n", - "2. processing the dataset according to our need\n", - "3. Create a dataloader\n", - "4. Creating a Custom loss function\n", - "5. Creating TrainTask and EvalTask \n", - "6. Create a Neural Network and train it\n", - "\n", - "(You need to have a kaggle account for downloading the dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0AjBi0zHE4pv" - }, - "source": [ - "Assuming that you already have a kaggle account, we will first begin by creating a kaggle API token. \n", - "If you don't have API token, follow these steps to create a new one:\n", - "1. Go to the Account section of kaggle website, after you login. \n", - "2. Click \"Expire API Token\" and then \"Create New API Token\". A file \"kaggle.json\" will be downloaded. \n", - "3. Using \"Choose files\" button, upload the kaggle.json file. The API token present in this file will help us download the dataset directly from kaggle. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 72, - "resources": { - "http://localhost:8080/nbextensions/google.colab/files.js": { - "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", - "headers": [ - [ - "content-type", - "application/javascript" - ] - ], - "ok": true, - "status": 200, - "status_text": "" - } - } - }, - "id": "dzXwMVFPf2qR", - "outputId": "0d776c36-8dbd-4242-e933-0b73abe243b0" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " Upload widget is only available when the cell has been executed in the\n", - " current browser session. Please rerun this cell to enable.\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Saving kaggle.json to kaggle.json\n" - ] - } - ], - "source": [ - "! pip install -q kaggle\n", - "from google.colab import files\n", - "files.upload() # upload kaggle.json\n", - "! mkdir ~/.kaggle" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sskHBrFsM4Yl" - }, - "source": [ - "We need to place kaggle.json at ~/.kaggle and also change its file permissions. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TT61H-y8gg4E", - "outputId": "333b528f-e768-496d-9593-09e4036703c0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Warning: Looks like you're using an outdated API Version, please consider updating (server 1.5.10 / client 1.5.4)\n", - "ref title size lastUpdated downloadCount \n", - "---------------------------------------------------------- ------------------------------------------------ ----- ------------------- ------------- \n", - "gpreda/reddit-vaccine-myths Reddit Vaccine Myths 215KB 2021-03-13 10:04:34 315 \n", - "dhruvildave/wikibooks-dataset Wikibooks Dataset 1GB 2021-02-18 10:08:27 362 \n", - "crowww/a-large-scale-fish-dataset A Large Scale Fish Dataset 3GB 2021-02-17 16:10:44 280 \n", - "imsparsh/musicnet-dataset MusicNet Dataset 22GB 2021-02-18 14:12:19 133 \n", - "fatiimaezzahra/famous-iconic-women Famous Iconic Women 838MB 2021-02-28 14:56:00 93 \n", - "nickuzmenkov/nih-chest-xrays-tfrecords NIH Chest X-rays TFRecords 11GB 2021-03-09 04:49:23 59 \n", - "nickuzmenkov/ranzcr-clip-kfold-tfrecords RANZCR CLiP KFold TFRecords 2GB 2021-02-21 13:29:51 23 \n", - "mathurinache/the-lj-speech-dataset The LJ Speech Dataset 3GB 2021-02-15 09:19:54 26 \n", - "imsparsh/accentdb-core-extended AccentDB - Core & Extended 6GB 2021-02-17 14:22:54 14 \n", - "coloradokb/dandelionimages DandelionImages 4GB 2021-02-19 20:03:47 25 \n", - "stuartjames/lights LightS: Light Specularity Dataset 18GB 2021-02-18 14:32:26 14 \n", - "landrykezebou/lvzhdr-tone-mapping-benchmark-dataset-tmonet LVZ-HDR Tone Mapping Benchmark Dataset (TMO-Net) 24GB 2021-03-01 05:03:40 16 \n", - "shivamb/netflix-shows Netflix Movies and TV Shows 1MB 2021-01-18 16:20:26 109680 \n", - "gpreda/covid-world-vaccination-progress COVID-19 World Vaccination Progress 134KB 2021-03-13 10:03:58 26651 \n", - "arashnic/hr-analytics-job-change-of-data-scientists HR Analytics: Job Change of Data Scientists 295KB 2020-12-07 00:25:10 13389 \n", - "michau96/restaurant-business-rankings-2020 Restaurant Business Rankings 2020 16KB 2021-01-30 14:20:45 7401 \n", - "jsphyg/weather-dataset-rattle-package Rain in Australia 4MB 2020-12-11 10:26:12 37626 \n", - "gpreda/reddit-wallstreetsbets-posts Reddit WallStreetBets Posts 10MB 2021-03-13 10:05:08 2920 \n", - "datasnaek/youtube-new Trending YouTube Video Statistics 201MB 2019-06-03 00:56:47 130671 \n", - "google/tinyquickdraw QuickDraw Sketches 11GB 2018-04-18 19:38:04 2900 \n" - ] - } - ], - "source": [ - "! cp kaggle.json ~/.kaggle/\n", - "! chmod 600 ~/.kaggle/kaggle.json\n", - "! kaggle datasets list" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S867tktZNHqD" - }, - "source": [ - "Now with this command, we actually download the dataset. This may take some time, depending on internet speed. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ZQ_96E2ngwvJ", - "outputId": "0aba3d27-0698-4abd-8057-c9615518e7f2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading cityscapes-image-pairs.zip to /content\n", - " 92% 185M/202M [00:03<00:00, 36.9MB/s]\n", - "100% 202M/202M [00:03<00:00, 53.9MB/s]\n" - ] - } - ], - "source": [ - "! kaggle datasets download -d dansbecker/cityscapes-image-pairs" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OKF2tpAAPHBN" - }, - "source": [ - "The download has to be uncompressed. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DQtpMD67hbAO" - }, - "outputs": [], - "source": [ - "! unzip -q cityscapes-image-pairs.zip" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zKH-76ZJPMeR" - }, - "source": [ - "Intall trax\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mh05_t3Phy2h", - "outputId": "21c0cd27-8c13-49d5-b30f-65533d9a8084" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[K |████████████████████████████████| 522kB 7.7MB/s \n", - "\u001b[K |████████████████████████████████| 235kB 13.7MB/s \n", - "\u001b[K |████████████████████████████████| 3.4MB 11.7MB/s \n", - "\u001b[K |████████████████████████████████| 1.2MB 59.5MB/s \n", - "\u001b[K |████████████████████████████████| 1.9MB 54.5MB/s \n", - "\u001b[K |████████████████████████████████| 3.8MB 53.0MB/s \n", - "\u001b[K |████████████████████████████████| 368kB 53.1MB/s \n", - "\u001b[K |████████████████████████████████| 61kB 10.3MB/s \n", - "\u001b[K |████████████████████████████████| 3.2MB 56.4MB/s \n", - "\u001b[K |████████████████████████████████| 890kB 54.4MB/s \n", - "\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "! pip install -q -U trax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6HGMKVu1kfYh" - }, - "outputs": [], - "source": [ - "# several imports from trax\n", - "\n", - "import trax\n", - "import numpy as np\n", - "import trax.layers as tl\n", - "from trax.fastmath import numpy as jnp" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "o-3g4wi1leJy" - }, - "outputs": [], - "source": [ - "# several imports out of trax\n", - "\n", - "import os\n", - "import os.path as osp\n", - "from PIL import Image\n", - "from itertools import cycle\n", - "from sklearn.cluster import KMeans\n", - "import matplotlib.pyplot as plt\n", - "% matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZTt5oh_QjdcI" - }, - "outputs": [], - "source": [ - "# let's fix batch size\n", - "batch_size = 32" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KmqJJqs8PnEN" - }, - "source": [ - "Some details of the dataset in its original form: \n", - "The original images are of the shape 256x512x3. The left half and the right half of images belong to input and label respectively. In a typical segmentation label, the label should be a 2D matrix consisting of the class label of objects, such that each pixel is alloted a class. In the label images given, we are not directly provided with the class labels. However, each class label is represented with a specific color. We need to map colors to class labels, to convert them into usable format. \n", - "\n", - "We know that there are total 13 classes in the dataset. Hence, we will be given 13 different colors in labels. For processing the label images, according to the procedure mentioned above, we will use K-Means utility of sklearn.\n", - "\n", - "We do the processing in the following manner" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tIBSJ3gpkmf9" - }, - "outputs": [], - "source": [ - "def color_kmean(root):\n", - " \"\"\" creates a k-means objects that recognizes all 13 colors of dataset. \"\"\"\n", - " \n", - " # take 10 first images\n", - " files = os.listdir(root)[:10] \n", - " colors = list()\n", - " for f in files:\n", - " img = load_image(osp.join(root, f))\n", - " # total width\n", - " w = img.shape[2]\n", - " # get the right half of image, which is the label image\n", - " img = img[:, w:, :]\n", - " # collect all the colors present in label image\n", - " colors.append(img.reshape(-1, 3))\n", - "\n", - " colors = np.array(colors)\n", - " colors = colors.reshape(-1, 3)\n", - "\n", - " # finally, fit all the colors into the KMeans\n", - " kmeans = KMeans(13)\n", - " kmeans.fit(colors)\n", - "\n", - " return kmeans\n", - "\n", - "def load_image(path):\n", - " \"\"\" loading an image. \"\"\"\n", - " \n", - " assert osp.exists(path), path + \" not found\"\n", - " image = Image.open(path)\n", - " image = np.asarray(image)\n", - " return image\n", - "\n", - "def color2class(segs, km):\n", - " \"\"\" \n", - " given an label image, convert it to class matrix, \n", - " which is a 2D matrix of class labels (scalars).\n", - " \"\"\"\n", - " \n", - " h, w, c = segs.shape\n", - " segs = segs.reshape((-1, 3))\n", - " segs = km.predict(segs)\n", - " segs = segs.reshape((h, w, 1))\n", - " return segs\n", - "\n", - "def load_dataset(root, km):\n", - " \"\"\" load dataset. \"\"\"\n", - " index = 0\n", - " imgs_path = [osp.join(root, f) for f in os.listdir(root)]\n", - "\n", - " # load images one by one, finally, and image and \n", - " # its label matrix is returned\n", - " while True:\n", - " img = load_image(imgs_path[index])\n", - " w = img.shape[1] // 2\n", - " img, seg = img[:, :w, :], img[:, w:, :]\n", - "\n", - " seg = color2class(seg, km)\n", - "\n", - " seg = seg.reshape(-1)\n", - " assert img.shape == (256, 256, 3), img.shape\n", - " assert seg.shape == (256 * 256,), seg.shape\n", - " yield img, seg\n", - "\n", - " index = (index + 1) % len(imgs_path)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "udqueyxmA6Pc" - }, - "source": [ - "Uncomment to try other backend. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "DJq1biuLxeFa", - "outputId": "f95918ee-413a-4ecb-9982-a34c3d3e6177" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "jax\n" - ] - } - ], - "source": [ - "# trax.fastmath.set_backend('tensorflow-numpy')\n", - "print(trax.fastmath.backend_name())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ce_KBGtTBB50" - }, - "source": [ - "Set path to dataset, and get kmeans color setter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "HLysKwN0Xy5t" - }, - "outputs": [], - "source": [ - "root = 'cityscapes_data'\n", - "\n", - "trainset_path = osp.join(root, 'train')\n", - "valset_path = osp.join(root, 'val')\n", - "\n", - "km = color_kmean(trainset_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lex2Tm72BrFf" - }, - "source": [ - "Create dataset loaders and data transforms." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ngHMyZBbjfft" - }, - "outputs": [], - "source": [ - "train_dataset = load_dataset(trainset_path, km)\n", - "val_dataset = load_dataset(valset_path, km)\n", - "\n", - "train_transforms = trax.data.Serial(\n", - " trax.data.Shuffle(),\n", - " trax.data.Batch(batch_size),\n", - " lambda g: map(lambda p: (p[0].astype(np.float32), p[1]), g),\n", - ")\n", - "val_transforms = trax.data.Serial(\n", - " trax.data.Batch(batch_size),\n", - " lambda g: map(lambda p: (p[0].astype(np.float32), p[1]), g),\n", - ")\n", - "\n", - "train_dataset = train_transforms(train_dataset)\n", - "val_dataset = val_transforms(val_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HURVJcElB9et" - }, - "source": [ - "Create a custom loss. In semantic segmentation we need to apply cross entropy for every pixel of image. Hence, we decrease the number of dimensions of the matrices so that we can use CrossEntropy2d, while maintaining the order of elements of matrices. \n", - "\n", - "Here, we convert the 3D Neural Network to 2D array and 2D label matrix to 1D array." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ZEdJXM9g8rif", - "outputId": "6b78ca76-db43-44c6-b618-435cbd8c8f3e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(32, 256, 256, 3) (32, 65536)\n", - "float32 int32\n" - ] - } - ], - "source": [ - "def CrossEntropy3d(criterion_2d):\n", - " \"\"\" returns 3D cross entropy loss function \"\"\"\n", - " def _loss_fn(output, target):\n", - " output = output.reshape(-1, 13)\n", - " target = target.reshape(-1,)\n", - " loss = criterion_2d((output, target))\n", - " return loss\n", - " return _loss_fn\n", - "\n", - "# check dataset\n", - "x, y = next(train_dataset) \n", - "print(x.shape, y.shape)\n", - "print(x.dtype, y.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VWmhQZDElSo6" - }, - "outputs": [], - "source": [ - "# set learning rate\n", - "lr = 1e-2\n", - "\n", - "# create new trax Fn for new loss fn, and provide it a name\n", - "criterion = trax.layers.base.Fn(\"CrossEntropy3d\", \n", - " CrossEntropy3d(tl.CategoryCrossEntropy())\n", - " )\n", - "\n", - "# create TrainTask\n", - "train_task = trax.supervised.training.TrainTask(\n", - " labeled_data=train_dataset,\n", - " loss_layer=criterion,\n", - " optimizer=trax.optimizers.Momentum(lr),\n", - " n_steps_per_checkpoint=50\n", - ")\n", - "\n", - "# create EvalTask\n", - "eval_task = trax.supervised.training.EvalTask(\n", - " labeled_data=val_dataset,\n", - " metrics=[criterion]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mieHBPnpExJo" - }, - "source": [ - "Now create a simple Serial model. You can create a complex one according to your need. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LgWQmYCVoBXU" - }, - "outputs": [], - "source": [ - "model = tl.Serial(\n", - " tl.Conv(13, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(64, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(128, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(64, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(32, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer()),\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - " tl.Conv(13, (3, 3), (1, 1), padding='SAME', kernel_initializer=tl.KaimingNormalInitializer())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6Z5SsOVNE6KJ" - }, - "source": [ - "Crete a training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TE2Rfdafv5xl", - "outputId": "3cc3fc96-f812-470b-d058-b07b7d67f339" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Will not write evaluation metrics, because output_dir is None.\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 1: Total number of trainable weights: 211795\n", - "Step 1: Ran 1 train steps in 58.45 secs\n", - "Step 1: train CrossEntropy3d | 4.64949989\n", - "Step 1: eval CrossEntropy3d | 5.10474443\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 50: Ran 49 train steps in 45.91 secs\n", - "Step 50: train CrossEntropy3d | 2.21896791\n", - "Step 50: eval CrossEntropy3d | 1.99541283\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 100: Ran 50 train steps in 48.97 secs\n", - "Step 100: train CrossEntropy3d | 1.97824812\n", - "Step 100: eval CrossEntropy3d | 1.94622588\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 150: Ran 50 train steps in 48.53 secs\n", - "Step 150: train CrossEntropy3d | 1.96946120\n", - "Step 150: eval CrossEntropy3d | 1.95210052\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 200: Ran 50 train steps in 48.86 secs\n", - "Step 200: train CrossEntropy3d | 1.95508432\n", - "Step 200: eval CrossEntropy3d | 1.93703401\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 250: Ran 50 train steps in 49.42 secs\n", - "Step 250: train CrossEntropy3d | 1.95142782\n", - "Step 250: eval CrossEntropy3d | 1.85742092\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 300: Ran 50 train steps in 49.70 secs\n", - "Step 300: train CrossEntropy3d | 1.94090188\n", - "Step 300: eval CrossEntropy3d | 1.97651076\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 350: Ran 50 train steps in 50.62 secs\n", - "Step 350: train CrossEntropy3d | 1.94406307\n", - "Step 350: eval CrossEntropy3d | 1.93186378\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 400: Ran 50 train steps in 50.43 secs\n", - "Step 400: train CrossEntropy3d | 1.93093717\n", - "Step 400: eval CrossEntropy3d | 1.94998646\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 450: Ran 50 train steps in 50.67 secs\n", - "Step 450: train CrossEntropy3d | 1.93102098\n", - "Step 450: eval CrossEntropy3d | 1.95288503\n", - "Did not save checkpoint as output_dir is None\n", - "\n", - "Step 500: Ran 50 train steps in 50.37 secs\n", - "Step 500: train CrossEntropy3d | 1.92091882\n", - "Step 500: eval CrossEntropy3d | 1.91793752\n" - ] - } - ], - "source": [ - "training_loop = trax.supervised.training.Loop(\n", - " model, \n", - " train_task, \n", - " eval_tasks=[eval_task],\n", - " output_dir=None\n", - ")\n", - "\n", - "training_loop.run(500)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "F_eQXlgAJQd8" - }, - "source": [ - "Lets see some example" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "CcR_gzqsJUom", - "outputId": "ea1e1457-b4d1-4499-f7da-c791163eb740" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 17, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAACFCAYAAACg7bhYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9549lSZre9wtz7L03r0lblVm2XfX0uJ2dWTM7S3JXIEBiV1wJEiCSgCAJIEhiSRECJUD6G7RfFwJEQAIIAisJkj6Q4o6G5HL9znBMz/Rwuqe7TJfNSnu9OzYi9OFk3qrsMu2qarpb9RQaXXXPuedEnDj3iTde84RwzvEcz/Ecz/Ecny3In3UDnuM5nuM5nuPJ4zm5P8dzPMdzfAbxnNyf4zme4zk+g3hO7s/xHM/xHJ9BPCf353iO53iOzyCek/tzPMdzPMdnEE+F3IUQf00IcVkIcU0I8T88jXs8x88Gz8f2s4nn4/rZg3jSee5CCAVcAf4qsA18H/hbzrmfPtEbPcczx/Ox/Wzi+bh+NvE0LPdfAK45564753Lgfwd+6ync5zmePZ6P7WcTz8f1Mwj9FK65Cdy579/bwC8+7gthFDkkdDptHA5rLUWekyQpzjmklHiehzEGz/eIopD5bEZRliilkFJS5CUIR60WEfoRzaUWgR8C4ADrHNY5SmuQUuAcSCFQQmKdw1mLMQbnHMYaalFMnmcURYa1Bt/3EVIihaAsSvIiJ4wijLUIKXHWoZWkNIY8y/B8H5zj4PCQvLQ0lhpIBMudZbRSVaMAi6AsCowpUVoync1QUtLr90AIoijClAasYzZLKIqqzzhLFAUoCUpr8jxHKY0DyrLEWIuDqm3OInA46wCJUhKTF9TjGmmeMctSakFILCW7g2HXObf6JMZWRzUXRh0QIMqPsUIU1SA6CdYTmOCjXwpAZSDLk585QXVd8Z5bm+q/R7XpoX9fXPTjtfNJIuluP7FxlVI67UuEEO/5vPpdHXsDlJIYYwGoThUnjjnAHh1/mpCyaqeDo9/APShV9cPaqh1SKQCstY9sm5QCISUCMMYipFi8AlJKjDEIIU6+Cu/xkNijdhy37f7P7m+bsRatFMYYpJQIKSs+AGKnGKTpI8f1aZD7B4IQ4u8CfxfADwMufPEFPE+ytNRgPB6RZhnOOur1OlEU8eKLLxIEAUmS0Om0uH3nJvv7O/h+QDLPOX36LBcunicIFN/42l/ir/zyr1MPG9Xvy4FzYCQYXDUIVA//+NFad0zulqIoqIUhWT7jj//sX3Ph4iZLS0s4Bwf7h5w6tYGxFpRGeT5ZliOFwFMSRDXAUkiss+SZYTLNWFpaYjadcfb0eUIVgnU4bclRjMZDpJ0xy0a8eeUKeVrygx+9gfIDarUlrl2+TjZN+enlm2RpQeQH5OmEL3/xJWqxJApDkqTACwNKYDSdMpnNMa56YZQCU2QIB8Ipirxk2Dvgqy9/nuu729zobXOx1eZlHfG7v/9vbz2xca21+fJf/kcPEOaHhdUCYSpir20neHf7H++C8iGsqyTJi6tY/+Ri1h8V6GFy4rOiHbH3S/HRtaD82oRsL0YlkvM/vw1Afx5hv7VCGUNt15I3BKOX7t03PJTEe47uL5c03vGIug+2KWsKytpR8xLQyUefLd74n//bJzau2tP81d/4FfBOkp8AcILIRiRyTuBCMpHi7hv/0AZkIgMByimWTJPYRh+naSeQiozABQ995VKREbr3WgbHZzrmMmEu59RsTPSYNlkspTCUokQ7TSFyfBfgOU0mcnznLdilFAUWRyZTlFNo5+E7b3Etg8UKg3ffZ/e3zWIZ6iGdsoNAsOvtUooSWzr2dnoMvvP2I8f1aZD7XeDMff/eOvrsBJxz/wT4JwBRI3LO5cznJZcuvcLW1hY//OGPKMsCrTX/4Lf/AVtbW/ze//Z7fP3rX+fNN98kTTKyrADg3PlznNk6j+dFnNu6wJc+/zWioA5QWaw4pBXI+zosjs2q41ldSJyqhsQqr7I2gTiuYYw9mtklt3e2aXRatJpNhr0hk/E+F89fREjBcDwgiiLiIMJYQ5KkHO7u0Kw3UUVKrKBMp7hII4VCOIECAiUIfM3167eZjQckqaHRaOCkz+3be+zuHFIkGQ4I45Ag8Hjplddodupo7ciLnKDRwAmBJySxkwgdYo1lPpsSRxHz2RSlFFp69IsBUadNN52Qu4y15RZx6NMbTz722N4/rvHqGWc98XDL90Pg+PvxfoZ3p/fxLvYoGEt0ef8DneqVhmAUkbUFeg5r7RFfffEnvNHfYpRVq0VPWcYtOPf7Q4afW8IfO175Xwb3LlKUoBUv/Z1Drv67S6z82c4D95m/ssb0dPXGRj1DtJs8cM7uNxrMT50kfX8k6Lx976H3P6fer0sfalxrtZqrzVcYqCEt02KohguO1FbTLDYZ+bfJcYSuSSKSxfHCgefqSCSpzLg4P0/TNBEf1wI4gsUijv58mGPHxy0WefTnUXBH3JHIhNBW43183ffe4/jc+z+///7Hxx/VJodDe/sop+iUHfDqXA+u44Sl3skf+yyeBrl/H3hJCHGB6gX5m8DfftwXnLOUJqcsS/r9LnleUq/XmU4nxHHMa6+9hlKK8+fPU6/X2dzc4vLly4BAKc2pUxtYZ7n8zjuc27zIcme1WiNVpgQIh9UScPce7XuNIOeQR8tMJaopQSnF1uYmQhVHyyLB2XNnUVqTpimB57MzGKHOCTztsb+zR7PVpHluiTzNuHXjBsvLy0S1kNLl9KYDpkXC+qqlFjcJrI8QDuEcZV4SehHZ3JKXJUVueffGdba395kMhwhr0FLRbNc4f/4MtcijLDOckKAkTlVLSYFASUGjFjOdThDWoIRACY0pDEEsMCYjK3Myl7G80iRLZwy6PUbp41+WDz22ApwUCPNk/BMy+5izxBOCSDOWbuQUPU28Peew3ORfbGwRfaVH5BdYJ/CkRRrIOxHe3FK7NkaMZyeu4wKPt7vreAYoH+xb/NYu8VuPb8uZ/3Pw+BMAYTff75QPOa6OcTAAA0M94P6lmRWWru4uSCuwAVJIZnJWvQ/CIZ3EeTkKtyC9J4XHkfLjjh0ff79z4B4RH1v397f/vd8/PqZ4+AT7Qfq+Uqyw7W+zF+1Rt3ViGzNTU8T7NPWJk7tzrhRC/EPgXwEK+F+dc499RZ0FKTTNZo3haEAyn1MaQ61WJwgC/vm/+OeY0nD12lXu3r3LP/qH/w1pmvGv/83vY4zhypXL1GtNvvylL9HvHpAlM8J4qXKkispGL6hIWx3xjBNHPrKjNrzXf+gApRS1eo15MsQYQ1kaNjdOo7Um8Hy01ER+WH3XObY2N4/6Y/GUpshyOp0OQS3i7XfeZDodEdcShqMur136IkUpuXtwQFz3iT2H9jVZUXJ4MOLd69vcurPPeDxFS7Au48zpLX7xl75GXAuYTAZADSEkWZ4fxSQsRVZQBh55lpMLS+gLPFkt8KS0GJNSmoSlKKBR95n0esz6IyaThIPu490dH3ZshQVZPBlidwqceHIk8LFgHdHlfSKtmF1ao3ZgiQ9hNF8m/YUxm+0RN1/fonnoyFsaYZ/MM/ioCA+zxx7/sOPqgGbZpC9HeIGhSO/RiMVihFnEIcZqjHYnaSYVKfW8Ti5ycpnjjHuiBP+s8Cza7HBM1ZQ9f49SlIzcCB0YpBHwPuGKp+Jzd859E/jmh/qOlaRpjtISqcDXkl/7tV/jypWrbG9v89WvfhWtNZ3lDt/93ve5euUq1jqcMxweHlKrNRhPBvzmX/sPiSIfR+VGOX7LBJWrVRz9zu7xxMLEf7BNQBzFSFWSpila+5gsp0xS6iurGGPxg6DyqRU5jUYDKSV5niOl5POf/zy+72Ozkk6tTk3DoN8lTQp6ezcYz0fc3tvh3PkzDIqUyTTnoLfDm2/eZmdvwCwt8bwAJTKCKGTz1DIXzmxwemuN6WyE5wfMZyna85hMJwRewHg0IQpC3vjRG8RrHSYzn+k0OZo4h0RxwMsvnWf31m2uX3mbg+27uMRiZECqH+b3+/hj+3EhnMMbW3Rv+ixv+3hoRb7ZxgbieHFI47Zl92LMuzs1Nr9jK1L/lHDWhxpXB9oEIMQJYj9GYAPWi3VGakQiH3QlIcC3Pm3bZrlY/kQT+6PcJg638Lk/rfY7HIlMOPQOMRyt7ITDObBGvO+79TMLqN4PpRR5moBwSAFh4FUPVQpm2ZyoFvCFL77Gq5cuMRunXL35LutbG1y9cZUyL8mzgt7hgN/6zd/i1Kk1bu/expQCz/MAS6fdRjkPrRVKqipCThUll1TZJNWjFAvydzgkUOYF3f0unU4L3/dxFtIkw9oqcOr7HmmaobUkCHysFTgrUUpVQdwyIU9GYFLaS0sM+10OezvowFC6gjwZsnu3pFFfJkklBztDbu90KYoSTytqNc362iqDwSFIqDXqCKGRIkDJECUtWkskDk8ItHN4WlNrNYnqMS+2Wuxsb9NsLHHr+g0Odne5cfk2cehTc5KXN8/QOxxyOExR5WMG6SlCOIcsHNmSQqcPWrn+qMTrzcGeNFVc6GNbNdTBEJ6xdZxvdRi9EBIO77lTVOY49UeViwwc7iieIz5jeyYI6VBhDve9L8odZZlgiW1M3dRZK9a4Edyg63VPXsDBUA9pZk30J4OCHokqeFoSPBCIhW1vm7Zp0zKtE58bzIkg6Ee9r8VyI7hxIqYBLFzKRfL4WMon4slaa6uUx6M0QxmFFEVBvdEgjiIOuwe8c/ltXrjwMl/7xV/kxt2bXL/xLsYZLr5wkf29XV586QW63UOmsxkb621AYUrLYXcP7SmkU6RJekTkVfqgpz2W4gbGGsIgxPM8lFSAWLhppJRcuXyZL37pCwRBiFKaGzfeqbJ4wogsz+j3e6yuLlehW1d9RyuFtXD93Xex5ZiiMKR5He1LsmLOeDpEKk2RZwwHA2q1ZYIwZufuLsaU+IHHUqPG2uoygpI8yzCmwGEQUty39HAgJSiFUAorBE5KGq0mzjl8rfE9nzs3b/LOW2/RP+wym06ZKoUqc9r1OiutDtNpl+n4IVbWs4ADf5BjveChrpes7aFmHnpw0nIvNppc+y8Vr/6PAWKePqvWAiCsIxyaB9qr8s8WkT8MVeDwZD+lk4QuZCIn7Hl7RDaiZVqczc8ykzMSlSwWyQLBSrnyRAOpTwvmvqyYkRrRMZ3FsY1i46Erk0IU7Hg7tMv2h+6fwzGXc24Ft5DuXqzi+GhgA0xhQRQo//F+mU8EuR/nskP1f601UgqS2QxTGibjCXt7+2S55fyFl2m2mvi+xvc19XrEgXRkeUKaJtQadeqNBtYKlFRoXzIcDrCFY2VlFVOWNFstrDH0+332DvcYjAYLcl9dWWU6nnBq4zTeUS7v5tYW+iiI6umA8XjMfD7H930mkwk3btyg3W6SjyfcubNDs9lkc/M0pcm5fecGWmb4YUBhM/a7BwxGPRrNGtJa5vMZjWaLwaBHp11Da0kYalZXVtBaYU2OwGKKktJU/x3ncAoJyMoXbYWklAIjqnTP+WyOVpJht8+/f/1HXHn7bdJ5gjMWawyT0ZC1doNQ+9VKwIH6GfmGsyVF48cDnLdM2vEfOC6sQ86Lkx9KQbLuE2yrKs/1GcLFIbPN8InFEz59cPhULjzlFAZDIQoKUVRBU+fY9XYJXEDDNBbnec7DOUcmMoZqyKp4VNr9JwMGw7a/TWQjIhuhqPqqUAgEoQsJTfjA947TGjOREbkPnubpcKQi5cA7YKAGlffivrlBeRZnSuzRYnHae/De9+MTQe5CVMUPeZ5y5uwWQeAhhOB73/0uxlr80OONn/yYdmeNlfVNBqMha2vrZFnG/t4+WZpz88ZNLr18CakUWVHgrCQIKms2iiM6zWXKsqQ0BcYaijJnns5ptZfwax6+71PkBdpXNJebKK2qIiel2NjYQOuqkMrzfF577TWCIMAYQxiGrK2tkWU52qsmpcFgwNpaB4Qhyebk2Yh6vYYTlvFkSJZnDIYj1lY3SJKUqFYyGPTo9uYEgWJraw0c5FmKJz3A0VyqcvaLsmCeJEymU0bjMVmeEYQxw9GUgRuTzOYopenu7THo9ti+fovdnR2yLENLReD7TGcpOIkzgv29A1qtZZaWluiN5z+T8Q8HhsmXNxBl5cJ4rzXshMB5J5eg5ak209OKF/7pTpVW+IyhMrtwu3xkCEGnNmdC+8k06hnBAZP4EMYS6eQigOpZj5KyyohB4ttqoj4mRSPNPZeCKB6R2/2zg6OaeEIXLv49lVNGaoTvfApRsGSWFuc/yioXCDaKDQZ6QFR8MHIvRMH14Dp93T+RNmmdXRC8KRSGe/9W3qfAcgcoihxxFIxcXm4TxxGXL18jrMXM5jNsF85efIlpOuflS6+yvX0HnOCHr79OvzdmbXWFwWDE9Zs3WO7M6ffHxHGN6XTIqY0N0nlaZbkEAWWZU5qCMArQvgYVIJWsvFzSkiQpjaiJtYaiyInjGGMKiqLA90KWlpZQR1VjWmvOnj1LlqXg4Pz58wwGfbI8AVGwtFTnYH/MaDKldJYky+gPxlinCfxq4A8O9gnCOkWZs7a+TIlPmuY4G+GMpSwKms0mgR+yt9fl+z/4EaY0DAZ9PN8nmWdEUUxW5HRaTWb9IXeuvsu1y1eZTWZIIfC1h6c9yqLEOUmaW2qNFv29bepLlqXmEnKv+z6j9OHg3je9uoL1BdmSpHVlTrr28PLT8SsN2t3RvX9fjGleL58tsSuJiwKKlfrHJ3YA57Du/QNjnzQ4BxkFEFDIeysq3/koFKlIMcIQuhDpJHVTJ5EJxpn7a4Y+cbBYxmqMM1WGihHViiRRCWM1Jhc5sYlRKGq29sjrSCQ1W+Ou90B5z0PvOVRDtv1txmpcTYrO53x2npqtcTW8ykQ9vP6k1np8FtQngtylEKyvr5NlCUWRk6YpBwcHhEFIa6lJalJaK8usrq+BVjgkYVDHW4kQeCjpEfgxOzt7ZFlGluccHB4iZY/ZbMxyp808m6CUIkkShBQkSYJzjjNbm8yTOUoprLWsLK8ym85ohS3iOEBKRVnYyh+vKreRtZVEQhWwPeqDrGQCgiCg02mTpGMsKY2lBj9+45Bmu05/0CcvDUtLHabTlGazxWw+YT4ZgUgpSouU0FyqEyY5SiqSeYItDUWeE8Q1nNAgPIxzWDykCvG1AAO+8jjcO+CN73yXpD/AJhlaVCXLWmju+XMESV5w2B8Q1WpMkoR0lpIUj39ZPixM+MF/wSvfPsAuPdrKifdP5uDPNiSNm882RlCcbjM/FaLTJ1cy70vD7BNIdI+FE7hhE+R9cQ4HpSjJRIZ0krVibZECWYqS0IbM5XyRfeJw9HSPWl77xPjdJRIrLDeDm7RMi7EaV3UjrpqwIhtxM7hJp+wQ2eiROfEOh3KKrXzrsfezWAZ6wOXwMhaL5zwuZBcIXUjDNHA41ov1R5L7p8ItU2808MOAvMwRyiNJC6azBN+T7O5OeeHSy7z86ueIoiqvOwwjyqJg2B+w0ukwH43pHRxSZinZPEWv6srfrjWtRpM4iGhvbJLME+7u7jBPEsqyIup5UnD1+k0KUyCc4Jd/fpl8Yig6DmMtaZGQpFP6/T5rqxukRYGUgjzPabc7lRaNKXHGEoUhzhWk2YzZfIgTKaNxj8KWFMYxT0vCMEIAna01nJCMpzMcjtl8TrvVYhYUZFmKcQ7f8xESJJbR4SGrniZUCk9KMuvQvk9ndY3pYMSo12N/Z4fB4T5kOYEQlA6kVEilEVJSFAVpmjKdzsjLglu7+4R+RfrGOYqHFNI8CwwvKupXPNL1CKsE8qjo6X73TNb28Jdqi0IgYUEmz9Ydk64FCAvGf/iPWjjHfEVRRoLG9vs/y9FX1jl4XSBfgtreBvFP9550k58aGqZBKlLEUQ04QC5yEBCbmJVyZfG5EYa5nNM2bfqqv7AxDr1DNvPNT1zGTClKbgY3UU6xmW8yUzMslk7ZoRQlQz3kdHH6kd+fyilXwitcSi89dIXijoq3Dr1DbgQ3sEcJ6zVbo1N20GhykfNu8C5TdTKJQDhB4AJSmbzviu8T8VTzoqS1vM7G6bO0Ox2MMbRXJuxu3yBP51y7dpOXXv084/GIVnuZNEvY2b1Lp9Xm7/+9v8f/9Lu/y7vXrlGr1cizyr0jjgSAlKqyX5T08AKH0BqLwIrqDZNSUzrIyhJXWEI/5oXPvUBuDdYVTCYjCjPnnctv43kB2vOo12u8887bvHLpFbT2SJOEfq/H2a2tyt/d3Wd7512iWKG0oLPS4fDwkF5vQKOxxM9/5WuUpaEoHatrG9y6fZ04auAMZGmBroUI6ZinCcYYBt0u/cMuzpRsrK1y49YN0B5hGDMbT9i9s831d97BFiUKi4fDOsjLEuEHWGdJk5TZbMZ8Pq/EjoTCAvO8IqEnLf38QSCcQxhovWsQRUntzT3mr26QrCjKUBAfGpwUGL8i/PdWeD5TSEFel3iPMbOtFvS/bKlf/wD+KClI24LOmzC+KAjvfoJy+N8H0kmWzBKrxWpFNCJlx99hrMZA5T+WrpoAJRLPeSinOJOdwQSGka7ca82y+cjKzZ8FBIL1Yh2DqQhdWO76d7mQXSCTGdfCa5wuTmOEWUxcD0MuchKZsO/tczG7+MDxRCZcC67RKTt8cf5FRmrE9fA6YzVmpEfkIudQH1bP8z43lue8E6uFeufxGWKfCHKPopivf+M/wPOOVA2LkjRN6LRXuXrtCoNxn1rcZGN9HeFKpICVlTXazRZf+OKX+c//i/+K3/md36G9ssKFF1/ACoEVx+pzitIaZlmCsVVp/zxLK7U1rTDGkuc5WZ7hCldl7iiJdkeRauco85yzW1vkaYYSgvlkgi1LRv0BnU6HPE3Z29ul2WwSRSGH/S794Qhv7phOJwShf6ReKeh02vT6XU5tbPLO5St0+/s4V1Lmc5KZJk0KluqKvMiBSoAsTTKiMGI2nJLMU3QY4LQmmc4Y3LzLYP8Qm1biZUIIpOehpGSW9jFJTlEUZFmGc7Yq/Lov1RN+NsR+jGh7hhpOq/L7Iz92vFegUlP53x3kDUHcPekK8aYOYZ7hSsM6osOSMn40Gc02FAjzgWIN2cVVsrYgPjDITDH4covlP/m0EHwVMDx2p9So0U7a7Hl73PXvIp08QUTCCTSa2MYsmSVGqiL3x/mtnzWOrenjNpeUeM5jrVjDYpnLOXM556Z/k81i85GupEIUKBTL5TKlKE9cHyriNxguZhep2colpZ3mwBwwVVPeCd95IEsGBxezi3TKDhLJzeDmp8dyR0iU36AwBoGgdBLte5zaeona0io3bl4l8hto4UFhEaXDlx6RHyGcII5qSKmp1xvEcY39XiUXYBH4qpLgTJIEJ46kPI8kha2p/OzmiCS0rh7HsYCnNYYsSXDWEHo+RZbhKYnTkmajgS1L+t0uSilObZzCOkt3cMg8mTKeTBgMe+T5nFarxWQ6RkpBWZY4Z7mzfRutA6TQhHFEvbbEzt0BQnhkWU5eFFjrCLVPkZekSQrGkSY5zsK4P2S0s0/ZHSNKi++q3HqpNcYZRtM503lKnr8nhdBVPfyEuDnJl0PCeYY4cglZLXChqkS8xCrpik8wdhx+SeMP1wjePai+6KDoxASDZ0eITj3+oUVdy1A5yvdLkJCCwUsBzXerPoc9RxmIanL7GUsVfFRIJKeKUzRNk1zkJyzM1XK1Cq4iWS/WOfQOSWVl7a+UK/juwfTXZwmHY9vfpq/7XEouYUVlSIQuJJMZvvNZMkt0dZe2aTOTsxMTwf0oKReBUe00iUiqQkkcqUw51Ic44XgpfWnh0kpkssiXd+Lh4+85j9CFi0kIwBSPfx8/EeQuEJSFQwiNdQ6crqxJJYnrbV599ecIggYXz7xMHHh0lpbwfvXXUFpTD2KkFdSjGu9eucZf/OmfsnX2BYSpNNqF0mAcZZqBENi8IJlMKUyJkgpTlmAdge8TqGAhHiYQZGnG7s4uUpVUCbyKLEvQuhrU6u8aYyRCikobfT7DOItUir29Q5SCCxfOEQQe16/f4vTpDQ4PuzQaLcIwZDTxCALNbDZhY2ON3bsDnBNkWUGRF6hIUJZHUgpacPv2HQaTEaN+n3I6IxKKAoM1lcrgbDpjNB4zmyeUpUUcvYCVdX4stVBp4SAqfe2nZbnL/P1nEJWaewVI1qFyixpWFo+eZIhlD5U7Om9DuuIRvFudmqwLgvG9fz8LqMxiQvnITA+VOZo/1eRLDz++gHWEQ8dsQ9G4a0iXBef+7/1PLbEfQyCo2Ro1TlrkTdNcFCwFLqBZNkn9lFSkVb57ufpEgqrHujbSyQ/l7nE4JmrCRE4YqzFjNaZlWpSiZKPYYCqnrBarHOpDYhvjW/+RwVSJxGC4kF1AukqI7Ep4hUxm1EyNuq1z17/LXM6RTjLQA+74d6p00kdAIE5IEB9Phs59CsgdQCEXuuvHgrylVKArbeS9nR6vvSR5+dKrBMJyZnljoRP9pdc+z3//j/87vvWtf8lqs8NXvvAlhsMps9kU6SydRrNSCpOCrJ5QJBlZnqG1Ry2OqdVqlK6A8p6OhEAwmUx4/fXXkcoQhRFS+vi+AmGJohhjShqNqkjDj2Kk52FdATg8P2A2y4gir0qTBMIwYDQakmUZjYajVg9pNmOGwx5KBezubjNPLI24sZAZzvKceqPB4WiK36wzGI4ox1PEaE6sFEgonGE8m2JGY5IkxZYOKRQKRcn9L82DL4MQ4gHRtCcF6z+erJwQ5G0f7yBAJFWmTvT2vaCimCYYr3F0siPauy+bx/FQqYKniWBnTNGoctI/rohZ67t3Kf7KFtu/YWj8VCGy91Xk/NTifuIWCE4Xp7HC0tO9hZX8JDBREy6Hl/GdT83UqNkapajcKzVTTThzVZFqfKTZfuw/P85WOc7sKUTBcrkMQGxjJJJLySVymXMzuEkhClbL1QdkCbTTbBQbFKIgchECQbtss+/t0y7bjPSIkpI9b4+u7i4Kv3DVfZRTLJtlpnJKJu9lHsU2Xjy/JbPEDnffdzr8RJB7mmXsHVTSofV6SJrNKQuDkyCQlAWYouT/+eYfgoOfe+0SwpDYquoAACAASURBVBkODg/otNucWjvNxvI6Z0+d4sKFc3h+zOmVo51irEUIB64qmD69toG1rnIBiUr0abnVIckSsjSjWVsCHEII4rjBuXMvM5n2mE+maL+SNDg43Dva1CNHSl25PpRC+yESR70RMJkO6R7s8+rnLjEaDEnzAqkV8ywhiKoCqCD0QUBhciazhN4wIQqXcVbghCYzhvFwRDZNmeQZapijihK/hFB7WCyzLGW/12c+T7HmSD7VVct7Yy32eGOKYx4U9xmervq7s0/uB3YSgvF5RfNG+UgydAqKtQb+rQfTMEVREvZLBi/5rP5wit69J28bHTj8YfHAd54mRJoT7mek6x9zKyjAeZr4oEQGR4U9n3Kr/cMgtjEvpi+yqlafqARB3dTZKDbo6i773j51W6cUZSULglu4i45dKqENF+Sei2py3fF2cMJRN3V2vB0uZhcJbMBMztj2twltyFqxRi5yJmpCUJ58FxSKwAXc8e7QMA0CFxDbmNPFacZqTCEKjDDserv4zsd3PsIJzuRnWClWPpA0cGADlFMU2adAWyYMfUajIefOvcDq2hJJOmE8mnH3zh3q9WX29kaUxrG3f0jv9/4Pfvrzr9FuLbG9c5etU6f51V/6OuutDp975bUj1TRBteevAKmoyLrEwVHF57FaZIVmXK2jj90TQgikk7Ray/yn/8nfxpiMPJly89pVzr94keF4SF4kpGlKmmaEQcytO7dI05zZaEZzyecHb/wpl148R7NebZRRb7YZJzOSPCOOPPKiJJlnjMcTSmPxw4j6UkiRaQaDKQfjMfMkRyY5Nssg0Nj5mHKWYIRGKMVBr894PCHJChziPtfLEWSlUeCce7S6nXO8rzD0R4WF5bdyirpC2JOVp8eZMsI61Ow+kpYCF4dQmopMr+yzvl9Hzk7mtJexIF318W8+naY/FKVBpiXw8cldZAXx1S6vbGZc2z738dv2TOEwmI+cwigQKNQJnZYnAYXiTH6GU/kp9v19tNNENjraLamgr/uUoiQVKU5UuehjOa5WD0c/m0QlRDZirMasFWvUbA3PeczFnL7uI53kfH6e9WL9kUQskZzJz7Dj7xAVEYfeISM9upev7qBTdlgv18lFTs3UaNjGB57kjvsU1B5v3HwiyD0KA37jr/8ltk6fJ82n9Ptd1lZOI5SmtJp/+s/+L/b3e0hg0J/y777zbc6cXSeIa1y/c5v9wx6/8HNf40uXLlHz1ZGmlrsXPORIv925hd7W8Z6Ox7if2OEotiUEnufjaYmvJcPhEK18ms1lDg52ufTKa/z4xz/hzJkLvPjSJcrSYgvD9777B/zcz32BWrPBwWGP4UgR1mv0JyO2zpxBojlz6iz1Wovvfu87FDajs9xGypKDwz5JrplkKa4sUfMEijlhJDFCUlAVJR0cHNIfT7DWIqXEHfXluP2VL70yCJ+W2+WDwD9MgAgTSFCuyhF3Dm/uEKXDhJL0VIxX8xDGYUNFvqSRhSO+OUJM5sjRg0FTlblnrpNuNtpMzscP7L/6/zdYYclFjnZ6ITfwMCmB99tl6GlAIPDw2Mw3FzowiawI+0J2oWo/lr7u0y7bDPWQXW+XTtnBCsu2v01sKpfNvrdPy7RomiaRjajZGhM5oad7LBfLj21HJjLueneJbEQucybyHrFvFBuczc8ClX6N7/wTmjUfFNZ+CnzuRVmwt7/H3m6fZqvOu+9e48J5Q63mk5cwn41QylFmOf3hgKWozun1NQ6GI7Z39lnurPFv/+zPGfQG/NKXv0SzWUcdCyceOeatNZU/X4iKDB0cT9fH5HeCBAVIcbSJLwIpJFEc4nkBr3/ve7z+wx/wt/7mfwZO4nsR3//+97lx4yb/8d/4j0iSOaury0hf0h9KDg4OeKHdYWNjg9lsThzGZFnGcivC90Nc6ZiMp9y502c6Badq+FojhcOUKcLktBvLTAvDNB2RuZx5kqG0rgjRnbDXFxlBVX/lAwHTRwVQn8YkYGMPlRjG5zzqOxURII4KgY6SJEwgMKsnrWHrw/SlFnrWQCcGvT9CZPcslbhrn71wl3NPnNiNlR97G8L3g9NPdmUmnVoIYh1vLTeTs4X2ynGwcaRGGGE+lvTtR4WgSr90OApRcNe7y2q5Ss3U8PDIRMZMzYhtzOeSzy1cJdJJ+rpP4AJeSV9hrMbUTb2SHDA1hBMM1IC+7tOwDSIbPbRvqUwpRcnV4Cqe89BO0zANLJb1Yv2BDKFc5AvZ5MfBYOjrPrkoyCaPzzL6RJD7cDDkW9/8A6zRtDsNJpMR3/7zH6JIUcpD6AiHYDoYsHP7JmfWPk8cxlw42+byleuYEs5uRbz+xhtkWcqXv/QqzUaMrz3ioH6k0S6qXcuPoxdHboyHErtzOHG82YdECgVC4HmKsizwPI+LFy+SZjm+H3AsERwEmsl0iOdpsjxlf/cAz4/41V/9VfZ7fVqtFvsHB4xGI2wuUCqi0aiTDxNu3LiOEEuEQYBWPoXNwaY4T5AWlrIo6fcG9Ht9rIXSWoxzOOsqKYL7sl4W1jsn/en3yxgLca//4kggzTzhvHFZVM/SeorWtYys5REeZpQ1/dh88bCbI7OSbDWkaGiKhiaiiXf7nvbNbF1Se1hB56csnTApPWT2dIlv8EoAf/Qkr1gJajkcPa9HJjJWyhWkk3R1d5HZse/tM9ADLqYXaZpm5V/+GeTg+tbnheyFRaIEVNbzsf/d4dj39unrPqeKU2inqZs6gQuQTrLtb3O6OE3TNJnLOYiqb6I4mcWSiIQD72AhlHa8ocdx0LgUJZGNOPAOGNsxqUwJbUhoww/sospFvthDNYg/BW6Z+WzG63/2F4BGyBwoEXg4U+llW2S1Y541UBZceUuRzmcsb6xTZjmJGZFNR9Qbbf78Bz9gbqZsnlqhWauzvrxJp9nGlRlBEFAUx4R8jwDut2QXZH+8qBRUCoo4wihECCjLAmcNrdYSeZ7hnCHNpsQ1n3o9olaLyW3OfJ5hpjllaej1+pUFJQS9Xh/Z1szTKaXNGQ671BsRp9bPsb87Ix0XuDJjNurDdM7e3h6RKTns9o82BReUxz1wR9k94iF9qTp0om/HxH4/uTvnFu6dJwmVOYolH5lbZGoAD5WWlLXHv3ZFXSNqGuvdJz/Q8SnqG8jSEexOiA8sYfdehkm52WHnG3WEhca2pfWd7SfaF6Rg/EL9U6nX3rr2ZDNxrLBcCa+c0DLv6i41WyOyEUM9JLIRhSgoRcmV8AoN0+BMfmaxscVxyuDwaJPtp1WpKhAPld09jhccrzKOC5U6dBaTkHKKgTcgkxkOx4F3QCYzcFWVaV/3WS6XF2R+K7hFV3fvv/ni2oENqJs6B15Vp7FklpjJGU3TZKzGdL0uEsnZ7OzCTQOVa8s7kld2OOZqfpQFV5DOHq+q+Ykgd6zDZbNqL1VZ4MgRKKzzcUJiBbiK4pHOcPP6Dfb2Dlg7vc766Q1GpuDLr73KZD7jJz99i+u33uYLr73E+a1zNOu7tBotPvfSBdpKk+UJhclwThBHdaQQOKpdkxRUm3k47m2u6jgiP4fSHgiI4ohW2ST0fWphBM4RRxFaOTxPoH3NdFzS7w7oD8c0201u7+zTWm7grKXbG9BorDLLU4xwnDl3msBrMBkWpJMRIgXSGaGUjNKUdJYQS4X2fAqZ4lzVTodAuqqpxxnsi+pTByhZlTu4SpxNSllZ+KJyUx1b7EEQEIQBSipGo9HDRugjQZYOmVdWizvSY7GBfl9FRRM+OMk4dU9mN1+vE++m6MGcwde3mK9JRl/OISupbczovrVE8ycRLqhe/rLu49/6+IqXZSjen9zfE9M+/uyZVtO+Bzu/EsAfPLnrCQQXsgvM1Iye7jGVU2Ibcz47T9M0SWSCcoqu7rLtb2OEqdIUo8s0TIOVcoW1Yq0qww+vcTG7yEq58jOx6o8R2WhRzKScInRhtTGGMIv2z+SMXOT4zudsfpZtf5uu7tIpO1wLr9HX/UVao3b6RIVqJjJGakRgA+ZyzliNsVTpoEc14wuFyCXbwDpHbGNaprWIbxz7/o+LmPzo8T7CTwS5OwQGENJiRRUScbgj/2y1Q1NVNg9WaqSVFPOMO9ducLi7y+rGGjdu3ODc+ZcxSU5/OsNdusTGqbP8q2/9EZ32CnvdA7ZOb3Jmq40QCUVasr56nv4w4/b+AOE0p9sRobJoT1BvtIiiGKU0pjSUBlbWz+C0z+aZc2zfvIVwHlr6CCGRzkOYKgdeKEjzEt8PiEKPn16+RYFPYftIY6g1G4StC9Q6Fylu7JPNr9NaXiWoOXZNRpYXNIzjYJ7RHQxRpcWjxPMEpThOgpEkRUGgq1SqzJQIqZFKVZY5cOkLrzHO5ly/fJU4CPF8n9FsgvAUuGorwDiOieMqBpBmT15l0e/NyVdisnZFtOnqR69GzFv3v64aHesqXdaCGmpqtyXjKMTTcOc3V0nWHGa5IL7mc/6ffcyOwELQ7HHwx45w8J4YhxI4Xz+Uuma5R7ZsqwwhAK3Y+fVlyhqsvlEQv7X7wHfSl9YxkaT27x889l64RszWX77Dlfc984NDOEnbtKst5soWAz1Y+JGPC5kANotNVsoVul53IX87VEPGaoyh8sX7zl/s3HRcjv/IvnBvI+2ZnC2yV5bLZQIXfKzJoWEaC0IuRUkiEqZyikJxNjuLw3EmP8Ot4BbaaWZyhuc8rofXue1uk4nsRBlJVWl6JIHiW0wumatqvwTlqsDpguOoDMima5BEI8beBDP3GZgBeZGzWWwykzPu+ndpl+0q+4ecdPwp8LkLB8qVVezTHRWICIF0Rxaoc4udvoU8cj3IynecZTm7Ozv88R/+ES9fOsSWOf1el/3dHcbDEXmasrt7B8OE8XhKr9fm1ZfPIKzlJz/5MZPyDDf3hgz7Y3Zufpvd22/TWvb4r//OP+b8mYt4XnDkxpDUak1Sa8gcXLz0KjKICJccTnu0VlbIUp8ky9je22VnZ4cwqoGEU6ng5u6ANHUMDka8snyRwcAjmeYUyToHe5dxsz3OnVoibkoOh1PsaM7ewSEmmVNTmmZQ48WX17g2fxMHTIVl6Cx1rZkXBq29qrpXstCZ3zp3ht3uHlYYSkpazRaZyyqLvnRkWcZ8PieKIvK80qB5uuP84EYcHxQP/Z4UNK9OafiK+k5IMEgJhgHDS2B8WHvdkaz46Ln76L54T4O1OK0I+6bK+nkUBJgA3qvQajxB2Qjw+0cHlCR9YZXppk/N32Pekxx8496uRMKCN2Wx6jkBa7n1Gx5uLePSm+/fp9FrbS7Vdj5obz80arZGLa/I/L3kerxb0en8NCvFCgfeAbf921hR7Q16XNo/UAPeid7hQnrhoS6a46ybuZzjWx+PqrDxrn+XkpJMZotMmA+L492PJFWK423/9iIIi4DIRAv//I3gBsKJBUkvmaVFsVHV4SqzQagqBiYAHRiUtpjiXmWzwbBklhZa8VIbnJWMzRSXV+4goQ3WOgZ6QOhCLJY7/h0manLPpaQ/FZt13EtRtICQCgdoUc1wxhjEUUqfOCL/3JZIBwiHsY7ewSE/mnyfRqvDYXeHP//jP+WtH79Fp73KytoKK2sXefvtq7zzlsMVOa2Gz/LyGmUueeGFTfqtBp49S80vaHUitjZPE4QeSuqjatFK0MxhcWlJ3IqxDqI4RknF2sYGb/zkFoc/2eOtd95G4jh37hyerzF3+oyHU/J8xnSYUdse81KryQsXX+b737tBmpb0ihGBTtk4tUHZzYlUjdc2t9i7u48oJKkf8etf+SqnSsvdbpe9dIZxJb6TeLUa+7MpSmnmaVUNayz84Iffw4s9fvFXfoEvfOELXHr1En/yJ3/Cv/yX38STVXaKEIJarUYcx2xvP2E/dXWDhSaLPywXFvyTgAkkZd1HlJbazSmqO6JZruBURL4Eeb26b9QzH43YtaL3yxuLTbDvjwE8CrJ8sHJWZ5ZkI2B+ahOEYLYusR7oueNge5nmsHI7OVUFoVXmcBKkefiP14YOpT5Yf6SBP/r+ax/o3A8M4ZjJGYfeIcIJUpnicLRMi8hGC30Zz3kLd8Nx0dBauUZP9zAYdrydakIQkIqUd6J3WDJLdMoOq+XqIr2yr/top1kySwuiPy4mSmRC3dQ/UjcMhrv+Xfa8PU7np9FoarbGWI0JXUhKSiYy+rq/cCWN1ZjCFczVfEHyAEI6ltZmjPZqxK2MMlPkiUYqS1m8xyAQMFXThXvFWYmzLCYHk0t8GxA7zVzN2fV2q3NFlYF0vMNVrfNp2KxDaT73pa8yGPaZzaYoJVlZXWH71m2yNMNasOJoowxnFxVnztnKZ24Mwjlm0zGz2QyE5XDvgN5BDymvs7K6DKokUG3m0znd3oDuYc5XvvLLrGUaozyKU6v89V96kVBKEIYg0ggUHPnjrS3IMot2lnh5BWHtQjUSaanHIUk6562f/pi7d+8Seh5nTp9h1B/T643AKdI0J4piyrLk4oWzvPraGlffhXkWE6sIZMlgmPONS1+kNsv4i+9+G38yw4kQ6QR/+K1v4qcpSZFz7qWL9C6nqFLw+a99lT/47ndI0pSoUePUqVNceuUVpvMpZ8+f4e//9m8zHA05e/Ys3f6A3//9/7d67rJKk+z3+xRFQVk++QRu6ytMIJG5Q8+KJ0ruTlGJpt3p4eoxZq2JygzhwJKsKrwJYGHwiqbxxkew3EvD0vWEZOPxmyLcaxCEgwcJ2QmBk4LhCxJvWm2irTKqIq6pXqhI5k1B3qx+4LVtgQkeEmSUkvaPJfJvzDn49S1Wv314cnNwT2M69UrNdClApRbskw1WOqCne3R1l1KUi+BfT/cq9UcTE7qQ2MQELmCqpgstleVymVP5Ke76d9Gu0i3PZSWRa7GM9IiRGrHv7XO6OL2QAdBOn8h2kcgT1vp7Vw25yLHYE6sA4QTy6M/xOYf6kFzk7Pg7+M4nlSkWSyqqZ3r/92Mbs5VvcS28RkkJrnLnzNT/R917B0mWXWd+v3ufS+/LV3VV2+meGcz0DMZhgAExMPQEQBILEqGlAZYmqGUopA2SYqxC2oiNpZZ/7MpESMEgVxJ3Sa0IkBS5oAiSIgnvZoDx095Vd1d1+fTuuXuv/niZ2V09PQZAIwSeiI6uzMp69TJfvXPP/c53vq+P1or2VrKDGbY9LFfhpmPiwELFyTHkiJhhMMl7l0mj24odtNBooXHTMeHQToa84goDOZjIIcQm+axfT1zs9vieSO6ZXIF3vO+HKRez7GysMTdT49G3n+Sbz7/Aiy++yOc+9zkGg0HiLSjHQzqJEYXRMRKBhUApnbBqhEDIpMoXRrO3tcPf/eXfszB/EGEZlleqGCN5+dQl7j9+H/VeByk8Um6GlOXA2LVo3JhEE0UB//ef/TFRp4MtJYM4pN3v4Tg2Rw8fIV1M88qrL7G7fQO/16c0Nc3GtTW2d7dptHpsbTfxvBjXlrh2yPxMFiliSsUc166HXN+oI3SER5Hzuz3s3TrBoIvxfWKjqG8rsn6bkmNjZ9Nstpt4lTJzlVnq/S6VmRrGGH7+4x/nB3/gh3Acl0998o/42te/zv/5R59i9fo1fuPXfwM9+rRSqRSWZTEYDOh2u8RxjP5uyBCMm6cGiO/u8Y0QxDkHDtRQaQsjBd7ekPyFFn65gl9Nrl9u7ZbfKwX+4Wm86419vPnXC2HMtwYnjZm2t4VywB4mVf3tek/dFU35tGCwpHjyofM8++UTDGegcs7QemKB/KVkVwKgZkrwY3Xur20S/eIOX3/gOMf/zXUYMZ22n54j3dTJ/IMjvisceoOmb/XxtEdgBVTjKn0rMZ0Z88eHZkjLahGJaJ8jU8/qsRQucSg4xFgp0TFJj61ltdh1dpPFwBpwyboEJlk0FsNFqnF1XxJ/I4x9KIecTZ9F3PJhW1jkVA7bJPLDvvTxpQ8i8X8dyME+rRtb25MkD8kO4rJ3GYPBNS4hIVokJh5jLftQhGgl0EMx+lsQCKmxbE1+MEVP9iZm4pajMEagopv68OEgKX4c7STQEIkiZDWusuVsEcpwspiGgzdO398TyV0DL546x73HDtMf+vzkO5/k4bfdx8OPPMJLL7/CXqPOmTNnmZ2dodVqMzVVQ2tNyrE5d/oUKI0lBForLCm4yRtJvrIthzgwrF1dxciIP/qj68zMLjM7e4Sp6RKFag0pDUaGKCIMDlJbCTxkDLGKCeMAx5M8+/JzhFFIvlrhhVdfplQucXX9KpdXL1KpFchnPZr1OlF/yAYW6VyOVruHMRrXtfAcKBQs9nav4vdC1q/u8OpLF/AsD6N94qCBKpc5MVuh4IDZrjPs6wRncR1i12Gz02Gn26K6vEyqUmbg+0zPzKC15szZM+w19ogjxd/+zd/RbvdoDYb0hgN+67//bTauXSeMFOlUolMfK5VQIqXAfBf44SLWCG0Ss41vIVEKY5CRSSSAb2HXCDPCNTUYCXFaEqddnG5M+kaCaeuMi4yh/zaf9NkUdjBqWqVcrn1kFn9G4zbnOPgne4juLabgjv0aT1ar42Pm3prJcZwWNO4VFC+C17lNf35g6LtAjwSbT42KFEchQ4nQQD5CCk1mQ7D4Z9fRlTydf6bY/Ysqs3+dJPe1DxR4e+1Uckyp0BkFtoV/ZJregpscx4CR353EDklSrcQVcirHprvJcrCcsEHsNl3ZJRIRHbszeX0kIhztkNZpBtaAq95VmlZzou0yrqQNhqIq0pf9m9XpCMK4mLqI9vVbVpCUJjHovhU6iUiSqmtctsTW5PgAfeu1RjDjc9h0NsmrPAM5IJIhQhosz4dh0lgNZEA5LpPSKXadXQICQJBSKYbWEKMFcWjRtJpYnkLGFkYlhtdjiuTAGjAUCaGhqIqJSqXVBZGc22w0S0mV2JW7k8/rHwRbJvR7dDYuIQ7OU6jlub6zRvhMlwuXL/PHf/onXLl8haA3YPv6JqlUijiOyOTSxIMBtXKJOAoTSGEYIbAwMsRITTqTZXHxKAsLy7z4/NdR0QAVGfTAZ2ftGp/72z+nlOkyOzebLATSYI3YJpZxsdNZhhp26rt4liFXyPCu73s3Fy9fQVmCyvQUA39INxqSLxVZWVmh092jNldjbmaWQd9HGpfsegPhgCMc8qUKlpPi/Lkv8s5Hvo9L577GsBnQCjpksmkG/TbDfpeLVzZJKQuLACfjEEoXp1BlI4zY6g2xUxn6nYBL/lV0FOFaNr3hgMj3OXv6DIsHDhBEik7HJ1fvonXAqedfoN8bgEhc1CMdIxwLrRRaaeI7lZzfYVh7HZysQ5y2CSvphCY2VKiUfF1KpNCG9LaPvdVCFzKEtSwyVNjdAOPaqJSN3fERsb45falMkpgdG9kZUvtiC7+yRKpuKLyUyOmKKMZrGYIa+DMx2nMmzU//2Ax2N9wnTjYO5b6WAtk4YZG/rvc5Mw2mJMWHdmkfTVP6TBa3uz/BF1YVQUliLAjf3UF+s8BTD53j66uHcHspqpUevnIoXktu2mAmy9HyZb7+jjyzf50sTqkn97Bu3ZZbhmihgvPfbNHeqzD1ycy3eom+5TDCsJq5hB0nQz7nU+exjU0gA3y3B1qQiZPKeFwJa8/Hj2I87RGKkLpd38cuwST66TPRDHmV37c4QGLVd827lrBFsN80wUskK8EKkYhY9VaJR6PFSih8/MkAEyQLgTTypsPS6NAKRV4nsMuF1IWJLK9la8LhCF4UEJuYXXsX17jMh/NkdZZtZzvhtxsPX/ooFJGMUKEko9NgBAM5mMwDRCKauFU17MYEY8fc7EcshotU4gq79i4A4fAfQOUOmq3NS3zus212uk3+5JOCdCxQIk7oe506jmUzDLoM/S719jYGjSsEJ9/2NhYW54mNplFvkXLTpDIWjVadp979Xj74oZ/CH/h87rOLrF46R+AnRhqxFszP57m+do3hsI/jOPjBAEY0QqUFMZKFlYMoFZLNZLnn4CH6zSbVap7dbpuh6qORpNNpiCOKeY+ZmTx7e2mKhSJg0e8EVCp10mGKQeiTy+XQsWLQ6zEzPcXjjz3KhXPnULGi2+7ieR656RJmKIn6IZ1+l4WpaRwnR7GYIdrZBVvSbneQ3R6WFEgBlUoFtKHX7VEqlxn2BrSbLbRStJoNPNei024hsCjnC/jxEKM0mWyGTqdDHMVYd3mICcA/PMXGUx7ls5rceoTdC7EaPUwuTTCdwUiB0w4THrwGGcSISGM1OqANstkjdZshhwW0H1tAeYLKl/c3gXXapfVAmcpX1pls3cbVuNLMfK1F9UyK3QfS6LSNJQWD4zPceI/Dkd9vveb8RXdAZidHWLT37TiGc4rSpf0JpnBd4f9pDQ6BvsOdJePE0KO7ZKFPFyivagp2gNrzkl2KgZQVsfaDcOKMR+uQw7wwOOkI4zmIWCHF/gPnqgPWPlDi0fQqV6mgLV4rkWDf5UXbgDIGNWKJDLlJoRUjlshQDidQA4x1UAwH1AItuvSsXvJzY60nBHmVJxBBkryNeA22HIiAy6nLHPYPv2mCH+vKt6020/E0G07CGLJIjKuzKkvdrrPtbuNpb0LrvJK6gi/G7yfZoViWxdhUAwRxeNvFHZ1GSILdH/GPsBQuUbfr1O065bh8y3CTYCj8yd9mx+pQjstUoyrbzjZ9K9kJjKmSjnEmXPs1dy0x/hh9Lv8gKvdMJs3P/syHsZ0Uf/n3XyJWsFSr8egTJ0HA5z/3BXa394jD5M0IY3CkhWNLmq09ajNlsCW5UoadzW0WC3M0mnWmpioszE5hScHXviLZ29tEmhijNZ6bolYtsrO3y+7eLmEU0Wq3cB2bTCZDoVpmaXkFzzXEYUCnOaCZ8/CkIhg20LGPbRuGYYAhhesIBBG97oCNjXX8wKdSnqLb76JVjBSCZr1BsVwinU4zU53ilZde5itf+gqDXh8pJbbtIBGUKyW21/cYRn2GashGfZdUBrA0zV6PGV5SbgAAIABJREFUvh8QRjEmUkgBM7MzRGFMqVim2WqxvbFDajlNHERorQkDn4yXxUZgS4kjJKERaAPdVjuZXB0xke52rL/H5dH3neHF4F7KX7s5SCSCiHRzVDZ/G3BQmJMEFcHtQ9txKcXuw4LKVxLmye3MFdnoYklJ7VSCxxvH5toHLazykPbJqTtOtnqre+gjU8Q5a5LgS2ckCcVhf/QXBI+//xSX/qd7X3vSAmJP0H4oYHmhTndjjg+WX+DK26psnlthudDGEoanHjzHsz9zH8ZKoJdHltY4/SMnKF2JWSqs7jvkgzMb3HhXf1LNO339GrrmW23AfSuRjjOTCVXXuIQixDMeAcFosvs2rr9KRu12vW0OdA+jheZS6tIE0zbCsGvvshwuMxvMEnsxEpkM+4xxcJFMwpKCkioxE828aQVfVEVsY7PpbJLWaabiKaSRDORgQmEcyiFKKKaiKSpxhQ33RtLYHDhsOpvkVA7XuBN7wNcNkWDuZ9NnqcU1+rLPQA7wpY+rXSKT0CvNLbi+lQ7pD/qkdArXuNTtOpaxyOosPatHJG/2hbTQk93G+Pe9UXxPJHdjDO94/CSD/oDnX64hnCxPvP1+3v7YA1y4cIEHTt5PY69BHCk8x8NC0K7X6XU73LixhiUNbsol1hrLlSDB9wN+73d/l1dfPs1UbZqvfe1LNPd2sYTBsmwyeQvHzqFNlzAMR8qKCVIvbQeDwLIERoUcO7jM9to6zWaLjKuZX5gl3tkml0khrYhKMY9nGYa9FhjwvDT9QYimQyFf4ug9HmfPncO2LPa2d0mnMjz52JM0dpPqFA2u6+I6Lqm0R1F6DKRNqVohm5thd6+N3x0y9CS5QoH1tR20Atfx0HFEsVii0+1w5coq9xy7h8uXL9Nvd1HhSKogjEm7ZexCnkF3QNDroaXBKD3REnekhWVZ3E3TOmMJwukYRyrK37dF/6U5sqduGbz5bmD8kZ5gzdpKGpm3h7XXBruMtiUiVsx+WbD5tEPx9GshGQCUJrXWpn+swqg3iNcyIARxGuxh8j6CgmS4EPOVb55g4XZRMwFbj0tWHlnnUW+AKxWNn07YEscKOww+7FJwk2RpCYNzX4fhtcSoxLNiOicDvJZLztlPf3OkYiXfmDweTNn78H5jcXPa+i7GVDzFhrMxMbdYipdI6zR1uz6hPUYiom7XJ6YToQwZxBFDOZyIePnyFqaPgHV3nd3MBn6ssIy1P5mRVPi+9NmSW9Si2luSHXa1y3w0TykuTSR8FSpZnEaXKRQh2842ACmdxh/4k2SdiRKVyC3nTmJGo/MaMXHGbJZdJ4FOMDAbzpLXeTadzYQ1ZGsiGaFjgRCJ/kwggyShyx5GGKS+xfHrNvjqNc+9TnxPJHcVG/a2d3nH44+wsnSWl85cpu/3uLx6GaU1rVYLx/WQQpFKZSjks4TBAN/vsby0wPzMNKlMmiih0JD20qRSWTZvbLG3vUsuk+fa+g6xPwSjEUIw4xYZBIZ0JkuxlLjrWLaDbdukUilymTxZN4NUhlI+T/H4cS6ev0Qq7TIIIyzL5cTxExgNxVyBtOugVEhv0KPbG2KETb5Y4fixE8T9IUeOHePMmbOcOXuO+ZkFFqbnMZHFxtoGjuVQKVXAQL/fY/PKDcIwZvHAHG5G0O8NiAaaQbuHtCxUpBHSxk2lEaSIlGYw8Bn0h/S6fYQRtOvtZAgMjQTSrkstn6drdxOdG2nQcYxrO6Q8b6KpU+d1Ety3EWEZllaSav1QcY+v/sA0x0+9yQ85NjqTQvYGSRP5Wwx7p4PTTShpb0nBUWnKX13H9ucRndc21cYh/CDhuY9uLr8iyd+ICYoWzkCjLUFmL0b6Ftl1iREK7YjJBC0CDrz9BovZm9DPTLrLV/vHyFkB95a3GN6yEi2WWlxYy08eHzuwzbA1z9VOlcPF5DMNlM2zV1bI5AL67TT2hou2k0WmPy+ICgl+/JHHnuF/eOsf4ZuGNJK8ynNQH+Ry6jJDOZzAKnmVRyInsMJ8OM+6uz7RVMmrPKEMb+LxI6x9nOQ1Gss4pLXLQA7I6ix92Z9g5HmVpxpXE5640CgzwsHfQJvGxmYlWJnw7QUCJRSOcSaerhpNz+olrkdhhiveleSHTdLQHCtbGgzCiJtG37fASqW4RNfqTow/Rt+gbyXTrAM5YC6aQ0UJvx4B4dChoJL+xFiSAJKdzPi97/vskVTiCnvjxeMN4k2TuxDi/wB+FNgxxtw/eq4CfApYAa4CHzXGNEWSIf5n4IeBAfDzxpgX3ux3xLFm2A8w2vBr/9Wv8Tu//x+5vPoirU6dTCZHbWqafndAaGJsx0MLuLG9Sc7zmJmeYtDt0m23SRUKFApFjBK4ThrPSbG7s0uj1caPAoJwiGtZeF4Ky7VRaBwvheclEgJWr49l28mQcywJ+wHNMOCqd5UwUqxv7xCERZqtOkEcEsUhloGOs0McxQz8gGE4ZHurwV6zjbRTrF/fIm0nNM12u0vgh6hYcWN9A6XAtiwK+QJGJwbd/jBg+vBBIhVzbXON6fkCg3iAtjyGnR6u42FihbBsQqVwXJsgjgijCCkt1q6vYQnJsD9EIjBao4Vhc3ODlGUjYrClRavZIPADpLSYnZ9KBMSSSvqoEOLi3biutUxvkogsYTDytmrWkkxE50ePz/6zGeaO7JL9l2Xsnc5boiu+Xkw/30M7b8LxlgKdz6JvNb9+nWlWI0lkhgVEeciudsk/3yVcrrH7cIbiavIzwxnD4l838RcLtA47NE/GzHzJQt4BHlFGYgnNjp8jf1tVfiu0MZPpcLayhKVuvp9dP0fqTBpNmlwAvRVF/e2G+h/+McM/O4tbTnPf7/wiq/0qgCWE+Dvuwj1rRtVmJa7QiBvU7TqXvcvs2rt4xmMxWCRlUggETbvJnrM3SXi+9IlERNNq0rOSfeKt2LwRhjA2ZEwCY5RUKUniKEIREouYht1gIAcTHZixXrvBMBvNMpADCqqAFnoyQZrWaW64Nyayur70J7DSmHM+5uPfCvVYWEgjadmtkWBX4uw05qiPY2wbeKcYSx9robGMRUmVkuQ+ijEF81YaZld2mY6nJ9RNBJNhrfFn8Gb73rdSuf974H8B/uCW534T+Kwx5reFEL85evxfAz8EHB39exz4ndH/bxjaKE488BS2O8t8bZ6P/NiP8Jv/3ZfYur7FgcVFlEpci9KpLIcWV2h2euSLNarlHGfPXCQaxjzxjreTy2RpNLtUqlWE1kgp6PtDXK3RQx9LS6RMLpYOIoJun5mVxUTG17ZoG0klX8Roje2maHWGrF45T/7MeTCG++57gJ2tJhqLVG6aeDDg9KlXGHZabG1u0+n0GfrDhGtvWSweWOKF517FoJJGaTaHH2vuP/kQrX6f6swsjz12nHp9i0YnolScxxY2lvQoT9W4eP0iWs8Q+D3iWBEPFWnLwkEgpSabshFSosIIKWRiHxjHGNsh0BHIhEFiOx7a2PQjjdGJEmQul6FYzFPfa06UItudDkDXGHP0blxXdRseMHdsF13JIxtdTMrl4i/MIBQc+NsA73qDrffP8fCDF0lZMef/uU2nn2fl3xrsjcY+dcvE7EOjHTl5DOx/DSAihbDugElIgXYt0Ibt71+k8c4Qa0vSOnqA3JpGxlD+xibEt3AJtaH4/BbGthCxQluz9A4VKDS6uOsNoqcyCGVI7Ur8aY3oDkifHaCdOcwPdukcqrCS2l+FKSOo2H0eTl/l+nB/96CW6mM/cH3fc87AsH2uxpEn9nCkYhg5HPjrFqs/UUQ82kFeLuAtdHA/eBz3Y/ez+m/+krwXcCy3AzAHfOpu3bN5lZ9UxH2rjy992nYbTMIHd7XLTDTDUA73QS9jTRbPeJMqNRD7F7VABGTIkNIpmlYTg5kwVTzjJV6nJklwR/2j7Dq7k4ZlKEM87U246ONKeiy+NWaajAeixqbaY+W922EgY8UYpanFNTadzUkjdCjGFMtb/uZug08m0r/C0JVdPOOx5WyR07lJn0KaZCdRVEXaVhuNpqAKpEyKqWhqMiOQ1dkJBbJhNTBCE/lvnL7fFI0zxnwJaNz29IeA/zD6+j8AH77l+T8wSTwDlIQQc2/2O5RSrK1vs7R4CCEEK0sHWFpYIZvOk05leO6bL/CNZ7/JubPneM/3PU0Y+Dz/3LPYjqTZbtHq9jj50NsJwpiZmTkefOAhbNshCmP2dutsbmziui7pdBpp2SAEYRgxHAw4cfx+PvDeH+ATP/sLPPnEU/z4B3+CX/nlX+XeYyc4ML9IJp3Bsh0cL4W0bR567F2cfPQpVo6+jSeefC+Hjt5L348IY02sDMlEqySXK3LPsRMsLh7gXU+9mw9/+Cf46E/9NLbj8c1vPs/s3DxahTzyyEn+0U98kMX5GfrdNmkvMdTe3dnFGEEUagI/wh+GiXKlZWPZFrZlUS4WcR2bXrdLFIZAwlf3Ui6O51AsFUmnUuTzhYQFJK1k2EVKUulkiGl0jTHGMEis7Mblx3d8Xa3bbpRqOhERA2g/NM3jT53liadPU/iXa1z/yDytE4a8E+BIxf21TR4/cJW13zBc/9jyROHRpFxWP77Mxg8o+vPJ3XT9Y8vJcM9tEVbTDGZfa4lnUh7r78mgUxbTX62Tf8lDxokeTWdF0j4k6T44c3MAaxyjZB8ulMlsBRjJZPdRPROTvtEjs2movHzrQgRSaha+MGR7cBNmcWXM6b1Znmsv49/BxciRiqnU/g7IsCopnxaEo4nT6UyXOO+x/Fd9LKlxuoKMF1F+cBE7n0zV9kOXepQFKHGX7lmDoW/1GWvHLIaLN4eFREIhHMgB2872JCFNQkAko6RqF5DX+YkRNSbBrm0SWzyJnEzAZlWS3BbCBQ77h6nFNQIRYJlkknMsdSBGFMOm3WQoh8xEM8xH8/Ss3mRBco1LJa5MBM5u9XEVCMqqnDw24KQj+nZvQluURk4WgnQxWZTEWAPrtpgL5yir8uT1gQzwpc9l7/JkQZmL5lgOlzk2PEZBFcjpHCeGJzjqH6WsyuR1nrzO7/scLSyEkUT+d8dDdcYYM+6MbQEzo68XgLVbXrc+eu418nVCiF8CfgnAchyMTtQfBYJ0OsN99z7ApXOn2XVc/GGI74dEcYKaRVGMJS12t/dwnQxPPPI4X3/mBZ5/4Rk++tGf4srqVYIgRGmNKyQIsFyLKIpw0m5Ce4yHYEG1XOXA0jLZbJ5Cvozn5cjny2TcFN12k36vz85wgO26xFzgqy9eotHuY4CU63BwaYbLV9foNhrJhKcQTM/O8cCDJ3nhxZe5eOkCjz3+KPPziywuLqK14ZlnnuX9H/gAftBldqrKdLXM4ZWDXF99huXZBYIoZPbAIpeuXWFnu86gHxAMI1xpE6kYpTUqDOn3ekRxjD8coFQyvaJUovQupSAMQ4wxk4axlHLi0gQjATZuaryPzDrGOMh3fF2Ltw3/VL0+Vx71KEwt0vzJPoeFxhKGouvTP6C4bS3AEobHF68RzNlcaB4nTgvcjuHhHzqDIxWnZufQn84zmNNsv6PI/Kf3sxka97h47TvAK56NShu0Izn/SxUyGyAjcROjF7D5pEXuQnr/kBOw9qFZhIHFT29iF2pc+JUFjvxhMxnzF8kxOocE05bc1zOweiGDEaQSaYsvPncvxXMW31ip8M4fvvSac7xTFK9GpK934eOjx67PXsUh85lX6V57hNzDTSz5un0K+zu5Z2+9ro7rkNKp0UclEmkAK/n7M1riGneirli365PkJoxgIVxgz9mbVPPCCHpWL9lNoycVbSQi+rJPTuXoWl2MNKR0ipzKIZGJEJm9g2UsloNlalGNVzOvsu1uTyR3felPNOYP+4e57iU7oVCG1EV9ktDHTeGhNcQIM4F1JJLWsI5jkkGmUIYJ3m3vJfzzrgsYKnF1AslYJqnWEclk7XgBuXkR7Em/QRiBox3yKj+RG0jr9Fvi8QOkC2+s0/8d99HNa1yZ3/LP/Z4x5hFjzCPpdIZiqYA2IUlu0VQqFTZvbPHNbzxHrzekUp1hdmGJL3z5KwS+IpcqcW11g+UDR/inv/pf8uij76TRbnNtY5XzV87gZiS5QgppaRYPzHDg8CxH71uhWE1Tnc2TykuEE2FMSKtdJ4x8EAajNSpWWJZk+cAyJ46fwHFcyuUpypVptBYUihWq1WkWFxaZm53noZMPUygWKJWLFMslalNTXLh4iSCMcd00QthcvHiZUrFMNpvHtl1WV6+xu71Nyk2RyxZYXDgwgkwMjz3+CB/84I9RyJfQ2kJKF4wkCEMG/hDbtomiiO3tbaIgwHHsxGVqJJEsRKK7k1gLJsldKbXPoMOyLIRM6Gljs467fV3jXIFA3cbLfvcO5ud2eWLp6r5hnEdOXuLY/XcWLrOlonXCoN7bIs6AHGGTU9ke7eNFqi8LRGwwKReT8pAKdDFHblPRXRLUv29x3/FEEJO/mqgu6kKMNeQ105x2T3D9x2qY/C1DQZbE2JDeMehyju6SjcpphNa0jrqISKE8cB9qsvfUAtc/dgC/YtFs5Aim02ScZN0cxA7Za9ZIb0fy5eZRNvrFN/1st56wka0ua91klxJpCxkZZCaDySgy7lvrT3w71/bW62rbiSbMGCvPquwEPoAE2hi7LxVVcd8EqmtcpqPppAIGOnZnMvE6hkUkcmJN17W6Cbwxwqzh5oIy5n9bWJNq3DY2KZMiq7MsB8u42kUi8YzHVDRFMS4mlba4SREdY97SyEQ+WwQM5TCRRVCSkIi21cYXfpLYRXIWRguy5cTIw8YGkyhFzkVzYBKz7T1n/PokIhntw9aLqnhz14EgEtGbJvaUTiW9gDcRj/t2K/dtIcScMWZztIXbGT1/A1i65XWLo+feMDzPpVTOYQgYBiGRsbh0+TxhFFEpTxHrOkvLB8nlS1y6chUdDikXM1guSCsiV3DJVzyeet9jRLLLxt469508zEp7ljiIeO/7n+ba9jUsy6Lb7VKrVen1+mQzWYZRi288d47HHnkcacX0By26vSzCEtRmakSx4tiRYxw+cpwHHnyYXKGCsGwsy0YKUOGAdz/2IP/6t/8VGsP29i6NZoMw1iwvH0qmZBeWaLXaXLu2xszMLOVSmXPnznHv/Uf4w//4SRxpsbbdpd7scPTgIYZ+n729OouLK5w+cwGMhTECN5Vi7AFbKBRAG9JeCmlb6Jxhc3OTdCqNtARCgmVbk2p9bKMXxzG2baN1whpilOghgXy0jhyAu3Fdo9jiUqvGfdWbFLITle07vjZrh2TtEGUE8Qh2kMKw2qmwcbVG6bLEXC1hDwxfPnuMp05c4NJ2jenY4HUMbg/aJ6dQnqB4WXH+l4rMfUVQvKIpnb5tOElrCtcj/JpD+RuSKDditNwS4yo+mM2TGglzbf7gAmHJIJ9oEXy2TPveOIEhCimC93RYS1dRKXhgaptP/Is/5ev9o/z755+kUBxy4+MOR6RCGcGpG/NUNjRRViA0vLo1R/YvCkz9cg9H3lkzoBd5aAfiGxtsbz/M8fIOl9tVKi9cZ/cn7qM2V2fnYo3akfodG7dAfLfuWUFC+zMkdns9q0esE0gSzCRBLQVL2CT6LEM5JJIRO84O9w3vY8/em1jWxSJpUkojsY3NfDiPJjGrzus8BoNGTxK6NHKSMCMRYTATz9Ou1aVhN5gNZydN3fE5l1UiEWCEoWE1Jvi5wZCP80xFU1xOXZ6cT8fqYGPjGHsywHR7DFoesdWf6L1MR9OkdIq+7N+csr3lcjjGmTBnDIaG3WApXJqwbXpWYl/4Zrr2AHH4xrX5t5vc/wL4OeC3R/9/+pbnf1UI8UmSpkz7lq3g64ZlSc6cfRnPGvClL32BGJfPfu7vOXLkKL/1L3+Lv/3sZ3nxlVdBJM5Cbkrznvc9xNLyLM12n9bgCplSnxMPzBEqxfRMjmqpSilXZG9rB2W6ZHKSOI6o1jIEYZtyNU+v22FzexXbdTh7/kX8YUgu53L23C52Ok+MRhvNfffdz1R1lgNzS4m8t9EIYSWuRm4aT07x+GOP8NKpUyweWGR9fQvtR4RxzKFDhxkMfNrtDv/u3/1vDIdJ5W2ExEllEXaKxx9/kvgbp1jbaNLu9dna3kI4KRqNJnFsMEYCAiEljueRzeVYXlxiY/0GcRQS6RjblhSLeaSUWJZECINlSWzbJo5jpJTEcTyx2Et6j8kfyTjRZ7Jp2q2Jrft3fF2FL2mertF/skHW3r+F/MIrx8E2PHXiwqSCV0bwzLUVcl/Moh2BXzOUzxlmSSQLINFvmf8bm6sLVfJZHzHysBQ6+Sfj5FiLnzP0pyWZvURES82W6S1nKL6QLC7aEgxqEq9tCCp3vpFkBFd+WlJ9dgFhICoIrAFEL5WhAt6uxcqfdxgcyBKsScyywjgaV8ZsxGUu9qeRjqJ7rYg9M+Dcq0t88J3PEwc2yhtpy0j4yJGX+OYzJ3jpQws8On/9jufy4umDTJ0Dq3Szt6C0ZPDgEuon6zhSk96WmMPiNQNEo2hxl+5ZYQTSSHbtXVZTq4lCInCPfw992WfD3WDVW2UhXMA1LivBCh2rw1XvKqW4hG1syqo8GbsvqMLk2GPeel7lyekcM9EMO85OwjgxiTl0Ja4QixiBmFjVjSdIcyqHp72JiuTtkTZpVoIVeunePspiz+pxMEj05Ot2PTm+Ecl780LSscfQRJPeQlL1i5HVRCIzEIqQrMqSNokGfMfqgIBqXJ00lqfiZFDqVDrhBN9qsC0Q7Dq7E7/Z14ue1SMSYbJTf4N4K1TIPwLeA9SEEOvAvyD5A/ljIcQ/Aa4BHx29/K9IKFWXSGhVH3+z4wMgJZGIeOXMGWpTh3jx5VcYDus8/Lb7mZmd5pd/4Rf59V/7dS6cP0M6bXNgIU3l2DH6YY8BQ148/xzCxAyGPoVihe2NXVQEzXqXa6uruG6KQ0ePsLe1QzbtMRwMScsca5c2EYFFvlglX6jheRleOJNosc/V5vEHPRr1PRbmZymXp+n2O6xtbXDlyhWO33OcAwuLGB1w9uxpHn/iMda2N1hdvU4cR/iDATqM8FyH5178JtevXaPdbpPJZNjc3qA2VcHBYbc+4NVzq+w19iiV8gRac+HSVfaafULfx5EQGYPRoKIY0Hgpi6PHD7K9u84gCBLuu4oTkbNY4dheUlsJg21bRFEwweLH//Z2m/hDH6U062ublCtFKrUS7VanMKJCfufXFXCbguevHuDdRxJcOdIWr/zpvdzzpQ4IwYtP349QMJwyxEVF9TmL9lGYfVaRX4POAYvhO3rMfGpkuvw4LHzREGlJu5shI8AIiDISbcNwRiSqiwLcDtz4YEzu1Sq9A5onHjnP9n97CG+tSXfJpr/IG5pv+FOGd9x3iefyB+BqBjC4bUGUS+iQblsgBwH5l7sUFhdpnYyw2hZfu3KYjzz2HCVniO0qImmwT+coXTfwTnjkyFUuPH8PVmAID/p4Ima4UiIc0T6VEXztK/dRvAgnfv4sjlTkrthMfXEdysV9yXvt/Q7lUQ7z7x2Sl5rLv/1puq9cJ+4Meekf/6+kfuU+SDD0D9yNe1aLREo3khGudie6LQLBfDSPbWyuele56l0FkoGnWlyjohJMuS/72MZmJpphw92Y2M9pdCKYJbtIJJax6Ka6k8XDCDPxLU3r9EQ7ZiATxdgb7g0yOkNBFSipEmMjjlsreEgS6nQ0zbp3EwaMRZyYXiMoxkUadiOhO5qkQlYku7QxxDSu1Mefx2w4S1/2SZskWY8pl5DsLpRQEz36W9lBSqjJuTnGQaMnVn6vF7axEUai1XdYuRtjPvY633rfHV5rgH/6Zse8U6zf2OLxkyd5/1M/jGenMKZFJedy7cLLPP7YO3nkvmNcePkZpOXRawZsrO+iMy7Ks+hfvIQrJWk3RXuvRX2rTm4px25rl15rwIGVGTr1Lo5xybs5lmeXSafzlB+q0u01CYYBXlpRLGfJ5/L02m0OLi+zvbVFtVLii1/6Aj/6ox/C9WxSuSwPPfYYnu2AJej3+pw+e4Z8KUuv18FoxfR0jUx6SKmUQwpDp9XhyKGjdDodVpaXubq6ShzF6CjEcV1anQ6dTmc0SGQRxpp2u0M+l6NZb2OMAmOQ0kLFMa5rM/R7aNTIc89QLBYYDIZAQom0bSsx8h6NyAvByAA7abZOz1YnhtnACJMH4IIx5pG7cV2FSeAN91IadTiBW660q+Q2NLLrEywUKV1U2ANF64hLOyvIbivqj8J4C2z7Bn0Lt7t8RhAUBO1OFteNaRzPIEMYzmmeeudpam6PL24eYa+eJ9h1cdIRxz58jfVuiXN/eJyZRhvjuTQeUqRv2MTZO587JPowr/zFCQp1g1DJtGflTB+77dO9p5Q0UTd3UIMBmZ058p8RKMew92CKhSdb7AQ57JdyeCFET3SJmnk+e/0Y901voR0STXdheDJ7kU/d+z6myzdJact/E+I0fPTP3UxKZuijdnbBVCfPOSu9SRN1qpLIORz+zQ/tex/31G7w5X/1jDLG3JV7VhpJTMxMOEM5LvNK5hWUUFxMXWQ+nGc+nKdlt2hb7YmsQMtqJfZ0Vn8CzwQiSGS8EQnsMRLKGkM+t0rellWZul2nZbXYdDY5EB6gFtfwhT/B4KfiqYkhd1mVyahMIisQT2EZa2IDKEiGkDIqkVDI6IR2ORPNJGbV9h5pk2YohrjGnUA/wISSeSvVUaGIRMSh4NDkMxovKAaDJCleYytOnJRG/QmB2Gc0cqte/RvF+FUq/u7AMnc1fD/g1KnLTBWmcN/nUClmUYMBnWGPv/pPn+TqxQucPf0yuVyWbr9LuxvSDTQH7z1Ma9iiUMyxMDtLe6dBynG5Z/kw5WKZSrrAfGWGd7zjnTiuy6Db5Orli0wVS0jL4+l3P83nP/+37LU6zE9VmJ2dJpPKsLezw/T0FM1mnampKtfX1/nMX32GhcUVvvbEEakyAAAgAElEQVTNb7CwtES72eaHvv/7qe9u4XguGsMjjz5MFGqEsOl1B4nevNLM1MoYo0k7Fq36LnMzU1QqZbyUQy6fZTgcMj09SzqVod3ugJFsbGxRKlXRWo0ao5qxcUjgB8zNzlMoFOl1hyhlyGZzpFItokghZdI4HePt4wRuJnxwJgyZcYN1bNxxt0M7kN4yXGnXyLkB1u9XQUDvxO2qMJBf6mD5GXCTRW58sztuTOdAmrAMy+++xvmL8yx+Mo39n2+x8aAkvpGhdE5w/cEyC7UWjXaW6uc9Kqd7bD6V5+B/Vidjh5z+EYl+wcPea3P8dyVb7yoR5nndUe7yWZj687OoZvMmh94YFJA5nTwc12/5P38eHjpBnHPIPtDnf9z8fp5fPYCZ1XgNid9Iod7Z52Chw3KmwUslg9ODUqnPg26P4ZThbYWbyT1OWwSHcgx6JQ4X9whLhmufOMLiv97B3nTh+F26QN9GGGFo2212xA5z0RyudhlaiT7LprtJz+pNcPkxcyQSERHRZMzeHuPYJoFgIBHRyuosh4JDBCJIbPWMS0mVcLXLtDXNnr1HRmcmSTNtbjKyqnE14YoLzY69M8luW85WsqsI55mL5rCwKKgCh4PDDOSAoiqS0kkyrsU1ynGZXWeXbWebUlzCM4mp9Z69N9F6GSfmrtVFC03DbrAcLDPmzo9x87FOfVqnJ0NTEklWZyfyw+PQJE3dMRPp9SLxYJWkcm/cQP+eSO5hGBL6sLq6yWc///esX79Eq9EmCAecPXuJT3367zDCZeBH2K6DjUJbDidzZXr9DhkrlWg9F6qUslm0MeRzBUI3YmNji0qhQCafo2MZOsUiKgxx0h62lKQ9j7nZKeJwwPbWGq7l4Voup0+folQpceK+E1SmqviB4tTpV/nqV7+ItCwc20NHAdPTVXK5FF4qxdTMNEuLy7x66jTFQoFjx+7h+tVVmvVtFuYXaLfb+P6QBx98kEw2y+Wr11lZOYDREn8YMxwG7O420JpExH/UAFUqwhBjjEMQhAyHNpblkPIyuK6H7/vkcllc1yOKRsp1cYyU+3nsQoiECqkUxkoS//h7wGQRuOvXtyRorlXJT/Wo9DRuJyJOW9j9mKDqYQWJvOk9tR3Wa0d55/FzPLd+AhkJwhMDPnHPN/jf158msyE5lK+zWqqS2rPoxTYpL8K7p85eLc8PVK+SkhGq49L5wT5eJ4Plw5+dOck9v9XDfaiCdfkCcaMJ6zeYu1YmfPAgm0+mXssdEYkmu2qO5BjeZOGT5TJnP5FGZGJkN8MzvRUggW60a5CZGGlpmn6aoXKICobBvCSlJecj745SCZmtgMu7SXLXDhNdm+qrBp5OvrafzcN7776x+ZtFNU5UDE+lT00qbIBYRPjSZylYShKbfUsze1SVd60uGCjFJWpxjazK4hqXs+mzpHWagircsYLNqzxpncYzr51dACYV+Ww4OxnvBybnd827Rttuc3x4HBt7wua5PWzsieG3wZBTOWbNLGmdnigzFlSBalxNfGCVSy2uYURyL41dqixjoYUmq7PMRrNc9i5PGEFT0RRNu7kvuRuRuDS9kZSCIbE4VG9BrP97IrkbnUiZBrHgz/7q/+HalVNEg4AH334SnSrxzLMv4g8VWtoYI/E8m4fe/hD3Hj/Ge979GDfWb3Dh3GWWj8yTzzgMBkNsKdG2TRAERHGc0APR2I4FxhCEQ+q72yilKNVKzCzO0+kOWJo7SOQrHC9FsVTg6vVVFpeXeP75l/mLz3yaWEU4jos/GPI3/+/fUJuq8olP/Cy9gY/rKba2dhn0fSxL0uv2iFVMEPlYriRbyBBEQ5574Tksy2b12gavvnoafxjg2Cm0FvR6HSLfx5hkR2PbNsViAWgn+cUI4lhz4cLF0ffK9PvruK6bSJOOaI1KKRznZsU+Tu7J1+xL7CN++4Q1c7dDWyAGFofKDbaqRbYfc3DbMP/5AVbeQYaK1oPJDdg+JPnx2gt8de4oCMN989scSW1hZCLW9dzOEs4LOcJSzMFCg6+fPkJtvs3iTJKEX+kskJ3uE58q4HYjmseS9yS6Awr/10VuvSVUs4mzN0Nqz2Pmqw1EezQ0JAT9++fIvrrBWzUe1M0mU88eYe/pGKNBWDA71aa+msLpCMLQIjSC8Ct5vvw+C7sviFOGYeCwo/J49f3JzB4qxNdfhV954A1/b/VshP/et3iSdymEEbjapRSX2La3R6yTZPGzbE2g+zRVk5zK0bKS5D72RNVodpwdUjrFUf/oJFHHJLz4O9njjcPC2odl3/HcRtX3prv52mlTYSYw0JuFJLEENCRSC5rEcalu12lb7YQXP7IG7Fk9ZqNZFOomBCPiCdXSl/5NuQa7MXnuNWES7D8QASnz+tV7LMb9szeO74nkjtFsb20xMz+PsTSVhVne/cR76XV3aXe7nHzgXl58/hUcy0KQWFYdPHKAUilLpZhn67qgnCvj2ilc18JoEiYLBulIeoMeVir5P4hCwiDAS2e5duM6pWqJUq1Mp99m4Iecv3yelJ2lUChzY+sGsY6IYoUymn/0Ux9FjI3qLBdpubieQ73ZQiAY9LZQ6kYChQhBq3WKMAyJYvjqM6dRUUy322V7e5u9vV0cOxnI8v2AdDmLJS2arRCtNXGkwASjRD0aTBISpWLCMKTRaJHJZAiCxCW91+uNcPVkWu61CX0/l33MmrnVXm8sHna3QxjIrkm2juZJ/ewWT1TXmXa6/P5j72DxDwQi1ji5ECkMcdZQtXrMLDTRRlBwfLoqjRUIpr+0y6X3ZTjyn3a4/DNT9BpTOPkQbWDz5Vk+NV8i/400mY6huyxY/ZhBWD6m46CmSrD+WoafOX+FqYsCFewfgffW1t9yYgcwcUz1ky/SOfgw0aEhRgm2dovohYioY5M77xI83EOlYNBJwwGfuak287k2//zVH2f2ZR9+av8xrSMrCJE0WIWGKJdw2v//Di0017xrLIQLEymAMedahzaGBGcfNx8FgnJcphbXiEQ0UY+8tWq1SIaRWlbrtVOt30IIBGmdphwnbJye1cPT3kTuYD6c/5aOHxNz3b2OYxwWw0VWghXW3XVSOkXX6iaJX2hadousziKMoGE3GMohWZWlbbcZ+6bORrOJRo2xaNmtCdd/HDY29w7vfcPKHdi3U3qj+N5I7kCrucWw32B+qUx2rkIUBVw88zKWMhTSDvm0y7A/RJNYcW5sbvLwQ8dJp1M4liCTSmGiCFvYKEcSRBGRCol0SL21S7ffoNHYw7UtKjNTFEplvFSaOFZs7+6hjUAbQTD0icMO8Y1thGXRbrdpNBoEoSIMI1Qco5VGa4hVMiiEAK0NKjaToSFjDLZjEQQRUWgIgoAwCtFKMxwOiUKfSsnDcz1qUzkcx6bfH6CNJp3NEimF67gUigV6/R7DIEDFMY5jEytFGIXMV+bZq9dRKqbZaBGFChUbBAYVJzxxISRCyAmqYEaJXykzSuYSKe2k+BLfHVgGILup2V4vI1KKd01fZiss8F88+Hl+/+APk/PSTJX3GMQuyjNsxclWeXetjGMp/th/hOJ52Hz/NI7TZuu9U7z96XO8ujOHuJxBb2ZJp6A/D+0TMXKERcq6x/SzUPnaDeJra3c8LxO98ZTftxLa9zn8e1e5+KvLqMWkMpOuInewT08UsQUM5jWVcp92N81WvchstkMpM2TjXUWmjUDd4vkZzhdBaHzlMPuMYu8fD9j8JycnJty7a2VKB23aewVq1S7Ndhb3VIbg/2vv3WPkyvL7vs855z6qqqv6SXbzTc6Qw52ZXe17vGtJ0cOCYHsF2I6sOLLhSDEMWHGcIAoMJJLzT2AkgeMgBpTAsLOxAliI5Ie0a2stR5LXK62EXVmzo92d1czODGeG5AybZJP9rvd9nXPyx7n31q1mD9mzbJJFbn2BRlfdunXvufd0/87v/n7f3/e3ZDjynNNa0UZytrZXPeT+kSoXEz+RnuByeBmTSZqmyXw2z/Xg+phOi7VOAvdIdoTQhiyny2X8uYBAUDM1jGdcww7rHUjOdz8oFM9Gz5bMlOJY9gDe7l54eJxOTpcef9M0uRhdZCiHrPvrTnSMhL7ss+ltciw9xoyZcbx+kbrKW2FKmiY4WWMtNKENXQvCShHTQa65bg7W9nEijLu1oLMBr7/yNRTPc+bUCb7+5lfZ3bzpYs+Zx8LiUXrDNaSS+L6HMQpjFU5/XRCGEmNStPbJ0pQ4iRlEfZaOLBDFfaKBIQhCFhaPEIQh7d6Q4eYumbEMBjFJ7OgQ7d0e2zu7aOURRTFRFJNlBq2dNy6Eh9aaNM2cYcf1bjXkRtMY4jhGSonv+1gLaZKhtR79AMLz6HZ7iFmBNpqt7QHGWOIkIY1NbsD7xKnz5KVSZDrBC1xoKskyur0enV4bbTKM1jTqddq7LiErkLnnLnOjLXIjr9HaAmYsTCOk5I7OzYcIlYIINf/xB1/mw41rfL3/FF+49WFUBMYTDBOfN7/2FLNrsJouYqzg3L+23PjpFh+78AZvfuIE1Aw1Izn+hXd5XT2HacGxVzNufVqRno4dO6eZYno+S19XrPz2Ktn79MDvF9mNmzz9uUUu/UwdWXNn7nVrsBQjhKV+pks/CjBWYA0EUvPRpRv82/MLvLx2kmeXb9P0Y67+eY/auqJe3+WPb57g6bfabLyyxLGvtGk/5xKQjWseQdty5p8r3v2JOZa+6tP50R6zjdFTiJKGI373UK9RSEunvsmu7HA6Pl2yRnqyRz/I5XmtGSWqc0Pfl31axum7LGZ3JtR963MkOzLqq/od5vervPFqwvUg4Zj9jrWXlqhwzTSeiZ7haniVHW+HpWyp7JpUaNBcjC6y5q+x7q+z4+3gWQ+DweTspoiI1+uvczw5fuDesIXImG99ug9IW+ZQYY3BZJrbt26zuDjHsZUjvPLqq8zO1llcOEIWWYKmwCqBCn2U9Lh5/Tar766xND9LZlwsLUojepF0eitJgucHnP/AU8zNzTHsD7h58xavvnqJXi9CqpDtnTabnQ4g8P3QyQQHdRCKqN8jSRMn3pVmztNFAKaMaeejR0rlGjcL8JVCWsiyDJtpF/IwBmENSoCQAqEk0g/QccLOdpv5+XmkUCRpjBQeUeToXWEYltWkvu8TxzEzMzNIKUmTjCuXr7Czs4vvewhpmZtrsbm5QZI36bA4QapCN6aIsxeJ1SpDpppYfRAYHHVVhS80r/Kl3ecZap8zMzvsyFP4fUPtlxeY30rYvRDw461v8TnxMWTilO9mvSGioZGbPnPHdul+8hSdCwaWYq4/D5CCFcgbNZa/bpj/2k2yd1cfqlGvwn7925z915/k2n8qxgpN4t0acnGI52nivAenFBZPaqRniIc+7aROw0upnegTL/iEWnJsvssbf3MFr2NJjtaJFiWx9ph/2+D33N+h8Aw7Pxxhdmvs7Lp4beNon1Y95uvds4d7fUaiYx9NVoYo+rKPhzNeJ5OTXA+uj3nKba9NkrinpIJFshc1W6OWjWLNFldJWkjlloJeEwCBILQhT8VP0VVdIhlRy2ou4WtCdtQORhhOJadKxce2aqNQbuHDxfUTkXC5dpl+2i8reu+Fguc+2Lk7q2YijLtAEHo+S0vzDIZDjIWzTz9Fa67F5sYOG7ttrHEJkTiLsPi0d3cZDIZsbG4ziGIiHeN7Hhs7OwT1GufOn2dhYQFtDZ1Oh62NLW5vbLOx0WZzs4OrF1GI2gxSKYT08MImSWqI4hhjEozRjlooFNiiMpUyfl1UdhZVodY6IywECAye5zEYRhidlYa1MLJGW3zloTPNzo5LOhUaMNY65l0cxy6843kopVDKw1qBlApjDHGcMtuap173adRrLCwsstzt0uv16HZ7GJNhrSpj6dXYeyFJUAiK7acvc9hQnuE3tj7CZjTDYjjgJ5b/iG/J7wFrCXqacHMIFwJeik6jpOHqX/Twwpgv3niWZ/+XXXY+eZTb9SX48RSRa8ObWOFv+pz4vYz67/8xpt9/ZEa9hLXUf/fbND/wUXofjUoDL+sZ8dDHa43SuolRJJWGq/0kYBD4xJEz/mni4SuNmEvQc/DOXxW0ZncwCGRqWf3LLt49U0/oX2/x/N9bBWOgXuPtv3YcPnHvpg7vFwJomIajAeIEsAqJ2g1/gw1/oxTi2vf2YJnXd6p47t2nrdqkMiU04YFDEQ8ThYF/fvg8mcicnALOaD8dP00iEmq2hpe5itwr4RWOp8dZDVbLuHkRjrnt36apmxzNjh78/A9IW+bQsXx0maDh04969Ad9oiRm/eomw0HM4sJRNm6vE9YkJ08sc+7UWWZnZkiSIRubG2y329RqPkeOr9BqtVhYWEAIwdr6bTrdLkJIPFVjcXGZRmOJk2ckaWrpdgdsd7tEccxgMEDrDmma5UYwQesMIRSZ1iil8H2fJElJ0zQ36M7IOk/bhTqGA1cBKoREqYBa6OMpSRiGjtqYZsRxTBRHRIOYOI5LWYBWq0UQBCRxRL/fp9/vY4whDEO3OBhBNIxzI+zCKWmaoqRkOOxwbXWN5kyDmVYLL/DpdgbEw6zkvFtryzEX3nux4BS89wcGAf63G3xNnOW//vCX+erueX5l/dM0bzpT7Pcy0vkarRsaX2hub81x5t8aTBBw4ycE238zxC4mI6OeKpqvBZz6rW24cg0zGOwVlXykMIMBp37pEpdOPwNHxpO1/e7I4/KEwavQ2uL0zn/JOBttE8JirEBiufnjCepGjdqGYPnrkqt/QfDm/7aMkAbf1yzUnWH/WHN/SYPvFMIKZvSMozRCqZe+Gri8Rs3ub9g7XodFvUihz3435otAMK/nR0yueyQZHxWKOLzFjlE4QxuWTCAfH2st56PzpCLlln+LRCRlDUBTN0tNmb7s0zCNez6hCKAxG991n4kw7kJAc6bB4soS125ewwpYOX6c7EaGEorjK0eIert88AMf4U/9qe9nYa7F6up1Go0ZtM14+pnzLB1ZYmFhgTTN2NzcZHt7C5NpmjMzJEnmWCVIjND0Bl02NrbYbXcRynexcZ2gjZML0JkmTp1BF0JipSBNUgaDQe5VyxFnPDeU1hg8KajV6zSbTaSSZGlGt9sjs5osSVFKEQYBAgiDgMTz6Ha7pGlKFEX0en2aTcFwGJGmKZ7n4vuDwYAoitBaE8cuWRtFQ6IowvMk/UGA0S7Gqc2ANEvwPIUfBKSxoVarIYQonwSqnnyBwtA/MFhYek3TfiEhsh4fbt3gF3/nh5k/KRguC1QM/Q/G2Ah+feujNF+ss3seOp9wsXSOxC4oFiualwKOfjPG/50X87zHZEJvbvGBf3KEN35mHrGwf+L2ameR2/5I532uHrETjxgxFtjqjTNkhoOQN7aW3XyfjOifhCtnA7y+oP5ig7l3MnbPe6x/JMJmkpcWnwJ+79Cuq+gaNJ/Ns6t2y5Z1hTeaCudtN3WzTBpGIipb1e0Xx34vTKpR34t7VZdueVvc9G/SMi2Op8ddb1kRcyo5RSxjWrqFQNBVXRrm3owoCwy7d7+HE2HclVJ4vmJ55SiDeEgQ1vBQPH32DKdPnWZleZkf+aHvZWGuiedZsizh4sWnCcI6S0tHUZ4ijmM2tzZYv70BQhCGNTKR0u320dqgjcJYRxVstRrUGzWOpSlp4gy0PXbEccU9RRQN2W278JDRlnbHefQudm0QwpBlzuMs1RWtIc5Sojhmt90uZXVBoJSHNoY0c9Wj7liQaoM2oLVxXZTSCKMNnldNgBq00Xi+j5suS6Y11ubetlRobVDKR0qRh1ucZv1MYwZfOS8xSRJq9TpJkhAnrqWh0c5zGFWuPkDP3ULnnOJHT1/itd4JztS3MYEhXnTa6tEzMUJaxIzhq2+fx3wsRnqm/Hex7YDGquL4VwfIP/gmTLBRr0J/+xJP/fonuPIXfWTrTgrb+ubs2Psbm/PYSvMHqwXRYPyf2Bpot8cNgFhI0AvQPSnonvc585spS68pjC/4vdozh3pNAieLeyw9Rs24OPNADsqY+2K2ONZLFbin0uGTjoVsoSzW0kIzn827TkzIUtpYC81stn8B114IoLbP31MVE2HchZIETY+V40e5tbGFsB5KpZw+fZzWTIivDMvLRwjDgHq9QaNeoxb4pGlKGAbs7O5w9epV0iyl0Zih2WyRJCm32z3HxDGgfImvPLLUcPrMcU6dPMOlN9+k3e44xcQ0QxsXNsmkxFMhu+0OW1tbxHGM7xfJSUqPfTwJabFCkGR65BnnIR6tR/HsYn+lFDrVeTLYTVa9ViOKBmjju05SaZp3jsqrSYkx1v2TCCnx8qIlsHiei53HcVby5IeDuAzJFE8aSOEaj6QpQeCXPPcHnVBFQOfZlH/1H17gQ9/zLl9++TkIDMMPDouPxyA9g9USuenTWJOc+o119KWDNbWYNHhf+jonFj/F2o/JvKnFe8Po+zOAQlhYSFj9K/mGrZBzvyi5el9H3XMOwBOS0ITMMlt2RarbeinPu9fj/m427OCeQArmThGOqhYqLeo72UPvhUQmaGGw9/hbmQzjDni+4vb67dzjlXziT7zAfDNgfrbluNnGkiQx7faQQd/D5vRAnRUc7wwrJItLxwjCkDTJmGkusrmxhR+ESE8xM9Pk6tV36fYiXvr6N+h2+6Sp45132m3SLKPb6ZJmKUnmCoyAkrECI/GtsfELgTWaTFv8PPkZhqHr+BTHxHHqDG7mCpBK1goCz/Ox1qB8F7s3BrRxIRnf98v4PkAQBGMGuKAygou9w52hluq4i6SpUqo8fqvVIopcjP9ByQ+UUBbZivn26nFEI7tDstREHmiB11HMXYKjL+3Am+9g4hj9IBeeh4DZf/Mtdp79OPHFhywVsBSz/VwdvnSYBxXMmhaXa5dZSVcwmNJwCcRjE0p5XBGYAPkA2+wdLgQsLMwTRUPSNGYw6FOv1zAmpt3eJY4i0lSXXuj8/DynTp4iyzTXb9zECo84S9AGBrHBC0P6UUKn3SWOLVdef41Tp06SZetcvfoOcZxTBQ1kmaXT6SKEoFaroY0kTUFXmCMFW6WgQGaZptlsMjMzU4ZfdJayubkBlnwfp7cssCW/PcsywjDMPeyYLDUgJW4vgfJ8TGqRwoVR6vV6uSgUsfLSAy9uXanPPoqjF68Lj93lDtw2x7pRZDDmte9dOB7gVFP0AbJGIjcCwm3B0T9Oaby9A+tb2CTB9PsTlSC9X5go4txn3+byf3We7Nw+pecPEofsNFsstWyGYe06V8IrWCwn0hPfkXc+FEOMMAdKIk4xggD8+t15YZNh3IEkiTl/4QIbG7u8/fabfOT2BY4vL9Ko17FGMBjssrx8jBMnTtDpdnnjzcsYA0tLR0lSSWdth+s3b/G7v/8Sw2GE8yFc7HphYZ52p8f6+qZjwxhACHSmEWLkaXd7A2dMtc31aESprFi81jpDShgOBwyHA6TKqYRa4ytFveEM8nA4JIkiojhhGGeV60xy3XVX5er6rkoQCuGBMAadJghBnjD1CMOQLMtKY1xoyBTGeG+CtLoQFBID1e1ZOorVOeqlHVsUHhi0QF6pM/cOLLwxxL/dwd68jen33ccP9uyPHPr2Ok99/iiXfmYGOXOwEvJJhW98TiYnuRncLDsmVZtuHBSxjAnMwZKrU4xggWRwZ2P1KibCuFtrGQ56SCk4evQop8+cJEliokEK6ZAzZ8/yoec/xu3b67z6yiW6gx4zzVnmZhf4xjdf4513rrOxvskwTukNhwih8qpM57kO423Wbt3Ked0qr94UuZFPy8pS62rzMdbcwf8uDaCEeqOGMYY0TanX6zQaDSQQDwZ5GzuFJwUpFqUEnifzBtYi9/5z8S4K4ywRUmK0QUiFVNJJy2rNcDhESkelFEKUoZqqF7+3EKl4X3Dq3TlG3j05u6d6/x+4127h2O8p5v/Nq5iuo9A96cZ8P9hvfpsLv/wx3v4pr6xgfdwgrCSRCQLBSrrCrJ4lFjEpKT53Nzh7cS+++xTvDZM9BjH3LMvY2tqh2+2SJDFnz52kFoacPHmSs6fOsbGxyR9+7SW6HZcgjVJNt7/NlSs3WFvboNcdorw6aT/FovC8gDCsE0UuoWgRKOkhpZcXIwmMcU2kpQJtM4SRjBpbSJRUY15tqX9ubF4glHvzQuL7AYFSKKnoR316aYryFLOtWVKtoTckioa5lK87b6HwaN0JnfdOboQRpDoDa3MJA0uv10cI8iIpMZYHAMrXhbEv9qtuG4vXV+6/EBLEg2UzqMQy+89fmmjq4sOC/MrLLD37abY+ZZHh43c/jNAM5KCk74UmvKuK4RSHDwH4tbv/7UyEcTfGsrPb57VvX6Ix0+Aj3/NBnnnmPDu7t/kPf/RV4ihlZeUk2gqu31hj7fY2m9u7tNttdOb6maZJRqY1MzMzzM/P0+l0MNppvCjpI1DobOTlujALuAcc11NUKa/SmWhkCKuyuFnmVBhdMtOQxH3a7T5ZlubHcCJdSbePUv3cY1bU6zMYY5xKZJq6Uuz8ScHaogGowGpAKYQx2FygzPN8ajWFMaPQTJIkeJ5Xxu9hlDQtXrdmW6RpTL/fR2cFJye/pvLVSADkQfnuwsCRb7ouVVMA1rL0i18D+yfY+cHH856s++tseBtcjC7e0cZuivePQi74oPfRAsnw7uZ7Ioy7kILTp0/yqU9/ih/8wf8IrVO+9tJLxMmQhYUl5ubq3Ly5zo3rawwGEZ1Oj8FgmCdZQWcpxliazSa+77O2tjbinwNFxWg1bFE0iRZynFVSlOTrbNTJCEaecbV8v+oJO+NvySpP2sUYtNYkSVI2qR59t/gZGWYrBdqQn9uWhtwlPJ3HXlxDmqZlsrVer49dX5qmWGMIgoB+v8+YilMFDzIcYwUsXNK03mpjXrn0wM7zWMJoVn57lZ3nT8HRu1caThqKtnanklNjDZ6n+M4RyYjABO87rHU3TIRxP3vmNBkncj4AACAASURBVH/3f/q73Lq1xquvfZt+v0cYhhw9coLXX3ubl1/+Ft1uH6V8fM8nMRnDaEiSZEihEMK1n+t1+xirx+PLuLBPlhvKIpZeoBraKIy5kwNQYKBWq5GmKXEc5zHsO2PbVQNZGP/R+UVppKv77MXoWJThHoMuF4csl/stKI9FaEYpVcbWfd8fW6h2d3cRshpTv/OfsKr3ftiQGcx87sUnivVymMhWr/P0549w+SdDxNzhSQ8/aFgs89k80h7c05zi7pgxd2nm+x1iIoy7NoZ/8au/xsb6Omnmyv7TJKXf+xrdTp9hFGO0QevIJRHRruJNSCefC7lMgAZhyvh4SRu0ICrCWMVn1oIf+GOhjQJGG7TVzgOuCG7lkRSAO4x7oT/jeR5CCNI0ZTAYlp/vpTBWY+HVcIrNm1w7do4u9xXCG4ujFwyeItFaq9UY5Eldx6pxuQOlPLR0VbAusTsac5UrP8XDh/zKy7Q++b10Pz6uIDnJkEhapsWcnpty2h8RAhvcs6B8Iox7fzAkTgxxChsbO/S6zkBZ7cr3XXVmTtXLJRPr9Tq1WoNup4ezmUWzZ3fMIk5urUUKkXdPUmONoZWSBEGQM1xGRT1SSvq9If1+fyyM4n6PKIae540xVwqDXoRRqgtDMaYCe71lW1k19lIci/FGUUQQBHeoNwohiKJoJKlaoXCCQgpBak0pc1D9XvVcD1oVcor9cfKfvs67jeeInnv4vVC/I1ixrx77FA8PLd0ieBwSqljo9Yasr+8w6CeAjzUiLySypWG3WDxPMTc3RxjW2Nrazo3VaAmrGtoyti1k2TwDKOPUAMNhRBgGuYSuC73UarWcTTPyuEeSuIKi+QW4AqcsTcmMxtqR2uJwOCx1q61xvwtxr/Ky96EwFs0zbH4tyvPQpY6N6wZVq4X5ouNyDnu12ctFIWf8aK2x5k6xsOr9Kha4KR4+9M4O5/7ZDV7/28dgNn1sPPhpSObRYdvbJonvXlE+Ef/NWhu2NjtY4wxnHEdoPWosUZT812o1FhcXieOE9fV1F7oQrgmEEK5TUi0I8ZQiSVPIVVuKkEwhrVtI3BpjMInO2+a5MEhqNGnSR1gwmetwhHECVjLXdDfaPUlgXcsBa1wLLqtdx5mx8EveUcYai6gaWAHSSnccA8LkXrR1BdwWmbcpc08qxhqEFFjrNOaDwCdN8ycaqRCMVCqrht4tAEXi1jGE3AIiqOYGlHKL5p1dRqd4GMiuvstz/zvOwD9G8fcpHj6GckgmUhrzj4Hk72Aw4OrVq2WZvDM6VS8carU6rVaL7e1toiimkN11LHbI5Q1Jk4QsN+CFfvmI3ph3ral49s4QjtMdAYSxY8ayZLpk2ZiXXRwPKqEWS64IuUfzJb8mq00Zt6c0vKbw8/P9LEI6b18okMItAAhFliX4vqRWdzF2mecPqtc14sGLsqJ1lHO4c8V3XPreYU7rFO8T2dV3WfmD49z+06LUrZ9iir3wrIdAIu/xNzIRxr0wklmWjRnjqhEdDocMh0VMcjzuXcTFXXhhZMD3Hr8akihi5TBu9IvfxmjnbefnuNvYCybK3sRp9XhVWibC5scugk75NmvLR3IrXARHeRIpc+OcJhiTAYJhNKQuQjxPleuD1QapRB6qMtTqTkCtEArLsozhcJhfrxyLzReFWlM8Wsz96jcYLH+S3icek/j7FA8dAzkgExnx4O7me2L+m4sQTNXDrHrGe4WxCqNaGGYnITDqE7rf8Ypj7ZfkrB7fHXP0vqBCFgnSvRose88BlCJjxmi0TgCDtRmuCtYZXyg6Kpn8R5fbnZEvPstQSuRFrBblSYLAI81S/MCjNduk1Wri+R5CgOfJPITjnlzi2BUypemdeibV6w/D8P1O2xSHDJsmnPrcNdiczsWTBIOhJ3sk4v5DboENUFY9HhWqcKf4VTVEUmCvQR4xWHIDhUXt4y2PipZE+XpvW7m9yU1tNFju2K8IHVWlCcok6D7HFrlHbnEUz6IrUiHmNeLQG2fA83EWoaAoiojjmCSJUVIwvzCXi5x1SNMYISAM/TwBHBLnTB3AMYn0KLE8eloZ3aPimouw0xSPHtnqdc781jHe+U/UYylPMIVDoW1fvK6ZGjLPpUkkfekE8wyGVKTM6tkDNciumzqeVaSPQ4UqUMaJi0KcarHO3vi2C0M42a2qJy1wN3G/fqDValGg5JBXQ0AlrC2ZLoVBLLzggjp55xPGyFB6nkcQBPmxDUJq+n2nDTM3P8vs7CyddpvBYIAfKPzAc/1StUYql2OQEmZbTZqtBpubmwyHQ6x1lMdaLWRxcYEsS5wWfaeN5/m5bovFNfbOm3GkliQZceKLMVZvT3Xhm2IyEPy7rzN/4U/S+fQ0PPO4wmJZDVYZyiHn4nN4eNz2brOrdjmTnKGruqz5a2Qiw7c+3zP4HnbVLqlIOZIduScbKWze/SngnsZdCHEa+CVgBRdP+Ky19heEEIvAvwDOAe8Af8lauyOcVf0F4DPAAPjPrbXfuNd5CiYMwPb2drm9Gl4ZbdSlqqEUjklTpUDuJ127t3FFfm0UOQlTev/Os7WMh3YKY15ICRTfL5KVI1ld41gseRu8IjTjjq9p73ZJE402GY0ZDyF80iyjMRMQxwnRMCKODWmasau7juWTWqx1SpZxlLF2c30seWpIWFhQPPX0OSdDLAVzs7OkaUaWGFavXWd3t4vRIKWPEYY0S9m6vVGGoGZaM8wtLgAoIcQXD2tep/gOYS0nvnANFZ9m5wcPrv+ebe2y+Y9/Dd3ugRC0fvgFZhd+BKbz+tAhkZxNzpYtBhORlF76arDKufgcDd3gtfprLGfLCFzj8ZvBTbqyS8u03tPAW6C/Xb/H+e+NDPjb1trngU8Df0sI8Tzwc8CXrLXP4Pq8/Fy+/58Fnsl//gbwj+51AqUUy8vLRFHExsZGSVmEOysoHRvFlKJbFGyZolhpTzK2OEZ5wZWk4dgiMJJ5GYvVV+P3RbPqQlMdKL3/RqNBvV4jDENUrvEeBEH+3iPwQzzlE0Uxg8GQhfkFmq0WCGe0pVBgBcYIlFRkqabX69PvD1zPVVxiNQxrhGEdkGS5aNpMo0m9PsPuTockTun1BlxbvY4xhvn5WVaOLTudeVuwf5ys8OLyEU6cPcXKqWP0u71C2Ov4Yc3rFPeHbPU6y1/dxO6+D71zKVn4K3+Wk3//Zzn+P/4XdP79HxJt34LpvD4SuK4STqYhtCEr6QoN0+BYeswZczPD6eQ0HdUhkhE+PmeSM/dski2AoHGfPVSttWvAWv66K4R4HTgJ/Hngh/Ld/inwZeC/z7f/knWW9Q+FEPNCiOP5cfaFVIqNjY0xCYD9PPDC0Mr3WM2qnvveCtBqCGZMT6ZCC9ybiC344lLKksVTtMkrOjMVDJ+iM5Mxhps3b5IkCb7v02w2y3PWajW2trbIsoyNjU36gw5KOu35OElJ4jSnNDLGYqnehyiKxvIDaZpirAYryHSbIPDRRiOEJYlu0mx2EDhZYsQAbTTaOM68549kgz3fLTzAfD6f9z2v3xWQCrW0CFmG3tk59MPr19/iA//3s7zxX84eqMGHtzCLt+CaZsh6iH/iKGmvDdN5fV+IRIRv/UOXVxAIGqbB5dplQhNyMbrIQA6YMTMMxZA6TojtXrF3C6TRIcbchRDngI8BLwIrlT+AW7iwDTjDv1r52vV829gfixDib+A8BfwwdBWde5Kg1d9VWJyXvjdJ+l4UxL2iXYWRL5QX9y4V5SKSG/cieer7fpl4rNIvC5pmkWyN4xhjDP1+n36/TxiGnDp1CmstMzMzbGxs5IuEJdYJWucNPIxLzEo14tYXOjXGGLIk3cNXd9enM0On08cYTRBqarWAOEmIhgk7u12kVBgDQilMzsyRFRmFNElJ44QgDAC8Q5vX5sIdc/dEQSo6P/kCwyVJ2oSzv3YL/daVQz+NeeUSK1/5FOs/+v747+nGDsm7azReOAuHOK+h/2QzeSwWIwxt2SawAU3TPLRjKxTH0+MsZAtEMmLNX+NscpZ1f53Ltcu0Bi0Ce7AntXtVMh/YuAshmsDngJ+11nb2ME2seJ8109bazwKfBWi0mrayfe9+d4RlpBBjydQ7iogYGfhCMbHQPh+T6s2PVT3vOAPGJVHr9fqYMiSMWDTF+8Kzrwp6VReV1dVVOp1O6fGPUzVxxPYiJZwvKEUBVUFTDDyPOHKUzDsbcLjq3ixN0b6iVqszjCMybfCKBU7IvJWfBTPi5G9vbLFwdPEOKuR9z+vyaUv//Xz78UL2Qx8lbQhUbJEprP/ACksPwLhjLQu//m12PvChA/dfNVHMxi/8Cot/9cdQZryRxv3Oa7PReuIz76lIaZjGgQ3t+4FAULd1arqGFrqMvyuryMgIONg5jbl7wvVAPHchhI8z7L9srf18vvm2EOJ4/vlxYD3ffgM4Xfn6qXzbgbDXuO7lpNsy1n7v45TCYbkH7Pt+mQAFl0Td6+3vXUzgTlpm9Rj77V8txCoSuUWsfpTYtRiT5QwXgx3jvY/GXiwqLmTljHwYhmVhVbFIKCXxfY9avUbR7Wh2dg4/CEkzkx+9iPmMnmg2b20w05xhfnGBer0OkD2IeX3SIPyA3skAm/8Hibzfipp9/31EDwLT7XLhH15BXb93xyObadZ/4VeY+d6PMPPCB4vN03k9IASCOT1HzTrq4oM8z1K2xIxxSdSVdIWGvXusHaCv+mQiQ/l3F/q758jzbPovAq9ba/9B5aMvAD+dv/5p4Ncr239KOHwaaB80fneHQTcWV8uTe5pFCMWOe+37ee9V79jzPEI/oFGv4ysPJSRKumkzNiPNYjKd4AqGRufPsow4jhkOh2P0x8I7L4x4sW/RZSnLslKqV0pJv98vxzJK+OISm1IiZS73KpyRL0TLqgydOI7pD/oM44gkyxz7Rgikp1CeQClQSqB1lmu8ZwS+z8LsPPUgwJcKJSRo4+4pgp2NbaeEOT9PFMWFHPDuYc/rkwgR+Jg9fRV0DczFMw/snNmt25z7jSHmLvxmay2b/+Tz+CeWmfvM9wMU45zO6wTCYJjP5hEIhvJgtNe6qeOhqDXvM6EKfB/wnwGvCCFezrf9HeDvAf9SCPHXgXeBv5R/9v/haFVv46hVf+0gA94vxi5yIpCkEk9HlMqK+8Xj9x7PUQkzIjNEZ5njz1uLEsqV6tuR2qM1lIa18IwLr99aW1IgARcmCQJ6vV4ZPql65sUTQ/G+VqsRBAFCOO31OI5zD33v6uu892ouobgeYy26om2DcKqUYaBQ0rX9yzKN5ykajQa9ThelfIQFJdwiIpU71jCOGPT6+EHArdUbgGXoqJBrwI8e1rw+sZASmYGp/AfJjAM9Vd4PxFdf5tTKp7jxF/bXf4/ffJf+V17GP73Cjb/zfwKw8GOfgem8TiQkktCGnExOllIk9+K3SyvBCrqbd6dCHoQt8xV4z7P9yD77W+Bv3eu4+3xv7DeUzMSxBOd+39vrxVeNqzOK2qk9pglV0SwhQCoXVxcIslyyt1q0VBjoohtSscgMh8OywMnzvDGhsjB0krxFbBwo+7oWMfwwTyJX2S97sZ+Br6Io9rImw/cUXn4dUZQwHESkmcZTHo1GozxP8VNvNDj7zPmx+wewxbq21h7avH5XwYC8cpMHXVPa/K1XqD33UeJn7/T0ah84x7n/938e2/bc6Vv85j9iOq8TiKohT0Wai4LdW0pZAEH97hXlE1GhutfAWOvCL3v1Wqr7FO9hnK++L2MGHDcel2H2fT83sAEIQ5KkDPqD3OlyxyzCK4WBrYqLVUMs1YYdSqkyPFMdWxAEtNvt0lsPw3Ds8yLMszdxDCO5gyL0U71f5WvjGDPYfKyZKRcio02pZT9abMYXi+JeTYXD7g/DFYFYmIMHQImswgwGnPvV27z+3y4h7/FoPsX7h8UyFMMHHnPfi/eTvLU8Jg2yYR9DDrkMwJ26M7DHeO/zeUEldMa9EOCyeZzdVYgao7Dkhls4poqUCqXu1F4pml7AKGGqlHItAXNBsaq0cDUJG0VRSWks2DIFp7/Kmtl7zuK+VI3v3gWgjMtnFiEypFQ4xUdXuVs0Din6rRb7F9dRvC+uZ4oDIk1JWgKZFpVv0FizZO+s3v17hwT95mVOfukINz6jkP5Uf+awYLH0ZZ+hHHJL3eJ4cpy6vXv445HhHhHAiTDu+/HR4b1jQXux1+hXjaUzWhJf+VhrCQK/1DqPk6iU03V9W2P2dnYqKInVJKoQoky2wrhuzX6FR1JKTpw4wYkTJ1hdXWVtbW3MWz9I7qBq7IvjFwJhJtOYStjI3QOJUq65h9ZO26bRaJT3qLjmvWGsKQ4GE0XMv52w+0yATC1ZTXD8a+2iyvehYObzf8RK+AIbf2Zq3A8LBsO74btYLAvZwkP13N8v7sWWmQjjbq11HY/G7apjxXBn7Ln4TtWIV4151Wg6ES+P0JeEYegYM/6IRpikCbu7bbKic5MAJVUZxsiyQqLXDUhKUXrhBbJMk+mRPnw0HLoWgdaW3+z3u2Q6ZWNzHSEsUhUCaBT09ry7kzuPKL8p8tduRykccUgI0GmK5/uu7V5aFDi51ntFlSu5bIHTxEmp12tobciMJtPZaDEsdOWnOBCEH2ACye6HMsJ1hZUWU/fwz53B7raxaYYZDB5sgtVoFl+8zc7zx0hPx/smWK0VvLM97Xd6UEgk56Pz3AxuEskIz06Eidwfh1XE9EBhK+Os6oPtCU3s5+Hu9firMXdjTB7/NljjjP5Mo0Gz0aBeC9E6Y6uzSxg6AzkYRIDIWS2K3Z1dHHPFxerd8WKyLMX3/ZIxk2UZ/X4G5PF047zowmiDa5x96dIbue6Mj2eta6WXe82l128MQptSyAyKylXnWft5MVbBgTdZSmbzxKrVCAFOM75Y4IouTZIs00SRq0QVGKyxGGvKe26yaYNsAKRyevjGoi6cIz0+S/DOJv0PHmPjoz5HXnG5jxs/KAk2BI01sEqw9aEG4rkGafMU3tDSWs2QqWH3mYBjv7OBSDNsPcS8eRWbpSAk3Kenr9++yoXPZrz288cR7yFPEF1p3dc5vpsgENRsjafip1wl/AR77jq5exh1Mow7dzbPKFD1kPf95j5x+OJ1YQC1zkhim8sCWJTM9diNRkmPubl52u1OmWzUWjMcDBgOh+N68ZX4fsF7L6QGiqbZSaIRws8XCOe4OSaNJMtSrB0ZdIvzsm2R7MXZWZmHS4rrqyaci65KBZIkQWtNEATlYjZK/sqx7wOuytZolK/K47/XwvnEQwi8lWVsnCDCgI0//TSz12I2P1RjcNyy9KpluCRJ5qH23CmEgdqWpXfc3bu5NyuH0vnfrwJv6F53T7t5EgbWf+Ao2heYABofXsB4guFRwdFvudBevOAx++W3EbUaNorQm1sHvozs3VWOvHiKrR/Kpu35DgmiJGI/vpgI424ZrwKtJkurv+/4XmVBqBreqjdfxpGFaz/X7fVI4ohaGDAz00D5Pr3egN3dXZIkBSvKOHw1KVt4y/stJk7awAPGJQesdWEWJV24BKtLAw2WLI2wxjXodpRN57ln1nnre68THAWzGnoClxgtdG+CICBJknLRKeJc1f2NMQQquCM38UTG3IUAIVFLi4iZOrYeMnh6nt4xj5l1zc5FD69vUQmkM4LtD4So2NJ6B5Kmkxao364c7zuxndapS8jMIjNIWm5O/J5l93yeN/EF8Z+7SP+4IOjA/OWUeF6x8K0d0iMNvHYMb1xx/XfTO3W8l37lG+jw4+x+7/7hmSmePByatszDQDWuvrfytLpP9fXeROB+BssYF+bQxiAyXBNtrZ3WikzyRhigpIcQChBIeae0wH6LzIjBYsmycQpk/gKLxvM9LO7YroI1w1OCVBuUcAVKZQOSvFDrvZgxexk1Bf2yuF9Sypxu6Z4ciorZ6mIYx3GZUAVKxsxjDyFQF56CTo/+C+dIZyTCwuCIRKUW44HI/YjecYXfzRc8D1T86IyizMdW38i9/lPuX/P29y2StgQyqSO+7+PI1LJ4KaJ7KmTxm9uI7oDk7BHUH77KsX91mc7589hjB9d/n+LxhIDHpM2eHTfs99y9YnT3473vfS+VRJfevc2baBjanW55DBeXVijl5e/1vhTLvWyYQkrY2Aw/cBRET3llVaof+KRpUnZxSuIEqSRKKpQV3L59i8FgiJQi15BJMNxJjSzolXsXvyq1sTpWZ+QVRdFW1XCX31HjTx8Hvf+TAuF5qNMnGZ4/Qrg5pPdUC+MLds9L6pu21H3BOq/ZipFhf5xQLEAFtj9QwypY/74lstoR4kWof+gF0lnBsT/U7Fyso2sQn0wgk9Q2n4BF+yHDYunKLjVbw7Ne2RqvqCB9lGGbg1IfJsO4v89n3b3MmHGDl3dmcq6vS4p5js0ghXDpkdxTLmSDPaXwQh8lfURuBDOdYq3BZBnWirLhtBVOeF8pZ8hrtRAszC/MYXRGGIYkaUqn00FKhecptE6Q0tBub7kcQKYxxrLQmqfVaiIlZGnm1C4tZJkhHctBiFIWobjmIA/DxHGM3WPoXWGTwRQqk8aWnaVKSEuaJBRhG8cC2iOWMmmQCm/5CHge0cVjDFZ8Ok9J/B7IJEDka1xtyxnyJ5n8I5yEPyq2NHIlmGDX0l9RzNy0DI4J6i8HDI8+ngvaYaEwg2OVoKR4uErQoRiihcazHn3VZ0bPENqQRCSsBWukIuVcfI5IRjR0g1vBLVq6RUd1OJ2cPnDhUSQiAhscSoI2ExlGmHvmVybEuN/dc9/72X5FP65Ix8OxRAwql6ARQoA1CGXxfB9PylxErOEaW6SJS44aCwIC33OG3FPM1OvUlY/OMmTgs7G7Q5wkJElCc6bGxWee4QPPXOSN11+n1xmwfnuLmzfW6PUHpJkm04bZZhNfwbPPXWBj8xZaG66v3iLwajRqIVkWI5s+GxsdBv2EMGwiZZqHhaDfGzjGS56wM7n3XpudJcj7zQ6HQ+r1uqN25vo3xlpMljkLUN6HSphHj9g4Qgi0Ndh0sqyh8Dxkq0X27BmSxYCND/vOcEtGhnxzssb8qCEziw4h3HH3pbF2T8bcEwODIRMZnvWIRYxvfXqqR8M0kFaicPIcRhg21SY1U6OjOigUR9OjpCLlWniNwAQkMqGt2hhhuBJe4an4KW4EN0hFSiACbvu3qZkaK+nKgZpab3lbKBQr6cp9e/y+9VH23ovExBh3eG8Dv7cwCUbx9Kr37kIpBinB8xSektSCkFotJGyENGp1sI4f7uiFGZ4vmak3CMOAIAjw/YB6o46WluUjR7ly6S1irXn61FNcfP45zl+4yO5Om7nZWT70wQ8ReB46tXzx332Ry5ffQiqPwXBIY6ZJENQYRgMGWcrW1hbRMC4ZLMePH2N5+Qjt9jZbW0PCsEa365gTs7OztFotrl27hpBOWsDdG3fttVqNleVlbt68WUokeJ7HcDgsE9Mjj9+Mxen33uuDFFLdD27+d99Ldm8VUwCW/yhj5rIr3c/mGwyO1eidVIjMLby1rXxBnzI2D4zvhnuVkpJI15/0Su0KTd2kp3pciC7QVm1SkRKakNCGCATb3ja3/dukIkVayYJeIBEJR9IjNHWToRySyISaqdFXfVbSFWbMDC3T4u3wbYww1E0dLTRv1d7iQnyh5MPvNdw6Vxo6nh4nkoebC7H30HOfKONeRTXUsjf8AuMhmcLIKw9qoWRursXcbItAKYTTG0AJgdGGOIoIpcTzFY2ZWayEdqeNyWJk6OFJQxoPyKwgUD6e7zPY3qa9s0uj0SCNU6JhzNqNS7z4By+SpRnvXLnKtWtXMTpB6xQlYTjoEdbr1Gs1MJL19Q2EsPT7A8IwQHqCuYV5rt+8zjvXbiClhx+EDKIBypf4/kKuWVNkxQ0WgcwrYfv9fkl7LCiQRYVqcW8KcuV+i2Y1jPNewmT3DetCBcHuwXbvnVD0ThwZ2yaT7xK3c4r3DdelAHqqxxv1N/CsR93USUVKJCK2vC3aqs2mtwm4LkhL2RLKKhKR4FsXhlzMFvHwiGTEUA5p6iYDOSAkxLc+iUzoqi4zegaJxGBKHvycnmNX7ZKKlJZu0TLjNQVFY2yLvWdf1PcDy2OkLVNgbwx9b/hlP5GrQsBrZsbn2Mo8CwvzJPEAJQSep0iiBLSmHgQI4znPPdOgA+JUMxz0mWk2URI6nV0uXryIMYpoMGBza5NOr8srr7zC1StXeOutK3hBQBgEzNQb9LsdtE44e/okUrrq0ShJ2NreQUhFq9Wg0agTxymDwZAk6RDHCTs7mwh1gTjVWBRxokmzCCklUTzg2uq7ZFmKlIKictYZekEURayuOg2TIg5fdImqcuLZ597tdz/voI1OMcVjgEQkRNLFsmumhrKKhWyBhWyB1XCVuqmzmC2SCBeqrJs6NeuanSxmi6QipWZqZfzdt37ZUu9schZwC8hN/yapSFnQCzwdP01Lt0pDbbEuBo7ZN9xSNMc+zGIogWBWt5hZuPuTwMQZ9/tBEWOX4ETHBAhrETnFUOXiWVnuEVtr0akGqwhUiK9CMBGzrQWiQUSUJkRx7Dxh5QpVVpaXuHDxGVavXePmjRt02x0WF5p5YVRGu90hHg7QWYJSHlkac+vWDlGUkcROBnh2rkEUDXn9zTcIajX8sMEw7oLW+IFCCI8kiVGeRAqVSwpk+SLm54VZo4bfVU0bd+nFb8aNfQV7aafV31NMcTekWcqXb/z+ox7GGDI0EskWI0XOm9x6hCN6sBjqiDVv+677PFHGHUS5Pubmqqz6tAisEGjrXhsEFonNhe9dUypJUZuGzVVdrHUBEVt0gcqwNsXaFK1jpNRYDAIPYY3jqtsMWcoWkKs26tGohGPfZGnK/PwyYbgJ9EG4ZhyZVsTxEN8P8Tw/56kbPM8vve9qHUBhwKuyxPvRJPe9Y5V9p8Z9ioMg0ykvrX/j6aFpLQAAAzNJREFUUQ9jintATELZuRCiC1x61ON4DxwBNh/1IPbBgxrXWWvt0cM40HRevyM8DvO6AfT57rp/h4EHMbb3nNdJ8dwvWWs/+agHsR+EEH80iWOb1HHtwXRe3ycmdVxVWGuPTuo4J3Vc8PDHNs2gTTHFFFM8gZga9ymmmGKKJxCTYtw/+6gHcBdM6tgmdVxVTPIYJ3VskzquvZjUcU7quOAhj20iEqpTTDHFFFMcLibFc59iiimmmOIQ8ciNuxDizwghLgkh3hZC/NxDPvf/I4RYF0K8Wtm2KIT4ohDirfz3Qr5dCCH+j3ycfyyE+PgDHttpIcTvCiFeE0J8Wwjx30zS+A4w/kc2r/n5J3Jup/N63+efzutBUW3D9rB/AAVcBp4GAuBbwPMP8fw/AHwceLWy7e8DP5e//jngf81ffwb4TVwl0qeBFx/w2I4DH89ft4A3gecnZXyTPK+TPLfTeZ3O68Ma2yP5I6nckD8J/Hbl/c8DP/+Qx3Buzx/KJeB4ZcIu5a//L+Av77ffQxrnrwM/Oqnjm7R5fVzmdjqv03l9UGN71GGZk8Bq5f31fNujxIq1Nm9/wC1gJX/9yMYqhDgHfAx4cRLHtw8maSxVTNS9m87roWGi7t2kzOujNu4TDeuW1EdKJxJCNIHPAT9rre1UP5uE8T2ueNT3bjqvDwaP+t5N0rw+auN+AzhdeX8q3/YocVsIcRwg/72eb3/oYxVC+Lg/lF+21n5+0sZ3F0zSWKqYiHs3nddDx0Tcu0mb10dt3F8CnhFCPCWECICfBL7wiMf0BeCn89c/jYudFdt/Ks9yfxpoVx63Dh1CCAH8IvC6tfYfTNr47oFJnFeYgHs3ndcHgkd+7yZyXh92MmSfxMNncJnly8D/8JDP/c+ANSDFxbz+OrAEfAl4C/j3wGK+rwD+YT7OV4BPPuCxfT/uEe6PgZfzn89MyvgmeV4neW6n8zqd14c1tmmF6hRTTDHFE4hHHZaZYoopppjiAWBq3KeYYoopnkBMjfsUU0wxxROIqXGfYooppngCMTXuU0wxxRRPIKbGfYopppjiCcTUuE8xxRRTPIGYGvcppphiiicQ/z9reSoeXRJKjgAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light", - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "x, y = next(val_dataset)\n", - "\n", - "fig, axs = plt.subplots(nrows=1, ncols=3)\n", - "\n", - "x = x[0]\n", - "y = y[0]\n", - "\n", - "y = np.reshape(y, (256, 256))\n", - "axs[0].imshow(x.astype(np.int32))\n", - "axs[1].imshow(y)\n", - "fig.show()\n", - "\n", - "x = np.expand_dims(x, 0)\n", - "y_hat = model(x)\n", - "y_hat = y_hat[0]\n", - "\n", - "y_hat = np.argmax(y_hat, 2)\n", - "y_hat = np.reshape(y_hat, (-1,))\n", - "y_hat = km.cluster_centers_[y_hat]\n", - "y_hat = np.reshape(y_hat, (256, 256, 3))\n", - "y_hat = np.round_(y_hat).astype(np.int32)\n", - "\n", - "axs[2].imshow(y_hat)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Q-TYBBWHk1v6" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "semantic_segmentation.ipynb", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.7" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} diff --git a/trax/examples/trax_data_Explained.ipynb b/trax/examples/trax_data_Explained.ipynb deleted file mode 100644 index 39a7567a9..000000000 --- a/trax/examples/trax_data_Explained.ipynb +++ /dev/null @@ -1,890 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "trax.data Explained", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "authorship_tag": "ABX9TyMN9H/craeNOTmFImALz3Uk", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "6NWA5uxOmBVz" - }, - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "u6IGlnMDLf6M" - }, - "source": [ - "## Install the Latest Version of Trax\n", - "!pip install --upgrade trax" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zOPgYEe2i7Cg" - }, - "source": [ - "Notebook Author: [@SauravMaheshkar](https://github.com/SauravMaheshkar)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jtMr8yxvM2m3" - }, - "source": [ - "# Introduction" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "yD3A2vRGSDwy" - }, - "source": [ - "import trax" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "v5VsWct1QjPz" - }, - "source": [ - "# Serial Fn" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gEa5pT6FQuta" - }, - "source": [ - "In Trax, we use combinators to build input pipelines, much like building deep learning models. The `Serial` combinator applies layers serially using function composition and uses stack semantics to manage data. \n", - "\n", - "Trax has the following definition for a `Serial` combinator.\n", - "\n", - "> ```\n", - "def Serial(*fns):\n", - " def composed_fns(generator=None):\n", - " for f in fastmath.tree_flatten(fns):\n", - " generator = f(generator)\n", - " return generator\n", - " return composed_fns\n", - " ```\n", - "\n", - "The `Serial` function has the following structure:\n", - "\n", - "* It takes as **input** arbitrary number of functions\n", - "* Convert the structure into lists\n", - "* Iterate through the list and apply the functions Serially\n", - "\n", - "---\n", - "\n", - "The [`fastmath.tree_flatten()`](https://github.com/google/trax/blob/c38a5b1e4c5cfe13d156b3fc0bfdb83554c8f799/trax/fastmath/numpy.py#L195) function, takes a tree as a input and returns a flattened list. This way we can use various generator functions like Tokenize and Shuffle, and apply them serially by '*iterating*' through the list. \n", - "\n", - "Initially, we've defined `generator` to `None`. Thus, in the first iteration we have no input and thus the first step executes the first function in our tree structure. In the next iteration, the `generator` variable is updated to be the output of the next function in the list.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1rkCvxscXtvk" - }, - "source": [ - "# Log Function" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oodQFyHDYJHF" - }, - "source": [ - "> ```\n", - "def Log(n_steps_per_example=1, only_shapes=True):\n", - " def log(stream):\n", - " counter = 0\n", - " for example in stream:\n", - " item_to_log = example\n", - " if only_shapes:\n", - " item_to_log = fastmath.nested_map(shapes.signature, example)\n", - " if counter % n_steps_per_example == 0:\n", - " logging.info(str(item_to_log))\n", - " print(item_to_log)\n", - " counter += 1\n", - " yield example\n", - " return log\n", - "\n", - "Every Deep Learning Framework needs to have a logging component for efficient debugging. \n", - "\n", - "`trax.data.Log` generator uses the `absl` package for logging. It uses a [`fastmath.nested_map`](https://github.com/google/trax/blob/c38a5b1e4c5cfe13d156b3fc0bfdb83554c8f799/trax/fastmath/numpy.py#L80) function that maps a certain function recursively inside a object. In the case depicted below, the function maps the `shapes.signature` recursively inside the input stream, thus giving us the shapes of the various objects in our stream.\n", - "\n", - "--\n", - "\n", - "The following two cells show the difference between when we set the `only_shapes` variable to `False`" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PqZZAYC4YlIt", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 663 - }, - "outputId": "aa36ceb1-65b1-4c65-83ae-c93d197759b7" - }, - "source": [ - "data_pipeline = trax.data.Serial(\n", - " trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", - " trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", - " trax.data.Log(only_shapes=False)\n", - " )\n", - "example = data_pipeline()\n", - "print(next(example))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "(array([ 182, 31, 43, 5981, 67, 6322, 243, 3898, 22, 8, 2138,\n", - " 2, 36, 47, 66, 597, 300, 10, 34, 3986, 2613, 64,\n", - " 5281, 2367, 2, 46, 1902, 4713, 2942, 3461, 8, 4797, 55,\n", - " 1466, 1351, 409, 3, 121, 114, 1622, 5622, 66, 124, 4106,\n", - " 47, 1972, 10, 536, 8, 4533, 2, 124, 1466, 3207, 93,\n", - " 449, 90, 407, 4860, 76, 114, 3898, 22, 36, 6, 2339,\n", - " 5160, 275, 2395, 6293, 181, 8, 182, 3898, 22, 25, 43,\n", - " 402, 4423, 794, 995, 3040, 2420, 2128, 2, 5116, 2, 8,\n", - " 28, 180, 3166, 3171, 3839, 44, 80, 668, 232, 4, 1743,\n", - " 3661, 239, 3082, 4076, 80, 2067, 124, 2700, 35, 3854, 1052,\n", - " 221, 8, 6149, 5481, 4607, 12, 547, 2942, 75, 4445, 3054,\n", - " 29, 3, 7, 245, 5372, 1135, 75, 14, 3304, 2, 4935,\n", - " 1197, 39, 5281, 2367, 2, 31, 5032, 2528, 121, 12, 3166,\n", - " 3171, 5888, 5403, 2, 2305, 93, 10, 12, 3898, 22, 37,\n", - " 31, 3060, 2558, 2, 5, 345, 2715, 2213, 8, 139, 907,\n", - " 2133, 1051, 2390, 200, 37, 266, 55, 3898, 44, 461, 114,\n", - " 3, 4269, 1264, 617, 36, 6, 461, 3986, 2613, 64, 5281,\n", - " 2367, 2, 36, 6, 2730, 177, 8, 139, 449, 1120, 839,\n", - " 4198, 2, 340, 71, 21]), 0)\n", - "(array([ 182, 31, 43, 5981, 67, 6322, 243, 3898, 22, 8, 2138,\n", - " 2, 36, 47, 66, 597, 300, 10, 34, 3986, 2613, 64,\n", - " 5281, 2367, 2, 46, 1902, 4713, 2942, 3461, 8, 4797, 55,\n", - " 1466, 1351, 409, 3, 121, 114, 1622, 5622, 66, 124, 4106,\n", - " 47, 1972, 10, 536, 8, 4533, 2, 124, 1466, 3207, 93,\n", - " 449, 90, 407, 4860, 76, 114, 3898, 22, 36, 6, 2339,\n", - " 5160, 275, 2395, 6293, 181, 8, 182, 3898, 22, 25, 43,\n", - " 402, 4423, 794, 995, 3040, 2420, 2128, 2, 5116, 2, 8,\n", - " 28, 180, 3166, 3171, 3839, 44, 80, 668, 232, 4, 1743,\n", - " 3661, 239, 3082, 4076, 80, 2067, 124, 2700, 35, 3854, 1052,\n", - " 221, 8, 6149, 5481, 4607, 12, 547, 2942, 75, 4445, 3054,\n", - " 29, 3, 7, 245, 5372, 1135, 75, 14, 3304, 2, 4935,\n", - " 1197, 39, 5281, 2367, 2, 31, 5032, 2528, 121, 12, 3166,\n", - " 3171, 5888, 5403, 2, 2305, 93, 10, 12, 3898, 22, 37,\n", - " 31, 3060, 2558, 2, 5, 345, 2715, 2213, 8, 139, 907,\n", - " 2133, 1051, 2390, 200, 37, 266, 55, 3898, 44, 461, 114,\n", - " 3, 4269, 1264, 617, 36, 6, 461, 3986, 2613, 64, 5281,\n", - " 2367, 2, 36, 6, 2730, 177, 8, 139, 449, 1120, 839,\n", - " 4198, 2, 340, 71, 21]), 0)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "uyqL-JMCaGn0", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 357 - }, - "outputId": "dfd51b28-159c-41b7-ba2a-39e95b1e3964" - }, - "source": [ - "data_pipeline = trax.data.Serial(\n", - " trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", - " trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", - " trax.data.Log(only_shapes=True)\n", - " )\n", - "example = data_pipeline()\n", - "print(next(example))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "(ShapeDtype{shape:(203,), dtype:int64}, ShapeDtype{shape:(), dtype:int64})\n", - "(array([ 182, 31, 43, 5981, 67, 6322, 243, 3898, 22, 8, 2138,\n", - " 2, 36, 47, 66, 597, 300, 10, 34, 3986, 2613, 64,\n", - " 5281, 2367, 2, 46, 1902, 4713, 2942, 3461, 8, 4797, 55,\n", - " 1466, 1351, 409, 3, 121, 114, 1622, 5622, 66, 124, 4106,\n", - " 47, 1972, 10, 536, 8, 4533, 2, 124, 1466, 3207, 93,\n", - " 449, 90, 407, 4860, 76, 114, 3898, 22, 36, 6, 2339,\n", - " 5160, 275, 2395, 6293, 181, 8, 182, 3898, 22, 25, 43,\n", - " 402, 4423, 794, 995, 3040, 2420, 2128, 2, 5116, 2, 8,\n", - " 28, 180, 3166, 3171, 3839, 44, 80, 668, 232, 4, 1743,\n", - " 3661, 239, 3082, 4076, 80, 2067, 124, 2700, 35, 3854, 1052,\n", - " 221, 8, 6149, 5481, 4607, 12, 547, 2942, 75, 4445, 3054,\n", - " 29, 3, 7, 245, 5372, 1135, 75, 14, 3304, 2, 4935,\n", - " 1197, 39, 5281, 2367, 2, 31, 5032, 2528, 121, 12, 3166,\n", - " 3171, 5888, 5403, 2, 2305, 93, 10, 12, 3898, 22, 37,\n", - " 31, 3060, 2558, 2, 5, 345, 2715, 2213, 8, 139, 907,\n", - " 2133, 1051, 2390, 200, 37, 266, 55, 3898, 44, 461, 114,\n", - " 3, 4269, 1264, 617, 36, 6, 461, 3986, 2613, 64, 5281,\n", - " 2367, 2, 36, 6, 2730, 177, 8, 139, 449, 1120, 839,\n", - " 4198, 2, 340, 71, 21]), 0)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Wy8L-e9qcRY4" - }, - "source": [ - "# Shuffling our datasets" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-cfg48KgcrlM" - }, - "source": [ - "Trax offers two generator functions to add shuffle functionality in our input pipelines. \n", - "\n", - "1. The `shuffle` function shuffles a given stream\n", - "2. The `Shuffle` function returns a shuffle function instead" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4iD21oiycWf4" - }, - "source": [ - "## `shuffle`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bVgN1yYAcaKM" - }, - "source": [ - "> ```\n", - "def shuffle(samples, queue_size):\n", - " if queue_size < 1:\n", - " raise ValueError(f'Arg queue_size ({queue_size}) is less than 1.')\n", - " if queue_size == 1:\n", - " logging.warning('Queue size of 1 results in no shuffling.')\n", - " queue = []\n", - " try:\n", - " queue.append(next(samples))\n", - " i = np.random.randint(queue_size)\n", - " yield queue[i]\n", - " queue[i] = sample\n", - " except StopIteration:\n", - " logging.warning(\n", - " 'Not enough samples (%d) to fill initial queue (size %d).',\n", - " len(queue), queue_size)\n", - " np.random.shuffle(queue)\n", - " for sample in queue:\n", - " yield sample\n", - "\n", - "\n", - "The `shuffle` function takes two inputs, the data stream and the queue size (minimum number of samples within which the shuffling takes place). Apart from the usual warnings, for negative and unity queue sizes, this generator function shuffles the given stream using [`np.random.randint()`](https://docs.python.org/3/library/random.html#random.randint) by randomly picks out integers using the `queue_size` as a range and then shuffle this new stream again using the [`np.random.shuffle()`](https://docs.python.org/3/library/random.html#random.shuffle)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "-kdz2fNIfn2l", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 105 - }, - "outputId": "110aa969-dab0-4e7a-e75f-41a6ab2fe0c4" - }, - "source": [ - "sentence = ['Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem. Ut enim ad minima veniam, quis nostrum exercitationem ullam corporis suscipit laboriosam, nisi ut aliquid ex ea commodi consequatur? Quis autem vel eum iure reprehenderit qui in ea voluptate velit esse quam nihil molestiae consequatur, vel illum qui dolorem eum fugiat quo voluptas nulla pariatur?',\n", - " 'But I must explain to you how all this mistaken idea of denouncing pleasure and praising pain was born and I will give you a complete account of the system, and expound the actual teachings of the great explorer of the truth, the master-builder of human happiness. No one rejects, dislikes, or avoids pleasure itself, because it is pleasure, but because those who do not know how to pursue pleasure rationally encounter consequences that are extremely painful. Nor again is there anyone who loves or pursues or desires to obtain pain of itself, because it is pain, but because occasionally circumstances occur in which toil and pain can procure him some great pleasure. To take a trivial example, which of us ever undertakes laborious physical exercise, except to obtain some advantage from it? But who has any right to find fault with a man who chooses to enjoy a pleasure that has no annoying consequences, or one who avoids a pain that produces no resultant pleasure?',\n", - " 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum',\n", - " 'At vero eos et accusamus et iusto odio dignissimos ducimus qui blanditiis praesentium voluptatum deleniti atque corrupti quos dolores et quas molestias excepturi sint occaecati cupiditate non provident, similique sunt in culpa qui officia deserunt mollitia animi, id est laborum et dolorum fuga. Et harum quidem rerum facilis est et expedita distinctio. Nam libero tempore, cum soluta nobis est eligendi optio cumque nihil impedit quo minus id quod maxime placeat facere possimus, omnis voluptas assumenda est, omnis dolor repellendus. Temporibus autem quibusdam et aut officiis debitis aut rerum necessitatibus saepe eveniet ut et voluptates repudiandae sint et molestiae non recusandae. Itaque earum rerum hic tenetur a sapiente delectus, ut aut reiciendis voluptatibus maiores alias consequatur aut perferendis doloribus asperiores repellat.']\n", - "\n", - "def sample_generator(x):\n", - " for i in x:\n", - " yield i\n", - "\n", - "example_shuffle = list(trax.data.inputs.shuffle(sample_generator(sentence), queue_size = 2))\n", - "example_shuffle" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "['Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem. Ut enim ad minima veniam, quis nostrum exercitationem ullam corporis suscipit laboriosam, nisi ut aliquid ex ea commodi consequatur? Quis autem vel eum iure reprehenderit qui in ea voluptate velit esse quam nihil molestiae consequatur, vel illum qui dolorem eum fugiat quo voluptas nulla pariatur?',\n", - " 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum',\n", - " 'But I must explain to you how all this mistaken idea of denouncing pleasure and praising pain was born and I will give you a complete account of the system, and expound the actual teachings of the great explorer of the truth, the master-builder of human happiness. No one rejects, dislikes, or avoids pleasure itself, because it is pleasure, but because those who do not know how to pursue pleasure rationally encounter consequences that are extremely painful. Nor again is there anyone who loves or pursues or desires to obtain pain of itself, because it is pain, but because occasionally circumstances occur in which toil and pain can procure him some great pleasure. To take a trivial example, which of us ever undertakes laborious physical exercise, except to obtain some advantage from it? But who has any right to find fault with a man who chooses to enjoy a pleasure that has no annoying consequences, or one who avoids a pain that produces no resultant pleasure?',\n", - " 'At vero eos et accusamus et iusto odio dignissimos ducimus qui blanditiis praesentium voluptatum deleniti atque corrupti quos dolores et quas molestias excepturi sint occaecati cupiditate non provident, similique sunt in culpa qui officia deserunt mollitia animi, id est laborum et dolorum fuga. Et harum quidem rerum facilis est et expedita distinctio. Nam libero tempore, cum soluta nobis est eligendi optio cumque nihil impedit quo minus id quod maxime placeat facere possimus, omnis voluptas assumenda est, omnis dolor repellendus. Temporibus autem quibusdam et aut officiis debitis aut rerum necessitatibus saepe eveniet ut et voluptates repudiandae sint et molestiae non recusandae. Itaque earum rerum hic tenetur a sapiente delectus, ut aut reiciendis voluptatibus maiores alias consequatur aut perferendis doloribus asperiores repellat.']" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 5 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "k-kTDkF-e7Vn" - }, - "source": [ - "## `Shuffle`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I5Djvqw2e9Jg" - }, - "source": [ - "> ```\n", - "def Shuffle(queue_size=1024): \n", - " return lambda g: shuffle(g, queue_size)\n", - "\n", - "This function returns the aforementioned `shuffle` function and is mostly used in input pipelines.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AA-Z4Sipkq98" - }, - "source": [ - "# Batch Generators" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yzwONDulksbd" - }, - "source": [ - "## `batch`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-DCABkndkudF" - }, - "source": [ - "This function, creates batches for the input generator function.\n", - "\n", - "> ```\n", - "def batch(generator, batch_size):\n", - " if batch_size <= 0:\n", - " raise ValueError(f'Batch size must be positive, but is {batch_size}.')\n", - " buf = []\n", - " for example in generator:\n", - " buf.append(example) \n", - " if len(buf) == batch_size:\n", - " batched_example = tuple(np.stack(x) for x in zip(*buf))\n", - " yield batched_example\n", - " buf = []\n", - "\n", - "It keeps adding objects from the generator into a list until the size becomes equal to the `batch_size` and then creates batches using the `np.stack()` function.\n", - "\n", - "It also raises an error for non-positive batch_sizes.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BZMKY6VUpD3M" - }, - "source": [ - "## `Batch`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g6pYJHgOpIG4" - }, - "source": [ - "> ```\n", - " def Batch(batch_size): \n", - " return lambda g: batch(g, batch_size)\n", - "\n", - "This Function returns the aforementioned `batch` function with given batch size." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cmQzaXw9vrbW" - }, - "source": [ - "# Pad to Maximum Dimensions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iL3MuKQIvt-Q" - }, - "source": [ - "This function is used to pad a tuple of tensors to a joint dimension and return their batch.\n", - "\n", - "For example, in this case a pair of tensors (1,2) and ( (3,4) , (5,6) ) is changed to (1,2,0) and ( (3,4) , (5,6) , 0)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lvbBDuq4p4qW", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "outputId": "ed69c541-3219-4a23-cf73-4568e3e2882f" - }, - "source": [ - "import numpy as np\n", - "\n", - "tensors = np.array([(1.,2.),\n", - " ((3.,4.),(5.,6.))])\n", - "padded_tensors = trax.data.inputs.pad_to_max_dims(tensors=tensors, boundary=3)\n", - "padded_tensors" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "array([[1.0, 2.0, 0],\n", - " [(3.0, 4.0), (5.0, 6.0), 0]], dtype=object)" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 6 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PDQQYCdLOkl1" - }, - "source": [ - "# Creating Buckets" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RjGD3YKJWj58" - }, - "source": [ - "For training Recurrent Neural Networks, with large vocabulary a method called Bucketing is usually applied. \n", - "\n", - "The usual technique of using padding ensures that all occurences within a mini-batch are of the same length. But this reduces the inter-batch variability and intuitively puts similar sentences into the same batch therefore, reducing the overall robustness of the system. \n", - "\n", - "Thus, we use Bucketing where multiple buckets are created depending on the length of the sentences and these occurences are assigned to buckets on the basis of which bucket corresponds to it's length. We need to ensure that the bucket sizes are large for adding some variablity to the system." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "17z3ASA-OrSF" - }, - "source": [ - "## `bucket_by_length`\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rf5trhANYpy5" - }, - "source": [ - "> ```\n", - "def bucket_by_length(generator, length_fn, boundaries, batch_sizes,strict_pad_on_len=False):\n", - " buckets = [[] for _ in range(len(batch_sizes))]\n", - " boundaries = boundaries + [math.inf] \n", - " for example in generator:\n", - " length = length_fn(example)\n", - " bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b])\n", - " buckets[bucket_idx].append(example)\n", - " if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]:\n", - " batched = zip(*buckets[bucket_idx])\n", - " boundary = boundaries[bucket_idx]\n", - " boundary = None if boundary == math.inf else boundary\n", - " padded_batch = tuple(\n", - " pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched)\n", - " yield padded_batch\n", - " buckets[bucket_idx] = []\n", - "\n", - "---\n", - "\n", - "This function can be summarised as:\n", - "\n", - "* Create buckets as per the lengths given in the `batch_sizes` array\n", - "\n", - "* Assign sentences into buckets if their length matches the bucket size\n", - "\n", - "* If padding is required, we use the `pad_to_max_dims` function\n", - "\n", - "---\n", - "\n", - "### Parameters\n", - "\n", - "1. **generator:** The input generator function\n", - "2. **length_fn:** A custom length function for determing the length of functions, not necessarily `len()`\n", - "3. **boundaries:** A python list containing corresponding bucket boundaries\n", - "4. **batch_sizes:** A python list containing batch sizes\n", - "5. **strict_pad_on_len:** – A python boolean variable (`True` or `False`). If set to true then the function pads on the length dimension, where dim[0] is strictly a multiple of boundary.\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "c0uQZaaPVyF_" - }, - "source": [ - "## `BucketByLength`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qhh21q71aX3l" - }, - "source": [ - "> ```\n", - "def BucketByLength(boundaries, batch_sizes,length_keys=None, length_axis=0, strict_pad_on_len=False):\n", - " length_keys = length_keys or [0, 1]\n", - " length_fn = lambda x: _length_fn(x, length_axis, length_keys)\n", - " return lambda g: bucket_by_length(g, length_fn, boundaries, batch_sizes, strict_pad_on_len)\n", - "\n", - "---\n", - "\n", - "This function, is usually used inside input pipelines(*combinators*) and uses the afforementioned `bucket_by_length`. It applies a predefined `length_fn` which chooses the maximum shape on length_axis over length_keys.\n", - "\n", - "It's use is illustrated below" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PFeqDQNsV0PV", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 153 - }, - "outputId": "ab9139c1-de56-4570-bcb6-731c1b475b12" - }, - "source": [ - "data_pipeline = trax.data.Serial(\n", - " trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", - " trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", - " trax.data.BucketByLength(boundaries=[32, 128, 512, 2048],\n", - " batch_sizes=[512, 128, 32, 8, 1],\n", - " length_keys=[0]),\n", - " trax.data.Log(only_shapes=True)\n", - " )\n", - "example = data_pipeline()\n", - "print(next(example))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "(ShapeDtype{shape:(8, 2048), dtype:int64}, ShapeDtype{shape:(8,), dtype:int64})\n", - "(array([[ 155, 452, 29, ..., 0, 0, 0],\n", - " [ 182, 1989, 1826, ..., 0, 0, 0],\n", - " [1389, 2597, 5378, ..., 0, 0, 0],\n", - " ...,\n", - " [4846, 1008, 2, ..., 0, 0, 0],\n", - " [ 68, 12, 173, ..., 0, 0, 0],\n", - " [ 186, 3817, 2064, ..., 0, 0, 0]]), array([0, 1, 1, 1, 1, 0, 1, 0]))\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9D0YdAT_ceSN" - }, - "source": [ - "# Filter by Length" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YLvi4Wu-eFAF" - }, - "source": [ - "> ```\n", - "def FilterByLength(max_length,length_keys=None, length_axis=0):\n", - " length_keys = length_keys or [0, 1]\n", - " length_fn = lambda x: _length_fn(x, length_axis, length_keys)\n", - " def filtered(gen):\n", - " for example in gen:\n", - " if length_fn(example) <= max_length:\n", - " yield example\n", - " return filtered\n", - "\n", - "---\n", - "\n", - "This function used the same predefined `length_fn` to only include those instances which are less than the given `max_length` parameter.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qyueQ1z-cg2p", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 153 - }, - "outputId": "da007ab0-e719-4044-e6a4-6bba5f43131e" - }, - "source": [ - "Filtered = trax.data.Serial(\n", - " trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", - " trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", - " trax.data.BucketByLength(boundaries=[32, 128, 512, 2048],\n", - " batch_sizes=[512, 128, 32, 8, 1],\n", - " length_keys=[0]),\n", - " trax.data.FilterByLength(max_length=2048, length_keys=[0]),\n", - " trax.data.Log(only_shapes=True)\n", - " )\n", - "filtered_example = Filtered()\n", - "print(next(filtered_example))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "(ShapeDtype{shape:(8, 2048), dtype:int64}, ShapeDtype{shape:(8,), dtype:int64})\n", - "(array([[ 155, 452, 29, ..., 0, 0, 0],\n", - " [ 182, 1989, 1826, ..., 0, 0, 0],\n", - " [1389, 2597, 5378, ..., 0, 0, 0],\n", - " ...,\n", - " [4846, 1008, 2, ..., 0, 0, 0],\n", - " [ 68, 12, 173, ..., 0, 0, 0],\n", - " [ 186, 3817, 2064, ..., 0, 0, 0]]), array([0, 1, 1, 1, 1, 0, 1, 0]))\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1XRrJSsUeZX-" - }, - "source": [ - "# Adding Loss Weights" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P3ySYhnpejy4" - }, - "source": [ - "## `add_loss_weights`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QgaXAlhgeuQv" - }, - "source": [ - "> ```\n", - "def add_loss_weights(generator, id_to_mask=None):\n", - " for example in generator:\n", - " if len(example) > 3 or len(example) < 2:\n", - " assert id_to_mask is None, 'Cannot automatically mask this stream.'\n", - " yield example\n", - " else:\n", - " if len(example) == 2:\n", - " weights = np.ones_like(example[1]).astype(np.float32)\n", - " else:\n", - " weights = example[2].astype(np.float32)\n", - " mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32)\n", - " weights *= mask\n", - " yield (example[0], example[1], weights)\n", - "\n", - "---\n", - "\n", - "This function essentially adds a loss mask (tensor of ones of the same shape) to the input stream. \n", - "\n", - "**Masking** is essentially a way to tell sequence-processing layers that certain timesteps in an input are missing, and thus should be skipped when processing the data.\n", - "\n", - "Thus, it adds 'weights' to the system. \n", - "\n", - "---\n", - "\n", - "### Parameters\n", - "\n", - "1. **generator:** The input data generator\n", - "2. **id_to_mask:** The value with which to mask. Can be used as `` in NLP." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hZPWc6a9hk_u" - }, - "source": [ - "```\n", - "\n", - "train_generator = trax.data.inputs.add_loss_weights(\n", - " data_generator(batch_size, x_train, y_train,vocab[''], True),\n", - " id_to_mask=vocab[''])\n", - "\n", - "\n", - "```\n", - "\n", - "For example, in this case I used the `add_loss_weights` function to add padding while implementing Named Entity Recogntion using the Reformer Architecture. You can read more about the project [here](https://www.kaggle.com/sauravmaheshkar/trax-ner-using-reformer)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GL31NErOgL3u" - }, - "source": [ - "## `AddLossWeights`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mBLf6iuXgPp2" - }, - "source": [ - "This function performs the afforementioned `add_loss_weights` to the data stream. \n", - "\n", - "> ```\n", - "def AddLossWeights(id_to_mask=None):\n", - " return lambda g: add_loss_weights(g,id_to_mask=id_to_mask)\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Jwtt-k_2iHEy", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 173 - }, - "outputId": "52295b0e-ff9c-415e-9ba6-1d5c1359b508" - }, - "source": [ - "data_pipeline = trax.data.Serial(\n", - " trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),\n", - " trax.data.Tokenize(vocab_dir='gs://trax-ml/vocabs/', vocab_file='en_8k.subword', keys=[0]),\n", - " trax.data.Shuffle(),\n", - " trax.data.FilterByLength(max_length=2048, length_keys=[0]),\n", - " trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],\n", - " batch_sizes=[512, 128, 32, 8, 1],\n", - " length_keys=[0]),\n", - " trax.data.AddLossWeights(),\n", - " trax.data.Log(only_shapes=True)\n", - " )\n", - "\n", - "example = data_pipeline()\n", - "print(next(example))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "(ShapeDtype{shape:(8, 2048), dtype:int64}, ShapeDtype{shape:(8,), dtype:int64}, ShapeDtype{shape:(8,), dtype:float32})\n", - "(array([[4176, 570, 636, ..., 0, 0, 0],\n", - " [3030, 2, 7, ..., 0, 0, 0],\n", - " [ 28, 3898, 22, ..., 0, 0, 0],\n", - " ...,\n", - " [ 139, 36, 76, ..., 0, 0, 0],\n", - " [2275, 2, 4198, ..., 0, 0, 0],\n", - " [ 182, 103, 151, ..., 0, 0, 0]]), array([0, 1, 1, 0, 0, 0, 1, 0]), array([1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32))\n" - ], - "name": "stdout" - } - ] - } - ] -} \ No newline at end of file diff --git a/trax/fastmath/jax.py b/trax/fastmath/jax.py index df838c708..e942278f4 100644 --- a/trax/fastmath/jax.py +++ b/trax/fastmath/jax.py @@ -16,204 +16,236 @@ """Trax fast math: JAX backend.""" import functools + import jax -from jax import lax -from jax import random as jax_random import jax.numpy as jnp import jax.scipy.special as jax_special import numpy as np import tensorflow as tf import tensorflow_datasets as tfds +from jax import lax +from jax import random as jax_random + from trax.fastmath import numpy as tnp -from trax.shapes import signature - - -def jax_conv(inp, fltr, window_strides, padding, dimension_numbers, - filter_dilation=None): - """A wrapper around `lax.conv_general_dilated`. - - It requires `dimension_numbers` and disallows `inp_dilation`. - - Args: - inp: an (N+2)-D array. The input of the convolution. - fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. - window_strides: the strides for moving the convolution window. - padding: a string, either 'VALID' or 'SAME'. The padding algorithm. - dimension_numbers: a tuple of three strings encoding the data format of - input, filter and output. 'I' means input; 'O' means output; 'C' means - channel; other characters such as 'W', 'H' and 'D' means spatial - dimensions. - filter_dilation: the dilation rates for the filter. Dilating the filter - means adding "holes" to the filter. - - Returns: - An (N+2)-D array. The convolution result. - """ - return lax.conv_general_dilated(inp, fltr, window_strides, padding, - lhs_dilation=None, - rhs_dilation=filter_dilation, - dimension_numbers=dimension_numbers) - - -def _pooling_general(inputs, reducer, init_val, rescaler=None, - pool_size=(2, 2), strides=None, padding='VALID'): - """Helper: general pooling computation used in pooling layers later.""" - spatial_strides = strides or (1,) * len(pool_size) - rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None - dims = (1,) + pool_size + (1,) # NHWC - strides = (1,) + spatial_strides + (1,) - out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) - return rescale(out, inputs) if rescale else out # pylint: disable=not-callable +from trax.utils.shapes import signature + + +def jax_conv( + inp, fltr, window_strides, padding, dimension_numbers, filter_dilation=None +): + """A wrapper around `lax.conv_general_dilated`. + + It requires `dimension_numbers` and disallows `inp_dilation`. + + Args: + inp: an (N+2)-D array. The input of the convolution. + fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. + window_strides: the strides for moving the convolution window. + padding: a string, either 'VALID' or 'SAME'. The padding algorithm. + dimension_numbers: a tuple of three strings encoding the data format of + input, filter and output. 'I' means input; 'O' means output; 'C' means + channel; other characters such as 'W', 'H' and 'D' means spatial + dimensions. + filter_dilation: the dilation rates for the filter. Dilating the filter + means adding "holes" to the filter. + + Returns: + An (N+2)-D array. The convolution result. + """ + return lax.conv_general_dilated( + inp, + fltr, + window_strides, + padding, + lhs_dilation=None, + rhs_dilation=filter_dilation, + dimension_numbers=dimension_numbers, + ) + + +def _pooling_general( + inputs, + reducer, + init_val, + rescaler=None, + pool_size=(2, 2), + strides=None, + padding="VALID", +): + """Helper: general pooling computation used in pooling layers later.""" + spatial_strides = strides or (1,) * len(pool_size) + rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None + dims = (1,) + pool_size + (1,) # NHWC + strides = (1,) + spatial_strides + (1,) + out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) + return rescale(out, inputs) if rescale else out # pylint: disable=not-callable def jax_max_pool(x, pool_size, strides, padding): - return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size, - strides=strides, padding=padding) + return _pooling_general( + x, lax.max, -jnp.inf, pool_size=pool_size, strides=strides, padding=padding + ) def jax_sum_pool(x, pool_size, strides, padding): - return _pooling_general(x, lax.add, 0., pool_size=pool_size, - strides=strides, padding=padding) + return _pooling_general( + x, lax.add, 0.0, pool_size=pool_size, strides=strides, padding=padding + ) + +def _normalize_by_window_size( + dims, spatial_strides, padding +): # pylint: disable=invalid-name + def rescale(outputs, inputs): + one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype) + window_sizes = lax.reduce_window( + one, 0.0, lax.add, dims, spatial_strides, padding + ) + return outputs / window_sizes[..., jnp.newaxis] -def _normalize_by_window_size(dims, spatial_strides, padding): # pylint: disable=invalid-name - def rescale(outputs, inputs): - one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype) - window_sizes = lax.reduce_window( - one, 0., lax.add, dims, spatial_strides, padding) - return outputs / window_sizes[..., jnp.newaxis] - return rescale + return rescale def jax_avg_pool(x, pool_size, strides, padding): - return _pooling_general(x, lax.add, 0., _normalize_by_window_size, - pool_size, strides=strides, padding=padding) + return _pooling_general( + x, + lax.add, + 0.0, + _normalize_by_window_size, + pool_size, + strides=strides, + padding=padding, + ) def jax_abstract_eval(f): - """Returns a function that evaluates `f` given input shapes and dtypes. + """Returns a function that evaluates `f` given input shapes and dtypes. + + It transforms function `f` to a function that performs the same computation as + `f` but only on shapes and dtypes (a.k.a. shape inference). - It transforms function `f` to a function that performs the same computation as - `f` but only on shapes and dtypes (a.k.a. shape inference). + Args: + f: the function to be transformed. - Args: - f: the function to be transformed. + Returns: + A function whose input arguments can be either the same as `f`'s or only + their shapes/dtypes represented by `ShapeDtype`, and whose return values are + `ShapeDtype`s with the same nested structure as `f`'s return values. + """ - Returns: - A function whose input arguments can be either the same as `f`'s or only - their shapes/dtypes represented by `ShapeDtype`, and whose return values are - `ShapeDtype`s with the same nested structure as `f`'s return values. - """ - def shape_fun(*args, **kwargs): - jax_shapes = jax.eval_shape(f, *args, **kwargs) - return tnp.nested_map(signature, jax_shapes) - return shape_fun + def shape_fun(*args, **kwargs): + jax_shapes = jax.eval_shape(f, *args, **kwargs) + return tnp.nested_map(signature, jax_shapes) + + return shape_fun # The default value of dtype is different from jax_random.randint def jax_randint(key, shape, minval, maxval, dtype=np.int32): - """Sample uniform random values in [minval, maxval) with given shape/dtype. + """Sample uniform random values in [minval, maxval) with given shape/dtype. - Args: - key: a PRNGKey used as the random key. - shape: a tuple of nonnegative integers representing the shape. - minval: int or array of ints broadcast-compatible with ``shape``, a minimum - (inclusive) value for the range. - maxval: int or array of ints broadcast-compatible with ``shape``, a maximum - (exclusive) value for the range. - dtype: optional, an int dtype for the returned values (default int32). + Args: + key: a PRNGKey used as the random key. + shape: a tuple of nonnegative integers representing the shape. + minval: int or array of ints broadcast-compatible with ``shape``, a minimum + (inclusive) value for the range. + maxval: int or array of ints broadcast-compatible with ``shape``, a maximum + (exclusive) value for the range. + dtype: optional, an int dtype for the returned values (default int32). - Returns: - A random array with the specified shape and dtype. - """ - return jax_random.randint(key, shape, minval=minval, maxval=maxval, - dtype=dtype) + Returns: + A random array with the specified shape and dtype. + """ + return jax_random.randint(key, shape, minval=minval, maxval=maxval, dtype=dtype) def _to_numpy(x): - """Converts non-NumPy tensors to NumPy arrays.""" - return x if isinstance(x, np.ndarray) else x.numpy() + """Converts non-NumPy tensors to NumPy arrays.""" + return x if isinstance(x, np.ndarray) else x.numpy() def _dataset_as_numpy(ds, batch_size=None): - """Speed up tfds.as_numpy by batching and then iterating over the batches.""" - batch_size = batch_size or 1 - try: # Check that dense_to_ragged_batch exists. - if batch_size < 2: # Fall back to default if no batching requested. - raise AttributeError - ds_batch = ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size)) - for example in tfds.as_numpy(ds_batch): - flat_example = tnp.tree_flatten(example) - np_flat_example = [_to_numpy(x) for x in flat_example] - for single_example_flat in zip(*np_flat_example): - single_example, _ = tnp.tree_unflatten(single_example_flat, example) - yield single_example - except AttributeError: - # In TF 1.X there is not dense_to_ragged_batch: fallback. - for example in tfds.as_numpy(ds): - yield example + """Speed up tfds.as_numpy by batching and then iterating over the batches.""" + batch_size = batch_size or 1 + try: # Check that dense_to_ragged_batch exists. + if batch_size < 2: # Fall back to default if no batching requested. + raise AttributeError + ds_batch = ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size)) + for example in tfds.as_numpy(ds_batch): + flat_example = tnp.tree_flatten(example) + np_flat_example = [_to_numpy(x) for x in flat_example] + for single_example_flat in zip(*np_flat_example): + single_example, _ = tnp.tree_unflatten(single_example_flat, example) + yield single_example + except AttributeError: + # In TF 1.X there is not dense_to_ragged_batch: fallback. + for example in tfds.as_numpy(ds): + yield example def _custom_grad(f_vjp, f_original): - f_ = jax.custom_transforms(f_original) - jax.defvjp_all(f_, f_vjp) - return f_ + f_ = jax.custom_transforms(f_original) + jax.defvjp_all(f_, f_vjp) + return f_ def _custom_vjp(f, f_fwd, f_bwd, nondiff_argnums=()): - @functools.partial(jax.custom_vjp, nondiff_argnums=nondiff_argnums) - def _f(*args, **kwargs): - return f(*args, **kwargs) - _f.defvjp(f_fwd, f_bwd) - return _f + @functools.partial(jax.custom_vjp, nondiff_argnums=nondiff_argnums) + def _f(*args, **kwargs): + return f(*args, **kwargs) + + _f.defvjp(f_fwd, f_bwd) + return _f JAX_BACKEND = { - 'name': 'jax', - 'np': jnp, - 'abstract_eval': jax_abstract_eval, - 'avg_pool': jax_avg_pool, - 'cond': lax.cond, - 'conv': jax_conv, - 'custom_vjp': _custom_vjp, - 'custom_grad': _custom_grad, - 'dataset_as_numpy': _dataset_as_numpy, - 'dynamic_slice': jax.lax.dynamic_slice, - 'dynamic_slice_in_dim': jax.lax.dynamic_slice_in_dim, - 'dynamic_update_slice': jax.lax.dynamic_update_slice, - 'dynamic_update_slice_in_dim': jax.lax.dynamic_update_slice_in_dim, - 'erf': jax_special.erf, - 'expit': jax_special.expit, - 'fori_loop': lax.fori_loop, - 'global_device_count': jax.device_count, - 'grad': jax.grad, - 'value_and_grad': jax.value_and_grad, - 'index_add': lambda x, idx, y: jnp.asarray(x).at[idx].add(y), - 'index_max': lambda x, idx, y: jnp.asarray(x).at[idx].max(y), - 'index_min': lambda x, idx, y: jnp.asarray(x).at[idx].min(y), - 'index_update': lambda x, idx, y: jnp.asarray(x).at[idx].set(y), - 'jit': jax.jit, - 'local_device_count': jax.local_device_count, - 'logsumexp': jax_special.logsumexp, - 'lt': lax.lt, - 'map': lax.map, - 'max_pool': jax_max_pool, - 'pmap': jax.pmap, - 'psum': lax.psum, - 'random_bernoulli': jax_random.bernoulli, - 'random_get_prng': jax.jit(jax_random.PRNGKey), - 'random_normal': jax_random.normal, - 'random_randint': jax_randint, - 'random_split': jax_random.split, - 'random_fold_in': jax_random.fold_in, - 'random_uniform': jax_random.uniform, - 'remat': jax.remat, - 'scan': lax.scan, - 'sort_key_val': jax.lax.sort_key_val, - 'stop_gradient': lax.stop_gradient, - 'sum_pool': jax_sum_pool, - 'top_k': lax.top_k, - 'vjp': jax.vjp, - 'vmap': jax.vmap, + "name": "jax", + "np": jnp, + "abstract_eval": jax_abstract_eval, + "avg_pool": jax_avg_pool, + "cond": lax.cond, + "cond": lax.cond, + "conv": jax_conv, + "custom_vjp": _custom_vjp, + "custom_grad": _custom_grad, + "dataset_as_numpy": _dataset_as_numpy, + "dynamic_slice": jax.lax.dynamic_slice, + "dynamic_slice_in_dim": jax.lax.dynamic_slice_in_dim, + "dynamic_update_slice": jax.lax.dynamic_update_slice, + "dynamic_update_slice_in_dim": jax.lax.dynamic_update_slice_in_dim, + "erf": jax_special.erf, + "expit": jax_special.expit, + "fori_loop": lax.fori_loop, + "global_device_count": jax.device_count, + "grad": jax.grad, + "value_and_grad": jax.value_and_grad, + "index_add": lambda x, idx, y: jnp.asarray(x).at[idx].add(y), + "index_max": lambda x, idx, y: jnp.asarray(x).at[idx].max(y), + "index_min": lambda x, idx, y: jnp.asarray(x).at[idx].min(y), + "index_update": lambda x, idx, y: jnp.asarray(x).at[idx].set(y), + "jit": jax.jit, + "local_device_count": jax.local_device_count, + "logsumexp": jax_special.logsumexp, + "lt": lax.lt, + "map": lax.map, + "max_pool": jax_max_pool, + "pmap": jax.pmap, + "psum": lax.psum, + "random_bernoulli": jax_random.bernoulli, + "random_get_prng": jax.jit(jax_random.PRNGKey), + "random_normal": jax_random.normal, + "random_randint": jax_randint, + "random_split": jax_random.split, + "random_fold_in": jax_random.fold_in, + "random_uniform": jax_random.uniform, + "remat": jax.remat, + "scan": lax.scan, + "sort_key_val": jax.lax.sort_key_val, + "stop_gradient": lax.stop_gradient, + "sum_pool": jax_sum_pool, + "top_k": lax.top_k, + "vjp": jax.vjp, + "vmap": jax.vmap, + "devices": jax.devices, } diff --git a/trax/fastmath/numpy.py b/trax/fastmath/numpy.py index 0826fa416..4825868d1 100644 --- a/trax/fastmath/numpy.py +++ b/trax/fastmath/numpy.py @@ -16,274 +16,278 @@ """Trax fast math: pure numpy backend.""" import numpy as np + from scipy.special import logsumexp -from trax.shapes import signature + +from trax.utils.shapes import signature def get_prng(seed): - """JAX-compatible way of getting PRNG seeds.""" - if np.shape(seed): - raise TypeError('PRNGKey seed must be a scalar.') - convert = lambda k: np.reshape(np.asarray(k, np.uint32), [1]) - k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF)) - k2 = convert(np.bitwise_and(seed, 0xFFFFFFFF)) - return np.concatenate([k1, k2], 0) + """JAX-compatible way of getting PRNG seeds.""" + if np.shape(seed): + raise TypeError("PRNGKey seed must be a scalar.") + convert = lambda k: np.reshape(np.asarray(k, np.uint32), [1]) + k1 = convert(np.bitwise_and(np.right_shift(seed, 32), 0xFFFFFFFF)) + k2 = convert(np.bitwise_and(seed, 0xFFFFFFFF)) + return np.concatenate([k1, k2], 0) def random_uniform(rng, shape=(), dtype=np.float64, minval=0.0, maxval=1.0): - del rng - return np.random.uniform(minval, maxval, size=shape).astype(dtype) + del rng + return np.random.uniform(minval, maxval, size=shape).astype(dtype) def random_normal(rng, shape=(), dtype=np.float64): - del rng - return np.random.normal(size=shape).astype(dtype) + del rng + return np.random.normal(size=shape).astype(dtype) def random_randint(rng, shape, minval, maxval, dtype=np.int64): - del rng - return np.random.randint(minval, maxval, size=shape).astype(dtype) + del rng + return np.random.randint(minval, maxval, size=shape).astype(dtype) def random_bernoulli(rng, p=0.5, shape=()): - del rng - return np.random.binomial(1, p, size=shape) + del rng + return np.random.binomial(1, p, size=shape) def np_abstract_eval(f): - """Abstract evaluation in numpy by running the real function on 0s.""" - def abstract_f(*args, **kwargs): - real_args = [nested_map(lambda x: np.zeros(x.shape, x.dtype), a) - for a in args] - real_res = f(*real_args, **kwargs) - return signature(real_res) - return abstract_f + """Abstract evaluation in numpy by running the real function on 0s.""" + + def abstract_f(*args, **kwargs): + real_args = [nested_map(lambda x: np.zeros(x.shape, x.dtype), a) for a in args] + real_res = f(*real_args, **kwargs) + return signature(real_res) + + return abstract_f NUMPY_BACKEND = { - 'abstract_eval': np_abstract_eval, - 'local_device_count': lambda: 1, - 'global_device_count': lambda: 1, - 'jit': lambda f: f, - 'logsumexp': logsumexp, - 'name': 'numpy', - 'np': np, - 'random_bernoulli': random_bernoulli, - 'random_get_prng': get_prng, - 'random_normal': random_normal, - 'random_randint': random_randint, - 'random_split': lambda prng, num=2: (None,) * num, - 'random_uniform': random_uniform, - 'expit': lambda x: 1. / (1. + np.exp(-x)), + "abstract_eval": np_abstract_eval, + "local_device_count": lambda: 1, + "global_device_count": lambda: 1, + "jit": lambda f: f, + "logsumexp": logsumexp, + "name": "numpy", + "np": np, + "random_bernoulli": random_bernoulli, + "random_get_prng": get_prng, + "random_normal": random_normal, + "random_randint": random_randint, + "random_split": lambda prng, num=2: (None,) * num, + "random_uniform": random_uniform, + "expit": lambda x: 1.0 / (1.0 + np.exp(-x)), } def nested_map(f, obj, level=0, ignore_nones=True): - """Maps `f` recursively inside any dicts/lists/tuples in `obj`. - - Args: - f: A function taking a single object as input. f's input must NOT be a - dict, list, or tuple, or any subclass of those. - obj: Either an input object to f or some nested structure of collections - of (collections of ...) input objects to f. - level: Level in the nested structure to stop at, counted from the leaves - - so level 0 is the leaf, level 1 is such that all of its children are at - level 0 etc. - ignore_nones: Whether to ignore Nones in the structure, i.e. return None - without calling `f`. - - Returns: - An object with the same nested structure as `obj`, but with each input - object `x` replaced by `f(x)`. - """ - if _is_at_level(obj, level): - if ignore_nones and _is_made_of_nones(obj): - return None - else: - return f(obj) - - if _is_namedtuple_instance(obj): - return type(obj)(*nested_map(f, list(obj), level=level)) - if isinstance(obj, list): - return [nested_map(f, y, level=level) for y in obj] - if isinstance(obj, tuple): - return tuple([nested_map(f, y, level=level) for y in obj]) - if isinstance(obj, dict): - return {k: nested_map(f, v, level=level) for (k, v) in obj.items()} - - raise ValueError('Non-exhaustive pattern match for {}.'.format(obj)) + """Maps `f` recursively inside any dicts/lists/tuples in `obj`. + + Args: + f: A function taking a single object as input. f's input must NOT be a + dict, list, or tuple, or any subclass of those. + obj: Either an input object to f or some nested structure of collections + of (collections of ...) input objects to f. + level: Level in the nested structure to stop at, counted from the leaves - + so level 0 is the leaf, level 1 is such that all of its children are at + level 0 etc. + ignore_nones: Whether to ignore Nones in the structure, i.e. return None + without calling `f`. + + Returns: + An object with the same nested structure as `obj`, but with each input + object `x` replaced by `f(x)`. + """ + if _is_at_level(obj, level): + if ignore_nones and _is_made_of_nones(obj): + return None + else: + return f(obj) + + if _is_namedtuple_instance(obj): + return type(obj)(*nested_map(f, list(obj), level=level)) + if isinstance(obj, list): + return [nested_map(f, y, level=level) for y in obj] + if isinstance(obj, tuple): + return tuple([nested_map(f, y, level=level) for y in obj]) + if isinstance(obj, dict): + return {k: nested_map(f, v, level=level) for (k, v) in obj.items()} + + raise ValueError("Non-exhaustive pattern match for {}.".format(obj)) def nested_map_multiarg(f, *objs, ignore_nones=True): - """Maps multi-arg `f` recursively inside any dicts/lists/tuples in `objs`. - - Args: - f: A function taking len(objs) inputs. f's input must NOT be a - dict, list, or tuple, or any subclass of those. - *objs: Either input objects to f or some nested structure of collections - of (collections of ...) input objects to f. - ignore_nones: Whether to ignore Nones in the structure, i.e. return None - without calling `f`. - - Returns: - An object with the same nested structure as `objs[0]`, but with each input - object `x` replaced by `f(*xs)`. - """ - if isinstance(objs[0], list): - return [nested_map_multiarg(f, *[o[i] for o in objs]) - for i in range(len(objs[0]))] - if isinstance(objs[0], tuple): - return tuple([nested_map_multiarg(f, *[o[i] for o in objs]) - for i in range(len(objs[0]))]) - if isinstance(objs[0], dict): - return {k: nested_map_multiarg(f, *[o[k] for o in objs]) - for k in objs[0]} - if ignore_nones and _is_made_of_nones(objs): - return None - return f(*objs) + """Maps multi-arg `f` recursively inside any dicts/lists/tuples in `objs`. + + Args: + f: A function taking len(objs) inputs. f's input must NOT be a + dict, list, or tuple, or any subclass of those. + *objs: Either input objects to f or some nested structure of collections + of (collections of ...) input objects to f. + ignore_nones: Whether to ignore Nones in the structure, i.e. return None + without calling `f`. + + Returns: + An object with the same nested structure as `objs[0]`, but with each input + object `x` replaced by `f(*xs)`. + """ + if isinstance(objs[0], list): + return [ + nested_map_multiarg(f, *[o[i] for o in objs]) for i in range(len(objs[0])) + ] + if isinstance(objs[0], tuple): + return tuple( + [nested_map_multiarg(f, *[o[i] for o in objs]) for i in range(len(objs[0]))] + ) + if isinstance(objs[0], dict): + return {k: nested_map_multiarg(f, *[o[k] for o in objs]) for k in objs[0]} + if ignore_nones and _is_made_of_nones(objs): + return None + return f(*objs) def nested_zip(objs): - """Zips the leaves of each nested structure in `objs`. + """Zips the leaves of each nested structure in `objs`. - Args: - objs: List of nested structures to zip. + Args: + objs: List of nested structures to zip. - Returns: - An object with the same nested structure as each element of `objs`, with - leaves zipped together into tuples. - """ - assert isinstance(objs, (list, tuple)) - assert objs, 'Cannot zip an empty sequence.' + Returns: + An object with the same nested structure as each element of `objs`, with + leaves zipped together into tuples. + """ + assert isinstance(objs, (list, tuple)) + assert objs, "Cannot zip an empty sequence." - if _is_at_level(objs, 1): - return tuple(objs) + if _is_at_level(objs, 1): + return tuple(objs) - if _is_namedtuple_instance(objs[0]): - return type(objs[0])(*nested_zip(list(map(list, objs)))) - if isinstance(objs[0], list): - return [nested_zip([obj[i] for obj in objs]) for i in range(len(objs[0]))] - if isinstance(objs[0], tuple): - return nested_zip(list(map(list, objs))) - if isinstance(objs[0], dict): - return {k: nested_zip([obj[k] for obj in objs]) for k in objs[0]} + if _is_namedtuple_instance(objs[0]): + return type(objs[0])(*nested_zip(list(map(list, objs)))) + if isinstance(objs[0], list): + return [nested_zip([obj[i] for obj in objs]) for i in range(len(objs[0]))] + if isinstance(objs[0], tuple): + return nested_zip(list(map(list, objs))) + if isinstance(objs[0], dict): + return {k: nested_zip([obj[k] for obj in objs]) for k in objs[0]} - raise ValueError('Non-exhaustive pattern match for {}.'.format(objs[0])) + raise ValueError("Non-exhaustive pattern match for {}.".format(objs[0])) def nested_stack(objs, axis=0, np_module=np): - """Stacks the numpy arrays inside any dicts/lists/tuples in `objs`. - - Args: - objs: List of nested structures to stack. - axis: Axis to stack along. - np_module: numpy module to use - typically numpy or jax.numpy. - - Returns: - An object with the same nested structure as each element of `objs`, with - leaves stacked together into numpy arrays. Nones are propagated, i.e. if - each element of the stacked sequence is None, the output will be None. - """ - # nested_map the stacking operation, but stopping at level 1 so at tuples of - # numpy arrays. - return nested_map( - lambda x: np_module.stack(x, axis=axis), - nested_zip(objs), - level=1, - ) + """Stacks the numpy arrays inside any dicts/lists/tuples in `objs`. + + Args: + objs: List of nested structures to stack. + axis: Axis to stack along. + np_module: numpy module to use - typically numpy or jax.numpy. + + Returns: + An object with the same nested structure as each element of `objs`, with + leaves stacked together into numpy arrays. Nones are propagated, i.e. if + each element of the stacked sequence is None, the output will be None. + """ + # nested_map the stacking operation, but stopping at level 1 so at tuples of + # numpy arrays. + return nested_map( + lambda x: np_module.stack(x, axis=axis), + nested_zip(objs), + level=1, + ) def tree_flatten(tree): - """Flatten a tree into a list.""" - if isinstance(tree, (list, tuple)): - # In python, sum of lists starting from [] is the concatenation. - return sum([tree_flatten(t) for t in tree], []) - if isinstance(tree, dict): - # Only use the values in case of a dictionary node. - return sum([tree_flatten(v) for v in tree.values()], []) - return [tree] + """Flatten a tree into a list.""" + if isinstance(tree, (list, tuple)): + # In python, sum of lists starting from [] is the concatenation. + return sum([tree_flatten(t) for t in tree], []) + if isinstance(tree, dict): + # Only use the values in case of a dictionary node. + return sum([tree_flatten(v) for v in tree.values()], []) + return [tree] def tree_leaves(tree, ignore_nones=True): - """Gets the leaves of a tree.""" + """Gets the leaves of a tree.""" - # Right now this is just `tree_flatten`, but we keep this separate since - # JAX's tree_flatten returns the structure of the tree as well. - flattened = tree_flatten(tree) - return [flat for flat in flattened if (not ignore_nones) or flat is not None] + # Right now this is just `tree_flatten`, but we keep this separate since + # JAX's tree_flatten returns the structure of the tree as well. + flattened = tree_flatten(tree) + return [flat for flat in flattened if (not ignore_nones) or flat is not None] def tree_unflatten(flat, tree, copy_from_tree=None): - """Unflatten a list into a tree given the tree shape as second argument. - - Args: - flat: a flat list of elements to be assembled into a tree. - tree: a tree with the structure we want to have in the new tree. - copy_from_tree: optional list of elements that we just copy from tree. - This argument is used when the flat version does not contain all elements - of the expected tree but just a subset, while the rest are filled from - the tree itself. It allows to omit "unnecessary" elements. For example, - consider trees (A, (B, X), X) and (X, (A, X), B) where X is some element - we do not care about. Flattening the first tree and removing X will yield - a flat list [A, B] and the second tree can then be reconstructed from this - list and the tree (X, (E, X), E) with copy_from_tree=[X]. One example - where this is used is the weights-tree of a model, where layers with no - weights have () in the tree and we use copy_from_tree=[()] to restore - a model from a file that only has a list of trainable weights. - - Returns: - A pair (new_tree, rest_of_flat) where the new tree that has the structure - of tree but with leaves from flat, and the remaining elements of flat if - more were provided than the number of leaves of tree (useful for recursion). - """ - if copy_from_tree is not None: - for el in copy_from_tree: - # Equality checks comparing a DeviceArray with other Python objects - # may legitimately raise a TypeError. - try: - if tree == el: - return tree, flat - except TypeError: - continue - - if isinstance(tree, (list, tuple)): - new_tree, rest = [], flat - for t in tree: - new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) - new_tree.append(new_t) - new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree - return new_tree, rest - if isinstance(tree, dict): - new_tree, rest = {}, flat - for k in tree: - new_v, rest = tree_unflatten(rest, tree[k], copy_from_tree=copy_from_tree) - new_tree[k] = new_v - return new_tree, rest - return flat[0], flat[1:] + """Unflatten a list into a tree given the tree shape as second argument. + + Args: + flat: a flat list of elements to be assembled into a tree. + tree: a tree with the structure we want to have in the new tree. + copy_from_tree: optional list of elements that we just copy from tree. + This argument is used when the flat version does not contain all elements + of the expected tree but just a subset, while the rest are filled from + the tree itself. It allows to omit "unnecessary" elements. For example, + consider trees (A, (B, X), X) and (X, (A, X), B) where X is some element + we do not care about. Flattening the first tree and removing X will yield + a flat list [A, B] and the second tree can then be reconstructed from this + list and the tree (X, (E, X), E) with copy_from_tree=[X]. One example + where this is used is the weights-tree of a model, where layers with no + weights have () in the tree and we use copy_from_tree=[()] to restore + a model from a file that only has a list of trainable weights. + + Returns: + A pair (new_tree, rest_of_flat) where the new tree that has the structure + of tree but with leaves from flat, and the remaining elements of flat if + more were provided than the number of leaves of tree (useful for recursion). + """ + if copy_from_tree is not None: + for el in copy_from_tree: + # Equality checks comparing a DeviceArray with other Python objects + # may legitimately raise a TypeError. + try: + if tree == el: + return tree, flat + except TypeError: + continue + + if isinstance(tree, (list, tuple)): + new_tree, rest = [], flat + for t in tree: + new_t, rest = tree_unflatten(rest, t, copy_from_tree=copy_from_tree) + new_tree.append(new_t) + new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree + return new_tree, rest + if isinstance(tree, dict): + new_tree, rest = {}, flat + for k in tree: + new_v, rest = tree_unflatten(rest, tree[k], copy_from_tree=copy_from_tree) + new_tree[k] = new_v + return new_tree, rest + return flat[0], flat[1:] def _is_namedtuple_instance(x): - """Checks if `x` is an instance of a `namedtuple` type.""" - if not isinstance(x, tuple): - return False - return hasattr(x, '_fields') + """Checks if `x` is an instance of a `namedtuple` type.""" + if not isinstance(x, tuple): + return False + return hasattr(x, "_fields") def _is_at_level(obj, level): - """Checks if `obj` is an at level `level`.""" - is_leaf = not isinstance(obj, (list, tuple, dict)) - if level == 0 or is_leaf: - return (level == 0) == is_leaf + """Checks if `obj` is an at level `level`.""" + is_leaf = not isinstance(obj, (list, tuple, dict)) + if level == 0 or is_leaf: + return (level == 0) == is_leaf - if isinstance(obj, dict): - elems = obj.values() - else: - elems = obj - return elems and all(_is_at_level(x, level - 1) for x in elems) + if isinstance(obj, dict): + elems = obj.values() + else: + elems = obj + return elems and all(_is_at_level(x, level - 1) for x in elems) def _is_made_of_nones(obj): - """Checks if `obj` is a nested structure of `None`s.""" - elems = tree_flatten(obj) - # Returning False for an empty list, because it doesn't have any Nones inside. - return elems and all(x is None for x in elems) + """Checks if `obj` is a nested structure of `None`s.""" + elems = tree_flatten(obj) + # Returning False for an empty list, because it doesn't have any Nones inside. + return elems and all(x is None for x in elems) diff --git a/trax/fastmath/ops.py b/trax/fastmath/ops.py index dbd6dfb83..3cc620927 100644 --- a/trax/fastmath/ops.py +++ b/trax/fastmath/ops.py @@ -32,6 +32,7 @@ import enum import gin + from trax.fastmath.jax import JAX_BACKEND from trax.fastmath.numpy import NUMPY_BACKEND from trax.fastmath.tf import TF_BACKEND @@ -39,9 +40,9 @@ @enum.unique class Backend(enum.Enum): - JAX = 'jax' - TFNP = 'tensorflow-numpy' - NUMPY = 'numpy' + JAX = "jax" + TFNP = "tensorflow-numpy" + NUMPY = "numpy" # For numpy and random modules, we need to call "backend()" lazily, only when @@ -52,320 +53,343 @@ class Backend(enum.Enum): # A class that just forwards attribute accesses to backend's numpy object. class NumpyBackend: - """Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.""" + """Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.""" + + def __getattr__(self, attr): + return getattr(backend()["np"], attr) - def __getattr__(self, attr): - return getattr(backend()['np'], attr) numpy = NumpyBackend() class RandomBackend: - """Backend providing random functions.""" + """Backend providing random functions.""" - def get_prng(self, seed): - return backend()['random_get_prng'](seed) + def get_prng(self, seed): + return backend()["random_get_prng"](seed) - def split(self, prng, num=2): - return backend()['random_split'](prng, num) + def split(self, prng, num=2): + return backend()["random_split"](prng, num) - def fold_in(self, rng, data): - return backend()['random_fold_in'](rng, data) + def fold_in(self, rng, data): + return backend()["random_fold_in"](rng, data) - def uniform(self, *args, **kwargs): - return backend()['random_uniform'](*args, **kwargs) + def uniform(self, *args, **kwargs): + return backend()["random_uniform"](*args, **kwargs) - def randint(self, *args, **kwargs): - return backend()['random_randint'](*args, **kwargs) + def randint(self, *args, **kwargs): + return backend()["random_randint"](*args, **kwargs) - def normal(self, *args, **kwargs): - return backend()['random_normal'](*args, **kwargs) + def normal(self, *args, **kwargs): + return backend()["random_normal"](*args, **kwargs) - def bernoulli(self, *args, **kwargs): - return backend()['random_bernoulli'](*args, **kwargs) + def bernoulli(self, *args, **kwargs): + return backend()["random_bernoulli"](*args, **kwargs) random = RandomBackend() def logsumexp(*args, **kwargs): - """Computes the log of the sum of exponentials of input elements.""" - return backend()['logsumexp'](*args, **kwargs) + """Computes the log of the sum of exponentials of input elements.""" + return backend()["logsumexp"](*args, **kwargs) def expit(*args, **kwargs): - """Computes the expit (sigmoid) function.""" - return backend()['expit'](*args, **kwargs) + """Computes the expit (sigmoid) function.""" + return backend()["expit"](*args, **kwargs) def sigmoid(*args, **kwargs): - """Computes the sigmoid (expit) function.""" - return backend()['expit'](*args, **kwargs) + """Computes the sigmoid (expit) function.""" + return backend()["expit"](*args, **kwargs) def erf(*args, **kwargs): - """Computes the erf function.""" - return backend()['erf'](*args, **kwargs) + """Computes the erf function.""" + return backend()["erf"](*args, **kwargs) def conv(*args, **kwargs): - """Computes a generalized convolution.""" - return backend()['conv'](*args, **kwargs) + """Computes a generalized convolution.""" + return backend()["conv"](*args, **kwargs) def avg_pool(*args, **kwargs): - """Average pooling.""" - return backend()['avg_pool'](*args, **kwargs) + """Average pooling.""" + return backend()["avg_pool"](*args, **kwargs) def max_pool(*args, **kwargs): - """Max pooling.""" - return backend()['max_pool'](*args, **kwargs) + """Max pooling.""" + return backend()["max_pool"](*args, **kwargs) def sum_pool(*args, **kwargs): - """Sum pooling.""" - return backend()['sum_pool'](*args, **kwargs) + """Sum pooling.""" + return backend()["sum_pool"](*args, **kwargs) def top_k(*args, **kwargs): - """Top k.""" - return backend()['top_k'](*args, **kwargs) + """Top k.""" + return backend()["top_k"](*args, **kwargs) def sort_key_val(*args, **kwargs): - """Sorts keys along dimension and applies same permutation to values.""" - return backend()['sort_key_val'](*args, **kwargs) + """Sorts keys along dimension and applies same permutation to values.""" + return backend()["sort_key_val"](*args, **kwargs) def scan(*args, **kwargs): - """Scan to make recurrent functions run faster on accelerators.""" - return backend()['scan'](*args, **kwargs) + """Scan to make recurrent functions run faster on accelerators.""" + return backend()["scan"](*args, **kwargs) def map(*args, **kwargs): # pylint: disable=redefined-builtin - """Map a function over leading array axes.""" - return backend()['map'](*args, **kwargs) + """Map a function over leading array axes.""" + return backend()["map"](*args, **kwargs) def fori_loop(lower, upper, body_fn, init_val): - """Loop from `lower` to `upper` running `body_fn` starting from `init_val`. - - The semantics of `fori_loop` is as follows:: - - def fori_loop(lower, upper, body_fn, init_val): - val = init_val - for i in range(lower, upper): - val = body_fn(i, val) - return val - - Args: - lower: an integer representing the loop index lower bound (inclusive) - upper: an integer representing the loop index upper bound (exclusive) - body_fn: function of type `(int, a) -> a`. - init_val: initial loop carry value of type `a`. - - Returns: - Loop value from the final iteration. - """ - if 'fori_loop' in backend(): - return backend()['fori_loop'](lower, upper, body_fn, init_val) - # Use scan otherwise. - def scanned_fn(loop_carry, _): - i, x = loop_carry - return (i + 1, body_fn(i, x)), None - (_, result), _ = scan( - scanned_fn, (lower, init_val), None, length=upper - lower) - return result + """Loop from `lower` to `upper` running `body_fn` starting from `init_val`. + + The semantics of `fori_loop` is as follows:: + + def fori_loop(lower, upper, body_fn, init_val): + val = init_val + for i in range(lower, upper): + val = body_fn(i, val) + return val + + Args: + lower: an integer representing the loop index lower bound (inclusive) + upper: an integer representing the loop index upper bound (exclusive) + body_fn: function of type `(int, a) -> a`. + init_val: initial loop carry value of type `a`. + + Returns: + Loop value from the final iteration. + """ + if "fori_loop" in backend(): + return backend()["fori_loop"](lower, upper, body_fn, init_val) + # Use scan otherwise. + def scanned_fn(loop_carry, _): + i, x = loop_carry + return (i + 1, body_fn(i, x)), None + + (_, result), _ = scan(scanned_fn, (lower, init_val), None, length=upper - lower) + return result def remat(*args, **kwargs): - """Recompute everything in the backward pass to same memory.""" - return backend()['remat'](*args, **kwargs) + """Recompute everything in the backward pass to same memory.""" + return backend()["remat"](*args, **kwargs) def cond(*args, **kwargs): - """Conditional computation to run on accelerators.""" - return backend()['cond'](*args, **kwargs) + """Conditional computation to run on accelerators.""" + return backend()["cond"](*args, **kwargs) def lt(*args, **kwargs): - """Less-than function for backends that do not override <.""" - return backend()['lt'](*args, **kwargs) + """Less-than function for backends that do not override <.""" + return backend()["lt"](*args, **kwargs) def index_update(*args, **kwargs): - return backend()['index_update'](*args, **kwargs) + return backend()["index_update"](*args, **kwargs) def index_add(*args, **kwargs): - return backend()['index_add'](*args, **kwargs) + return backend()["index_add"](*args, **kwargs) def index_min(*args, **kwargs): - return backend()['index_min'](*args, **kwargs) + return backend()["index_min"](*args, **kwargs) def index_max(*args, **kwargs): - return backend()['index_max'](*args, **kwargs) + return backend()["index_max"](*args, **kwargs) def dynamic_slice(*args, **kwargs): - return backend()['dynamic_slice'](*args, **kwargs) + return backend()["dynamic_slice"](*args, **kwargs) def dynamic_slice_in_dim(*args, **kwargs): - return backend()['dynamic_slice_in_dim'](*args, **kwargs) + return backend()["dynamic_slice_in_dim"](*args, **kwargs) def dynamic_update_slice(*args, **kwargs): - return backend()['dynamic_update_slice'](*args, **kwargs) + return backend()["dynamic_update_slice"](*args, **kwargs) def dynamic_update_slice_in_dim(*args, **kwargs): - return backend()['dynamic_update_slice_in_dim'](*args, **kwargs) + return backend()["dynamic_update_slice_in_dim"](*args, **kwargs) def stop_gradient(*args, **kwargs): - """Identity on the forward pass but 0 (no gradient) on the backward pass.""" - return backend()['stop_gradient'](*args, **kwargs) + """Identity on the forward pass but 0 (no gradient) on the backward pass.""" + return backend()["stop_gradient"](*args, **kwargs) _disable_jit = False def jit(*args, **kwargs): - """Just-In-Time compiles the given function for use on accelerators.""" - global _disable_jit - if _disable_jit: - return args[0] # jit(f, **unused_now_jit_kwargs) = f - return backend()['jit'](*args, **kwargs) + """Just-In-Time compiles the given function for use on accelerators.""" + global _disable_jit + if _disable_jit: + return args[0] # jit(f, **unused_now_jit_kwargs) = f + return backend()["jit"](*args, **kwargs) def disable_jit(): - """Disables JIT-compilation; helpful for debugging.""" - global _disable_jit - _disable_jit = True + """Disables JIT-compilation; helpful for debugging.""" + global _disable_jit + _disable_jit = True def vmap(*args, **kwargs): - """Vectorizes the specified function (returns a function).""" - return backend()['vmap'](*args, **kwargs) + """Vectorizes the specified function (returns a function).""" + return backend()["vmap"](*args, **kwargs) def grad(*args, **kwargs): - """Computes the gradient of the specified function (returns a function).""" - return backend()['grad'](*args, **kwargs) + """Computes the gradient of the specified function (returns a function).""" + return backend()["grad"](*args, **kwargs) def value_and_grad(*args, **kwargs): - """Computes the gradient of the specified function together with the value.""" - if 'value_and_grad' in backend(): - return backend()['value_and_grad'](*args, **kwargs) - grad_fn = grad(*args, **kwargs) - fn = args[0] - has_aux = False - if has_aux in kwargs: - has_aux = kwargs['has_aux'] - if not has_aux: - def val_and_grad(*fn_args, **fn_kwargs): - return fn(*fn_args, **fn_kwargs), grad_fn(*fn_args, **fn_kwargs) - return val_and_grad - def val_and_grad_aux(*fn_args, **fn_kwargs): - g, aux = grad_fn(*fn_args, **fn_kwargs) - res, _ = fn(*fn_args, **fn_kwargs) - return (res, aux), g - return val_and_grad_aux + """Computes the gradient of the specified function together with the value.""" + if "value_and_grad" in backend(): + return backend()["value_and_grad"](*args, **kwargs) + + grad_fn = grad(*args, **kwargs) + fn = args[0] + has_aux = False + if has_aux in kwargs: + has_aux = kwargs["has_aux"] + if not has_aux: + + def val_and_grad(*fn_args, **fn_kwargs): + return fn(*fn_args, **fn_kwargs), grad_fn(*fn_args, **fn_kwargs) + + return val_and_grad + + def val_and_grad_aux(*fn_args, **fn_kwargs): + g, aux = grad_fn(*fn_args, **fn_kwargs) + res, _ = fn(*fn_args, **fn_kwargs) + return (res, aux), g + + return val_and_grad_aux def vjp(*args, **kwargs): - """Computes the vector-Jacobian product for the specified function.""" - return backend()['vjp'](*args, **kwargs) + """Computes the vector-Jacobian product for the specified function.""" + return backend()["vjp"](*args, **kwargs) def custom_grad(*args, **kwargs): - """Set a custom gradient computation (override the default) for a function.""" - return backend()['custom_grad'](*args, **kwargs) + """Set a custom gradient computation (override the default) for a function.""" + return backend()["custom_grad"](*args, **kwargs) def custom_vjp(f, f_fwd, f_bwd, nondiff_argnums=()): - """Set a custom vjp computation (override the default) for a function.""" - # Call backend custom_vjp if it exists. - # TODO(lukaszkaiser): unify the APIs and remove nondiff_argnums altogether. - if 'custom_vjp' in backend(): - return backend()['custom_vjp'](f, f_fwd, f_bwd) - - # Check that nondiff_argnums is (0, 1, ..., N) for some N. - # Currently we only support nondiff_argnums at the front. - counter = -1 - for i in nondiff_argnums: - counter += 1 - if i != counter: - raise ValueError('Currently we only support custom_vjps with all nondiff' - '_argnums up front, like (0,) or (0, 1) but not (1,) or' - ' (1, 2). Found: %s' % str(nondiff_argnums)) - - # Use custom_grad. - if counter == -1: # no non-diff args - def f_vjp(*args): - out, residual = f_fwd(*args) - def vjpfn(g): - return f_bwd(residual, g) - return out, vjpfn - return backend()['custom_grad'](f_vjp, f) - - # Handle non-diff args by closure. - def f_joint(*args): - """This function takes all args, first counter+1 are non-diff ones.""" - nondiff_args = list(args[:counter+1]) - def f_diff(*diff_args): # Takes only diff args, will define custom grad. - args = nondiff_args + list(diff_args) - return f(*args) - def f_vjp(*diff_args): # Custom VJP for diff args. - args = nondiff_args + list(diff_args) - out, residual = f_fwd(*args) - def vjpfn(g): - bwd_args = [residual, g] - res = f_bwd(*bwd_args) - return res[counter+1:] - return out, vjpfn - # This is the function taking only diff args with custom vjp. - f_diff_vjp = backend()['custom_grad'](f_vjp, f_diff) - # Call it on the diff args. - return f_diff_vjp(*args[counter+1:]) - return f_joint + """Set a custom vjp computation (override the default) for a function.""" + # Call backend custom_vjp if it exists. + # TODO(lukaszkaiser): unify the APIs and remove nondiff_argnums altogether. + if "custom_vjp" in backend(): + return backend()["custom_vjp"](f, f_fwd, f_bwd) + + # Check that nondiff_argnums is (0, 1, ..., N) for some N. + # Currently we only support nondiff_argnums at the front. + counter = -1 + for i in nondiff_argnums: + counter += 1 + if i != counter: + raise ValueError( + "Currently we only support custom_vjps with all nondiff" + "_argnums up front, like (0,) or (0, 1) but not (1,) or" + " (1, 2). Found: %s" % str(nondiff_argnums) + ) + + # Use custom_grad. + if counter == -1: # no non-diff args + + def f_vjp(*args): + out, residual = f_fwd(*args) + + def vjpfn(g): + return f_bwd(residual, g) + + return out, vjpfn + + return backend()["custom_grad"](f_vjp, f) + + # Handle non-diff args by closure. + def f_joint(*args): + """This function takes all args, first counter+1 are non-diff ones.""" + nondiff_args = list(args[: counter + 1]) + + def f_diff(*diff_args): # Takes only diff args, will define custom grad. + args = nondiff_args + list(diff_args) + return f(*args) + + def f_vjp(*diff_args): # Custom VJP for diff args. + args = nondiff_args + list(diff_args) + out, residual = f_fwd(*args) + + def vjpfn(g): + bwd_args = [residual, g] + res = f_bwd(*bwd_args) + return res[counter + 1 :] + + return out, vjpfn + + # This is the function taking only diff args with custom vjp. + f_diff_vjp = backend()["custom_grad"](f_vjp, f_diff) + # Call it on the diff args. + return f_diff_vjp(*args[counter + 1 :]) + + return f_joint def pmap(*args, **kwargs): - """Parallel-map to apply a function on multiple accelerators in parallel.""" - return backend()['pmap'](*args, **kwargs) + """Parallel-map to apply a function on multiple accelerators in parallel.""" + return backend()["pmap"](*args, **kwargs) def psum(*args, **kwargs): - """Parallel-sum to use within a pmap'd function for aggregation.""" - return backend()['psum'](*args, **kwargs) + """Parallel-sum to use within a pmap'd function for aggregation.""" + return backend()["psum"](*args, **kwargs) def abstract_eval(*args, **kwargs): - """Evaluates function just on signatures of parameters, return signatures.""" - return backend()['abstract_eval'](*args, **kwargs) + """Evaluates function just on signatures of parameters, return signatures.""" + return backend()["abstract_eval"](*args, **kwargs) def dataset_as_numpy(*args, **kwargs): - """Convert a tf.data.Dataset to a stream of numpy arrays.""" - if 'dataset_as_numpy' in backend(): - return backend()['dataset_as_numpy'](*args, **kwargs) - return JAX_BACKEND['dataset_as_numpy'](*args, **kwargs) + """Convert a tf.data.Dataset to a stream of numpy arrays.""" + if "dataset_as_numpy" in backend(): + return backend()["dataset_as_numpy"](*args, **kwargs) + return JAX_BACKEND["dataset_as_numpy"](*args, **kwargs) def global_device_count(*args, **kwargs): - """Return the number of accelerators (GPUs or TPUs) in all hosts.""" - return backend()['global_device_count'](*args, **kwargs) + """Return the number of accelerators (GPUs or TPUs) in all hosts.""" + return backend()["global_device_count"](*args, **kwargs) + + +def devices(*args, **kwargs): + """Return the number of accelerators (GPUs or TPUs) in all hosts.""" + return backend()["devices"](*args, **kwargs) def local_device_count(*args, **kwargs): - """Return the number of accelerators (GPUs or TPUs) available on this host.""" - return backend()['local_device_count'](*args, **kwargs) + """Return the number of accelerators (GPUs or TPUs) available on this host.""" + return backend()["local_device_count"](*args, **kwargs) # Backend selection functions. @@ -380,65 +404,65 @@ def local_device_count(*args, **kwargs): def _assert_valid_backend_name(name): - for backend_ in Backend: - if backend_.value == name: - return - raise ValueError(f'No backend with name {name}') + for backend_ in Backend: + if backend_.value == name: + return + raise ValueError(f"No backend with name {name}") def set_backend(name): - """Sets the default backend to use in Trax.""" - if name: - _assert_valid_backend_name(name) - global default_backend - default_backend = name + """Sets the default backend to use in Trax.""" + if name: + _assert_valid_backend_name(name) + global default_backend + default_backend = name def _get_backend_from_string(name_str): - # name is a string. - for backend_ in Backend: - if backend_.value == name_str: - return _backend_dict[backend_] - return JAX_BACKEND + # name is a string. + for backend_ in Backend: + if backend_.value == name_str: + return _backend_dict[backend_] + return JAX_BACKEND @gin.configurable -def backend(name='jax'): - """Returns the backend used to provide fastmath ops ('tf' or 'jax').""" - if override_backend: - return _get_backend_from_string(override_backend) +def backend(name="jax"): + """Returns the backend used to provide fastmath ops ('tf' or 'jax').""" + if override_backend: + return _get_backend_from_string(override_backend) - if default_backend: - return _get_backend_from_string(default_backend) + if default_backend: + return _get_backend_from_string(default_backend) - if isinstance(name, Backend): - return _backend_dict[name] + if isinstance(name, Backend): + return _backend_dict[name] - # name is a string. - return _get_backend_from_string(name) + # name is a string. + return _get_backend_from_string(name) @contextlib.contextmanager def use_backend(name): - """Call fastmath functions with a specified backend.""" - if isinstance(name, Backend): - name = name.value + """Call fastmath functions with a specified backend.""" + if isinstance(name, Backend): + name = name.value - _assert_valid_backend_name(name) - global override_backend - prev_name_or_backend = override_backend - override_backend = name - # Run the decorated function in try-finally in case it throws, e.g. for tests. - try: - yield - finally: - override_backend = prev_name_or_backend + _assert_valid_backend_name(name) + global override_backend + prev_name_or_backend = override_backend + override_backend = name + # Run the decorated function in try-finally in case it throws, e.g. for tests. + try: + yield + finally: + override_backend = prev_name_or_backend def backend_name(): - """Returns the name of the backend currently in use ('tf' or 'jax').""" - return backend()['name'] + """Returns the name of the backend currently in use ('tf' or 'jax').""" + return backend()["name"] def is_backend(backend_): - return backend()['name'] == backend_.value + return backend()["name"] == backend_.value diff --git a/trax/fastmath/ops_test.py b/trax/fastmath/ops_test.py deleted file mode 100644 index 2e22b91b6..000000000 --- a/trax/fastmath/ops_test.py +++ /dev/null @@ -1,123 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.fastmath.ops.""" - -import collections -from absl.testing import parameterized - -import gin -import jax.numpy as jnp -import numpy as onp -from tensorflow import test -from trax import fastmath - - -_TestNamedtuple = collections.namedtuple('_TestNamedtuple', ['x']) - - -class BackendTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def override_gin(self, bindings): - gin.parse_config_files_and_bindings(None, bindings) - - def test_backend_imports_correctly(self): - backend = fastmath.backend() - self.assertEqual(jnp, backend['np']) - self.assertNotEqual(onp, backend['np']) - - self.override_gin("backend.name = 'numpy'") - - backend = fastmath.backend() - self.assertNotEqual(jnp, backend['np']) - self.assertEqual(onp, backend['np']) - - def test_backend_can_be_set(self): - self.assertEqual(fastmath.backend_name(), 'jax') - fastmath.set_backend('tensorflow-numpy') - self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy') - fastmath.set_backend(None) - self.assertEqual(fastmath.backend_name(), 'jax') - - def test_numpy_backend_delegation(self): - # Assert that we are getting JAX's numpy backend. - backend = fastmath.backend() - numpy = fastmath.numpy - self.assertEqual(jnp, backend['np']) - - # Assert that `numpy` calls the appropriate gin configured functions and - # properties. - self.assertTrue(numpy.isinf(numpy.inf)) - self.assertEqual(jnp.isinf, numpy.isinf) - self.assertEqual(jnp.inf, numpy.inf) - - # Assert that we will now get the pure numpy backend. - - self.override_gin("backend.name = 'numpy'") - - backend = fastmath.backend() - numpy = fastmath.numpy - self.assertEqual(onp, backend['np']) - - # Assert that `numpy` calls the appropriate gin configured functions and - # properties. - self.assertTrue(numpy.isinf(numpy.inf)) - self.assertEqual(onp.isinf, numpy.isinf) - self.assertEqual(onp.inf, numpy.inf) - - @parameterized.named_parameters( - ('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP)) - def test_fori_loop(self, backend): - with fastmath.use_backend(backend): - res = fastmath.fori_loop(2, 5, lambda i, x: x + i, 1) - self.assertEqual(res, 1 + 2 + 3 + 4) - - def test_nested_map(self): - inp = {'a': ([0, 1], 2), 'b': _TestNamedtuple(3)} - out = {'a': ([1, 2], 3), 'b': _TestNamedtuple(4)} - self.assertEqual(fastmath.nested_map(lambda x: x + 1, inp), out) - - def test_nested_stack(self): - inp = [ - {'a': ([0, 1], 2), 'b': _TestNamedtuple(3)}, - {'a': ([1, 2], 3), 'b': _TestNamedtuple(4)}, - ] - out = {'a': ([[0, 1], [1, 2]], [2, 3]), 'b': _TestNamedtuple([3, 4])} - onp.testing.assert_equal(fastmath.nested_stack(inp), out) - - def test_names_match(self): - # Names match up. - for backend_enum, backend_obj in fastmath.ops._backend_dict.items(): - self.assertEqual(backend_enum.value, backend_obj['name']) - - # Every backend appears in the dictionary. - for backend_enum in fastmath.ops.Backend: - self.assertIn(backend_enum, fastmath.ops._backend_dict) - - def test_use_backend_str(self): - with fastmath.use_backend('tensorflow-numpy'): - self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy') - - def test_use_backend_enum(self): - with fastmath.use_backend(fastmath.Backend.NUMPY): - self.assertEqual(fastmath.backend_name(), 'numpy') - - -if __name__ == '__main__': - test.main() diff --git a/trax/fastmath/tf.py b/trax/fastmath/tf.py index e02ba40a7..827fb1fa4 100644 --- a/trax/fastmath/tf.py +++ b/trax/fastmath/tf.py @@ -18,164 +18,177 @@ import numpy as np import tensorflow.compat.v2 as tf -from trax.shapes import ShapeDtype -from trax.tf_numpy import extensions as tf_np_extensions -from trax.tf_numpy import numpy as tf_np +from trax.tf import extensions as tf_np_extensions +from trax.tf import numpy as tf_np +from trax.utils.shapes import ShapeDtype def tf_abstract_eval(f): - """Returns a function that evaluates `f` given input shapes and dtypes. - - It transforms function `f` to a function that performs the same computation as - `f` but only on shapes and dtypes (a.k.a. shape inference). - - Args: - f: the function to be transformed. - - Returns: - A function whose input arguments can be either the same as `f`'s or only - their shapes/dtypes represented by `ShapeDtype`, and whose return values are - `ShapeDtype`s with the same nested structure as `f`'s return values. - """ - f_shape = tf_np_extensions.eval_on_shapes(f) - def from_shape_type(x): - if isinstance(x, ShapeDtype): - return tf.TensorSpec(x.shape, x.dtype) - else: - return x - def to_shape_type(x): # pylint: disable=missing-docstring - # TODO(wangpeng): handle partial output shapes using `tf.shape`. - def to_numpy_shape(s): - if s.is_fully_defined(): - return tuple(s.as_list()) - else: - raise ValueError("The output shapes (%s) of the dry-run'ed function are" - ' not fully defined.' % s) - def to_numpy_dtype(t): - return np.dtype(t.as_numpy_dtype) - if isinstance(x, tf.TensorSpec): - return ShapeDtype(to_numpy_shape(x.shape), to_numpy_dtype(x.dtype)) - else: - return x - def f_return(*args): - args = tf.nest.map_structure(from_shape_type, args) - res = f_shape(*args) - return tf.nest.map_structure(to_shape_type, res) - return f_return + """Returns a function that evaluates `f` given input shapes and dtypes. + + It transforms function `f` to a function that performs the same computation as + `f` but only on shapes and dtypes (a.k.a. shape inference). + + Args: + f: the function to be transformed. + + Returns: + A function whose input arguments can be either the same as `f`'s or only + their shapes/dtypes represented by `ShapeDtype`, and whose return values are + `ShapeDtype`s with the same nested structure as `f`'s return values. + """ + f_shape = tf_np_extensions.eval_on_shapes(f) + + def from_shape_type(x): + if isinstance(x, ShapeDtype): + return tf.TensorSpec(x.shape, x.dtype) + else: + return x + + def to_shape_type(x): # pylint: disable=missing-docstring + # TODO(wangpeng): handle partial output shapes using `tf.shape`. + def to_numpy_shape(s): + if s.is_fully_defined(): + return tuple(s.as_list()) + else: + raise ValueError( + "The output shapes (%s) of the dry-run'ed function are" + " not fully defined." % s + ) + + def to_numpy_dtype(t): + return np.dtype(t.as_numpy_dtype) + + if isinstance(x, tf.TensorSpec): + return ShapeDtype(to_numpy_shape(x.shape), to_numpy_dtype(x.dtype)) + else: + return x + + def f_return(*args): + args = tf.nest.map_structure(from_shape_type, args) + res = f_shape(*args) + return tf.nest.map_structure(to_shape_type, res) + + return f_return # The arguments order is different from tf_np_extensions.uniform def tf_randint(key, shape, minval, maxval, dtype=np.int32): - """Sample uniform random values in [minval, maxval) with given shape/dtype. + """Sample uniform random values in [minval, maxval) with given shape/dtype. - Args: - key: a PRNGKey used as the random key. - shape: a tuple of nonnegative integers representing the shape. - minval: int or array of ints broadcast-compatible with ``shape``, a minimum - (inclusive) value for the range. - maxval: int or array of ints broadcast-compatible with ``shape``, a maximum - (exclusive) value for the range. - dtype: optional, an int dtype for the returned values (default int32). + Args: + key: a PRNGKey used as the random key. + shape: a tuple of nonnegative integers representing the shape. + minval: int or array of ints broadcast-compatible with ``shape``, a minimum + (inclusive) value for the range. + maxval: int or array of ints broadcast-compatible with ``shape``, a maximum + (exclusive) value for the range. + dtype: optional, an int dtype for the returned values (default int32). - Returns: - A random array with the specified shape and dtype. - """ - return tf_np_extensions.uniform(key, shape, minval=minval, maxval=maxval, - dtype=dtype) + Returns: + A random array with the specified shape and dtype. + """ + return tf_np_extensions.uniform( + key, shape, minval=minval, maxval=maxval, dtype=dtype + ) _tf_xla_forced_compile_enabled = False def tf_xla_forced_compile_enabled(): - return _tf_xla_forced_compile_enabled + return _tf_xla_forced_compile_enabled def set_tf_xla_forced_compile(b): - global _tf_xla_forced_compile_enabled - _tf_xla_forced_compile_enabled = b + global _tf_xla_forced_compile_enabled + _tf_xla_forced_compile_enabled = b def _tf_jit(*args, **kwargs): - kwargs['xla_forced_compile'] = tf_xla_forced_compile_enabled() - kwargs.pop('donate_argnums', None) # donate_argnums not used in TF - return tf_np_extensions.jit(*args, **kwargs) + kwargs["xla_forced_compile"] = tf_xla_forced_compile_enabled() + kwargs.pop("donate_argnums", None) # donate_argnums not used in TF + return tf_np_extensions.jit(*args, **kwargs) def _tf_pmap(*args, **kwargs): - kwargs.pop('donate_argnums', None) # donate_argnums not used in TF - return tf_np_extensions.pmap(*args, **kwargs) + kwargs.pop("donate_argnums", None) # donate_argnums not used in TF + return tf_np_extensions.pmap(*args, **kwargs) def _tf_grad(f, **kwargs): - """Grad with support for argnums.""" - argnums = kwargs.pop('argnums', 0) - if argnums != 0: - def g(*args, **kwargs): - args = list(args) - args[0], args[argnums] = args[argnums], args[0] - return f(*args, **kwargs) - else: - g = f - grad_g = tf_np_extensions.grad(g, **kwargs) - if argnums == 0: - return grad_g - def grad_f(*args, **kwargs): - args = list(args) - args[0], args[argnums] = args[argnums], args[0] - return grad_g(*args, **kwargs) - return grad_f + """Grad with support for argnums.""" + argnums = kwargs.pop("argnums", 0) + if argnums != 0: + + def g(*args, **kwargs): + args = list(args) + args[0], args[argnums] = args[argnums], args[0] + return f(*args, **kwargs) + + else: + g = f + grad_g = tf_np_extensions.grad(g, **kwargs) + if argnums == 0: + return grad_g + + def grad_f(*args, **kwargs): + args = list(args) + args[0], args[argnums] = args[argnums], args[0] + return grad_g(*args, **kwargs) + + return grad_f def _fold_in(rng, d): - """Equivalent of jax.random.fold_in.""" - # TODO(lukaszkaiser): verify that this function has good randomness - # properties or switch to an implementation equivalent to JAX. - _, rng = tf_np_extensions.split(rng + tf_np.sum(d).astype(tf_np.int64), 2) - return rng + """Equivalent of jax.random.fold_in.""" + # TODO(lukaszkaiser): verify that this function has good randomness + # properties or switch to an implementation equivalent to JAX. + _, rng = tf_np_extensions.split(rng + tf_np.sum(d).astype(tf_np.int64), 2) + return rng TF_BACKEND = { - 'name': 'tensorflow-numpy', - 'np': tf_np, - 'jit': _tf_jit, - 'stop_gradient': tf_np_extensions.stop_gradient, - 'grad': _tf_grad, - 'vjp': tf_np_extensions.vjp, - 'custom_grad': tf_np_extensions.custom_grad, - 'abstract_eval': tf_abstract_eval, - 'expit': tf_np_extensions.expit, - 'erf': tf_np_extensions.erf, - 'index_update': tf_np_extensions.index_update, - 'index_add': tf_np_extensions.index_add, - 'index_min': tf_np_extensions.index_min, - 'index_max': tf_np_extensions.index_max, - 'dynamic_slice': tf_np_extensions.dynamic_slice, - 'dynamic_slice_in_dim': tf_np_extensions.dynamic_slice_in_dim, - 'dynamic_update_slice': tf_np_extensions.dynamic_update_slice, - 'dynamic_update_slice_in_dim': tf_np_extensions.dynamic_update_slice_in_dim, - 'logsumexp': tf_np_extensions.logsumexp, - 'conv': tf_np_extensions.conv, - 'lt': lambda x, y: x < y, - 'avg_pool': tf_np_extensions.avg_pool, - 'max_pool': tf_np_extensions.max_pool, - 'sort_key_val': tf_np_extensions.sort_key_val, - 'random_uniform': tf_np_extensions.uniform, - 'random_randint': tf_randint, - 'random_normal': tf_np_extensions.normal, - 'random_bernoulli': tf_np_extensions.bernoulli, - 'random_get_prng': tf_np_extensions.prng, - 'random_split': tf_np_extensions.split, - 'random_fold_in': _fold_in, + "name": "tensorflow-numpy", + "np": tf_np, + "jit": _tf_jit, + "stop_gradient": tf_np_extensions.stop_gradient, + "grad": _tf_grad, + "vjp": tf_np_extensions.vjp, + "custom_grad": tf_np_extensions.custom_grad, + "abstract_eval": tf_abstract_eval, + "expit": tf_np_extensions.expit, + "erf": tf_np_extensions.erf, + "index_update": tf_np_extensions.index_update, + "index_add": tf_np_extensions.index_add, + "index_min": tf_np_extensions.index_min, + "index_max": tf_np_extensions.index_max, + "dynamic_slice": tf_np_extensions.dynamic_slice, + "dynamic_slice_in_dim": tf_np_extensions.dynamic_slice_in_dim, + "dynamic_update_slice": tf_np_extensions.dynamic_update_slice, + "dynamic_update_slice_in_dim": tf_np_extensions.dynamic_update_slice_in_dim, + "logsumexp": tf_np_extensions.logsumexp, + "conv": tf_np_extensions.conv, + "lt": lambda x, y: x < y, + "avg_pool": tf_np_extensions.avg_pool, + "max_pool": tf_np_extensions.max_pool, + "sort_key_val": tf_np_extensions.sort_key_val, + "random_uniform": tf_np_extensions.uniform, + "random_randint": tf_randint, + "random_normal": tf_np_extensions.normal, + "random_bernoulli": tf_np_extensions.bernoulli, + "random_get_prng": tf_np_extensions.prng, + "random_split": tf_np_extensions.split, + "random_fold_in": _fold_in, # TODO(wangpeng): See whether and how to support `remat` - 'remat': lambda f: f, - 'scan': tf_np_extensions.scan, - 'map': tf_np_extensions.tf_map, + "remat": lambda f: f, + "scan": tf_np_extensions.scan, + "map": tf_np_extensions.tf_map, # TODO(wangpeng): can we make extensions ds_as_numpy compatible with data? # 'dataset_as_numpy': tf_np_extensions.dataset_as_numpy, - 'global_device_count': lambda: max(len(tf_np_extensions.accelerators()), 1), - 'local_device_count': lambda: max(len(tf_np_extensions.accelerators()), 1), - 'pmap': _tf_pmap, - 'psum': tf_np_extensions.psum, - 'vmap': tf_np_extensions.vmap, + "global_device_count": lambda: max(len(tf_np_extensions.accelerators()), 1), + "local_device_count": lambda: max(len(tf_np_extensions.accelerators()), 1), + "pmap": _tf_pmap, + "psum": tf_np_extensions.psum, + "vmap": tf_np_extensions.vmap, } diff --git a/trax/intro.ipynb b/trax/intro.ipynb deleted file mode 100644 index 6641e295c..000000000 --- a/trax/intro.ipynb +++ /dev/null @@ -1,757 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "7yuytuIllsv1" - }, - "source": [ - "# Trax Quick Intro\n", - "\n", - "[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google.com/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information.\n", - "\n", - " 1. **Run a pre-trained Transformer**: create a translator in a few lines of code\n", - " 1. **Features and resources**: [API docs](https://trax-ml.readthedocs.io/en/latest/trax.html), where to [talk to us](https://gitter.im/trax-ml/community), how to [open an issue](https://github.com/google/trax/issues) and more\n", - " 1. **Walkthrough**: how Trax works, how to make new models and train on your own data\n", - "\n", - "We welcome **contributions** to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love **notebooks** that explain how models work and show how to use them to solve problems!\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BIl27504La0G" - }, - "source": [ - "**General Setup**\n", - "\n", - "Execute the following few cells (once) before running any of the code samples." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "executionInfo": { - "elapsed": 36794, - "status": "ok", - "timestamp": 1607149386661, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "oILRLCWN_16u" - }, - "outputs": [], - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", - "import os\n", - "import numpy as np\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "executionInfo": { - "elapsed": 463, - "status": "ok", - "timestamp": 1607149387132, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "vlGjGoGMTt-D", - "outputId": "3076e638-695d-4017-e757-98d929630e17" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/bin/sh: pip: command not found\n" - ] - } - ], - "source": [ - "#@title\n", - "# Import Trax\n", - "\n", - "!pip install -q -U trax\n", - "import trax" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-LQ89rFFsEdk" - }, - "source": [ - "## 1. Run a pre-trained Transformer\n", - "\n", - "Here is how you create an Engligh-German translator in a few lines of code:\n", - "\n", - "* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)\n", - "* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)\n", - "* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)\n", - "* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)\n", - "* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "executionInfo": { - "elapsed": 46373, - "status": "ok", - "timestamp": 1607149433512, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "djTiSLcaNFGa", - "outputId": "a7917337-0a77-4064-8a6e-4e44e4a9c7c7" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Es ist schÃļn, heute neue Dinge zu lernen!\n" - ] - } - ], - "source": [ - "\n", - "# Create a Transformer model.\n", - "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", - "model = trax.models.Transformer(\n", - " input_vocab_size=33300,\n", - " d_model=512, d_ff=2048,\n", - " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", - " max_len=2048, mode='predict')\n", - "\n", - "# Initialize using pre-trained weights.\n", - "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", - " weights_only=True)\n", - "\n", - "# Tokenize a sentence.\n", - "sentence = 'It is nice to learn new things today!'\n", - "tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword'))[0]\n", - "\n", - "# Decode from the Transformer.\n", - "tokenized = tokenized[None, :] # Add batch dimension.\n", - "tokenized_translation = trax.supervised.decoding.autoregressive_sample(\n", - " model, tokenized, temperature=0.0) # Higher temperature: more diverse results.\n", - "\n", - "# De-tokenize,\n", - "tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.\n", - "translation = trax.data.detokenize(tokenized_translation,\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword')\n", - "print(translation)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QMo3OnsGgLNK" - }, - "source": [ - "## 2. Features and resources\n", - "\n", - "Trax includes basic models (like [ResNet](https://github.com/google/trax/blob/master/trax/models/resnet.py#L70), [LSTM](https://github.com/google/trax/blob/master/trax/models/rnn.py#L100), [Transformer](https://github.com/google/trax/blob/master/trax/models/transformer.py#L189) and RL algorithms\n", - "(like [REINFORCE](https://github.com/google/trax/blob/master/trax/rl/training.py#L244), [A2C](https://github.com/google/trax/blob/master/trax/rl/actor_critic_joint.py#L458), [PPO](https://github.com/google/trax/blob/master/trax/rl/actor_critic_joint.py#L209)). It is also actively used for research and includes\n", - "new models like the [Reformer](https://github.com/google/trax/tree/master/trax/models/reformer) and new RL algorithms like [AWR](https://arxiv.org/abs/1910.00177). Trax has bindings to a large number of deep learning datasets, including\n", - "[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor) and [TensorFlow datasets](https://www.tensorflow.org/datasets/catalog/overview).\n", - "\n", - "\n", - "You can use Trax either as a library from your own python scripts and notebooks\n", - "or as a binary from the shell, which can be more convenient for training large models.\n", - "It runs without any changes on CPUs, GPUs and TPUs.\n", - "\n", - "* [API docs](https://trax-ml.readthedocs.io/en/latest/)\n", - "* [chat with us](https://gitter.im/trax-ml/community)\n", - "* [open an issue](https://github.com/google/trax/issues)\n", - "* subscribe to [trax-discuss](https://groups.google.com/u/1/g/trax-discuss) for news\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8wgfJyhdihfR" - }, - "source": [ - "## 3. Walkthrough\n", - "\n", - "You can learn here how Trax works, how to create new models and how to train them on your own data." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yM12hgQnp4qo" - }, - "source": [ - "### Tensors and Fast Math\n", - "\n", - "The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n", - "\n", - "In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "executionInfo": { - "elapsed": 667, - "status": "ok", - "timestamp": 1607149434186, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "kSauPt0NUl_o", - "outputId": "c7288312-767d-4344-91ae-95ebf386ce57" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "matrix =\n", - "[[1 2 3]\n", - " [4 5 6]\n", - " [7 8 9]]\n", - "vector = [1. 1. 1.]\n", - "product = [12. 15. 18.]\n", - "tanh(product) = [0.99999994 0.99999994 0.99999994]\n" - ] - } - ], - "source": [ - "from trax.fastmath import numpy as fastnp\n", - "trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.\n", - "\n", - "matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n", - "print(f'matrix =\\n{matrix}')\n", - "vector = fastnp.ones(3)\n", - "print(f'vector = {vector}')\n", - "product = fastnp.dot(vector, matrix)\n", - "print(f'product = {product}')\n", - "tanh = fastnp.tanh(product)\n", - "print(f'tanh(product) = {tanh}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "snLYtU6OsKU2" - }, - "source": [ - "Gradients can be calculated using `trax.fastmath.grad`." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "executionInfo": { - "elapsed": 545, - "status": "ok", - "timestamp": 1607149434742, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "cqjYoxPEu8PG", - "outputId": "04739509-9d3a-446d-d088-84882b8917bc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "grad(2x^2) at 1 = 4.0\n", - "grad(2x^2) at -2 = -8.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " return 2.0 * x * x\n", - "\n", - "grad_f = trax.fastmath.grad(f)\n", - "\n", - "print(f'grad(2x^2) at 1 = {grad_f(1.0)}')\n", - "print(f'grad(2x^2) at -2 = {grad_f(-2.0)}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p-wtgiWNseWw" - }, - "source": [ - "### Layers\n", - "\n", - "Layers are basic building blocks of Trax models. You will learn all about them in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) but for now, just take a look at the implementation of one core Trax layer, `Embedding`:\n", - "\n", - "```\n", - "class Embedding(base.Layer):\n", - " \"\"\"Trainable layer that maps discrete tokens/IDs to vectors.\"\"\"\n", - "\n", - " def __init__(self,\n", - " vocab_size,\n", - " d_feature,\n", - " kernel_initializer=init.RandomNormalInitializer(1.0)):\n", - " \"\"\"Returns an embedding layer with given vocabulary size and vector size.\n", - "\n", - " Args:\n", - " vocab_size: Size of the input vocabulary. The layer will assign a unique\n", - " vector to each id in `range(vocab_size)`.\n", - " d_feature: Dimensionality/depth of the output vectors.\n", - " kernel_initializer: Function that creates (random) initial vectors for\n", - " the embedding.\n", - " \"\"\"\n", - " super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')\n", - " self._d_feature = d_feature # feature dimensionality\n", - " self._vocab_size = vocab_size\n", - " self._kernel_initializer = kernel_initializer\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Returns embedding vectors corresponding to input token IDs.\n", - "\n", - " Args:\n", - " x: Tensor of token IDs.\n", - "\n", - " Returns:\n", - " Tensor of embedding vectors.\n", - " \"\"\"\n", - " return jnp.take(self.weights, x, axis=0, mode='clip')\n", - "\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Randomly initializes this layer's weights.\"\"\"\n", - " del input_signature\n", - " shape_w = (self._vocab_size, self._d_feature)\n", - " w = self._kernel_initializer(shape_w, self.rng)\n", - " self.weights = w\n", - "```\n", - "\n", - "Layers with trainable weights like `Embedding` need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "executionInfo": { - "elapsed": 598, - "status": "ok", - "timestamp": 1607149436202, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "4MLSQsIiw9Aw", - "outputId": "394efc9d-9e3c-4f8c-80c2-ce3b5a935e38" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x = [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]\n", - "shape of y = (15, 32)\n" - ] - } - ], - "source": [ - "from trax import layers as tl\n", - "\n", - "# Create an input tensor x.\n", - "x = np.arange(15)\n", - "print(f'x = {x}')\n", - "\n", - "# Create the embedding layer.\n", - "embedding = tl.Embedding(vocab_size=20, d_feature=32)\n", - "embedding.init(trax.shapes.signature(x))\n", - "\n", - "# Run the layer -- y = embedding(x).\n", - "y = embedding(x)\n", - "print(f'shape of y = {y.shape}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MgCPl9ZOyCJw" - }, - "source": [ - "### Models\n", - "\n", - "Models in Trax are built from layers most often using the `Serial` and `Branch` combinators. You can read more about those combinators in the [layers intro](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html) and\n", - "see the code for many models in `trax/models/`, e.g., this is how the [Transformer Language Model](https://github.com/google/trax/blob/master/trax/models/transformer.py#L167) is implemented. Below is an example of how to build a sentiment classification model." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "executionInfo": { - "elapsed": 473, - "status": "ok", - "timestamp": 1607149436685, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "WoSz5plIyXOU", - "outputId": "f94c84c4-3224-4231-8879-4a68f328b89e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serial[\n", - " Embedding_8192_256\n", - " Mean\n", - " Dense_2\n", - "]\n" - ] - } - ], - "source": [ - "model = tl.Serial(\n", - " tl.Embedding(vocab_size=8192, d_feature=256),\n", - " tl.Mean(axis=1), # Average on axis 1 (length of sentence).\n", - " tl.Dense(2), # Classify 2 classes.\n", - ")\n", - "\n", - "# You can print model structure.\n", - "print(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FcnIjFLD0Ju1" - }, - "source": [ - "### Data\n", - "\n", - "To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call `next(data_stream)` and get a tuple, e.g., `(inputs, targets)`. Trax allows you to use [TensorFlow Datasets](https://www.tensorflow.org/datasets) easily and you can also get an iterator from your own text file using the standard `open('my_file.txt')`." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "executionInfo": { - "elapsed": 19863, - "status": "ok", - "timestamp": 1607149456555, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "pKITF1jR0_Of", - "outputId": "44a73b25-668d-4f85-9133-ebb0f5edd191" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(b\"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.\", 0)\n" - ] - } - ], - "source": [ - "train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()\n", - "eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()\n", - "print(next(train_stream)) # See one example." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fRGj4Skm1kL4" - }, - "source": [ - "Using the `trax.data` module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using `trax.data.Serial` and they are functions that you apply to streams to create processed streams." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "executionInfo": { - "elapsed": 1746, - "status": "ok", - "timestamp": 1607149458319, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "AV5wrgjZ10yU", - "outputId": "82b8e3bc-7812-4cd3-a669-401fef29f1c0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "shapes = [(8, 2048), (8,), (8,)]\n" - ] - } - ], - "source": [ - "data_pipeline = trax.data.Serial(\n", - " trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),\n", - " trax.data.Shuffle(),\n", - " trax.data.FilterByLength(max_length=2048, length_keys=[0]),\n", - " trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],\n", - " batch_sizes=[512, 128, 32, 8, 1],\n", - " length_keys=[0]),\n", - " trax.data.AddLossWeights()\n", - " )\n", - "train_batches_stream = data_pipeline(train_stream)\n", - "eval_batches_stream = data_pipeline(eval_stream)\n", - "example_batch = next(train_batches_stream)\n", - "print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l25krioP2twf" - }, - "source": [ - "### Supervised training\n", - "\n", - "When you have the model and the data, use `trax.supervised.training` to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "executionInfo": { - "elapsed": 43631, - "status": "ok", - "timestamp": 1607149504226, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "d6bIKUO-3Cw8", - "outputId": "038e6ad5-0d2f-442b-ffa1-ed431dc1d2e0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Step 1: Total number of trainable weights: 2097666\n", - "Step 1: Ran 1 train steps in 1.15 secs\n", - "Step 1: train WeightedCategoryCrossEntropy | 0.69192106\n", - "Step 1: eval WeightedCategoryCrossEntropy | 0.69349981\n", - "Step 1: eval WeightedCategoryAccuracy | 0.50312500\n", - "\n", - "Step 500: Ran 499 train steps in 10.62 secs\n", - "Step 500: train WeightedCategoryCrossEntropy | 0.50712883\n", - "Step 500: eval WeightedCategoryCrossEntropy | 0.42969493\n", - "Step 500: eval WeightedCategoryAccuracy | 0.81406250\n", - "\n", - "Step 1000: Ran 500 train steps in 8.89 secs\n", - "Step 1000: train WeightedCategoryCrossEntropy | 0.35916388\n", - "Step 1000: eval WeightedCategoryCrossEntropy | 0.41775789\n", - "Step 1000: eval WeightedCategoryAccuracy | 0.79531250\n", - "\n", - "Step 1500: Ran 500 train steps in 9.13 secs\n", - "Step 1500: train WeightedCategoryCrossEntropy | 0.35241464\n", - "Step 1500: eval WeightedCategoryCrossEntropy | 0.35194683\n", - "Step 1500: eval WeightedCategoryAccuracy | 0.85117188\n", - "\n", - "Step 2000: Ran 500 train steps in 8.54 secs\n", - "Step 2000: train WeightedCategoryCrossEntropy | 0.29129386\n", - "Step 2000: eval WeightedCategoryCrossEntropy | 0.37591279\n", - "Step 2000: eval WeightedCategoryAccuracy | 0.84062500\n" - ] - } - ], - "source": [ - "from trax.supervised import training\n", - "\n", - "# Training task.\n", - "train_task = training.TrainTask(\n", - " labeled_data=train_batches_stream,\n", - " loss_layer=tl.WeightedCategoryCrossEntropy(),\n", - " optimizer=trax.optimizers.Adam(0.01),\n", - " n_steps_per_checkpoint=500,\n", - ")\n", - "\n", - "# Evaluaton task.\n", - "eval_task = training.EvalTask(\n", - " labeled_data=eval_batches_stream,\n", - " metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],\n", - " n_eval_batches=20 # For less variance in eval numbers.\n", - ")\n", - "\n", - "# Training loop saves checkpoints to output_dir.\n", - "output_dir = os.path.expanduser('~/output_dir/')\n", - "!rm -rf {output_dir}\n", - "training_loop = training.Loop(model,\n", - " train_task,\n", - " eval_tasks=[eval_task],\n", - " output_dir=output_dir)\n", - "\n", - "# Run 2000 steps (batches).\n", - "training_loop.run(2000)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-aCkIu3x686C" - }, - "source": [ - "After training the model, run it like any layer to get results." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "executionInfo": { - "elapsed": 1683, - "status": "ok", - "timestamp": 1607149514303, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "yuPu37Lp7GST", - "outputId": "fdc4d832-2f1d-4aee-87b5-9c9dc1238503" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "example input_str: There are a few aspects to Park's movies, and in particular Wallace \u0026 Gromit, that I would say make them so great. The first is subtlety and observation, the flagship of which is the character of Gromit. He doesn't speak, he doesn't make any noise, all he has are his eyes, brow, and body posture, and with these he commands the film. Park manages to give us everything we need from this silent character through his expression. The comedy and the emotion is conveyed through the subtlest of movements and it works superbly well.\u003cbr /\u003e\u003cbr /\u003eWatching the movie you have to be aware of the entire screen. Normally you'll be guided to things in the movies, the screen won't be cluttered too much, there won't be many things to take your eyes away from the main clue or action. Park seems to need to look the other way with his movies. He throws extra content at his audience, there's action in the background, to the side of the screen, even off screen, and there's just about always something in the foreground to catch your eye. His movies are about multiple viewing and discovery, they're layered with jokes and ancillary action.\u003cbr /\u003e\u003cbr /\u003eThroughout this film there are layers of things happening on screen, jokes in the foreground maybe on a jar label and background shadows that give away action. You can imagine that for Park the movies has always been an event, and the movies he loves are ones which he wants to watch again and again. This is what shows in his movies, and in through his most beloved characters.\u003cbr /\u003e\u003cbr /\u003eThen there are the bizarre and wacky inventions which Wallace make, something which is reflected in the storyline and the twists and turns of the plot, everything is bizarre and off the wall, yet it seems so perfectly normal in this world. You can imagine that inside Park is the mind of Wallace.\u003cbr /\u003e\u003cbr /\u003eThere's also one more thing that make these movies so unique, and that's the modelling and precise hand animation. I must admit I was concerned when I knew Dreamworks was involved in the making of this movie, and I thought that they would bring their computer animation experience to the forefront. What I was scared of was Wallace \u0026 Gromit becoming CGI entities, or at the smallest, CGI being used to clean up the feel that the modelling brought to the movie.\u003cbr /\u003e\u003cbr /\u003eNot so. You can still see thumbprints and toolmarks on the characters, and far from distracting from the movie, this just adds so much real feeling to it and a feeling of physical depth to the characters and the scene on screen.\u003cbr /\u003e\u003cbr /\u003eSo what of the movie? Well I must say that the plot twist was something I had thought about well before the film was in the cinema and it came as no surprise, but that did not affect my enjoyment one little bit. Actually watching the twist unfold and the comic timing of the discovery and reactions was everything, and it had me just as sucked in as if it was a thriller, yet all the time I was laughing.\u003cbr /\u003e\u003cbr /\u003eWatching the movie was fascinating in various ways. To see the animation completed, how wild the inventions are, how Wallace is going to get into trouble and Gromit get him out, where all the cross references are in the movie, and where all the jokes are! I must admit afterwards talking with my friends I couldn't believe how much I had missed.\u003cbr /\u003e\u003cbr /\u003eThere's something different in this movie than with the others, there's a new level of adult humour in here, and I don't mean rude jokes (although there are a couple that are just so British you can't help laughing), I mean jokes that simply fly over kids heads but slap adults in the face. The kind you are used to seeing come out of somewhere like Pixar. This just adds even more appeal to the movie.\u003cbr /\u003e\u003cbr /\u003eOkay though, let me try and be a bit negative here. I didn't notice the voices in this movie, you know how you usually listen to the actors and see if you can recognise them? Well I was just too wrapped up in the movie to care or to notice who they were...okay, that's not negative. Let me try again. The main plot wasn't as strong and gripping as I'd expected, and I found myself being caught up in the side stories and the characters themselves...again...that's not a bad thing, the film was just so much rich entertainment.\u003cbr /\u003e\u003cbr /\u003eI honestly can't think of a bad thing to say about this movie, probably the worst thing I could say is that the title sequence at the end is quite repetitive...until the final title! Really, that's the worst I can say.\u003cbr /\u003e\u003cbr /\u003eThe story is a lot of fun, well set-up, well written, well executed. There's lot's of fantastic characters in here, not just Wallace \u0026 Gromit. There's so much happening on screen, so many references and jokes (check out the dresses of Lady Tottingham), cheese jokes everywhere, jokes for all the family. The characters are superbly absorbing and you'll find that you've taken to them before you realise. There's just so much in this movie for everyone.\u003cbr /\u003e\u003cbr /\u003eThere's so much I could say and write about, but I know it will quickly turn into a backslapping exercise for Park and Aardman, it would also just turn into a series of \"this bit was really funny\" and \"there's a bit when...\", and what I would rather do is tell you that this is a superb movie, to go see it, and to experience the whole thing for yourselves. I will say though that the bunnies are excellent!\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\n", - "Model returned sentiment probabilities: [[0.36765265 2.7904649 ]]\n" - ] - } - ], - "source": [ - "example_input = next(eval_batches_stream)[0][0]\n", - "example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')\n", - "print(f'example input_str: {example_input_str}')\n", - "sentiment_log_probs = model(example_input[None, :]) # Add batch dimension.\n", - "print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "last_runtime": { - "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Trax Quick Intro", - "provenance": [ - { - "file_id": "trax/intro.ipynb", - "timestamp": 1595931762204 - }, - { - "file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV", - "timestamp": 1578964243645 - }, - { - "file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07", - "timestamp": 1572044421118 - }, - { - "file_id": "intro.ipynb", - "timestamp": 1571858674399 - }, - { - "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", - "timestamp": 1569980697572 - }, - { - "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", - "timestamp": 1563927451951 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/jaxboard.py b/trax/jaxboard.py deleted file mode 100644 index c160c63fa..000000000 --- a/trax/jaxboard.py +++ /dev/null @@ -1,360 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Write Summaries from JAX for use with Tensorboard. - -See jaxboard_demo.py for example usage. -""" -import io -import struct -import time -import warnings -import wave -import matplotlib as mpl -# Necessary to prevent attempted Tk import: -with warnings.catch_warnings(): - warnings.simplefilter('ignore') - mpl.use('Agg') -# pylint: disable=g-import-not-at-top -import matplotlib.pyplot as plt -import numpy as np -import tensorflow as tf - -# pylint: disable=g-direct-tensorflow-import -from tensorflow.core.util import event_pb2 -from tensorflow.python.summary.writer.event_file_writer import EventFileWriter -# pylint: enable=g-direct-tensorflow-import - - -def _pack_images(images, rows, cols): - """Helper utility to make a tiled field of images from numpy arrays. - - Args: - images: Image tensor in shape [N, W, H, C]. - rows: Number of images per row in tiled image. - cols: Number of images per column in tiled image. - - Returns: - A tiled image of shape [W * rows, H * cols, C]. - Truncates incomplete rows. - """ - shape = np.shape(images) - width, height, depth = shape[-3:] - images = np.reshape(images, (-1, width, height, depth)) - batch = np.shape(images)[0] - rows = np.minimum(rows, batch) - cols = np.minimum(batch // rows, cols) - images = images[:rows * cols] - images = np.reshape(images, (rows, cols, width, height, depth)) - images = np.transpose(images, [0, 2, 1, 3, 4]) - images = np.reshape(images, [rows * width, cols * height, depth]) - return images - - -class SummaryWriter: - """Saves data in event and summary protos for tensorboard.""" - - def __init__(self, log_dir, enable=True): - """Create a new SummaryWriter. - - Args: - log_dir: path to record tfevents files in. - enable: bool: if False don't actually write or flush data. Used in - multihost training. - """ - # If needed, create log_dir directory as well as missing parent directories. - if not tf.io.gfile.isdir(log_dir): - tf.io.gfile.makedirs(log_dir) - - self._event_writer = EventFileWriter(log_dir, 10, 120, None) - self._step = 0 - self._closed = False - self._enabled = enable - - def add_summary(self, summary, step): - if not self._enabled: - return - event = event_pb2.Event(summary=summary) - event.wall_time = time.time() - if step is not None: - event.step = int(step) - self._event_writer.add_event(event) - - def close(self): - """Close SummaryWriter. Final!""" - if not self._closed: - self._event_writer.close() - self._closed = True - del self._event_writer - - def __del__(self): # safe? - # TODO(afrozm): Sometimes this complains with - # `TypeError: 'NoneType' object is not callable` - try: - self.close() - except Exception: # pylint: disable=broad-except - pass - - def flush(self): - if not self._enabled: - return - self._event_writer.flush() - - def scalar(self, tag, value, step=None): - """Saves scalar value. - - Args: - tag: str: label for this data - value: int/float: number to log - step: int: training step - """ - value = float(np.array(value)) - if step is None: - step = self._step - else: - self._step = step - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) - self.add_summary(summary, step) - - def image(self, tag, image, step=None): - """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3]. - - Args: - tag: str: label for this data - image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ - step: int: training step - """ - image = np.array(image) - if step is None: - step = self._step - else: - self._step = step - if len(np.shape(image)) == 2: - image = image[:, :, np.newaxis] - if np.shape(image)[-1] == 1: - image = np.repeat(image, 3, axis=-1) - image_strio = io.BytesIO() - plt.imsave(image_strio, image, format='png') - image_summary = tf.compat.v1.Summary.Image( - encoded_image_string=image_strio.getvalue(), - colorspace=3, - height=image.shape[0], - width=image.shape[1]) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) - self.add_summary(summary, step) - - def images(self, tag, images, step=None, rows=None, cols=None): - """Saves (rows, cols) tiled images from np.ndarray. - - If either rows or cols aren't given, they are determined automatically - from the size of the image batch, if neither are given a long column - of images is produced. This truncates the image batch rather than padding - if it doesn't fill the final row. - - Args: - tag: str: label for this data - images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d - step: int: training step - rows: int: number of rows in tile - cols: int: number of columns in tile - """ - images = np.array(images) - if step is None: - step = self._step - else: - self._step = step - n_images = np.shape(images)[0] - if rows is None and cols is None: - rows = 1 - cols = n_images - elif rows is None: - rows = n_images // cols - elif cols is None: - cols = n_images // rows - tiled_images = _pack_images(images, rows, cols) - self.image(tag, tiled_images, step=step) - - def plot(self, tag, mpl_plt, step=None, close_plot=True): - """Saves matplotlib plot output to summary image. - - Args: - tag: str: label for this data - mpl_plt: matplotlib stateful pyplot object with prepared plotting state - step: int: training step - close_plot: bool: automatically closes plot - """ - if step is None: - step = self._step - else: - self._step = step - fig = mpl_plt.get_current_fig_manager() - img_w, img_h = fig.canvas.get_width_height() - image_buf = io.BytesIO() - mpl_plt.savefig(image_buf, format='png') - image_summary = tf.compat.v1.Summary.Image( - encoded_image_string=image_buf.getvalue(), - colorspace=4, # RGBA - height=img_h, - width=img_w) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) - self.add_summary(summary, step) - if close_plot: - mpl_plt.close() - - def audio(self, tag, audiodata, step=None, sample_rate=44100): - """Saves audio. - - NB: single channel only right now. - - Args: - tag: str: label for this data - audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave - step: int: training step - sample_rate: sample rate of passed in audio buffer - """ - audiodata = np.array(audiodata) - if step is None: - step = self._step - else: - self._step = step - audiodata = np.clip(np.squeeze(audiodata), -1, 1) - if audiodata.ndim != 1: - raise ValueError('Audio data must be 1D.') - sample_list = (32767.0 * audiodata).astype(int).tolist() - wio = io.BytesIO() - wav_buf = wave.open(wio, 'wb') - wav_buf.setnchannels(1) - wav_buf.setsampwidth(2) - wav_buf.setframerate(sample_rate) - enc = b''.join([struct.pack(' 0 else np.concatenate([[0], counts[:end]])) - limits = limits[start:end + 1] - sum_sq = values.dot(values) - histo = tf.compat.v1.HistogramProto( - min=values.min(), - max=values.max(), - num=len(values), - sum=values.sum(), - sum_squares=sum_sq, - bucket_limit=limits.tolist(), - bucket=counts.tolist()) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)]) - self.add_summary(summary, step) - - def text(self, tag, textdata, step=None): - """Saves a text summary. - - Args: - tag: str: label for this data - textdata: string, or 1D/2D list/numpy array of strings - step: int: training step - Note: markdown formatting is rendered by tensorboard. - """ - if step is None: - step = self._step - else: - self._step = step - smd = tf.compat.v1.SummaryMetadata( - plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name='text')) - if isinstance(textdata, (str, bytes)): - tensor = tf.make_tensor_proto( - values=[textdata.encode(encoding='utf_8')], shape=(1,)) - else: - textdata = np.array(textdata) # convert lists, jax arrays, etc. - datashape = np.shape(textdata) - if len(datashape) == 1: - tensor = tf.make_tensor_proto( - values=[td.encode(encoding='utf_8') for td in textdata], - shape=(datashape[0],)) - elif len(datashape) == 2: - tensor = tf.make_tensor_proto( - values=[ - td.encode(encoding='utf_8') for td in np.reshape(textdata, -1) - ], - shape=(datashape[0], datashape[1])) - summary = tf.compat.v1.Summary( - value=[tf.compat.v1.Summary.Value( - tag=tag, metadata=smd, tensor=tensor)]) - self.add_summary(summary, step) - - -# Copied from gin/tf/utils.py:GinConfigSaverHook -def markdownify_operative_config_str(string): - """Convert an operative config string to markdown format.""" - - # TODO(b/37527917): Total hack below. Implement more principled formatting. - def process(line): - """Convert a single line to markdown format.""" - if not line.startswith('#'): - return ' ' + line - - line = line[2:] - if line.startswith('===='): - return '' - if line.startswith('None'): - return ' # None.' - if line.endswith(':'): - return '#### ' + line - return line - - output_lines = [] - for line in string.splitlines(): - procd_line = process(line) - if procd_line is not None: - output_lines.append(procd_line) - - return '\n'.join(output_lines) diff --git a/trax/layers/__init__.py b/trax/layers/__init__.py index 3913fdafc..ce8f2f83e 100644 --- a/trax/layers/__init__.py +++ b/trax/layers/__init__.py @@ -16,6 +16,7 @@ """Layers: trainable functions as neural network building blocks.""" import gin + # We create a flat layers.* namespace for uniform calling conventions as we # upstream changes. # pylint: disable=wildcard-import @@ -44,8 +45,9 @@ # Ginify def layer_configure(*args, **kwargs): - kwargs['module'] = 'trax.layers' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.layers" + return gin.external_configurable(*args, **kwargs) + # pylint: disable=used-before-assignment # pylint: disable=invalid-name @@ -69,41 +71,44 @@ def layer_configure(*args, **kwargs): FilterResponseNorm = layer_configure(FilterResponseNorm) ThresholdedLinearUnit = layer_configure(ThresholdedLinearUnit) -Attention = layer_configure(Attention, denylist=['mode']) -CausalAttention = layer_configure(CausalAttention, denylist=['mode']) -FavorAttention = layer_configure(FavorAttention, denylist=['mode']) -Favor = layer_configure(Favor, denylist=['mode']) -CausalFavor = layer_configure(CausalFavor, denylist=['mode']) -CausalFavorAttention = layer_configure(CausalFavorAttention, denylist=['mode']) +Attention = layer_configure(Attention, denylist=["mode"]) +CausalAttention = layer_configure(CausalAttention, denylist=["mode"]) +FavorAttention = layer_configure(FavorAttention, denylist=["mode"]) +Favor = layer_configure(Favor, denylist=["mode"]) +CausalFavor = layer_configure(CausalFavor, denylist=["mode"]) +CausalFavorAttention = layer_configure(CausalFavorAttention, denylist=["mode"]) DotProductCausalAttention = layer_configure( - DotProductCausalAttention, denylist=['mode']) -SelfAttention = layer_configure(SelfAttention, denylist=['mode']) -ModularCausalAttention = layer_configure(ModularCausalAttention, - denylist=['mode']) -LowRankCausalAttention = layer_configure(LowRankCausalAttention, - denylist=['mode']) -MultiplicativeCausalAttention = layer_configure(MultiplicativeCausalAttention, - denylist=['mode']) + DotProductCausalAttention, denylist=["mode"] +) +SelfAttention = layer_configure(SelfAttention, denylist=["mode"]) +ModularCausalAttention = layer_configure(ModularCausalAttention, denylist=["mode"]) +LowRankCausalAttention = layer_configure(LowRankCausalAttention, denylist=["mode"]) +MultiplicativeCausalAttention = layer_configure( + MultiplicativeCausalAttention, denylist=["mode"] +) MultiplicativeModularCausalAttention = layer_configure( - MultiplicativeModularCausalAttention, denylist=['mode']) -ConvCausalAttention = layer_configure(ConvCausalAttention, denylist=['mode']) + MultiplicativeModularCausalAttention, denylist=["mode"] +) +ConvCausalAttention = layer_configure(ConvCausalAttention, denylist=["mode"]) MultiplicativeConvCausalAttention = layer_configure( - MultiplicativeConvCausalAttention, denylist=['mode']) + MultiplicativeConvCausalAttention, denylist=["mode"] +) ConvTranspose = layer_configure(ConvTranspose) -LSHSelfAttention = layer_configure(LSHSelfAttention, denylist=['mode']) -PureLSHSelfAttention = layer_configure(PureLSHSelfAttention, denylist=['mode']) -MixedLSHSelfAttention = layer_configure( - MixedLSHSelfAttention, denylist=['mode']) +LSHSelfAttention = layer_configure(LSHSelfAttention, denylist=["mode"]) +PureLSHSelfAttention = layer_configure(PureLSHSelfAttention, denylist=["mode"]) +MixedLSHSelfAttention = layer_configure(MixedLSHSelfAttention, denylist=["mode"]) PureLSHSelfAttentionWrapper = layer_configure( - PureLSHSelfAttentionWrapper, denylist=['mode']) -EncDecAttention = layer_configure(EncDecAttention, denylist=['mode']) + PureLSHSelfAttentionWrapper, denylist=["mode"] +) +EncDecAttention = layer_configure(EncDecAttention, denylist=["mode"]) -PositionalEncoding = layer_configure( - PositionalEncoding, denylist=['mode']) +PositionalEncoding = layer_configure(PositionalEncoding, denylist=["mode"]) InfinitePositionalEncoding = layer_configure( - InfinitePositionalEncoding, denylist=['mode']) + InfinitePositionalEncoding, denylist=["mode"] +) TimeBinPositionalEncoding = layer_configure( - TimeBinPositionalEncoding, denylist=['mode']) + TimeBinPositionalEncoding, denylist=["mode"] +) AtariConvInit = layer_configure(AtariConvInit) CrossEntropyLossWithLogSoftmax = layer_configure(CrossEntropyLossWithLogSoftmax) diff --git a/trax/layers/acceleration.py b/trax/layers/acceleration.py index 57fd7ffe5..7228976a2 100644 --- a/trax/layers/acceleration.py +++ b/trax/layers/acceleration.py @@ -17,253 +17,267 @@ import jax import numpy as np + from trax import fastmath from trax.fastmath import numpy as jnp from trax.layers import base class Accelerate(base.Layer): - """Accelerates a layer, running in data-parallel way on multiple devices. - - By default it uses all available accelerators, splits the input on the - first (batch) axis, and runs each part on the corresponding accelerator. - If only one accelerator is available, this layer JIT-compiles the underlying - layer and in this way makes it run faster. - - The output is guaranteed to be the same as the output of the original layer - if the batch dimension is divisible by the number of devices. If it is not, - then 0-padding is added to make it divisible and the output may be affected - if it relies on layers like batch normalization. - - This layer does not require calling ``init`` if the underlying layer has - already been initialized, so it can be used as follows:: - - layer = tl.Serial(...) - layer.init(...) - fast_layer = tl.Accelerate(layer) - y = fast_layer(x) # Split x on batch and run data-parallel - - In case the weights of this layer need to be set using the weights of - the sublayer, use the ``replicate_weights`` function:: - - # Instead of layer.weights = new_weights: - fast_layer.replicate_weights(new_weights) - - """ - - def __init__(self, layer, n_devices=None): - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - self._n_devices = n_devices or fastmath.local_device_count() - self._jit_pure_fn = jit_forward( - layer.pure_fn, self._n_devices, do_mean=False) - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - def pure_fn(self, x, weights, state, rng, use_cache=False): - """Calls ``self.sublayer.pure_fn`` in an accelerated way.""" - # Check if we can divide x evenly across devices. - # Note: x can be a list/tuple because the underlying layer may take - # its input as a list/tuple, ex: (inputs, targets, weight). - if isinstance(x, (list, tuple)): - remainder = x[0].shape[0] % self._n_devices - else: - remainder = x.shape[0] % self._n_devices - if remainder == 0: # If yes, run the accelerated sublayer.pure_fn. - return self._jit_pure_fn(x, weights, state, rng) - # If not, pad first. - def pad(z): - pad_widths = [(0, 0)] * len(z.shape) - pad_widths[0] = (0, self._n_devices - remainder) - return jnp.pad(z, pad_widths, mode='constant', - constant_values=z.dtype.type(0)) - padded_x = [pad(z) for z in x] if isinstance(x, (list, tuple)) else pad(x) - # Run and un-pad. - padded_y, state = self._jit_pure_fn(padded_x, weights, state, rng) - if isinstance(x, (list, tuple)): - y = tuple(padded_z[:z.shape[0]] for (padded_z, z) in zip(padded_y, x)) - y = list(y) if isinstance(x, list) else y - else: - y = padded_y[:x.shape[0]] - return y, state - - def _prepare_weights(self, weights): - """Replicate or shard weights for the number of devices requested.""" - if base.N_WEIGHTS_SHARDS > 1: - if base.N_WEIGHTS_SHARDS % self._n_devices != 0: - raise ValueError(f'Number of shards ({base.N_WEIGHTS_SHARDS}) must ' - f'be a multiple of n_devices ({self._n_devices}).') - return base.shard(weights, base.N_WEIGHTS_SHARDS) + """Accelerates a layer, running in data-parallel way on multiple devices. + + By default it uses all available accelerators, splits the input on the + first (batch) axis, and runs each part on the corresponding accelerator. + If only one accelerator is available, this layer JIT-compiles the underlying + layer and in this way makes it run faster. + + The output is guaranteed to be the same as the output of the original layer + if the batch dimension is divisible by the number of devices. If it is not, + then 0-padding is added to make it divisible and the output may be affected + if it relies on layers like batch normalization. + + This layer does not require calling ``init`` if the underlying layer has + already been initialized, so it can be used as follows:: + + layer = tl.Serial(...) + layer.init(...) + fast_layer = tl.Accelerate(layer) + y = fast_layer(x) # Split x on batch and run data-parallel + + In case the weights of this layer need to be set using the weights of + the sublayer, use the ``replicate_weights`` function:: + + # Instead of layer.weights = new_weights: + fast_layer.replicate_weights(new_weights) + + """ + + def __init__(self, layer, n_devices=None): + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + self._n_devices = n_devices or fastmath.local_device_count() + self._jit_pure_fn = jit_forward(layer.pure_fn, self._n_devices, do_mean=False) + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + def pure_fn(self, x, weights, state, rng, use_cache=False): + """Calls ``self.sublayer.pure_fn`` in an accelerated way.""" + # Check if we can divide x evenly across devices. + # Note: x can be a list/tuple because the underlying layer may take + # its input as a list/tuple, ex: (inputs, targets, weight). + if isinstance(x, (list, tuple)): + remainder = x[0].shape[0] % self._n_devices + else: + remainder = x.shape[0] % self._n_devices + if remainder == 0: # If yes, run the accelerated sublayer.pure_fn. + return self._jit_pure_fn(x, weights, state, rng) + # If not, pad first. + def pad(z): + pad_widths = [(0, 0)] * len(z.shape) + pad_widths[0] = (0, self._n_devices - remainder) + return jnp.pad( + z, pad_widths, mode="constant", constant_values=z.dtype.type(0) + ) + + padded_x = [pad(z) for z in x] if isinstance(x, (list, tuple)) else pad(x) + # Run and un-pad. + padded_y, state = self._jit_pure_fn(padded_x, weights, state, rng) + if isinstance(x, (list, tuple)): + y = tuple(padded_z[: z.shape[0]] for (padded_z, z) in zip(padded_y, x)) + y = list(y) if isinstance(x, list) else y + else: + y = padded_y[: x.shape[0]] + return y, state + + def _prepare_weights(self, weights): + """Replicate or shard weights for the number of devices requested.""" + if base.N_WEIGHTS_SHARDS > 1: + if base.N_WEIGHTS_SHARDS % self._n_devices != 0: + raise ValueError( + f"Number of shards ({base.N_WEIGHTS_SHARDS}) must " + f"be a multiple of n_devices ({self._n_devices})." + ) + return base.shard(weights, base.N_WEIGHTS_SHARDS) + else: + return for_n_devices(weights, self._n_devices) + + def init(self, input_signature): + """Calls ``self.sublayer.init`` and replicates its values onto devices.""" + weights, state = self.sublayer.init(input_signature, use_cache=True) + self._weights = self._prepare_weights(weights) + self._state = for_n_devices(state, self._n_devices) + return (self.weights, self.state) + + def replicate_weights(self, weights): + """Sets the weights of the sublayer and replicates them for this layer.""" + self.sublayer.weights = weights + self._weights = self._prepare_weights(weights) + + def replicate_state(self, state): + """Sets the state of the sublayer and replicates it for this layer.""" + self.sublayer.state = state + self._state = for_n_devices(state, self._n_devices) + + def _unreplicate(self, x): + """Returns a single-device version of ``x``.""" + if self._n_devices < 2: + return x + return fastmath.nested_map(lambda y: y[0], x) + + @property + def weights(self): + # Override the getter so it works even if only sublayer is initialized. + if self._weights is base.EMPTY_WEIGHTS: + self._weights = self._prepare_weights(self.sublayer.weights) + return self._weights + + @weights.setter + def weights(self, weights): + self._weights = weights + self.sublayer.weights = self._unreplicate(weights) + + @property + def state(self): + # Override the getter so it works even if only sublayer is initialized. + if self._state is base.EMPTY_STATE: + self._state = for_n_devices(self.sublayer.state, self._n_devices) + return self._state + + @state.setter + def state(self, state): + self._state = state + self.sublayer.state = self._unreplicate(state) + + +def mean(n_devices, x, axis=None): + """Computes the mean of a distributed value ``x``. + + Args: + n_devices: Number of devices. + x: Distributed array. + axis: Axis along which to compute means; can only be ``0`` or ``None``. + + Returns: + A local array. + """ + if fastmath.backend_name() == "tensorflow-numpy" and n_devices > 1: + if axis not in (None, 0): + raise ValueError("axis can only be None or 0") + x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices + if axis is None: + x = jnp.mean(x) + return x else: - return for_n_devices(weights, self._n_devices) - - def init(self, input_signature): - """Calls ``self.sublayer.init`` and replicates its values onto devices.""" - weights, state = self.sublayer.init(input_signature, use_cache=True) - self._weights = self._prepare_weights(weights) - self._state = for_n_devices(state, self._n_devices) - return (self.weights, self.state) - - def replicate_weights(self, weights): - """Sets the weights of the sublayer and replicates them for this layer.""" - self.sublayer.weights = weights - self._weights = self._prepare_weights(weights) - - def replicate_state(self, state): - """Sets the state of the sublayer and replicates it for this layer.""" - self.sublayer.state = state - self._state = for_n_devices(state, self._n_devices) - - def _unreplicate(self, x): - """Returns a single-device version of ``x``.""" - if self._n_devices < 2: - return x - return fastmath.nested_map(lambda y: y[0], x) - - @property - def weights(self): - # Override the getter so it works even if only sublayer is initialized. - if self._weights is base.EMPTY_WEIGHTS: - self._weights = self._prepare_weights(self.sublayer.weights) - return self._weights - - @weights.setter - def weights(self, weights): - self._weights = weights - self.sublayer.weights = self._unreplicate(weights) - - @property - def state(self): - # Override the getter so it works even if only sublayer is initialized. - if self._state is base.EMPTY_STATE: - self._state = for_n_devices(self.sublayer.state, self._n_devices) - return self._state - - @state.setter - def state(self, state): - self._state = state - self.sublayer.state = self._unreplicate(state) - - -# TODO(jonni): Rename, since implementation does not use pmean. -def mean_or_pmean(n_devices, x, axis=None): - """Computes the mean of a distributed value ``x``. - - Args: - n_devices: Number of devices. - x: Distributed array. - axis: Axis along which to compute means; can only be ``0`` or ``None``. - - Returns: - A local array. - """ - if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: - if axis not in (None, 0): - raise ValueError('axis can only be None or 0') - x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices - if axis is None: - x = jnp.mean(x) - return x - else: - return jnp.mean(x, axis=axis) + return jnp.mean(x, axis=axis) def jit_forward(forward, n_devices, do_mean=True): - """Returns a JIT-compiled forward function running on ``n_devices``.""" - model_predict = _accelerate(forward, n_devices) - # n_devices == 0 => CPU - if n_devices < 2: - return model_predict - - def predict(x, weights, state, rng): - """Predict function JIT-compiled and parallelized as requested.""" - res, state = model_predict( - reshape_by_device(x, n_devices), - weights, - state, - jnp.stack(fastmath.random.split(rng, n_devices))) - res = _combine_devices(res) - if do_mean: - return fastmath.nested_map( - lambda y: mean_or_pmean(n_devices, y, axis=0), res), state - else: - return res, state - - return predict + """Returns a JIT-compiled forward function running on ``n_devices``.""" + model_predict = _accelerate(forward, n_devices) + # n_devices == 0 => CPU + if n_devices < 2: + return model_predict + + def predict(x, weights, state, rng): + """Predict function JIT-compiled and parallelized as requested.""" + res, state = model_predict( + reshape_by_device(x, n_devices), + weights, + state, + jnp.stack(fastmath.random.split(rng, n_devices)), + ) + res = _combine_devices(res) + if do_mean: + return ( + fastmath.nested_map(lambda y: mean(n_devices, y, axis=0), res), + state, + ) + else: + return res, state + + return predict def _combine_devices(x_tuple): - """Combines multi-device tensors into a single batch.""" - def f(x): - if len(x.shape) < 2: - return x # No extra batch dimension: use devices as batch, so return. - batch_size = x.shape[0] * x.shape[1] - return jnp.reshape(x, [batch_size] + list(x.shape[2:])) - return fastmath.nested_map(f, x_tuple) + """Combines multi-device tensors into a single batch.""" + + def f(x): + if len(x.shape) < 2: + return x # No extra batch dimension: use devices as batch, so return. + batch_size = x.shape[0] * x.shape[1] + return jnp.reshape(x, [batch_size] + list(x.shape[2:])) + + return fastmath.nested_map(f, x_tuple) def _accelerate(f, n_devices): - """Returns an accelerated version of ``f`` running on ``n_devices``.""" - if n_devices == 0: # no accelerators - run on CPU - return fastmath.jit(f, device=jax.devices('cpu')[0]) + """Returns an accelerated version of ``f`` running on ``n_devices``.""" + if n_devices == 0: # no accelerators - run on CPU + return fastmath.jit(f, device=jax.devices("cpu")[0]) - if n_devices == 1: - return fastmath.jit(f) + if n_devices == 1: + return fastmath.jit(f) - return fastmath.pmap(f, axis_name='batch') + return fastmath.pmap(f, axis_name="batch") def reshape_by_device(x, n_devices, pure_np=False): - """Reshapes possibly nested ``x`` into a shape ``(n_devices, ...)``.""" - def f(x): - x_shape = list(x.shape) - batch_size = x_shape[0] - batch_size_per_device = batch_size // n_devices - if batch_size_per_device * n_devices != batch_size: - raise ValueError(f'Number of devices ({n_devices}) does not evenly ' - f'divide batch size ({batch_size}).') - new_shape_prefix = [n_devices, batch_size_per_device] - if pure_np: - return np.reshape(x, new_shape_prefix + x_shape[1:]) - else: - return jnp.reshape(x, new_shape_prefix + x_shape[1:]) - return fastmath.nested_map(f, x) + """Reshapes possibly nested ``x`` into a shape ``(n_devices, ...)``.""" + + def f(x): + x_shape = list(x.shape) + batch_size = x_shape[0] + batch_size_per_device = batch_size // n_devices + if batch_size_per_device * n_devices != batch_size: + raise ValueError( + f"Number of devices ({n_devices}) does not evenly " + f"divide batch size ({batch_size})." + ) + new_shape_prefix = [n_devices, batch_size_per_device] + if pure_np: + return np.reshape(x, new_shape_prefix + x_shape[1:]) + else: + return jnp.reshape(x, new_shape_prefix + x_shape[1:]) + + return fastmath.nested_map(f, x) def for_n_devices(x, n_devices): - """Replicates/broadcasts ``x`` for ``n_devices``.""" - def f(x): - if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): - return jax.device_put_replicated(x, jax.local_devices()) - elif n_devices > 1: - return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) - else: - return x - return fastmath.nested_map(f, x) + """Replicates/broadcasts ``x`` for ``n_devices``.""" + + def f(x): + if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): + return jax.device_put_replicated(x, jax.local_devices()) + elif n_devices > 1: + return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) + else: + return x + + return fastmath.nested_map(f, x) def on_cpu(x): - """Puts ``x`` in CPU memory in JAX.""" - if fastmath.is_backend(fastmath.Backend.JAX): - return jax.device_put(x, jax.devices('cpu')[0]) - else: - return x + """Puts ``x`` in CPU memory in JAX.""" + if fastmath.is_backend(fastmath.Backend.JAX): + return jax.device_put(x, jax.devices("cpu")[0]) + else: + return x def on_accelerator(x): - """Puts ``x`` in (single) accelerator memory in JAX.""" - try: - accelerator_devices = jax.devices('gpu') - except RuntimeError: + """Puts ``x`` in (single) accelerator memory in JAX.""" try: - accelerator_devices = jax.devices('tpu') + accelerator_devices = jax.devices("gpu") except RuntimeError: - accelerator_devices = [] - if not accelerator_devices: - return x - if len(accelerator_devices) != 1: - return x - return jax.device_put(x, accelerator_devices[0]) + try: + accelerator_devices = jax.devices("tpu") + except RuntimeError: + accelerator_devices = [] + if not accelerator_devices: + return x + if len(accelerator_devices) != 1: + return x + return jax.device_put(x, accelerator_devices[0]) diff --git a/trax/layers/acceleration_test.py b/trax/layers/acceleration_test.py deleted file mode 100644 index 57002474d..000000000 --- a/trax/layers/acceleration_test.py +++ /dev/null @@ -1,102 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for acceleration.""" - -from absl.testing import absltest - -from jax.config import config -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes - - -class AccelerationTest(absltest.TestCase): - - def test_accelerated_same_result(self): - layer = tl.Dense(2) - x = np.random.uniform(size=(8, 7)) - layer.init(shapes.signature(x)) - y = layer(x) - z = tl.Accelerate(layer)(x) - for i in range(8): - self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) - self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) - - def test_accelerated_pad(self): - layer = tl.Dense(2) - x = np.random.uniform(size=(3, 7)) - layer.init(shapes.signature(x)) - y = layer(x) - z = tl.Accelerate(layer)(x) - self.assertEqual(z.shape, y.shape) - for i in range(3): - self.assertAlmostEqual(float(y[i, 0]), float(z[i, 0]), places=4) - self.assertAlmostEqual(float(y[i, 1]), float(z[i, 1]), places=4) - - def test_accelerated_weighted_category_accuracy(self): - """Test multi-device aggregation of weights.""" - layer = tl.Accelerate(tl.WeightedCategoryAccuracy()) - weights = np.array([1., 1., 1., 0.]) - targets = np.array([0, 1, 2, 3]) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(np.mean(accuracy), 1 / 3) - - def test_chunk_memory(self): - """Test chunking here to exercise accelerator memory usage.""" - layer = tl.Serial(tl.Dense(1024*1024), tl.Dense(128)) - chunked = tl.Chunk(layer, 256) - x = np.random.uniform(size=(16*1024, 16)) - chunked.init(shapes.signature(x)) - y = chunked(x) - z = tl.Accelerate(chunked)(x) - self.assertEqual(y.shape, (16*1024, 128)) - self.assertEqual(z.shape, (16*1024, 128)) - - def test_chunk_grad_memory(self): - """Test chunking gradient here to exercise accelerator memory usage.""" - layer = tl.Serial(tl.Dense(1024*1024), tl.Dense(24)) - chunked = tl.Chunk(layer, 256) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits, new_state = chunked.pure_fn(x, weights, state, rng) - loss = fastmath.numpy.mean(logits) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - x = np.random.uniform(size=(32*1024, 16)) - chunked.init(shapes.signature(x)) - weights, _, logits = mock_training_step( - x, chunked.weights, chunked.state, fastmath.random.get_prng(0)) - self.assertEqual(logits.shape, (32*1024, 24)) - self.assertEqual(weights[1][0][0][0].shape, (16, 1024*1024)) - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() diff --git a/trax/layers/activation_fns.py b/trax/layers/activation_fns.py index 625ff87ab..133273d2f 100644 --- a/trax/layers/activation_fns.py +++ b/trax/layers/activation_fns.py @@ -29,9 +29,9 @@ from trax.layers.base import Fn -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Relu(): - r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function. + r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -39,12 +39,12 @@ def Relu(): x & \text{otherwise}. \end{array} \right. """ - return Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)) + return Fn("Relu", lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)) -@assert_shape('...->...') # The output and input shapes are the same. -def ParametricRelu(a=1.): - r"""Returns a layer that computes a ReLU function with the given slope. +@assert_shape("...->...") # The output and input shapes are the same. +def ParametricRelu(a=1.0): + r"""Returns a layer that computes a ReLU function with the given slope. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -55,12 +55,12 @@ def ParametricRelu(a=1.): Args: a: Slope of line for positive inputs. """ - return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x))) + return Fn("ParametricRelu", lambda x: jnp.maximum(a * x, jnp.zeros_like(x))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def LeakyRelu(a=0.01): - r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs. + r"""Returns a ReLU-like layer with linear nonzero outputs for negative inputs. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -71,12 +71,12 @@ def LeakyRelu(a=0.01): Args: a: Slope of line for negative inputs. """ - return Fn('LeakyRelu', lambda x: jnp.where(x >= 0, x, a * x)) + return Fn("LeakyRelu", lambda x: jnp.where(x >= 0, x, a * x)) -@assert_shape('...->...') # The output and input shapes are the same. -def Elu(a=1.): - r"""Returns a ReLU-like layer with exponential outputs for negative inputs. +@assert_shape("...->...") # The output and input shapes are the same. +def Elu(a=1.0): + r"""Returns a ReLU-like layer with exponential outputs for negative inputs. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -89,13 +89,14 @@ def Elu(a=1.): Args: a: Coefficient multiplying the exponential, for negative inputs. """ - return Fn('Elu', lambda x: jnp.where(x > 0, x, a * jnp.expm1(x))) + return Fn("Elu", lambda x: jnp.where(x > 0, x, a * jnp.expm1(x))) -@assert_shape('...->...') # The output and input shapes are the same. -def Selu(alpha=1.6732632423543772848170429916717, - lmbda=1.0507009873554804934193349852946): - r"""Returns an `Elu`-like layer with an additional scaling/slope parameter. +@assert_shape("...->...") # The output and input shapes are the same. +def Selu( + alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946 +): + r"""Returns an `Elu`-like layer with an additional scaling/slope parameter. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -107,58 +108,62 @@ def Selu(alpha=1.6732632423543772848170429916717, alpha: Coefficient multiplying the exponential, for negative inputs. lmbda: Coefficient scaling the whole function. """ - return Fn('Selu', lambda x: lmbda * jnp.where(x > 0, x, alpha * jnp.expm1(x))) + return Fn("Selu", lambda x: lmbda * jnp.where(x > 0, x, alpha * jnp.expm1(x))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Gelu(): - r"""Returns a layer that computes the Gaussian Error Linear Unit function. + r"""Returns a layer that computes the Gaussian Error Linear Unit function. - .. math:: - f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}})) - """ - return Fn('Gelu', lambda x: x * 0.5 * (1.0 + fastmath.erf(x / jnp.sqrt(2.0)))) + .. math:: + f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}})) + """ + return Fn("Gelu", lambda x: x * 0.5 * (1.0 + fastmath.erf(x / jnp.sqrt(2.0)))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def FastGelu(): - r"""Returns a layer that computes a fast approximation to `Gelu`. + r"""Returns a layer that computes a fast approximation to `Gelu`. - .. math:: - f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3)) + .. math:: + f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3)) - where :math:`a = 0.7978845608` and :math:`b = 0.044715`. - """ - def f(x): # pylint: disable=invalid-name - return 0.5 * x * (1 + jnp.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) - return Fn('FastGelu', f) + where :math:`a = 0.7978845608` and :math:`b = 0.044715`. + """ + + def f(x): # pylint: disable=invalid-name + return 0.5 * x * (1 + jnp.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) + + return Fn("FastGelu", f) # pylint: disable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Sigmoid(): - r"""Returns a layer that computes the sigmoid function. + r"""Returns a layer that computes the sigmoid function. - .. math:: - f(x) = \frac{1}{1 + e^{-x}} - """ - return Fn('Sigmoid', lambda x: fastmath.expit(x)) + .. math:: + f(x) = \frac{1}{1 + e^{-x}} + """ + return Fn("Sigmoid", lambda x: fastmath.expit(x)) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Tanh(): - r"""Returns a layer that computes the hyperbolic tangent function. + r"""Returns a layer that computes the hyperbolic tangent function. + + .. math:: + f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} + """ + return Fn("Tanh", lambda x: jnp.tanh(x)) + - .. math:: - f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} - """ - return Fn('Tanh', lambda x: jnp.tanh(x)) # pylint: enable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def HardSigmoid(): - r"""Returns a layer that computes a linear approximation to `Sigmoid`. + r"""Returns a layer that computes a linear approximation to `Sigmoid`. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -167,12 +172,12 @@ def HardSigmoid(): 1 & \text{otherwise}. \end{array} \right. """ - return Fn('HardSigmoid', lambda x: jnp.maximum(0, jnp.minimum(1, (1 + x)))) + return Fn("HardSigmoid", lambda x: jnp.maximum(0, jnp.minimum(1, (1 + x)))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def HardTanh(): - r"""Returns a layer that computes a linear approximation to `Tanh`. + r"""Returns a layer that computes a linear approximation to `Tanh`. .. math:: f(x) = \left\{ \begin{array}{cl} @@ -181,76 +186,76 @@ def HardTanh(): 1 & \text{otherwise}. \end{array} \right. """ - return Fn('HardTanh', lambda x: jnp.maximum(-1, jnp.minimum(1, x))) + return Fn("HardTanh", lambda x: jnp.maximum(-1, jnp.minimum(1, x))) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Softplus(): - r"""Returns a layer that computes the softplus function. + r"""Returns a layer that computes the softplus function. - .. math:: - f(x) = \ln(e^x + 1) - """ - return Fn('Softplus', lambda x: jnp.logaddexp(x, 0.)) + .. math:: + f(x) = \ln(e^x + 1) + """ + return Fn("Softplus", lambda x: jnp.logaddexp(x, 0.0)) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Exp(): - """Returns a layer that computes the element-wise exponential of a tensor.""" - return Fn('Exp', lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda + """Returns a layer that computes the element-wise exponential of a tensor.""" + return Fn("Exp", lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Log(): - """Returns a layer that computes the element-wise logarithm of a tensor.""" - return Fn('Log', lambda x: jnp.log(x)) # pylint: disable=unnecessary-lambda + """Returns a layer that computes the element-wise logarithm of a tensor.""" + return Fn("Log", lambda x: jnp.log(x)) # pylint: disable=unnecessary-lambda -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Swish(): - r"""Returns a layer that computes the Swish function. + r"""Returns a layer that computes the Swish function. - .. math:: - f(x) = x \cdot \text{sigmoid}(x) - """ - return Fn('Swish', lambda x: x * fastmath.expit(x)) + .. math:: + f(x) = x \cdot \text{sigmoid}(x) + """ + return Fn("Swish", lambda x: x * fastmath.expit(x)) -@assert_shape('...a->...b') # The output and input shapes are not the same. +@assert_shape("...a->...b") # The output and input shapes are not the same. def Glu(): - r"""Returns a layer that computes the Gated Linear Unit function. + r"""Returns a layer that computes the Gated Linear Unit function. - .. math:: - f(x) = a \cdot \text{sigmoid}(b) - where a and b are formed by splitting input in half along axis + .. math:: + f(x) = a \cdot \text{sigmoid}(b) + where a and b are formed by splitting input in half along axis - """ + """ - def _f(x, axis=-1): # pylint: disable=invalid-name - size = x.shape[axis] - assert size % 2 == 0, f'axis {axis} of size {size} is not be divisible by 2' - a, b = jnp.split(x, 2, axis) - return a * fastmath.expit(b) + def _f(x, axis=-1): # pylint: disable=invalid-name + size = x.shape[axis] + assert size % 2 == 0, f"axis {axis} of size {size} is not be divisible by 2" + a, b = jnp.split(x, 2, axis) + return a * fastmath.expit(b) - return Fn('Glu', _f) + return Fn("Glu", _f) class ThresholdedLinearUnit(base.Layer): - """Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf .""" + """Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf .""" - def init_weights_and_state(self, input_signature): - """Initializes this layer's single weight to zero.""" - del input_signature - self.weights = jnp.zeros((), dtype=jnp.float32) + def init_weights_and_state(self, input_signature): + """Initializes this layer's single weight to zero.""" + del input_signature + self.weights = jnp.zeros((), dtype=jnp.float32) - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model. + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model. - Args: - inputs: Tensor. + Args: + inputs: Tensor. - Returns: - Tensor of same shape and dtype as the input. - """ - threshold = self.weights - return jnp.maximum(inputs, threshold) + Returns: + Tensor of same shape and dtype as the input. + """ + threshold = self.weights + return jnp.maximum(inputs, threshold) diff --git a/trax/layers/activation_fns_test.py b/trax/layers/activation_fns_test.py deleted file mode 100644 index 2f128bd47..000000000 --- a/trax/layers/activation_fns_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for activation function layers.""" - -from absl.testing import absltest -import numpy as np - -import trax.layers as tl - - -class ActivationFnsTest(absltest.TestCase): - - def test_relu(self): - layer = tl.Relu() - x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) - y = layer(x) - self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, 2.0, 3.0, 5.0]) - - def test_parametric_relu(self): - layer = tl.ParametricRelu(a=.25) - x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) - y = layer(x) - self.assertEqual(tl.to_list(y), [0.0, 0.0, 0.0, .5, .75, 1.25]) - - def test_leaky_relu(self): - layer = tl.LeakyRelu(a=.125) - x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) - y = layer(x) - self.assertEqual(tl.to_list(y), [-.25, -.125, 0.0, 2.0, 3.0, 5.0]) - - def test_hard_sigmoid(self): - layer = tl.HardSigmoid() - x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5]) - y = layer(x) - self.assertEqual(tl.to_list(y), [0.0, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0]) - - def test_hard_tanh(self): - layer = tl.HardTanh() - x = np.array([-1.5, -.5, -.25, 0.0, .25, .5, 1.5]) - y = layer(x) - self.assertEqual(tl.to_list(y), [-1.0, -.5, -.25, 0.0, .25, .5, 1.0]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/assert_shape.py b/trax/layers/assert_shape.py index dffa85392..95bf7c9e9 100644 --- a/trax/layers/assert_shape.py +++ b/trax/layers/assert_shape.py @@ -18,274 +18,283 @@ import functools import inspect import string + from absl import logging -from trax.layers import base -from trax.layers import combinators + +from trax.layers import base, combinators def assert_shape(specification): - """Decorator for checking the input and output shapes of Layer. - - Decorator can be applied on trax.base.Layer class, or a function returning - a trax.base.Layer class. It uses notation similar to einsum (Einstein - summation convention), achieving concise and simple representation of tensors. - For example 'ij,jh->ih' is a valid representation of a function taking two - 2D matrices as input, and returning a single output, also a 2D matrix. - - It improves readability and puts puts three levels of asserts on the function: - first level is the number of input tensors and output tensors; second level is - the rank of each tensor; third level is the size of each dimension of each - tensor. The decorator inserts those asserts right before and right after - 'forward' call. - - First level, assert on number of inputs and outputs. In the representation - input tensors are separated from output tensors by an arrow '->'. For layers - taking multiple input tensors or returning multiple output tensors, those - tensors will be separated by a comma ','. - For example, specification 'bsd,df->bsf' asserts that there will be two - input tensors, with shapes represented by 'bsd' and 'df' respectively; and - a single output tensor with shape represented by 'bsf'. - - Second level, asserts on possible rank of each tensor. Most commonly, - each letter represents a single dimension. For example,the tensor with shapes - represented by 'bsd' has rank three; with 'df' it has rank two. The special - case is an ellipsis ('...'), which expand to arbitrary number of dimensions, - including zero. For example, the tensor with specification '...sf' has at - least two dimensions. Each tensor may have in its representation one ellipsis. - - Third level, asserts the size of each dimension. If two dimensions in any - of input or output tensors have the same letter in the representation then - they must have the same size. For example, with a tensor A represented by 'df' - and a tensor B represented by 'bsf', the size of the second dimension of A - must equal the size of the third dimension of B. Another example: with a - tensor C represented by '...dv' and a tensor D represented by 'd', the size of - the first and only dimension of D must be equal to the size of the second to - last dimension of tensor C. - - If two distinct tensors have an ellipsis in their representation then all of - dimensions covered by those ellipses must match. For example, with a tensor E - represented by '...d' and tensor F represented by '...x' then E and F must - have the same rank, and the sizes of all but the last dimensions must match. - - Examples: - # In Dense layer there is a single input and single output; the last dimension - # may change in size, while the sizes of all previous dimensions, marked by - # an ellipsis, will stay the same. - @assert_shape('...a->...b') - class Dense(base.Layer): - (...) - - # DotProductCausalAttention takes three tensors as input: Queries, Keys, and - # Values, and outputs a single tensor. Sizes of the first two dimensions in - # all those tensors must match, while the last dimension must match only - # between Queries and Keys, and separately between Values and output tensor. - @assert_shape('blk,blk,bld->bld') - class DotProductCausalAttention(base.Layer): - (...) - - # assert_shape can also be placed before the function returning base.Layer. - @assert_shape('...d->...') - def ReduceSum(): - return Fn('ReduceSum', lambda x: jnp.sum(x, axis=-1, keepdims=False)) - - Args: - specification: A text specification for the input/output tensors. - - Returns: - The decorator changing the class or function. - """ - caller = inspect.getframeinfo(inspect.stack()[1][0]) - message = f'Defined at {caller.filename}:{caller.lineno}' - - def wrap_cls(cls): - forward = getattr(cls, 'forward') - init = getattr(cls, '__init__') - - before_spec, after_spec = specification.split('->') - - @functools.wraps(init) - def init_wrapper(self, *args, **kwargs): - before_assert = AssertShape(before_spec, - message=message + ' function input') - after_assert = AssertShape(after_spec, - message=message + ' function output') - after_assert._create_link(before_assert) # pylint: disable=protected-access - out = init(self, *args, **kwargs) - self._before_assert_fun = before_assert # pylint: disable=protected-access - self._after_assert_fun = after_assert # pylint: disable=protected-access - return out - - @functools.wraps(forward) - def forward_wrapper(self, x, *args, **kwargs): - x = self._before_assert_fun.forward(x) # pylint: disable=protected-access - y = forward(self, x, *args, **kwargs) - y = self._after_assert_fun.forward(y) # pylint: disable=protected-access - return y - - setattr(cls, 'forward', forward_wrapper) - setattr(cls, '__init__', init_wrapper) - return cls - - # TODO(jaszczur): replace this with forward/init override. - def wrap_fun(fun): - @functools.wraps(fun) - def fun_wrapper(*args, **kwargs): - layer = fun(*args, **kwargs) - return AssertFunction(specification, layer, message) - return fun_wrapper - - def wrap_fun_or_cls(fun_or_cls): - return (wrap_cls(fun_or_cls) if inspect.isclass(fun_or_cls) else - wrap_fun(fun_or_cls)) - - return wrap_fun_or_cls + """Decorator for checking the input and output shapes of Layer. + + Decorator can be applied on trax.base.Layer class, or a function returning + a trax.base.Layer class. It uses notation similar to einsum (Einstein + summation convention), achieving concise and simple representation of tensors. + For example 'ij,jh->ih' is a valid representation of a function taking two + 2D matrices as input, and returning a single output, also a 2D matrix. + + It improves readability and puts puts three levels of asserts on the function: + first level is the number of input tensors and output tensors; second level is + the rank of each tensor; third level is the size of each dimension of each + tensor. The decorator inserts those asserts right before and right after + 'forward' call. + + First level, assert on number of inputs and outputs. In the representation + input tensors are separated from output tensors by an arrow '->'. For layers + taking multiple input tensors or returning multiple output tensors, those + tensors will be separated by a comma ','. + For example, specification 'bsd,df->bsf' asserts that there will be two + input tensors, with shapes represented by 'bsd' and 'df' respectively; and + a single output tensor with shape represented by 'bsf'. + + Second level, asserts on possible rank of each tensor. Most commonly, + each letter represents a single dimension. For example,the tensor with shapes + represented by 'bsd' has rank three; with 'df' it has rank two. The special + case is an ellipsis ('...'), which expand to arbitrary number of dimensions, + including zero. For example, the tensor with specification '...sf' has at + least two dimensions. Each tensor may have in its representation one ellipsis. + + Third level, asserts the size of each dimension. If two dimensions in any + of input or output tensors have the same letter in the representation then + they must have the same size. For example, with a tensor A represented by 'df' + and a tensor B represented by 'bsf', the size of the second dimension of A + must equal the size of the third dimension of B. Another example: with a + tensor C represented by '...dv' and a tensor D represented by 'd', the size of + the first and only dimension of D must be equal to the size of the second to + last dimension of tensor C. + + If two distinct tensors have an ellipsis in their representation then all of + dimensions covered by those ellipses must match. For example, with a tensor E + represented by '...d' and tensor F represented by '...x' then E and F must + have the same rank, and the sizes of all but the last dimensions must match. + + Examples: + # In Dense layer there is a single input and single output; the last dimension + # may change in size, while the sizes of all previous dimensions, marked by + # an ellipsis, will stay the same. + @assert_shape('...a->...b') + class Dense(base.Layer): + (...) + + # DotProductCausalAttention takes three tensors as input: Queries, Keys, and + # Values, and outputs a single tensor. Sizes of the first two dimensions in + # all those tensors must match, while the last dimension must match only + # between Queries and Keys, and separately between Values and output tensor. + @assert_shape('blk,blk,bld->bld') + class DotProductCausalAttention(base.Layer): + (...) + + # assert_shape can also be placed before the function returning base.Layer. + @assert_shape('...d->...') + def ReduceSum(): + return Fn('ReduceSum', lambda x: jnp.sum(x, axis=-1, keepdims=False)) + Args: + specification: A text specification for the input/output tensors. -def AssertFunction(specification, layer, message=None): # pylint: disable=invalid-name - """AssertFunction asserts shapes on the input/output tensors of a layer. - - It passes all inputs to the layer, and returns all outputs of the layer - unchanged. - - Args: - specification: A specification. See assert_shape decorator for a full - documentation. - layer: A base.Layer to wrap around. - message: An optional message to print if an assert fails. By default it will - print the filename and the line number where AssertFunction was called. - - Returns: - The given layer wrapped in asserts on its inputs and outputs. - """ - if message is None: + Returns: + The decorator changing the class or function. + """ caller = inspect.getframeinfo(inspect.stack()[1][0]) - message = f'Defined at {caller.filename}:{caller.lineno}' - before_spec, after_spec = specification.split('->') - before_assert = AssertShape(before_spec, message=message + ' function input') - after_assert = AssertShape(after_spec, message=message + ' function output') - after_assert._create_link(before_assert) # pylint: disable=protected-access - return combinators.Serial( - before_assert, layer, after_assert) - + message = f"Defined at {caller.filename}:{caller.lineno}" + + def wrap_cls(cls): + forward = getattr(cls, "forward") + init = getattr(cls, "__init__") + + before_spec, after_spec = specification.split("->") + + @functools.wraps(init) + def init_wrapper(self, *args, **kwargs): + before_assert = AssertShape( + before_spec, message=message + " function input" + ) + after_assert = AssertShape(after_spec, message=message + " function output") + after_assert._create_link(before_assert) # pylint: disable=protected-access + out = init(self, *args, **kwargs) + self._before_assert_fun = before_assert # pylint: disable=protected-access + self._after_assert_fun = after_assert # pylint: disable=protected-access + return out + + @functools.wraps(forward) + def forward_wrapper(self, x, *args, **kwargs): + x = self._before_assert_fun.forward(x) # pylint: disable=protected-access + y = forward(self, x, *args, **kwargs) + y = self._after_assert_fun.forward(y) # pylint: disable=protected-access + return y + + setattr(cls, "forward", forward_wrapper) + setattr(cls, "__init__", init_wrapper) + return cls + + # TODO(jaszczur): replace this with forward/init override. + def wrap_fun(fun): + @functools.wraps(fun) + def fun_wrapper(*args, **kwargs): + layer = fun(*args, **kwargs) + return AssertFunction(specification, layer, message) + + return fun_wrapper + + def wrap_fun_or_cls(fun_or_cls): + return ( + wrap_cls(fun_or_cls) + if inspect.isclass(fun_or_cls) + else wrap_fun(fun_or_cls) + ) + + return wrap_fun_or_cls -class AssertShape(base.Layer): - """Layer which put asserts on shapes of tensors, and returns them unchanged. - It borrows the notation from assert_shape decorator, except it doesn't have - the arrow '->' special character, as the input tensors are the same as output. - """ +def AssertFunction(specification, layer, message=None): # pylint: disable=invalid-name + """AssertFunction asserts shapes on the input/output tensors of a layer. - def __init__(self, spec, message=None, visible_layer=False): - """Creates AssertShape layer. + It passes all inputs to the layer, and returns all outputs of the layer + unchanged. Args: - spec: Specification for input tensors. See assert_shape decorator for the - full documentation. - message: An optional message to include when assert fails. By default it - includes the filename and line number where this function was called. - visible_layer: If true, print this layer inside the model (default: False) + specification: A specification. See assert_shape decorator for a full + documentation. + layer: A base.Layer to wrap around. + message: An optional message to print if an assert fails. By default it will + print the filename and the line number where AssertFunction was called. + + Returns: + The given layer wrapped in asserts on its inputs and outputs. """ - name = 'AssertShape' if visible_layer else '' - super().__init__(name=name) - spec = spec.replace('...', '.') - for letter in spec: - assert letter in string.ascii_letters + string.digits + '.' + ',' - self._specs = spec.split(',') - self._n_in = self._n_out = len(self._specs) + if message is None: + caller = inspect.getframeinfo(inspect.stack()[1][0]) + message = f"Defined at {caller.filename}:{caller.lineno}" + before_spec, after_spec = specification.split("->") + before_assert = AssertShape(before_spec, message=message + " function input") + after_assert = AssertShape(after_spec, message=message + " function output") + after_assert._create_link(before_assert) # pylint: disable=protected-access + return combinators.Serial(before_assert, layer, after_assert) - self._defined_shapes = {str(i): i for i in range(10)} - self._linked = False - if message is None: - caller = inspect.getframeinfo(inspect.stack()[1][0]) - self._message = f'Defined at {caller.filename}:{caller.lineno}' - else: - self._message = message - - def forward(self, xs): - if not self._linked: - for k in list(self._defined_shapes.keys()): - if not k.isdigit(): - del self._defined_shapes[k] - - if not isinstance(xs, (list, tuple)): - xs = [xs] - - # Try-except below checks if something is wrong with shapes. It can happen - # e.g. when using trax2keras. If this is the case we cannot check if shapes - # are correct or not - try: - for x in xs: - for i in range(len(x.shape)): - if x.shape[i] != x.shape[i]: - raise TypeError() - except TypeError: - message = ('AssertShape cannot check shapes. This often happens when' - ' using trax2keras. Shape asserts are skipped.') - print(message) - logging.warning(message) - if len(xs) == 1: - return xs[0] - else: - return xs - - # helper functions - def assert_true(cond): - if not cond: - shapes = [x.shape for x in xs] - defined_shapes_dict_without_digits = { - k: v for k, v in self._defined_shapes.items() if not k.isdigit()} - raise ValueError( - f'AssertShape Error. Expected {self._specs}, got {shapes} with dict' - f' {defined_shapes_dict_without_digits}. {self._message}') - - def assert_equal(a, b): - assert_true(a == b) - return a - - def check_shape(shape, spec): - assert_equal(len(shape), len(spec)) - for shape_dim, letter in zip(shape, spec): - if letter in self._defined_shapes: - self._defined_shapes[letter] = assert_equal( - self._defined_shapes[letter], shape_dim) +class AssertShape(base.Layer): + """Layer which put asserts on shapes of tensors, and returns them unchanged. + + It borrows the notation from assert_shape decorator, except it doesn't have + the arrow '->' special character, as the input tensors are the same as output. + """ + + def __init__(self, spec, message=None, visible_layer=False): + """Creates AssertShape layer. + + Args: + spec: Specification for input tensors. See assert_shape decorator for the + full documentation. + message: An optional message to include when assert fails. By default it + includes the filename and line number where this function was called. + visible_layer: If true, print this layer inside the model (default: False) + """ + name = "AssertShape" if visible_layer else "" + super().__init__(name=name) + spec = spec.replace("...", ".") + for letter in spec: + assert letter in string.ascii_letters + string.digits + "." + "," + self._specs = spec.split(",") + self._n_in = self._n_out = len(self._specs) + + self._defined_shapes = {str(i): i for i in range(10)} + self._linked = False + + if message is None: + caller = inspect.getframeinfo(inspect.stack()[1][0]) + self._message = f"Defined at {caller.filename}:{caller.lineno}" else: - self._defined_shapes[letter] = shape_dim - - def check_ellipsys(shape): - if '.' not in self._defined_shapes: - self._defined_shapes['.'] = shape - else: - assert_equal(len(shape), len(self._defined_shapes['.'])) - for s1, s2 in zip(shape, self._defined_shapes['.']): - assert_equal(s1, s2) - - # actual asserts - assert_equal(len(xs), len(self._specs)) - - for x, spec in zip(xs, self._specs): - if '.' in spec: - assert_true(len(x.shape) >= (len(spec) - 1)) - - before, after = spec.split('.') - check_shape(x.shape[:len(before)], before) - if after: - check_shape(x.shape[-len(after):], after) - check_ellipsys(x.shape[len(before):-len(after)]) + self._message = message + + def forward(self, xs): + if not self._linked: + for k in list(self._defined_shapes.keys()): + if not k.isdigit(): + del self._defined_shapes[k] + + if not isinstance(xs, (list, tuple)): + xs = (xs,) + + # Try-except below checks if something is wrong with shapes. It can happen + # e.g. when using trax2keras. If this is the case we cannot check if shapes + # are correct or not + try: + for x in xs: + for i in range(len(x.shape)): + if x.shape[i] != x.shape[i]: + raise TypeError() + except TypeError: + message = ( + "AssertShape cannot check shapes. This often happens when" + " using trax2keras. Shape asserts are skipped." + ) + print(message) + logging.warning(message) + if len(xs) == 1: + return xs[0] + else: + return xs + + # helper functions + def assert_true(cond): + if not cond: + shapes = [x.shape for x in xs] + defined_shapes_dict_without_digits = { + k: v for k, v in self._defined_shapes.items() if not k.isdigit() + } + raise ValueError( + f"AssertShape Error. Expected {self._specs}, got {shapes} with dict" + f" {defined_shapes_dict_without_digits}. {self._message}" + ) + + def assert_equal(a, b): + assert_true(a == b) + return a + + def check_shape(shape, spec): + assert_equal(len(shape), len(spec)) + for shape_dim, letter in zip(shape, spec): + if letter in self._defined_shapes: + self._defined_shapes[letter] = assert_equal( + self._defined_shapes[letter], shape_dim + ) + else: + self._defined_shapes[letter] = shape_dim + + def check_ellipsys(shape): + if "." not in self._defined_shapes: + self._defined_shapes["."] = shape + else: + assert_equal(len(shape), len(self._defined_shapes["."])) + for s1, s2 in zip(shape, self._defined_shapes["."]): + assert_equal(s1, s2) + + # actual asserts + assert_equal(len(xs), len(self._specs)) + + for x, spec in zip(xs, self._specs): + if "." in spec: + assert_true(len(x.shape) >= (len(spec) - 1)) + + before, after = spec.split(".") + check_shape(x.shape[: len(before)], before) + if after: + check_shape(x.shape[-len(after) :], after) + check_ellipsys(x.shape[len(before) : -len(after)]) + else: + # if len(after) == 0 then -len(after) in indices evaluates badly. + check_ellipsys(x.shape[len(before) :]) + else: + check_shape(x.shape, spec) + + if len(xs) == 1: + return xs[0] else: - # if len(after) == 0 then -len(after) in indices evaluates badly. - check_ellipsys(x.shape[len(before):]) - else: - check_shape(x.shape, spec) - - if len(xs) == 1: - return xs[0] - else: - return xs - - def _create_link(self, other): - """Internal. Used to create a shared dictionary.""" - # This works well for assert_shape and AssertFunction; but it can break - # easily if the order of calls to forward() is not known in advance. - self._linked = True - self._defined_shapes = other._defined_shapes # pylint: disable=protected-access + return xs + + def _create_link(self, other): + """Internal. Used to create a shared dictionary.""" + # This works well for assert_shape and AssertFunction; but it can break + # easily if the order of calls to forward() is not known in advance. + self._linked = True + self._defined_shapes = other._defined_shapes # pylint: disable=protected-access diff --git a/trax/layers/assert_shape_test.py b/trax/layers/assert_shape_test.py deleted file mode 100644 index 5c3995ba6..000000000 --- a/trax/layers/assert_shape_test.py +++ /dev/null @@ -1,275 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for assert shape layers.""" - -from absl.testing import absltest -import numpy as np - -import trax.layers as tl - - -class AssertFunctionTest(absltest.TestCase): - """Test AssertFunction layer.""" - - def test_simple_pass(self): - layer = tl.AssertFunction('abc->abc', tl.Dropout(rate=0.1)) - x = np.ones((2, 5, 20)) - layer(x) - - def test_simple_fail(self): - layer = tl.AssertFunction('abc->cba', tl.Dropout(rate=0.1)) - x = np.ones((2, 5, 20)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_reduce_rank_ellipsis_pass(self): - layer = tl.AssertFunction('...ab->...c', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - layer(x) - - def test_reduce_rank_explicit_pass(self): - layer = tl.AssertFunction('xyzab->xyzc', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - layer(x) - - def test_reduce_rank_to_one_pass(self): - layer = tl.AssertFunction('abcde->x', tl.Flatten(n_axes_to_keep=0)) - x = np.ones((1, 2, 3, 4, 5)) - layer(x) - - def test_reduce_rank_explicit_fail1(self): - layer = tl.AssertFunction('abcde->abcde', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_reduce_rank_explicit_fail2(self): - layer = tl.AssertFunction('abcde->abcd', tl.Flatten(n_axes_to_keep=3)) - x = np.ones((1, 2, 3, 4, 5)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_two_outputs_pass(self): - layer = tl.AssertFunction( - '...cd->...x,...cd', - tl.Branch( - tl.Flatten(n_axes_to_keep=2), - tl.Dropout(rate=0.1), - )) - x = np.ones((1, 2, 3, 4)) - layer(x) - - def test_numeric_dimensions_pass(self): - layer = tl.AssertFunction( - '...34->1234,...34', - tl.Branch( - tl.Dropout(rate=0.1), - tl.Serial(), - )) - x = np.ones((1, 2, 3, 4)) - layer(x) - - def test_too_many_outputs_fail(self): - layer = tl.AssertFunction( - '...cd->...x,...cd,...cd,...cd', - tl.Branch( - tl.Flatten(n_axes_to_keep=2), - tl.Dropout(rate=0.1), - tl.Serial(), - )) - x = np.ones((1, 2, 3, 4)) - with self.assertRaises(tl.LayerError): - layer(x) - - def test_multi_output_rank_fail(self): - layer = tl.AssertFunction( - '...34->...x,...y', - tl.Branch( - tl.Flatten(n_axes_to_keep=3), - tl.Serial(), - )) - x = np.ones((1, 2, 3, 4)) - with self.assertRaises(tl.LayerError): - layer(x) - - -class AssertShapeTest(absltest.TestCase): - """Test AssertShape layer.""" - - def test_simple_pass(self): - layer = tl.AssertShape('aba,ba') - x = [np.ones((10, 5, 10)), np.zeros((5, 10))] - y = layer(x) - self.assertEqual(y, x) - - def test_same_shapes_pass(self): - layer = tl.AssertShape('aba,ba') - x = [np.ones((5, 5, 5)), np.zeros((5, 5))] - y = layer(x) - self.assertEqual(y, x) - - def test_single_arg_pass(self): - layer = tl.AssertShape('a') - x = np.ones((5,)) - y = layer(x) - self.assertEqual(y.tolist(), x.tolist()) - - def test_scalar_pass(self): - layer = tl.AssertShape('') - x = np.ones(()) - y = layer(x) - self.assertEqual(y.tolist(), x.tolist()) - - def test_square_matrix_pass(self): - layer = tl.AssertShape('aa') - x = np.ones((3, 3)) - y = layer(x) - self.assertEqual(y.tolist(), x.tolist()) - - def test_vector_scalar_pass(self): - layer = tl.AssertShape('a,') - x = [np.ones((5,)), np.zeros(())] - y = layer(x) - self.assertEqual(y, x) - - def test_three_args_pass(self): - layer = tl.AssertShape('a,b,a') - x = [np.ones((5,)), np.zeros((2)), np.zeros((5))] - y = layer(x) - self.assertEqual(y, x) - - def test_multiple_matching_dims_pass(self): - layer = tl.AssertShape('a,b,a,ab') - x = [np.ones((5,)), np.zeros((2)), np.zeros((5)), np.zeros((5, 2))] - y = layer(x) - self.assertEqual(y, x) - - def test_numeric_dims_pass(self): - layer = tl.AssertShape('23,1,93') - x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_numeric_dims_fail(self): - layer = tl.AssertShape('24,1,93') - x = [np.ones((2, 3)), np.zeros((1)), np.zeros((9, 3))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_middle_pass(self): - layer = tl.AssertShape('a...bc,abc') - x = [np.ones((1, 5, 5, 2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_ellipsis_prefix_pass(self): - layer = tl.AssertShape('...bc,abc') - x = [np.ones((5, 5, 2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_ellipsis_matching_zero_dims_pass(self): - layer = tl.AssertShape('...bc,abc') - x = [np.ones((2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_ellipsis_matching_ellipsis_pass(self): - layer = tl.AssertShape('...bc,...bc') - x = [np.ones((1, 2, 3)), np.zeros((1, 2, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_prefix_ellipsis_matching_sufix_ellipsis_pass(self): - layer = tl.AssertShape('bb...,...bb') - x = [np.ones((2, 2, 5, 6)), np.zeros((5, 6, 2, 2))] - y = layer(x) - self.assertEqual(y, x) - - def test_middle_ellipsis_fail(self): - layer = tl.AssertShape('ab...cde,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_short_middle_ellipsis_fail(self): - layer = tl.AssertShape('b...c,2') - x = [np.ones((2)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_double_ellipsis_fail(self): - layer = tl.AssertShape('b......c,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_typo_ellipsis_fail(self): - layer = tl.AssertShape('b..c,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_matching_ellipsis_fail(self): - layer = tl.AssertShape('...a,...b') - x = [np.ones((1, 2, 3, 7)), np.zeros((1, 2, 8))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_numeric_pass(self): - layer = tl.AssertShape('...22,...3') - x = [np.ones((1, 2, 3, 2, 2)), np.zeros((1, 2, 3, 3))] - y = layer(x) - self.assertEqual(y, x) - - def test_prefix_and_sufix_ellipsis_fail(self): - layer = tl.AssertShape('...c...,2') - x = [np.ones((2, 3, 4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipsis_too_few_dims_fail(self): - layer = tl.AssertShape('...abc,2') - x = [np.ones((4, 5)), np.zeros((2))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_ellipses_matching_dims_fail(self): - layer = tl.AssertShape('...2,...8') - x = [np.ones((1, 2, 3, 9)), np.zeros((1, 3, 3, 8))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_dims_matching_fail(self): - layer = tl.AssertShape('aba,ab') - x = [np.ones((10, 5, 10)), np.ones((5, 8))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_rank_fail(self): - layer = tl.AssertShape('aba,ab') - x = [np.ones((10, 5, 10)), np.ones((5, 10, 4))] - with self.assertRaises(tl.LayerError): - layer(x) - - def test_square_matrix_fail(self): - layer = tl.AssertShape('aa') - x = np.ones((10, 5)) - with self.assertRaises(tl.LayerError): - layer(x) - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/attention.py b/trax/layers/attention.py index 231c1eb8f..1419d445a 100644 --- a/trax/layers/attention.py +++ b/trax/layers/attention.py @@ -48,153 +48,45 @@ from trax import fastmath from trax.fastmath import numpy as jnp -from trax.layers import base +from trax.layers import base, convolution, core from trax.layers import combinators as cb -from trax.layers import convolution -from trax.layers import core from trax.layers import initializers as init from trax.layers.assert_shape import assert_shape from trax.layers.base import Fn from trax.layers.research import sparsity - # Layers are always CamelCase, but functions in general are snake_case # pylint: disable=invalid-name # inputs are [batch, length, depth], [batch, 1, 1 length] -@assert_shape('bld,b11l->bld,b11l') -def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): - """Returns a layer that maps `(vectors, mask)` to `(new_vectors, mask)`. - - This layer type represents one pass of multi-head self-attention, from vector - set to vector set, using masks to represent out-of-bound (e.g., padding) - positions. It: - - - makes three copies of incoming activations and maps these to multi-head - query (Q) vectors, key (K) vectors, and value (V) vectors, respectively; - - for each head, computes the scaled dot product of each Q-K pair; - - applies mask to screen out positions that come from padding tokens - (indicated by 0 value); - - [in ``'train'`` mode] applies dropout to Q-K dot products; - - for each head, computes Q-K attention strengths using a per-query softmax - of the Q-K dot products; - - for each head, for each query position, combines V vectors according - to the Q-K attention strengths; and - - concatenates and fuses resulting per-head vectors into outgoing - activations matching original input activation shapes. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - """ - return cb.Serial( - cb.Select([0, 0, 0]), - AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), - ) - - -@assert_shape('bSq,blk,blv,b1xl->bSd,b1xl') -def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', - cache_KV_in_predict=False, q_sparsity=None, - result_sparsity=None): - """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`. - - Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the - incoming activations (`AQ`, `AK`, and `AV`) to come from different sources. - This is used, for instance, in encoder-decoder attention (Q-related - activations `AQ` from the decoder, K- and V-related activations -- `AK` and - `AV` -- from the encoder). Otherwise, see the :py:class:`Attention` - description for further context/details. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode. - q_sparsity: Sparsity with which to process queries. If ``None``, - :py:class:`Dense` is used; if ``'noop'``, no processing is used. - result_sparsity: Sparsity with which to process result of the attention. - If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is - used. - """ - def _SparsifiableDense(layer_sparsity): - if layer_sparsity is None: - return core.Dense(d_feature) - elif layer_sparsity == 'noop': - return cb.Serial() # No-op layer. - else: - d_module = d_feature // layer_sparsity - return cb.Serial( - sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), - sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode, - kernel_size=3, length_kernel_size=3) - ) - - def _CacheableDense(): - if cache_KV_in_predict and mode == 'predict': - return cb.Cache(core.Dense(d_feature)) - else: - return core.Dense(d_feature) - - def _PureAttention(): - return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode) - - return cb.Serial( - cb.Parallel(_SparsifiableDense(q_sparsity), - _CacheableDense(), - _CacheableDense()), - _PureAttention(), - _SparsifiableDense(result_sparsity), - ) - - -# 'k' is number of keys/values, while 'l' is number of queries. Typically they -# will be the same, but it is not necessary. -@assert_shape('blq,bkq,bkd,b1xk->bld,b1xk') -class PureAttention(base.Layer): - """Returns a layer that maps `(Q, K, V, mask)` to `(activations, mask)`. - - This layer type performs the inner workings of one pass of multi-head - self-attention. It: - - - subdivides incoming Q/K/V activations into multi-head versions; - - for each head, computes the scaled dot product of each Q-K pair; - - applies mask to screen out positions that come from padding tokens - (indicated by 0 value); - - [in ``'train'`` mode] applies dropout to Q-K dot products; - - for each head, computes Q-K attention strengths using a per-query softmax - of the Q-K dot products; - - for each head, for each query position, combines V vectors according - to the Q-K attention strengths; and - - concatenates and fuses resulting per-head vectors into outgoing - activations matching original input activation shapes. - """ - - def __init__(self, n_heads=1, dropout=0.0, mode='train'): - """Returns a new :py:class:`PureAttention` instance. +@assert_shape("bld,b11l->bld,b11l") +def Attention(d_feature, n_heads=1, dropout=0.0, mode="train"): + """Returns a layer that maps `(vectors, mask)` to `(new_vectors, mask)`. + + This layer type represents one pass of multi-head self-attention, from vector + set to vector set, using masks to represent out-of-bound (e.g., padding) + positions. It: + + - makes three copies of incoming activations and maps these to multi-head + query (Q) vectors, key (K) vectors, and value (V) vectors, respectively; + - for each head, computes the scaled dot product of each Q-K pair; + - applies mask to screen out positions that come from padding tokens + (indicated by 0 value); + - [in ``'train'`` mode] applies dropout to Q-K dot products; + - for each head, computes Q-K attention strengths using a per-query softmax + of the Q-K dot products; + - for each head, for each query position, combines V vectors according + to the Q-K attention strengths; and + - concatenates and fuses resulting per-head vectors into outgoing + activations matching original input activation shapes. Args: - n_heads: Number of attention heads. + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors @@ -203,305 +95,384 @@ def __init__(self, n_heads=1, dropout=0.0, mode='train'): created in ``'train'`` mode. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ - super().__init__(n_in=4, n_out=2) - self._n_heads = n_heads - self._dropout = dropout - self._mode = mode - - def forward(self, inputs): - """Returns attention-computed activations and unmodified mask. + return cb.Serial( + cb.Select([0, 0, 0]), + AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), + ) + + +@assert_shape("bSq,blk,blv,b1xl->bSd,b1xl") +def AttentionQKV( + d_feature, + n_heads=1, + dropout=0.0, + mode="train", + cache_KV_in_predict=False, + q_sparsity=None, + result_sparsity=None, +): + """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`. + + Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the + incoming activations (`AQ`, `AK`, and `AV`) to come from different sources. + This is used, for instance, in encoder-decoder attention (Q-related + activations `AQ` from the decoder, K- and V-related activations -- `AK` and + `AV` -- from the encoder). Otherwise, see the :py:class:`Attention` + description for further context/details. Args: - inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value - activations have not yet been subdivided into heads. + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode. + q_sparsity: Sparsity with which to process queries. If ``None``, + :py:class:`Dense` is used; if ``'noop'``, no processing is used. + result_sparsity: Sparsity with which to process result of the attention. + If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is + used. """ - q, k, v, mask = inputs - d_feature = q.shape[-1] - n_heads = self._n_heads - if d_feature % n_heads != 0: - raise ValueError( - f'Dimensionality of feature embedding ({d_feature}) is not a ' - f'multiple of the requested number of attention heads ({n_heads}).') - - per_head_results, dots = _per_head_attention( - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), - mask, - dropout=self._dropout, - mode=self._mode, - rng=self.rng) - if self._mode == 'viz': - self.state = dots - merged_results = MergeHeads(n_heads, merged_batch_and_head=False).forward( - per_head_results) - return (merged_results, mask) + def _SparsifiableDense(layer_sparsity): + if layer_sparsity is None: + return core.Dense(d_feature) + elif layer_sparsity == "noop": + return cb.Serial() # No-op layer. + else: + d_module = d_feature // layer_sparsity + return cb.Serial( + sparsity.FactoredDense(layer_sparsity, d_feature, d_feature), + sparsity.LocallyConvDense( + layer_sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=3, + ), + ) + + def _CacheableDense(): + if cache_KV_in_predict and mode == "predict": + return cb.Cache(core.Dense(d_feature)) + else: + return core.Dense(d_feature) + + def _PureAttention(): + return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode) + + return cb.Serial( + cb.Parallel( + _SparsifiableDense(q_sparsity), _CacheableDense(), _CacheableDense() + ), + _PureAttention(), + _SparsifiableDense(result_sparsity), + ) -def _per_head_attention(queries, keys, values, mask, dropout, mode, rng): - """Computes new per-head activations via scaled dot-product attention. - - This function is the core of the attention mechanism. Given per-head - ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it: - - - computes the scaled dot product of each Q-K pair; - - applies ``mask`` to screen out positions that come from padding tokens - (indicated by 0 value); - - [in ``'train'`` mode] applies dropout to Q-K dot products; - - computes Q-K attention strengths using a per-query softmax of the Q-K dot - products; and - - for each query position, combines V vectors according to the Q-K - attention strengths. - - Args: - queries: Per-head activations representing attention queries. - keys: Per-head activations representing attention keys. - values: Per-head activations to be combined by computed attention strengths. - mask: Mask that distinguishes positions with real content vs. padding. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only in ``'train'`` - mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - rng: Single-use random number generator (JAX PRNG key). - - Returns: - Tuple of (activations, attn_strengths), where activations are new per-head - activation vectors and attn_strengths is a matrix of per-head attention - strengths. - """ - if dropout >= 1.0: - raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.') - - d_feature = queries.shape[-1] - - dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) - if mask is not None: - dots = jnp.where(mask, - dots, - jnp.full_like(dots, -1e9)) - attn_strengths = ( - jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))) - if dropout is not None and dropout > 0.0 and mode == 'train': - keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape) - attn_strengths = jnp.where(keep, - attn_strengths / (1.0 - dropout), - jnp.zeros_like(attn_strengths)) - activations = jnp.matmul(attn_strengths, values).astype(jnp.float32) - attn_strengths = attn_strengths.astype(jnp.float32) - return activations, attn_strengths +# 'k' is number of keys/values, while 'l' is number of queries. Typically they +# will be the same, but it is not necessary. +@assert_shape("blq,bkq,bkd,b1xk->bld,b1xk") +class PureAttention(base.Layer): + """Returns a layer that maps `(Q, K, V, mask)` to `(activations, mask)`. + + This layer type performs the inner workings of one pass of multi-head + self-attention. It: + + - subdivides incoming Q/K/V activations into multi-head versions; + - for each head, computes the scaled dot product of each Q-K pair; + - applies mask to screen out positions that come from padding tokens + (indicated by 0 value); + - [in ``'train'`` mode] applies dropout to Q-K dot products; + - for each head, computes Q-K attention strengths using a per-query softmax + of the Q-K dot products; + - for each head, for each query position, combines V vectors according + to the Q-K attention strengths; and + - concatenates and fuses resulting per-head vectors into outgoing + activations matching original input activation shapes. + """ + def __init__(self, n_heads=1, dropout=0.0, mode="train"): + """Returns a new :py:class:`PureAttention` instance. + + Args: + n_heads: Number of attention heads. + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ + super().__init__(n_in=4, n_out=2) + self._n_heads = n_heads + self._dropout = dropout + self._mode = mode + + def forward(self, inputs): + """Returns attention-computed activations and unmodified mask. + + Args: + inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value + activations have not yet been subdivided into heads. + """ + q, k, v, mask = inputs + + d_feature = q.shape[-1] + n_heads = self._n_heads + if d_feature % n_heads != 0: + raise ValueError( + f"Dimensionality of feature embedding ({d_feature}) is not a " + f"multiple of the requested number of attention heads ({n_heads})." + ) + + per_head_results, dots = _per_head_attention( + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), + mask, + dropout=self._dropout, + mode=self._mode, + rng=self.rng, + ) + if self._mode == "viz": + self.state = dots + merged_results = MergeHeads(n_heads, merged_batch_and_head=False).forward( + per_head_results + ) + return (merged_results, mask) -class DotProductAttention(base.Layer): - """Returns a layer that computes per-head attention (via scaled dot-product). - This layer computes the core of the attention mechanism. Given per-head - queries (Q), keys (K), values (V), and mask, it: +def _per_head_attention(queries, keys, values, mask, dropout, mode, rng): + """Computes new per-head activations via scaled dot-product attention. - - computes the scaled dot product of each Q-K pair; - - applies mask to screen out positions that come from padding tokens - (indicated by 0 value); - - [if created in ``'train'`` mode] applies dropout to Q-K dot products; - - computes Q-K attention strengths using a per-query softmax of the Q-K dot - products; and - - for each query position, combines V vectors according to the Q-K - attention strengths. - """ + This function is the core of the attention mechanism. Given per-head + ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it: - def __init__(self, dropout=0.0, mode='train'): - """Creates a :py:class:`DotProductAttention` instance in a specific mode. + - computes the scaled dot product of each Q-K pair; + - applies ``mask`` to screen out positions that come from padding tokens + (indicated by 0 value); + - [in ``'train'`` mode] applies dropout to Q-K dot products; + - computes Q-K attention strengths using a per-query softmax of the Q-K dot + products; and + - for each query position, combines V vectors according to the Q-K + attention strengths. Args: + queries: Per-head activations representing attention queries. + keys: Per-head activations representing attention keys. + values: Per-head activations to be combined by computed attention strengths. + mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - mode: One of ``'train'``, ``'eval'``, ``'predict'`` or ``'viz'``. - """ - super().__init__(n_in=4, n_out=1) - self._dropout = dropout - self._mode = mode - - def forward(self, inputs): - """Returns attention-computed per-head activations and unchanged mask. + cause some node activations to be ignored. Applies only in ``'train'`` + mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + rng: Single-use random number generator (JAX PRNG key). - Args: - inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value - activations have been subdivided into heads. + Returns: + Tuple of (activations, attn_strengths), where activations are new per-head + activation vectors and attn_strengths is a matrix of per-head attention + strengths. """ - q, k, v, mask = inputs - activations, attn_strengths = _per_head_attention( - q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) - if self._mode == 'viz': - self.state = attn_strengths - return activations + if dropout >= 1.0: + raise ValueError(f"Dropout rate ({dropout}) must be lower than 1.") + d_feature = queries.shape[-1] -# (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) -@assert_shape('bld->...lh') -def SplitIntoHeads(n_heads, merged_batch_and_head=True): - """Returns a layer that reshapes an array for multi-head computation.""" - def f(x): - batch_size, seq_len, d_feature = x.shape - if d_feature % n_heads != 0: - raise ValueError( - f'Feature embedding dimensionality ({d_feature}) is not a multiple' - f' of the requested number of attention heads ({n_heads}).') + dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) + if mask is not None: + dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) + attn_strengths = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) + if dropout is not None and dropout > 0.0 and mode == "train": + keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape) + attn_strengths = jnp.where( + keep, attn_strengths / (1.0 - dropout), jnp.zeros_like(attn_strengths) + ) + activations = jnp.matmul(attn_strengths, values).astype(jnp.float32) + attn_strengths = attn_strengths.astype(jnp.float32) + return activations, attn_strengths - d_head = d_feature // n_heads - # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) - x = x.reshape((batch_size, seq_len, n_heads, d_head)) - x = x.transpose((0, 2, 1, 3)) - if merged_batch_and_head: - x = x.reshape((batch_size * n_heads, seq_len, d_head)) - return x - return Fn('SplitIntoHeads', f) +class DotProductAttention(base.Layer): + """Returns a layer that computes per-head attention (via scaled dot-product). + + This layer computes the core of the attention mechanism. Given per-head + queries (Q), keys (K), values (V), and mask, it: + + - computes the scaled dot product of each Q-K pair; + - applies mask to screen out positions that come from padding tokens + (indicated by 0 value); + - [if created in ``'train'`` mode] applies dropout to Q-K dot products; + - computes Q-K attention strengths using a per-query softmax of the Q-K dot + products; and + - for each query position, combines V vectors according to the Q-K + attention strengths. + """ + def __init__(self, dropout=0.0, mode="train"): + """Creates a :py:class:`DotProductAttention` instance in a specific mode. + + Args: + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + mode: One of ``'train'``, ``'eval'``, ``'predict'`` or ``'viz'``. + """ + super().__init__(n_in=4, n_out=1) + self._dropout = dropout + self._mode = mode + + def forward(self, inputs): + """Returns attention-computed per-head activations and unchanged mask. + + Args: + inputs: A `(Q, K, V, mask)` tuple, whose query, key, and value + activations have been subdivided into heads. + """ + q, k, v, mask = inputs + activations, attn_strengths = _per_head_attention( + q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng + ) + if self._mode == "viz": + self.state = attn_strengths + return activations -# (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) -@assert_shape('...lh->bld') -def MergeHeads(n_heads, merged_batch_and_head=True): - """Returns a layer that rejoins heads, after multi-head computation.""" - def f(x): - if merged_batch_and_head: - dim_0, seq_len, d_head = x.shape - if dim_0 % n_heads != 0: - raise ValueError( - f"Array's leading dimension ({dim_0}) is not a multiple of the" - f" number of attention heads ({n_heads}).") - batch_size = dim_0 // n_heads - x = x.reshape((batch_size, n_heads, seq_len, d_head)) - else: - batch_size, _, seq_len, d_head = x.shape - - # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) - x = x.transpose((0, 2, 1, 3)) - x = x.reshape((batch_size, seq_len, n_heads * d_head)) - return x - return Fn('MergeHeads', f) - - -@assert_shape('bld->bld') -def ConfigurableAttention(q_layer, k_layer, v_layer, final_layer, # pylint: disable=invalid-name - qkv_attention_layer, n_heads=1): - """Returns a configured multi-head self-attention layer. - - A :py:class:`ConfigurableAttention` layer acts similarly to - :py:class:`Attention` layers, but with configurable components. It - - - makes three copies of incoming activations and uses ``q_layer``, - ``k_layer``, and ``v_layer`` to map activations to multi-head query (Q) - vectors, key (K) vectors, and value (V) vectors, respectively; - - uses ``qkv_attention_layer`` to compute per-head attention, similar to - :py:class:`DotProductAttention` or :py:class:`DotProductCausalAttention`; - - concatenates and fuses resulting per-head vectors into activations - matching original input activation shapes; and - - applies a final layer, ``final_layer``, mapping activations to - activations (with shape matching the original input activations). - - Args: - q_layer: Layer that maps input activations to per-head query activations. - k_layer: Layer that maps input activations to per-head key activations. - v_layer: Layer that maps input activations to per-head value activations. - final_layer: After main multi-head computation and rejoining of heads, - layer that maps activations to activations (with shape matching the - original input activations). - qkv_attention_layer: Layer the does the core multi-head self-attention - computation. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - """ - return cb.Serial( - cb.Branch( - [q_layer, SplitIntoHeads(n_heads)], - [k_layer, SplitIntoHeads(n_heads)], - [v_layer, SplitIntoHeads(n_heads)], - ), - qkv_attention_layer, - MergeHeads(n_heads), - final_layer - ) - - -@assert_shape('bld->bld') -def CausalAttention(d_feature, - n_heads=1, - dropout=0.0, - max_inference_length=2048, - use_dconv=False, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like :py:class:`Attention`, this layer type represents one pass of multi-head - self-attention, but with causal masking rather than padding-based masking. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size - ``d_feature / n_heads``. - dropout: Probababilistic rate for attention dropout, which overrides - (sets to zero) some attention strengths derived from query-key - matching. As a result, on a given forward pass, some value vectors - don't contribute to the output, analogous to how regular dropout can - cause some node activations to be ignored. Applies only if layer is - created in ``'train'`` mode. - max_inference_length: Maximum sequence length allowed in non-training - modes. - use_dconv: if True, use depthwise convolutions on top of dense layers - for Q, K and V. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - """ - if d_feature % n_heads != 0: - raise ValueError( - f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' - f'of the requested number of attention heads ({n_heads}).') - - def QKVLayer(): - """Function returning the Q, K and V layer.""" - if use_dconv: - return cb.Serial(core.Dense(d_feature), convolution.CausalDepthwiseConv()) - else: - return core.Dense(d_feature) +# (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) +@assert_shape("bld->...lh") +def SplitIntoHeads(n_heads, merged_batch_and_head=True): + """Returns a layer that reshapes an array for multi-head computation.""" - return ConfigurableAttention( - QKVLayer(), - QKVLayer(), - QKVLayer(), - core.Dense(d_feature), - n_heads=n_heads, - qkv_attention_layer=DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) + def f(x): + batch_size, seq_len, d_feature = x.shape + if d_feature % n_heads != 0: + raise ValueError( + f"Feature embedding dimensionality ({d_feature}) is not a multiple" + f" of the requested number of attention heads ({n_heads})." + ) + d_head = d_feature // n_heads -@assert_shape('bld,bld,bld->bld') -class DotProductCausalAttention(base.Layer): - """Layer that computes attention strengths by masking out the "future". + # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) + x = jnp.reshape(x, (batch_size, seq_len, n_heads, d_head)) + x = x.transpose((0, 2, 1, 3)) + if merged_batch_and_head: + x = jnp.reshape(x, (batch_size * n_heads, seq_len, d_head)) + return x + + return Fn("SplitIntoHeads", f) - Causal attention uses masking to prevent a given sequence position from - attending to positions greater than / following it. This is used, for - example, when training autoregressive sequence models, or when decoding a - sequence symbol by symbol. - This layer performs the core per-head attention calculation. The layer - assumes that any splitting into attention heads precedes it, and that any - merging of attention heads will follow it. - """ +# (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) +@assert_shape("...lh->bld") +def MergeHeads(n_heads, merged_batch_and_head=True): + """Returns a layer that rejoins heads, after multi-head computation.""" + + def f(x): + if merged_batch_and_head: + dim_0, seq_len, d_head = x.shape + if dim_0 % n_heads != 0: + raise ValueError( + f"Array's leading dimension ({dim_0}) is not a multiple of the" + f" number of attention heads ({n_heads})." + ) + + batch_size = dim_0 // n_heads + x = x.reshape((batch_size, n_heads, seq_len, d_head)) + else: + batch_size, _, seq_len, d_head = x.shape + + # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) + x = x.transpose((0, 2, 1, 3)) + x = x.reshape((batch_size, seq_len, n_heads * d_head)) + return x + + return Fn("MergeHeads", f) + + +@assert_shape("bld->bld") +def ConfigurableAttention( + q_layer, + k_layer, + v_layer, + final_layer, # pylint: disable=invalid-name + qkv_attention_layer, + n_heads=1, +): + """Returns a configured multi-head self-attention layer. + + A :py:class:`ConfigurableAttention` layer acts similarly to + :py:class:`Attention` layers, but with configurable components. It + + - makes three copies of incoming activations and uses ``q_layer``, + ``k_layer``, and ``v_layer`` to map activations to multi-head query (Q) + vectors, key (K) vectors, and value (V) vectors, respectively; + - uses ``qkv_attention_layer`` to compute per-head attention, similar to + :py:class:`DotProductAttention` or :py:class:`DotProductCausalAttention`; + - concatenates and fuses resulting per-head vectors into activations + matching original input activation shapes; and + - applies a final layer, ``final_layer``, mapping activations to + activations (with shape matching the original input activations). - def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'): - """Creates a :py:class:`DotProductCausalAttention` instance. + Args: + q_layer: Layer that maps input activations to per-head query activations. + k_layer: Layer that maps input activations to per-head key activations. + v_layer: Layer that maps input activations to per-head value activations. + final_layer: After main multi-head computation and rejoining of heads, + layer that maps activations to activations (with shape matching the + original input activations). + qkv_attention_layer: Layer the does the core multi-head self-attention + computation. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. + """ + return cb.Serial( + cb.Branch( + [q_layer, SplitIntoHeads(n_heads)], + [k_layer, SplitIntoHeads(n_heads)], + [v_layer, SplitIntoHeads(n_heads)], + ), + qkv_attention_layer, + MergeHeads(n_heads), + final_layer, + ) + + +@assert_shape("bld->bld") +def CausalAttention( + d_feature, + n_heads=1, + dropout=0.0, + max_inference_length=2048, + use_dconv=False, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like :py:class:`Attention`, this layer type represents one pass of multi-head + self-attention, but with causal masking rather than padding-based masking. Args: + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size + ``d_feature / n_heads``. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors @@ -510,345 +481,436 @@ def __init__(self, dropout=0.0, max_inference_length=2048, mode='train'): created in ``'train'`` mode. max_inference_length: Maximum sequence length allowed in non-training modes. + use_dconv: if True, use depthwise convolutions on top of dense layers + for Q, K and V. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ - super().__init__(n_in=3, n_out=1) - self._dropout = dropout - self._mode = mode - self._max_len = max_inference_length - self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none - - def monkey_patched_mask(self): - # This is necessary for Terraformer model. See comments there. - # The mask will only be used in Terraformer in predict mode. - return None + if d_feature % n_heads != 0: + raise ValueError( + f"Dimensionality of feature embedding ({d_feature}) is not a multiple " + f"of the requested number of attention heads ({n_heads})." + ) + + def QKVLayer(): + """Function returning the Q, K and V layer.""" + if use_dconv: + return cb.Serial(core.Dense(d_feature), convolution.CausalDepthwiseConv()) + else: + return core.Dense(d_feature) + + return ConfigurableAttention( + QKVLayer(), + QKVLayer(), + QKVLayer(), + core.Dense(d_feature), + n_heads=n_heads, + qkv_attention_layer=DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("bld,bld,bld->bld") +class DotProductCausalAttention(base.Layer): + """Layer that computes attention strengths by masking out the "future". - def forward(self, inputs): - """Returns attention-computed activations. + Causal attention uses masking to prevent a given sequence position from + attending to positions greater than / following it. This is used, for + example, when training autoregressive sequence models, or when decoding a + sequence symbol by symbol. - Args: - inputs: A (queries, keys, values) tuple. + This layer performs the core per-head attention calculation. The layer + assumes that any splitting into attention heads precedes it, and that any + merging of attention heads will follow it. """ - q, k, v = inputs - if self._portal_mask is not None: - mask_for_predict = self._portal_mask.get_value() - else: - mask_for_predict = None - - if self._mode == 'predict': - self.state, mask = _fast_inference_update_state( - inputs, self.state, - mask_for_predict=mask_for_predict) - if self._portal_mask is not None: - (_, k, v, _) = self.state - else: - (k, v, _) = self.state + def __init__(self, dropout=0.0, max_inference_length=2048, mode="train"): + """Creates a :py:class:`DotProductCausalAttention` instance. + + Args: + dropout: Probababilistic rate for attention dropout, which overrides + (sets to zero) some attention strengths derived from query-key + matching. As a result, on a given forward pass, some value vectors + don't contribute to the output, analogous to how regular dropout can + cause some node activations to be ignored. Applies only if layer is + created in ``'train'`` mode. + max_inference_length: Maximum sequence length allowed in non-training + modes. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ + super().__init__(n_in=3, n_out=1) + self._dropout = dropout + self._mode = mode + self._max_len = max_inference_length + self._portal_mask = ( + self.monkey_patched_mask() + ) # pylint: disable=assignment-from-none + + def monkey_patched_mask(self): + # This is necessary for Terraformer model. See comments there. + # The mask will only be used in Terraformer in predict mode. + return None + + def forward(self, inputs): + """Returns attention-computed activations. + + Args: + inputs: A (queries, keys, values) tuple. + """ + q, k, v = inputs + + if self._portal_mask is not None: + mask_for_predict = self._portal_mask.get_value() + else: + mask_for_predict = None + + if self._mode == "predict": + self.state, mask = _fast_inference_update_state( + inputs, self.state, mask_for_predict=mask_for_predict + ) + if self._portal_mask is not None: + (_, k, v, _) = self.state + else: + (k, v, _) = self.state + else: + sequence_length = q.shape[-2] + mask = _causal_mask(sequence_length) + + activations, attn_strengths = _per_head_attention( + q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng + ) + if self._mode == "viz": + self.state = attn_strengths + return activations + + def init_weights_and_state(self, input_signature): + """Initializes this layer for fast inference, if in ``'predict'`` mode.""" + if self._mode == "predict": + self.state = _fast_inference_init_state( + input_signature, self._max_len, predict_mask=self._portal_mask + ) + + +def _causal_mask(length): + # Not all backends define jnp.tril. However, using np.tril is inefficient + # in that it creates a large global constant. TODO(kitaev): try to find an + # alternative that works across all backends. + if fastmath.is_backend(fastmath.Backend.JAX): + return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0) else: - sequence_length = q.shape[-2] - mask = _causal_mask(sequence_length) + return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0) - activations, attn_strengths = _per_head_attention( - q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) - if self._mode == 'viz': - self.state = attn_strengths - return activations - def init_weights_and_state(self, input_signature): - """Initializes this layer for fast inference, if in ``'predict'`` mode.""" - if self._mode == 'predict': - self.state = _fast_inference_init_state( - input_signature, self._max_len, - predict_mask=self._portal_mask) +@assert_shape("...d->...d") +def ShiftRight(n_positions=1, mode="train"): + """Returns a layer that can insert padding to shift the input sequence. + Args: + n_positions: Number of positions to shift the input sequence rightward; + initial positions freed by the shift get padded with zeros. Applies + only if layer is created in a non-``'eval'`` mode. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ -def _causal_mask(length): - # Not all backends define jnp.tril. However, using np.tril is inefficient - # in that it creates a large global constant. TODO(kitaev): try to find an - # alternative that works across all backends. - if fastmath.is_backend(fastmath.Backend.JAX): - return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0) - else: - return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0) - - -@assert_shape('...d->...d') -def ShiftRight(n_positions=1, mode='train'): - """Returns a layer that can insert padding to shift the input sequence. - - Args: - n_positions: Number of positions to shift the input sequence rightward; - initial positions freed by the shift get padded with zeros. Applies - only if layer is created in a non-``'eval'`` mode. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - """ - # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads? - def f(x): - if mode == 'predict': - return x - padded = _zero_pad(x, (n_positions, 0), 1) - return padded[:, :-n_positions] - return Fn(f'ShiftRight({n_positions})', f) - - -@assert_shape('bs->b11l') + # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads? + def f(x): + if mode == "predict": + return x + padded = _zero_pad(x, (n_positions, 0), 1) + return padded[:, :-n_positions] + + return Fn(f"ShiftRight({n_positions})", f) + + +@assert_shape("bs->b11l") def PaddingMask(pad=0): - """Returns a layer that maps integer sequences to padding masks. - - The layer expects as input a batch of integer sequences. The layer output is - an N-D array that marks for each sequence position whether the integer (e.g., - a token ID) in that position represents padding -- value ``pad`` -- versus - text/content -- all other values. The padding mask shape is - (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast - to cover any number of attention heads and axis 2 will broadcast to cover - decoder sequence positions. - - Args: - pad: Integer that represents padding rather than a token/content ID. - """ - def f(x): - if len(x.shape) != 2: - raise ValueError( - f'Input to PaddingMask must be a 2-D array with shape ' - f'(batch_size, sequence_length); instead got shape {x.shape}.') - batch_size = x.shape[0] - sequence_length = x.shape[1] - content_positions = (x != pad) - return content_positions.reshape((batch_size, 1, 1, sequence_length)) - return Fn(f'PaddingMask({pad})', f) + """Returns a layer that maps integer sequences to padding masks. + + The layer expects as input a batch of integer sequences. The layer output is + an N-D array that marks for each sequence position whether the integer (e.g., + a token ID) in that position represents padding -- value ``pad`` -- versus + text/content -- all other values. The padding mask shape is + (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast + to cover any number of attention heads and axis 2 will broadcast to cover + decoder sequence positions. + + Args: + pad: Integer that represents padding rather than a token/content ID. + """ + + def f(x): + if len(x.shape) != 2: + raise ValueError( + f"Input to PaddingMask must be a 2-D array with shape " + f"(batch_size, sequence_length); instead got shape {x.shape}." + ) + batch_size = x.shape[0] + sequence_length = x.shape[1] + content_positions = x != pad + return content_positions.reshape((batch_size, 1, 1, sequence_length)) + + return Fn(f"PaddingMask({pad})", f) def EncoderDecoderMask(): - """Returns a layer that creates a mask for encoder-decoder cross attention. - - The layer expects two inputs: - - - decoder_input: batch of integer (e.g., token ID) sequences - - mask: padding mask from the encoder - - The layer output is a mask that marks for each sequence position (for both - encoder and decoder) whether that position can be attended to or not. The - encoder-decoder mask shape is (batch_size, 1, decoder_sequence_length, - encoder_sequence_length), such that axis 1 will automatically broadcast to - cover any number of attention heads. - """ - def f(decoder_input, mask): - if len(decoder_input.shape) != 3: - raise ValueError( - f'Decoder input to EncoderDecoderMask must be a 3-D array with ' - f'shape (batch_size, decoder_sequence_length, d_model); instead got ' - f'shape {decoder_input.shape}.') - batch_size = mask.shape[0] - encoder_sequence_length = mask.shape[-1] - decoder_sequence_length = decoder_input.shape[1] - mask = mask.reshape((batch_size, 1, 1, encoder_sequence_length)) - return mask + jnp.zeros((1, 1, decoder_sequence_length, 1)) - return Fn('EncoderDecoderMask', f) - - -@assert_shape('...d->...d') -class PositionalEncoding(base.Layer): - """Implements bare positional encoding. + """Returns a layer that creates a mask for encoder-decoder cross attention. - Positional encoding includes a kind of dropout, if the layer is created in - ``'train'`` mode with a nonzero ``dropout`` value. For such a layer, on each - forward pass a subset of sequence positions selected at random will *not* - receive positional marking. - """ + The layer expects two inputs: - def __init__(self, max_len=2048, dropout=0.0, dropout_broadcast_dims=(-2,), - use_bfloat16=False, start_from_zero_prob=1.0, - max_offset_to_add=0, d_feature=None, mode='train'): - """Creates a :py:class:`PositionalEncoding` instance in a given mode. + - decoder_input: batch of integer (e.g., token ID) sequences + - mask: padding mask from the encoder - Args: - max_len: Maximum input sequence length. - dropout: Probability of *not* adding positional encoding to a sequence - position. Applies only if layer is created in ``'train'`` mode. - dropout_broadcast_dims: Axes along which dropout mask values are - broadcast rather than individually set at random. - use_bfloat16: If ``True``, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - max_offset_to_add: maximum offset to add to the positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - d_feature: int or None; have this dimension for embeddings + shared FF if - not None. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + The layer output is a mask that marks for each sequence position (for both + encoder and decoder) whether that position can be attended to or not. The + encoder-decoder mask shape is (batch_size, 1, decoder_sequence_length, + encoder_sequence_length), such that axis 1 will automatically broadcast to + cover any number of attention heads. """ - super().__init__() - self._max_len = max_len - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if mode == 'train': - self._dropout = dropout - else: - self._dropout = 0.0 - self._dropout_broadcast_dims = dropout_broadcast_dims - self._use_bfloat16 = use_bfloat16 - self._start_from_zero_prob = start_from_zero_prob - self._max_offset_to_add = max_offset_to_add - self._mode = mode - self._d_feature = d_feature - - def forward(self, inputs): - """Returns the input activations, with added positional information.""" - weights = self.weights - if self._d_feature is not None: - weights, ff = weights - weights = jnp.dot(weights[:inputs.shape[1], :], ff) - if len(weights.shape) < 3: # old checkpoints have 1 in first dim already - weights = weights[None, :, :] # [1, self._max_len, d_feature] - if self._mode != 'predict': - x = inputs - symbol_size = jnp.shape(x)[1] - if self._mode != 'train' or self._start_from_zero_prob >= 1.0: - px = weights[:, :symbol_size, :] - else: - rng1, rng2 = fastmath.random.split(self.rng, 2) - start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) - start_from_zero = fastmath.random.uniform(rng2, (), jnp.float32, 0, 1) - start = jnp.where(start_from_zero < self._start_from_zero_prob, - jnp.zeros((), dtype=jnp.int32), start) - px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size, - axis=1) - if self._dropout == 0: - return x + px - else: - noise_shape = list(px.shape) - for dim in self._dropout_broadcast_dims: - noise_shape[dim] = 1 - keep_prob = 1.0 - self._dropout - keep = fastmath.random.bernoulli(self.rng, keep_prob, - tuple(noise_shape)) - multiplier = keep.astype(x.dtype) / keep_prob - return x + px * multiplier - else: - if self._dropout != 0: - raise ValueError(f'In predict mode, but dropout rate ' - f'({self._dropout}) is not zero.') - - # State in this class is only used for fast inference. In that case, - # the model is called with consecutive elements position-by-position. - # This positional encoding layer stores the index of the current - # position and increments it on each call. - emb = fastmath.dynamic_slice_in_dim( - weights, self.state, inputs.shape[1], axis=1) - self.state += inputs.shape[1] - return inputs + emb - - def init_weights_and_state(self, input_signature): - """Randomly initializes the positional encoding vectors. - Args: - input_signature: :py:class:`ShapeDtype` instance characterizing the input - this layer should compute on. + def f(decoder_input, mask): + if len(decoder_input.shape) != 3: + raise ValueError( + f"Decoder input to EncoderDecoderMask must be a 3-D array with " + f"shape (batch_size, decoder_sequence_length, d_model); instead got " + f"shape {decoder_input.shape}." + ) + batch_size = mask.shape[0] + encoder_sequence_length = mask.shape[-1] + decoder_sequence_length = decoder_input.shape[1] + mask = mask.reshape((batch_size, 1, 1, encoder_sequence_length)) + return mask + jnp.zeros((1, 1, decoder_sequence_length, 1)) + + return Fn("EncoderDecoderMask", f) + + +@assert_shape("...d->...d") +class PositionalEncoding(base.Layer): + """Implements bare positional encoding. + + Positional encoding includes a kind of dropout, if the layer is created in + ``'train'`` mode with a nonzero ``dropout`` value. For such a layer, on each + forward pass a subset of sequence positions selected at random will *not* + receive positional marking. """ - d_feature = input_signature.shape[-1] - if self._d_feature is not None: - d_feature = self._d_feature - pe = np.zeros((self._max_len, d_feature), dtype=np.float32) - position = np.arange(0, self._max_len)[:, np.newaxis] - div_term = np.exp( - np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) - pe[:, 0::2] = np.sin(position * div_term) - pe[:, 1::2] = np.cos(position * div_term) # [self._max_len, d_feature] - if self._use_bfloat16: - pe = pe.astype(jnp.bfloat16) - w = jnp.array(pe) # Trainable parameters, initialized above. - if self._d_feature is not None: - ff = init.GlorotUniformInitializer()( - (d_feature, input_signature.shape[-1]), self.rng) - self.weights = w, ff - else: - self.weights = w - if self._mode == 'predict': - self.state = jnp.zeros((), dtype=jnp.int32) + + def __init__( + self, + max_len=2048, + dropout=0.0, + dropout_broadcast_dims=(-2,), + use_bfloat16=False, + start_from_zero_prob=1.0, + max_offset_to_add=0, + d_feature=None, + mode="train", + ): + """Creates a :py:class:`PositionalEncoding` instance in a given mode. + + Args: + max_len: Maximum input sequence length. + dropout: Probability of *not* adding positional encoding to a sequence + position. Applies only if layer is created in ``'train'`` mode. + dropout_broadcast_dims: Axes along which dropout mask values are + broadcast rather than individually set at random. + use_bfloat16: If ``True``, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + max_offset_to_add: maximum offset to add to the positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + d_feature: int or None; have this dimension for embeddings + shared FF if + not None. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + """ + super().__init__() + self._max_len = max_len + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if mode == "train": + self._dropout = dropout + else: + self._dropout = 0.0 + self._dropout_broadcast_dims = dropout_broadcast_dims + self._use_bfloat16 = use_bfloat16 + self._start_from_zero_prob = start_from_zero_prob + self._max_offset_to_add = max_offset_to_add + self._mode = mode + self._d_feature = d_feature + + def forward(self, inputs): + """Returns the input activations, with added positional information.""" + weights = self.weights + if self._d_feature is not None: + weights, ff = weights + weights = jnp.dot(weights[: inputs.shape[1], :], ff) + if len(weights.shape) < 3: # old checkpoints have 1 in first dim already + weights = weights[None, :, :] # [1, self._max_len, d_feature] + if self._mode != "predict": + x = inputs + symbol_size = jnp.shape(x)[1] + if self._mode != "train" or self._start_from_zero_prob >= 1.0: + px = weights[:, :symbol_size, :] + else: + rng1, rng2 = fastmath.random.split(self.rng, 2) + start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) + start_from_zero = fastmath.random.uniform(rng2, (), jnp.float32, 0, 1) + start = jnp.where( + start_from_zero < self._start_from_zero_prob, + jnp.zeros((), dtype=jnp.int32), + start, + ) + px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size, axis=1) + if self._dropout == 0: + return x + px + else: + noise_shape = list(px.shape) + for dim in self._dropout_broadcast_dims: + noise_shape[dim] = 1 + keep_prob = 1.0 - self._dropout + keep = fastmath.random.bernoulli( + self.rng, keep_prob, tuple(noise_shape) + ) + multiplier = keep.astype(x.dtype) / keep_prob + return x + px * multiplier + else: + if self._dropout != 0: + raise ValueError( + f"In predict mode, but dropout rate " + f"({self._dropout}) is not zero." + ) + + # State in this class is only used for fast inference. In that case, + # the model is called with consecutive elements position-by-position. + # This positional encoding layer stores the index of the current + # position and increments it on each call. + emb = fastmath.dynamic_slice_in_dim( + weights, self.state, inputs.shape[1], axis=1 + ) + self.state += inputs.shape[1] + return inputs + emb + + def init_weights_and_state(self, input_signature): + """Randomly initializes the positional encoding vectors. + + Args: + input_signature: :py:class:`ShapeDtype` instance characterizing the input + this layer should compute on. + """ + d_feature = input_signature.shape[-1] + if self._d_feature is not None: + d_feature = self._d_feature + pe = np.zeros((self._max_len, d_feature), dtype=np.float32) + position = np.arange(0, self._max_len)[:, np.newaxis] + div_term = np.exp(np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + pe[:, 0::2] = np.sin(position * div_term) + pe[:, 1::2] = np.cos(position * div_term) # [self._max_len, d_feature] + if self._use_bfloat16: + pe = pe.astype(jnp.bfloat16) + w = jnp.array(pe) # Trainable parameters, initialized above. + if self._d_feature is not None: + ff = init.GlorotUniformInitializer()( + (d_feature, input_signature.shape[-1]), self.rng + ) + self.weights = w, ff + else: + self.weights = w + if self._mode == "predict": + self.state = jnp.zeros((), dtype=jnp.int32) def _zero_pad(x, pad, axis): - """Helper for jnp.pad with 0s for single-axis case.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = pad # Padding on axis. - return jnp.pad(x, pad_widths, mode='constant') - - -def _fast_inference_init_state(input_signature, buffer_length, - predict_mask=None): - """Returns an initial state for causal attention layer fast inference.""" - def zeros_for(batch_size, shape_dtype): - shape, dtype = shape_dtype.as_tuple() - d_feature = shape[-1] - return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) - - batch_size = input_signature[0].shape[0] - k = zeros_for(batch_size, input_signature[1]) - v = zeros_for(batch_size, input_signature[2]) - if predict_mask is not None: - mask_for_predict = jnp.zeros((buffer_length,)) != 0 - return (mask_for_predict, k, v, jnp.array(0)) - else: - return (k, v, jnp.array(0)) + """Helper for jnp.pad with 0s for single-axis case.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = pad # Padding on axis. + return jnp.pad(x, pad_widths, mode="constant") + + +def _fast_inference_init_state(input_signature, buffer_length, predict_mask=None): + """Returns an initial state for causal attention layer fast inference.""" + + def zeros_for(batch_size, shape_dtype): + shape, dtype = shape_dtype.as_tuple() + d_feature = shape[-1] + return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) + + batch_size = input_signature[0].shape[0] + k = zeros_for(batch_size, input_signature[1]) + v = zeros_for(batch_size, input_signature[2]) + if predict_mask is not None: + mask_for_predict = jnp.zeros((buffer_length,)) != 0 + return (mask_for_predict, k, v, jnp.array(0)) + else: + return (k, v, jnp.array(0)) def _fast_inference_update_state(inputs, state, mask_for_predict=None): - """Updates state of a causal attention layer for fast inference. - - The layer state stores arrays with cached values of keys and values, - as well as an index. To make shapes static, keys and values in the state are - long, and the index indicates where the new keys and values from inputs need - to be appended. - - During update, we append new_keys and new_values to keys and values at - position given by index. And we increment index by length of new keys. - We also create a mask to be 1 at appropriate positions (causal mask). - - Args: - inputs: a triple (new_queries, new_keys, new_values) - state: layer state with (keys, values, index) - mask_for_predict: mask used for predict mode. This is used only in - Terraformer. - - Returns: - Updated state and mask to be used. - """ - # Fast inference: run step-by-step, storing the sequence - # of keys and values calculated so far in state. - (_, new_k, new_v) = inputs - if mask_for_predict is not None: - (state_mask_for_predict, ks, vs, idx) = state - else: - (ks, vs, idx) = state - length = new_k.shape[1] - # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path - # with index_update when length == 1 is worth it. - # Keys and values are of shape [batch_size, length, d_kv]. - ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) - vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) - k_length = ks.shape[1] - - # Mask is of shape [1, q_length, k_length]. - # Mask should be true for every pair of (query_token, key_token) such that - # index of query_token is equal or larger to index of key_token. - mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) - <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1))) - if mask_for_predict is None: - return (ks, vs, idx + length), mask - else: - state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( - state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0, - axis=0) - - state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( - state_mask_for_predict != 0, jnp.ones((1,)) != 0, - jnp.sum(mask_for_predict, dtype=jnp.int32), axis=0) - - state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( - state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0) - placeholder = jnp.reshape(state_mask_for_predict != 0, - (1, 1, mask.shape[2],)) - mask = mask * placeholder - - return (state_mask_for_predict, ks, vs, idx + length), mask + """Updates state of a causal attention layer for fast inference. + + The layer state stores arrays with cached values of keys and values, + as well as an index. To make shapes static, keys and values in the state are + long, and the index indicates where the new keys and values from inputs need + to be appended. + + During update, we append new_keys and new_values to keys and values at + position given by index. And we increment index by length of new keys. + We also create a mask to be 1 at appropriate positions (causal mask). + + Args: + inputs: a triple (new_queries, new_keys, new_values) + state: layer state with (keys, values, index) + mask_for_predict: mask used for predict mode. This is used only in + Terraformer. + + Returns: + Updated state and mask to be used. + """ + # Fast inference: run step-by-step, storing the sequence + # of keys and values calculated so far in state. + (_, new_k, new_v) = inputs + if mask_for_predict is not None: + (state_mask_for_predict, ks, vs, idx) = state + else: + (ks, vs, idx) = state + length = new_k.shape[1] + # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path + # with index_update when length == 1 is worth it. + # Keys and values are of shape [batch_size, length, d_kv]. + ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) + vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) + k_length = ks.shape[1] + + # Mask is of shape [1, q_length, k_length]. + # Mask should be true for every pair of (query_token, key_token) such that + # index of query_token is equal or larger to index of key_token. + mask = jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) <= jnp.reshape( + jnp.arange(length) + idx, (1, length, 1) + ) + if mask_for_predict is None: + return (ks, vs, idx + length), mask + else: + state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( + state_mask_for_predict != 0, mask_for_predict.reshape((-1)) != 0, 0, axis=0 + ) + + state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( + state_mask_for_predict != 0, + jnp.ones((1,)) != 0, + jnp.sum(mask_for_predict, dtype=jnp.int32), + axis=0, + ) + + state_mask_for_predict = fastmath.dynamic_update_slice_in_dim( + state_mask_for_predict != 0, jnp.ones((1,)) != 0, idx, axis=0 + ) + placeholder = jnp.reshape( + state_mask_for_predict != 0, + ( + 1, + 1, + mask.shape[2], + ), + ) + mask = mask * placeholder + + return (state_mask_for_predict, ks, vs, idx + length), mask diff --git a/trax/layers/attention_test.py b/trax/layers/attention_test.py deleted file mode 100644 index 165866d62..000000000 --- a/trax/layers/attention_test.py +++ /dev/null @@ -1,190 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.attention.""" - -import functools -from absl.testing import absltest -import numpy as np - -from trax import shapes -import trax.layers as tl -from trax.layers import test_utils - - -class AttentionTest(absltest.TestCase): - - def test_simple_call(self): - layer = tl.CausalAttention(d_feature=4, n_heads=2) - x = [np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]), - np.array([[[[1, 0, 1]]]])] - _, _ = layer.init(shapes.signature(x)) - - y, mask = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - self.assertEqual(mask.shape, (1, 1, 1, 3)) - - def test_shift_right(self): - # Test shifts right on axis=1 - layer = tl.ShiftRight() - x = np.array([[[9, 9, 9], - [8, 8, 8], - [7, 7, 7], - [6, 6, 6]], - [[99, 98, 97], - [96, 95, 94], - [93, 92, 91], - [90, 89, 88]]]) - y = layer(x) - self.assertEqual(x.shape, y.shape) - self.assertEqual(tl.to_list(y), [[[0, 0, 0], - [9, 9, 9], - [8, 8, 8], - [7, 7, 7]], - [[0, 0, 0], - [99, 98, 97], - [96, 95, 94], - [93, 92, 91]]]) - - def test_shift_right_float(self): - layer = tl.ShiftRight() - x = np.array([[[9, 9, 9], - [8, 8, 8], - [7, 7, 7], - [6, 6, 6]], - [[99, 98, 97], - [96, 95, 94], - [93, 92, 91], - [90, 89, 88]]]).astype(np.float32) - x /= 2.0 - self.assertEqual(x.dtype, np.float32) - - y = layer(x) - self.assertEqual(y.dtype, np.float32) - self.assertEqual(tl.to_list(y), [[[0.0, 0.0, 0.0], - [4.5, 4.5, 4.5], - [4.0, 4.0, 4.0], - [3.5, 3.5, 3.5]], - [[0.0, 0.0, 0.0], - [49.5, 49.0, 48.5], - [48.0, 47.5, 47.0], - [46.5, 46.0, 45.5]]]) - - def test_padding_mask(self): - layer = tl.PaddingMask() - x = np.array([ - [1., 2., 3., 4., 0.], - [1., 2., 3., 0., 0.], - [1., 2., 0., 0., 0.], - ]) - y = layer(x) - self.assertEqual(x.shape, (3, 5)) - self.assertEqual(y.shape, (3, 1, 1, 5)) - np.testing.assert_equal(y, [[[[True, True, True, True, False]]], - [[[True, True, True, False, False]]], - [[[True, True, False, False, False]]]]) - - -class CausalAttentionTest(absltest.TestCase): - - def test_simple_call(self): - layer = tl.CausalAttention(d_feature=4, n_heads=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - def test_deterministic_eval(self): - d_model = 32 - seq_len = 3 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - tl.CausalAttention, - d_feature=d_model, - n_heads=4, - ) - - test_utils.test_eval_is_deterministic(inp, model_fn) - - def test_predict_equals_eval(self): - d_model = 32 - seq_len = 10 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - tl.CausalAttention, - d_feature=d_model, - n_heads=4, - ) - - test_utils.test_eval_equals_predict(inp, model_fn) - - -class PositionalEncodingTest(absltest.TestCase): - - def test_simple_call(self): - layer = tl.PositionalEncoding(max_len=8) - x = np.array([[[2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0]]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 2, 4)) - - def test_predict(self): - layer = tl.PositionalEncoding(max_len=8) - x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) - self.assertEqual(x.shape, (1, 4, 2)) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 4, 2)) - layer = tl.PositionalEncoding(max_len=8, mode='predict') - layer.init(shapes.signature(x[:, :1, :])) - y0 = layer(x[:, :1, :]) # just the first token - self.assertEqual(y0.shape, (1, 1, 2)) - self.assertTrue(np.array_equal(y0, y[:, :1, :])) - y1 = layer(x[:, 1:3, :]) # now the next 2 tokens - self.assertEqual(y1.shape, (1, 2, 2)) - self.assertTrue(np.array_equal(y1, y[:, 1:3, :])) - y2 = layer(x[:, 3:4, :]) # final one token - self.assertEqual(y2.shape, (1, 1, 2)) - self.assertTrue(np.array_equal(y2, y[:, 3:4, :])) - - def test_predict_equals_eval(self): - x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) - self.assertEqual(x.shape, (1, 4, 2)) - - layer_eval = tl.PositionalEncoding(max_len=8, d_feature=4, mode='eval') - layer_eval.init(shapes.signature(x)) - - output_eval = layer_eval(x) - - layer_predict = tl.PositionalEncoding(max_len=8, d_feature=4, - mode='predict') - layer_predict.init(shapes.signature(x)) - layer_predict.weights = layer_eval.weights - - output_predict = layer_predict(x) - self.assertTrue(np.array_equal(output_eval, output_predict)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/base.py b/trax/layers/base.py index 5c80f1521..0455710c4 100644 --- a/trax/layers/base.py +++ b/trax/layers/base.py @@ -21,1035 +21,1119 @@ import inspect import pickle import random +import sys import traceback import jax import numpy as np import tensorflow as tf +import trax.learning.supervised.history +import trax.utils.shapes + from trax import fastmath from trax.fastmath import nested_map from trax.fastmath import numpy as jnp -from trax.shapes import ShapeDtype -from trax.shapes import signature - +from trax.utils.shapes import ShapeDtype, signature + +sys.modules[ + "trax.shapes" +] = trax.utils.shapes # Load older pickle object, backward compatibility to test +sys.modules[ + "trax.supervised.history" +] = ( + trax.learning.supervised.history +) # Load older pickle object, backward compatibility to test +sys.modules[ + "trax.supervised.history" +] = ( + trax.learning.supervised.history +) # Load older pickle object, backward compatibility to test # TODO(lukaszkaiser): should we use special objects for these for clarity? -EMPTY_WEIGHTS = () # Used for layers that have no trainable weights. -EMPTY_STATE = () # Used for layers that have no non-trainable state. -GET_WEIGHTS_FROM_CACHE = {'__marker_for_cached_weights_': ()} -GET_STATE_FROM_CACHE = {'__marker_for_cached_state_': ()} +EMPTY_WEIGHTS = () # Used for layers that have no trainable weights. +EMPTY_STATE = () # Used for layers that have no non-trainable state. +GET_WEIGHTS_FROM_CACHE = {"__marker_for_cached_weights_": ()} +GET_STATE_FROM_CACHE = {"__marker_for_cached_state_": ()} N_WEIGHTS_SHARDS = 1 # TODO(lukaszkaiser): make weight-sharding non-global class Layer: - """Base class for composable layers in a deep learning network. - - Layers are the basic building blocks for deep learning models. A layer - computes a function from zero or more inputs to zero or more outputs, - optionally using trainable weights (common) and non-parameter state (not - common). - - Layer subclasses typically override at most two methods of the base `Layer` - class: - - `forward(inputs)`: - Computes the layer's output as part of a forward pass through the model. + """Base class for composable layers in a deep learning network. - `init_weights_and_state(self, input_signature)`: - Initializes the layer's weights and state to handle input with the given - signature (number, shapes and dtypes of input arguments). + Layers are the basic building blocks for deep learning models. A layer + computes a function from zero or more inputs to zero or more outputs, + optionally using trainable weights (common) and non-parameter state (not + common). - A small number of layer types are combinators -- they organize the computation - of their sublayers, e.g., applying their sublayers in series or in parallel. + Layer subclasses typically override at most two methods of the base `Layer` + class: - All layers have the following properties, with default values implemented - in the base `Layer` class: + `forward(inputs)`: + Computes the layer's output as part of a forward pass through the model. - - `n_in`: int (default 1) - - `n_out`: int (default 1) - - `weights`: tuple (default empty -- the layer has no weights) - - `state`: tuple (default empty -- the layer has no non-parameter state) - - `sublayers`: tuple (default empty -- the layer has no sublayers) + `init_weights_and_state(self, input_signature)`: + Initializes the layer's weights and state to handle input with the given + signature (number, shapes and dtypes of input arguments). - The inputs to a layer are tensors, packaged according to how many there are: + A small number of layer types are combinators -- they organize the computation + of their sublayers, e.g., applying their sublayers in series or in parallel. - - `n_in = 0`: an empty tuple - - `n_in = 1`: one tensor (NOT wrapped in a tuple) - - `n_in > 1`: a tuple of tensors + All layers have the following properties, with default values implemented + in the base `Layer` class: - (The special treatment of the single-input case is meant to simplify the - work of layer writers; this design choice may be revisited in the future.) + - `n_in`: int (default 1) + - `n_out`: int (default 1) + - `weights`: tuple (default empty -- the layer has no weights) + - `state`: tuple (default empty -- the layer has no non-parameter state) + - `sublayers`: tuple (default empty -- the layer has no sublayers) - The outputs from a layer are also tensors, packaged the same as layer inputs: + The inputs to a layer are tensors, packaged according to how many there are: - - `n_out = 0`: an empty tuple - - `n_out = 1`: the tensor (NOT wrapped in a tuple) - - `n_out > 1`: a tuple of tensors + - `n_in = 0`: an empty tuple + - `n_in = 1`: one tensor (NOT wrapped in a tuple) + - `n_in > 1`: a tuple of tensors - The Trax runtime maintains a data stack with which layer calls are composed. - For more complex data network architectures, possibly involving multiple data - flows, one can view each layer as a function from stack state to stack state, - where the function's inputs are a slice from the stack, and the function's - outputs are spliced back into the stack. - """ + (The special treatment of the single-input case is meant to simplify the + work of layer writers; this design choice may be revisited in the future.) - def __init__(self, n_in=1, n_out=1, name=None, sublayers_to_print=None): - """Creates a partially initialized, unconnected layer instance. + The outputs from a layer are also tensors, packaged the same as layer inputs: - Args: - n_in: Number of inputs expected by this layer. - n_out: Number of outputs promised by this layer. - name: Class-like name for this layer; for use when printing this layer. - sublayers_to_print: Sublayers to display when printing out this layer; - if None (the default), display all sublayers. - """ - self._n_in = n_in - self._n_out = n_out - self._name = self.__class__.__name__ if name is None else name - self._sublayers_to_print = sublayers_to_print - self._sublayers = () # Default is no sublayers. - - # The actual rng value/shape depends on the backend, which may not yet be - # initialized at the point this method is run. Hence, at first initialize - # only a seed random integer, in a backend-neutral way. - self._rng = None - self._rng_seed_int = random.randint(0, 2**31 - 1) - - # The private fields _weights and _state store the private part of - # layer weights and state. When a layer has no sublayers, these are - # the same as layer.weights and layer.state. For layers with sublayers - # (i.e., combinators), these just mark which weights are cached -- see - # the getter and setter for weights and state for details. - # There is no need to use these fields in most user-implemented classes. - self._weights = EMPTY_WEIGHTS # By default no trainable weights. - self._state = EMPTY_STATE # By default no non-trainable state. - - # Record layer creation site for use in LayerError messages. - # The frame can mutate, so copy relevant values out of it. - frame = _find_frame(inspect.currentframe()) - self._caller = {'filename': copy.copy(frame.f_code.co_filename), - 'lineno': int(frame.f_lineno)} - del frame # Just in case. - - self._init_cached = False - self._jit_cache = {} - - def __repr__(self): - """Renders this layer as a medium-detailed string, to help in debugging. - - Subclasses should aim for high-signal/low-noise when overriding this - method. + - `n_out = 0`: an empty tuple + - `n_out = 1`: the tensor (NOT wrapped in a tuple) + - `n_out > 1`: a tuple of tensors - Returns: - A high signal-to-noise string representing this layer. + The Trax runtime maintains a data stack with which layer calls are composed. + For more complex data network architectures, possibly involving multiple data + flows, one can view each layer as a function from stack state to stack state, + where the function's inputs are a slice from the stack, and the function's + outputs are spliced back into the stack. """ - def indent_string(x): - return ' ' + x.replace('\n', '\n ') - name_str = self._name - n_in, n_out = self.n_in, self.n_out - if n_in != 1: name_str += f'_in{n_in}' - if n_out != 1: name_str += f'_out{n_out}' + def __init__(self, n_in=1, n_out=1, name=None, sublayers_to_print=None): + """Creates a partially initialized, unconnected layer instance. + + Args: + n_in: Number of inputs expected by this layer. + n_out: Number of outputs promised by this layer. + name: Class-like name for this layer; for use when printing this layer. + sublayers_to_print: Sublayers to display when printing out this layer; + if None (the default), display all sublayers. + """ + self._n_in = n_in + self._n_out = n_out + self._name = self.__class__.__name__ if name is None else name + self._sublayers_to_print = sublayers_to_print + self._sublayers = () # Default is no sublayers. + + # The actual rng value/shape depends on the backend, which may not yet be + # initialized at the point this method is run. Hence, at first initialize + # only a seed random integer, in a backend-neutral way. + self._rng = None + self._rng_seed_int = random.randint(0, 2**31 - 1) + + # The private fields _weights and _state store the private part of + # layer weights and state. When a layer has no sublayers, these are + # the same as layer.weights and layer.state. For layers with sublayers + # (i.e., combinators), these just mark which weights are cached -- see + # the getter and setter for weights and state for details. + # There is no need to use these fields in most user-implemented classes. + self._weights = EMPTY_WEIGHTS # By default no trainable weights. + self._state = EMPTY_STATE # By default no non-trainable state. + + # Record layer creation site for use in LayerError messages. + # The frame can mutate, so copy relevant values out of it. + frame = _find_frame(inspect.currentframe()) + self._caller = { + "filename": copy.copy(frame.f_code.co_filename), + "lineno": int(frame.f_lineno), + } + del frame # Just in case. + + self._init_cached = False + self._jit_cache = {} + + def __repr__(self): + """Renders this layer as a medium-detailed string, to help in debugging. + + Subclasses should aim for high-signal/low-noise when overriding this + method. + + Returns: + A high signal-to-noise string representing this layer. + """ + + def indent_string(x): + return " " + x.replace("\n", "\n ") + + name_str = self._name + n_in, n_out = self.n_in, self.n_out + if n_in != 1: + name_str += f"_in{n_in}" + if n_out != 1: + name_str += f"_out{n_out}" + + if self._sublayers_to_print is not None: + substructure = self._sublayers_to_print + else: + substructure = self.sublayers + if substructure: + substructure_strs = [str(x) for x in substructure if str(x)] + substructure_str = "\n".join(indent_string(s) for s in substructure_strs) + return f"{name_str}[\n{substructure_str}\n]" + else: + return name_str + + def __call__(self, x, weights=None, state=None, rng=None): + """Makes layers callable; for use in tests or interactive settings. + + This convenience method helps library users play with, test, or otherwise + probe the behavior of layers outside of a full training environment. It + presents the layer as callable function from inputs to outputs, with the + option of manually specifying weights and non-parameter state per individual + call. For convenience, weights and non-parameter state are cached per layer + instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`, + and acquiring non-empty values either by initialization or from values + explicitly provided via the weights and state keyword arguments, in which + case the old weights will be preserved, and the state will be updated. + + Args: + x: Zero or more input tensors, packaged as described in the `Layer` class + docstring. + weights: Weights or `None`; if `None`, use self's cached weights value. + state: State or `None`; if `None`, use self's cached state value. + rng: Single-use random number generator (JAX PRNG key), or `None`; + if `None`, use a default computed from an integer 0 seed. + + Returns: + Zero or more output tensors, packaged as described in the `Layer` class + docstring. + """ + weights = self.weights if weights is None else weights + rng = self.rng if rng is None else rng + if state is not None: + self.state = state # Needed if the model wasn't fully initialized. + state = self.state + outputs, new_state = self.pure_fn(x, weights, state, rng) + self.state = new_state + return outputs - if self._sublayers_to_print is not None: - substructure = self._sublayers_to_print - else: - substructure = self.sublayers - if substructure: - substructure_strs = [str(x) for x in substructure if str(x)] - substructure_str = '\n'.join(indent_string(s) for s in substructure_strs) - return f'{name_str}[\n{substructure_str}\n]' - else: - return name_str + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. - def __call__(self, x, weights=None, state=None, rng=None): - """Makes layers callable; for use in tests or interactive settings. + A layer subclass overrides this method to define how the layer computes + outputs from inputs. If the layer depends on weights, state, or randomness + as part of the computation, the needed information can be accessed as + properties of the layer object: `self.weights`, `self.state`, and + `self.rng`. (See numerous examples in `trax.layers.core`.) - This convenience method helps library users play with, test, or otherwise - probe the behavior of layers outside of a full training environment. It - presents the layer as callable function from inputs to outputs, with the - option of manually specifying weights and non-parameter state per individual - call. For convenience, weights and non-parameter state are cached per layer - instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`, - and acquiring non-empty values either by initialization or from values - explicitly provided via the weights and state keyword arguments, in which - case the old weights will be preserved, and the state will be updated. + Args: + inputs: Zero or more input tensors, packaged as described in the `Layer` + class docstring. - Args: - x: Zero or more input tensors, packaged as described in the `Layer` class + Returns: + Zero or more output tensors, packaged as described in the `Layer` class docstring. - weights: Weights or `None`; if `None`, use self's cached weights value. - state: State or `None`; if `None`, use self's cached state value. - rng: Single-use random number generator (JAX PRNG key), or `None`; - if `None`, use a default computed from an integer 0 seed. + """ + raise NotImplementedError + + def init_weights_and_state(self, input_signature): + """Initializes weights and state, to handle input with the given signature. + + A layer subclass must override this method if the layer uses weights or + state. To initialize weights, set `self.weights` to desired (typically + random) values. To initialize state (uncommon), set `self.state` to desired + starting values. + + Args: + input_signature: A `ShapeDtype` instance (if this layer takes one input) + or a list/tuple of `ShapeDtype` instances. + """ + del input_signature + + @property + def has_backward(self): + """Returns `True` if this layer provides its own custom backward pass code. + + A layer subclass that provides custom backward pass code (for custom + gradients) must override this method to return `True`. + """ + return False + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + """Custom backward pass to propagate gradients in a custom way. + + Args: + inputs: Input tensors; can be a (possibly nested) tuple. + output: The result of running this layer on inputs. + grad: Gradient signal computed based on subsequent layers; its structure + and shape must match output. + weights: This layer's weights. + state: This layer's state prior to the current forward pass. + new_state: This layer's state after the current forward pass. + rng: Single-use random number generator (JAX PRNG key). + + Returns: + The custom gradient signal for the input. Note that we need to return + a gradient for each argument of forward, so it will usually be a tuple + of signals: the gradient for inputs and weights. + """ + raise NotImplementedError + + # End of public subclassing interface. + # Begin public callable interface. + + def init(self, input_signature, rng=None, use_cache=False): + """Initializes weights/state of this layer and its sublayers recursively. + + Initialization creates layer weights and state, for layers that use them. + It derives the necessary array shapes and data types from the layer's input + signature, which is itself just shape and data type information. + + For layers without weights or state, this method safely does nothing. + + This method is designed to create weights/state only once for each layer + instance, even if the same layer instance occurs in multiple places in the + network. This enables weight sharing to be implemented as layer sharing. + + Args: + input_signature: `ShapeDtype` instance (if this layer takes one input) + or list/tuple of `ShapeDtype` instances. + rng: Single-use random number generator (JAX PRNG key), or `None`; + if `None`, use a default computed from an integer 0 seed. + use_cache: If `True`, and if this layer instance has already been + initialized elsewhere in the network, then return special marker + values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`. + Else return this layer's newly initialized weights and state. + + Returns: + A `(weights, state)` tuple. + """ + try: + if self._init_cached and use_cache: + return (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE) + + if rng is not None: + self.rng = rng + self.init_weights_and_state(input_signature) + + if use_cache: + self._init_cached = True + else: + self._clear_init_cache() + + return (self.weights, self.state) + + except Exception: + # Skipping 3 lines as it's always the uninteresting internal call. + name, trace = self._name, _short_traceback(skip=3) + raise LayerError( + name, "init", self._caller, input_signature, trace + ) from None + + def init_from_file(self, file_name, weights_only=False, input_signature=None): + """Initializes this layer and its sublayers from a pickled checkpoint. + + In the common case (`weights_only=False`), the file must be a gziped pickled + dictionary containing items with keys `'flat_weights', `'flat_state'` and + `'input_signature'`, which are used to initialize this layer. + If `input_signature` is specified, it's used instead of the one in the file. + If `weights_only` is `True`, the dictionary does not need to have the + `'flat_state'` item and the state it not restored either. + + Args: + file_name: Name/path of the pickled weights/state file. + weights_only: If `True`, initialize only the layer's weights. Else + initialize both weights and state. + input_signature: Input signature to be used instead of the one from file. + + Returns: + A `(weights, state)` tuple. + """ + with tf.io.gfile.GFile(file_name, "rb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: + dictionary = pickle.load(gzipf) + # In the current checkpoint format, we store weights in a separate + # non-pickled file with the same name but added ".npy". + if isinstance(dictionary["flat_weights"], int): + if file_name.endswith(".pkl.gz"): + weights_path = file_name[:-6] + "weights.npy.gz" + else: + weights_path = file_name + ".npy" + if not tf.io.gfile.exists(weights_path): # old format compatibility + weights_path = file_name + ".npy" + dictionary["flat_weights"] = np_from_file( + weights_path, compresslevel=dictionary["flat_weights"] + ) + if input_signature is None: + input_signature = dictionary["input_signature"] + if weights_only and input_signature is not None: + self.init(input_signature) + weights_and_state_sig = self.weights_and_state_signature(input_signature) + weights, state = unflatten_weights_and_state( + dictionary["flat_weights"], + dictionary["flat_state"], + weights_and_state_sig, + weights_only=weights_only, + ) + if not weights_only: + self.state = state + self.weights = weights + return (self.weights, self.state) + + def save_to_file(self, file_name, weights_only=False, input_signature=None): + """Saves this layer and its sublayers to a pickled checkpoint. + + Args: + file_name: Name/path of the pickled weights/state file. + weights_only: If `True`, save only the layer's weights. Else + save both weights and state. + input_signature: Input signature to be used. + """ + flat_weights, flat_state = flatten_weights_and_state(self.weights, self.state) + dictionary = { + "flat_weights": flat_weights, + } + if not weights_only: + dictionary["flat_state"] = flat_state + if input_signature is not None: + dictionary["input_signature"] = input_signature + + tmp_file_path = file_name + "._tmp_" + with tf.io.gfile.GFile(tmp_file_path, "wb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: + pickle.dump(dictionary, gzipf, protocol=pickle.HIGHEST_PROTOCOL) + # Moving a file is much less error-prone than pickling large files. + tf.io.gfile.rename(tmp_file_path, file_name, overwrite=True) + + def flatten_tuple(self, inputs): + flat_tuple = () + for _input in inputs: + if isinstance(_input, tuple): + flat_tuple += self.flatten_tuple(_input) + else: + flat_tuple += (_input,) + return flat_tuple + + # End of public callable methods. + # Methods and properties below are reserved for internal use. + + @property + def name(self): + """Returns the name of this layer.""" + return self._name + + @property + def n_in(self): + """Returns how many tensors this layer expects as input.""" + return self._n_in + + @property + def n_out(self): + """Returns how many tensors this layer promises as output.""" + return self._n_out + + @property + def sublayers(self): + """Returns a tuple containing this layer's sublayers; may be empty.""" + return self._sublayers + + @property + def weights(self): + """Returns this layer's weights. + + Depending on the layer, the weights can be in the form of: + + - an empty tuple + - a tensor (ndarray) + - a nested structure of tuples and tensors + + If the layer has sublayers, the weights by convention will be + a tuple of length `len(sublayers)` containing the weights of sublayers. + Note that in this case self._weights only marks which ones are shared. + """ + if not self.sublayers: + return self._weights + else: + return tuple( + layer.weights if w is None else w + for (layer, w) in zip(self.sublayers, self._weights) + ) + + @weights.setter + def weights(self, weights): + """Sets the weights of this layer and its sublayers. + + Args: + weights: the weights to set; if layer has sublayers, weights should be + either a list or a tuple of the same length as `len(self.sublayers)` + and it will be used to set the weights of all sublayers. + """ + if isinstance(weights, dict) and weights == GET_WEIGHTS_FROM_CACHE: + return + if not self.sublayers: + self._weights = weights + else: + # When having sublayers, self._weights just marks which are cached, + # the actual weights are stored by sublayers. + self._weights = [] + for w in weights: + if isinstance(w, dict) and w == GET_WEIGHTS_FROM_CACHE: + self._weights.append(w) + else: + self._weights.append(None) + # Set sublayer weights. + n_layers = len(self.sublayers) + if len(weights) != n_layers: + raise ValueError( + f"Number of weight elements ({len(weights)}) does not equal the " + f"number of sublayers ({n_layers}) in: {str(self)}." + ) + for sublayer, sublayer_weights in zip(self.sublayers, weights): + sublayer.weights = sublayer_weights + + @property + def state(self): + """Returns a tuple containing this layer's state; may be empty. + + If the layer has sublayers, the state by convention will be + a tuple of length `len(sublayers)` containing sublayer states. + Note that in this case self._state only marks which ones are shared. + """ + if not self.sublayers: + return self._state + else: + return tuple( + layer.state if s is None else s + for (layer, s) in zip(self.sublayers, self._state) + ) + + @state.setter + def state(self, state): + """Sets the state of this layer and its sublayers. + + Args: + state: the state to set; if layer has sublayers, state should be + either a list or a tuple of the same length as `len(self.sublayers)` + and it will be used to set the state of all sublayers. + """ + if isinstance(state, dict) and state == GET_STATE_FROM_CACHE: + return + if not self._sublayers: + self._state = state + else: + # When having sublayers, self._state just marks which are cached, + # the actual weights are stored by sublayers. + self._state = [] + for s in state: + if isinstance(s, dict) and s == GET_STATE_FROM_CACHE: + self._state.append(s) + else: + self._state.append(None) + # Set sublayer states. + n_layers = len(self.sublayers) + if len(state) != n_layers: + raise ValueError( + f"Number of state elements ({len(state)}) does not equal the " + f"number of sublayers ({n_layers}) in: {str(self)}." + ) + for sublayer, sublayer_state in zip(self.sublayers, state): + sublayer.state = sublayer_state + + def weights_and_state_signature(self, input_signature, unsafe=False): + """Return a pair containing the signatures of weights and state.""" + rng, state, weights = self.rng, self.state, self.weights + abstract_init = fastmath.abstract_eval(self.init) + sig = abstract_init(input_signature) + self.rng = rng + if not unsafe: + self.state, self.weights = state, weights + return sig + + @property + def rng(self): + """Returns this layer's current single-use random number generator. + + Code that wants to base random samples on this generator must explicitly + split off new generators from it. (See, for example, the `rng` setter code + below.) + """ + if self._rng is None: + # One-time initialization from backend-neutral seed int. + self._rng = fastmath.random.get_prng(self._rng_seed_int) + return self._rng + + @rng.setter + def rng(self, rng): + """Sets the rng (JAX PRNG key) for this layer and sublayers, recursively.""" + self._rng = rng + sublayers = self.sublayers + if sublayers: + rngs = fastmath.random.split(rng, len(sublayers)) + for sublayer, rng in zip(sublayers, rngs): + sublayer.rng = rng + + def _clear_init_cache(self): + self._init_cached = False + for sublayer in self.sublayers: + sublayer._clear_init_cache() # pylint: disable=protected-access + + def pure_fn(self, x, weights, state, rng, use_cache=False): + """Applies this layer as a pure function with no optional args. + + This method exposes the layer's computation as a pure function. This is + especially useful for JIT compilation. Do not override, use `forward` + instead. + + Args: + x: Zero or more input tensors, packaged as described in the `Layer` class + docstring. + weights: A tuple or list of trainable weights, with one element for this + layer if this layer has no sublayers, or one for each sublayer if + this layer has sublayers. If a layer (or sublayer) has no trainable + weights, the corresponding weights element is an empty tuple. + state: Layer-specific non-parameter state that can update between batches. + rng: Single-use random number generator (JAX PRNG key). + use_cache: if `True`, cache weights and state in the layer object; used + to implement layer sharing in combinators. + + Returns: + A tuple of `(tensors, state)`. The tensors match the number (`n_out`) + promised by this layer, and are packaged as described in the `Layer` + class docstring. + """ + try: + old_weights, old_state, old_rng = self.weights, self.state, self.rng + self._rng = rng + # The isinstance check is only needed when == is overloaded, as in TF. + if ( + isinstance(weights, dict) + and isinstance(state, dict) + and weights == GET_WEIGHTS_FROM_CACHE + and state == GET_STATE_FROM_CACHE + ): + was_cached = True + weights = self.weights + state = self.state + else: + # In this case, we're called for the first time: cache weights. + was_cached = False + self.weights, self.state = weights, state + + # If weights are sharded across multiple devices, unshard before forward. + sharded_weights, weights_were_unsharded = weights, False + if N_WEIGHTS_SHARDS > 1 and not self.sublayers: + self.weights, weights_were_unsharded = unshard_in_pmap( + weights, N_WEIGHTS_SHARDS + ) + + if not self.has_backward: + outputs = self.forward(x) + s = self.state + else: + outputs, s = self._do_custom_gradients(x) + self.state = s + self._rng = old_rng + if weights_were_unsharded: # only store a shard of weights if sharded + self.weights = sharded_weights + + if not use_cache: + self.weights, self.state = old_weights, old_state + if was_cached: # If the layer was shared, return a state marking this. + s = GET_STATE_FROM_CACHE + return outputs, s + + except Exception: + # Skipping 3 lines as it's always the uninteresting internal call. + name, trace = self._name, _short_traceback(skip=3) + raise LayerError( + name, "pure_fn", self._caller, signature(x), trace + ) from None + + def output_signature(self, input_signature): + """Returns output signature this layer would give for `input_signature`.""" + return self._forward_abstract(input_signature)[0] # output only, not state + + def _forward_abstract(self, input_signature): + """Computes shapes and dtypes this layer would produce in a forward pass. + + Args: + input_signature: `ShapeDtype` instance (if this layer takes one input) + or list/tuple of `ShapeDtype` instances. + + Returns: + Tuple of (output, state). + + The output part of the tuple is a `ShapeDtype` instance representing the + shape and type of the output (if this layer has one output) or a tuple + of `ShapeDtype` instances (if this layer has more than one output). + """ + try: + # Note: By using rng_signature in place of an rng, we avoid computing and + # permanently storing in global memory a large number of dropout masks. + # TODO(jonni): Check if using an rng still carries this cost. + dummy_rng = fastmath.random.get_prng(0) + rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype) + weights_signature = nested_map(signature, self.weights) + state_signature = nested_map(signature, self.state) + forward_infer_shapes = fastmath.abstract_eval(self.pure_fn) + return forward_infer_shapes( + input_signature, weights_signature, state_signature, rng_signature + ) + except Exception: + # TODO(lukaszkaiser): the choice of 7 is a heuristic, can we automate it? + # Skipping 7 lines which are all JAX abstract'ifying wrappers. + name, trace = self._name, _short_traceback(skip=7) + raise LayerError( + name, "_forward_abstract", self._caller, input_signature, trace + ) from None + + # pylint: disable=protected-access + def _do_custom_gradients(self, x): + """Calls this layer for a forward pass, but with custom gradients.""" + + def _f(state, rng, y, weights): + old_weights, old_state, old_rng = self.weights, self.state, self._rng + self.weights, self.state, self._rng = weights, state, rng + res = self.forward(y) + s = self.state + self.weights, self.state, self._rng = old_weights, old_state, old_rng + return res, s + + def _f_fwd(state, rng, y, weights): + old_weights, old_state, old_rng = self.weights, self.state, self._rng + self.weights, self.state, self._rng = weights, state, rng + res = self.forward(y) + s = self.state + self.weights, self.state, self._rng = old_weights, old_state, old_rng + return (res, s), (state, rng, y, res, weights, s) + + def _f_bwd(residual, grad): + """Custom gradient function.""" + state, rng, y, output, weights, new_state = residual + grad = grad[0] # Ignore dummy gradient wrt state. + out = self.backward(y, output, grad, weights, state, new_state, rng) + return (None, None, *out) + + do_forward = fastmath.custom_vjp(_f, _f_fwd, _f_bwd, nondiff_argnums=(0, 1)) + + output, state = do_forward(self.state, self._rng, x, self.weights) + return output, state + + def _settable_attrs(self): + """We only allow to set these attributes in Trax layers to prevent typos.""" + return ("weights", "state", "rng") + + def __setattr__(self, attr, value): + """Sets class attributes and protects from typos. + + In Trax layers, we only allow to set the following public attributes:: + + -weights + -state + -rng + + This function prevents from setting other public attributes to avoid typos, + for example, this is not possible and would be without this function:: + + [typo] layer.weighs = some_tensor + + If you need to set other public attributes in a derived class (which we + do not recommend as in almost all cases it suffices to use a private + attribute), override self._settable_attrs to include the attribute name. + + Args: + attr: Name of the attribute to be set. + value: Value to be assigned to the attribute. + """ + if attr[0] != "_" and attr not in self._settable_attrs(): + raise ValueError( + f"Trax layers only allow to set {self._settable_attrs()} as public " + f"attribues, not {attr}." + ) + else: + super().__setattr__(attr, value) - Returns: - Zero or more output tensors, packaged as described in the `Layer` class - docstring. - """ - weights = self.weights if weights is None else weights - rng = self.rng if rng is None else rng - if state is not None: - self.state = state # Needed if the model wasn't fully initialized. - state = self.state - outputs, new_state = self.pure_fn(x, weights, state, rng) - self.state = new_state - return outputs - - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. - - A layer subclass overrides this method to define how the layer computes - outputs from inputs. If the layer depends on weights, state, or randomness - as part of the computation, the needed information can be accessed as - properties of the layer object: `self.weights`, `self.state`, and - `self.rng`. (See numerous examples in `trax.layers.core`.) - Args: - inputs: Zero or more input tensors, packaged as described in the `Layer` - class docstring. +class PureLayer(Layer): + """Pure function from inputs to outputs, packaged as neural network layer. - Returns: - Zero or more output tensors, packaged as described in the `Layer` class - docstring. + The `PureLayer` class represents the simplest kinds of layers: layers with + no trainable weights and no randomness, hence pure functions from inputs to + outputs. """ - raise NotImplementedError - def init_weights_and_state(self, input_signature): - """Initializes weights and state, to handle input with the given signature. + def __init__(self, forward_fn, n_in=1, n_out=1, name="PureLayer"): + """Creates an unconnected `PureLayer` instance. - A layer subclass must override this method if the layer uses weights or - state. To initialize weights, set `self.weights` to desired (typically - random) values. To initialize state (uncommon), set `self.state` to desired - starting values. + Args: + forward_fn: Pure function from input tensors to output tensors, where + inputs and outputs are packaged as specified for `forward`. + n_in: Number of inputs expected by this layer. + n_out: Number of outputs promised by this layer. + name: Class-like name for this layer; for use only in debugging. + """ + super().__init__(n_in, n_out, name) + self._forward_fn = forward_fn - Args: - input_signature: A `ShapeDtype` instance (if this layer takes one input) - or a list/tuple of `ShapeDtype` instances. - """ - del input_signature + def forward(self, inputs): + """Overrides `Layer.forward`. - @property - def has_backward(self): - """Returns `True` if this layer provides its own custom backward pass code. + Args: + inputs: Zero or more input tensors, packaged as described in the `Layer` + class docstring. - A layer subclass that provides custom backward pass code (for custom - gradients) must override this method to return `True`. - """ - return False + Returns: + Zero or more output tensors, packaged as described in the `Layer` class + docstring. + """ + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) - def backward(self, inputs, output, grad, weights, state, new_state, rng): - """Custom backward pass to propagate gradients in a custom way. + # The input should be a flat single tuple without nested tuples + inputs = self.flatten_tuple(inputs) - Args: - inputs: Input tensors; can be a (possibly nested) tuple. - output: The result of running this layer on inputs. - grad: Gradient signal computed based on subsequent layers; its structure - and shape must match output. - weights: This layer's weights. - state: This layer's state prior to the current forward pass. - new_state: This layer's state after the current forward pass. - rng: Single-use random number generator (JAX PRNG key). + _validate_forward_input(inputs, self.n_in) - Returns: - The custom gradient signal for the input. Note that we need to return - a gradient for each argument of forward, so it will usually be a tuple - of signals: the gradient for inputs and weights. - """ - raise NotImplementedError + raw_output = self._forward_fn(inputs) + output = () if _is_empty(raw_output) else raw_output + return output - # End of public subclassing interface. - # Begin public callable interface. - def init(self, input_signature, rng=None, use_cache=False): - """Initializes weights/state of this layer and its sublayers recursively. +def Fn(name, f, n_out=1): # pylint: disable=invalid-name + """Returns a layer with no weights that applies the function `f`. - Initialization creates layer weights and state, for layers that use them. - It derives the necessary array shapes and data types from the layer's input - signature, which is itself just shape and data type information. + `f` can take and return any number of arguments, and takes only positional + arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`). + The following, for example, would create a layer that takes two inputs and + returns two outputs -- element-wise sums and maxima: - For layers without weights or state, this method safely does nothing. + `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)` - This method is designed to create weights/state only once for each layer - instance, even if the same layer instance occurs in multiple places in the - network. This enables weight sharing to be implemented as layer sharing. + The layer's number of inputs (`n_in`) is automatically set to number of + positional arguments in `f`, but you must explicitly set the number of + outputs (`n_out`) whenever it's not the default value 1. Args: - input_signature: `ShapeDtype` instance (if this layer takes one input) - or list/tuple of `ShapeDtype` instances. - rng: Single-use random number generator (JAX PRNG key), or `None`; - if `None`, use a default computed from an integer 0 seed. - use_cache: If `True`, and if this layer instance has already been - initialized elsewhere in the network, then return special marker - values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`. - Else return this layer's newly initialized weights and state. + name: Class-like name for the resulting layer; for use in debugging. + f: Pure function from input tensors to output tensors, where each input + tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`. + Output tensors must be packaged as specified in the `Layer` class + docstring. + n_out: Number of outputs promised by the layer; default value 1. Returns: - A `(weights, state)` tuple. + Layer executing the function `f`. """ - try: - if self._init_cached and use_cache: - return (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE) - - if rng is not None: - self.rng = rng - self.init_weights_and_state(input_signature) + argspec = inspect.getfullargspec(f) + if argspec.defaults is not None: + raise ValueError("Function has default arguments (not allowed).") + if argspec.varkw is not None: + raise ValueError("Function has keyword arguments (not allowed).") + if argspec.varargs is not None: + raise ValueError("Function has variable args (not allowed).") - if use_cache: - self._init_cached = True - else: - self._clear_init_cache() + def _forward(xs): # pylint: disable=invalid-name + if not isinstance(xs, (tuple, list)): + xs = (xs,) + return f(*xs) - return (self.weights, self.state) + n_in = len(argspec.args) + name = name or "Fn" + return PureLayer(_forward, n_in=n_in, n_out=n_out, name=name) - except Exception: - # Skipping 3 lines as it's always the uninteresting internal call. - name, trace = self._name, _short_traceback(skip=3) - raise LayerError(name, 'init', self._caller, - input_signature, trace) from None - def init_from_file(self, file_name, weights_only=False, input_signature=None): - """Initializes this layer and its sublayers from a pickled checkpoint. +class LayerError(Exception): + """Exception raised in the layer stack.""" + + def __init__( + self, layer_name, function_name, caller, input_signature, traceback_string + ): + self._layer_name = layer_name + self._function_name = function_name + self._caller = caller # Python inspect object with init caller info. + self._traceback = traceback_string + self._input_signature = input_signature + super().__init__(self.message) + + @property + def message(self): + """Assembles current layer context into an error message.""" + prefix = "Exception passing through layer " + prefix += "%s (in %s):\n" % (self._layer_name, self._function_name) + short_path = "[...]/" + "/".join(self._caller["filename"].split("/")[-3:]) + caller = " layer created in file %s, line %d\n" % ( + short_path, + self._caller["lineno"], + ) + shapes_str = " layer input shapes: %s\n\n" % str(self._input_signature) + return prefix + caller + shapes_str + self._traceback - In the common case (`weights_only=False`), the file must be a gziped pickled - dictionary containing items with keys `'flat_weights', `'flat_state'` and - `'input_signature'`, which are used to initialize this layer. - If `input_signature` is specified, it's used instead of the one in the file. - If `weights_only` is `True`, the dictionary does not need to have the - `'flat_state'` item and the state it not restored either. - Args: - file_name: Name/path of the pickled weights/state file. - weights_only: If `True`, initialize only the layer's weights. Else - initialize both weights and state. - input_signature: Input signature to be used instead of the one from file. - - Returns: - A `(weights, state)` tuple. - """ - with tf.io.gfile.GFile(file_name, 'rb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: - dictionary = pickle.load(gzipf) - # In the current checkpoint format, we store weights in a separate - # non-pickled file with the same name but added ".npy". - if isinstance(dictionary['flat_weights'], int): - if file_name.endswith('.pkl.gz'): - weights_path = file_name[:-6] + 'weights.npy.gz' - else: - weights_path = file_name + '.npy' - if not tf.io.gfile.exists(weights_path): # old format compatibility - weights_path = file_name + '.npy' - dictionary['flat_weights'] = np_from_file( - weights_path, compresslevel=dictionary['flat_weights']) - if input_signature is None: - input_signature = dictionary['input_signature'] - if weights_only and input_signature is not None: - self.init(input_signature) - weights_and_state_sig = self.weights_and_state_signature(input_signature) - weights, state = unflatten_weights_and_state( - dictionary['flat_weights'], dictionary['flat_state'], - weights_and_state_sig, weights_only=weights_only) - if not weights_only: - self.state = state - self.weights = weights - return (self.weights, self.state) +def flatten_weights_and_state(weights, state): + """Flatten weights and state into lists, excluding empty and cached ones.""" - def save_to_file(self, file_name, weights_only=False, input_signature=None): - """Saves this layer and its sublayers to a pickled checkpoint. + def _is_empty_weight(x): + return x is EMPTY_WEIGHTS or ( + isinstance(x, dict) and x == GET_WEIGHTS_FROM_CACHE + ) - Args: - file_name: Name/path of the pickled weights/state file. - weights_only: If `True`, save only the layer's weights. Else - save both weights and state. - input_signature: Input signature to be used. - """ - flat_weights, flat_state = flatten_weights_and_state( - self.weights, self.state) - dictionary = { - 'flat_weights': flat_weights, - } - if not weights_only: - dictionary['flat_state'] = flat_state - if input_signature is not None: - dictionary['input_signature'] = input_signature - - tmp_file_path = file_name + '._tmp_' - with tf.io.gfile.GFile(tmp_file_path, 'wb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=2) as gzipf: - pickle.dump(dictionary, gzipf, protocol=pickle.HIGHEST_PROTOCOL) - # Moving a file is much less error-prone than pickling large files. - tf.io.gfile.rename(tmp_file_path, file_name, overwrite=True) + flat_weights = [ + w for w in fastmath.tree_flatten(weights) if not _is_empty_weight(w) + ] - # End of public callable methods. - # Methods and properties below are reserved for internal use. + def _is_empty_state(x): + return x is EMPTY_STATE or (isinstance(x, dict) and x == GET_STATE_FROM_CACHE) - @property - def name(self): - """Returns the name of this layer.""" - return self._name + flat_state = [s for s in fastmath.tree_flatten(state) if not _is_empty_state(s)] + return flat_weights, flat_state - @property - def n_in(self): - """Returns how many tensors this layer expects as input.""" - return self._n_in - @property - def n_out(self): - """Returns how many tensors this layer promises as output.""" - return self._n_out +def unflatten_weights_and_state( + flat_weights, flat_state, weights_and_state_signature, weights_only=False +): + """Unflatten weights and state given their signatures.""" + weights_tree, state_tree = weights_and_state_signature + weights_to_copy = [EMPTY_WEIGHTS, GET_WEIGHTS_FROM_CACHE] + weights, _ = fastmath.tree_unflatten( + flat_weights, weights_tree, copy_from_tree=weights_to_copy + ) + state = None + if not weights_only: + states_to_copy = [EMPTY_STATE, GET_STATE_FROM_CACHE] + state, _ = fastmath.tree_unflatten( + flat_state, state_tree, copy_from_tree=states_to_copy + ) + return weights, state - @property - def sublayers(self): - """Returns a tuple containing this layer's sublayers; may be empty.""" - return self._sublayers - @property - def weights(self): - """Returns this layer's weights. +def np_to_file(list_of_nparrays, file_path, compresslevel): + """Save numpy arrays to file_path with gzipping and failure protection.""" + # Pickle to tmp file and overwrite to prevent writing partial files. + tmp_file_path = file_path + "._tmp_" + with tf.io.gfile.GFile(tmp_file_path, "wb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: + for x in list_of_nparrays: + np.save(gzipf, x, allow_pickle=False) + # Moving a file is much less error-prone than pickling large files. + tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) - Depending on the layer, the weights can be in the form of: - - an empty tuple - - a tensor (ndarray) - - a nested structure of tuples and tensors +def np_from_file(file_path, compresslevel): + """Load numpy arrays from file_path with gzipping.""" + if not tf.io.gfile.exists(file_path): + raise FileNotFoundError(file_path) + res = [] + with tf.io.gfile.GFile(file_path, "rb") as f: + with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: + while True: + try: + res.append(np.load(gzipf, allow_pickle=False)) + except Exception: # pylint: disable=broad-except + break + return res - If the layer has sublayers, the weights by convention will be - a tuple of length `len(sublayers)` containing the weights of sublayers. - Note that in this case self._weights only marks which ones are shared. - """ - if not self.sublayers: - return self._weights - else: - return tuple(layer.weights if w is None else w - for (layer, w) in zip(self.sublayers, self._weights)) - @weights.setter - def weights(self, weights): - """Sets the weights of this layer and its sublayers. +def to_list(outputs): + """Converts layer outputs to a nested list, for easier equality testing. Args: - weights: the weights to set; if layer has sublayers, weights should be - either a list or a tuple of the same length as `len(self.sublayers)` - and it will be used to set the weights of all sublayers. - """ - if isinstance(weights, dict) and weights == GET_WEIGHTS_FROM_CACHE: - return - if not self.sublayers: - self._weights = weights - else: - # When having sublayers, self._weights just marks which are cached, - # the actual weights are stored by sublayers. - self._weights = [] - for w in weights: - if isinstance(w, dict) and w == GET_WEIGHTS_FROM_CACHE: - self._weights.append(w) - else: - self._weights.append(None) - # Set sublayer weights. - n_layers = len(self.sublayers) - if len(weights) != n_layers: - raise ValueError( - f'Number of weight elements ({len(weights)}) does not equal the ' - f'number of sublayers ({n_layers}) in: {str(self)}.') - for sublayer, sublayer_weights in zip(self.sublayers, weights): - sublayer.weights = sublayer_weights - - @property - def state(self): - """Returns a tuple containing this layer's state; may be empty. - - If the layer has sublayers, the state by convention will be - a tuple of length `len(sublayers)` containing sublayer states. - Note that in this case self._state only marks which ones are shared. + outputs: A tensor or tuple/list of tensors coming from the forward + application of a layer. Each tensor is NumPy ndarray-like, which + complicates simple equality testing (e.g., via `assertEquals`): + such tensors require equality testing to use either `all` (all + elements match) or `any` (at least one element matches), which is not + directly supported in `absltest`. + + Returns: + A nested list structure containing all the output values, but now directly + testable using `assertEquals`. """ - if not self.sublayers: - return self._state + if isinstance(outputs, (list, tuple)): + return [y.tolist() for y in outputs] else: - return tuple(layer.state if s is None else s - for (layer, s) in zip(self.sublayers, self._state)) + return outputs.tolist() - @state.setter - def state(self, state): - """Sets the state of this layer and its sublayers. - Args: - state: the state to set; if layer has sublayers, state should be - either a list or a tuple of the same length as `len(self.sublayers)` - and it will be used to set the state of all sublayers. - """ - if isinstance(state, dict) and state == GET_STATE_FROM_CACHE: - return - if not self._sublayers: - self._state = state - else: - # When having sublayers, self._state just marks which are cached, - # the actual weights are stored by sublayers. - self._state = [] - for s in state: - if isinstance(s, dict) and s == GET_STATE_FROM_CACHE: - self._state.append(s) - else: - self._state.append(None) - # Set sublayer states. - n_layers = len(self.sublayers) - if len(state) != n_layers: +def _validate_forward_input(x, n_in): + if n_in != 1: + if not isinstance(x, (tuple, list)): + raise TypeError( + f"Expected input to be a tuple or list; instead got {type(x)}." + ) + + if len(x) != n_in: raise ValueError( - f'Number of state elements ({len(state)}) does not equal the ' - f'number of sublayers ({n_layers}) in: {str(self)}.') - for sublayer, sublayer_state in zip(self.sublayers, state): - sublayer.state = sublayer_state - - def weights_and_state_signature(self, input_signature, unsafe=False): - """Return a pair containing the signatures of weights and state.""" - rng, state, weights = self.rng, self.state, self.weights - abstract_init = fastmath.abstract_eval(self.init) - sig = abstract_init(input_signature) - self.rng = rng - if not unsafe: - self.state, self.weights = state, weights - return sig - - @property - def rng(self): - """Returns this layer's current single-use random number generator. - - Code that wants to base random samples on this generator must explicitly - split off new generators from it. (See, for example, the `rng` setter code - below.) - """ - if self._rng is None: - # One-time initialization from backend-neutral seed int. - self._rng = fastmath.random.get_prng(self._rng_seed_int) - return self._rng - - @rng.setter - def rng(self, rng): - """Sets the rng (JAX PRNG key) for this layer and sublayers, recursively.""" - self._rng = rng - sublayers = self.sublayers - if sublayers: - rngs = fastmath.random.split(rng, len(sublayers)) - for sublayer, rng in zip(sublayers, rngs): - sublayer.rng = rng - - def _clear_init_cache(self): - self._init_cached = False - for sublayer in self.sublayers: - sublayer._clear_init_cache() # pylint: disable=protected-access - - def pure_fn(self, x, weights, state, rng, use_cache=False): - """Applies this layer as a pure function with no optional args. - - This method exposes the layer's computation as a pure function. This is - especially useful for JIT compilation. Do not override, use `forward` - instead. + f"Input tuple length ({len(x)}) does not equal required " + f"number of inputs ({n_in})." + ) - Args: - x: Zero or more input tensors, packaged as described in the `Layer` class - docstring. - weights: A tuple or list of trainable weights, with one element for this - layer if this layer has no sublayers, or one for each sublayer if - this layer has sublayers. If a layer (or sublayer) has no trainable - weights, the corresponding weights element is an empty tuple. - state: Layer-specific non-parameter state that can update between batches. - rng: Single-use random number generator (JAX PRNG key). - use_cache: if `True`, cache weights and state in the layer object; used - to implement layer sharing in combinators. - Returns: - A tuple of `(tensors, state)`. The tensors match the number (`n_out`) - promised by this layer, and are packaged as described in the `Layer` - class docstring. - """ - try: - old_weights, old_state, old_rng = self.weights, self.state, self.rng - self._rng = rng - # The isinstance check is only needed when == is overloaded, as in TF. - if (isinstance(weights, dict) and isinstance(state, dict) and - weights == GET_WEIGHTS_FROM_CACHE and state == GET_STATE_FROM_CACHE): - was_cached = True - weights = self.weights - state = self.state - else: - # In this case, we're called for the first time: cache weights. - was_cached = False - self.weights, self.state = weights, state - - # If weights are sharded across multiple devices, unshard before forward. - sharded_weights, weights_were_unsharded = weights, False - if N_WEIGHTS_SHARDS > 1 and not self.sublayers: - self.weights, weights_were_unsharded = unshard_in_pmap( - weights, N_WEIGHTS_SHARDS) - - if not self.has_backward: - outputs = self.forward(x) - s = self.state - else: - outputs, s = self._do_custom_gradients(x) - self.state = s - self._rng = old_rng - if weights_were_unsharded: # only store a shard of weights if sharded - self.weights = sharded_weights - - if not use_cache: - self.weights, self.state = old_weights, old_state - if was_cached: # If the layer was shared, return a state marking this. - s = GET_STATE_FROM_CACHE - return outputs, s - - except Exception: - # Skipping 3 lines as it's always the uninteresting internal call. - name, trace = self._name, _short_traceback(skip=3) - raise LayerError(name, 'pure_fn', - self._caller, signature(x), trace) from None - - def output_signature(self, input_signature): - """Returns output signature this layer would give for `input_signature`.""" - return self._forward_abstract(input_signature)[0] # output only, not state - - def _forward_abstract(self, input_signature): - """Computes shapes and dtypes this layer would produce in a forward pass. +def _is_empty(container): + if container is None: + raise ValueError('Argument "container" is None.') + return ( + isinstance(container, (list, tuple)) and len(container) == 0 + ) # pylint: disable=g-explicit-length-test - Args: - input_signature: `ShapeDtype` instance (if this layer takes one input) - or list/tuple of `ShapeDtype` instances. - Returns: - Tuple of (output, state). +def _find_frame(frame): + """Find the frame with the caller on the stack.""" - The output part of the tuple is a `ShapeDtype` instance representing the - shape and type of the output (if this layer has one output) or a tuple - of `ShapeDtype` instances (if this layer has more than one output). - """ - try: - # Note: By using rng_signature in place of an rng, we avoid computing and - # permanently storing in global memory a large number of dropout masks. - # TODO(jonni): Check if using an rng still carries this cost. - dummy_rng = fastmath.random.get_prng(0) - rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype) - weights_signature = nested_map(signature, self.weights) - state_signature = nested_map(signature, self.state) - forward_infer_shapes = fastmath.abstract_eval(self.pure_fn) - return forward_infer_shapes( - input_signature, weights_signature, state_signature, rng_signature) - except Exception: - # TODO(lukaszkaiser): the choice of 7 is a heuristic, can we automate it? - # Skipping 7 lines which are all JAX abstract'ifying wrappers. - name, trace = self._name, _short_traceback(skip=7) - raise LayerError(name, '_forward_abstract', self._caller, input_signature, - trace) from None - - # pylint: disable=protected-access - def _do_custom_gradients(self, x): - """Calls this layer for a forward pass, but with custom gradients.""" - - def _f(state, rng, y, weights): - old_weights, old_state, old_rng = self.weights, self.state, self._rng - self.weights, self.state, self._rng = weights, state, rng - res = self.forward(y) - s = self.state - self.weights, self.state, self._rng = old_weights, old_state, old_rng - return res, s - - def _f_fwd(state, rng, y, weights): - old_weights, old_state, old_rng = self.weights, self.state, self._rng - self.weights, self.state, self._rng = weights, state, rng - res = self.forward(y) - s = self.state - self.weights, self.state, self._rng = old_weights, old_state, old_rng - return (res, s), (state, rng, y, res, weights, s) - - def _f_bwd(residual, grad): - """Custom gradient function.""" - state, rng, y, output, weights, new_state = residual - grad = grad[0] # Ignore dummy gradient wrt state. - out = self.backward(y, output, grad, weights, state, new_state, rng) - return (None, None, *out) - - do_forward = fastmath.custom_vjp(_f, _f_fwd, _f_bwd, nondiff_argnums=(0, 1)) - - output, state = do_forward(self.state, self._rng, x, self.weights) - return output, state - - def _settable_attrs(self): - """We only allow to set these attributes in Trax layers to prevent typos.""" - return ('weights', 'state', 'rng') - - def __setattr__(self, attr, value): - """Sets class attributes and protects from typos. - - In Trax layers, we only allow to set the following public attributes:: - - - weights - - state - - rng - - This function prevents from setting other public attributes to avoid typos, - for example, this is not possible and would be without this function:: - - [typo] layer.weighs = some_tensor - - If you need to set other public attributes in a derived class (which we - do not recommend as in almost all cases it suffices to use a private - attribute), override self._settable_attrs to include the attribute name. + def _dirname_is_trax_layers_or_gin(frame): + """Skip frames coming from trax/layers or .../gin.""" + try: + dirname1 = frame.f_code.co_filename.split("/")[-3] + dirname2 = frame.f_code.co_filename.split("/")[-2] + return (dirname1 == "trax" and dirname2 == "layers") or dirname2 == "gin" + except IndexError: + return False - Args: - attr: Name of the attribute to be set. - value: Value to be assigned to the attribute. - """ - if attr[0] != '_' and attr not in self._settable_attrs(): - raise ValueError( - f'Trax layers only allow to set {self._settable_attrs()} as public ' - f'attribues, not {attr}.') - else: - super().__setattr__(attr, value) + while _dirname_is_trax_layers_or_gin(frame): + frame = frame.f_back + return frame -class PureLayer(Layer): - """Pure function from inputs to outputs, packaged as neural network layer. +def _shorten_file_path(line): + """Shorten file path in error lines for more readable tracebacks.""" + start = line.lower().find("file") + if start < 0: + return line + first_quote = line.find('"', start) + if first_quote < 0: + return line + second_quote = line.find('"', first_quote + 1) + if second_quote < 0: + return line + path = line[first_quote + 1 : second_quote] + new_path = "/".join(path.split("/")[-3:]) + return line[:first_quote] + "[...]/" + new_path + line[second_quote + 1 :] - The `PureLayer` class represents the simplest kinds of layers: layers with - no trainable weights and no randomness, hence pure functions from inputs to - outputs. - """ - def __init__(self, forward_fn, n_in=1, n_out=1, name='PureLayer'): - """Creates an unconnected `PureLayer` instance. +def _short_traceback(skip=3): + """Cleaned-up form of traceback.""" + counter, res = 0, [] + # Skipping 3 lines by default: the top (useless) and self-call. + # In python 3, we need to set chain to False (it doesn't exist in python 2). + lines = traceback.format_exc(chain=False).splitlines()[ + skip: + ] # pylint: disable=unexpected-keyword-arg + for l in lines: + if l.startswith("trax.layers.base.LayerError"): + l = l[len("trax.layers.base.") :] # Remove the trax.layers.base prefix. + res.append(_shorten_file_path(l)) + if counter % 2 == 1: + res.append("") + counter += 1 + # If we see a LayerError, the traceback has already been processed. + if l.startswith("LayerError"): + # Skip 4 back except last as these are internal base-layer calls. + res = res[:-4] + [res[-1]] + res += lines[counter:] + break + return "\n".join(res) - Args: - forward_fn: Pure function from input tensors to output tensors, where - inputs and outputs are packaged as specified for `forward`. - n_in: Number of inputs expected by this layer. - n_out: Number of outputs promised by this layer. - name: Class-like name for this layer; for use only in debugging. - """ - super().__init__(n_in, n_out, name) - self._forward_fn = forward_fn - def forward(self, inputs): - """Overrides `Layer.forward`. +def _random_values(input_signature, rng): + """Creates random floats or ints of the given shape. Args: - inputs: Zero or more input tensors, packaged as described in the `Layer` - class docstring. + input_signature: A `ShapeDtype` instance (if `layer_obj` takes one input) + or a list/tuple of ShapeDtype instances. + rng: Single-use random number generator (JAX PRNG key). Returns: - Zero or more output tensors, packaged as described in the `Layer` class - docstring. + Random values with the shape and type specified. """ - _validate_forward_input(inputs, self.n_in) - raw_output = self._forward_fn(inputs) - output = () if _is_empty(raw_output) else raw_output - return output - - -def Fn(name, f, n_out=1): # pylint: disable=invalid-name - """Returns a layer with no weights that applies the function `f`. - - `f` can take and return any number of arguments, and takes only positional - arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`). - The following, for example, would create a layer that takes two inputs and - returns two outputs -- element-wise sums and maxima: - - `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)` - - The layer's number of inputs (`n_in`) is automatically set to number of - positional arguments in `f`, but you must explicitly set the number of - outputs (`n_out`) whenever it's not the default value 1. - - Args: - name: Class-like name for the resulting layer; for use in debugging. - f: Pure function from input tensors to output tensors, where each input - tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`. - Output tensors must be packaged as specified in the `Layer` class - docstring. - n_out: Number of outputs promised by the layer; default value 1. - - Returns: - Layer executing the function `f`. - """ - argspec = inspect.getfullargspec(f) - if argspec.defaults is not None: - raise ValueError('Function has default arguments (not allowed).') - if argspec.varkw is not None: - raise ValueError('Function has keyword arguments (not allowed).') - if argspec.varargs is not None: - raise ValueError('Function has variable args (not allowed).') - - def _forward(xs): # pylint: disable=invalid-name - if not isinstance(xs, (tuple, list)): - xs = (xs,) - return f(*xs) - - n_in = len(argspec.args) - name = name or 'Fn' - return PureLayer(_forward, n_in=n_in, n_out=n_out, name=name) - - -class LayerError(Exception): - """Exception raised in the layer stack.""" - - def __init__(self, layer_name, function_name, caller, - input_signature, traceback_string): - self._layer_name = layer_name - self._function_name = function_name - self._caller = caller # Python inspect object with init caller info. - self._traceback = traceback_string - self._input_signature = input_signature - super().__init__(self.message) - - @property - def message(self): - """Assembles current layer context into an error message.""" - prefix = 'Exception passing through layer ' - prefix += '%s (in %s):\n' % (self._layer_name, self._function_name) - short_path = '[...]/' + '/'.join( - self._caller['filename'].split('/')[-3:]) - caller = ' layer created in file %s, line %d\n' % (short_path, - self._caller['lineno']) - shapes_str = ' layer input shapes: %s\n\n' % str(self._input_signature) - return prefix + caller + shapes_str + self._traceback - - -def flatten_weights_and_state(weights, state): - """Flatten weights and state into lists, excluding empty and cached ones.""" - def _is_empty_weight(x): - return (x is EMPTY_WEIGHTS or - (isinstance(x, dict) and x == GET_WEIGHTS_FROM_CACHE)) - flat_weights = [w for w in fastmath.tree_flatten(weights) - if not _is_empty_weight(w)] - def _is_empty_state(x): - return (x is EMPTY_STATE or - (isinstance(x, dict) and x == GET_STATE_FROM_CACHE)) - flat_state = [s for s in fastmath.tree_flatten(state) - if not _is_empty_state(s)] - return flat_weights, flat_state - - -def unflatten_weights_and_state( - flat_weights, flat_state, weights_and_state_signature, weights_only=False): - """Unflatten weights and state given their signatures.""" - weights_tree, state_tree = weights_and_state_signature - weights_to_copy = [EMPTY_WEIGHTS, GET_WEIGHTS_FROM_CACHE] - weights, _ = fastmath.tree_unflatten(flat_weights, weights_tree, - copy_from_tree=weights_to_copy) - state = None - if not weights_only: - states_to_copy = [EMPTY_STATE, GET_STATE_FROM_CACHE] - state, _ = fastmath.tree_unflatten(flat_state, state_tree, - copy_from_tree=states_to_copy) - return weights, state - + if isinstance(input_signature, ShapeDtype): + shape, dtype = input_signature.shape, input_signature.dtype + if np.issubdtype(dtype, np.integer): + return fastmath.random.bernoulli(rng, 0.5, shape).astype(np.int32) + else: + return fastmath.random.uniform(rng, shape, minval=-1.0, maxval=1.0) + elif isinstance(input_signature, (list, tuple)): + return tuple(_random_values(x, rng) for x in input_signature) + else: + raise TypeError(type(input_signature)) -def np_to_file(list_of_nparrays, file_path, compresslevel): - """Save numpy arrays to file_path with gzipping and failure protection.""" - # Pickle to tmp file and overwrite to prevent writing partial files. - tmp_file_path = file_path + '._tmp_' - with tf.io.gfile.GFile(tmp_file_path, 'wb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: - for x in list_of_nparrays: - np.save(gzipf, x, allow_pickle=False) - # Moving a file is much less error-prone than pickling large files. - tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) +def _shapes(x): + """Gets a structure of shapes for a structure of nested arrays.""" -def np_from_file(file_path, compresslevel): - """Load numpy arrays from file_path with gzipping.""" - if not tf.io.gfile.exists(file_path): - raise FileNotFoundError(file_path) - res = [] - with tf.io.gfile.GFile(file_path, 'rb') as f: - with gzip.GzipFile(fileobj=f, compresslevel=compresslevel) as gzipf: - while True: + def shape(x): try: - res.append(np.load(gzipf, allow_pickle=False)) + return tuple([int(i) for i in x.shape]) except Exception: # pylint: disable=broad-except - break - return res + return () + return tuple(nested_map(shape, x)) -def to_list(outputs): - """Converts layer outputs to a nested list, for easier equality testing. - - Args: - outputs: A tensor or tuple/list of tensors coming from the forward - application of a layer. Each tensor is NumPy ndarray-like, which - complicates simple equality testing (e.g., via `assertEquals`): - such tensors require equality testing to use either `all` (all - elements match) or `any` (at least one element matches), which is not - directly supported in `absltest`. - - Returns: - A nested list structure containing all the output values, but now directly - testable using `assertEquals`. - """ - if isinstance(outputs, (list, tuple)): - return [y.tolist() for y in outputs] - else: - return outputs.tolist() +@functools.partial(fastmath.pmap, axis_name="batch") +def _axis_index(unused_x): + """Return the axis indices.""" + return jax.lax.axis_index("batch") -def _validate_forward_input(x, n_in): - if n_in != 1: - if not isinstance(x, (tuple, list)): - raise TypeError( - f'Expected input to be a tuple or list; instead got {type(x)}.') - if len(x) != n_in: - raise ValueError(f'Input tuple length ({len(x)}) does not equal required ' - f'number of inputs ({n_in}).') - - -def _is_empty(container): - if container is None: - raise ValueError('Argument "container" is None.') - return isinstance(container, (list, tuple)) and len(container) == 0 # pylint: disable=g-explicit-length-test - - -def _find_frame(frame): - """Find the frame with the caller on the stack.""" - def _dirname_is_trax_layers_or_gin(frame): - """Skip frames coming from trax/layers or .../gin.""" - try: - dirname1 = frame.f_code.co_filename.split('/')[-3] - dirname2 = frame.f_code.co_filename.split('/')[-2] - return (dirname1 == 'trax' and dirname2 == 'layers') or dirname2 == 'gin' - except IndexError: - return False - - while _dirname_is_trax_layers_or_gin(frame): - frame = frame.f_back - return frame +def _axis_to_shard_heuristic(shape): + """Chooses an axis to shard on - a simple heuristic to be revisited.""" + axis = 0 if len(shape) < 3 else -1 + return axis -def _shorten_file_path(line): - """Shorten file path in error lines for more readable tracebacks.""" - start = line.lower().find('file') - if start < 0: - return line - first_quote = line.find('"', start) - if first_quote < 0: - return line - second_quote = line.find('"', first_quote + 1) - if second_quote < 0: - return line - path = line[first_quote + 1:second_quote] - new_path = '/'.join(path.split('/')[-3:]) - return line[:first_quote] + '[...]/' + new_path + line[second_quote + 1:] +def shard(tensors, n_shards=None): + """Shard tensors across n_shards.""" + n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards + indices = _axis_index(np.zeros(fastmath.local_device_count())) -def _short_traceback(skip=3): - """Cleaned-up form of traceback.""" - counter, res = 0, [] - # Skipping 3 lines by default: the top (useless) and self-call. - # In python 3, we need to set chain to False (it doesn't exist in python 2). - lines = traceback.format_exc(chain=False).splitlines()[skip:] # pylint: disable=unexpected-keyword-arg - for l in lines: - if l.startswith('trax.layers.base.LayerError'): - l = l[len('trax.layers.base.'):] # Remove the trax.layers.base prefix. - res.append(_shorten_file_path(l)) - if counter % 2 == 1: - res.append('') - counter += 1 - # If we see a LayerError, the traceback has already been processed. - if l.startswith('LayerError'): - # Skip 4 back except last as these are internal base-layer calls. - res = res[:-4] + [res[-1]] - res += lines[counter:] - break - return '\n'.join(res) + def _shard_fn(x): + axis = _axis_to_shard_heuristic(x.shape) + if int(x.shape[axis]) % n_shards != 0: + raise ValueError(f"Cannot split x with shape {x.shape} into {n_shards}.") + split_x = jnp.split(x, n_shards, axis=axis) + split_x = [split_x[i % n_shards] for i in indices] + return np.stack(split_x, axis=0) + return fastmath.nested_map(_shard_fn, tensors) -def _random_values(input_signature, rng): - """Creates random floats or ints of the given shape. - - Args: - input_signature: A `ShapeDtype` instance (if `layer_obj` takes one input) - or a list/tuple of ShapeDtype instances. - rng: Single-use random number generator (JAX PRNG key). - - Returns: - Random values with the shape and type specified. - """ - if isinstance(input_signature, ShapeDtype): - shape, dtype = input_signature.shape, input_signature.dtype - if np.issubdtype(dtype, np.integer): - return fastmath.random.bernoulli(rng, 0.5, shape).astype(np.int32) - else: - return fastmath.random.uniform(rng, shape, minval=-1.0, maxval=1.0) - elif isinstance(input_signature, (list, tuple)): - return tuple(_random_values(x, rng) for x in input_signature) - else: - raise TypeError(type(input_signature)) +def unshard_in_pmap(tensors, n_shards): + """Unshard tensors that were sharded into n_shards (call inside pmap).""" + groups = [ + [n_shards * i + d for d in range(n_shards)] + for i in range(fastmath.global_device_count() // n_shards) + ] + + def _unshard_fn(x): + y = jax.lax.all_gather(x, "batch", axis_index_groups=groups) + split_y = jnp.split(y, n_shards, axis=0) + split_y = [jnp.squeeze(sy, axis=0) for sy in split_y] + axis = _axis_to_shard_heuristic(split_y[0].shape) + return jnp.concatenate(split_y, axis=axis) -def _shapes(x): - """Gets a structure of shapes for a structure of nested arrays.""" - def shape(x): try: - return tuple([int(i) for i in x.shape]) - except Exception: # pylint: disable=broad-except - return () - return tuple(nested_map(shape, x)) - - -@functools.partial(fastmath.pmap, axis_name='batch') -def _axis_index(unused_x): - """Return the axis indices.""" - return jax.lax.axis_index('batch') - - -def _axis_to_shard_heuristic(shape): - """Chooses an axis to shard on - a simple heuristic to be revisited.""" - axis = 0 if len(shape) < 3 else -1 - return axis + jax.lax.axis_index("batch") # will throw if not in pmap, e.g., on init + res = fastmath.nested_map(_unshard_fn, tensors) + return res, True + except NameError: # thrown from axis_index above + return tensors, False -def shard(tensors, n_shards=None): - """Shard tensors across n_shards.""" - n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards - indices = _axis_index(np.zeros(fastmath.local_device_count())) - def _shard_fn(x): - axis = _axis_to_shard_heuristic(x.shape) - if int(x.shape[axis]) % n_shards != 0: - raise ValueError(f'Cannot split x with shape {x.shape} into {n_shards}.') - split_x = jnp.split(x, n_shards, axis=axis) - split_x = [split_x[i % n_shards] for i in indices] - return np.stack(split_x, axis=0) - return fastmath.nested_map(_shard_fn, tensors) - - -def unshard_in_pmap(tensors, n_shards): - """Unshard tensors that were sharded into n_shards (call inside pmap).""" - groups = [[n_shards * i + d for d in range(n_shards)] - for i in range(fastmath.global_device_count() // n_shards)] - def _unshard_fn(x): - y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups) - split_y = jnp.split(y, n_shards, axis=0) - split_y = [jnp.squeeze(sy, axis=0) for sy in split_y] - axis = _axis_to_shard_heuristic(split_y[0].shape) - return jnp.concatenate(split_y, axis=axis) - try: - jax.lax.axis_index('batch') # will throw if not in pmap, e.g., on init - res = fastmath.nested_map(_unshard_fn, tensors) - return res, True - except NameError: # thrown from axis_index above - return tensors, False - - -@functools.partial(fastmath.pmap, axis_name='batch') +@functools.partial(fastmath.pmap, axis_name="batch") def _all_gather(x, groups): - return jax.lax.all_gather(x, 'batch', axis_index_groups=groups) + return jax.lax.all_gather(x, "batch", axis_index_groups=groups) def unshard(tensors, n_shards=None): - """Unshard tensors that were sharded into n_shards (outside of pmap).""" - n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards - def _unshard_fn(x): - # We use numpy here to put the large un-sharded arrays in CPU memory. - # For unsharding on accelerators use ushard_in_pmap above and pmap it. - split_y = np.split(np.asarray(x), n_shards, axis=0) - split_y = [np.squeeze(sy, axis=0) for sy in split_y] - axis = _axis_to_shard_heuristic(split_y[0].shape) - return np.concatenate(split_y, axis=axis) - return fastmath.nested_map(_unshard_fn, tensors) + """Unshard tensors that were sharded into n_shards (outside of pmap).""" + n_shards = N_WEIGHTS_SHARDS if n_shards is None else n_shards + + def _unshard_fn(x): + # We use numpy here to put the large un-sharded arrays in CPU memory. + # For unsharding on accelerators use ushard_in_pmap above and pmap it. + split_y = np.split(np.asarray(x), n_shards, axis=0) + split_y = [np.squeeze(sy, axis=0) for sy in split_y] + axis = _axis_to_shard_heuristic(split_y[0].shape) + return np.concatenate(split_y, axis=axis) + + return fastmath.nested_map(_unshard_fn, tensors) diff --git a/trax/layers/base_test.py b/trax/layers/base_test.py deleted file mode 100644 index b1abf2786..000000000 --- a/trax/layers/base_test.py +++ /dev/null @@ -1,223 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Trax base layer classes and generic layer-creating functions.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np - -from trax import fastmath -from trax import shapes -from trax.fastmath import numpy as jnp -import trax.layers as tl - -BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP] -CUSTOM_GRAD_BACKENDS = [fastmath.Backend.JAX] # TODO(afrozm): del after TF 2.3 - - -class BaseLayerTest(parameterized.TestCase): - - def test_call_raises_error(self): - layer = tl.Layer() - x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) - with self.assertRaisesRegex(tl.LayerError, 'NotImplementedError'): - _ = layer(x) - - def test_set_weighs_raises_error(self): - layer = tl.Layer() - layer.weights = 1.0 # can assign weights - with self.assertRaisesRegex(ValueError, 'weighs'): - layer.weighs = 1.0 # cannot assign weighs - - def test_forward_raises_error(self): - layer = tl.Layer() - x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) - with self.assertRaises(NotImplementedError): - _ = layer.forward(x) - - def test_init_returns_empty_weights_and_state(self): - layer = tl.Layer() - input_signature = shapes.ShapeDtype((2, 5)) - weights, state = layer.init(input_signature) - self.assertEmpty(weights) - self.assertEmpty(state) - - def test_output_signature_no_weights(self): - shape_2_3_5 = shapes.ShapeDtype((2, 3, 5)) - input_signature = (shape_2_3_5, shape_2_3_5) - layer = tl.Fn('2in1out', lambda x, y: x + y) - output_signature = layer.output_signature(input_signature) - self.assertEqual(output_signature, shape_2_3_5) - - shape_5_7 = shapes.ShapeDtype((5, 7)) - input_signature = shape_5_7 - layer = tl.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3) - output_signature = layer.output_signature(input_signature) - self.assertEqual(output_signature, (shape_5_7, shape_5_7, shape_5_7)) - - # TODO(jonni): Define/test behavior of output signature for layers w/weights. - - @parameterized.named_parameters( - [('_' + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) - def test_custom_zero_grad(self, backend): - - class IdWithZeroGrad(tl.Layer): - - def forward(self, x): - return x - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng): - return (jnp.zeros_like(grad), ()) - - with fastmath.use_backend(backend): - layer = IdWithZeroGrad() - rng = fastmath.random.get_prng(0) - input_signature = shapes.ShapeDtype((9, 17)) - random_input = fastmath.random.uniform( - rng, input_signature.shape, minval=-1.0, maxval=1.0) - layer.init(input_signature) - f = lambda x: jnp.mean(layer(x)) - grad = fastmath.grad(f)(random_input) - self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. - self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0. - - @parameterized.named_parameters( - [('_' + b.value, b) for b in CUSTOM_GRAD_BACKENDS]) - def test_custom_id_grad(self, backend): - - class IdWithIdGrad(tl.Layer): - - def forward(self, x): - return x - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng): - return (inputs, ()) - - with fastmath.use_backend(backend): - layer = IdWithIdGrad() - rng = fastmath.random.get_prng(0) - input_signature = shapes.ShapeDtype((9, 17)) - random_input = fastmath.random.uniform( - rng, input_signature.shape, minval=-1.0, maxval=1.0) - layer.init(input_signature) - f = lambda x: jnp.mean(layer(x)) - grad = fastmath.grad(f)(random_input) - self.assertEqual(grad.shape, (9, 17)) # Gradient for each input. - self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input. - - def test_weights_and_state_signature(self): - - class MyLayer(tl.Layer): - - def init_weights_and_state(self, input_signature): - self.weights = jnp.zeros((2, 3)) - self.state = jnp.ones(input_signature.shape) - - def forward(self, inputs): - return self.weights + self.state - - layer = MyLayer() - w, s = layer.weights_and_state_signature(jnp.zeros((3, 4))) - self.assertEqual(w.shape, (2, 3)) - self.assertEqual(s.shape, (3, 4)) - - def test_custom_name(self): - layer = tl.Layer() - self.assertIn('Layer', str(layer)) - self.assertNotIn('CustomLayer', str(layer)) - - layer = tl.Layer(name='CustomLayer') - self.assertIn('CustomLayer', str(layer)) - - -class PureLayerTest(absltest.TestCase): - - def test_forward(self): - layer = tl.PureLayer(lambda x: 2 * x) - - # Use Layer.__call__. - in_0 = np.array([1, 2]) - out_0 = layer(in_0, weights=jnp.zeros((2, 3))) - self.assertEqual(out_0.tolist(), [2, 4]) - self.assertEmpty(layer.weights) - - # Use PureLayer.forward. - in_1 = np.array([3, 4]) - out_1 = layer.forward(in_1) - self.assertEqual(out_1.tolist(), [6, 8]) - - # Use Layer.pure_fn - in_2 = np.array([5, 6]) - out_2, _ = layer.pure_fn(in_2, tl.EMPTY_WEIGHTS, tl.EMPTY_WEIGHTS, None) - self.assertEqual(out_2.tolist(), [10, 12]) - - -class FnTest(absltest.TestCase): - - def test_bad_f_has_default_arg(self): - with self.assertRaisesRegex(ValueError, 'default arg'): - _ = tl.Fn('', lambda x, sth=None: x) - - def test_bad_f_has_keyword_arg(self): - with self.assertRaisesRegex(ValueError, 'keyword arg'): - _ = tl.Fn('', lambda x, **kwargs: x) - - def test_bad_f_has_variable_arg(self): - with self.assertRaisesRegex(ValueError, 'variable arg'): - _ = tl.Fn('', lambda *args: args[0]) - - def test_forward(self): - layer = tl.Fn( - 'SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2) - - x0 = np.array([1, 2, 3, 4, 5]) - x1 = np.array([10, 20, 30, 40, 50]) - - y0, y1 = layer((x0, x1)) - self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55]) - self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50]) - - y2, y3 = layer.forward((x0, x1)) - self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55]) - self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50]) - - (y4, y5), state = layer.pure_fn((x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE, - None) - self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55]) - self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50]) - self.assertEqual(state, tl.EMPTY_STATE) - - def test_weights_state(self): - layer = tl.Fn( - '2in2out', - lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), - n_out=2) - layer.init_weights_and_state(None) - self.assertEmpty(layer.weights) - self.assertEmpty(layer.state) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/combinators.py b/trax/layers/combinators.py index ddde61f8f..762993122 100644 --- a/trax/layers/combinators.py +++ b/trax/layers/combinators.py @@ -21,991 +21,1083 @@ from trax.fastmath import numpy as jnp from trax.layers import base from trax.layers.base import Fn -from trax.shapes import ShapeDtype +from trax.utils.shapes import ShapeDtype class Serial(base.Layer): - """Combinator that applies layers serially (by function composition). - - This combinator is commonly used to construct deep networks, e.g., like this:: - - mlp = tl.Serial( - tl.Dense(128), - tl.Relu(), - tl.Dense(10), - ) - - A Serial combinator uses stack semantics to manage data for its sublayers. - Each sublayer sees only the inputs it needs and returns only the outputs it - has generated. The sublayers interact via the data stack. For instance, a - sublayer k, following sublayer j, gets called with the data stack in the - state left after layer j has applied. The Serial combinator then: - - - takes n_in items off the top of the stack (n_in = k.n_in) and calls - layer k, passing those items as arguments; and - - - takes layer k's n_out return values (n_out = k.n_out) and pushes - them onto the data stack. - - A Serial instance with no sublayers acts as a special-case (but useful) - 1-input 1-output no-op. - """ - - def __init__(self, *sublayers, name=None, sublayers_to_print=None): - super().__init__( - name=name, sublayers_to_print=sublayers_to_print) - - sublayers = _ensure_flat(sublayers) - self._sublayers = sublayers - self._n_layers = len(sublayers) - - if sublayers: - self._n_in, self._n_out = self._n_inputs_n_outputs(sublayers) - self._weights = tuple(None for l in sublayers) - self._state = tuple(None for l in sublayers) - - def forward(self, xs): - """Executes this layer as part of a forward pass through the model.""" - self._validate_forward_inputs(xs) - if not self.sublayers: # No-op: outputs = inputs - return xs - - state, weights = self.state, self.weights - rngs = _split_rngs(self.rng, self._n_layers) - stack = xs - new_state = [] - n_layers = self._n_layers - if len(weights) != n_layers: - raise ValueError( - f'Number of weight elements ({len(weights)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(state) != n_layers: - raise ValueError( - f'Number of state elements ({len(state)}) does not equal ' - f'number of sublayers ({n_layers}).') - - for layer, w, s, rng in zip(self.sublayers, weights, state, rngs): - inputs = inputs_from_stack(stack, layer.n_in) - outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True) - stack = outputs_onto_stack(outputs, stack, layer.n_in) - new_state.append(s) - self.state = tuple(new_state) - return stack - - # pylint: disable=protected-access - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - weights = [] - states = [] - # In the code below, stack, inputs, and outputs are abstract (shapes and - # dtypes), but weights and states are non-abstract actual values. - stack = input_signature - for sublayer in self.sublayers: - inputs = inputs_from_stack(stack, sublayer.n_in) - weights_or_cache_marker, state_or_cache_marker = ( - sublayer.init(inputs, use_cache=True)) - outputs, _ = sublayer._forward_abstract(inputs) - stack = outputs_onto_stack(outputs, stack, sublayer.n_in) - - weights.append(weights_or_cache_marker) - states.append(state_or_cache_marker) - self.state = tuple(states) - self.weights = tuple(weights) - # pylint: enable=protected-access - - def _n_inputs_n_outputs(self, layers): - del self - running_max = 0 - running_total = 0 - for layer in layers: - running_total += layer.n_in - running_max = max(running_max, running_total) - running_total -= layer.n_out - return running_max, (running_max - running_total) - - def _validate_forward_inputs(self, xs): - if not isinstance(xs, (tuple, list)) and self._n_in != 1: - raise TypeError(f'Serial.forward input must be a tuple or list; ' - f'instead got {type(xs)}.') - # TODO(jonni): Include full xs (or shape) in error message? - len_xs = 1 if isinstance(xs, jnp.ndarray) else len(xs) - if len_xs < self.n_in: - raise ValueError( - f'Number of inputs ({len(xs)}) to Serial.forward less than n_in ' - f'({self.n_in}).') + """Combinator that applies layers serially (by function composition). + This combinator is commonly used to construct deep networks, e.g., like this:: -class Parallel(base.Layer): - """Combinator that applies a list of layers in parallel to its inputs. + mlp = tl.Serial( + tl.Dense(128), + tl.Relu(), + tl.Dense(10), + ) - Layers in the list apply to successive spans of inputs, where the spans are - determined how many inputs each layer takes. The resulting output is the - (flattened) concatenation of the respective layer outputs. + A Serial combinator uses stack semantics to manage data for its sublayers. + Each sublayer sees only the inputs it needs and returns only the outputs it + has generated. The sublayers interact via the data stack. For instance, a + sublayer k, following sublayer j, gets called with the data stack in the + state left after layer j has applied. The Serial combinator then: - For example, suppose one has three layers: + - takes n_in items off the top of the stack (n_in = k.n_in) and calls + layer k, passing those items as arguments; and - - F: 1 input, 1 output - - G: 3 inputs, 1 output - - H: 2 inputs, 2 outputs (h1, h2) + - takes layer k's n_out return values (n_out = k.n_out) and pushes + them onto the data stack. - Then Parallel(F, G, H) will take 6 inputs and give 4 outputs: + A Serial instance with no sublayers acts as a special-case (but useful) + 1-input 1-output no-op. + """ - - inputs: a, b, c, d, e, f - - outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f) + def __init__(self, *sublayers, name=None, sublayers_to_print=None): + super().__init__(name=name, sublayers_to_print=sublayers_to_print) + + sublayers = _ensure_flat(sublayers) + self._sublayers = sublayers + self._n_layers = len(sublayers) + + if sublayers: + self._n_in, self._n_out = self._n_inputs_n_outputs(sublayers) + self._weights = tuple(None for l in sublayers) + self._state = tuple(None for l in sublayers) + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model.""" + if not isinstance(xs, (tuple, list)): + xs = (xs,) + + # The input should be a flat single tuple without nested tuples + xs = self.flatten_tuple(xs) + + self._validate_forward_inputs(xs) + + if not self.sublayers: # No-op: outputs = inputs + return xs + + state, weights = self.state, self.weights + rngs = _split_rngs(self.rng, self._n_layers) + stack = xs + new_state = [] + n_layers = self._n_layers + if len(weights) != n_layers: + raise ValueError( + f"Number of weight elements ({len(weights)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(state) != n_layers: + raise ValueError( + f"Number of state elements ({len(state)}) does not equal " + f"number of sublayers ({n_layers})." + ) + + for layer, w, s, rng in zip(self.sublayers, weights, state, rngs): + inputs = inputs_from_stack(stack, layer.n_in) + outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True) + stack = outputs_onto_stack(outputs, stack, layer.n_in) + new_state.append(s) + self.state = tuple(new_state) + return stack + + # pylint: disable=protected-access + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + weights = [] + states = [] + # In the code below, stack, inputs, and outputs are abstract (shapes and + # dtypes), but weights and states are non-abstract actual values. + stack = input_signature + for sublayer in self.sublayers: + inputs = inputs_from_stack(stack, sublayer.n_in) + weights_or_cache_marker, state_or_cache_marker = sublayer.init( + inputs, use_cache=True + ) + outputs, _ = sublayer._forward_abstract(inputs) + stack = outputs_onto_stack(outputs, stack, sublayer.n_in) + + weights.append(weights_or_cache_marker) + states.append(state_or_cache_marker) + self.state = tuple(states) + self.weights = tuple(weights) - As an important special case, a None argument to Parallel acts as if it takes - one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For - example: + # pylint: enable=protected-access - Parallel(None, F) + def _n_inputs_n_outputs(self, layers): + del self + running_max = 0 + running_total = 0 + for layer in layers: + running_total += layer.n_in + running_max = max(running_max, running_total) + running_total -= layer.n_out + return running_max, (running_max - running_total) + + def _validate_forward_inputs(self, xs): + if not isinstance(xs, (tuple, list)) and self._n_in != 1: + raise TypeError( + f"Serial.forward input must be a tuple or list; " + f"instead got {type(xs)}." + ) + # TODO(jonni): Include full xs (or shape) in error message? + + len_xs = 1 if isinstance(xs, jnp.ndarray) else len(xs) + if len_xs < self.n_in: + raise ValueError( + f"Number of inputs ({len(xs)}) to Serial.forward less than n_in " + f"({self.n_in})." + ) - creates a layer that passes its first input unchanged and applies F to the - following input(s). - """ - def __init__(self, *sublayers, name=None): - """The constructor. +class Parallel(base.Layer): + """Combinator that applies a list of layers in parallel to its inputs. - Args: - *sublayers: A list of sublayers. - name: Descriptive name for this layer. + Layers in the list apply to successive spans of inputs, where the spans are + determined how many inputs each layer takes. The resulting output is the + (flattened) concatenation of the respective layer outputs. - Returns: - A new layer in which each of the given sublayers applies to its - corresponding span of elements in the dataflow stack. - """ - super().__init__(name=name) - sublayers = self._validate(sublayers) - self._n_layers = len(sublayers) - self._sublayers = sublayers - self._n_in = sum(l.n_in for l in sublayers) - self._n_out = sum(l.n_out for l in sublayers) - self._weights = tuple(None for l in sublayers) - self._state = tuple(None for l in sublayers) - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - n_layers, layers = self._n_layers, self.sublayers - sublayer_inputs = self._allot_to_sublayers(inputs) - state, weights = self.state, self.weights - rngs = _split_rngs(self.rng, n_layers) - if len(sublayer_inputs) != n_layers: - raise ValueError( - f'Number of inputs for sublayers ({len(sublayer_inputs)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(weights) != n_layers: - raise ValueError( - f'Number of weight elements ({len(weights)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(state) != n_layers: - raise ValueError( - f'Number of state elements ({len(state)}) does not equal ' - f'number of sublayers ({n_layers}).') - if len(rngs) != n_layers: - raise ValueError( - f'Number of rngs ({len(rngs)}) does not equal ' - f'number of sublayers ({n_layers}).') - outputs = [] - new_state = [] - for layer, x, w, s, r in zip(layers, sublayer_inputs, weights, state, rngs): - # Note that zip silently truncates its result if lengths don't match. - sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True) - if layer.n_out == 1: - outputs.append(sub_outputs) - else: - outputs.extend(sub_outputs) - new_state.append(sub_state) - output = outputs[0] if self.n_out == 1 else tuple(outputs) - self.state = tuple(new_state) - return output - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - sublayer_signatures = self._allot_to_sublayers(input_signature) - inits = [layer.init(signature, use_cache=True) - for layer, signature - in zip(self.sublayers, sublayer_signatures)] - if inits: - weights, state = tuple(zip(*inits)) - self.state = state - self.weights = weights - - def _validate(self, layers): - if not layers or len(layers) < 2: - raise ValueError( - f'layers ({layers}) must be a list with at least two elements') - layers = list(layers) # Ensure we can modify layers. - for i, obj in enumerate(layers): - if obj is None or obj == []: # pylint: disable=g-explicit-bool-comparison - layers[i] = Serial(None) - elif isinstance(obj, (list, tuple)): - layers[i] = Serial(obj) - else: - if not isinstance(obj, base.Layer): - raise ValueError( - f'Found nonlayer object ({obj}) in layers list: [{layers}]') - if layers[i].n_in == 0: - raise ValueError( - f'Sublayer with n_in = 0 not allowed in Parallel: {layers[i]}') - return layers + For example, suppose one has three layers: - def _allot_to_sublayers(self, inputs): - """Divides Parallel's inputs for use by the sublayers. + - F: 1 input, 1 output + - G: 3 inputs, 1 output + - H: 2 inputs, 2 outputs (h1, h2) - Args: - inputs: Tuple of ndarrays or ShapeDtype instances. + Then Parallel(F, G, H) will take 6 inputs and give 4 outputs: - Returns: - A tuple that partitions this layer's inputs among its sublayers. - Sublayers that take one argument get that argument directly. All other - sublayers get a tuple of items. + - inputs: a, b, c, d, e, f + - outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f) + + As an important special case, a None argument to Parallel acts as if it takes + one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For + example: + + Parallel(None, F) + + creates a layer that passes its first input unchanged and applies F to the + following input(s). """ - start, end = 0, 0 - sub_inputs = [] - for layer in self.sublayers: - n_in = layer.n_in - end = start + n_in - if n_in == 1: - sub_inputs.append(inputs[start]) - else: - sub_inputs.append(inputs[start:end]) - start = end - return tuple(sub_inputs) + + def __init__(self, *sublayers, name=None): + """The constructor. + + Args: + *sublayers: A list of sublayers. + name: Descriptive name for this layer. + + Returns: + A new layer in which each of the given sublayers applies to its + corresponding span of elements in the dataflow stack. + """ + super().__init__(name=name) + sublayers = self._validate(sublayers) + self._n_layers = len(sublayers) + self._sublayers = sublayers + self._n_in = sum(l.n_in for l in sublayers) + self._n_out = sum(l.n_out for l in sublayers) + self._weights = tuple(None for l in sublayers) + self._state = tuple(None for l in sublayers) + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + n_layers, layers = self._n_layers, self.sublayers + sublayer_inputs = self._allot_to_sublayers(inputs) + state, weights = self.state, self.weights + rngs = _split_rngs(self.rng, n_layers) + if len(sublayer_inputs) != n_layers: + raise ValueError( + f"Number of inputs for sublayers ({len(sublayer_inputs)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(weights) != n_layers: + raise ValueError( + f"Number of weight elements ({len(weights)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(state) != n_layers: + raise ValueError( + f"Number of state elements ({len(state)}) does not equal " + f"number of sublayers ({n_layers})." + ) + if len(rngs) != n_layers: + raise ValueError( + f"Number of rngs ({len(rngs)}) does not equal " + f"number of sublayers ({n_layers})." + ) + outputs = [] + new_state = [] + for layer, x, w, s, r in zip(layers, sublayer_inputs, weights, state, rngs): + # Note that zip silently truncates its result if lengths don't match. + sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True) + if layer.n_out == 1: + outputs.append(sub_outputs) + else: + outputs.extend(sub_outputs) + new_state.append(sub_state) + output = outputs[0] if self.n_out == 1 else tuple(outputs) + self.state = tuple(new_state) + + if not isinstance(output, (tuple, list)): + output = (output,) + + # The input should be a flat single tuple without nested tuples + output = self.flatten_tuple(output) + + return output + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + sublayer_signatures = self._allot_to_sublayers(input_signature) + inits = [ + layer.init(signature, use_cache=True) + for layer, signature in zip(self.sublayers, sublayer_signatures) + ] + if inits: + weights, state = tuple(zip(*inits)) + self.state = state + self.weights = weights + + def _validate(self, layers): + if not layers or len(layers) < 2: + raise ValueError( + f"layers ({layers}) must be a list with at least two elements" + ) + layers = list(layers) # Ensure we can modify layers. + for i, obj in enumerate(layers): + if obj is None or obj == []: # pylint: disable=g-explicit-bool-comparison + layers[i] = Serial(None) + elif isinstance(obj, (list, tuple)): + layers[i] = Serial(obj) + else: + if not isinstance(obj, base.Layer): + raise ValueError( + f"Found nonlayer object ({obj}) in layers list: [{layers}]" + ) + if layers[i].n_in == 0: + raise ValueError( + f"Sublayer with n_in = 0 not allowed in Parallel: {layers[i]}" + ) + return layers + + def _allot_to_sublayers(self, inputs): + """Divides Parallel's inputs for use by the sublayers. + + Args: + inputs: Tuple of ndarrays or ShapeDtype instances. + + Returns: + A tuple that partitions this layer's inputs among its sublayers. + Sublayers that take one argument get that argument directly. All other + sublayers get a tuple of items. + """ + start, end = 0, 0 + sub_inputs = [] + for layer in self.sublayers: + n_in = layer.n_in + end = start + n_in + if n_in == 1: + sub_inputs.append(inputs[start]) + else: + sub_inputs.append(inputs[start:end]) + start = end + return tuple(sub_inputs) class Concatenate(base.Layer): - """Concatenates a number of tensors into a single tensor. + """Concatenates a number of tensors into a single tensor. - For example:: + For example:: - x = np.array([1, 2]) - y = np.array([3, 4]) - z = np.array([5, 6]) - concat3 = tl.Concatenate(n_items=3) - z = concat3((x, y, z)) # z = [1, 2, 3, 4, 5, 6] + x = np.array([1, 2]) + y = np.array([3, 4]) + z = np.array([5, 6]) + concat3 = tl.Concatenate(n_items=3) + z = concat3((x, y, z)) # z = [1, 2, 3, 4, 5, 6] - Use the `axis` argument to specify on which axis to concatenate the tensors. - By default it's the last axis, `axis=-1`, and `n_items=2`. - """ + Use the `axis` argument to specify on which axis to concatenate the tensors. + By default it's the last axis, `axis=-1`, and `n_items=2`. + """ - def __init__(self, n_items=2, axis=-1): - name = 'Concatenate' if axis == -1 else f'Concatenate_axis{axis}' - super().__init__(n_in=n_items, name=name) - self._n_items = n_items - self._axis = axis + def __init__(self, n_items=2, axis=-1): + name = "Concatenate" if axis == -1 else f"Concatenate_axis{axis}" + super().__init__(n_in=n_items, name=name) + self._n_items = n_items + self._axis = axis - def forward(self, xs): - """Executes this layer as part of a forward pass through the model.""" - return jnp.concatenate(xs, self._axis) + def forward(self, xs): + """Executes this layer as part of a forward pass through the model.""" + return jnp.concatenate(xs, self._axis) class Split(base.Layer): - """Splits the input into n items along an axis.""" + """Splits the input into n items along an axis.""" - def __init__(self, n_items=2, axis=-1): - super().__init__(n_out=n_items) - self._n_items = n_items - self._axis = axis + def __init__(self, n_items=2, axis=-1): + super().__init__(n_out=n_items) + self._n_items = n_items + self._axis = axis - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - return tuple(jnp.split(inputs, self._n_items, self._axis)) + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + return tuple(jnp.split(inputs, self._n_items, self._axis)) def _scan(f, xs, init_value, axis=0, remat=False): - """Scans the f over the given axis of xs. - - In pseudo-python, the scan function would look as follows: - - def scan(f, xs, init_value, axis): - xs = [xs[..., i, ...] for i in range(xs.shape[axis])] - cur_value = init_value - ys = [] - for x in xs: - y, cur_value = f(x, cur_value) - ys.append(y) - return np.stack(ys, axis), cur_value - - Args: - f: function (x, carry) -> (y, new_carry) - xs: tensor, x will be xs slices on axis - init_value: tensor, initial value of the carry-over - axis: int, the axis on which to slice xs - remat: whether to re-materialize f - - Returns: - A pair (ys, last_value) as described above. - """ - def swapaxes(x): - transposed_axes = list(range(len(x.shape))) - transposed_axes[axis] = 0 - transposed_axes[0] = axis - return jnp.transpose(x, axes=transposed_axes) - if axis != 0: - xs = fastmath.nested_map(swapaxes, xs) - def transposed_f(c, x): - y, d = f(x, c) - return d, y - if remat: - transposed_f = fastmath.remat(transposed_f) - last_value, ys = fastmath.scan(transposed_f, init_value, xs) - if axis != 0: - ys = fastmath.nested_map(swapaxes, ys) - return ys, last_value + """Scans the f over the given axis of xs. + + In pseudo-python, the scan function would look as follows: + + def scan(f, xs, init_value, axis): + xs = [xs[..., i, ...] for i in range(xs.shape[axis])] + cur_value = init_value + ys = [] + for x in xs: + y, cur_value = f(x, cur_value) + ys.append(y) + return np.stack(ys, axis), cur_value + + Args: + f: function (x, carry) -> (y, new_carry) + xs: tensor, x will be xs slices on axis + init_value: tensor, initial value of the carry-over + axis: int, the axis on which to slice xs + remat: whether to re-materialize f + + Returns: + A pair (ys, last_value) as described above. + """ + + def swapaxes(x): + transposed_axes = list(range(len(x.shape))) + transposed_axes[axis] = 0 + transposed_axes[0] = axis + return jnp.transpose(x, axes=transposed_axes) + + if axis != 0: + xs = fastmath.nested_map(swapaxes, xs) + + def transposed_f(c, x): + y, d = f(x, c) + return d, y + + if remat: + transposed_f = fastmath.remat(transposed_f) + last_value, ys = fastmath.scan(transposed_f, init_value, xs) + if axis != 0: + ys = fastmath.nested_map(swapaxes, ys) + return ys, last_value class Scan(base.Layer): - """Applies a layer progressively/cumulatively to an axis-derived sequence. - - Conceptually, this is a function from a list to a same-length list of partial - (cumulative) results. For instance, a list of values (`[1, 2, 3, 4, 5]`) can - transform to a list of cumulative sums (`[1, 3, 6, 10, 15]`). Functions for - the same concept are called `scan` in Scala, `scanl` in Haskell, and - `accumulate*` in Factor. - - In more detail, we assume the layer takes a tuple of inputs of the following - form: - - (input1, ..., inputN, carry1, ..., carryM) - - and returns: - - (output1, ..., outputK, new_carry1, ..., new_carryM) - - The scanned version applies the layer iteratively to a tensor treating values - at the given axis as if they were a list. For example, to calculate all - sums of prefixes of a tensor, we can do this:: - - def add(x, carry): - def f(input, carry): - res = input + carry - return res, res # output and carry are the same - return tl.Fn('add', f, n_out=2) - - Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6 - """ - - def __init__(self, layer, axis=0, n_carry=1, remat=False, mode='train'): - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - self._n_carry = n_carry - self._axis = axis - self._remat = remat - self._weights = (None,) - self._state = (None, ()) - self._mode = mode - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - @property - def state(self): - """Returns a tuple containing this layer's state.""" - return (self.sublayer.state, self._state[1]) - - @state.setter - def state(self, state): - """Recursively sets state on this layer the sublayer.""" - if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: - return - self._state = (None, state[1]) - self.sublayer.state = state[0] - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - weights = self.weights[0] - if isinstance(inputs, list): - inputs = tuple(inputs) # so that inputs structure matches outputs - n_carry = self._n_carry - def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name - carry, state, i = carry_and_state - x_and_carry = x + carry if n_carry > 0 else x - rng = fastmath.random.fold_in(self.rng, i) - res, new_state = self.sublayer.pure_fn( - x_and_carry, weights, state, rng, use_cache=True) - if n_carry > 0: - return (res[:-n_carry], (res[-n_carry:], new_state, i+1)) - else: - return (res, ([], new_state, i+1)) - - if n_carry > 0: - xs = inputs[:-n_carry] # Split input stack into inputs and carry. - xs_carry = inputs[-n_carry:] - if self._mode == 'predict' and self._state[1] is not (): # pylint: disable=literal-comparison - xs_carry = self._state[1] - init = (xs_carry, self.state[0], jnp.array(0, dtype=jnp.int32)) - else: - xs_carry = () - xs, init = inputs, ([], self.state[0], jnp.array(0, dtype=jnp.int32)) - ys, (carry, new_state, _) = _scan(scannable_fn, xs, init, - axis=self._axis, remat=self._remat) - res = ys + carry if n_carry > 0 else ys - state_carry = carry if self._mode == 'predict' and n_carry > 0 else () - self.state = (new_state, state_carry) - return res # Put outputs and carry back on stack. - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - n_carry = self._n_carry - if n_carry == 0: - if isinstance(input_signature, (list, tuple)): - layer_sig = [ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) - for x in input_signature] - layer_sig = tuple(layer_sig) - else: - layer_sig = ShapeDtype(_shape_without_axis(input_signature, self._axis), - input_signature.dtype) - weights, state = self.sublayer.init(layer_sig) - self.state = (state, ()) - self.weights = (weights,) - else: - xs = input_signature[:-n_carry] - init = input_signature[-n_carry:] - xs_slices = [ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) - for x in xs] - layer_signature = tuple(xs_slices + list(init)) - weights, state = self.sublayer.init(layer_signature, use_cache=True) - self.state = (state, ()) - self.weights = (weights,) + """Applies a layer progressively/cumulatively to an axis-derived sequence. + Conceptually, this is a function from a list to a same-length list of partial + (cumulative) results. For instance, a list of values (`[1, 2, 3, 4, 5]`) can + transform to a list of cumulative sums (`[1, 3, 6, 10, 15]`). Functions for + the same concept are called `scan` in Scala, `scanl` in Haskell, and + `accumulate*` in Factor. -class Cond(base.Layer): - """Applies layers conditionally. - - For parameters `cond`, `true`, and `false` runs the equivalent of `true(y) - if cond(x) else false(y)`, where `x` is `cond.n_in` elements from front of the - stack and `y` is the rest of the stack. - Exactly one of `true` and `false` functions is executed, so it can be used to - conditionally run long computations. The state of non-executed function is not - updated. Note that different branches may be executed on different devices - if `cond` returns different values on them. - By default 'false' function is an identity. - - `cond` must return exactly one element: a Boolean value. - `true` and `false` must have the same n_in, and the same n_out. - """ - - def __init__(self, cond, true, false=None, name=None): - super(Cond, self).__init__(name=name) - - if false is None: - self._identity_false_fun = True - # We don't need this function, but it will be useful for checking if - # 'true' has proper n_in/n_out. - false = Serial() - self._false = false - else: - self._identity_false_fun = False - self._false = false - - sublayers = [cond, true, false] - self._sublayers = sublayers - self._n_layers = len(sublayers) - self._cond = cond - self._true = true - - if cond.n_out != 1: - raise ValueError( - 'cond.n_out must be 1: cond:{}->{}'.format(cond.n_in, cond.n_out)) - if true.n_in != false.n_in: - raise ValueError( - 'true.n_in and false.n_in must be equal: true:{}->{} ; false:{}->{}' - .format(true.n_in, true.n_out, false.n_in, false.n_out)) - if true.n_out != false.n_out: - raise ValueError( - 'true.n_out and false.n_out must be equal: true:{}->{} ; false:{}->{}' - .format(true.n_in, true.n_out, false.n_in, false.n_out)) - - self._n_in = cond.n_in + true.n_in - self._n_out = true.n_out - self._weights = tuple(None for l in sublayers) - self._state = tuple(None for l in sublayers) - - # pylint: disable=protected-access - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - weights = [] - states = [] - # In the code below, stack, inputs, and outputs are abstract (shapes and - # dtypes), but weights and states are non-abstract actual values. - stack = _make_tuple(input_signature) - - # Inputs/outputs of `cond`. - inputs = inputs_from_stack(stack, self._cond.n_in) - weights_or_cache_marker, state_or_cache_marker = ( - self._cond.init(inputs, use_cache=True)) - weights.append(weights_or_cache_marker) - states.append(state_or_cache_marker) - self._cond._forward_abstract(inputs) - stack = _make_tuple(outputs_onto_stack([], stack, self._cond.n_in)) - - # Inputs/outputs of `true` and `false`. - for sublayer in [self._true, self._false]: - inputs = inputs_from_stack(stack, sublayer.n_in) - weights_or_cache_marker, state_or_cache_marker = ( - sublayer.init(inputs, use_cache=True)) - weights.append(weights_or_cache_marker) - states.append(state_or_cache_marker) - - self.state = states - self.weights = weights - # pylint: enable=protected-access + In more detail, we assume the layer takes a tuple of inputs of the following + form: - def _validate_forward_inputs(self, xs): - xs = _make_tuple(xs) - if len(xs) < self.n_in: - raise ValueError( - f'Number of inputs ({len(xs)}) to Cond.forward less than n_in ' - f'({self.n_in}).') + (input1, ..., inputN, carry1, ..., carryM) - def forward(self, xs): - """Executes this layer as part of a forward pass through the model. + and returns: - Args: - xs: Tensors of as required by the branches of this conditional. + (output1, ..., outputK, new_carry1, ..., new_carryM) - Returns: - Tensors resulting from running the chosen branch. + The scanned version applies the layer iteratively to a tensor treating values + at the given axis as if they were a list. For example, to calculate all + sums of prefixes of a tensor, we can do this:: + + def add(x, carry): + def f(input, carry): + res = input + carry + return res, res # output and carry are the same + return tl.Fn('add', f, n_out=2) + + Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6 """ - # TODO(jaszczur): modify; it's a copy from SkippingSerial - self._validate_forward_inputs(xs) - layers_state = self.state - # Get 3 rngs, one for each layer. - rngs = _split_rngs(self.rng, 3) - - # Prepare the stack and do some safety checks as in the parent class. - stack = _make_tuple(xs) - weights = self.weights - if len(weights) != 3: - raise ValueError('number of weights ({}) not equal to 3' - .format(len(weights))) - if len(layers_state) != 3: - raise ValueError('length of state ({}) not equal to 3' - .format(len(layers_state))) - - def true_func(t): - outputs, new_true_state = self._true.pure_fn( - t[0][0], t[1][0], t[2][0], t[3][0]) - # t[2][1] is old_false_state which is not changing if true is executed. - return outputs, (new_true_state, t[2][1]) - - def false_func(t): - if self._identity_false_fun: - # Memory optimization: we don't need pure_fn call. - return t[0][1], t[2] - outputs, new_false_state = self._false.pure_fn( - t[0][1], t[1][1], t[2][1], t[3][1]) - # t[2][1] is old_true_state, which is not changing if false is executed. - return outputs, (t[2][0], new_false_state) - - cond_inputs = inputs_from_stack(xs, self._cond.n_in) - cond_output, s = self._cond.pure_fn(cond_inputs, self.weights[0], - self.state[0], rngs[0], use_cache=True) - stack = outputs_onto_stack([], stack, self._cond.n_in) - self._cond.state = s - - outputs, both_states = fastmath.cond( - cond_output, - true_func, - false_func, - [(stack, stack), - (self.weights[1], self.weights[2]), - (self.state[1], self.state[2]), - (rngs[1], rngs[2])] - ) - stack = outputs_onto_stack([], stack, self._cond.n_in) - # We don't know which (`true` or `false`) branch was run, but both of them - # are adding (n_out) and removing (n_in) the same number of elements of the - # stack (this was checked in __init__). outputs_onto_stack just uses the - # layer's n_in, so we can pass either `true` or `false` to it. - # Note that `outputs` is the actual output of `true` or `false` branch, - # whichever was run, and we add it to the stack in any case. - stack = outputs_onto_stack(outputs, stack, self._true.n_in) - self._true.state = both_states[0] - self._false.state = both_states[1] - return _make_singleitem_or_original(stack) + def __init__(self, layer, axis=0, n_carry=1, remat=False, mode="train"): + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + self._n_carry = n_carry + self._axis = axis + self._remat = remat + self._weights = (None,) + self._state = (None, ()) + self._mode = mode + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + @property + def state(self): + """Returns a tuple containing this layer's state.""" + return (self.sublayer.state, self._state[1]) + + @state.setter + def state(self, state): + """Recursively sets state on this layer the sublayer.""" + if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: + return + self._state = (None, state[1]) + self.sublayer.state = state[0] + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + weights = self.weights[0] + if isinstance(inputs, list): + inputs = tuple(inputs) # so that inputs structure matches outputs + n_carry = self._n_carry + + def scannable_fn(x, carry_and_state): # pylint: disable=invalid-name + carry, state, i = carry_and_state + x_and_carry = x + carry if n_carry > 0 else x + rng = fastmath.random.fold_in(self.rng, i) + res, new_state = self.sublayer.pure_fn( + x_and_carry, weights, state, rng, use_cache=True + ) + if n_carry > 0: + return (res[:-n_carry], (res[-n_carry:], new_state, i + 1)) + else: + return (res, ([], new_state, i + 1)) + + if n_carry > 0: + xs = inputs[:-n_carry] # Split input stack into inputs and carry. + xs_carry = inputs[-n_carry:] + if ( + self._mode == "predict" and self._state[1] is not () + ): # pylint: disable=literal-comparison + xs_carry = self._state[1] + init = (xs_carry, self.state[0], jnp.array(0, dtype=jnp.int32)) + else: + xs_carry = () + xs, init = inputs, ([], self.state[0], jnp.array(0, dtype=jnp.int32)) + ys, (carry, new_state, _) = _scan( + scannable_fn, xs, init, axis=self._axis, remat=self._remat + ) + res = ys + carry if n_carry > 0 else ys + state_carry = carry if self._mode == "predict" and n_carry > 0 else () + self.state = (new_state, state_carry) + return res # Put outputs and carry back on stack. + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + n_carry = self._n_carry + if n_carry == 0: + if isinstance(input_signature, (list, tuple)): + layer_sig = [ + ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) + for x in input_signature + ] + layer_sig = tuple(layer_sig) + else: + layer_sig = ShapeDtype( + _shape_without_axis(input_signature, self._axis), + input_signature.dtype, + ) + weights, state = self.sublayer.init(layer_sig) + self.state = (state, ()) + self.weights = (weights,) + else: + xs = input_signature[:-n_carry] + init = input_signature[-n_carry:] + xs_slices = [ + ShapeDtype(_shape_without_axis(x, self._axis), x.dtype) for x in xs + ] + layer_signature = tuple(xs_slices + list(init)) + weights, state = self.sublayer.init(layer_signature, use_cache=True) + self.state = (state, ()) + self.weights = (weights,) + + +class Cond(base.Layer): + """Applies layers conditionally. + + For parameters `cond`, `true`, and `false` runs the equivalent of `true(y) + if cond(x) else false(y)`, where `x` is `cond.n_in` elements from front of the + stack and `y` is the rest of the stack. + Exactly one of `true` and `false` functions is executed, so it can be used to + conditionally run long computations. The state of non-executed function is not + updated. Note that different branches may be executed on different devices + if `cond` returns different values on them. + By default 'false' function is an identity. + + `cond` must return exactly one element: a Boolean value. + `true` and `false` must have the same n_in, and the same n_out. + """ + + def __init__(self, cond, true, false=None, name=None): + super(Cond, self).__init__(name=name) + + if false is None: + self._identity_false_fun = True + # We don't need this function, but it will be useful for checking if + # 'true' has proper n_in/n_out. + false = Serial() + self._false = false + else: + self._identity_false_fun = False + self._false = false + + sublayers = [cond, true, false] + self._sublayers = sublayers + self._n_layers = len(sublayers) + self._cond = cond + self._true = true + + if cond.n_out != 1: + raise ValueError( + "cond.n_out must be 1: cond:{}->{}".format(cond.n_in, cond.n_out) + ) + if true.n_in != false.n_in: + raise ValueError( + "true.n_in and false.n_in must be equal: true:{}->{} ; false:{}->{}".format( + true.n_in, true.n_out, false.n_in, false.n_out + ) + ) + if true.n_out != false.n_out: + raise ValueError( + "true.n_out and false.n_out must be equal: true:{}->{} ; false:{}->{}".format( + true.n_in, true.n_out, false.n_in, false.n_out + ) + ) + + self._n_in = cond.n_in + true.n_in + self._n_out = true.n_out + self._weights = tuple(None for l in sublayers) + self._state = tuple(None for l in sublayers) + + # pylint: disable=protected-access + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + weights = [] + states = [] + # In the code below, stack, inputs, and outputs are abstract (shapes and + # dtypes), but weights and states are non-abstract actual values. + stack = _make_tuple(input_signature) + + # Inputs/outputs of `cond`. + inputs = inputs_from_stack(stack, self._cond.n_in) + weights_or_cache_marker, state_or_cache_marker = self._cond.init( + inputs, use_cache=True + ) + weights.append(weights_or_cache_marker) + states.append(state_or_cache_marker) + self._cond._forward_abstract(inputs) + stack = _make_tuple(outputs_onto_stack([], stack, self._cond.n_in)) + + # Inputs/outputs of `true` and `false`. + for sublayer in [self._true, self._false]: + inputs = inputs_from_stack(stack, sublayer.n_in) + weights_or_cache_marker, state_or_cache_marker = sublayer.init( + inputs, use_cache=True + ) + weights.append(weights_or_cache_marker) + states.append(state_or_cache_marker) + + self.state = states + self.weights = weights + # pylint: enable=protected-access + + def _validate_forward_inputs(self, xs): + xs = _make_tuple(xs) + if len(xs) < self.n_in: + raise ValueError( + f"Number of inputs ({len(xs)}) to Cond.forward less than n_in " + f"({self.n_in})." + ) + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model. + + Args: + xs: Tensors of as required by the branches of this conditional. + + Returns: + Tensors resulting from running the chosen branch. + """ + + # TODO(jaszczur): modify; it's a copy from SkippingSerial + self._validate_forward_inputs(xs) + xs = _make_tuple(xs) + + # The input should be a flat single tuple without nested tuples + xs = self.flatten_tuple(xs) + + layers_state = self.state + # Get 3 rngs, one for each layer. + rngs = _split_rngs(self.rng, 3) + + # Prepare the stack and do some safety checks as in the parent class. + stack = _make_tuple(xs) + stack = self.flatten_tuple(stack) + weights = self.weights + if len(weights) != 3: + raise ValueError( + "number of weights ({}) not equal to 3".format(len(weights)) + ) + if len(layers_state) != 3: + raise ValueError( + "length of state ({}) not equal to 3".format(len(layers_state)) + ) + + def true_func(t): + outputs, new_true_state = self._true.pure_fn( + t[0][0], t[1][0], t[2][0], t[3][0] + ) + # t[2][1] is old_false_state which is not changing if true is executed. + outputs = _make_tuple(outputs) + return outputs, (new_true_state, t[2][1]) + + def false_func(t): + if self._identity_false_fun: + # Memory optimization: we don't need pure_fn call. + return t[0][1], t[2] + outputs, new_false_state = self._false.pure_fn( + t[0][1], t[1][1], t[2][1], t[3][1] + ) + # t[2][1] is old_true_state, which is not changing if false is executed. + outputs = _make_tuple(outputs) + return outputs, (t[2][0], new_false_state) + + cond_inputs = inputs_from_stack(xs, self._cond.n_in) + cond_inputs = _make_tuple(cond_inputs) + cond_output, s = self._cond.pure_fn( + cond_inputs, self.weights[0], self.state[0], rngs[0], use_cache=True + ) + stack = outputs_onto_stack([], stack, self._cond.n_in) + stack = _make_tuple(stack) + self._cond.state = s + + outputs, both_states = fastmath.cond( + cond_output, + true_func, + false_func, + [ + (stack, stack), + (self.weights[1], self.weights[2]), + (self.state[1], self.state[2]), + (rngs[1], rngs[2]), + ], + ) + stack = outputs_onto_stack([], stack, self._cond.n_in) + + # We don't know which (`true` or `false`) branch was run, but both of them + # are adding (n_out) and removing (n_in) the same number of elements of the + # stack (this was checked in __init__). outputs_onto_stack just uses the + # layer's n_in, so we can pass either `true` or `false` to it. + # Note that `outputs` is the actual output of `true` or `false` branch, + # whichever was run, and we add it to the stack in any case. + stack = outputs_onto_stack(outputs, stack, self._true.n_in) + self._true.state = both_states[0] + self._false.state = both_states[1] + return _make_singleitem_or_original(stack) # pylint: disable=invalid-name def Chunk(layer, chunk_size, pass_unchunkable=True): - """Executes `layer` using batch chunks of size `chunk_size` to save memory.""" - if chunk_size < 1: - return layer - def reshape_to_chunks(x): - chunk_batch = x.shape[0] - size = chunk_size - n_chunks = chunk_batch // size - if chunk_batch % size != 0: - if pass_unchunkable: - n_chunks = 1 - size = chunk_batch - else: - raise ValueError(f'Chunk size {size} must divide batch ' - f'size {chunk_batch}') - return jnp.reshape(x, [n_chunks, size] + list(x.shape[1:])) - reshape_to_chunks_layer = base.PureLayer( - lambda xs: fastmath.nested_map(reshape_to_chunks, xs), - n_in=layer.n_in, n_out=layer.n_in, name='ReshapeToChunks') - def reshape_from_chunks(x): - batch_size = x.shape[0] * x.shape[1] - return jnp.reshape(x, [batch_size] + list(x.shape[2:])) - reshape_from_chunks_layer = base.PureLayer( - lambda xs: fastmath.nested_map(reshape_from_chunks, xs), - n_in=layer.n_out, n_out=layer.n_out, name='ReshapeFromChunks') - return Serial( - reshape_to_chunks_layer, - Scan(layer, axis=0, n_carry=0, remat=True), - reshape_from_chunks_layer, - ) - - -def Branch(*layers, name='Branch'): - """Combinator that applies a list of layers in parallel to copies of inputs. - - Each layer in the input list is applied to as many inputs from the stack - as it needs, and their outputs are successively combined on stack. - - For example, suppose one has three layers: - - - F: 1 input, 1 output - - G: 3 inputs, 1 output - - H: 2 inputs, 2 outputs (h1, h2) - - Then Branch(F, G, H) will take 3 inputs and give 4 outputs: - - - inputs: a, b, c - - outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b) - - As an important special case, a None argument to Branch acts as if it takes - one argument, which it leaves unchanged. (It acts as a one-arg no-op.) - - Args: - *layers: List of layers. - name: Descriptive name for this layer. - - Returns: - A branch layer built from the given sublayers. - """ - if len(layers) == 1: - return layers[0] - parallel_layer = Parallel(*layers) - indices = [list(range(layer.n_in)) for layer in parallel_layer.sublayers] - return Serial(Select(_deep_flatten(indices)), parallel_layer, - name=name, sublayers_to_print=layers) + """Executes `layer` using batch chunks of size `chunk_size` to save memory.""" + if chunk_size < 1: + return layer + + def reshape_to_chunks(x): + chunk_batch = x.shape[0] + size = chunk_size + n_chunks = chunk_batch // size + if chunk_batch % size != 0: + if pass_unchunkable: + n_chunks = 1 + size = chunk_batch + else: + raise ValueError( + f"Chunk size {size} must divide batch " f"size {chunk_batch}" + ) + return jnp.reshape(x, [n_chunks, size] + list(x.shape[1:])) + + reshape_to_chunks_layer = base.PureLayer( + lambda xs: fastmath.nested_map(reshape_to_chunks, xs), + n_in=layer.n_in, + n_out=layer.n_in, + name="ReshapeToChunks", + ) + + def reshape_from_chunks(x): + batch_size = x.shape[0] * x.shape[1] + return jnp.reshape(x, [batch_size] + list(x.shape[2:])) + + reshape_from_chunks_layer = base.PureLayer( + lambda xs: fastmath.nested_map(reshape_from_chunks, xs), + n_in=layer.n_out, + n_out=layer.n_out, + name="ReshapeFromChunks", + ) + return Serial( + reshape_to_chunks_layer, + Scan(layer, axis=0, n_carry=0, remat=True), + reshape_from_chunks_layer, + ) + + +def Branch(*layers, name="Branch"): + """Combinator that applies a list of layers in parallel to copies of inputs. + + Each layer in the input list is applied to as many inputs from the stack + as it needs, and their outputs are successively combined on stack. + + For example, suppose one has three layers: + + - F: 1 input, 1 output + - G: 3 inputs, 1 output + - H: 2 inputs, 2 outputs (h1, h2) + + Then Branch(F, G, H) will take 3 inputs and give 4 outputs: + + - inputs: a, b, c + - outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b) + + As an important special case, a None argument to Branch acts as if it takes + one argument, which it leaves unchanged. (It acts as a one-arg no-op.) + + Args: + *layers: List of layers. + name: Descriptive name for this layer. + + Returns: + A branch layer built from the given sublayers. + """ + if len(layers) == 1: + return layers[0] + parallel_layer = Parallel(*layers) + indices = [list(range(layer.n_in)) for layer in parallel_layer.sublayers] + return Serial( + Select(_deep_flatten(indices)), + parallel_layer, + name=name, + sublayers_to_print=layers, + ) def Residual(*layers, shortcut=None): - """Wraps a series of layers with a residual connection. - - Args: - *layers: One or more layers, to be applied in series. - shortcut: If None (the usual case), the Residual layer computes the - element-wise sum of the stack-top input with the output of the layer - series. If specified, the `shortcut` layer applies to a copy of the - inputs and (elementwise) adds its output to the output from the main - layer series. - - Returns: - A layer representing a residual connection paired with a layer series. - """ - layers = _ensure_flat(layers) - layer = layers[0] if len(layers) == 1 else Serial(layers) - # TODO(jonni): Should we require layer.n_out = 1 and shortcut.n_out = 1? - return Serial( - Branch(shortcut, layer), - Add(), # pylint: disable=no-value-for-parameter - ) + """Wraps a series of layers with a residual connection. + + Args: + *layers: One or more layers, to be applied in series. + shortcut: If None (the usual case), the Residual layer computes the + element-wise sum of the stack-top input with the output of the layer + series. If specified, the `shortcut` layer applies to a copy of the + inputs and (elementwise) adds its output to the output from the main + layer series. + + Returns: + A layer representing a residual connection paired with a layer series. + """ + layers = _ensure_flat(layers) + layer = layers[0] if len(layers) == 1 else Serial(layers) + # TODO(jonni): Should we require layer.n_out = 1 and shortcut.n_out = 1? + return Serial( + Branch(shortcut, layer), + Add(), # pylint: disable=no-value-for-parameter + ) def Select(indices, n_in=None, name=None): - """Copies, reorders, or deletes stack elements according to `indices`. - - Args: - indices: A list or tuple of 0-based indices to select elements relative to - the top of the stack. - n_in: Number of input elements to pop from the stack, and replace with - those specified by `indices`. If not specified, its value will be - calculated as `max(indices) + 1`. - name: Descriptive name for this layer. - - Returns: - Tensors, matching the number selected (`n_out = len(indices)`). - Specifically: - - - n_out = 0: an empty tuple - - n_out = 1: one tensor (NOT wrapped in a tuple) - - n_out > 1: a tuple of tensors, with n_out items - """ - if n_in is None: - n_in = max(indices) + 1 - if name is None: - name = f'Select{indices}'.replace(' ', '') - - def select(xs): # pylint: disable=invalid-name - if not isinstance(xs, (tuple, list)): - xs = (xs,) - selected = tuple(xs[i] for i in indices) - return selected[0] if len(selected) == 1 else selected - - return base.PureLayer(select, n_in=n_in, n_out=len(indices), name=name) + """Copies, reorders, or deletes stack elements according to `indices`. + + Args: + indices: A list or tuple of 0-based indices to select elements relative to + the top of the stack. + n_in: Number of input elements to pop from the stack, and replace with + those specified by `indices`. If not specified, its value will be + calculated as `max(indices) + 1`. + name: Descriptive name for this layer. + + Returns: + Tensors, matching the number selected (`n_out = len(indices)`). + Specifically: + + - n_out = 0: an empty tuple + - n_out = 1: one tensor (NOT wrapped in a tuple) + - n_out > 1: a tuple of tensors, with n_out items + """ + if n_in is None: + n_in = max(indices) + 1 + if name is None: + name = f"Select{indices}".replace(" ", "") + + def select(xs): # pylint: disable=invalid-name + if not isinstance(xs, (tuple, list)): + xs = (xs,) + selected = tuple(xs[i] for i in indices) + return selected[0] if len(selected) == 1 else selected + + return base.PureLayer(select, n_in=n_in, n_out=len(indices), name=name) def Drop(): - """Drops the top stack element.""" - return Fn('Drop', lambda x: (), n_out=0) + """Drops the top stack element.""" + return Fn("Drop", lambda x: (), n_out=0) def Dup(): - """Duplicates (copies) the top element on the data stack.""" - return Fn('Dup', lambda x: (x, x), n_out=2) + """Duplicates (copies) the top element on the data stack.""" + return Fn("Dup", lambda x: (x, x), n_out=2) def Swap(): - """Swaps the top two stack elements.""" - return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2) + """Swaps the top two stack elements.""" + return Fn("Swap", lambda x0, x1: (x1, x0), n_out=2) def SerialWithSideOutputs(layers, n_side_outputs=1): - """Serial layer with side outputs. - - This layer makes it easier to manage the stack when layers have side outputs. - - In the simplest case of layers with n_in=1, n_out=2 and with - n_side_outputs=1, this layer runs the following computation on x:: - - side_outputs = [] - for i in range(len(layers)): - x, side_output = layers[i](x) - side_outputs.append(side_output) - return [x] + side_outputs - - In the general case of layers with variable n_in and n_out and - n_side_outputs being a list of N integers, it does the following:: - - side_outputs = [] - for i in range(N): - res = layer[i](cur_stack) # remove n_in from stack - cur_stack.append(res[:n_side_outputs[i]]) # put back some on stack - side_outputs.extend(res[n_side_outputs:]) - return cur_stack + side_outputs - - Args: - layers: a list of layers to execute - n_side_outputs: an int or a list of ints, how many outputs of each layer - to put aside - - Returns: - A layer that performs the above computation. - """ - if isinstance(n_side_outputs, int): - n_side_outputs = [n_side_outputs] * len(layers) - - # Calculate the n_in for this layer. - running_max = 0 - running_total = 0 - for layer, n_side_output in zip(layers, n_side_outputs): - running_total += layer.n_in - running_max = max(running_max, running_total) - running_total -= layer.n_out - n_side_output - n_in = running_max - - # Create the list of layers to run serially. - cur_stack_size = n_in - serial_layers = [] - for layer, n_side_output in zip(layers, n_side_outputs): - serial_layers.append(layer) - cur_stack_size += layer.n_out - layer.n_in - # Indices to move n_side_outputs to the back of the stack. - # Don't touch first n_out - n_side_outputs. - move_back_indices = list(range(layer.n_out - n_side_output)) - # Then comes the rest of the stack that we're not moving. - move_back_indices += [i + layer.n_out - for i in range(cur_stack_size - layer.n_out)] - # Finally the indices we move. - move_back_indices += [i + layer.n_out - n_side_output - for i in range(n_side_output)] - # Swap them on stack. - serial_layers.append(Select(move_back_indices)) - - return Serial(serial_layers) + """Serial layer with side outputs. + + This layer makes it easier to manage the stack when layers have side outputs. + + In the simplest case of layers with n_in=1, n_out=2 and with + n_side_outputs=1, this layer runs the following computation on x:: + + side_outputs = [] + for i in range(len(layers)): + x, side_output = layers[i](x) + side_outputs.append(side_output) + return [x] + side_outputs + + In the general case of layers with variable n_in and n_out and + n_side_outputs being a list of N integers, it does the following:: + + side_outputs = [] + for i in range(N): + res = layer[i](cur_stack) # remove n_in from stack + cur_stack.append(res[: n_side_outputs[i]]) # put back some on stack + side_outputs.extend(res[n_side_outputs:]) + return cur_stack + side_outputs + + Args: + layers: a list of layers to execute + n_side_outputs: an int or a list of ints, how many outputs of each layer + to put aside + + Returns: + A layer that performs the above computation. + """ + if isinstance(n_side_outputs, int): + n_side_outputs = [n_side_outputs] * len(layers) + + # Calculate the n_in for this layer. + running_max = 0 + running_total = 0 + for layer, n_side_output in zip(layers, n_side_outputs): + running_total += layer.n_in + running_max = max(running_max, running_total) + running_total -= layer.n_out - n_side_output + n_in = running_max + + # Create the list of layers to run serially. + cur_stack_size = n_in + serial_layers = [] + for layer, n_side_output in zip(layers, n_side_outputs): + serial_layers.append(layer) + cur_stack_size += layer.n_out - layer.n_in + # Indices to move n_side_outputs to the back of the stack. + # Don't touch first n_out - n_side_outputs. + move_back_indices = list(range(layer.n_out - n_side_output)) + # Then comes the rest of the stack that we're not moving. + move_back_indices += [ + i + layer.n_out for i in range(cur_stack_size - layer.n_out) + ] + # Finally the indices we move. + move_back_indices += [ + i + layer.n_out - n_side_output for i in range(n_side_output) + ] + # Swap them on stack. + serial_layers.append(Select(move_back_indices)) + + return Serial(serial_layers) def FlattenList(): - """Flatten lists.""" - # TODO(jonni): Consider renaming layer to DeepFlatten. - return Fn('FlattenList', lambda x: tuple(_deep_flatten(x))) + """Flatten lists.""" + # TODO(jonni): Consider renaming layer to DeepFlatten. + return Fn("FlattenList", lambda x: tuple(_deep_flatten(x))) def Add(): - """Adds two tensors.""" - return Fn('Add', lambda x0, x1: x0 + x1) + """Adds two tensors.""" + return Fn("Add", lambda x0, x1: jnp.add(x0, x1)) def SubtractTop(): - """Subtracts the first tensor from the second.""" - return Fn('SubtractTop', lambda x0, x1: x1 - x0) + """Subtracts the first tensor from the second.""" + return Fn("SubtractTop", lambda x0, x1: jnp.subtract(x1, x0)) def Multiply(): - """Multiplies two tensors.""" - return Fn('Multiply', lambda x0, x1: x0 * x1) + """Multiplies two tensors.""" + return Fn("Multiply", lambda x0, x1: jnp.multiply(x0, x1)) def Gate(): - """Returns a gating layer on a (memory, gate, candidate) tuple. + """Returns a gating layer on a (memory, gate, candidate) tuple. - Final update is memory * gate + (1 - gate) * candidate + Final update is memory * gate + (1 - gate) * candidate - This gating equation may also be referred to as Highway Network. - Highway Networks: https://arxiv.org/abs/1505.00387 - """ - return Fn('Gate', lambda m, g, c: g * m + (1.0 - g) * c) + This gating equation may also be referred to as Highway Network. + Highway Networks: https://arxiv.org/abs/1505.00387 + """ + return Fn("Gate", lambda m, g, c: g * m + (1.0 - g) * c) class Cache(base.Layer): - """Applies a layer on the first run and returns the outputs on next calls.""" - - def __init__(self, layer): - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - @property - def state(self): - """Returns a tuple containing this layer's state; may be empty.""" - return self._state - - @state.setter - def state(self, state): - """Recursively sets state on this layer and all sublayers.""" - if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: - return - self._state = state - self.sublayer.state = state[1] - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - weights, layer_state = self.sublayer.init(input_signature, use_cache=True) - self.state = ((), layer_state) - self._weights = (weights,) - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model. + """Applies a layer on the first run and returns the outputs on next calls.""" + + def __init__(self, layer): + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + @property + def state(self): + """Returns a tuple containing this layer's state; may be empty.""" + return self._state + + @state.setter + def state(self, state): + """Recursively sets state on this layer and all sublayers.""" + if isinstance(state, dict) and state == base.GET_STATE_FROM_CACHE: + return + self._state = state + self.sublayer.state = state[1] + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + weights, layer_state = self.sublayer.init(input_signature, use_cache=True) + self.state = ((), layer_state) + self._weights = (weights,) + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model. + + Args: + inputs: Tensors required by the sublayer. + + Returns: + Tensors resulting from running the sublayer the first time. + """ + state, weights = self.state, self.weights[0] + if state[0] is (): # pylint: disable=literal-comparison + res, layer_state = self.sublayer.pure_fn( + inputs, weights, state[1], self.rng + ) + self.state = (res, layer_state) + return res + else: + return state[0] - Args: - inputs: Tensors required by the sublayer. - Returns: - Tensors resulting from running the sublayer the first time. - """ - state, weights = self.state, self.weights[0] - if state[0] is (): # pylint: disable=literal-comparison - res, layer_state = self.sublayer.pure_fn( - inputs, weights, state[1], self.rng) - self.state = (res, layer_state) - return res - else: - return state[0] +class BatchLeadingAxes(base.Layer): + """Applies a layer after flattening all but n_last_axes_to_keep to batch. + This can be used to make layers accept an arbitrary number of leading + axes (dimensions) as batch. For example, a Convolution layer may normally + only operate on tensors of shape [B, W, H, C]. In this case, the layer -class BatchLeadingAxes(base.Layer): - """Applies a layer after flattening all but n_last_axes_to_keep to batch. - - This can be used to make layers accept an arbitrary number of leading - axes (dimensions) as batch. For example, a Convolution layer may normally - only operate on tensors of shape [B, W, H, C]. In this case, the layer - - BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3) - - will operate on any tensor [..., W, H, C] and treat the leading axes as batch. - """ - - def __init__(self, layer, n_last_axes_to_keep=1): - if layer.n_out != 1: - raise ValueError('BatchLeadingAxes currently only works for layers with ' - f'n_out = 1, got {layer.n_out}.') - super().__init__(n_in=layer.n_in, n_out=layer.n_out) - self._sublayers = [layer] - self._n_last_axes_to_keep = n_last_axes_to_keep - self._weights = (None,) - self._state = (None,) - - @property - def sublayer(self): - """Returns the unique sublayer managed by this layer.""" - return self._sublayers[0] - - def forward(self, inputs): - """Executes this layer as part of a forward pass through the model.""" - if self._n_in == 1: - inputs = [inputs] - new_inputs = [] - for old_input in inputs: - batched_axes_shape = list(old_input.shape[:-self._n_last_axes_to_keep]) - batched_shape = [-1] + list(old_input.shape[-self._n_last_axes_to_keep:]) - new_inputs.append(jnp.reshape(old_input, batched_shape)) - new_inputs = tuple(new_inputs) - if self._n_in == 1: - new_inputs = new_inputs[0] - res, layer_state = self.sublayer.pure_fn( - new_inputs, self.weights[0], self.state[0], self.rng) - self.state = (layer_state,) - return jnp.reshape(res, batched_axes_shape + list(res.shape[1:])) - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - if self._n_in == 1 and not isinstance(input_signature, (list, tuple)): - input_signature = (input_signature,) - batched_signature = [] - for sub_input_signature in input_signature: - batched_size = 1 - for d in sub_input_signature.shape[:-self._n_last_axes_to_keep]: - batched_size *= d - batched_shape = [batched_size] + list( - sub_input_signature.shape[-self._n_last_axes_to_keep:]) - batched_signature.append(ShapeDtype(batched_shape, - sub_input_signature.dtype)) - if self._n_in == 1: - batched_signature = batched_signature[0] - weights, layer_state = self.sublayer.init(batched_signature, use_cache=True) - self.state = (layer_state,) - self.weights = (weights,) + BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3) + + will operate on any tensor [..., W, H, C] and treat the leading axes as batch. + """ + + def __init__(self, layer, n_last_axes_to_keep=1): + if layer.n_out != 1: + raise ValueError( + "BatchLeadingAxes currently only works for layers with " + f"n_out = 1, got {layer.n_out}." + ) + super().__init__(n_in=layer.n_in, n_out=layer.n_out) + self._sublayers = [layer] + self._n_last_axes_to_keep = n_last_axes_to_keep + self._weights = (None,) + self._state = (None,) + + @property + def sublayer(self): + """Returns the unique sublayer managed by this layer.""" + return self._sublayers[0] + + def forward(self, inputs): + """Executes this layer as part of a forward pass through the model.""" + if self._n_in == 1: + inputs = [inputs] + new_inputs = [] + for old_input in inputs: + batched_axes_shape = list(old_input.shape[: -self._n_last_axes_to_keep]) + batched_shape = [-1] + list(old_input.shape[-self._n_last_axes_to_keep :]) + new_inputs.append(jnp.reshape(old_input, batched_shape)) + new_inputs = tuple(new_inputs) + if self._n_in == 1: + new_inputs = new_inputs[0] + res, layer_state = self.sublayer.pure_fn( + new_inputs, self.weights[0], self.state[0], self.rng + ) + self.state = (layer_state,) + return jnp.reshape(res, batched_axes_shape + list(res.shape[1:])) + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + if self._n_in == 1 and not isinstance(input_signature, (list, tuple)): + input_signature = (input_signature,) + batched_signature = [] + for sub_input_signature in input_signature: + batched_size = 1 + for d in sub_input_signature.shape[: -self._n_last_axes_to_keep]: + batched_size *= d + batched_shape = [batched_size] + list( + sub_input_signature.shape[-self._n_last_axes_to_keep :] + ) + batched_signature.append( + ShapeDtype(batched_shape, sub_input_signature.dtype) + ) + if self._n_in == 1: + batched_signature = batched_signature[0] + weights, layer_state = self.sublayer.init(batched_signature, use_cache=True) + self.state = (layer_state,) + self.weights = (weights,) def Bidirectional(forward_layer, axis=1, merge_layer=Concatenate()): - """Bidirectional combinator for RNNs. - - Args: - forward_layer: A layer, such as `trax.layers.LSTM` or `trax.layers.GRU`. - axis: a time axis of the inputs. Default value is `1`. - merge_layer: A combinator used to combine outputs of the forward - and backward RNNs. Default value is 'trax.layers.Concatenate'. - - Example: - Bidirectional(RNN(n_units=8)) - - Returns: - The Bidirectional combinator for RNNs. - """ - backward_layer = copy.deepcopy(forward_layer) - flip = base.Fn('_FlipAlongTimeAxis', lambda x: jnp.flip(x, axis=axis)) - backward = Serial( - flip, - backward_layer, - flip, - ) - - return Serial( - Branch(forward_layer, backward), - merge_layer, - ) + """Bidirectional combinator for RNNs. + + Args: + forward_layer: A layer, such as `trax.layers.LSTM` or `trax.layers.GRU`. + axis: a time axis of the inputs. Default value is `1`. + merge_layer: A combinator used to combine outputs of the forward + and backward RNNs. Default value is 'trax.layers.Concatenate'. + + Example: + Bidirectional(RNN(n_units=8)) + + Returns: + The Bidirectional combinator for RNNs. + """ + backward_layer = copy.deepcopy(forward_layer) + flip = base.Fn("_FlipAlongTimeAxis", lambda x: jnp.flip(x, axis=axis)) + backward = Serial( + flip, + backward_layer, + flip, + ) + + return Serial( + Branch(forward_layer, backward), + merge_layer, + ) # All module-private helper functions are below. @@ -1013,99 +1105,103 @@ def Bidirectional(forward_layer, axis=1, merge_layer=Concatenate()): def _deep_flatten(items): - """Returns a list of objects, flattening sublists/subtuples along the way. + """Returns a list of objects, flattening sublists/subtuples along the way. - Example: _deep_flatten([1, (2, 3, (4, 5), [6, 7]), [[[8]]]]) would return - the list [1, 2, 3, 4, 5, 6, 7, 8]. + Example: _deep_flatten([1, (2, 3, (4, 5), [6, 7]), [[[8]]]]) would return + the list [1, 2, 3, 4, 5, 6, 7, 8]. - Args: - items: An iterable. If elements of this iterable are lists or tuples, they - will be (recursively) flattened until non-list non-tuple objects are - reached. + Args: + items: An iterable. If elements of this iterable are lists or tuples, they + will be (recursively) flattened until non-list non-tuple objects are + reached. + + Returns: + A list of non-list, non-tuple objects. + """ - Returns: - A list of non-list, non-tuple objects. - """ - def _flat_gen(xs): - for x in xs: - if isinstance(x, (list, tuple)): - for y in _flat_gen(x): - yield y - else: - yield x - return list(_flat_gen(items)) + def _flat_gen(xs): + for x in xs: + if isinstance(x, (list, tuple)): + for y in _flat_gen(x): + yield y + else: + yield x + + return list(_flat_gen(items)) def _ensure_sublayers(layers): - """Ensures that elements in a layer list are layers. - - Args: - layers: A tuple or list whose elements can each be a layer, tuple, or list, - and so on recursively. - - Returns: - An analogous collection of layers in which embedded layer lists are - wrapped in Serial layer instances. - """ - if not layers: # None or an empty list can signal a no-op. - return Serial(None) # no-op, but still handles shapes and initialization - elif isinstance(layers, (list, tuple)): - sublayers_not_lists = [] - for layer in layers: - sublayers_not_lists.append( - Serial(layer) if isinstance(layer, (list, tuple)) else layer) - return sublayers_not_lists - else: - raise TypeError(type(layers)) + """Ensures that elements in a layer list are layers. + + Args: + layers: A tuple or list whose elements can each be a layer, tuple, or list, + and so on recursively. + + Returns: + An analogous collection of layers in which embedded layer lists are + wrapped in Serial layer instances. + """ + if not layers: # None or an empty list can signal a no-op. + return Serial(None) # no-op, but still handles shapes and initialization + elif isinstance(layers, (list, tuple)): + sublayers_not_lists = [] + for layer in layers: + sublayers_not_lists.append( + Serial(layer) if isinstance(layer, (list, tuple)) else layer + ) + return sublayers_not_lists + else: + raise TypeError(type(layers)) def _split_rngs(rng, n_copies): - if rng is None: - return (None,) * n_copies - return fastmath.random.split(rng, n_copies) + if rng is None: + return (None,) * n_copies + return fastmath.random.split(rng, n_copies) def inputs_from_stack(stack, n): - """Returns n inputs from stack.""" - stack = _make_tuple(stack) - return _make_singleitem_or_original(stack[:n]) + """Returns n inputs from stack.""" + stack = _make_tuple(stack) + return _make_singleitem_or_original(stack[:n]) def outputs_onto_stack(outputs, stack, n): - """"Returns the new stack after removing n items and pushing outputs there.""" - outputs = _make_tuple(outputs) - stack = _make_tuple(stack) - return _make_singleitem_or_original(outputs + stack[n:]) + """ "Returns the new stack after removing n items and pushing outputs there.""" + outputs = _make_tuple(outputs) + stack = _make_tuple(stack) + return _make_singleitem_or_original(outputs + stack[n:]) def _make_tuple(xs): - """Returns a tuple from a list, a tuple, or a single element.""" - if isinstance(xs, (list, tuple)): - return tuple(xs) - else: - return (xs,) + """Returns a tuple from a list, a tuple, or a single element.""" + if isinstance(xs, (list, tuple)): + return tuple(xs) + else: + return (xs,) def _make_singleitem_or_original(xs): - """Returns a single element if possible, or the original list/tuple if not.""" - if isinstance(xs, (list, tuple)) and len(xs) == 1: - return xs[0] - else: - return xs + """Returns a single element if possible, or the original list/tuple if not.""" + if isinstance(xs, (list, tuple)) and len(xs) == 1: + return xs[0] + else: + return xs def _shape_without_axis(x, axis): - return x.shape[:axis] + x.shape[axis + 1:] + return x.shape[:axis] + x.shape[axis + 1 :] def _ensure_flat(layers): - """Ensures that layers is a single flat list of Layer instances.""" - if len(layers) == 1 and layers[0] is None: - layers = () - else: - layers = _deep_flatten(layers) - for obj in layers: - if not isinstance(obj, base.Layer): - raise ValueError( - f'Found nonlayer object ({obj}) in layers: {layers}') - return layers + """Ensures that layers is a single flat list of Layer instances.""" + if len(layers) == 1 and layers[0] is None: + layers = () + else: + layers = _deep_flatten(layers) + for obj in layers: + if not isinstance(obj, base.Layer): + raise ValueError( + f"Found non-layer object ({obj}) type ({type(obj)}) ({type(obj)}) in layers: {layers}" + ) + return layers diff --git a/trax/layers/combinators_test.py b/trax/layers/combinators_test.py deleted file mode 100644 index 4f6ba40b8..000000000 --- a/trax/layers/combinators_test.py +++ /dev/null @@ -1,802 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for combinator layers.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np - -from trax import fastmath -from trax import shapes -import trax.layers as tl - - -def DivideBy(val): # pylint: disable=invalid-name - """Returns a simple division layer with n_in == 1 and n_out == 1.""" - return tl.Fn('DivideBy', lambda x: x / val) - - -def ReturnConst(val): # pylint: disable=invalid-name - """Returns a simple const layer with n_in == 0 and n_out == 1.""" - return tl.Fn('ReturnConst', lambda: val) - - -def SmallerThan(val): # pylint: disable=invalid-name - """Checks if the input is smaller than certain value.""" - return tl.Fn('SmallerThan', lambda x: x < val) - - -# TODO(jonni): Consider a more generic home for this utiliity function. -def as_list(outputs): - """Converts layer outputs to a nested list, for easier equality testing. - - Args: - outputs: A tensor or tuple/list of tensors coming from the forward - application of a layer. Each tensor is NumPy ndarray-like, which - complicates simple equality testing (e.g., via `assertEquals`): - such tensors require equality testing to use either `all` (all - elements match) or `any` (at least one element matches), which is not - directly supported in absltest. - - Returns: - A nested list structure containing all the output values, but now directly - testable using `assertEquals`. - """ - if isinstance(outputs, (list, tuple)): - return [as_list(y) for y in outputs] - else: - return outputs.tolist() - - -class SerialTest(absltest.TestCase): - - def test_none_is_no_op(self): - layer = tl.Serial(None) - xs = [np.array([1, 2, 3, 4]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3, 4], - [10, 20, 30]]) - - def test_empty_list_is_no_op(self): - layer = tl.Serial([]) - xs = [np.array([1, 2, 3, 4]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3, 4], - [10, 20, 30]]) - - def test_one_in_one_out(self): - layer = tl.Serial(DivideBy(3)) - x = np.array([3, 6, 9, 12]) - y = layer(x) - self.assertEqual(as_list(y), [1, 2, 3, 4]) - - def test_zero_in_one_out(self): - layer = tl.Serial(ReturnConst(np.array([3, 4, 5, 6]))) - y = layer(()) - self.assertEqual(as_list(y), [3, 4, 5, 6]) - - def test_one_in_two_out(self): - layer = tl.Serial(DivideBy(3), - ReturnConst(np.array([3, 4, 5, 6]))) - x = np.array([3, 6, 9, 12]) - y = layer(x) - self.assertEqual(as_list(y), [[3, 4, 5, 6], - [1, 2, 3, 4]]) - - def test_const_div(self): - layer = tl.Serial(ReturnConst(np.array([3, 6, 9, 12])), - DivideBy(3)) - y = layer(()) - self.assertEqual(as_list(y), [1, 2, 3, 4]) - - def test_div_div(self): - layer = tl.Serial(DivideBy(2.0), DivideBy(5.0)) - x = np.array([10, 20, 30]) - y = layer(x) - self.assertEqual(as_list(y), [1, 2, 3]) - - def test_dup_dup(self): - layer = tl.Serial(tl.Dup(), tl.Dup()) - x = np.array([1, 2, 3]) - ys = layer(x) - self.assertEqual(as_list(ys), [[1, 2, 3], - [1, 2, 3], - [1, 2, 3]]) - - def test_default_name(self): - layer = tl.Serial(tl.Dup(), tl.Dup()) - self.assertIn('Serial', str(layer)) - - def test_custom_name(self): - layer = tl.Serial(tl.Dup(), tl.Dup(), name='Branch') - self.assertIn('Branch', str(layer)) - - def test_weights(self): - model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) - self.assertIsInstance(model.weights, tuple) - self.assertLen(model.weights, 3) - - def test_flat_weights_and_state(self): - model = tl.Serial(tl.Dup(), tl.Dense(5), tl.Serial(tl.Dense(7), tl.Dup())) - sample_input_signature = shapes.signature(np.zeros((2, 3))) - model.init(sample_input_signature) - flat_weights, flat_state = tl.flatten_weights_and_state( - model.weights, model.state) - # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. - # So after making them flat, there are 4 trainable weights. - self.assertLen(flat_weights, 4) - self.assertEmpty(flat_state) - model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(7)) - sig = model2.weights_and_state_signature(sample_input_signature) - weights2, state2 = tl.unflatten_weights_and_state( - flat_weights, flat_state, sig) - model2.weights = weights2 - model2.state = state2 - self.assertLen(model2.weights, 3) - self.assertEqual(model.weights[1], model2.weights[0]) - self.assertEqual(model.weights[2][0], model2.weights[2]) - - def test_flat_weights_and_state_shared(self): - shared = tl.Dense(5) - model = tl.Serial(tl.Dense(5), shared, tl.Serial(shared, tl.Dup())) - sample_input_signature = shapes.signature(np.zeros((2, 3))) - model.init(sample_input_signature) - flat_weights, flat_state = tl.flatten_weights_and_state( - model.weights, model.state) - # Model has 2 pairs of trainable weights: (w, b) for the 2 dense layers. - # So after making them flat, there are 4 trainable weights. - self.assertLen(flat_weights, 4) - self.assertEmpty(flat_state) - model2 = tl.Serial(tl.Dense(5), tl.Dup(), tl.Dense(5)) - sig = model2.weights_and_state_signature(sample_input_signature) - weights2, state2 = tl.unflatten_weights_and_state( - flat_weights, flat_state, sig) - model2.weights = weights2 - model2.state = state2 - self.assertLen(model2.weights, 3) - self.assertEqual(model.weights[0], model2.weights[0]) - self.assertEqual(model.weights[1], model2.weights[2]) - - def test_assign_sublayer_weights(self): - layer = tl.Dense(5, use_bias=False) - model = tl.Serial(tl.Serial(layer, tl.Dense(6)), tl.Dense(7)) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - new_layer_weights = np.random.uniform(weights[0][0].shape) - layer.weights = new_layer_weights - self.assertIs(model.weights[0][0], new_layer_weights) - - def test_shared_weights(self): - layer = tl.Dense(5) - model = tl.Serial(layer, layer) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_nested(self): - layer = tl.Dense(5) - model = tl.Serial(layer, tl.Serial(layer)) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_double_nested(self): - layer = tl.Dense(5) - model = tl.Serial(tl.Serial(layer), tl.Serial(layer)) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_for_shared_serial(self): - layer = tl.Serial(tl.Dense(5), tl.Dense(5)) - model = tl.Serial(layer, layer) - sample_input = np.array([1, 2, 3, 4, 5]) - # Init gives weights reflecting weight sharing. - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - # Forward pass runs successfully. - y = model(sample_input) - self.assertEqual(y.shape, (5,)) - - def test_state(self): - model = tl.Serial(tl.Dense(4), tl.Dense(5), tl.Dense(7)) - self.assertIsInstance(model.state, tuple) - self.assertLen(model.state, 3) - - def test_set_rng_recurse_two_levels(self): - dense_00 = tl.Dense(2) - dense_01 = tl.Dense(2) - dense_10 = tl.Dense(2) - dense_11 = tl.Dense(2) - layer = tl.Serial( - tl.Serial(dense_00, dense_01), - tl.Serial(dense_10, dense_11), - ) - input_signature = shapes.ShapeDtype((1, 2)) - - _, _ = layer.init(input_signature) - weights = layer.weights - dense_00_w, dense_00_b = weights[0][0] - dense_01_w, dense_01_b = weights[0][1] - dense_10_w, dense_10_b = weights[1][0] - dense_11_w, dense_11_b = weights[1][1] - - # Setting rng's recursively during init should yield differing weights. - self.assertFalse(np.array_equal(dense_00_w, dense_01_w)) - self.assertFalse(np.array_equal(dense_00_b, dense_01_b)) - self.assertFalse(np.array_equal(dense_10_w, dense_11_w)) - self.assertFalse(np.array_equal(dense_10_b, dense_11_b)) - - -class ParallelTest(absltest.TestCase): - - def test_dup_dup(self): - layer = tl.Parallel(tl.Dup(), tl.Dup()) - xs = [np.array([1, 2, 3]), - np.array([10, 20])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3], - [1, 2, 3], - [10, 20], - [10, 20]]) - - def test_div_div(self): - layer = tl.Parallel(DivideBy(0.5), DivideBy(3.0)) - xs = [np.array([1, 2, 3]), - np.array([30, 60])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[2, 4, 6], - [10, 20]]) - - def test_two_no_ops(self): - layer = tl.Parallel([], None) - xs = [np.array([1, 2, 3]), - np.array([10, 20])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3], - [10, 20]]) - - def test_default_name(self): - layer = tl.Parallel(tl.Dup(), tl.Dup()) - self.assertIn('Parallel', str(layer)) - - def test_custom_name(self): - layer = tl.Parallel(tl.Dup(), tl.Dup(), name='DupDup') - self.assertIn('DupDup', str(layer)) - - def test_weights(self): - model = tl.Parallel(tl.Dense(3), tl.Dense(5)) - self.assertIsInstance(model.weights, tuple) - self.assertLen(model.weights, 2) - - def test_shared_weights(self): - layer = tl.Dense(5) - model = tl.Parallel(layer, layer) - sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_nested(self): - layer = tl.Dense(5) - model = tl.Parallel([layer, tl.Dense(2)], - [layer, tl.Dense(2)]) - sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE) - - def test_shared_weights_for_shared_parallel(self): - layer = tl.Parallel(tl.Dense(5), tl.Dense(7)) - model = tl.Parallel(layer, layer) - sample_input = [ - np.array([1, 2, 3]), - np.array([10, 20, 30]), - np.array([100, 200, 300]), - np.array([1000, 2000, 3000]), - ] - # Init gives weights reflecting weight sharing. - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIsNot(weights[0], tl.GET_WEIGHTS_FROM_CACHE) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - # Forward pass runs successfully. - y0, y1, y2, y3 = model(sample_input) - self.assertEqual(y0.shape, (5,)) - self.assertEqual(y1.shape, (7,)) - self.assertEqual(y2.shape, (5,)) - self.assertEqual(y3.shape, (7,)) - - def test_state(self): - model = tl.Parallel(tl.Dense(3), tl.Dense(5)) - self.assertIsInstance(model.state, tuple) - self.assertLen(model.state, 2) - - -class ConcatenateTest(absltest.TestCase): - - def test_n_in_n_out(self): - layer = tl.Concatenate() - self.assertEqual(layer.n_in, 2) - self.assertEqual(layer.n_out, 1) - - def test_with_defaults(self): - layer = tl.Concatenate() # Default n_items=2, axis=-1 - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[1, 2, 3, 10, 20, 30], - [4, 5, 6, 40, 50, 60]]) - - def test_axis_0(self): - layer = tl.Concatenate(axis=0) - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]])] - y = layer(xs) - self.assertEqual(as_list(y), [[1, 2, 3], - [4, 5, 6], - [10, 20, 30], - [40, 50, 60]]) - - def test_axis_1(self): - layer = tl.Concatenate(axis=1) - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]])] - y = layer(xs) - self.assertEqual(as_list(y), [[1, 2, 3, 10, 20, 30], - [4, 5, 6, 40, 50, 60]]) - - def test_n_items_is_not_default(self): - layer = tl.Concatenate(n_items=3) - xs = [np.array([[1, 2, 3], - [4, 5, 6]]), - np.array([[10, 20, 30], - [40, 50, 60]]), - np.array([[100, 200, 300], - [400, 500, 600]])] - y = layer(xs) - self.assertEqual(y.shape, (2, 9)) - self.assertEqual(as_list(y), [[1, 2, 3, 10, 20, 30, 100, 200, 300], - [4, 5, 6, 40, 50, 60, 400, 500, 600]]) - - def test_repr(self): - layer = tl.Concatenate() - self.assertEqual(repr(layer), 'Concatenate_in2') - - layer = tl.Concatenate(axis=0) - self.assertEqual(repr(layer), 'Concatenate_axis0_in2') - - layer = tl.Concatenate(axis=1) - self.assertEqual(repr(layer), 'Concatenate_axis1_in2') - - layer = tl.Concatenate(n_items=3) - self.assertEqual(repr(layer), 'Concatenate_in3') - - -class BranchTest(absltest.TestCase): - - def test_noop_dup(self): - layer = tl.Branch([], tl.Dup()) - x = np.array([1, 2, 3]) - ys = layer(x) - self.assertEqual(as_list(ys), [[1, 2, 3], - [1, 2, 3], - [1, 2, 3]]) - - def test_add_div(self): - layer = tl.Branch(tl.Add(), DivideBy(0.5)) - xs = [np.array([1, 2, 3]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[11, 22, 33], - [2, 4, 6]]) - - def test_one_sublayer(self): - layer = tl.Branch(DivideBy(0.5)) - x = np.array([1, 2, 3]) - ys = layer(x) - self.assertEqual(as_list(ys), [2, 4, 6]) - - def test_default_name(self): - layer = tl.Branch(tl.Add(), DivideBy(0.5)) - self.assertIn('Branch', str(layer)) - - def test_printing_sublayers(self): - layer = tl.Branch(tl.Add(), tl.Add()) - expected_result = 'Branch_in2_out2[\n Add_in2\n Add_in2\n]' - self.assertEqual(expected_result, str(layer)) - - -class SelectTest(absltest.TestCase): - - def test_computes_n_in(self): - layer = tl.Select([0, 0]) - self.assertEqual(layer.n_in, 1) - - layer = tl.Select([1, 0]) - self.assertEqual(layer.n_in, 2) - - layer = tl.Select([2]) - self.assertEqual(layer.n_in, 3) - - def test_given_n_in(self): - layer = tl.Select([0], n_in=2) - self.assertEqual(layer.n_in, 2) - - layer = tl.Select([0], n_in=3) - self.assertEqual(layer.n_in, 3) - - def test_first_of_3(self): - layer = tl.Select([0], n_in=3) - xs = [np.array([1, 2, 3]), - np.array([10, 20]), - np.array([100])] - y = layer(xs) - self.assertEqual(as_list(y), [1, 2, 3]) - - def test_second_of_3(self): - layer = tl.Select([1], n_in=3) - xs = [np.array([1, 2, 3]), - np.array([10, 20]), - np.array([100])] - y = layer(xs) - self.assertEqual(as_list(y), [10, 20]) - - -class DropTest(absltest.TestCase): - - def test_drop(self): - layer = tl.Drop() - x = np.array([1, 2, 3]) - y = layer(x) - self.assertEqual(as_list(y), []) - - -class SwapTest(absltest.TestCase): - - def test_swap(self): - layer = tl.Swap() - xs = [np.array([1, 2, 3]), - np.array([10, 20, 30])] - ys = layer(xs) - self.assertEqual(as_list(ys), [[10, 20, 30], - [1, 2, 3]]) - - -class ChunkTest(absltest.TestCase): - - def test_chunk(self): - layer = tl.Dense(4) - x = np.array([[1, 2, 3], [4, 5, 6]]) - layer.init(x) - y = layer(x) - z = tl.Chunk(layer, 1)(x) - self.assertLess(np.sum((y - z)**2), 1e-5) # y == z upto numerics - - def test_chunk_uneven_numbers(self): - layer = tl.Dense(4) - x = np.array([[1, 2, 3], [4, 5, 6]]) - layer.init(x) - y = layer(x) - z = tl.Chunk(layer, 3)(x) # By default it should just pass - self.assertLess(np.sum((y - z)**2), 1e-5) # y == z upto numerics - chunk_with_test = tl.Chunk(layer, 3, pass_unchunkable=False) - self.assertRaises(tl.LayerError, lambda: chunk_with_test(x)) - - -class SerialWithSideOutputsTest(absltest.TestCase): - - def test_serial_with_side_outputs_div_div(self): - def some_layer(): - return tl.Parallel(DivideBy(2.0), DivideBy(5.0)) - layer = tl.SerialWithSideOutputs([some_layer(), some_layer()]) - xs = (np.array([1, 2, 3]), - np.array([10, 20, 30, 40, 50]), - np.array([100, 200])) - ys = layer(xs) - output_shapes = [y.shape for y in ys] - self.assertEqual(output_shapes, [(3,), (5,), (2,)]) - - -BACKENDS = [fastmath.Backend.JAX] - - -@parameterized.named_parameters( - ('_' + b.value, b) for b in BACKENDS) -class ScanTest(parameterized.TestCase): - - def _AddWithCarry(self): # pylint: disable=invalid-name - del self - def f(x, carry): - res = x + carry - return res, res # output and carry are the same - return tl.Fn('AddWithCarry', f, n_out=2) - - def test_default_axis(self, backend): - with fastmath.use_backend(backend): - layer = tl.Scan(self._AddWithCarry()) - xs = [ - np.array([[0, 1, 2, 3], - [0, 10, 20, 30], - [0, 100, 200, 300]]), - np.array([9000, 8000, 7000, 6000]) - ] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[9000, 8001, 7002, 6003], - [9000, 8011, 7022, 6033], - [9000, 8111, 7222, 6333] - ], - [9000, 8111, 7222, 6333] - ]) - - def test_axis_1(self, backend): - with fastmath.use_backend(backend): - layer = tl.Scan(self._AddWithCarry(), axis=1) - xs = [ - np.array([[0, 1, 2, 3], - [0, 10, 20, 30], - [0, 100, 200, 300]]), - np.array([9000, - 8000, - 7000]) - ] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[9000, 9001, 9003, 9006], - [8000, 8010, 8030, 8060], - [7000, 7100, 7300, 7600] - ], - [9006, - 8060, - 7600] - ]) - - def test_predict(self, backend): - with fastmath.use_backend(backend): - layer = tl.Scan(self._AddWithCarry(), axis=1, mode='predict') - xs = [np.array([[0, 1, 2]]), - np.array([90])] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[90, 91, 93]], - [93]]) - xs = [np.array([[3, 4]]), - np.array([90])] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[96, 100]], - [100]]) - - def test_multi_input(self, backend): - def _MultiInputFn(): # pylint: disable=invalid-name - def f(a, b, carry): - return a + b, b, carry + 1 - return tl.Fn('MultiInputFn', f, n_out=2) - - with fastmath.use_backend(backend): - layer = tl.Scan(_MultiInputFn(), axis=1) - xs = [ - np.array([[0, 1, 2], - [0, 10, 20]]), - np.array([[4, 5, 6], - [40, 50, 60]]), - np.array([9000, - 8000]) - ] - ys = layer(xs) - self.assertEqual(as_list(ys), - [[[4, 6, 8], - [40, 60, 80]], - [[4, 5, 6], - [40, 50, 60]], - [9003, - 8003] - ]) - - def test_no_carry(self, backend): - def _AddOne(): # pylint: disable=invalid-name - return tl.Fn('AddOne', lambda x: x + 1) - - with fastmath.use_backend(backend): - layer = tl.Scan(_AddOne(), n_carry=0) - x = np.array([[1, 3, 7], - [10, 30, 70]]) - y = layer(x) - self.assertEqual(as_list(y), [[2, 4, 8], - [11, 31, 71]]) - - -class CondTest(absltest.TestCase): - - def test_basic_true(self): - cond = ReturnConst(True) - true = ReturnConst([2]) - false = ReturnConst([5]) - layer = tl.Cond(cond, true, false) - layer.init(()) - xs = tuple() - ys = layer(xs) - self.assertEqual(as_list(ys), 2) - - def test_basic_false(self): - cond = ReturnConst(False) - true = ReturnConst([2]) - false = ReturnConst([5]) - layer = tl.Cond(cond, true, false) - layer.init(()) - xs = tuple() - ys = layer(xs) - self.assertEqual(as_list(ys), 5) - - def test_complex_blocks(self): - cond = ReturnConst(True) - true = DivideBy(2.) - false = DivideBy(4.) - layer = tl.Cond(cond, true, false) - xs = [np.arange(5).astype(np.float32)] - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [0., 0.5, 1.0, 1.5, 2.0]) - - def test_condition_func_true(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = DivideBy(4.) - layer = tl.Cond(cond, true, false) - xs = (np.array(2.), np.array([4., 12.])) - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [2., 6.]) - - def test_condition_func_false(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = DivideBy(4.) - layer = tl.Cond(cond, true, false) - xs = (np.array(4.), np.array([4., 12.])) - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [1., 3.]) - - def test_condition_func_default_false(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - layer = tl.Cond(cond, true) - xs = (np.array(4.), np.array([4., 12.])) - layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual(as_list(ys), [4., 12.]) - - def test_exception_n_out(self): - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = tl.Dup() - self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) - - def test_exception_n_in(self): - cond = SmallerThan(3.0) - true = ReturnConst(2.) - false = DivideBy(2.) - self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false)) - - def test_exception_run1(self): - # We expect exactly one input. - cond = SmallerThan(3.0) - true = ReturnConst(2.) - false = ReturnConst(5.) - def init_and_run(layer, xs): - layer.init(shapes.signature(xs)) - layer(xs) - # It will pass with one input. - xs = np.array(4.) - layer = tl.Cond(cond, true, false) - init_and_run(layer, xs) - # It will fail with zero or two inputs. - for xs in ((), (np.array(4.), np.array([4., 12.]))): - layer = tl.Cond(cond, true, false) - # pylint: disable=cell-var-from-loop - self.assertRaises(Exception, lambda: init_and_run(layer, xs)) - - def test_exception_run2(self): - # We expect exactly two inputs. - cond = SmallerThan(3.0) - true = DivideBy(2.) - false = DivideBy(5.) - def init_and_run(layer, xs): - layer.init(shapes.signature(xs)) - layer(xs) - # It will pass with two inputs. - xs = (np.array(4.), np.array([4., 12.])) - layer = tl.Cond(cond, true, false) - init_and_run(layer, xs) - # It will fail with zero or one input. - for xs in ((), (np.array(4.))): - # pylint: disable=cell-var-from-loop - self.assertRaises(Exception, lambda: init_and_run(layer, xs)) - - def test_weights_and_state(self): - cond = SmallerThan(3.0) - true = tl.Dense(5) - false = tl.Dense(5) - different = tl.Dense(5) - layer = tl.Cond(cond, true, false) - xs = (np.array(2.), np.array([0., 1., 2.])) - layer.init(shapes.signature(xs)) - - # weights - self.assertEqual(as_list(layer.weights), - as_list((cond.weights, true.weights, false.weights))) - self.assertNotEqual(as_list(true.weights), as_list(false.weights)) - self.assertNotEqual(as_list(true.weights), as_list(different.weights)) - - false.weights = true.weights - self.assertEqual(as_list(layer.weights), - as_list((cond.weights, true.weights, true.weights))) - - layer.weights = (cond.weights, true.weights, different.weights) - self.assertEqual(as_list(layer.weights), - as_list((cond.weights, true.weights, different.weights))) - # state - self.assertEqual(as_list(layer.state), - as_list((cond.state, true.state, false.state))) - # just check if simple assignments (setter from base.Layer) work correctly - # with Cond.init_weights_and_state ; all states are empty so there is no - # point in checking equality - false.state = true.state - layer.state = (cond.state, true.state, different.state) - - -class BatchLeadingAxesTest(absltest.TestCase): - - def _Id3Dim(self): # pylint: disable=invalid-name - del self - def f(x): - assert len(x.shape) == 3 - return x - return tl.Fn('Id3Dim', f, n_out=1) - - def test_2axes(self): - layer = tl.BatchLeadingAxes(self._Id3Dim(), n_last_axes_to_keep=2) - ys = layer(np.zeros((3, 4, 5))) - self.assertEqual(ys.shape, (3, 4, 5)) - ys = layer(np.zeros((2, 3, 4, 5))) - self.assertEqual(ys.shape, (2, 3, 4, 5)) - ys = layer(np.zeros((1, 2, 3, 4, 5))) - self.assertEqual(ys.shape, (1, 2, 3, 4, 5)) - - -class BidirectionalTest(absltest.TestCase): - - def test_dimensionality(self): - x = np.ones((2, 3, 8)) - layer = tl.Bidirectional(tl.GRU(n_units=8)) - input_signature = shapes.signature(x) - _, _ = layer.init(input_signature) - yhat = layer(x) - - self.assertEqual(yhat.shape, (2, 3, 8 + 8)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/convolution.py b/trax/layers/convolution.py index d0658f679..54e44ff8b 100644 --- a/trax/layers/convolution.py +++ b/trax/layers/convolution.py @@ -26,167 +26,193 @@ class Conv(base.Layer): - """Layer constructor function for a general convolution layer.""" - - def __init__(self, filters, kernel_size, strides=None, padding='VALID', - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - kernel_initializer=None, - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - super().__init__() - self._filters = filters - self._kernel_size = kernel_size - self._padding = padding - self._dimension_numbers = dimension_numbers - self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers - self._one = (1,) * len(kernel_size) - self._strides = strides or self._one - self._bias_initializer = bias_initializer - self._use_bias = use_bias - rhs_spec = self._rhs_spec - self._kernel_initializer = kernel_initializer - if kernel_initializer is None: - self._kernel_initializer = init.GlorotNormalInitializer( - rhs_spec.index('O'), rhs_spec.index('I')) - - def _check_nhwc(self): - msg = 'Convolutions on more than 4 dimensions only supported in NHWC.' - assert self._lhs_spec == self._out_spec == 'NHWC', msg - - def forward(self, x): - if self._use_bias: - w, b = self.weights - else: - w = self.weights - x_shape = list(x.shape) - if len(x_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, x_shape[:-3]) - x = jnp.reshape(x, [new_batch_dim] + x_shape[-3:]) - res = fastmath.conv( - x, w, self._strides, self._padding, self._dimension_numbers, - self._one) - if self._use_bias: - res = res + b - if len(x_shape) > 4: - res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) - return res - - def _kernel_shape(self, input_shape): - """Helper to calculate the kernel shape.""" - kernel_size_iter = iter(self._kernel_size) - return [self._filters if c == 'O' else - input_shape[self._lhs_spec.index('C')] if c == 'I' else - next(kernel_size_iter) for c in self._rhs_spec] - - def init_weights_and_state(self, input_signature): - input_shape = input_signature.shape - if len(input_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) - input_shape = [new_batch_dim] + list(input_shape[-3:]) - kernel_shape = self._kernel_shape(input_shape) - rng1, rng2 = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(kernel_shape, rng1) - if self._use_bias: - bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec] - bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) - b = self._bias_initializer(bias_shape, rng2) - self.weights = (w, b) - else: - self.weights = w + """Layer constructor function for a general convolution layer.""" + + def __init__( + self, + filters, + kernel_size, + strides=None, + padding="VALID", + dimension_numbers=("NHWC", "HWIO", "NHWC"), + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + ): + super().__init__() + self._filters = filters + self._kernel_size = kernel_size + self._padding = padding + self._dimension_numbers = dimension_numbers + self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers + self._one = (1,) * len(kernel_size) + self._strides = strides or self._one + self._bias_initializer = bias_initializer + self._use_bias = use_bias + rhs_spec = self._rhs_spec + self._kernel_initializer = kernel_initializer + if kernel_initializer is None: + self._kernel_initializer = init.GlorotNormalInitializer( + rhs_spec.index("O"), rhs_spec.index("I") + ) + + def _check_nhwc(self): + msg = "Convolutions on more than 4 dimensions only supported in NHWC." + assert self._lhs_spec == self._out_spec == "NHWC", msg + + def forward(self, x): + if self._use_bias: + w, b = self.weights + else: + w = self.weights + x_shape = list(x.shape) + if len(x_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, x_shape[:-3]) + x = jnp.reshape(x, [new_batch_dim] + x_shape[-3:]) + res = fastmath.conv( + x, w, self._strides, self._padding, self._dimension_numbers, self._one + ) + if self._use_bias: + res = res + b + if len(x_shape) > 4: + res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) + return res + + def _kernel_shape(self, input_shape): + """Helper to calculate the kernel shape.""" + kernel_size_iter = iter(self._kernel_size) + return [ + self._filters + if c == "O" + else input_shape[self._lhs_spec.index("C")] + if c == "I" + else next(kernel_size_iter) + for c in self._rhs_spec + ] + + def init_weights_and_state(self, input_signature): + input_shape = input_signature.shape + if len(input_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) + input_shape = [new_batch_dim] + list(input_shape[-3:]) + kernel_shape = self._kernel_shape(input_shape) + rng1, rng2 = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(kernel_shape, rng1) + if self._use_bias: + bias_shape = [self._filters if c == "C" else 1 for c in self._out_spec] + bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) + b = self._bias_initializer(bias_shape, rng2) + self.weights = (w, b) + else: + self.weights = w class CausalConv(Conv): - """Causal (masked) convolution for [batch x time x depth] sequences. - - Maintains causality along time axis. Used in language modeling tasks. - """ - - def __init__(self, - filters, - kernel_width=3, - kernel_initializer=None, - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - super().__init__( - filters=filters, - kernel_size=(kernel_width,), - strides=None, - padding='VALID', - dimension_numbers=('NWC', 'WIO', 'NWC'), + """Causal (masked) convolution for [batch x time x depth] sequences. + + Maintains causality along time axis. Used in language modeling tasks. + """ + + def __init__( + self, + filters, + kernel_width=3, + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + ): + super().__init__( + filters=filters, + kernel_size=(kernel_width,), + strides=None, + padding="VALID", + dimension_numbers=("NWC", "WIO", "NWC"), + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + use_bias=use_bias, + ) + + def forward(self, x): + assert self._padding == "VALID" + # Left pad with 0s. Applying an unmasked valid convolution on top of this + # yields a causal convolution. + # TODO(ddohan): Support strided and dilated convolutions. + rate = 1 + effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1) + pad = effective_kernel_size - 1 + x_leftpad = jnp.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode="constant") + return super().forward(x_leftpad) + + +def Conv1d( + filters, + kernel_size, + stride=1, + padding="VALID", + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, +): + return Conv( + filters, + (kernel_size,), + strides=(stride,), + padding=padding, + dimension_numbers=("NWC", "WIO", "NWC"), kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, - use_bias=use_bias) - - def forward(self, x): - assert self._padding == 'VALID' - # Left pad with 0s. Applying an unmasked valid convolution on top of this - # yields a causal convolution. - # TODO(ddohan): Support strided and dilated convolutions. - rate = 1 - effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1) - pad = effective_kernel_size - 1 - x_leftpad = ( - jnp.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant')) - return super().forward(x_leftpad) - - -def Conv1d(filters, kernel_size, stride=1, padding='VALID', - kernel_initializer=None, - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - return Conv(filters, (kernel_size,), strides=(stride,), padding=padding, - dimension_numbers=('NWC', 'WIO', 'NWC'), - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - use_bias=use_bias) + use_bias=use_bias, + ) def _zero_pad(x, pad, axis): # pylint: disable = invalid-name - """Helper for jnp.pad with 0s for single-axis case.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = pad # Padding on axis. - return jnp.pad(x, pad_widths, mode='constant') + """Helper for jnp.pad with 0s for single-axis case.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = pad # Padding on axis. + return jnp.pad(x, pad_widths, mode="constant") # @assert_shape('bld->bld') class CausalDepthwiseConv(base.Layer): - """A causal depthwise convolution layer.""" - - def __init__(self, - kernel_size=3, - kernel_initializer=init.GlorotUniformInitializer(), - use_bfloat16=False): - """Returns a causal depthwise convolution layer.""" - super().__init__(n_in=1, n_out=1) - self._kernel_size = kernel_size - self._kernel_initializer = kernel_initializer - self._use_bfloat16 = use_bfloat16 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input. - """ - w = self.weights - res = x * w[0, :][None, None, :] - for i in range(1, self._kernel_size): - x = _zero_pad(x, (1, 0), 1) - x = x[:, :-1, :] - res += x * w[i, :][None, None, :] - return res - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - shape_w = (self._kernel_size, input_signature.shape[-1]) - rng_w, _ = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(shape_w, rng_w) - if self._use_bfloat16: - w = w.astype(jnp.bfloat16) - self.weights = w + """A causal depthwise convolution layer.""" + + def __init__( + self, + kernel_size=3, + kernel_initializer=init.GlorotUniformInitializer(), + use_bfloat16=False, + ): + """Returns a causal depthwise convolution layer.""" + super().__init__(n_in=1, n_out=1) + self._kernel_size = kernel_size + self._kernel_initializer = kernel_initializer + self._use_bfloat16 = use_bfloat16 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + w = self.weights + res = x * w[0, :][None, None, :] + for i in range(1, self._kernel_size): + x = _zero_pad(x, (1, 0), 1) + x = x[:, :-1, :] + res += x * w[i, :][None, None, :] + return res + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + shape_w = (self._kernel_size, input_signature.shape[-1]) + rng_w, _ = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(shape_w, rng_w) + if self._use_bfloat16: + w = w.astype(jnp.bfloat16) + self.weights = w diff --git a/trax/layers/convolution_test.py b/trax/layers/convolution_test.py deleted file mode 100644 index 7d7c69d30..000000000 --- a/trax/layers/convolution_test.py +++ /dev/null @@ -1,91 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for convolution layers.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -import trax.layers as tl - - -class ConvolutionTest(absltest.TestCase): - - def test_call(self): - layer = tl.Conv(30, (3, 3)) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 3, 3, 30)) - - def test_use_bias_true(self): - layer = tl.Conv(30, (3, 3), use_bias=True) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 3, 3, 30)) - - self.assertIsInstance(layer.weights, tuple) - self.assertLen(layer.weights, 2) - self.assertEqual(layer.weights[0].shape, (3, 3, 20, 30)) - self.assertEqual(layer.weights[1].shape, (30,)) - - def test_use_bias_false(self): - layer = tl.Conv(30, (3, 3), use_bias=False) - x = np.ones((9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 3, 3, 30)) - # With use_bias=False, layer.weights is just 'w' and there is no 'b'. - self.assertEqual(layer.weights.shape, (3, 3, 20, 30)) - - def test_call_rebatch(self): - layer = tl.Conv(30, (3, 3)) - x = np.ones((2, 9, 5, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (2, 9, 3, 3, 30)) - - -class CausalConvolutionTest(absltest.TestCase): - - def test_causal_conv(self): - layer = tl.CausalConv(filters=30, kernel_width=3) - x = np.ones((9, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 5, 30)) - - # TODO(ddohan): How to test for causality? Gradient check between positions? - - def test_causal_conv_use_bias_false(self): - layer = tl.CausalConv(filters=30, kernel_width=3, use_bias=False) - x = np.ones((9, 5, 20)) - layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (9, 5, 30)) - - self.assertEqual(layer.weights.shape, (3, 20, 30)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/core.py b/trax/layers/core.py index b7fd6fc31..57ac10410 100644 --- a/trax/layers/core.py +++ b/trax/layers/core.py @@ -15,10 +15,11 @@ """Core layer types and key functions used by various layers.""" -from absl import logging import numpy as np import tensorflow as tf +from absl import logging + from trax import fastmath from trax.fastmath import numpy as jnp from trax.layers import base @@ -29,841 +30,892 @@ # The output tensor has the same shape as the input tensor, except for the size # of the last dimension. -@assert_shape('...a->...b') +@assert_shape("...a->...b") class Dense(base.Layer): - """A dense (a.k.a. fully-connected, affine) layer. - - Dense layers are the prototypical example of a trainable layer, i.e., a layer - with trainable weights. Each node in a dense layer computes a weighted sum of - all node values from the preceding layer and adds to that sum a node-specific - bias term. The full layer computation is expressed compactly in linear - algebra as an affine map `y = Wx + b`, where `W` is a matrix and `y`, `x`, - and `b` are vectors. The layer is trained, or "learns", by updating the - values in `W` and `b`. - - Less commonly, a dense layer can omit the bias term and be a pure linear map: - `y = Wx`. - """ - - def __init__(self, - n_units, - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True, - use_bfloat16=False): - """Returns a dense (fully connected) layer of width `n_units`. - - A dense layer maps collections of `R^m` vectors to `R^n`, where `n` - (`= n_units`) is fixed at layer creation time, and `m` is set at layer - initialization time. - - Args: - n_units: Number of nodes in the layer, also known as the width of the - layer. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - use_bias: If `True`, compute an affine map `y = Wx + b`; else compute - a linear map `y = Wx`. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - """ - super().__init__(name=f'Dense_{n_units}') - self._n_units = n_units - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - self._use_bias = use_bias - self._use_bfloat16 = use_bfloat16 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input, except the final dimension - is the layer's `n_units` value. + """A dense (a.k.a. fully-connected, affine) layer. + + Dense layers are the prototypical example of a trainable layer, i.e., a layer + with trainable weights. Each node in a dense layer computes a weighted sum of + all node values from the preceding layer and adds to that sum a node-specific + bias term. The full layer computation is expressed compactly in linear + algebra as an affine map `y = Wx + b`, where `W` is a matrix and `y`, `x`, + and `b` are vectors. The layer is trained, or "learns", by updating the + values in `W` and `b`. + + Less commonly, a dense layer can omit the bias term and be a pure linear map: + `y = Wx`. """ - if self._use_bias: - if not isinstance(self.weights, (tuple, list)): - raise ValueError(f'Weights should be a (w, b) tuple or list; ' - f'instead got: {self.weights}') - w, b = self.weights - return jnp.dot(x, w) + b # Affine map. - else: - w = self.weights - return jnp.dot(x, w) # Linear map. - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights. - - Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the - default case), or a `w` tensor for layers created with `use_bias=False`. - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. - """ - shape_w = (input_signature.shape[-1], self._n_units) - shape_b = (self._n_units,) - rng_w, rng_b = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(shape_w, rng_w) - if self._use_bfloat16: - w = w.astype(jnp.bfloat16) - - if self._use_bias: - b = self._bias_initializer(shape_b, rng_b) - if self._use_bfloat16: - b = b.astype(jnp.bfloat16) - self.weights = (w, b) - else: - self.weights = w + def __init__( + self, + n_units, + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + use_bfloat16=False, + ): + """Returns a dense (fully connected) layer of width `n_units`. + + A dense layer maps collections of `R^m` vectors to `R^n`, where `n` + (`= n_units`) is fixed at layer creation time, and `m` is set at layer + initialization time. + + Args: + n_units: Number of nodes in the layer, also known as the width of the + layer. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + use_bias: If `True`, compute an affine map `y = Wx + b`; else compute + a linear map `y = Wx`. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + """ + super().__init__(name=f"Dense_{n_units}") + self._n_units = n_units + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + self._use_bias = use_bias + self._use_bfloat16 = use_bfloat16 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input, except the final dimension + is the layer's `n_units` value. + """ + if self._use_bias: + if not isinstance(self.weights, (tuple, list)): + raise ValueError( + f"Weights should be a (w, b) tuple or list; " + f"instead got: {self.weights}" + ) + w, b = self.weights + return jnp.dot(x, w) + b # Affine map. + else: + w = self.weights + return jnp.dot(x, w) # Linear map. + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights. + + Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the + default case), or a `w` tensor for layers created with `use_bias=False`. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. + """ + shape_w = (input_signature.shape[-1], self._n_units) + shape_b = (self._n_units,) + rng_w, rng_b = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(shape_w, rng_w) + if self._use_bfloat16: + w = w.astype(jnp.bfloat16) + + if self._use_bias: + b = self._bias_initializer(shape_b, rng_b) + if self._use_bfloat16: + b = b.astype(jnp.bfloat16) + self.weights = (w, b) + else: + self.weights = w # The output tensor has the same shape as the input tensor, but with added # dimension at the end. This dimension size corresponds to embedding depth. -@assert_shape('...->...d') +@assert_shape("...->...d") class Embedding(base.Layer): - """Trainable layer that maps discrete tokens/IDs to vectors. - - Embedding layers are commonly used to map discrete data, like words in NLP, - into vectors. Here is a canonical example:: - - vocab_size = 5 - word_ids = np.array([1, 2, 3, 4], dtype=np.int32) # word_ids < vocab_size - embedding_layer = tl.Embedding(vocab_size, 32) - embedding_layer.init(trax.shapes.signature(word_ids)) - embedded = embedding_layer(word_ids) # embedded.shape = (4, 32) - """ - - def __init__(self, - vocab_size, - d_feature, - use_bfloat16=False, - kernel_initializer=init.ScaledInitializer( - out_dim=-1, in_dim=-2, scale=1., mode='fan_out', - distribution='uniform')): - """Returns an embedding layer with given vocabulary size and vector size. - - The layer clips input values (token IDs) to the range `[0, vocab_size)`. - That is, negative token IDs all clip to `0` before being mapped to a - vector, and token IDs with value `vocab_size` or greater all clip to - `vocab_size - 1` before being mapped to a vector. + """Trainable layer that maps discrete tokens/IDs to vectors. - Args: - vocab_size: Size of the input vocabulary. The layer will assign a unique - vector to each id in `range(vocab_size)`. - d_feature: Dimensionality/depth of the output vectors. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - kernel_initializer: Function that creates (random) initial vectors for - the embedding. - """ - # TODO(jonni): is the clipping behavior what we want going forward? - super().__init__(name=f'Embedding_{vocab_size}_{d_feature}') - self._d_feature = d_feature # feature dimensionality - self._vocab_size = vocab_size - self._use_bfloat16 = use_bfloat16 - self._kernel_initializer = kernel_initializer - - def forward(self, x): - """Returns embedding vectors corresponding to input token IDs. - - Args: - x: Tensor of token IDs. + Embedding layers are commonly used to map discrete data, like words in NLP, + into vectors. Here is a canonical example:: - Returns: - Tensor of embedding vectors. + vocab_size = 5 + word_ids = np.array([1, 2, 3, 4], dtype=np.int32) # word_ids < vocab_size + embedding_layer = tl.Embedding(vocab_size, 32) + embedding_layer.init(trax.shapes.signature(word_ids)) + embedded = embedding_layer(word_ids) # embedded.shape = (4, 32) """ - embedded = jnp.take(self.weights, x, axis=0, mode='clip') - if self._use_bfloat16: # Return float32 activations w/ bfloat16 weights. - embedded = embedded.astype(jnp.float32) - return embedded - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - del input_signature - shape_w = (self._vocab_size, self._d_feature) - # TODO(lukaszkaiser): do we split self.rng for consistency? Add a method? - w = self._kernel_initializer(shape_w, self.rng) - if self._use_bfloat16: - w = w.astype(jnp.bfloat16) - self.weights = w - - -@assert_shape('...->...') # The output and input shapes are the same. -class Dropout(base.Layer): - """A layer that stochastically ignores a subset of inputs each training step. - In training, to compensate for the fraction of input values dropped (`rate`), - all surviving values are multiplied by `1 / (1 - rate)`. - - The parameter `shared_axes` allows to specify a list of axes on which - the mask will be shared: we will use size 1 on those axes for dropout mask - and broadcast it. Sharing reduces randomness, but can save memory. + def __init__( + self, + vocab_size, + d_feature, + use_bfloat16=False, + kernel_initializer=init.ScaledInitializer( + out_dim=-1, in_dim=-2, scale=1.0, mode="fan_out", distribution="uniform" + ), + ): + """Returns an embedding layer with given vocabulary size and vector size. + + The layer clips input values (token IDs) to the range `[0, vocab_size)`. + That is, negative token IDs all clip to `0` before being mapped to a + vector, and token IDs with value `vocab_size` or greater all clip to + `vocab_size - 1` before being mapped to a vector. + + Args: + vocab_size: Size of the input vocabulary. The layer will assign a unique + vector to each id in `range(vocab_size)`. + d_feature: Dimensionality/depth of the output vectors. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + kernel_initializer: Function that creates (random) initial vectors for + the embedding. + """ + # TODO(jonni): is the clipping behavior what we want going forward? + super().__init__(name=f"Embedding_{vocab_size}_{d_feature}") + + self._d_feature = d_feature # feature dimensionality + self._vocab_size = vocab_size + self._use_bfloat16 = use_bfloat16 + self._kernel_initializer = kernel_initializer + + def forward(self, x): + """Returns embedding vectors corresponding to input token IDs. + + Args: + x: Tensor of token IDs. + + Returns: + Tensor of embedding vectors. + """ + embedded = jnp.take(self.weights, x, axis=0, mode="clip") + if self._use_bfloat16: # Return float32 activations w/ bfloat16 weights. + embedded = embedded.astype(jnp.float32) + return embedded + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + del input_signature + shape_w = (self._vocab_size, self._d_feature) + # TODO(lukaszkaiser): do we split self.rng for consistency? Add a method? + w = self._kernel_initializer(shape_w, self.rng) + if self._use_bfloat16: + w = w.astype(jnp.bfloat16) + self.weights = w + + +@assert_shape("...->...") # The output and input shapes are the same. +class Dropout(base.Layer): + """A layer that stochastically ignores a subset of inputs each training step. - This layer is active only during training (`mode='train'`). In other - circumstances it is a no-op. + In training, to compensate for the fraction of input values dropped (`rate`), + all surviving values are multiplied by `1 / (1 - rate)`. - Originally introduced in the paper "Dropout: A Simple Way to Prevent Neural - Networks from Overfitting" available under the following link: - https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf - """ + The parameter `shared_axes` allows to specify a list of axes on which + the mask will be shared: we will use size 1 on those axes for dropout mask + and broadcast it. Sharing reduces randomness, but can save memory. - def __init__(self, rate=0.0, shared_axes=None, mode='train'): - """Creates a dropout layer with the given target drop rate. + This layer is active only during training (`mode='train'`). In other + circumstances it is a no-op. - Args: - rate: Stochastic rate (probability) for dropping an activation value - from the preceding layer (setting it to zero). - shared_axes: List of axes on which the mask is shared. - mode: If `'train'`, this layer will perform dropout; else, it will pass - all values through unaltered. + Originally introduced in the paper "Dropout: A Simple Way to Prevent Neural + Networks from Overfitting" available under the following link: + https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf """ - super().__init__() - self._initial_rate = rate - self._shared_axes = [] if shared_axes is None else shared_axes - self._mode = mode - - def init_weights_and_state(self, input_signature): - """Sets layer-specific internal state.""" - del input_signature - self.state = jnp.array(self._initial_rate) - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of activations. - - Returns: - Tensor of same shape and dtype as the input. - """ - if self._mode != 'train': - return x - state, rng = self.state, self.rng - rate = self._initial_rate - if isinstance(state, dict) and self._name in state: - rate = state[self._name] - if rate == 0.0: - return x - mask_shape = list(x.shape) - for axis in self._shared_axes: - mask_shape[axis] = 1 - keep_prob = 1.0 - rate - keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape)) - mask = keep.astype(x.dtype) / keep_prob - return x * mask + def __init__(self, rate=0.0, shared_axes=None, mode="train"): + """Creates a dropout layer with the given target drop rate. + + Args: + rate: Stochastic rate (probability) for dropping an activation value + from the preceding layer (setting it to zero). + shared_axes: List of axes on which the mask is shared. + mode: If `'train'`, this layer will perform dropout; else, it will pass + all values through unaltered. + """ + super().__init__() + self._initial_rate = rate + self._shared_axes = [] if shared_axes is None else shared_axes + self._mode = mode + + def init_weights_and_state(self, input_signature): + """Sets layer-specific internal state.""" + del input_signature + self.state = jnp.array(self._initial_rate) + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of activations. + + Returns: + Tensor of same shape and dtype as the input. + """ + if self._mode != "train": + return x + state, rng = self.state, self.rng + rate = self._initial_rate + if isinstance(state, dict) and self._name in state: + rate = state[self._name] + if rate == 0.0: + return x + mask_shape = list(jnp.shape(x)) + + for axis in self._shared_axes: + mask_shape[axis] = 1 + keep_prob = 1.0 - rate + keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape)) + mask = keep.astype(x.dtype) / keep_prob + return x * mask class Weights(base.Layer): - """Learnable weights as a layer. + """Learnable weights as a layer. - It takes no input and returns a single tensor: weights. - """ - - def __init__(self, initializer, shape=tuple(), use_bfloat16=False): - """Returns a learnable tensor of shape `shape`. - - Args: - initializer: Function taking shape and rng as arguments. - shape: Shape of the learnable weights. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. + It takes no input and returns a single tensor: weights. """ - super().__init__(name=f'Weights_{shape}', n_in=0, n_out=1) - self._shape = shape - self._initializer = initializer - self._use_bfloat16 = use_bfloat16 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor with previously specified shape and dtype. - """ - del x # Unused. There is no input to this layer. - return self.weights - - def init_weights_and_state(self, input_signature): - """Returns newly initialized weights for this layer. + def __init__(self, initializer, shape=tuple(), use_bfloat16=False): + """Returns a learnable tensor of shape `shape`. + + Args: + initializer: Function taking shape and rng as arguments. + shape: Shape of the learnable weights. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + """ + super().__init__(name=f"Weights_{shape}", n_in=0, n_out=1) + self._shape = shape + self._initializer = initializer + self._use_bfloat16 = use_bfloat16 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor with previously specified shape and dtype. + """ + del x # Unused. There is no input to this layer. + return self.weights + + def init_weights_and_state(self, input_signature): + """Returns newly initialized weights for this layer. + + Weights is a single `w` tensor with previously specified shape. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. Unused. + """ + del input_signature # Unused. There is no input to this layer. + self.weights = self._initializer(self._shape, self.rng) + if self._use_bfloat16: + self.weights = self.weights.astype(jnp.bfloat16) + + +def PrintShape(n_in=1, msg=""): + """Prints the shapes of `n_in` inputs and returns then unchanged.""" + + def Fwd(xs): + def format_shape(x): # pylint: disable = invalid-name + return str(jnp.shape(x)) + f"[{x.dtype}]" + + if n_in > 1: + shapes_and_dtypes = ", ".join([format_shape(x) for x in xs]) + else: + shapes_and_dtypes = format_shape(xs) + info = f"PrintShape: {msg}: [{shapes_and_dtypes}]" + print(info) + logging.info(info) + return xs - Weights is a single `w` tensor with previously specified shape. - - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. Unused. - """ - del input_signature # Unused. There is no input to this layer. - self.weights = self._initializer(self._shape, self.rng) - if self._use_bfloat16: - self.weights = self.weights.astype(jnp.bfloat16) - - -def PrintShape(n_in=1, msg=''): - """Prints the shapes of `n_in` inputs and returns then unchanged.""" - def Fwd(xs): - def format_shape(x): # pylint: disable = invalid-name - return str(x.shape) + f'[{x.dtype}]' - if n_in > 1: - shapes_and_dtypes = ', '.join([format_shape(x) for x in xs]) - else: - shapes_and_dtypes = format_shape(xs) - info = f'PrintShape: {msg}: [{shapes_and_dtypes}]' - print(info) - logging.info(info) - return xs - return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f'PrintShape_{n_in}') + return base.PureLayer(Fwd, n_in=n_in, n_out=n_in, name=f"PrintShape_{n_in}") class SummaryImage(base.Layer): - """A layer receiving a tensor, and adding it to TensorBoard as an image. - - It takes an input and returns it unchanged. It stores this input as a state to - be used as a metric in TensorBoard. - It converts a tensor to a scalar by running a given aggregation function (mean - by default). On TensorBoard, results for each device will be reported - separately. - """ - - def __init__(self, name, n_in, num_summaries=5, - recover_fn=None): - """Takes a tensor and returns it. + """A layer receiving a tensor, and adding it to TensorBoard as an image. - Args: - name: Name of the metric to be reported. - n_in: Number of inputs. - num_summaries: Number of images to show. - recover_fn: the function for converting a tensor to a dipslayable image. + It takes an input and returns it unchanged. It stores this input as a state to + be used as a metric in TensorBoard. + It converts a tensor to a scalar by running a given aggregation function (mean + by default). On TensorBoard, results for each device will be reported + separately. """ - super().__init__(name=f'Summary_{name}', n_in=n_in, n_out=n_in) - name = 'summary_' + name - self._name = name - self._num_summaries = num_summaries - self._recover_fn = recover_fn - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + def __init__(self, name, n_in, num_summaries=5, recover_fn=None): + """Takes a tensor and returns it. + + Args: + name: Name of the metric to be reported. + n_in: Number of inputs. + num_summaries: Number of images to show. + recover_fn: the function for converting a tensor to a dipslayable image. + """ + super().__init__(name=f"Summary_{name}", n_in=n_in, n_out=n_in) + name = "summary_" + name + self._name = name + self._num_summaries = num_summaries + self._recover_fn = recover_fn + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor with previously specified shape and dtype. + """ + self.state = {} + batch_size = x[0].shape[0] + num_images = min(self._num_summaries, batch_size) + for s in range(num_images): + images = [] + for i in range(self._n_in): + images.append( + self._recover_fn(x[i][s]) if self._recover_fn else x[i][s] + ) + self.state[self._name + str(s)] = jnp.concatenate(images, axis=0) + return x[: self._n_in] + + def init_weights_and_state(self, input_signature): + """Returns newly initialized weights for this layer. + + Weights is a single `w` tensor with previously specified shape. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. Unused. + """ + del input_signature # Unused. + self.weights = () + self.state = {self._name: jnp.array(0.0)} - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - Returns: - Tensor with previously specified shape and dtype. - """ - self.state = {} - batch_size = x[0].shape[0] - num_images = min(self._num_summaries, batch_size) - for s in range(num_images): - images = [] - for i in range(self._n_in): - images.append( - self._recover_fn(x[i][s]) if self._recover_fn else x[i][s]) - self.state[self._name + str(s)] = jnp.concatenate(images, axis=0) - return x[:self._n_in] - - def init_weights_and_state(self, input_signature): - """Returns newly initialized weights for this layer. - - Weights is a single `w` tensor with previously specified shape. +class SummaryScalar(base.Layer): + """A layer receiving a tensor, and adding it to TensorBoard as a scalar. - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. Unused. + It takes an input and returns it unchanged. It stores this input as a state to + be used as a metric in TensorBoard. + It converts a tensor to a scalar by running a given aggregation function (mean + by default). On TensorBoard, results for each device will be reported + separately. """ - del input_signature # Unused. - self.weights = () - self.state = {self._name: jnp.array(0.)} + def __init__(self, name, aggregation_fun=jnp.mean): + """Takes a tensor and returns it. -class SummaryScalar(base.Layer): - """A layer receiving a tensor, and adding it to TensorBoard as a scalar. + Args: + name: Name of the metric to be reported. + aggregation_fun: Aggregation function to be used. + """ + super().__init__(name=f"Summary_{name}", n_in=1, n_out=1) + name = "summary_" + name + self._name = name + self._aggregation_fun = aggregation_fun - It takes an input and returns it unchanged. It stores this input as a state to - be used as a metric in TensorBoard. - It converts a tensor to a scalar by running a given aggregation function (mean - by default). On TensorBoard, results for each device will be reported - separately. - """ + def forward(self, x): + """Executes this layer as part of a forward pass through the model. - def __init__(self, name, aggregation_fun=jnp.mean): - """Takes a tensor and returns it. + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. - Args: - name: Name of the metric to be reported. - aggregation_fun: Aggregation function to be used. - """ - super().__init__(name=f'Summary_{name}', n_in=1, n_out=1) - name = 'summary_' + name - self._name = name - self._aggregation_fun = aggregation_fun + Returns: + Tensor with previously specified shape and dtype. + """ + self.state = {self._name: self._aggregation_fun(x)} + return x - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + def init_weights_and_state(self, input_signature): + """Returns newly initialized weights for this layer. - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + Weights is a single `w` tensor with previously specified shape. - Returns: - Tensor with previously specified shape and dtype. - """ - self.state = {self._name: self._aggregation_fun(x)} - return x - - def init_weights_and_state(self, input_signature): - """Returns newly initialized weights for this layer. - - Weights is a single `w` tensor with previously specified shape. - - Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. Unused. - """ - del input_signature # Unused. - self.weights = () - self.state = {self._name: jnp.array(0.)} + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. Unused. + """ + del input_signature # Unused. + self.weights = () + self.state = {self._name: jnp.array(0.0)} class RandomUniform(base.Layer): - """Layer returning a tensor with random values distributed uniformly.""" + """Layer returning a tensor with random values distributed uniformly.""" + + def __init__( + self, min_val=0.0, max_val=1.0, shape=(), dtype=jnp.float32, sync=False + ): + """Layer returning a tensor with random values distributed uniformly. + + Args: + min_val: Lower end of uniform distribution. + max_val: Upper end of uniform distribution. + shape: Shape of the tensor to return. Values are sampled independently. + dtype: Type of value to return. + sync: Whether to synchronise `rng` across devices. + """ + super().__init__(n_in=0, n_out=1) + self._min_val = min_val + self._max_val = max_val + self._shape = shape + self._dtype = dtype + self._sync = sync + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model. + + Args: + xs: Unused tensors. + + Returns: + Random uniform tensor of the shape and type specified in constructor. + """ + rng = self._get_conditionally_synced_rng() + result = fastmath.random.uniform( + rng, self._shape, self._dtype, self._min_val, self._max_val + ) + return result + + def _get_conditionally_synced_rng(self): + if self._sync and fastmath.global_device_count() > 1: + return fastmath.psum(self.rng, "batch") + else: + return self.rng - def __init__(self, min_val=0.0, max_val=1.0, shape=(), dtype=jnp.float32, - sync=False): - """Layer returning a tensor with random values distributed uniformly. - Args: - min_val: Lower end of uniform distribution. - max_val: Upper end of uniform distribution. - shape: Shape of the tensor to return. Values are sampled independently. - dtype: Type of value to return. - sync: Whether to synchronise `rng` across devices. - """ - super().__init__(n_in=0, n_out=1) - self._min_val = min_val - self._max_val = max_val - self._shape = shape - self._dtype = dtype - self._sync = sync +class LocallyConnected1d(base.Layer): + """Locally-connected layer for 1D inputs. - def forward(self, xs): - """Executes this layer as part of a forward pass through the model. + The LocallyConnected1d layer applies a different set of filters to each patch + of the input. This is similar to applying a convolution layer, except that + locally-connected layer uses a different set of weights for each patch. - Args: - xs: Unused tensors. + The size of patch is determined by the kernel size. The stride is currently + not modifiable and set to one. This means for the input of shape (..., L, D) + the output shape for paddings 'SAME' and 'WRAP' will be (..., L, filters) and + for padding 'VALID' (..., L-kernel_size+1, filters); where L is the number of + "pixels" or "steps" in the input, D is the size of the embedding. - Returns: - Random uniform tensor of the shape and type specified in constructor. + Note that, since the weights for different patches are not shared, the number + of "pixels" or "steps" cannot change after calling init_weights_and_state. + This is because each "pixel" is assigned its own set of weights. """ - rng = self._get_conditionally_synced_rng() - result = fastmath.random.uniform( - rng, self._shape, self._dtype, self._min_val, self._max_val) - return result - def _get_conditionally_synced_rng(self): - if self._sync and fastmath.global_device_count() > 1: - return fastmath.psum(self.rng, 'batch') - else: - return self.rng - - -class LocallyConnected1d(base.Layer): - """Locally-connected layer for 1D inputs. - - The LocallyConnected1d layer applies a different set of filters to each patch - of the input. This is similar to applying a convolution layer, except that - locally-connected layer uses a different set of weights for each patch. - - The size of patch is determined by the kernel size. The stride is currently - not modifiable and set to one. This means for the input of shape (..., L, D) - the output shape for paddings 'SAME' and 'WRAP' will be (..., L, filters) and - for padding 'VALID' (..., L-kernel_size+1, filters); where L is the number of - "pixels" or "steps" in the input, D is the size of the embedding. - - Note that, since the weights for different patches are not shared, the number - of "pixels" or "steps" cannot change after calling init_weights_and_state. - This is because each "pixel" is assigned its own set of weights. - """ + def __init__( + self, + filters, + kernel_size, + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, + padding="VALID", + ): + """Returns a locally-connected conv-like layer. + + Args: + filters: Number of output filters in the convolution. + kernel_size: A length of the convolution window. Must be an odd number. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + use_bias: If `True`, the layer uses a bias vector. + padding: The type of padding to use; must be 'VALID', 'SAME', or 'WRAP'. + """ + super().__init__(name=f"LocallyConnected1d_{filters}_{kernel_size}") + self._filters = filters + self._kernel_size = kernel_size + assert self._kernel_size % 2 == 1 # kernel size has to be odd + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + self._use_bias = use_bias + self._padding = padding + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input, except the final dimension + is the layer's `filters` value, and the second to last dimension is + shrinked if 'VALID' padding is used with kernel_size bigger than one. + """ + if self._use_bias: + if not isinstance(self.weights, (tuple, list)): + raise ValueError( + f"Weights should be a (w, b) tuple or list; " + f"instead got: {self.weights}" + ) + w, b = self.weights + else: + w = self.weights - def __init__(self, filters, kernel_size, - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True, padding='VALID'): - """Returns a locally-connected conv-like layer. + linear_results_before_shifting = jnp.einsum("...lp,lkpd->...lkd", x, w) + # TODO(jaszczur): this could be run after padding for better efficiency - Args: - filters: Number of output filters in the convolution. - kernel_size: A length of the convolution window. Must be an odd number. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - use_bias: If `True`, the layer uses a bias vector. - padding: The type of padding to use; must be 'VALID', 'SAME', or 'WRAP'. - """ - super().__init__(name=f'LocallyConnected1d_{filters}_{kernel_size}') - self._filters = filters - self._kernel_size = kernel_size - assert self._kernel_size % 2 == 1 # kernel size has to be odd - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - self._use_bias = use_bias - self._padding = padding - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + if self._kernel_size == 1: + # With kernel size 1 we don't have to split or shift anything. + linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) + else: + # We computed a result for every "pixel", but each direction from the + # receptive field (there are 'self._kernel_size' such directions) must be + # shifted by a different amount. The easiest way to do it is to split + # the tensor to 'self._kernel_size' smaller tensors, shift each one + # appropriately, and then sum them together. + split_shifting_linear_results = jnp.split( + linear_results_before_shifting, self._kernel_size, axis=-2 + ) + + for i in range(self._kernel_size): + # Each tensor has to be shifted a different amount. + if self._padding == "WRAP": + # We can shift by padding and cutting. With 'wrap' padding we + # essentially have a torus. + padding = [(0, 0) for i in split_shifting_linear_results[i].shape] + padding[-3] = ((self._kernel_size - 1) - i, i) + split_shifting_linear_results[i] = jnp.pad( + split_shifting_linear_results[i], padding, mode="wrap" + ) + split_shifting_linear_results[i] = split_shifting_linear_results[i][ + ..., + (self._kernel_size - 1) // 2 : -(self._kernel_size - 1) // 2, + :, + :, + ] + elif self._padding == "SAME": + # We can shift by padding and cutting. + padding = [(0, 0) for i in split_shifting_linear_results[i].shape] + padding[-3] = ((self._kernel_size - 1) - i, i) + split_shifting_linear_results[i] = jnp.pad( + split_shifting_linear_results[i], padding + ) + split_shifting_linear_results[i] = split_shifting_linear_results[i][ + ..., + (self._kernel_size - 1) // 2 : -(self._kernel_size - 1) // 2, + :, + :, + ] + # TODO(jaszczur): improve efficiency by not padding things to cut + elif self._padding == "VALID": + # We don't need to shift - just cut the leftmost and rightmost values. + cut_left = (self._kernel_size - 1) - i + cut_right = split_shifting_linear_results[i].shape[-3] - i + split_shifting_linear_results[i] = split_shifting_linear_results[i][ + ..., cut_left:cut_right, :, : + ] + else: + raise ValueError(f"Invalid padding {self._padding}") + # After shifting. + shifted_linear_results = jnp.concatenate( + split_shifting_linear_results, axis=-2 + ) + linear_result = jnp.sum(shifted_linear_results, axis=-2) + + if self._use_bias: + return linear_result + b + else: + return linear_result + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights. + + Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the + default case), or a `w` tensor for layers created with `use_bias=False`. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this layer + should compute on. + """ + shape_w = ( + input_signature.shape[-2], + self._kernel_size, + input_signature.shape[-1], + self._filters, + ) + if self._padding == "VALID": + shape_b = ( + input_signature.shape[-2] - self._kernel_size + 1, + self._filters, + ) + else: + shape_b = ( + input_signature.shape[-2], + self._filters, + ) - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + rng_w, rng_b = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(shape_w, rng_w, nonreceptive_dims=[0]) - Returns: - Tensor of same shape and dtype as the input, except the final dimension - is the layer's `filters` value, and the second to last dimension is - shrinked if 'VALID' padding is used with kernel_size bigger than one. - """ - if self._use_bias: - if not isinstance(self.weights, (tuple, list)): - raise ValueError(f'Weights should be a (w, b) tuple or list; ' - f'instead got: {self.weights}') - w, b = self.weights - else: - w = self.weights - - linear_results_before_shifting = jnp.einsum( - '...lp,lkpd->...lkd', x, w) - # TODO(jaszczur): this could be run after padding for better efficiency - - if self._kernel_size == 1: - # With kernel size 1 we don't have to split or shift anything. - linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) - else: - # We computed a result for every "pixel", but each direction from the - # receptive field (there are 'self._kernel_size' such directions) must be - # shifted by a different amount. The easiest way to do it is to split - # the tensor to 'self._kernel_size' smaller tensors, shift each one - # appropriately, and then sum them together. - split_shifting_linear_results = jnp.split( - linear_results_before_shifting, self._kernel_size, axis=-2) - - for i in range(self._kernel_size): - # Each tensor has to be shifted a different amount. - if self._padding == 'WRAP': - # We can shift by padding and cutting. With 'wrap' padding we - # essentially have a torus. - padding = [(0, 0) for i in split_shifting_linear_results[i].shape] - padding[-3] = ((self._kernel_size - 1) - i, i) - split_shifting_linear_results[i] = jnp.pad( - split_shifting_linear_results[i], padding, mode='wrap') - split_shifting_linear_results[i] = split_shifting_linear_results[i][ - ..., (self._kernel_size-1)//2:-(self._kernel_size-1)//2, :, :] - elif self._padding == 'SAME': - # We can shift by padding and cutting. - padding = [(0, 0) for i in split_shifting_linear_results[i].shape] - padding[-3] = ((self._kernel_size - 1) - i, i) - split_shifting_linear_results[i] = jnp.pad( - split_shifting_linear_results[i], padding) - split_shifting_linear_results[i] = split_shifting_linear_results[i][ - ..., (self._kernel_size-1)//2:-(self._kernel_size-1)//2, :, :] - # TODO(jaszczur): improve efficiency by not padding things to cut - elif self._padding == 'VALID': - # We don't need to shift - just cut the leftmost and rightmost values. - cut_left = (self._kernel_size - 1) - i - cut_right = split_shifting_linear_results[i].shape[-3] - i - split_shifting_linear_results[i] = split_shifting_linear_results[i][ - ..., cut_left:cut_right, :, :] + if self._use_bias: + b = self._bias_initializer(shape_b, rng_b) + self.weights = (w, b) else: - raise ValueError(f'Invalid padding {self._padding}') - # After shifting. - shifted_linear_results = jnp.concatenate(split_shifting_linear_results, - axis=-2) - linear_result = jnp.sum(shifted_linear_results, axis=-2) + self.weights = w - if self._use_bias: - return linear_result + b - else: - return linear_result - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights. +def Flatten(n_axes_to_keep=1): + """Returns a layer that combines one or more trailing axes of a tensor. - Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the - default case), or a `w` tensor for layers created with `use_bias=False`. + Flattening keeps all the values of the input tensor, but reshapes it by + collapsing one or more trailing axes into a single axis. For example, a + `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape + `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`. Args: - input_signature: `ShapeDtype` instance characterizing the input this layer - should compute on. + n_axes_to_keep: Number of leading axes to leave unchanged when reshaping; + collapse only the axes after these. """ - shape_w = (input_signature.shape[-2], self._kernel_size, - input_signature.shape[-1], self._filters) - if self._padding == 'VALID': - shape_b = (input_signature.shape[-2] - self._kernel_size + 1, - self._filters,) - else: - shape_b = (input_signature.shape[-2], self._filters,) - - rng_w, rng_b = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(shape_w, rng_w, nonreceptive_dims=[0]) + layer_name = f"Flatten_keep{n_axes_to_keep}" - if self._use_bias: - b = self._bias_initializer(shape_b, rng_b) - self.weights = (w, b) - else: - self.weights = w + def f(x): # pylint: disable=invalid-name + in_rank = len(jnp.shape(x)) + if in_rank <= n_axes_to_keep: + raise ValueError( + f"Input rank ({in_rank}) must exceed the number of " + f"axes to keep ({n_axes_to_keep}) after flattening." + ) + shape = jnp.shape(x) + if isinstance(shape, tf.TensorShape): + shape = tuple(shape.as_list()) + return jnp.reshape(x, (shape[:n_axes_to_keep] + (-1,))) - -def Flatten(n_axes_to_keep=1): - """Returns a layer that combines one or more trailing axes of a tensor. - - Flattening keeps all the values of the input tensor, but reshapes it by - collapsing one or more trailing axes into a single axis. For example, a - `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape - `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`. - - Args: - n_axes_to_keep: Number of leading axes to leave unchanged when reshaping; - collapse only the axes after these. - """ - layer_name = f'Flatten_keep{n_axes_to_keep}' - def f(x): # pylint: disable=invalid-name - in_rank = len(x.shape) - if in_rank <= n_axes_to_keep: - raise ValueError(f'Input rank ({in_rank}) must exceed the number of ' - f'axes to keep ({n_axes_to_keep}) after flattening.') - shape = x.shape - if isinstance(shape, tf.TensorShape): - shape = tuple(shape.as_list()) - return jnp.reshape(x, (shape[:n_axes_to_keep] + (-1,))) - return Fn(layer_name, f) + return Fn(layer_name, f) def LogSoftmax(axis=-1): - """Returns a layer that applies log softmax along one tensor axis. + """Returns a layer that applies log softmax along one tensor axis. - Note that the implementation actually computes x - LogSumExp(x), - which is mathematically equal to LogSoftmax(x). + Note that the implementation actually computes x - LogSumExp(x), + which is mathematically equal to LogSoftmax(x). - `LogSoftmax` acts on a group of values and normalizes them to look like a set - of log probability values. (Probability values must be non-negative, and as - a set must sum to 1. A group of log probability values can be seen as the - natural logarithm function applied to a set of probability values.) + `LogSoftmax` acts on a group of values and normalizes them to look like a set + of log probability values. (Probability values must be non-negative, and as + a set must sum to 1. A group of log probability values can be seen as the + natural logarithm function applied to a set of probability values.) - Args: - axis: Axis along which values are grouped for computing log softmax. - """ - return Fn('LogSoftmax', lambda x: log_softmax(x, axis=axis)) + Args: + axis: Axis along which values are grouped for computing log softmax. + """ + return Fn("LogSoftmax", lambda x: log_softmax(x, axis=axis)) def LogSumExp(axis=-1): - """Returns a layer that computes log(sum(exp(x))) along one tensor axis. + """Returns a layer that computes log(sum(exp(x))) along one tensor axis. - Args: - axis: Axis along which values are grouped for computing log-sum-exp. - """ - return Fn('LogSumExp', - lambda x: fastmath.logsumexp(x, axis=axis, keepdims=True)) + Args: + axis: Axis along which values are grouped for computing log-sum-exp. + """ + return Fn("LogSumExp", lambda x: fastmath.logsumexp(x, axis=axis, keepdims=True)) def Softmax(axis=-1): - """Returns a layer that applies softmax along one tensor axis. + """Returns a layer that applies softmax along one tensor axis. - `Softmax` acts on a group of values and normalizes them to look like a set - of probability values. (Probability values must be non-negative, and as a - set must sum to 1.) + `Softmax` acts on a group of values and normalizes them to look like a set + of probability values. (Probability values must be non-negative, and as a + set must sum to 1.) - Args: - axis: Axis along which values are grouped for computing softmax. - """ - return Fn('Softmax', - lambda x: jnp.exp(log_softmax(x, axis=axis))) + Args: + axis: Axis along which values are grouped for computing softmax. + """ + return Fn("Softmax", lambda x: jnp.exp(log_softmax(x, axis=axis))) def ToFloat(): - """Returns a layer that changes the dtype of a tensor to `float32`.""" - return Fn('ToFloat', lambda x: x.astype(np.float32)) + """Returns a layer that changes the dtype of a tensor to `float32`.""" + return Fn("ToFloat", lambda x: x.astype(np.float32)) def Mean(axis=-1, keepdims=False): - """Returns a layer that computes mean values using one tensor axis. + """Returns a layer that computes mean values using one tensor axis. - `Mean` uses one tensor axis to form groups of values and replaces each group - with the mean value of that group. The resulting values can either remain - in their own size 1 axis (`keepdims=True`), or that axis can be removed from - the overall tensor (default `keepdims=False`), lowering the rank of the - tensor by one. + `Mean` uses one tensor axis to form groups of values and replaces each group + with the mean value of that group. The resulting values can either remain + in their own size 1 axis (`keepdims=True`), or that axis can be removed from + the overall tensor (default `keepdims=False`), lowering the rank of the + tensor by one. - Args: - axis: Axis along which values are grouped for computing a mean. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims)) + Args: + axis: Axis along which values are grouped for computing a mean. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Mean", lambda x: jnp.mean(x, axis=axis, keepdims=keepdims)) def Min(axis=-1, keepdims=False): - """Returns a layer that applies min along one tensor axis. + """Returns a layer that applies min along one tensor axis. - Args: - axis: Axis along which values are grouped for computing minimum. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Min', lambda x: jnp.min(x, axis, keepdims=keepdims)) + Args: + axis: Axis along which values are grouped for computing minimum. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Min", lambda x: jnp.min(x, axis, keepdims=keepdims)) def Max(axis=-1, keepdims=False): - """Returns a layer that applies max along one tensor axis. + """Returns a layer that applies max along one tensor axis. - Args: - axis: Axis along which values are grouped for computing maximum. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Max', lambda x: jnp.max(x, axis, keepdims=keepdims)) + Args: + axis: Axis along which values are grouped for computing maximum. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Max", lambda x: jnp.max(x, axis, keepdims=keepdims)) def Sum(axis=None, keepdims=False): - """Returns a layer that computes sums using one tensor axis. + """Returns a layer that computes sums using one tensor axis. + + `Sum` uses one tensor axis to form groups of values and replaces each group + with the sum of that group. The resulting sum values can either remain in + their own size 1 axis (`keepdims=True`), or that axis can be removed from the + overall tensor (default `keepdims=False`), lowering the rank of the tensor by + one. - `Sum` uses one tensor axis to form groups of values and replaces each group - with the sum of that group. The resulting sum values can either remain in - their own size 1 axis (`keepdims=True`), or that axis can be removed from the - overall tensor (default `keepdims=False`), lowering the rank of the tensor by - one. + Args: + axis: Axis along which values are grouped for computing a sum; if None, + compute sum over all elements in tensor. + keepdims: If `True`, keep the resulting size 1 axis as a separate tensor + axis; else, remove that axis. + """ + return Fn("Sum", lambda x: jnp.sum(x, axis=axis, keepdims=keepdims)) - Args: - axis: Axis along which values are grouped for computing a sum; if None, - compute sum over all elements in tensor. - keepdims: If `True`, keep the resulting size 1 axis as a separate tensor - axis; else, remove that axis. - """ - return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims)) +def ThresholdToBinary(threshold=0.5): + """Returns a layer that thresholds inputs to yield outputs in {0, 1}.""" -def ThresholdToBinary(threshold=.5): - """Returns a layer that thresholds inputs to yield outputs in {0, 1}.""" - def f(model_output): # pylint: disable=invalid-name - return (model_output > threshold).astype(jnp.int32) - return Fn('ThresholdToBinary', f) + def f(model_output): # pylint: disable=invalid-name + return (model_output > threshold).astype(jnp.int32) + + return Fn("ThresholdToBinary", f) def ArgMax(axis=-1): - """Returns a layer that calculates argmax along the given axis.""" - def f(model_output): # pylint: disable=invalid-name - return jnp.argmax(model_output, axis=axis) - return Fn('ArgMax', f) + """Returns a layer that calculates argmax along the given axis.""" + + def f(model_output): # pylint: disable=invalid-name + return jnp.argmax(model_output, axis=axis) + + return Fn("ArgMax", f) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def Negate(): - """Returns a layer that computes the element-wise negation of a tensor.""" - return Fn('Negate', lambda x: -x) + """Returns a layer that computes the element-wise negation of a tensor.""" + return Fn("Negate", lambda x: -x) -@assert_shape('...->...') # The output and input shapes are the same. +@assert_shape("...->...") # The output and input shapes are the same. def StopGradient(): - """Returns an identity layer with a stop gradient.""" - return Fn('StopGradient', lambda x: fastmath.stop_gradient(x)) # pylint: disable=unnecessary-lambda + """Returns an identity layer with a stop gradient.""" + return Fn( + "StopGradient", lambda x: fastmath.stop_gradient(x) + ) # pylint: disable=unnecessary-lambda def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name - """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" - indices_less_than_n = jnp.arange(n_categories) - return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype) + """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" + indices_less_than_n = jnp.arange(n_categories) + mask = jnp.equal(x[..., jnp.newaxis], indices_less_than_n) + return jnp.array(mask, dtype) def log_softmax(x, axis=-1): # pylint: disable=invalid-name - """Transforms activation vectors to log-probability vectors. + """Transforms activation vectors to log-probability vectors. - Log probability vectors are derived by, in effect, applying softmax to raw - activation vectors and then applying log element-wise. The actual - implementation uses a mathematically valid simplification of this. + Log probability vectors are derived by, in effect, applying softmax to raw + activation vectors and then applying log element-wise. The actual + implementation uses a mathematically valid simplification of this. - Args: - x: An ndarray with activation vectors along the given axis. - axis: Axis along which values are grouped for computing log softmax. + Args: + x: An ndarray with activation vectors along the given axis. + axis: Axis along which values are grouped for computing log softmax. - Returns: - An ndarray containing log-probability vectors derived from the raw - activation vectors in `x`. - """ - return x - fastmath.logsumexp(x, axis=axis, keepdims=True) + Returns: + An ndarray containing log-probability vectors derived from the raw + activation vectors in `x`. + """ + return x - fastmath.logsumexp(x, axis=axis, keepdims=True) def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name - """Returns `log N(x | mu, sigma)`. - - Args: - x: - mu: - sigma: - """ - a = mu.shape[-1] * jnp.log(2 * jnp.pi) - _, b = jnp.linalg.slogdet(sigma) - y = jnp.linalg.solve(sigma, x - mu) - y = jnp.expand_dims(y, axis=-1) - xm = jnp.expand_dims(x - mu, axis=-2) - c = jnp.matmul(xm, y) - c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) - return -0.5 * (a + b + c) + """Returns `log N(x | mu, sigma)`. + + Args: + x: + mu: + sigma: + """ + a = mu.shape[-1] * jnp.log(2 * jnp.pi) + _, b = jnp.linalg.slogdet(sigma) + y = jnp.linalg.solve(sigma, jnp.expand_dims(x - mu, axis=-1)) + y = y.squeeze(-1) + y = jnp.expand_dims(y, axis=-1) + xm = jnp.expand_dims(x - mu, axis=-2) + c = jnp.matmul(xm, y) + c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) + return -0.5 * (a + b + c) def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name - """Returns `log N(x | mu, eye(diag_sigma))`. - - Args: - x: - mu: - diag_sigma: - """ - a = mu.shape[-1] * jnp.log(2 * jnp.pi) - b = jnp.sum(jnp.log(diag_sigma), axis=-1) - y = x - mu / diag_sigma - y = jnp.expand_dims(y, axis=-1) - xm = jnp.expand_dims(x - mu, axis=-2) - c = jnp.matmul(xm, y) - c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) - return -0.5 * (a + b + c) + """Returns `log N(x | mu, eye(diag_sigma))`. + + Args: + x: + mu: + diag_sigma: + """ + a = mu.shape[-1] * jnp.log(2 * jnp.pi) + b = jnp.sum(jnp.log(diag_sigma), axis=-1) + y = x - mu / diag_sigma + y = jnp.expand_dims(y, axis=-1) + xm = jnp.expand_dims(x - mu, axis=-2) + c = jnp.matmul(xm, y) + c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) + return -0.5 * (a + b + c) def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name - """Returns a mixture of gaussians loss. - - Args: - preds: - targets: - ngauss: - """ - ndims = targets.shape[-1] - logits = preds[:, :ngauss] - mus = preds[:, ngauss:ngauss*(ndims + 1)] - sigmas = preds[:, ngauss(ndims + 1):] - sigmas = sigmas * sigmas + 1e-6 # Make positive. - loglogits = logits - fastmath.logsumexp(logits, axis=-1, keepdims=True) - mus = jnp.reshape(mus, [-1, ngauss, ndims]) - sigmas = jnp.reshape(sigmas, [-1, ngauss, ndims]) - targets = jnp.reshape(targets, [-1, 1, ndims]) - glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) - return fastmath.logsumexp(loglogits + glogprobs, axis=-1) + """Returns a mixture of gaussians loss. + + Args: + preds: + targets: + ngauss: + """ + ndims = targets.shape[-1] + logits = preds[:, :ngauss] + mus = preds[:, ngauss : ngauss * (ndims + 1)] + sigmas = preds[:, ngauss(ndims + 1) :] + sigmas = sigmas * sigmas + 1e-6 # Make positive. + loglogits = logits - fastmath.logsumexp(logits, axis=-1, keepdims=True) + mus = jnp.reshape(mus, [-1, ngauss, ndims]) + sigmas = jnp.reshape(sigmas, [-1, ngauss, ndims]) + targets = jnp.reshape(targets, [-1, 1, ndims]) + glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) + return fastmath.logsumexp(loglogits + glogprobs, axis=-1) # TODO(jonni): Rename to log_softmax_sample. def logsoftmax_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name - """Returns a sample from a log-softmax output, with temperature. - - Args: - log_probs: Logarithms of probabilities (often coming from LogSoftmax) - temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax) - """ - # This is equivalent to sampling from a softmax with temperature. - u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape) - g = -np.log(-np.log(u)) - return np.argmax(log_probs + g * temperature, axis=-1) + """Returns a sample from a log-softmax output, with temperature. + + Args: + log_probs: Logarithms of probabilities (often coming from LogSoftmax) + temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax) + """ + # This is equivalent to sampling from a softmax with temperature. + u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape) + g = -np.log(-np.log(u)) + return np.argmax(log_probs + g * temperature, axis=-1) diff --git a/trax/layers/core_test.py b/trax/layers/core_test.py deleted file mode 100644 index 85143f0ca..000000000 --- a/trax/layers/core_test.py +++ /dev/null @@ -1,492 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for core layers.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.fastmath import numpy as jnp -import trax.layers as tl -import trax.layers.initializers as init - - -class DenseTest(absltest.TestCase): - """Test Dense layer per se and as a key example of trainable layers.""" - - def test_call_before_init_raises_error(self): - layer = tl.Dense(5) - x = np.array([1, 2, 3]) - - # Without init, layer lacks the weights it needs for forward computation. - with self.assertRaises(tl.LayerError): - _ = layer(x) - - def test_call_uses_and_caches_supplied_weights(self): - layer = tl.Dense(4) - x = np.array([2, 3]) - - # Weights from random initialization are cached in the layer. - _, _ = layer.init(shapes.signature(x)) - w_init, b_init = layer.weights - - # Call the layer with externally specified weights. - w = np.array([[10000, 20000, 30000, 40000], [100, 200, 100, 200]]) - b = np.array([9, 8, 7, 6]) - y = layer(x, weights=(w, b)) - - # Using weights keyword arg overrides any previous cached weights ... - self.assertEqual(y.tolist(), [20309, 40608, 60307, 80606]) - self.assertNotEqual(w.tolist(), w_init.tolist()) - self.assertNotEqual(b.tolist(), b_init.tolist()) - - # ... and do not over-write the old weights. - w_cached, b_cached = layer.weights - self.assertNotEqual(w.tolist(), w_cached.tolist()) - self.assertNotEqual(b.tolist(), b_cached.tolist()) - - def test_separate_instances_have_separate_weights(self): - # Two dense layer instances: each will get its own initial weights (w, b). - model = tl.Serial(tl.Dense(5), tl.Dense(5)) - - sample_input = np.array([1, 2, 3, 4, 5]) - _, _ = model.init(shapes.signature(sample_input)) - weights_0 = model.sublayers[0].weights - weights_1 = model.sublayers[1].weights - - w0, b0 = weights_0 - w1, b1 = weights_1 - self.assertNotEqual(w0.tolist(), w1.tolist()) - self.assertNotEqual(b0.tolist(), b1.tolist()) - - def test_shared_instance_means_shared_weights(self): - # Same dense layer instance in two places --> shared weights. - layer = tl.Dense(5) - model = tl.Serial(layer, layer) - sample_input = np.array([1, 2, 3, 4, 5]) - weights, _ = model.init(shapes.signature(sample_input)) - self.assertIs(weights[1], tl.GET_WEIGHTS_FROM_CACHE) - - def test_call_no_bias(self): - layer = tl.Dense(4, use_bias=False) - x = np.array([2, 5, 3]) - _, _ = layer.init(shapes.signature(x)) - - w = np.array([[100, 200, 300, 400], [10, 10, 10, 10], [1, 2, 1, 2]]) - y = layer(x, weights=w) - self.assertEqual(y.tolist(), [253, 456, 653, 856]) - - def test_new_weights_use_bias(self): - layer = tl.Dense(4) - x = np.array([1, 2]) - _, _ = layer.init(shapes.signature(x)) - self.assertLen(layer.weights, 2) - self.assertEqual(layer.weights[0].shape, (2, 4)) - self.assertEqual(layer.weights[1].shape, (4,)) - - def test_new_weights_no_bias(self): - layer = tl.Dense(4, use_bias=False) - x = np.array([1, 2]) - _, _ = layer.init(shapes.signature(x)) - self.assertEqual(layer.weights.shape, (2, 4)) - - def test_init_twice_weights_same_shape(self): - layer = tl.Dense(4, use_bias=False) - x = np.array([1, 2]) - w1, _ = layer.init(shapes.signature(x)) - w2, _ = layer.init(shapes.signature(x)) - self.assertEqual(w1.shape, (2, 4)) - self.assertEqual(w2.shape, (2, 4)) - - def test_save_to_file_and_init_to_file(self): - layer1 = tl.Dense(4, use_bias=False) - layer2 = tl.Dense(4, use_bias=False) - x = np.array([1, 2]) - w1, _ = layer1.init(shapes.signature(x)) - layer1.save_to_file('/tmp/dense_weights', - input_signature=shapes.signature(x)) - w2, _ = layer2.init_from_file('/tmp/dense_weights') - self.assertEqual(w1.shape, (2, 4)) - self.assertEqual(w2.shape, (2, 4)) - self.assertEqual(w1.tolist(), w2.tolist()) - - -class EmbeddingTest(absltest.TestCase): - - def test_forward(self): - layer = tl.Embedding(10, 3) # vocab_size=10, d_feature=3 - _, _ = layer.init(None) # Embedding init doesn't use input signature. - x = np.array([2, 3, 5, 3, 2]) - y = layer(x) - self.assertEqual(y.shape, (5, 3)) - - # For distinct in-domain token IDs, resulting vectors should be distinct. - self.assertNotEqual(y[0].tolist(), y[1].tolist()) - self.assertNotEqual(y[0].tolist(), y[2].tolist()) - self.assertNotEqual(y[1].tolist(), y[2].tolist()) - - # For repeats of a token id, resulting vectors should match. - self.assertEqual(y[0].tolist(), y[4].tolist()) - self.assertEqual(y[1].tolist(), y[3].tolist()) - - def test_negative_inputs_clip_to_zero(self): - layer = tl.Embedding(10, 3) - _, _ = layer.init(None) - x = np.array([0, 2, 3, -2, -3]) - y = layer(x) - self.assertNotEqual(y[0].tolist(), y[1].tolist()) - self.assertNotEqual(y[0].tolist(), y[2].tolist()) - self.assertEqual(y[0].tolist(), y[3].tolist()) - self.assertEqual(y[0].tolist(), y[4].tolist()) - - def test_large_inputs_clip_to_upper_bound(self): - layer = tl.Embedding(10, 3) - _, _ = layer.init(None) - x = np.array([2, 3, 9, 10, 20]) - y = layer(x) - - # vocab_size of 10 means max valid token id is 9. - self.assertNotEqual(y[2].tolist(), y[0].tolist()) - self.assertNotEqual(y[2].tolist(), y[1].tolist()) - self.assertEqual(y[2].tolist(), y[3].tolist()) - self.assertEqual(y[2].tolist(), y[4].tolist()) - - def test_new_weights(self): - layer = tl.Embedding(20, 5) - _, _ = layer.init(None) - - # Default weights sampled from Gaussian, mu = 0, sigma = 1. - w = layer.weights - self.assertEqual(w.shape, (20, 5)) - self.assertLess(np.abs(np.mean(w)), .4) # .4 is 4 sigma deviation - - def test_explicit_kernel_initializer(self): - - def f(shape, rng): - del rng - n_elements = np.prod(shape) - return np.arange(n_elements).reshape(shape) - - layer = tl.Embedding(5, 2, kernel_initializer=f) - _, _ = layer.init(None) - x = np.array([0, 1, 2, 3, 4]) - y = layer(x) - self.assertEqual(y.tolist(), [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) - - -class DropoutTest(absltest.TestCase): - - def test_call_in_train_mode(self): - layer = tl.Dropout(rate=0.1, mode='train') - x = np.ones((2, 5, 1000)) # 10,000 values - y = layer(x) - self.assertEqual(y.shape, (2, 5, 1000)) - - # Dropout is stochastic; test it nonflakily at 4 sigmas (.99994). - n_remaining = np.count_nonzero(y) - mu_of_remaining = 9000 # N * q: 10000 * .9 - sigma_of_remaining = 30 # sqrt(N * p * q): sqrt(10000 * .1 * .9) - self.assertLess( - np.abs(n_remaining - mu_of_remaining), 4 * sigma_of_remaining) - - def test_call_in_eval_mode_does_no_dropout(self): - layer = tl.Dropout(rate=0.1, mode='eval') - x = np.ones((2, 5, 1000)) - y = layer(x) - self.assertEqual(np.count_nonzero(y), 10_000) - - def test_new_weights(self): - layer = tl.Dropout(rate=0.1, mode='train') - layer.init(None) - self.assertEmpty(layer.weights) - - -class WeightsTest(absltest.TestCase): - """Test Weights layer.""" - - def test_simple(self): - layer = tl.Weights(lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)) - layer.init(()) - y = layer(()) - self.assertEqual(y.tolist(), 0.) - - def test_shape(self): - layer = tl.Weights(init.RandomNormalInitializer(), (5, 10, 3)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (5, 10, 3)) - - def test_simple_custom_initializer(self): - layer = tl.Weights(init.RandomNormalInitializer()) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, ()) - self.assertNotEqual(y.tolist(), 0.) - - def test_custom_initializer_shape(self): - layer = tl.Weights(lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), - (2, 2)) - layer.init(()) - y = layer(()) - self.assertEqual(y.tolist(), [[0., 0.], - [0., 0.]]) - - layer = tl.Weights(init.RandomNormalInitializer(), (2, 2)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (2, 2)) - self.assertNotEqual(y.tolist(), [[0., 0.], - [0., 0.]]) - - -class SummaryScalarTest(absltest.TestCase): - - def test_passes(self): - layer = tl.SummaryScalar('test') - x = np.array([[3., 5.], [2., 6.]]) # 10,000 values - y = layer(x) - self.assertEqual(y.tolist(), [[3., 5.], [2., 6.]]) - self.assertEqual(layer.state['summary_test'].tolist(), 4.0) - - -class RandomUniformTest(absltest.TestCase): - """Test Weights layer.""" - - def test_simple(self): - layer = tl.RandomUniform() - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, ()) - self.assertBetween(y, 0.0, 1.0) - - def test_shape(self): - layer = tl.RandomUniform(shape=(5, 10, 3)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (5, 10, 3)) - - def test_simple_range(self): - layer = tl.RandomUniform(1., 2., shape=(1000,)) - layer.init(()) - y = layer(()) - self.assertEqual(y.shape, (1000,)) - self.assertBetween(min(y.tolist()), 1., 2.) - self.assertBetween(max(y.tolist()), 1., 2.) - self.assertBetween(1.5, min(y.tolist()), max(y.tolist())) - - -class LocallyConnected1dTest(absltest.TestCase): - - def test_shape_kernel1(self): - for padding in ['WRAP', 'SAME', 'VALID']: - layer = tl.LocallyConnected1d(6, 1, padding=padding) - x = np.array([[0, 1], [2, 3], [4, 5]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (3, 6)) - - def test_shape_kernel3(self): - for padding in ['WRAP', 'SAME']: - layer = tl.LocallyConnected1d(6, 3, padding=padding) - x = np.array([[0, 1], [2, 3], [4, 5]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (3, 6)) - - for padding in ['VALID']: - layer = tl.LocallyConnected1d(6, 3, padding=padding) - x = np.array([[0, 1], [2, 3], [4, 5]]) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 6)) - - -class FlattenTest(absltest.TestCase): - - def test_keep_default(self): - layer = tl.Flatten() - x = np.ones((1, 2, 3, 4, 5)) - y = layer(x) - # Default is leave first axis untouched, flatten the rest. - self.assertEqual(y.shape, (1, 2 * 3 * 4 * 5)) - - def test_keep_3(self): - layer = tl.Flatten(n_axes_to_keep=3) - x = np.ones((1, 2, 3, 4, 5)) - y = layer(x) - self.assertEqual(y.shape, (1, 2, 3, 4 * 5)) - - def test_keep_max_number(self): - layer = tl.Flatten(n_axes_to_keep=4) - x = np.ones((1, 2, 3, 4, 5)) - y = layer(x) - self.assertEqual(y.shape, (1, 2, 3, 4, 5)) - - def test_keep_too_many_raises_error(self): - layer = tl.Flatten(n_axes_to_keep=5) - with self.assertRaises(tl.LayerError): - x = np.ones((1, 2, 3, 4, 5)) - _ = layer(x) - - -class LogSoftmaxTest(absltest.TestCase): - - def test_call(self): - layer = tl.LogSoftmax() - x = np.array([[2., 1., -10.], - [1., 1., -10.]]) - y = layer(x) - np.testing.assert_allclose(y, - [[-0.313, -1.313, -12.313], - [-0.693, -0.693, -11.693]], - atol=.001) - - -class SoftmaxTest(absltest.TestCase): - - def test_call(self): - layer = tl.Softmax() - x = np.array([[2., 1., -10.], - [1., 1., -10.]]) - y = layer(x) - np.testing.assert_allclose(y, - [[.731, .269, .00000449], - [.500, .500, .00000835]], - atol=.001) - - -class CoreFunctionsTest(absltest.TestCase): - - def test_one_hot(self): - targets = np.array([2, 0, 1]) - n_categories = 5 - target_distributions = tl.one_hot(targets, n_categories) - self.assertEqual(tl.to_list(target_distributions), - [[0., 0., 1., 0., 0.], - [1., 0., 0., 0., 0.], - [0., 1., 0., 0., 0.]]) - - def test_log_softmax(self): - activations = np.array([[2., 1., -10.], - [1., 1., -10.]]) - log_probabilities = tl.log_softmax(activations) - np.testing.assert_allclose(log_probabilities, - [[-0.313, -1.313, -12.313], - [-0.693, -0.693, -11.693]], - atol=.001) - - def test_log_gaussian_pdf(self): - x = np.zeros((2, 5), dtype=np.float32) - mu = x - dsigma = np.eye(5)[None, :, :] - sigma = np.concatenate([dsigma, 2 * dsigma], axis=0) - prob = tl.log_gaussian_pdf(x, mu, sigma) - self.assertEqual(prob.shape, (2,)) - self.assertEqual(int(prob[0]), -4) - self.assertEqual(int(prob[1]), -6) - - def test_log_gaussian_diag_pdf(self): - x = np.zeros((2, 5), dtype=np.float32) - mu = x - sigma = np.ones((5,))[None, :] - sigma = np.concatenate([sigma, 2 * sigma], axis=0) - prob = tl.log_gaussian_diag_pdf(x, mu, sigma) - self.assertEqual(prob.shape, (2,)) - self.assertEqual(int(prob[0]), -4) - self.assertEqual(int(prob[1]), -6) - - -class StopGradientTest(absltest.TestCase): - - def test_passes(self): - layer = tl.StopGradient() - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2, 2)) - self.assertEqual(y.tolist(), [[3., 5.], [2., 6.]]) - - -class MinMaxTest(absltest.TestCase): - - def test_min(self): - layer = tl.Min() - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [3., 2.]) - - layer = tl.Min(axis=0) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [2., 5.]) - - layer = tl.Min(axis=None) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, ()) - self.assertEqual(y.tolist(), 2.) - - layer = tl.Min(keepdims=True) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2, 1)) - self.assertEqual(y.tolist(), [[3.], [2.]]) - - def test_max(self): - layer = tl.Max() - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [5., 6.]) - - layer = tl.Max(axis=0) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (2,)) - self.assertEqual(y.tolist(), [3., 6.]) - - layer = tl.Max(axis=None) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, ()) - self.assertEqual(y.tolist(), 6.) - - layer = tl.Max(axis=0, keepdims=True) - x = np.array([[3., 5.], [2., 6.]]) - y = layer(x) - self.assertEqual(y.shape, (1, 2)) - self.assertEqual(y.tolist(), [[3., 6.]]) - - -class ClassifierLayersTest(absltest.TestCase): - - def test_threshold_to_binary(self): - layer = tl.ThresholdToBinary() - x = np.array([.30, .49, .50, .51, .70]) - y = layer(x) - self.assertEqual(y.tolist(), [0, 0, 0, 1, 1]) - - def test_arg_max(self): - layer = tl.ArgMax() - x = np.array([[.10, .90, .20, .80], - [.22, .88, .11, .99]]) - y = layer(x) - self.assertEqual(y.tolist(), [1, 3]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/deconvolution.py b/trax/layers/deconvolution.py index 05c5a26de..645e8d7a6 100644 --- a/trax/layers/deconvolution.py +++ b/trax/layers/deconvolution.py @@ -20,6 +20,7 @@ import operator from jax import lax + from trax import fastmath from trax.fastmath import numpy as jnp from trax.layers import base @@ -27,68 +28,84 @@ class ConvTranspose(base.Layer): - """Layer constructor function for a general Transpose Convolutional Layer.""" + """Layer constructor function for a general Transpose Convolutional Layer.""" - def __init__(self, - filters, - kernel_size, - strides=None, - padding='VALID', - rhs_dilation=None, - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - kernel_initialzer=None, - bias_initializer=init.RandomNormalInitializer(1e-6)): - super(ConvTranspose, self).__init__() - self._filters = filters - self._kernel_size = kernel_size - self._padding = padding - self._rhs_dilation = rhs_dilation - self._dimension_numbers = dimension_numbers - self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers - self._one = (1,) * len(kernel_size) - self._strides = strides or self._one - self._bias_initializer = bias_initializer - rhs_spec = self._rhs_spec - self._kernel_initializer = kernel_initialzer - if kernel_initialzer is None: - self._kernel_initializer = init.GlorotNormalInitializer( - rhs_spec.index('O'), rhs_spec.index('I')) + def __init__( + self, + filters, + kernel_size, + strides=None, + padding="VALID", + rhs_dilation=None, + dimension_numbers=("NHWC", "HWIO", "NHWC"), + kernel_initialzer=None, + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + super(ConvTranspose, self).__init__() + self._filters = filters + self._kernel_size = kernel_size + self._padding = padding + self._rhs_dilation = rhs_dilation + self._dimension_numbers = dimension_numbers + self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers + self._one = (1,) * len(kernel_size) + self._strides = strides or self._one + self._bias_initializer = bias_initializer + rhs_spec = self._rhs_spec + self._kernel_initializer = kernel_initialzer + if kernel_initialzer is None: + self._kernel_initializer = init.GlorotNormalInitializer( + rhs_spec.index("O"), rhs_spec.index("I") + ) - def _check_nhwc(self): - msg = 'Deconvolutions on more than 4 dimensions only supported in NHWC.' - assert self._lhs_spec == self._out_spec == 'NHWC', msg + def _check_nhwc(self): + msg = "Deconvolutions on more than 4 dimensions only supported in NHWC." + assert self._lhs_spec == self._out_spec == "NHWC", msg - def forward(self, x): - w, b = self.weights - x_shape = list(x.shape) - if len(x_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, x.shape[:-3]) - x = jnp.reshape(x, [new_batch_dim] + list(x.shape[-3:])) - res = lax.conv_transpose(x, w, self._strides, self._padding, - self._rhs_dilation, self._dimension_numbers) + b - if len(x_shape) > 4: - res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) - return res + def forward(self, x): + w, b = self.weights + x_shape = list(x.shape) + if len(x_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, x.shape[:-3]) + x = jnp.reshape(x, [new_batch_dim] + list(x.shape[-3:])) + res = ( + lax.conv_transpose( + x, + w, + self._strides, + self._padding, + self._rhs_dilation, + self._dimension_numbers, + ) + + b + ) + if len(x_shape) > 4: + res = jnp.reshape(res, x_shape[:-3] + list(res.shape[-3:])) + return res - def _kernel_shape(self, input_shape): - """Helper to calculate the kernel shape.""" - kernel_size_iter = iter(self._kernel_size) - return [ - self._filters if c == 'O' else input_shape[self._lhs_spec.index('C')] - if c == 'I' else next(kernel_size_iter) for c in self._rhs_spec - ] + def _kernel_shape(self, input_shape): + """Helper to calculate the kernel shape.""" + kernel_size_iter = iter(self._kernel_size) + return [ + self._filters + if c == "O" + else input_shape[self._lhs_spec.index("C")] + if c == "I" + else next(kernel_size_iter) + for c in self._rhs_spec + ] - def init_weights_and_state(self, input_signature): - input_shape = input_signature.shape - if len(input_shape) > 4: - self._check_nhwc() - new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) - input_shape = [new_batch_dim] + list(input_shape[-3:]) - kernel_shape = self._kernel_shape(input_shape) - bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec] - bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) - rng1, rng2 = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer(kernel_shape, rng1) - b = self._bias_initializer(bias_shape, rng2) - self.weights = (w, b) + def init_weights_and_state(self, input_signature): + input_shape = input_signature.shape + if len(input_shape) > 4: + self._check_nhwc() + new_batch_dim = functools.reduce(operator.mul, input_shape[:-3]) + input_shape = [new_batch_dim] + list(input_shape[-3:]) + kernel_shape = self._kernel_shape(input_shape) + bias_shape = [self._filters if c == "C" else 1 for c in self._out_spec] + bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) + rng1, rng2 = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer(kernel_shape, rng1) + b = self._bias_initializer(bias_shape, rng2) + self.weights = (w, b) diff --git a/trax/layers/initializers.py b/trax/layers/initializers.py index 42e9220d5..f4a8c89c4 100644 --- a/trax/layers/initializers.py +++ b/trax/layers/initializers.py @@ -15,183 +15,195 @@ """Trax initializers.""" -from absl import logging - import numpy as np import tensorflow.compat.v2 as tf + +from absl import logging + from trax.fastmath import numpy as jnp from trax.fastmath import random def _GetFans(shape, out_dim=-1, in_dim=-2, nonreceptive_dims=None): - """Get the fan-in and fan-out sizes for the given shape and dims.""" - # Temporary fix until numpy.delete supports negative indices. - if out_dim < 0: - out_dim += len(shape) - if in_dim < 0: - in_dim += len(shape) - - if nonreceptive_dims is None: - nonreceptive_dims = [] - if not isinstance(nonreceptive_dims, (list, tuple)): - nonreceptive_dims = [nonreceptive_dims] - - receptive_field = jnp.prod(np.delete(shape, [in_dim, out_dim, - *nonreceptive_dims])) - if len(shape) >= 2: - fan_in, fan_out = shape[in_dim], shape[out_dim] - elif len(shape) == 1: - fan_in = shape[0] - fan_out = shape[0] - else: - fan_in = 1. - fan_out = 1. - fan_in *= receptive_field - fan_out *= receptive_field - return fan_in, fan_out + """Get the fan-in and fan-out sizes for the given shape and dims.""" + # Temporary fix until numpy.delete supports negative indices. + if out_dim < 0: + out_dim += len(shape) + if in_dim < 0: + in_dim += len(shape) + + if nonreceptive_dims is None: + nonreceptive_dims = [] + if not isinstance(nonreceptive_dims, (list, tuple)): + nonreceptive_dims = [nonreceptive_dims] + + receptive_field = jnp.prod(np.delete(shape, [in_dim, out_dim, *nonreceptive_dims])) + if len(shape) >= 2: + fan_in, fan_out = shape[in_dim], shape[out_dim] + elif len(shape) == 1: + fan_in = shape[0] + fan_out = shape[0] + else: + fan_in = 1.0 + fan_out = 1.0 + fan_in *= receptive_field + fan_out *= receptive_field + return fan_in, fan_out def InitializerFromFile(path): - """Loads parameters from .npy file.""" + """Loads parameters from .npy file.""" - def Initializer(shape, rng): - del rng - logging.info('Loading pretrained embeddings from %s', path) - with tf.io.gfile.GFile(path, 'rb') as f: - parameters = jnp.load(f) - assert jnp.shape(parameters) == shape, ( - 'Expected shape %s, got %s' % (shape, jnp.shape(parameters))) - return parameters + def Initializer(shape, rng): + del rng + logging.info("Loading pretrained embeddings from %s", path) + with tf.io.gfile.GFile(path, "rb") as f: + parameters = jnp.load(f) + assert jnp.shape(parameters) == shape, "Expected shape %s, got %s" % ( + shape, + jnp.shape(parameters), + ) + return parameters - return Initializer + return Initializer def _PureShape(shape): - """Make sure shape does not contain int tensors by calling int().""" - return [int(x) for x in shape] + """Make sure shape does not contain int tensors by calling int().""" + return [int(x) for x in shape] def RandomNormalInitializer(stddev=1e-2): - """Returns an initializer for random normal coefficients.""" - return lambda shape, rng: (stddev * random.normal( # pylint: disable=g-long-lambda - rng, _PureShape(shape)).astype('float32')) + """Returns an initializer for random normal coefficients.""" + return lambda shape, rng: ( + stddev + * random.normal(rng, _PureShape(shape)).astype( # pylint: disable=g-long-lambda + "float32" + ) + ) def RandomUniformInitializer(lim=1.0): - """Returns an initializer for random uniform coefficients.""" - # Make sure shape does not contain int tensors by calling int() below. - return lambda shape, rng: random.uniform( # pylint: disable=g-long-lambda - rng, _PureShape(shape), jnp.float32, -lim, lim) + """Returns an initializer for random uniform coefficients.""" + # Make sure shape does not contain int tensors by calling int() below. + return lambda shape, rng: random.uniform( # pylint: disable=g-long-lambda + rng, _PureShape(shape), jnp.float32, -lim, lim + ) def ScaledInitializer(out_dim, in_dim, scale, mode, distribution): - """Returns an initializer that adjusts its scale based on weight shapes.""" - if scale <= 0.: - raise ValueError('scale must be positive float, {} given'.format(scale)) - if mode not in {'fan_in', 'fan_out', 'fan_avg'}: - raise ValueError( - 'Invalid mode argument:, {}, must be either fan_in, fan_out or fan_avg' - .format(mode)) - - def Init(shape, rng, nonreceptive_dims=None): - """Returns random values for initializing weights of the given `shape`.""" - shape = _PureShape(shape) - fan_in, fan_out = _GetFans(shape, out_dim, in_dim, nonreceptive_dims) - gain = scale - if mode == 'fan_in': - gain /= fan_in - elif mode == 'fan_out': - gain /= fan_out - elif mode == 'fan_avg': - gain /= (fan_in + fan_out) / 2 - if distribution == 'truncated_normal': - # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) - stddev = jnp.sqrt(gain) / .87962566103423978 - new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev - return new_weights.astype('float32') - elif distribution == 'normal': - new_weights = random.normal(rng, shape) * jnp.sqrt(gain) - return new_weights.astype('float32') - elif distribution == 'uniform': - lim = jnp.sqrt(3. * gain) - return random.uniform(rng, shape, jnp.float32, -lim, lim) - else: - raise ValueError('invalid distribution for ScaleInitializer') - - return Init - - -def GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random Glorot-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_avg', 'normal') - - -def GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random uniform Glorot-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_avg', 'uniform') - - -def LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random LeCun-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_in', 'normal') - - -def LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.): - """Returns an initializer for random uniform LeCun-scaled coefficients.""" - return ScaledInitializer(out_dim, in_dim, scale, 'fan_in', 'uniform') - - -def KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.): - """Returns an initializer for random Kaiming-scaled coefficients.""" - return ScaledInitializer( - out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), 'fan_in', 'normal') - - -def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.): - """Returns an initializer for random uniform Kaiming-scaled coefficients.""" - return ScaledInitializer( - out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), 'fan_in', 'uniform') + """Returns an initializer that adjusts its scale based on weight shapes.""" + if scale <= 0.0: + raise ValueError("scale must be positive float, {} given".format(scale)) + if mode not in {"fan_in", "fan_out", "fan_avg"}: + raise ValueError( + "Invalid mode argument:, {}, must be either fan_in, fan_out or fan_avg".format( + mode + ) + ) + + def Init(shape, rng, nonreceptive_dims=None): + """Returns random values for initializing weights of the given `shape`.""" + shape = _PureShape(shape) + fan_in, fan_out = _GetFans(shape, out_dim, in_dim, nonreceptive_dims) + gain = scale + if mode == "fan_in": + gain /= fan_in + elif mode == "fan_out": + gain /= fan_out + elif mode == "fan_avg": + gain /= (fan_in + fan_out) / 2 + if distribution == "truncated_normal": + # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + stddev = jnp.sqrt(gain) / 0.87962566103423978 + new_weights = random.truncated_normal(rng, -2, 2, shape) * stddev + return new_weights.astype("float32") + elif distribution == "normal": + new_weights = random.normal(rng, shape) * jnp.sqrt(gain) + return new_weights.astype("float32") + elif distribution == "uniform": + lim = jnp.sqrt(3.0 * gain) + return random.uniform(rng, shape, jnp.float32, -lim, lim) + else: + raise ValueError("invalid distribution for ScaleInitializer") + + return Init + + +def GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random Glorot-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_avg", "normal") + + +def GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random uniform Glorot-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_avg", "uniform") + + +def LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random LeCun-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_in", "normal") + + +def LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0): + """Returns an initializer for random uniform LeCun-scaled coefficients.""" + return ScaledInitializer(out_dim, in_dim, scale, "fan_in", "uniform") + + +def KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.0): + """Returns an initializer for random Kaiming-scaled coefficients.""" + return ScaledInitializer( + out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), "fan_in", "normal" + ) + + +def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.0): + """Returns an initializer for random uniform Kaiming-scaled coefficients.""" + return ScaledInitializer( + out_dim, in_dim, 2.0 / jnp.sqrt(1 + param**2), "fan_in", "uniform" + ) def OrthogonalInitializer(stddev=1.0): - """Returns an orthogonal initializer.""" - def Init(shape, rng): - """Returns orthogonalized random normal values with the given `shape`.""" - # Have at least 2 elements in shape. - cur_shape = list(shape) - while len(cur_shape) < 2: - cur_shape = [1] + cur_shape + """Returns an orthogonal initializer.""" + + def Init(shape, rng): + """Returns orthogonalized random normal values with the given `shape`.""" + # Have at least 2 elements in shape. + cur_shape = list(shape) + while len(cur_shape) < 2: + cur_shape = [1] + cur_shape - # Flatten the input shape with the last dimension remaining. - n_rows = 1 - for dim in cur_shape[:-1]: - n_rows *= dim - n_cols = cur_shape[-1] - flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) + # Flatten the input shape with the last dimension remaining. + n_rows = 1 + for dim in cur_shape[:-1]: + n_rows *= dim + n_cols = cur_shape[-1] + flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) - # Generate a random matrix - a = random.normal(rng, flat_shape, dtype=jnp.float32) + # Generate a random matrix + a = random.normal(rng, flat_shape, dtype=jnp.float32) - # Compute the qr factorization - q, r = jnp.linalg.qr(a) + # Compute the qr factorization + q, r = jnp.linalg.qr(a) - # Make Q uniform - d = jnp.diag(r) - q *= jnp.sign(d) + # Make Q uniform + d = jnp.diag(r) + q *= jnp.sign(d) - # Transpose and reshape back q if needed. - if n_rows < n_cols: - q = jnp.transpose(q) - q = jnp.reshape(q, shape) + # Transpose and reshape back q if needed. + if n_rows < n_cols: + q = jnp.transpose(q) + q = jnp.reshape(q, shape) - # Return scaled as requested. - return stddev * q + # Return scaled as requested. + return stddev * q - return Init + return Init def AtariConvInit(kernel_shape, rng, dtype=jnp.float32): - """The standard init for Conv laters and Atari.""" - filter_height, filter_width, fan_in, _ = kernel_shape - std = 1 / jnp.sqrt(fan_in * filter_height * filter_width) - return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std) + """The standard init for Conv laters and Atari.""" + filter_height, filter_width, fan_in, _ = kernel_shape + std = 1 / jnp.sqrt(fan_in * filter_height * filter_width) + return random.uniform(rng, kernel_shape, dtype, minval=-std, maxval=std) diff --git a/trax/layers/initializers_test.py b/trax/layers/initializers_test.py deleted file mode 100644 index 921452c58..000000000 --- a/trax/layers/initializers_test.py +++ /dev/null @@ -1,96 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for initializers.""" - -from absl.testing import absltest -import numpy as np - -from trax import fastmath -from trax import test_utils -import trax.layers as tl - - -INPUT_SHAPE = (5, 7, 20) - - -def rng(): # Can't be a constant, because JAX has to init itself in main first. - return fastmath.random.get_prng(0) - - -class InitializersTest(absltest.TestCase): - - def test_random_normal(self): - f = tl.RandomNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_lecun_uniform(self): - f = tl.LeCunUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_random_uniform(self): - f = tl.RandomUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_glorot_normal(self): - f = tl.GlorotNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_glorot_uniform(self): - f = tl.GlorotUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_lecun_normal(self): - f = tl.LeCunNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_kaiming_normal(self): - f = tl.KaimingNormalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_kaiming_uniform(self): - f = tl.KaimingUniformInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_orthogonal(self): - f = tl.OrthogonalInitializer() - init_value = f(INPUT_SHAPE, rng()) - self.assertEqual(init_value.shape, INPUT_SHAPE) - - def test_from_file(self): - params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) - # `create_tempfile` needs access to --test_tmpdir, however in the OSS world - # pytest doesn't run `absltest.main`, so we need to manually parse the flags - test_utils.ensure_flag('test_tmpdir') - filename = self.create_tempfile('params.npy').full_path - with open(filename, 'wb') as f: - np.save(f, params) - f = tl.InitializerFromFile(filename) - init_value = f(params.shape, rng()) - np.testing.assert_almost_equal( - tl.to_list(init_value), tl.to_list(params), decimal=4) - # self.assertEqual('%s' % init_value, '%s' % params) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/intro.ipynb b/trax/layers/intro.ipynb deleted file mode 100644 index d8ec91fbd..000000000 --- a/trax/layers/intro.ipynb +++ /dev/null @@ -1,1400 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "7yuytuIllsv1" - }, - "source": [ - "# Trax Layers Intro\n", - "\n", - "This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:\n", - "\n", - " 1. **Layers**: the basic building blocks and how to combine them\n", - " 1. **Inputs and Outputs**: how data streams flow through layers\n", - " 1. **Defining New Layer Classes** (if combining existing layers isn't enough)\n", - " 1. **Testing and Debugging Layer Classes**\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BIl27504La0G" - }, - "source": [ - "**General Setup**\n", - "\n", - "Execute the following few cells (once) before running any of the code samples in this notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "oILRLCWN_16u" - }, - "outputs": [], - "source": [ - "# Copyright 2018 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", - "import numpy as np\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": { - "height": 51 - }, - "colab_type": "code", - "id": "vlGjGoGMTt-D", - "outputId": "76b95a37-3f1b-4748-bef0-646858f33e25" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/bin/sh: pip: command not found\n", - "/bin/sh: pip: command not found\n" - ] - } - ], - "source": [ - "# Import Trax\n", - "\n", - "! pip install -q -U trax\n", - "! pip install -q tensorflow\n", - "\n", - "from trax import fastmath\n", - "from trax import layers as tl\n", - "from trax import shapes\n", - "from trax.fastmath import numpy as jnp # For use in defining new layer types.\n", - "from trax.shapes import ShapeDtype\n", - "from trax.shapes import signature" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "bYWNWL9MJHv9" - }, - "outputs": [], - "source": [ - "# Settings and utilities for handling inputs, outputs, and object properties.\n", - "\n", - "np.set_printoptions(precision=3) # Reduce visual noise from extra digits.\n", - "\n", - "def show_layer_properties(layer_obj, layer_name):\n", - " template = ('{}.n_in: {}\\n'\n", - " '{}.n_out: {}\\n'\n", - " '{}.sublayers: {}\\n'\n", - " '{}.weights: {}\\n')\n", - " print(template.format(layer_name, layer_obj.n_in,\n", - " layer_name, layer_obj.n_out,\n", - " layer_name, layer_obj.sublayers,\n", - " layer_name, layer_obj.weights))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-LQ89rFFsEdk" - }, - "source": [ - "## 1. Layers\n", - "\n", - "The Layer class represents Trax's basic building blocks:\n", - "```\n", - "class Layer:\n", - " \"\"\"Base class for composable layers in a deep learning network.\n", - "\n", - " Layers are the basic building blocks for deep learning models. A Trax layer\n", - " computes a function from zero or more inputs to zero or more outputs,\n", - " optionally using trainable weights (common) and non-parameter state (not\n", - " common). ...\n", - "\n", - " ...\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "LyLVtdxorDPO" - }, - "source": [ - "### Layers compute functions." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ntZ4_eNQldzL" - }, - "source": [ - "A layer computes a function from zero or more inputs to zero or more outputs.\n", - "The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.\n", - "\n", - "The simplest layers, those with no weights or sublayers, can be used without\n", - "initialization. You can think of them as (pure) mathematical functions that can\n", - "be plugged into neural networks.\n", - "\n", - "For ease of testing and interactive exploration, layer objects implement the\n", - "`__call__ ` method, so you can call them directly on input data:\n", - "```\n", - "y = my_layer(x)\n", - "```\n", - "\n", - "Layers are also objects, so you can inspect their properties. For example:\n", - "```\n", - "print(f'Number of inputs expected by this layer: {my_layer.n_in}')\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "hCoapc5le8B7" - }, - "source": [ - "**Example 1.** tl.Relu $[n_{in} = 1, n_{out} = 1]$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 224 - }, - "colab_type": "code", - "id": "V09viOSEQvQe", - "outputId": "a0134cee-0db8-4396-825e-93e695a42ca5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x:\n", - "[[ -2 -1 0 1 2]\n", - " [-20 -10 0 10 20]]\n", - "\n", - "relu(x):\n", - "[[ 0 0 0 1 2]\n", - " [ 0 0 0 10 20]]\n", - "\n", - "Number of inputs expected by this layer: 1\n", - "Number of outputs promised by this layer: 1\n" - ] - } - ], - "source": [ - "relu = tl.Relu()\n", - "\n", - "x = np.array([[-2, -1, 0, 1, 2],\n", - " [-20, -10, 0, 10, 20]])\n", - "y = relu(x)\n", - "\n", - "# Show input, output, and two layer properties.\n", - "print(f'x:\\n{x}\\n\\n'\n", - " f'relu(x):\\n{y}\\n\\n'\n", - " f'Number of inputs expected by this layer: {relu.n_in}\\n'\n", - " f'Number of outputs promised by this layer: {relu.n_out}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "7sYxIT8crFVE" - }, - "source": [ - "**Example 2.** tl.Concatenate $[n_{in} = 2, n_{out} = 1]$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 255 - }, - "colab_type": "code", - "id": "LMPPNWXLoOZI", - "outputId": "42f595b1-4014-429a-a0b3-2c12d630cd32" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x0:\n", - "[[1 2 3]\n", - " [4 5 6]]\n", - "\n", - "x1:\n", - "[[10 20 30]\n", - " [40 50 60]]\n", - "\n", - "concat([x1, x2]):\n", - "[[ 1 2 3 10 20 30]\n", - " [ 4 5 6 40 50 60]]\n", - "\n", - "Number of inputs expected by this layer: 2\n", - "Number of outputs promised by this layer: 1\n" - ] - } - ], - "source": [ - "concat = tl.Concatenate()\n", - "\n", - "x0 = np.array([[1, 2, 3],\n", - " [4, 5, 6]])\n", - "x1 = np.array([[10, 20, 30],\n", - " [40, 50, 60]])\n", - "y = concat([x0, x1])\n", - "\n", - "print(f'x0:\\n{x0}\\n\\n'\n", - " f'x1:\\n{x1}\\n\\n'\n", - " f'concat([x1, x2]):\\n{y}\\n\\n'\n", - " f'Number of inputs expected by this layer: {concat.n_in}\\n'\n", - " f'Number of outputs promised by this layer: {concat.n_out}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "z7N1qe91eYyM" - }, - "source": [ - "### Layers are configurable.\n", - "\n", - "Many layer types have creation-time parameters for flexibility. The \n", - "`Concatenate` layer type, for instance, has two optional parameters:\n", - "\n", - "* `axis`: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.\n", - "* `n_items`: number of tensors to join into one by concatenation; default value is 2.\n", - "\n", - "The following example shows `Concatenate` configured for **3** input tensors,\n", - "and concatenation along the initial $(0^{th})$ axis." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "l53Jw23pZ4s6" - }, - "source": [ - "**Example 3.** tl.Concatenate(n_items=3, axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 340 - }, - "colab_type": "code", - "id": "bhhWlVLffZtf", - "outputId": "5a8afaa1-66c8-47fe-abcc-e7cfa33bb28c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x0:\n", - "[[1 2 3]\n", - " [4 5 6]]\n", - "\n", - "x1:\n", - "[[10 20 30]\n", - " [40 50 60]]\n", - "\n", - "x2:\n", - "[[100 200 300]\n", - " [400 500 600]]\n", - "\n", - "concat3([x0, x1, x2]):\n", - "[[ 1 2 3]\n", - " [ 4 5 6]\n", - " [ 10 20 30]\n", - " [ 40 50 60]\n", - " [100 200 300]\n", - " [400 500 600]]\n" - ] - } - ], - "source": [ - "concat3 = tl.Concatenate(n_items=3, axis=0)\n", - "\n", - "x0 = np.array([[1, 2, 3],\n", - " [4, 5, 6]])\n", - "x1 = np.array([[10, 20, 30],\n", - " [40, 50, 60]])\n", - "x2 = np.array([[100, 200, 300],\n", - " [400, 500, 600]])\n", - "\n", - "y = concat3([x0, x1, x2])\n", - "\n", - "print(f'x0:\\n{x0}\\n\\n'\n", - " f'x1:\\n{x1}\\n\\n'\n", - " f'x2:\\n{x2}\\n\\n'\n", - " f'concat3([x0, x1, x2]):\\n{y}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1oZv3R8bRMvF" - }, - "source": [ - "### Layers are trainable.\n", - "\n", - "Many layer types include weights that affect the computation of outputs from\n", - "inputs, and they use back-progagated gradients to update those weights.\n", - "\n", - "🚧🚧 *A very small subset of layer types, such as `BatchNorm`, also include\n", - "modifiable weights (called `state`) that are updated based on forward-pass\n", - "inputs/computation rather than back-propagated gradients.*" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3d64M7wLryji" - }, - "source": [ - "**Initialization**\n", - "\n", - "Trainable layers must be initialized before use. Trax can take care of this\n", - "as part of the overall training process. In other settings (e.g., in tests or\n", - "interactively in a Colab notebook), you need to initialize the\n", - "*outermost/topmost* layer explicitly. For this, use `init`:\n", - "\n", - "```\n", - " def init(self, input_signature, rng=None, use_cache=False):\n", - " \"\"\"Initializes weights/state of this layer and its sublayers recursively.\n", - "\n", - " Initialization creates layer weights and state, for layers that use them.\n", - " It derives the necessary array shapes and data types from the layer's input\n", - " signature, which is itself just shape and data type information.\n", - "\n", - " For layers without weights or state, this method safely does nothing.\n", - "\n", - " This method is designed to create weights/state only once for each layer\n", - " instance, even if the same layer instance occurs in multiple places in the\n", - " network. This enables weight sharing to be implemented as layer sharing.\n", - "\n", - " Args:\n", - " input_signature: `ShapeDtype` instance (if this layer takes one input)\n", - " or list/tuple of `ShapeDtype` instances.\n", - " rng: Single-use random number generator (JAX PRNG key), or `None`;\n", - " if `None`, use a default computed from an integer 0 seed.\n", - " use_cache: If `True`, and if this layer instance has already been\n", - " initialized elsewhere in the network, then return special marker\n", - " values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.\n", - " Else return this layer's newly initialized weights and state.\n", - "\n", - " Returns:\n", - " A `(weights, state)` tuple.\n", - " \"\"\"\n", - "```\n", - "\n", - "Input signatures can be built from scratch using `ShapeDType` objects, or can\n", - "be derived from data via the `signature` function (in module `shapes`):\n", - "```\n", - "def signature(obj):\n", - " \"\"\"Returns a `ShapeDtype` signature for the given `obj`.\n", - "\n", - " A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`\n", - " instances. Note that this function is permissive with respect to its inputs\n", - " (accepts lists or tuples or dicts, and underlying objects can be any type\n", - " as long as they have shape and dtype attributes) and returns the corresponding\n", - " nested structure of `ShapeDtype`.\n", - "\n", - " Args:\n", - " obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict\n", - " of such objects.\n", - "\n", - " Returns:\n", - " A corresponding nested structure of `ShapeDtype` instances.\n", - " \"\"\"\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "yL8HAj6GEAp1" - }, - "source": [ - "**Example 4.** tl.LayerNorm $[n_{in} = 1, n_{out} = 1]$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 221 - }, - "colab_type": "code", - "id": "Ie7iyX91qAx2", - "outputId": "0efecdf5-c0a4-4304-f442-d12fc1a51253" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x:\n", - "[[-2. -1. 0. 1. 2.]\n", - " [ 1. 2. 3. 4. 5.]\n", - " [10. 20. 30. 40. 50.]]\n", - "\n", - "layer_norm(x):\n", - "[[-1.414 -0.707 0. 0.707 1.414]\n", - " [-1.414 -0.707 0. 0.707 1.414]\n", - " [-1.414 -0.707 0. 0.707 1.414]]\n", - "\n", - "layer_norm.weights:\n", - "(DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))\n" - ] - } - ], - "source": [ - "layer_norm = tl.LayerNorm()\n", - "\n", - "x = np.array([[-2, -1, 0, 1, 2],\n", - " [1, 2, 3, 4, 5],\n", - " [10, 20, 30, 40, 50]]).astype(np.float32)\n", - "layer_norm.init(shapes.signature(x))\n", - "\n", - "y = layer_norm(x)\n", - "\n", - "print(f'x:\\n{x}\\n\\n'\n", - " f'layer_norm(x):\\n{y}\\n')\n", - "print(f'layer_norm.weights:\\n{layer_norm.weights}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "d47gVdGV1vWw" - }, - "source": [ - "### Layers combine into layers.\n", - "\n", - "The Trax library authors encourage users to build networks and network\n", - "components as combinations of existing layers, by means of a small set of\n", - "_combinator_ layers. A combinator makes a list of layers behave as a single\n", - "layer -- by combining the sublayer computations yet looking from the outside\n", - "like any other layer. The combined layer, like other layers, can:\n", - "\n", - "* compute outputs from inputs,\n", - "* update parameters from gradients, and\n", - "* combine with yet more layers." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vC1ymG2j0iyp" - }, - "source": [ - "**Combine with `Serial`**\n", - "\n", - "The most common way to combine layers is with the `Serial` combinator:\n", - "```\n", - "class Serial(base.Layer):\n", - " \"\"\"Combinator that applies layers serially (by function composition).\n", - "\n", - " This combinator is commonly used to construct deep networks, e.g., like this::\n", - "\n", - " mlp = tl.Serial(\n", - " tl.Dense(128),\n", - " tl.Relu(),\n", - " tl.Dense(10),\n", - " )\n", - "\n", - " A Serial combinator uses stack semantics to manage data for its sublayers.\n", - " Each sublayer sees only the inputs it needs and returns only the outputs it\n", - " has generated. The sublayers interact via the data stack. For instance, a\n", - " sublayer k, following sublayer j, gets called with the data stack in the\n", - " state left after layer j has applied. The Serial combinator then:\n", - "\n", - " - takes n_in items off the top of the stack (n_in = k.n_in) and calls\n", - " layer k, passing those items as arguments; and\n", - "\n", - " - takes layer k's n_out return values (n_out = k.n_out) and pushes\n", - " them onto the data stack.\n", - "\n", - " A Serial instance with no sublayers acts as a special-case (but useful)\n", - " 1-input 1-output no-op.\n", - " \"\"\"\n", - "```\n", - "If one layer has the same number of outputs as the next layer has inputs (which\n", - "is the usual case), the successive layers behave like function composition:\n", - "\n", - "```\n", - "# h(.) = g(f(.))\n", - "layer_h = Serial(\n", - " layer_f,\n", - " layer_g,\n", - ")\n", - "```\n", - "Note how, inside `Serial`, function composition is expressed naturally as a\n", - "succession of operations, so that no nested parentheses are needed.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "uPOnrDa9ViPi" - }, - "source": [ - "**Example 5.** y = layer_norm(relu(x)) $[n_{in} = 1, n_{out} = 1]$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 136 - }, - "colab_type": "code", - "id": "dW5fpusjvjmh", - "outputId": "acdcffe7-23d5-4ecd-df9b-32f48ae77959" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x:\n", - "[[ -2. -1. 0. 1. 2.]\n", - " [-20. -10. 0. 10. 20.]]\n", - "\n", - "layer_block(x):\n", - "[[-0.75 -0.75 -0.75 0.5 1.75]\n", - " [-0.75 -0.75 -0.75 0.5 1.75]]\n" - ] - } - ], - "source": [ - "layer_block = tl.Serial(\n", - " tl.Relu(),\n", - " tl.LayerNorm(),\n", - ")\n", - "\n", - "x = np.array([[-2, -1, 0, 1, 2],\n", - " [-20, -10, 0, 10, 20]]).astype(np.float32)\n", - "layer_block.init(shapes.signature(x))\n", - "y = layer_block(x)\n", - "\n", - "print(f'x:\\n{x}\\n\\n'\n", - " f'layer_block(x):\\n{y}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "bRtmN6ckQO1q" - }, - "source": [ - "And we can inspect the block as a whole, as if it were just another layer:\n", - "\n", - "**Example 5'.** Inspecting a `Serial` layer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 68 - }, - "colab_type": "code", - "id": "D6BpYddZQ1eu", - "outputId": "1a00c9f2-63a0-450c-d902-c9baf06dc917" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "layer_block: Serial[\n", - " Relu\n", - " LayerNorm\n", - "]\n", - "\n", - "layer_block.weights: ((), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32)))\n" - ] - } - ], - "source": [ - "print(f'layer_block: {layer_block}\\n\\n'\n", - " f'layer_block.weights: {layer_block.weights}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "kJ8bpYZtE66x" - }, - "source": [ - "**Combine with `Branch`**\n", - "\n", - "The `Branch` combinator arranges layers into parallel computational channels:\n", - "```\n", - "def Branch(*layers, name='Branch'):\n", - " \"\"\"Combinator that applies a list of layers in parallel to copies of inputs.\n", - "\n", - " Each layer in the input list is applied to as many inputs from the stack\n", - " as it needs, and their outputs are successively combined on stack.\n", - "\n", - " For example, suppose one has three layers:\n", - "\n", - " - F: 1 input, 1 output\n", - " - G: 3 inputs, 1 output\n", - " - H: 2 inputs, 2 outputs (h1, h2)\n", - "\n", - " Then Branch(F, G, H) will take 3 inputs and give 4 outputs:\n", - "\n", - " - inputs: a, b, c\n", - " - outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)\n", - "\n", - " As an important special case, a None argument to Branch acts as if it takes\n", - " one argument, which it leaves unchanged. (It acts as a one-arg no-op.)\n", - "\n", - " Args:\n", - " *layers: List of layers.\n", - " name: Descriptive name for this layer.\n", - "\n", - " Returns:\n", - " A branch layer built from the given sublayers.\n", - " \"\"\"\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "RlPcnRtdIVgq" - }, - "source": [ - "Residual blocks, for example, are implemented using `Branch`:\n", - "```\n", - "def Residual(*layers, shortcut=None):\n", - " \"\"\"Wraps a series of layers with a residual connection.\n", - "\n", - " Args:\n", - " *layers: One or more layers, to be applied in series.\n", - " shortcut: If None (the usual case), the Residual layer computes the\n", - " element-wise sum of the stack-top input with the output of the layer\n", - " series. If specified, the `shortcut` layer applies to a copy of the\n", - " inputs and (elementwise) adds its output to the output from the main\n", - " layer series.\n", - "\n", - " Returns:\n", - " A layer representing a residual connection paired with a layer series.\n", - " \"\"\"\n", - " layers = _ensure_flat(layers)\n", - " layer = layers[0] if len(layers) == 1 else Serial(layers)\n", - " return Serial(\n", - " Branch(shortcut, layer),\n", - " Add(),\n", - " )\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ruX4aFMdUOwS" - }, - "source": [ - "Here's a simple code example to highlight the mechanics." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "JGGnKjg4ESIg" - }, - "source": [ - "**Example 6.** `Branch`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 204 - }, - "colab_type": "code", - "id": "lw6A2YwuW-Ul", - "outputId": "a07ef350-bafa-4fa7-a083-19e6f725b3ce" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x:\n", - "[[ -2 -1 0 1 2]\n", - " [-20 -10 0 10 20]]\n", - "\n", - "y0:\n", - "[[ 0 0 0 1 2]\n", - " [ 0 0 0 10 20]]\n", - "\n", - "y1:\n", - "[[ -200. -100. 0. 100. 200.]\n", - " [-2000. -1000. 0. 1000. 2000.]]\n" - ] - } - ], - "source": [ - "relu = tl.Relu()\n", - "times_100 = tl.Fn(\"Times100\", lambda x: x * 100.0)\n", - "branch_relu_t100 = tl.Branch(relu, times_100)\n", - "\n", - "x = np.array([[-2, -1, 0, 1, 2],\n", - " [-20, -10, 0, 10, 20]])\n", - "branch_relu_t100.init(shapes.signature(x))\n", - "\n", - "y0, y1 = branch_relu_t100(x)\n", - "\n", - "print(f'x:\\n{x}\\n\\n'\n", - " f'y0:\\n{y0}\\n\\n'\n", - " f'y1:\\n{y1}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "zr2ZZ1vO8T8V" - }, - "source": [ - "## 2. Inputs and Outputs\n", - "\n", - "Trax allows layers to have multiple input streams and output streams. When\n", - "designing a network, you have the flexibility to use layers that:\n", - "\n", - " - process a single data stream ($n_{in} = n_{out} = 1$),\n", - " - process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, ... $),\n", - " - split or inject data streams ($n_{in} \u003c n_{out}$), or\n", - " - merge or remove data streams ($n_{in} \u003e n_{out}$).\n", - "\n", - "We saw in section 1 the example of `Residual`, which involves both a split and a merge:\n", - "```\n", - " ...\n", - " return Serial(\n", - " Branch(shortcut, layer),\n", - " Add(),\n", - " )\n", - "```\n", - "In other words, layer by layer:\n", - "\n", - " - `Branch(shortcut, layers)`: makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [$n_{in} = 1$, $n_{out} = 2$]\n", - " - `Add()`: combines the two streams back into one by adding two tensors elementwise. [$n_{in} = 2$, $n_{out} = 1$]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1FEttSCVVM3T" - }, - "source": [ - "### Data Stack\n", - "\n", - "Trax supports flexible data flows through a network via a data stack, which is\n", - "managed by the `Serial` combinator:\n", - "```\n", - "class Serial(base.Layer):\n", - " \"\"\"Combinator that applies layers serially (by function composition).\n", - "\n", - " ...\n", - "\n", - " A Serial combinator uses stack semantics to manage data for its sublayers.\n", - " Each sublayer sees only the inputs it needs and returns only the outputs it\n", - " has generated. The sublayers interact via the data stack. For instance, a\n", - " sublayer k, following sublayer j, gets called with the data stack in the\n", - " state left after layer j has applied. The Serial combinator then:\n", - "\n", - " - takes n_in items off the top of the stack (n_in = k.n_in) and calls\n", - " layer k, passing those items as arguments; and\n", - "\n", - " - takes layer k's n_out return values (n_out = k.n_out) and pushes\n", - " them onto the data stack.\n", - "\n", - " ...\n", - "\n", - " \"\"\"\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "5DAiajI-Gzk4" - }, - "source": [ - "**Simple Case 1 -- Each layer takes one input and has one output.**\n", - "\n", - "This is in effect a single data stream pipeline, and the successive layers\n", - "behave like function composition:\n", - "\n", - "```\n", - "# s(.) = h(g(f(.)))\n", - "layer_s = Serial(\n", - " layer_f,\n", - " layer_g,\n", - " layer_h,\n", - ")\n", - "```\n", - "Note how, inside `Serial`, function composition is expressed naturally as a\n", - "succession of operations, so that no nested parentheses are needed and the\n", - "order of operations matches the textual order of layers.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "WR8bh64tIzIY" - }, - "source": [ - "**Simple Case 2 -- Each layer consumes all outputs of the preceding layer.**\n", - "\n", - "This is still a single pipeline, but data streams internal to it can split and\n", - "merge. The `Residual` example above illustrates this kind.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "ACG88RdtLbvG" - }, - "source": [ - "**General Case -- Successive layers interact via the data stack.**\n", - "\n", - "As described in the `Serial` class docstring, each layer gets its inputs from\n", - "the data stack after the preceding layer has put its outputs onto the stack.\n", - "This covers the simple cases above, but also allows for more flexible data\n", - "interactions between non-adjacent layers. The following example is schematic:\n", - "```\n", - "x, y_target = get_batch_of_labeled_data()\n", - "\n", - "model_plus_eval = Serial(\n", - " my_fancy_deep_model(), # Takes one arg (x) and has one output (y_hat)\n", - " my_eval(), # Takes two args (y_hat, y_target) and has one output (score)\n", - ")\n", - "\n", - "eval_score = model_plus_eval((x, y_target))\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "66hUOOYRQqej" - }, - "source": [ - "Here is the corresponding progression of stack states:\n", - "\n", - "0. At start: _--empty--_\n", - "0. After `get_batch_of_labeled_data()`: *x*, *y_target*\n", - "0. After `my_fancy_deep_model()`: *y_hat*, *y_target*\n", - "0. After `my_eval()`: *score*\n", - "\n", - "Note in particular how the application of the model (between stack states 1\n", - "and 2) only uses and affects the top element on the stack: `x` --\u003e `y_hat`.\n", - "The rest of the data stack (`y_target`) comes in use only later, for the\n", - "eval function." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "65ite-671cTT" - }, - "source": [ - "## 3. Defining New Layer Classes\n", - "\n", - "If you need a layer type that is not easily defined as a combination of\n", - "existing layer types, you can define your own layer classes in a couple\n", - "different ways." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "hHSaD9H6hDTf" - }, - "source": [ - "### With the `Fn` layer-creating function.\n", - "\n", - "Many layer types needed in deep learning compute pure functions from inputs to\n", - "outputs, using neither weights nor randomness. You can use Trax's `Fn` function\n", - "to define your own pure layer types:\n", - "```\n", - "def Fn(name, f, n_out=1): # pylint: disable=invalid-name\n", - " \"\"\"Returns a layer with no weights that applies the function `f`.\n", - "\n", - " `f` can take and return any number of arguments, and takes only positional\n", - " arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).\n", - " The following, for example, would create a layer that takes two inputs and\n", - " returns two outputs -- element-wise sums and maxima:\n", - "\n", - " `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`\n", - "\n", - " The layer's number of inputs (`n_in`) is automatically set to number of\n", - " positional arguments in `f`, but you must explicitly set the number of\n", - " outputs (`n_out`) whenever it's not the default value 1.\n", - "\n", - " Args:\n", - " name: Class-like name for the resulting layer; for use in debugging.\n", - " f: Pure function from input tensors to output tensors, where each input\n", - " tensor is a separate positional arg, e.g., `f(x0, x1) --\u003e x0 + x1`.\n", - " Output tensors must be packaged as specified in the `Layer` class\n", - " docstring.\n", - " n_out: Number of outputs promised by the layer; default value 1.\n", - "\n", - " Returns:\n", - " Layer executing the function `f`.\n", - " \"\"\"\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "TX30lGLXcjB1" - }, - "source": [ - "**Example 7.** Use `Fn` to define a new layer type:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 153 - }, - "colab_type": "code", - "id": "vKrc6XMV9ErS", - "outputId": "13f74094-e43e-4267-9055-f3d55d58ae53" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x0:\n", - "[ 1 2 3 4 5 6 7 8 9 10]\n", - "\n", - "x1:\n", - "[11 12 13 14 15 16 17 18 19 20]\n", - "\n", - "gcd((x0, x1)):\n", - "[ 1 2 1 2 5 2 1 2 1 10]\n" - ] - } - ], - "source": [ - "# Define new layer type.\n", - "def Gcd():\n", - " \"\"\"Returns a layer to compute the greatest common divisor, elementwise.\"\"\"\n", - " return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))\n", - "\n", - "# Use it.\n", - "gcd = Gcd()\n", - "\n", - "x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n", - "x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])\n", - "\n", - "y = gcd((x0, x1))\n", - "\n", - "print(f'x0:\\n{x0}\\n\\n'\n", - " f'x1:\\n{x1}\\n\\n'\n", - " f'gcd((x0, x1)):\\n{y}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "W74Eehgp5A57" - }, - "source": [ - "The `Fn` function infers `n_in` (number of inputs) as the length of `f`'s arg\n", - "list. `Fn` does not infer `n_out` (number out outputs) though. If your `f` has\n", - "more than one output, you need to give an explicit value using the `n_out`\n", - "keyword arg." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "2lCjml7SCR-u" - }, - "source": [ - "**Example 8.** `Fn` with multiple outputs:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 204 - }, - "colab_type": "code", - "id": "rfnA2B9ZczWK", - "outputId": "9ffd7648-ffda-453e-b88b-4aa4ba8ea482" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x0:\n", - "[1 2 3 4 5]\n", - "\n", - "x1:\n", - "[ 10 -20 30 -40 50]\n", - "\n", - "y0:\n", - "[ 11 -18 33 -36 55]\n", - "\n", - "y1:\n", - "[10 2 30 4 50]\n" - ] - } - ], - "source": [ - "# Define new layer type.\n", - "def SumAndMax():\n", - " \"\"\"Returns a layer to compute sums and maxima of two input tensors.\"\"\"\n", - " return tl.Fn('SumAndMax',\n", - " lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),\n", - " n_out=2)\n", - "\n", - "# Use it.\n", - "sum_and_max = SumAndMax()\n", - "\n", - "x0 = np.array([1, 2, 3, 4, 5])\n", - "x1 = np.array([10, -20, 30, -40, 50])\n", - "\n", - "y0, y1 = sum_and_max([x0, x1])\n", - "\n", - "print(f'x0:\\n{x0}\\n\\n'\n", - " f'x1:\\n{x1}\\n\\n'\n", - " f'y0:\\n{y0}\\n\\n'\n", - " f'y1:\\n{y1}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "GrXQUSbKDs41" - }, - "source": [ - "**Example 9.** Use `Fn` to define a configurable layer:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "height": 374 - }, - "colab_type": "code", - "id": "h1KwpmFpEIK3", - "outputId": "9f6e7009-04a0-46c9-b005-35c091f720eb" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "x:\n", - "[[[ 1 2 3]\n", - " [ 10 20 30]\n", - " [100 200 300]]\n", - "\n", - " [[ 4 5 6]\n", - " [ 40 50 60]\n", - " [400 500 600]]]\n", - "\n", - "flatten_keep_1_axis(x):\n", - "[[ 1 2 3 10 20 30 100 200 300]\n", - " [ 4 5 6 40 50 60 400 500 600]]\n", - "\n", - "flatten_keep_2_axes(x):\n", - "[[[ 1 2 3]\n", - " [ 10 20 30]\n", - " [100 200 300]]\n", - "\n", - " [[ 4 5 6]\n", - " [ 40 50 60]\n", - " [400 500 600]]]\n" - ] - } - ], - "source": [ - "# Function defined in trax/layers/core.py:\n", - "def Flatten(n_axes_to_keep=1):\n", - " \"\"\"Returns a layer that combines one or more trailing axes of a tensor.\n", - "\n", - " Flattening keeps all the values of the input tensor, but reshapes it by\n", - " collapsing one or more trailing axes into a single axis. For example, a\n", - " `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape\n", - " `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.\n", - "\n", - " Args:\n", - " n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;\n", - " collapse only the axes after these.\n", - " \"\"\"\n", - " layer_name = f'Flatten_keep{n_axes_to_keep}'\n", - " def f(x):\n", - " in_rank = len(x.shape)\n", - " if in_rank \u003c= n_axes_to_keep:\n", - " raise ValueError(f'Input rank ({in_rank}) must exceed the number of '\n", - " f'axes to keep ({n_axes_to_keep}) after flattening.')\n", - " return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))\n", - " return tl.Fn(layer_name, f)\n", - "\n", - "flatten_keep_1_axis = Flatten(n_axes_to_keep=1)\n", - "flatten_keep_2_axes = Flatten(n_axes_to_keep=2)\n", - "\n", - "x = np.array([[[1, 2, 3],\n", - " [10, 20, 30],\n", - " [100, 200, 300]],\n", - " [[4, 5, 6],\n", - " [40, 50, 60],\n", - " [400, 500, 600]]])\n", - "\n", - "y1 = flatten_keep_1_axis(x)\n", - "y2 = flatten_keep_2_axes(x)\n", - "\n", - "print(f'x:\\n{x}\\n\\n'\n", - " f'flatten_keep_1_axis(x):\\n{y1}\\n\\n'\n", - " f'flatten_keep_2_axes(x):\\n{y2}')\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "cqM6WJwNhoHI" - }, - "source": [ - "### By defining a `Layer` subclass\n", - "\n", - "If you need a layer type that uses trainable weights (or state), you can extend\n", - "the base `Layer` class:\n", - "```\n", - "class Layer:\n", - " \"\"\"Base class for composable layers in a deep learning network.\n", - "\n", - " ...\n", - "\n", - " Authors of new layer subclasses typically override at most two methods of\n", - " the base `Layer` class:\n", - "\n", - " `forward(inputs)`:\n", - " Computes this layer's output as part of a forward pass through the model.\n", - "\n", - " `init_weights_and_state(self, input_signature)`:\n", - " Initializes weights and state for inputs with the given signature.\n", - "\n", - " ...\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "tZlzxNUigD_4" - }, - "source": [ - "The `forward` method uses *weights stored in the layer object* (`self.weights`)\n", - "to compute outputs from inputs. For example, here is the definition of\n", - "`forward` for Trax's `Dense` layer:\n", - "```\n", - " def forward(self, x):\n", - " \"\"\"Executes this layer as part of a forward pass through the model.\n", - "\n", - " Args:\n", - " x: Tensor of same shape and dtype as the input signature used to\n", - " initialize this layer.\n", - "\n", - " Returns:\n", - " Tensor of same shape and dtype as the input, except the final dimension\n", - " is the layer's `n_units` value.\n", - " \"\"\"\n", - " if self._use_bias:\n", - " if not isinstance(self.weights, (tuple, list)):\n", - " raise ValueError(f'Weights should be a (w, b) tuple or list; '\n", - " f'instead got: {self.weights}')\n", - " w, b = self.weights\n", - " return jnp.dot(x, w) + b # Affine map.\n", - " else:\n", - " w = self.weights\n", - " return jnp.dot(x, w) # Linear map.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PJEEyX9_iPbk" - }, - "source": [ - "Layer weights must be initialized before the layer can be used; the\n", - "`init_weights_and_state` method specifies how. Continuing the `Dense` example,\n", - "here is the corresponding initialization code:\n", - "```\n", - " def init_weights_and_state(self, input_signature):\n", - " \"\"\"Randomly initializes this layer's weights.\n", - "\n", - " Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the\n", - " default case), or a `w` tensor for layers created with `use_bias=False`.\n", - "\n", - " Args:\n", - " input_signature: `ShapeDtype` instance characterizing the input this layer\n", - " should compute on.\n", - " \"\"\"\n", - " shape_w = (input_signature.shape[-1], self._n_units)\n", - " shape_b = (self._n_units,)\n", - " rng_w, rng_b = fastmath.random.split(self.rng, 2)\n", - " w = self._kernel_initializer(shape_w, rng_w)\n", - "\n", - " if self._use_bias:\n", - " b = self._bias_initializer(shape_b, rng_b)\n", - " self.weights = (w, b)\n", - " else:\n", - " self.weights = w\n", - "\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "D77mYZZD41QO" - }, - "source": [ - "### By defining a `Combinator` subclass\n", - "\n", - "*TBD*" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PgdQvZ5G6Aei" - }, - "source": [ - "## 4. Testing and Debugging Layer Classes\n", - "\n", - "*TBD*" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "last_runtime": { - "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Trax Layers Intro", - "provenance": [ - { - "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", - "timestamp": 1569980697572 - }, - { - "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", - "timestamp": 1563927451951 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/layers/metrics.py b/trax/layers/metrics.py index 9cea73922..fe80da1cc 100644 --- a/trax/layers/metrics.py +++ b/trax/layers/metrics.py @@ -45,239 +45,243 @@ """ from trax import fastmath -from trax import shapes from trax.fastmath import numpy as jnp -from trax.layers import base +from trax.layers import base, core from trax.layers import combinators as cb -from trax.layers import core +from trax.utils import shapes def CategoryAccuracy(): - r"""Returns a layer that computes category prediction accuracy. + r"""Returns a layer that computes category prediction accuracy. - The layer takes two inputs: + The layer takes two inputs: - - A batch of activation vectors. The components in a given vector should - be mappable to a probability distribution in the following loose sense: - within a vector, a higher component value corresponds to a higher - probability, such that argmax within a vector (``axis=-1``) picks the - index (category) having the highest probablity. + - A batch of activation vectors. The components in a given vector should + be mappable to a probability distribution in the following loose sense: + within a vector, a higher component value corresponds to a higher + probability, such that argmax within a vector (``axis=-1``) picks the + index (category) having the highest probablity. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`. - The predicted category from each vector is the index of the highest-valued - vector component. The layer returns the accuracy of these predictions - averaged over the batch. - """ - def f(model_output, targets): # pylint: disable=invalid-name - predictions = jnp.argmax(model_output, axis=-1) - shapes.assert_same_shape(predictions, targets) - n_total = predictions.size - n_correct = jnp.sum(jnp.equal(predictions, targets)) - return n_correct / n_total + The predicted category from each vector is the index of the highest-valued + vector component. The layer returns the accuracy of these predictions + averaged over the batch. + """ + + def f(model_output, targets): # pylint: disable=invalid-name + predictions = jnp.argmax(model_output, axis=-1) + shapes.assert_same_shape(predictions, targets) + n_total = predictions.size + n_correct = jnp.sum(jnp.equal(predictions, targets)) + return n_correct / n_total - return base.Fn('CategoryAccuracy', f) + return base.Fn("CategoryAccuracy", f) def _n_weights_per_core(weights): # pylint: disable=invalid-name - """Calculates the number of weights per core. - - In multi-device settings, gradients and losses are averaged over all devices. - When loss is weighted and the number of weights can differ by device, e.g., - when the weights represent the number of tokens in a batch of sentences (which - can differ from device to device), we want to make sure each token on each - device is weighted in the same way. This function ensures that by reporting - the number of weights per core in multi-core settings (and simply - np.sum(weights) in a single-core setting). - - Args: - weights: tensor with arbitrary shape - - Returns: - a scalar equal to np.sum(weights) in 1-machine settings and to the sum - of weights over all cores divided by the number of cores otherwise - """ - weights_sum = jnp.sum(weights) - if fastmath.global_device_count() < 2: - return weights_sum - else: - try: - n_devices_total = fastmath.psum(1, 'batch') - return fastmath.psum(weights_sum, 'batch') / n_devices_total - except (NameError, ValueError): # running outside of pmap, e.g., on init - return weights_sum # fall back to the sum + """Calculates the number of weights per core. + + In multi-device settings, gradients and losses are averaged over all devices. + When loss is weighted and the number of weights can differ by device, e.g., + when the weights represent the number of tokens in a batch of sentences (which + can differ from device to device), we want to make sure each token on each + device is weighted in the same way. This function ensures that by reporting + the number of weights per core in multi-core settings (and simply + np.sum(weights) in a single-core setting). + + Args: + weights: tensor with arbitrary shape + + Returns: + a scalar equal to np.sum(weights) in 1-machine settings and to the sum + of weights over all cores divided by the number of cores otherwise + """ + weights_sum = jnp.sum(weights) + if fastmath.global_device_count() < 2: + return weights_sum + else: + try: + n_devices_total = fastmath.psum(1, "batch") + return fastmath.psum(weights_sum, "batch") / n_devices_total + except (NameError, ValueError): # running outside of pmap, e.g., on init + return weights_sum # fall back to the sum def _non_nan(x): # pylint: disable=invalid-name - """Replaces NaN values with zeros. + """Replaces NaN values with zeros. - A support function replaces NaN values with zeros to escape - the undefined behavior of the division by zero. + A support function replaces NaN values with zeros to escape + the undefined behavior of the division by zero. - Args: - x: tensor with arbitrary shape. + Args: + x: tensor with arbitrary shape. - Returns: - Array with NaNs replaced with 0. - """ - return jnp.where(jnp.isnan(x), 0., x) + Returns: + Array with NaNs replaced with 0. + """ + return jnp.where(jnp.isnan(x), 0.0, x) def _precision_recall(predictions, targets, k): # pylint: disable=invalid-name - """Returns precision, recall, and intermediate values for the category `k`. - - A support function for calculating precision, recall, - and intermediate values for the single category `k` - for future use in metric layers. - - Args: - predictions: predicted categories. - targets: target categories. - k: a category number. - - Returns a tuple: - n_correct: a number of correct (or true) examples. - n_k_predictions: a number of predictions of the `k` category. - n_k_targets: a number of targets for the `k` category. - precision: a precision score. - recall: a recall score. - """ - n_correct = sum((predictions == k) & (targets == k)) - n_k_predictions = sum(predictions == k) - precision = _non_nan(n_correct / n_k_predictions) - n_k_targets = sum(targets == k) - recall = _non_nan(n_correct / n_k_targets) - return (n_correct, n_k_predictions, n_k_targets, precision, recall) + """Returns precision, recall, and intermediate values for the category `k`. + + A support function for calculating precision, recall, + and intermediate values for the single category `k` + for future use in metric layers. + + Args: + predictions: predicted categories. + targets: target categories. + k: a category number. + + Returns a tuple: + n_correct: a number of correct (or true) examples. + n_k_predictions: a number of predictions of the `k` category. + n_k_targets: a number of targets for the `k` category. + precision: a precision score. + recall: a recall score. + """ + n_correct = sum((predictions == k) & (targets == k)) + n_k_predictions = sum(predictions == k) + precision = _non_nan(n_correct / n_k_predictions) + n_k_targets = sum(targets == k) + recall = _non_nan(n_correct / n_k_targets) + return (n_correct, n_k_predictions, n_k_targets, precision, recall) def _f_score(precision, recall, beta2): # pylint: disable=invalid-name - """Returns F-score. + """Returns F-score. - Args: - precision: a precision score. - recall: a recall score. - beta2: a square of the parameter that determines the weight of recall. + Args: + precision: a precision score. + recall: a recall score. + beta2: a square of the parameter that determines the weight of recall. - A support function to calculate F-score for the single category. - """ - return _non_nan( - (beta2 + 1) * (precision * recall) / ((beta2 * precision) + recall)) + A support function to calculate F-score for the single category. + """ + return _non_nan((beta2 + 1) * (precision * recall) / ((beta2 * precision) + recall)) def WeightedCategoryAccuracy(): - r"""Returns a layer that computes a weighted category prediction accuracy. + r"""Returns a layer that computes a weighted category prediction accuracy. - The layer takes three inputs: + The layer takes three inputs: - - A batch of activation vectors. The components in a given vector should - be mappable to a probability distribution in the following loose sense: - within a vector, a higher component value corresponds to a higher - probability, such that argmax within a vector (``axis=-1``) picks the - index (category) having the highest probablity. + - A batch of activation vectors. The components in a given vector should + be mappable to a probability distribution in the following loose sense: + within a vector, a higher component value corresponds to a higher + probability, such that argmax within a vector (``axis=-1``) picks the + index (category) having the highest probablity. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. - - A batch of weights, which matches or can be broadcast to match the shape - of the target ndarray. This arg can give uneven weighting to different - items in the batch (depending, for instance, on the item's target - category). + - A batch of weights, which matches or can be broadcast to match the shape + of the target ndarray. This arg can give uneven weighting to different + items in the batch (depending, for instance, on the item's target + category). - The predicted category from each vector is the index of the highest-valued - vector component. The layer returns a weighted average accuracy of these - predictions. - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - predictions = jnp.argmax(model_output, axis=-1) - shapes.assert_same_shape(predictions, targets) - ones_and_zeros = jnp.equal(predictions, targets) - return jnp.sum(ones_and_zeros * weights) / _n_weights_per_core(weights) + The predicted category from each vector is the index of the highest-valued + vector component. The layer returns a weighted average accuracy of these + predictions. + """ - return base.Fn('WeightedCategoryAccuracy', f) + def f(model_output, targets, weights): # pylint: disable=invalid-name + predictions = jnp.argmax(model_output, axis=-1) + shapes.assert_same_shape(predictions, targets) + ones_and_zeros = jnp.equal(predictions, targets) + return jnp.sum(ones_and_zeros * weights) / _n_weights_per_core(weights) + + return base.Fn("WeightedCategoryAccuracy", f) def CategoryCrossEntropy(label_smoothing=None): - r"""Returns a layer that computes cross-entropy from activations and integers. + r"""Returns a layer that computes cross-entropy from activations and integers. - The layer takes two inputs: + The layer takes two inputs: - - A batch of activation vectors. The components in a given vector should - be pre-softmax activations (mappable to a probability distribution via - softmax). For performance reasons, the softmax and cross-entropy - computations are combined inside the layer. + - A batch of activation vectors. The components in a given vector should + be pre-softmax activations (mappable to a probability distribution via + softmax). For performance reasons, the softmax and cross-entropy + computations are combined inside the layer. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. - To compute cross-entropy per batch item, the layer derives probability - distributions: + To compute cross-entropy per batch item, the layer derives probability + distributions: - - from model output (vectors): :math:`\ q = \text{softmax}(v)` + - from model output (vectors): :math:`\ q = \text{softmax}(v)` - - from target categories (integers): :math:`\ p = \text{one_hot}(n)` or - :math:`p = (1-\varepsilon)\cdot\text{one_hot}(n) + \frac{\varepsilon}{N}`, - where :math:`\varepsilon` is the label smoothing factor. + - from target categories (integers): :math:`\ p = \text{one_hot}(n)` or + :math:`p = (1-\varepsilon)\cdot\text{one_hot}(n) + \frac{\varepsilon}{N}`, + where :math:`\varepsilon` is the label smoothing factor. - (The conversion of integer category targets to one-hot vectors amounts to - assigning all the probability mass to the target category.) Cross-entropy - per batch item is computed between the resulting distributions: + (The conversion of integer category targets to one-hot vectors amounts to + assigning all the probability mass to the target category.) Cross-entropy + per batch item is computed between the resulting distributions: - .. math:: - \text{cross_entropy} = - \sum_{i=0}^{N-1} p_i \log q_i + .. math:: + \text{cross_entropy} = - \sum_{i=0}^{N-1} p_i \log q_i - The layer returns the average of these cross-entropy values over all items in - the batch. + The layer returns the average of these cross-entropy values over all items in + the batch. - Args: - label_smoothing: Creates soft targets if provided. Must be between 0 and 1. - """ - def f(model_output, targets): # pylint: disable=invalid-name - cross_entropies = _category_cross_entropy( - model_output, targets, label_smoothing, 0.0) - return jnp.average(cross_entropies) + Args: + label_smoothing: Creates soft targets if provided. Must be between 0 and 1. + """ + + def f(model_output, targets): # pylint: disable=invalid-name + cross_entropies = _category_cross_entropy( + model_output, targets, label_smoothing, 0.0 + ) + return jnp.average(cross_entropies) - return base.Fn('CategoryCrossEntropy', f) + return base.Fn("CategoryCrossEntropy", f) def WeightedCategoryCrossEntropy(label_smoothing=None, cutoff=0.0): - r"""Returns a layer like ``CategoryCrossEntropy``, with weights as third input. + r"""Returns a layer like ``CategoryCrossEntropy``, with weights as third input. - The layer takes three inputs: + The layer takes three inputs: - - A batch of activation vectors. The components in a given vector should - be pre-softmax activations (mappable to a probability distribution via - softmax). For performance reasons, the softmax and cross-entropy - computations are combined inside the layer. + - A batch of activation vectors. The components in a given vector should + be pre-softmax activations (mappable to a probability distribution via + softmax). For performance reasons, the softmax and cross-entropy + computations are combined inside the layer. - - A batch of target categories; each target is an integer in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. + - A batch of target categories; each target is an integer in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. - - A batch of weights, which matches or can be broadcast to match the shape - of the target ndarray. This arg can give uneven weighting to different - items in the batch (depending, for instance, on the item's target - category). + - A batch of weights, which matches or can be broadcast to match the shape + of the target ndarray. This arg can give uneven weighting to different + items in the batch (depending, for instance, on the item's target + category). - The layer returns the weighted average of these cross-entropy values over all - items in the batch. + The layer returns the weighted average of these cross-entropy values over all + items in the batch. - Args: - label_smoothing: Creates soft targets if provided. Must be between 0 and 1. - cutoff: Prevent loss lower than this cutoff (0.0 meaning none by default). - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - cross_entropies = _category_cross_entropy( - model_output, targets, label_smoothing, cutoff) - return jnp.sum(cross_entropies * weights) / _n_weights_per_core(weights) + Args: + label_smoothing: Creates soft targets if provided. Must be between 0 and 1. + cutoff: Prevent loss lower than this cutoff (0.0 meaning none by default). + """ - return base.Fn('WeightedCategoryCrossEntropy', f) + def f(model_output, targets, weights): # pylint: disable=invalid-name + cross_entropies = _category_cross_entropy( + model_output, targets, label_smoothing, cutoff + ) + return jnp.sum(cross_entropies * weights) / _n_weights_per_core(weights) + + return base.Fn("WeightedCategoryCrossEntropy", f) def BinaryCrossEntropy(): - r"""Returns a layer that computes cross-entropy for binary classification. + r"""Returns a layer that computes cross-entropy for binary classification. The layer takes two inputs: @@ -305,156 +309,168 @@ def BinaryCrossEntropy(): The layer returns the average of these cross-entropy values over all items in the batch. """ - def f(model_output, targets): # pylint: disable=invalid-name - probabilities = fastmath.expit(model_output) - binary_entropies = - (targets * jnp.log(probabilities) + - (1 - targets) * (jnp.log(1 - probabilities))) - return jnp.average(binary_entropies) - - return base.Fn('BinaryCrossEntropy', f) + def f(model_output, targets): # pylint: disable=invalid-name + probabilities = fastmath.expit(model_output) + binary_entropies = -( + targets * jnp.log(probabilities) + + (1 - targets) * (jnp.log(1 - probabilities)) + ) + return jnp.average(binary_entropies) -def MaskedSequenceAccuracy(): - r"""Returns a layer that computes sequence prediction accuracy with masking. + return base.Fn("BinaryCrossEntropy", f) - This layer type is intended for variable length sequences, especially text, - represented as a batch of fixed-length sequences via padding for unused - positions. - - The layer takes three inputs: - - A batch of sequences of activation vectors. The components in a given - vector should be mappable to a probability distribution in the following - loose sense: within a vector, a higher component value corresponds to a - higher probability, such that argmax within a vector (``axis=-1``) picks - the index having the highest probablity. In text modeling, the index - represents a token id from a predetermined token vocabulary (or padding). - - - A batch of target integer sequences, with values in - :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector - depth/dimensionality. In text modeling, these sequences typically - represent token ids from a predetermined token vocabulary (or padding). - - - A batch of weights/masks, which matches or can be broadcast to match the - shape of the target ndarray. This arg is used to give weight 0 to padding - positions, which masks those positions out of the calculation. Only the - zero/non-zero distinction matters; all non-zero values are treated alike - as signaling non-masked (i.e., valid/in-use) positions. - - The predicted integer value for each sequence position is the index of the - highest-valued component of the position's vector. A predicted integer - sequence is judged correct if it matches the target integer sequence in all - non-zero-weighted positions. The layer returns the accuracy of predicted - sequences averaged over the batch. - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - predictions = jnp.argmax(model_output, axis=-1) - shapes.assert_same_shape(predictions, targets) - position_is_padding = jnp.equal(weights, 0) - position_is_accurate = jnp.logical_or(jnp.equal(predictions, targets), - position_is_padding) - sequence_is_accurate = jnp.all(position_is_accurate, axis=-1) - return jnp.average(sequence_is_accurate) - - return base.Fn('MaskedSequenceAccuracy', f) +def MaskedSequenceAccuracy(): + r"""Returns a layer that computes sequence prediction accuracy with masking. + + This layer type is intended for variable length sequences, especially text, + represented as a batch of fixed-length sequences via padding for unused + positions. + + The layer takes three inputs: + + - A batch of sequences of activation vectors. The components in a given + vector should be mappable to a probability distribution in the following + loose sense: within a vector, a higher component value corresponds to a + higher probability, such that argmax within a vector (``axis=-1``) picks + the index having the highest probablity. In text modeling, the index + represents a token id from a predetermined token vocabulary (or padding). + + - A batch of target integer sequences, with values in + :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector + depth/dimensionality. In text modeling, these sequences typically + represent token ids from a predetermined token vocabulary (or padding). + + - A batch of weights/masks, which matches or can be broadcast to match the + shape of the target ndarray. This arg is used to give weight 0 to padding + positions, which masks those positions out of the calculation. Only the + zero/non-zero distinction matters; all non-zero values are treated alike + as signaling non-masked (i.e., valid/in-use) positions. + + The predicted integer value for each sequence position is the index of the + highest-valued component of the position's vector. A predicted integer + sequence is judged correct if it matches the target integer sequence in all + non-zero-weighted positions. The layer returns the accuracy of predicted + sequences averaged over the batch. + """ + + def f(model_output, targets, weights): # pylint: disable=invalid-name + predictions = jnp.argmax(model_output, axis=-1) + shapes.assert_same_shape(predictions, targets) + position_is_padding = jnp.equal(weights, 0) + position_is_accurate = jnp.logical_or( + jnp.equal(predictions, targets), position_is_padding + ) + sequence_is_accurate = jnp.all(position_is_accurate, axis=-1) + return jnp.average(sequence_is_accurate) + + return base.Fn("MaskedSequenceAccuracy", f) def Accuracy(classifier=core.ArgMax()): - """Returns a layer that computes mean category prediction accuracy. + """Returns a layer that computes mean category prediction accuracy. - DEPRECATED; use ``WeightedCategoryAccuracy`` instead. + DEPRECATED; use ``WeightedCategoryAccuracy`` instead. - Args: - classifier: Layer that transforms activation vectors into category - predictions. - """ - return cb.Serial(classifier, - _Accuracy(), - _WeightedMean(), - name='Accuracy', - sublayers_to_print=[]) + Args: + classifier: Layer that transforms activation vectors into category + predictions. + """ + return cb.Serial( + classifier, _Accuracy(), _WeightedMean(), name="Accuracy", sublayers_to_print=[] + ) def SequenceAccuracy(classifier=core.ArgMax()): - """Returns a layer that computes mean sequence prediction accuracy. + """Returns a layer that computes mean sequence prediction accuracy. - DEPRECATED; use ``MaskedSequenceAccuracy`` instead. + DEPRECATED; use ``MaskedSequenceAccuracy`` instead. - Args: - classifier: Layer that transforms activation vectors into category - predictions. - """ - return cb.Serial(classifier, - _Accuracy(), - _WeightedSequenceMean(), - name='SequenceAccuracy', - sublayers_to_print=[]) + Args: + classifier: Layer that transforms activation vectors into category + predictions. + """ + return cb.Serial( + classifier, + _Accuracy(), + _WeightedSequenceMean(), + name="SequenceAccuracy", + sublayers_to_print=[], + ) def CrossEntropyLoss(): - """Returns a layer that outputs multiclass prediction-target cross-entropy. + """Returns a layer that outputs multiclass prediction-target cross-entropy. - DEPRECATED; refactor to use ``WeightedCategoryCrossEntropy`` or - ``CategoryCrossEntropy`` instead. + DEPRECATED; refactor to use ``WeightedCategoryCrossEntropy`` or + ``CategoryCrossEntropy`` instead. - (``CrossEntropyLoss`` by itself does not compute cross-entropy. In older - code, this layer had to be preceded by ``LogSoftmax``, and the two layers - together did the work of converting category information to probability - distributions and computing the cross-entropy between those distributions. - All this is now done by ``WeightedCategoryCrossEntropy``.) - """ - return cb.Serial(_CrossEntropy(), - _WeightedMean(), - name='CrossEntropyLoss', - sublayers_to_print=[]) + (``CrossEntropyLoss`` by itself does not compute cross-entropy. In older + code, this layer had to be preceded by ``LogSoftmax``, and the two layers + together did the work of converting category information to probability + distributions and computing the cross-entropy between those distributions. + All this is now done by ``WeightedCategoryCrossEntropy``.) + """ + return cb.Serial( + _CrossEntropy(), _WeightedMean(), name="CrossEntropyLoss", sublayers_to_print=[] + ) def CrossEntropyLossWithLogSoftmax(): - """Mean prediction-target cross-entropy for multiclass classification.""" - return cb.Serial(core.LogSoftmax(), _CrossEntropy(), _WeightedMean(), - name='CrossEntropyLossWithLogSoftmax', - sublayers_to_print=[]) + """Mean prediction-target cross-entropy for multiclass classification.""" + return cb.Serial( + core.LogSoftmax(), + _CrossEntropy(), + _WeightedMean(), + name="CrossEntropyLossWithLogSoftmax", + sublayers_to_print=[], + ) def BinaryCrossEntropyLoss(): - """Returns a layer that outputs binary prediction-target cross-entropy. + """Returns a layer that outputs binary prediction-target cross-entropy. - DEPRECATED; refactor to use ``BinaryCrossEntropy`` instead. (The newer - ``BinaryCrossEntropy`` does not use weights, so refactor accordingly. Unless - and until clear motivating use cases arise, the library will not include a - binary cross-entropy function with weights.) - """ - return cb.Serial(_BinaryCrossEntropy(), - _WeightedMean(), - name='BinaryCrossEntropyLoss', - sublayers_to_print=[]) + DEPRECATED; refactor to use ``BinaryCrossEntropy`` instead. (The newer + ``BinaryCrossEntropy`` does not use weights, so refactor accordingly. Unless + and until clear motivating use cases arise, the library will not include a + binary cross-entropy function with weights.) + """ + return cb.Serial( + _BinaryCrossEntropy(), + _WeightedMean(), + name="BinaryCrossEntropyLoss", + sublayers_to_print=[], + ) def L2Loss(): - r"""Returns a layer that computes an L2-like loss for one batch. + r"""Returns a layer that computes an L2-like loss for one batch. - The layer takes three inputs: + The layer takes three inputs: - - Model output from one batch, an ndarray of float-valued elements. + - Model output from one batch, an ndarray of float-valued elements. - - A batch of element-wise target values, which matches the shape of the - model output. + - A batch of element-wise target values, which matches the shape of the + model output. - - A batch of weights, which matches the shape of the model output. + - A batch of weights, which matches the shape of the model output. - The layer returns a weighted average of element-wise squared error terms - :math:`(y_i - t_i)^2`. - """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - shapes.assert_same_shape(model_output, targets) - shapes.assert_same_shape(model_output, weights) - weighted_sse = weights * (model_output - targets)**2 - return jnp.sum(weighted_sse) / jnp.sum(weights) - return base.Fn('L2Loss', f) + The layer returns a weighted average of element-wise squared error terms + :math:`(y_i - t_i)^2`. + """ + + def f(model_output, targets, weights): # pylint: disable=invalid-name + shapes.assert_same_shape(model_output, targets) + shapes.assert_same_shape(model_output, weights) + weighted_sse = weights * (model_output - targets) ** 2 + return jnp.sum(weighted_sse) / jnp.sum(weights) + + return base.Fn("L2Loss", f) def SmoothL1Loss(): - r"""Returns a layer that computes a weighted, smoothed L1 loss for one batch. + r"""Returns a layer that computes a weighted, smoothed L1 loss for one batch. The layer takes three inputs: @@ -476,178 +492,202 @@ def SmoothL1Loss(): The layer returns a weighted average of these element-wise values. """ - def f(model_output, targets, weights): # pylint: disable=invalid-name - shapes.assert_same_shape(model_output, targets) - shapes.assert_same_shape(model_output, weights) - l1_dist = jnp.abs(model_output - targets) - smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) - weighted_smooth_dist = weights * smooth_dist - return jnp.sum(weighted_smooth_dist) / jnp.sum(weights) - return base.Fn('SmoothL1Loss', f) + def f(model_output, targets, weights): # pylint: disable=invalid-name + shapes.assert_same_shape(model_output, targets) + shapes.assert_same_shape(model_output, weights) + l1_dist = jnp.abs(model_output - targets) + smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) + weighted_smooth_dist = weights * smooth_dist + return jnp.sum(weighted_smooth_dist) / jnp.sum(weights) -def MacroAveragedFScore(beta=1., initial_category_index=0): - r"""Returns a layer that computes a macro-averaged F-score. + return base.Fn("SmoothL1Loss", f) - The macro-averaged F-score summarize how well the classifier's `k` predictions - align with the observed/gold instances of `k`. It additionally cares about - all the classes equally regardless of their size. - Args: - beta: a parameter that determines the weight of recall in the F-score. - initial_category_index: an index of the initial category. +def MacroAveragedFScore(beta=1.0, initial_category_index=0): + r"""Returns a layer that computes a macro-averaged F-score. - The layer takes two inputs: + The macro-averaged F-score summarize how well the classifier's `k` predictions + align with the observed/gold instances of `k`. It additionally cares about + all the classes equally regardless of their size. - - Model output from one batch, an ndarray of float-valued elements. + Args: + beta: a parameter that determines the weight of recall in the F-score. + initial_category_index: an index of the initial category. - - A batch of element-wise target values, which matches the shape of the - model output. + The layer takes two inputs: - The layer returns an macro-averaged F-score across all the classes. - """ - def f(model_output, targets): # pylint: disable=invalid-name - beta2 = beta ** 2 - predictions = jnp.argmax(model_output, axis=-1) - n_categories = model_output.shape[-1] - f_scores = jnp.empty(0) - for k in range(initial_category_index, n_categories): - _, _, _, precision, recall = _precision_recall(predictions, targets, k) - f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) - return jnp.mean(f_scores) + - Model output from one batch, an ndarray of float-valued elements. - return base.Fn('MacroAveragedFScore', f) + - A batch of element-wise target values, which matches the shape of the + model output. + The layer returns an macro-averaged F-score across all the classes. + """ -def WeightedFScore(beta=1., initial_category_index=0): - """Returns a layer that computes a weighted F-score. + def f(model_output, targets): # pylint: disable=invalid-name + beta2 = beta**2 + predictions = jnp.argmax(model_output, axis=-1) + n_categories = model_output.shape[-1] + f_scores = jnp.empty(0) + for k in range(initial_category_index, n_categories): + _, _, _, precision, recall = _precision_recall(predictions, targets, k) + f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) + return jnp.mean(f_scores) - The weighted F-score summarize how well the classifier's `k` predictions - align with the observed/gold instances of `k`. It additionally - weights the summary by the number of observed/gold and predicted examples - in each class. + return base.Fn("MacroAveragedFScore", f) - Args: - beta: a parameter that determines the weight of recall in the F-score. - initial_category_index: an index of the initial category. - The layer takes two inputs: +def WeightedFScore(beta=1.0, initial_category_index=0): + """Returns a layer that computes a weighted F-score. - - Model output from one batch, an ndarray of float-valued elements. + The weighted F-score summarize how well the classifier's `k` predictions + align with the observed/gold instances of `k`. It additionally + weights the summary by the number of observed/gold and predicted examples + in each class. - - A batch of element-wise target values, which matches the shape of the - model output. + Args: + beta: a parameter that determines the weight of recall in the F-score. + initial_category_index: an index of the initial category. - The layer returns a weighted F-score across all the classes. - """ - def f(model_output, targets): # pylint: disable=invalid-name - beta2 = beta ** 2 - predictions = jnp.argmax(model_output, axis=-1) - n_categories = model_output.shape[-1] - f_scores = jnp.empty(0) - weights = jnp.empty(0) - for k in range(initial_category_index, n_categories): - _, _, n_k_targets, precision, recall = _precision_recall( - predictions, targets, k) - f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) - weights = jnp.append(weights, n_k_targets) - return jnp.average(f_scores, weights=weights) + The layer takes two inputs: + + - Model output from one batch, an ndarray of float-valued elements. + + - A batch of element-wise target values, which matches the shape of the + model output. - return base.Fn('WeightedFScore', f) + The layer returns a weighted F-score across all the classes. + """ + + def f(model_output, targets): # pylint: disable=invalid-name + beta2 = beta**2 + predictions = jnp.argmax(model_output, axis=-1) + n_categories = model_output.shape[-1] + f_scores = jnp.empty(0) + weights = jnp.empty(0) + for k in range(initial_category_index, n_categories): + _, _, n_k_targets, precision, recall = _precision_recall( + predictions, targets, k + ) + f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) + weights = jnp.append(weights, n_k_targets) + return jnp.average(f_scores, weights=weights) + + return base.Fn("WeightedFScore", f) def WeightedSum(): - """Returns a layer that computes a weighted sum of the given values.""" - def f(values, weights): # pylint: disable=invalid-name - return jnp.sum(values * weights) - return base.Fn('WeightedSum', f) + """Returns a layer that computes a weighted sum of the given values.""" + + def f(values, weights): # pylint: disable=invalid-name + return jnp.sum(values * weights) + + return base.Fn("WeightedSum", f) def _Accuracy(): - """Returns a layer that scores predicted versus target category.""" - def f(predicted_category, target_category): # pylint: disable=invalid-name - # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. - # shapes.assert_same_shape(predicted_category, target_category) - return jnp.equal(predicted_category, target_category).astype(jnp.float32) - return base.Fn('_Accuracy', f) + """Returns a layer that scores predicted versus target category.""" + + def f(predicted_category, target_category): # pylint: disable=invalid-name + # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. + # shapes.assert_same_shape(predicted_category, target_category) + return jnp.equal(predicted_category, target_category).astype(jnp.float32) + + return base.Fn("_Accuracy", f) def _CrossEntropy(): - """Returns a layer that computes prediction-target cross entropies.""" - def f(model_output, target_category): # pylint: disable=invalid-name - # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. - # shapes.assert_shape_equals(target_category, model_output.shape[:-1]) - target_distribution = core.one_hot(target_category, model_output.shape[-1]) - return -1.0 * jnp.sum(model_output * target_distribution, axis=-1) - return base.Fn('_CrossEntropy', f) + """Returns a layer that computes prediction-target cross entropies.""" + + def f(model_output, target_category): # pylint: disable=invalid-name + target_distribution = core.one_hot(target_category, model_output.shape[-1]) + return jnp.negative( + jnp.sum(jnp.multiply(model_output, target_distribution), axis=-1) + ) + + return base.Fn("_CrossEntropy", f) def _BinaryCrossEntropy(): - """Returns a layer that computes prediction-target cross entropies.""" - def f(model_output, target_category): # pylint: disable=invalid-name - shapes.assert_same_shape(model_output, target_category) - batch_size = model_output.shape[0] - j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output)) - j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output)) - j = -1.0/batch_size * jnp.squeeze(j) - return j - return base.Fn('_BinaryCrossEntropy', f) + """Returns a layer that computes prediction-target cross entropies.""" + + def f(model_output, target_category): # pylint: disable=invalid-name + shapes.assert_same_shape(model_output, target_category) + batch_size = model_output.shape[0] + j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output)) + j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output)) + j = -1.0 / batch_size * jnp.squeeze(j) + return j + + return base.Fn("_BinaryCrossEntropy", f) def CrossEntropySum(): - """Sum of prediction-target cross entropies for multiclass classification.""" - return cb.Serial(_CrossEntropy(), - WeightedSum(), - name='CrossEntropySum', - sublayers_to_print=[]) + """Sum of prediction-target cross entropies for multiclass classification.""" + return cb.Serial( + _CrossEntropy(), WeightedSum(), name="CrossEntropySum", sublayers_to_print=[] + ) def BinaryCrossEntropySum(): - """Sum of prediction-target cross entropies for binary classification.""" - return cb.Serial(_BinaryCrossEntropy(), - WeightedSum(), - name='BinaryCrossEntropySum', - sublayers_to_print=[]) + """Sum of prediction-target cross entropies for binary classification.""" + return cb.Serial( + _BinaryCrossEntropy(), + WeightedSum(), + name="BinaryCrossEntropySum", + sublayers_to_print=[], + ) + + # pylint: enable=no-value-for-parameter def _WeightedMean(): - """Returns a layer that computes a weighted mean of the given values.""" - def f(values, weights): # pylint: disable=invalid-name - return jnp.sum(values * weights) / _n_weights_per_core(weights) - return base.Fn('_WeightedMean', f) + """Returns a layer that computes a weighted mean of the given values.""" + + def f(values, weights): # pylint: disable=invalid-name + return jnp.divide( + jnp.sum(jnp.multiply(values, weights)), _n_weights_per_core(weights) + ) + + return base.Fn("_WeightedMean", f) def _WeightedSequenceMean(): - """Returns a layer that computes a weighted sequence accuracy mean.""" - def f(values, weights): # pylint: disable=invalid-name - # This function assumes weights are 0 or 1. - # Then compute 1: not-correct, 0: correct or masked - not_correct = (1.0 - values) * weights - axis_to_sum = list(range(1, len(not_correct.shape))) - # Summing not-correct on all axes but batch. We're summing 0s and 1s, - # so the sum is 0 if it's all 0 and >=1 in all other cases. - not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum) - # Sequence is correct if not_correct_seq is 0, reverting here. - correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq) - return jnp.mean(correct_seq) # Mean over batch. - return base.Fn('_WeightedSequenceMean', f) + """Returns a layer that computes a weighted sequence accuracy mean.""" + + def f(values, weights): # pylint: disable=invalid-name + # This function assumes weights are 0 or 1. + # Then compute 1: not-correct, 0: correct or masked + not_correct = (1.0 - values) * weights + axis_to_sum = list(range(1, len(not_correct.shape))) + # Summing not-correct on all axes but batch. We're summing 0s and 1s, + # so the sum is 0 if it's all 0 and >=1 in all other cases. + not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum) + # Sequence is correct if not_correct_seq is 0, reverting here. + correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq) + return jnp.mean(correct_seq) # Mean over batch. + + return base.Fn("_WeightedSequenceMean", f) def _category_cross_entropy( # pylint: disable=invalid-name - model_output, targets, label_smoothing, cutoff): - """Computes category cross entropy with label smoothing.""" - n_categories = model_output.shape[-1] - target_distributions = core.one_hot(targets, n_categories) - if label_smoothing: - if label_smoothing < 0. or label_smoothing > 1.: - raise ValueError( - f'Arg label_smoothing ({label_smoothing}) must be between 0 and 1.') - target_distributions *= (1. - label_smoothing) - target_distributions += label_smoothing / n_categories - model_log_distributions = core.log_softmax(model_output) - cross_ent = - jnp.sum(target_distributions * model_log_distributions, axis=-1) - if cutoff > 0.0: - return jnp.maximum(cross_ent, cutoff) - cutoff - else: - return cross_ent + model_output, targets, label_smoothing, cutoff +): + """Computes category cross entropy with label smoothing.""" + n_categories = model_output.shape[-1] + target_distributions = core.one_hot(targets, n_categories) + if label_smoothing: + if label_smoothing < 0.0 or label_smoothing > 1.0: + raise ValueError( + f"Arg label_smoothing ({label_smoothing}) must be between 0 and 1." + ) + target_distributions *= 1.0 - label_smoothing + target_distributions += label_smoothing / n_categories + model_log_distributions = core.log_softmax(model_output) + cross_ent = -jnp.sum(target_distributions * model_log_distributions, axis=-1) + if cutoff > 0.0: + return jnp.maximum(cross_ent, cutoff) - cutoff + else: + return cross_ent diff --git a/trax/layers/metrics_test.py b/trax/layers/metrics_test.py deleted file mode 100644 index 3c59da790..000000000 --- a/trax/layers/metrics_test.py +++ /dev/null @@ -1,430 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for metrics layers.""" - -from absl.testing import absltest -import numpy as np -import trax.layers as tl - - -class MetricsTest(absltest.TestCase): - - def test_category_accuracy(self): - layer = tl.CategoryAccuracy() - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets]) - self.assertEqual(accuracy, 1 / 3) - - def test_weighted_category_accuracy_even_weights(self): - layer = tl.WeightedCategoryAccuracy() - weights = np.array([1., 1., 1.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1 / 3) - - def test_weighted_category_accuracy_uneven_weights(self): - layer = tl.WeightedCategoryAccuracy() - weights = np.array([1., 5., 2.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .7, .1, 0.], - [.2, .7, .1, 0.], - [.2, .7, .1, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .625) - - def test_category_cross_entropy(self): - layer = tl.CategoryCrossEntropy() - targets = np.array([0, 1]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .001, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .665, places=3) - - # First item near perfect, second item more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .333, places=3) - - def test_category_cross_entropy_with_label_smoothing(self): - epsilon = 0.01 - layer = tl.CategoryCrossEntropy(label_smoothing=epsilon) - targets = np.array([0, 1]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .069, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .682, places=3) - - # First item near perfect, second item more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .375, places=3) - - def test_weighted_category_cross_entropy(self): - layer = tl.WeightedCategoryCrossEntropy() - targets = np.array([0, 1]) - weights = np.array([30, 10]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .001, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .665, places=3) - - # First item (with 75% weight) near perfect, second more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .167, places=3) - - def test_weighted_category_cross_entropy_with_label_smoothing(self): - epsilon = 0.01 - layer = tl.WeightedCategoryCrossEntropy(label_smoothing=epsilon) - targets = np.array([0, 1]) - weights = np.array([30, 10]) - - # Near-perfect prediction (for both items in batch). - model_outputs = np.array([[9., 2., 0., -2.], - [2., 9., 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .069, places=3) - - # More right than wrong (for both items in batch). - model_outputs = np.array([[2.2, 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .682, places=3) - - # First item (with 75% weight) near perfect, second more right than wrong. - model_outputs = np.array([[9., 2., 0., -2.], - [2., 2.2, 0., -2.]]) - loss = layer([model_outputs, targets, weights]) - self.assertAlmostEqual(loss, .222, places=3) - - def test_masked_sequence_accuracy(self): - layer = tl.MaskedSequenceAccuracy() - targets = np.array([[0, 1, 0, 0], - [1, 0, 1, 0]]) - weights = np.array([[1., 1., 1., 0.], - [1., 1., 1., 0.]]) - - # Model gets both sequences right; output in final position would give - # wrong category but is ignored. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.35, .65]], - [[.3, .7], [.8, .2], [.1, .9], [.35, .65]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.) - - # Model gets the first element of the first sequence barely wrong. - model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.6, .4]], - [[.3, .7], [.8, .2], [.1, .9], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .5) - - # Model gets second-to-last element of each sequence barely wrong. - model_outputs = np.array([[[.9, .1], [.2, .8], [.48, .52], [.6, .4]], - [[.3, .7], [.8, .2], [.51, .49], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 0.) - - def test_binary_cross_entropy(self): - layer = tl.BinaryCrossEntropy() - targets = np.array([1, 1, 0, 0, 0]) - - # Near-perfect prediction for all five items in batch. - model_outputs = np.array([9., 9., -9., -9., -9.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 0.000123, places=6) - - # More right than wrong for all five items in batch. - model_outputs = np.array([1., 1., -1., -1., -1.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 0.313, places=3) - - # Near-perfect for 2, more right than wrong for 3. - model_outputs = np.array([9., 1., -1., -1., -9.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 0.188, places=3) - - # More wrong than right for all five. - model_outputs = np.array([-1., -1., 1., 1., 1.]) - metric_output = layer([model_outputs, targets]) - self.assertAlmostEqual(metric_output, 1.313, places=3) - - def test_accuracy_even_weights(self): - layer = tl.Accuracy() - weights = np.array([1., 1., 1.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .1, .7, 0.], - [.2, .1, .7, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1 / 3) - - def test_accuracy_uneven_weights(self): - layer = tl.Accuracy() - weights = np.array([1., 5., 2.]) - targets = np.array([0, 1, 2]) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.2, .7, .1, 0.], - [.2, .1, .7, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.2, .7, .1, 0.], - [.2, .7, .1, 0.], - [.2, .7, .1, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .625) - - model_outputs = np.array([[.7, .2, .1, 0.], - [.7, .2, .1, 0.], - [.7, .2, .1, 0.]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .125) - - def test_accuracy_binary_classifier(self): - layer = tl.Accuracy(classifier=tl.ThresholdToBinary()) - targets = np.array([[0, 0, 1, 1], - [1, 1, 1, 0]]) - weights = np.ones_like(targets) - - model_outputs = np.array([[.499, .500, .501, .502], - [.503, .502, .501, .500]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.0) - - model_outputs = np.array([[.498, .499, .500, .501], - [.502, .501, .500, .499]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .75) - - def test_sequence_accuracy_weights_all_ones(self): - layer = tl.SequenceAccuracy() - targets = np.array([[0, 1, 0, 1], - [1, 0, 1, 1]]) - weights = np.ones_like(targets) - - # Model gets both sequences right; for each position in each sequence, the - # category (integer ID) selected by argmax matches the target category. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.4, .6]], - [[.3, .7], [.8, .2], [.1, .9], [.4, .6]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.) - - # Model gets the first element of the first sequence barely wrong. - model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.4, .6]], - [[.3, .7], [.8, .2], [.1, .9], [.4, .6]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .5) - - # Model gets the last element of each sequence barely wrong. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.55, .45]], - [[.3, .7], [.8, .2], [.1, .9], [.52, .48]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 0.) - - def test_sequence_accuracy_last_position_zero_weight(self): - layer = tl.SequenceAccuracy() - targets = np.array([[0, 1, 0, 0], - [1, 0, 1, 0]]) - weights = np.array([[1., 1., 1., 0.], - [1., 1., 1., 0.]]) - - # Model gets both sequences right; output in final position would give - # wrong category but is ignored. - model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.35, .65]], - [[.3, .7], [.8, .2], [.1, .9], [.35, .65]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 1.) - - # Model gets the first element of the first sequence barely wrong. - model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.6, .4]], - [[.3, .7], [.8, .2], [.1, .9], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, .5) - - # Model gets second-to-last element of each sequence barely wrong. - model_outputs = np.array([[[.9, .1], [.2, .8], [.48, .52], [.6, .4]], - [[.3, .7], [.8, .2], [.51, .49], [.6, .4]]]) - accuracy = layer([model_outputs, targets, weights]) - self.assertEqual(accuracy, 0.) - - def test_binary_cross_entropy_loss(self): - # TODO(jonni): Clarify desired semantics/naming, then test it. - layer = tl.BinaryCrossEntropyLoss() - xs = [np.ones((9, 1)), - np.ones((9, 1)), - np.ones((9, 1))] - y = layer(xs) - self.assertEqual(y.shape, ()) - - def test_cross_entropy_loss(self): - # TODO(jonni): Clarify desired semantics/naming, then test it. - layer = tl.CrossEntropyLoss() - xs = [np.ones((9, 4, 4, 20)), - np.ones((9, 4, 4)), - np.ones((9, 4, 4))] - y = layer(xs) - self.assertEqual(y.shape, ()) - - def test_l2_loss(self): - layer = tl.L2Loss() - - model_outputs = np.array([[1., 1.], [1., 1.]]) - targets = np.array([[1., 1.], [1., 0.]]) - weights = np.array([[1., 1.], [1., 0.]]) - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.0) - - weights = np.array([[1., 0.], [0., 1.]]) - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.5) - - def test_smooth_l1_loss(self): - layer = tl.SmoothL1Loss() - - model_outputs = np.array([[1., 1.], [1., 2.]]) - targets = np.array([[1., 1.], [1., 0.]]) - l1_dist = 2 - - weights = np.array([[1., 1.], [1., 0.]]) - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.0) - - weights = np.array([[1., 0.], [0., 1.]]) - sum_weights = 2 - - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, (l1_dist-0.5) / sum_weights) - - model_outputs = np.array([[1., 1.], [1., 1.5]]) - targets = np.array([[1., 1.], [1., 1.]]) - l1_dist = 0.5 - loss = layer([model_outputs, targets, weights]) - np.testing.assert_allclose(loss, 0.5 * l1_dist**2 / sum_weights) - - def test_macro_averaged_f_score(self): - # predictions = [1, 1, 2, 1, 1]. - model_outputs = np.array([[0, 1, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 1, 0, 0], - [0, 1, 0, 0]]) - targets = np.array([1, 2, 2, 3, 1]) - # Category indices starting with `0`. - layer = tl.MacroAveragedFScore() - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .333, places=3) - # Excluding the padding index `0`. - layer = tl.MacroAveragedFScore(initial_category_index=1) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .444, places=3) - - def test_weighted_f_score(self): - # predictions = [1, 1, 2, 1, 1]. - model_outputs = np.array([[0, 1, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 1, 0, 0], - [0, 1, 0, 0]]) - targets = np.array([1, 2, 2, 3, 1]) - # Category indices starting with `0`. - layer = tl.WeightedFScore() - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .533, places=3) - # Excluding the padding index `0`. - layer = tl.WeightedFScore(initial_category_index=1) - loss = layer([model_outputs, targets]) - self.assertAlmostEqual(loss, .533, places=3) - - def test_names(self): - layer = tl.L2Loss() - self.assertEqual('L2Loss_in3', str(layer)) - layer = tl.Accuracy() - self.assertEqual('Accuracy_in3', str(layer)) - layer = tl.SequenceAccuracy() - self.assertEqual('SequenceAccuracy_in3', str(layer)) - layer = tl.BinaryCrossEntropyLoss() - self.assertEqual('BinaryCrossEntropyLoss_in3', str(layer)) - layer = tl.CrossEntropyLoss() - self.assertEqual('CrossEntropyLoss_in3', str(layer)) - layer = tl.BinaryCrossEntropySum() - self.assertEqual('BinaryCrossEntropySum_in3', str(layer)) - layer = tl.CrossEntropySum() - self.assertEqual('CrossEntropySum_in3', str(layer)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/normalization.py b/trax/layers/normalization.py index 8c3e5c138..ad2d91407 100644 --- a/trax/layers/normalization.py +++ b/trax/layers/normalization.py @@ -20,189 +20,200 @@ class BatchNorm(base.Layer): - """Layer that performs batch normalization. - - In training, batch normalization keeps smoothed cumulative statistics across - batches of input data and modifies each new batch so that its components are - normally distributed. In eval or inference, a `BatchNorm` instance uses its - stored mean and variance to approximately normalize each new batch of data. - - See https://arxiv.org/abs/1502.03167 for original presentation and motivation - of batch normalization). - """ - - def __init__(self, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, - momentum=0.999, mode='train'): - super().__init__() - self._axis = axis - self._epsilon = epsilon - self._center = center - self._scale = scale - self._momentum = momentum - self._mode = mode - - def forward(self, x): - """Computes batch normalization as part of a forward pass in the model.""" - running_mean, running_var, n_batches = self.state - if self._mode == 'train': - n_batches += 1 - mean, var = self._fast_mean_and_variance(x) - # Gather smoothed input statistics for later use in evals or inference. - running_mean = _exponentially_smoothed(self._momentum, running_mean, mean) - running_var = _exponentially_smoothed(self._momentum, running_var, var) - self.state = (running_mean, running_var, n_batches) - else: - mean = running_mean - var = running_var - - z = self._z_score(x, mean, var) - beta, gamma = self._beta_gamma_with_correct_axes(x, self.weights) - - # Return the z rescaled by the parameters if requested. - if self._center and self._scale: - output = gamma * z + beta - elif self._center: - output = z + beta - elif self._scale: - output = gamma * z - else: - output = z - if output.dtype != x.dtype: - raise TypeError(f'The dtype of the output ({output.dtype}) of batch ' - f'norm is not the same as the input ({x.dtype}). ' - f'Batch norm should not change the dtype.') - return output - - def init_weights_and_state(self, input_signature): - """Helper to initialize batch norm weights and state.""" - axis = self._axis - axis = (axis,) if jnp.isscalar(axis) else axis - input_shape = input_signature.shape - shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) - # TODO(jonni): Should beta and gamma match the dtype in the input signature? - beta = jnp.zeros(shape, dtype='float32') if self._center else () - gamma = jnp.ones(shape, dtype='float32') if self._scale else () - def get_stats_axis(i, d): - if i in axis: - return 1 - else: - return d - stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) - running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) - running_var = jnp.ones(stats_shape, dtype=jnp.float32) - n_batches = jnp.zeros((), dtype=jnp.int64) - self.weights = (beta, gamma) - self.state = (running_mean, running_var, n_batches) - - def _fast_mean_and_variance(self, x): - mean = jnp.mean(x, self._axis, keepdims=True) - # Fast but less numerically-stable variance calculation than jnp.var. - m1 = jnp.mean(x**2, self._axis, keepdims=True) - variance = m1 - mean**2 - return mean, variance - - def _z_score(self, x, mean, variance): - mu = mean.astype(x.dtype) - sigma = jnp.sqrt(variance + self._epsilon).astype(x.dtype) - return (x - mu) / sigma - - def _beta_gamma_with_correct_axes(self, x, weights): - # Expand the parameters to have the right axes. - beta, gamma = weights - # TODO(phawkins): jnp.expand_dims should accept an axis tuple. - # (https://github.com/numpy/numpy/issues/12290) - ed = tuple(None if i in self._axis else slice(None) - for i in range(jnp.ndim(x))) - beta = beta[ed] - gamma = gamma[ed] - return beta, gamma + """Layer that performs batch normalization. + + In training, batch normalization keeps smoothed cumulative statistics across + batches of input data and modifies each new batch so that its components are + normally distributed. In eval or inference, a `BatchNorm` instance uses its + stored mean and variance to approximately normalize each new batch of data. + + See https://arxiv.org/abs/1502.03167 for original presentation and motivation + of batch normalization). + """ + + def __init__( + self, + axis=(0, 1, 2), + epsilon=1e-5, + center=True, + scale=True, + momentum=0.999, + mode="train", + ): + super().__init__() + self._axis = axis + self._epsilon = epsilon + self._center = center + self._scale = scale + self._momentum = momentum + self._mode = mode + + def forward(self, x): + """Computes batch normalization as part of a forward pass in the model.""" + running_mean, running_var, n_batches = self.state + if self._mode == "train": + n_batches += 1 + mean, var = self._fast_mean_and_variance(x) + # Gather smoothed input statistics for later use in evals or inference. + running_mean = _exponentially_smoothed(self._momentum, running_mean, mean) + running_var = _exponentially_smoothed(self._momentum, running_var, var) + self.state = (running_mean, running_var, n_batches) + else: + mean = running_mean + var = running_var + + z = self._z_score(x, mean, var) + beta, gamma = self._beta_gamma_with_correct_axes(x, self.weights) + + # Return the z rescaled by the parameters if requested. + if self._center and self._scale: + output = gamma * z + beta + elif self._center: + output = z + beta + elif self._scale: + output = gamma * z + else: + output = z + if output.dtype != x.dtype: + raise TypeError( + f"The dtype of the output ({output.dtype}) of batch " + f"norm is not the same as the input ({x.dtype}). " + f"Batch norm should not change the dtype." + ) + return output + + def init_weights_and_state(self, input_signature): + """Helper to initialize batch norm weights and state.""" + axis = self._axis + axis = (axis,) if jnp.isscalar(axis) else axis + input_shape = input_signature.shape + shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) + # TODO(jonni): Should beta and gamma match the dtype in the input signature? + beta = jnp.zeros(shape, dtype="float32") if self._center else () + gamma = jnp.ones(shape, dtype="float32") if self._scale else () + + def get_stats_axis(i, d): + if i in axis: + return 1 + else: + return d + + stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) + running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) + running_var = jnp.ones(stats_shape, dtype=jnp.float32) + n_batches = jnp.zeros((), dtype=jnp.int64) + self.weights = (beta, gamma) + self.state = (running_mean, running_var, n_batches) + + def _fast_mean_and_variance(self, x): + mean = jnp.mean(x, self._axis, keepdims=True) + # Fast but less numerically-stable variance calculation than jnp.var. + m1 = jnp.mean(x**2, self._axis, keepdims=True) + variance = m1 - mean**2 + return mean, variance + + def _z_score(self, x, mean, variance): + mu = mean.astype(x.dtype) + sigma = jnp.sqrt(variance + self._epsilon).astype(x.dtype) + return (x - mu) / sigma + + def _beta_gamma_with_correct_axes(self, x, weights): + # Expand the parameters to have the right axes. + beta, gamma = weights + # TODO(phawkins): jnp.expand_dims should accept an axis tuple. + # (https://github.com/numpy/numpy/issues/12290) + ed = tuple(None if i in self._axis else slice(None) for i in range(jnp.ndim(x))) + beta = beta[ed] + gamma = gamma[ed] + return beta, gamma class LayerNorm(base.Layer): - """Layer normalization.""" + """Layer normalization.""" - def __init__(self, center=True, epsilon=1e-6): - super().__init__() - self._epsilon = epsilon - self._center = center + def __init__(self, center=True, epsilon=1e-6): + super().__init__() + self._epsilon = epsilon + self._center = center - def forward(self, x): - scale, bias = self.weights - mean = jnp.mean(x, axis=-1, keepdims=True) - centered = x - mean if self._center else x - variance = jnp.mean(centered * centered, axis=-1, keepdims=True) - norm_inputs = centered / jnp.sqrt(variance + self._epsilon) - scaled = norm_inputs * scale - return scaled + bias if self._center else scaled + def forward(self, x): + scale, bias = self.weights + mean = jnp.mean(x, axis=-1, keepdims=True) + centered = x - mean if self._center else x + variance = jnp.mean(centered * centered, axis=-1, keepdims=True) + norm_inputs = centered / jnp.sqrt(variance + self._epsilon) + scaled = norm_inputs * scale + return scaled + bias if self._center else scaled - def init_weights_and_state(self, input_signature): - features = input_signature.shape[-1] - scale = jnp.ones(features, dtype=input_signature.dtype) - bias = jnp.zeros(features, dtype=input_signature.dtype) - self.weights = scale, bias + def init_weights_and_state(self, input_signature): + features = input_signature.shape[-1] + scale = jnp.ones(features, dtype=input_signature.dtype) + bias = jnp.zeros(features, dtype=input_signature.dtype) + self.weights = scale, bias class FilterResponseNorm(base.Layer): - """Filter Response Normalization layer without Threshold Linear Unit. + """Filter Response Normalization layer without Threshold Linear Unit. - c.f. https://arxiv.org/pdf/1911.09737.pdf - """ + c.f. https://arxiv.org/pdf/1911.09737.pdf + """ - def __init__(self, - mode=None, - learn_epsilon=False, - init_epsilon=1e-6, - init_learnt_epsilon=1e-4): - super().__init__() + def __init__( + self, + mode=None, + learn_epsilon=False, + init_epsilon=1e-6, + init_learnt_epsilon=1e-4, + ): + super().__init__() - del mode + del mode - # If we learn epsilon then epsilon = init_epsilon + |learnt_value| - # where learnt_value is initialized to init_learnt_epsilon. - # If learn_epsilon is false then epsilon is just init_epsilon. - # - # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`. - self._learn_epsilon = learn_epsilon + # If we learn epsilon then epsilon = init_epsilon + |learnt_value| + # where learnt_value is initialized to init_learnt_epsilon. + # If learn_epsilon is false then epsilon is just init_epsilon. + # + # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`. + self._learn_epsilon = learn_epsilon - # TODO(jonni): Replace asserts with ValueError. - assert init_epsilon > 0 - assert init_learnt_epsilon > 0 + # TODO(jonni): Replace asserts with ValueError. + assert init_epsilon > 0 + assert init_learnt_epsilon > 0 - self._init_epsilon = jnp.array(init_epsilon, dtype=jnp.float32) - self._init_learnt_epsilon = jnp.array(init_learnt_epsilon, - dtype=jnp.float32) + self._init_epsilon = jnp.array(init_epsilon, dtype=jnp.float32) + self._init_learnt_epsilon = jnp.array(init_learnt_epsilon, dtype=jnp.float32) - def forward(self, inputs): - gamma, beta, epsilon_l = self.weights + def forward(self, inputs): + gamma, beta, epsilon_l = self.weights - epsilon = self._init_epsilon - if epsilon_l is not base.EMPTY_WEIGHTS: - epsilon += jnp.abs(epsilon_l[0]) + epsilon = self._init_epsilon + if epsilon_l is not base.EMPTY_WEIGHTS: + epsilon += jnp.abs(epsilon_l[0]) - # Omit B and C - axis = tuple(range(1, len(jnp.shape(inputs)) - 1)) - # (B, 1, 1, C) - nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True) - # (B, W, H, C) - xhat = inputs / jnp.sqrt(nu2 + epsilon) + # Omit B and C + axis = tuple(range(1, len(jnp.shape(inputs)) - 1)) + # (B, 1, 1, C) + nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True) + # (B, W, H, C) + xhat = inputs / jnp.sqrt(nu2 + epsilon) - return gamma * xhat + beta + return gamma * xhat + beta - def init_weights_and_state(self, input_signature): - # Usually (B, W, H, C) - shape = input_signature.shape - num_channels = shape[-1] + def init_weights_and_state(self, input_signature): + # Usually (B, W, H, C) + shape = input_signature.shape + num_channels = shape[-1] - gamma = jnp.ones((num_channels,), dtype=jnp.float32) - beta = jnp.zeros((num_channels,), dtype=jnp.float32) + gamma = jnp.ones((num_channels,), dtype=jnp.float32) + beta = jnp.zeros((num_channels,), dtype=jnp.float32) - epsilon_l = base.EMPTY_WEIGHTS - if self._learn_epsilon: - epsilon_l = (self._init_learnt_epsilon,) + epsilon_l = base.EMPTY_WEIGHTS + if self._learn_epsilon: + epsilon_l = (self._init_learnt_epsilon,) - self.weights = gamma, beta, epsilon_l + self.weights = gamma, beta, epsilon_l def _exponentially_smoothed(momentum, old, new): - smoothed_value = momentum * old + (1 - momentum) * new - return smoothed_value.astype(old.dtype) + smoothed_value = momentum * old + (1 - momentum) * new + return smoothed_value.astype(old.dtype) diff --git a/trax/layers/normalization_test.py b/trax/layers/normalization_test.py deleted file mode 100644 index b844c0d1c..000000000 --- a/trax/layers/normalization_test.py +++ /dev/null @@ -1,130 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for normalization layers.""" - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -import trax.layers as tl - - -class BatchNormTest(parameterized.TestCase): - - def test_forward_shape(self): - layer = tl.BatchNorm() - x = np.ones((30, 20, 70)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - @parameterized.named_parameters( - ('jax32', fastmath.Backend.JAX, np.float32), - ('tf32', fastmath.Backend.TFNP, np.float32), - ('tf64', fastmath.Backend.TFNP, np.float64), - ) - def test_forward_dtype(self, backend, dtype): - with fastmath.use_backend(backend): - layer = tl.BatchNorm() - x = np.ones((3, 2, 7)).astype(dtype) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.dtype, dtype) - - @parameterized.named_parameters( - ('momentum_999', .999), - ('momentum_900', .900), - ('momentum_800', .800), - ) - def test_forward(self, momentum): - layer = tl.BatchNorm(momentum=momentum) - x = np.array([[[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]], - [[12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23]]]).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - running_mean, running_var, n_batches = layer.state - - fraction_old = momentum - fraction_new = 1.0 - momentum - mean_of_x = 11.5 # mean of range(24) - var_of_x = 47.9167 # variance of range(24) - np.testing.assert_allclose( - running_mean, 0.0 * fraction_old + mean_of_x * fraction_new) - np.testing.assert_allclose( - running_var, 1.0 * fraction_old + var_of_x * fraction_new, rtol=1e-6) - self.assertEqual(n_batches, 1) - eps = 1e-5 - np.testing.assert_allclose( - y, (x - mean_of_x) / np.sqrt(var_of_x + eps), rtol=1e-6) - - def test_new_weights_and_state(self): - layer = tl.BatchNorm() - x = np.ones((3, 2, 7)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - - running_mean, running_var, n_batches = layer.state - np.testing.assert_allclose(running_mean, 0.0) - np.testing.assert_allclose(running_var, 1.0) - self.assertEqual(n_batches, 0) - - -class LayerNormTest(parameterized.TestCase): - - def test_forward_shape(self): - layer = tl.LayerNorm() - x = np.ones((3, 2, 7)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - @parameterized.named_parameters( - ('jax32', fastmath.Backend.JAX, np.float32), - ('tf32', fastmath.Backend.TFNP, np.float32), - ('tf64', fastmath.Backend.TFNP, np.float64), - ) - def test_forward_dtype(self, backend, dtype): - with fastmath.use_backend(backend): - layer = tl.LayerNorm() - x = np.ones((3, 2, 7)).astype(dtype) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.dtype, dtype) - - -class FilterResponseNormTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('learn_epsilon_false', False), - ('learn_epsilon_true', True), - ) - def test_forward_shape(self, learn_epsilon): - layer = tl.FilterResponseNorm(learn_epsilon=learn_epsilon) - - B, H, W, C = 64, 5, 7, 3 # pylint: disable=invalid-name - x = np.ones((B, H, W, C)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/pooling.py b/trax/layers/pooling.py index f10241d3c..00e3d5abd 100644 --- a/trax/layers/pooling.py +++ b/trax/layers/pooling.py @@ -20,100 +20,109 @@ # pylint: disable=invalid-name -def MaxPool(pool_size=(2, 2), strides=None, padding='VALID'): - """Reduces each multi-dimensional window to the max of the window's values. - - Windows, as specified by `pool_size` and `strides`, involve all axes of an - n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` - from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. - - Args: - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only - full windows get reduced; partial windows are discarded. If 'SAME', - padding is added at array edges as needed to avoid partial windows - but does not otherwise affect the selection of max values. - - Returns: - N-dimensional array in which each valid (or padded-valid) window position - in the input is reduced to / replaced by the max value from that window. - An output array has the same number of dimensions as its input, but has - fewer elements. - """ - layer_name = f'MaxPool{pool_size}'.replace(' ', '') - def f(x): - return fastmath.max_pool( - x, pool_size=pool_size, strides=strides, padding=padding) - return Fn(layer_name, f) - - -def SumPool(pool_size=(2, 2), strides=None, padding='VALID'): - """Reduces each multi-dimensional window to the sum of the window's values. - - Windows, as specified by `pool_size` and `strides`, involve all axes of an - n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` - from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. - - Args: - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only - full windows get reduced; partial windows are discarded. If 'SAME', - padding is added at array edges as needed to avoid partial - windows but does not otherwise affect the computation of sums. - - Returns: - N-dimensional array in which each valid (or padded-valid) window position - in the input is reduced to / replaced by the sum of values in that window. - An output array has the same number of dimensions as its input, but has - fewer elements. - """ - layer_name = f'SumPool{pool_size}'.replace(' ', '') - def f(x): - return fastmath.sum_pool( - x, pool_size=pool_size, strides=strides, padding=padding) - return Fn(layer_name, f) - - -def AvgPool(pool_size=(2, 2), strides=None, padding='VALID'): - """Reduces each multi-dimensional window to the mean of the window's values. - - Windows, as specified by `pool_size` and `strides`, involve all axes of an - n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` - from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. - - Args: - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only - full windows get reduced; partial windows are discarded. If 'SAME', - padding is added at array edges as needed but is not counted in the - computation of averages. - - Returns: - N-dimensional array in which each valid (or padded-valid) window position - in the input is reduced to / replaced by the mean of values in that window. - An output array has the same number of dimensions as its input, but has - fewer elements. - """ - layer_name = f'AvgPool{pool_size}'.replace(' ', '') - def f(x): - return fastmath.avg_pool( - x, pool_size=pool_size, strides=strides, padding=padding) - return Fn(layer_name, f) +def MaxPool(pool_size=(2, 2), strides=None, padding="VALID"): + """Reduces each multi-dimensional window to the max of the window's values. + + Windows, as specified by `pool_size` and `strides`, involve all axes of an + n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` + from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. + + Args: + pool_size: Shape of window that gets reduced to a single vector value. + If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` + must be a tuple of length :math:`n-2`. + strides: Offsets from the location of one window to the locations of + neighboring windows along each axis. If specified, must be a tuple of + the same length as `pool_size`. If None, then offsets of 1 along each + window axis, :math:`(1, ..., 1)`, will be used. + padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only + full windows get reduced; partial windows are discarded. If 'SAME', + padding is added at array edges as needed to avoid partial windows + but does not otherwise affect the selection of max values. + + Returns: + N-dimensional array in which each valid (or padded-valid) window position + in the input is reduced to / replaced by the max value from that window. + An output array has the same number of dimensions as its input, but has + fewer elements. + """ + layer_name = f"MaxPool{pool_size}".replace(" ", "") + + def f(x): + return fastmath.max_pool( + x, pool_size=pool_size, strides=strides, padding=padding + ) + + return Fn(layer_name, f) + + +def SumPool(pool_size=(2, 2), strides=None, padding="VALID"): + """Reduces each multi-dimensional window to the sum of the window's values. + + Windows, as specified by `pool_size` and `strides`, involve all axes of an + n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` + from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. + + Args: + pool_size: Shape of window that gets reduced to a single vector value. + If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` + must be a tuple of length :math:`n-2`. + strides: Offsets from the location of one window to the locations of + neighboring windows along each axis. If specified, must be a tuple of + the same length as `pool_size`. If None, then offsets of 1 along each + window axis, :math:`(1, ..., 1)`, will be used. + padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only + full windows get reduced; partial windows are discarded. If 'SAME', + padding is added at array edges as needed to avoid partial + windows but does not otherwise affect the computation of sums. + + Returns: + N-dimensional array in which each valid (or padded-valid) window position + in the input is reduced to / replaced by the sum of values in that window. + An output array has the same number of dimensions as its input, but has + fewer elements. + """ + layer_name = f"SumPool{pool_size}".replace(" ", "") + + def f(x): + return fastmath.sum_pool( + x, pool_size=pool_size, strides=strides, padding=padding + ) + + return Fn(layer_name, f) + + +def AvgPool(pool_size=(2, 2), strides=None, padding="VALID"): + """Reduces each multi-dimensional window to the mean of the window's values. + + Windows, as specified by `pool_size` and `strides`, involve all axes of an + n-dimensional array except the first and last: :math:`(d_1, ..., d_{n-2})` + from shape :math:`(d_0, d_1, ..., d_{n-2}, d_{n-1})`. + + Args: + pool_size: Shape of window that gets reduced to a single vector value. + If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` + must be a tuple of length :math:`n-2`. + strides: Offsets from the location of one window to the locations of + neighboring windows along each axis. If specified, must be a tuple of + the same length as `pool_size`. If None, then offsets of 1 along each + window axis, :math:`(1, ..., 1)`, will be used. + padding: 'VALID' or 'SAME'. If 'VALID', no padding is done, and only + full windows get reduced; partial windows are discarded. If 'SAME', + padding is added at array edges as needed but is not counted in the + computation of averages. + + Returns: + N-dimensional array in which each valid (or padded-valid) window position + in the input is reduced to / replaced by the mean of values in that window. + An output array has the same number of dimensions as its input, but has + fewer elements. + """ + layer_name = f"AvgPool{pool_size}".replace(" ", "") + + def f(x): + return fastmath.avg_pool( + x, pool_size=pool_size, strides=strides, padding=padding + ) + + return Fn(layer_name, f) diff --git a/trax/layers/pooling_test.py b/trax/layers/pooling_test.py deleted file mode 100644 index 7c858cd8b..000000000 --- a/trax/layers/pooling_test.py +++ /dev/null @@ -1,140 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for conv layers.""" - -from absl.testing import absltest -import numpy as np - -import trax.layers as tl - - -class MaxPoolTest(absltest.TestCase): - - def test_forward_shape(self): - layer = tl.MaxPool(pool_size=(2, 2), strides=(1, 2)) - x = np.ones((11, 6, 4, 17)) - y = layer(x) - self.assertEqual(y.shape, (11, 5, 2, 17)) - - def test_forward(self): - layer = tl.MaxPool(pool_size=(2, 2), strides=(2, 2)) - x = np.array([[ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[[7, 5, 6], [70, 50, 60]]]]) - - def test_padding_default(self): - layer = tl.MaxPool(pool_size=(3,), strides=(3,)) - - # Discard incomplete window at end: [[3, 6], [4, 5]]. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[2, 9]]]) - - def test_padding_same(self): - layer = tl.MaxPool(pool_size=(3,), strides=(3,), padding='SAME') - - # One padding position needed; add at end. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[2, 9], [4, 6]]]) - - # Two padding positions needed; add one at end and one at start. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[1, 9], [3, 7]]]) - - -class SumPoolTest(absltest.TestCase): - - def test_forward_shape(self): - layer = tl.SumPool(pool_size=(2, 2), strides=(1, 2)) - x = np.ones((11, 6, 4, 17)) - y = layer(x) - self.assertEqual(y.shape, (11, 5, 2, 17)) - - def test_forward(self): - layer = tl.SumPool(pool_size=(2, 2), strides=(2, 2)) - x = np.array([[ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[[16, 10, 14], [160, 100, 140]]]]) - - def test_padding_same(self): - layer = tl.SumPool(pool_size=(3,), strides=(3,), padding='SAME') - - # One padding position needed; add at end. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[3, 24], [7, 11]]]) - - # Two padding positions needed; add one at end and one at start. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[1, 17], [5, 13]]]) - - -class AvgPoolTest(absltest.TestCase): - - def test_forward_shape(self): - layer = tl.AvgPool(pool_size=(2, 2), strides=(1, 2)) - x = np.ones((11, 6, 4, 17)) - y = layer(x) - self.assertEqual(y.shape, (11, 5, 2, 17)) - - def test_forward(self): - layer = tl.AvgPool(pool_size=(2, 2), strides=(2, 2)) - x = np.array([[ - [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]], - [[4, 2, 3], [7, 1, 2], [40, 20, 30], [70, 10, 20]], - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[[4.0, 2.5, 3.5], [40, 25, 35]]]]) - - def test_padding_same(self): - layer = tl.AvgPool(pool_size=(3,), strides=(3,), padding='SAME') - - # One padding position needed; add at end. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6], [4, 5] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[1, 8], [3.5, 5.5]]]) - - # Two padding positions needed; add one at end and one at start. - x = np.array([[ - [0, 9], [1, 8], [2, 7], [3, 6] - ]]) - y = layer(x) - self.assertEqual(tl.to_list(y), [[[.5, 8.5], [2.5, 6.5]]]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/research/__init__.py b/trax/layers/research/__init__.py index a4ee92161..dfc0f987b 100644 --- a/trax/layers/research/__init__.py +++ b/trax/layers/research/__init__.py @@ -13,3 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Research layers.""" + +from . import flash_attention # noqa: F401 diff --git a/trax/layers/research/efficient_attention.py b/trax/layers/research/efficient_attention.py index ecc52e523..4f3d19e98 100644 --- a/trax/layers/research/efficient_attention.py +++ b/trax/layers/research/efficient_attention.py @@ -37,3232 +37,3681 @@ import math import jax + from trax import fastmath from trax.fastmath import numpy as np -from trax.layers import attention -from trax.layers import base +from trax.layers import attention, base, core from trax.layers import combinators as cb -from trax.layers import core from trax.layers import initializers as init from trax.layers import sparsity as sp from trax.layers.research import rotary_positional_embedding as rotary_pe - ####################################################### Functions def length_normalized(x, epsilon=1e-6): - variance = np.mean(x**2, axis=-1, keepdims=True) - norm_inputs = x / np.sqrt(variance + epsilon) - return norm_inputs + variance = np.mean(x**2, axis=-1, keepdims=True) + norm_inputs = x / np.sqrt(variance + epsilon) + return norm_inputs def hash_vecs(vecs, n_buckets_in, n_hashes, rng): - """Hash vectors into buckets. - - Args: - vecs: vectors to hash, a tensor of shape [batch_size, depth] - n_buckets_in: an int or a list of ints, number of hash buckets; - if it is a list, we do hierarchical hashing as specified by the list - n_hashes: number of hashes - rng: random generator to use for hashing - - Returns: - A pair (buckets, n_buckets) where buckets is a tensor of shape - [n_hashes, batch_size] of integers -- the hash bucket IDs, and - n_buckets is an int, the total number of hash buckets, equal to - the product of all items in n_buckets_in. - """ - # See https://arxiv.org/pdf/1509.02897.pdf - # We sample a different random rotation for each round of hashing to - # decrease the probability of hash misses. - if isinstance(n_buckets_in, int): - assert n_buckets_in % 2 == 0 - rot_size = n_buckets_in - n_buckets = n_buckets_in - else: - # Factorize the hash if n_buckets_in is a list or tuple - rot_size, n_buckets = 0, 1 - for factor in n_buckets_in: - assert factor % 2 == 0 - rot_size += factor - n_buckets *= factor - - rotations_shape = (vecs.shape[-1], n_hashes, rot_size // 2) - random_rotations = fastmath.random.normal(rng, rotations_shape).astype( - np.float32) - if fastmath.is_backend(fastmath.Backend.JAX): - rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations) - else: - random_rotations = np.reshape(random_rotations, - [-1, n_hashes * (rot_size // 2)]) - rotated_vecs = np.dot(vecs, random_rotations) - rotated_vecs = np.reshape(rotated_vecs, [-1, n_hashes, rot_size//2]) - rotated_vecs = np.transpose(rotated_vecs, (1, 0, 2)) - - if isinstance(n_buckets_in, int) or len(n_buckets_in) == 1: - rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) - buckets = np.argmax(rotated_vecs, axis=-1).astype(np.int32) - else: - # Get the buckets for them and combine. - buckets, cur_sum, cur_product = None, 0, 1 - for factor in n_buckets_in: - rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] - cur_sum += factor // 2 - rv = np.concatenate([rv, -rv], axis=-1) - if buckets is None: - buckets = np.argmax(rv, axis=-1).astype(np.int32) - else: - buckets += cur_product * np.argmax(rv, axis=-1).astype(np.int32) - cur_product *= factor - - return buckets, n_buckets # buckets is now (n_hashes, batch_size) + """Hash vectors into buckets. + Args: + vecs: vectors to hash, a tensor of shape [batch_size, depth] + n_buckets_in: an int or a list of ints, number of hash buckets; + if it is a list, we do hierarchical hashing as specified by the list + n_hashes: number of hashes + rng: random generator to use for hashing -def look_adjacent(x, n_chunks_before, n_chunks_after): - """Used to implement attention between consecutive chunks. - - Args: - x: array of shape [n_chunks, chunk_len, ...] - n_chunks_before: Number of previous chunks to attend to. - n_chunks_after: Number of subsequent chunks to attend to. - Returns: - array of shape [n_chunks, N * chunk_len, ...], where - N = (1 + n_chunks_before + n_chunks_after). - """ - if n_chunks_before == 0 and n_chunks_after == 0: - return x - - slices = [] - for i in range(-n_chunks_before, n_chunks_after + 1): - if i == 0: - slices.append(x) + Returns: + A pair (buckets, n_buckets) where buckets is a tensor of shape + [n_hashes, batch_size] of integers -- the hash bucket IDs, and + n_buckets is an int, the total number of hash buckets, equal to + the product of all items in n_buckets_in. + """ + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + if isinstance(n_buckets_in, int): + assert n_buckets_in % 2 == 0 + rot_size = n_buckets_in + n_buckets = n_buckets_in else: - slices.append(np.concatenate([x[i:, ...], x[:i, ...]], axis=0)) - return np.concatenate(slices, axis=1) + # Factorize the hash if n_buckets_in is a list or tuple + rot_size, n_buckets = 0, 1 + for factor in n_buckets_in: + assert factor % 2 == 0 + rot_size += factor + n_buckets *= factor + + rotations_shape = (vecs.shape[-1], n_hashes, rot_size // 2) + random_rotations = fastmath.random.normal(rng, rotations_shape).astype(np.float32) + if fastmath.is_backend(fastmath.Backend.JAX): + rotated_vecs = np.einsum("tf,fhb->htb", vecs, random_rotations) + else: + random_rotations = np.reshape( + random_rotations, [-1, n_hashes * (rot_size // 2)] + ) + rotated_vecs = np.dot(vecs, random_rotations) + rotated_vecs = np.reshape(rotated_vecs, [-1, n_hashes, rot_size // 2]) + rotated_vecs = np.transpose(rotated_vecs, (1, 0, 2)) + + if isinstance(n_buckets_in, int) or len(n_buckets_in) == 1: + rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) + buckets = np.argmax(rotated_vecs, axis=-1).astype(np.int32) + else: + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for factor in n_buckets_in: + rv = rotated_vecs[..., cur_sum : cur_sum + (factor // 2)] + cur_sum += factor // 2 + rv = np.concatenate([rv, -rv], axis=-1) + if buckets is None: + buckets = np.argmax(rv, axis=-1).astype(np.int32) + else: + buckets += cur_product * np.argmax(rv, axis=-1).astype(np.int32) + cur_product *= factor + + return buckets, n_buckets # buckets is now (n_hashes, batch_size) + + +def look_adjacent(x, n_chunks_before, n_chunks_after): + """Used to implement attention between consecutive chunks. + + Args: + x: array of shape [n_chunks, chunk_len, ...] + n_chunks_before: Number of previous chunks to attend to. + n_chunks_after: Number of subsequent chunks to attend to. + Returns: + array of shape [n_chunks, N * chunk_len, ...], where + N = (1 + n_chunks_before + n_chunks_after). + """ + if n_chunks_before == 0 and n_chunks_after == 0: + return x + + slices = [] + for i in range(-n_chunks_before, n_chunks_after + 1): + if i == 0: + slices.append(x) + else: + slices.append(np.concatenate([x[i:, ...], x[:i, ...]], axis=0)) + return np.concatenate(slices, axis=1) def mask_self_attention( - dots, q_info, kv_info, causal=True, exclude_self=True, masked=False): - """Performs masking for self-attention.""" - q_info = q_info.astype(np.float32) - kv_info = kv_info.astype(np.float32) - if causal: - mask = fastmath.lt(q_info, kv_info) - dots = dots - 1e9 * mask - if exclude_self: - mask = np.equal(q_info, kv_info) - dots = dots - 1e5 * mask - if masked: - zeros_like_kv_info = np.zeros_like(kv_info) - mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32) - dots = dots - 1e9 * mask - return dots + dots, q_info, kv_info, causal=True, exclude_self=True, masked=False +): + """Performs masking for self-attention.""" + q_info = q_info.astype(np.float32) + kv_info = kv_info.astype(np.float32) + if causal: + mask = fastmath.lt(q_info, kv_info) + dots = dots - 1e9 * mask + if exclude_self: + mask = np.equal(q_info, kv_info) + dots = dots - 1e5 * mask + if masked: + zeros_like_kv_info = np.zeros_like(kv_info) + mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32) + dots = dots - 1e9 * mask + return dots def attend( - q, k=None, v=None, - q_chunk_len=None, kv_chunk_len=None, - n_chunks_before=0, n_chunks_after=0, - mask_fn=None, q_info=None, kv_info=None, - dropout=0.0, rng=None, - ): - """Dot-product attention, with optional chunking and/or masking. - - Args: - q: Query vectors, shape [q_len, d_qk] - k: Key vectors, shape [kv_len, d_qk]; or None - v: Value vectors, shape [kv_len, d_v] - q_chunk_len: Set to non-zero to enable chunking for query vectors - kv_chunk_len: Set to non-zero to enable chunking for key/value vectors - n_chunks_before: Number of adjacent previous chunks to attend to - n_chunks_after: Number of adjacent subsequent chunks to attend to - mask_fn: TODO(kitaev) doc - q_info: Query-associated metadata for masking - kv_info: Key-associated metadata for masking - dropout: Dropout rate - rng: RNG for dropout - - Returns: - A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and - dots_logsumexp has shape [q_len]. The logsumexp of the attention - probabilities is useful for combining multiple rounds of attention (as in - LSH attention). - """ - assert v is not None - share_qk = (k is None) - - # `q_info` and `kv_info` if supplied are 0 indexed, we want them to be 1 - # indexed instead so that we can mask position 0 as well - see Github #820 - - if q_info is None: - q_info = np.arange(1, q.shape[-2] + 1, dtype=np.int32) - else: - q_info += 1 - - if kv_info is None and not share_qk: - kv_info = np.arange(1, v.shape[-2] + 1, dtype=np.int32) - elif kv_info is not None: - kv_info += 1 - - # Split q/k/v into chunks along the time axis, if desired. - if q_chunk_len is not None: - q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) - q_info = np.reshape(q_info, (-1, q_chunk_len)) - - if share_qk: - assert kv_chunk_len is None or kv_chunk_len == q_chunk_len - k = q - kv_chunk_len = q_chunk_len - if kv_info is None: - kv_info = q_info - elif kv_chunk_len is not None: - # kv_info is not None, but reshape as required. - kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) - elif kv_chunk_len is not None: - k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) - kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) - - if kv_chunk_len is not None: - v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) - - if share_qk: - k = length_normalized(k) - k = k / np.sqrt(k.shape[-1]) - - # Optionally include adjacent chunks. - if q_chunk_len is not None or kv_chunk_len is not None: - assert q_chunk_len is not None and kv_chunk_len is not None - else: - assert n_chunks_before == 0 and n_chunks_after == 0 - - k = look_adjacent(k, n_chunks_before, n_chunks_after) - v = look_adjacent(v, n_chunks_before, n_chunks_after) - kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) - - # Dot-product attention. - dots = np.matmul(q, np.swapaxes(k, -1, -2)) - - # Masking - if mask_fn is not None: - dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) - - # Softmax. - dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True) - dots = np.exp(dots - dots_logsumexp) - - if dropout > 0.0: - assert rng is not None - # Dropout is broadcast across the bin dimension - dropout_shape = (dots.shape[-2], dots.shape[-1]) - # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) - keep_prob = 1.0 - dropout - keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape) - multiplier = keep.astype(dots.dtype) / keep_prob - dots = dots * multiplier - - # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. - out = np.matmul(dots, v) - out = np.reshape(out, (-1, out.shape[-1])) - dots_logsumexp = np.reshape(dots_logsumexp, (-1,)) - return out, dots_logsumexp + q, + k=None, + v=None, + q_chunk_len=None, + kv_chunk_len=None, + n_chunks_before=0, + n_chunks_after=0, + mask_fn=None, + q_info=None, + kv_info=None, + dropout=0.0, + rng=None, +): + """Dot-product attention, with optional chunking and/or masking. + Args: + q: Query vectors, shape [q_len, d_qk] + k: Key vectors, shape [kv_len, d_qk]; or None + v: Value vectors, shape [kv_len, d_v] + q_chunk_len: Set to non-zero to enable chunking for query vectors + kv_chunk_len: Set to non-zero to enable chunking for key/value vectors + n_chunks_before: Number of adjacent previous chunks to attend to + n_chunks_after: Number of adjacent subsequent chunks to attend to + mask_fn: TODO(kitaev) doc + q_info: Query-associated metadata for masking + kv_info: Key-associated metadata for masking + dropout: Dropout rate + rng: RNG for dropout -def apply_broadcasted_dropout(vecs, dropout_rate, rng): - """Apply dropout, broadcasted across all but the last dimension of `vecs`.""" - if dropout_rate > 0.0: - assert rng is not None - keep_prob = 1.0 - dropout_rate - keep = fastmath.random.bernoulli(rng, keep_prob, (vecs.shape[-1],)) - multiplier = keep.astype(vecs.dtype) / keep_prob - return vecs * multiplier - else: - return vecs + Returns: + A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and + dots_logsumexp has shape [q_len]. The logsumexp of the attention + probabilities is useful for combining multiple rounds of attention (as in + LSH attention). + """ + assert v is not None + share_qk = k is None + # `q_info` and `kv_info` if supplied are 0 indexed, we want them to be 1 + # indexed instead so that we can mask position 0 as well - see Github #820 -# The new implementations below don't use custom_transforms in JAX but -# do cause Tracer errors, so we don't use them for now. + if q_info is None: + q_info = np.arange(1, q.shape[-2] + 1, dtype=np.int32) + else: + q_info += 1 + + if kv_info is None and not share_qk: + kv_info = np.arange(1, v.shape[-2] + 1, dtype=np.int32) + elif kv_info is not None: + kv_info += 1 + + # Split q/k/v into chunks along the time axis, if desired. + if q_chunk_len is not None: + q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) + q_info = np.reshape(q_info, (-1, q_chunk_len)) + + if share_qk: + assert kv_chunk_len is None or kv_chunk_len == q_chunk_len + k = q + kv_chunk_len = q_chunk_len + if kv_info is None: + kv_info = q_info + elif kv_chunk_len is not None: + # kv_info is not None, but reshape as required. + kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) + elif kv_chunk_len is not None: + k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) + kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) + if kv_chunk_len is not None: + v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) -def permute_via_gather(val, permutation, inverse_permutation, axis=0): - """Permutation helper for LSH attention.""" - def permute_impl(p, unused_ip, val): - return np.take(val, p, axis=axis) - def permute_fwd(p, ip, val): - return np.take(val, p, axis=axis), ip - def permute_bwd(ip, permuted_grad): - # JAX autodiff would synthesize a scatter operation because it doesn't - # know that the indices are a permutation. However on TPU, gathers are - # faster than scatters (at least in the regime the LSH attention uses). - return (None, None, np.take(permuted_grad, ip, axis=axis)) - permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) - return permute(permutation, inverse_permutation, val) + if share_qk: + k = length_normalized(k) + k = k / np.sqrt(k.shape[-1]) + # Optionally include adjacent chunks. + if q_chunk_len is not None or kv_chunk_len is not None: + assert q_chunk_len is not None and kv_chunk_len is not None + else: + assert n_chunks_before == 0 and n_chunks_after == 0 -def permute_via_sort(val, keys, inverse_keys, axis=0): - """Permutation helper for LSH attention.""" - def permute_impl(k, unused_ik, val): - # On TPU, sorting scalars by key is faster than a gather. - _, permuted = fastmath.sort_key_val(k, val, dimension=axis) - return permuted - def permute_fwd(k, ik, val): - # On TPU, sorting scalars by key is faster than a gather. - _, permuted = fastmath.sort_key_val(k, val, dimension=axis) - return permuted, ik - def permute_bwd(ik, permuted_grad): - _, val_grad = fastmath.sort_key_val( - ik, permuted_grad, dimension=axis) - return (None, None, val_grad) - permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) - return permute(keys, inverse_keys, val) + k = look_adjacent(k, n_chunks_before, n_chunks_after) + v = look_adjacent(v, n_chunks_before, n_chunks_after) + kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) + # Dot-product attention. + dots = np.matmul(q, np.swapaxes(k, -1, -2)) -####################################################### Classes + # Masking + if mask_fn is not None: + dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) + # Softmax. + dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True) + dots = np.exp(dots - dots_logsumexp) -class EfficientAttentionBase(base.Layer): - """Base class for efficient attention. + if dropout > 0.0: + assert rng is not None + # Dropout is broadcast across the bin dimension + dropout_shape = (dots.shape[-2], dots.shape[-1]) + # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) + keep_prob = 1.0 - dropout + keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape) + multiplier = keep.astype(dots.dtype) / keep_prob + dots = dots * multiplier - This is a base class that implements memory-efficient batching for both the - forward and backward passes. Subclasses should override - `create_weights_unbatched`, `create_state_unbatched`, `forward_unbatched`, and - optionally `incremental_forward_unbatched` to define the actual attention - mechanism. - """ + # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. + out = np.matmul(dots, v) + out = np.reshape(out, (-1, out.shape[-1])) + dots_logsumexp = np.reshape(dots_logsumexp, (-1,)) + return out, dots_logsumexp - def __init__(self, n_heads, n_in=1, n_parallel_heads=None, - incremental=False, predict_mem_len=None, predict_drop_len=None, - use_python_loop=False, use_reference_code=False): - """Constructs an EfficientAttentionBase instance. - Args: - n_heads: Number of attention heads. - n_in: Number of inputs to the layer (default 1). - n_parallel_heads: Number of attention heads to compute in parallel. - - - If `n_parallel_heads` is None (default), the entire layer is - computed with maximum parallelism. This mode is the fastest, but - also uses the most memory. Start with this mode, but switch to one - of the others if memory runs out. - - If `n_parallel_heads` is 1, attention is computed one head at a - time, and one example at a time. This mode uses the least memory - but is not as fast as batched attention. Use this mode when working - with very long sequences, such that any amount of parallelism won't - fit in memory. - - If `n_parallel_heads` is a multiple of `n_heads`, attention is - computed for sub-batches of (`n_parallel_heads // n_heads`) - examples at a time. - - If `1 < n_parallel_heads < n_heads`, attention is computed for - several heads at a time, but only within a single example. It must - be the case that `n_heads` is a multiple of `n_parallel_heads`. Use - this mode for long sequences, to strike a balance between - parallelism and memory usage. - incremental: If `True`, enable fast inference for self-attention types. - Note that this flag should *not* be set when doing encoder-decoder - attention, but only when doing self-attention. - predict_mem_len: Number of input positions to remember in a cache - when doing fast inference. Whenever the cache fills up, some input - elements will be forgotten. - predict_drop_len: Number of input elements to drop once the fast - inference input cache fills up. - use_python_loop: Set to True to use a Python loop when iterating over - sub-batches of examples/heads (as opposed to a JAX/XLA loop). - This option will increase compilation time and jitted code size, - potentially drastically. Using it is not recommended except for - testing/debugging. In particular, note that enabling this option on - TPU can decrease the maximum model size that will fit in memory. - use_reference_code: Set to True to fall back to the reference - implementation of batched attention. This option will increase - compilation time and jitted code size, potentially drastically. Using - it is not recommended except for testing/debugging. - """ - super().__init__(n_in=n_in, n_out=1) - self._n_heads = n_heads - self._incremental = incremental - if self._incremental: - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads +def apply_broadcasted_dropout(vecs, dropout_rate, rng): + """Apply dropout, broadcasted across all but the last dimension of `vecs`.""" + if dropout_rate > 0.0: + assert rng is not None + keep_prob = 1.0 - dropout_rate + keep = fastmath.random.bernoulli(rng, keep_prob, (vecs.shape[-1],)) + multiplier = keep.astype(vecs.dtype) / keep_prob + return vecs * multiplier else: - self._n_parallel_heads = None - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - def init_weights_and_state(self, input_signature): - if not isinstance(input_signature, (tuple, list)): - input_signature = (input_signature,) - input_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - input_signature) - batch_size = int(input_signature[0].shape[0]) - - weights = [] - weight_rngs = fastmath.random.split(self.rng, self._n_heads) - for i in range(self._n_heads): - weights.append(self.create_weights_unbatched(input_signature_unbatched, - weight_rngs[i])) - state = [] - state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) - for i in range(self._n_heads * batch_size): - state.append(self.create_state_unbatched(input_signature_unbatched, - state_rngs[i])) - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - input_signature) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = tuple(weights) - - def create_weights_unbatched(self, input_signature, rng): - raise NotImplementedError( - 'Subclasses should override create_weights_unbatched') + return vecs - def create_state_unbatched(self, input_signature, rng): - return () - def forward_unbatched(self, *inputs, weights, state): - """Perform attention for a single batch element and head. +# The new implementations below don't use custom_transforms in JAX but +# do cause Tracer errors, so we don't use them for now. - Subclasses should override this method. - Args: - *inputs: Inputs for a single example (subclasses may use different inputs) - weights: Weights for a single attention head - state: State for a single example & attention head pair. +def permute_via_gather(val, permutation, inverse_permutation, axis=0): + """Permutation helper for LSH attention.""" - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ - raise NotImplementedError('Subclasses should override forward_unbatched') + def permute_impl(p, unused_ip, val): + return np.take(val, p, axis=axis) - def _incremental_forward_unbatched(self, *inputs, q_start, q_len, - weights, state): - """Perform fast inference for a single batch element and head. + def permute_fwd(p, ip, val): + return np.take(val, p, axis=axis), ip - Subclasses should override this method. + def permute_bwd(ip, permuted_grad): + # JAX autodiff would synthesize a scatter operation because it doesn't + # know that the indices are a permutation. However on TPU, gathers are + # faster than scatters (at least in the regime the LSH attention uses). + return (None, None, np.take(permuted_grad, ip, axis=axis)) - Args: - *inputs: Inputs for a single example (subclasses may use different inputs) - q_start: Index along the sequence-length dimension that points to the - first input element that should be used as a query (and not just a key). - q_len: Number of new query elements in this call to the attention - mechanism. This is typically 1 for autoregressive decoding, but may be - longer if initializing a language model with a prefix. - weights: Weights for a single attention head - state: State for a single example & attention head pair. - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ - raise NotImplementedError( - 'Fast inference is not implemented for this attention type.') + permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) + return permute(permutation, inverse_permutation, val) - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. - Args: - inputs: Layer inputs (subclasses may use different inputs) +def permute_via_sort(val, keys, inverse_keys, axis=0): + """Permutation helper for LSH attention.""" - Returns: - A tuple (output, new_state). + def permute_impl(k, unused_ik, val): + # On TPU, sorting scalars by key is faster than a gather. + _, permuted = fastmath.sort_key_val(k, val, dimension=axis) + return permuted + + def permute_fwd(k, ik, val): + # On TPU, sorting scalars by key is faster than a gather. + _, permuted = fastmath.sort_key_val(k, val, dimension=axis) + return permuted, ik + + def permute_bwd(ik, permuted_grad): + _, val_grad = fastmath.sort_key_val(ik, permuted_grad, dimension=axis) + return (None, None, val_grad) + + permute = fastmath.custom_vjp(permute_impl, permute_fwd, permute_bwd) + return permute(keys, inverse_keys, val) + + +####################################################### Classes + + +class EfficientAttentionBase(base.Layer): + """Base class for efficient attention. + + This is a base class that implements memory-efficient batching for both the + forward and backward passes. Subclasses should override + `create_weights_unbatched`, `create_state_unbatched`, `forward_unbatched`, and + optionally `incremental_forward_unbatched` to define the actual attention + mechanism. """ - weights, state, rng = self.weights, self.state, self.rng - if not self._use_reference_code: - # By default, an efficient, batched implementation is used. - output, new_state, _, _ = self.forward_and_or_backward( - inputs, weights, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - # The reference implementation below provides a more readable overview of - # what this class does. It's not optimized, however, and should only be used - # when testing this class for correctness. - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - if self._incremental: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - - output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] - new_state = [] - for example_idx in range(batch_size): - for head_idx in range(self._n_heads): - # pylint: disable=cell-var-from-loop - single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) - single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) - single_state = fastmath.nested_map( - lambda s: s[example_idx * self._n_heads + head_idx], state) - # pylint: enable=cell-var-from-loop + + def __init__( + self, + n_heads, + n_in=1, + n_parallel_heads=None, + incremental=False, + predict_mem_len=None, + predict_drop_len=None, + use_python_loop=False, + use_reference_code=False, + ): + """Constructs an EfficientAttentionBase instance. + + Args: + n_heads: Number of attention heads. + n_in: Number of inputs to the layer (default 1). + n_parallel_heads: Number of attention heads to compute in parallel. + + - If `n_parallel_heads` is None (default), the entire layer is + computed with maximum parallelism. This mode is the fastest, but + also uses the most memory. Start with this mode, but switch to one + of the others if memory runs out. + - If `n_parallel_heads` is 1, attention is computed one head at a + time, and one example at a time. This mode uses the least memory + but is not as fast as batched attention. Use this mode when working + with very long sequences, such that any amount of parallelism won't + fit in memory. + - If `n_parallel_heads` is a multiple of `n_heads`, attention is + computed for sub-batches of (`n_parallel_heads // n_heads`) + examples at a time. + - If `1 < n_parallel_heads < n_heads`, attention is computed for + several heads at a time, but only within a single example. It must + be the case that `n_heads` is a multiple of `n_parallel_heads`. Use + this mode for long sequences, to strike a balance between + parallelism and memory usage. + incremental: If `True`, enable fast inference for self-attention types. + Note that this flag should *not* be set when doing encoder-decoder + attention, but only when doing self-attention. + predict_mem_len: Number of input positions to remember in a cache + when doing fast inference. Whenever the cache fills up, some input + elements will be forgotten. + predict_drop_len: Number of input elements to drop once the fast + inference input cache fills up. + use_python_loop: Set to True to use a Python loop when iterating over + sub-batches of examples/heads (as opposed to a JAX/XLA loop). + This option will increase compilation time and jitted code size, + potentially drastically. Using it is not recommended except for + testing/debugging. In particular, note that enabling this option on + TPU can decrease the maximum model size that will fit in memory. + use_reference_code: Set to True to fall back to the reference + implementation of batched attention. This option will increase + compilation time and jitted code size, potentially drastically. Using + it is not recommended except for testing/debugging. + """ + super().__init__(n_in=n_in, n_out=1) + self._n_heads = n_heads + self._incremental = incremental if self._incremental: - single_out, single_new_state = self._incremental_forward_unbatched( - *single_inputs, q_start=q_start, q_len=seqlen, - weights=single_weights, rng=rng, - state=single_state, update_state=True) + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads else: - single_out, single_new_state = self.forward_unbatched( - *single_inputs, weights=single_weights, rng=rng, - state=single_state, update_state=True) - new_state.append(single_new_state) - output_accum[example_idx] = output_accum[example_idx] + single_out - - output = np.stack(output_accum, 0) - if new_state and fastmath.tree_leaves(new_state[0]): - new_state = fastmath.nested_map_multiarg( - lambda *s: np.stack(s, 0), *new_state) - else: - new_state = state - if self._incremental: - new_state = (new_mem_end, new_mem, new_state) - self.state = tuple(new_state) - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) - else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - # Use an efficient backward pass, unless we're running the reference code. - return not self._use_reference_code - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - assert not self._use_reference_code - del output, state, kwargs - _, _, inputs_grad, weights_grad = self.forward_and_or_backward( - inputs, weights, new_state, rng, output_grad=grad, - compute_output=False, update_state=False) - return inputs_grad, weights_grad - - def forward_and_or_backward( - self, inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. + self._n_parallel_heads = None + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + def init_weights_and_state(self, input_signature): + if not isinstance(input_signature, (tuple, list)): + input_signature = (input_signature,) + input_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), input_signature + ) + batch_size = int(input_signature[0].shape[0]) + + weights = [] + weight_rngs = fastmath.random.split(self.rng, self._n_heads) + for i in range(self._n_heads): + weights.append( + self.create_weights_unbatched(input_signature_unbatched, weight_rngs[i]) + ) + state = [] + state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) + for i in range(self._n_heads * batch_size): + state.append( + self.create_state_unbatched(input_signature_unbatched, state_rngs[i]) + ) + + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + input_signature, + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = tuple(weights) + + def create_weights_unbatched(self, input_signature, rng): + raise NotImplementedError("Subclasses should override create_weights_unbatched") + + def create_state_unbatched(self, input_signature, rng): + return () + + def forward_unbatched(self, *inputs, weights, state): + """Perform attention for a single batch element and head. + + Subclasses should override this method. + + Args: + *inputs: Inputs for a single example (subclasses may use different inputs) + weights: Weights for a single attention head + state: State for a single example & attention head pair. + + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + raise NotImplementedError("Subclasses should override forward_unbatched") + + def _incremental_forward_unbatched(self, *inputs, q_start, q_len, weights, state): + """Perform fast inference for a single batch element and head. + + Subclasses should override this method. + + Args: + *inputs: Inputs for a single example (subclasses may use different inputs) + q_start: Index along the sequence-length dimension that points to the + first input element that should be used as a query (and not just a key). + q_len: Number of new query elements in this call to the attention + mechanism. This is typically 1 for autoregressive decoding, but may be + longer if initializing a language model with a prefix. + weights: Weights for a single attention head + state: State for a single example & attention head pair. + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + raise NotImplementedError( + "Fast inference is not implemented for this attention type." + ) - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. + + Args: + inputs: Layer inputs (subclasses may use different inputs) + + Returns: + A tuple (output, new_state). + """ + weights, state, rng = self.weights, self.state, self.rng + if not self._use_reference_code: + # By default, an efficient, batched implementation is used. + output, new_state, _, _ = self.forward_and_or_backward( + inputs, weights, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + # The reference implementation below provides a more readable overview of + # what this class does. It's not optimized, however, and should only be used + # when testing this class for correctness. + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(kitaev): profile ~4% speed drop compared to previous implementation - # in some conditions. Other conditions (e.g. the enwik8 model) appear - # to have the same overall training speed. - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - have_single_input = not isinstance(inputs, (tuple, list)) - if have_single_input: - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' - - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, new_values) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads == 1: - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - example_idx = idx // self._n_heads - head_idx = idx % self._n_heads - - i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_h = fastmath.nested_map(lambda w: w[head_idx], weights) - s_h = fastmath.nested_map(lambda s: s[idx], state) - - def forward_fn(i_h, w_h): - return forward_unbatched( - *i_h, weights=w_h, state=fastmath.stop_gradient(s_h)) + if self._incremental: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + + output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] + new_state = [] + for example_idx in range(batch_size): + for head_idx in range(self._n_heads): + # pylint: disable=cell-var-from-loop + single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) + single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) + single_state = fastmath.nested_map( + lambda s: s[example_idx * self._n_heads + head_idx], state + ) + # pylint: enable=cell-var-from-loop + if self._incremental: + single_out, single_new_state = self._incremental_forward_unbatched( + *single_inputs, + q_start=q_start, + q_len=seqlen, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + else: + single_out, single_new_state = self.forward_unbatched( + *single_inputs, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + new_state.append(single_new_state) + output_accum[example_idx] = output_accum[example_idx] + single_out + + output = np.stack(output_accum, 0) + if new_state and fastmath.tree_leaves(new_state[0]): + new_state = fastmath.nested_map_multiarg( + lambda *s: np.stack(s, 0), *new_state + ) + else: + new_state = state + if self._incremental: + new_state = (new_mem_end, new_mem, new_state) + self.state = tuple(new_state) + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen + else: + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + # Use an efficient backward pass, unless we're running the reference code. + return not self._use_reference_code + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + assert not self._use_reference_code + del output, state, kwargs + _, _, inputs_grad, weights_grad = self.forward_and_or_backward( + inputs, + weights, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) + return inputs_grad, weights_grad - if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) - ct_h = output_grad[example_idx] - assert o_h.shape == ct_h.shape - i_ct_h, w_ct_h = backward_fn(ct_h) + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(kitaev): profile ~4% speed drop compared to previous implementation + # in some conditions. Other conditions (e.g. the enwik8 model) appear + # to have the same overall training speed. + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + have_single_input = not isinstance(inputs, (tuple, list)) + if have_single_input: + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) else: - o_h, s_h = forward_fn(i_h, w_h) + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_h) - if update_state: - s_all = tree_update(s_all, idx, s_h) if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) - w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) - return (o_all, s_all, i_ct_all, w_ct_all) - elif n_parallel_heads < self._n_heads: - assert self._n_heads % n_parallel_heads == 0 - def run_inner(idx, loop_val): - """Runs one slice of attention (multiple heads, but one example).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * self._n_parallel_heads - example_idx = idx // self._n_heads - head_idx_lo = idx % self._n_heads - head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_mh = fastmath.nested_map(lambda w: w[head_range], weights) - s_mh = fastmath.nested_map(lambda s: s[state_range], state) - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - def forward_fn(i_mh, w_mh): - o_mh, new_s_mh = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0)( - i_mh, w_mh, s_mh) - o_mh = np.sum(o_mh, axis=0) - return o_mh, new_s_mh + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads == 1: + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + example_idx = idx // self._n_heads + head_idx = idx % self._n_heads + + i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_h = fastmath.nested_map(lambda w: w[head_idx], weights) + s_h = fastmath.nested_map(lambda s: s[idx], state) + + def forward_fn(i_h, w_h): + return forward_unbatched( + *i_h, weights=w_h, state=fastmath.stop_gradient(s_h) + ) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) + ct_h = output_grad[example_idx] + assert o_h.shape == ct_h.shape + i_ct_h, w_ct_h = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h, w_h) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) + w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) + return (o_all, s_all, i_ct_all, w_ct_all) + + elif n_parallel_heads < self._n_heads: + assert self._n_heads % n_parallel_heads == 0 + + def run_inner(idx, loop_val): + """Runs one slice of attention (multiple heads, but one example).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * self._n_parallel_heads + example_idx = idx // self._n_heads + head_idx_lo = idx % self._n_heads + head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_mh = fastmath.nested_map(lambda w: w[head_range], weights) + s_mh = fastmath.nested_map(lambda s: s[state_range], state) + + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + def forward_fn(i_mh, w_mh): + o_mh, new_s_mh = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0 + )(i_mh, w_mh, s_mh) + o_mh = np.sum(o_mh, axis=0) + return o_mh, new_s_mh + + if compute_grad: + o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) + ct_mh = output_grad[example_idx] + assert o_mh.shape == ct_mh.shape + i_ct_mh, w_ct_mh = backward_fn(ct_mh) + else: + o_mh, s_mh = forward_fn(i_mh, w_mh) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_mh) + if update_state: + s_all = tree_update(s_all, state_range, s_mh) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) + w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) + return (o_all, s_all, i_ct_all, w_ct_all) - if compute_grad: - o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) - ct_mh = output_grad[example_idx] - assert o_mh.shape == ct_mh.shape - i_ct_mh, w_ct_mh = backward_fn(ct_mh) else: - o_mh, s_mh = forward_fn(i_mh, w_mh) - + assert n_parallel_heads % self._n_heads == 0 + + def forward_single_example(i_x, w_all, s_x): + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + o_x, s_x = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0) + )(i_x, w_all, s_x) + o_x = np.sum(o_x, axis=0) + return o_x, s_x + + def run_inner(idx, loop_val): + """Runs one slice of attention (all heads for one or more examples).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * n_parallel_heads + example_idx_lo = idx // self._n_heads + example_range = example_idx_lo + np.arange( + n_parallel_heads // self._n_heads, dtype=np.int32 + ) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) + s_mex = fastmath.nested_map( + lambda s: np.reshape( + s[state_range], # pylint: disable=g-long-lambda + (-1, self._n_heads) + s.shape[1:], + ), + state, + ) + + def forward_fn(i_mex, w_all): + o_mex, new_s_mex = fastmath.vmap( + forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0) + )(i_mex, w_all, s_mex) + new_s_mex = fastmath.nested_map( + lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), + new_s_mex, + ) + return o_mex.astype(i_mex[0].dtype), new_s_mex + + if compute_grad: + o_mex, backward_fn, s_mex = vjp( + forward_fn, i_mex, weights, has_aux=True + ) + ct_mex = output_grad[example_range] + assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) + assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) + i_ct_mex, w_ct_mex = backward_fn(ct_mex) + else: + o_mex, s_mex = forward_fn(i_mex, weights) + + if compute_output: + o_all = fastmath.index_add( + o_all, jax.numpy.index_exp[example_range], o_mex + ) + if update_state: + s_all = tree_update(s_all, state_range, s_mex) + if compute_grad: + i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) + w_ct_all = fastmath.nested_map_multiarg( + lambda old_all, delta_all: old_all + delta_all, + w_ct_all, + w_ct_mex, + ) + return (o_all, s_all, i_ct_all, w_ct_all) + + o_all = s_all = i_ct_all = w_ct_all = None if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_mh) + o_all = np.zeros((batch_size, seqlen, d_model), dtype=inputs[0].dtype) if update_state: - s_all = tree_update(s_all, state_range, s_mh) + s_all = state if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) - w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) - return (o_all, s_all, i_ct_all, w_ct_all) - else: - assert n_parallel_heads % self._n_heads == 0 - def forward_single_example(i_x, w_all, s_x): - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - o_x, s_x = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0))( - i_x, w_all, s_x) - o_x = np.sum(o_x, axis=0) - return o_x, s_x - def run_inner(idx, loop_val): - """Runs one slice of attention (all heads for one or more examples).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * n_parallel_heads - example_idx_lo = idx // self._n_heads - example_range = example_idx_lo + np.arange( - n_parallel_heads // self._n_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) - s_mex = fastmath.nested_map( - lambda s: np.reshape(s[state_range], # pylint: disable=g-long-lambda - (-1, self._n_heads) + s.shape[1:]), - state) - def forward_fn(i_mex, w_all): - o_mex, new_s_mex = fastmath.vmap( - forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0))( - i_mex, w_all, s_mex) - new_s_mex = fastmath.nested_map( - lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), - new_s_mex) - return o_mex.astype(i_mex[0].dtype), new_s_mex + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) + w_ct_all = fastmath.nested_map(np.zeros_like, weights) - if compute_grad: - o_mex, backward_fn, s_mex = vjp(forward_fn, i_mex, weights, - has_aux=True) - ct_mex = output_grad[example_range] - assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) - assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) - i_ct_mex, w_ct_mex = backward_fn(ct_mex) - else: - o_mex, s_mex = forward_fn(i_mex, weights) + loop_val = (o_all, s_all, i_ct_all, w_ct_all) - if compute_output: - o_all = fastmath.index_add(o_all, jax.numpy.index_exp[example_range], - o_mex) - if update_state: - s_all = tree_update(s_all, state_range, s_mex) - if compute_grad: - i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) - w_ct_all = fastmath.nested_map_multiarg( - lambda old_all, delta_all: old_all + delta_all, - w_ct_all, w_ct_mex) - return (o_all, s_all, i_ct_all, w_ct_all) - - o_all = s_all = i_ct_all = w_ct_all = None - if compute_output: - o_all = np.zeros( - (batch_size, seqlen, d_model), dtype=inputs[0].dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - w_ct_all = fastmath.nested_map(np.zeros_like, weights) - - loop_val = (o_all, s_all, i_ct_all, w_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) + else: + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - (o_all, s_all, i_ct_all, w_ct_all) = loop_val + (o_all, s_all, i_ct_all, w_ct_all) = loop_val - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + if compute_grad: + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) - if have_single_input and compute_grad: - assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 - return (o_all, s_all, i_ct_all[0], w_ct_all) - else: - return (o_all, s_all, i_ct_all, w_ct_all) + if have_single_input and compute_grad: + assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 + return (o_all, s_all, i_ct_all[0], w_ct_all) + else: + return (o_all, s_all, i_ct_all, w_ct_all) class SelfAttention(base.Layer): - """Memory-efficient self-attention (second attempt).""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, share_qk=False, - causal=False, masked=False, - chunk_len=None, n_chunks_before=0, n_chunks_after=0, - bias=False, - mode='train', - predict_mem_len=None, predict_drop_len=None, - attention_dropout=0.0, - output_dropout=0.0, - n_parallel_heads=None, - use_python_loop=False, - use_reference_code=False, - ): - """Construct a self-attention layer. + """Memory-efficient self-attention (second attempt).""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + share_qk=False, + causal=False, + masked=False, + chunk_len=None, + n_chunks_before=0, + n_chunks_after=0, + bias=False, + mode="train", + predict_mem_len=None, + predict_drop_len=None, + attention_dropout=0.0, + output_dropout=0.0, + n_parallel_heads=None, + use_python_loop=False, + use_reference_code=False, + ): + """Construct a self-attention layer. + + Args: + n_heads: int: Number of attention heads + d_qk: int: Depth of query ond key vectors + d_v: int: Depth of value vectors + share_qk: bool: Set to True to share query and key projection weights + causal: bool: Set to True to mask out attention to future items + masked: bool: Set to True to accept an additional mask argument, that + allows masking out attention to padding tokens. + chunk_len (optional): Number of tokens per chunk. Setting this option will + enable chunked attention. + n_chunks_before: Number of previous chunks to attend to, when using + chunked attention. + n_chunks_after: Number of subsequent chunks to attend to, when using + chunked attention. Don't use this option for causal attention, because + attention to future tokens will be masked out anyway. However, note that + cross-chunk attention "wraps around" in both directions, so this option + is never a strict no-op. + bias: bool: Set to True to add bias vectors when computing query/key/value + mode: 'train', 'eval', or 'predict' + predict_mem_len: int: Number of input positions to remember in a cache + when doing fast inference. Whenever the cache fills up, some input + elements will be forgotten. When chunking is enabled, the default is to + store chunk_len * (1 + n_chunks_before) elements. + predict_drop_len: int: Number of input elements to drop once the fast + inference input cache fills up. When chunking is enabled, the default is + to drop exactly chunk_len elements. + attention_dropout: Dropout probability for attention mask. + output_dropout: Dropout probability for the layer output. + n_parallel_heads: Number of attention heads to compute in parallel. + + - If `n_parallel_heads` is None (default), the entire layer is + computed with maximum parallelism. This mode is the fastest, but + also uses the most memory. Start with this mode, but switch to one + of the others if memory runs out. + - If `n_parallel_heads` is 1, attention is computed one head at a + time, and one example at a time. This mode uses the least memory + but is not as fast as batched attention. Use this mode when working + with very long sequences, such that any amount of parallelism won't + fit in memory. + - If `n_parallel_heads` is a multiple of `n_heads`, attention is + computed for sub-batches of (`n_parallel_heads // n_heads`) + examples at a time. + - If `1 < n_parallel_heads < n_heads`, attention is computed for + several heads at a time, but only within a single example. It must + be the case that `n_heads` is a multiple of `n_parallel_heads`. Use + this mode for long sequences, to strike a balance between + parallelism and memory usage. + use_python_loop: Set to True to use a Python loop when iterating over + sub-batches of examples/heads (as opposed to a JAX/XLA loop). + This option will increase compilation time and jitted code size, + potentially drastically. Using it is not recommended except for + testing/debugging. In particular, note that enabling this option on + TPU can decrease the maximum model size that will fit in memory. + use_reference_code: Set to True to fall back to the reference + implementation of batched attention. This option will increase + compilation time and jitted code size, potentially drastically. Using + it is not recommended except for testing/debugging. + + """ + super().__init__(n_in=(2 if masked else 1), n_out=1) + + self._incremental = mode == "predict" + if self._incremental: + assert causal, "Only causal attention supports fast inference" + assert chunk_len is not None or (predict_mem_len and predict_drop_len) + predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) + predict_drop_len = predict_drop_len or chunk_len + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + self._n_heads = n_heads + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads + else: + self._n_parallel_heads = None + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + self._d_qk = d_qk + self._d_v = d_v + self._share_qk = share_qk + self._causal = causal + self._masked = masked + self._chunk_len = chunk_len + self._n_chunks_before = n_chunks_before + self._n_chunks_after = n_chunks_after + self._bias = bias + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout + else: + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def init_weights_and_state(self, input_signature): + if not isinstance(input_signature, (tuple, list)): + input_signature = (input_signature,) + else: + input_signature = (input_signature[0],) + input_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), input_signature + ) + batch_size = int(input_signature[0].shape[0]) + + weights = [] + weight_rngs = fastmath.random.split(self.rng, self._n_heads) + for i in range(self._n_heads): + weights.append( + self.create_weights_unbatched(input_signature_unbatched, weight_rngs[i]) + ) + state = [] + state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) + for i in range(self._n_heads * batch_size): + state.append( + self.create_state_unbatched(input_signature_unbatched, state_rngs[i]) + ) + + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - Args: - n_heads: int: Number of attention heads - d_qk: int: Depth of query ond key vectors - d_v: int: Depth of value vectors - share_qk: bool: Set to True to share query and key projection weights - causal: bool: Set to True to mask out attention to future items - masked: bool: Set to True to accept an additional mask argument, that - allows masking out attention to padding tokens. - chunk_len (optional): Number of tokens per chunk. Setting this option will - enable chunked attention. - n_chunks_before: Number of previous chunks to attend to, when using - chunked attention. - n_chunks_after: Number of subsequent chunks to attend to, when using - chunked attention. Don't use this option for causal attention, because - attention to future tokens will be masked out anyway. However, note that - cross-chunk attention "wraps around" in both directions, so this option - is never a strict no-op. - bias: bool: Set to True to add bias vectors when computing query/key/value - mode: 'train', 'eval', or 'predict' - predict_mem_len: int: Number of input positions to remember in a cache - when doing fast inference. Whenever the cache fills up, some input - elements will be forgotten. When chunking is enabled, the default is to - store chunk_len * (1 + n_chunks_before) elements. - predict_drop_len: int: Number of input elements to drop once the fast - inference input cache fills up. When chunking is enabled, the default is - to drop exactly chunk_len elements. - attention_dropout: Dropout probability for attention mask. - output_dropout: Dropout probability for the layer output. - n_parallel_heads: Number of attention heads to compute in parallel. - - - If `n_parallel_heads` is None (default), the entire layer is - computed with maximum parallelism. This mode is the fastest, but - also uses the most memory. Start with this mode, but switch to one - of the others if memory runs out. - - If `n_parallel_heads` is 1, attention is computed one head at a - time, and one example at a time. This mode uses the least memory - but is not as fast as batched attention. Use this mode when working - with very long sequences, such that any amount of parallelism won't - fit in memory. - - If `n_parallel_heads` is a multiple of `n_heads`, attention is - computed for sub-batches of (`n_parallel_heads // n_heads`) - examples at a time. - - If `1 < n_parallel_heads < n_heads`, attention is computed for - several heads at a time, but only within a single example. It must - be the case that `n_heads` is a multiple of `n_parallel_heads`. Use - this mode for long sequences, to strike a balance between - parallelism and memory usage. - use_python_loop: Set to True to use a Python loop when iterating over - sub-batches of examples/heads (as opposed to a JAX/XLA loop). - This option will increase compilation time and jitted code size, - potentially drastically. Using it is not recommended except for - testing/debugging. In particular, note that enabling this option on - TPU can decrease the maximum model size that will fit in memory. - use_reference_code: Set to True to fall back to the reference - implementation of batched attention. This option will increase - compilation time and jitted code size, potentially drastically. Using - it is not recommended except for testing/debugging. + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + input_signature, + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = tuple(weights) + + def create_weights_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + d_model = input_signature.shape[-1] + rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) + w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) + if not self._share_qk: + w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) + w_v = self._kernel_initializer((d_model, self._d_v), rng_v) + w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) + + if self._bias: + b_q = np.zeros(self._d_qk) + b_v = np.zeros(self._d_v) + if self._share_qk: + return (w_q, w_v, w_o, b_q, b_v) + else: + b_k = np.zeros(self._d_qk) + return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) + + if self._share_qk: + return (w_q, w_v, w_o) + else: + return (w_q, w_k, w_v, w_o) + + def create_state_unbatched(self, input_signature, rng): + return () + + def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state): + """Perform attention for a single batch element and head. + + Args: + x: Inputs for a single example (subclasses may use different inputs) + mask: Mask for the inputs. + weights: Weights for a single attention head + state: State for a single example & attention head pair. + rng: PRNG key for the layer (shared across all examples and heads) + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + + del update_state + attend_rng, output_rng = fastmath.random.split(rng) + if self._bias: + if self._share_qk: + w_q, w_v, w_o, b_q, b_v = weights + else: + w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights + else: + if self._share_qk: + w_q, w_v, w_o = weights + else: + w_q, w_k, w_v, w_o = weights + + q = np.matmul(x, w_q) + k = None + if not self._share_qk: + k = np.matmul(x, w_k) + v = np.matmul(x, w_v) + + if self._bias: + q = q + b_q + if not self._share_qk: + k = k + b_k + v = v + b_v + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=self._share_qk, + masked=self._masked, + ) + q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32) + + assert (mask is not None) == self._masked + if self._masked: + # mask is a boolean array (True means "is valid token") + ones_like_mask = np.ones_like(mask, dtype=np.int32) + kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask) + + o, _ = attend( + q, + k, + v, + q_chunk_len=self._chunk_len, + kv_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - """ - super().__init__(n_in=(2 if masked else 1), n_out=1) - - self._incremental = (mode == 'predict') - if self._incremental: - assert causal, 'Only causal attention supports fast inference' - assert chunk_len is not None or (predict_mem_len and predict_drop_len) - predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) - predict_drop_len = predict_drop_len or chunk_len - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - self._n_heads = n_heads - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads - else: - self._n_parallel_heads = None - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - self._d_qk = d_qk - self._d_v = d_v - self._share_qk = share_qk - self._causal = causal - self._masked = masked - self._chunk_len = chunk_len - self._n_chunks_before = n_chunks_before - self._n_chunks_after = n_chunks_after - self._bias = bias - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def init_weights_and_state(self, input_signature): - if not isinstance(input_signature, (tuple, list)): - input_signature = (input_signature,) - else: - input_signature = (input_signature[0],) - input_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - input_signature) - batch_size = int(input_signature[0].shape[0]) - - weights = [] - weight_rngs = fastmath.random.split(self.rng, self._n_heads) - for i in range(self._n_heads): - weights.append(self.create_weights_unbatched(input_signature_unbatched, - weight_rngs[i])) - state = [] - state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) - for i in range(self._n_heads * batch_size): - state.append(self.create_state_unbatched(input_signature_unbatched, - state_rngs[i])) - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - input_signature) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = tuple(weights) - - def create_weights_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - d_model = input_signature.shape[-1] - rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) - w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) - if not self._share_qk: - w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) - w_v = self._kernel_initializer((d_model, self._d_v), rng_v) - w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) - - if self._bias: - b_q = np.zeros(self._d_qk) - b_v = np.zeros(self._d_v) - if self._share_qk: - return (w_q, w_v, w_o, b_q, b_v) - else: - b_k = np.zeros(self._d_qk) - return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) - - if self._share_qk: - return (w_q, w_v, w_o) - else: - return (w_q, w_k, w_v, w_o) + out = np.matmul(o, w_o) + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state - def create_state_unbatched(self, input_signature, rng): - return () + def _incremental_forward_unbatched( + self, x, mask=None, *, q_start, q_len, weights, state, rng, update_state + ): + """Perform fast inference for a single batch element and head. + + Args: + x: Inputs for a single example (subclasses may use different inputs) + mask: inputs mask. + q_start: Index along the sequence-length dimension that points to the + first input element that should be used as a query (and not just a key). + q_len: Number of new query elements in this call to the attention + mechanism. This is typically 1 for autoregressive decoding, but may be + longer if initializing a language model with a prefix. + weights: Weights for a single attention head + state: State for a single example & attention head pair. + rng: PRNG key for the layer (shared across all examples and heads) + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state) -- output and new state for a single example + and attention head. + """ + del update_state + attend_rng, output_rng = fastmath.random.split(rng) + if self._share_qk: + w_q, w_v, w_o = weights + else: + w_q, w_k, w_v, w_o = weights + + q_range = q_start + np.arange(q_len, dtype=np.int32) + if q_len == 1: + # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not + # floating-point equivalent, at least in non-jitted code. We correct the + # discrepancy by duplicating the slice. Floating-point noise may not be + # an issue when using models, but it makes it harder to write tests that + # compare fast and slow inference code for equivalence. + q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) + else: + q = np.matmul(x[q_range], w_q) + if self._share_qk: + k = length_normalized(np.matmul(x, w_q)) + else: + k = np.matmul(x, w_k) + v = np.matmul(x, w_v) + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=self._share_qk, + masked=self._masked, + ) + q_info = q_range + kv_info = np.arange(k.shape[-2], dtype=np.int32) + + if self._chunk_len is not None and q_len > self._chunk_len: + assert q_start == 0 + assert q_len % self._chunk_len == 0 + o, _ = attend( + q, + k, + v, + q_chunk_len=self._chunk_len, + kv_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) + else: + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) + + out = np.matmul(o, w_o) + if q_len == 1: + out = out[:1] + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state + + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. + + Args: + inputs: Layer inputs (subclasses may use different inputs) + + Returns: + A tuple (output, new_state). + """ + weights, state, rng = self.weights, self.state, self.rng + if not self._use_reference_code: + # By default, an efficient, batched implementation is used. + output, new_state, _, _ = self.forward_and_or_backward( + inputs, weights, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + # The reference implementation below provides a more readable overview of + # what this class does. It's not optimized, however, and should only be used + # when testing this class for correctness. + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] - def forward_unbatched(self, x, mask=None, *, - weights, state, rng, update_state): - """Perform attention for a single batch element and head. + if self._incremental: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + + output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] + new_state = [] + for example_idx in range(batch_size): + for head_idx in range(self._n_heads): + # pylint: disable=cell-var-from-loop + single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) + single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) + single_state = fastmath.nested_map( + lambda s: s[example_idx * self._n_heads + head_idx], state + ) + # pylint: enable=cell-var-from-loop + if self._incremental: + single_out, single_new_state = self._incremental_forward_unbatched( + *single_inputs, + q_start=q_start, + q_len=seqlen, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + else: + single_out, single_new_state = self.forward_unbatched( + *single_inputs, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + new_state.append(single_new_state) + output_accum[example_idx] = output_accum[example_idx] + single_out + + output = np.stack(output_accum, 0) + if new_state and fastmath.tree_leaves(new_state[0]): + new_state = fastmath.nested_map_multiarg( + lambda *s: np.stack(s, 0), *new_state + ) + else: + new_state = state + if self._incremental: + new_state = (new_mem_end, new_mem, new_state) + self.state = tuple(new_state) + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen + else: + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + # Use an efficient backward pass, unless we're running the reference code. + return not self._use_reference_code + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + assert not self._use_reference_code + del output, state, kwargs + _, _, inputs_grad, weights_grad = self.forward_and_or_backward( + inputs, + weights, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) + return inputs_grad, weights_grad - Args: - x: Inputs for a single example (subclasses may use different inputs) - mask: Mask for the inputs. - weights: Weights for a single attention head - state: State for a single example & attention head pair. - rng: PRNG key for the layer (shared across all examples and heads) - update_state: bool: whether to return an updated layer state. + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(kitaev): profile ~4% speed drop compared to previous implementation + # in some conditions. Other conditions (e.g. the enwik8 model) appear + # to have the same overall training speed. + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + have_single_input = not isinstance(inputs, (tuple, list)) + if have_single_input: + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) + else: + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ + if compute_grad: + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads == 1: + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + example_idx = idx // self._n_heads + head_idx = idx % self._n_heads + + i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_h = fastmath.nested_map(lambda w: w[head_idx], weights) + s_h = fastmath.nested_map(lambda s: s[idx], state) + + def forward_fn(i_h, w_h): + return forward_unbatched( + *i_h, weights=w_h, state=fastmath.stop_gradient(s_h) + ) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) + ct_h = output_grad[example_idx] + assert o_h.shape == ct_h.shape + i_ct_h, w_ct_h = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h, w_h) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) + w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) + return (o_all, s_all, i_ct_all, w_ct_all) + + elif n_parallel_heads < self._n_heads: + assert self._n_heads % n_parallel_heads == 0 + + def run_inner(idx, loop_val): + """Runs one slice of attention (multiple heads, but one example).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * self._n_parallel_heads + example_idx = idx // self._n_heads + head_idx_lo = idx % self._n_heads + head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_mh = fastmath.nested_map(lambda w: w[head_range], weights) + s_mh = fastmath.nested_map(lambda s: s[state_range], state) + + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + def forward_fn(i_mh, w_mh): + o_mh, new_s_mh = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0 + )(i_mh, w_mh, s_mh) + o_mh = np.sum(o_mh, axis=0) + return o_mh, new_s_mh + + if compute_grad: + o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) + ct_mh = output_grad[example_idx] + assert o_mh.shape == ct_mh.shape + i_ct_mh, w_ct_mh = backward_fn(ct_mh) + else: + o_mh, s_mh = forward_fn(i_mh, w_mh) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_mh) + if update_state: + s_all = tree_update(s_all, state_range, s_mh) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) + w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) + return (o_all, s_all, i_ct_all, w_ct_all) - del update_state - attend_rng, output_rng = fastmath.random.split(rng) - if self._bias: - if self._share_qk: - w_q, w_v, w_o, b_q, b_v = weights - else: - w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights - else: - if self._share_qk: - w_q, w_v, w_o = weights - else: - w_q, w_k, w_v, w_o = weights + else: + assert n_parallel_heads % self._n_heads == 0 + + def forward_single_example(i_x, w_all, s_x): + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + o_x, s_x = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0) + )(i_x, w_all, s_x) + o_x = np.sum(o_x, axis=0) + return o_x, s_x + + def run_inner(idx, loop_val): + """Runs one slice of attention (all heads for one or more examples).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * n_parallel_heads + example_idx_lo = idx // self._n_heads + example_range = example_idx_lo + np.arange( + n_parallel_heads // self._n_heads, dtype=np.int32 + ) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) + s_mex = fastmath.nested_map( + lambda s: np.reshape( + s[state_range], # pylint: disable=g-long-lambda + (-1, self._n_heads) + s.shape[1:], + ), + state, + ) + + def forward_fn(i_mex, w_all): + o_mex, new_s_mex = fastmath.vmap( + forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0) + )(i_mex, w_all, s_mex) + new_s_mex = fastmath.nested_map( + lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), + new_s_mex, + ) + return o_mex.astype(i_mex[0].dtype), new_s_mex + + if compute_grad: + o_mex, backward_fn, s_mex = vjp( + forward_fn, i_mex, weights, has_aux=True + ) + ct_mex = output_grad[example_range] + assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) + assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) + i_ct_mex, w_ct_mex = backward_fn(ct_mex) + else: + o_mex, s_mex = forward_fn(i_mex, weights) + + if compute_output: + o_all = fastmath.index_add( + o_all, jax.numpy.index_exp[example_range], o_mex + ) + if update_state: + s_all = tree_update(s_all, state_range, s_mex) + if compute_grad: + i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) + w_ct_all = fastmath.nested_map_multiarg( + lambda old_all, delta_all: old_all + delta_all, + w_ct_all, + w_ct_mex, + ) + return (o_all, s_all, i_ct_all, w_ct_all) + + o_all = s_all = i_ct_all = w_ct_all = None + if compute_output: + o_all = np.zeros((batch_size, seqlen, d_model), dtype=inputs[0].dtype) + if update_state: + s_all = state + if compute_grad: + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) + w_ct_all = fastmath.nested_map(np.zeros_like, weights) - q = np.matmul(x, w_q) - k = None - if not self._share_qk: - k = np.matmul(x, w_k) - v = np.matmul(x, w_v) - - if self._bias: - q = q + b_q - if not self._share_qk: - k = k + b_k - v = v + b_v - - mask_fn = functools.partial( - mask_self_attention, - causal=self._causal, exclude_self=self._share_qk, masked=self._masked) - q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32) - - assert (mask is not None) == self._masked - if self._masked: - # mask is a boolean array (True means "is valid token") - ones_like_mask = np.ones_like(mask, dtype=np.int32) - kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask) - - o, _ = attend( - q, k, v, - q_chunk_len=self._chunk_len, - kv_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) + loop_val = (o_all, s_all, i_ct_all, w_ct_all) - out = np.matmul(o, w_o) - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) + else: + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - def _incremental_forward_unbatched(self, x, mask=None, *, - q_start, q_len, - weights, state, rng, update_state): - """Perform fast inference for a single batch element and head. + (o_all, s_all, i_ct_all, w_ct_all) = loop_val - Args: - x: Inputs for a single example (subclasses may use different inputs) - mask: inputs mask. - q_start: Index along the sequence-length dimension that points to the - first input element that should be used as a query (and not just a key). - q_len: Number of new query elements in this call to the attention - mechanism. This is typically 1 for autoregressive decoding, but may be - longer if initializing a language model with a prefix. - weights: Weights for a single attention head - state: State for a single example & attention head pair. - rng: PRNG key for the layer (shared across all examples and heads) - update_state: bool: whether to return an updated layer state. + if compute_grad: + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) - Returns: - A tuple (output, new_state) -- output and new state for a single example - and attention head. - """ - del update_state - attend_rng, output_rng = fastmath.random.split(rng) - if self._share_qk: - w_q, w_v, w_o = weights - else: - w_q, w_k, w_v, w_o = weights - - q_range = q_start + np.arange(q_len, dtype=np.int32) - if q_len == 1: - # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not - # floating-point equivalent, at least in non-jitted code. We correct the - # discrepancy by duplicating the slice. Floating-point noise may not be - # an issue when using models, but it makes it harder to write tests that - # compare fast and slow inference code for equivalence. - q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) - else: - q = np.matmul(x[q_range], w_q) - if self._share_qk: - k = length_normalized(np.matmul(x, w_q)) - else: - k = np.matmul(x, w_k) - v = np.matmul(x, w_v) - - mask_fn = functools.partial( - mask_self_attention, - causal=self._causal, exclude_self=self._share_qk, masked=self._masked) - q_info = q_range - kv_info = np.arange(k.shape[-2], dtype=np.int32) - - if self._chunk_len is not None and q_len > self._chunk_len: - assert q_start == 0 - assert q_len % self._chunk_len == 0 - o, _ = attend( - q, k, v, - q_chunk_len=self._chunk_len, - kv_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) - else: - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) - out = np.matmul(o, w_o) - if q_len == 1: - out = out[:1] - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state + if have_single_input and compute_grad: + assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 + return (o_all, s_all, i_ct_all[0], w_ct_all) + else: + return (o_all, s_all, i_ct_all, w_ct_all) - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. - Args: - inputs: Layer inputs (subclasses may use different inputs) +class LSHSelfAttention(base.Layer): + """LSH self-attention (second implementation).""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + share_qk="unused", + causal=False, + masked=False, + chunk_len=128, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=1, + n_buckets=None, + mode="train", + predict_mem_len=2048, + predict_drop_len=256, + attention_dropout=0.0, + output_dropout=0.0, + max_length_for_buckets=None, + bias=False, + n_parallel_heads=1, + use_python_loop=False, + use_reference_code=False, + ): + """Construct an LSH self-attention layer.""" + super().__init__(n_in=(2 if masked else 1), n_out=1) + + self._n_heads = n_heads + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads + else: + self._n_parallel_heads = None - Returns: - A tuple (output, new_state). - """ - weights, state, rng = self.weights, self.state, self.rng - if not self._use_reference_code: - # By default, an efficient, batched implementation is used. - output, new_state, _, _ = self.forward_and_or_backward( - inputs, weights, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - # The reference implementation below provides a more readable overview of - # what this class does. It's not optimized, however, and should only be used - # when testing this class for correctness. - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - if self._incremental: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - - output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] - new_state = [] - for example_idx in range(batch_size): - for head_idx in range(self._n_heads): - # pylint: disable=cell-var-from-loop - single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) - single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) - single_state = fastmath.nested_map( - lambda s: s[example_idx * self._n_heads + head_idx], state) - # pylint: enable=cell-var-from-loop + self._incremental = mode == "predict" if self._incremental: - single_out, single_new_state = self._incremental_forward_unbatched( - *single_inputs, q_start=q_start, q_len=seqlen, - weights=single_weights, rng=rng, - state=single_state, update_state=True) + assert causal, "Only causal attention supports fast inference" + assert chunk_len is not None or (predict_mem_len and predict_drop_len) + predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) + predict_drop_len = predict_drop_len or chunk_len + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + self._d_qk = d_qk + self._d_v = d_v + self._share_qk = True + self._causal = causal + self._masked = masked + self._chunk_len = chunk_len + self._n_chunks_before = n_chunks_before + self._n_chunks_after = n_chunks_after + self._bias = bias + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout else: - single_out, single_new_state = self.forward_unbatched( - *single_inputs, weights=single_weights, rng=rng, - state=single_state, update_state=True) - new_state.append(single_new_state) - output_accum[example_idx] = output_accum[example_idx] + single_out - - output = np.stack(output_accum, 0) - if new_state and fastmath.tree_leaves(new_state[0]): - new_state = fastmath.nested_map_multiarg( - lambda *s: np.stack(s, 0), *new_state) - else: - new_state = state - if self._incremental: - new_state = (new_mem_end, new_mem, new_state) - self.state = tuple(new_state) - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) - else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - # Use an efficient backward pass, unless we're running the reference code. - return not self._use_reference_code - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - assert not self._use_reference_code - del output, state, kwargs - _, _, inputs_grad, weights_grad = self.forward_and_or_backward( - inputs, weights, new_state, rng, output_grad=grad, - compute_output=False, update_state=False) - return inputs_grad, weights_grad - - def forward_and_or_backward( - self, inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. - - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. - - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). - - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(kitaev): profile ~4% speed drop compared to previous implementation - # in some conditions. Other conditions (e.g. the enwik8 model) appear - # to have the same overall training speed. - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - have_single_input = not isinstance(inputs, (tuple, list)) - if have_single_input: - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' - - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, new_values) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads == 1: - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - example_idx = idx // self._n_heads - head_idx = idx % self._n_heads - - i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_h = fastmath.nested_map(lambda w: w[head_idx], weights) - s_h = fastmath.nested_map(lambda s: s[idx], state) - - def forward_fn(i_h, w_h): - return forward_unbatched( - *i_h, weights=w_h, state=fastmath.stop_gradient(s_h)) + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + self._n_hashes = n_hashes + self._n_buckets = n_buckets + self._max_length_for_buckets = max_length_for_buckets + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def init_weights_and_state(self, input_signature): + if not isinstance(input_signature, (tuple, list)): + input_signature = (input_signature,) + input_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), input_signature + ) + batch_size = int(input_signature[0].shape[0]) + + weights = [] + weight_rngs = fastmath.random.split(self.rng, self._n_heads) + for i in range(self._n_heads): + weights.append( + self.create_weights_unbatched(input_signature_unbatched, weight_rngs[i]) + ) + state = [] + state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) + for i in range(self._n_heads * batch_size): + state.append( + self.create_state_unbatched(input_signature_unbatched, state_rngs[i]) + ) + + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) - ct_h = output_grad[example_idx] - assert o_h.shape == ct_h.shape - i_ct_h, w_ct_h = backward_fn(ct_h) + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + input_signature, + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = tuple(weights) + + def create_weights_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + d_model = input_signature.shape[-1] + rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) + w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) + if not self._share_qk: + w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) + w_v = self._kernel_initializer((d_model, self._d_v), rng_v) + w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) + + if self._bias: + b_q = np.zeros(self._d_qk) + b_v = np.zeros(self._d_v) + if self._share_qk: + return (w_q, w_v, w_o, b_q, b_v) + else: + b_k = np.zeros(self._d_qk) + return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) + + if self._share_qk: + return (w_q, w_v, w_o) else: - o_h, s_h = forward_fn(i_h, w_h) - - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_h) - if update_state: - s_all = tree_update(s_all, idx, s_h) - if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) - w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) - return (o_all, s_all, i_ct_all, w_ct_all) - elif n_parallel_heads < self._n_heads: - assert self._n_heads % n_parallel_heads == 0 - def run_inner(idx, loop_val): - """Runs one slice of attention (multiple heads, but one example).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * self._n_parallel_heads - example_idx = idx // self._n_heads - head_idx_lo = idx % self._n_heads - head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_mh = fastmath.nested_map(lambda w: w[head_range], weights) - s_mh = fastmath.nested_map(lambda s: s[state_range], state) - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - def forward_fn(i_mh, w_mh): - o_mh, new_s_mh = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0)( - i_mh, w_mh, s_mh) - o_mh = np.sum(o_mh, axis=0) - return o_mh, new_s_mh - - if compute_grad: - o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) - ct_mh = output_grad[example_idx] - assert o_mh.shape == ct_mh.shape - i_ct_mh, w_ct_mh = backward_fn(ct_mh) + return (w_q, w_k, w_v, w_o) + + def create_state_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + # The `rng` argument passed to forward_unbatched is shared across all + # examples and heads. This facilitates using broadcasted dropout, which + # saves memory and hasn't been shown to hurt model quality. Even though the + # same sharing is likely to be safe when selecting random hash functions + # for LSH, we haven't run experiments to demonstrate this. To be on the safe + # side we include a per-head RNG in the state for the purpose of doing LSH. + if not self._incremental: + length = self._max_length_for_buckets or input_signature.shape[0] + buckets = np.zeros(self._n_hashes * length, dtype=np.int32) + return (buckets, rng) else: - o_mh, s_mh = forward_fn(i_mh, w_mh) + buckets = np.zeros(self._n_hashes * self._predict_mem_len, dtype=np.int32) + buckets_idx = np.zeros((), dtype=np.int32) + return (buckets, buckets_idx, rng) + + def hash_vectors(self, vecs, rng, mask=None): + n_buckets_list = self._n_buckets + + # Determine the number of buckets needed from input length if not set. + if n_buckets_list is None: + length = vecs.shape[0] + n_buckets = 2 * max(1, length // self._chunk_len) + if n_buckets <= 128: + n_buckets_list = n_buckets + else: # Factorize n_buckets. + n_buckets_div = 2 ** math.ceil(math.log2(math.sqrt(n_buckets))) + # Both factors must be even. + n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) + n_buckets_list = [n_buckets_div, n_buckets_rest] + + # Hash vectors. + buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) + + if mask is not None: + n_buckets += 1 # Create an extra bucket for padding tokens only + buckets = np.where(mask[None, :], buckets, n_buckets - 1) + + # buckets is now (n_hashes, seqlen). Next we add offsets so that + # bucket numbers from different hashing rounds don't overlap. + offsets = np.arange(self._n_hashes, dtype=np.int32) + offsets = np.reshape(offsets * n_buckets, (-1, 1)) + buckets = np.reshape(buckets + offsets, (-1,)) + return buckets + + def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state): + attend_rng, output_rng = fastmath.random.split(rng) + w_q, w_v, w_o = weights - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_mh) - if update_state: - s_all = tree_update(s_all, state_range, s_mh) - if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) - w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) - return (o_all, s_all, i_ct_all, w_ct_all) - else: - assert n_parallel_heads % self._n_heads == 0 - def forward_single_example(i_x, w_all, s_x): - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - o_x, s_x = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0))( - i_x, w_all, s_x) - o_x = np.sum(o_x, axis=0) - return o_x, s_x - def run_inner(idx, loop_val): - """Runs one slice of attention (all heads for one or more examples).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * n_parallel_heads - example_idx_lo = idx // self._n_heads - example_range = example_idx_lo + np.arange( - n_parallel_heads // self._n_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) - s_mex = fastmath.nested_map( - lambda s: np.reshape(s[state_range], # pylint: disable=g-long-lambda - (-1, self._n_heads) + s.shape[1:]), - state) - def forward_fn(i_mex, w_all): - o_mex, new_s_mex = fastmath.vmap( - forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0))( - i_mex, w_all, s_mex) - new_s_mex = fastmath.nested_map( - lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), - new_s_mex) - return o_mex.astype(i_mex[0].dtype), new_s_mex + q = np.matmul(x, w_q) + v = np.matmul(x, w_v) - if compute_grad: - o_mex, backward_fn, s_mex = vjp(forward_fn, i_mex, weights, - has_aux=True) - ct_mex = output_grad[example_range] - assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) - assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) - i_ct_mex, w_ct_mex = backward_fn(ct_mex) + if update_state: + _, old_hash_rng = state + hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) + buckets = self.hash_vectors(q, hash_subrng, mask) + s_buckets = buckets + if self._max_length_for_buckets: + length = self._n_hashes * self._max_length_for_buckets + if buckets.shape[0] < length: + s_buckets = np.concatenate( + [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], + axis=0, + ) + state = (s_buckets, hash_rng) else: - o_mex, s_mex = forward_fn(i_mex, weights) + buckets, _ = state + if self._max_length_for_buckets: + buckets = buckets[: self._n_hashes * x.shape[0]] - if compute_output: - o_all = fastmath.index_add(o_all, jax.numpy.index_exp[example_range], - o_mex) - if update_state: - s_all = tree_update(s_all, state_range, s_mex) - if compute_grad: - i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) - w_ct_all = fastmath.nested_map_multiarg( - lambda old_all, delta_all: old_all + delta_all, - w_ct_all, w_ct_mex) - return (o_all, s_all, i_ct_all, w_ct_all) - - o_all = s_all = i_ct_all = w_ct_all = None - if compute_output: - o_all = np.zeros( - (batch_size, seqlen, d_model), dtype=inputs[0].dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - w_ct_all = fastmath.nested_map(np.zeros_like, weights) - - loop_val = (o_all, s_all, i_ct_all, w_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) + seqlen = x.shape[0] + assert int(buckets.shape[0]) == self._n_hashes * seqlen - (o_all, s_all, i_ct_all, w_ct_all) = loop_val + ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) + buckets_and_t = seqlen * buckets + (ticker % seqlen) + buckets_and_t = fastmath.stop_gradient(buckets_and_t) - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + # Hash-based sort ("s" at the start of variable names means "sorted") + sbuckets_and_t, sticker = fastmath.sort_key_val( + buckets_and_t, ticker, dimension=-1 + ) + _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) + sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) + sticker = fastmath.stop_gradient(sticker) + undo_sort = fastmath.stop_gradient(undo_sort) + + st = sticker % seqlen + sq = np.take(q, st, axis=0) + sv = np.take(v, st, axis=0) + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=True, + masked=self._masked, + ) + q_info = st + + assert (mask is not None) == self._masked + kv_info = None + if self._masked: + # mask is a boolean array (True means "is valid token") + smask = np.take(mask, st, axis=0) + ones_like_mask = np.ones_like(smask, dtype=np.int32) + kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) + + so, slogits = attend( + sq, + k=None, + v=sv, + q_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would + # also work, but these helpers include performance optimizations for TPU. + o = permute_via_gather(so, undo_sort, sticker, axis=0) + logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) - if have_single_input and compute_grad: - assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 - return (o_all, s_all, i_ct_all[0], w_ct_all) - else: - return (o_all, s_all, i_ct_all, w_ct_all) + if self._n_hashes > 1: + o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) + logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) + probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) + o = np.sum(o * probs, axis=0) + assert o.shape == (seqlen, w_v.shape[-1]) + out = np.matmul(o, w_o) + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state -class LSHSelfAttention(base.Layer): - """LSH self-attention (second implementation).""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, share_qk='unused', - causal=False, - masked=False, - chunk_len=128, n_chunks_before=1, n_chunks_after=0, - n_hashes=1, - n_buckets=None, - mode='train', - predict_mem_len=2048, predict_drop_len=256, - attention_dropout=0.0, - output_dropout=0.0, - max_length_for_buckets=None, - bias=False, - n_parallel_heads=1, - use_python_loop=False, - use_reference_code=False, - ): - """Construct an LSH self-attention layer.""" - super().__init__(n_in=(2 if masked else 1), n_out=1) - - self._n_heads = n_heads - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads - else: - self._n_parallel_heads = None - - self._incremental = (mode == 'predict') - if self._incremental: - assert causal, 'Only causal attention supports fast inference' - assert chunk_len is not None or (predict_mem_len and predict_drop_len) - predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) - predict_drop_len = predict_drop_len or chunk_len - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - self._d_qk = d_qk - self._d_v = d_v - self._share_qk = True - self._causal = causal - self._masked = masked - self._chunk_len = chunk_len - self._n_chunks_before = n_chunks_before - self._n_chunks_after = n_chunks_after - self._bias = bias - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - self._n_hashes = n_hashes - self._n_buckets = n_buckets - self._max_length_for_buckets = max_length_for_buckets - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def init_weights_and_state(self, input_signature): - if not isinstance(input_signature, (tuple, list)): - input_signature = (input_signature,) - input_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - input_signature) - batch_size = int(input_signature[0].shape[0]) - - weights = [] - weight_rngs = fastmath.random.split(self.rng, self._n_heads) - for i in range(self._n_heads): - weights.append(self.create_weights_unbatched(input_signature_unbatched, - weight_rngs[i])) - state = [] - state_rngs = fastmath.random.split(self.rng, self._n_heads * batch_size) - for i in range(self._n_heads * batch_size): - state.append(self.create_state_unbatched(input_signature_unbatched, - state_rngs[i])) - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - weights = fastmath.nested_map_multiarg(stack_along_axis_0, *weights) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - input_signature) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = tuple(weights) - - def create_weights_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - d_model = input_signature.shape[-1] - rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) - w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) - if not self._share_qk: - w_k = self._kernel_initializer((d_model, self._d_qk), rng_k) - w_v = self._kernel_initializer((d_model, self._d_v), rng_v) - w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) - - if self._bias: - b_q = np.zeros(self._d_qk) - b_v = np.zeros(self._d_v) - if self._share_qk: - return (w_q, w_v, w_o, b_q, b_v) - else: - b_k = np.zeros(self._d_qk) - return (w_q, w_k, w_v, w_o, b_q, b_k, b_v) - - if self._share_qk: - return (w_q, w_v, w_o) - else: - return (w_q, w_k, w_v, w_o) - - def create_state_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - # The `rng` argument passed to forward_unbatched is shared across all - # examples and heads. This facilitates using broadcasted dropout, which - # saves memory and hasn't been shown to hurt model quality. Even though the - # same sharing is likely to be safe when selecting random hash functions - # for LSH, we haven't run experiments to demonstrate this. To be on the safe - # side we include a per-head RNG in the state for the purpose of doing LSH. - if not self._incremental: - length = self._max_length_for_buckets or input_signature.shape[0] - buckets = np.zeros(self._n_hashes * length, dtype=np.int32) - return (buckets, rng) - else: - buckets = np.zeros( - self._n_hashes * self._predict_mem_len, dtype=np.int32) - buckets_idx = np.zeros((), dtype=np.int32) - return (buckets, buckets_idx, rng) - - def hash_vectors(self, vecs, rng, mask=None): - n_buckets_list = self._n_buckets - - # Determine the number of buckets needed from input length if not set. - if n_buckets_list is None: - length = vecs.shape[0] - n_buckets = 2 * max(1, length // self._chunk_len) - if n_buckets <= 128: - n_buckets_list = n_buckets - else: # Factorize n_buckets. - n_buckets_div = 2**math.ceil(math.log2(math.sqrt(n_buckets))) - # Both factors must be even. - n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) - n_buckets_list = [n_buckets_div, n_buckets_rest] - - # Hash vectors. - buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) - - if mask is not None: - n_buckets += 1 # Create an extra bucket for padding tokens only - buckets = np.where(mask[None, :], buckets, n_buckets - 1) - - # buckets is now (n_hashes, seqlen). Next we add offsets so that - # bucket numbers from different hashing rounds don't overlap. - offsets = np.arange(self._n_hashes, dtype=np.int32) - offsets = np.reshape(offsets * n_buckets, (-1, 1)) - buckets = np.reshape(buckets + offsets, (-1,)) - return buckets - - def forward_unbatched(self, x, mask=None, *, weights, state, rng, - update_state): - attend_rng, output_rng = fastmath.random.split(rng) - w_q, w_v, w_o = weights - - q = np.matmul(x, w_q) - v = np.matmul(x, w_v) - - if update_state: - _, old_hash_rng = state - hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) - buckets = self.hash_vectors(q, hash_subrng, mask) - s_buckets = buckets - if self._max_length_for_buckets: - length = self._n_hashes * self._max_length_for_buckets - if buckets.shape[0] < length: - s_buckets = np.concatenate( - [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], - axis=0) - state = (s_buckets, hash_rng) - else: - buckets, _ = state - if self._max_length_for_buckets: - buckets = buckets[:self._n_hashes * x.shape[0]] - - seqlen = x.shape[0] - assert int(buckets.shape[0]) == self._n_hashes * seqlen - - ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) - buckets_and_t = seqlen * buckets + (ticker % seqlen) - buckets_and_t = fastmath.stop_gradient(buckets_and_t) - - # Hash-based sort ("s" at the start of variable names means "sorted") - sbuckets_and_t, sticker = fastmath.sort_key_val( - buckets_and_t, ticker, dimension=-1) - _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) - sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) - sticker = fastmath.stop_gradient(sticker) - undo_sort = fastmath.stop_gradient(undo_sort) - - st = (sticker % seqlen) - sq = np.take(q, st, axis=0) - sv = np.take(v, st, axis=0) - - mask_fn = functools.partial(mask_self_attention, causal=self._causal, - exclude_self=True, masked=self._masked) - q_info = st - - assert (mask is not None) == self._masked - kv_info = None - if self._masked: - # mask is a boolean array (True means "is valid token") - smask = np.take(mask, st, axis=0) - ones_like_mask = np.ones_like(smask, dtype=np.int32) - kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) - - so, slogits = attend( - sq, k=None, v=sv, - q_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + def _incremental_forward_unbatched( + self, x, *, q_start, q_len, weights, state, rng, update_state + ): + assert ( + update_state + ), "This setting not supported (e.g. no backprop for fast inference)" + if q_len > 1: + if isinstance(q_start, int): + assert q_start == 0, "Chunks larger than 1 only work at start for now." + if x.shape[0] % self._chunk_len == 0: + x_padded = x + else: + pad_amount = self._chunk_len - (x.shape[0] % self._chunk_len) + x_padded = np.pad(x, ((0, pad_amount), (0, 0)), mode="constant") + buckets, buckets_idx, hash_rng = state + q = np.matmul(x_padded, weights[0]) + buckets_update = self.hash_vectors(q, hash_rng) + + out, _ = self.forward_unbatched( + x_padded, + weights=weights, + state=(buckets_update, hash_rng), + rng=rng, + update_state=False, + ) + + out = out[:q_len] + buckets = np.reshape(buckets, (self._n_hashes, -1)) + buckets_update = np.reshape(buckets_update, (self._n_hashes, -1))[:, :q_len] + if q_len > self._predict_mem_len: + buckets_update = buckets_update[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + buckets = fastmath.dynamic_update_slice_in_dim( + buckets, buckets_update, q_start, axis=1 + ) + buckets = np.reshape(buckets, (-1,)) + + return out, (buckets, buckets_idx + q_len, hash_rng) + + # This codepath is for handling one token at a time. + assert q_len == 1 + buckets, buckets_idx, hash_rng = state + + def roll_buckets(buckets): + buckets = np.reshape(buckets, (self._n_hashes, -1)) + new_buckets = np.concatenate( + [ + buckets, + np.zeros( + (self._n_hashes, self._predict_drop_len), dtype=buckets.dtype + ), + ], + axis=1, + ) + new_buckets = fastmath.dynamic_slice_in_dim( + new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1 + ) + new_buckets = np.reshape(new_buckets, (-1,)) + return new_buckets + + buckets = fastmath.cond( + pred=buckets_idx > q_start, + true_operand=buckets, + true_fun=roll_buckets, + false_operand=buckets, + false_fun=lambda x: x, ) - # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would - # also work, but these helpers include performance optimizations for TPU. - o = permute_via_gather(so, undo_sort, sticker, axis=0) - logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) - - if self._n_hashes > 1: - o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) - logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) - probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) - o = np.sum(o * probs, axis=0) - - assert o.shape == (seqlen, w_v.shape[-1]) - out = np.matmul(o, w_o) - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state - - def _incremental_forward_unbatched(self, x, *, - q_start, q_len, - weights, state, rng, update_state): - assert update_state, ( - 'This setting not supported (e.g. no backprop for fast inference)') - if q_len > 1: - if isinstance(q_start, int): - assert q_start == 0, 'Chunks larger than 1 only work at start for now.' - if x.shape[0] % self._chunk_len == 0: - x_padded = x - else: - pad_amount = self._chunk_len - (x.shape[0] % self._chunk_len) - x_padded = np.pad(x, ((0, pad_amount), (0, 0)), mode='constant') - buckets, buckets_idx, hash_rng = state - q = np.matmul(x_padded, weights[0]) - buckets_update = self.hash_vectors(q, hash_rng) - - out, _ = self.forward_unbatched( - x_padded, weights=weights, state=(buckets_update, hash_rng), - rng=rng, update_state=False) - - out = out[:q_len] - buckets = np.reshape(buckets, (self._n_hashes, -1)) - buckets_update = np.reshape( - buckets_update, (self._n_hashes, -1))[:, :q_len] - if q_len > self._predict_mem_len: - buckets_update = buckets_update[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - buckets = fastmath.dynamic_update_slice_in_dim( - buckets, buckets_update, q_start, axis=1) - buckets = np.reshape(buckets, (-1,)) - - return out, (buckets, buckets_idx + q_len, hash_rng) - - # This codepath is for handling one token at a time. - assert q_len == 1 - buckets, buckets_idx, hash_rng = state - - def roll_buckets(buckets): - buckets = np.reshape(buckets, (self._n_hashes, -1)) - new_buckets = np.concatenate( - [buckets, np.zeros((self._n_hashes, self._predict_drop_len), - dtype=buckets.dtype) - ], axis=1) - new_buckets = fastmath.dynamic_slice_in_dim( - new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1) - new_buckets = np.reshape(new_buckets, (-1,)) - return new_buckets - - buckets = fastmath.cond( - pred=buckets_idx > q_start, - true_operand=buckets, - true_fun=roll_buckets, - false_operand=buckets, - false_fun=lambda x: x, - ) + attend_rng, output_rng = fastmath.random.split(rng) + w_q, w_v, w_o = weights - attend_rng, output_rng = fastmath.random.split(rng) - w_q, w_v, w_o = weights - - q_range = q_start + np.arange(q_len, dtype=np.int32) - # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not - # floating-point equivalent, at least in non-jitted code. We correct the - # discrepancy by duplicating the slice. Floating-point noise may not be - # an issue when using models, but it makes it harder to write tests that - # compare fast and slow inference code for equivalence. - q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) - - q_buckets = self.hash_vectors(q, hash_rng) - q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - - unflattened_buckets = fastmath.dynamic_update_slice_in_dim( - np.reshape(buckets, (self._n_hashes, -1)), - q_buckets, q_start, axis=1) - buckets = np.reshape(unflattened_buckets, (-1,)) - is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) - - assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 - seqlen = x.shape[0] - arange_seqlen = np.arange(seqlen, dtype=np.int32) - kv_priorities = np.where( - arange_seqlen > (q_start + q_len), - -(seqlen + arange_seqlen), arange_seqlen) - kv_priorities = kv_priorities + seqlen * is_valid_target.astype(np.int32) - _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) - kv_indices = kv_indices[ - -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before):] - assert self._n_chunks_after == 0 - - x_attend_to = x[kv_indices] - k = length_normalized(np.matmul(x_attend_to, w_q)) - v = np.matmul(x_attend_to, w_v) - - mask_fn = functools.partial( - mask_self_attention, causal=True, masked=True, exclude_self=True) - q_info = q_start + np.arange(q_len, dtype=np.int32) - kv_info = kv_indices.astype(np.int32) - q_info = q_info.astype(np.int32) - # TODO(kitaev): is it better to mask out attention across buckets? - # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, - ) + q_range = q_start + np.arange(q_len, dtype=np.int32) + # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not + # floating-point equivalent, at least in non-jitted code. We correct the + # discrepancy by duplicating the slice. Floating-point noise may not be + # an issue when using models, but it makes it harder to write tests that + # compare fast and slow inference code for equivalence. + q = np.matmul(np.concatenate([x[q_range]] * 2, 0), w_q) - out = np.matmul(o, w_o) - if q_len == 1: - out = out[:1] - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) - return out, (buckets, buckets_idx, hash_rng) + q_buckets = self.hash_vectors(q, hash_rng) + q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. + unflattened_buckets = fastmath.dynamic_update_slice_in_dim( + np.reshape(buckets, (self._n_hashes, -1)), q_buckets, q_start, axis=1 + ) + buckets = np.reshape(unflattened_buckets, (-1,)) + is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) + + assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 + seqlen = x.shape[0] + arange_seqlen = np.arange(seqlen, dtype=np.int32) + kv_priorities = np.where( + arange_seqlen > (q_start + q_len), -(seqlen + arange_seqlen), arange_seqlen + ) + kv_priorities = kv_priorities + seqlen * is_valid_target.astype(np.int32) + _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) + kv_indices = kv_indices[ + -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before) : + ] + assert self._n_chunks_after == 0 + + x_attend_to = x[kv_indices] + k = length_normalized(np.matmul(x_attend_to, w_q)) + v = np.matmul(x_attend_to, w_v) + + mask_fn = functools.partial( + mask_self_attention, causal=True, masked=True, exclude_self=True + ) + q_info = q_start + np.arange(q_len, dtype=np.int32) + kv_info = kv_indices.astype(np.int32) + q_info = q_info.astype(np.int32) + # TODO(kitaev): is it better to mask out attention across buckets? + # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - Args: - inputs: Layer inputs (subclasses may use different inputs) + out = np.matmul(o, w_o) + if q_len == 1: + out = out[:1] + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) + return out, (buckets, buckets_idx, hash_rng) + + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. + + Args: + inputs: Layer inputs (subclasses may use different inputs) + + Returns: + A tuple (output, new_state). + """ + weights, state, rng = self.weights, self.state, self.rng + if not self._use_reference_code: + # By default, an efficient, batched implementation is used. + output, new_state, _, _ = self.forward_and_or_backward( + inputs, weights, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + # The reference implementation below provides a more readable overview of + # what this class does. It's not optimized, however, and should only be used + # when testing this class for correctness. + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] - Returns: - A tuple (output, new_state). - """ - weights, state, rng = self.weights, self.state, self.rng - if not self._use_reference_code: - # By default, an efficient, batched implementation is used. - output, new_state, _, _ = self.forward_and_or_backward( - inputs, weights, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - # The reference implementation below provides a more readable overview of - # what this class does. It's not optimized, however, and should only be used - # when testing this class for correctness. - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - if self._incremental: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - - output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] - new_state = [] - for example_idx in range(batch_size): - for head_idx in range(self._n_heads): - # pylint: disable=cell-var-from-loop - single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) - single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) - single_state = fastmath.nested_map( - lambda s: s[example_idx * self._n_heads + head_idx], state) - # pylint: enable=cell-var-from-loop if self._incremental: - single_out, single_new_state = self._incremental_forward_unbatched( - *single_inputs, q_start=q_start, q_len=seqlen, - weights=single_weights, rng=rng, - state=single_state, update_state=True) - else: - single_out, single_new_state = self.forward_unbatched( - *single_inputs, weights=single_weights, rng=rng, - state=single_state, update_state=True) - new_state.append(single_new_state) - output_accum[example_idx] = output_accum[example_idx] + single_out - - output = np.stack(output_accum, 0) - if new_state and fastmath.tree_leaves(new_state[0]): - new_state = fastmath.nested_map_multiarg( - lambda *s: np.stack(s, 0), *new_state) - else: - new_state = state - if self._incremental: - new_state = (new_mem_end, new_mem, new_state) - self.state = tuple(new_state) - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + + output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)] + new_state = [] + for example_idx in range(batch_size): + for head_idx in range(self._n_heads): + # pylint: disable=cell-var-from-loop + single_inputs = fastmath.nested_map(lambda x: x[example_idx], inputs) + single_weights = fastmath.nested_map(lambda w: w[head_idx], weights) + single_state = fastmath.nested_map( + lambda s: s[example_idx * self._n_heads + head_idx], state + ) + # pylint: enable=cell-var-from-loop + if self._incremental: + single_out, single_new_state = self._incremental_forward_unbatched( + *single_inputs, + q_start=q_start, + q_len=seqlen, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + else: + single_out, single_new_state = self.forward_unbatched( + *single_inputs, + weights=single_weights, + rng=rng, + state=single_state, + update_state=True, + ) + new_state.append(single_new_state) + output_accum[example_idx] = output_accum[example_idx] + single_out + + output = np.stack(output_accum, 0) + if new_state and fastmath.tree_leaves(new_state[0]): + new_state = fastmath.nested_map_multiarg( + lambda *s: np.stack(s, 0), *new_state + ) else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type + new_state = state + if self._incremental: + new_state = (new_mem_end, new_mem, new_state) + self.state = tuple(new_state) + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - # Use an efficient backward pass, unless we're running the reference code. - return not self._use_reference_code - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - assert not self._use_reference_code - del output, state, kwargs - _, _, inputs_grad, weights_grad = self.forward_and_or_backward( - inputs, weights, new_state, rng, output_grad=grad, - compute_output=False, update_state=False) - return inputs_grad, weights_grad - - def forward_and_or_backward( - self, inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. - - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. - - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + # Use an efficient backward pass, unless we're running the reference code. + return not self._use_reference_code + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + assert not self._use_reference_code + del output, state, kwargs + _, _, inputs_grad, weights_grad = self.forward_and_or_backward( + inputs, + weights, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) + return inputs_grad, weights_grad - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(kitaev): profile ~4% speed drop compared to previous implementation - # in some conditions. Other conditions (e.g. the enwik8 model) appear - # to have the same overall training speed. - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - have_single_input = not isinstance(inputs, (tuple, list)) - if have_single_input: - inputs = (inputs,) - batch_size = int(inputs[0].shape[0]) - seqlen = inputs[0].shape[-2] - d_model = inputs[0].shape[-1] - - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' - - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, new_values) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads == 1: - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - example_idx = idx // self._n_heads - head_idx = idx % self._n_heads - - i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_h = fastmath.nested_map(lambda w: w[head_idx], weights) - s_h = fastmath.nested_map(lambda s: s[idx], state) - - def forward_fn(i_h, w_h): - return forward_unbatched( - *i_h, weights=w_h, state=fastmath.stop_gradient(s_h)) + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(kitaev): profile ~4% speed drop compared to previous implementation + # in some conditions. Other conditions (e.g. the enwik8 model) appear + # to have the same overall training speed. + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + have_single_input = not isinstance(inputs, (tuple, list)) + if have_single_input: + inputs = (inputs,) + batch_size = int(inputs[0].shape[0]) + seqlen = inputs[0].shape[-2] + d_model = inputs[0].shape[-1] + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) + else: + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) - ct_h = output_grad[example_idx] - assert o_h.shape == ct_h.shape - i_ct_h, w_ct_h = backward_fn(ct_h) - else: - o_h, s_h = forward_fn(i_h, w_h) + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads == 1: + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + example_idx = idx // self._n_heads + head_idx = idx % self._n_heads + + i_h = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_h = fastmath.nested_map(lambda w: w[head_idx], weights) + s_h = fastmath.nested_map(lambda s: s[idx], state) + + def forward_fn(i_h, w_h): + return forward_unbatched( + *i_h, weights=w_h, state=fastmath.stop_gradient(s_h) + ) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, w_h, has_aux=True) + ct_h = output_grad[example_idx] + assert o_h.shape == ct_h.shape + i_ct_h, w_ct_h = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h, w_h) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) + w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) + return (o_all, s_all, i_ct_all, w_ct_all) + + elif n_parallel_heads < self._n_heads: + assert self._n_heads % n_parallel_heads == 0 + + def run_inner(idx, loop_val): + """Runs one slice of attention (multiple heads, but one example).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * self._n_parallel_heads + example_idx = idx // self._n_heads + head_idx_lo = idx % self._n_heads + head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) + w_mh = fastmath.nested_map(lambda w: w[head_range], weights) + s_mh = fastmath.nested_map(lambda s: s[state_range], state) + + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + def forward_fn(i_mh, w_mh): + o_mh, new_s_mh = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0 + )(i_mh, w_mh, s_mh) + o_mh = np.sum(o_mh, axis=0) + return o_mh, new_s_mh + + if compute_grad: + o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) + ct_mh = output_grad[example_idx] + assert o_mh.shape == ct_mh.shape + i_ct_mh, w_ct_mh = backward_fn(ct_mh) + else: + o_mh, s_mh = forward_fn(i_mh, w_mh) + + if compute_output: + o_all = fastmath.index_add(o_all, example_idx, o_mh) + if update_state: + s_all = tree_update(s_all, state_range, s_mh) + if compute_grad: + i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) + w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) + return (o_all, s_all, i_ct_all, w_ct_all) + else: + assert n_parallel_heads % self._n_heads == 0 + + def forward_single_example(i_x, w_all, s_x): + def forward_unbatched_h(i_h, w_h, s_h): + return forward_unbatched(*i_h, weights=w_h, state=s_h) + + o_x, s_x = fastmath.vmap( + forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0) + )(i_x, w_all, s_x) + o_x = np.sum(o_x, axis=0) + return o_x, s_x + + def run_inner(idx, loop_val): + """Runs one slice of attention (all heads for one or more examples).""" + o_all, s_all, i_ct_all, w_ct_all = loop_val + idx = idx * n_parallel_heads + example_idx_lo = idx // self._n_heads + example_range = example_idx_lo + np.arange( + n_parallel_heads // self._n_heads, dtype=np.int32 + ) + state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) + + i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) + s_mex = fastmath.nested_map( + lambda s: np.reshape( + s[state_range], # pylint: disable=g-long-lambda + (-1, self._n_heads) + s.shape[1:], + ), + state, + ) + + def forward_fn(i_mex, w_all): + o_mex, new_s_mex = fastmath.vmap( + forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0) + )(i_mex, w_all, s_mex) + new_s_mex = fastmath.nested_map( + lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), + new_s_mex, + ) + return o_mex.astype(i_mex[0].dtype), new_s_mex + + if compute_grad: + o_mex, backward_fn, s_mex = vjp( + forward_fn, i_mex, weights, has_aux=True + ) + ct_mex = output_grad[example_range] + assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) + assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) + i_ct_mex, w_ct_mex = backward_fn(ct_mex) + else: + o_mex, s_mex = forward_fn(i_mex, weights) + + if compute_output: + o_all = fastmath.index_add( + o_all, jax.numpy.index_exp[example_range], o_mex + ) + if update_state: + s_all = tree_update(s_all, state_range, s_mex) + if compute_grad: + i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) + w_ct_all = fastmath.nested_map_multiarg( + lambda old_all, delta_all: old_all + delta_all, + w_ct_all, + w_ct_mex, + ) + return (o_all, s_all, i_ct_all, w_ct_all) + + o_all = s_all = i_ct_all = w_ct_all = None if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_h) + o_all = np.zeros((batch_size, seqlen, d_model), dtype=inputs[0].dtype) if update_state: - s_all = tree_update(s_all, idx, s_h) + s_all = state if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_h) - w_ct_all = tree_add(w_ct_all, head_idx, w_ct_h) - return (o_all, s_all, i_ct_all, w_ct_all) - elif n_parallel_heads < self._n_heads: - assert self._n_heads % n_parallel_heads == 0 - def run_inner(idx, loop_val): - """Runs one slice of attention (multiple heads, but one example).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * self._n_parallel_heads - example_idx = idx // self._n_heads - head_idx_lo = idx % self._n_heads - head_range = head_idx_lo + np.arange(n_parallel_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mh = fastmath.nested_map(lambda x: x[example_idx], inputs) - w_mh = fastmath.nested_map(lambda w: w[head_range], weights) - s_mh = fastmath.nested_map(lambda s: s[state_range], state) - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - def forward_fn(i_mh, w_mh): - o_mh, new_s_mh = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=0)( - i_mh, w_mh, s_mh) - o_mh = np.sum(o_mh, axis=0) - return o_mh, new_s_mh + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) + w_ct_all = fastmath.nested_map(np.zeros_like, weights) - if compute_grad: - o_mh, backward_fn, s_mh = vjp(forward_fn, i_mh, w_mh, has_aux=True) - ct_mh = output_grad[example_idx] - assert o_mh.shape == ct_mh.shape - i_ct_mh, w_ct_mh = backward_fn(ct_mh) + loop_val = (o_all, s_all, i_ct_all, w_ct_all) + + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) else: - o_mh, s_mh = forward_fn(i_mh, w_mh) + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - if compute_output: - o_all = fastmath.index_add(o_all, example_idx, o_mh) - if update_state: - s_all = tree_update(s_all, state_range, s_mh) - if compute_grad: - i_ct_all = tree_add(i_ct_all, example_idx, i_ct_mh) - w_ct_all = tree_add(w_ct_all, head_range, w_ct_mh) - return (o_all, s_all, i_ct_all, w_ct_all) - else: - assert n_parallel_heads % self._n_heads == 0 - def forward_single_example(i_x, w_all, s_x): - def forward_unbatched_h(i_h, w_h, s_h): - return forward_unbatched(*i_h, weights=w_h, state=s_h) - o_x, s_x = fastmath.vmap( - forward_unbatched_h, in_axes=(None, 0, 0), out_axes=(0, 0))( - i_x, w_all, s_x) - o_x = np.sum(o_x, axis=0) - return o_x, s_x - def run_inner(idx, loop_val): - """Runs one slice of attention (all heads for one or more examples).""" - o_all, s_all, i_ct_all, w_ct_all = loop_val - idx = idx * n_parallel_heads - example_idx_lo = idx // self._n_heads - example_range = example_idx_lo + np.arange( - n_parallel_heads // self._n_heads, dtype=np.int32) - state_range = idx + np.arange(n_parallel_heads, dtype=np.int32) - - i_mex = fastmath.nested_map(lambda x: x[example_range], inputs) - s_mex = fastmath.nested_map( - lambda s: np.reshape(s[state_range], # pylint: disable=g-long-lambda - (-1, self._n_heads) + s.shape[1:]), - state) - def forward_fn(i_mex, w_all): - o_mex, new_s_mex = fastmath.vmap( - forward_single_example, in_axes=(0, None, 0), out_axes=(0, 0))( - i_mex, w_all, s_mex) - new_s_mex = fastmath.nested_map( - lambda s: np.reshape(s, (n_parallel_heads,) + s.shape[2:]), - new_s_mex) - return o_mex.astype(i_mex[0].dtype), new_s_mex + (o_all, s_all, i_ct_all, w_ct_all) = loop_val if compute_grad: - o_mex, backward_fn, s_mex = vjp(forward_fn, i_mex, weights, - has_aux=True) - ct_mex = output_grad[example_range] - assert o_mex.shape == ct_mex.shape, str(ct_mex.shape) - assert o_mex.dtype == ct_mex.dtype, str(ct_mex.dtype) - i_ct_mex, w_ct_mex = backward_fn(ct_mex) + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) + + if have_single_input and compute_grad: + assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 + return (o_all, s_all, i_ct_all[0], w_ct_all) else: - o_mex, s_mex = forward_fn(i_mex, weights) + return (o_all, s_all, i_ct_all, w_ct_all) - if compute_output: - o_all = fastmath.index_add(o_all, jax.numpy.index_exp[example_range], - o_mex) - if update_state: - s_all = tree_update(s_all, state_range, s_mex) - if compute_grad: - i_ct_all = tree_update(i_ct_all, example_range, i_ct_mex) - w_ct_all = fastmath.nested_map_multiarg( - lambda old_all, delta_all: old_all + delta_all, - w_ct_all, w_ct_mex) - return (o_all, s_all, i_ct_all, w_ct_all) - - o_all = s_all = i_ct_all = w_ct_all = None - if compute_output: - o_all = np.zeros( - (batch_size, seqlen, d_model), dtype=inputs[0].dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - w_ct_all = fastmath.nested_map(np.zeros_like, weights) - - loop_val = (o_all, s_all, i_ct_all, w_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) - (o_all, s_all, i_ct_all, w_ct_all) = loop_val +class PureLSHSelfAttention(base.Layer): + """LSH self-attention without weights.""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + share_qk="unused", + causal=False, + masked=False, + chunk_len=128, + n_chunks_before=1, + n_chunks_after=0, + n_hashes=1, + n_buckets=None, + mode="train", + predict_mem_len=2048, + predict_drop_len=256, + attention_dropout=0.0, + output_dropout=0.0, + max_length_for_buckets=None, + bias=False, + n_parallel_heads=1, + use_python_loop=False, + use_reference_code=False, + ): + """Construct an LSH self-attention layer.""" + # (qk, v, mask) -> (o) if masked + # (qk, v) -> (o) otherwise + super().__init__(n_in=(3 if masked else 2), n_out=1) + + self._n_heads = n_heads + if n_parallel_heads: + if (n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) or ( + n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0 + ): + raise ValueError( + "n_parallel_heads must be a multiple or fraction of n_heads" + ) + self._n_parallel_heads = n_parallel_heads + else: + self._n_parallel_heads = None - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + self._incremental = mode == "predict" + if self._incremental: + assert causal, "Only causal attention supports fast inference" + assert chunk_len is not None or (predict_mem_len and predict_drop_len) + predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) + predict_drop_len = predict_drop_len or chunk_len + if predict_mem_len is None or predict_drop_len is None: + raise ValueError("This configuration does not support fast inference.") + if not 0 < predict_drop_len <= predict_mem_len: + raise ValueError( + "Bad parameter values: (predict_mem_len, predict_drop_len) = ", + predict_mem_len, + predict_drop_len, + ) + self._predict_mem_len = predict_mem_len + self._predict_drop_len = predict_drop_len + + self._use_python_loop = use_python_loop + self._use_reference_code = use_reference_code + + self._d_qk = d_qk + self._d_v = d_v + self._share_qk = True + self._causal = causal + self._masked = masked + self._chunk_len = chunk_len + self._n_chunks_before = n_chunks_before + self._n_chunks_after = n_chunks_after + self._bias = bias + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout + else: + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + self._n_hashes = n_hashes + self._n_buckets = n_buckets + self._max_length_for_buckets = max_length_for_buckets + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def init_weights_and_state(self, input_signature): + # input_signature should be the type signature of (qk, v, mask) or (qk, v) + expected_inputs = 3 if self._masked else 2 + if not ( + isinstance(input_signature, (tuple, list)) + and len(input_signature) == expected_inputs + ): + raise ValueError( + f"input_signature should be {expected_inputs}-tuple, " + f"but is: {input_signature}" + ) + + # Each of qk, v are shaped - (batch * heads, length, d_head) + # mask is shaped: (batch, length) + qk_signature = input_signature[0] + v_signature = input_signature[1] + # mask_signature = input_signature[2] + # batch = mask_signature.shape[0] + batch_x_heads = qk_signature.shape[0] + + assert batch_x_heads % self._n_heads == 0 + batch = batch_x_heads // self._n_heads + + query_signature_unbatched = fastmath.nested_map( + lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), qk_signature + ) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + state_rngs = fastmath.random.split(self.rng, batch_x_heads) + state = [ + self.create_state_unbatched(query_signature_unbatched, rng) + for rng in state_rngs + ] - if have_single_input and compute_grad: - assert isinstance(i_ct_all, tuple) and len(i_ct_all) == 1 - return (o_all, s_all, i_ct_all[0], w_ct_all) - else: - return (o_all, s_all, i_ct_all, w_ct_all) + stack_along_axis_0 = lambda *x: np.stack(x, axis=0) + state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) + if self._incremental: + mem = fastmath.nested_map( + lambda x: np.zeros( # pylint: disable=g-long-lambda + x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], dtype=x.dtype + ), + (qk_signature, v_signature), + ) + mem_end = np.zeros((), dtype=np.int32) + state = (mem_end, mem, state) + + self.state = tuple(state) + self.weights = () + + def create_state_unbatched(self, input_signature, rng): + if isinstance(input_signature, (tuple, list)): + input_signature = input_signature[0] + # The `rng` argument passed to forward_unbatched is shared across all + # examples and heads. This facilitates using broadcasted dropout, which + # saves memory and hasn't been shown to hurt model quality. Even though the + # same sharing is likely to be safe when selecting random hash functions + # for LSH, we haven't run experiments to demonstrate this. To be on the safe + # side we include a per-head RNG in the state for the purpose of doing LSH. + if not self._incremental: + length = self._max_length_for_buckets or input_signature.shape[0] + buckets = np.zeros(self._n_hashes * length, dtype=np.int32) + return (buckets, rng) + else: + buckets = np.zeros(self._n_hashes * self._predict_mem_len, dtype=np.int32) + buckets_idx = np.zeros((), dtype=np.int32) + return (buckets, buckets_idx, rng) + + def hash_vectors(self, vecs, rng, mask=None): + n_buckets_list = self._n_buckets + + # Determine the number of buckets needed from input length if not set. + if n_buckets_list is None: + length = vecs.shape[0] + n_buckets = 2 * max(1, length // self._chunk_len) + if n_buckets <= 128: + n_buckets_list = n_buckets + else: # Factorize n_buckets. + n_buckets_div = 2 ** math.ceil(math.log2(math.sqrt(n_buckets))) + # Both factors must be even. + n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) + n_buckets_list = [n_buckets_div, n_buckets_rest] + + # Hash vectors. + buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) + + if mask is not None: + n_buckets += 1 # Create an extra bucket for padding tokens only + buckets = np.where(mask[None, :], buckets, n_buckets - 1) + + # buckets is now (n_hashes, seqlen). Next we add offsets so that + # bucket numbers from different hashing rounds don't overlap. + offsets = np.arange(self._n_hashes, dtype=np.int32) + offsets = np.reshape(offsets * n_buckets, (-1, 1)) + buckets = np.reshape(buckets + offsets, (-1,)) + return buckets + + def forward_unbatched(self, qk, v, mask=None, *, state, rng, update_state): + attend_rng, output_rng = fastmath.random.split( + rng + ) # pylint: disable=unused-variable + + # Since these are unbatched: + # q, v are shaped (seqlen, d_head) + # mask is shaped (seqlen,) + q = qk + seqlen = q.shape[0] -class PureLSHSelfAttention(base.Layer): - """LSH self-attention without weights.""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, share_qk='unused', - causal=False, - masked=False, - chunk_len=128, - n_chunks_before=1, n_chunks_after=0, - n_hashes=1, - n_buckets=None, - mode='train', - predict_mem_len=2048, predict_drop_len=256, - attention_dropout=0.0, - output_dropout=0.0, - max_length_for_buckets=None, - bias=False, - n_parallel_heads=1, - use_python_loop=False, - use_reference_code=False, - ): - """Construct an LSH self-attention layer.""" - # (qk, v, mask) -> (o) if masked - # (qk, v) -> (o) otherwise - super().__init__(n_in=(3 if masked else 2), n_out=1) - - self._n_heads = n_heads - if n_parallel_heads: - if ((n_parallel_heads > n_heads and n_parallel_heads % n_heads != 0) - or (n_parallel_heads < n_heads and n_heads % n_parallel_heads != 0)): - raise ValueError( - 'n_parallel_heads must be a multiple or fraction of n_heads') - self._n_parallel_heads = n_parallel_heads - else: - self._n_parallel_heads = None - - self._incremental = (mode == 'predict') - if self._incremental: - assert causal, 'Only causal attention supports fast inference' - assert chunk_len is not None or (predict_mem_len and predict_drop_len) - predict_mem_len = predict_mem_len or (chunk_len * (1 + n_chunks_before)) - predict_drop_len = predict_drop_len or chunk_len - if predict_mem_len is None or predict_drop_len is None: - raise ValueError('This configuration does not support fast inference.') - if not 0 < predict_drop_len <= predict_mem_len: - raise ValueError( - 'Bad parameter values: (predict_mem_len, predict_drop_len) = ', - predict_mem_len, predict_drop_len) - self._predict_mem_len = predict_mem_len - self._predict_drop_len = predict_drop_len - - self._use_python_loop = use_python_loop - self._use_reference_code = use_reference_code - - self._d_qk = d_qk - self._d_v = d_v - self._share_qk = True - self._causal = causal - self._masked = masked - self._chunk_len = chunk_len - self._n_chunks_before = n_chunks_before - self._n_chunks_after = n_chunks_after - self._bias = bias - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - self._n_hashes = n_hashes - self._n_buckets = n_buckets - self._max_length_for_buckets = max_length_for_buckets - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def init_weights_and_state(self, input_signature): - # input_signature should be the type signature of (qk, v, mask) or (qk, v) - expected_inputs = 3 if self._masked else 2 - if not (isinstance(input_signature, (tuple, list)) and - len(input_signature) == expected_inputs): - raise ValueError( - f'input_signature should be {expected_inputs}-tuple, ' - f'but is: {input_signature}') - - # Each of qk, v are shaped - (batch * heads, length, d_head) - # mask is shaped: (batch, length) - qk_signature = input_signature[0] - v_signature = input_signature[1] - # mask_signature = input_signature[2] - # batch = mask_signature.shape[0] - batch_x_heads = qk_signature.shape[0] - - assert batch_x_heads % self._n_heads == 0 - batch = batch_x_heads // self._n_heads - - query_signature_unbatched = fastmath.nested_map( - lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype), - qk_signature) - - state_rngs = fastmath.random.split(self.rng, batch_x_heads) - state = [self.create_state_unbatched(query_signature_unbatched, rng) - for rng in state_rngs] - - stack_along_axis_0 = lambda *x: np.stack(x, axis=0) - state = fastmath.nested_map_multiarg(stack_along_axis_0, *state) - - if self._incremental: - mem = fastmath.nested_map( - lambda x: np.zeros( # pylint: disable=g-long-lambda - x.shape[:1] + (self._predict_mem_len,) + x.shape[2:], - dtype=x.dtype), - (qk_signature, v_signature)) - mem_end = np.zeros((), dtype=np.int32) - state = (mem_end, mem, state) - - self.state = tuple(state) - self.weights = () - - def create_state_unbatched(self, input_signature, rng): - if isinstance(input_signature, (tuple, list)): - input_signature = input_signature[0] - # The `rng` argument passed to forward_unbatched is shared across all - # examples and heads. This facilitates using broadcasted dropout, which - # saves memory and hasn't been shown to hurt model quality. Even though the - # same sharing is likely to be safe when selecting random hash functions - # for LSH, we haven't run experiments to demonstrate this. To be on the safe - # side we include a per-head RNG in the state for the purpose of doing LSH. - if not self._incremental: - length = self._max_length_for_buckets or input_signature.shape[0] - buckets = np.zeros(self._n_hashes * length, dtype=np.int32) - return (buckets, rng) - else: - buckets = np.zeros( - self._n_hashes * self._predict_mem_len, dtype=np.int32) - buckets_idx = np.zeros((), dtype=np.int32) - return (buckets, buckets_idx, rng) - - def hash_vectors(self, vecs, rng, mask=None): - n_buckets_list = self._n_buckets - - # Determine the number of buckets needed from input length if not set. - if n_buckets_list is None: - length = vecs.shape[0] - n_buckets = 2 * max(1, length // self._chunk_len) - if n_buckets <= 128: - n_buckets_list = n_buckets - else: # Factorize n_buckets. - n_buckets_div = 2**math.ceil(math.log2(math.sqrt(n_buckets))) - # Both factors must be even. - n_buckets_rest = 2 * (n_buckets // (2 * n_buckets_div)) - n_buckets_list = [n_buckets_div, n_buckets_rest] - - # Hash vectors. - buckets, n_buckets = hash_vecs(vecs, n_buckets_list, self._n_hashes, rng) - - if mask is not None: - n_buckets += 1 # Create an extra bucket for padding tokens only - buckets = np.where(mask[None, :], buckets, n_buckets - 1) - - # buckets is now (n_hashes, seqlen). Next we add offsets so that - # bucket numbers from different hashing rounds don't overlap. - offsets = np.arange(self._n_hashes, dtype=np.int32) - offsets = np.reshape(offsets * n_buckets, (-1, 1)) - buckets = np.reshape(buckets + offsets, (-1,)) - return buckets - - def forward_unbatched(self, qk, v, mask=None, *, state, rng, - update_state): - attend_rng, output_rng = fastmath.random.split(rng) # pylint: disable=unused-variable - - # Since these are unbatched: - # q, v are shaped (seqlen, d_head) - # mask is shaped (seqlen,) - q = qk - seqlen = q.shape[0] - - if update_state: - _, old_hash_rng = state - hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) - buckets = self.hash_vectors(q, hash_subrng, mask) - s_buckets = buckets - if self._max_length_for_buckets: - length = self._n_hashes * self._max_length_for_buckets - if buckets.shape[0] < length: - s_buckets = np.concatenate( - [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], - axis=0) - state = (s_buckets, hash_rng) - else: - buckets, _ = state - if self._max_length_for_buckets: - buckets = buckets[:self._n_hashes * seqlen] - - assert int(buckets.shape[0]) == self._n_hashes * seqlen - - ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) - buckets_and_t = seqlen * buckets + (ticker % seqlen) - buckets_and_t = fastmath.stop_gradient(buckets_and_t) - - # Hash-based sort ("s" at the start of variable names means "sorted") - sbuckets_and_t, sticker = fastmath.sort_key_val( - buckets_and_t, ticker, dimension=-1) - _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) - sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) - sticker = fastmath.stop_gradient(sticker) - undo_sort = fastmath.stop_gradient(undo_sort) - - st = (sticker % seqlen) - sq = np.take(q, st, axis=0) - sv = np.take(v, st, axis=0) - - mask_fn = functools.partial(mask_self_attention, causal=self._causal, - exclude_self=True, masked=self._masked) - q_info = st - - assert (mask is not None) == self._masked - kv_info = None - if self._masked: - # mask is a boolean array (True means "is valid token") - smask = np.take(mask, st, axis=0) - ones_like_mask = np.ones_like(smask, dtype=np.int32) - kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) - - so, slogits = attend( - sq, k=None, v=sv, - q_chunk_len=self._chunk_len, - n_chunks_before=self._n_chunks_before, - n_chunks_after=self._n_chunks_after, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + if update_state: + _, old_hash_rng = state + hash_rng, hash_subrng = fastmath.random.split(old_hash_rng) + buckets = self.hash_vectors(q, hash_subrng, mask) + s_buckets = buckets + if self._max_length_for_buckets: + length = self._n_hashes * self._max_length_for_buckets + if buckets.shape[0] < length: + s_buckets = np.concatenate( + [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)], + axis=0, + ) + state = (s_buckets, hash_rng) + else: + buckets, _ = state + if self._max_length_for_buckets: + buckets = buckets[: self._n_hashes * seqlen] + + assert int(buckets.shape[0]) == self._n_hashes * seqlen + + ticker = np.arange(self._n_hashes * seqlen, dtype=np.int32) + buckets_and_t = seqlen * buckets + (ticker % seqlen) + buckets_and_t = fastmath.stop_gradient(buckets_and_t) + + # Hash-based sort ("s" at the start of variable names means "sorted") + sbuckets_and_t, sticker = fastmath.sort_key_val( + buckets_and_t, ticker, dimension=-1 + ) + _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1) + sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t) + sticker = fastmath.stop_gradient(sticker) + undo_sort = fastmath.stop_gradient(undo_sort) + + st = sticker % seqlen + sq = np.take(q, st, axis=0) + sv = np.take(v, st, axis=0) + + mask_fn = functools.partial( + mask_self_attention, + causal=self._causal, + exclude_self=True, + masked=self._masked, + ) + q_info = st + + assert (mask is not None) == self._masked + kv_info = None + if self._masked: + # mask is a boolean array (True means "is valid token") + smask = np.take(mask, st, axis=0) + ones_like_mask = np.ones_like(smask, dtype=np.int32) + kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask) + + so, slogits = attend( + sq, + k=None, + v=sv, + q_chunk_len=self._chunk_len, + n_chunks_before=self._n_chunks_before, + n_chunks_after=self._n_chunks_after, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, ) - # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would - # also work, but these helpers include performance optimizations for TPU. - o = permute_via_gather(so, undo_sort, sticker, axis=0) - logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) - - if self._n_hashes > 1: - o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) - logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) - probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) - o = np.sum(o * probs, axis=0) - - # assert o.shape == (seqlen, w_v.shape[-1]) - assert o.shape == v.shape - - # TODO(afrozm): Unlike LSHSelfAttention we don't apply output dropout here. - out = o - return out, state - - def _incremental_forward_unbatched(self, qk, v, mask=None, *, - q_start, q_len, - state, rng, update_state): - x = (qk, v) - length = x[0].shape[0] - assert update_state, ( - 'This setting not supported (e.g. no backprop for fast inference)') - if q_len > 1: - if isinstance(q_start, int): - assert q_start == 0, 'Chunks larger than 1 only work at start for now.' - if length % self._chunk_len == 0: - x_padded = x - else: - pad_amount = self._chunk_len - (length % self._chunk_len) - x_padded = fastmath.nested_map( - lambda x: np.pad(x, ((0, pad_amount), (0, 0)), mode='constant'), x) - buckets, buckets_idx, hash_rng = state - qk, v = x_padded - buckets_update = self.hash_vectors(qk, hash_rng) - - out, _ = self.forward_unbatched( - qk, v, mask=mask, state=(buckets_update, hash_rng), rng=rng, - update_state=False) - - out = out[:q_len] - buckets = np.reshape(buckets, (self._n_hashes, -1)) - buckets_update = np.reshape( - buckets_update, (self._n_hashes, -1))[:, :q_len] - if q_len > self._predict_mem_len: - buckets_update = buckets_update[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - buckets = fastmath.dynamic_update_slice_in_dim( - buckets, buckets_update, q_start, axis=1) - buckets = np.reshape(buckets, (-1,)) - - return out, (buckets, buckets_idx + q_len, hash_rng) - - # This codepath is for handling one token at a time. - assert q_len == 1 - buckets, buckets_idx, hash_rng = state - - def roll_buckets(buckets): - buckets = np.reshape(buckets, (self._n_hashes, -1)) - new_buckets = np.concatenate( - [buckets, np.zeros((self._n_hashes, self._predict_drop_len), - dtype=buckets.dtype) - ], axis=1) - new_buckets = fastmath.dynamic_slice_in_dim( - new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1) - new_buckets = np.reshape(new_buckets, (-1,)) - return new_buckets - - buckets = fastmath.cond( - pred=buckets_idx > q_start, - true_operand=buckets, - true_fun=roll_buckets, - false_operand=buckets, - false_fun=lambda x: x, - ) + # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would + # also work, but these helpers include performance optimizations for TPU. + o = permute_via_gather(so, undo_sort, sticker, axis=0) + logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) + + if self._n_hashes > 1: + o = np.reshape(o, (self._n_hashes, seqlen, o.shape[-1])) + logits = np.reshape(logits, (self._n_hashes, seqlen, 1)) + probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True)) + o = np.sum(o * probs, axis=0) + + # assert o.shape == (seqlen, w_v.shape[-1]) + assert o.shape == v.shape - attend_rng, unused_output_rng = fastmath.random.split(rng) - - q_range = q_start + np.arange(q_len, dtype=np.int32) - # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not - # floating-point equivalent, at least in non-jitted code. We correct the - # discrepancy by duplicating the slice. Floating-point noise may not be - # an issue when using models, but it makes it harder to write tests that - # compare fast and slow inference code for equivalence. - q = np.concatenate([qk[q_range]] * 2, 0) - - q_buckets = self.hash_vectors(q, hash_rng) - q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - - unflattened_buckets = fastmath.dynamic_update_slice_in_dim( - np.reshape(buckets, (self._n_hashes, -1)), - q_buckets, q_start, axis=1) - buckets = np.reshape(unflattened_buckets, (-1,)) - is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) - - assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 - length = qk.shape[0] - arange_seqlen = np.arange(length, dtype=np.int32) - kv_priorities = np.where( - arange_seqlen > (q_start + q_len), - -(length + arange_seqlen), arange_seqlen) - kv_priorities = kv_priorities + length * is_valid_target.astype(np.int32) - _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) - kv_indices = kv_indices[ - -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before):] - assert self._n_chunks_after == 0 - - k = length_normalized(qk[kv_indices]) - v = v[kv_indices] - - mask_fn = functools.partial( - mask_self_attention, causal=True, masked=True, exclude_self=True) - q_info = q_start + np.arange(q_len, dtype=np.int32) - kv_info = kv_indices.astype(np.int32) - q_info = q_info.astype(np.int32) - # TODO(kitaev): is it better to mask out attention across buckets? - # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + # TODO(afrozm): Unlike LSHSelfAttention we don't apply output dropout here. + out = o + return out, state + + def _incremental_forward_unbatched( + self, qk, v, mask=None, *, q_start, q_len, state, rng, update_state + ): + x = (qk, v) + length = x[0].shape[0] + assert ( + update_state + ), "This setting not supported (e.g. no backprop for fast inference)" + if q_len > 1: + if isinstance(q_start, int): + assert q_start == 0, "Chunks larger than 1 only work at start for now." + if length % self._chunk_len == 0: + x_padded = x + else: + pad_amount = self._chunk_len - (length % self._chunk_len) + x_padded = fastmath.nested_map( + lambda x: np.pad(x, ((0, pad_amount), (0, 0)), mode="constant"), x + ) + buckets, buckets_idx, hash_rng = state + qk, v = x_padded + buckets_update = self.hash_vectors(qk, hash_rng) + + out, _ = self.forward_unbatched( + qk, + v, + mask=mask, + state=(buckets_update, hash_rng), + rng=rng, + update_state=False, + ) + + out = out[:q_len] + buckets = np.reshape(buckets, (self._n_hashes, -1)) + buckets_update = np.reshape(buckets_update, (self._n_hashes, -1))[:, :q_len] + if q_len > self._predict_mem_len: + buckets_update = buckets_update[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + buckets = fastmath.dynamic_update_slice_in_dim( + buckets, buckets_update, q_start, axis=1 + ) + buckets = np.reshape(buckets, (-1,)) + + return out, (buckets, buckets_idx + q_len, hash_rng) + + # This codepath is for handling one token at a time. + assert q_len == 1 + buckets, buckets_idx, hash_rng = state + + def roll_buckets(buckets): + buckets = np.reshape(buckets, (self._n_hashes, -1)) + new_buckets = np.concatenate( + [ + buckets, + np.zeros( + (self._n_hashes, self._predict_drop_len), dtype=buckets.dtype + ), + ], + axis=1, + ) + new_buckets = fastmath.dynamic_slice_in_dim( + new_buckets, buckets_idx - q_start, buckets.shape[-1], axis=1 + ) + new_buckets = np.reshape(new_buckets, (-1,)) + return new_buckets + + buckets = fastmath.cond( + pred=buckets_idx > q_start, + true_operand=buckets, + true_fun=roll_buckets, + false_operand=buckets, + false_fun=lambda x: x, ) - out = o - if q_len == 1: - out = out[:1] - buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) - return out, (buckets, buckets_idx, hash_rng) + attend_rng, unused_output_rng = fastmath.random.split(rng) - def forward(self, inputs): - """Computes this layer's output as part of a forward pass through the model. + q_range = q_start + np.arange(q_len, dtype=np.int32) + # On TPU, np.matmul(a[:1], b) and np.matmul(a, b)[:1] are not + # floating-point equivalent, at least in non-jitted code. We correct the + # discrepancy by duplicating the slice. Floating-point noise may not be + # an issue when using models, but it makes it harder to write tests that + # compare fast and slow inference code for equivalence. + q = np.concatenate([qk[q_range]] * 2, 0) - Args: - inputs: Layer inputs (subclasses may use different inputs) + q_buckets = self.hash_vectors(q, hash_rng) + q_buckets = np.reshape(q_buckets, (self._n_hashes, 2))[:, :q_len] - Returns: - A tuple (output, new_state). - """ - state, rng = self.state, self.rng - - if self._use_reference_code: - raise NotImplementedError( - 'Reference code not implemented for PureLSHSelfAttention') - - output, new_state, unused_input_cotangents = self.forward_and_or_backward( - inputs, state, rng, compute_output=True, update_state=True) - self.state = new_state - return output - - def _use_predict_mem(self, inputs, state): - """Update input cache for fast inference.""" - - # inputs is (qk, v). mask isn't passed in. - # where qk/v are shaped - (batch * n_heads, seq_len, d_head) - - mem_end, mem, state = state - seqlen = inputs[0].shape[-2] - - if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: - # This branch is called when only a small number of tokens are appended to - # the sequence, e.g. when generating one token at a time. A fixed number - # of tokens (self._predict_drop_tokens) will be dropped from memory if - # needed, and then new values will be inserted into the memory. - def roll_mem(buf): - return np.concatenate( - [buf[:, self._predict_drop_len:], - np.zeros_like(buf[:, :self._predict_drop_len])], axis=1) - - do_roll_mem = (mem_end + seqlen > self._predict_mem_len) - mem = fastmath.cond( - pred=do_roll_mem, - true_operand=mem, - true_fun=lambda x: fastmath.nested_map(roll_mem, x), - false_operand=mem, - false_fun=lambda x: x, - ) - mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) - def update_mem(mem_element, new_vals): - assert new_vals.shape[1] == seqlen - if seqlen == 1: - return fastmath.index_update( - mem_element, jax.numpy.index_exp[:, mem_end], new_vals[:, 0, ...]) - else: - return fastmath.dynamic_update_slice_in_dim( - mem_element, new_vals, mem_end, axis=1) - inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) - return inputs, state, mem_end, inputs, mem_end + seqlen - else: - assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len - # This branch handles the case where a large number of tokens are being - # introduced all at once. The code here assumes that we are at the start - # of the sequence, which matches the typical use case of decoding from a - # language model given a long prefix. Note that if we're not at the start - # of the sequence, the code here won't work. - new_flat_mem = [] - for inp in fastmath.tree_leaves(inputs): - assert inp.shape[1] == seqlen - if seqlen == self._predict_mem_len: - new_mem_val = inp - elif seqlen > self._predict_mem_len: - new_mem_val = inp[:, -self._predict_mem_len:] # pylint: disable=invalid-unary-operand-type - else: - new_mem_val = np.concatenate([ - inp, - np.zeros(inp.shape[:1] - + (self._predict_mem_len - inp.shape[1],) - + inp.shape[2:], - dtype=inp.dtype) - ], axis=1) - new_flat_mem.append(new_mem_val) - mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) - - # This code only works at the start of the sequence. There's no "assert" - # primitive we can use to signal an error, so we instead signal the error - # by introducing NaNs into the computation. - def replace_with_nan_if_not_seq_start(x): - if x.dtype != np.float32: - return x - return fastmath.cond( - pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), - true_operand=x, true_fun=lambda x: x, - false_operand=x, false_fun=lambda x: x * np.nan) - inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) - return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng=None, - **kwargs): - """Custom backward pass, for efficiency (see forward_and_or_backward).""" - del output - del state - del kwargs - unused_output, unused_new_state, inputs_grad = self.forward_and_or_backward( - inputs, - new_state, - rng, - output_grad=grad, - compute_output=False, - update_state=False) + unflattened_buckets = fastmath.dynamic_update_slice_in_dim( + np.reshape(buckets, (self._n_hashes, -1)), q_buckets, q_start, axis=1 + ) + buckets = np.reshape(unflattened_buckets, (-1,)) + is_valid_target = np.any(unflattened_buckets == q_buckets, axis=0) + + assert q_buckets.shape[-1] == 1 # Is true when q_len == 1 + length = qk.shape[0] + arange_seqlen = np.arange(length, dtype=np.int32) + kv_priorities = np.where( + arange_seqlen > (q_start + q_len), -(length + arange_seqlen), arange_seqlen + ) + kv_priorities = kv_priorities + length * is_valid_target.astype(np.int32) + _, kv_indices = fastmath.sort_key_val(kv_priorities, arange_seqlen) + kv_indices = kv_indices[ + -self._n_hashes * self._chunk_len * (1 + self._n_chunks_before) : + ] + assert self._n_chunks_after == 0 + + k = length_normalized(qk[kv_indices]) + v = v[kv_indices] + + mask_fn = functools.partial( + mask_self_attention, causal=True, masked=True, exclude_self=True + ) + q_info = q_start + np.arange(q_len, dtype=np.int32) + kv_info = kv_indices.astype(np.int32) + q_info = q_info.astype(np.int32) + # TODO(kitaev): is it better to mask out attention across buckets? + # kv_info = np.where(is_valid_target[kv_indices], kv_indices, -kv_indices) + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, + ) - weights_grad = fastmath.nested_map(np.zeros_like, weights) - return inputs_grad, weights_grad + out = o + if q_len == 1: + out = out[:1] + buckets_idx = np.array(q_start + q_len, dtype=buckets_idx.dtype) + return out, (buckets, buckets_idx, hash_rng) - def forward_and_or_backward( - self, inputs, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. + def forward(self, inputs): + """Computes this layer's output as part of a forward pass through the model. - See `forward` for a reference implementation of what this layer does. The - reference implementation is not very efficient, however, and this method - provides a more performant version. + Args: + inputs: Layer inputs (subclasses may use different inputs) - Args: - inputs: inputs to the attention layer tuple (qk, v, mask) - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. + Returns: + A tuple (output, new_state). + """ + state, rng = self.state, self.rng - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). + if self._use_reference_code: + raise NotImplementedError( + "Reference code not implemented for PureLSHSelfAttention" + ) - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # TODO(b/148460708): reduce memory usage further - # TODO(kitaev): there should be a higher-level API (like vmap) that does - # batching, instead of needing 3 separate manual implementations here. - - # Notes regarding the implementation: - # (a) Multiple heads or examples are batched together. There are three - # different regimes possible: one head at a time (for long sequences and - # expensive attention types), several attention heads at a time (for - # long sequences but less-expensive attention types), and several - # examples at a time (for large batches of shorter sequences). For the - # time being, each of these regimes has its own code. - # (b) Python loops produce large computation graphs when jitted, so the - # default is to use a JAX loop instead. - # (c) No intermediate quantities are cached for the backward pass. Instead, - # the forward pass is re-computed when doing backprop. This approach is - # often called "checkpointing" or "rematerialization". When not all - # examples or heads fit in memory simultaneously, the implementation - # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse - # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] - # automatically, so the looping for the backward pass is done manually. - # - # [FW-BW-1] for example, head in zip(examples, heads): - # forward(example, head) - # backward(example, head) # uses intermediates from forward - # - # [FW-BW-2] for example, head in zip(examples, heads): - # forward(example, head) - # for example, head in zip(examples, heads): - # backward(example, head) - - if self._masked: - qk, v, mask = inputs - batch_size = mask.shape[0] - else: - qk, v = inputs - mask = None - batch_size = qk.shape[0] // self._n_heads - batch_x_heads, seqlen, d_model = qk.shape + output, new_state, unused_input_cotangents = self.forward_and_or_backward( + inputs, state, rng, compute_output=True, update_state=True + ) + self.state = new_state + return output + + def _use_predict_mem(self, inputs, state): + """Update input cache for fast inference.""" + + # inputs is (qk, v). mask isn't passed in. + # where qk/v are shaped - (batch * n_heads, seq_len, d_head) + + mem_end, mem, state = state + seqlen = inputs[0].shape[-2] + + if seqlen <= self._predict_drop_len and seqlen < self._predict_mem_len: + # This branch is called when only a small number of tokens are appended to + # the sequence, e.g. when generating one token at a time. A fixed number + # of tokens (self._predict_drop_tokens) will be dropped from memory if + # needed, and then new values will be inserted into the memory. + def roll_mem(buf): + return np.concatenate( + [ + buf[:, self._predict_drop_len :], + np.zeros_like(buf[:, : self._predict_drop_len]), + ], + axis=1, + ) + + do_roll_mem = mem_end + seqlen > self._predict_mem_len + mem = fastmath.cond( + pred=do_roll_mem, + true_operand=mem, + true_fun=lambda x: fastmath.nested_map(roll_mem, x), + false_operand=mem, + false_fun=lambda x: x, + ) + mem_end = np.where(do_roll_mem, mem_end - self._predict_drop_len, mem_end) + + def update_mem(mem_element, new_vals): + assert new_vals.shape[1] == seqlen + if seqlen == 1: + return fastmath.index_update( + mem_element, + jax.numpy.index_exp[:, mem_end], + new_vals[:, 0, ...], + ) + else: + return fastmath.dynamic_update_slice_in_dim( + mem_element, new_vals, mem_end, axis=1 + ) + + inputs = fastmath.nested_map_multiarg(update_mem, mem, inputs) + return inputs, state, mem_end, inputs, mem_end + seqlen + else: + assert seqlen > self._predict_drop_len or seqlen == self._predict_mem_len + # This branch handles the case where a large number of tokens are being + # introduced all at once. The code here assumes that we are at the start + # of the sequence, which matches the typical use case of decoding from a + # language model given a long prefix. Note that if we're not at the start + # of the sequence, the code here won't work. + new_flat_mem = [] + for inp in fastmath.tree_leaves(inputs): + assert inp.shape[1] == seqlen + if seqlen == self._predict_mem_len: + new_mem_val = inp + elif seqlen > self._predict_mem_len: + new_mem_val = inp[ + :, -self._predict_mem_len : + ] # pylint: disable=invalid-unary-operand-type + else: + new_mem_val = np.concatenate( + [ + inp, + np.zeros( + inp.shape[:1] + + (self._predict_mem_len - inp.shape[1],) + + inp.shape[2:], + dtype=inp.dtype, + ), + ], + axis=1, + ) + new_flat_mem.append(new_mem_val) + mem, _ = fastmath.tree_unflatten(new_flat_mem, mem) + + # This code only works at the start of the sequence. There's no "assert" + # primitive we can use to signal an error, so we instead signal the error + # by introducing NaNs into the computation. + def replace_with_nan_if_not_seq_start(x): + if x.dtype != np.float32: + return x + return fastmath.cond( + pred=np.equal(mem_end, np.array(0, dtype=mem_end.dtype)), + true_operand=x, + true_fun=lambda x: x, + false_operand=x, + false_fun=lambda x: x * np.nan, + ) + + inputs = fastmath.nested_map(replace_with_nan_if_not_seq_start, inputs) + return inputs, state, 0, mem, np.minimum(seqlen, self._predict_mem_len) + + @property + def has_backward(self): + return True + + def backward( + self, inputs, output, grad, weights, state, new_state, rng=None, **kwargs + ): + """Custom backward pass, for efficiency (see forward_and_or_backward).""" + del output + del state + del kwargs + unused_output, unused_new_state, inputs_grad = self.forward_and_or_backward( + inputs, + new_state, + rng, + output_grad=grad, + compute_output=False, + update_state=False, + ) - compute_grad = (output_grad is not None) - assert compute_output or compute_grad, 'No work to perform!' + weights_grad = fastmath.nested_map(np.zeros_like, weights) + return inputs_grad, weights_grad - if not self._incremental: - forward_unbatched = functools.partial( - self.forward_unbatched, rng=rng, update_state=update_state) - else: - assert not compute_grad - - # The input to use_predict_mem is (qk, v) - inputs = (qk, v) - - if update_state: - inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( - inputs, state) - else: - # This assumes that the memory stores all of the inputs, which would not - # be valid if doing backprop in mode 'predict' with long lengths. - new_mem_end, inputs, state = state - q_start = new_mem_end - seqlen - - # Reset qk and v to what use_predict_mem/state gave us. - qk, v = inputs - - forward_unbatched = functools.partial( - self._incremental_forward_unbatched, - q_start=fastmath.stop_gradient(q_start), - q_len=fastmath.stop_gradient(seqlen), - rng=rng, update_state=update_state) - - # Adjust degree of parallelism based on the batch size. - n_parallel_heads = batch_size * self._n_heads - if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: - n_parallel_heads = self._n_parallel_heads - - def tree_update(tree, indices, new_values): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], - y), - tree, new_values) - - def tree_add(tree, indices, addends): - return fastmath.nested_map_multiarg( - lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), - tree, addends) - - if compute_grad: - inputs_is_differentiable = fastmath.nested_map( - lambda x: np.issubdtype(x.dtype, np.inexact), inputs) - def split_differentiable(xs): - differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: x if is_differentiable else None, - xs, inputs_is_differentiable) - non_differentiable_xs = fastmath.nested_map_multiarg( - lambda x, is_differentiable: None if is_differentiable else x, - xs, inputs_is_differentiable) - return differentiable_xs, non_differentiable_xs - def join_differentiable(differentiable_xs, non_differentiable_xs): - """Reconstitute inputs pytree from differentiable/non-d. partitions.""" - differentiable_leaves = fastmath.tree_leaves(differentiable_xs) - non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) - leaves = [] - for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): - if is_differentiable: - leaves.append(differentiable_leaves.pop(0)) - else: - leaves.append(non_differentiable_leaves.pop(0)) - assert not differentiable_leaves - assert not non_differentiable_leaves - tree, _ = fastmath.tree_unflatten(leaves, inputs) - return tree - - def vjp(fn, inp, *args, has_aux=False): - d_inp, nd_inp = split_differentiable(inp) - - def fn_closed_over_nd_inp(d_inp, *args): - inp = join_differentiable(d_inp, nd_inp) - return fn(inp, *args) - return fastmath.vjp(fn_closed_over_nd_inp, d_inp, *args, - has_aux=has_aux) - - if n_parallel_heads != 1: - raise NotImplementedError( - 'PureLSHSelfAttention is not implemented for n_parallel_heads != 1.') - - def run_inner(idx, loop_val): - """Runs one slice of attention (for a single head).""" - o_all, s_all, i_ct_all = loop_val - example_idx = idx // self._n_heads - unused_head_idx = idx % self._n_heads - - s_h = fastmath.nested_map(lambda s: s[idx], state) - - if self._masked: - i_h = (qk[idx], v[idx], mask[example_idx]) - else: - i_h = (qk[idx], v[idx]) - - def forward_fn(i_h): - return forward_unbatched( - *i_h, state=fastmath.stop_gradient(s_h)) - - if compute_grad: - o_h, backward_fn, s_h = vjp(forward_fn, i_h, has_aux=True) - ct_h = output_grad[idx] - assert o_h.shape == ct_h.shape - i_ct_h, = backward_fn(ct_h) - else: - o_h, s_h = forward_fn(i_h) - - if compute_output: - o_all = tree_update(o_all, idx, o_h) - if update_state: - s_all = tree_update(s_all, idx, s_h) - if compute_grad: - i_ct_all = tree_add(i_ct_all, idx, i_ct_h) - return (o_all, s_all, i_ct_all) - - o_all = s_all = i_ct_all = None - if compute_output: - o_all = np.zeros((batch_x_heads, seqlen, d_model), dtype=v.dtype) - if update_state: - s_all = state - if compute_grad: - i_ct_all = fastmath.nested_map(np.zeros_like, inputs) - i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - - loop_val = (o_all, s_all, i_ct_all) - - assert (batch_size * self._n_heads) % n_parallel_heads == 0 - loop_hi = (batch_size * self._n_heads) // n_parallel_heads - if self._use_python_loop or loop_hi == 1: - for idx in range(loop_hi): - loop_val = run_inner(idx, loop_val) - else: - loop_val = fastmath.fori_loop( - 0, loop_hi, run_inner, loop_val) + def forward_and_or_backward( + self, + inputs, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + See `forward` for a reference implementation of what this layer does. The + reference implementation is not very efficient, however, and this method + provides a more performant version. + + Args: + inputs: inputs to the attention layer tuple (qk, v, mask) + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # TODO(b/148460708): reduce memory usage further + # TODO(kitaev): there should be a higher-level API (like vmap) that does + # batching, instead of needing 3 separate manual implementations here. + + # Notes regarding the implementation: + # (a) Multiple heads or examples are batched together. There are three + # different regimes possible: one head at a time (for long sequences and + # expensive attention types), several attention heads at a time (for + # long sequences but less-expensive attention types), and several + # examples at a time (for large batches of shorter sequences). For the + # time being, each of these regimes has its own code. + # (b) Python loops produce large computation graphs when jitted, so the + # default is to use a JAX loop instead. + # (c) No intermediate quantities are cached for the backward pass. Instead, + # the forward pass is re-computed when doing backprop. This approach is + # often called "checkpointing" or "rematerialization". When not all + # examples or heads fit in memory simultaneously, the implementation + # should be [FW-BW-1] and NOT [FW-BW-2], because the latter has worse + # memory locality. I don't think JAX autodiff can synthesize [FW-BW-1] + # automatically, so the looping for the backward pass is done manually. + # + # [FW-BW-1] for example, head in zip(examples, heads): + # forward(example, head) + # backward(example, head) # uses intermediates from forward + # + # [FW-BW-2] for example, head in zip(examples, heads): + # forward(example, head) + # for example, head in zip(examples, heads): + # backward(example, head) + + if self._masked: + qk, v, mask = inputs + batch_size = mask.shape[0] + else: + qk, v = inputs + mask = None + batch_size = qk.shape[0] // self._n_heads + batch_x_heads, seqlen, d_model = qk.shape + + compute_grad = output_grad is not None + assert compute_output or compute_grad, "No work to perform!" + + if not self._incremental: + forward_unbatched = functools.partial( + self.forward_unbatched, rng=rng, update_state=update_state + ) + else: + assert not compute_grad + + # The input to use_predict_mem is (qk, v) + inputs = (qk, v) + + if update_state: + inputs, state, q_start, new_mem, new_mem_end = self._use_predict_mem( + inputs, state + ) + else: + # This assumes that the memory stores all of the inputs, which would not + # be valid if doing backprop in mode 'predict' with long lengths. + new_mem_end, inputs, state = state + q_start = new_mem_end - seqlen + + # Reset qk and v to what use_predict_mem/state gave us. + qk, v = inputs + + forward_unbatched = functools.partial( + self._incremental_forward_unbatched, + q_start=fastmath.stop_gradient(q_start), + q_len=fastmath.stop_gradient(seqlen), + rng=rng, + update_state=update_state, + ) + + # Adjust degree of parallelism based on the batch size. + n_parallel_heads = batch_size * self._n_heads + if self._n_parallel_heads and self._n_parallel_heads < n_parallel_heads: + n_parallel_heads = self._n_parallel_heads + + def tree_update(tree, indices, new_values): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_update(x, jax.numpy.index_exp[indices], y), + tree, + new_values, + ) + + def tree_add(tree, indices, addends): + return fastmath.nested_map_multiarg( + lambda x, y: fastmath.index_add(x, jax.numpy.index_exp[indices], y), + tree, + addends, + ) - (o_all, s_all, i_ct_all) = loop_val + if compute_grad: + inputs_is_differentiable = fastmath.nested_map( + lambda x: np.issubdtype(x.dtype, np.inexact), inputs + ) + + def split_differentiable(xs): + differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: x if is_differentiable else None, + xs, + inputs_is_differentiable, + ) + non_differentiable_xs = fastmath.nested_map_multiarg( + lambda x, is_differentiable: None if is_differentiable else x, + xs, + inputs_is_differentiable, + ) + return differentiable_xs, non_differentiable_xs + + def join_differentiable(differentiable_xs, non_differentiable_xs): + """Reconstitute inputs pytree from differentiable/non-d. partitions.""" + differentiable_leaves = fastmath.tree_leaves(differentiable_xs) + non_differentiable_leaves = fastmath.tree_leaves(non_differentiable_xs) + leaves = [] + for is_differentiable in fastmath.tree_leaves(inputs_is_differentiable): + if is_differentiable: + leaves.append(differentiable_leaves.pop(0)) + else: + leaves.append(non_differentiable_leaves.pop(0)) + assert not differentiable_leaves + assert not non_differentiable_leaves + tree, _ = fastmath.tree_unflatten(leaves, inputs) + return tree + + def vjp(fn, inp, *args, has_aux=False): + d_inp, nd_inp = split_differentiable(inp) + + def fn_closed_over_nd_inp(d_inp, *args): + inp = join_differentiable(d_inp, nd_inp) + return fn(inp, *args) + + return fastmath.vjp( + fn_closed_over_nd_inp, d_inp, *args, has_aux=has_aux + ) + + if n_parallel_heads != 1: + raise NotImplementedError( + "PureLSHSelfAttention is not implemented for n_parallel_heads != 1." + ) + + def run_inner(idx, loop_val): + """Runs one slice of attention (for a single head).""" + o_all, s_all, i_ct_all = loop_val + example_idx = idx // self._n_heads + unused_head_idx = idx % self._n_heads + + s_h = fastmath.nested_map(lambda s: s[idx], state) + + if self._masked: + i_h = (qk[idx], v[idx], mask[example_idx]) + else: + i_h = (qk[idx], v[idx]) + + def forward_fn(i_h): + return forward_unbatched(*i_h, state=fastmath.stop_gradient(s_h)) + + if compute_grad: + o_h, backward_fn, s_h = vjp(forward_fn, i_h, has_aux=True) + ct_h = output_grad[idx] + assert o_h.shape == ct_h.shape + (i_ct_h,) = backward_fn(ct_h) + else: + o_h, s_h = forward_fn(i_h) + + if compute_output: + o_all = tree_update(o_all, idx, o_h) + if update_state: + s_all = tree_update(s_all, idx, s_h) + if compute_grad: + i_ct_all = tree_add(i_ct_all, idx, i_ct_h) + return (o_all, s_all, i_ct_all) + + o_all = s_all = i_ct_all = None + if compute_output: + o_all = np.zeros((batch_x_heads, seqlen, d_model), dtype=v.dtype) + if update_state: + s_all = state + if compute_grad: + i_ct_all = fastmath.nested_map(np.zeros_like, inputs) + i_ct_all, i_nondifferentiable_dummy_ct = split_differentiable(i_ct_all) - if compute_grad: - i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + loop_val = (o_all, s_all, i_ct_all) - if self._incremental and update_state: - s_all = (new_mem_end, new_mem, s_all) + assert (batch_size * self._n_heads) % n_parallel_heads == 0 + loop_hi = (batch_size * self._n_heads) // n_parallel_heads + if self._use_python_loop or loop_hi == 1: + for idx in range(loop_hi): + loop_val = run_inner(idx, loop_val) + else: + loop_val = fastmath.fori_loop(0, loop_hi, run_inner, loop_val) - return (o_all, s_all, i_ct_all) + (o_all, s_all, i_ct_all) = loop_val + + if compute_grad: + i_ct_all = join_differentiable(i_ct_all, i_nondifferentiable_dummy_ct) + + if self._incremental and update_state: + s_all = (new_mem_end, new_mem, s_all) + + return (o_all, s_all, i_ct_all) def _ProjectAndSplitHeads( # pylint: disable=invalid-name @@ -3272,522 +3721,632 @@ def _ProjectAndSplitHeads( # pylint: disable=invalid-name num_weights=2, sparsity=16, length_kernel_size=3, - weights_format='sparse', + weights_format="sparse", rotary_position_emb=False, - mode='train'): - """Creates the QK and V activations from input.""" - # There can be either two or three weights: - # two - qk and v or three - q, k, v - # If there are three, we want to average q and k and use that. - - # Weights can also be in 'heads' major format - (n_heads, d_model, d_head) - # this is used by efficient_attention.LSHSelfAttention and - # efficient_attention.SelfAttention - - # Or they can be in 'model' major format - (d_model, d_model), which is what - # tl._attention/CausalAttention etc use -- so use this format if we pretrain a - # model trained with those and finetuning with PureLSHSelfAttention. - - assert weights_format in ('heads', 'model', 'sparse') - - # When an earlier model was trained with 3 separate weights for Q, K, V - # projections with tl._attention/tl._causalAttention etc. - if weights_format == 'model' and num_weights == 3: - return cb.Serial( - # Create the raw Q, K, V projections. - cb.Branch( - core.Dense(d_model, use_bias=use_bias), - core.Dense(d_model, use_bias=use_bias), - core.Dense(d_model, use_bias=use_bias)), # q, k, v - # Optionally, rotate Q and K vectors if rotary embeddings are used. - cb.Parallel(rotary_pe.Rotate(), rotary_pe.Rotate(), None) - if rotary_position_emb else [], - # Average Q and K into one single QK tensor. - core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1), # qk, v - # Split heads and combine with batch dimension to get two tensors of - # (batch * n_heads, seq_len, d_head) shape. - cb.Parallel( - attention.SplitIntoHeads(n_heads), - attention.SplitIntoHeads(n_heads)) # qk, v - ) + mode="train", +): + """Creates the QK and V activations from input.""" + # There can be either two or three weights: + # two - qk and v or three - q, k, v + # If there are three, we want to average q and k and use that. + + # Weights can also be in 'heads' major format - (n_heads, d_model, d_head) + # this is used by efficient_attention.LSHSelfAttention and + # efficient_attention.SelfAttention + + # Or they can be in 'model' major format - (d_model, d_model), which is what + # tl._attention/CausalAttention etc use -- so use this format if we pretrain a + # model trained with those and finetuning with PureLSHSelfAttention. + + assert weights_format in ("heads", "model", "sparse") + + # When an earlier model was trained with 3 separate weights for Q, K, V + # projections with tl._attention/tl._causalAttention etc. + if weights_format == "model" and num_weights == 3: + return cb.Serial( + # Create the raw Q, K, V projections. + cb.Branch( + core.Dense(d_model, use_bias=use_bias), + core.Dense(d_model, use_bias=use_bias), + core.Dense(d_model, use_bias=use_bias), + ), # q, k, v + # Optionally, rotate Q and K vectors if rotary embeddings are used. + cb.Parallel(rotary_pe.Rotate(), rotary_pe.Rotate(), None) + if rotary_position_emb + else [], + # Average Q and K into one single QK tensor. + core.Fn("QKAvg", lambda x, y: (x + y) / 2.0, n_out=1), # qk, v + # Split heads and combine with batch dimension to get two tensors of + # (batch * n_heads, seq_len, d_head) shape. + cb.Parallel( + attention.SplitIntoHeads(n_heads), attention.SplitIntoHeads(n_heads) + ), # qk, v + ) - if weights_format == 'sparse' and num_weights == 3: - d_module = d_model // sparsity - # This layer matches sparsity.MultiplicativeConvCausalAttention, - # see there for more explanation. - # TODO(lukaszkaiser): unify code so that we don't duplicate so much. - return cb.Serial( - cb.Select([0, 0]), # duplicate activations - sp.FactoredDense(sparsity, d_model, d_model), - cb.Select([0, 0, 0]), # use for q, k, v - cb.Parallel( - [sp.LocallyConvDense(sparsity, d_module, mode=mode, - kernel_size=3, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], - [sp.LocallyConvDense(sparsity, d_module, mode=mode, - kernel_size=3, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], - [cb.Select([0], n_in=2), - sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], - ), - core.Fn('QKAvg', lambda x, y: (x + y) / 2.0, n_out=1), - ) + if weights_format == "sparse" and num_weights == 3: + d_module = d_model // sparsity + # This layer matches sparsity.MultiplicativeConvCausalAttention, + # see there for more explanation. + # TODO(lukaszkaiser): unify code so that we don't duplicate so much. + return cb.Serial( + cb.Select([0, 0]), # duplicate activations + sp.FactoredDense(sparsity, d_model, d_model), + cb.Select([0, 0, 0]), # use for q, k, v + cb.Parallel( + [ + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + [ + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + [ + cb.Select([0], n_in=2), + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + ), + core.Fn("QKAvg", lambda x, y: (x + y) / 2.0, n_out=1), + ) - if weights_format == 'sparse' and num_weights == 2: - d_module = d_model // sparsity - # This layer matches sparsity.MultiplicativeConvCausalAttention, - # see there for more explanation. - # TODO(lukaszkaiser): unify code so that we don't duplicate so much. - return cb.Serial( - cb.Select([0, 0]), # pre-qkv, pre-v-for-concat - sp.FactoredDense(sparsity, d_model, d_model), # shared q k - cb.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat - sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads), - cb.Parallel( - [], - [cb.Select([0], n_in=2), - sp.LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - attention.SplitIntoHeads(n_heads)], + if weights_format == "sparse" and num_weights == 2: + d_module = d_model // sparsity + # This layer matches sparsity.MultiplicativeConvCausalAttention, + # see there for more explanation. + # TODO(lukaszkaiser): unify code so that we don't duplicate so much. + return cb.Serial( + cb.Select([0, 0]), # pre-qkv, pre-v-for-concat + sp.FactoredDense(sparsity, d_model, d_model), # shared q k + cb.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + cb.Parallel( + [], + [ + cb.Select([0], n_in=2), + sp.LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + attention.SplitIntoHeads(n_heads), + ], + ), ) - ) - # We want to train from scratch and have only two weights, w_qk and w_v. - if weights_format == 'model' and num_weights == 2: - return cb.Branch( - [ - core.Dense(d_model, use_bias=use_bias), - rotary_pe.Rotate() if rotary_position_emb else [], - attention.SplitIntoHeads(n_heads) - ], - [ - core.Dense(d_model, use_bias=use_bias), - attention.SplitIntoHeads(n_heads) - ], - ) + # We want to train from scratch and have only two weights, w_qk and w_v. + if weights_format == "model" and num_weights == 2: + return cb.Branch( + [ + core.Dense(d_model, use_bias=use_bias), + rotary_pe.Rotate() if rotary_position_emb else [], + attention.SplitIntoHeads(n_heads), + ], + [core.Dense(d_model, use_bias=use_bias), attention.SplitIntoHeads(n_heads)], + ) - assert weights_format == 'head' + assert weights_format == "head" - raise NotImplementedError('TODO(afrozm): Implement this when we want to use ' - 'checkpoints trained with LSHSelfAttention or ' - 'SelfAttention') + raise NotImplementedError( + "TODO(afrozm): Implement this when we want to use " + "checkpoints trained with LSHSelfAttention or " + "SelfAttention" + ) class MixedLSHSelfAttention(base.Layer): - """LSH attention mixed with standard attention used until std_length.""" - - def __init__(self, - n_heads=1, - d_qk=64, - d_v=64, - causal=False, - masked=False, - std_length=None, - mode='train', - output_dropout=0.0, - attention_dropout=0.0, - force_no_dropout=False, - **pure_lsh_implementation_kwargs): - # This class could be replaced with a Branch and tl.Fn(..) selecting - # one of the arguments based on the class. But, similarly to the Wrapper - # below, we need forward_and_backward currently to pass remembered state - # back to the PureLSH layer. We should switch that to the other Remember - # mechanism used for the SparseFF layer (and clarify and document that too). - # Once this is done, we can remove this and the Wrapper class. - attention_dropout = 0.0 if force_no_dropout else attention_dropout - output_dropout = 0.0 if force_no_dropout else output_dropout - self._lsha = PureLSHSelfAttention(n_heads=n_heads, - d_qk=d_qk, - d_v=d_v, - causal=causal, - masked=masked, - mode=mode, - output_dropout=output_dropout, - attention_dropout=attention_dropout, - **pure_lsh_implementation_kwargs) - if causal: - pure_attn = attention.DotProductCausalAttention - preprocess = core.Fn('dup_shared_qk', lambda q, v: (q, q, v), n_out=3) - else: - pure_attn = attention.DotProductAttention - def _add_heads_to_mask(m): - m_with_heads = np.reshape(m, (m.shape[0], 1, m.shape[1])) - m_with_heads = np.broadcast_to(m_with_heads, - (m.shape[0], n_heads, m.shape[1])) - return np.reshape(m_with_heads, (-1, 1, m.shape[1])) - preprocess = core.Fn('dup_shared_qk_and_make_mask', - lambda q, v, m: (q, q, v, _add_heads_to_mask(m)), - n_out=4) - self._stda = cb.Serial( - preprocess, - pure_attn(dropout=attention_dropout, mode=mode) - ) - self._std_length = std_length - self._sublayers = [self._lsha, self._stda] - - if self._stda.n_in != self._lsha.n_in: - raise ValueError(f'n_in diff: {self._stda.n_in} != {self._lsha.n_in}') - if self._stda.n_out != self._lsha.n_out: - raise ValueError(f'n_out diff: {self._stda.n_out} != {self._lsha.n_out}') - super().__init__(n_in=self._stda.n_in, n_out=self._stda.n_out) - - def init_weights_and_state(self, input_signature): - """Initializes weights and state for inputs with the given signature.""" - states = [] - for sublayer in [self._lsha, self._stda]: - unused_weights_or_cache_marker, state_or_cache_marker = sublayer.init( - input_signature, use_cache=False) - states.append(state_or_cache_marker) - self.state = tuple(states) - self.weights = () # Wrapper forward_and_backward assumes this is () - - def forward(self, xs): - """Executes this layer as part of a forward pass through the model.""" - rng1, rng2 = fastmath.random.split(self.rng, 2) - l = xs[0].shape[1] if isinstance(xs, tuple) else xs.shape[1] - if self._std_length is None or l > self._std_length: - s = self.state[0] - outputs, s = self._lsha.pure_fn(xs, (), s, rng1, use_cache=True) - self.state = (s, self.state[1]) - else: - s = self.state[1] - w = ((), ()) # std attention is a Serial(Dup, DotProduct), needs 2 () - outputs, s = self._stda.pure_fn(xs, w, s, rng2, use_cache=True) - self.state = (self.state[0], s) - return outputs - - def forward_and_or_backward(self, inputs, state, rng, - output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes.""" - assert compute_output - assert not update_state - - l = inputs[0].shape[1] if isinstance(inputs, tuple) else inputs.shape[1] - rng1, rng2 = fastmath.random.split(rng, 2) - if self._std_length is None or l > self._std_length: - # Run the LSH layer - s = state[0] - (out, unused_new_s, grads_inputs) = self._lsha.forward_and_or_backward( - inputs, s, rng1, output_grad=output_grad, - compute_output=True, update_state=False) - else: - # Run the standard layer - s, w = state[1], ((), ()) - out, std_vjp_fn, unused_new_s = fastmath.vjp( - self._stda.pure_fn, inputs, w, s, rng2, has_aux=True) - if output_grad is not None: - grads_inputs, _, _, _ = std_vjp_fn(output_grad) - else: - grads_inputs = None + """LSH attention mixed with standard attention used until std_length.""" + + def __init__( + self, + n_heads=1, + d_qk=64, + d_v=64, + causal=False, + masked=False, + std_length=None, + mode="train", + output_dropout=0.0, + attention_dropout=0.0, + force_no_dropout=False, + **pure_lsh_implementation_kwargs, + ): + # This class could be replaced with a Branch and tl.Fn(..) selecting + # one of the arguments based on the class. But, similarly to the Wrapper + # below, we need forward_and_backward currently to pass remembered state + # back to the PureLSH layer. We should switch that to the other Remember + # mechanism used for the SparseFF layer (and clarify and document that too). + # Once this is done, we can remove this and the Wrapper class. + attention_dropout = 0.0 if force_no_dropout else attention_dropout + output_dropout = 0.0 if force_no_dropout else output_dropout + self._lsha = PureLSHSelfAttention( + n_heads=n_heads, + d_qk=d_qk, + d_v=d_v, + causal=causal, + masked=masked, + mode=mode, + output_dropout=output_dropout, + attention_dropout=attention_dropout, + **pure_lsh_implementation_kwargs, + ) + if causal: + pure_attn = attention.DotProductCausalAttention + preprocess = core.Fn("dup_shared_qk", lambda q, v: (q, q, v), n_out=3) + else: + pure_attn = attention.DotProductAttention + + def _add_heads_to_mask(m): + m_with_heads = np.reshape(m, (m.shape[0], 1, m.shape[1])) + m_with_heads = np.broadcast_to( + m_with_heads, (m.shape[0], n_heads, m.shape[1]) + ) + return np.reshape(m_with_heads, (-1, 1, m.shape[1])) + + preprocess = core.Fn( + "dup_shared_qk_and_make_mask", + lambda q, v, m: (q, q, v, _add_heads_to_mask(m)), + n_out=4, + ) + self._stda = cb.Serial( + preprocess, pure_attn(dropout=attention_dropout, mode=mode) + ) + self._std_length = std_length + self._sublayers = [self._lsha, self._stda] + + if self._stda.n_in != self._lsha.n_in: + raise ValueError(f"n_in diff: {self._stda.n_in} != {self._lsha.n_in}") + if self._stda.n_out != self._lsha.n_out: + raise ValueError(f"n_out diff: {self._stda.n_out} != {self._lsha.n_out}") + super().__init__(n_in=self._stda.n_in, n_out=self._stda.n_out) + + def init_weights_and_state(self, input_signature): + """Initializes weights and state for inputs with the given signature.""" + states = [] + for sublayer in [self._lsha, self._stda]: + unused_weights_or_cache_marker, state_or_cache_marker = sublayer.init( + input_signature, use_cache=False + ) + states.append(state_or_cache_marker) + self.state = tuple(states) + self.weights = () # Wrapper forward_and_backward assumes this is () + + def forward(self, xs): + """Executes this layer as part of a forward pass through the model.""" + rng1, rng2 = fastmath.random.split(self.rng, 2) + l = xs[0].shape[1] if isinstance(xs, tuple) else xs.shape[1] + if self._std_length is None or l > self._std_length: + s = self.state[0] + outputs, s = self._lsha.pure_fn(xs, (), s, rng1, use_cache=True) + self.state = (s, self.state[1]) + else: + s = self.state[1] + w = ((), ()) # std attention is a Serial(Dup, DotProduct), needs 2 () + outputs, s = self._stda.pure_fn(xs, w, s, rng2, use_cache=True) + self.state = (self.state[0], s) + return outputs + + def forward_and_or_backward( + self, + inputs, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes.""" + assert compute_output + assert not update_state + + l = inputs[0].shape[1] if isinstance(inputs, tuple) else inputs.shape[1] + rng1, rng2 = fastmath.random.split(rng, 2) + if self._std_length is None or l > self._std_length: + # Run the LSH layer + s = state[0] + (out, unused_new_s, grads_inputs) = self._lsha.forward_and_or_backward( + inputs, + s, + rng1, + output_grad=output_grad, + compute_output=True, + update_state=False, + ) + else: + # Run the standard layer + s, w = state[1], ((), ()) + out, std_vjp_fn, unused_new_s = fastmath.vjp( + self._stda.pure_fn, inputs, w, s, rng2, has_aux=True + ) + if output_grad is not None: + grads_inputs, _, _, _ = std_vjp_fn(output_grad) + else: + grads_inputs = None - return (out, None, grads_inputs) + return (out, None, grads_inputs) class PureLSHSelfAttentionWrapper(cb.Serial): - """Pure LSH serial.""" - - def __init__(self, - n_heads=1, - d_qk=64, - d_v=64, - causal=False, - masked=False, - output_dropout=0.0, - attention_dropout=0.0, - pure_lsh_implementation=None, - bias=True, - mode='train', - num_weights=3, - sparsity=16, - weights_format='model', - rotary_position_emb=False, - **pure_lsh_implementation_kwargs): - d_model = d_qk * n_heads - self._qkv = _ProjectAndSplitHeads( - d_model, - n_heads, - bias, - num_weights=num_weights, - sparsity=sparsity, - weights_format=weights_format, - rotary_position_emb=rotary_position_emb, - mode=mode) - self._attn = pure_lsh_implementation(n_heads=n_heads, - d_qk=d_qk, - d_v=d_v, - causal=causal, - masked=masked, - mode=mode, - output_dropout=output_dropout, - attention_dropout=attention_dropout, - **pure_lsh_implementation_kwargs) - self._merge = attention.MergeHeads(n_heads) - if weights_format != 'sparse': - self._dense = core.Dense(d_model, use_bias=bias) - super().__init__(self._qkv, self._attn, self._merge, self._dense) - else: - self._dense = None - super().__init__(self._qkv, self._attn, self._merge) + """Pure LSH serial.""" + + def __init__( + self, + n_heads=1, + d_qk=64, + d_v=64, + causal=False, + masked=False, + output_dropout=0.0, + attention_dropout=0.0, + pure_lsh_implementation=None, + bias=True, + mode="train", + num_weights=3, + sparsity=16, + weights_format="model", + rotary_position_emb=False, + **pure_lsh_implementation_kwargs, + ): + d_model = d_qk * n_heads + self._qkv = _ProjectAndSplitHeads( + d_model, + n_heads, + bias, + num_weights=num_weights, + sparsity=sparsity, + weights_format=weights_format, + rotary_position_emb=rotary_position_emb, + mode=mode, + ) + self._attn = pure_lsh_implementation( + n_heads=n_heads, + d_qk=d_qk, + d_v=d_v, + causal=causal, + masked=masked, + mode=mode, + output_dropout=output_dropout, + attention_dropout=attention_dropout, + **pure_lsh_implementation_kwargs, + ) + self._merge = attention.MergeHeads(n_heads) + if weights_format != "sparse": + self._dense = core.Dense(d_model, use_bias=bias) + super().__init__(self._qkv, self._attn, self._merge, self._dense) + else: + self._dense = None + super().__init__(self._qkv, self._attn, self._merge) - def forward_and_or_backward(self, inputs, weights, state, rng, - output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. + def forward_and_or_backward( + self, + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + assert compute_output + assert not update_state + assert output_grad is not None + + rngs = fastmath.random.split(rng, 4) + # Layer order forward: self._qkv, self._attn, self._merge, self._dense + # Use forward_and_or_backward for attn. + + qkv_output, qkv_vjp_fn, unused_qkv_new_state = fastmath.vjp( + self._qkv.pure_fn, inputs, weights[0], state[0], rngs[0], has_aux=True + ) - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. + attn_output, _, _ = self._attn.forward_and_or_backward( + qkv_output, + state[1], + rngs[1], + output_grad=None, + compute_output=True, + update_state=False, + ) - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - assert compute_output - assert not update_state - assert output_grad is not None - - rngs = fastmath.random.split(rng, 4) - # Layer order forward: self._qkv, self._attn, self._merge, self._dense - # Use forward_and_or_backward for attn. - - qkv_output, qkv_vjp_fn, unused_qkv_new_state = fastmath.vjp( - self._qkv.pure_fn, inputs, weights[0], state[0], - rngs[0], has_aux=True) - - attn_output, _, _ = self._attn.forward_and_or_backward( - qkv_output, state[1], rngs[1], output_grad=None, - compute_output=True, update_state=False) - - merge_output, merge_vjp_fn, unused_merge_new_state = fastmath.vjp( - self._merge.pure_fn, attn_output, weights[2], state[2], rngs[2], - has_aux=True) - - if self._dense is not None: - dense_output, dense_vjp_fn, unused_dense_new_state = fastmath.vjp( - self._dense.pure_fn, merge_output, weights[3], state[3], - rngs[3], has_aux=True) - - # Now backward. - if self._dense is not None: - dense_grads_inputs, dense_grads_weights, _, _ = dense_vjp_fn( - output_grad) - else: - dense_grads_inputs = output_grad - merge_grads_inputs, merge_grads_weights, _, _ = merge_vjp_fn( - dense_grads_inputs) - - # Use forward_and_or_backward for attn. - (attn_output, _, attn_grads_inputs) = self._attn.forward_and_or_backward( - qkv_output, state[1], rngs[1], output_grad=merge_grads_inputs, - compute_output=True, update_state=False) - - # Backward for qkv layer. - qkv_grad_inputs, qkv_grads_weights, _, _ = qkv_vjp_fn(attn_grads_inputs) - - if self._dense is None: - grads_weights = (qkv_grads_weights, - (), - merge_grads_weights) - else: - grads_weights = (qkv_grads_weights, - (), - merge_grads_weights, - dense_grads_weights) - - # Output is (output, new_state, inputs_grad, weights_grad). - # new_state is None because update_state is False. - if self._dense is None: - return (merge_output, None, qkv_grad_inputs, grads_weights) - else: - return (dense_output, None, qkv_grad_inputs, grads_weights) + merge_output, merge_vjp_fn, unused_merge_new_state = fastmath.vjp( + self._merge.pure_fn, + attn_output, + weights[2], + state[2], + rngs[2], + has_aux=True, + ) + + if self._dense is not None: + dense_output, dense_vjp_fn, unused_dense_new_state = fastmath.vjp( + self._dense.pure_fn, + merge_output, + weights[3], + state[3], + rngs[3], + has_aux=True, + ) + + # Now backward. + if self._dense is not None: + dense_grads_inputs, dense_grads_weights, _, _ = dense_vjp_fn(output_grad) + else: + dense_grads_inputs = output_grad + merge_grads_inputs, merge_grads_weights, _, _ = merge_vjp_fn(dense_grads_inputs) + + # Use forward_and_or_backward for attn. + (attn_output, _, attn_grads_inputs) = self._attn.forward_and_or_backward( + qkv_output, + state[1], + rngs[1], + output_grad=merge_grads_inputs, + compute_output=True, + update_state=False, + ) + + # Backward for qkv layer. + qkv_grad_inputs, qkv_grads_weights, _, _ = qkv_vjp_fn(attn_grads_inputs) + + if self._dense is None: + grads_weights = (qkv_grads_weights, (), merge_grads_weights) + else: + grads_weights = ( + qkv_grads_weights, + (), + merge_grads_weights, + dense_grads_weights, + ) + + # Output is (output, new_state, inputs_grad, weights_grad). + # new_state is None because update_state is False. + if self._dense is None: + return (merge_output, None, qkv_grad_inputs, grads_weights) + else: + return (dense_output, None, qkv_grad_inputs, grads_weights) class EncDecAttention(EfficientAttentionBase): - """Memory-efficient encoder-decoder attention.""" - - def __init__(self, - n_heads=2, d_qk=64, d_v=64, - masked=True, - mode='train', - attention_dropout=0.0, - output_dropout=0.0, - n_parallel_heads=None, - use_python_loop=False, - use_reference_code=False, - ): - super().__init__( - n_heads=n_heads, - n_in=(3 if masked else 2), - n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop, - use_reference_code=use_reference_code, + """Memory-efficient encoder-decoder attention.""" + + def __init__( + self, + n_heads=2, + d_qk=64, + d_v=64, + masked=True, + mode="train", + attention_dropout=0.0, + output_dropout=0.0, + n_parallel_heads=None, + use_python_loop=False, + use_reference_code=False, + ): + super().__init__( + n_heads=n_heads, + n_in=(3 if masked else 2), + n_parallel_heads=n_parallel_heads, + use_python_loop=use_python_loop, + use_reference_code=use_reference_code, ) - self._d_qk = d_qk - self._d_v = d_v - self._masked = masked - self._mode = mode - if mode == 'train': - self._attention_dropout = attention_dropout - self._output_dropout = output_dropout - else: - self._attention_dropout = 0.0 - self._output_dropout = 0.0 - - def _kernel_initializer(self, shape, rng): - # Attention uses Glorot uniform initalization with respect to the *total* - # dimension of queries/key/values across all heads. We initialize one head - # at a time in this class, so init.GlorotUniformInitializer won't work. - # This initialization type is for parity with previous Trax & tensor2tensor - # Transformers; it's not clear if it's strictly needed for model accuracy. - lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) - return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) - - def create_weights_unbatched(self, input_signature, rng): - d_model = input_signature[0].shape[-1] - d_kv_antecedent = input_signature[1].shape[-1] - rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) - w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) - w_k = self._kernel_initializer((d_kv_antecedent, self._d_qk), rng_k) - w_v = self._kernel_initializer((d_kv_antecedent, self._d_v), rng_v) - w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) - return (w_q, w_k, w_v, w_o) - - def forward_unbatched(self, q_antecedent, kv_antecedent, mask=None, *, - weights, state, rng, update_state): - del update_state - attend_rng, output_rng = fastmath.random.split(rng) - w_q, w_k, w_v, w_o = weights - - q = np.matmul(q_antecedent, w_q) - k = np.matmul(kv_antecedent, w_k) - v = np.matmul(kv_antecedent, w_v) - - if not self._masked: - assert mask is None - q_info = kv_info = mask_fn = None - else: - # mask is a boolean array (True means "is valid token") - assert mask is not None - q_info = None - kv_info = (~mask).astype(np.int32) # pylint: disable=invalid-unary-operand-type - def mask_fn(dots, q_info, kv_info): - del q_info - mask = kv_info.astype(np.float32) - dots = dots - 1e9 * mask - return dots + self._d_qk = d_qk + self._d_v = d_v + self._masked = masked + self._mode = mode + if mode == "train": + self._attention_dropout = attention_dropout + self._output_dropout = output_dropout + else: + self._attention_dropout = 0.0 + self._output_dropout = 0.0 + + def _kernel_initializer(self, shape, rng): + # Attention uses Glorot uniform initalization with respect to the *total* + # dimension of queries/key/values across all heads. We initialize one head + # at a time in this class, so init.GlorotUniformInitializer won't work. + # This initialization type is for parity with previous Trax & tensor2tensor + # Transformers; it's not clear if it's strictly needed for model accuracy. + lim = np.sqrt(6.0 / (shape[0] + shape[1] * self._n_heads)) + return fastmath.random.uniform(rng, shape, np.float32, -lim, lim) + + def create_weights_unbatched(self, input_signature, rng): + d_model = input_signature[0].shape[-1] + d_kv_antecedent = input_signature[1].shape[-1] + rng_q, rng_k, rng_v, rng_o = fastmath.random.split(rng, 4) + w_q = self._kernel_initializer((d_model, self._d_qk), rng_q) + w_k = self._kernel_initializer((d_kv_antecedent, self._d_qk), rng_k) + w_v = self._kernel_initializer((d_kv_antecedent, self._d_v), rng_v) + w_o = np.transpose(self._kernel_initializer((d_model, self._d_v), rng_o)) + return (w_q, w_k, w_v, w_o) + + def forward_unbatched( + self, + q_antecedent, + kv_antecedent, + mask=None, + *, + weights, + state, + rng, + update_state, + ): + del update_state + attend_rng, output_rng = fastmath.random.split(rng) + w_q, w_k, w_v, w_o = weights - o, _ = attend( - q, k, v, - mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, - dropout=self._attention_dropout, rng=attend_rng, + q = np.matmul(q_antecedent, w_q) + k = np.matmul(kv_antecedent, w_k) + v = np.matmul(kv_antecedent, w_v) + + if not self._masked: + assert mask is None + q_info = kv_info = mask_fn = None + else: + # mask is a boolean array (True means "is valid token") + assert mask is not None + q_info = None + kv_info = (~mask).astype( + np.int32 + ) # pylint: disable=invalid-unary-operand-type + + def mask_fn(dots, q_info, kv_info): + del q_info + mask = kv_info.astype(np.float32) + dots = dots - 1e9 * mask + return dots + + o, _ = attend( + q, + k, + v, + mask_fn=mask_fn, + q_info=q_info, + kv_info=kv_info, + dropout=self._attention_dropout, + rng=attend_rng, ) - out = np.matmul(o, w_o) - out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) - return out, state + out = np.matmul(o, w_o) + out = apply_broadcasted_dropout(out, self._output_dropout, output_rng) + return out, state class LSHFF(base.Layer): - """Feed-forward block with LSH. - - The original (non-LSH) feed-forward block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - The original block can be slow in decoding due to the need to fetch a lot of - weights from memory. The LSH block aims to exploit this sparsity. So in the - first Dense(d_ff) layer, instead of making a full matrix multiplication, - this block only multiplies by the parts of the weights matrix that have - the highest chance to give non-0 after Relu. This is determined by taking - a number of locality-sensitive hashes and masking to only include weights - that have one hash identical to the multiplied element. - """ - - def __init__(self, d_ff, n_buckets, n_hashes=4, mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6)): - """Returns a LSH feed-forward block.""" - super().__init__(name=f'LSHFF_{d_ff}') - self._mode = mode - self._d_ff = d_ff - self._n_buckets = n_buckets - self._n_hashes = n_hashes - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input. + """Feed-forward block with LSH. + + The original (non-LSH) feed-forward block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. + + The original block can be slow in decoding due to the need to fetch a lot of + weights from memory. The LSH block aims to exploit this sparsity. So in the + first Dense(d_ff) layer, instead of making a full matrix multiplication, + this block only multiplies by the parts of the weights matrix that have + the highest chance to give non-0 after Relu. This is determined by taking + a number of locality-sensitive hashes and masking to only include weights + that have one hash identical to the multiplied element. """ - w1, w2, b2 = self.weights - x_shape = x.shape - x = np.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Hash x into hash buckets; x_buckets is [n_hashes, joint_batch]. - x_buckets, _ = hash_vecs(x, self._n_buckets, self._n_hashes, self.rng) - - # Hash w1 into hash buckets; w1_buckets is [n_hashes, d_ff]. - # Note that we use the same self.rng - so the same hash vectors as for x. - w1_buckets, _ = hash_vecs(w1, self._n_buckets, self._n_hashes, self.rng) - - # Create a mask to determine which x's have the same hash as which w1's. - # First: just subtract the hashes and make them non-negative. - hash_mask = (x_buckets[:, :, None] - w1_buckets[:, None, :])**2 - hash_mask = fastmath.stop_gradient(hash_mask) # make sure no gradients here - # hash_mask is [n_hashes, joint_batch, d_ff], 0 iff hashes were equal - hash_mask = 1 - np.minimum(hash_mask, 1) # now 1 if equal, 0 otherwise - # we now sum over n_hashes and use min, it's 1 iff any of n_hashes was equal - hash_mask = np.minimum(np.sum(hash_mask, axis=0), 1) - hash_mask = hash_mask.astype(np.float32) # convert to float to use mask - - # First dense layer of the block, with hash masking. - mid = np.dot(x, w1.T) * hash_mask # [joint_batch, d_ff] - - # Relu and the second dense layer, as in a standard feed-forward block. - # Note: we merge the second block into this layer because of future plans, - # not anything implemented yet. The potential gain would be as follows: - # in predict mode, we would pre-hash (once) both w1 and w2 and only do - # matmuls (and memory copies) for the parts that correspond to the hash - # of the input. The hash of w1 determines which parts of Relu are 0, so - # it also determines which parts of w2 can be skipped. - relu = np.where(mid <= 0, np.zeros_like(mid), mid) - res = np.dot(relu, w2) + b2 - return np.reshape(res, x_shape) # un-flatten if needed - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature.shape[-1] - shape_w1 = (self._d_ff, d_model) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - - rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - b2 = self._bias_initializer(shape_b2, rng_b2) - self.weights = (w1, w2, b2) + + def __init__( + self, + d_ff, + n_buckets, + n_hashes=4, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + """Returns a LSH feed-forward block.""" + super().__init__(name=f"LSHFF_{d_ff}") + self._mode = mode + self._d_ff = d_ff + self._n_buckets = n_buckets + self._n_hashes = n_hashes + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + w1, w2, b2 = self.weights + x_shape = x.shape + x = np.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Hash x into hash buckets; x_buckets is [n_hashes, joint_batch]. + x_buckets, _ = hash_vecs(x, self._n_buckets, self._n_hashes, self.rng) + + # Hash w1 into hash buckets; w1_buckets is [n_hashes, d_ff]. + # Note that we use the same self.rng - so the same hash vectors as for x. + w1_buckets, _ = hash_vecs(w1, self._n_buckets, self._n_hashes, self.rng) + + # Create a mask to determine which x's have the same hash as which w1's. + # First: just subtract the hashes and make them non-negative. + hash_mask = (x_buckets[:, :, None] - w1_buckets[:, None, :]) ** 2 + hash_mask = fastmath.stop_gradient(hash_mask) # make sure no gradients here + # hash_mask is [n_hashes, joint_batch, d_ff], 0 iff hashes were equal + hash_mask = 1 - np.minimum(hash_mask, 1) # now 1 if equal, 0 otherwise + # we now sum over n_hashes and use min, it's 1 iff any of n_hashes was equal + hash_mask = np.minimum(np.sum(hash_mask, axis=0), 1) + hash_mask = hash_mask.astype(np.float32) # convert to float to use mask + + # First dense layer of the block, with hash masking. + mid = np.dot(x, w1.T) * hash_mask # [joint_batch, d_ff] + + # Relu and the second dense layer, as in a standard feed-forward block. + # Note: we merge the second block into this layer because of future plans, + # not anything implemented yet. The potential gain would be as follows: + # in predict mode, we would pre-hash (once) both w1 and w2 and only do + # matmuls (and memory copies) for the parts that correspond to the hash + # of the input. The hash of w1 determines which parts of Relu are 0, so + # it also determines which parts of w2 can be skipped. + relu = np.where(mid <= 0, np.zeros_like(mid), mid) + res = np.dot(relu, w2) + b2 + return np.reshape(res, x_shape) # un-flatten if needed + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature.shape[-1] + shape_w1 = (self._d_ff, d_model) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + b2 = self._bias_initializer(shape_b2, rng_b2) + self.weights = (w1, w2, b2) diff --git a/trax/layers/research/efficient_attention_test.py b/trax/layers/research/efficient_attention_test.py deleted file mode 100644 index fd4cbc9e3..000000000 --- a/trax/layers/research/efficient_attention_test.py +++ /dev/null @@ -1,441 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.research.efficient_attention.""" - -from absl.testing import parameterized -import jax -import numpy as np -from tensorflow import test - -from trax import fastmath -from trax import shapes -from trax.fastmath import numpy as jnp -from trax.layers.research import efficient_attention - - -class EfficientAttentionTest(test.TestCase, parameterized.TestCase): - - def test_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.SelfAttention( - n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_lsh_ff(self): - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHFF(d_ff=1024*8, n_buckets=[16, 8]) - x = np.ones((3, 7, 1024)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_self_attention_tf(self): - with fastmath.use_backend(fastmath.Backend.TFNP): - layer = efficient_attention.SelfAttention( - n_heads=5, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_lsh_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHSelfAttention( - n_heads=5, d_qk=7, d_v=17, causal=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - - def _run_forward_and_backward(self, model, inp, weights, state): - def forward(inp, weights): - return model.pure_fn( - inp, weights, state, rng=jax.random.PRNGKey(0)) - out, vjpfun, new_state = jax.vjp(forward, inp, weights, has_aux=True) - inp_grad, weights_grad = vjpfun(fastmath.numpy.ones_like(inp)) - return out, new_state, inp_grad, weights_grad - - def _test_equivalence_to_reference_code( - self, model_cls, inp, input_signature, common_kwargs, *test_kwargs): - ref_model = model_cls(use_reference_code=True, **common_kwargs) - rng = fastmath.random.get_prng(123) - weights, state = ref_model.init(input_signature, rng) - - ref_all = self._run_forward_and_backward(ref_model, inp, weights, state) - ref_out, ref_state, ref_inp_grad, ref_weights_grad = ref_all - - for kwargs in test_kwargs: - test_model = model_cls(**common_kwargs, **kwargs) - state = test_model.init(input_signature, rng)[1] - test_all = self._run_forward_and_backward(test_model, inp, weights, state) - test_out, test_state, test_inp_grad, test_weights_grad = test_all - - self.assertEqual(jax.tree_structure(ref_out), - jax.tree_structure(test_out)) - self.assertEqual(jax.tree_structure(ref_state), - jax.tree_structure(test_state)) - self.assertEqual(jax.tree_structure(ref_inp_grad), - jax.tree_structure(test_inp_grad)) - self.assertEqual(jax.tree_structure(ref_weights_grad), - jax.tree_structure(test_weights_grad)) - - check_close = lambda x, y: self.assertAllClose(x, y, rtol=2e-3, atol=2e-3) - fastmath.nested_map_multiarg(check_close, ref_out, test_out) - fastmath.nested_map_multiarg(check_close, ref_state, test_state) - fastmath.nested_map_multiarg(check_close, ref_inp_grad, test_inp_grad) - fastmath.nested_map_multiarg(check_close, ref_weights_grad, - test_weights_grad) - - def test_batching_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - common_kwargs = dict( - n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=5, n_chunks_before=1, n_chunks_after=0, - attention_dropout=0.2, output_dropout=0.1, mode='train', - ) - test_kwargs = [] - for n_parallel_heads in [1, 3, 6, 12]: - for use_python_loop in [True, False]: - test_kwargs.append(dict(n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop)) - - x = jax.random.uniform( - jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) - input_signature = shapes.signature(x) - self._test_equivalence_to_reference_code( - efficient_attention.SelfAttention, - x, input_signature, - common_kwargs, *test_kwargs) - - def test_batching_lsh_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - common_kwargs = dict( - n_heads=6, d_qk=7, d_v=17, causal=True, - chunk_len=5, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - attention_dropout=0.2, output_dropout=0.1, mode='train', - ) - test_kwargs = [] - for n_parallel_heads in [1, 3, 6, 12]: - for use_python_loop in [True, False]: - test_kwargs.append(dict(n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop)) - - x = jax.random.uniform( - jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) - input_signature = shapes.signature(x) - self._test_equivalence_to_reference_code( - efficient_attention.LSHSelfAttention, - x, input_signature, - common_kwargs, *test_kwargs) - - def _test_fast_inference( - self, model_cls, x, input_signature, common_kwargs, *test_kwargs): - ref_model = model_cls(use_reference_code=True, mode='eval', **common_kwargs) - weights, state = ref_model.init(input_signature) - - ref_out, _ = ref_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - - def get_slice(pytree, i): - def get_slice_for_val(x): - if isinstance(x, shapes.ShapeDtype): - return shapes.ShapeDtype(shape=x.shape[:1] + (1,) + x.shape[2:], - dtype=x.dtype) - else: - return x[:, i:i+1] - return jax.tree_map(get_slice_for_val, pytree) - - seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1] - - for kwargs in test_kwargs: - test_model = model_cls(mode='predict', **common_kwargs, **kwargs) - cur_state = test_model.init(get_slice(input_signature, 0))[1] - out = [] - for i in range(seqlen): - cur_out, cur_state = test_model.pure_fn( - get_slice(x, i), weights, cur_state, jax.random.PRNGKey(0)) - out.append(cur_out) - out = jnp.concatenate(out, axis=1) - - self.assertAllClose(out, ref_out, rtol=1e-3, atol=1e-3) - - def test_fast_inference_self_attention(self): - with fastmath.use_backend(fastmath.Backend.JAX): - common_kwargs = dict( - n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True, - chunk_len=5, n_chunks_before=1, n_chunks_after=0, - attention_dropout=0.0, output_dropout=0.0, - ) - test_kwargs = [] - for n_parallel_heads in [1, 3, 6, 12]: - for use_python_loop in [True, False]: - test_kwargs.append(dict(n_parallel_heads=n_parallel_heads, - use_python_loop=use_python_loop)) - - x = jax.random.uniform( - jax.random.PRNGKey(0), (2, 10, 13), dtype=jnp.float32) - input_signature = shapes.signature(x) - self._test_fast_inference( - efficient_attention.SelfAttention, - x, input_signature, - common_kwargs, *test_kwargs) - - def _test_lsh_self_attention_deterministic_given_seed(self, causal=False): - # Once the initialization and the call seeds are pinned down we have - # deterministic output. - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHSelfAttention( - n_heads=5, d_qk=7, d_v=17, causal=causal, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - - def get_output(): - _, _ = layer.init(shapes.signature(x), jax.random.PRNGKey(0)) - return layer(x, rng=jax.random.PRNGKey(1)) - - ys = [get_output() for _ in range(10)] - - self.assertEqual(ys[0].shape, x.shape) - - for y in ys[1:]: - np.testing.assert_array_almost_equal(ys[0], y, decimal=6) - - def test_lsh_determinism_causal(self): - self._test_lsh_self_attention_deterministic_given_seed(causal=True) - - def test_lsh_determinism_non_causal(self): - self._test_lsh_self_attention_deterministic_given_seed(causal=False) - - def test_lsh_self_attention_masked_non_causal(self): - # Test that when the input that is in the masked area changes the attention - # for the un-masked outputs doesn't change, but the masked region does - # change. - with fastmath.use_backend(fastmath.Backend.JAX): - layer = efficient_attention.LSHSelfAttention( - n_heads=5, d_qk=7, d_v=17, causal=False, masked=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - use_reference_code=True, attention_dropout=0.0, mode='train') - - batch = 5 - max_len = 32 - hidden = 8 - - x = np.random.uniform(size=(batch, max_len, hidden)) - mask = np.ones((batch, max_len)).astype(bool) - rngs = jax.random.randint( - jax.random.PRNGKey(0), (batch,), minval=1, maxval=max_len - 1) - - # Set some suffix of each mask[b] to 0. - for i in range(batch): - mask[i, rngs[i]:] = 0 - - # Fix rngs and get the output for the LSH layer. - def get_output(x, mask): - xs = [x, mask] - _, _ = layer.init(shapes.signature(xs), jax.random.PRNGKey(0)) - return layer(xs, rng=jax.random.PRNGKey(1)) - - # Get the attention output for masked x. - y = get_output(x, mask) - - # Change x, but only in the masked regions. - for i in range(batch): - x[i, rngs[i]:] = np.random.uniform(size=(max_len - rngs[i], hidden)) - - y2 = get_output(x, mask) - - for i in range(batch): - # y and y2 should be identical in the non-masked part. - np.testing.assert_array_almost_equal(y[i, :rngs[i]], y2[i, :rngs[i]], - decimal=6) - - # In the masked out part, they should be different. - self.assertGreater( - np.mean(np.abs(y[i, rngs[i]:] - y2[i, rngs[i]:])), 1e-5) - - @parameterized.named_parameters(('_weights_2', 2), ('_weights_3', 3)) - def test_pure_lsh_wrapper_causal_non_masked(self, num_weights): - with fastmath.use_backend(fastmath.Backend.JAX): - n_heads = 5 - batch, seqlen, d_head = 3, 32, 8 - n_hashes = 2 - d_model = n_heads * d_head - layer = efficient_attention.PureLSHSelfAttentionWrapper( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=n_hashes, n_buckets=4, bias=False, - pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, - mode='train', num_weights=num_weights) - - rng = jax.random.PRNGKey(0) - rng, x_rng = jax.random.split(rng) - - input_shape = (batch, seqlen, d_model) - x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) - - inp = x - w, s = layer.init(shapes.signature(inp)) - o = layer(inp) - - # Get the actual weights. - weights = fastmath.tree_leaves(w) - # Assert number of weights is as expected, the extra 1 is for output. - self.assertLen(weights, num_weights + 1) - - # Assert each weight is of the expected shape. - for i in range(num_weights + 1): - self.assertEqual(weights[i].shape, (d_model, d_model)) - - # Test that the output and the input shape match. - self.assertEqual(inp.shape, o.shape) - - # Assert state is the shape expected. - state = fastmath.tree_leaves(s) - self.assertLen(state, 2) - # buckets - self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) - # rngs - self.assertEqual(state[1].shape, (batch * n_heads, 2)) - - @parameterized.named_parameters(('_weights_2', 2), ('_weights_3', 3)) - def test_pure_lsh_wrapper_non_causal_masked(self, num_weights): - with fastmath.use_backend(fastmath.Backend.JAX): - n_heads = 5 - batch, seqlen, d_head = 3, 32, 8 - num_weights = 2 - n_hashes = 2 - d_model = n_heads * d_head - layer = efficient_attention.PureLSHSelfAttentionWrapper( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=False, masked=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=n_hashes, n_buckets=4, bias=False, - pure_lsh_implementation=efficient_attention.PureLSHSelfAttention, - mode='train', num_weights=num_weights) - - rng = jax.random.PRNGKey(0) - rng, x_rng = jax.random.split(rng) - - input_shape = (batch, seqlen, d_model) - x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32) - mask = jnp.ones((batch, seqlen), dtype=jnp.int32) - - inp = (x, mask) - w, s = layer.init(shapes.signature(inp)) - o = layer(inp) - - # Get the actual weights. - weights = fastmath.tree_leaves(w) - # Assert number of weights is as expected, the extra 1 is for output. - self.assertLen(weights, num_weights + 1) - - # Assert each weight is of the expected shape. - for i in range(num_weights + 1): - self.assertEqual(weights[i].shape, (d_model, d_model)) - - # Test that the output and the x's shape match. - self.assertEqual(x.shape, o.shape) - - # Assert state is the shape expected. - state = fastmath.tree_leaves(s) - self.assertLen(state, 2) - # buckets - self.assertEqual(state[0].shape, (batch * n_heads, n_hashes * seqlen)) - # rngs - self.assertEqual(state[1].shape, (batch * n_heads, 2)) - - def test_lsh_and_pure_lsh_self_attention_equivalence(self): - # Given the same weight matrices and random numbers, do these produce the - # same output. - with fastmath.use_backend(fastmath.Backend.JAX): - n_heads = 4 - d_head = 4 - d_model = n_heads * d_head - pure_lsh_layer = efficient_attention.PureLSHSelfAttention( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=4, n_buckets=8, - use_reference_code=False, - attention_dropout=0.0, - use_python_loop=True, - bias=False, mode='train') - lsh_layer = efficient_attention.LSHSelfAttention( - n_heads=n_heads, d_qk=d_head, d_v=d_head, causal=True, masked=False, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=4, n_buckets=8, - use_reference_code=False, - attention_dropout=0.0, - use_python_loop=True, - mode='train') - - batch, seqlen = 3, 32 - input_shape = (batch, seqlen, d_model) - - x = jax.random.uniform(jax.random.PRNGKey(0), input_shape, - dtype=jnp.float32) - lsh_layer_input = x - - call_rng = jax.random.PRNGKey(42) - - lsh_layer_weights, lsh_layer_state = lsh_layer.init( - shapes.signature(lsh_layer_input)) - lsh_layer.rng = call_rng - lsh_layer_output = lsh_layer(lsh_layer_input) - - # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head), - # (n_heads, d_head, d_model) - # Abbreviated as - hmn, hmn, hnm - w_qk, w_v, w_o = lsh_layer_weights - - qk = jnp.einsum('blm,hmn->bhln', x, w_qk) - qk = qk.reshape((-1, qk.shape[2], qk.shape[3])) - - v = jnp.einsum('blm,hmn->bhln', x, w_v) - v = v.reshape((-1, v.shape[2], v.shape[3])) - - pure_lsh_layer_input = (qk, v) - _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input)) - pure_lsh_layer.rng = call_rng - pure_lsh_layer.state = lsh_layer_state - pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input) - - # b*h,l,n - pure_lsh_layer_output = pure_lsh_layer_output.reshape( - (batch, -1) + pure_lsh_layer_output.shape[1:]) - pure_lsh_layer_output_projected = ( - jnp.einsum('bhld,hdm->blm', pure_lsh_layer_output, w_o)) - - diff = pure_lsh_layer_output_projected - lsh_layer_output - avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff)) - - self.assertLess(avg_diff, 1e-5) - -if __name__ == '__main__': - test.main() diff --git a/trax/layers/research/flash_attention.py b/trax/layers/research/flash_attention.py new file mode 100644 index 000000000..336875930 --- /dev/null +++ b/trax/layers/research/flash_attention.py @@ -0,0 +1,53 @@ +# coding=utf-8 +"""Flash attention implementation for Trax.""" + +import jax +from jax import lax +import jax.numpy as jnp + + +def flash_attention(q, k, v, *, block_size, mask=None): + """Memory efficient dot-product attention. + + Args: + q: Queries array of shape [batch, len, depth]. + k: Keys array of shape [batch, len, depth]. + v: Values array of shape [batch, len, depth_v]. + block_size: Integer block size used for computation. + mask: Optional boolean mask of shape [batch, len] where ``True`` values + indicate positions that should be masked out. + + Returns: + Array of shape [batch, len, depth_v] with the attention outputs. + """ + seqlen = q.shape[1] + pad_len = (block_size - seqlen % block_size) % block_size + if pad_len: + pad = ((0, 0), (0, pad_len), (0, 0)) + q = jnp.pad(q, pad) + k = jnp.pad(k, pad) + v = jnp.pad(v, pad) + if mask is not None: + mask = jnp.pad(mask, ((0, 0), (0, pad_len)), constant_values=True) + total_len = q.shape[1] + + if mask is not None: + mask_b = mask[:, None, :] + else: + mask_b = None + + outputs = [] + for start in range(0, total_len, block_size): + q_block = lax.dynamic_slice( + q, (0, start, 0), (q.shape[0], block_size, q.shape[2]) + ) + logits = jnp.einsum("bqd,bkd->bqk", q_block, k) + if mask_b is not None: + logits = jnp.where(mask_b, -1e9, logits) + weights = jax.nn.softmax(logits, axis=-1) + out_block = jnp.einsum("bqk,bkd->bqd", weights, v) + outputs.append(out_block) + output = jnp.concatenate(outputs, axis=1) + if pad_len: + output = output[:, :seqlen, :] + return output diff --git a/trax/layers/research/position_encodings.py b/trax/layers/research/position_encodings.py index b483a8650..9e657d8e9 100644 --- a/trax/layers/research/position_encodings.py +++ b/trax/layers/research/position_encodings.py @@ -16,9 +16,13 @@ """Experimenting with position encodings.""" import logging + import jax +import jax.extend as jex import numpy as np + import trax + from trax import fastmath from trax.fastmath import numpy as jnp from trax.layers import base as layer_base @@ -26,528 +30,582 @@ class AxialPositionalEncoding(layer_base.Layer): - """Axial positional encoding.""" - # TODO(kitaev): support variable-length sequences. - - def __init__(self, shape=(64, 64, 3), d_embs=(384, 384, 256), - kernel_initializer=init.RandomNormalInitializer(1.0), - dropout=0.0, dropout_broadcast_dims=(), mode='train'): - super().__init__() - self._kernel_initializer = kernel_initializer - assert len(shape) == len(d_embs) - self._shape = shape - self._d_embs = d_embs - - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if mode == 'train': - self._dropout = dropout - else: - self._dropout = 0.0 - self._dropout_broadcast_dims = dropout_broadcast_dims - self._mode = mode - - def forward(self, inputs): - rng, state = self.rng, self.state - embs = [] - for ax_emb in self.weights: - ax_emb = jnp.broadcast_to( - ax_emb, (inputs.shape[0],) + self._shape + (ax_emb.shape[-1],)) - embs.append(ax_emb) - - if self._mode == 'predict': - assert self._dropout == 0.0 - emb = jnp.concatenate(embs, -1) - emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) - emb = fastmath.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) - self.state = state + inputs.shape[1] - return inputs + emb - elif self._dropout == 0: - # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) - # leads to memory blow-up on TPU. - # emb = jnp.concatenate(embs, -1) - # return inputs + jnp.reshape(emb, inputs.shape), state - return inputs + jnp.concatenate( - [jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1],)) - for emb in embs - ], -1) - else: - emb = jnp.concatenate(embs, -1) - noise_shape = list(emb.shape) - for dim in self._dropout_broadcast_dims: - noise_shape[dim] = 1 - keep_prob = 1.0 - self._dropout - keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape)) - multiplier = keep.astype(inputs.dtype) / keep_prob - return inputs + jnp.reshape(emb * multiplier, inputs.shape) - - def init_weights_and_state(self, input_signature): - d_feature = input_signature.shape[-1] - if sum(self._d_embs) != d_feature: - raise ValueError( - f'sum(self._d_embs) != d_feature: ' - f'sum({self._d_embs}) vs d_feature: {d_feature}') - - rngs = fastmath.random.split(self.rng, len(self._d_embs)) - weights = [] - for ax, (ax_rng, d_emb) in enumerate(zip(rngs, self._d_embs)): - ax_shape = [1] * len(self._shape) - ax_shape[ax] = self._shape[ax] - ax_shape = (1,) + tuple(ax_shape) + (d_emb,) - ax_emb = self._kernel_initializer(ax_shape, ax_rng) - weights.append(ax_emb) - - # State is EMPTY_STATE by default, stays so except for predict mode. - if self._mode == 'predict': - self.state = np.array(0, dtype=np.int32) - self.weights = tuple(weights) + """Axial positional encoding.""" + + # TODO(kitaev): support variable-length sequences. + + def __init__( + self, + shape=(64, 64, 3), + d_embs=(384, 384, 256), + kernel_initializer=init.RandomNormalInitializer(1.0), + dropout=0.0, + dropout_broadcast_dims=(), + mode="train", + ): + super().__init__() + self._kernel_initializer = kernel_initializer + assert len(shape) == len(d_embs) + self._shape = shape + self._d_embs = d_embs + + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if mode == "train": + self._dropout = dropout + else: + self._dropout = 0.0 + self._dropout_broadcast_dims = dropout_broadcast_dims + self._mode = mode + + def forward(self, inputs): + rng, state = self.rng, self.state + embs = [] + for ax_emb in self.weights: + ax_emb = jnp.broadcast_to( + ax_emb, (inputs.shape[0],) + self._shape + (ax_emb.shape[-1],) + ) + embs.append(ax_emb) + + if self._mode == "predict": + assert self._dropout == 0.0 + emb = jnp.concatenate(embs, -1) + emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) + emb = fastmath.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) + self.state = state + inputs.shape[1] + return inputs + emb + elif self._dropout == 0: + # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) + # leads to memory blow-up on TPU. + # emb = jnp.concatenate(embs, -1) + # return inputs + jnp.reshape(emb, inputs.shape), state + return inputs + jnp.concatenate( + [ + jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1],)) + for emb in embs + ], + -1, + ) + else: + emb = jnp.concatenate(embs, -1) + noise_shape = list(emb.shape) + for dim in self._dropout_broadcast_dims: + noise_shape[dim] = 1 + keep_prob = 1.0 - self._dropout + keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape)) + multiplier = keep.astype(inputs.dtype) / keep_prob + return inputs + jnp.reshape(emb * multiplier, inputs.shape) + + def init_weights_and_state(self, input_signature): + d_feature = input_signature.shape[-1] + if sum(self._d_embs) != d_feature: + raise ValueError( + f"sum(self._d_embs) != d_feature: " + f"sum({self._d_embs}) vs d_feature: {d_feature}" + ) + + rngs = fastmath.random.split(self.rng, len(self._d_embs)) + weights = [] + for ax, (ax_rng, d_emb) in enumerate(zip(rngs, self._d_embs)): + ax_shape = [1] * len(self._shape) + ax_shape[ax] = self._shape[ax] + ax_shape = (1,) + tuple(ax_shape) + (d_emb,) + ax_emb = self._kernel_initializer(ax_shape, ax_rng) + weights.append(ax_emb) + + # State is EMPTY_STATE by default, stays so except for predict mode. + if self._mode == "predict": + self.state = np.array(0, dtype=np.int32) + self.weights = tuple(weights) class SinCosPositionalEncoding(layer_base.Layer): - """Implements the sin-cos positional encoding.""" - - def __init__(self, add_offset=2048, dropout=0.0, dropout_broadcast_dims=(-2,), - start_from_zero_one_in=2, mode='train'): - """Creates a SinCosPositionalEncoding instance. - - Args: - add_offset: Maximumnumber to add to positions during training. - dropout: Probability of *not* adding positional encoding to a sequence - position. - dropout_broadcast_dims: Axes along which dropout mask values are - broadcast rather than individually set at random. - start_from_zero_one_in: how often to start from 0 during training, - every one in that many times (e.g., if 4, then it's 25% of the time). - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - super().__init__() - self._add_offset = add_offset - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if mode == 'train': - self._dropout = dropout - else: - self._dropout = 0.0 - self._dropout_broadcast_dims = dropout_broadcast_dims - self._start_from_zero_one_in = start_from_zero_one_in - self._mode = mode - - def _sincos(self, start, length, d_feature): - """Create the sin-cos tensor of shape [1, length, d_feature].""" - position = jnp.arange(0, length)[:, None] + start - div_term = jnp.exp( - jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature)) - sin = jnp.sin(position * div_term) - cos = jnp.cos(position * div_term) - pe = jnp.concatenate([sin, cos], axis=1) - return pe[None, :, :] # [1, length, d_feature] - - def forward(self, inputs): - """Returns the input activations, with added positional information.""" - if self._mode != 'predict': - x = inputs - length = jnp.shape(x)[1] - if self._mode != 'train': - start = 0 - else: - rng1, rng2 = fastmath.random.split(self.rng, 2) - start = fastmath.random.randint(rng1, (), 0, self._add_offset) - start_from_nonzero = fastmath.random.randint( - rng2, (), 0, self._start_from_zero_one_in) - start_from_nonzero = jnp.minimum(1, start_from_nonzero) - start *= start_from_nonzero - px = self._sincos(start, length, inputs.shape[2]) - if self._dropout == 0: - return x + px - else: - noise_shape = list(px.shape) - for dim in self._dropout_broadcast_dims: - noise_shape[dim] = 1 - keep_prob = 1.0 - self._dropout - keep = fastmath.random.bernoulli(self.rng, keep_prob, - tuple(noise_shape)) - multiplier = keep.astype(x.dtype) / keep_prob - return x + px * multiplier - else: - if self._dropout != 0: - raise ValueError(f'In predict mode, but dropout rate ' - f'({self._dropout}) is not zero.') - - # State in this class is only used for fast inference. In that case, - # the model is called with consecutive elements position-by-position. - # This positional encoding layer needs to store the index of the current - # position then and increment it on each call -- that's how state is used - # and updated below. - pe = self._sincos(self.state, inputs.shape[1], inputs.shape[2]) - self.state += inputs.shape[1] - return inputs + pe - - def init_weights_and_state(self, input_signature): - """Randomly initializes the positional encoding vectors. - - Args: - input_signature: `ShapeDtype` instance characterizing the input this - layer should compute on. - """ - if self._mode == 'predict': - self.state = jnp.zeros((), dtype=jnp.int32) + """Implements the sin-cos positional encoding.""" + + def __init__( + self, + add_offset=2048, + dropout=0.0, + dropout_broadcast_dims=(-2,), + start_from_zero_one_in=2, + mode="train", + ): + """Creates a SinCosPositionalEncoding instance. + + Args: + add_offset: Maximumnumber to add to positions during training. + dropout: Probability of *not* adding positional encoding to a sequence + position. + dropout_broadcast_dims: Axes along which dropout mask values are + broadcast rather than individually set at random. + start_from_zero_one_in: how often to start from 0 during training, + every one in that many times (e.g., if 4, then it's 25% of the time). + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + super().__init__() + self._add_offset = add_offset + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if mode == "train": + self._dropout = dropout + else: + self._dropout = 0.0 + self._dropout_broadcast_dims = dropout_broadcast_dims + self._start_from_zero_one_in = start_from_zero_one_in + self._mode = mode + + def _sincos(self, start, length, d_feature): + """Create the sin-cos tensor of shape [1, length, d_feature].""" + position = jnp.arange(0, length)[:, None] + start + div_term = jnp.exp( + jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature) + ) + sin = jnp.sin(position * div_term) + cos = jnp.cos(position * div_term) + pe = jnp.concatenate([sin, cos], axis=1) + return pe[None, :, :] # [1, length, d_feature] + + def forward(self, inputs): + """Returns the input activations, with added positional information.""" + if self._mode != "predict": + x = inputs + length = jnp.shape(x)[1] + if self._mode != "train": + start = 0 + else: + rng1, rng2 = fastmath.random.split(self.rng, 2) + start = fastmath.random.randint(rng1, (), 0, self._add_offset) + start_from_nonzero = fastmath.random.randint( + rng2, (), 0, self._start_from_zero_one_in + ) + start_from_nonzero = jnp.minimum(1, start_from_nonzero) + start *= start_from_nonzero + px = self._sincos(start, length, inputs.shape[2]) + if self._dropout == 0: + return x + px + else: + noise_shape = list(px.shape) + for dim in self._dropout_broadcast_dims: + noise_shape[dim] = 1 + keep_prob = 1.0 - self._dropout + keep = fastmath.random.bernoulli( + self.rng, keep_prob, tuple(noise_shape) + ) + multiplier = keep.astype(x.dtype) / keep_prob + return x + px * multiplier + else: + if self._dropout != 0: + raise ValueError( + f"In predict mode, but dropout rate " + f"({self._dropout}) is not zero." + ) + + # State in this class is only used for fast inference. In that case, + # the model is called with consecutive elements position-by-position. + # This positional encoding layer needs to store the index of the current + # position then and increment it on each call -- that's how state is used + # and updated below. + pe = self._sincos(self.state, inputs.shape[1], inputs.shape[2]) + self.state += inputs.shape[1] + return inputs + pe + + def init_weights_and_state(self, input_signature): + """Randomly initializes the positional encoding vectors. + + Args: + input_signature: `ShapeDtype` instance characterizing the input this + layer should compute on. + """ + if self._mode == "predict": + self.state = jnp.zeros((), dtype=jnp.int32) class FixedBasePositionalEncoding(layer_base.Layer): - """Implements fixed-base positional encoding.""" - - def __init__(self, bases=[11, 13, 14, 15], n_digits=8, # pylint: disable=dangerous-default-value - start_from_zero_one_in=2, base_dropout_one_in=100, - mode='train', initializer=init.RandomUniformInitializer(1e-4)): - super().__init__() - self._bases = bases - self._n_digits = n_digits - self._mode = mode - self._initializer = initializer - self._start_from_zero_one_in = start_from_zero_one_in - self._base_dropout_one_in = base_dropout_one_in - - def forward(self, x): - rng = self.rng - base_weights, start_vec = self.weights - batch_size, length = x.shape[0], x.shape[1] - max_pos = min(self._bases)**self._n_digits - rng1, rng2, rng3 = fastmath.random.split(rng, 3) - assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) - positions = jnp.arange(0, length)[None, :] - # In training we'll randomize starts for better generalization. - # We use the trainable start_vec to compensate and give model a way - # to learn what is the starting position in a sequence. - if self._mode == 'train': - # In 1% of training cases still start from 0 to be exactly as in eval. - start_from_nonzero = fastmath.random.randint( - rng1, (batch_size,), 0, self._start_from_zero_one_in) - start_from_nonzero = jnp.minimum(1, start_from_nonzero) - random_start = fastmath.random.randint( - rng2, (batch_size,), 0, max_pos-length) - random_start *= start_from_nonzero - positions += random_start[:, None] - if self._mode == 'predict': - positions += self.state - res = [] - for bn, base in enumerate(self._bases): - pos_embeddings = [] - cur_positions = positions - for i in range(self._n_digits): - cur_indices = jnp.mod(cur_positions, base) - cur_positions = cur_positions // base - s = base_weights[bn][i] - pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s) - embeddings = jnp.concatenate(pos_embeddings, axis=-1) - if self._mode == 'train': - base_dropout = fastmath.random.randint( - rng3, (batch_size,), 0, self._base_dropout_one_in) - base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) - embeddings *= base_dropout[:, None, None] - res.append(embeddings) - res = sum(res) # Sum embeddings from all bases. - # Add start_vec to the first position only to mark it as starting. - res0 = res[:, 0, :][:, None, :] - start_pos = res0 + start_vec - if self._mode == 'predict': - start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0) - self.state += length # Add input length to state. - res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1) - return x + res - - def init_weights_and_state(self, input_signature): - d_feature = input_signature.shape[-1] - if d_feature % self._n_digits != 0: - raise ValueError( - f'd_feature({d_feature}) % self._n_digits({self._n_digits}) != 0') - d_weight = d_feature // self._n_digits - rng1, rng2 = fastmath.random.split(self.rng, 2) - base_weights = [[self._initializer((1, d_weight), rng) - for rng in fastmath.random.split(rng1, self._n_digits)] - for _ in self._bases] - # Special vector to mark the starting position. - start_vec = self._initializer((1, 1, d_feature), rng2) - self.weights = (base_weights, start_vec) - if self._mode == 'predict': - self.state = jnp.zeros((), dtype=jnp.int32) + """Implements fixed-base positional encoding.""" + + def __init__( + self, + bases=[11, 13, 14, 15], + n_digits=8, # pylint: disable=dangerous-default-value + start_from_zero_one_in=2, + base_dropout_one_in=100, + mode="train", + initializer=init.RandomUniformInitializer(1e-4), + ): + super().__init__() + self._bases = bases + self._n_digits = n_digits + self._mode = mode + self._initializer = initializer + self._start_from_zero_one_in = start_from_zero_one_in + self._base_dropout_one_in = base_dropout_one_in + + def forward(self, x): + rng = self.rng + base_weights, start_vec = self.weights + batch_size, length = x.shape[0], x.shape[1] + max_pos = min(self._bases) ** self._n_digits + rng1, rng2, rng3 = fastmath.random.split(rng, 3) + assert length < max_pos, "length (%d) >= max_pos (%d)" % (length, max_pos) + positions = jnp.arange(0, length)[None, :] + # In training we'll randomize starts for better generalization. + # We use the trainable start_vec to compensate and give model a way + # to learn what is the starting position in a sequence. + if self._mode == "train": + # In 1% of training cases still start from 0 to be exactly as in eval. + start_from_nonzero = fastmath.random.randint( + rng1, (batch_size,), 0, self._start_from_zero_one_in + ) + start_from_nonzero = jnp.minimum(1, start_from_nonzero) + random_start = fastmath.random.randint( + rng2, (batch_size,), 0, max_pos - length + ) + random_start *= start_from_nonzero + positions += random_start[:, None] + if self._mode == "predict": + positions += self.state + res = [] + for bn, base in enumerate(self._bases): + pos_embeddings = [] + cur_positions = positions + for i in range(self._n_digits): + cur_indices = jnp.mod(cur_positions, base) + cur_positions = cur_positions // base + s = base_weights[bn][i] + pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s) + embeddings = jnp.concatenate(pos_embeddings, axis=-1) + if self._mode == "train": + base_dropout = fastmath.random.randint( + rng3, (batch_size,), 0, self._base_dropout_one_in + ) + base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) + embeddings *= base_dropout[:, None, None] + res.append(embeddings) + res = sum(res) # Sum embeddings from all bases. + # Add start_vec to the first position only to mark it as starting. + res0 = res[:, 0, :][:, None, :] + start_pos = res0 + start_vec + if self._mode == "predict": + start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0) + self.state += length # Add input length to state. + res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1) + return x + res + + def init_weights_and_state(self, input_signature): + d_feature = input_signature.shape[-1] + if d_feature % self._n_digits != 0: + raise ValueError( + f"d_feature({d_feature}) % self._n_digits({self._n_digits}) != 0" + ) + d_weight = d_feature // self._n_digits + rng1, rng2 = fastmath.random.split(self.rng, 2) + base_weights = [ + [ + self._initializer((1, d_weight), rng) + for rng in fastmath.random.split(rng1, self._n_digits) + ] + for _ in self._bases + ] + # Special vector to mark the starting position. + start_vec = self._initializer((1, 1, d_feature), rng2) + self.weights = (base_weights, start_vec) + if self._mode == "predict": + self.state = jnp.zeros((), dtype=jnp.int32) def threefry_2x32_prf(key, x: jnp.ndarray) -> jnp.ndarray: - """Apply the threefry PRF to an array of inputs. - - This function is vectorized over x. - For threefry_2x32: K = X = uint32[2] - - Args: - key: uint32[2] the key of the PRF - x: uint32[..., 2] the inputs - - Returns: - y: uint32[..., 2] the outputs - """ - if not (key.shape == (2,) and key.dtype == jnp.uint32): - raise TypeError('key must be uint32[2]', key) - if not (x.shape[-1:] == (2,) and x.dtype == jnp.uint32): - raise TypeError('x must be uint32[..., 2]', x) - # Threefry-2x32 expects this weird format: - x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten() - y_3f = jax.random.threefry_2x32(key, x_3f) - y = jnp.moveaxis( - jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1) - return y + """Apply the threefry PRF to an array of inputs. + This function is vectorized over x. + For threefry_2x32: K = X = uint32[2] -def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): - """Splits a key into a stream of random keys. - - This uses the little-endian counter mode. - - Args: - key: uint32[2] the key to split - lo: the range to start extracting from - hi: the range to stop extracting from - - Returns: - keys: uint32[hi - lo, 2] the split keys - """ - if not (key.shape == (2,) and key.dtype == jnp.uint32): - raise ValueError('key must be uint32[2]') - if not hi < 2**32: - # You shouldn't really be using more than half the key size anyways. - raise NotImplementedError('only 32-bit sizes are supported') - # Create a 64-bit counter: - i_lo = jnp.arange(lo, hi, dtype=jnp.uint32) - i_hi = jnp.zeros_like(i_lo) - i = jnp.stack([i_lo, i_hi], axis=-1) - return threefry_2x32_prf(key, i) - - -class InfinitePositionalEncoding(layer_base.Layer): - """Infinite positional encoding.""" + Args: + key: uint32[2] the key of the PRF + x: uint32[..., 2] the inputs - def __init__( - self, drift=.03, affine=True, transform='any', - time_bin_length=None, - mode='train'): - """Initializes the encoding. + Returns: + y: uint32[..., 2] the outputs + """ + if not (key.shape == (2,) and key.dtype == jnp.uint32): + raise TypeError("key must be uint32[2]", key) + if not (x.shape[-1:] == (2,) and x.dtype == jnp.uint32): + raise TypeError("x must be uint32[..., 2]", x) + # Threefry-2x32 expects this weird format: + x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten() + y_3f = jex.random.threefry_2x32(key, x_3f) + y = jnp.moveaxis(jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1) + return y - The encoding tries to roughly evenly traverse the latent space. - The recurrence time is dependent on how many bits per dimension you use. - There are two parameters to control randomization: - - randomizing the origin every 1/drift steps by letting it drift - - randomizing the origin per call +def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): + """Splits a key into a stream of random keys. - Args: - drift: variance in position difference per unit of difference - affine: whether to randomize the origin every call - transform: learnable transform after encoding (any/diag/none) - time_bin_length: Add features AxialPositionalEncoding learns if - TimeBinCausalAttention is the first layer. - bin_length should match TBCA.bin_length - If you set transform='diag', this flag increases your model capacity to - close to transform='any', though it will still train slower. - mode: if 'predict', allow evaluating one token at a time - """ - super().__init__() - if transform not in ('any', 'diag', 'none'): - raise ValueError(transform) - self._noise_rng = jax.random.split(jax.random.PRNGKey(234234535))[0] - assert self._noise_rng is not None - self._noise = None - self._drift = drift - self._affine = affine - self._transform = transform - self._time_bin_length = time_bin_length - self._mode = mode - - def _get_noise(self, lo: int, hi: int, depth: int): - """Return pseudorandom noise with shape float[length, depth]. + This uses the little-endian counter mode. Args: - lo: where to start sampling - hi: where to stop sampling - depth: noise depth + key: uint32[2] the key to split + lo: the range to start extracting from + hi: the range to stop extracting from Returns: - noise[lo:hi, :]: the noise, where noise.diff(axis=0) is i.i.d. U(-1,1) + keys: uint32[hi - lo, 2] the split keys """ - if self._noise is None or self._noise.shape[0] < hi: - # Resize the noise: - new_length = 1 - while new_length < hi: - new_length *= 2 - noise = threefry_2x32_prange(self._noise_rng, 0, new_length * depth) - noise = noise.reshape((new_length, depth, 2))[:, :, 0] - # Normalize to [-sqrt(3), sqrt(3)]: - noise = noise.astype(jnp.float32) / np.float32(2**31 - 1) - noise = noise * 3**.5 - # TODO(tying): use multiscale noise for memory-efficient sampling - noise = noise.cumsum(axis=0) - self._noise = noise - assert self._noise.shape[0] >= hi - assert self._noise.shape[1] == depth - return self._noise[lo:hi, :] - - def _get_embeddings(self, lo: int, hi: int, depth, rng=None): - """Get embeddings float[length, depth]. + if not (key.shape == (2,) and key.dtype == jnp.uint32): + raise ValueError("key must be uint32[2]") + if not hi < 2**32: + # You shouldn't really be using more than half the key size anyways. + raise NotImplementedError("only 32-bit sizes are supported") + # Create a 64-bit counter: + i_lo = jnp.arange(lo, hi, dtype=jnp.uint32) + i_hi = jnp.zeros_like(i_lo) + i = jnp.stack([i_lo, i_hi], axis=-1) + return threefry_2x32_prf(key, i) - Args: - lo: where to start sampling - hi: where to stop sampling - depth: embedding depth - rng: rng for random phase - Returns: - embeddings: float[length, depth] - """ - noise = self._get_noise(lo, hi, (depth + 1) // 2) - # Make the stddev around 1 after 1/drift. - noise = noise * self._drift**.5 - - t, c = np.mgrid[lo:hi, :depth] - # Make even channels cos, odd channels sin: - c_div_2, c_mod_2 = divmod(c, 2) - # Off-by-one correction for odd depth: - drift = self._drift - if depth > 2: - drift = drift**(((depth+1)//2)/(depth//2)) - # Spend roughly half the frequencies on noise: - freq = jnp.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2] - cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4 - assert cycles.shape == (hi - lo, depth), cycles.shape - - # Get random phases: - if self._affine: - assert rng is not None - cycles = cycles + trax.fastmath.random.uniform( - rng, (1, depth,), minval=0, maxval=1) - - # Convert from cycles to radians: - embeddings = jnp.cos(jnp.pi * 2 * cycles) - - # Set the last channels to the time bin features: - if self._time_bin_length is not None: - inter_bin_idx, intra_bin_idx = divmod(t[:, -1:], self._time_bin_length) - bin_parity = inter_bin_idx % 2 - bin_fraction = intra_bin_idx / self._time_bin_length - embeddings = jnp.concatenate( - [ - embeddings[:, :-3], - 1 / (1 + inter_bin_idx), - bin_fraction, - bin_parity.astype(jnp.float32), - ], -1) - - assert embeddings.shape == (hi - lo, depth), embeddings.shape - return embeddings - - def forward(self, inputs): - rng, state = self.rng, self.state - d_feature = inputs.shape[-1] - input_len = inputs.shape[-2] - - if self._mode == 'predict': - # Assume all the positions are pretty close to each other. - index, predict_rng = state - lo = index.min() - hi = index.max() + 1 - emb = self._get_embeddings(lo=lo, hi=hi, depth=d_feature, rng=predict_rng) - emb = emb[index - lo, jnp.newaxis, :] - index = index + 1 - state = index, predict_rng - else: - emb = self._get_embeddings(lo=0, hi=input_len, depth=d_feature, rng=rng) - emb = emb[jnp.newaxis, :input_len, :] - # TODO(tying): check that XLA swaps matmul(slice(x)) -> slice(matmul(x)), - # or inline this code into get_embeddings/get_noise - if self._transform == 'diag': - emb = emb * jax.nn.softplus(self.weights) - elif self._transform == 'any': - emb = emb @ self.weights - self.state = state - return inputs + emb - - def init_weights_and_state(self, input_signature): - d_feature = input_signature.shape[-1] - if self._transform == 'diag': - # Initialize it to a small value because JAX has a bug in softplus. - scale_isoftplus = jnp.zeros((d_feature,), dtype=jnp.float32) + 1e-4 - weights = scale_isoftplus - elif self._transform == 'any': - ortho = trax.layers.initializers.OrthogonalInitializer() - weights = ortho((d_feature, d_feature), self.rng) - else: - weights = layer_base.EMPTY_WEIGHTS - if self._mode == 'predict': - batch_size = input_signature.shape[0] - self.state = jnp.zeros((batch_size,), dtype=jnp.int32), self.rng - self.weights = weights +class InfinitePositionalEncoding(layer_base.Layer): + """Infinite positional encoding.""" + + def __init__( + self, + drift=0.03, + affine=True, + transform="any", + time_bin_length=None, + mode="train", + ): + """Initializes the encoding. + + The encoding tries to roughly evenly traverse the latent space. + The recurrence time is dependent on how many bits per dimension you use. + + There are two parameters to control randomization: + - randomizing the origin every 1/drift steps by letting it drift + - randomizing the origin per call + + Args: + drift: variance in position difference per unit of difference + affine: whether to randomize the origin every call + transform: learnable transform after encoding (any/diag/none) + time_bin_length: Add features AxialPositionalEncoding learns if + TimeBinCausalAttention is the first layer. + bin_length should match TBCA.bin_length + If you set transform='diag', this flag increases your model capacity to + close to transform='any', though it will still train slower. + mode: if 'predict', allow evaluating one token at a time + """ + super().__init__() + if transform not in ("any", "diag", "none"): + raise ValueError(transform) + self._noise_rng = jax.random.split(jax.random.PRNGKey(234234535))[0] + assert self._noise_rng is not None + self._noise = None + self._drift = drift + self._affine = affine + self._transform = transform + self._time_bin_length = time_bin_length + self._mode = mode + + def _get_noise(self, lo: int, hi: int, depth: int): + """Return pseudorandom noise with shape float[length, depth]. + + Args: + lo: where to start sampling + hi: where to stop sampling + depth: noise depth + + Returns: + noise[lo:hi, :]: the noise, where noise.diff(axis=0) is i.i.d. U(-1,1) + """ + if self._noise is None or self._noise.shape[0] < hi: + # Resize the noise: + new_length = 1 + while new_length < hi: + new_length *= 2 + noise = threefry_2x32_prange(self._noise_rng, 0, new_length * depth) + noise = noise.reshape((new_length, depth, 2))[:, :, 0] + # Normalize to [-sqrt(3), sqrt(3)]: + noise = noise.astype(jnp.float32) / np.float32(2**31 - 1) + noise = noise * 3**0.5 + # TODO(tying): use multiscale noise for memory-efficient sampling + noise = noise.cumsum(axis=0) + self._noise = noise + assert self._noise.shape[0] >= hi + assert self._noise.shape[1] == depth + return self._noise[lo:hi, :] + + def _get_embeddings(self, lo: int, hi: int, depth, rng=None): + """Get embeddings float[length, depth]. + + Args: + lo: where to start sampling + hi: where to stop sampling + depth: embedding depth + rng: rng for random phase + + Returns: + embeddings: float[length, depth] + """ + noise = self._get_noise(lo, hi, (depth + 1) // 2) + # Make the stddev around 1 after 1/drift. + noise = noise * self._drift**0.5 + + t, c = np.mgrid[lo:hi, :depth] + # Make even channels cos, odd channels sin: + c_div_2, c_mod_2 = divmod(c, 2) + # Off-by-one correction for odd depth: + drift = self._drift + if depth > 2: + drift = drift ** (((depth + 1) // 2) / (depth // 2)) + # Spend roughly half the frequencies on noise: + freq = jnp.geomspace(0.5, 0.5 * drift**2, num=(depth + 1) // 2)[c_div_2] + cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4 + assert cycles.shape == (hi - lo, depth), cycles.shape + + # Get random phases: + if self._affine: + assert rng is not None + cycles = cycles + trax.fastmath.random.uniform( + rng, + ( + 1, + depth, + ), + minval=0, + maxval=1, + ) + + # Convert from cycles to radians: + embeddings = jnp.cos(jnp.pi * 2 * cycles) + + # Set the last channels to the time bin features: + if self._time_bin_length is not None: + inter_bin_idx, intra_bin_idx = divmod(t[:, -1:], self._time_bin_length) + bin_parity = inter_bin_idx % 2 + bin_fraction = intra_bin_idx / self._time_bin_length + embeddings = jnp.concatenate( + [ + embeddings[:, :-3], + 1 / (1 + inter_bin_idx), + bin_fraction, + bin_parity.astype(jnp.float32), + ], + -1, + ) + + assert embeddings.shape == (hi - lo, depth), embeddings.shape + return embeddings + + def forward(self, inputs): + rng, state = self.rng, self.state + d_feature = inputs.shape[-1] + input_len = inputs.shape[-2] + + if self._mode == "predict": + # Assume all the positions are pretty close to each other. + index, predict_rng = state + lo = index.min() + hi = index.max() + 1 + emb = self._get_embeddings(lo=lo, hi=hi, depth=d_feature, rng=predict_rng) + emb = emb[index - lo, jnp.newaxis, :] + index = index + 1 + state = index, predict_rng + else: + emb = self._get_embeddings(lo=0, hi=input_len, depth=d_feature, rng=rng) + emb = emb[jnp.newaxis, :input_len, :] + # TODO(tying): check that XLA swaps matmul(slice(x)) -> slice(matmul(x)), + # or inline this code into get_embeddings/get_noise + if self._transform == "diag": + emb = emb * jax.nn.softplus(self.weights) + elif self._transform == "any": + emb = emb @ self.weights + self.state = state + return inputs + emb + + def init_weights_and_state(self, input_signature): + d_feature = input_signature.shape[-1] + if self._transform == "diag": + # Initialize it to a small value because JAX has a bug in softplus. + scale_isoftplus = jnp.zeros((d_feature,), dtype=jnp.float32) + 1e-4 + weights = scale_isoftplus + elif self._transform == "any": + ortho = trax.layers.initializers.OrthogonalInitializer() + weights = ortho((d_feature, d_feature), self.rng) + else: + weights = layer_base.EMPTY_WEIGHTS + if self._mode == "predict": + batch_size = input_signature.shape[0] + self.state = jnp.zeros((batch_size,), dtype=jnp.int32), self.rng + self.weights = weights class TimeBinPositionalEncoding(layer_base.Layer): - """Just the engineered features from InfinitePositionalEncoding.""" - num_features = 3 - - def __init__(self, time_bin_length, mode='train'): - """Initializes the encoding. - - Args: - time_bin_length: TimeBinCausalAttention.bin_length of the first layer. - mode: if 'predict', allow evaluating one token at a time - """ - super().__init__() - self._time_bin_length = time_bin_length - self._mode = mode - - def _get_embeddings(self, t): - """Get embeddings float[..., num_features]. - - Args: - t: int[...] position (i.e. jnp.arange(..., jnp.int32)) - - Returns: - embeddings: float[..., num_features] - """ - inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length) - bin_parity = inter_bin_idx % 2 - bin_fraction = intra_bin_idx / self._time_bin_length - embeddings = jnp.stack([ - 1 / (1 + inter_bin_idx), - bin_fraction, - bin_parity.astype(jnp.float32), - ], -1) - - assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape - return embeddings - - def forward(self, inputs): - state = self.state - depth = inputs.shape[-1] - - if self._mode == 'predict': - emb = self._get_embeddings(t=state) - emb = emb[:, jnp.newaxis, :] - state = state + 1 - else: - input_len = inputs.shape[-2] - emb = self._get_embeddings(t=jnp.arange(input_len, dtype=jnp.int32)) - # Leave batch axis as 1 for broadcasting: - emb = emb[jnp.newaxis, :, :] - emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3,)) - - # Replace the last num_features channels of input. - inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1) - if inputs.shape[-1] > depth: - logging.warning( - 'dropping feature(s): %d down to %d', inputs.shape[-1], depth) - inputs = inputs[..., -depth:] - - assert inputs.shape[-1] == depth, inputs.shape - self.state = state - return inputs - - def init_weights_and_state(self, input_signature): - if self._mode == 'predict': - batch_size = input_signature.shape[0] - self.state = jnp.zeros((batch_size,), dtype=jnp.int32) + """Just the engineered features from InfinitePositionalEncoding.""" + + num_features = 3 + + def __init__(self, time_bin_length, mode="train"): + """Initializes the encoding. + + Args: + time_bin_length: TimeBinCausalAttention.bin_length of the first layer. + mode: if 'predict', allow evaluating one token at a time + """ + super().__init__() + self._time_bin_length = time_bin_length + self._mode = mode + + def _get_embeddings(self, t): + """Get embeddings float[..., num_features]. + + Args: + t: int[...] position (i.e. jnp.arange(..., jnp.int32)) + + Returns: + embeddings: float[..., num_features] + """ + inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length) + bin_parity = inter_bin_idx % 2 + bin_fraction = intra_bin_idx / self._time_bin_length + embeddings = jnp.stack( + [ + 1 / (1 + inter_bin_idx), + bin_fraction, + bin_parity.astype(jnp.float32), + ], + -1, + ) + + assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape + return embeddings + + def forward(self, inputs): + state = self.state + depth = inputs.shape[-1] + + if self._mode == "predict": + emb = self._get_embeddings(t=state) + emb = emb[:, jnp.newaxis, :] + state = state + 1 + else: + input_len = inputs.shape[-2] + emb = self._get_embeddings(t=jnp.arange(input_len, dtype=jnp.int32)) + # Leave batch axis as 1 for broadcasting: + emb = emb[jnp.newaxis, :, :] + emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3,)) + + # Replace the last num_features channels of input. + inputs = jnp.concatenate([inputs[..., : -self.num_features], emb], -1) + if inputs.shape[-1] > depth: + logging.warning( + "dropping feature(s): %d down to %d", inputs.shape[-1], depth + ) + inputs = inputs[..., -depth:] + + assert inputs.shape[-1] == depth, inputs.shape + self.state = state + return inputs + + def init_weights_and_state(self, input_signature): + if self._mode == "predict": + batch_size = input_signature.shape[0] + self.state = jnp.zeros((batch_size,), dtype=jnp.int32) diff --git a/trax/layers/research/position_encodings_test.py b/trax/layers/research/position_encodings_test.py deleted file mode 100644 index f59cbc592..000000000 --- a/trax/layers/research/position_encodings_test.py +++ /dev/null @@ -1,100 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.research.position_encodings.""" - -import functools -import absl.testing.absltest as unittest -import numpy as np -import parameterized - -from trax import fastmath -import trax.layers.research.position_encodings as pe - - -@parameterized.parameterized_class([ - # {'Encoding': pe.FixedBasePositionalEncoding}, - {'Encoding': pe.InfinitePositionalEncoding}, - {'Encoding': functools.partial( - pe.InfinitePositionalEncoding, affine=False)}, - {'Encoding': functools.partial( - pe.TimeBinPositionalEncoding, time_bin_length=5)}, -]) -class PositionEncodingsTest(unittest.TestCase): - """Position encodings conform to the position encodings protocol.""" - - @parameterized.parameterized.expand([ - (1, 100, 8), # typical - (1, 1, 8), # short - (1, 100, 1), # narrow - (2, 100, 8), # batched - ]) - def test_training(self, n, t, c): - encoding = self.Encoding() - input_ntc = np.random.randn(n, t, c) - encoding.init(input_ntc) - output_ntc = encoding(input_ntc) - self.assertEqual(output_ntc.shape, input_ntc.shape) - self.assertTrue(np.not_equal(output_ntc, input_ntc).any()) - - @parameterized.parameterized.expand([ - (1, 100, 8), # typical - (1, 100, 1), # narrow - (2, 100, 8), # batched - ]) - def test_inference(self, n, t, c): - # Get the eval mode outputs: - encoding = self.Encoding(mode='eval') - input_ntc = np.random.randn(n, t, c) - rng = fastmath.random.get_prng(1234) - encoding.init(input_ntc, rng=rng) - output_ntc = encoding(input_ntc) - - is_random = self.Encoding == pe.InfinitePositionalEncoding - - # Get the predict mode outputs: - encoding_pred = self.Encoding(mode='predict') - encoding_pred.init(input_ntc[:, 0:1, :], rng=rng) - output_ntc0 = encoding_pred(input_ntc[:, 0:1, :]) - if not is_random: - np.testing.assert_allclose(output_ntc0, output_ntc[:, 0:1, :], atol=1e-4) - - output_ntc1 = encoding_pred(input_ntc[:, 1:2, :]) - if not is_random: - np.testing.assert_allclose(output_ntc1, output_ntc[:, 1:2, :], atol=1e-4) - - output_ntc2 = encoding_pred(input_ntc[:, 2:3, :]) - if not is_random: - np.testing.assert_allclose(output_ntc2, output_ntc[:, 2:3, :], atol=1e-4) - - -class SinCosEncodingsTest(unittest.TestCase): - """Position encodings conform to the position encodings protocol.""" - - @parameterized.parameterized.expand([ - (1, 100, 8), # typical - (1, 1, 8), # short - (2, 100, 8), # batched - ]) - def test_training(self, n, t, c): - encoding = pe.SinCosPositionalEncoding() - input_ntc = np.random.randn(n, t, c) - encoding.init(input_ntc) - output_ntc = encoding(input_ntc) - self.assertEqual(output_ntc.shape, input_ntc.shape) - - -if __name__ == '__main__': - unittest.main() diff --git a/trax/layers/research/rel_attention.py b/trax/layers/research/rel_attention.py index 19b25240d..e986c3609 100644 --- a/trax/layers/research/rel_attention.py +++ b/trax/layers/research/rel_attention.py @@ -25,483 +25,505 @@ from trax import fastmath from trax.fastmath import numpy as jnp -from trax.layers import base +from trax.layers import base, core from trax.layers import combinators as cb -from trax.layers import core from trax.layers import initializers as init from trax.layers.assert_shape import assert_shape -from trax.layers.attention import MergeHeads -from trax.layers.attention import SplitIntoHeads - +from trax.layers.attention import MergeHeads, SplitIntoHeads # Layers are always CamelCase, but functions in general are snake_case # pylint: disable=invalid-name -def RelativeAttentionWrapper(d_feature, - n_heads=1, - dropout=0.0, - max_inference_length=2048, - mode='train', - context_bias_layer=None, - location_bias_layer=None, - total_pooling=None): - """Relative attention wrapper. - - Args: - d_feature: Last/innermost dimension of activations in the input to and - output from this layer. - n_heads: Number of attention heads. Attention heads effectively split - activation vectors into ``n_heads`` subvectors, of size ``d_feature / - n_heads``. - dropout: dropout rate. - max_inference_length: max inference length. - mode: One of ``'train'``, ``'eval'``, or ``'predict'``. - context_bias_layer: context bias layer. - location_bias_layer: location bias layer. - total_pooling: total pooling. - - Returns: - relative attention layer. - - Relative attention wrapper for compatibility with configurable attention, - so that it can be called by `ApplyAttentionLayer`. - """ - del max_inference_length - - attention = RelativeAttentionLMLayer( - d_feature, - context_bias_layer, - location_bias_layer, - total_pooling, - n_heads=n_heads, - dropout=dropout, - mode=mode) - - return cb.Serial(cb.Select([0, 0, 0]), attention) +def RelativeAttentionWrapper( + d_feature, + n_heads=1, + dropout=0.0, + max_inference_length=2048, + mode="train", + context_bias_layer=None, + location_bias_layer=None, + total_pooling=None, +): + """Relative attention wrapper. + Args: + d_feature: Last/innermost dimension of activations in the input to and + output from this layer. + n_heads: Number of attention heads. Attention heads effectively split + activation vectors into ``n_heads`` subvectors, of size ``d_feature / + n_heads``. + dropout: dropout rate. + max_inference_length: max inference length. + mode: One of ``'train'``, ``'eval'``, or ``'predict'``. + context_bias_layer: context bias layer. + location_bias_layer: location bias layer. + total_pooling: total pooling. -def get_rel_att_inputs(d_model, n_heads): - """Global relative attentions bias initialization shared across layers.""" - assert d_model % n_heads == 0 and d_model % 2 == 0 - d_head = d_model // n_heads - - bias_initializer = init.RandomNormalInitializer(1e-6) - context_bias_layer = core.Weights( - bias_initializer, shape=(1, n_heads, 1, d_head)) - location_bias_layer = core.Weights( - bias_initializer, shape=(1, n_heads, 1, d_head)) - return context_bias_layer, location_bias_layer - - -@assert_shape('bSq,blk,blv,b1xl->bSd,b1xl') -def RelativeAttentionLayer(d_feature, - context_bias_layer, - location_bias_layer, - total_kv_pooling, - separate_cls, - n_heads=1, - dropout=0.0, - mode='train'): - """Returns a layer that maps (q, k, v, masks) to (activations, masks). - - When number of keys is smaller than number of queries layer works in O(q^2*d). - Otherwise it is O(q*k*d). That is because we need to shift relative distances - by current_pooling. When we upsample this is current pooling is a fraction < 1 - Visual explanation: - [01][23][45][67] -> [0][1][2][3][4][5][6][7] - For token [0] we calculate relative distances as follows: - * 0 2 4 6 - However for token [1] we need relative distances changed by 1, specifically: - * -1 1 3 5 - So we not only need to calculate the distances that corresponds to spacing - between the keys but also for the ones in between because there are more than - one query tokens (on different positions which means different relative - distances) for single key token. - - Args: - d_feature: Depth/dimensionality of feature embedding. - context_bias_layer: Global context bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers - location_bias_layer: Global location bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers. - total_kv_pooling: Accumulated pool size of keys/values used at this layer - separate_cls: True/False if we separate_cls in calculations. - - n_heads: Number of attention heads. - dropout: Probabilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - return cb.Serial( - cb.Branch( - PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), - cb.Select([0]), cb.Select([1])), - cb.Parallel( - core.Dense(d_feature), - core.Dense(d_feature), - core.Dense(d_feature), - core.Dense(d_feature), - ), - context_bias_layer, - location_bias_layer, - RelativeAttention( # pylint: disable=no-value-for-parameter - separate_cls=separate_cls, - n_heads=n_heads, - dropout=dropout, - mode=mode), - core.Dense(d_feature), - ) - - -@assert_shape('bSq,blk,blv->bSd') -def RelativeAttentionLMLayer(d_feature, - context_bias_layer, - location_bias_layer, - total_kv_pooling, - separate_cls=False, - n_heads=1, - dropout=0.0, - mode='train'): - """Returns a layer that maps (q, k, v) to (activations). - - Same as standard Relative attention layer but additionally based on sizes - of queries and keys prepares a mask that masks out the future. - Masking the future is the concept primarily used for Language Modelling. - Args: - d_feature: Depth/dimensionality of feature embedding. - context_bias_layer: Global context bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers - location_bias_layer: Global location bias from Transformer XL's attention. - There should be one such layer shared for all relative attention layers. - total_kv_pooling: Accumulated pool size of keys/values used at this layer. - separate_cls: True/False if we separate_cls in calculations. - n_heads: Number of attention heads. - dropout: Probabilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - attention = RelativeAttentionLayer( - d_feature, - context_bias_layer, - location_bias_layer, - total_kv_pooling, - separate_cls, - n_heads=n_heads, - dropout=dropout, - mode=mode) - - return cb.Serial( - CreateAttentionMaskLayer(), # q, k, v, mask - attention, # vecs, mask - cb.Select([0], n_in=2), # vecs - ) + Returns: + relative attention layer. + + Relative attention wrapper for compatibility with configurable attention, + so that it can be called by `ApplyAttentionLayer`. + """ + del max_inference_length + attention = RelativeAttentionLMLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_pooling, + n_heads=n_heads, + dropout=dropout, + mode=mode, + ) -class RelativeAttention(base.Layer): - """Relative attention layer. - - Layer that maps (location_bias, context_bias, pos_emb, q, k, v, mask) - to (activations, mask). - This layer type performs the inner workings of one pass of multi-head - self-attention. It: - - splits queries, keys, and values into multiple 'heads', - - splits positional embeddings into multiple 'heads', - - computes per-head attention weights from per-head (queries, keys), - - applies mask to screen out positions that come from padding tokens, - - [in `'train'` mode] applies dropout to attention weights, - - uses attention weights to combine per-head values vectors, and - - merges per-head results into outgoing activations matching original input - activation vector shapes. - """ - - def __init__(self, separate_cls, n_heads=1, dropout=0.0, mode='train'): - """Returns a new PureAttention instance. + return cb.Serial(cb.Select([0, 0, 0]), attention) + + +def get_rel_att_inputs(d_model, n_heads): + """Global relative attentions bias initialization shared across layers.""" + assert d_model % n_heads == 0 and d_model % 2 == 0 + d_head = d_model // n_heads + + bias_initializer = init.RandomNormalInitializer(1e-6) + context_bias_layer = core.Weights(bias_initializer, shape=(1, n_heads, 1, d_head)) + location_bias_layer = core.Weights(bias_initializer, shape=(1, n_heads, 1, d_head)) + return context_bias_layer, location_bias_layer + + +@assert_shape("bSq,blk,blv,b1xl->bSd,b1xl") +def RelativeAttentionLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_kv_pooling, + separate_cls, + n_heads=1, + dropout=0.0, + mode="train", +): + """Returns a layer that maps (q, k, v, masks) to (activations, masks). + + When number of keys is smaller than number of queries layer works in O(q^2*d). + Otherwise it is O(q*k*d). That is because we need to shift relative distances + by current_pooling. When we upsample this is current pooling is a fraction < 1 + Visual explanation: + [01][23][45][67] -> [0][1][2][3][4][5][6][7] + For token [0] we calculate relative distances as follows: + * 0 2 4 6 + However for token [1] we need relative distances changed by 1, specifically: + * -1 1 3 5 + So we not only need to calculate the distances that corresponds to spacing + between the keys but also for the ones in between because there are more than + one query tokens (on different positions which means different relative + distances) for single key token. Args: + d_feature: Depth/dimensionality of feature embedding. + context_bias_layer: Global context bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers + location_bias_layer: Global location bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers. + total_kv_pooling: Accumulated pool size of keys/values used at this layer separate_cls: True/False if we separate_cls in calculations. + n_heads: Number of attention heads. - dropout: Probabilistic rate for dropout applied to attention strengths - (based on query-key pairs) before applying them to values. + dropout: Probabilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. mode: One of `'train'`, `'eval'`, or `'predict'`. """ - super().__init__(n_in=7, n_out=2) - self._separate_cls = separate_cls - self._n_heads = n_heads - self._dropout = dropout - self._mode = mode - - def forward(self, inputs): - """Returns attention-computed activations and unmodified mask. + return cb.Serial( + cb.Branch( + PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling), + cb.Select([0]), + cb.Select([1]), + ), + cb.Parallel( + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + ), + context_bias_layer, + location_bias_layer, + RelativeAttention( # pylint: disable=no-value-for-parameter + separate_cls=separate_cls, n_heads=n_heads, dropout=dropout, mode=mode + ), + core.Dense(d_feature), + ) + + +@assert_shape("bSq,blk,blv->bSd") +def RelativeAttentionLMLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_kv_pooling, + separate_cls=False, + n_heads=1, + dropout=0.0, + mode="train", +): + """Returns a layer that maps (q, k, v) to (activations). + + Same as standard Relative attention layer but additionally based on sizes + of queries and keys prepares a mask that masks out the future. + Masking the future is the concept primarily used for Language Modelling. Args: - inputs: A (location_bias, context_bias, pos_emb, q, k, v, mask) tuple. + d_feature: Depth/dimensionality of feature embedding. + context_bias_layer: Global context bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers + location_bias_layer: Global location bias from Transformer XL's attention. + There should be one such layer shared for all relative attention layers. + total_kv_pooling: Accumulated pool size of keys/values used at this layer. + separate_cls: True/False if we separate_cls in calculations. + n_heads: Number of attention heads. + dropout: Probabilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - location_bias, context_bias, pos_emb, q, k, v, mask = inputs - - d_feature = q.shape[-1] - n_heads = self._n_heads - if d_feature % n_heads != 0: - raise ValueError( - f'Dimensionality of feature embedding ({d_feature}) is not a ' - f'multiple of the requested number of attention heads ({n_heads}).') - - per_head_results, dots = DotProductAttention( - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), - SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), - pos_emb.reshape((-1, n_heads, d_feature // n_heads)), - context_bias, - location_bias, - mask, - separate_cls=self._separate_cls, - dropout=self._dropout, - mode=self._mode, - rng=self.rng) - if self._mode == 'viz': - self.state = dots - merged_results = MergeHeads( - n_heads, merged_batch_and_head=False).forward(per_head_results) - return merged_results, mask - - -def DotProductAttention(queries, keys, values, pos_emb, context_bias, - location_bias, mask, separate_cls, dropout, mode, rng): - """Computes new activations via masked attention-weighted sum of values. - - Args: - queries: Per-head activations representing attention queries. - keys: Per-head activations representing attention keys. - values: Per-head activations to be combined by computed attention weights. - pos_emb: Per-head activations representing positional embeddings. - context_bias: Global context bias from Transformer XL's attention. - location_bias: Global location bias from Transformer XL's attention. - mask: Mask that distinguishes positions with real content vs. padding. - separate_cls: True/False if we separate_cls in calculations. - dropout: Probabilistic rate for dropout applied to attention strengths - (based on query-key pairs) before applying them to values. - mode: One of `'train'`, `'eval'`, or `'predict'`. - rng: Single-use random number generator (JAX PRNG key). - - Returns: - Per-head activations resulting from masked per-head attention-weighted - sum of per-head values. - - This function is the core of the attention mechanism. It: - - computes per-head attention weights from per-head `queries` and `keys`, - - applies `mask` to screen out positions that come from padding tokens, - - optionally applies dropout to attention weights, and - - uses attention weights to combine per-head `values` vectors. - """ - d_feature = queries.shape[-1] - keys_len, queries_len = keys.shape[-2], queries.shape[-2] - funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - - ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys) - bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb) - bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling) - - if separate_cls: - # Masking out location part of attention for cls token - bd = bd.at[:, :, :, 0].set(0) - bd = bd.at[:, :, 0, :].set(0) - - dots = (ac + bd) / jnp.sqrt(d_feature) - if mask is not None: - dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) - # Softmax. - dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) - if dropout >= 1.0: - raise ValueError('Dropout rates must be lower than 1.') - if dropout is not None and dropout > 0.0 and mode == 'train': - keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) - dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) - out = jnp.matmul(dots, values) - out = out.astype(jnp.float32) - dots = dots.astype(jnp.float32) - return out, dots + attention = RelativeAttentionLayer( + d_feature, + context_bias_layer, + location_bias_layer, + total_kv_pooling, + separate_cls, + n_heads=n_heads, + dropout=dropout, + mode=mode, + ) -def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling): - """Positional embeddings. + return cb.Serial( + CreateAttentionMaskLayer(), # q, k, v, mask + attention, # vecs, mask + cb.Select([0], n_in=2), # vecs + ) - Args: - d_feature: Depth/dimensionality of feature embedding. - separate_cls: True/False if we separate_cls in calculations. - total_kv_pooling: Accumulated pool size of keys/values until this layer. - Returns: - a layer that based on queries, keys and accumulated pool size of - keys/values until this layer calculates sinusoidal positional embeddings - for relative attention calculations. - """ +class RelativeAttention(base.Layer): + """Relative attention layer. + + Layer that maps (location_bias, context_bias, pos_emb, q, k, v, mask) + to (activations, mask). + This layer type performs the inner workings of one pass of multi-head + self-attention. It: + - splits queries, keys, and values into multiple 'heads', + - splits positional embeddings into multiple 'heads', + - computes per-head attention weights from per-head (queries, keys), + - applies mask to screen out positions that come from padding tokens, + - [in `'train'` mode] applies dropout to attention weights, + - uses attention weights to combine per-head values vectors, and + - merges per-head results into outgoing activations matching original input + activation vector shapes. + """ - def PositionsVectors(queries, keys): - assert not separate_cls + def __init__(self, separate_cls, n_heads=1, dropout=0.0, mode="train"): + """Returns a new PureAttention instance. + + Args: + separate_cls: True/False if we separate_cls in calculations. + n_heads: Number of attention heads. + dropout: Probabilistic rate for dropout applied to attention strengths + (based on query-key pairs) before applying them to values. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + super().__init__(n_in=7, n_out=2) + self._separate_cls = separate_cls + self._n_heads = n_heads + self._dropout = dropout + self._mode = mode + + def forward(self, inputs): + """Returns attention-computed activations and unmodified mask. + + Args: + inputs: A (location_bias, context_bias, pos_emb, q, k, v, mask) tuple. + """ + location_bias, context_bias, pos_emb, q, k, v, mask = inputs + + d_feature = q.shape[-1] + n_heads = self._n_heads + if d_feature % n_heads != 0: + raise ValueError( + f"Dimensionality of feature embedding ({d_feature}) is not a " + f"multiple of the requested number of attention heads ({n_heads})." + ) + + per_head_results, dots = DotProductAttention( + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(q), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(k), + SplitIntoHeads(n_heads, merged_batch_and_head=False).forward(v), + pos_emb.reshape((-1, n_heads, d_feature // n_heads)), + context_bias, + location_bias, + mask, + separate_cls=self._separate_cls, + dropout=self._dropout, + mode=self._mode, + rng=self.rng, + ) + if self._mode == "viz": + self.state = dots + merged_results = MergeHeads(n_heads, merged_batch_and_head=False).forward( + per_head_results + ) + return merged_results, mask + + +def DotProductAttention( + queries, + keys, + values, + pos_emb, + context_bias, + location_bias, + mask, + separate_cls, + dropout, + mode, + rng, +): + """Computes new activations via masked attention-weighted sum of values. + Args: + queries: Per-head activations representing attention queries. + keys: Per-head activations representing attention keys. + values: Per-head activations to be combined by computed attention weights. + pos_emb: Per-head activations representing positional embeddings. + context_bias: Global context bias from Transformer XL's attention. + location_bias: Global location bias from Transformer XL's attention. + mask: Mask that distinguishes positions with real content vs. padding. + separate_cls: True/False if we separate_cls in calculations. + dropout: Probabilistic rate for dropout applied to attention strengths + (based on query-key pairs) before applying them to values. + mode: One of `'train'`, `'eval'`, or `'predict'`. + rng: Single-use random number generator (JAX PRNG key). + + Returns: + Per-head activations resulting from masked per-head attention-weighted + sum of per-head values. + + This function is the core of the attention mechanism. It: + - computes per-head attention weights from per-head `queries` and `keys`, + - applies `mask` to screen out positions that come from padding tokens, + - optionally applies dropout to attention weights, and + - uses attention weights to combine per-head `values` vectors. + """ + d_feature = queries.shape[-1] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - if funnel_factor == 1: - offset = keys_len - 1 - positions = (jnp.arange(keys_len) - offset) * total_kv_pooling - else: - if is_upsampling: - positions = jnp.arange(-queries_len + 1, queries_len, 1.0) - else: - positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling + ac = jnp.einsum("bnid,bnjd->bnij", queries + context_bias, keys) + bd = jnp.einsum("bnid,jnd->bnij", queries + location_bias, pos_emb) + bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling) + + if separate_cls: + # Masking out location part of attention for cls token + bd = bd.at[:, :, :, 0].set(0) + bd = bd.at[:, :, 0, :].set(0) + + dots = (ac + bd) / jnp.sqrt(d_feature) + if mask is not None: + dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) + # Softmax. + dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) + if dropout >= 1.0: + raise ValueError("Dropout rates must be lower than 1.") + if dropout is not None and dropout > 0.0 and mode == "train": + keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) + dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) + out = jnp.matmul(dots, values) + out = out.astype(jnp.float32) + dots = dots.astype(jnp.float32) + return out, dots - return positions - def Sinusoidal_Embeddings(positions): - inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) - sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) - pos_emb = jnp.concatenate( - [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) - return pos_emb +def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling): + """Positional embeddings. - return cb.Serial( - cb.Fn('Generate positions vectors', PositionsVectors, n_out=1), - cb.Fn( - 'Transform to sinusoidal encodings', Sinusoidal_Embeddings, n_out=1)) + Args: + d_feature: Depth/dimensionality of feature embedding. + separate_cls: True/False if we separate_cls in calculations. + total_kv_pooling: Accumulated pool size of keys/values until this layer. + Returns: + a layer that based on queries, keys and accumulated pool size of + keys/values until this layer calculates sinusoidal positional embeddings + for relative attention calculations. + """ -def calc_funnel_ratio(keys_len, queries_len): - """Calculate funnel ratio.""" + def PositionsVectors(queries, keys): + assert not separate_cls - if queries_len > keys_len: # Upsampling - assert queries_len % keys_len == 0 - funnel_factor = queries_len // keys_len - is_upsampling = True - else: # Downsampling - assert keys_len % queries_len == 0 - funnel_factor = keys_len // queries_len - is_upsampling = False + keys_len, queries_len = keys.shape[-2], queries.shape[-2] + funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - return funnel_factor, is_upsampling + if funnel_factor == 1: + offset = keys_len - 1 + positions = (jnp.arange(keys_len) - offset) * total_kv_pooling + else: + if is_upsampling: + positions = jnp.arange(-queries_len + 1, queries_len, 1.0) + else: + positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling + return positions -def _fast_matrix_shift(x, funnel_factor=1, is_upsampling=False): - """Fast matrix shift.""" + def Sinusoidal_Embeddings(positions): + inv_freq = 1 / (10000 ** (jnp.arange(0.0, d_feature, 2.0) / d_feature)) + sinusoid_freq = jnp.einsum("i,j->ij", positions, inv_freq) + pos_emb = jnp.concatenate( + [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1 + ) + return pos_emb - if funnel_factor == 1 and not is_upsampling: - shift = 1 - batch_size, n_head = x.shape[0], x.shape[1] - queries_len, keys_len = x.shape[2], x.shape[3] - zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift)) - x = jnp.concatenate([zero_pad, x], axis=3) - x = x.reshape(batch_size, n_head, keys_len + shift, queries_len) - x = x[:, :, shift:, :] - return x + return cb.Serial( + cb.Fn("Generate positions vectors", PositionsVectors, n_out=1), + cb.Fn("Transform to sinusoidal encodings", Sinusoidal_Embeddings, n_out=1), + ) - if is_upsampling: - k = funnel_factor - shift = 1 - else: - k = 1 - shift = funnel_factor - bsz, n_head = x.shape[0], x.shape[1] - qlen, klen = x.shape[2], (x.shape[3] + 1) // 2 +def calc_funnel_ratio(keys_len, queries_len): + """Calculate funnel ratio.""" - zero_pad = jnp.zeros((bsz, n_head, qlen, shift)) - x = jnp.concatenate([zero_pad, x], axis=3) - x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen) - x = x[:, :, shift:, :] - x = x.reshape(bsz, n_head, qlen, klen * 2 - 1) - x = x[:, :, :, shift - 1:shift - 1 + klen:k] - return x + if queries_len > keys_len: # Upsampling + assert queries_len % keys_len == 0 + funnel_factor = queries_len // keys_len + is_upsampling = True + else: # Downsampling + assert keys_len % queries_len == 0 + funnel_factor = keys_len // queries_len + is_upsampling = False + return funnel_factor, is_upsampling -@assert_shape('bqd,bkd,bvd->bqd,bkd,bvd,b1qk') -def CreateAttentionMaskLayer(): - """Creates attention mask layer. - Returns a layer that based on queries, keys and accumulated pool size of - keys/values until this layer calculates positional embeddings for - causal relative attention calculations. +def _fast_matrix_shift(x, funnel_factor=1, is_upsampling=False): + """Fast matrix shift.""" + + if funnel_factor == 1 and not is_upsampling: + shift = 1 + batch_size, n_head = x.shape[0], x.shape[1] + queries_len, keys_len = x.shape[2], x.shape[3] + zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift)) + x = jnp.concatenate([zero_pad, x], axis=3) + x = x.reshape(batch_size, n_head, keys_len + shift, queries_len) + x = x[:, :, shift:, :] + return x + + if is_upsampling: + k = funnel_factor + shift = 1 + else: + k = 1 + shift = funnel_factor - Takes as input q, k, v and appends proper mask in the end. - Causal attention uses masking to prevent a given sequence position from - attending to positions greater than / following it. This is used, for - example, when training autoregressive sequence models, or when decoding a - sequence symbol by symbol. + bsz, n_head = x.shape[0], x.shape[1] + qlen, klen = x.shape[2], (x.shape[3] + 1) // 2 - Returns: - an attention mask layer. - """ + zero_pad = jnp.zeros((bsz, n_head, qlen, shift)) + x = jnp.concatenate([zero_pad, x], axis=3) + x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen) + x = x[:, :, shift:, :] + x = x.reshape(bsz, n_head, qlen, klen * 2 - 1) + x = x[:, :, :, shift - 1 : shift - 1 + klen : k] + return x - def calculate_mask(queries, keys): - batch_size = queries.shape[0] - keys_len, queries_len = keys.shape[-2], queries.shape[-2] - funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) - return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, - is_upsampling) +@assert_shape("bqd,bkd,bvd->bqd,bkd,bvd,b1qk") +def CreateAttentionMaskLayer(): + """Creates attention mask layer. - def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, - is_upsampling): - """Funnel mask. + Returns a layer that based on queries, keys and accumulated pool size of + keys/values until this layer calculates positional embeddings for + causal relative attention calculations. - Args: - batch_size: batch size. - keys_len: keys length. - queries_len: queries length. - funnel_factor: funnel factor. - is_upsampling: True or False. + Takes as input q, k, v and appends proper mask in the end. + Causal attention uses masking to prevent a given sequence position from + attending to positions greater than / following it. This is used, for + example, when training autoregressive sequence models, or when decoding a + sequence symbol by symbol. Returns: - funnel mask. - - This function based on keys/queries lengths creates a triangle mask - that prevents tokens from attending to positions following it. - - If funnel_factor is not equal to 1 due to funnel upsampling or - downsampling it adjusts created mask for funnel attention - by repeating each element funnel_factor times. - - This is because after funnel layer one token attends to funnel_factor - different tokens in downsampling. During upsampling on the other hand - funnel_factor tokens are attending to single token before upsampling. + an attention mask layer. """ - if funnel_factor != 1: - if not is_upsampling: - mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) - mask = jnp.repeat(mask, funnel_factor, axis=-1) - else: - mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) - mask = jnp.repeat(mask, funnel_factor, axis=-2) - else: - mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) - - return jnp.repeat(mask[None, None, :, :], batch_size, axis=0) - - return cb.Branch( - cb.Select([0]), cb.Select([1]), cb.Select([2]), - cb.Fn('create attention mask layer', calculate_mask, n_out=1)) - - -@assert_shape('...d->...d') + def calculate_mask(queries, keys): + batch_size = queries.shape[0] + keys_len, queries_len = keys.shape[-2], queries.shape[-2] + funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) + + return _funnel_mask( + batch_size, keys_len, queries_len, funnel_factor, is_upsampling + ) + + def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): + """Funnel mask. + + Args: + batch_size: batch size. + keys_len: keys length. + queries_len: queries length. + funnel_factor: funnel factor. + is_upsampling: True or False. + + Returns: + funnel mask. + + This function based on keys/queries lengths creates a triangle mask + that prevents tokens from attending to positions following it. + + If funnel_factor is not equal to 1 due to funnel upsampling or + downsampling it adjusts created mask for funnel attention + by repeating each element funnel_factor times. + + This is because after funnel layer one token attends to funnel_factor + different tokens in downsampling. During upsampling on the other hand + funnel_factor tokens are attending to single token before upsampling. + """ + + if funnel_factor != 1: + if not is_upsampling: + mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) + mask = jnp.repeat(mask, funnel_factor, axis=-1) + else: + mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) + mask = jnp.repeat(mask, funnel_factor, axis=-2) + else: + mask = jnp.tril(jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) + + return jnp.repeat(mask[None, None, :, :], batch_size, axis=0) + + return cb.Branch( + cb.Select([0]), + cb.Select([1]), + cb.Select([2]), + cb.Fn("create attention mask layer", calculate_mask, n_out=1), + ) + + +@assert_shape("...d->...d") def ShiftRightCls(cls_id): - """Shifts right and insert cls. + """Shifts right and insert cls. - Args: - cls_id: id of the cls token in embedding dictionary. Returns a layer that - shifts input tokens to the right by one and inserts an cls token to the - beginning like in BERT paper. + Args: + cls_id: id of the cls token in embedding dictionary. Returns a layer that + shifts input tokens to the right by one and inserts an cls token to the + beginning like in BERT paper. - Returns: - layer shifting to right and inserting cls. - """ + Returns: + layer shifting to right and inserting cls. + """ - def shift_right(x): - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[1] = (1, 0) - padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(cls_id)) - return padded[:, :-1] + def shift_right(x): + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[1] = (1, 0) + padded = jnp.pad( + x, pad_widths, mode="constant", constant_values=x.dtype.type(cls_id) + ) + return padded[:, :-1] - return cb.Fn('ShiftRightCls()', shift_right) + return cb.Fn("ShiftRightCls()", shift_right) diff --git a/trax/layers/research/resampling.py b/trax/layers/research/resampling.py index e1866fd8c..2fc8b4ea3 100644 --- a/trax/layers/research/resampling.py +++ b/trax/layers/research/resampling.py @@ -24,109 +24,135 @@ def AveragePooling(shorten_factor, *args, **kwargs): - del args, kwargs - - return AvgPool(pool_size=(shorten_factor,), strides=(shorten_factor,)) - - -def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode='train', - **kwargs): - del args, kwargs - - return cb.Serial( - core.Fn( - 'Shorten', - lambda x: jnp.reshape( # pylint: disable=g-long-lambda - # Shorten -- move to depth. # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1] // shorten_factor, -1)), - n_out=1), - core.Dense(d_model), - core.Dropout(rate=dropout, mode=mode) - ) - - -def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train', - **kwargs): - del args, kwargs - - return cb.Serial( - core.Dense(shorten_factor * d_model), - core.Dropout(rate=dropout, mode=mode), - core.Fn( - 'ProlongBack', - lambda x: jnp.reshape( # pylint: disable=g-long-lambda - # Prolong back. # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1] * shorten_factor, -1)), - n_out=1) - ) - - -def NaiveUpsampling(shorten_factor, d_model, *args, **kwargs): # pylint: disable = unused-argument - return core.Fn('Repeat', lambda x: jnp.repeat(x, shorten_factor, axis=1)) + del args, kwargs + + return AvgPool(pool_size=(shorten_factor,), strides=(shorten_factor,)) + + +def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode="train", **kwargs): + del args, kwargs + + return cb.Serial( + core.Fn( + "Shorten", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + # Shorten -- move to depth. # pylint: disable=g-long-lambda + x, + (x.shape[0], x.shape[1] // shorten_factor, -1), + ), + n_out=1, + ), + core.Dense(d_model), + core.Dropout(rate=dropout, mode=mode), + ) + + +def LinearUpsampling( + shorten_factor, d_model, *args, dropout=0.0, mode="train", **kwargs +): + del args, kwargs + + return cb.Serial( + core.Dense(shorten_factor * d_model), + core.Dropout(rate=dropout, mode=mode), + core.Fn( + "ProlongBack", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + # Prolong back. # pylint: disable=g-long-lambda + x, + (x.shape[0], x.shape[1] * shorten_factor, -1), + ), + n_out=1, + ), + ) + + +def NaiveUpsampling( + shorten_factor, d_model, *args, **kwargs +): # pylint: disable = unused-argument + return core.Fn("Repeat", lambda x: jnp.repeat(x, shorten_factor, axis=1)) def NoUpsampling(shorten_factor, d_model, *args, **kwargs): - del d_model, args, kwargs - - return core.Fn('ReturnZero', lambda x: jnp.zeros( # pylint: disable=g-long-lambda - (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype)) - - -def FeedForwardBlock(d_model, - d_ff, - dropout, - dropout_shared_axes, - mode, - activation): - # We copy the ff block function because we cannot import it from models - return [ - core.Dense(d_ff), - activation(), - core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, - mode=mode), - core.Dense(d_model), - ] - - -def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads, - dropout, dropout_shared_axes, mode, ff_activation, - context_bias_layer, location_bias_layer, total_pooling, - resampling_fn): - """Attention resampling.""" - - attention = RelativeAttentionLMLayer( - d_model, context_bias_layer, location_bias_layer, - total_pooling, n_heads=n_heads, dropout=dropout, - mode=mode) - - feed_forward = FeedForwardBlock( - d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) - - resampling = resampling_fn(shorten_factor, d_model, - mode=mode) - - def _Dropout(): - return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, - mode=mode) - - return [ - LayerNorm(), # h - cb.Branch(cb.Serial( - resampling, - LayerNorm(), - ), None), # h', h - cb.Serial( # pylint: disable=g-long-ternary - cb.Select([0, 2, 1, 2]), - cb.Add(), - ) if is_upsampling else [], - cb.Residual( - cb.Select([0, 1, 1]), # h', h, h - attention, - _Dropout(), - ), - cb.Residual( - LayerNorm(), - feed_forward, - _Dropout(), - ), - ] + del d_model, args, kwargs + + return core.Fn( + "ReturnZero", + lambda x: jnp.zeros( # pylint: disable=g-long-lambda + (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype + ), + ) + + +def FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, activation): + # We copy the ff block function because we cannot import it from models + return [ + core.Dense(d_ff), + activation(), + core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + core.Dense(d_model), + ] + + +def AttentionResampling( + shorten_factor, + d_model, + is_upsampling, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + context_bias_layer, + location_bias_layer, + total_pooling, + resampling_fn, +): + """Attention resampling.""" + + attention = RelativeAttentionLMLayer( + d_model, + context_bias_layer, + location_bias_layer, + total_pooling, + n_heads=n_heads, + dropout=dropout, + mode=mode, + ) + + feed_forward = FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + resampling = resampling_fn(shorten_factor, d_model, mode=mode) + + def _Dropout(): + return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + LayerNorm(), # h + cb.Branch( + cb.Serial( + resampling, + LayerNorm(), + ), + None, + ), # h', h + cb.Serial( # pylint: disable=g-long-ternary + cb.Select([0, 2, 1, 2]), + cb.Add(), + ) + if is_upsampling + else [], + cb.Residual( + cb.Select([0, 1, 1]), # h', h, h + attention, + _Dropout(), + ), + cb.Residual( + LayerNorm(), + feed_forward, + _Dropout(), + ), + ] diff --git a/trax/layers/research/rotary_positional_embedding.py b/trax/layers/research/rotary_positional_embedding.py index dbb08fcea..8e9b1b3bd 100644 --- a/trax/layers/research/rotary_positional_embedding.py +++ b/trax/layers/research/rotary_positional_embedding.py @@ -19,30 +19,29 @@ https://arxiv.org/pdf/2104.09864.pdf """ -# from trax import layers as tl from trax.fastmath import numpy as jnp from trax.layers import core def rotate(x): - """Rotate function.""" - _, l, d = x.shape - inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) - positions = jnp.arange(l) - freqs = jnp.einsum('i,j->ij', positions, inv_freq) - emb = jnp.concatenate((freqs, freqs), axis=-1) - cos = jnp.cos(emb) - sin = jnp.sin(emb) + """Rotate function.""" + _, l, d = x.shape + inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) + positions = jnp.arange(l) + freqs = jnp.einsum("i,j->ij", positions, inv_freq) + emb = jnp.concatenate((freqs, freqs), axis=-1) + cos = jnp.cos(emb) + sin = jnp.sin(emb) - def mul(vecs, pos_emb): - return jnp.einsum('bld,ld->bld', vecs, pos_emb) + def mul(vecs, pos_emb): + return jnp.einsum("bld,ld->bld", vecs, pos_emb) - def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) - return mul(x, cos) + mul(rotate_half(x), sin) + return mul(x, cos) + mul(rotate_half(x), sin) def Rotate(): # pylint: disable=invalid-name - return core.Fn('Rotate', rotate) + return core.Fn("Rotate", rotate) diff --git a/trax/layers/research/sparsity.py b/trax/layers/research/sparsity.py index 1ac1c8ca4..463e01f4f 100644 --- a/trax/layers/research/sparsity.py +++ b/trax/layers/research/sparsity.py @@ -18,1716 +18,1996 @@ import functools import math import random as pyrandom + import numpy as np from trax import fastmath from trax import layers as tl from trax.fastmath import numpy as jnp from trax.fastmath import random -from trax.layers import base -from trax.layers import core +from trax.layers import base, core, reversible from trax.layers import initializers as init -from trax.layers import reversible from trax.layers.assert_shape import assert_shape - # We use mixed CamelCase and snake_case names in this file. # pylint: disable=invalid-name -@assert_shape('...->...') +@assert_shape("...->...") class ReversibleReshapePermute(reversible.ReversibleLayer): - """Simple and fast, reversible, random-looking permutation layer. - - This layer permutates the last dimension (usually the embedding dimension) - with simple reshapes. It uses the same permutation for every embedding, and - permutation never changes. - The layer works only when the last dimension is a power of 2. The - permutation is not truly random, as it just uses reshapes to get a fast - random-looking permutation. It has, however, a permutation cycle length - of just log2(dimension_size). - """ - - def forward(self, x): - shape = x.shape - x = x.reshape(shape[:-1]+(-1, self._get_multiplier(x))) - t_x = jnp.einsum('...ab->...ba', x) # transpose - return t_x.reshape(shape) - - def reverse(self, x, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng - shape = x.shape - x = x.reshape(shape[:-1]+(self._get_multiplier(x), -1)) - t_x = jnp.einsum('...ab->...ba', x) # transpose - return t_x.reshape(shape) - - def _get_multiplier(self, x): - """Return a size of the new dimension for reshaping. - - We want to split the last dimension into two using approximately equal - dimensions, we could split a dimension of size 512 into 16 * 32. - However, not all numbers will work equally well, because we have a different - cycle length for permutations for different numbers. For example, for - dimension size 1024 and multiplier 32 we would get the same permutation - already after applying permutation twice (cycle length is 2), but with - multiplier 8 we would get the same permutation after appling permutation 10 - times (cycle length is 10). - - For powers of two the cycle length is limited by log2(dimension_size). - This function returns the biggest multiplier smaller than - sqrt(dimension_size) that keeps the longest possible cycle lenght of the - permutation. + """Simple and fast, reversible, random-looking permutation layer. + + This layer permutates the last dimension (usually the embedding dimension) + with simple reshapes. It uses the same permutation for every embedding, and + permutation never changes. + The layer works only when the last dimension is a power of 2. The + permutation is not truly random, as it just uses reshapes to get a fast + random-looking permutation. It has, however, a permutation cycle length + of just log2(dimension_size). + """ - Args: - x: The input tensor. + def forward(self, x): + shape = x.shape + x = x.reshape(shape[:-1] + (-1, self._get_multiplier(x))) + t_x = jnp.einsum("...ab->...ba", x) # transpose + return t_x.reshape(shape) + + def reverse(self, x, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng + shape = x.shape + x = x.reshape(shape[:-1] + (self._get_multiplier(x), -1)) + t_x = jnp.einsum("...ab->...ba", x) # transpose + return t_x.reshape(shape) + + def _get_multiplier(self, x): + """Return a size of the new dimension for reshaping. + + We want to split the last dimension into two using approximately equal + dimensions, we could split a dimension of size 512 into 16 * 32. + However, not all numbers will work equally well, because we have a different + cycle length for permutations for different numbers. For example, for + dimension size 1024 and multiplier 32 we would get the same permutation + already after applying permutation twice (cycle length is 2), but with + multiplier 8 we would get the same permutation after appling permutation 10 + times (cycle length is 10). + + For powers of two the cycle length is limited by log2(dimension_size). + This function returns the biggest multiplier smaller than + sqrt(dimension_size) that keeps the longest possible cycle lenght of the + permutation. + + Args: + x: The input tensor. + + Returns: + An appropriate multiplier for the permutation reshape. + """ + last_dim = x.shape[-1] + + def big_relatively_prime(n): + # The longest possible cycle is achieved iff log2(multiplier) and + # log2(dimension_size) are relatively prime. We choose the biggest such + # number smaller than sqrt(dimension_size). + for i in range(n // 2, 0, -1): + if n % i != 0: + return i + return 1 + + max_cycle_len = int(math.log(last_dim, 2)) + assert 2**max_cycle_len == last_dim + + return 2 ** big_relatively_prime(max_cycle_len) + + +@assert_shape("...->...") +class ReversibleRandomPermute(reversible.ReversibleLayer): + """Reversible, random permutation layer. - Returns: - An appropriate multiplier for the permutation reshape. + This layer permutates the last dimension (usually the embedding dimension) + by indexing and slicing. It uses the same random permutation for every + embedding, and this permutation never changes. """ - last_dim = x.shape[-1] - def big_relatively_prime(n): - # The longest possible cycle is achieved iff log2(multiplier) and - # log2(dimension_size) are relatively prime. We choose the biggest such - # number smaller than sqrt(dimension_size). - for i in range(n//2, 0, -1): - if n%i != 0: - return i - return 1 + def forward(self, x): + permutation, _ = self._get_permutation_and_reverse_permutation(x) + return x[..., permutation] - max_cycle_len = int(math.log(last_dim, 2)) - assert 2 ** max_cycle_len == last_dim + def reverse(self, x, weights=(), state=(), new_state=(), rng=None): + _, rev_permutation = self._get_permutation_and_reverse_permutation(x) + return x[..., rev_permutation] - return 2 ** big_relatively_prime(max_cycle_len) + def _get_permutation_and_reverse_permutation(self, x): + # TODO(jaszczur): random seed should be stored in state. + # Currently there is no way of doing it reliably. + last_dim = x.shape[-1] + permutation = list(range(last_dim)) + rand = pyrandom.Random(42) + rand.shuffle(permutation) + rev_permutation = [permutation.index(i) for i in range(last_dim)] + return permutation, rev_permutation -@assert_shape('...->...') -class ReversibleRandomPermute(reversible.ReversibleLayer): - """Reversible, random permutation layer. +@assert_shape("...a->...bc") +def SplitLastAxis(num_splits): + return tl.Fn( + f"SplitLastAxis_{num_splits}", + lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1)), + ) - This layer permutates the last dimension (usually the embedding dimension) - by indexing and slicing. It uses the same random permutation for every - embedding, and this permutation never changes. - """ - def forward(self, x): - permutation, _ = self._get_permutation_and_reverse_permutation(x) - return x[..., permutation] +@assert_shape("...ab->...c") +def MergeLastTwoAxes(): + return tl.Fn( + "MergeLastTwoAxes", lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,)) + ) + - def reverse(self, x, weights=(), state=(), new_state=(), rng=None): - _, rev_permutation = self._get_permutation_and_reverse_permutation(x) - return x[..., rev_permutation] +@assert_shape("...a->...b") +def LocallyConnectedDense( + n_modules, + n_units, + kernel_size=1, + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + use_bias=True, +): + """Layer using LocallyConnected1d for approximation of Dense layer. - def _get_permutation_and_reverse_permutation(self, x): - # TODO(jaszczur): random seed should be stored in state. - # Currently there is no way of doing it reliably. - last_dim = x.shape[-1] - permutation = list(range(last_dim)) - rand = pyrandom.Random(42) - rand.shuffle(permutation) - rev_permutation = [permutation.index(i) for i in range(last_dim)] - return permutation, rev_permutation + The layer splits the last axis of a tensor into `n_modules`, then runs + LocallyConnected1d (grouped convolution) on all those modules, and + concatenates their results. It is essentially a locally-sensitive + approximation of Dense layer, with number of parameters smaller by the factor + of `n_modules / kernel_size`. + Args: + n_modules: Indicates how many modules (pixels) should be input and output + split into for processing. + n_units: how many outputs (filters) should each module generate. + kernel_size: The size of the kernel to be used. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + use_bias: If `True`, compute an affine map `y = Wx + b`; else compute + a linear map `y = Wx`. -@assert_shape('...a->...bc') -def SplitLastAxis(num_splits): - return tl.Fn(f'SplitLastAxis_{num_splits}', - lambda x: jnp.reshape(x, tuple(x.shape)[:-1] + (num_splits, -1))) + Returns: + LocallyConnectedDense base.Layer. + """ + if n_modules == 1: + return tl.Dense( + n_units, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + use_bias=use_bias, + ) + return tl.Serial( + tl.SplitLastAxis(n_modules), + tl.LocallyConnected1d( + n_units, + kernel_size, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + use_bias=use_bias, + padding="WRAP", + ), + tl.MergeLastTwoAxes(), + ) -@assert_shape('...ab->...c') -def MergeLastTwoAxes(): - return tl.Fn('MergeLastTwoAxes', - lambda x: jnp.reshape(x, tuple(x.shape)[:-2] + (-1,))) - - -@assert_shape('...a->...b') -def LocallyConnectedDense(n_modules, n_units, kernel_size=1, - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - use_bias=True): - """Layer using LocallyConnected1d for approximation of Dense layer. - - The layer splits the last axis of a tensor into `n_modules`, then runs - LocallyConnected1d (grouped convolution) on all those modules, and - concatenates their results. It is essentially a locally-sensitive - approximation of Dense layer, with number of parameters smaller by the factor - of `n_modules / kernel_size`. - - Args: - n_modules: Indicates how many modules (pixels) should be input and output - split into for processing. - n_units: how many outputs (filters) should each module generate. - kernel_size: The size of the kernel to be used. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - use_bias: If `True`, compute an affine map `y = Wx + b`; else compute - a linear map `y = Wx`. - - Returns: - LocallyConnectedDense base.Layer. - """ - if n_modules == 1: - return tl.Dense(n_units, kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, use_bias=use_bias) - return tl.Serial( - tl.SplitLastAxis(n_modules), - tl.LocallyConnected1d( - n_units, kernel_size, kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, use_bias=use_bias, padding='WRAP'), - tl.MergeLastTwoAxes()) - - -@assert_shape('bld->bld') -def ModularCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, - max_inference_length=2048, - kernel_size=1, mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - it uses LocallyConnectedDense instead of Dense layer for computing Q/K/V. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: Number of modules used in LocallyConnectedDense. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - kernel_size: Kernel size used in LocallyConnectedDense. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - n_modules = n_heads if sparsity is None else sparsity - @assert_shape('...a->...b') - def ProcessingLayer(): - assert d_feature % n_modules == 0 - return LocallyConnectedDense(n_modules, d_feature // n_modules, - kernel_size=kernel_size) - - return tl.ConfigurableAttention( - ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), - ProcessingLayer(), n_heads=n_heads, - qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) +@assert_shape("bld->bld") +def ModularCausalAttention( + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + kernel_size=1, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + it uses LocallyConnectedDense instead of Dense layer for computing Q/K/V. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: Number of modules used in LocallyConnectedDense. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + kernel_size: Kernel size used in LocallyConnectedDense. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + n_modules = n_heads if sparsity is None else sparsity + + @assert_shape("...a->...b") + def ProcessingLayer(): + assert d_feature % n_modules == 0 + return LocallyConnectedDense( + n_modules, d_feature // n_modules, kernel_size=kernel_size + ) + + return tl.ConfigurableAttention( + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) class _RememberPad(base.Layer): - """Layer which remembers last N elements in predict mode.""" + """Layer which remembers last N elements in predict mode.""" + + def __init__(self, n_items_to_remember, mode): + """Returns a layer which remembers last N elements in predict mode. + + For predict mode, the layer remembers last N elements and pads with them. + For other modes, it pads with zeros. The layer pads/remembers elements from + the second axis. + + Args: + n_items_to_remember: Number of items to remember/pad with. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + super().__init__(name="_RememberPad") + self._n_items_to_remember = n_items_to_remember + self._mode = mode + self._portal_mask = ( + self.monkey_patched_mask() + ) # pylint: disable=assignment-from-none + + def monkey_patched_mask(self): + # This is necessary for Terraformer model. See comments there. + # The mask will only be used in Terraformer in predict mode. + return None + + def forward(self, x): + if self._n_items_to_remember == 0: + return x + if self._mode == "predict": + x = jnp.concatenate([self.state[0], x], axis=1) + if self._portal_mask is not None and "init" in self.state[1]: + # TODO(jaszczur): In predict mode with monkey-patched mask, we + # currently assume that batch size is 1. + assert x.shape[0] == 1 + mask = self._portal_mask.get_value() + count_padding = jnp.sum(mask == 0, dtype=jnp.int32) + self.state = ( + fastmath.dynamic_slice_in_dim( + x, + x.shape[1] - (self._n_items_to_remember + count_padding), + self._n_items_to_remember, + axis=1, + ), + {"forward": ()}, + ) + else: + self.state = (x[:, -self._n_items_to_remember :, ...], {"forward": ()}) + else: + pad_widths = [[0, 0] for _ in range(len(x.shape))] + pad_widths[1][0] = self._n_items_to_remember + x = jnp.pad(x, pad_width=pad_widths, mode="constant") + return x + + def init_weights_and_state(self, input_signature): + """Initializes this layer's weights.""" + if isinstance(input_signature, (list, tuple)): + input_signature = input_signature[0] + self.weights = () + if self._mode == "predict": + shape = list(input_signature.shape) + shape[1] = self._n_items_to_remember + self.state = (jnp.zeros(shape, dtype=jnp.float32), {"init": ()}) + else: + self.state = () + + +@assert_shape("...a->...b") +def LocallyConvDense(n_modules, n_units, mode, kernel_size=1, length_kernel_size=1): + """Layer using local convolutions for approximation of Dense layer. + + The layer splits the last axis of a tensor into `n_modules`, then runs + a convolution on all those modules, and concatenates their results. + It is similar to LocallyConnectedDense above, but shares weights. + + Args: + n_modules: Indicates how many modules (pixels) should be input and output + split into for processing. + n_units: how many outputs (filters) should each module generate. + mode: One of `'train'`, `'eval'`, or `'predict'`. + kernel_size: The size of the kernel to be used. + length_kernel_size: If > 1, also do causal convolution on the previous axis, + which is often the sentence length in sequence models. + + Returns: + LocallyConvDense base.Layer. + """ + if n_modules == 1: + return tl.Dense(n_units) + if kernel_size % 2 != 1: + raise ValueError("Currently we only handle odd kernel sizes.") + half = (kernel_size - 1) // 2 + pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]] + return tl.Serial( + tl.SplitLastAxis(n_modules), + tl.Fn("Pad", lambda x: jnp.pad(x, pad_width=pad_widths, mode="constant")), + _RememberPad(length_kernel_size - 1, mode=mode), + tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)), + tl.MergeLastTwoAxes(), + ) + - def __init__(self, n_items_to_remember, mode): - """Returns a layer which remembers last N elements in predict mode. +@assert_shape("bld->bld") +def ConvCausalAttention( + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + kernel_size=1, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. - For predict mode, the layer remembers last N elements and pads with them. - For other modes, it pads with zeros. The layer pads/remembers elements from - the second axis. + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + it uses LocallyConvDense instead of Dense layer for computing Q/K/V. Args: - n_items_to_remember: Number of items to remember/pad with. + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: Number of modules used in LocallyConvDense. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + kernel_size: Kernel size used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ - super().__init__(name='_RememberPad') - self._n_items_to_remember = n_items_to_remember - self._mode = mode - self._portal_mask = self.monkey_patched_mask() # pylint: disable=assignment-from-none - - def monkey_patched_mask(self): - # This is necessary for Terraformer model. See comments there. - # The mask will only be used in Terraformer in predict mode. - return None - - def forward(self, x): - if self._n_items_to_remember == 0: - return x - if self._mode == 'predict': - x = jnp.concatenate([self.state[0], x], axis=1) - if self._portal_mask is not None and 'init' in self.state[1]: - # TODO(jaszczur): In predict mode with monkey-patched mask, we - # currently assume that batch size is 1. - assert x.shape[0] == 1 - mask = self._portal_mask.get_value() - count_padding = jnp.sum(mask == 0, dtype=jnp.int32) - self.state = (fastmath.dynamic_slice_in_dim( - x, x.shape[1] - (self._n_items_to_remember + count_padding), - self._n_items_to_remember, axis=1), {'forward': ()}) - else: - self.state = (x[:, -self._n_items_to_remember:, ...], {'forward': ()}) - else: - pad_widths = [[0, 0] for _ in range(len(x.shape))] - pad_widths[1][0] = self._n_items_to_remember - x = jnp.pad(x, pad_width=pad_widths, mode='constant') - return x - - def init_weights_and_state(self, input_signature): - """Initializes this layer's weights.""" - if isinstance(input_signature, (list, tuple)): - input_signature = input_signature[0] - self.weights = () - if self._mode == 'predict': - shape = list(input_signature.shape) - shape[1] = self._n_items_to_remember - self.state = (jnp.zeros(shape, dtype=jnp.float32), {'init': ()}) - else: - self.state = () - - -@assert_shape('...a->...b') -def LocallyConvDense(n_modules, n_units, mode, kernel_size=1, - length_kernel_size=1): - """Layer using local convolutions for approximation of Dense layer. - - The layer splits the last axis of a tensor into `n_modules`, then runs - a convolution on all those modules, and concatenates their results. - It is similar to LocallyConnectedDense above, but shares weights. - - Args: - n_modules: Indicates how many modules (pixels) should be input and output - split into for processing. - n_units: how many outputs (filters) should each module generate. - mode: One of `'train'`, `'eval'`, or `'predict'`. - kernel_size: The size of the kernel to be used. - length_kernel_size: If > 1, also do causal convolution on the previous axis, - which is often the sentence length in sequence models. - - Returns: - LocallyConvDense base.Layer. - """ - if n_modules == 1: - return tl.Dense(n_units) - if kernel_size % 2 != 1: - raise ValueError('Currently we only handle odd kernel sizes.') - half = (kernel_size - 1) // 2 - pad_widths = [[0, 0], [0, 0], [half, half], [0, 0]] - return tl.Serial( - tl.SplitLastAxis(n_modules), - tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths, mode='constant')), - _RememberPad(length_kernel_size-1, mode=mode), - tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)), - tl.MergeLastTwoAxes() - ) - - -@assert_shape('bld->bld') -def ConvCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, - max_inference_length=2048, - kernel_size=1, mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - it uses LocallyConvDense instead of Dense layer for computing Q/K/V. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: Number of modules used in LocallyConvDense. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - kernel_size: Kernel size used in LocallyConnectedDense. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - n_modules = n_heads if sparsity is None else sparsity - @assert_shape('...a->...b') - def ProcessingLayer(): - assert d_feature % n_modules == 0 - return LocallyConvDense(n_modules, d_feature // n_modules, mode=mode, - kernel_size=kernel_size) - - return tl.ConfigurableAttention( - ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), - ProcessingLayer(), n_heads=n_heads, - qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('...a->...b') + n_modules = n_heads if sparsity is None else sparsity + + @assert_shape("...a->...b") + def ProcessingLayer(): + assert d_feature % n_modules == 0 + return LocallyConvDense( + n_modules, d_feature // n_modules, mode=mode, kernel_size=kernel_size + ) + + return tl.ConfigurableAttention( + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + ProcessingLayer(), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("...a->...b") def LowRankDense(n_units, d_lowrank): - return tl.Serial( - tl.Dense(d_lowrank), - tl.Dense(n_units) - ) + return tl.Serial(tl.Dense(d_lowrank), tl.Dense(n_units)) -@assert_shape('...a->...b') +@assert_shape("...a->...b") def EinsumDense(d_input, d_output, use_bias): - """Returns a reimplementation of Dense layer, using einsum. - - While this is an equivalent of a Dense layer, it seems to be faster when used - in decoding if used with bias (see decoding_timing_test.py ). - This layer can be removed when we understand better the reason for the - difference in decoding speed. - - Args: - d_input: Dimensionality of the input tensor. - d_output: Dimensionality of the output tensor. - use_bias: Whether to use bias. - """ - layers = [ - tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]), - tl.Fn('EinsumDense', - (lambda kernel, embeds: # pylint: disable=g-long-lambda - jnp.einsum('xd,...d->...x', kernel, embeds))) - ] - if use_bias: - layers.extend([ - tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), - tl.Add() - ]) - return tl.Serial(layers) + """Returns a reimplementation of Dense layer, using einsum. + + While this is an equivalent of a Dense layer, it seems to be faster when used + in decoding if used with bias (see decoding_timing_test.py ). + This layer can be removed when we understand better the reason for the + difference in decoding speed. + + Args: + d_input: Dimensionality of the input tensor. + d_output: Dimensionality of the output tensor. + use_bias: Whether to use bias. + """ + layers = [ + tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]), + tl.Fn( + "EinsumDense", + ( + lambda kernel, embeds: jnp.einsum( # pylint: disable=g-long-lambda + "xd,...d->...x", kernel, embeds + ) + ), + ), + ] + if use_bias: + layers.extend( + [tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), tl.Add()] + ) + return tl.Serial(layers) def RandomLayer(layer_a, layer_b, prob_a): - """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.""" - condition = tl.Serial( - tl.RandomUniform(), - tl.Fn('SmallerThan', lambda x: x < prob_a) - ) - return tl.Cond(condition, layer_a, layer_b) - - -@assert_shape('...a->...b') -def SparseDenseWithOptions(n_units, d_input=None, sparsity_type=None, - sparsity=0, d_lowrank=None, prob_sparse=None, - mode=None, use_bias=True, use_bfloat16=False): - """Configurable sparse version of Dense layer.""" - if prob_sparse is not None: - if mode is not None and mode != 'train': - # For non-training modes, we want to use a sparse variant. - # This is different than simply prob_sparse being None, as the weights of - # the model are different. - prob_sparse = 1.0 - return RandomLayer( - SparseDenseWithOptions(n_units, d_input, sparsity_type, sparsity, - d_lowrank, use_bias=use_bias, - use_bfloat16=use_bfloat16), - tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16), - prob_sparse) - - if sparsity_type is None or sparsity_type == 'None' or sparsity == 0: - return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16) - if sparsity_type == 'mult': - return FactoredDense(sparsity, d_input, n_units, use_bias=use_bias, - use_bfloat16=use_bfloat16) - - assert not use_bfloat16 # use_bfloat16 is unsupported for other variants - if sparsity_type == 'lowrank': - assert use_bias # use_bias=False is unsupported - return LowRankDense(n_units, d_lowrank) - if sparsity_type == 'einsum': - return EinsumDense(d_input, n_units, use_bias=use_bias) - if sparsity_type == 'local': - assert use_bias # use_bias = False is unsupported - assert n_units % sparsity == 0 - return LocallyConnectedDense(sparsity, n_units/sparsity) - if sparsity_type == 'local3': - assert use_bias # use_bias = False is unsupported - assert n_units % sparsity == 0 - return LocallyConnectedDense(sparsity, n_units/sparsity, kernel_size=3) - - raise ValueError('Unknown sparsity type: {}'.format(sparsity_type)) - - -@assert_shape('bld->bld') -def LowRankCausalAttention(d_feature, n_heads=1, dropout=0.0, - max_inference_length=2048, lowrank=64, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - it uses low-rank approximation of kernel in Dense layer for computing Q/K/V. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - lowrank: The rank of low-rank approximation. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - return tl.ConfigurableAttention( - LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), - LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), - n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('...a->...b') + """Runs `layer_a` with probability `prob_a`, otherwise runs `layer_b`.""" + condition = tl.Serial( + tl.RandomUniform(), tl.Fn("SmallerThan", lambda x: x < prob_a) + ) + return tl.Cond(condition, layer_a, layer_b) + + +@assert_shape("...a->...b") +def SparseDenseWithOptions( + n_units, + d_input=None, + sparsity_type=None, + sparsity=0, + d_lowrank=None, + prob_sparse=None, + mode=None, + use_bias=True, + use_bfloat16=False, +): + """Configurable sparse version of Dense layer.""" + if prob_sparse is not None: + if mode is not None and mode != "train": + # For non-training modes, we want to use a sparse variant. + # This is different than simply prob_sparse being None, as the weights of + # the model are different. + prob_sparse = 1.0 + return RandomLayer( + SparseDenseWithOptions( + n_units, + d_input, + sparsity_type, + sparsity, + d_lowrank, + use_bias=use_bias, + use_bfloat16=use_bfloat16, + ), + tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16), + prob_sparse, + ) + + if sparsity_type is None or sparsity_type == "None" or sparsity == 0: + return tl.Dense(n_units, use_bias=use_bias, use_bfloat16=use_bfloat16) + if sparsity_type == "mult": + return FactoredDense( + sparsity, d_input, n_units, use_bias=use_bias, use_bfloat16=use_bfloat16 + ) + + assert not use_bfloat16 # use_bfloat16 is unsupported for other variants + if sparsity_type == "lowrank": + assert use_bias # use_bias=False is unsupported + return LowRankDense(n_units, d_lowrank) + if sparsity_type == "einsum": + return EinsumDense(d_input, n_units, use_bias=use_bias) + if sparsity_type == "local": + assert use_bias # use_bias = False is unsupported + assert n_units % sparsity == 0 + return LocallyConnectedDense(sparsity, n_units / sparsity) + if sparsity_type == "local3": + assert use_bias # use_bias = False is unsupported + assert n_units % sparsity == 0 + return LocallyConnectedDense(sparsity, n_units / sparsity, kernel_size=3) + + raise ValueError("Unknown sparsity type: {}".format(sparsity_type)) + + +@assert_shape("bld->bld") +def LowRankCausalAttention( + d_feature, + n_heads=1, + dropout=0.0, + max_inference_length=2048, + lowrank=64, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + it uses low-rank approximation of kernel in Dense layer for computing Q/K/V. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + lowrank: The rank of low-rank approximation. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + + return tl.ConfigurableAttention( + LowRankDense(d_feature, lowrank), + LowRankDense(d_feature, lowrank), + LowRankDense(d_feature, lowrank), + LowRankDense(d_feature, lowrank), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("...a->...b") def FactoredDense(n_modules, d_in, d_out, use_bias=True, use_bfloat16=False): - r"""Returns a Dense-like layer, internally factored to use fewer parameters. - - This layer treats an activation vector as if divided into :math:`M` - subvectors (``n_modules`` 'modules'). It uses this factored view to compute - a :py:class:`Dense`-like mapping with high mixing/connectivity, but using - approximately :math:`1/M` the number of weights of a similarly dimensioned - :py:class:`Dense` layer. - - More specifically, each activation vector of dimensionality ``n_in`` is - multiplied element-wise (a generalized form of gating) with ``n_modules`` - vectors also of dimensionality ``n_in``. The resulting vectors are projected - to the subvector/module dimensionality ``d_out / n_modules`` via a matrix - multiply, and finally reshaped back to a single vector of dimensionality - ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at - the end. All the above-mentioned non-input objects -- gating vectors, - projection matrix, and optional bias -- are trainable weights. - - Args: - n_modules: Number by which an activation vector is divided into subvectors - (modules) for the factored computation. - d_in: Last/innermost dimension of input array. - d_out: Last/innermost dimension of output array. - use_bias: If True, add bias vectors at the end of the layer; else end the - layer with the matrix multiply. - use_bfloat16: If True, use bfloat16 weights; else use float32 weights. - """ - if d_out % n_modules != 0: - raise ValueError(f'Value d_out ({d_out}) must be a multiple of arg ' - f'n_modules ({n_modules}).') - d_module = d_out // n_modules - - def GatingVectors(): - return tl.Weights(init.RandomNormalInitializer(stddev=0.5), - shape=[n_modules, d_in], - use_bfloat16=use_bfloat16) - - def ProjectionMatrix(): - return tl.Weights(init.GlorotUniformInitializer(), - shape=[d_in, d_module], - use_bfloat16=use_bfloat16), - - def Bias(): - return tl.Weights(init.RandomNormalInitializer(1e-6), - shape=[d_out], - use_bfloat16=use_bfloat16), - - layers = [ - GatingVectors(), - ProjectionMatrix(), - _GateAndProject(), - MergeLastTwoAxes(), - ] - if use_bias: - layers += [Bias(), tl.Add()] - - return tl.Serial(layers) + r"""Returns a Dense-like layer, internally factored to use fewer parameters. + + This layer treats an activation vector as if divided into :math:`M` + subvectors (``n_modules`` 'modules'). It uses this factored view to compute + a :py:class:`Dense`-like mapping with high mixing/connectivity, but using + approximately :math:`1/M` the number of weights of a similarly dimensioned + :py:class:`Dense` layer. + + More specifically, each activation vector of dimensionality ``n_in`` is + multiplied element-wise (a generalized form of gating) with ``n_modules`` + vectors also of dimensionality ``n_in``. The resulting vectors are projected + to the subvector/module dimensionality ``d_out / n_modules`` via a matrix + multiply, and finally reshaped back to a single vector of dimensionality + ``d_out``. Optionally, a bias vector of dimensionality ``d_out`` is added at + the end. All the above-mentioned non-input objects -- gating vectors, + projection matrix, and optional bias -- are trainable weights. + + Args: + n_modules: Number by which an activation vector is divided into subvectors + (modules) for the factored computation. + d_in: Last/innermost dimension of input array. + d_out: Last/innermost dimension of output array. + use_bias: If True, add bias vectors at the end of the layer; else end the + layer with the matrix multiply. + use_bfloat16: If True, use bfloat16 weights; else use float32 weights. + """ + if d_out % n_modules != 0: + raise ValueError( + f"Value d_out ({d_out}) must be a multiple of arg " + f"n_modules ({n_modules})." + ) + d_module = d_out // n_modules + + def GatingVectors(): + return tl.Weights( + init.RandomNormalInitializer(stddev=0.5), + shape=[n_modules, d_in], + use_bfloat16=use_bfloat16, + ) + + def ProjectionMatrix(): + return ( + tl.Weights( + init.GlorotUniformInitializer(), + shape=[d_in, d_module], + use_bfloat16=use_bfloat16, + ), + ) + + def Bias(): + return ( + tl.Weights( + init.RandomNormalInitializer(1e-6), + shape=[d_out], + use_bfloat16=use_bfloat16, + ), + ) + + layers = [ + GatingVectors(), + ProjectionMatrix(), + _GateAndProject(), + MergeLastTwoAxes(), + ] + if use_bias: + layers += [Bias(), tl.Add()] + + return tl.Serial(layers) def _GateAndProject(): - """Returns a combined gating+projection layer that saves on memory.""" + """Returns a combined gating+projection layer that saves on memory.""" - def f(projection, gating, x): - # Args arrive in reverse order because of how they were put on the stack. - # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules) - return jnp.einsum('...d,nd,dm->...nm', x, gating, projection) + def f(projection, gating, x): + # Args arrive in reverse order because of how they were put on the stack. + # Einsum indices: d (d_in), n (n_modules), m (d_module = d_out/n_modules) + return jnp.einsum("...d,nd,dm->...nm", x, gating, projection) - return tl.Fn('_GateAndProject', f) + return tl.Fn("_GateAndProject", f) -@assert_shape('...a->...a') +@assert_shape("...a->...a") def MultiplicativeModularSparseDense(sparsity, d_feature): - """Returns a replacement of Dense layer which uses less parameters. - - The layer uses number of modules equal to `sparsity`. It is a combination of - multiplicative dense and locally connected dense layers. - - Args: - sparsity: The sparsity of the layer; the output vector is divided into this - number of modules. - d_feature: Dimensionality of input and output tensor. - """ - - assert d_feature % sparsity == 0 - d_module = d_feature // sparsity - - return tl.Serial( - # Weight below is used for per-head preprocessing of an embedding. - tl.Weights(init.RandomNormalInitializer(stddev=0.5), - shape=[sparsity, d_feature]), - # Weight below is a kernel of multiplicative dense, shared across heads. - tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), - # Weight below is a kernel of modular dense. - tl.Weights(functools.partial(init.GlorotUniformInitializer(), - nonreceptive_dims=[0]), - [sparsity, d_module, d_module]), - # To save memory the per-head preprocessing and multiplying by - # kernels is done in a single einsum. - tl.Fn('SparseDenseEinsum', - (lambda kmod, kmult, multiplier, embeds: # pylint: disable=g-long-lambda - jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier, embeds - ))), - MergeLastTwoAxes(), - # Weight below is bias after dense, per-head. - tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), - tl.Add(), - ) - - -@assert_shape('bld->bld') -def MultiplicativeCausalAttention(d_feature, n_heads=1, sparsity=None, - dropout=0.0, max_inference_length=2048, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - for computing Q/K/V instead of a Dense layer it multiplies each embedding - dimension by a scalar specific to each dimension and each head; then it - produces Q/K/V by applying the same dense layer to each head. In comparison - to standard dense layer for computing Q/K/V, this layer uses less parameters - while still being able to express many functions, like a permutation. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: The sparsity of the layer; usually it should be equal to n_heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - sparsity = n_heads if sparsity is None else sparsity - return tl.ConfigurableAttention( - FactoredDense(sparsity, d_feature, d_feature), - FactoredDense(sparsity, d_feature, d_feature), - FactoredDense(sparsity, d_feature, d_feature), - FactoredDense(sparsity, d_feature, d_feature), - n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('bld->bld') + """Returns a replacement of Dense layer which uses less parameters. + + The layer uses number of modules equal to `sparsity`. It is a combination of + multiplicative dense and locally connected dense layers. + + Args: + sparsity: The sparsity of the layer; the output vector is divided into this + number of modules. + d_feature: Dimensionality of input and output tensor. + """ + + assert d_feature % sparsity == 0 + d_module = d_feature // sparsity + + return tl.Serial( + # Weight below is used for per-head preprocessing of an embedding. + tl.Weights( + init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_feature] + ), + # Weight below is a kernel of multiplicative dense, shared across heads. + tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), + # Weight below is a kernel of modular dense. + tl.Weights( + functools.partial(init.GlorotUniformInitializer(), nonreceptive_dims=[0]), + [sparsity, d_module, d_module], + ), + # To save memory the per-head preprocessing and multiplying by + # kernels is done in a single einsum. + tl.Fn( + "SparseDenseEinsum", + ( + lambda kmod, kmult, multiplier, embeds: jnp.einsum( # pylint: disable=g-long-lambda + "hxo,dx,hd,...d->...ho", kmod, kmult, multiplier, embeds + ) + ), + ), + MergeLastTwoAxes(), + # Weight below is bias after dense, per-head. + tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), + tl.Add(), + ) + + +@assert_shape("bld->bld") +def MultiplicativeCausalAttention( + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + for computing Q/K/V instead of a Dense layer it multiplies each embedding + dimension by a scalar specific to each dimension and each head; then it + produces Q/K/V by applying the same dense layer to each head. In comparison + to standard dense layer for computing Q/K/V, this layer uses less parameters + while still being able to express many functions, like a permutation. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: The sparsity of the layer; usually it should be equal to n_heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + sparsity = n_heads if sparsity is None else sparsity + return tl.ConfigurableAttention( + FactoredDense(sparsity, d_feature, d_feature), + FactoredDense(sparsity, d_feature, d_feature), + FactoredDense(sparsity, d_feature, d_feature), + FactoredDense(sparsity, d_feature, d_feature), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("bld->bld") def MultiplicativeModularCausalAttention( - d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, - mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - for computing Q/K/V instead of a Dense layer it combines - FactoredDense layer with LocallyConnectedLayer. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: The sparsity of the layer; usually it should be equal to n_heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - max_inference_length: maximum length for inference. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - sparsity = n_heads if sparsity is None else sparsity - return tl.ConfigurableAttention( - MultiplicativeModularSparseDense(sparsity, d_feature), - MultiplicativeModularSparseDense(sparsity, d_feature), - MultiplicativeModularSparseDense(sparsity, d_feature), - MultiplicativeModularSparseDense(sparsity, d_feature), n_heads=n_heads, - qkv_attention_layer=tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode)) - - -@assert_shape('bld->bld') + d_feature, + n_heads=1, + sparsity=None, + dropout=0.0, + max_inference_length=2048, + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + for computing Q/K/V instead of a Dense layer it combines + FactoredDense layer with LocallyConnectedLayer. + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: The sparsity of the layer; usually it should be equal to n_heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + max_inference_length: maximum length for inference. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + sparsity = n_heads if sparsity is None else sparsity + return tl.ConfigurableAttention( + MultiplicativeModularSparseDense(sparsity, d_feature), + MultiplicativeModularSparseDense(sparsity, d_feature), + MultiplicativeModularSparseDense(sparsity, d_feature), + MultiplicativeModularSparseDense(sparsity, d_feature), + n_heads=n_heads, + qkv_attention_layer=tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + ) + + +@assert_shape("bld->bld") def MultiplicativeConvCausalAttention( - d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0, - force_no_dropout=False, max_inference_length=2048, share_qk=False, - output_layer_type='none', v_concat_type='none', mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - self-attention with causal masking rather than padding-based masking. However, - for computing Q/K/V instead of a Dense layer it combines - FactoredDense layer with LocallyConvLayer. - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - sparsity: The sparsity of the layer; usually it should be equal to n_heads. - length_kernel_size: Size of convolution kernel on the length dimension. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - force_no_dropout: If True, force dropout to be 0.0 independent of the above - value; used to override some configurations. - max_inference_length: maximum length for inference. - share_qk: if True, average Q and K embeddings and share for both Q and K. - output_layer_type: Which sparse layers to use for processing output from the - attention mechanism. One of `'none'`, `'mult'`, `'conv'`, - or `'multconv'`. - v_concat_type: What kind of concatenation to use when computing V tensor. - One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just - output from mutliplicative layer shared by Q, K, V. `'fixed'` means - using output from multiplicative layer concatenated, for each module, - with the layer input. `'original'` means using concatenation without - properly taking modules into account; this method was used in - experiments previously, so it is included for backwards-compatibility. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - assert output_layer_type in ['none', 'mult', 'conv', 'multconv'] - assert v_concat_type in ['original', 'fixed', 'none'] - - dropout = 0.0 if force_no_dropout else dropout - sparsity = n_heads if sparsity is None else sparsity - d_module = d_feature // sparsity - - output_layers = [] - if 'mult' in output_layer_type: - output_layers.append(FactoredDense( - sparsity, d_feature, d_feature)) - if 'conv' in output_layer_type: - output_layers.append(LocallyConvDense( - sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size)) - - if v_concat_type == 'original': - # 'original'` uses concatenation without properly taking modules into - # account; this method was used in experiments previously, so it is included - # for backwards-compatibility. - concat_layers = [tl.Concatenate()] # use permuted and original for v - elif v_concat_type == 'fixed': - # `'fixed'` uses the output from multiplicative layer concatenated, for each - # module, with the layer input. This means that every module in Conv layer - # has access both to parts of embeddings which were used to compute Q/K of - # this particular module, and it ha access to parts of the embedding which - # will be modified by this module. - concat_layers = [ - tl.Parallel( - tl.Fn('Reshape1', lambda x: jnp.reshape( # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1], sparsity, d_module))), - tl.Fn('Reshape2', lambda x: jnp.reshape( # pylint: disable=g-long-lambda - x, (x.shape[0], x.shape[1], sparsity, d_module)))), - tl.Concatenate(), - tl.Fn('Reshape3', - lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2*d_feature))), - ] - elif v_concat_type == 'none': - # `'none'` doesn't use concatenation: we throw away the original layer - # input and pass to Conv only output of shared Multiplicative layer. - concat_layers = [tl.Select([0], n_in=2)] + d_feature, + n_heads=1, + sparsity=None, + length_kernel_size=3, + dropout=0.0, + force_no_dropout=False, + max_inference_length=2048, + share_qk=False, + output_layer_type="none", + v_concat_type="none", + mode="train", +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + self-attention with causal masking rather than padding-based masking. However, + for computing Q/K/V instead of a Dense layer it combines + FactoredDense layer with LocallyConvLayer. - if share_qk: + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + sparsity: The sparsity of the layer; usually it should be equal to n_heads. + length_kernel_size: Size of convolution kernel on the length dimension. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + force_no_dropout: If True, force dropout to be 0.0 independent of the above + value; used to override some configurations. + max_inference_length: maximum length for inference. + share_qk: if True, average Q and K embeddings and share for both Q and K. + output_layer_type: Which sparse layers to use for processing output from the + attention mechanism. One of `'none'`, `'mult'`, `'conv'`, + or `'multconv'`. + v_concat_type: What kind of concatenation to use when computing V tensor. + One of `'original'`, `'fixed'`, or `'none'`. `'none'` means using just + output from mutliplicative layer shared by Q, K, V. `'fixed'` means + using output from multiplicative layer concatenated, for each module, + with the layer input. `'original'` means using concatenation without + properly taking modules into account; this method was used in + experiments previously, so it is included for backwards-compatibility. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + assert output_layer_type in ["none", "mult", "conv", "multconv"] + assert v_concat_type in ["original", "fixed", "none"] + + dropout = 0.0 if force_no_dropout else dropout + sparsity = n_heads if sparsity is None else sparsity + d_module = d_feature // sparsity + + output_layers = [] + if "mult" in output_layer_type: + output_layers.append(FactoredDense(sparsity, d_feature, d_feature)) + if "conv" in output_layer_type: + output_layers.append( + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ) + ) + + if v_concat_type == "original": + # 'original'` uses concatenation without properly taking modules into + # account; this method was used in experiments previously, so it is included + # for backwards-compatibility. + concat_layers = [tl.Concatenate()] # use permuted and original for v + elif v_concat_type == "fixed": + # `'fixed'` uses the output from multiplicative layer concatenated, for each + # module, with the layer input. This means that every module in Conv layer + # has access both to parts of embeddings which were used to compute Q/K of + # this particular module, and it ha access to parts of the embedding which + # will be modified by this module. + concat_layers = [ + tl.Parallel( + tl.Fn( + "Reshape1", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + x, (x.shape[0], x.shape[1], sparsity, d_module) + ), + ), + tl.Fn( + "Reshape2", + lambda x: jnp.reshape( # pylint: disable=g-long-lambda + x, (x.shape[0], x.shape[1], sparsity, d_module) + ), + ), + ), + tl.Concatenate(), + tl.Fn( + "Reshape3", + lambda x: jnp.reshape(x, (x.shape[0], x.shape[1], 2 * d_feature)), + ), + ] + elif v_concat_type == "none": + # `'none'` doesn't use concatenation: we throw away the original layer + # input and pass to Conv only output of shared Multiplicative layer. + concat_layers = [tl.Select([0], n_in=2)] + + if share_qk: + return tl.Serial( + tl.Select([0, 0]), # pre-qkv, pre-v-for-concat + FactoredDense(sparsity, d_feature, d_feature), # shared q k + tl.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + tl.Select([0, 0]), # use for q and k + tl.Parallel( + [], + [], + [ + concat_layers, + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], + ), + tl.DotProductCausalAttention( + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), + tl.MergeHeads(n_heads), + output_layers, + ) return tl.Serial( - tl.Select([0, 0]), # pre-qkv, pre-v-for-concat - FactoredDense(sparsity, d_feature, d_feature), # shared q k - tl.Select([0, 0]), # pre-qk, pre-v, pre-v-for-concat - LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads), - tl.Select([0, 0]), # use for q and k + tl.Select([0, 0]), # duplicate activations + FactoredDense(sparsity, d_feature, d_feature), # shared q, k + tl.Select([0, 0, 0]), # use for q, k, v tl.Parallel( - [], - [], - [concat_layers, - LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], + [ + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], + [ + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=3, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], + [ + concat_layers, + LocallyConvDense( + sparsity, + d_module, + mode=mode, + kernel_size=1, + length_kernel_size=length_kernel_size, + ), + tl.SplitIntoHeads(n_heads), + ], ), tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode), + dropout=dropout, max_inference_length=max_inference_length, mode=mode + ), tl.MergeHeads(n_heads), output_layers, ) - return tl.Serial( - tl.Select([0, 0]), # duplicate activations - FactoredDense(sparsity, d_feature, d_feature), # shared q, k - tl.Select([0, 0, 0]), # use for q, k, v - tl.Parallel( - [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], - [LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=3, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], - [concat_layers, - LocallyConvDense(sparsity, d_module, mode=mode, kernel_size=1, - length_kernel_size=length_kernel_size), - tl.SplitIntoHeads(n_heads)], - ), - tl.DotProductCausalAttention( - dropout=dropout, max_inference_length=max_inference_length, - mode=mode), - tl.MergeHeads(n_heads), - output_layers, - ) class FavorAttention(base.Layer): - """Implements FAVOR+ attention. - - Original paper: https://arxiv.org/abs/2006.03555 - The layer expects 4 inputs: (Q, K, V, MASK), and returns two outputs: - (RENORMALIZED_ATTENTION, MASK). - - Attributes: - - d_feature: Dimensionality of feature embedding. - n_heads: Number of attention heads. - n_random_features: Free dimension size for the orthogonal random matrix. - numerical_stabilizer: float, small number used for numerical stability. - use_approximate_softmax: Bool, if True uses approximate softmax, otherwise - Relu. - scale_by_norm: Boolean; whether to scale orthogonal random matrix. - normalize_data: predicate indicating whether data should be normalized. - epsilon: numerical stabilizer. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - def __init__(self, d_feature=4, n_heads=1, n_random_features=256, - numerical_stabilizer=0.001, - use_approximate_softmax=False, scale_by_norm=True, - normalize_data=False, - epsilon=0.0001, mode='train'): - super().__init__(n_in=4, n_out=2) - self._d_feature = d_feature - self._n_heads = n_heads - self._n_random_features = n_random_features - self._numerical_stabilizer = numerical_stabilizer - self._mode = mode - self._use_approximate_softmax = use_approximate_softmax - self._normalize_data = normalize_data - self._epsilon = epsilon - if self._use_approximate_softmax: - rng = random.get_prng(0) - self._projection_matrix = self.get_2d_array( - rng=rng, n_rows=self._n_random_features, - n_columns=(self._d_feature // self._n_heads), - scale_by_norm=scale_by_norm, - normalize_data=normalize_data, epsilon=epsilon) - else: - self._projection_matrix = None + """Implements FAVOR+ attention. - def nonnegative_softmax_kernel_feature_creator(self, x, is_query): - """Constructs nonnegative kernel features for fast softmax attention. + Original paper: https://arxiv.org/abs/2006.03555 + The layer expects 4 inputs: (Q, K, V, MASK), and returns two outputs: + (RENORMALIZED_ATTENTION, MASK). - Args: - x: input for which features are computed. - is_query: predicate indicating whether input data corresponds to - queries or keys. + Attributes: - Returns: - Random features for fast softmax attention. + d_feature: Dimensionality of feature embedding. + n_heads: Number of attention heads. + n_random_features: Free dimension size for the orthogonal random matrix. + numerical_stabilizer: float, small number used for numerical stability. + use_approximate_softmax: Bool, if True uses approximate softmax, otherwise + Relu. + scale_by_norm: Boolean; whether to scale orthogonal random matrix. + normalize_data: predicate indicating whether data should be normalized. + epsilon: numerical stabilizer. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - if self._normalize_data: - # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where - # w_norm = w * data_normalizer for w in {q,k}. - data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(x.shape[-1]))) - else: - data_normalizer = 1.0 - ratio = 1.0 / jnp.sqrt(self._projection_matrix.shape[0]) - # TODO(wgaj): Double-check... Should there be only one batch dimension...? - data_mod_shape = x.shape[0:1] + self._projection_matrix.shape - data_thick_random_matrix = (jnp.zeros(data_mod_shape) + - self._projection_matrix) - - data_dash = jnp.einsum('Bij, Bkj -> Bik', - data_normalizer * x, - data_thick_random_matrix) - diag_data = jnp.square(x) - diag_data = jnp.sum(diag_data, axis=x.ndim - 1) - diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer - diag_data = jnp.expand_dims(diag_data, axis=x.ndim - 1) - - last_dims_t = (len(data_dash.shape) - 1,) - attention_dims_t = (1,) - if is_query: - data_dash = ratio * ( - jnp.exp(data_dash - diag_data - - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + - self._epsilon) - else: - data_dash = ratio * ( - jnp.exp(data_dash - diag_data - jnp.max( - data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) + - self._epsilon) - return data_dash - - @staticmethod - def get_2d_array(rng, n_rows=256, n_columns=0, scale_by_norm=True, - normalize_data=False, epsilon=0.0001): - """Generator for approximate softmax orthogonal kernel feature matrix. + def __init__( + self, + d_feature=4, + n_heads=1, + n_random_features=256, + numerical_stabilizer=0.001, + use_approximate_softmax=False, + scale_by_norm=True, + normalize_data=False, + epsilon=0.0001, + mode="train", + ): + super().__init__(n_in=4, n_out=2) + self._d_feature = d_feature + self._n_heads = n_heads + self._n_random_features = n_random_features + self._numerical_stabilizer = numerical_stabilizer + self._mode = mode + self._use_approximate_softmax = use_approximate_softmax + self._normalize_data = normalize_data + self._epsilon = epsilon + if self._use_approximate_softmax: + rng = random.get_prng(0) + self._projection_matrix = self.get_2d_array( + rng=rng, + n_rows=self._n_random_features, + n_columns=(self._d_feature // self._n_heads), + scale_by_norm=scale_by_norm, + normalize_data=normalize_data, + epsilon=epsilon, + ) + else: + self._projection_matrix = None + + def nonnegative_softmax_kernel_feature_creator(self, x, is_query): + """Constructs nonnegative kernel features for fast softmax attention. + + Args: + x: input for which features are computed. + is_query: predicate indicating whether input data corresponds to + queries or keys. + + Returns: + Random features for fast softmax attention. + """ + if self._normalize_data: + # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where + # w_norm = w * data_normalizer for w in {q,k}. + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(x.shape[-1]))) + else: + data_normalizer = 1.0 + ratio = 1.0 / jnp.sqrt(self._projection_matrix.shape[0]) + # TODO(wgaj): Double-check... Should there be only one batch dimension...? + data_mod_shape = x.shape[0:1] + self._projection_matrix.shape + data_thick_random_matrix = jnp.zeros(data_mod_shape) + self._projection_matrix + + data_dash = jnp.einsum( + "Bij, Bkj -> Bik", data_normalizer * x, data_thick_random_matrix + ) + diag_data = jnp.square(x) + diag_data = jnp.sum(diag_data, axis=x.ndim - 1) + diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer + diag_data = jnp.expand_dims(diag_data, axis=x.ndim - 1) + + last_dims_t = (len(data_dash.shape) - 1,) + attention_dims_t = (1,) + if is_query: + data_dash = ratio * ( + jnp.exp( + data_dash + - diag_data + - jnp.max(data_dash, axis=last_dims_t, keepdims=True) + ) + + self._epsilon + ) + else: + data_dash = ratio * ( + jnp.exp( + data_dash + - diag_data + - jnp.max( + data_dash, axis=last_dims_t + attention_dims_t, keepdims=True + ) + ) + + self._epsilon + ) + + return data_dash + + @staticmethod + def get_2d_array( + rng, + n_rows=256, + n_columns=0, + scale_by_norm=True, + normalize_data=False, + epsilon=0.0001, + ): + """Generator for approximate softmax orthogonal kernel feature matrix. + + Args: + rng: Random number generator. + n_rows: Number of rows. + n_columns: Number of columns. + scale_by_norm: Boolean; whether to scale orthogonal random matrix. + normalize_data: predicate indicating whether data should be normalized. + epsilon: numerical stabilizer. + + Returns: + Orthogonal kernel feature matrix. + """ + n_full_blocks = int(n_rows / n_columns) + block_list = [] + rng_key = rng + for _ in range(n_full_blocks): + rng, rng_input = random.split(rng) + unstructured_block = random.normal(rng_input, (n_columns, n_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q) + remaining_rows = n_rows - n_full_blocks * n_columns + if remaining_rows > 0: + rng, rng_input = random.split(rng) + unstructured_block = random.normal(rng_input, (n_columns, n_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q[0:remaining_rows]) + final_matrix = jnp.vstack(block_list) + + if scale_by_norm: + multiplier = jnp.linalg.norm( + random.normal(rng_key, (n_rows, n_columns)), axis=1 + ) + else: + multiplier = jnp.sqrt(float(n_columns)) * jnp.ones((n_rows)) + + return jnp.matmul(jnp.diag(multiplier), final_matrix) + + @staticmethod + def bidirectional_numerator(query_prime, key_prime, value): + kvs = jnp.einsum("lbm,lbd->bmd", key_prime, value) + return jnp.einsum("lbm,bmd->lbd", query_prime, kvs) + + @staticmethod + def bidirectional_denominator(query_prime, key_prime): + all_ones = jnp.ones([query_prime.shape[0]]) + ks_sum = jnp.einsum("lbm,l->bm", key_prime, all_ones) + return jnp.einsum("lbm,bm->lb", query_prime, ks_sum) + + @staticmethod + def relu(x): + return jnp.where(x <= 0, jnp.zeros_like(x), x) + + def forward(self, inputs): + query, key, value, mask = inputs + if self._use_approximate_softmax: + query_prime = self.nonnegative_softmax_kernel_feature_creator(query, True) + key_prime = self.nonnegative_softmax_kernel_feature_creator(key, False) + else: + query_prime = self.relu(query) + self._numerical_stabilizer + key_prime = self.relu(key) + self._numerical_stabilizer + mask_batch_1_length = jnp.reshape( + mask, [key.shape[0] // self._n_heads, 1, key.shape[1]] + ).astype(jnp.float32) + mask_heads = mask_batch_1_length + jnp.zeros((1, self._n_heads, 1)) + key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) + + w = self.bidirectional_numerator( + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + jnp.moveaxis(value, 1, 0), + ) + r = self.bidirectional_denominator( + jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0) + ) + w = jnp.moveaxis(w, 0, 1) + r = jnp.moveaxis(r, 0, 1) + r = jnp.reciprocal(r) + r = jnp.expand_dims(r, len(r.shape)) + renormalized_attention = w * r + return renormalized_attention, mask + + +def Favor( + d_feature, + n_heads=1, + n_random_features=256, + dropout=0.0, + numerical_stabilizer=0.001, + use_approximate_softmax=False, + scale_by_norm=0, + normalize_data=False, + epsilon=0.0001, + mode="train", +): + """Returns a layer that maps (activations, mask) to (new_activations, mask). + + See the FAVOR paper for details: https://arxiv.org/abs/2006.03555 Args: - rng: Random number generator. - n_rows: Number of rows. - n_columns: Number of columns. + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + n_random_features: Free dimension size for the orthogonal random matrix. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + numerical_stabilizer: float, small number used for numerical stability. + use_approximate_softmax: Bool, if True uses approximate softmax, otherwise + Relu. scale_by_norm: Boolean; whether to scale orthogonal random matrix. normalize_data: predicate indicating whether data should be normalized. epsilon: numerical stabilizer. - - Returns: - Orthogonal kernel feature matrix. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - n_full_blocks = int(n_rows / n_columns) - block_list = [] - rng_key = rng - for _ in range(n_full_blocks): - rng, rng_input = random.split(rng) - unstructured_block = random.normal(rng_input, (n_columns, n_columns)) - q, _ = jnp.linalg.qr(unstructured_block) - q = jnp.transpose(q) - block_list.append(q) - remaining_rows = n_rows - n_full_blocks * n_columns - if remaining_rows > 0: - rng, rng_input = random.split(rng) - unstructured_block = random.normal(rng_input, (n_columns, n_columns)) - q, _ = jnp.linalg.qr(unstructured_block) - q = jnp.transpose(q) - block_list.append(q[0:remaining_rows]) - final_matrix = jnp.vstack(block_list) - - if scale_by_norm: - multiplier = jnp.linalg.norm( - random.normal(rng_key, (n_rows, n_columns)), axis=1) - else: - multiplier = jnp.sqrt(float(n_columns)) * jnp.ones((n_rows)) - - return jnp.matmul(jnp.diag(multiplier), final_matrix) - - @staticmethod - def bidirectional_numerator(query_prime, key_prime, value): - kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value) - return jnp.einsum('lbm,bmd->lbd', query_prime, kvs) - - @staticmethod - def bidirectional_denominator(query_prime, key_prime): - all_ones = jnp.ones([query_prime.shape[0]]) - ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones) - return jnp.einsum('lbm,bm->lb', query_prime, ks_sum) - - @staticmethod - def relu(x): - return jnp.where(x <= 0, jnp.zeros_like(x), x) - - def forward(self, inputs): - query, key, value, mask = inputs - if self._use_approximate_softmax: - query_prime = self.nonnegative_softmax_kernel_feature_creator(query, True) - key_prime = self.nonnegative_softmax_kernel_feature_creator(key, False) - else: - query_prime = self.relu(query) + self._numerical_stabilizer - key_prime = self.relu(key) + self._numerical_stabilizer - mask_batch_1_length = jnp.reshape( - mask, [key.shape[0] // self._n_heads, 1, key.shape[1]]).astype( - jnp.float32) - mask_heads = mask_batch_1_length + jnp.zeros((1, self._n_heads, 1)) - key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) - - w = self.bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0), - jnp.moveaxis(value, 1, 0)) - r = self.bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0)) - w = jnp.moveaxis(w, 0, 1) - r = jnp.moveaxis(r, 0, 1) - r = jnp.reciprocal(r) - r = jnp.expand_dims(r, len(r.shape)) - renormalized_attention = w * r - return renormalized_attention, mask - - -def Favor(d_feature, n_heads=1, n_random_features=256, dropout=0.0, - numerical_stabilizer=0.001, use_approximate_softmax=False, - scale_by_norm=0, normalize_data=False, epsilon=0.0001, mode='train'): - """Returns a layer that maps (activations, mask) to (new_activations, mask). - - See the FAVOR paper for details: https://arxiv.org/abs/2006.03555 - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - n_random_features: Free dimension size for the orthogonal random matrix. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - numerical_stabilizer: float, small number used for numerical stability. - use_approximate_softmax: Bool, if True uses approximate softmax, otherwise - Relu. - scale_by_norm: Boolean; whether to scale orthogonal random matrix. - normalize_data: predicate indicating whether data should be normalized. - epsilon: numerical stabilizer. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - del dropout # not implemented yet but needed in the API - - return tl.ConfigurableAttention( - tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), - tl.Dense(d_feature), - tl.FavorAttention(d_feature, n_heads, n_random_features, - numerical_stabilizer, use_approximate_softmax, - scale_by_norm, normalize_data, epsilon, mode), - n_heads=n_heads) + del dropout # not implemented yet but needed in the API + + return tl.ConfigurableAttention( + tl.Dense(d_feature), + tl.Dense(d_feature), + tl.Dense(d_feature), + tl.Dense(d_feature), + tl.FavorAttention( + d_feature, + n_heads, + n_random_features, + numerical_stabilizer, + use_approximate_softmax, + scale_by_norm, + normalize_data, + epsilon, + mode, + ), + n_heads=n_heads, + ) class CausalFavorAttention(base.Layer): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - causal attention, but using FAVOR fast attention as in the following paper: - https://arxiv.org/abs/2006.03555 - - Layer expects three inputs (Q, K, V), and returns one output - RENORMALIZED_ATTENTION. - - Attributes: - numerical_stabilizer: float, small number used for numerical stability. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - - def __init__(self, numerical_stabilizer=0.001, mode='train'): - super().__init__(n_in=3, n_out=1) - self._numerical_stabilizer = numerical_stabilizer - self._mode = mode - - def forward(self, inputs): - def favor_numerator_fwd(init_prefix_sum_value, - query_prime, key_prime, value): - def body(p, qkv): - (q, k, v) = qkv - p += jnp.einsum('...m,...d->...md', k, v) - x_slice = jnp.einsum('...m,...md->...d', q, p) - return p, x_slice - p, w = fastmath.scan(body, init_prefix_sum_value, - (query_prime, key_prime, value)) - return w, (p, query_prime, key_prime, value) - - def favor_numerator_bwd(pqkv, w_ct): - p, qs, ks, vs = pqkv - - def body(carry, qkv_xct): - p, p_ct = carry - q, k, v, x_ct = qkv_xct - q_ct = jnp.einsum('...d,...md->...m', x_ct, p) - p_ct += jnp.einsum('...d,...m->...md', x_ct, q) - k_ct = jnp.einsum('...md,...d->...m', p_ct, v) - v_ct = jnp.einsum('...md,...m->...d', p_ct, k) - p -= jnp.einsum('...m,...d->...md', k, v) - return (p, p_ct), (q_ct, k_ct, v_ct) - - _, (qs_ct, ks_ct, vs_ct) = fastmath.scan( - body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) - return (None, qs_ct, ks_ct, vs_ct) - - def favor_numerator(init_prefix_sum_value, query_prime, - key_prime, value): - w, _ = favor_numerator_fwd(init_prefix_sum_value, - query_prime, key_prime, value) - return w - - favor_numerator = fastmath.custom_vjp( - favor_numerator, favor_numerator_fwd, favor_numerator_bwd) - - def favor_denominator_fwd(init_prefix_sum_value, - query_prime, key_prime): - def body(p, qk): - q, k = qk - p += k - x = jnp.einsum('...m,...m->...', q, p) - return p, x - - p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, - key_prime)) - return r, (query_prime, key_prime, p) - - def favor_denominator_bwd(qkp, r_ct): - qs, ks, p = qkp - - def body(carry, qkx): - p, p_ct = carry - q, k, x_ct = qkx - q_ct = jnp.einsum('...,...m->...m', x_ct, p) - p_ct += jnp.einsum('...,...m->...m', x_ct, q) - k_ct = p_ct - p -= k - return (p, p_ct), (q_ct, k_ct) - - _, (qs_ct, ks_ct) = fastmath.scan( - body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True) - return (None, qs_ct, ks_ct) - - def favor_denominator(init_prefix_sum_value, query_prime, - key_prime): - r, _ = favor_denominator_fwd(init_prefix_sum_value, - query_prime, key_prime) - return r - - favor_denominator = fastmath.custom_vjp( - favor_denominator, favor_denominator_fwd, favor_denominator_bwd) - - favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) - - def relu(x): - return jnp.where(x <= 0, jnp.zeros_like(x), x) - - query, key, value = inputs - query_prime = relu(query) + self._numerical_stabilizer - key_prime = relu(key) + self._numerical_stabilizer - prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) - t_slice_shape = (key.shape[0], key.shape[-1]) - init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) - init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) - - w = favor_numerator(init_prefix_sum_value_numerator, - jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0), - jnp.moveaxis(value, 1, 0)) - r = favor_denominator(init_prefix_sum_value_denominator, - jnp.moveaxis(query_prime, 1, 0), - jnp.moveaxis(key_prime, 1, 0)) - w = jnp.moveaxis(w, 0, 1) - r = jnp.moveaxis(r, 0, 1) - r = jnp.reciprocal(r) - r = jnp.expand_dims(r, len(r.shape)) - renormalized_attention = w * r - return renormalized_attention - - -def CausalFavor(d_feature, n_heads=1, dropout=0.0, - numerical_stabilizer=0.001, mode='train'): - """Returns a layer that maps activations to activations, with causal masking. - - Like `CausalAttention`, this layer type represents one pass of multi-head - causal attention, but using FAVOR fast attention as in the following paper: - https://arxiv.org/abs/2006.03555 - - Args: - d_feature: Depth/dimensionality of feature embedding. - n_heads: Number of attention heads. - dropout: Probababilistic rate for internal dropout applied to attention - activations (based on query-key pairs) before dotting them with values. - numerical_stabilizer: float, small number used for numerical stability. - mode: One of `'train'`, `'eval'`, or `'predict'`. - """ - del dropout - return tl.ConfigurableAttention( - core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), - core.Dense(d_feature), n_heads=n_heads, - qkv_attention_layer=tl.CausalFavorAttention(numerical_stabilizer, - mode)) + """Returns a layer that maps activations to activations, with causal masking. + Like `CausalAttention`, this layer type represents one pass of multi-head + causal attention, but using FAVOR fast attention as in the following paper: + https://arxiv.org/abs/2006.03555 -class _RememberInReverse(base.Layer): - """Layer remembering the input in forward pass. For reversible models.""" + Layer expects three inputs (Q, K, V), and returns one output + RENORMALIZED_ATTENTION. - def __init__(self, output=True): - """Layer remembering the input in forward pass. For reversible models. + Attributes: + numerical_stabilizer: float, small number used for numerical stability. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ - During the first pass through the model this layer saves the input as - state, and returns the input unmodified. During the second pass through the - model the layer outputs the input from the first pass. This is used to - combat numerical stability problems in Terraformer. It doesn't do anything - in non-reversible models. + def __init__(self, numerical_stabilizer=0.001, mode="train"): + super().__init__(n_in=3, n_out=1) + self._numerical_stabilizer = numerical_stabilizer + self._mode = mode + + def forward(self, inputs): + def favor_numerator_fwd(init_prefix_sum_value, query_prime, key_prime, value): + def body(p, qkv): + (q, k, v) = qkv + p += jnp.einsum("...m,...d->...md", k, v) + x_slice = jnp.einsum("...m,...md->...d", q, p) + return p, x_slice + + p, w = fastmath.scan( + body, init_prefix_sum_value, (query_prime, key_prime, value) + ) + return w, (p, query_prime, key_prime, value) + + def favor_numerator_bwd(pqkv, w_ct): + p, qs, ks, vs = pqkv + + def body(carry, qkv_xct): + p, p_ct = carry + q, k, v, x_ct = qkv_xct + q_ct = jnp.einsum("...d,...md->...m", x_ct, p) + p_ct += jnp.einsum("...d,...m->...md", x_ct, q) + k_ct = jnp.einsum("...md,...d->...m", p_ct, v) + v_ct = jnp.einsum("...md,...m->...d", p_ct, k) + p -= jnp.einsum("...m,...d->...md", k, v) + return (p, p_ct), (q_ct, k_ct, v_ct) + + _, (qs_ct, ks_ct, vs_ct) = fastmath.scan( + body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True + ) + return (None, qs_ct, ks_ct, vs_ct) + + def favor_numerator(init_prefix_sum_value, query_prime, key_prime, value): + w, _ = favor_numerator_fwd( + init_prefix_sum_value, query_prime, key_prime, value + ) + return w + + favor_numerator = fastmath.custom_vjp( + favor_numerator, favor_numerator_fwd, favor_numerator_bwd + ) + + def favor_denominator_fwd(init_prefix_sum_value, query_prime, key_prime): + def body(p, qk): + q, k = qk + p += k + x = jnp.einsum("...m,...m->...", q, p) + return p, x + + p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) + return r, (query_prime, key_prime, p) + + def favor_denominator_bwd(qkp, r_ct): + qs, ks, p = qkp + + def body(carry, qkx): + p, p_ct = carry + q, k, x_ct = qkx + q_ct = jnp.einsum("...,...m->...m", x_ct, p) + p_ct += jnp.einsum("...,...m->...m", x_ct, q) + k_ct = p_ct + p -= k + return (p, p_ct), (q_ct, k_ct) + + _, (qs_ct, ks_ct) = fastmath.scan( + body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True + ) + return (None, qs_ct, ks_ct) + + def favor_denominator(init_prefix_sum_value, query_prime, key_prime): + r, _ = favor_denominator_fwd(init_prefix_sum_value, query_prime, key_prime) + return r + + favor_denominator = fastmath.custom_vjp( + favor_denominator, favor_denominator_fwd, favor_denominator_bwd + ) + + favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) + + def relu(x): + return jnp.where(x <= 0, jnp.zeros_like(x), x) + + query, key, value = inputs + query_prime = relu(query) + self._numerical_stabilizer + key_prime = relu(key) + self._numerical_stabilizer + prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) + t_slice_shape = (key.shape[0], key.shape[-1]) + init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) + init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) + + w = favor_numerator( + init_prefix_sum_value_numerator, + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + jnp.moveaxis(value, 1, 0), + ) + r = favor_denominator( + init_prefix_sum_value_denominator, + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + ) + w = jnp.moveaxis(w, 0, 1) + r = jnp.moveaxis(r, 0, 1) + r = jnp.reciprocal(r) + r = jnp.expand_dims(r, len(r.shape)) + renormalized_attention = w * r + return renormalized_attention + + +def CausalFavor( + d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, mode="train" +): + """Returns a layer that maps activations to activations, with causal masking. + + Like `CausalAttention`, this layer type represents one pass of multi-head + causal attention, but using FAVOR fast attention as in the following paper: + https://arxiv.org/abs/2006.03555 Args: - output: Whether to pass the input or not. + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + numerical_stabilizer: float, small number used for numerical stability. + mode: One of `'train'`, `'eval'`, or `'predict'`. """ - n_out = 1 if output else 0 - self._output = output - super().__init__(name='_RememberInReverse', n_out=n_out) - - def forward(self, x): - if 'running_second_time_yes' in self.state[1]: - result = self.state[0] - else: - result = x - self.state = (x, {'running_second_time': ()}) + del dropout + return tl.ConfigurableAttention( + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + n_heads=n_heads, + qkv_attention_layer=tl.CausalFavorAttention(numerical_stabilizer, mode), + ) - if self._output: - return result - else: - return tuple() - def init_weights_and_state(self, input_signature): - """Initializes this layer's weights.""" - if isinstance(input_signature, (list, tuple)): - input_signature = input_signature[0] - self.weights = () - self.state = (jnp.zeros(input_signature.shape, dtype=jnp.int32), - {'running_second_time': ()}) +class _RememberInReverse(base.Layer): + """Layer remembering the input in forward pass. For reversible models.""" + + def __init__(self, output=True): + """Layer remembering the input in forward pass. For reversible models. + + During the first pass through the model this layer saves the input as + state, and returns the input unmodified. During the second pass through the + model the layer outputs the input from the first pass. This is used to + combat numerical stability problems in Terraformer. It doesn't do anything + in non-reversible models. + + Args: + output: Whether to pass the input or not. + """ + n_out = 1 if output else 0 + self._output = output + super().__init__(name="_RememberInReverse", n_out=n_out) + + def forward(self, x): + if "running_second_time_yes" in self.state[1]: + result = self.state[0] + else: + result = x + self.state = (x, {"running_second_time": ()}) + + if self._output: + return result + else: + return tuple() + + def init_weights_and_state(self, input_signature): + """Initializes this layer's weights.""" + if isinstance(input_signature, (list, tuple)): + input_signature = input_signature[0] + self.weights = () + self.state = ( + jnp.zeros(input_signature.shape, dtype=jnp.int32), + {"running_second_time": ()}, + ) class _RecallQuantMaskInReverse(base.Layer): - """Layer recalling quant mask from specific _RememberInReverse. - - This layer is needed for memory-efficient training of reversible model with - ff chunking. During forward pass it simply returns minus ones, which are - ignored in the controller. During reverse_and_grad it returns a quant_mask - which was memorized (saved to state) by a RememberInReverse layer. - - This enable us to save quant_mask right after chunking, and load it again - (when reversing) right before chunking. - """ - - def __init__(self, remember_layer, elements): - self._remember_layer = remember_layer - self._elements = elements - super().__init__(name='_RecallQuantMaskInReverse', n_in=1, n_out=2) - - def forward(self, x): - if (self._remember_layer.state and - 'running_second_time_yes' in self._remember_layer.state[1]): - # It's reverse_and_grad, so we pull the quant_mask from remembering layer. - result = self._remember_layer.state[0] - else: - result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32) - return (x, result) + """Layer recalling quant mask from specific _RememberInReverse. + This layer is needed for memory-efficient training of reversible model with + ff chunking. During forward pass it simply returns minus ones, which are + ignored in the controller. During reverse_and_grad it returns a quant_mask + which was memorized (saved to state) by a RememberInReverse layer. -class _SparseFFController(base.Layer): - """The controller part of Sparse Feed-Forward layer.""" - - def __init__(self, d_ff, n_elements_in_block, d_lowrank, temperature, - use_bfloat16, mode, kernel_initializer, bias_initializer, - also_return_nondiscrete_output): - """Returns a sparse feed-forward block.""" - n_out = 2 if also_return_nondiscrete_output else 1 - super().__init__(name=f'_SparseFFController_{d_ff}', n_in=2, n_out=n_out) - self._use_bfloat16 = use_bfloat16 - self._d_ff = d_ff - self._d_lowrank = d_lowrank - # Q: what temperature is actually most useful in training? - self._temperature = temperature if mode == 'train' else 0.0 - self._mode = mode - self._n_elements_in_block = n_elements_in_block - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - # Helper numbers as d_ff will be divided by n_elements_in_block. - assert self._d_ff % self._n_elements_in_block == 0 - self._d1 = self._d_ff // self._n_elements_in_block - self._d2 = self._n_elements_in_block - self._also_return_nondiscrete_output = also_return_nondiscrete_output - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. - - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input. + This enable us to save quant_mask right after chunking, and load it again + (when reversing) right before chunking. """ - x, recalled_quant_mask = x - m1, m2, mb = self.weights - - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Q: should we add bias and/or put relu after the low-rank m1 dot? - # Replacing multiplication and reshape by this einsum brings training speed - # improvement (see also reshape in initialization). - mask_logits = jnp.einsum('bd,dl,lxy->bxy', x, m1, m2) + mb - - if self._also_return_nondiscrete_output: - # Softmax. - mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) - log_mask = mask_logits - mask_logsumexp - mask = jnp.exp(log_mask) - # Gumbel-softmax with straight-through discretization. - if self._temperature == 0.0: - quant_mask = jnp.argmax(log_mask, axis=-1) - else: - u = fastmath.random.uniform(self.rng, mask.shape, jnp.float32, 1e-6, - 1.0 - 1e-6) - g = -jnp.log(-jnp.log(u)) - quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) - else: - quant_mask = jnp.argmax(mask_logits, axis=-1) - if self._mode == 'train': - # We use recalled_quant_mask if it's different than -1; otherwise - # we use a quant_mask which we have just computed. - quant_mask = jnp.where(recalled_quant_mask == -1, - quant_mask, recalled_quant_mask) + def __init__(self, remember_layer, elements): + self._remember_layer = remember_layer + self._elements = elements + super().__init__(name="_RecallQuantMaskInReverse", n_in=1, n_out=2) - if self._also_return_nondiscrete_output: - return quant_mask, mask - else: - return quant_mask + def forward(self, x): + if ( + self._remember_layer.state + and "running_second_time_yes" in self._remember_layer.state[1] + ): + # It's reverse_and_grad, so we pull the quant_mask from remembering layer. + result = self._remember_layer.state[0] + else: + result = -jnp.ones((x.shape[0], self._elements), dtype=jnp.int32) + return (x, result) - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - x_input_signature = input_signature[0] - d_model = x_input_signature.shape[-1] - shape_m1 = (d_model, self._d_lowrank) - shape_m2 = (self._d_lowrank, self._d_ff) - shape_mb = (self._d_ff,) - rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3) - m1 = self._kernel_initializer(shape_m1, rng_m1) - m2 = self._kernel_initializer(shape_m2, rng_m2) - mb = self._bias_initializer(shape_mb, rng_mb) - if self._use_bfloat16: - m1 = m1.astype(jnp.bfloat16) - m2 = m2.astype(jnp.bfloat16) - mb = mb.astype(jnp.bfloat16) +class _SparseFFController(base.Layer): + """The controller part of Sparse Feed-Forward layer.""" + + def __init__( + self, + d_ff, + n_elements_in_block, + d_lowrank, + temperature, + use_bfloat16, + mode, + kernel_initializer, + bias_initializer, + also_return_nondiscrete_output, + ): + """Returns a sparse feed-forward block.""" + n_out = 2 if also_return_nondiscrete_output else 1 + super().__init__(name=f"_SparseFFController_{d_ff}", n_in=2, n_out=n_out) + self._use_bfloat16 = use_bfloat16 + self._d_ff = d_ff + self._d_lowrank = d_lowrank + # Q: what temperature is actually most useful in training? + self._temperature = temperature if mode == "train" else 0.0 + self._mode = mode + self._n_elements_in_block = n_elements_in_block + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Helper numbers as d_ff will be divided by n_elements_in_block. + assert self._d_ff % self._n_elements_in_block == 0 + self._d1 = self._d_ff // self._n_elements_in_block + self._d2 = self._n_elements_in_block + self._also_return_nondiscrete_output = also_return_nondiscrete_output + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + x, recalled_quant_mask = x + m1, m2, mb = self.weights + + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Q: should we add bias and/or put relu after the low-rank m1 dot? + # Replacing multiplication and reshape by this einsum brings training speed + # improvement (see also reshape in initialization). + mask_logits = jnp.einsum("bd,dl,lxy->bxy", x, m1, m2) + mb + + if self._also_return_nondiscrete_output: + # Softmax. + mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) + log_mask = mask_logits - mask_logsumexp + mask = jnp.exp(log_mask) + # Gumbel-softmax with straight-through discretization. + if self._temperature == 0.0: + quant_mask = jnp.argmax(log_mask, axis=-1) + else: + u = fastmath.random.uniform( + self.rng, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6 + ) + g = -jnp.log(-jnp.log(u)) + quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) + else: + quant_mask = jnp.argmax(mask_logits, axis=-1) + + if self._mode == "train": + # We use recalled_quant_mask if it's different than -1; otherwise + # we use a quant_mask which we have just computed. + quant_mask = jnp.where( + recalled_quant_mask == -1, quant_mask, recalled_quant_mask + ) + + if self._also_return_nondiscrete_output: + return quant_mask, mask + else: + return quant_mask + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + x_input_signature = input_signature[0] + d_model = x_input_signature.shape[-1] + shape_m1 = (d_model, self._d_lowrank) + shape_m2 = (self._d_lowrank, self._d_ff) + shape_mb = (self._d_ff,) + + rng_m1, rng_m2, rng_mb = fastmath.random.split(self.rng, 3) + m1 = self._kernel_initializer(shape_m1, rng_m1) + m2 = self._kernel_initializer(shape_m2, rng_m2) + mb = self._bias_initializer(shape_mb, rng_mb) + if self._use_bfloat16: + m1 = m1.astype(jnp.bfloat16) + m2 = m2.astype(jnp.bfloat16) + mb = mb.astype(jnp.bfloat16) + + # Reshapes below, with einsum in feedforward, improve the training speed. + m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2]) + mb = jnp.reshape(mb, [self._d1, self._d2]) + + self.weights = (m1, m2, mb) - # Reshapes below, with einsum in feedforward, improve the training speed. - m2 = jnp.reshape(m2, [self._d_lowrank, self._d1, self._d2]) - mb = jnp.reshape(mb, [self._d1, self._d2]) - self.weights = (m1, m2, mb) +class _SparseFFMain(base.Layer): + """The main (non-controller) part of Sparse Feed-Forward layer.""" + + def __init__( + self, + d_ff, + n_elements_in_block, + d_lowrank, + quant_prob, + use_bfloat16, + big_weights_in_bfloat16, + mode, + kernel_initializer, + bias_initializer, + multiply_by_controller_output, + kernel_scaling, + ): + """Returns a sparse feed-forward block.""" + n_in = 3 if mode == "train" or multiply_by_controller_output else 2 + super().__init__(name=f"_SparseFFMain_{d_ff}", n_in=n_in, n_out=2) + self._mode = mode + self._use_bfloat16 = use_bfloat16 + self._big_weights_in_bfloat16 = big_weights_in_bfloat16 + self._d_ff = d_ff + self._d_lowrank = d_lowrank + self._quant_prob = quant_prob + self._n_elements_in_block = n_elements_in_block + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Helper numbers as d_ff will be divided by n_elements_in_block. + assert self._d_ff % self._n_elements_in_block == 0 + self._d1 = self._d_ff // self._n_elements_in_block + self._d2 = self._n_elements_in_block + self._multiply_by_controller_output = multiply_by_controller_output + self._kernel_scaling = kernel_scaling + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + if self._mode == "train" or self._multiply_by_controller_output: + quant_mask, mask, x = x + else: + quant_mask, x = x + original_quant_mask = quant_mask + + w1, w2, b2 = self.weights + + if self._mode == "predict": + w1 = jnp.transpose(w1, (1, 2, 0)) # dm, d1, d2 -> d1, d2, dm + w2 = jnp.transpose(w2, (1, 0, 2)) # d2, d1, dm -> d1, d2, dm + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + if self._mode == "train": + # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 + quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) + quant_mask = fastmath.stop_gradient(quant_mask) + quant_mask += mask - fastmath.stop_gradient(mask) # straight-through + # We will sometimes (quant_prob of the batches) use the soft-mask instead + # of the quantized mask to improve training stability (see paper above). + select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0) + quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) + + # In training, run full matmul to get benefits from the above tricks. + mid = jnp.einsum("bd,dxy->bxy", x, w1) * quant_mask + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + if self._multiply_by_controller_output: + # We multiply only for quantized decisions, since for non-quantized + # decisions we've already multiplied the output. + mask_mult = jnp.where( + select < self._quant_prob, mask, jnp.ones_like(mask) + ) + # Stop-gradient is here, because we already have a pass-through gradient + # (for quantized decisions). + mask_mult = fastmath.stop_gradient(mask_mult) + relu = relu * mask_mult + res = jnp.einsum("bxy,yxd->bd", relu, w2) + b2 + elif self._mode == "predict": + # This implementation mimicks inference. It's not efficient for large + # size of joint_batch, but at inference that will be 1 most of the time. + # Shapes: + # quant_mask is [joint_batch, self._d1] + # w1 is [d_model, self._d1, self._d2] + # we'll index w1 with advanced numpy indexing, first range over + # self._d1 times the batch size, second range being quant_mask + batch_size = quant_mask.shape[0] + idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) + # flatten indices and select from w1 + idx1 = jnp.reshape(idx1, [-1]) + idx2 = jnp.reshape(quant_mask, [-1]) + w = w1[idx1, idx2, :] # now we have per-element weights with batch dim + w = jnp.reshape(w, [batch_size, self._d1, -1]) + mid = jnp.einsum("ai,aji->aj", x, w) + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + if self._multiply_by_controller_output: + mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0] + relu = relu * mask_mult + # w2 is [self._d1, self._d2, d_model] + v = w2[idx1, idx2, :] + v = jnp.reshape(v, [batch_size, self._d1, -1]) + res = jnp.einsum("ai,aij->aj", relu, v) + b2 + else: + quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) + mid = jnp.einsum("bd,dxy->bxy", x, w1) * quant_mask + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + if self._multiply_by_controller_output: + relu = relu * mask + res = jnp.einsum("bxy,yxd->bd", relu, w2) + b2 + + return original_quant_mask, jnp.reshape(res, x_shape) + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature[-1].shape[-1] + shape_w1 = (d_model, self._d_ff) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) + if base.N_WEIGHTS_SHARDS > 1: + # In sharded-weights mode, put the weights on CPU on init + # as they will be sharded later. + w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1)) + w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2)) + else: + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + + b2 = self._bias_initializer(shape_b2, rng_b2) + if self._use_bfloat16: + b2 = b2.astype(jnp.bfloat16) + if self._use_bfloat16 or self._big_weights_in_bfloat16: + w1 = w1.astype(jnp.bfloat16) + w2 = w2.astype(jnp.bfloat16) + + w1 = jnp.reshape(w1, (-1, self._d1, self._d2)) + w2 = jnp.reshape(w2, (self._d2, self._d1, -1)) + + if self._kernel_scaling: + # This keeps expected variance of the output regardless of N. + w2 = w2 * (self._n_elements_in_block**0.5) + + self.weights = (w1, w2, b2) -class _SparseFFMain(base.Layer): - """The main (non-controller) part of Sparse Feed-Forward layer.""" - - def __init__(self, d_ff, n_elements_in_block, d_lowrank, quant_prob, - use_bfloat16, big_weights_in_bfloat16, mode, kernel_initializer, - bias_initializer, multiply_by_controller_output, kernel_scaling): - """Returns a sparse feed-forward block.""" - n_in = 3 if mode == 'train' or multiply_by_controller_output else 2 - super().__init__(name=f'_SparseFFMain_{d_ff}', n_in=n_in, n_out=2) - self._mode = mode - self._use_bfloat16 = use_bfloat16 - self._big_weights_in_bfloat16 = big_weights_in_bfloat16 - self._d_ff = d_ff - self._d_lowrank = d_lowrank - self._quant_prob = quant_prob - self._n_elements_in_block = n_elements_in_block - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - # Helper numbers as d_ff will be divided by n_elements_in_block. - assert self._d_ff % self._n_elements_in_block == 0 - self._d1 = self._d_ff // self._n_elements_in_block - self._d2 = self._n_elements_in_block - self._multiply_by_controller_output = multiply_by_controller_output - self._kernel_scaling = kernel_scaling - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. +def SparseFF( + d_ff, + n_elements_in_block=32, + d_lowrank=64, + temperature=0.1, + quant_prob=0.3, + use_bfloat16=False, + big_weights_in_bfloat16=False, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + dropout_rate=0.0, + dropout_shared_axes=None, + ff_chunk_size=0, + multiply_by_controller_output=False, + kernel_scaling=False, +): + """Returns Feed-forward block with sparsity. + + The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. + + The original block can be slow in decoding due to the need to fetch a lot of + weights from memory. This sparse block only allows one non-zero element + in a block of a specified size. This is trained with straight-through Gumbel + softmax trick. Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. - - Returns: - Tensor of same shape and dtype as the input. + d_ff: Depth/dimensionality of FeedForward layer. + n_elements_in_block: The sparsity level. The layer is divided into blocks of + this size, and each block has only a single element active. + d_lowrank: The dimensionality of low-rank controller. + temperature: The temperature of the controller during training. + quant_prob: During training this proportion of blocks will have quantized + mask (i.e. a single element active). The rest will use a soft mask. + use_bfloat16: Whether to use bfloat16 for weights. + big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the + FeedForward layer. + mode: One of `'train'`, `'eval'`, or `'predict'`. + kernel_initializer: Function that creates a matrix of (random) initial + connection weights `W` for the layer. + bias_initializer: Function that creates a vector of (random) initial + bias weights `b` for the layer. + dropout_rate: Probability for dropping an activation value. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks. + multiply_by_controller_output: whether to multiply the middle activation + layer of FF by controller output (i.e. softmax). + kernel_scaling: Whether to scale the kernel matrix (during init) to keep the + variance of the layer output regardless of n_elements_in_block. """ - if self._mode == 'train' or self._multiply_by_controller_output: - quant_mask, mask, x = x - else: - quant_mask, x = x - original_quant_mask = quant_mask - - w1, w2, b2 = self.weights - - if self._mode == 'predict': - w1 = jnp.transpose(w1, (1, 2, 0)) # dm, d1, d2 -> d1, d2, dm - w2 = jnp.transpose(w2, (1, 0, 2)) # d2, d1, dm -> d1, d2, dm - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - if self._mode == 'train': - # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 - quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) - quant_mask = fastmath.stop_gradient(quant_mask) - quant_mask += mask - fastmath.stop_gradient(mask) # straight-through - # We will sometimes (quant_prob of the batches) use the soft-mask instead - # of the quantized mask to improve training stability (see paper above). - select = fastmath.random.uniform(self.rng, (), jnp.float32, 0.0, 1.0) - quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) - - # In training, run full matmul to get benefits from the above tricks. - mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - if self._multiply_by_controller_output: - # We multiply only for quantized decisions, since for non-quantized - # decisions we've already multiplied the output. - mask_mult = jnp.where(select < self._quant_prob, - mask, jnp.ones_like(mask)) - # Stop-gradient is here, because we already have a pass-through gradient - # (for quantized decisions). - mask_mult = fastmath.stop_gradient(mask_mult) - relu = relu * mask_mult - res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2 - elif self._mode == 'predict': - # This implementation mimicks inference. It's not efficient for large - # size of joint_batch, but at inference that will be 1 most of the time. - # Shapes: - # quant_mask is [joint_batch, self._d1] - # w1 is [d_model, self._d1, self._d2] - # we'll index w1 with advanced numpy indexing, first range over - # self._d1 times the batch size, second range being quant_mask - batch_size = quant_mask.shape[0] - idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) - # flatten indices and select from w1 - idx1 = jnp.reshape(idx1, [-1]) - idx2 = jnp.reshape(quant_mask, [-1]) - w = w1[idx1, idx2, :] # now we have per-element weights with batch dim - w = jnp.reshape(w, [batch_size, self._d1, -1]) - mid = jnp.einsum('ai,aji->aj', x, w) - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - if self._multiply_by_controller_output: - mask_mult = jnp.take_along_axis(mask, quant_mask[..., None], -1)[..., 0] - relu = relu * mask_mult - # w2 is [self._d1, self._d2, d_model] - v = w2[idx1, idx2, :] - v = jnp.reshape(v, [batch_size, self._d1, -1]) - res = jnp.einsum('ai,aij->aj', relu, v) + b2 - else: - quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) - mid = jnp.einsum('bd,dxy->bxy', x, w1) * quant_mask - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - if self._multiply_by_controller_output: - relu = relu * mask - res = jnp.einsum('bxy,yxd->bd', relu, w2) + b2 - - return original_quant_mask, jnp.reshape(res, x_shape) - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature[-1].shape[-1] - shape_w1 = (d_model, self._d_ff) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - - rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 3) - if base.N_WEIGHTS_SHARDS > 1: - # In sharded-weights mode, put the weights on CPU on init - # as they will be sharded later. - w1 = tl.on_cpu(self._kernel_initializer(shape_w1, rng_w1)) - w2 = tl.on_cpu(self._kernel_initializer(shape_w2, rng_w2)) - else: - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - - b2 = self._bias_initializer(shape_b2, rng_b2) - if self._use_bfloat16: - b2 = b2.astype(jnp.bfloat16) - if self._use_bfloat16 or self._big_weights_in_bfloat16: - w1 = w1.astype(jnp.bfloat16) - w2 = w2.astype(jnp.bfloat16) - w1 = jnp.reshape(w1, (-1, self._d1, self._d2)) - w2 = jnp.reshape(w2, (self._d2, self._d1, -1)) - - if self._kernel_scaling: - # This keeps expected variance of the output regardless of N. - w2 = w2 * (self._n_elements_in_block ** 0.5) + if mode == "train" or multiply_by_controller_output: + also_return_nondiscrete_output = True + else: + also_return_nondiscrete_output = False + controller = _SparseFFController( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + d_lowrank=d_lowrank, + temperature=temperature, + use_bfloat16=use_bfloat16, + mode=mode, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + also_return_nondiscrete_output=also_return_nondiscrete_output, + ) - self.weights = (w1, w2, b2) + main = [ + _SparseFFMain( + d_ff=d_ff, + n_elements_in_block=n_elements_in_block, + d_lowrank=d_lowrank, + quant_prob=quant_prob, + use_bfloat16=use_bfloat16, + big_weights_in_bfloat16=big_weights_in_bfloat16, + mode=mode, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + multiply_by_controller_output=multiply_by_controller_output, + kernel_scaling=kernel_scaling, + ), + # quant_mask, emb + tl.Select([1, 0]), + # emb, quant_mask + tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode), + tl.Select([1, 0]), + # quant_mask, emb + ] + # We will "remember" quant_mask _after_ chunking, and "recall" this same + # quant_mask during reverse_and_grad _before_ chunking. + remembering = _RememberInReverse(output=False) + recalling = _RecallQuantMaskInReverse( + remember_layer=remembering, elements=d_ff // n_elements_in_block + ) -def SparseFF( - d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.1, quant_prob=0.3, - use_bfloat16=False, big_weights_in_bfloat16=False, mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6), - dropout_rate=0.0, dropout_shared_axes=None, ff_chunk_size=0, - multiply_by_controller_output=False, kernel_scaling=False): - """Returns Feed-forward block with sparsity. - - The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - The original block can be slow in decoding due to the need to fetch a lot of - weights from memory. This sparse block only allows one non-zero element - in a block of a specified size. This is trained with straight-through Gumbel - softmax trick. - - Args: - d_ff: Depth/dimensionality of FeedForward layer. - n_elements_in_block: The sparsity level. The layer is divided into blocks of - this size, and each block has only a single element active. - d_lowrank: The dimensionality of low-rank controller. - temperature: The temperature of the controller during training. - quant_prob: During training this proportion of blocks will have quantized - mask (i.e. a single element active). The rest will use a soft mask. - use_bfloat16: Whether to use bfloat16 for weights. - big_weights_in_bfloat16: : Whether to use bfloat16 for main weights of the - FeedForward layer. - mode: One of `'train'`, `'eval'`, or `'predict'`. - kernel_initializer: Function that creates a matrix of (random) initial - connection weights `W` for the layer. - bias_initializer: Function that creates a vector of (random) initial - bias weights `b` for the layer. - dropout_rate: Probability for dropping an activation value. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks. - multiply_by_controller_output: whether to multiply the middle activation - layer of FF by controller output (i.e. softmax). - kernel_scaling: Whether to scale the kernel matrix (during init) to keep the - variance of the layer output regardless of n_elements_in_block. - """ - - if mode == 'train' or multiply_by_controller_output: - also_return_nondiscrete_output = True - else: - also_return_nondiscrete_output = False - controller = _SparseFFController( - d_ff=d_ff, n_elements_in_block=n_elements_in_block, - d_lowrank=d_lowrank, temperature=temperature, - use_bfloat16=use_bfloat16, mode=mode, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - also_return_nondiscrete_output=also_return_nondiscrete_output) - - main = [ - _SparseFFMain( - d_ff=d_ff, n_elements_in_block=n_elements_in_block, - d_lowrank=d_lowrank, quant_prob=quant_prob, use_bfloat16=use_bfloat16, - big_weights_in_bfloat16=big_weights_in_bfloat16, mode=mode, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - multiply_by_controller_output=multiply_by_controller_output, - kernel_scaling=kernel_scaling), - # quant_mask, emb - tl.Select([1, 0]), - # emb, quant_mask - tl.Dropout(rate=dropout_rate, shared_axes=dropout_shared_axes, mode=mode), - tl.Select([1, 0]), - # quant_mask, emb - ] - - # We will "remember" quant_mask _after_ chunking, and "recall" this same - # quant_mask during reverse_and_grad _before_ chunking. - remembering = _RememberInReverse(output=False) - recalling = _RecallQuantMaskInReverse( - remember_layer=remembering, elements=d_ff//n_elements_in_block) - - return tl.BatchLeadingAxes(tl.Serial( - recalling, # emb, quant_mask - tl.Chunk(chunk_size=ff_chunk_size, layer=tl.Serial( - # emb, quant_mask - tl.Select((0, 1, 0)), # emb, quant_mask, emb - controller, # quant_mask, mask, emb - main, # quant_mask, emb/output - )), - remembering, # emb/output - )) + return tl.BatchLeadingAxes( + tl.Serial( + recalling, # emb, quant_mask + tl.Chunk( + chunk_size=ff_chunk_size, + layer=tl.Serial( + # emb, quant_mask + tl.Select((0, 1, 0)), # emb, quant_mask, emb + controller, # quant_mask, mask, emb + main, # quant_mask, emb/output + ), + ), + remembering, # emb/output + ) + ) class BlockSparseFF(base.Layer): - """Feed-forward block with block sparsity. - - The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - This block sparse layer mimics mixture of experts architecture. - It divides the dimension of d_ff in each weight matrix to # of blocks equal to - n_experts and activates only one non-zero block from the weights matrix. - This is trained with straight-through Gumbel softmax trick. - """ - - def __init__(self, - d_ff, - n_experts=64, - temperature=0.7, - mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6)): - """Returns a block sparse feed-forward block.""" - super().__init__(name=f'BlockSparseFF_{d_ff}') - self._mode = mode - self._d_ff = d_ff - self._n_experts = n_experts - self._temperature = temperature if mode == 'train' else 0.0 - self._n_elements_in_block = d_ff // n_experts - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - assert self._d_ff % self._n_experts == 0 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + """Feed-forward block with block sparsity. - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. - Returns: - Tensor of same shape and dtype as the input. + This block sparse layer mimics mixture of experts architecture. + It divides the dimension of d_ff in each weight matrix to # of blocks equal to + n_experts and activates only one non-zero block from the weights matrix. + This is trained with straight-through Gumbel softmax trick. """ - m1, w1, w2, b2 = self.weights - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Q: check if we need bias and/or put relu after the m1 dot? - mask_logits = jnp.dot(x, m1) - # Softmax. - mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) - log_mask = mask_logits - mask_logsumexp - mask = jnp.exp(log_mask) - # Gumbel-softmax with straight-through discretization. - # TODO(lukaszkaiser, chowdhery): Extract this block and share - rng1, rng2 = fastmath.random.split(self.rng, 2) - u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) - g = -jnp.log(-jnp.log(u)) - selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) - if self._mode == 'train': - # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 - quant_mask = tl.one_hot(selected_experts, self._n_experts) - quant_mask = fastmath.stop_gradient(quant_mask) - quant_mask += mask - fastmath.stop_gradient(mask) # straight-through - # We will sometimes (50% of the batches) use the soft-mask instead of - # the quantized mask to improve training stability (see the paper above). - # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? - select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) - quant_mask = jnp.where(select > 0.0, quant_mask, mask) - else: - quant_mask = tl.one_hot(selected_experts, self._n_experts) - quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) - batch_size = quant_mask.shape[0] - - if self._mode == 'predict' and batch_size == 1: - # This implementation mimicks inference for batch_size 1. - start_idx = selected_experts[0] * self._n_elements_in_block - # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] - w = fastmath.dynamic_slice(w1, [0, start_idx], - [w1.shape[0], self._n_elements_in_block]) - mid = jnp.dot(x, w) - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] - v = fastmath.dynamic_slice(w2, [start_idx, 0], - [self._n_elements_in_block, w2.shape[-1]]) - v = jnp.reshape(v, [self._n_elements_in_block, -1]) - res = jnp.dot(relu, v) + b2 - else: - expanded_mask = jnp.broadcast_to( - quant_mask, - (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block)) - expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) - mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - res = jnp.dot(relu, w2) + b2 - - return jnp.reshape(res, x_shape) # un-flatten if needed - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature.shape[-1] - shape_m1 = (d_model, self._n_experts) - shape_w1 = (d_model, self._d_ff) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) - m1 = self._kernel_initializer(shape_m1, rng_m1) - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - b2 = self._bias_initializer(shape_b2, rng_b2) - - self.weights = (m1, w1, w2, b2) + def __init__( + self, + d_ff, + n_experts=64, + temperature=0.7, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + """Returns a block sparse feed-forward block.""" + super().__init__(name=f"BlockSparseFF_{d_ff}") + self._mode = mode + self._d_ff = d_ff + self._n_experts = n_experts + self._temperature = temperature if mode == "train" else 0.0 + self._n_elements_in_block = d_ff // n_experts + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + assert self._d_ff % self._n_experts == 0 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + m1, w1, w2, b2 = self.weights + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Q: check if we need bias and/or put relu after the m1 dot? + mask_logits = jnp.dot(x, m1) + # Softmax. + mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) + log_mask = mask_logits - mask_logsumexp + mask = jnp.exp(log_mask) + # Gumbel-softmax with straight-through discretization. + # TODO(lukaszkaiser, chowdhery): Extract this block and share + rng1, rng2 = fastmath.random.split(self.rng, 2) + u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) + g = -jnp.log(-jnp.log(u)) + selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) + if self._mode == "train": + # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 + quant_mask = tl.one_hot(selected_experts, self._n_experts) + quant_mask = fastmath.stop_gradient(quant_mask) + quant_mask += mask - fastmath.stop_gradient(mask) # straight-through + # We will sometimes (50% of the batches) use the soft-mask instead of + # the quantized mask to improve training stability (see the paper above). + # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? + select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) + quant_mask = jnp.where(select > 0.0, quant_mask, mask) + else: + quant_mask = tl.one_hot(selected_experts, self._n_experts) + quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) + batch_size = quant_mask.shape[0] + + if self._mode == "predict" and batch_size == 1: + # This implementation mimicks inference for batch_size 1. + start_idx = selected_experts[0] * self._n_elements_in_block + # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] + w = fastmath.dynamic_slice( + w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block] + ) + mid = jnp.dot(x, w) + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] + v = fastmath.dynamic_slice( + w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]] + ) + v = jnp.reshape(v, [self._n_elements_in_block, -1]) + res = jnp.dot(relu, v) + b2 + else: + expanded_mask = jnp.broadcast_to( + quant_mask, + (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block), + ) + expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) + mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + res = jnp.dot(relu, w2) + b2 + + return jnp.reshape(res, x_shape) # un-flatten if needed + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature.shape[-1] + shape_m1 = (d_model, self._n_experts) + shape_w1 = (d_model, self._d_ff) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) + m1 = self._kernel_initializer(shape_m1, rng_m1) + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + b2 = self._bias_initializer(shape_b2, rng_b2) + + self.weights = (m1, w1, w2, b2) class SwitchSparseFF(base.Layer): - """Feed-forward block with switch-style block sparsity. - - The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense - that takes an input, makes it of size d_ff (usually larger than it was) and - then brings it back to the original size after Relu. It is commonly used in - Transformer models where it often accounts for most of the trainable weights. - - This block sparse layer mimics mixture of experts architecture. - It divides the dimension of d_ff in each weight matrix to # of blocks equal to - n_experts and activates only one non-zero block from the weights matrix. - This is trained with methods following the Switch Transformer. - """ - - def __init__(self, - d_ff, - n_experts=64, - temperature=0.1, - mode='train', - kernel_initializer=init.GlorotUniformInitializer(), - bias_initializer=init.RandomNormalInitializer(1e-6)): - """Returns a switch-style training block sparse feed-forward block.""" - super().__init__(name=f'SwitchSparseFF_{d_ff}') - self._mode = mode - self._d_ff = d_ff - self._n_experts = n_experts - self._temperature = temperature if mode == 'train' else 0.0 - self._n_elements_in_block = d_ff // n_experts - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - assert self._d_ff % self._n_experts == 0 - - def forward(self, x): - """Executes this layer as part of a forward pass through the model. + """Feed-forward block with switch-style block sparsity. - Args: - x: Tensor of same shape and dtype as the input signature used to - initialize this layer. + The original (non-sparse) FF block is a triple Dense(d_ff)-Relu-Dense + that takes an input, makes it of size d_ff (usually larger than it was) and + then brings it back to the original size after Relu. It is commonly used in + Transformer models where it often accounts for most of the trainable weights. - Returns: - Tensor of same shape and dtype as the input. + This block sparse layer mimics mixture of experts architecture. + It divides the dimension of d_ff in each weight matrix to # of blocks equal to + n_experts and activates only one non-zero block from the weights matrix. + This is trained with methods following the Switch Transformer. """ - m1, w1, w2, b2 = self.weights - x_shape = x.shape - x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. - - # Q: check if we need bias and/or put relu after the m1 dot? - mask_logits = jnp.dot(x, m1) - # Softmax. - mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) - log_mask = mask_logits - mask_logsumexp - mask = jnp.exp(log_mask) - # Gumbel noise to allow sampling from the softmax. - rng1, _ = fastmath.random.split(self.rng, 2) - u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) - g = -jnp.log(-jnp.log(u)) - selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) - quant_mask = tl.one_hot(selected_experts, self._n_experts) - quant_mask = fastmath.stop_gradient(quant_mask) - quant_mask *= mask # go to just the selected expert - quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) - batch_size = quant_mask.shape[0] - - if self._mode == 'predict' and batch_size == 1: - mask_flat = jnp.reshape(mask, [-1, self._n_experts]) - selected_flat = jnp.reshape(selected_experts, [-1]) - selected_mask_flat = mask_flat[np.arange(selected_flat.size), - selected_flat] - # This implementation mimicks inference for batch_size 1. - start_idx = selected_experts[0] * self._n_elements_in_block - # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] - w = fastmath.dynamic_slice(w1, [0, start_idx], - [w1.shape[0], self._n_elements_in_block]) - mid = jnp.dot(x, w) - mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None] - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] - v = fastmath.dynamic_slice(w2, [start_idx, 0], - [self._n_elements_in_block, w2.shape[-1]]) - v = jnp.reshape(v, [self._n_elements_in_block, -1]) - res = jnp.dot(relu, v) + b2 - else: - expanded_mask = jnp.broadcast_to( - quant_mask, - (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block)) - expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) - mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] - relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) - res = jnp.dot(relu, w2) + b2 - - return jnp.reshape(res, x_shape) # un-flatten if needed - - def init_weights_and_state(self, input_signature): - """Randomly initializes this layer's weights.""" - d_model = input_signature.shape[-1] - shape_m1 = (d_model, self._n_experts) - shape_w1 = (d_model, self._d_ff) - shape_w2 = (self._d_ff, d_model) - shape_b2 = (d_model,) - - rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) - m1 = self._kernel_initializer(shape_m1, rng_m1) - w1 = self._kernel_initializer(shape_w1, rng_w1) - w2 = self._kernel_initializer(shape_w2, rng_w2) - b2 = self._bias_initializer(shape_b2, rng_b2) - - self.weights = (m1, w1, w2, b2) + + def __init__( + self, + d_ff, + n_experts=64, + temperature=0.1, + mode="train", + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6), + ): + """Returns a switch-style training block sparse feed-forward block.""" + super().__init__(name=f"SwitchSparseFF_{d_ff}") + self._mode = mode + self._d_ff = d_ff + self._n_experts = n_experts + self._temperature = temperature if mode == "train" else 0.0 + self._n_elements_in_block = d_ff // n_experts + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + assert self._d_ff % self._n_experts == 0 + + def forward(self, x): + """Executes this layer as part of a forward pass through the model. + + Args: + x: Tensor of same shape and dtype as the input signature used to + initialize this layer. + + Returns: + Tensor of same shape and dtype as the input. + """ + m1, w1, w2, b2 = self.weights + x_shape = x.shape + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + + # Q: check if we need bias and/or put relu after the m1 dot? + mask_logits = jnp.dot(x, m1) + # Softmax. + mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) + log_mask = mask_logits - mask_logsumexp + mask = jnp.exp(log_mask) + # Gumbel noise to allow sampling from the softmax. + rng1, _ = fastmath.random.split(self.rng, 2) + u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) + g = -jnp.log(-jnp.log(u)) + selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) + quant_mask = tl.one_hot(selected_experts, self._n_experts) + quant_mask = fastmath.stop_gradient(quant_mask) + quant_mask *= mask # go to just the selected expert + quant_mask = jnp.reshape(quant_mask, [-1, self._n_experts, 1]) + batch_size = quant_mask.shape[0] + + if self._mode == "predict" and batch_size == 1: + mask_flat = jnp.reshape(mask, [-1, self._n_experts]) + selected_flat = jnp.reshape(selected_experts, [-1]) + selected_mask_flat = mask_flat[np.arange(selected_flat.size), selected_flat] + # This implementation mimicks inference for batch_size 1. + start_idx = selected_experts[0] * self._n_elements_in_block + # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] + w = fastmath.dynamic_slice( + w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block] + ) + mid = jnp.dot(x, w) + mid *= jnp.reshape(selected_mask_flat, mid.shape[:-1])[..., None] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] + v = fastmath.dynamic_slice( + w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]] + ) + v = jnp.reshape(v, [self._n_elements_in_block, -1]) + res = jnp.dot(relu, v) + b2 + else: + expanded_mask = jnp.broadcast_to( + quant_mask, + (quant_mask.shape[0], quant_mask.shape[1], self._n_elements_in_block), + ) + expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) + mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + res = jnp.dot(relu, w2) + b2 + + return jnp.reshape(res, x_shape) # un-flatten if needed + + def init_weights_and_state(self, input_signature): + """Randomly initializes this layer's weights.""" + d_model = input_signature.shape[-1] + shape_m1 = (d_model, self._n_experts) + shape_w1 = (d_model, self._d_ff) + shape_w2 = (self._d_ff, d_model) + shape_b2 = (d_model,) + + rng_m1, rng_w1, rng_w2, rng_b2 = fastmath.random.split(self.rng, 4) + m1 = self._kernel_initializer(shape_m1, rng_m1) + w1 = self._kernel_initializer(shape_w1, rng_w1) + w2 = self._kernel_initializer(shape_w2, rng_w2) + b2 = self._bias_initializer(shape_b2, rng_b2) + + self.weights = (m1, w1, w2, b2) diff --git a/trax/layers/research/sparsity_test.py b/trax/layers/research/sparsity_test.py deleted file mode 100644 index dd39091aa..000000000 --- a/trax/layers/research/sparsity_test.py +++ /dev/null @@ -1,466 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.research.efficient_attention.""" - -import functools -from absl.testing import parameterized -import jax -import numpy as np -from tensorflow import test - -from trax import fastmath -from trax import shapes -import trax.layers as tl -from trax.layers import test_utils -from trax.layers.research import sparsity - - -class EfficientFeedForwardTest(test.TestCase, parameterized.TestCase): - - def test_blocksparse_ff_train(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (3, 7, d_model) - with fastmath.use_backend(fastmath.Backend.JAX): - layer = sparsity.BlockSparseFF( - d_ff=d_ff, n_experts=n_experts, temperature=0.7, mode='train') - x = np.ones(x_shape).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_blocksparse_ff_predict_equals_eval(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (1, 1, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - common_kwargs = dict( - d_ff=d_ff, - n_experts=n_experts, - temperature=temperature, - ) - eval_model = sparsity.BlockSparseFF( - mode='eval', **common_kwargs) - weights, state = eval_model.init(input_signature) - eval_out, _ = eval_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - pred_model = sparsity.BlockSparseFF( - mode='predict', **common_kwargs) - _, _ = pred_model.init(input_signature) - pred_out, _ = pred_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(eval_out.shape, x.shape) - # eval_out and pred_out should be identical. - np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) - - def test_sparse_ff_predict_equals_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - d_model = 64 - seq_len = 6 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - sparsity.SparseFF, - d_ff=256, - temperature=0.7, - n_elements_in_block=8, - ) - - configs = [ - {'multiply_by_controller_output': True}, - {'multiply_by_controller_output': False}, - {'ff_chunk_size': 2}, - ] - - test_utils.test_eval_equals_predict_configs(inp, model_fn, configs) - - @parameterized.named_parameters(('_mode_train', 'train'), - ('_mode_eval', 'eval'), - ('_mode_predict', 'predict')) - def test_sparse_ff_with_chunking(self, mode): - d_model = 8 - n_elements_in_block = 2 - d_ff = 16 - x_shape = (2, 8, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - model = sparsity.SparseFF( - d_ff=d_ff, - n_elements_in_block=n_elements_in_block, - temperature=temperature, - ff_chunk_size=4, - mode=mode) - weights, state = model.init(input_signature) - out, _ = model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(out.shape, x.shape) - - @parameterized.named_parameters(('_mode_train', 'train'), - ('_mode_eval', 'eval'), - ('_mode_predict', 'predict')) - def test_sparse_ff_multiply(self, mode): - d_model = 8 - n_elements_in_block = 2 - d_ff = 16 - x_shape = (2, 8, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - model = sparsity.SparseFF( - d_ff=d_ff, - n_elements_in_block=n_elements_in_block, - temperature=temperature, - ff_chunk_size=4, - mode=mode, - multiply_by_controller_output=True) - weights, state = model.init(input_signature) - out, _ = model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(out.shape, x.shape) - - def test_sparse_ff_kernel_scaling(self): - d_model = 8 - n_elements_in_block = 2 - d_ff = 16 - x_shape = (2, 8, d_model) - temperature = 0.7 - with fastmath.use_backend(fastmath.Backend.JAX): - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - model = sparsity.SparseFF( - d_ff=d_ff, - n_elements_in_block=n_elements_in_block, - temperature=temperature, - ff_chunk_size=4, - mode='train', - kernel_scaling=True) - weights, state = model.init(input_signature) - out, _ = model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(out.shape, x.shape) - - def test_switchsparse_ff_train(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (3, 7, d_model) - layer = sparsity.SwitchSparseFF( - d_ff=d_ff, n_experts=n_experts, mode='train') - x = np.ones(x_shape).astype(np.float32) - layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_switchsparse_ff_predict_equals_eval(self): - d_model = 1024 - n_experts = 64 - d_ff = d_model * 8 - x_shape = (1, 1, d_model) - x = np.ones(x_shape).astype(np.float32) - input_signature = shapes.signature(x) - eval_model = sparsity.SwitchSparseFF( - mode='eval', d_ff=d_ff, n_experts=n_experts) - weights, state = eval_model.init(input_signature) - eval_out, _ = eval_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - pred_model = sparsity.SwitchSparseFF( - mode='predict', d_ff=d_ff, n_experts=n_experts) - pred_model.init(input_signature) - pred_out, _ = pred_model.pure_fn( - x, weights, state, rng=jax.random.PRNGKey(0)) - self.assertEqual(eval_out.shape, x.shape) - # eval_out and pred_out should be identical. - np.testing.assert_array_almost_equal(eval_out[0, 0, :], pred_out[0, 0, :]) - - -class ReversibleReshapePermuteTest(test.TestCase): - - def test_reversible_permute(self): - layer = sparsity.ReversibleReshapePermute() - x = np.array([[1, 2, 3, 4, 5, 6, 7, 8], - [0, 1, 2, 3, 4, 5, 6, 7]]) - layer.init(shapes.signature(x)) - ys = layer(x) - self.assertEqual(tl.to_list(ys), [ - [1, 3, 5, 7, 2, 4, 6, 8], - [0, 2, 4, 6, 1, 3, 5, 7]]) - rev_x = layer.reverse(ys, weights=layer.weights) - self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) - - -class ReversibleRandomPermuteTest(test.TestCase): - - def test_reversible_permute(self): - layer = sparsity.ReversibleRandomPermute() - x = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - [0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 11, 12, 13], - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - ]) - layer.init(shapes.signature(x)) - ys = layer(x) - # this assert will fail once per ~87B runs, but it's okay - self.assertNotEqual(tl.to_list(ys), tl.to_list(x)) - - self.assertEqual(tl.to_list(ys[0]), tl.to_list(ys[2])) - self.assertNotEqual(tl.to_list(ys[0]), tl.to_list(ys[1])) - rev_x = layer.reverse(ys, weights=layer.weights) - self.assertEqual(tl.to_list(x), tl.to_list(rev_x)) - - -class LocallyConnectedDenseTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.LocallyConnectedDense(2, 8) - x = np.array([[2, 5, 3, 4], - [0, 1, 2, 3]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (2, 16)) - - -class SparseDenseWithOptionsTest(test.TestCase): - - def test_simple_call(self): - d_input, d_output = 16, 32 - settings = [ - (None, 0, 0, False), - (None, 0, 0, True), - ('einsum', 0, 0, False), - ('lowrank', 0, 8, False), - ('mult', 2, 0, False), - ('mult', 2, 0, True), - ('local', 2, 0, False), - ('local3', 2, 0, False), - ] - for stype, sparsity_level, d_lowrank, use_bfloat16 in settings: - layer = sparsity.SparseDenseWithOptions( - d_output, d_input=d_input, sparsity_type=stype, - sparsity=sparsity_level, d_lowrank=d_lowrank, - use_bfloat16=use_bfloat16) - x = np.ones((1, 1, d_input)) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, (1, 1, d_output), - msg='[{}->{}] {} - {} - {} - {}'.format( - d_input, d_output, stype, sparsity_level, d_lowrank, - use_bfloat16)) - - -class ModularCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.ModularCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class LowRankCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.LowRankCausalAttention( - d_feature=4, n_heads=2, lowrank=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class MultiplicativeCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.MultiplicativeCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class MultiplicativeModularCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.MultiplicativeModularCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - -class MultiplicativeConvCausalAttentionTest(test.TestCase): - - def test_simple_call(self): - layer = sparsity.MultiplicativeConvCausalAttention( - d_feature=4, n_heads=2, sparsity=2) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - def test_various_calls(self): - list_kwargs = [] - for share_qk in [True, False]: - for output in ['none', 'mult', 'conv', 'multconv']: - for concat in ['original', 'fixed', 'none']: - kwargs = {'share_qk': share_qk, 'output_layer_type': output, - 'v_concat_type': concat} - list_kwargs.append(kwargs) - for kwargs in list_kwargs: - layer = sparsity.MultiplicativeConvCausalAttention( - d_feature=4, n_heads=2, sparsity=2, **kwargs) - x = np.array([[[2, 5, 3, 4], - [0, 1, 2, 3], - [0, 1, 2, 3],]]) - _, _ = layer.init(shapes.signature(x)) - - y = layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - def test_predict_equals_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - d_model = 32 - seq_len = 5 - x_shape = (1, seq_len, d_model) - inp = np.ones(x_shape).astype(np.float32) - - model_fn = functools.partial( - sparsity.MultiplicativeConvCausalAttention, - d_feature=d_model, - n_heads=4, - sparsity=4, - ) - - list_kwargs = [] - for share_qk in [True, False]: - for output in ['none', 'mult', 'conv', 'multconv']: - for concat in ['original', 'fixed', 'none']: - kwargs = {'share_qk': share_qk, 'output_layer_type': output, - 'v_concat_type': concat} - list_kwargs.append(kwargs) - - test_utils.test_eval_equals_predict_configs(inp, model_fn, list_kwargs) - - -class FavorTest(test.TestCase): - - def test_call_and_grad(self): - layer_partial = tl.Serial( - tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), - sparsity.Favor(d_feature=4, n_heads=2), - tl.Select([0], n_in=2), - ) - layer = tl.Serial( - tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), - sparsity.Favor(d_feature=4, n_heads=2), - tl.Select([0], n_in=2), - tl.WeightedCategoryCrossEntropy(), - ) - x = np.ones((1, 2), dtype=np.int32) - w = np.ones_like(x).astype(np.float32) - x_sig = shapes.signature(x) - w_sig = shapes.signature(w) - layer_partial.init(x_sig) - y = layer_partial(x) - self.assertEqual(y.shape, (1, 2, 4)) - layer.init((x_sig, x_sig, w_sig)) - y = layer((x, x, w)) - self.assertEqual(y.shape, ()) - state = layer.state - rng = fastmath.random.get_prng(0) - fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] - g = fastmath.grad(fwd)(layer.weights, (x, x, w)) - self.assertEqual(g[0][1][0].shape, (3, 4)) - - def test_call_and_grad_approximate_softmax(self): - layer_partial = tl.Serial( - tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), - sparsity.Favor(d_feature=12, n_heads=3, n_random_features=128, - use_approximate_softmax=True), - tl.Select([0], n_in=2), - ) - layer = tl.Serial( - tl.Branch(tl.Embedding(11, 12), tl.PaddingMask()), - sparsity.Favor(d_feature=12, n_heads=3, n_random_features=128, - use_approximate_softmax=True), - tl.Select([0], n_in=2), - tl.WeightedCategoryCrossEntropy(), - ) - x = np.ones((3, 5), dtype=np.int32) - w = np.ones_like(x).astype(np.float32) - x_sig = shapes.signature(x) - w_sig = shapes.signature(w) - layer_partial.init(x_sig) - y = layer_partial(x) - self.assertEqual(y.shape, (3, 5, 12)) - layer.init((x_sig, x_sig, w_sig)) - y = layer((x, x, w)) - self.assertEqual(y.shape, ()) - state = layer.state - rng = fastmath.random.get_prng(0) - fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] - g = fastmath.grad(fwd)(layer.weights, (x, x, w)) - self.assertEqual(g[0][1][0].shape, (11, 12)) - - def test_causal_call_and_grad(self): - layer = tl.Serial( - tl.Dense(4), - sparsity.CausalFavor(d_feature=4, n_heads=2), - tl.L2Loss() - ) - x = np.random.uniform(size=(1, 2, 4)).astype(np.float32) - w = np.ones_like(x) - x_sig = shapes.signature(x) - w_sig = shapes.signature(w) - layer.init((x_sig, x_sig, w_sig)) - y = layer((x, x, w)) - self.assertEqual(y.shape, ()) - state = layer.state - rng = fastmath.random.get_prng(0) - fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] - g = fastmath.grad(fwd)(layer.weights, (x, x, w)) - self.assertEqual(g[0][0].shape, (4, 4)) - - -if __name__ == '__main__': - test.main() diff --git a/trax/layers/reversible.py b/trax/layers/reversible.py index 438e2bf52..d61f3c477 100644 --- a/trax/layers/reversible.py +++ b/trax/layers/reversible.py @@ -25,451 +25,501 @@ Transformer](https://arxiv.org/abs/2001.04451). """ -from absl import logging import jax + +from absl import logging + from trax import fastmath from trax.layers import base from trax.layers import combinators as cb - _split_rngs = cb._split_rngs # pylint: disable=protected-access class ReversibleLayer(base.Layer): - """Reversible Layer.""" - - def reverse(self, output, weights=(), state=(), new_state=(), rng=None): - """Reverse this layer: compute input given output.""" - raise NotImplementedError - - def _pure_forward(self, x, weights, state, rng): - """Call self.forward in a pure way.""" - old_weights, old_state, old_rng = self.weights, self.state, self._rng - self.weights, self.state, self._rng = weights, state, rng - res = self.forward(x) - self.weights, self.state, self._rng = old_weights, old_state, old_rng - return res - - def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(), - rng=None): - """Backward pass: computes the inverse of a layer and propagates gradients. - - While you may choose to only implement reverse, some layers implement this - function directly as computation may be shared between reversing and - computing gradients. - - Args: - output: Output activations; can be a (possibly nested) tuple. - grad: gradient signal (cotangent) computed based on subsequent layers. - The structure and shape must match the output. - weights: layer weights - state: start state - new_state: updated state computed by the forward pass - rng: Single-use random number generator (JAX PRNG key). - - Returns: - A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, - x_grad is the gradient signal for the input, and weights_grad is the - gradient signal for the weights. - """ - reconstructed_x = self.reverse(output, weights, state, new_state, rng) - _, vjpfun = fastmath.vjp( - self._pure_forward, reconstructed_x, weights, state, rng) - x_grad, weights_grad, _, _ = vjpfun(grad) - return reconstructed_x, (x_grad, weights_grad) - - @property - def has_backward(self): - return True - - def backward(self, inputs, output, grad, weights, state, new_state, rng): - del inputs - _, inputs_weights_grad = ( - self.reverse_and_grad(output, grad, weights, state, new_state, rng)) - return inputs_weights_grad + """Reversible Layer.""" + + def reverse(self, output, weights=(), state=(), new_state=(), rng=None): + """Reverse this layer: compute input given output.""" + raise NotImplementedError + + def _pure_forward(self, x, weights, state, rng): + """Call self.forward in a pure way.""" + old_weights, old_state, old_rng = self.weights, self.state, self._rng + self.weights, self.state, self._rng = weights, state, rng + res = self.forward(x) + self.weights, self.state, self._rng = old_weights, old_state, old_rng + return res + + def reverse_and_grad( + self, output, grad, weights=(), state=(), new_state=(), rng=None + ): + """Backward pass: computes the inverse of a layer and propagates gradients. + + While you may choose to only implement reverse, some layers implement this + function directly as computation may be shared between reversing and + computing gradients. + + Args: + output: Output activations; can be a (possibly nested) tuple. + grad: gradient signal (cotangent) computed based on subsequent layers. + The structure and shape must match the output. + weights: layer weights + state: start state + new_state: updated state computed by the forward pass + rng: Single-use random number generator (JAX PRNG key). + + Returns: + A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, + x_grad is the gradient signal for the input, and weights_grad is the + gradient signal for the weights. + """ + reconstructed_x = self.reverse(output, weights, state, new_state, rng) + _, vjpfun = fastmath.vjp( + self._pure_forward, reconstructed_x, weights, state, rng + ) + x_grad, weights_grad, _, _ = vjpfun(grad) + return reconstructed_x, (x_grad, weights_grad) + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, grad, weights, state, new_state, rng): + del inputs + _, inputs_weights_grad = self.reverse_and_grad( + output, grad, weights, state, new_state, rng + ) + return inputs_weights_grad class ReversibleConcatenatePair(ReversibleLayer): - """Maps (x, y) -> ([x, y], [x, y]); [x, y] is concatenation on last axis.""" + """Maps (x, y) -> ([x, y], [x, y]); [x, y] is concatenation on last axis.""" - def __init__(self): - super().__init__(n_in=2, n_out=2) + def __init__(self): + super().__init__(n_in=2, n_out=2) - def forward(self, inputs): - x, y = inputs - r = fastmath.numpy.concatenate((x, y), axis=-1) - return r, r + def forward(self, inputs): + x, y = inputs + r = fastmath.numpy.concatenate((x, y), axis=-1) + return r, r - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - pair, _ = outputs - x, y = fastmath.numpy.split(pair, 2, axis=-1) - return x, y + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + pair, _ = outputs + x, y = fastmath.numpy.split(pair, 2, axis=-1) + return x, y class ReversibleSelect(ReversibleLayer): - """Reversible version of the Select combinator.""" - - def __init__(self, indices, n_in=None, name=None): - if n_in is None: - n_in = max(indices) + 1 - if name is None: - name = f'ReversibleSelect{indices}'.replace(' ', '') - super().__init__(n_in=n_in, n_out=len(indices), name=name) - self._indices = indices - - # Calculate reverse indices. - self._reverse_indices = [] - for i in range(n_in): - if i not in indices: - raise ValueError('To be reversible, all inputs to Select must be in ' - 'indices. Did not find %d in indices.' % i) - else: - self._reverse_indices.append(indices.index(i)) - - def forward(self, inputs): - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - selected = tuple(inputs[i] for i in self._indices) - return selected[0] if len(selected) == 1 else selected - - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - if not isinstance(outputs, (tuple, list)): - outputs = (outputs,) - selected = tuple(outputs[i] for i in self._reverse_indices) - return selected[0] if len(selected) == 1 else selected + """Reversible version of the Select combinator.""" + + def __init__(self, indices, n_in=None, name=None): + if n_in is None: + n_in = max(indices) + 1 + if name is None: + name = f"ReversibleSelect{indices}".replace(" ", "") + super().__init__(n_in=n_in, n_out=len(indices), name=name) + self._indices = indices + + # Calculate reverse indices. + self._reverse_indices = [] + for i in range(n_in): + if i not in indices: + raise ValueError( + "To be reversible, all inputs to Select must be in " + "indices. Did not find %d in indices." % i + ) + else: + self._reverse_indices.append(indices.index(i)) + + def forward(self, inputs): + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + selected = tuple(inputs[i] for i in self._indices) + return selected[0] if len(selected) == 1 else selected + + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + if not isinstance(outputs, (tuple, list)): + outputs = (outputs,) + selected = tuple(outputs[i] for i in self._reverse_indices) + return selected[0] if len(selected) == 1 else selected def ReversibleSwap(): # pylint: disable=invalid-name - return ReversibleSelect([1, 0], name='ReversibleSwap') + return ReversibleSelect([1, 0], name="ReversibleSwap") class ReversibleReshape(ReversibleLayer): - """Reversible reshaping layer.""" - - def __init__(self, shape1, shape2, n_in=1): - self._shape1 = list(shape1) - self._shape2 = list(shape2) - name = 'ReversibleReshape_%s_%s' % (str(shape1), str(shape2)) - super().__init__(n_in=n_in, n_out=n_in, name=name) - - def forward(self, inputs): - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - res = [] - for x in inputs: - new_shape = self._shape2 + list(x.shape)[len(self._shape1):] - res.append(fastmath.numpy.reshape(x, new_shape)) - if len(res) == 1: - return res[0] - return tuple(res) - - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - if not isinstance(outputs, (tuple, list)): - outputs = (outputs,) - res = [] - for x in outputs: - new_shape = self._shape1 + list(x.shape)[len(self._shape2):] - res.append(fastmath.numpy.reshape(x, new_shape)) - if len(res) == 1: - return res[0] - return tuple(res) + """Reversible reshaping layer.""" + + def __init__(self, shape1, shape2, n_in=1): + self._shape1 = list(shape1) + self._shape2 = list(shape2) + name = "ReversibleReshape_%s_%s" % (str(shape1), str(shape2)) + super().__init__(n_in=n_in, n_out=n_in, name=name) + + def forward(self, inputs): + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + res = [] + for x in inputs: + new_shape = self._shape2 + list(x.shape)[len(self._shape1) :] + res.append(fastmath.numpy.reshape(x, new_shape)) + if len(res) == 1: + return res[0] + return tuple(res) + + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + if not isinstance(outputs, (tuple, list)): + outputs = (outputs,) + res = [] + for x in outputs: + new_shape = self._shape1 + list(x.shape)[len(self._shape2) :] + res.append(fastmath.numpy.reshape(x, new_shape)) + if len(res) == 1: + return res[0] + return tuple(res) class ReversiblePrintShape(ReversibleLayer): - """Reversible PrintShape for debugging reversible serial layers.""" + """Reversible PrintShape for debugging reversible serial layers.""" - def __init__(self, n_in=1, msg=''): - super().__init__(n_in=n_in, n_out=n_in) - self._msg = msg + def __init__(self, n_in=1, msg=""): + super().__init__(n_in=n_in, n_out=n_in) + self._msg = msg - def forward(self, xs): - shapes_and_dtypes = ', '.join([str(x.shape) + f'[{x.dtype}]' for x in xs]) - info = f'PrintShape: {self._msg}: [{shapes_and_dtypes}]' - print(info) - logging.info(info) - return xs + def forward(self, xs): + shapes_and_dtypes = ", ".join([str(x.shape) + f"[{x.dtype}]" for x in xs]) + info = f"PrintShape: {self._msg}: [{shapes_and_dtypes}]" + print(info) + logging.info(info) + return xs - def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): - del state, new_state, rng, weights - return outputs + def reverse(self, outputs, weights=(), state=(), new_state=(), rng=None): + del state, new_state, rng, weights + return outputs class ReversibleSerial(ReversibleLayer, cb.Serial): - """A reversible version of tl.Serial (requires reversible sub-layers).""" - - def __init__(self, *layers): - super().__init__(*layers) - # def __init__(self, *layers): # pylint: disable=super-init-not-called - # cb.Serial.__init__(self, layers) - - # Note that sublayers has already been flattened to remove nested lists. - for i, layer in enumerate(self.sublayers): - if not isinstance(layer, ReversibleLayer): - raise ValueError( - 'Sub-layer {} of ReversibleSerial is not reversible: {}'.format( - i, layer)) - - def reverse(self, output, weights=(), state=(), new_state=(), rng=None): - rngs = (None,) * self._n_layers - if rng is not None: - rngs = fastmath.random.split(rng, self._n_layers) - - stack = output - for layer, p, s, ns, rng in reversed(list(zip( - self.sublayers, weights, state, new_state, rngs))): - layer_val = cb.inputs_from_stack(stack, layer.n_out) - layer_val = layer.reverse(layer_val, p, s, ns, rng=rng) - stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) - - return stack - - def reverse_and_grad(self, output, grad, weights=(), state=(), new_state=(), - rng=None): - rngs = (None,) * self._n_layers - if rng is not None: - rngs = fastmath.random.split(rng, self._n_layers) - - stack = output - stack_grad = grad - weights_grad = [] - for layer, p, s, ns, rng in reversed(list(zip( - self.sublayers, weights, state, new_state, rngs))): - layer_val = cb.inputs_from_stack(stack, layer.n_out) - layer_ct = cb.inputs_from_stack(stack_grad, layer.n_out) - layer_val, layer_ct = layer.reverse_and_grad( - layer_val, layer_ct, p, s, ns, rng=rng) - layer_ct, p_ct = layer_ct - weights_grad.insert(0, p_ct) - stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) - stack_grad = cb.outputs_onto_stack(layer_ct, stack_grad, layer.n_out) - - return stack, (stack_grad, tuple(weights_grad)) + """A reversible version of tl.Serial (requires reversible sub-layers).""" + + def __init__(self, *layers): + super().__init__(*layers) + # def __init__(self, *layers): # pylint: disable=super-init-not-called + # cb.Serial.__init__(self, layers) + + # Note that sublayers has already been flattened to remove nested lists. + for i, layer in enumerate(self.sublayers): + if not isinstance(layer, ReversibleLayer): + raise ValueError( + "Sub-layer {} of ReversibleSerial is not reversible: {}".format( + i, layer + ) + ) + + def reverse(self, output, weights=(), state=(), new_state=(), rng=None): + rngs = (None,) * self._n_layers + if rng is not None: + rngs = fastmath.random.split(rng, self._n_layers) + + stack = output + for layer, p, s, ns, rng in reversed( + list(zip(self.sublayers, weights, state, new_state, rngs)) + ): + layer_val = cb.inputs_from_stack(stack, layer.n_out) + layer_val = layer.reverse(layer_val, p, s, ns, rng=rng) + stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) + + return stack + + def reverse_and_grad( + self, output, grad, weights=(), state=(), new_state=(), rng=None + ): + rngs = (None,) * self._n_layers + if rng is not None: + rngs = fastmath.random.split(rng, self._n_layers) + + stack = output + stack_grad = grad + weights_grad = [] + for layer, p, s, ns, rng in reversed( + list(zip(self.sublayers, weights, state, new_state, rngs)) + ): + layer_val = cb.inputs_from_stack(stack, layer.n_out) + layer_ct = cb.inputs_from_stack(stack_grad, layer.n_out) + layer_val, layer_ct = layer.reverse_and_grad( + layer_val, layer_ct, p, s, ns, rng=rng + ) + layer_ct, p_ct = layer_ct + weights_grad.insert(0, p_ct) + stack = cb.outputs_onto_stack(layer_val, stack, layer.n_out) + stack_grad = cb.outputs_onto_stack(layer_ct, stack_grad, layer.n_out) + + return stack, (stack_grad, tuple(weights_grad)) class ReversibleHalfResidual(ReversibleLayer): - """Half of a RevNet-style residual that optionally performs attention. - - When attention_layer is None, this layer has the signature :: - - [accumulator, *context] -> [accumulator + f(context), *context] - - The attention_layer must be an instance of EfficientAttentionBase or one of - its subclasses (see efficient_attention.py), or None. - - Attention is special-cased for the following two reasons: - - - LSH attention needs to save bucket assignments from the forward pass to the - backward pass, for training stability. This requires special-casing it. - - We can call attention_layer.forward_and_or_backward to compute its output - (needed for inverting a reversible residual layer) while simultaneously - performing the backward pass. Sharing computation between these two - operations improves training speed. - """ - - def __init__(self, *residual_layers, attention_layer=None, name=None): - super().__init__(name=name) - - self._compute_residual = cb.Serial(*residual_layers) - self._attention_layer = attention_layer - - if self._attention_layer is None: - self._sublayers = (self._compute_residual,) - else: - if hasattr(attention_layer, 'forward_and_or_backward'): - self._forward_and_or_backward = attention_layer.forward_and_or_backward - else: - self._forward_and_or_backward = _forward_and_or_backward( - attention_layer) - self._sublayers = (self._compute_residual, self._attention_layer) - - running_max = 0 - running_total = 0 - for layer in self._sublayers: - running_total += layer.n_in - running_max = max(running_max, running_total) - running_total -= layer.n_out - self._n_in = self._n_out = running_max + 1 - - def forward(self, xs): - rngs = _split_rngs(self.rng, len(self.sublayers)) - accumulator, *context = xs - stack = context = tuple(context) - new_state = [] - for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): - inputs = cb.inputs_from_stack(stack, layer.n_in) - if base.N_WEIGHTS_SHARDS > 1: - # With sharded weights, make sure we don't keep them concatenated - # in memory on each device by using remat. - outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng) - else: - outputs, s = layer.pure_fn(inputs, w, s, rng) - stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) - new_state.append(s) - residual = stack[0] if isinstance(stack, (tuple, list)) else stack - - output = accumulator + residual - stack = (output,) + context - self.state = tuple(new_state) - return stack - - def reverse(self, output, weights=(), state=(), new_state=(), rng=None): - raise NotImplementedError('Only reverse_and_grad is actually used.') - - def reverse_and_grad(self, output, ct, weights=(), state=(), new_state=(), - rng=None): - rngs = _split_rngs(rng, len(self.sublayers)) - - accumulator_output, *context = output - context = tuple(context) - accumulator_output_ct, *context_ct = ct - context_ct = tuple(context_ct) - - # Forward pass through self._compute_residual. Outputs that will not receive - # a gradient signal from subsequent layers are moved to aux. - def call_compute_residual(x, weights): - state_to_pass = state[0] # old_state - - # _replace_second_time is currently used exclusively in _RememberInReverse - # layer to combat numerical instability in Terraformer when quantizing - # the mask in SparseFF. - def _replace_second_time(stt, nstt): - if (isinstance(stt, tuple) and len(stt) == 2 and - isinstance(stt[1], dict) and 'running_second_time' in stt[1]): - return (nstt[0], {'running_second_time_yes': ()}) - elif isinstance(stt, (tuple, list)): - assert isinstance(nstt, (tuple, list)) and len(nstt) == len(stt) - return type(stt)([ - _replace_second_time(s, ns) for s, ns in zip(stt, nstt)]) + """Half of a RevNet-style residual that optionally performs attention. + + When attention_layer is None, this layer has the signature :: + + [accumulator, *context] -> [accumulator + f(context), *context] + + The attention_layer must be an instance of EfficientAttentionBase or one of + its subclasses (see efficient_attention.py), or None. + + Attention is special-cased for the following two reasons: + + - LSH attention needs to save bucket assignments from the forward pass to the + backward pass, for training stability. This requires special-casing it. + - We can call attention_layer.forward_and_or_backward to compute its output + (needed for inverting a reversible residual layer) while simultaneously + performing the backward pass. Sharing computation between these two + operations improves training speed. + """ + + def __init__(self, *residual_layers, attention_layer=None, name=None): + super().__init__(name=name) + + self._compute_residual = cb.Serial(*residual_layers) + self._attention_layer = attention_layer + + if self._attention_layer is None: + self._sublayers = (self._compute_residual,) + else: + if hasattr(attention_layer, "forward_and_or_backward"): + self._forward_and_or_backward = attention_layer.forward_and_or_backward + else: + self._forward_and_or_backward = _forward_and_or_backward( + attention_layer + ) + self._sublayers = (self._compute_residual, self._attention_layer) + + running_max = 0 + running_total = 0 + for layer in self._sublayers: + running_total += layer.n_in + running_max = max(running_max, running_total) + running_total -= layer.n_out + self._n_in = self._n_out = running_max + 1 + + def forward(self, xs): + rngs = _split_rngs(self.rng, len(self.sublayers)) + accumulator, *context = xs + stack = context = tuple(context) + new_state = [] + for layer, w, s, rng in zip(self.sublayers, self.weights, self.state, rngs): + inputs = cb.inputs_from_stack(stack, layer.n_in) + if base.N_WEIGHTS_SHARDS > 1: + # With sharded weights, make sure we don't keep them concatenated + # in memory on each device by using remat. + outputs, s = jax.remat(layer.pure_fn)(inputs, w, s, rng) + else: + outputs, s = layer.pure_fn(inputs, w, s, rng) + stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) + new_state.append(s) + residual = stack[0] if isinstance(stack, (tuple, list)) else stack + + output = accumulator + residual + stack = (output,) + context + self.state = tuple(new_state) + return stack + + def reverse(self, output, weights=(), state=(), new_state=(), rng=None): + raise NotImplementedError("Only reverse_and_grad is actually used.") + + def reverse_and_grad( + self, output, ct, weights=(), state=(), new_state=(), rng=None + ): + rngs = _split_rngs(rng, len(self.sublayers)) + + accumulator_output, *context = output + context = tuple(context) + accumulator_output_ct, *context_ct = ct + context_ct = tuple(context_ct) + + # Forward pass through self._compute_residual. Outputs that will not receive + # a gradient signal from subsequent layers are moved to aux. + def call_compute_residual(x, weights): + state_to_pass = state[0] # old_state + + # _replace_second_time is currently used exclusively in _RememberInReverse + # layer to combat numerical instability in Terraformer when quantizing + # the mask in SparseFF. + def _replace_second_time(stt, nstt): + if ( + isinstance(stt, tuple) + and len(stt) == 2 + and isinstance(stt[1], dict) + and "running_second_time" in stt[1] + ): + return (nstt[0], {"running_second_time_yes": ()}) + elif isinstance(stt, (tuple, list)): + assert isinstance(nstt, (tuple, list)) and len(nstt) == len(stt) + return type(stt)( + [_replace_second_time(s, ns) for s, ns in zip(stt, nstt)] + ) + else: + return stt + + state_to_pass = _replace_second_time(state_to_pass, new_state[0]) + res, _ = self._compute_residual.pure_fn( + x, weights=weights, state=state_to_pass, rng=rngs[0] + ) + if not isinstance(res, (tuple, list)): + return res, None + else: + n_differentiable = 1 + if self._attention_layer is not None: + n_differentiable = min(len(res), self._attention_layer.n_in) + return res[:n_differentiable], res[n_differentiable:] + + stack = context + inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) + outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( + call_compute_residual, inputs, weights[0], has_aux=True + ) + if outputs_aux is not None: + n_differentiable_outputs = len(outputs) + outputs = outputs + outputs_aux + stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) + + stack_ct = accumulator_output_ct + if self._attention_layer is None: + residual = stack[0] if isinstance(stack, (tuple, list)) else stack + else: + inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) + ( + residual, + _, + attn_inputs_ct, + attn_weights_ct, + ) = self._forward_and_or_backward( + inputs, + weights[1], + new_state[1], + rngs[1], + output_grad=accumulator_output_ct, + compute_output=True, + update_state=False, + ) + stack_ct = cb.outputs_onto_stack( + attn_inputs_ct, stack_ct, self._attention_layer.n_out + ) + + compute_residual_ct = cb.inputs_from_stack( + stack_ct, self._compute_residual.n_out + ) + if outputs_aux is not None: + if not isinstance(compute_residual_ct, (tuple, list)): + compute_residual_ct = (compute_residual_ct,) + compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] + assert len(compute_residual_ct) == n_differentiable_outputs + ( + compute_residual_inputs_ct, + compute_residual_weights_ct, + ) = compute_residual_vjpfun(compute_residual_ct) + stack_ct = cb.outputs_onto_stack( + compute_residual_inputs_ct, stack_ct, self._compute_residual.n_out + ) + if not isinstance(stack_ct, (tuple, list)): + stack_ct = (stack_ct,) + + def _add(x, y): + # `None` is for TFNP backend, which uses `None` as the gradient of + # int/bool instead of an array of dtype `float0`. + if x is None or x.dtype == jax.float0: + return y + if y is None or y.dtype == jax.float0: + return x + return x + y + + stack_ct = ( + (accumulator_output_ct,) + + fastmath.nested_map_multiarg(_add, context_ct[: len(stack_ct)], stack_ct) + + context_ct[len(stack_ct) :] + ) + + reconstructed_x = accumulator_output - residual + stack = (reconstructed_x,) + context + if self._attention_layer is None: + weights_ct = (compute_residual_weights_ct,) else: - return stt - - state_to_pass = _replace_second_time(state_to_pass, new_state[0]) - res, _ = self._compute_residual.pure_fn( - x, weights=weights, state=state_to_pass, rng=rngs[0]) - if not isinstance(res, (tuple, list)): - return res, None - else: - n_differentiable = 1 - if self._attention_layer is not None: - n_differentiable = min(len(res), self._attention_layer.n_in) - return res[:n_differentiable], res[n_differentiable:] - - stack = context - inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) - outputs, compute_residual_vjpfun, outputs_aux = fastmath.vjp( - call_compute_residual, inputs, weights[0], has_aux=True) - if outputs_aux is not None: - n_differentiable_outputs = len(outputs) - outputs = outputs + outputs_aux - stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) - - stack_ct = accumulator_output_ct - if self._attention_layer is None: - residual = stack[0] if isinstance(stack, (tuple, list)) else stack - else: - inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) - (residual, _, attn_inputs_ct, attn_weights_ct - ) = self._forward_and_or_backward( - inputs, weights[1], new_state[1], rngs[1], - output_grad=accumulator_output_ct, - compute_output=True, update_state=False) - stack_ct = cb.outputs_onto_stack( - attn_inputs_ct, stack_ct, self._attention_layer.n_out) - - compute_residual_ct = cb.inputs_from_stack( - stack_ct, self._compute_residual.n_out) - if outputs_aux is not None: - if not isinstance(compute_residual_ct, (tuple, list)): - compute_residual_ct = (compute_residual_ct,) - compute_residual_ct = compute_residual_ct[:n_differentiable_outputs] - assert len(compute_residual_ct) == n_differentiable_outputs - (compute_residual_inputs_ct, compute_residual_weights_ct - ) = compute_residual_vjpfun(compute_residual_ct) - stack_ct = cb.outputs_onto_stack( - compute_residual_inputs_ct, stack_ct, self._compute_residual.n_out) - if not isinstance(stack_ct, (tuple, list)): - stack_ct = (stack_ct,) - def _add(x, y): - # `None` is for TFNP backend, which uses `None` as the gradient of - # int/bool instead of an array of dtype `float0`. - if x is None or x.dtype == jax.float0: - return y - if y is None or y.dtype == jax.float0: - return x - return x + y - stack_ct = (accumulator_output_ct,) + fastmath.nested_map_multiarg( - _add, context_ct[:len(stack_ct)], stack_ct) + context_ct[len(stack_ct):] - - reconstructed_x = accumulator_output - residual - stack = (reconstructed_x,) + context - if self._attention_layer is None: - weights_ct = (compute_residual_weights_ct,) - else: - weights_ct = (compute_residual_weights_ct, attn_weights_ct) - return stack, (stack_ct, weights_ct) - - # pylint: disable=protected-access - def init_weights_and_state(self, input_signature): - stack = input_signature[1:] - if len(stack) == 1: - stack = stack[0] - - inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) - weights, state = self._compute_residual.init(inputs) - outputs, _ = self._compute_residual._forward_abstract(inputs) - stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) - - if self._attention_layer is None: - self.state = (state,) - self.weights = (weights,) - else: - inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) - attn_weights, attn_state = self._attention_layer.init(inputs) - self.state = (state, attn_state) - self.weights = (weights, attn_weights) - # pylint: enable=protected-access + weights_ct = (compute_residual_weights_ct, attn_weights_ct) + return stack, (stack_ct, weights_ct) + + # pylint: disable=protected-access + def init_weights_and_state(self, input_signature): + stack = input_signature[1:] + if len(stack) == 1: + stack = stack[0] + + inputs = cb.inputs_from_stack(stack, self._compute_residual.n_in) + weights, state = self._compute_residual.init(inputs) + outputs, _ = self._compute_residual._forward_abstract(inputs) + stack = cb.outputs_onto_stack(outputs, stack, self._compute_residual.n_in) + + if self._attention_layer is None: + self.state = (state,) + self.weights = (weights,) + else: + inputs = cb.inputs_from_stack(stack, self._attention_layer.n_in) + attn_weights, attn_state = self._attention_layer.init(inputs) + self.state = (state, attn_state) + self.weights = (weights, attn_weights) + + # pylint: enable=protected-access def _forward_and_or_backward(layer): - """Create forward_and_or_backward for layers that don't define it.""" - - def forward_and_or_backward(inputs, weights, state, rng, output_grad=None, - compute_output=True, update_state=True): - """Performs batched forward and/or backward passes. - - Args: - inputs: inputs to the attention layer - weights: weights for the attention layer - state: state of the attention layer - rng: PRNG key for the layer (shared across all examples and heads) - output_grad: gradient of the loss wrt the output of the layer, or None. - This function performs the backward pass iff `output_grad` is not - None. - compute_output: bool: whether to return the output of the forward pass - (for example, a pure backwards pass does not need to return the - output). - update_state: bool: whether to return an updated layer state. - - Returns: - A tuple (output, new_state, inputs_grad, weights_grad). - - output is not None iff compute_output is True - - new_state is not None iff update_state is True - - inputs_grad & weights_grad are not None iff output_grad is not None - """ - # Calculate the vector-Jacobian product of the layer pure_fn. - output, vjp_fn, new_state = fastmath.vjp( - layer.pure_fn, inputs, weights, state, rng, has_aux=True) - output = output if compute_output else None - new_state = new_state if update_state else None - - # The vjp function returns gradients with respect to inputs and weights. - if output_grad is not None: - grads_inputs, grads_weights, _, _ = vjp_fn(output_grad) - else: - grads_inputs, grads_weights = None, None - - return (output, new_state, grads_inputs, grads_weights) - return forward_and_or_backward + """Create forward_and_or_backward for layers that don't define it.""" + + def forward_and_or_backward( + inputs, + weights, + state, + rng, + output_grad=None, + compute_output=True, + update_state=True, + ): + """Performs batched forward and/or backward passes. + + Args: + inputs: inputs to the attention layer + weights: weights for the attention layer + state: state of the attention layer + rng: PRNG key for the layer (shared across all examples and heads) + output_grad: gradient of the loss wrt the output of the layer, or None. + This function performs the backward pass iff `output_grad` is not + None. + compute_output: bool: whether to return the output of the forward pass + (for example, a pure backwards pass does not need to return the + output). + update_state: bool: whether to return an updated layer state. + + Returns: + A tuple (output, new_state, inputs_grad, weights_grad). + - output is not None iff compute_output is True + - new_state is not None iff update_state is True + - inputs_grad & weights_grad are not None iff output_grad is not None + """ + # Calculate the vector-Jacobian product of the layer pure_fn. + output, vjp_fn, new_state = fastmath.vjp( + layer.pure_fn, inputs, weights, state, rng, has_aux=True + ) + output = output if compute_output else None + new_state = new_state if update_state else None + + # The vjp function returns gradients with respect to inputs and weights. + if output_grad is not None: + grads_inputs, grads_weights, _, _ = vjp_fn(output_grad) + else: + grads_inputs, grads_weights = None, None + + return (output, new_state, grads_inputs, grads_weights) + + return forward_and_or_backward diff --git a/trax/layers/rnn.py b/trax/layers/rnn.py index 3d80cfdda..bce4cd866 100644 --- a/trax/layers/rnn.py +++ b/trax/layers/rnn.py @@ -17,212 +17,235 @@ from trax import fastmath from trax.fastmath import numpy as jnp -from trax.layers import activation_fns -from trax.layers import base +from trax.layers import activation_fns, base, convolution, core, initializers from trax.layers import combinators as cb -from trax.layers import convolution -from trax.layers import core -from trax.layers import initializers class LSTMCell(base.Layer): - """LSTM Cell. - - For a nice overview of the motivation and (i, o, f) gates, see this tutorial: - https://colah.github.io/posts/2015-08-Understanding-LSTMs/ - - See this paper for a description and detailed study of all gate types: - https://arxiv.org/pdf/1503.04069.pdf - """ - - def __init__(self, - n_units, - forget_bias=1.0, - kernel_initializer=initializers.GlorotUniformInitializer(), - bias_initializer=initializers.RandomNormalInitializer(1e-6)): - super().__init__(n_in=2, n_out=2) - self._n_units = n_units - self._forget_bias = forget_bias - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - - def forward(self, inputs): - x, lstm_state = inputs - - # LSTM state consists of c and h. - c, h = jnp.split(lstm_state, 2, axis=-1) - - # Dense layer on the concatenation of x and h. - w, b = self.weights - y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = jnp.split(y, 4, axis=-1) - - new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j) - new_h = jnp.tanh(new_c) * fastmath.sigmoid(o) - return new_h, jnp.concatenate([new_c, new_h], axis=-1) - - def init_weights_and_state(self, input_signature): - # LSTM state last dimension must be twice n_units. - if input_signature[1].shape[-1] != 2 * self._n_units: - raise ValueError( - f'Last dimension of state (shape: {str(input_signature[1].shape)}) ' - f'must be equal to 2*n_units ({2 * self._n_units})') - # The dense layer input is the input and half of the lstm state. - input_shape = input_signature[0].shape[-1] + self._n_units - rng1, rng2 = fastmath.random.split(self.rng, 2) - w = self._kernel_initializer((input_shape, 4 * self._n_units), rng1) - b = self._bias_initializer((4 * self._n_units,), rng2) + self._forget_bias - self.weights = (w, b) + """LSTM Cell. + + For a nice overview of the motivation and (i, o, f) gates, see this tutorial: + https://colah.github.io/posts/2015-08-Understanding-LSTMs/ + + See this paper for a description and detailed study of all gate types: + https://arxiv.org/pdf/1503.04069.pdf + """ + + def __init__( + self, + n_units, + forget_bias=1.0, + kernel_initializer=initializers.GlorotUniformInitializer(), + bias_initializer=initializers.RandomNormalInitializer(1e-6), + ): + super().__init__(n_in=2, n_out=2) + self._n_units = n_units + self._forget_bias = forget_bias + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + def forward(self, inputs): + x, lstm_state = inputs + + # LSTM state consists of c and h. + c, h = jnp.split(lstm_state, 2, axis=-1) + + # Dense layer on the concatenation of x and h. + w, b = self.weights + y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = jnp.split(y, 4, axis=-1) + + new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j) + new_h = jnp.tanh(new_c) * fastmath.sigmoid(o) + return new_h, jnp.concatenate([new_c, new_h], axis=-1) + + def init_weights_and_state(self, input_signature): + # LSTM state last dimension must be twice n_units. + if input_signature[1].shape[-1] != 2 * self._n_units: + raise ValueError( + f"Last dimension of state (shape: {str(input_signature[1].shape)}) " + f"must be equal to 2*n_units ({2 * self._n_units})" + ) + # The dense layer input is the input and half of the lstm state. + input_shape = input_signature[0].shape[-1] + self._n_units + rng1, rng2 = fastmath.random.split(self.rng, 2) + w = self._kernel_initializer((input_shape, 4 * self._n_units), rng1) + b = self._bias_initializer((4 * self._n_units,), rng2) + self._forget_bias + self.weights = (w, b) def MakeZeroState(depth_multiplier=1): - """Makes zeros of shape like x but removing the length (axis 1).""" - def f(x): # pylint: disable=invalid-name - if len(x.shape) != 3: - raise ValueError(f'Layer input should be a rank 3 tensor representing' - f' (batch_size, sequence_length, feature_depth); ' - f'instead got shape {x.shape}.') - return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]), - dtype=jnp.float32) - return base.Fn('MakeZeroState', f) - - -def LSTM(n_units, mode='train', return_state=False, initial_state=False): - """LSTM running on axis 1. - - Args: - n_units: `n_units` for the `LSTMCell`. - mode: if 'predict' then we save the previous state for one-by-one inference. - return_state: Boolean. Whether to return the latest status in addition to - the output. Default: False. - initial_state: Boolean. If the state RNN (c, h) is to be obtained from the - stack. Default: False. - - Returns: - A LSTM layer. - """ - - if not initial_state: - zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter - if return_state: - return cb.Serial( - cb.Branch([], zero_state), - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - name=f'LSTM_{n_units}', - sublayers_to_print=[]) + """Makes zeros of shape like x but removing the length (axis 1).""" + + def f(x): # pylint: disable=invalid-name + if len(x.shape) != 3: + raise ValueError( + f"Layer input should be a rank 3 tensor representing" + f" (batch_size, sequence_length, feature_depth); " + f"instead got shape {x.shape}." + ) + return jnp.zeros( + (x.shape[0], depth_multiplier * x.shape[-1]), dtype=jnp.float32 + ) + + return base.Fn("MakeZeroState", f) + + +def LSTM(n_units, mode="train", return_state=False, initial_state=False): + """LSTM running on axis 1. + + Args: + n_units: `n_units` for the `LSTMCell`. + mode: if 'predict' then we save the previous state for one-by-one inference. + return_state: Boolean. Whether to return the latest status in addition to + the output. Default: False. + initial_state: Boolean. If the state RNN (c, h) is to be obtained from the + stack. Default: False. + + Returns: + A LSTM layer. + """ + + if not initial_state: + zero_state = MakeZeroState( + depth_multiplier=2 + ) # pylint: disable=no-value-for-parameter + if return_state: + return cb.Serial( + cb.Branch([], zero_state), + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) + else: + return cb.Serial( + cb.Branch([], zero_state), # fill state RNN with zero. + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), # Drop RNN state. + # Set the name to LSTM and don't print sublayers. + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) else: - return cb.Serial( - cb.Branch([], zero_state), # fill state RNN with zero. - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - # Set the name to LSTM and don't print sublayers. - name=f'LSTM_{n_units}', sublayers_to_print=[]) - else: - if return_state: - return cb.Serial( - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - name=f'LSTM_{n_units}', sublayers_to_print=[]) - else: - return cb.Serial( - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - name=f'LSTM_{n_units}', sublayers_to_print=[]) + if return_state: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) + else: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), # Drop RNN state. + name=f"LSTM_{n_units}", + sublayers_to_print=[], + ) class GRUCell(base.Layer): - """Builds a traditional GRU cell with dense internal transformations. - - Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555 - """ - - def __init__(self, - n_units, - forget_bias=0.0, - kernel_initializer=initializers.RandomUniformInitializer(0.01), - bias_initializer=initializers.RandomNormalInitializer(1e-6)): - super().__init__(n_in=2, n_out=2) - self._n_units = n_units - self._forget_bias = forget_bias - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - - def forward(self, inputs): - x, gru_state = inputs - - # Dense layer on the concatenation of x and h. - w1, b1, w2, b2 = self.weights - y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 - - # Update and reset gates. - u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1) - - # Candidate. - c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 - - new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) - return new_gru_state, new_gru_state - - def init_weights_and_state(self, input_signature): - if input_signature[1].shape[-1] != self._n_units: - raise ValueError( - f'Second argument in input signature should have a final dimension of' - f' {self._n_units}; instead got {input_signature[1].shape[-1]}.') - - # The dense layer input is the input and half of the GRU state. - input_shape = input_signature[0].shape[-1] + self._n_units - rng1, rng2, rng3, rng4 = fastmath.random.split(self.rng, 4) - w1 = self._kernel_initializer((input_shape, 2 * self._n_units), rng1) - b1 = self._bias_initializer((2 * self._n_units,), rng2) + self._forget_bias - w2 = self._kernel_initializer((input_shape, self._n_units), rng3) - b2 = self._bias_initializer((self._n_units,), rng4) - self.weights = (w1, b1, w2, b2) - - -def GRU(n_units, mode='train'): - """GRU running on axis 1.""" - zero_state = MakeZeroState(depth_multiplier=1) # pylint: disable=no-value-for-parameter - return cb.Serial( - cb.Branch([], zero_state), - cb.Scan(GRUCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - # Set the name to GRU and don't print sublayers. - name=f'GRU_{n_units}', sublayers_to_print=[] - ) + """Builds a traditional GRU cell with dense internal transformations. + + Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555 + """ + + def __init__( + self, + n_units, + forget_bias=0.0, + kernel_initializer=initializers.RandomUniformInitializer(0.01), + bias_initializer=initializers.RandomNormalInitializer(1e-6), + ): + super().__init__(n_in=2, n_out=2) + self._n_units = n_units + self._forget_bias = forget_bias + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + def forward(self, inputs): + x, gru_state = inputs + + # Dense layer on the concatenation of x and h. + w1, b1, w2, b2 = self.weights + y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 + + # Update and reset gates. + u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1) + + # Candidate. + c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 + + new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) + return new_gru_state, new_gru_state + + def init_weights_and_state(self, input_signature): + if input_signature[1].shape[-1] != self._n_units: + raise ValueError( + f"Second argument in input signature should have a final dimension of" + f" {self._n_units}; instead got {input_signature[1].shape[-1]}." + ) + + # The dense layer input is the input and half of the GRU state. + input_shape = input_signature[0].shape[-1] + self._n_units + rng1, rng2, rng3, rng4 = fastmath.random.split(self.rng, 4) + w1 = self._kernel_initializer((input_shape, 2 * self._n_units), rng1) + b1 = self._bias_initializer((2 * self._n_units,), rng2) + self._forget_bias + w2 = self._kernel_initializer((input_shape, self._n_units), rng3) + b2 = self._bias_initializer((self._n_units,), rng4) + self.weights = (w1, b1, w2, b2) + + +def GRU(n_units, mode="train"): + """GRU running on axis 1.""" + zero_state = MakeZeroState( + depth_multiplier=1 + ) # pylint: disable=no-value-for-parameter + return cb.Serial( + cb.Branch([], zero_state), + cb.Scan(GRUCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), # Drop RNN state. + # Set the name to GRU and don't print sublayers. + name=f"GRU_{n_units}", + sublayers_to_print=[], + ) def ConvGRUCell(n_units, kernel_size=(3, 3)): - """Builds a convolutional GRU. + """Builds a convolutional GRU. - Paper: https://arxiv.org/abs/1511.06432. + Paper: https://arxiv.org/abs/1511.06432. - Args: - n_units: Number of hidden units - kernel_size: Kernel size for convolution + Args: + n_units: Number of hidden units + kernel_size: Kernel size for convolution - Returns: - A Stax model representing a GRU cell with convolution transforms. - """ + Returns: + A Stax model representing a GRU cell with convolution transforms. + """ - def BuildConv(): - return convolution.Conv( - filters=n_units, kernel_size=kernel_size, padding='SAME') + def BuildConv(): + return convolution.Conv( + filters=n_units, kernel_size=kernel_size, padding="SAME" + ) - return GeneralGRUCell( - candidate_transform=BuildConv, - memory_transform_fn=None, - gate_nonlinearity=activation_fns.Sigmoid, - candidate_nonlinearity=activation_fns.Tanh) + return GeneralGRUCell( + candidate_transform=BuildConv, + memory_transform_fn=None, + gate_nonlinearity=activation_fns.Sigmoid, + candidate_nonlinearity=activation_fns.Tanh, + ) -def GeneralGRUCell(candidate_transform, - memory_transform_fn=None, - gate_nonlinearity=activation_fns.Sigmoid, - candidate_nonlinearity=activation_fns.Tanh, - dropout_rate_c=0.1, - sigmoid_bias=0.5): - r"""Parametrized Gated Recurrent Unit (GRU) cell construction. +def GeneralGRUCell( + candidate_transform, + memory_transform_fn=None, + gate_nonlinearity=activation_fns.Sigmoid, + candidate_nonlinearity=activation_fns.Tanh, + dropout_rate_c=0.1, + sigmoid_bias=0.5, +): + r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations for update gate, reset gate, candidate memory, and new state: @@ -252,75 +275,85 @@ def GeneralGRUCell(candidate_transform, Returns: A model representing a GRU cell with specified transforms. """ - gate_block = [ # u_t - candidate_transform(), - _AddSigmoidBias(sigmoid_bias), - gate_nonlinearity(), - ] - reset_block = [ # r_t - candidate_transform(), - _AddSigmoidBias(sigmoid_bias), # Want bias to start positive. - gate_nonlinearity(), - ] - candidate_block = [ - cb.Dup(), - reset_block, - cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) - candidate_transform(), # Final projection + tanh to get Ct - candidate_nonlinearity(), # Candidate gate - - # Only apply dropout on the C gate. Paper reports 0.1 as a good default. - core.Dropout(rate=dropout_rate_c) - ] - memory_transform = memory_transform_fn() if memory_transform_fn else [] - return cb.Serial( - cb.Branch(memory_transform, gate_block, candidate_block), - cb.Gate(), - ) + gate_block = [ # u_t + candidate_transform(), + _AddSigmoidBias(sigmoid_bias), + gate_nonlinearity(), + ] + reset_block = [ # r_t + candidate_transform(), + _AddSigmoidBias(sigmoid_bias), # Want bias to start positive. + gate_nonlinearity(), + ] + candidate_block = [ + cb.Dup(), + reset_block, + cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) + candidate_transform(), # Final projection + tanh to get Ct + candidate_nonlinearity(), # Candidate gate + # Only apply dropout on the C gate. Paper reports 0.1 as a good default. + core.Dropout(rate=dropout_rate_c), + ] + memory_transform = memory_transform_fn() if memory_transform_fn else [] + return cb.Serial( + cb.Branch(memory_transform, gate_block, candidate_block), + cb.Gate(), + ) def InnerSRUCell(): - """The inner (non-parallel) computation of an SRU.""" - def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name - res = cur_f * cur_state + cur_x_times_one_minus_f - return res, res - return base.Fn('InnerSRUCell', f, n_out=2) - - -def ScanSRUCell(mode, monkey_patched_mask=None): - """The inner (non-parallel) computation of an SRU.""" - if monkey_patched_mask is None: - return cb.Scan(InnerSRUCell(), axis=1, mode=mode) - - # This is necessary for Terraformer model. See comments there. - # The mask will only be used in Terraformer in predict mode. - assert mode == 'predict' - - def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name - initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32) - if initial.shape[1] > 1: - updated_mask = fastmath.dynamic_update_slice_in_dim( - initial != 0, mask != 0, 1, axis=1) - else: - updated_mask = initial - return updated_mask, x_times_one_minus_f + """The inner (non-parallel) computation of an SRU.""" - def masked_inner_sru_cell(cur_mask, cur_x_times_one_minus_f, cur_f, # pylint: disable=invalid-name - cur_state): - res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask - + (1 - cur_mask) * cur_state) - return res, res + def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name + res = cur_f * cur_state + cur_x_times_one_minus_f + return res, res - return cb.Serial( - monkey_patched_mask.get_layer(), - base.Fn('update_mask', update_mask, n_out=2), - cb.Scan(base.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2), - axis=1, mode=mode), - ) + return base.Fn("InnerSRUCell", f, n_out=2) -def SRU(n_units, activation=None, mode='train'): - r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. +def ScanSRUCell(mode, monkey_patched_mask=None): + """The inner (non-parallel) computation of an SRU.""" + if monkey_patched_mask is None: + return cb.Scan(InnerSRUCell(), axis=1, mode=mode) + + # This is necessary for Terraformer model. See comments there. + # The mask will only be used in Terraformer in predict mode. + assert mode == "predict" + + def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name + initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32) + if initial.shape[1] > 1: + updated_mask = fastmath.dynamic_update_slice_in_dim( + initial != 0, mask != 0, 1, axis=1 + ) + else: + updated_mask = initial + return updated_mask, x_times_one_minus_f + + def masked_inner_sru_cell( + cur_mask, + cur_x_times_one_minus_f, + cur_f, # pylint: disable=invalid-name + cur_state, + ): + res = (cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask + ( + 1 - cur_mask + ) * cur_state + return res, res + + return cb.Serial( + monkey_patched_mask.get_layer(), + base.Fn("update_mask", update_mask, n_out=2), + cb.Scan( + base.Fn("MaskedInnerSRUCell", masked_inner_sru_cell, n_out=2), + axis=1, + mode=mode, + ), + ) + + +def SRU(n_units, activation=None, mode="train"): + r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: @@ -343,24 +376,26 @@ def SRU(n_units, activation=None, mode='train'): Returns: The SRU layer. """ - sigmoid_activation = activation_fns.Sigmoid() - return cb.Serial( # x - cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x - cb.Split(n_items=3), # r, f, y, x - cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x - base.Fn('', - lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x - n_out=3), - cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), - ScanSRUCell(mode=mode), - cb.Select([0], n_in=2), # act(c), r, x - activation if activation is not None else [], - base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), - # Set the name to SRU and don't print sublayers. - name=f'SRU_{n_units}', sublayers_to_print=[] - ) + sigmoid_activation = activation_fns.Sigmoid() + return cb.Serial( # x + cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x + cb.Split(n_items=3), # r, f, y, x + cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x + base.Fn( + "", + lambda r, f, y: (y * (1.0 - f), f, r), + n_out=3, # y * (1 - f), f, r, x + ), + cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), + ScanSRUCell(mode=mode), + cb.Select([0], n_in=2), # act(c), r, x + activation if activation is not None else [], + base.Fn("FinalSRUGate", lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), + # Set the name to SRU and don't print sublayers. + name=f"SRU_{n_units}", + sublayers_to_print=[], + ) def _AddSigmoidBias(sigmoid_bias): - return base.Fn('AddSigmoidBias({sigmoid_bias})', - lambda x: x + sigmoid_bias) + return base.Fn("AddSigmoidBias({sigmoid_bias})", lambda x: x + sigmoid_bias) diff --git a/trax/layers/rnn_test.py b/trax/layers/rnn_test.py deleted file mode 100644 index 40e128785..000000000 --- a/trax/layers/rnn_test.py +++ /dev/null @@ -1,77 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for rnn layers.""" - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -import trax.layers as tl - - -BACKENDS = [fastmath.Backend.JAX] - - -@parameterized.named_parameters( - ('_' + b.value, b) for b in BACKENDS) -class RnnTest(parameterized.TestCase): - - def test_conv_gru_cell(self, backend): - with fastmath.use_backend(backend): - layer = tl.ConvGRUCell(9, kernel_size=(3, 3)) - x = np.ones((8, 1, 7, 9)) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_gru_cell(self, backend): - with fastmath.use_backend(backend): - layer = tl.GRUCell(9) - xs = [np.ones((8, 7, 9)), np.ones((8, 7, 9))] - _, _ = layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual([y.shape for y in ys], [(8, 7, 9), (8, 7, 9)]) - - def test_lstm_cell(self, backend): - with fastmath.use_backend(backend): - layer = tl.LSTMCell(9) - xs = [np.ones((8, 9)), np.ones((8, 18))] - _, _ = layer.init(shapes.signature(xs)) - ys = layer(xs) - self.assertEqual([y.shape for y in ys], [(8, 9), (8, 18)]) - - def test_sru(self, backend): - with fastmath.use_backend(backend): - layer = tl.SRU(7) - x = np.ones((8, 9, 7), np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) - - def test_names(self, backend): - with fastmath.use_backend(backend): - layer = tl.LSTM(3) - self.assertEqual('LSTM_3', str(layer)) - layer = tl.GRU(5) - self.assertEqual('GRU_5', str(layer)) - layer = tl.SRU(7) - self.assertEqual('SRU_7', str(layer)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/layers/test_utils.py b/trax/layers/test_utils.py deleted file mode 100644 index 156220314..000000000 --- a/trax/layers/test_utils.py +++ /dev/null @@ -1,283 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for testing.""" - -import copy -import functools - -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes - - -def test_eval_is_deterministic(inp, model_fn, message=''): - """Utility method for testing if eval mode is deterministic. - - Args: - inp: input fed to the model. It can be a tensor, or a tuple of tensors. - model_fn: function creating a model after calling with `mode` argument. - message: Optional message to show when outputs of eval/predict mode don't - match. - """ - with fastmath.use_backend(fastmath.Backend.JAX): - model_eval1 = model_fn(mode='eval') - model_eval2 = model_fn(mode='eval') - - input_signature = shapes.signature(inp) - model_eval1.init(input_signature) - model_eval2.init(input_signature) - model_eval1.save_to_file('/tmp/unique_weights') - model_eval2.init_from_file('/tmp/unique_weights', weights_only=True, - input_signature=input_signature) - - rng = fastmath.random.get_prng(0) - output_eval1 = model_eval1(inp, rng=rng) - if not isinstance(output_eval1, (tuple, list)): - # We will automatically check each and every tensor returned. - output_eval1 = [output_eval1] - - output_eval2 = model_eval2(inp, rng=rng) - if not isinstance(output_eval2, (tuple, list)): - # We will automatically check each and every tensor returned. - output_eval2 = [output_eval2] - - np.testing.assert_equal(len(output_eval1), len(output_eval2)) - for out1, out2 in zip(output_eval1, output_eval2): - np.testing.assert_array_almost_equal( - out1, - out2, - decimal=5, - err_msg='Non-deterministic.{}'.format(message)) - - -def test_eval_equals_predict(inp, model_fn, seq_axis=1, seq_tensor=None, - init_tokens=3, message=''): - """Utility method for testing equivalence of predict and eval modes. - - Args: - inp: input fed to the model. It can be a tensor, or a tuple of tensors. - model_fn: function creating a model after calling with `mode` argument. - seq_axis: axis of sequence_length. In predict mode we iterate over this - axis. By default `1`, which is 2nd dimension. - seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor - in this tuple on which we iterate the sequence. - init_tokens: how many tokens should be passed to the first `predict` call. - message: Optional message to show when outputs of eval/predict mode don't - match. - """ - with fastmath.use_backend(fastmath.Backend.JAX): - model_eval = model_fn(mode='eval') - model_predict = model_fn(mode='predict') - - input_signature = shapes.signature(inp) - model_eval.init(input_signature) - model_predict.init(input_signature) - model_eval.save_to_file('/tmp/unique_weights') - model_predict.init_from_file('/tmp/unique_weights', weights_only=True, - input_signature=input_signature) - - rng = fastmath.random.get_prng(0) - output_eval = model_eval(inp, rng=rng) - if not isinstance(output_eval, (tuple, list)): - # We will automatically check each and every tensor returned. - output_eval = [output_eval] - - if seq_tensor is None: - length = inp.shape[seq_axis] - else: - length = inp[seq_tensor].shape[seq_axis] - - assert length >= init_tokens + 2 # Required to properly test predict mode. - indices_list = [(0, init_tokens)] + [(i, i+1) - for i in range(init_tokens, length)] - - for indices in indices_list: - start, end = indices - if seq_tensor is None: - new_inp = inp.take(indices=np.arange(start, end), axis=seq_axis) - else: - new_inp = list(inp) - new_inp[seq_tensor] = new_inp[seq_tensor].take( - indices=np.arange(start, end), axis=seq_axis) - - output_predict = model_predict(new_inp, rng=rng) - if not isinstance(output_predict, (tuple, list)): - # We will automatically check each and every tensor returned. - output_predict = [output_predict] - - np.testing.assert_equal(len(output_predict), len(output_eval)) - for outp, oute in zip(output_predict, output_eval): - np.testing.assert_array_almost_equal( - oute.take(indices=np.arange(start, end), axis=seq_axis), - outp.take(indices=np.arange(0, end-start), axis=seq_axis), - decimal=5, - err_msg='Error on element {} out of {}.{}'.format(indices, length, - message)) - - -def test_eval_equals_predict_configs(inp, model_fn, configs, seq_axis=1, - seq_tensor=None, message=''): - """Utility method for testing equivalence of predict and eval modes. - - This function iterates over a list of dictionaries `confis`, and runs the test - on models with each configuration. - - Args: - inp: input fed to the model. It can be a tensor, or a tuple of tensors. - model_fn: function creating a model after calling with `mode` argument. - configs: List of dictionaries, which contain configs to be fed into - `model_fn`. - seq_axis: axis of sequence_length. In predict mode we iterate over this - axis. By default `1`, which is 2nd dimension. - seq_tensor: if `inp` is a tuple, `seq_tensor` is an index of an input tensor - in this tuple on which we iterate the sequence. - message: Optional message to show when outputs of eval/predict mode don't - match. - """ - for config in configs: - model_fn_configured = functools.partial(model_fn, **config) - test_eval_equals_predict(inp, model_fn_configured, seq_axis=seq_axis, - seq_tensor=seq_tensor, - message=' Config: {}.{}'.format(config, message)) - - -def test_eval_equals_predict_discrete( - model_fn, vocab_size=10, length=5, batch_size=3 -): - """Tests the equivalence of eval and predict modes for discrete models.""" - with fastmath.use_backend(fastmath.Backend.JAX): - model_slow = model_fn(mode='eval', vocab_size=vocab_size) - model_fast = model_fn(mode='predict', vocab_size=vocab_size) - rng = fastmath.random.get_prng(0) - input_signature = shapes.ShapeDtype((batch_size, 1), np.int32) - # Given the same rng, both models initialize with the same parameters. - model_slow.init(input_signature, rng) - model_fast.init(input_signature, rng) - - buf = np.zeros((batch_size, length), dtype=np.int32) - next_sym = np.zeros((batch_size, 1), dtype=np.int32) - - for index in range(length): - logits_slow = model_slow(buf, rng=rng) - logits_fast = model_fast(next_sym, rng=rng) - np.testing.assert_array_almost_equal( - logits_slow[:, index, :], logits_fast[:, 0, :], - decimal=5, - ) - next_sym = np.random.randint(vocab_size, size=(batch_size, 1)) - buf[:, index] = next_sym[:, 0] - - -class MockTransformerLM(tl.Layer): - r"""Mock TransformerLM for testing autoregressive sampling routines. - - Mimics the behavior of a perfectly-trained, deterministic TransformerLM. - Allows to specify the \sigma^* -> \sigma function implemented by the model - and to make assertions about the input sequence passed to the model. - - Supports two modes: stateful "predict" for fast inference, and stateless - non-"predict" ("train", "eval" etc). - - Useful for testing any logic that relies on autoregressive sampling, as it - removes the additional layer of complexity related to training a model or - maintaining a pretrained one. Makes the tests run MUCH faster. - - Does not support acceleration. Do not wrap in tl.Accelerate(). - """ - - def __init__(self, sequence_fn, mode, vocab_size): - super().__init__() - - self._sequence_fn = sequence_fn - self._mode = mode - self._vocab_size = vocab_size - - self._prediction_buffers = None - - @property - def state(self): - return copy.deepcopy(self._prediction_buffers) - - @state.setter - def state(self, state): - self._prediction_buffers = copy.deepcopy(state) - - def _output_symbol_predict(self, input_symbols, prediction_buffer): - prediction_buffer.extend(input_symbols) - output_symbol = self._sequence_fn(np.array(prediction_buffer)) - return np.array([output_symbol]) - - def _output_symbols_eval(self, input_symbols, prediction_buffer): - del prediction_buffer - - # Add a leading 0 token to imitate ShiftRight. - input_symbols = np.concatenate(([0], input_symbols)) - - # Call sequence_fn repeatedly along the input sequence. - return np.array([ - self._sequence_fn(input_symbols[:end]) - for end in range(1, len(input_symbols)) - ]) - - def _symbols_to_logits(self, symbols): - # Assert that symbols are discrete. - assert np.issubdtype(symbols.dtype, np.integer) - # Assert that 0 <= symbols < vocab_size. - np.testing.assert_array_less(-1, symbols) - np.testing.assert_array_less(symbols, self._vocab_size) - - # Return almost-determinisitc logits: - # e^1000 / (e^1000 + vocab_size) ~= 1 - return tl.one_hot(symbols, n_categories=self._vocab_size) * 1000.0 - - def __call__(self, inputs, rng=None): - del rng - - assert inputs.ndim == 2, ( - 'The input sequences should have exactly two axes.' - ) - - if self._prediction_buffers is None: - # Initialize the buffer. - batch_size = inputs.shape[0] - # [[]] * batch_size would create multiple references to the same - # list, and we want separate lists. - self._prediction_buffers = [[] for _ in range(batch_size)] - - if self._mode == 'predict': - output_fn = self._output_symbol_predict - else: - output_fn = self._output_symbols_eval - - # Calculate the output separately for each sequence in the batch. - output_symbols = np.array([ - output_fn(input_seq, pred_buffer) - for (input_seq, pred_buffer) in zip( - inputs, self._prediction_buffers - ) - ]) - return self._symbols_to_logits(output_symbols) - - def assert_prediction_buffers_equal(self, expected_buffers): - if self._prediction_buffers is None: - batch_size = expected_buffers.shape[0] - actual_buffers = np.empty((batch_size, 0)) - else: - actual_buffers = np.array(self._prediction_buffers) - - np.testing.assert_array_equal(actual_buffers, expected_buffers) diff --git a/trax/layers/test_utils_test.py b/trax/layers/test_utils_test.py deleted file mode 100644 index e21a50dbd..000000000 --- a/trax/layers/test_utils_test.py +++ /dev/null @@ -1,91 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.layers.test_utils.""" - -import functools - -from absl.testing import absltest -import numpy as np - -from trax.layers import test_utils -from trax.supervised import decoding - - -def arithmetic_sequence(input_seq, limit=10): - # Increment the last symbol. Wrap to [0, 10). - return (input_seq[-1] + 1) % limit - - -class TestUtilsTest(absltest.TestCase): - - def test_mock_transformer_lm_eval_equals_predict(self): - model_fn = functools.partial( - test_utils.MockTransformerLM, - sequence_fn=arithmetic_sequence, - vocab_size=10, - ) - test_utils.test_eval_equals_predict_discrete(model_fn, vocab_size=10) - - def test_mock_transformer_lm_decodes_arithmetic_sequence(self): - model = test_utils.MockTransformerLM( - sequence_fn=arithmetic_sequence, - vocab_size=10, - mode='predict', - ) - output = decoding.autoregressive_sample( - model, max_length=5, start_id=0, eos_id=-1, accelerate=False - ) - - # Sequence including the leading 0 and the last predicted symbol. - full_seq = list(range(6)) - # decoding.autoregressive_sample doesn't return the leading 0. - np.testing.assert_array_equal(output, [full_seq[1:]]) - # The prediction buffers don't include the last predicted symbol. - model.assert_prediction_buffers_equal([full_seq[:-1]]) - - def test_mock_transformer_lm_rewinds(self): - model = test_utils.MockTransformerLM( - sequence_fn=arithmetic_sequence, - vocab_size=10, - mode='predict', - ) - sample_3 = functools.partial( - decoding.autoregressive_sample, - max_length=3, - eos_id=-1, - accelerate=False, - ) - - # Generate the 3 initial symbols. - init_output = sample_3(model, start_id=0) - np.testing.assert_array_equal(init_output, [[1, 2, 3]]) - state = model.state - - # Generate the next 3 symbols. - next_output = sample_3(model, start_id=init_output[0, -1]) - np.testing.assert_array_equal(next_output, [[4, 5, 6]]) - - # Rewind and generate the last 3 symbols again. - model.state = state - next_output = sample_3(model, start_id=init_output[0, -1]) - np.testing.assert_array_equal(next_output, [[4, 5, 6]]) - - # Check the buffers. - model.assert_prediction_buffers_equal([[0, 1, 2, 3, 4, 5]]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/learning/__init__.py b/trax/learning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/rl/__init__.py b/trax/learning/reinforcement/__init__.py similarity index 68% rename from trax/rl/__init__.py rename to trax/learning/reinforcement/__init__.py index e26fba012..3ac91831d 100644 --- a/trax/rl/__init__.py +++ b/trax/learning/reinforcement/__init__.py @@ -17,17 +17,27 @@ import gin -from trax.rl import actor_critic -from trax.rl import actor_critic_joint -from trax.rl import envs -from trax.rl import serialization_utils -from trax.rl import training +from learning.reinforcement import ( + actor_critic, + actor_critic_joint, + serialization_utils, + training, +) def configure_rl(*args, **kwargs): - kwargs['module'] = 'trax.rl' - kwargs['denylist'] = ['task', 'output_dir'] - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.reinforcement" + kwargs["denylist"] = ["task", "output_dir"] + return gin.external_configurable(*args, **kwargs) + + +gin.enter_interactive_mode() + + +@gin.configurable(module="trax.reinforcement") +def every(n_steps): + """Returns True every n_steps, for use as *_at functions in various places.""" + return lambda step: step % n_steps == 0 A2C = configure_rl(actor_critic.A2C) @@ -45,5 +55,5 @@ def configure_rl(*args, **kwargs): DQN = configure_rl(training.DQN) TimeSeriesModel = gin.external_configurable( - serialization_utils.TimeSeriesModel, module='trax.rl' + serialization_utils.TimeSeriesModel, module="trax.reinforcement" ) diff --git a/trax/learning/reinforcement/actor_critic.py b/trax/learning/reinforcement/actor_critic.py new file mode 100644 index 000000000..349d41f23 --- /dev/null +++ b/trax/learning/reinforcement/actor_critic.py @@ -0,0 +1,1286 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for RL training in Trax.""" + +import functools +import os + +import gym +import numpy as np +import tensorflow as tf + +from learning.reinforcement import advantages as rl_advantages +from learning.reinforcement import distributions, policy_tasks, value_tasks +from learning.reinforcement import training as rl_training + +from trax import data, fastmath +from trax import layers as tl +from trax.fastmath import numpy as jnp +from trax.learning import supervised +from trax.learning.supervised import lr_schedules as lr +from trax.optimizers import adam +from trax.utils import shapes + + +class ActorCriticAgent(rl_training.PolicyAgent): + """Trains policy and value models using actor-critic methods. + + Attrs: + on_policy (bool): Whether the algorithm is on-policy. Used in the data + generators. Should be set in derived classes. + """ + + on_policy = None + + def __init__( + self, + task, + value_model=None, + value_optimizer=None, + value_lr_schedule=lr.multifactor, + value_batch_size=64, + value_train_steps_per_epoch=500, + value_evals_per_epoch=1, + value_eval_steps=1, + n_shared_layers=0, + added_policy_slice_length=0, + n_replay_epochs=1, + scale_value_targets=False, + q_value=False, + q_value_aggregate="logsumexp", + q_value_temperature=1.0, + q_value_n_samples=1, + q_value_normalization=False, + offline=False, + **kwargs, + ): # Arguments of PolicyAgent come here. + """Configures the actor-critic trainers. + + Args: + task: `RLTask` instance to use. + value_model: Model to use for the value function. + value_optimizer: Optimizer to train the value model. + value_lr_schedule: lr schedule for value model training. + value_batch_size: Batch size for value model training. + value_train_steps_per_epoch: Number of steps are we using to train the + value model in each epoch. + value_evals_per_epoch: Number of value trainers evaluations per RL epoch. + Every evaluation, we also synchronize the weights of the target + network. + value_eval_steps: Number of value trainers steps per evaluation; only + affects metric reporting. + n_shared_layers: Number of layers to share between value and policy + models. + added_policy_slice_length: How much longer should slices of + trajectories be for policy than for value training; this + is useful for TD calculations and only affect the length + of elements produced for policy batches; value batches + have maximum length set by `max_slice_length` in `**kwargs`. + n_replay_epochs: Number of last epochs to take into the replay buffer; + only makes sense for off-policy algorithms. + scale_value_targets: If `True`, scale value function targets by + `1 / (1 - gamma)`. + q_value: If `True`, use Q-values as baselines. + q_value_aggregate: How to aggregate Q-values. Options: 'mean', 'max', + 'softmax', 'logsumexp'. + q_value_temperature: Temperature parameter for the 'softmax' and + 'logsumexp' aggregation methods. + q_value_n_samples: Number of samples to average over when calculating + baselines based on Q-values. + q_value_normalization: How to normalize Q-values before aggregation. + Allowed values: 'std', 'abs', `None`. If `None`, don't normalize. + offline: Whether to train in offline mode. This matters for some + algorithms, e.g. QWR. + **kwargs: Arguments for `PolicyAgent` superclass. + """ + self._n_shared_layers = n_shared_layers + self._value_batch_size = value_batch_size + self._value_train_steps_per_epoch = value_train_steps_per_epoch + self._value_evals_per_epoch = value_evals_per_epoch + self._value_eval_steps = value_eval_steps + + # The 2 below will be initalized in super.__init__ anyway, but are needed + # to construct value batches which are needed before PolicyAgent init + # since policy input creation calls the value model -- hence this code. + self._task = task + self._max_slice_length = kwargs.get("max_slice_length", 1) + self._added_policy_slice_length = added_policy_slice_length + self._n_replay_epochs = n_replay_epochs + task.set_n_replay_epochs(n_replay_epochs) + + if scale_value_targets: + self._value_network_scale = 1 / (1 - self._task.gamma) + else: + self._value_network_scale = 1 + + self._q_value = q_value + self._q_value_aggregate = q_value_aggregate + self._q_value_temperature = q_value_temperature + self._q_value_n_samples = q_value_n_samples + self._q_value_normalization = q_value_normalization + + is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete) + self._is_discrete = is_discrete + self._vocab_size = None + self._sample_all_discrete_actions = False + if q_value and is_discrete: + self._vocab_size = self.task.action_space.n + # TODO(lukaszkaiser): the code below is specific to AWR, move it. + # If n_samples = n_actions, we'll take them all in actor and reweight. + if self._q_value_n_samples == self._vocab_size: + # TODO(lukaszkaiser): set this explicitly once it's in AWR Trainer. + self._sample_all_discrete_actions = True + if offline and is_discrete: + raise NotImplementedError( + "Offline training is only supported for continuous action spaces for " + "now." + ) + self._offline = offline + + if q_value: + value_model = functools.partial( + value_model, + inject_actions=True, + is_discrete=is_discrete, + vocab_size=self._vocab_size, + ) + self._value_eval_model = value_model(mode="eval") + self._value_eval_model.init(self._value_model_signature) + self._value_eval_jit = tl.jit_forward( + self._value_eval_model.pure_fn, fastmath.local_device_count(), do_mean=False + ) + + # Initialize policy training. + super().__init__(task, **kwargs) + + # Initialize training of the value function. + value_output_dir = kwargs.get("output_dir", None) + if value_output_dir is not None: + value_output_dir = os.path.join(value_output_dir, "value") + # If needed, create value_output_dir and missing parent directories. + if not tf.io.gfile.isdir(value_output_dir): + tf.io.gfile.makedirs(value_output_dir) + self._value_inputs = data.inputs.Inputs( + train_stream=lambda _: self.value_batches_stream() + ) + self._value_trainer = supervised.Trainer( + model=value_model, + optimizer=value_optimizer, + lr_schedule=value_lr_schedule(), + loss_fn=tl.L2Loss(), + inputs=self._value_inputs, + output_dir=value_output_dir, + metrics={"value_loss": tl.L2Loss(), "value_mean": self.value_mean}, + ) + + @property + def value_mean(self): + """The mean value of the value function.""" + + # TODO(henrykm): A better solution would take into account the masks + def f(values): + return jnp.mean(values) + + return tl.Fn("ValueMean", f) + + @property + def _value_model_signature(self): + obs_sig = shapes.signature(self._task.observation_space) + target_sig = mask_sig = shapes.ShapeDtype( + shape=(1, 1, 1), + ) + inputs_sig = (obs_sig.replace(shape=(1, 1) + obs_sig.shape),) + if self._q_value: + act_sig = shapes.signature(self._task.action_space) + inputs_sig += (act_sig.replace(shape=(1, 1) + act_sig.shape),) + return (*inputs_sig, target_sig, mask_sig) + + @property + def _replay_epochs(self): + if self.on_policy: + assert self._n_replay_epochs == 1, ( + "Non-unit replay buffer size only makes sense for off-policy " + "algorithms." + ) + return [-(ep + 1) for ep in range(self._n_replay_epochs)] + + def _run_value_model(self, observations, dist_inputs): + if dist_inputs is None: + dist_inputs = jnp.zeros( + observations.shape[:2] + (self._policy_dist.n_inputs,) + ) + + actions = None + if self._q_value: + if self._sample_all_discrete_actions: + # Since we want to sample all actions, start by creating their list. + act = np.arange(self._vocab_size) + # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it. + # Add extra dimenstions so it's the same dimensionality as dist_inputs. + act = jnp.reshape(act, [-1] + [1] * (len(dist_inputs.shape) - 1)) + # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs. + dist_inputs = jnp.broadcast_to( + dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape + ) + if self._sample_all_discrete_actions: + actions = act + jnp.zeros(dist_inputs.shape[:-1], dtype=jnp.int32) + actions = jnp.swapaxes(actions, 0, 1) + # Swapping the n_samples and batch_size axes, so the input is split + # between accelerators along the batch_size axis. + dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) + if not self._sample_all_discrete_actions: + actions = self._policy_dist.sample(dist_inputs) + log_probs = self._policy_dist.log_prob(dist_inputs, actions) + obs = observations + obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) + inputs = (obs, actions) + else: + log_probs = None + inputs = (observations,) + + n_devices = fastmath.local_device_count() + weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) + state = tl.for_n_devices(self._value_eval_model.state, n_devices) + rng = self._value_eval_model.rng + values, _ = self._value_eval_jit(inputs, weights, state, rng) + values *= self._value_network_scale + values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. + return (values, actions, log_probs) + + def _aggregate_values(self, values, aggregate, act_log_probs): + # Normalize the Q-values before aggragetion, so it can adapt to the scale + # of the returns. This does not affect mean and max aggregation. + scale = 1 + epsilon = 1e-5 + if self._q_value_normalization == "std": + scale = jnp.std(values) + epsilon + elif self._q_value_normalization == "abs": + scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon + values /= scale + + temp = self._q_value_temperature + if self._q_value: + assert values.shape[:2] == (self._value_batch_size, self._q_value_n_samples) + if aggregate == "max": + # max_a Q(s, a) + values = jnp.max(values, axis=1) + elif aggregate == "softmax": + # sum_a (Q(s, a) * w(s, a)) + # where w(s, .) = softmax (Q(s, .) / T) + weights = tl.Softmax(axis=1)(values / temp) + values = jnp.sum(values * weights, axis=1) + elif aggregate == "logsumexp": + # log(mean_a exp(Q(s, a) / T)) * T + n = values.shape[1] + values = (fastmath.logsumexp(values / temp, axis=1) - jnp.log(n)) * temp + else: + assert aggregate == "mean" + # mean_a Q(s, a) + if self._sample_all_discrete_actions: + values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) + else: + values = jnp.mean(values, axis=1) + + # Re-scale the Q-values after aggregation. + values *= scale + return np.array(values) # Move the values to CPU. + + def _get_dist_inputs(self, trajectory): + if not self._offline: + return trajectory.dist_inputs + else: + return trajectory.action + + def value_batches_stream(self): + """Use the RLTask self._task to create inputs to the value model.""" + max_slice_length = self._max_slice_length + self._added_policy_slice_length + for np_trajectory in self._task.trajectory_batch_stream( + self._value_batch_size, + max_slice_length=max_slice_length, + min_slice_length=(1 + self._added_policy_slice_length), + margin=self._added_policy_slice_length, + epochs=self._replay_epochs, + ): + dist_inputs = self._get_dist_inputs(np_trajectory) + (values, _, act_log_probs) = self._run_value_model( + np_trajectory.observation, dist_inputs + ) + values = self._aggregate_values( + values, self._q_value_aggregate, act_log_probs + ) + + # TODO(pkozakowski): Add some shape assertions and docs. + # Calculate targets based on the advantages over the target network - this + # allows TD learning for value networks. + advantages = self._advantage_estimator( + rewards=np_trajectory.reward, + returns=np_trajectory.return_, + values=values, + dones=np_trajectory.done, + discount_mask=np_trajectory.env_info.discount_mask, + ) + length = advantages.shape[1] + values = values[:, :length] + target_returns = values + advantages + + inputs = (np_trajectory.observation[:, :length],) + if self._q_value: + inputs += (np_trajectory.action[:, :length],) + + # Insert an extra depth dimension, so the target shape is consistent with + # the network output shape. + yield ( + # Inputs: observations and maybe actions. + *inputs, + # Targets: computed returns. + target_returns[:, :, None] / self._value_network_scale, + # Mask to zero-out padding. + np_trajectory.mask[:, :length, None], + ) + + def policy_inputs(self, trajectory, values): + """Create inputs to policy model from a TimeStepBatch and values. + + Args: + trajectory: a TimeStepBatch, the trajectory to create inputs from + values: a numpy array: value function computed on trajectory + + Returns: + a tuple of numpy arrays of the form (inputs, x1, x2, ...) that will be + passed to the policy model; policy model will compute outputs from + inputs and (outputs, x1, x2, ...) will be passed to self.policy_loss + which should be overridden accordingly. + """ + return NotImplementedError + + def policy_batches_stream(self): + """Use the RLTask self._task to create inputs to the policy model.""" + # Maximum slice length for policy is max_slice_len + the added policy len. + max_slice_length = self._max_slice_length + self._added_policy_slice_length + for np_trajectory in self._task.trajectory_batch_stream( + self._policy_batch_size, + epochs=self._replay_epochs, + max_slice_length=max_slice_length, + margin=self._added_policy_slice_length, + ): + dist_inputs = self._get_dist_inputs(np_trajectory) + (values, _, act_log_probs) = self._run_value_model( + np_trajectory.observation, dist_inputs + ) + values = self._aggregate_values(values, "mean", act_log_probs) + if len(values.shape) != 2: + raise ValueError( + "Values are expected to have shape " + + "[batch_size, length], got: %s" % str(values.shape) + ) + if values.shape[0] != self._policy_batch_size: + raise ValueError( + "Values first dimension should = policy batch size, " + + "%d != %d" % (values.shape[0], self._policy_batch_size) + ) + yield self.policy_inputs(np_trajectory, values) + + def train_epoch(self): + """Trains RL for one epoch.""" + # Copy policy state accumulated during data collection to the trainers. + self._policy_trainer.model_state = self._policy_collect_model.state + + # Copy policy weights and state to value trainers. + if self._n_shared_layers > 0: + _copy_model_weights_and_state( + 0, self._n_shared_layers, self._policy_trainer, self._value_trainer + ) + + # Update the target value network. + self._value_eval_model.weights = self._value_trainer.model_weights + self._value_eval_model.state = self._value_trainer.model_state + + n_value_evals = rl_training.remaining_evals( + self._value_trainer.step, + self._epoch, + self._value_train_steps_per_epoch, + self._value_evals_per_epoch, + ) + for _ in range(n_value_evals): + self._value_trainer.train_epoch( + self._value_train_steps_per_epoch // self._value_evals_per_epoch, + self._value_eval_steps, + ) + # Update the target value network. + self._value_eval_model.weights = self._value_trainer.model_weights + self._value_eval_model.state = self._value_trainer.model_state + + # Copy value weights and state to policy trainers. + if self._n_shared_layers > 0: + _copy_model_weights_and_state( + 0, self._n_shared_layers, self._value_trainer, self._policy_trainer + ) + n_policy_evals = rl_training.remaining_evals( + self._policy_trainer.step, + self._epoch, + self._policy_train_steps_per_epoch, + self._policy_evals_per_epoch, + ) + # Check if there was a restart after value training finishes and policy not. + stopped_after_value = ( + n_value_evals == 0 and n_policy_evals < self._policy_evals_per_epoch + ) + should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value + if should_copy_weights: + _copy_model_weights_and_state( + 0, self._n_shared_layers, self._value_trainer, self._policy_trainer + ) + + # Update the target value network. + self._value_eval_model.weights = self._value_trainer.model_weights + self._value_eval_model.state = self._value_trainer.model_state + + for _ in range(n_policy_evals): + self._policy_trainer.train_epoch( + self._policy_train_steps_per_epoch // self._policy_evals_per_epoch, + self._policy_eval_steps, + ) + + def close(self): + self._value_trainer.close() + super().close() + + +def _copy_model_weights_and_state( # pylint: disable=invalid-name + start, end, from_trainer, to_trainer, copy_optimizer_slots=False +): + """Copy model weights[start:end] from from_trainer to to_trainer.""" + from_weights = from_trainer.model_weights + to_weights = list(to_trainer.model_weights) + shared_weights = from_weights[start:end] + to_weights[start:end] = shared_weights + to_trainer.model_weights = to_weights + + from_state = from_trainer.model_state + to_state = list(to_trainer.model_state) + shared_state = from_state[start:end] + to_state[start:end] = shared_state + to_trainer.model_state = to_state + + if copy_optimizer_slots: + # TODO(lukaszkaiser): make a nicer API in Trainer to support this. + # Currently we use the hack below. Note [0] since that's the model w/o loss. + # pylint: disable=protected-access + from_slots = from_trainer._opt_state.slots[0][start:end] + to_slots = to_trainer._opt_state.slots[0] + # The lines below do to_slots[start:end] = from_slots, but on tuples. + new_slots = to_slots[:start] + from_slots[start:end] + to_slots[end:] + new_slots = tuple([new_slots] + list(to_trainer._opt_state.slots[1:])) + to_trainer._opt_state = to_trainer._opt_state._replace(slots=new_slots) + # pylint: enable=protected-access + + +class AdvantageBasedActorCriticAgent(ActorCriticAgent): + """Base class for advantage-based actor-critic algorithms.""" + + def __init__( + self, + task, + advantage_estimator=rl_advantages.td_lambda, + advantage_normalization=True, + advantage_normalization_epsilon=1e-5, + advantage_normalization_factor=1.0, + added_policy_slice_length=0, + **kwargs, + ): + self._advantage_estimator = advantage_estimator( + gamma=task.gamma, margin=added_policy_slice_length + ) + self._advantage_normalization = advantage_normalization + self._advantage_normalization_epsilon = advantage_normalization_epsilon + self._advantage_normalization_factor = advantage_normalization_factor + super().__init__( + task, added_policy_slice_length=added_policy_slice_length, **kwargs + ) + + def policy_inputs(self, trajectory, values): + """Create inputs to policy model from a TimeStepBatch and values.""" + # How much TD to use is determined by the added policy slice length, + # as the policy batches need to be this much longer to calculate TD. + advantages = self._advantage_estimator( + rewards=trajectory.reward, + returns=trajectory.return_, + values=values, + dones=trajectory.done, + discount_mask=trajectory.env_info.discount_mask, + ) + # Observations should be the same length as advantages - so if we are + # using n_extra_steps, we need to trim the length to match. + obs = trajectory.observation[:, : advantages.shape[1]] + act = trajectory.action[:, : advantages.shape[1]] + mask = trajectory.mask[:, : advantages.shape[1]] # Mask to zero-out padding. + if trajectory.dist_inputs is not None: + dist_inputs = self._get_dist_inputs(trajectory) + dist_inputs = dist_inputs[:, : advantages.shape[1]] + else: + dist_inputs = jnp.zeros(advantages.shape + (self._policy_dist.n_inputs,)) + # Shape checks to help debugging. + if len(advantages.shape) != 2: + raise ValueError( + "Advantages are expected to have shape " + + "[batch_size, length], got: %s" % str(advantages.shape) + ) + if act.shape[0:2] != advantages.shape: + raise ValueError( + "First 2 dimensions of actions should be the same as in " + "advantages, %s != %s" % (act.shape[0:2], advantages.shape) + ) + if obs.shape[0:2] != advantages.shape: + raise ValueError( + "First 2 dimensions of observations should be the same " + "as in advantages, %s != %s" % (obs.shape[0:2], advantages.shape) + ) + if dist_inputs.shape[:2] != advantages.shape: + raise ValueError( + "First 2 dimensions of dist_inputs should be the same " + "as in advantages, %s != %s" % (dist_inputs.shape[:2], advantages.shape) + ) + if mask.shape != advantages.shape: + raise ValueError( + "Mask and advantages shapes should be the same" + ", %s != %s" % (mask.shape, advantages.shape) + ) + return (obs, act, advantages, dist_inputs, mask) + + @property + def policy_loss_given_log_probs(self): + """Policy loss given action log-probabilities.""" + raise NotImplementedError + + def _preprocess_advantages(self, advantages): + if self._advantage_normalization: + advantages = self._advantage_normalization_factor * ( + (advantages - jnp.mean(advantages)) + / (jnp.std(advantages) + self._advantage_normalization_epsilon) + ) + return advantages + + @property + def policy_loss(self, **unused_kwargs): + """Policy loss.""" + + def LossInput( + dist_inputs, actions, advantages, old_dist_inputs + ): # pylint: disable=invalid-name + """Calculates action log probabilities and normalizes advantages.""" + advantages = self._preprocess_advantages(advantages) + log_probs = self._policy_dist.log_prob(dist_inputs, actions) + old_log_probs = self._policy_dist.log_prob(old_dist_inputs, actions) + return (log_probs, advantages, old_log_probs) + + return tl.Serial( + tl.Fn("LossInput", LossInput, n_out=3), + # Policy loss is expected to consume + # (log_probs, advantages, old_log_probs, mask). + self.policy_loss_given_log_probs, + ) + + @property + def policy_metrics(self): + metrics = super().policy_metrics + metrics.update( + { + "advantage_mean": self.advantage_mean, + "advantage_std": self.advantage_std, + } + ) + return metrics + + @property + def advantage_mean(self): + return tl.Serial( + [ + # (dist_inputs, advantages, old_dist_inputs, mask) + tl.Select([1]), # Select just the advantages. + tl.Fn( + "AdvantageMean", lambda x: jnp.mean(x) + ), # pylint: disable=unnecessary-lambda + ] + ) + + @property + def advantage_std(self): + return tl.Serial( + [ + # (dist_inputs, advantages, old_dist_inputs, mask) + tl.Select([1]), # Select just the advantages. + tl.Fn( + "AdvantageStd", lambda x: jnp.std(x) + ), # pylint: disable=unnecessary-lambda + ] + ) + + +# TODO(pkozakowski): Rewrite all interleaved actor-critic algos to subclass +# this, then rename to ActorCriticAgent and remove the other base classes. +class LoopActorCriticAgent(rl_training.Agent): + """Base class for actor-critic algorithms based on `Loop`.""" + + on_policy = None + + def __init__( + self, + task, + model_fn, + optimizer=adam.Adam, + policy_lr_schedule=lr.multifactor, + policy_n_steps_per_epoch=1000, + policy_weight_fn=(lambda x: x), + value_lr_schedule=lr.multifactor, + value_n_steps_per_epoch=1000, + value_sync_at=(lambda x: x % 100 == 0), + advantage_estimator=rl_advantages.monte_carlo, + batch_size=64, + network_eval_at=None, + n_eval_batches=1, + max_slice_length=1, + margin=0, + n_replay_epochs=1, + **kwargs, + ): + """Initializes LoopActorCriticAgent. + + Args: + task: `RLTask` instance to use. + model_fn: Function mode -> Trax model, building a joint policy and value + network. + optimizer: Optimizer for the policy and value networks. + policy_lr_schedule: Learning rate schedule for the policy network. + policy_n_steps_per_epoch: Number of steps to train the policy network for + in each epoch. + policy_weight_fn: Function advantages -> weights for calculating the + log probability weights in policy training. + value_lr_schedule: Learning rate schedule for the value network. + value_n_steps_per_epoch: Number of steps to train the value network for + in each epoch. + value_sync_at: Function step -> bool indicating when to synchronize the + target network with the trained network in value training. + advantage_estimator: Advantage estimator to use in policy and value + training. + batch_size: Batch size for training the networks. + network_eval_at: Function step -> bool indicating in when to evaluate the + networks. + n_eval_batches: Number of batches to compute the network evaluation + metrics on. + max_slice_length: Maximum length of a trajectory slice to train on. + margin: Number of timesteps to add at the end of each trajectory slice for + better advantage estimation. + n_replay_epochs: Number of epochs of trajectories to store in the replay + buffer. + **kwargs: Keyword arguments forwarded to Agent. + """ + super().__init__(task, **kwargs) + + self._policy_dist = distributions.create_distribution(self.task.action_space) + model_fn = functools.partial( + model_fn, + policy_distribution=self._policy_dist, + ) + train_model = model_fn(mode="train") + eval_model = model_fn(mode="eval") + + trajectory_batch_stream = self._init_trajectory_batch_stream( + batch_size, max_slice_length, margin, n_replay_epochs + ) + advantage_estimator = advantage_estimator(task.gamma, margin=margin) + (value_train_task, value_eval_task) = self._init_value_tasks( + trajectory_batch_stream, + optimizer=optimizer(), + lr_schedule=value_lr_schedule(), + advantage_estimator=advantage_estimator, + train_model=train_model, + eval_model=eval_model, + sync_at=value_sync_at, + n_steps_per_epoch=value_n_steps_per_epoch, + n_eval_batches=n_eval_batches, + ) + (policy_train_task, policy_eval_task) = self._init_policy_tasks( + trajectory_batch_stream, + optimizer=optimizer(), + lr_schedule=policy_lr_schedule(), + advantage_estimator=advantage_estimator, + value_train_task=value_train_task, + weight_fn=policy_weight_fn, + n_eval_batches=n_eval_batches, + ) + self._init_loop( + train_model=train_model, + eval_model=eval_model, + policy_train_and_eval_task=(policy_train_task, policy_eval_task), + value_train_and_eval_task=(value_train_task, value_eval_task), + eval_at=network_eval_at, + policy_n_steps_per_epoch=policy_n_steps_per_epoch, + value_n_steps_per_epoch=value_n_steps_per_epoch, + ) + self._init_collection(model_fn, policy_train_task.sample_batch) + + def _init_trajectory_batch_stream( + self, batch_size, max_slice_length, margin, n_replay_epochs + ): + assert self.on_policy is not None, 'Attribute "on_policy" not set.' + if self.on_policy: + assert n_replay_epochs == 1, ( + "Non-unit replay buffer size only makes sense for off-policy " + "algorithms." + ) + self._task.set_n_replay_epochs(n_replay_epochs) + self._max_slice_length = max_slice_length + return self._task.trajectory_batch_stream( + batch_size, + epochs=[-(ep + 1) for ep in range(n_replay_epochs)], + min_slice_length=(1 + margin), + max_slice_length=(self._max_slice_length + margin), + margin=margin, + ) + + def _init_value_tasks( + self, + trajectory_batch_stream, + optimizer, + lr_schedule, + advantage_estimator, + train_model, + eval_model, + sync_at, + n_steps_per_epoch, + n_eval_batches, + ): + def sync_also_at_epoch_boundaries(step): + return sync_at(step) or ( + # 0 - end of the epoch, 1 - beginning of the next. + step % n_steps_per_epoch + in (0, 1) + ) + + head_selector = tl.Select([1]) + value_train_task = value_tasks.ValueTrainTask( + trajectory_batch_stream, + optimizer, + lr_schedule, + advantage_estimator=advantage_estimator, + model=train_model, + target_model=eval_model, + target_scale=(1 - self.task.gamma), + sync_at=sync_also_at_epoch_boundaries, + head_selector=head_selector, + ) + value_eval_task = value_tasks.ValueEvalTask( + value_train_task, n_eval_batches, head_selector + ) + return (value_train_task, value_eval_task) + + def _init_policy_tasks( + self, + trajectory_batch_stream, + optimizer, + lr_schedule, + advantage_estimator, + value_train_task, + weight_fn, + n_eval_batches, + ): + head_selector = tl.Select([0], n_in=2) + policy_train_task = policy_tasks.PolicyTrainTask( + trajectory_batch_stream, + optimizer, + lr_schedule, + self._policy_dist, + advantage_estimator=advantage_estimator, + value_fn=value_train_task.value, + weight_fn=weight_fn, + head_selector=head_selector, + ) + policy_eval_task = policy_tasks.PolicyEvalTask( + policy_train_task, n_eval_batches, head_selector + ) + return (policy_train_task, policy_eval_task) + + def _init_loop( + self, + train_model, + eval_model, + policy_train_and_eval_task, + value_train_and_eval_task, + eval_at, + policy_n_steps_per_epoch, + value_n_steps_per_epoch, + ): + (policy_train_task, policy_eval_task) = policy_train_and_eval_task + (value_train_task, value_eval_task) = value_train_and_eval_task + + if self._output_dir is not None: + model_output_dir = os.path.join(self._output_dir, "model") + else: + model_output_dir = None + + self._n_train_steps_per_epoch = ( + policy_n_steps_per_epoch + value_n_steps_per_epoch + ) + + checkpoint_at = lambda step: step % self._n_train_steps_per_epoch == 0 + + def which_task(step): + if step % self._n_train_steps_per_epoch < value_n_steps_per_epoch: + return 1 + else: + return 0 + + self._loop = supervised.training.Loop( + model=train_model, + tasks=(policy_train_task, value_train_task), + eval_model=eval_model, + eval_tasks=(policy_eval_task, value_eval_task), + output_dir=model_output_dir, + eval_at=eval_at, + checkpoint_at=checkpoint_at, + which_task=which_task, + ) + + # Validate the restored checkpoints. + # TODO(pkozakowski): Move this to the base class once all Agents use Loop. + if self._loop.step != self._epoch * self._n_train_steps_per_epoch: + raise ValueError( + "The number of Loop steps must equal the number of Agent epochs " + "times the number of steps per epoch, got {}, {} and {}.".format( + self._loop.step, self._epoch, self._n_train_steps_per_epoch + ) + ) + + def _init_collection(self, model_fn, sample_batch): + self._collect_model = model_fn(mode="collect") + self._collect_model.init(shapes.signature(sample_batch)) + + @property + def loop(self): + """Loop exposed for testing.""" + return self._loop + + def policy(self, trajectory, temperature=1.0): + """Policy function that allows to play using this agent.""" + tr_slice = trajectory.suffix(self._max_slice_length) + trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) + return rl_training.network_policy( + collect_model=self._collect_model, + policy_distribution=self._policy_dist, + loop=self.loop, + trajectory_np=trajectory_np, + head_index=0, + temperature=temperature, + ) + + def train_epoch(self): + """Trains RL for one epoch.""" + # Copy policy state accumulated during data collection to the trainers. + self._loop.update_weights_and_state(state=self._collect_model.state) + # Perform one gradient step per training epoch to ensure we stay on policy. + self._loop.run(n_steps=self._n_train_steps_per_epoch) + + +### Implementations of common actor-critic algorithms. + + +class A2C(AdvantageBasedActorCriticAgent): + """Trains policy and value models using the A2C algorithm.""" + + on_policy = True + + def __init__(self, task, entropy_coeff=0.01, **kwargs): + """Configures the A2C Trainer.""" + self._entropy_coeff = entropy_coeff + super().__init__(task, **kwargs) + + @property + def policy_loss_given_log_probs(self): + """Definition of the Advantage Actor Critic (A2C) loss.""" + + # A2C is one of the most basic actor-critic RL algorithms. + # TODO(henrykm) re-factor f into rl_layers and finally share code between + # actor_critic.py and actor_critic_joint.py - requires change of inputs + # in actor_critic_joint.py from dist_inputs to log_probs. + def f(log_probs, advantages, old_log_probs, mask): + del old_log_probs # Not used in A2C. + # log_probs of the shape float32[128,1] + # advantages of the shape int32[128,1] + # mask of the shape int32[128,1] + if log_probs.shape != advantages.shape: + raise ValueError( + "New log-probs and advantages shapes " + "should be the same, %s != %s" % (log_probs.shape, advantages.shape) + ) + if log_probs.shape != mask.shape: + raise ValueError( + "New log-probs and mask shapes should be the same" + ", %s != %s" % (log_probs.shape, mask.shape) + ) + + a2c_objective = -jnp.sum(log_probs * advantages * mask) / jnp.sum(mask) + + entropy_vec = self._policy_dist.entropy(log_probs) * self._entropy_coeff + entropy_loss = jnp.mean(entropy_vec) + + combined_loss = a2c_objective - entropy_loss + + return combined_loss + + return tl.Fn("A2CLoss", f) + + +class PPO(AdvantageBasedActorCriticAgent): + """The Proximal Policy Optimization Algorithm aka PPO. + + Trains policy and value models using the PPO algorithm. + """ + + on_policy = True + + def __init__(self, task, epsilon=0.2, entropy_coeff=0.01, **kwargs): + """Configures the PPO Trainer.""" + self._entropy_coeff = entropy_coeff + self._epsilon = epsilon + super().__init__(task, **kwargs) + + @property + def policy_loss_given_log_probs(self): + """Definition of the Proximal Policy Optimization loss.""" + + def f(new_log_probs, advantages, old_log_probs, mask): + # new_log_probs of the shape float32[128,1] + # advantages of the shape int32[128,1] + # old_log_probs of the shape int32[128,1] + # mask of the shape int32[128,1] + if new_log_probs.shape != advantages.shape: + raise ValueError( + "New log-probs and advantages shapes " + "should be the same, %s != %s" + % (new_log_probs.shape, advantages.shape) + ) + if new_log_probs.shape != old_log_probs.shape: + raise ValueError( + "New log-probs and old log-probs shapes " + "should be the same, %s != %s" + % (new_log_probs.shape, old_log_probs.shape) + ) + if new_log_probs.shape != mask.shape: + raise ValueError( + "New log-probs and mask shapes should be the same" + ", %s != %s" % (new_log_probs.shape, mask.shape) + ) + + # The ratio between new_probs and old_probs expressed + # using log_probs and exponentiation + probs_ratio = jnp.exp(new_log_probs - old_log_probs) + if advantages.shape != probs_ratio.shape: + raise ValueError( + "New log-probs and old log probs shapes " + "should be the same, %s != %s" + % (advantages.shape, probs_ratio.shape) + ) + unclipped_objective = probs_ratio * advantages + clipped_objective = ( + jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages + ) + + if unclipped_objective.shape != probs_ratio.shape: + raise ValueError( + "unclipped_objective and clipped_objective shapes " + "should be the same, %s != %s" + % (unclipped_objective.shape, clipped_objective.shape) + ) + + ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) + + if ppo_objective.shape != mask.shape: + raise ValueError( + "ppo_objective and mask shapes " + "should be the same, %s != %s" % (ppo_objective.shape, mask.shape) + ) + + ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask) + entropy_vec = self._policy_dist.entropy(new_log_probs) * self._entropy_coeff + entropy_loss = jnp.mean(entropy_vec) + combined_loss = ppo_loss - entropy_loss + + return combined_loss + + return tl.Fn("PPOLoss", f) + + +def _weighted_percentiles(x, thresholds): + """Calculate weights for x by percentile-and-weights given in thresholds. + + Thresholds contain a list of (p, weight, minumum). For each threshold, + all elements of x that are above the p-th percentile *and* above minimum + get the weight weight, and all other get the weight 0. + The result is the sum over all thresholds. + + Args: + x: tensor to calculate the weights for + thresholds: list of triples (percentile, weight, minimum) used to + calculate the weights (see above how) + + Returns: + weights, a tensor of the same shape as x + """ + res = [] + for percentile, weight, minimum in thresholds: + threshold = jnp.percentile(x, percentile) + if minimum is not None: + threshold = jnp.maximum(minimum, threshold) + zero_ones = jnp.where(x < threshold, jnp.zeros_like(x), jnp.ones_like(x)) + res.append(weight * zero_ones) + return sum(res) + + +# AWR is an off-policy actor-critic RL algorithm. +def awr_weights(advantages, beta, thresholds): + if thresholds: + return _weighted_percentiles(advantages, thresholds) + return jnp.exp(advantages / beta) + + +# Helper functions for computing AWR metrics. +def awr_metrics(beta, thresholds, preprocess_layer=None): + return { # pylint: disable=g-complex-comprehension + "awr_weight_" + + name: awr_weight_stat(name, fn, beta, thresholds, preprocess_layer) + for (name, fn) in [ + ("mean", jnp.mean), + ("std", jnp.std), + ("min", jnp.min), + ("max", jnp.max), + ] + } + + +def awr_weight_stat(stat_name, stat_fn, beta, thresholds, preprocess_layer): + # Select just the advantages if preprocess layer is not given. + preprocess = tl.Select([1]) if preprocess_layer is None else preprocess_layer + return tl.Serial( + [ + preprocess, + tl.Fn( + "AWRWeight" + stat_name.capitalize(), + lambda x: stat_fn(awr_weights(x, beta, thresholds)), + ), + ] + ) + + +def AWRLoss(beta, w_max, thresholds): # pylint: disable=invalid-name + """Definition of the Advantage Weighted Regression (AWR) loss.""" + + def f(log_probs, advantages, old_log_probs, mask): + del old_log_probs # Not used in AWR. + weights = jnp.minimum(awr_weights(advantages, beta, thresholds), w_max) + return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask) + + return tl.Fn("AWRLoss", f) + + +class AWR(AdvantageBasedActorCriticAgent): + """Trains policy and value models using AWR.""" + + on_policy = False + + def __init__(self, task, beta=1.0, w_max=20.0, thresholds=None, **kwargs): + """Configures the AWR Trainer.""" + self._beta = beta + self._w_max = w_max + self._thresholds = thresholds + super().__init__(task, **kwargs) + + @property + def policy_loss_given_log_probs(self): + """Policy loss.""" + return AWRLoss( + beta=self._beta, w_max=self._w_max, thresholds=self._thresholds + ) # pylint: disable=no-value-for-parameter + + +class LoopAWR(LoopActorCriticAgent): + """Advantage Weighted Regression.""" + + on_policy = False + + def __init__(self, task, model_fn, beta=1.0, w_max=20, **kwargs): + def policy_weight_fn(advantages): + return jnp.minimum(jnp.exp(advantages / beta), w_max) + + super().__init__(task, model_fn, policy_weight_fn=policy_weight_fn, **kwargs) + + +def SamplingAWRLoss( + beta, + w_max, + thresholds, # pylint: disable=invalid-name + reweight=False, + sampled_all_discrete=False, +): + """Definition of the Advantage Weighted Regression (AWR) loss.""" + + def f(log_probs, advantages, old_log_probs, mask): + if reweight: # Use new policy weights for sampled actions instead. + mask *= jnp.exp(fastmath.stop_gradient(log_probs) - old_log_probs) + if sampled_all_discrete: # Actions were sampled uniformly; weight them. + mask *= jnp.exp(old_log_probs) + weights = jnp.minimum(awr_weights(advantages, beta, thresholds), w_max) + return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask) + + return tl.Fn("SamplingAWRLoss", f) + + +class SamplingAWR(AdvantageBasedActorCriticAgent): + """Trains policy and value models using Sampling AWR.""" + + on_policy = False + + def __init__( + self, task, beta=1.0, w_max=20.0, thresholds=None, reweight=False, **kwargs + ): + """Configures the AWR Trainer.""" + self._beta = beta + self._w_max = w_max + self._thresholds = thresholds + self._reweight = reweight + super().__init__(task, q_value=True, **kwargs) + + def _policy_inputs_to_advantages(self, preprocess): + """A layer that computes advantages from policy inputs.""" + + def fn(dist_inputs, actions, q_values, act_log_probs, mask): + del dist_inputs, actions, mask + q_values = jnp.swapaxes(q_values, 0, 1) + act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) + if self._sample_all_discrete_actions: + values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) + else: + values = jnp.mean(q_values, axis=0) + advantages = q_values - values # Broadcasting values over n_samples + if preprocess: + advantages = self._preprocess_advantages(advantages) + return advantages + + return tl.Fn("PolicyInputsToAdvantages", fn) + + @property + def policy_metrics(self): + metrics = { + "policy_loss": self.policy_loss, + "advantage_mean": tl.Serial( + self._policy_inputs_to_advantages(False), + tl.Fn( + "Mean", lambda x: jnp.mean(x) + ), # pylint: disable=unnecessary-lambda + ), + "advantage_std": tl.Serial( + self._policy_inputs_to_advantages(False), + tl.Fn( + "Std", lambda x: jnp.std(x) + ), # pylint: disable=unnecessary-lambda + ), + } + metrics.update( + awr_metrics( + self._beta, + self._thresholds, + preprocess_layer=self._policy_inputs_to_advantages(True), + ) + ) + return metrics + + @property + def policy_loss(self, **unused_kwargs): + """Policy loss.""" + + def LossInput( + dist_inputs, actions, q_values, act_log_probs, mask + ): # pylint: disable=invalid-name + """Calculates action log probabilities and normalizes advantages.""" + # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...) + q_values = jnp.swapaxes(q_values, 0, 1) + mask = jnp.swapaxes(mask, 0, 1) + actions = jnp.swapaxes(actions, 0, 1) + act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) + + # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting? + if self._sample_all_discrete_actions: + values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) + else: + values = jnp.mean(q_values, axis=0) + advantages = q_values - values # Broadcasting values over n_samples + advantages = self._preprocess_advantages(advantages) + + # Broadcast inputs and calculate log-probs + dist_inputs = jnp.broadcast_to( + dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape + ) + log_probs = self._policy_dist.log_prob(dist_inputs, actions) + return (log_probs, advantages, act_log_probs, mask) + + return tl.Serial( + tl.Fn("LossInput", LossInput, n_out=4), + # Policy loss is expected to consume + # (log_probs, advantages, old_log_probs, mask). + SamplingAWRLoss( + beta=self._beta, + w_max=self._w_max, + thresholds=self._thresholds, + reweight=self._reweight, + sampled_all_discrete=self._sample_all_discrete_actions, + ), + ) + + def policy_batches_stream(self): + """Use the RLTask self._task to create inputs to the policy model.""" + # For now TD-0 estimation of the value. TODO(pkozakowski): Support others? + for np_trajectory in self._task.trajectory_batch_stream( + self._policy_batch_size, + epochs=self._replay_epochs, + max_slice_length=self._max_slice_length, + ): + dist_inputs = self._get_dist_inputs(np_trajectory) + (q_values, actions, act_log_probs) = self._run_value_model( + np_trajectory.observation, dist_inputs + ) + shapes.assert_same_shape(q_values, act_log_probs) + + # q_values shape: (batch_size, n_samples, length) + if len(q_values.shape) != 3: + raise ValueError( + "Q-values are expected to have shape [batch_size, " + + "n_samples, length], got: %s" % str(q_values.shape) + ) + if q_values.shape[1] != self._q_value_n_samples: + raise ValueError( + "Q-values dimension 1 should = n_samples, %d != %d" + % (q_values.shape[1], self._q_value_n_samples) + ) + if q_values.shape[0] != self._policy_batch_size: + raise ValueError( + "Q-values dimension 0 should = policy batch size, " + + "%d!=%d" % (q_values.shape[1], self._policy_batch_size) + ) + + mask = np_trajectory.mask + mask = np.reshape(mask, [mask.shape[0], 1] + list(mask.shape[1:])) + mask = jnp.broadcast_to(mask, q_values.shape) + shapes.assert_same_shape(mask, q_values) + yield (np_trajectory.observation, actions, q_values, act_log_probs, mask) diff --git a/trax/learning/reinforcement/actor_critic_joint.py b/trax/learning/reinforcement/actor_critic_joint.py new file mode 100644 index 000000000..654ff28a6 --- /dev/null +++ b/trax/learning/reinforcement/actor_critic_joint.py @@ -0,0 +1,773 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for RL training in Trax.""" + +import functools + +from trax import data +from trax import layers as tl +from trax.fastmath import numpy as jnp +from trax.fastmath import stop_gradient +from trax.learning import supervised +from trax.learning.reinforcement import actor_critic, distributions, rl_layers +from trax.learning.reinforcement import training as rl_training +from trax.learning.supervised import lr_schedules as lr + + +# pylint: disable=g-long-lambda +class ActorCriticJointAgent(rl_training.Agent): + """Trains a joint policy-and-value model using actor-critic methods.""" + + def __init__( + self, + task, + joint_model=None, + optimizer=None, + lr_schedule=lr.multifactor, + batch_size=64, + train_steps_per_epoch=500, + supervised_evals_per_epoch=1, + supervised_eval_steps=1, + n_trajectories_per_epoch=50, + max_slice_length=1, + normalize_advantages=True, + output_dir=None, + n_replay_epochs=1, + ): + """Configures the joint trainers. + + Args: + task: RLTask instance, which defines the environment to train on. + joint_model: Trax layer, representing the joint policy and value model. + optimizer: the optimizer to use to train the joint model. + lr_schedule: learning rate schedule to use to train the joint model/. + batch_size: batch size used to train the joint model. + train_steps_per_epoch: how long to train the joint model in each RL epoch. + supervised_evals_per_epoch: number of value trainers evaluations per RL + epoch - only affects metric reporting. + supervised_eval_steps: number of value trainers steps per evaluation - + only affects metric reporting. + n_trajectories_per_epoch: how many trajectories to collect per epoch. + max_slice_length: the maximum length of trajectory slices to use. + normalize_advantages: if True, then normalize advantages - currently + implemented only in PPO. + output_dir: Path telling where to save outputs (evals and checkpoints). + n_replay_epochs: how many last epochs to take into the replay buffer; + > 1 only makes sense for off-policy algorithms. + """ + super().__init__( + task, + n_trajectories_per_epoch=n_trajectories_per_epoch, + output_dir=output_dir, + ) + self._batch_size = batch_size + self._train_steps_per_epoch = train_steps_per_epoch + self._supervised_evals_per_epoch = supervised_evals_per_epoch + self._supervised_eval_steps = supervised_eval_steps + self._n_trajectories_per_epoch = n_trajectories_per_epoch + self._max_slice_length = max_slice_length + self._policy_dist = distributions.create_distribution(task.action_space) + self._lr_schedule = lr_schedule() + self._optimizer = optimizer + self._normalize_advantages = normalize_advantages + self._n_replay_epochs = n_replay_epochs + self._task.set_n_replay_epochs(n_replay_epochs) + + # Inputs to the joint model are produced by self.batches_stream. + self._inputs = data.inputs.Inputs(train_stream=lambda _: self.batches_stream()) + + self._joint_model = functools.partial( + joint_model, + policy_distribution=self._policy_dist, + ) + + # This is the joint Trainer that will be used to train the policy model. + # * inputs to the trainers come from self.batches_stream + # * outputs are passed to self._joint_loss + self._trainer = supervised.Trainer( + model=self._joint_model, + optimizer=self._optimizer, + lr_schedule=self._lr_schedule, + loss_fn=self.joint_loss, + inputs=self._inputs, + output_dir=output_dir, + metrics={ + "joint_loss": self.joint_loss, + "advantage_mean": self.advantage_mean, + "advantage_norm": self.advantage_norm, + "value_loss": self.value_loss, + "explained_variance": self.explained_variance, + "log_probs_mean": self.log_probs_mean, + "preferred_move": self.preferred_move, + }, + ) + self._eval_model = tl.Accelerate(self._joint_model(mode="eval"), n_devices=1) + example_batch = next(self.batches_stream()) + self._eval_model.init(example_batch) + + def close(self): + self._trainer.close() + super().close() + + def batches_stream(self): + """Use self.task to create inputs to the policy model.""" + return NotImplementedError + + @property + def joint_loss(self): + """Joint policy and value loss layer.""" + return NotImplementedError + + @property + def advantage_mean(self): + """Mean of advantages.""" + + def f(dist_inputs, values, returns): + del dist_inputs + return jnp.mean(returns - values) + + return tl.Fn("AdvantageMean", f) + + @property + def advantage_norm(self): + """Norm of advantages.""" + + def f(dist_inputs, values, returns): + del dist_inputs + return jnp.linalg.norm(returns - values) + + return tl.Fn("AdvantageNorm", f) + + @property + def value_loss(self): + """Value loss - so far generic for all A2C.""" + + def f(dist_inputs, values, returns): + del dist_inputs + return rl_layers.ValueLoss(values, returns, self._value_loss_coeff) + + return tl.Fn("ValueLoss", f) + + @property + def explained_variance(self): + """Explained variance metric.""" + + def f(dist_inputs, values, returns): + del dist_inputs + return rl_layers.ExplainedVariance(values, returns) + + return tl.Fn("ExplainedVariance", f) + + @property + def log_probs_mean(self): + """Mean of log_probs aka dist_inputs.""" + + def f(dist_inputs, values): + del values + return jnp.mean(dist_inputs) + + return tl.Fn("LogProbsMean", f) + + @property + def preferred_move(self): + """Preferred move - the mean of selected moves.""" + + def f(dist_inputs, values): + del values + return rl_layers.PreferredMove(dist_inputs, self._policy_dist.sample) + + return tl.Fn("PreferredMove", f) + + def policy(self, trajectory, temperature=1.0): + """Chooses an action to play after a trajectory.""" + model = self._eval_model + model.replicate_weights(self._trainer.model_weights) + # The two lines below along with the copying + # before return make the TPU happy + tr_slice = trajectory.suffix(self._max_slice_length) + trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) + # Add batch dimension to trajectory_np and run the model. + pred = model(trajectory_np.observation[None, ...])[0] + # Pick element 0 from the batch (the only one), last (current) timestep. + pred = pred[0, -1, :] + sample = self._policy_dist.sample(pred, temperature=temperature) + return (sample.copy(), pred.copy()) + + def train_epoch(self): + """Trains RL for one epoch.""" + n_evals = rl_training.remaining_evals( + self._trainer.step, + self._epoch, + self._train_steps_per_epoch, + self._supervised_evals_per_epoch, + ) + for _ in range(n_evals): + self._trainer.train_epoch( + self._train_steps_per_epoch // self._supervised_evals_per_epoch, + self._supervised_eval_steps, + ) + + +class PPOJoint(ActorCriticJointAgent): + """The Proximal Policy Optimization Algorithm aka PPO. + + Trains policy and value models using the PPO algortithm. + """ + + # TODO(henrykm): make on_policy more generic + # (currently epochs are passed manually) + on_policy = True + + def __init__( + self, task, epsilon=0.2, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs + ): + """Configures the PPO Trainer.""" + self._epsilon = epsilon + self._value_loss_coeff = value_loss_coeff + self._entropy_coeff = entropy_coeff + super().__init__(task, **kwargs) + self._trainer = supervised.Trainer( + model=self._joint_model, + optimizer=self._optimizer, + lr_schedule=self._lr_schedule, + loss_fn=self.joint_loss, + inputs=self._inputs, + output_dir=self._output_dir, + metrics={ + "joint_loss": self.joint_loss, + "advantage_mean": self.advantage_mean, + "advantage_norm": self.advantage_norm, + "value_loss": self.value_loss, + "explained_variance": self.explained_variance, + "log_probs_mean": self.log_probs_mean, + "entropy_loss": self.entropy_loss, + "probs_ratio_mean": self.probs_ratio_mean, + "unclipped_objective_mean": self.unclipped_objective_mean, + "clipped_objective_mean": self.clipped_objective_mean, + "ppo_objective_mean": self.ppo_objective_mean, + "clip_fraction": self.clip_fraction, + "preferred_move": self.preferred_move, + "approximate_kl_divergence": self.approximate_kl_divergence, + }, + ) + + def batches_stream(self): + """Use the RLTask self._task to create inputs to the value model.""" + for np_trajectory in self._task.trajectory_batch_stream( + self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1] + ): + if np_trajectory.dist_inputs is not None: + old_dist_inputs = np_trajectory.dist_inputs + else: + old_dist_inputs = jnp.zeros( + np_trajectory.reward.shape + (self._policy_dist.n_inputs,) + ) + old_log_probs = self._policy_dist.log_prob( + old_dist_inputs, np_trajectory.action + ) + # Insert an extra depth dimension, so the target shape is consistent with + # the network output shape. + yield ( + np_trajectory.observation, # Inputs to the value model. + np_trajectory.return_[:, :, None], + np_trajectory.done[:, :, None], + np_trajectory.reward[:, :, None], + np_trajectory.action, + old_log_probs, + np_trajectory.mask, + ) + + @property + def joint_loss(self): + """Joint policy and value loss.""" + + def f( + dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask + ): + """Definition of the Proximal Policy Optimization loss.""" + del mask # TODO(lukaszkaiser): make PPO work with Transformer + # We have dist_inputs of the shape float32[128,1,18] + assert len(dist_inputs.shape) == 3, ( + f"dist_inputs.shape was {dist_inputs.shape}" + f"but expected length of the tensor shape is 3" + ) + # values of the shape float32[128,1,1] + # returns of the shape float32[128,1,1] + # dones of the shape int32[128,1,1] + # rewards of the shape float32[128,1,1] + # and old_log_probs of the shape float32[128,1] + assert values.shape == returns.shape, ( + f"values.shape was {values.shape}" f"returns.shape was {returns.shape}" + ) + assert values.shape == dones.shape, ( + f"values.shape was {values.shape}" f"returns.shape was {dones.shape}" + ) + assert rewards.shape == dones.shape, ( + f"values.shape was {values.shape}" f"returns.shape was {dones.shape}" + ) + assert returns.shape[0:2] == old_log_probs.shape, ( + f"returns.shape was {returns.shape}" + f"old_log_probs.shape was {old_log_probs.shape}" + ) + + # actions is a tensor of the shape int32[128,1] in the case + # of discrete actions and float32[128,1,6] in the case of + # half-cheetah and other continuous actions + # actions agree with returns/values on the first two coordinates + # meaning batch and time + assert actions.shape[0:2] == returns.shape[0:2], ( + f"actions.shape was {actions.shape} and " + f"returns.shape was {returns.shape}" + ) + + ppo_objective = rl_layers.PPOObjective( + dist_inputs, + stop_gradient(values), + returns, + dones, + rewards, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + epsilon=self._epsilon, + normalize_advantages=self._normalize_advantages, + ) + + # we insist that ppo_objective is a vector of shape [128,1] + assert len(ppo_objective.shape) == 2, f"ppo_objective was {ppo_objective}" + # which agrees with returns/values/actions on the first two coordinates + assert ppo_objective.shape[0:2] == values.shape[0:2], ( + f"ppo_objective.shape was {ppo_objective.shape} and " + f"values.shape was {values.shape}" + ) + + entropy_loss = rl_layers.EntropyLoss( + dist_inputs, + distribution=self._policy_dist, + coeff=self._entropy_coeff, + ) + + assert jnp.ndim(entropy_loss) == 0, f"entropy_loss was {entropy_loss}" + + l2_value_loss = rl_layers.ValueLoss( + values, returns, value_loss_coeff=self._value_loss_coeff + ) + + assert jnp.ndim(l2_value_loss) == 0, f"l2_value_loss was {l2_value_loss}" + + return -ppo_objective.mean() + l2_value_loss - entropy_loss + + return tl.Fn("PPOJointLoss", f) + + # pylint: disable=invalid-name + @property + def probs_ratio_mean(self): + """Joint policy and value loss layer.""" + + def ProbsRatioMean(dist_inputs, actions, old_log_probs): + """Probability Ratio Mean from the PPO algorithm.""" + probs_ratio = rl_layers.ProbsRatio( + dist_inputs, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + ) + return jnp.mean(probs_ratio) + + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + del values, returns, dones, rewards + return ProbsRatioMean(dist_inputs, actions, old_log_probs) + + return tl.Fn("ProbsRatioMean", f) + + @property + def clip_fraction(self): + """Joint policy and value loss layer.""" + + def ClipFraction(dist_inputs, actions, old_log_probs): + """Probability Ratio Mean from the PPO algorithm.""" + probs_ratio = rl_layers.ProbsRatio( + dist_inputs, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + ) + return jnp.mean(jnp.abs(probs_ratio - 1) > self._epsilon) + + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + del values, returns, dones, rewards + return ClipFraction(dist_inputs, actions, old_log_probs) + + return tl.Fn("ClipFraction", f) + + # pylint: enable=invalid-name + + @property + def entropy_loss(self): + """Entropy layer.""" + + def f(dist_inputs, values, returns, dones, rewards, actions): + del values, returns, dones, rewards, actions + return rl_layers.EntropyLoss( + dist_inputs, + distribution=self._policy_dist, + coeff=self._entropy_coeff, + ) + + return tl.Fn("EntropyLoss", f) + + @property + def approximate_kl_divergence(self): + """Approximate KL divergence.""" + + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + del values, returns, dones, rewards + return rl_layers.ApproximateKLDivergence( + dist_inputs, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + ) + + return tl.Fn("ApproximateKLDivergence", f) + + @property + def unclipped_objective_mean(self): + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + """Unclipped objective Mean from the PPO algorithm.""" + del dones, rewards + advantages = returns - values + probs_ratio = rl_layers.ProbsRatio( + dist_inputs, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + ) + # advantages are of the shape [128,1,1] + # and probs_ratio are of the shape [128,1] + advantages = advantages.squeeze(axis=2) + unclipped_objective = rl_layers.UnclippedObjective(probs_ratio, advantages) + return jnp.mean(unclipped_objective) + + return tl.Fn("UnclippedObjectiveMean", f) + + @property + def clipped_objective_mean(self): + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + """Clipped objective from the PPO algorithm.""" + del dones, rewards + advantages = returns - values + probs_ratio = rl_layers.ProbsRatio( + dist_inputs, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + ) + # advantages are of the shape [128,1,1] + # and probs_ratio are of the shape [128,1] + advantages = advantages.squeeze(axis=2) + clipped_objective = rl_layers.ClippedObjective( + probs_ratio, advantages, epsilon=self._epsilon + ) + return jnp.mean(clipped_objective) + + return tl.Fn("ClippedObjectiveMean", f) + + @property + def ppo_objective(self): + """PPO objective with local parameters.""" + + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + return rl_layers.PPOObjective( + dist_inputs, + values, + returns, + dones, + rewards, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + epsilon=self._epsilon, + normalize_advantages=self._normalize_advantages, + ) + + return tl.Fn("PPOObjective", f) + + @property + def ppo_objective_mean(self): + """PPO objective mean.""" + + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + """Clipped objective from the PPO algorithm.""" + ppo_objective = rl_layers.PPOObjective( + dist_inputs, + values, + returns, + dones, + rewards, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + epsilon=self._epsilon, + normalize_advantages=self._normalize_advantages, + ) + return jnp.mean(ppo_objective) + + return tl.Fn("PPOObjectiveMean", f) + + +class A2CJoint(ActorCriticJointAgent): + """The A2C algorithm. + + Trains policy and value models using the A2C algortithm. + """ + + on_policy = True + + def __init__(self, task, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs): + """Configures the A2C Trainer.""" + self._value_loss_coeff = value_loss_coeff + self._entropy_coeff = entropy_coeff + super().__init__(task, **kwargs) + self._trainer = supervised.Trainer( + model=self._joint_model, + optimizer=self._optimizer, + lr_schedule=self._lr_schedule, + loss_fn=self.joint_loss, + inputs=self._inputs, + output_dir=self._output_dir, + metrics={ + "joint_loss": self.joint_loss, + "advantage_mean": self.advantage_mean, + "advantage_norm": self.advantage_norm, + "value_loss": self.value_loss, + "explained_variance": self.explained_variance, + "log_probs_mean": self.log_probs_mean, + "entropy_loss": self.entropy_loss, + "a2c_objective_mean": self.a2c_objective_mean, + "approximate_kl_divergence": self.approximate_kl_divergence, + "preferred_move": self.preferred_move, + }, + ) + + def batches_stream(self): + """Use the RLTask self._task to create inputs to the value model.""" + for np_trajectory in self._task.trajectory_batch_stream( + self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1] + ): + # Insert an extra depth dimension, so the target shape is consistent with + # the network output shape. + yield ( + np_trajectory.observation, # Inputs to the value model. + np_trajectory.return_[:, :, None], + np_trajectory.done[:, :, None], + np_trajectory.reward[:, :, None], + np_trajectory.action, + jnp.zeros_like(np_trajectory.mask), + np_trajectory.mask, + ) + + @property + def joint_loss(self): + """Joint policy and value loss.""" + + def f( + dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask + ): + """Definition of the A2C loss.""" + del old_log_probs + + # Typically we have dist_inputs of the shape float32[128,1,18] + assert len(dist_inputs.shape) == 3, ( + f"dist_inputs.shape was {dist_inputs.shape} " + f"but expected length of the tensor shape is 3" + ) + # values of the shape float32[128,1,1] + # returns of the shape float32[128,1,1] + assert values.shape == returns.shape, ( + f"values.shape was {values.shape}" f"returns.shape was (returns.shape)" + ) + # actions of the shape int32[128,1] in the case of discrete actions + # and float32[128,1,6] in the case of of half-cheetah + # actions agree with returns/values on the first two coordinates + assert actions.shape[0:2] == returns.shape[0:2], ( + f"actions.shape was {actions.shape}" + f"returns.shape was (returns.shape)" + ) + # and mask of the shape float32[128,1] + assert len(mask.shape) == 2, f"mask.shape was {mask.shape}" + # which agrees with returns/values/actions on the first two coordinates + assert mask.shape[0:2] == returns.shape[0:2], ( + f"mask.shape was {mask.shape}" f"returns.shape was (returns.shape)" + ) + + a2c_objective = rl_layers.A2CObjective( + dist_inputs, + stop_gradient(values), + returns, + dones, + rewards, + actions, + mask, + log_prob_fun=self._policy_dist.log_prob, + normalize_advantages=self._normalize_advantages, + ) + + # we insist that a2c_objective is a scalar + assert jnp.ndim(a2c_objective) == 0, f"a2c_objective was {a2c_objective}" + + entropy_loss = rl_layers.EntropyLoss( + dist_inputs, + distribution=self._policy_dist, + coeff=self._entropy_coeff, + ) + + assert jnp.ndim(entropy_loss) == 0, f"entropy_loss was {entropy_loss}" + + l2_value_loss = rl_layers.ValueLoss( + values, returns, value_loss_coeff=self._value_loss_coeff + ) + + assert jnp.ndim(l2_value_loss) == 0, f"l2_value_loss was {l2_value_loss}" + + combined_loss = a2c_objective + l2_value_loss - entropy_loss + + return combined_loss + + return tl.Fn("A2CJointLoss", f, n_out=1) + + @property + def entropy_loss(self): + """Entropy layer.""" + + def f(dist_inputs, values, returns, dones, rewards, actions): + del values, returns, dones, rewards, actions + return rl_layers.EntropyLoss( + dist_inputs, + distribution=self._policy_dist, + coeff=self._entropy_coeff, + ) + + return tl.Fn("EntropyLoss", f) + + @property + def approximate_kl_divergence(self): + """Approximate KL divergence.""" + + def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): + del values, returns, dones, rewards + return rl_layers.ApproximateKLDivergence( + dist_inputs, + actions, + old_log_probs, + log_prob_fun=self._policy_dist.log_prob, + ) + + return tl.Fn("ApproximateKLDivergence", f) + + @property + def a2c_objective(self): + """A2C objective with local parameters.""" + return tl.Fn( + "A2CObjective", + lambda dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask: rl_layers.A2CObjective( + dist_inputs, + values, + returns, + dones, + rewards, + actions, + mask, + log_prob_fun=self._policy_dist.log_prob, + normalize_advantages=self._normalize_advantages, + ), + n_out=1, + ) + + @property + def a2c_objective_mean(self): + """A2C objective mean.""" + + def f( + dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask + ): + """A2C objective mean.""" + # TODO(henrykm): include dones, rewards + del old_log_probs + a2c_objective = rl_layers.A2CObjective( + dist_inputs, + values, + returns, + dones, + rewards, + actions, + mask, + log_prob_fun=self._policy_dist.log_prob, + normalize_advantages=self._normalize_advantages, + ) + return jnp.mean(a2c_objective) + + return tl.Fn("A2CObjectiveMean", f, n_out=1) + + +class AWRJoint(ActorCriticJointAgent): + """Trains a joint policy-and-value model using AWR.""" + + # TODO(henrykm): value_loss_coeff looks like a common parameter + def __init__( + self, + task, + value_loss_coeff=0.1, + beta=1.0, + w_max=20.0, + thresholds=None, + **kwargs, + ): + """Configures the joint AWR Trainer.""" + self._beta = beta + self._w_max = w_max + self._thresholds = thresholds + self._value_loss_coeff = value_loss_coeff + super().__init__(task, **kwargs) + + def batches_stream(self): + """Use the RLTask self._task to create inputs to the value model.""" + for np_trajectory in self._task.trajectory_batch_stream( + self._batch_size, max_slice_length=self._max_slice_length + ): + # Insert an extra depth dimension, so the target shape is consistent with + # the network output shape. + yield ( + np_trajectory.observation, # Inputs to the value model. + np_trajectory.return_[:, :, None], # Targets: regress to returns. + np_trajectory.action, # Policy targets: actions. + np_trajectory.mask, + ) # Padding mask. + + @property + def joint_loss(self): + """Joint policy and value loss.""" + + def f(preds, values, returns, actions, mask): + advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) + logps = self._policy_dist.log_prob(preds, actions) + awr_loss = actor_critic.AWRLoss( + beta=self._beta, w_max=self._w_max, thresholds=self._thresholds + )((logps, advantages, jnp.zeros_like(logps), mask)) + l2_value_loss = jnp.mean((returns - values) ** 2) * self._value_loss_coeff + return awr_loss + l2_value_loss + + return tl.Fn("AWRJointLoss", f) diff --git a/trax/learning/reinforcement/advantages.py b/trax/learning/reinforcement/advantages.py new file mode 100644 index 000000000..7ebf5a414 --- /dev/null +++ b/trax/learning/reinforcement/advantages.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RL advantage estimators.""" + +import gin +import numpy as np + +from trax import fastmath + +common_args = ["gamma", "margin"] + + +def mask_discount(discount, discount_mask): + """Computes a discount to apply at a given timestep, based on the mask.""" + return fastmath.numpy.where(discount_mask, discount, 1.0) + + +def discounted_returns(rewards, gammas): + """Computes discounted returns for a trajectory or a batch of them.""" + returns = np.zeros_like(rewards) + ret = 0.0 + for i in reversed(range(rewards.shape[-1])): + ret = rewards[..., i] + gammas[..., i] * ret + returns[..., i] = ret + return returns + + +@gin.configurable(denylist=common_args) +def monte_carlo(gamma, margin): + """Calculate Monte Carlo advantage. + + We assume the values are a tensor of shape [batch_size, length] and this + is the same shape as rewards and returns. + + Args: + gamma: float, gamma parameter for TD from the underlying task + margin: number of extra steps in the sequence + + Returns: + Function (rewards, returns, values, dones) -> advantages, where advantages + advantages is an array of shape [batch_size, length - margin]. + """ + del gamma + + def estimator(rewards, returns, values, dones, discount_mask): + del discount_mask + (_, length) = returns.shape + # Make sure that the future returns and values at "done" states are zero. + returns[dones] = rewards[dones] + values[dones] = 0 + return (returns - values)[:, : (length - margin)] + + return estimator + + +@gin.configurable(denylist=common_args) +def td_k(gamma, margin): + """Calculate TD-k advantage. + + The k parameter is assumed to be the same as margin. + + We calculate advantage(s_i) as: + + gamma^n_steps * value(s_{i + n_steps}) - value(s_i) + discounted_rewards + + where discounted_rewards is the sum of rewards in these steps with + discounting by powers of gamma. + + Args: + gamma: float, gamma parameter for TD from the underlying task + margin: number of extra steps in the sequence + + Returns: + Function (rewards, returns, values, dones) -> advantages, where advantages + advantages is an array of shape [batch_size, length - margin]. + """ + + def estimator(rewards, returns, values, dones, discount_mask): + del returns + gammas = mask_discount(gamma, discount_mask) + # Here we calculate advantage with TD-k, where k=margin. + k = margin + assert k > 0 + advantages = np.zeros_like(values[:, k:]) + discount = 1.0 + for i in range(margin): + advantages += discount * rewards[:, i : -(margin - i)] + discount *= gammas[:, i : -(margin - i)] + advantages += discount * values[:, k:] + # Zero out the future returns at "done" states. + dones = dones[:, :-k] + # TPU friendly version of the formula + # advantages[dones] = rewards[:, :-k][dones] + advantages = fastmath.index_update(advantages, dones, rewards[:, :-k][dones]) + # Subtract the baseline (value). + advantages -= values[:, :-k] + return advantages + + return estimator + + +@gin.configurable(denylist=common_args) +def td_lambda(gamma, margin, lambda_=0.95): + """Calculate TD-lambda advantage. + + The estimated return is an exponentially-weighted average of different TD-k + returns. + + Args: + gamma: float, gamma parameter for TD from the underlying task + margin: number of extra steps in the sequence + lambda_: float, the lambda parameter of TD-lambda + + Returns: + Function (rewards, returns, values, dones) -> advantages, where advantages + advantages is an array of shape [batch_size, length - margin]. + """ + + def estimator(rewards, returns, values, dones, discount_mask): + gammas = mask_discount(gamma, discount_mask) + lambdas = mask_discount(lambda_, discount_mask) + td_returns = np.zeros_like(returns) + (_, length) = returns.shape + td_returns[:, -1] = values[:, -1] + for i in reversed(range(length - 1)): + lambda_i = lambdas[:, i] + td_returns[:, i] = rewards[:, i] + (1 - dones[:, i]) * gammas[:, i] * ( + (1 - lambda_i) * values[:, i + 1] + lambda_i * td_returns[:, i + 1] + ) + return (td_returns - values)[:, : (returns.shape[1] - margin)] + + return estimator + + +@gin.configurable(denylist=common_args) +def gae(gamma, margin, lambda_=0.95): + """Calculate Generalized Advantage Estimation. + + Calculate state values bootstrapping off the following state values - + Generalized Advantage Estimation https://arxiv.org/abs/1506.02438 + + Args: + gamma: float, gamma parameter for TD from the underlying task + margin: number of extra steps in the sequence + lambda_: float, the lambda parameter of GAE + + Returns: + Function (rewards, returns, values, dones) -> advantages, where advantages + advantages is an array of shape [batch_size, length - margin]. + """ + + def estimator(rewards, returns, values, dones, discount_mask): + del returns + gammas = mask_discount(gamma, discount_mask) + lambdas = mask_discount(lambda_, discount_mask) + advantages = np.zeros_like(rewards) + (_, length) = rewards.shape + + for i in reversed(range(length - 1)): + bellman_delta = ( + rewards[:, i] + - values[:, i] + + (1 - dones[:, i]) * (gammas[:, i] * values[:, i + 1]) + ) + advantages[:, i] = bellman_delta + (1 - dones[:, i]) * ( + gammas[:, i] * lambdas[:, i] * advantages[:, i + 1] + ) + + return advantages[:, : (rewards.shape[1] - margin)] + + return estimator diff --git a/trax/rl/atari_test.py b/trax/learning/reinforcement/atari_test.py similarity index 71% rename from trax/rl/atari_test.py rename to trax/learning/reinforcement/atari_test.py index 2bf552b90..3f8f4fee3 100644 --- a/trax/rl/atari_test.py +++ b/trax/learning/reinforcement/atari_test.py @@ -15,19 +15,7 @@ """Tests for RL training.""" -import functools - from absl.testing import absltest -from trax import models -from trax import optimizers as opt -from trax.models import atari_cnn -from trax.rl import actor_critic -from trax.rl import task as rl_task -from trax.supervised import lr_schedules - - - - -if __name__ == '__main__': - absltest.main() +if __name__ == "__main__": + absltest.main() diff --git a/trax/rl/configs/dqn_cartpole_regression.gin b/trax/learning/reinforcement/configs/dqn_cartpole_regression.gin similarity index 100% rename from trax/rl/configs/dqn_cartpole_regression.gin rename to trax/learning/reinforcement/configs/dqn_cartpole_regression.gin diff --git a/trax/rl/configs/light_a2c_joint_atari_sweep.yaml b/trax/learning/reinforcement/configs/light_a2c_joint_atari_sweep.yaml similarity index 100% rename from trax/rl/configs/light_a2c_joint_atari_sweep.yaml rename to trax/learning/reinforcement/configs/light_a2c_joint_atari_sweep.yaml diff --git a/trax/rl/configs/light_atari.gin b/trax/learning/reinforcement/configs/light_atari.gin similarity index 100% rename from trax/rl/configs/light_atari.gin rename to trax/learning/reinforcement/configs/light_atari.gin diff --git a/trax/rl/configs/light_atari_sweep.yaml b/trax/learning/reinforcement/configs/light_atari_sweep.yaml similarity index 100% rename from trax/rl/configs/light_atari_sweep.yaml rename to trax/learning/reinforcement/configs/light_atari_sweep.yaml diff --git a/trax/rl/configs/light_awr_cartpole_sweep.yaml b/trax/learning/reinforcement/configs/light_awr_cartpole_sweep.yaml similarity index 100% rename from trax/rl/configs/light_awr_cartpole_sweep.yaml rename to trax/learning/reinforcement/configs/light_awr_cartpole_sweep.yaml diff --git a/trax/rl/configs/light_awr_joint_atari_sweep.yaml b/trax/learning/reinforcement/configs/light_awr_joint_atari_sweep.yaml similarity index 100% rename from trax/rl/configs/light_awr_joint_atari_sweep.yaml rename to trax/learning/reinforcement/configs/light_awr_joint_atari_sweep.yaml diff --git a/trax/rl/configs/light_awr_joint_cartpole.gin b/trax/learning/reinforcement/configs/light_awr_joint_cartpole.gin similarity index 100% rename from trax/rl/configs/light_awr_joint_cartpole.gin rename to trax/learning/reinforcement/configs/light_awr_joint_cartpole.gin diff --git a/trax/rl/configs/light_awr_joint_cartpole_sweep.yaml b/trax/learning/reinforcement/configs/light_awr_joint_cartpole_sweep.yaml similarity index 100% rename from trax/rl/configs/light_awr_joint_cartpole_sweep.yaml rename to trax/learning/reinforcement/configs/light_awr_joint_cartpole_sweep.yaml diff --git a/trax/rl/configs/light_cartpole.gin b/trax/learning/reinforcement/configs/light_cartpole.gin similarity index 100% rename from trax/rl/configs/light_cartpole.gin rename to trax/learning/reinforcement/configs/light_cartpole.gin diff --git a/trax/rl/configs/light_copy.gin b/trax/learning/reinforcement/configs/light_copy.gin similarity index 100% rename from trax/rl/configs/light_copy.gin rename to trax/learning/reinforcement/configs/light_copy.gin diff --git a/trax/rl/configs/light_copy_sweep.yaml b/trax/learning/reinforcement/configs/light_copy_sweep.yaml similarity index 100% rename from trax/rl/configs/light_copy_sweep.yaml rename to trax/learning/reinforcement/configs/light_copy_sweep.yaml diff --git a/trax/rl/configs/light_joint_atari.gin b/trax/learning/reinforcement/configs/light_joint_atari.gin similarity index 100% rename from trax/rl/configs/light_joint_atari.gin rename to trax/learning/reinforcement/configs/light_joint_atari.gin diff --git a/trax/rl/configs/light_joint_cartpole.gin b/trax/learning/reinforcement/configs/light_joint_cartpole.gin similarity index 100% rename from trax/rl/configs/light_joint_cartpole.gin rename to trax/learning/reinforcement/configs/light_joint_cartpole.gin diff --git a/trax/rl/configs/light_lunarlander.gin b/trax/learning/reinforcement/configs/light_lunarlander.gin similarity index 100% rename from trax/rl/configs/light_lunarlander.gin rename to trax/learning/reinforcement/configs/light_lunarlander.gin diff --git a/trax/rl/configs/light_mujoco.gin b/trax/learning/reinforcement/configs/light_mujoco.gin similarity index 100% rename from trax/rl/configs/light_mujoco.gin rename to trax/learning/reinforcement/configs/light_mujoco.gin diff --git a/trax/rl/configs/light_mujoco_regression_test.gin b/trax/learning/reinforcement/configs/light_mujoco_regression_test.gin similarity index 100% rename from trax/rl/configs/light_mujoco_regression_test.gin rename to trax/learning/reinforcement/configs/light_mujoco_regression_test.gin diff --git a/trax/rl/configs/light_mujoco_sweep.yaml b/trax/learning/reinforcement/configs/light_mujoco_sweep.yaml similarity index 100% rename from trax/rl/configs/light_mujoco_sweep.yaml rename to trax/learning/reinforcement/configs/light_mujoco_sweep.yaml diff --git a/trax/rl/configs/light_ppo_atari.gin b/trax/learning/reinforcement/configs/light_ppo_atari.gin similarity index 100% rename from trax/rl/configs/light_ppo_atari.gin rename to trax/learning/reinforcement/configs/light_ppo_atari.gin diff --git a/trax/rl/configs/light_ppo_boxing_regression_test.gin b/trax/learning/reinforcement/configs/light_ppo_boxing_regression_test.gin similarity index 100% rename from trax/rl/configs/light_ppo_boxing_regression_test.gin rename to trax/learning/reinforcement/configs/light_ppo_boxing_regression_test.gin diff --git a/trax/rl/configs/light_ppo_cartpole_regression_test.gin b/trax/learning/reinforcement/configs/light_ppo_cartpole_regression_test.gin similarity index 100% rename from trax/rl/configs/light_ppo_cartpole_regression_test.gin rename to trax/learning/reinforcement/configs/light_ppo_cartpole_regression_test.gin diff --git a/trax/rl/configs/light_ppo_half_cheetah_regression_test.gin b/trax/learning/reinforcement/configs/light_ppo_half_cheetah_regression_test.gin similarity index 100% rename from trax/rl/configs/light_ppo_half_cheetah_regression_test.gin rename to trax/learning/reinforcement/configs/light_ppo_half_cheetah_regression_test.gin diff --git a/trax/rl/configs/light_ppo_joint_atari.gin b/trax/learning/reinforcement/configs/light_ppo_joint_atari.gin similarity index 100% rename from trax/rl/configs/light_ppo_joint_atari.gin rename to trax/learning/reinforcement/configs/light_ppo_joint_atari.gin diff --git a/trax/rl/configs/light_ppo_joint_atari_sweep.yaml b/trax/learning/reinforcement/configs/light_ppo_joint_atari_sweep.yaml similarity index 100% rename from trax/rl/configs/light_ppo_joint_atari_sweep.yaml rename to trax/learning/reinforcement/configs/light_ppo_joint_atari_sweep.yaml diff --git a/trax/rl/configs/light_ppo_lunar_lander_regression_test.gin b/trax/learning/reinforcement/configs/light_ppo_lunar_lander_regression_test.gin similarity index 100% rename from trax/rl/configs/light_ppo_lunar_lander_regression_test.gin rename to trax/learning/reinforcement/configs/light_ppo_lunar_lander_regression_test.gin diff --git a/trax/rl/configs/light_ppo_pong_regression_test.gin b/trax/learning/reinforcement/configs/light_ppo_pong_regression_test.gin similarity index 100% rename from trax/rl/configs/light_ppo_pong_regression_test.gin rename to trax/learning/reinforcement/configs/light_ppo_pong_regression_test.gin diff --git a/trax/rl/configs/ppo_atari_sweep.yaml b/trax/learning/reinforcement/configs/ppo_atari_sweep.yaml similarity index 100% rename from trax/rl/configs/ppo_atari_sweep.yaml rename to trax/learning/reinforcement/configs/ppo_atari_sweep.yaml diff --git a/trax/rl/configs/ppo_cartpole_sweep.yaml b/trax/learning/reinforcement/configs/ppo_cartpole_sweep.yaml similarity index 100% rename from trax/rl/configs/ppo_cartpole_sweep.yaml rename to trax/learning/reinforcement/configs/ppo_cartpole_sweep.yaml diff --git a/trax/rl/configs/transformer_srl_sine.gin b/trax/learning/reinforcement/configs/transformer_srl_sine.gin similarity index 100% rename from trax/rl/configs/transformer_srl_sine.gin rename to trax/learning/reinforcement/configs/transformer_srl_sine.gin diff --git a/trax/learning/reinforcement/distributions.py b/trax/learning/reinforcement/distributions.py new file mode 100644 index 000000000..ca6beb58a --- /dev/null +++ b/trax/learning/reinforcement/distributions.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Probability distributions for RL training in Trax.""" + +import gin +import gym +import numpy as np + +from trax import layers as tl +from trax.fastmath import numpy as jnp + + +class Distribution: + """Abstract class for parametrized probability distributions.""" + + @property + def n_inputs(self): + """Returns the number of inputs to the distribution (i.e. parameters).""" + raise NotImplementedError + + def sample(self, inputs, temperature=1.0): + """Samples a point from the distribution. + + Args: + inputs (jnp.ndarray): Distribution inputs. Shape is subclass-specific. + Broadcasts along the first dimensions. For example, in the categorical + distribution parameter shape is (C,), where C is the number of + categories. If (B, C) is passed, the object will represent a batch of B + categorical distributions with different parameters. + temperature: sampling temperature; 1.0 is default, at 0.0 chooses + the most probable (preferred) action. + + Returns: + Sampled point of shape dependent on the subclass and on the shape of + inputs. + """ + raise NotImplementedError + + def log_prob(self, inputs, point): + """Retrieves log probability (or log probability density) of a point. + + Args: + inputs (jnp.ndarray): Distribution parameters. + point (jnp.ndarray): Point from the distribution. Shape should be + consistent with inputs. + + Returns: + Array of log probabilities of points in the distribution. + """ + raise NotImplementedError + + def LogProb(self): # pylint: disable=invalid-name + """Builds a log probability layer for this distribution.""" + return tl.Fn( + "LogProb", lambda inputs, point: self.log_prob(inputs, point) + ) # pylint: disable=unnecessary-lambda + + +@gin.configurable(denylist=["n_categories", "shape"]) +class Categorical(Distribution): + """Categorical distribution parametrized by logits.""" + + def __init__(self, n_categories, shape=()): + """Initializes Categorical distribution. + + Args: + n_categories (int): Number of categories. + shape (tuple): Shape of the sample. + """ + self._n_categories = n_categories + self._shape = shape + + @property + def n_inputs(self): + return np.prod(self._shape, dtype=jnp.int32) * self._n_categories + + def _unflatten_inputs(self, inputs): + return jnp.reshape( + inputs, inputs.shape[:-1] + self._shape + (self._n_categories,) + ) + + def sample(self, inputs, temperature=1.0): + # No need for LogSoftmax with sampling - softmax normalization is + # subtracting a constant from every logit, and sampling is taking + # a max over logits plus noise, so invariant to adding a constant. + if temperature == 0.0: + return jnp.argmax(self._unflatten_inputs(inputs), axis=-1) + return tl.logsoftmax_sample(self._unflatten_inputs(inputs), temperature) + + def log_prob(self, inputs, point): + inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs)) + return jnp.sum( + # Select the logits specified by point. + inputs * tl.one_hot(point, self._n_categories), + # Sum over the parameter dimensions. + axis=[-a for a in range(1, len(self._shape) + 2)], + ) + + def entropy(self, inputs): + log_probs = tl.LogSoftmax()(inputs) + probs = jnp.exp(log_probs) + return -jnp.sum(probs * log_probs, axis=-1) + + +@gin.configurable(denylist=["shape"]) +class Gaussian(Distribution): + """Independent multivariate Gaussian distribution parametrized by mean.""" + + def __init__(self, shape=(), std=1.0, learn_std=None): + """Initializes Gaussian distribution. + + Args: + shape (tuple): Shape of the sample. + std (float): Standard deviation, shared across the whole sample. + learn_std (str or None): How to learn the standard deviation - 'shared' + to have a single, shared std parameter, or 'separate' to have separate + parameters for each dimension. + """ + self._shape = shape + self._std = std + self._learn_std = learn_std + + @property + def _n_dims(self): + return np.prod(self._shape, dtype=jnp.int32) + + def _params(self, inputs): + """Extracts the mean and std parameters from the inputs.""" + if inputs.shape[-1] != self.n_inputs: + raise ValueError( + "Invalid distribution parametrization - expected {} parameters, " + "got {}. Input shape: {}.".format( + self.n_inputs, inputs.shape[-1], inputs.shape + ) + ) + n_dims = self._n_dims + # Split the distribution inputs into two parts: mean and std. + mean = inputs[..., :n_dims] + if self._learn_std is not None: + std = inputs[..., n_dims:] + # Std is non-negative, so let's softplus it. + std = tl.Softplus()(std + self._std) + else: + std = self._std + # In case of constant or shared std, upsample it to the same dimensionality + # as the means. + std = jnp.broadcast_to(std, mean.shape) + return (mean, std) + + @property + def n_inputs(self): + n_dims = self._n_dims + return { + None: n_dims, + "shared": n_dims + 1, + "separate": n_dims * 2, + }[self._learn_std] + + def sample(self, inputs, temperature=1.0): + (mean, std) = self._params(inputs) + mean = jnp.reshape(mean, mean.shape[:-1] + self._shape) + std = jnp.reshape(std, std.shape[:-1] + self._shape) + if temperature == 0: + # this seemingly strange if solves the problem + # of calling np/jnp.random in the metric PreferredMove + return mean + else: + return np.random.normal(loc=mean, scale=(std * temperature)) + + def log_prob(self, inputs, point): + point = point.reshape(inputs.shape[:-1] + (-1,)) + (mean, std) = self._params(inputs) + return -jnp.sum( + # Scaled distance. + (point - mean) ** 2 / (2 * std**2) + + # Normalizing constant. + (jnp.log(std) + jnp.log(jnp.sqrt(2 * jnp.pi))), + axis=-1, + ) + + def entropy(self, inputs): + (_, std) = self._params(inputs) + return jnp.sum(jnp.exp(std) + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e), axis=-1) + + +# TODO(pkozakowski): Implement GaussianMixture. + + +def create_distribution(space): + """Creates a Distribution for the given Gym space.""" + if isinstance(space, gym.spaces.Discrete): + return Categorical(shape=(), n_categories=space.n) + elif isinstance(space, gym.spaces.MultiDiscrete): + assert space.nvec.size + assert min(space.nvec) == max(space.nvec), ( + "Every dimension must have the same number of categories, got " + "{}.".format(space.nvec) + ) + return Categorical(shape=(len(space.nvec),), n_categories=space.nvec[0]) + elif isinstance(space, gym.spaces.Box): + return Gaussian(shape=space.shape) + else: + raise TypeError("Space {} unavailable as a distribution support.") + + +def LogLoss(distribution, **unused_kwargs): # pylint: disable=invalid-name + """Builds a log loss layer for a Distribution.""" + return tl.Serial(distribution.LogProb(), tl.Negate(), tl.WeightedSum()) diff --git a/trax/rl/envs/__init__.py b/trax/learning/reinforcement/envs/__init__.py similarity index 86% rename from trax/rl/envs/__init__.py rename to trax/learning/reinforcement/envs/__init__.py index 6cf568ba3..cdb2ae89f 100644 --- a/trax/rl/envs/__init__.py +++ b/trax/learning/reinforcement/envs/__init__.py @@ -16,12 +16,13 @@ """Trax RL environments library.""" import gin -from trax.rl.envs import data_envs + +from . import data_envs def configure_rl_env(*args, **kwargs): - kwargs['module'] = 'trax.rl.envs' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.reinforcement.envs" + return gin.external_configurable(*args, **kwargs) copy_stream = configure_rl_env(data_envs.copy_stream) diff --git a/trax/learning/reinforcement/envs/data_envs.py b/trax/learning/reinforcement/envs/data_envs.py new file mode 100644 index 000000000..f6839907f --- /dev/null +++ b/trax/learning/reinforcement/envs/data_envs.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RL environments created from supervised data-sets.""" + +import gym +import numpy as np + + +class SequenceDataEnv(object): + """RL environment created from a generator of sequential data. + + This class allows to create RL environments from supervised sequential data, + such as tokenized natural languague processing tasks. The data comes as: + (input1, output1, input2, output2, ...) + where inputs and outputs are all sequences of integers. + + For example, with input (2, 3) and output (4, 5), so data = [(2, 3), (4, 5)], + the sequence of (observations, rewards, actions) will look like: + 2 = env.reset() # first observation + 3, 0.0, _, _ = env.step(ignored_action) + eos, 0.0, _, _ = env.step(ignored_action) + act1, 0.0, _, _ = env.step(act1) # observation = action + act2, 0.0, _, _ = env.step(act2) # observation = action + eos, score, _, _ = env.step(eos) + + where score = metric((4, 5), (act1, act2)) is the reward gotten from + comparing the two actions to the actual output from the data. + + The environment first presents the input as observations, doing this + sequentially, token-by-token, and ignoring all actions taken by the policy. + Then, the policy is asked to generate the response, again, token-by-token, + until it generates EOS. Generated tokens are repeated as observations. + When EOS is encountered, a metric is computed between the generated + output and the output from data, and this metric is returned as reward. + """ + + def __init__(self, data_stream, vocab_size, metric=None, eos_id=1, max_length=1000): + """The constructor. + + Args: + data_stream: A python generator creating lists or tuples of + sequences (list, tuples or numpy arrays) of integers. + vocab_size: Integer, the size of the vocabulary. All integers in the + data stream must be positive and smaller than this value. + metric: A function taking two lists of integers and returning a float. + If None, we use per-token accuracy as the default metric. + eos_id: Integer, the id of the EOS symbol. + max_length: Integer, maximum length of the policy reply to avoid + infinite episodes if policy never produces EOS. + + Returns: + A new environment which presents the data and compares the policy + response with the expected data, returning metric as reward. + """ + self._data = data_stream + self._vocab_size = vocab_size + self._eos = eos_id + self._max_length = max_length + self._metric = _accuracy if metric is None else metric + self.reset() + + @property + def _on_input(self): + """Return True if we're currently processing input, False if output.""" + cur_sequence_id, _ = self._cur_position + return cur_sequence_id % 2 == 0 + + @property + def observation(self): + cur_sequence_id, cur_token_id = self._cur_position + if cur_sequence_id >= len(self._cur_sequence): + obs = self._eos + elif self._on_input: + obs = self._cur_sequence[cur_sequence_id][cur_token_id] + else: + obs = self._response[-1] if self._response else self._eos + return np.array(int(obs), dtype=np.int32) + + @property + def action_space(self): + return gym.spaces.Discrete(self._vocab_size) + + @property + def observation_space(self): + return gym.spaces.Discrete(self._vocab_size) + + def reset(self): + """Reset this environment.""" + self._cur_sequence = next(self._data) + # Position contains 2 indices: which sequnece are we in? (input1, output1, + # input2, output2 and so on) and which token in the sequence are we in? + self._cur_position = (0, 0) + self._response = [] + return self.observation + + def step(self, action): + """Single step of the environment when policy took `action`.""" + cur_sequence_id, cur_token_id = self._cur_position + if cur_sequence_id >= len(self._cur_sequence): + return np.array(self._eos, dtype=np.int32), 0.0, True, None + + # Emit the control mask on the output. + control_mask = int(not self._on_input) + + if self._on_input: + self._response = [] + if cur_token_id + 1 < len(self._cur_sequence[cur_sequence_id]): + self._cur_position = (cur_sequence_id, cur_token_id + 1) + done = False + else: + self._cur_position = (cur_sequence_id + 1, 0) + done = cur_sequence_id + 1 >= len(self._cur_sequence) + reward = 0.0 + discount_mask = 0 + + else: + self._response.append(action) + if action == self._eos or len(self._response) > self._max_length: + self._cur_position = (cur_sequence_id + 1, 0) + reward = self._metric( + self._response[:-1], self._cur_sequence[cur_sequence_id] + ) + done = cur_sequence_id + 1 >= len(self._cur_sequence) + # Emit the discount mask on the last token of each action. + discount_mask = 1 + else: + reward = 0.0 + done = False + discount_mask = 0 + + info = {"control_mask": control_mask, "discount_mask": discount_mask} + return self.observation, reward, done, info + + +def copy_stream(length, low=2, high=15, n=1): + """Generate `n` random sequences of length `length` and yield with copies.""" + while True: + res = [] + for _ in range(n): + seq = np.random.randint(low, high, size=(length,), dtype=np.int32) + res.extend([seq, seq]) + yield res + + +def _accuracy(seq1, seq2): + """Token-level accuracy.""" + seq1, seq2 = np.array(seq1), np.array(seq2) + max_length = max(seq1.shape[-1], seq2.shape[-1]) + min_length = min(seq1.shape[-1], seq2.shape[-1]) + seq1s, seq2s = seq1[..., :min_length], seq2[..., :min_length] + return np.sum(np.equal(seq1s, seq2s)) / max_length diff --git a/trax/learning/reinforcement/normalization.py b/trax/learning/reinforcement/normalization.py new file mode 100644 index 000000000..5b192d9d1 --- /dev/null +++ b/trax/learning/reinforcement/normalization.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Normalization helpers.""" + +import gin +import numpy as np + +from trax import fastmath +from trax import layers as tl + + +def running_mean_init(shape, fill_value=0): + return (np.full(shape, fill_value), np.array(0)) + + +def running_mean_update(x, state): + (mean, n) = state + mean = n.astype(np.float32) / (n + 1) * mean + x / (n + 1) + return (mean, n + 1) + + +def running_mean_get_mean(state): + (mean, _) = state + return mean + + +def running_mean_get_count(state): + (_, count) = state + return count + + +def running_mean_and_variance_init(shape): + mean_state = running_mean_init(shape, fill_value=0.0) + var_state = running_mean_init(shape, fill_value=1.0) + return (mean_state, var_state) + + +def running_mean_and_variance_update(x, state): + (mean_state, var_state) = state + old_mean = running_mean_get_mean(mean_state) + mean_state = running_mean_update(x, mean_state) + new_mean = running_mean_get_mean(mean_state) + + var_state = running_mean_update((x - new_mean) * (x - old_mean), var_state) + + return (mean_state, var_state) + + +def running_mean_and_variance_get_mean(state): + (mean_state, _) = state + return running_mean_get_mean(mean_state) + + +def running_mean_and_variance_get_count(state): + (mean_state, _) = state + return running_mean_get_count(mean_state) + + +def running_mean_and_variance_get_variance(state): + (_, var_state) = state + return running_mean_get_mean(var_state) + + +@gin.configurable(denylist=["mode"]) +class Normalize(tl.Layer): + """Numerically stable normalization layer.""" + + def __init__(self, sample_limit=float("+inf"), epsilon=1e-5, mode="train"): + super().__init__() + self._sample_limit = sample_limit + self._epsilon = epsilon + self._mode = mode + + def init_weights_and_state(self, input_signature): + self.state = running_mean_and_variance_init(input_signature.shape[2:]) + + def forward(self, inputs): + state = self.state + observations = inputs + if self._mode == "collect": + # Accumulate statistics only in the collect mode, i.e. when collecting + # data using the agent. + for observation in observations[:, -1]: # (batch_size, time, ...) + # Update statistics for each observation separately for simplicity. + # Currently during data collection the batch size is 1 anyway. + count = running_mean_and_variance_get_count(state) + state = fastmath.cond( + count < self._sample_limit, + true_operand=(observation, state), + true_fun=lambda args: running_mean_and_variance_update(*args), + false_operand=None, + false_fun=lambda _: state, + ) + + mean = running_mean_and_variance_get_mean(state) + var = running_mean_and_variance_get_variance(state) + norm_observations = (observations - mean) / (var**0.5 + self._epsilon) + self.state = state + return norm_observations + + +@gin.configurable(denylist=["mode"]) +def LayerNormSquash(mode, width=128): # pylint: disable=invalid-name + """Dense-LayerNorm-Tanh normalizer inspired by ACME.""" + # https://github.com/deepmind/acme/blob/master/acme/jax/networks/continuous.py#L34 + del mode + return tl.Serial( + [ + tl.Dense(width), + tl.LayerNorm(), + tl.Tanh(), + ] + ) diff --git a/trax/learning/reinforcement/policy_tasks.py b/trax/learning/reinforcement/policy_tasks.py new file mode 100644 index 000000000..64076e7ac --- /dev/null +++ b/trax/learning/reinforcement/policy_tasks.py @@ -0,0 +1,266 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy network training tasks. + +Policy tasks encapsulate the training process of a policy network into a simple, +replaceable component. To implement a policy-based Agent using policy tasks: + + 1. Subclass the base Agent class. + 2. In __init__(), initialize the policy training and evaluation tasks, and + a trax.supervised.training.Loop instance using them. + 3. In train_epoch(), call the Loop to train the network. + 4. In policy(), call network_policy() defined in this module. +""" + +import numpy as np + +from trax import layers as tl +from trax.fastmath import numpy as jnp +from trax.learning.reinforcement import distributions +from trax.learning.supervised import training + + +class PolicyTrainTask(training.TrainTask): + """Task for policy training. + + Trains the policy based on action advantages. + """ + + def __init__( + self, + trajectory_batch_stream, + optimizer, + lr_schedule, + policy_distribution, + advantage_estimator, + value_fn, + weight_fn=(lambda x: x), + advantage_normalization=True, + advantage_normalization_epsilon=1e-5, + head_selector=(), + ): + """Initializes PolicyTrainTask. + + Args: + trajectory_batch_stream: Generator of trax.reinforcement.task.TimeStepBatch. + optimizer: Optimizer for network training. + lr_schedule: Learning rate schedule for network training. + policy_distribution: Distribution over actions. + advantage_estimator: Function + (rewards, returns, values, dones) -> advantages, created by one of the + functions from trax.reinforcement.advantages. + value_fn: Function TimeStepBatch -> array (batch_size, seq_len) + calculating the baseline for advantage calculation. Can be used to + implement actor-critic algorithms, by substituting a call to the value + network as value_fn. + weight_fn: Function float -> float to apply to advantages. Examples: + - A2C: weight_fn = id + - AWR: weight_fn = exp + - behavioral cloning: weight_fn(_) = 1 + advantage_normalization: Whether to normalize advantages. + advantage_normalization_epsilon: Epsilon to use then normalizing + advantages. + head_selector: Layer to apply to the network output to select the value + head. Only needed in multitask training. By default, use a no-op layer, + signified by an empty sequence of layers, (). + """ + self.trajectory_batch_stream = trajectory_batch_stream + self._value_fn = value_fn + self._advantage_estimator = advantage_estimator + self._weight_fn = weight_fn + self._advantage_normalization = advantage_normalization + self._advantage_normalization_epsilon = advantage_normalization_epsilon + self.policy_distribution = policy_distribution + + labeled_data = map(self.policy_batch, trajectory_batch_stream) + sample_batch = self.policy_batch(next(trajectory_batch_stream), shape_only=True) + loss_layer = distributions.LogLoss(distribution=policy_distribution) + loss_layer = tl.Serial(head_selector, loss_layer) + super().__init__( + labeled_data, + loss_layer, + optimizer, + sample_batch=sample_batch, + lr_schedule=lr_schedule, + loss_name="policy_loss", + ) + + def calculate_advantages(self, trajectory_batch, shape_only=False): + (batch_size, seq_len) = trajectory_batch.observation.shape[:2] + assert trajectory_batch.action.shape[:2] == (batch_size, seq_len) + assert trajectory_batch.mask.shape == (batch_size, seq_len) + if shape_only: + values = np.zeros((batch_size, seq_len)) + else: + # Compute the value, i.e. baseline in advantage computation. + values = np.array(self._value_fn(trajectory_batch)) + assert values.shape == (batch_size, seq_len) + # Compute the advantages using the chosen advantage estimator. + return self._advantage_estimator( + rewards=trajectory_batch.reward, + returns=trajectory_batch.return_, + dones=trajectory_batch.done, + values=values, + discount_mask=trajectory_batch.env_info.discount_mask, + ) + + def calculate_weights(self, advantages): + """Calculates advantage-based weights for log loss in policy training.""" + if self._advantage_normalization: + # Normalize advantages. + advantages -= jnp.mean(advantages) + advantage_std = jnp.std(advantages) + advantages /= advantage_std + self._advantage_normalization_epsilon + weights = self._weight_fn(advantages) + assert weights.shape == advantages.shape + return weights + + def trim_and_mask_batch(self, trajectory_batch, advantages): + (batch_size, seq_len) = trajectory_batch.observation.shape[:2] + adv_seq_len = advantages.shape[1] + # The advantage sequence should be shorter by the margin. Margin is the + # number of timesteps added to the trajectory slice, to make the advantage + # estimation more accurate. adv_seq_len determines the length of the target + # sequence, and is later used to trim the inputs and targets in the training + # batch. Example for margin 2: + # observations.shape == (4, 5, 6) + # rewards.shape == values.shape == (4, 5) + # advantages.shape == (4, 3) + assert adv_seq_len <= seq_len + assert advantages.shape == (batch_size, adv_seq_len) + # Trim observations, actions and mask to match the target length. + observations = trajectory_batch.observation[:, :adv_seq_len] + actions = trajectory_batch.action[:, :adv_seq_len] + mask = trajectory_batch.mask[:, :adv_seq_len] + # Apply the control mask, so we only compute policy loss for controllable + # timesteps. + mask *= trajectory_batch.env_info.control_mask[:, :adv_seq_len] + return (observations, actions, mask) + + def policy_batch(self, trajectory_batch, shape_only=False): + """Computes a policy training batch based on a trajectory batch. + + Args: + trajectory_batch: trax.reinforcement.task.TimeStepBatch with a batch of trajectory + slices. Elements should have shape (batch_size, seq_len, ...). + shape_only: Whether to return dummy zero arrays of correct shape. Useful + for initializing models. + + Returns: + Triple (observations, actions, weights), where weights are the + advantage-based weights for the policy loss. Shapes: + - observations: (batch_size, seq_len) + observation_shape + - actions: (batch_size, seq_len) + action_shape + - weights: (batch_size, seq_len) + """ + advantages = self.calculate_advantages(trajectory_batch, shape_only=shape_only) + (observations, actions, mask) = self.trim_and_mask_batch( + trajectory_batch, advantages + ) + weights = self.calculate_weights(advantages) * mask / jnp.sum(mask) + return (observations, actions, weights) + + +class PolicyEvalTask(training.EvalTask): + """Task for policy evaluation.""" + + def __init__(self, train_task, n_eval_batches=1, head_selector=()): + """Initializes PolicyEvalTask. + + Args: + train_task: PolicyTrainTask used to train the policy network. + n_eval_batches: Number of batches per evaluation. + head_selector: Layer to apply to the network output to select the value + head. Only needed in multitask training. + """ + self._train_task = train_task + self._policy_dist = train_task.policy_distribution + labeled_data = map(self._eval_batch, train_task.trajectory_batch_stream) + sample_batch = self._eval_batch( + next(train_task.trajectory_batch_stream), shape_only=True + ) + # TODO(pkozakowski): Implement more metrics. + metrics = { + "policy_entropy": self.entropy_metric, + } + metrics.update(self.advantage_metrics) + metrics.update(self.weight_metrics) + metrics = { + name: tl.Serial(head_selector, metric) for (name, metric) in metrics.items() + } + (metric_names, metric_layers) = zip(*metrics.items()) + # Select the appropriate head for evaluation. + super().__init__( + labeled_data, + metric_layers, + sample_batch=sample_batch, + metric_names=metric_names, + n_eval_batches=n_eval_batches, + ) + + def _eval_batch(self, trajectory_batch, shape_only=False): + advantages = self._train_task.calculate_advantages( + trajectory_batch, shape_only=shape_only + ) + (observations, actions, mask) = self._train_task.trim_and_mask_batch( + trajectory_batch, advantages + ) + return (observations, actions, advantages, mask) + + @property + def entropy_metric(self): + def Entropy(policy_inputs, actions, advantages, mask): + del actions, advantages, mask + return jnp.mean(self._policy_dist.entropy(policy_inputs)) + + return tl.Fn("Entropy", Entropy) + + @property + def advantage_metrics(self): + def make_metric(aggregate_fn): # pylint: disable=invalid-name + def AdvantageMetric(policy_inputs, actions, advantages, mask): + del policy_inputs, actions, mask + return aggregate_fn(advantages) + + return tl.Fn("AdvantageMetric", AdvantageMetric) + + return { + "advantage_" + name: make_metric(fn) + for (name, fn) in [ + ("mean", jnp.mean), + ("std", jnp.std), + ] + } + + @property + def weight_metrics(self): + def make_metric(aggregate_fn): # pylint: disable=invalid-name + def WeightMetric(policy_inputs, actions, advantages, mask): + del policy_inputs, actions, mask + weights = self._train_task.calculate_weights(advantages) + return aggregate_fn(weights) + + return tl.Fn("WeightMetric", WeightMetric) + + return { # pylint: disable=g-complex-comprehension + "weight_" + name: make_metric(fn) + for (name, fn) in [ + ("mean", jnp.mean), + ("std", jnp.std), + ("min", jnp.min), + ("max", jnp.max), + ] + } diff --git a/trax/learning/reinforcement/rl_layers.py b/trax/learning/reinforcement/rl_layers.py new file mode 100644 index 000000000..15afa8ef4 --- /dev/null +++ b/trax/learning/reinforcement/rl_layers.py @@ -0,0 +1,257 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A number of RL functions intended to be later wrapped as Trax layers. + + Wrapping happens with help of the function tl.Fn. +""" + +from trax.fastmath import numpy as jnp + + +def ValueLoss(values, returns, value_loss_coeff): + """Definition of the loss of the value function.""" + advantages = returns - values + l2_value_loss = jnp.mean(advantages**2) * value_loss_coeff + return l2_value_loss + + +def ExplainedVariance(values, returns): + """Definition of explained variance - an approach from OpenAI baselines.""" + assert ( + returns.shape == values.shape + ), f"returns.shape was {returns.shape} and values.shape was {values.shape}" + # TODO(henrykm): it would be good to explain the relation with the time dim. + returns_variance = jnp.var(returns) + explained_variance = 1 - jnp.var(returns - values) / returns_variance + return explained_variance + + +def PreferredMove(dist_inputs, sample): + """Definition of the preferred move.""" + preferred_moves = sample(dist_inputs, temperature=0.0) + return jnp.mean(preferred_moves) + + +def NewLogProbs(dist_inputs, actions, log_prob_fun): + """Given distribution and actions calculate log probs.""" + new_log_probs = log_prob_fun(dist_inputs, actions) + return new_log_probs + + +# TODO(henrykm): Clarify how jnp.mean is applied. +def EntropyLoss(dist_inputs, distribution, coeff): + """Definition of the Entropy Layer.""" + entropy_loss = distribution.entropy(dist_inputs) * coeff + return jnp.mean(entropy_loss) + + +def ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun): + """Probability Ratio from the PPO algorithm.""" + # dist_inputs of the shape float32[128,1,18] + # actions of the shape int32[128,1] + # and old_log_probs of the shape float32[128,1] + new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) + assert new_log_probs.shape == old_log_probs.shape, ( + f"new_log_probs.shape was {new_log_probs.shape} and" + f"old_log_probs.shape was {old_log_probs.shape}" + ) + # The ratio between new_probs and old_probs expressed + # using log_probs and exponentiation + probs_ratio = jnp.exp(new_log_probs - old_log_probs) + return probs_ratio + + +def ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun): + """Probability Ratio from the PPO algorithm.""" + new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) + assert new_log_probs.shape == old_log_probs.shape, ( + f"new_log_probs.shape was {new_log_probs.shape} and" + f"old_log_probs.shape was {old_log_probs.shape}" + ) + approximate_kl_divergence = 0.5 * jnp.mean(new_log_probs - old_log_probs) ** 2 + return approximate_kl_divergence + + +def UnclippedObjective(probs_ratio, advantages): + """Unclipped Objective from the PPO algorithm.""" + assert probs_ratio.shape == advantages.shape, ( + f"probs_ratio.shape was {probs_ratio.shape} and" + f"advantages.shape was {advantages.shape}" + ) + unclipped_objective = probs_ratio * advantages + return unclipped_objective + + +def ClippedObjective(probs_ratio, advantages, epsilon): + """Clipped Objective from the PPO algorithm.""" + assert probs_ratio.shape == advantages.shape, ( + f"probs_ratio.shape was {probs_ratio.shape} and" + f"advantages.shape was {advantages.shape}" + ) + clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages + assert probs_ratio.shape == clipped_objective.shape, ( + f"probs_ratio.shape was {probs_ratio.shape} and" + f"clipped_objective.shape was {clipped_objective.shape}" + ) + return clipped_objective + + +def PPOObjective( + dist_inputs, + values, + returns, + dones, + rewards, + actions, + old_log_probs, + log_prob_fun, + epsilon, + normalize_advantages, +): + """PPO Objective.""" + # dist_inputs of the shape float32[128,1,18] + # values of the shape float32[128,1,1] + # returns of the shape float32[128,1,1] + # dones of the shape float32[128,1,1] + # rewards of the shape int32[128,1,1] + # actions of the shape int32[128,1] + # and old_log_probs of the shape float32[128,1] + returns = returns.squeeze(axis=2) + values = values.squeeze(axis=2) + dones = dones.squeeze(axis=2) + rewards = rewards.squeeze(axis=2) + assert ( + rewards.shape == dones.shape + ), f"rewards.shape was {rewards.shape} and dones.shape was {dones.shape}" + assert ( + dones.shape == values.shape + ), f"dones.shape was {dones.shape} and values.shape was {values.shape}" + assert ( + returns.shape == values.shape + ), f"returns.shape was {returns.shape} and values.shape was {values.shape}" + assert returns.shape == old_log_probs.shape, ( + f"returns.shape was {returns.shape} and" + f"old_log_probs.shape was {old_log_probs.shape}" + ) + + probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun) + assert probs_ratio.shape == old_log_probs.shape, ( + f"probs_ratio.shape was {probs_ratio.shape} and" + f"old_log_probs.shape was {old_log_probs.shape}" + ) + + # jaxified versions of + # returns[dones] = rewards[dones] + # values[dones] = 0 + returns = jnp.where(dones, rewards, returns) + values = jnp.where(dones, jnp.zeros_like(values), values) + advantages = returns - values + if normalize_advantages: + advantages = advantages - jnp.mean(advantages) + advantages /= jnp.std(advantages) + 1e-8 + assert old_log_probs.shape == advantages.shape, ( + f"old_log_probs.shape was {old_log_probs.shape} and advantages.shape was " + f"{advantages.shape}" + ) + + unclipped_objective = UnclippedObjective(probs_ratio, advantages) + assert unclipped_objective.shape == advantages.shape, ( + f"old_log_probs.shape was {old_log_probs.shape} and" + f"unclipped_objective.shape was {unclipped_objective.shape}" + ) + + clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon) + assert clipped_objective.shape == advantages.shape, ( + f"clipped_objective.shape was {clipped_objective.shape} and" + f"advantages.shape was {advantages.shape}" + ) + + ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) + assert ppo_objective.shape == advantages.shape, ( + f"ppo_objective.shape was {ppo_objective.shape} and" + f"advantages.shape was {advantages.shape}" + ) + + return ppo_objective + + +def A2CObjective( + dist_inputs, + values, + returns, + dones, + rewards, + actions, + mask, + log_prob_fun, + normalize_advantages, +): + """Definition of the Advantage Actor Critic (A2C) loss.""" + # dist_inputs of the shape float32[128,1,18] + # values of the shape float32[128,1,1] + # returns of the shape float32[128,1,1] + # dones of the shape int32[128,1,1] + # actions of the shape int32[128,1] + # and mask of the shape float32[128,1] + # We have to squeeze values and returns, because we + # are planning to compute (return - values) * new_log_probs * mask + # and all of them should be of the same dimension + values = values.squeeze(axis=2) + returns = returns.squeeze(axis=2) + dones = dones.squeeze(axis=2) + rewards = rewards.squeeze(axis=2) + assert ( + rewards.shape == dones.shape + ), f"rewards.shape was {rewards.shape} and dones.shape was {dones.shape}" + assert ( + dones.shape == values.shape + ), f"dones.shape was {dones.shape} and values.shape was {values.shape}" + assert ( + returns.shape == values.shape + ), f"returns.shape was {returns.shape} and values.shape was {values.shape}" + assert ( + values.shape == mask.shape + ), f"values.shape was {values.shape} and mask.shape was {mask.shape}" + assert returns.shape[0] == dist_inputs.shape[0], ( + f"returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was " + f"{dist_inputs.shape[0]}" + ) + + new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) + assert new_log_probs.shape == mask.shape, ( + f"new_log_probs.shape was {new_log_probs.shape} and mask.shape was " + f"{mask.shape}" + ) + + # jaxified versions of + # returns[dones] = rewards[dones] + # values[dones] = 0 + returns = jnp.where(dones, rewards, returns) + values = jnp.where(dones, jnp.zeros_like(values), values) + advantages = returns - values + if normalize_advantages: + advantages = advantages - jnp.mean(advantages) + advantages /= jnp.std(advantages) + 1e-8 + assert new_log_probs.shape == advantages.shape, ( + f"new_log_probs.shape was {new_log_probs.shape} and advantages.shape was " + f"{advantages.shape}" + ) + + # One of the motivation to the squeezes and assertions is to + # avoid [128,1] * [128,1,1] * [128] multiplications in the definition + # of the a2c objective - we insist on the same shapes + a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask) + return a2c_objective diff --git a/trax/learning/reinforcement/serialization_utils.py b/trax/learning/reinforcement/serialization_utils.py new file mode 100644 index 000000000..80138d29a --- /dev/null +++ b/trax/learning/reinforcement/serialization_utils.py @@ -0,0 +1,472 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for serializing trajectories into discrete sequences.""" + +import functools + +import gym +import numpy as np + +from trax import layers as tl +from trax.fastmath import numpy as jnp +from trax.learning.reinforcement import space_serializer + + +# pylint: disable=invalid-name +# TODO(pkozakowski): Move the layers to trax.layers and remove this module. +def Serialize(serializer): + """Layer that serializes a given array.""" + + def serialize(x): + (batch_size, length) = x.shape[:2] + shape_suffix = x.shape[2:] + x = jnp.reshape(x, (batch_size * length,) + shape_suffix) + x = serializer.serialize(x) + return jnp.reshape( + x, + ( + batch_size, + -1, + serializer.representation_length, + ), + ) + + return tl.Fn("Serialize", serialize) + + +def Interleave(): + """Layer that interleaves and flattens two serialized sequences. + + The first sequence can be longer by 1 than the second one. This is so we can + interleave sequences of observations and actions, when there's 1 extra + observation at the end. + + For serialized sequences [[x_1_1, ..., x_1_R1], ..., [x_L1_1, ..., x_L1_R1]] + and [[y_1_1, ..., y_1_R2], ..., [y_L2_1, ..., y_L2_R2]], where L1 = L2 + 1, + the result is [x_1_1, ..., x_1_R1, y_1_1, ..., y_1_R2, ..., x_L2_1, ..., + x_L2_R1, y_L2_1, ..., y_L2_R2, x_L1_1, ..., x_L1_R1] (batch dimension omitted + for clarity). + + The layer inputs are a sequence pair of shapes (B, L1, R1) and (B, L2, R2), + where B is batch size, L* is the length of the sequence and R* is the + representation length of each element in the sequence. + + Returns: + Layer that interleaves sequence of shape (B, L1 * R1 + L2 * R2). + """ + + def interleave(x, y): + (batch_size, _, _) = x.shape + (_, length, _) = y.shape + assert x.shape[1] in (length, length + 1) + + reprs = jnp.concatenate((x[:, :length], y), axis=2) + reprs = jnp.reshape(reprs, (batch_size, -1)) + remainder = jnp.reshape(x[:, length:], (batch_size, -1)) + return jnp.concatenate((reprs, remainder), axis=1) + + return tl.Fn("Interleave", interleave) + + +def Deinterleave(x_size, y_size): + """Layer that does the inverse of Interleave.""" + + def deinterleave(inputs): + reprs = inputs + (batch_size, length) = reprs.shape[:2] + shape_suffix = reprs.shape[2:] + remainder_length = length % (x_size + y_size) + if remainder_length > 0: + remainder = reprs[:, None, -remainder_length:] + reprs = reprs[:, :-remainder_length] + reprs = jnp.reshape(reprs, (batch_size, -1, x_size + y_size) + shape_suffix) + x_reprs = reprs[:, :, :x_size] + y_reprs = reprs[:, :, x_size:] + if remainder_length > 0: + x_reprs = jnp.concatenate((x_reprs, remainder), axis=1) + return (x_reprs, y_reprs) + + return tl.Fn("Deinterleave", deinterleave, n_out=2) + + +def RepresentationMask(serializer): + """Upsamples a mask to cover the serialized representation.""" + + # Trax enforces the mask to be of the same size as the target. Get rid of the + # extra dimensions. + def representation_mask(mask): + # mask shape (batch_size,4) + mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim))) + # mask shape (batch_size,4) + mask = jnp.repeat( + mask[..., jnp.newaxis], repeats=serializer.representation_length, axis=2 + ) + # mask shape (batch_size,4,representation_length) + return mask + + return tl.Fn("RepresentationMask", representation_mask) + + +def SignificanceWeights(serializer, decay): + """Multiplies a binary mask with a symbol significance mask.""" + + def significance_weights(mask): + # (repr,) -> (batch, length, repr) + # significance = [0, 1, 2] + significance = serializer.significance_map + assert significance.shape[0] == mask.shape[2] + # significance = batch_size * [0, 1, 2] + significance = jnp.repeat( + significance[np.newaxis, ...], repeats=mask.shape[0], axis=0 + ) + # significance = batch_size * [0, 1, 2] * mask.shape[1] + significance = jnp.repeat( + significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2 + ) + # significance = batch_size * mask.shape[1] * [0, 1, 2] + significance = jnp.swapaxes(significance, 1, 2) + assert significance.shape == mask.shape + sig_weights = mask * decay**significance + return sig_weights + + return tl.Fn("SignificanceWeights", significance_weights) + + +class SerializedModel(tl.Serial): + """Wraps a world model in serialization machinery for training. + + The resulting model takes as input the observation and action sequences, + serializes them and interleaves into one sequence, which is fed into a given + autoregressive model. The resulting logit sequence is deinterleaved into + observations and actions, and the observation logits are returned together + with computed symbol significance weights. + + The model has a signature + (obs, act, obs, mask) -> (obs_logits, obs_repr, weights), where obs are + observations (the second occurrence is the target), act are actions, mask is + the observation mask, obs_logits are logits of the output observation + representation, obs_repr is the target observation representation and weights + are the target weights. + """ + + def __init__( + self, + seq_model, + observation_serializer, + action_serializer, + significance_decay, + mode="train", + ): + """Initializes SerializedModel. + + Args: + seq_model: Trax autoregressive model taking as input a sequence of symbols + and outputting a sequence of symbol logits. + observation_serializer: Serializer to use for observations. + action_serializer: Serializer to use for actions. + significance_decay: Float from (0, 1) for exponential weighting of symbols + in the representation. + mode: 'train' or 'eval'. + """ + assert mode in ("train", "eval") + weigh_by_significance = [ + # (mask,) + RepresentationMask(serializer=observation_serializer), + # (repr_mask) + SignificanceWeights( + serializer=observation_serializer, decay=significance_decay + ), + # (mask, sig_weights) + ] + super().__init__( + # (obs, act, obs, mask) + tl.Parallel( + Serialize(serializer=observation_serializer), + Serialize(serializer=action_serializer), + Serialize(serializer=observation_serializer), + ), + # (obs_repr, act_repr, obs_repr, mask) + Interleave(), + # (obs_act_repr, obs_repr, mask) + seq_model(mode=mode), + # (obs_act_logits, obs_repr, mask) + Deinterleave( + x_size=observation_serializer.representation_length, + y_size=action_serializer.representation_length, + ), + # (obs_logits, act_logits, obs_repr, mask) + tl.Parallel(None, tl.Drop(), None, weigh_by_significance), + # (obs_logits, obs_repr, weights) + ) + + self._seq_model = seq_model + self._observation_serializer = observation_serializer + self._action_serializer = action_serializer + + @property + def observation_serializer(self): + return self._observation_serializer + + @property + def action_serializer(self): + return self._action_serializer + + def make_predict_model(self): + """Returns a predict-mode model of the same architecture.""" + return self._seq_model(mode="predict") + + @property + def seq_model_weights(self): + """Extracts the weights of the underlying sequence model.""" + return self.weights[2] + + @property + def seq_model_state(self): + """Extracts the state of the underlying sequence model.""" + return self.state[2] + + +def TimeSeriesModel( + seq_model, + low=0.0, + high=1.0, + precision=2, + vocab_size=64, + significance_decay=0.7, + mode="train", +): + """Simplified constructor for SerializedModel, for time series prediction.""" + # Model scalar time series. + obs_srl = space_serializer.BoxSpaceSerializer( + space=gym.spaces.Box(shape=(), low=low, high=high), + vocab_size=vocab_size, + precision=precision, + ) + # Artifact of the fact that we must provide some actions. + # TODO(pkozakowski): Remove this requirement. + act_srl = space_serializer.DiscreteSpaceSerializer( + space=gym.spaces.Discrete(n=1), vocab_size=1 + ) + seq_model = functools.partial(seq_model, vocab_size=vocab_size) + return SerializedModel(seq_model, obs_srl, act_srl, significance_decay, mode) + + +def RawPolicy(seq_model, n_controls, n_actions): + """Wraps a sequence model in a policy interface. + + The resulting model takes as input observation anc action sequences, but only + uses the observations. Adds output heads for action logits and value + predictions. + + Args: + seq_model: Trax sequence model taking as input and outputting a sequence of + continuous vectors. + n_controls: Number of controls. + n_actions: Number of action categories in each control. + + Returns: + A model of signature (obs, act) -> (act_logits, values), with shapes: + obs: (batch_size, length + 1, obs_depth) + act: (batch_size, length, n_controls) + act_logits: (batch_size, length, n_controls, n_actions) + values: (batch_size, length) + """ + + def SplitControls(): # pylint: disable=invalid-name + """Splits logits for actions in different controls.""" + + def f(x): + return jnp.reshape(x, x.shape[:2] + (n_controls, n_actions)) + + return tl.Fn("SplitControls", f) + + action_head = [ + # Predict all action logits at the same time. + tl.Dense(n_controls * n_actions), + # Then group them into separate controls, adding a new dimension. + SplitControls(), + tl.LogSoftmax(), + ] + return tl.Serial( # (obs, act) + tl.Select([0], n_in=2), # (obs,) + seq_model, # (obs_hidden,) + tl.Dup(), # (obs_hidden, obs_hidden) + tl.Parallel(action_head, [tl.Dense(1), tl.Flatten()]), # (act_logits, values) + ) + + +def substitute_inner_policy_raw( + raw_policy, inner_policy +): # pylint: disable=invalid-name + """Substitutes the weights/state of the inner model in a RawPolicy.""" + return raw_policy[:1] + [inner_policy] + raw_policy[2:] + + +def SerializedPolicy( + seq_model, n_controls, n_actions, observation_serializer, action_serializer +): + """Wraps a policy in serialization machinery for training. + + The resulting model takes as input observation and action sequences, and + serializes them into one sequence similar to SerializedModel, before passing + to the given sequence model. Adds output heads for action logits and value + predictions. + + Args: + seq_model: Trax sequence model taking as input a sequence of symbols and + outputting a sequence of continuous vectors. + n_controls: Number of controls. + n_actions: Number of action categories in each control. + observation_serializer: Serializer to use for observations. + action_serializer: Serializer to use for actions. + + Returns: + A model of signature (obs, act) -> (act_logits, values), same as in + RawPolicy. + """ + if action_serializer.representation_length != n_controls: + raise ValueError( + "Action symbols should correspond 1-1 to controls, but got {} " + "controls and {} symbols.".format( + n_controls, action_serializer.representation_length + ) + ) + + def FirstSymbol(): + return tl.Fn("FirstSymbol", lambda x: x[:, :, 0]) + + def PadRight(n_to_pad): + def pad_right(x): + pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2) + return jnp.pad( + x, pad_widths, mode="constant", constant_values=x.dtype.type(0) + ) + + return tl.Fn(f"PadRight({n_to_pad})", pad_right) + + action_head = [ + tl.Dense(n_actions), + tl.LogSoftmax(), + ] + value_head = [ + # Take just the vectors corresponding to the first action symbol. + FirstSymbol(), + # Predict values. + tl.Dense(1), + # Get rid of the singleton dimension. + tl.Flatten(), + ] + return tl.Serial( + # (obs, act) + tl.Parallel(Serialize(observation_serializer), Serialize(action_serializer)), + # (obs_repr, act_repr) + Interleave(), + # (obs_act_repr,) + # Add one dummy action to the right - we'll use the output at its first + # symbol to predict the value for the last observation. + PadRight(action_serializer.representation_length), + # Shift one symbol to the right, so we predict the n-th action symbol + # based on action symbols 1..n-1 instead of 1..n. + tl.ShiftRight(), + seq_model, + # (obs_act_hidden,) + Deinterleave( + observation_serializer.representation_length, + action_serializer.representation_length, + ), + # (obs_hidden, act_hidden) + tl.Select([1, 1]), + # (act_hidden, act_hidden) + tl.Parallel(action_head, value_head), + # (act_logits, values) + ) + + +def substitute_inner_policy_serialized( + serialized_policy, inner_policy +): # pylint: disable=invalid-name + """Substitutes the weights/state of the inner model in a SerializedPolicy.""" + return serialized_policy[:4] + [inner_policy] + serialized_policy[5:] + + +def analyze_action_space(action_space): # pylint: disable=invalid-name + """Returns the number of controls and actions for an action space.""" + assert isinstance( + action_space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete) + ), "Action space expected to be Discrete of MultiDiscrete, got {}.".format( + type(action_space) + ) + if isinstance(action_space, gym.spaces.Discrete): + n_actions = action_space.n + n_controls = 1 + else: + (n_controls,) = action_space.nvec.shape + assert n_controls > 0 + assert np.min(action_space.nvec) == np.max( + action_space.nvec + ), "Every control must have the same number of actions." + n_actions = action_space.nvec[0] + return (n_controls, n_actions) + + +def wrap_policy( + seq_model, observation_space, action_space, vocab_size +): # pylint: disable=invalid-name + """Wraps a sequence model in either RawPolicy or SerializedPolicy. + + Args: + seq_model: Trax sequence model. + observation_space: Gym observation space. + action_space: Gym action space. + vocab_size: Either the number of symbols for a serialized policy, or None. + + Returns: + RawPolicy if vocab_size is None, else SerializedPolicy. + """ + (n_controls, n_actions) = analyze_action_space(action_space) + if vocab_size is None: + policy_wrapper = RawPolicy + else: + obs_serializer = space_serializer.create(observation_space, vocab_size) + act_serializer = space_serializer.create(action_space, vocab_size) + policy_wrapper = functools.partial( + SerializedPolicy, + observation_serializer=obs_serializer, + action_serializer=act_serializer, + ) + return policy_wrapper(seq_model, n_controls, n_actions) + + +def substitute_inner_policy( + wrapped_policy, inner_policy, vocab_size +): # pylint: disable=invalid-name + """Substitutes the inner weights/state in a {Raw,Serialized}Policy. + + Args: + wrapped_policy (pytree): Weights or state of a wrapped policy. + inner_policy (pytree): Weights or state of an inner policy. + vocab_size (int or None): Vocabulary size of a serialized policy, or None + in case of a raw policy. + + Returns: + New weights or state of wrapped_policy, with the inner weights/state + copied from inner_policy. + """ + if vocab_size is None: + substitute_fn = substitute_inner_policy_raw + else: + substitute_fn = substitute_inner_policy_serialized + return substitute_fn(wrapped_policy, inner_policy) diff --git a/trax/learning/reinforcement/space_serializer.py b/trax/learning/reinforcement/space_serializer.py new file mode 100644 index 000000000..fe9cced9d --- /dev/null +++ b/trax/learning/reinforcement/space_serializer.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization of elements of Gym spaces into discrete sequences.""" +import copy + +import gin +import gym + +from absl import logging +from jax import numpy as np + + +class SpaceSerializer: + """Base class for Gym space serializers. + + Attrs: + space_type: (type) Gym space class that this SpaceSerializer corresponds + to. Should be defined in subclasses. + representation_length: (int) Number of symbols in the representation of + every element of the space. + significance_map: (np.ndarray) Integer array of the same size as the + discrete representation, where elements describe the significance of + symbols, e.g. in fixed-precision encoding. 0 is the most significant + symbol, 1 the second most significant etc. + """ + + space_type = None + representation_length = None + significance_map = None + + def __init__(self, space, vocab_size): + """Creates a SpaceSerializer. + + Subclasses should retain the signature. + + Args: + space: (gym.Space) Gym space of type self.space_type. + vocab_size: (int) Number of symbols in the vocabulary. + """ + assert isinstance(space, self.space_type) + self._space = space + self._vocab_size = vocab_size + + @property + def vocab_size(self): + return self._vocab_size + + def serialize(self, data): + """Serializes a batch of space elements into discrete sequences. + + Should be defined in subclasses. + + Args: + data: A batch of batch_size elements of the Gym space to be serialized. + + Returns: + int32 array of shape (batch_size, self.representation_length). + """ + raise NotImplementedError + + def deserialize(self, representation): + """Deserializes a batch of discrete sequences into space elements. + + Should be defined in subclasses. + + Args: + representation: int32 Numpy array of shape + (batch_size, self.representation_length) to be deserialized. + + Returns: + A batch of batch_size deserialized elements of the Gym space. + """ + raise NotImplementedError + + +def create(space, vocab_size): + """Creates a SpaceSerializer for the given Gym space.""" + return { + gym.spaces.Box: BoxSpaceSerializer, + gym.spaces.Discrete: DiscreteSpaceSerializer, + gym.spaces.MultiDiscrete: MultiDiscreteSpaceSerializer, + }[type(space)](space, vocab_size) + + +@gin.configurable(denylist=["space", "vocab_size"]) +class BoxSpaceSerializer(SpaceSerializer): + """Serializer for gym.spaces.Box. + + Assumes that the space is bounded. Internally rescales it to the [0, 1] + interval and uses a fixed-precision encoding. + """ + + space_type = gym.spaces.Box + + def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)): + self._precision = precision + + # Some gym envs (e.g. CartPole) have unreasonably high bounds for + # observations. We clip so we can represent them. + bounded_space = copy.copy(space) + (min_low, max_high) = max_range + bounded_space.low = np.maximum(space.low, min_low) + bounded_space.high = np.minimum(space.high, max_high) + if not np.allclose(bounded_space.low, space.low) or not np.allclose( + bounded_space.high, space.high + ): + logging.warning( + "Space limits %s, %s out of bounds %s. Clipping to %s, %s.", + str(space.low), + str(space.high), + str(max_range), + str(bounded_space.low), + str(bounded_space.high), + ) + + super().__init__(bounded_space, vocab_size) + + def serialize(self, data): + array = data + batch_size = array.shape[0] + array = (array - self._space.low) / (self._space.high - self._space.low) + array = np.clip(array, 0, 1) + digits = [] + for digit_index in range(-1, -self._precision - 1, -1): + threshold = self._vocab_size**digit_index + digit = np.array(array / threshold).astype(np.int32) + # For the corner case of x == high. + digit = np.where(digit == self._vocab_size, digit - 1, digit) + digits.append(digit) + array -= digit * threshold + digits = np.stack(digits, axis=-1) + return np.reshape(digits, (batch_size, -1)) + + def deserialize(self, representation): + digits = representation + batch_size = digits.shape[0] + digits = np.reshape(digits, (batch_size, -1, self._precision)) + array = np.zeros(digits.shape[:-1]) + for digit_index_in_seq in range(self._precision): + digit_index = -digit_index_in_seq - 1 + array += self._vocab_size**digit_index * digits[..., digit_index_in_seq] + array = np.reshape(array, (batch_size,) + self._space.shape) + return array * (self._space.high - self._space.low) + self._space.low + + @property + def representation_length(self): + return self._precision * self._space.low.size + + @property + def significance_map(self): + return np.reshape( + np.broadcast_to( + np.arange(self._precision), self._space.shape + (self._precision,) + ), + -1, + ) + + +class DiscreteSpaceSerializer(SpaceSerializer): + """Serializer for gym.spaces.Discrete. + + Assumes that the size of the space fits in the number of symbols. + """ + + space_type = gym.spaces.Discrete + representation_length = 1 + + def __init__(self, space, vocab_size): + super().__init__(space, vocab_size) + assert ( + space.n <= vocab_size + ), "Discrete space size should fit in the number of symbols." + + def serialize(self, data): + return np.reshape(data, (-1, 1)).astype(np.int32) + + def deserialize(self, representation): + return np.reshape(representation, -1) + + @property + def significance_map(self): + return np.zeros(1, dtype=np.int32) + + +class MultiDiscreteSpaceSerializer(SpaceSerializer): + """Serializer for gym.spaces.MultiDiscrete. + + Assumes that the number of categories in each dimension fits in the number of + symbols. + """ + + space_type = gym.spaces.MultiDiscrete + + def __init__(self, space, vocab_size): + super().__init__(space, vocab_size) + assert np.max(space.nvec) <= vocab_size, ( + "MultiDiscrete maximum number of categories should fit in the number " + "of symbols." + ) + + def serialize(self, data): + return data.astype(np.int32) + + def deserialize(self, representation): + return representation + + @property + def representation_length(self): + return len(self._space.nvec) + + @property + def significance_map(self): + return np.zeros(self.representation_length, dtype=np.int32) diff --git a/trax/learning/reinforcement/task.py b/trax/learning/reinforcement/task.py new file mode 100644 index 000000000..fe73a9ada --- /dev/null +++ b/trax/learning/reinforcement/task.py @@ -0,0 +1,921 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for defining RL tasks in Trax.""" + +import collections +import os + +import gin +import gym +import numpy as np + +from trax import fastmath +from trax.learning.reinforcement import advantages +from trax.learning.supervised import training + +# TimeStepBatch stores a single step in the trajectory of an RL run, or +# a sequence of timesteps (trajectory slice), or a batch of such sequences. +# Fields: +# * `observation` at the beginning of the step +# * `action` that was taken +# * `reward` gotten when the action was taken (or None if action wasn't taken) +# * `done` - whether the trajectory has finished in this step +# * `mask` - padding mask +# * `return_` - discounted return from this state (includes the current reward); +# `None` if it hasn't been computed yet +# * `dist_inputs` - parameters of the policy distribution, stored by some +# RL algortihms +# TODO(pkozakowski): Generalize `dist_inputs` to `agent_info` - a namedtuple +# storing agent-specific data. +TimeStepBatch = collections.namedtuple( + "TimeStepBatch", + [ + "observation", + "action", + "reward", + "done", + "mask", + "dist_inputs", + "env_info", + "return_", + ], +) + + +# EnvInfo stores additional information returned by +# `trax.reinforcement.envs.SequenceDataEnv`. In those environments, one timestep +# corresponds to one token in the sequence. While the environment is emitting +# observation tokens, actions taken by the agent don't matter. Actions can also +# span multiple tokens, but the discount should only be applied once. +# Fields: +# * `control_mask` - mask determining whether the last interaction was +# controlled, so whether the action performed by the agent mattered; +# can be used to mask policy and value loss; negation can be used to mask +# world model observation loss; defaults to 1 - all actions matter +# * `discount_mask` - mask determining whether the discount should be applied to +# the current reward; defaults to 1 - all rewards are discounted +EnvInfo = collections.namedtuple("EnvInfo", ["control_mask", "discount_mask"]) +EnvInfo.__new__.__defaults__ = (1, 1) + + +# `env_info` and `return_` can be omitted in `TimeStepBatch`. +TimeStepBatch.__new__.__defaults__ = ( + EnvInfo(), + None, +) + + +class Trajectory: + """A trajectory of interactions with a RL environment. + + Trajectories are created when interacting with an RL environment. They can + be prolonged and sliced and when completed, allow to re-calculate returns. + """ + + def __init__(self, observation): + # TODO(lukaszkaiser): add support for saving and loading trajectories, + # reuse code from base_trainer.dump_trajectories and related functions. + self._last_observation = observation + self._timesteps = [] + self._timestep_batch = None + self._cached_to_np_args = None + + def __len__(self): + """Returns the number of observations in the trajectory.""" + # We always have 1 more of observations than of everything else. + return len(self._timesteps) + 1 + + def __repr__(self): + return repr( + { + "timesteps": self._timesteps, + "last_observation": self._last_observation, + } + ) + + def suffix(self, length): + """Returns a `Trajectory` with the last `length` observations.""" + assert length >= 1 + t = Trajectory(self._last_observation) + t._timesteps = self._timesteps[ + -(length - 1) : + ] # pylint: disable=protected-access + return t + + @property + def timesteps(self): + return self._timesteps + + @property + def total_return(self): + """Sum of all rewards in this trajectory.""" + return sum([t.reward or 0.0 for t in self._timesteps]) + + @property + def last_observation(self): + """Return the last observation in this trajectory.""" + return self._last_observation + + @property + def done(self): + """Returns whether the trajectory is finished.""" + if not self._timesteps: + return False + return self._timesteps[-1].done + + @done.setter + def done(self, done): + """Sets the `done` flag in the last timestep.""" + if not self._timesteps: + raise ValueError("No interactions yet in the trajectory.") + self._timesteps[-1] = self._timesteps[-1]._replace(done=done) + + def extend(self, new_observation, mask=1, **kwargs): + """Take action in the last state, getting reward and going to new state.""" + self._timesteps.append( + TimeStepBatch(observation=self._last_observation, mask=mask, **kwargs) + ) + self._last_observation = new_observation + + def calculate_returns(self, gamma): + """Calculate discounted returns.""" + rewards = np.array([ts.reward for ts in self._timesteps]) + discount_mask = np.array([ts.env_info.discount_mask for ts in self._timesteps]) + gammas = advantages.mask_discount(gamma, discount_mask) + returns = advantages.discounted_returns(rewards, gammas) + for i, return_ in enumerate(returns): + self._timesteps[i] = self._timesteps[i]._replace(return_=return_) + + def _default_timestep_to_np(self, ts): + """Default way to convert timestep to numpy.""" + return fastmath.nested_map(np.array, ts) + + def to_np(self, margin=1, timestep_to_np=None): + """Create a tuple of numpy arrays from a given trajectory. + + Args: + margin (int): Number of dummy timesteps past the trajectory end to + include. By default we include 1, which contains the last + observation. + timestep_to_np (callable or None): Optional function + TimeStepBatch[Any] -> TimeStepBatch[np.array], converting the + timestep data into numpy arrays. + + Returns: + TimeStepBatch, where all fields have shape + (len(self) + margin - 1, ...). + """ + timestep_to_np = timestep_to_np or self._default_timestep_to_np + args = (margin, timestep_to_np) + + # Return the cached result if the arguments agree and the trajectory has not + # grown. + if self._timestep_batch: + result_length = len(self) + margin - 1 + length_ok = self._timestep_batch.observation.shape[0] == result_length + if args == self._cached_to_np_args and length_ok: + return self._timestep_batch + + # observation, action, reward, etc. + fields = TimeStepBatch._fields + # List of timestep data for each field. + data_lists = TimeStepBatch(**{field: [] for field in fields}) + for timestep in self._timesteps: + timestep_np = timestep_to_np(timestep) + # Append each field of timestep_np to the appropriate list. + for field in fields: + getattr(data_lists, field).append(getattr(timestep_np, field)) + # Append the last observation. + data_lists.observation.append(self._last_observation) + + # TODO(pkozakowski): The case len(obs) == 1 is for handling + # "dummy trajectories", that are only there to determine data shapes. Check + # if they're still required. + if len(data_lists.observation) > 1: + # Extend the trajectory with a given margin - this is to make sure that + # the networks always "see" the "done" states in the training data, even + # when a suffix is added to the trajectory slice for better estimation of + # returns. + # We set `mask` to 0, so the added timesteps don't influence the loss. We + # set `done` to True for easier implementation of advantage estimators. + # The rest of the fields don't matter, so we set them to 0 for easy + # debugging (unless they're None). The list of observations is longer, so + # we pad it with margin - 1. + data_lists.mask.extend([0] * margin) + data_lists.done.extend([True] * margin) + data_lists.observation.extend( + [np.zeros_like(data_lists.observation[-1])] * (margin - 1) + ) + for field in set(fields) - {"mask", "done", "observation"}: + l = getattr(data_lists, field) + filler = None if l[-1] is None else np.zeros_like(l[-1]) + l.extend([filler] * margin) + + # Trim the observations to have the same length as the rest of the fields. + # This is not be the case when margin=0. + if margin == 0: + data_lists.observation.pop() + + def stack(x): + if not x: + return None + return fastmath.nested_stack(x) + + # Stack the data_lists into numpy arrays. + timestep_batch = TimeStepBatch(*map(stack, data_lists)) + + self._timestep_batch = timestep_batch + self._cached_to_np_args = args + + return timestep_batch + + +def play(env, policy, dm_suite=False, max_steps=None, last_observation=None): + """Play an episode in env taking actions according to the given policy. + + Environment is first reset and an from then on, a game proceeds. At each + step, the policy is asked to choose an action and the environment moves + forward. A Trajectory is created in that way and returns when the episode + finished, which is either when env returns `done` or max_steps is reached. + + Args: + env: the environment to play in, conforming to gym.Env or + DeepMind suite interfaces. + policy: a function taking a Trajectory and returning a pair consisting + of an action (int or float) and the confidence in that action (float, + defined as the log of the probability of taking that action). + dm_suite: whether we are using the DeepMind suite or the gym interface + max_steps: for how many steps to play. + last_observation: last observation from a previous trajectory slice, used to + begin a new one. Controls whether we reset the environment at the + beginning - if `None`, resets the env and starts the slice from the + observation got from reset(). + + Returns: + a completed trajectory slice that was just played. + """ + done = False + cur_step = 0 + if last_observation is None: + # TODO(pkozakowski): Make a Gym wrapper over DM envs to get rid of branches + # like that. + last_observation = env.reset().observation if dm_suite else env.reset() + cur_trajectory = Trajectory(last_observation) + while not done and (max_steps is None or cur_step < max_steps): + action, dist_inputs = policy(cur_trajectory) + action = np.asarray(action) + step = env.step(action) + if dm_suite: + (observation, reward, done) = ( + step.observation, + step.reward, + step.step_type.last(), + ) + info = {} + else: + if isinstance(step, tuple) and len(step) == 5: + observation, reward, terminated, truncated, info = step + done = bool(terminated) or bool(truncated) + info = info or {} + # Surface termination flags so agents may inspect them if needed. + info = { + **info, + "terminated": terminated, + "truncated": truncated, + } + else: + observation, reward, done, info = step + info = info or {} + + # Make an EnvInfo out of the supported keys in the info dict. + env_info = EnvInfo( + **{key: value for (key, value) in info.items() if key in EnvInfo._fields} + ) + cur_trajectory.extend( + action=action, + dist_inputs=dist_inputs, + reward=reward, + done=done, + new_observation=observation, + env_info=env_info, + ) + cur_step += 1 + return cur_trajectory + + +def _zero_pad(x, pad, axis): + """Helper for np.pad with 0s for single-axis case.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = pad # Padding on axis. + return np.pad(x, pad_widths, mode="constant", constant_values=x.dtype.type(0)) + + +def _random_policy(action_space): + return lambda _: (action_space.sample(), None) + + +def _sample_proportionally(inputs, weights): + """Sample an element from the inputs list proportionally to weights. + + Args: + inputs: a list, we will return one element of this list. + weights: a sequence of numbers of the same length as inputs; we will sample + the k-th input with probability weights[k] / sum(weights). + + Returns: + an element from inputs. + """ + l = len(inputs) + weights = np.array(weights) + if l != len(weights): + raise ValueError( + f"Inputs and weights must have the same length, but do not" + f": {l} != {len(weights)}" + ) + norm_weights = weights / np.sum(weights) + # TODO(pkozakowski): Currently this is O(n). It can be sped up to O(log n) by + # storing CDF and binsearching on it. + idx = np.random.choice(l, p=norm_weights) + return inputs[int(idx)] + + +def _n_slices(trajectory, max_slice_length, margin): + """How many slices of length upto max_slice_length in a trajectory.""" + # TODO(lukaszkaiser): add option to sample from n last trajectories. + if not max_slice_length: + return 1 + # A trajectory [a, b, c, end_state] will have 2 slices of length 2: + # the slice [a, b] and the one [b, c], with margin=0; 3 with margin=1. + return max(1, len(trajectory) + margin - max_slice_length) + + +@gin.configurable +class RLTask: + """A RL task: environment and a collection of trajectories.""" + + def __init__( + self, + env=gin.REQUIRED, + initial_trajectories=1, + gamma=0.99, + dm_suite=False, + random_starts=True, + max_steps=None, + time_limit=None, + timestep_to_np=None, + num_stacked_frames=1, + n_replay_epochs=1, + ): + r"""Configures a RL task. + + Args: + env: Environment confirming to the gym.Env interface or a string, + in which case `gym.make` will be called on this string to create an env. + initial_trajectories: either a dict or list of Trajectories to use + at start or an int, in which case that many trajectories are + collected using a random policy to play in env. It can be also a string + and then it should direct to the location where previously recorded + trajectories are stored. + gamma: float: discount factor for calculating returns. + dm_suite: whether we are using the DeepMind suite or the gym interface + random_starts: use random starts for training of Atari agents. + max_steps: optional int: cut all trajectory slices at that many steps. + The trajectory will be continued in the next epochs, up to `time_limit`. + time_limit: optional int: stop all trajectories after that many steps (or + after getting "done"). If `None`, use the same value as `max_steps`. + timestep_to_np: a function that turns a timestep into a numpy array + (ie., a tensor); if None, we just use the state of the timestep to + represent it, but other representations (such as embeddings that include + actions or serialized representations) can be passed here. + num_stacked_frames: the number of stacked frames for Atari. + n_replay_epochs: the size of the replay buffer expressed in epochs. + """ + if isinstance(env, str): + self._env_name = env + if dm_suite: + eval_env = None + env = None + else: + env = gym.make(self._env_name) + eval_env = gym.make(self._env_name) + else: + self._env_name = type(env).__name__ + eval_env = env + self._env = env + self._eval_env = eval_env + self._dm_suite = dm_suite + self._max_steps = max_steps + if time_limit is None: + time_limit = max_steps + self._time_limit = time_limit + self._gamma = gamma + self._initial_trajectories = initial_trajectories + self._last_observation = None + self._n_steps_left = time_limit + # Example trajectory for determining input/output shapes of the networks. + self._example_trajectory = self.play( + _random_policy(self.action_space), only_eval=True + ) + # TODO(lukaszkaiser): find a better way to pass initial trajectories, + # whether they are an explicit list, a file, or a number of random ones. + if isinstance(initial_trajectories, int): + initial_trajectories = [ + self.play(_random_policy(self.action_space)) + for _ in range(initial_trajectories) + ] + if isinstance(initial_trajectories, str): + initial_trajectories = self.load_initial_trajectories_from_path( + initial_trajectories_path=initial_trajectories + ) + if isinstance(initial_trajectories, list): + if initial_trajectories: + initial_trajectories = {0: initial_trajectories} + else: + initial_trajectories = {} + self._timestep_to_np = timestep_to_np + # Stored trajectories are indexed by epoch and within each epoch they + # are stored in the order of generation so we can implement replay buffers. + # TODO(lukaszkaiser): use dump_trajectories from BaseTrainer to allow + # saving and reading trajectories from disk. + self._trajectories = collections.defaultdict(list) + self._trajectories.update(initial_trajectories) + # When we repeatedly save, trajectories for many epochs do not change, so + # we don't need to save them again. This keeps track which are unchanged. + self._saved_epochs_unchanged = [] + self._n_replay_epochs = n_replay_epochs + self._n_trajectories = 0 + self._n_interactions = 0 + + @property + def env(self): + return self._env + + @property + def env_name(self): + return self._env_name + + @property + def max_steps(self): + return self._max_steps + + @property + def gamma(self): + return self._gamma + + @property + def action_space(self): + if self._dm_suite: + return gym.spaces.Discrete(self._env.action_spec().num_values) + else: + return self._env.action_space + + @property + def observation_space(self): + """Returns the env's observation space in a Gym interface.""" + if self._dm_suite: + return gym.spaces.Box( + shape=self._env.observation_spec().shape, + dtype=self._env.observation_spec().dtype, + low=float("-inf"), + high=float("+inf"), + ) + else: + return self._env.observation_space + + @property + def trajectories(self): + return self._trajectories + + @property + def timestep_to_np(self): + return self._timestep_to_np + + @timestep_to_np.setter + def timestep_to_np(self, ts): + self._timestep_to_np = ts + + def _epoch_filename(self, base_filename, epoch): + """Helper function: file name for saving the given epoch.""" + # If base is /foo/task.pkl, we save epoch 1 under /foo/task_epoch1.pkl. + filename, ext = os.path.splitext(base_filename) + return filename + "_epoch" + str(epoch) + ext + + def set_n_replay_epochs(self, n_replay_epochs): + self._n_replay_epochs = n_replay_epochs + + def load_initial_trajectories_from_path( + self, + initial_trajectories_path, + dictionary_file="trajectories.pkl", + start_epoch_to_load=0, + ): + """Initialize trajectories task from file.""" + # We assume that this is a dump generated by Trax + dictionary_file = os.path.join(initial_trajectories_path, dictionary_file) + dictionary = training.unpickle_from_file(dictionary_file, gzip=False) + # TODO(henrykm): as currently implemented this accesses only + # at most the last n_replay_epochs - this should be more flexible + epochs_to_load = dictionary["all_epochs"][start_epoch_to_load:] + + all_trajectories = [] + for epoch in epochs_to_load: + trajectories = training.unpickle_from_file( + self._epoch_filename(dictionary_file, epoch), gzip=True + ) + all_trajectories += trajectories + return all_trajectories + + def init_from_file(self, file_name): + """Initialize this task from file.""" + dictionary = training.unpickle_from_file(file_name, gzip=False) + self._n_trajectories = dictionary["n_trajectories"] + self._n_interactions = dictionary["n_interactions"] + self._max_steps = dictionary["max_steps"] + self._gamma = dictionary["gamma"] + epochs_to_load = dictionary["all_epochs"][-self._n_replay_epochs :] + + for epoch in epochs_to_load: + trajectories = training.unpickle_from_file( + self._epoch_filename(file_name, epoch), gzip=True + ) + self._trajectories[epoch] = trajectories + self._saved_epochs_unchanged = epochs_to_load + + def save_to_file(self, file_name): + """Save this task to file.""" + # Save trajectories from new epochs first. + epochs_to_save = [ + e for e in self._trajectories if e not in self._saved_epochs_unchanged + ] + for epoch in epochs_to_save: + training.pickle_to_file( + self._trajectories[epoch], + self._epoch_filename(file_name, epoch), + gzip=True, + ) + # Now save the list of epochs (so the trajectories are already there, + # even in case of preemption). + dictionary = { + "n_interactions": self._n_interactions, + "n_trajectories": self._n_trajectories, + "max_steps": self._max_steps, + "gamma": self._gamma, + "all_epochs": list(self._trajectories.keys()), + } + training.pickle_to_file(dictionary, file_name, gzip=False) + + def play(self, policy, max_steps=None, only_eval=False): + """Play an episode in env taking actions according to the given policy.""" + if max_steps is None: + max_steps = self._max_steps + if only_eval: + cur_trajectory = play( + self._eval_env, + policy, + self._dm_suite, + # Only step up to the time limit. + max_steps=min(max_steps, self._time_limit), + # Always reset at the beginning of an eval episode. + last_observation=None, + ) + else: + cur_trajectory = play( + self._env, + policy, + self._dm_suite, + # Only step up to the time limit, used up by all slices of the same + # trajectory. + max_steps=min(max_steps, self._n_steps_left), + # Pass the environmnent state between play() calls, so one episode can + # span multiple training epochs. + # NOTE: Cutting episodes between epochs may hurt e.g. with + # Transformers. + # TODO(pkozakowski): Join slices together if this becomes a problem. + last_observation=self._last_observation, + ) + # Update the number of steps left to reach the time limit. + # NOTE: This should really be done as an env wrapper. + # TODO(pkozakowski): Do that once we wrap the DM envs in Gym interface. + # The initial reset doesn't count. + self._n_steps_left -= len(cur_trajectory) - 1 + assert self._n_steps_left >= 0 + if self._n_steps_left == 0: + cur_trajectory.done = True + # Pass the last observation between trajectory slices. + if cur_trajectory.done: + self._last_observation = None + # Reset the time limit. + self._n_steps_left = self._time_limit + else: + self._last_observation = cur_trajectory.last_observation + + cur_trajectory.calculate_returns(self._gamma) + return cur_trajectory + + def collect_trajectories( + self, + policy, + n_trajectories=None, + n_interactions=None, + only_eval=False, + max_steps=None, + epoch_id=1, + ): + """Collect experience in env playing the given policy.""" + max_steps = max_steps or self.max_steps + if n_trajectories is not None: + new_trajectories = [ + self.play(policy, max_steps=max_steps, only_eval=only_eval) + for _ in range(n_trajectories) + ] + elif n_interactions is not None: + new_trajectories = [] + while n_interactions > 0: + traj = self.play(policy, max_steps=min(n_interactions, max_steps)) + new_trajectories.append(traj) + n_interactions -= len(traj) - 1 # The initial reset does not count. + else: + raise ValueError("Either n_trajectories or n_interactions must be defined.") + + # Calculate returns. + returns = [t.total_return for t in new_trajectories] + if returns: + mean_returns = sum(returns) / float(len(returns)) + else: + mean_returns = 0 + + # If we're only evaluating, we're done, return the average. + if only_eval: + return mean_returns + # Store new trajectories. + if new_trajectories: + self._trajectories[epoch_id].extend(new_trajectories) + + # Mark that epoch epoch_id has changed. + if epoch_id in self._saved_epochs_unchanged: + self._saved_epochs_unchanged = [ + e for e in self._saved_epochs_unchanged if e != epoch_id + ] + + # Remove epochs not intended to be in the buffer + current_trajectories = { + key: value + for key, value in self._trajectories.items() + if key >= epoch_id - self._n_replay_epochs + } + self._trajectories = collections.defaultdict(list) + self._trajectories.update(current_trajectories) + + # Update statistics. + self._n_trajectories += len(new_trajectories) + self._n_interactions += sum([len(traj) for traj in new_trajectories]) + + return mean_returns + + def n_trajectories(self, epochs=None): + # TODO(henrykm) support selection of epochs if really necessary (will + # require a dump of a list of lengths in save_to_file + del epochs + return self._n_trajectories + + def n_interactions(self, epochs=None): + # TODO(henrykm) support selection of epochs if really necessary (will + # require a dump of a list of lengths in save_to_file + del epochs + return self._n_interactions + + def _random_slice(self, trajectory, max_slice_length, margin): + """Returns a random TimeStepBatch slice from a given trajectory.""" + # Sample a slice from the trajectory. + slice_start = np.random.randint(_n_slices(trajectory, max_slice_length, margin)) + + # Convert the whole trajectory to Numpy while adding the margin. The + # result is cached, so we don't have to repeat this for every sample. + timestep_batch = trajectory.to_np(margin, self._timestep_to_np) + + # Slice and yield the result. + slice_end = slice_start + ( + max_slice_length or timestep_batch.observation.shape[0] + ) + return fastmath.nested_map(lambda x: x[slice_start:slice_end], timestep_batch) + + def _trajectory_stream( + self, + epochs=None, + max_slice_length=None, + sample_trajectories_uniformly=False, + margin=0, + ): + """Return a stream of random trajectory slices from the specified epochs. + + Args: + epochs: a list of epochs to use; we use all epochs if None + max_slice_length: maximum length of the slices of trajectories to return + sample_trajectories_uniformly: whether to sample trajectories uniformly, + or proportionally to the number of slices in each trajectory (default) + margin: number of extra steps after "done" that should be included in + slices, so that networks see the terminal states in the training data + + Yields: + random trajectory slices sampled uniformly from all slices of length + up to max_slice_length in all specified epochs + """ + # {int: array[int]}; + # epoch_to_ns_slices[epoch][i] = n_slices(self._trajectories[epoch][i]) + # It stores arrays for faster sampling. + epoch_to_ns_slices = {} + # {int: int}; + # epoch_to_total_n_slices[epoch] = sum(epoch_to_ns_slices[epoch]) + epoch_to_total_n_slices = {} + # [int]: list of epoch indices to sample from. + epoch_indices = [] + # epoch_to_total_n_slices filtered using epoch_indices. It's an array for + # faster sampling. + sampling_epoch_weights = None + + def new_epoch(epoch_id): + """Updates the lists defined above to include the new epoch.""" + all_epochs = list(self._trajectories.keys()) + max_epoch = max(all_epochs) + 1 + + # Calculate the numbers of slices for the new epoch. + epoch_to_ns_slices[epoch_id] = np.array( + [ + _n_slices(trajectory, max_slice_length, margin) + for trajectory in self._trajectories[epoch_id] + ] + ) + epoch_to_total_n_slices[epoch_id] = np.sum(epoch_to_ns_slices[epoch_id]) + + # Update the indices of epochs to sample from. + new_epoch_indices = epochs or all_epochs + new_epoch_indices = [ + # So -1 means "last". + ep % max_epoch + for ep in new_epoch_indices + ] + # Remove duplicates and consider only epochs where some trajectories + # were recorded and that we have processed in new_epoch. + new_epoch_indices = [ + epoch_id + for epoch_id in set(new_epoch_indices) + if self._trajectories[epoch_id] and epoch_id in epoch_to_ns_slices + ] + epoch_indices[:] = new_epoch_indices + + nonlocal sampling_epoch_weights + sampling_epoch_weights = np.array( + list(epoch_to_total_n_slices[ep] for ep in epoch_indices) + ) + + while True: + # If we haven't collected any trajectories yet, yield an example + # trajectory. It's needed to determine the input/output shapes of + # networks. + if not self._trajectories: + yield self._example_trajectory + continue + + # Catch up if we have a new epoch or we've restarted the experiment. + for epoch_id in ( + self._trajectories.keys() - epoch_to_ns_slices.keys() + ): # pylint:disable=g-builtin-op + new_epoch(epoch_id) + + # Sample an epoch proportionally to number of slices in each epoch. + epoch_id = _sample_proportionally(epoch_indices, sampling_epoch_weights) + epoch = self._trajectories[epoch_id] + + # Sample a trajectory proportionally to number of slices in each one. + if sample_trajectories_uniformly: + slices_per_trajectory = np.ones((len(epoch),)) + else: + slices_per_trajectory = epoch_to_ns_slices[epoch_id] + trajectory = _sample_proportionally(epoch, slices_per_trajectory) + + yield trajectory + + def trajectory_slice_stream( + self, + epochs=None, + max_slice_length=None, + sample_trajectories_uniformly=False, + margin=0, + trajectory_stream_preprocessing_fn=None, + ): + """Return a stream of random trajectory slices from the specified epochs. + + Args: + epochs: a list of epochs to use; we use all epochs if None + max_slice_length: maximum length of the slices of trajectories to return + sample_trajectories_uniformly: whether to sample trajectories uniformly, + or proportionally to the number of slices in each trajectory (default) + margin: number of extra steps after "done" that should be included in + slices, so that networks see the terminal states in the training data + trajectory_stream_preprocessing_fn: function to apply to the trajectory + stream before batching; can be used e.g. to filter trajectories + + Yields: + random trajectory slices sampled uniformly from all slices of length + up to max_slice_length in all specified epochs + """ + trajectory_stream = self._trajectory_stream( + epochs=epochs, + max_slice_length=max_slice_length, + sample_trajectories_uniformly=sample_trajectories_uniformly, + margin=margin, + ) + + if trajectory_stream_preprocessing_fn is not None: + trajectory_stream = trajectory_stream_preprocessing_fn(trajectory_stream) + + for trajectory in trajectory_stream: + yield self._random_slice(trajectory, max_slice_length, margin) + + def trajectory_batch_stream( + self, + batch_size, + epochs=None, + max_slice_length=None, + min_slice_length=None, + margin=0, + sample_trajectories_uniformly=False, + trajectory_stream_preprocessing_fn=None, + ): + """Return a stream of trajectory batches from the specified epochs. + + This function returns a stream of tuples of numpy arrays (tensors). + If tensors have different lengths, they will be padded by 0. + + Args: + batch_size: the size of the batches to return + epochs: a list of epochs to use; we use all epochs if None + max_slice_length: maximum length of the slices of trajectories to return + min_slice_length: minimum length of the slices of trajectories to return + margin: number of extra steps after "done" that should be included in + slices, so that networks see the terminal states in the training data + sample_trajectories_uniformly: whether to sample trajectories uniformly, + or proportionally to the number of slices in each trajectory (default) + trajectory_stream_preprocessing_fn: function to apply to the trajectory + stream before batching; can be used e.g. to filter trajectories + + Yields: + batches of trajectory slices sampled uniformly from all slices of length + at least min_slice_length and up to max_slice_length in all specified + epochs + """ + + def pad(tensor_list): + # Replace Nones with valid tensors. + not_none_tensors = [t for t in tensor_list if t is not None] + assert not_none_tensors, "All tensors to pad are None." + prototype = np.zeros_like(not_none_tensors[0]) + tensor_list = [t if t is not None else prototype for t in tensor_list] + + max_len = max([t.shape[0] for t in tensor_list]) + if min_slice_length is not None: + max_len = max(max_len, min_slice_length) + min_len = min([t.shape[0] for t in tensor_list]) + if max_len == min_len: # No padding needed. + return np.array(tensor_list) + + pad_len = 2 ** int(np.ceil(np.log2(max_len))) + return np.array( + [_zero_pad(t, (0, pad_len - t.shape[0]), axis=0) for t in tensor_list] + ) + + trajectory_slice_stream = self.trajectory_slice_stream( + epochs=epochs, + max_slice_length=max_slice_length, + sample_trajectories_uniformly=sample_trajectories_uniformly, + margin=margin, + trajectory_stream_preprocessing_fn=trajectory_stream_preprocessing_fn, + ) + + cur_batch = [] + for t in trajectory_slice_stream: + cur_batch.append(t) + if len(cur_batch) == batch_size: + # Make a nested TimeStepBatch of lists out of a list of TimeStepBatches. + timestep_batch = fastmath.nested_zip(cur_batch) + # Actions, rewards and returns in the trajectory slice have shape + # [batch_size, trajectory_length], which we denote as [B, L]. + # Observations are more complex: [B, L] + S, where S is the shape of the + # observation space (self.observation_space.shape). + # We stop the recursion at level 1, so we pass lists of arrays into + # pad(). + yield fastmath.nested_map(pad, timestep_batch, level=1) + cur_batch = [] diff --git a/trax/learning/reinforcement/training.py b/trax/learning/reinforcement/training.py new file mode 100644 index 000000000..da8cbd136 --- /dev/null +++ b/trax/learning/reinforcement/training.py @@ -0,0 +1,1216 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for RL training in Trax.""" + +import contextlib +import functools +import os +import pickle +import time + +import gin +import numpy as np +import tensorflow as tf + +from trax import data, fastmath, models +from trax import layers as tl +from trax.fastmath import numpy as jnp +from trax.learning import supervised +from trax.learning.reinforcement import ( + advantages, + distributions, + policy_tasks, +) +from trax.learning.reinforcement import task as rl_task +from trax.learning.supervised import lr_schedules as lr +from trax.optimizers import adam +from trax.utils import jaxboard, shapes + + +class Agent: + """Abstract class for RL agents, presenting the required API.""" + + def __init__( + self, + task: rl_task.RLTask, + n_trajectories_per_epoch=None, + n_interactions_per_epoch=None, + n_eval_episodes=0, + eval_steps=None, + eval_temperatures=(0.0,), + only_eval=False, + output_dir=None, + timestep_to_np=None, + ): + """Configures the Agent. + + Note that subclasses can have many more arguments, which will be configured + using defaults and gin. But task and output_dir are passed explicitly. + + Args: + task: RLTask instance, which defines the environment to train on. + n_trajectories_per_epoch: How many new trajectories to collect in each + epoch. + n_interactions_per_epoch: How many interactions to collect in each epoch. + n_eval_episodes: Number of episodes to play with policy at + temperature 0 in each epoch -- used for evaluation only. + eval_steps: an optional list of max_steps to use for evaluation + (defaults to task.max_steps). + eval_temperatures: we always train with temperature 1 and evaluate with + temperature specified in the eval_temperatures list + (defaults to [0.0, 0.5]) + only_eval: If set to True, then trajectories are collected only for + for evaluation purposes, but they are not recorded. + output_dir: Path telling where to save outputs such as checkpoints. + timestep_to_np: Timestep-to-numpy function to override in the task. + """ + if n_trajectories_per_epoch is None == n_interactions_per_epoch is None: + raise ValueError( + "Exactly one of n_trajectories_per_epoch or " + "n_interactions_per_epoch should be specified." + ) + self._epoch = 0 + self._task = task + self._eval_steps = eval_steps or [task.max_steps] + if timestep_to_np is not None: + self._task.timestep_to_np = timestep_to_np + self._n_trajectories_per_epoch = n_trajectories_per_epoch + self._n_interactions_per_epoch = n_interactions_per_epoch + self._only_eval = only_eval + self._output_dir = output_dir + self._avg_returns = [] + self._n_eval_episodes = n_eval_episodes + self._eval_temperatures = eval_temperatures + self._avg_returns_temperatures = { + eval_t: {step: [] for step in self._eval_steps} + for eval_t in eval_temperatures + } + if self._output_dir is not None: + self.init_from_file() + + @property + def current_epoch(self): + """Returns current step number in this training session.""" + return self._epoch + + @property + def task(self): + """Returns the task.""" + return self._task + + @property + def avg_returns(self): + return self._avg_returns + + def save_gin(self, summary_writer=None): + assert self._output_dir is not None + config_path = os.path.join(self._output_dir, "config.gin") + config_str = gin.operative_config_str() + with tf.io.gfile.GFile(config_path, "w") as f: + f.write(config_str) + if summary_writer is not None: + summary_writer.text( + "gin_config", jaxboard.markdownify_operative_config_str(config_str) + ) + + def save_to_file( + self, file_name="reinforcement.pkl", task_file_name="trajectories.pkl" + ): + """Save current epoch number and average returns to file.""" + assert self._output_dir is not None + task_path = os.path.join(self._output_dir, task_file_name) + self._task.save_to_file(task_path) + file_path = os.path.join(self._output_dir, file_name) + dictionary = {"epoch": self._epoch, "avg_returns": self._avg_returns} + for eval_t in self._eval_temperatures: + dictionary[ + "avg_returns_temperature_{}".format(eval_t) + ] = self._avg_returns_temperatures[eval_t] + with tf.io.gfile.GFile(file_path, "wb") as f: + pickle.dump(dictionary, f) + + def init_from_file( + self, file_name="reinforcement.pkl", task_file_name="trajectories.pkl" + ): + """Initialize epoch number and average returns from file.""" + assert self._output_dir is not None + task_path = os.path.join(self._output_dir, task_file_name) + if tf.io.gfile.exists(task_path): + self._task.init_from_file(task_path) + file_path = os.path.join(self._output_dir, file_name) + if not tf.io.gfile.exists(file_path): + return + with tf.io.gfile.GFile(file_path, "rb") as f: + dictionary = pickle.load(f) + self._epoch = dictionary["epoch"] + self._avg_returns = dictionary["avg_returns"] + for eval_t in self._eval_temperatures: + self._avg_returns_temperatures[eval_t] = dictionary[ + "avg_returns_temperature_{}".format(eval_t) + ] + + def _collect_trajectories(self): + return self.task.collect_trajectories( + self.policy, + n_trajectories=self._n_trajectories_per_epoch, + n_interactions=self._n_interactions_per_epoch, + only_eval=self._only_eval, + epoch_id=self._epoch, + ) + + def policy(self, trajectory, temperature=1.0): + """Policy function that allows to play using this trainers. + + Args: + trajectory: an instance of trax.reinforcement.task.Trajectory + temperature: temperature used to sample from the policy (default=1.0) + + Returns: + a pair (action, dist_inputs) where action is the action taken and + dist_inputs is the parameters of the policy distribution, that will later + be used for training. + """ + raise NotImplementedError + + def train_epoch(self): + """Trains this Agent for one epoch -- main RL logic goes here.""" + raise NotImplementedError + + @contextlib.contextmanager + def _open_summary_writer(self): + """Opens the Jaxboard summary writer wrapped by a context manager. + + Yields: + A Jaxboard summary writer wrapped in a GeneratorContextManager object. + Elements of the lists correspond to the training and evaluation task + directories created during initialization. If there is no output_dir + provided, yields None. + """ + if self._output_dir is not None: + writer = jaxboard.SummaryWriter( + os.path.join(self._output_dir, "reinforcement") + ) + try: + yield writer + finally: + writer.close() + else: + yield None + + def run(self, n_epochs=1, n_epochs_is_total_epochs=False): + """Runs this loop for n epochs. + + Args: + n_epochs: Stop training after completing n steps. + n_epochs_is_total_epochs: if True, consider n_epochs as the total + number of epochs to train, including previously trained ones + """ + with self._open_summary_writer() as sw: + n_epochs_to_run = n_epochs + if n_epochs_is_total_epochs: + n_epochs_to_run -= self._epoch + cur_n_interactions = 0 + for _ in range(n_epochs_to_run): + self._epoch += 1 + cur_time = time.time() + avg_return = self._collect_trajectories() + self._avg_returns.append(avg_return) + if self._n_trajectories_per_epoch: + supervised.trainer_lib.log( + "Collecting %d episodes took %.2f seconds." + % (self._n_trajectories_per_epoch, time.time() - cur_time) + ) + else: + supervised.trainer_lib.log( + "Collecting %d interactions took %.2f seconds." + % (self._n_interactions_per_epoch, time.time() - cur_time) + ) + supervised.trainer_lib.log( + "Average return in epoch %d was %.2f." % (self._epoch, avg_return) + ) + if self._n_eval_episodes > 0: + for steps in self._eval_steps: + for eval_t in self._eval_temperatures: + avg_return_temperature = self.task.collect_trajectories( + functools.partial(self.policy, temperature=eval_t), + n_trajectories=self._n_eval_episodes, + max_steps=steps, + only_eval=True, + ) + supervised.trainer_lib.log( + "Eval return in epoch %d with temperature %.2f was %.2f." + % (self._epoch, eval_t, avg_return_temperature) + ) + self._avg_returns_temperatures[eval_t][steps].append( + avg_return_temperature + ) + + if sw is not None: + sw.scalar( + "timing/collect", time.time() - cur_time, step=self._epoch + ) + sw.scalar("reinforcement/avg_return", avg_return, step=self._epoch) + if self._n_eval_episodes > 0: + for steps in self._eval_steps: + for eval_t in self._eval_temperatures: + sw.scalar( + "reinforcement/avg_return_temperature%.2f_steps%d" + % (eval_t, steps), + self._avg_returns_temperatures[eval_t][steps][-1], + step=self._epoch, + ) + sw.scalar( + "reinforcement/n_interactions", + self.task.n_interactions(), + step=self._epoch, + ) + sw.scalar( + "reinforcement/n_interactions_per_second", + (self.task.n_interactions() - cur_n_interactions) + / (time.time() - cur_time), + step=self._epoch, + ) + cur_n_interactions = self.task.n_interactions() + sw.scalar( + "reinforcement/n_trajectories", + self.task.n_trajectories(), + step=self._epoch, + ) + sw.flush() + + cur_time = time.time() + self.train_epoch() + supervised.trainer_lib.log( + "RL training took %.2f seconds." % (time.time() - cur_time) + ) + + if self._output_dir is not None and self._epoch == 1: + self.save_gin(sw) + if self._output_dir is not None: + self.save_to_file() + + def close(self): + pass + + +class PolicyAgent(Agent): + """Agent that uses a deep learning model for policy. + + Many deep RL methods, such as policy gradient (REINFORCE) or actor-critic fall + into this category, so a lot of classes will be subclasses of this one. But + some methods only have a value or Q function, these are different. + """ + + def __init__( + self, + task, + policy_model=None, + policy_optimizer=None, + policy_lr_schedule=lr.multifactor, + policy_batch_size=64, + policy_train_steps_per_epoch=500, + policy_evals_per_epoch=1, + policy_eval_steps=1, + n_eval_episodes=0, + only_eval=False, + max_slice_length=1, + output_dir=None, + **kwargs, + ): + """Configures the policy trainers. + + Args: + task: RLTask instance, which defines the environment to train on. + policy_model: Trax layer, representing the policy model. + functions and eval functions (a.k.a. metrics) are considered to be + outside the core model, taking core model output and data labels as + their two inputs. + policy_optimizer: the optimizer to use to train the policy model. + policy_lr_schedule: learning rate schedule to use to train the policy. + policy_batch_size: batch size used to train the policy model. + policy_train_steps_per_epoch: how long to train policy in each RL epoch. + policy_evals_per_epoch: number of policy trainers evaluations per RL epoch + - only affects metric reporting. + policy_eval_steps: number of policy trainers steps per evaluation - only + affects metric reporting. + n_eval_episodes: number of episodes to play with policy at + temperature 0 in each epoch -- used for evaluation only + only_eval: If set to True, then trajectories are collected only for + for evaluation purposes, but they are not recorded. + max_slice_length: the maximum length of trajectory slices to use. + output_dir: Path telling where to save outputs (evals and checkpoints). + **kwargs: arguments for the superclass Agent. + """ + super().__init__( + task, n_eval_episodes=n_eval_episodes, output_dir=output_dir, **kwargs + ) + self._policy_batch_size = policy_batch_size + self._policy_train_steps_per_epoch = policy_train_steps_per_epoch + self._policy_evals_per_epoch = policy_evals_per_epoch + self._policy_eval_steps = policy_eval_steps + self._only_eval = only_eval + self._max_slice_length = max_slice_length + self._policy_dist = distributions.create_distribution(task.action_space) + + # Inputs to the policy model are produced by self._policy_batches_stream. + self._policy_inputs = data.inputs.Inputs( + train_stream=lambda _: self.policy_batches_stream() + ) + + policy_model = functools.partial( + policy_model, + policy_distribution=self._policy_dist, + ) + + # This is the policy Trainer that will be used to train the policy model. + # * inputs to the trainers come from self.policy_batches_stream + # * outputs, targets and weights are passed to self.policy_loss + self._policy_trainer = supervised.Trainer( + model=policy_model, + optimizer=policy_optimizer, + lr_schedule=policy_lr_schedule(), + loss_fn=self.policy_loss, + inputs=self._policy_inputs, + output_dir=output_dir, + metrics=self.policy_metrics, + ) + self._policy_collect_model = tl.Accelerate( + policy_model(mode="collect"), n_devices=1 + ) + policy_batch = next(self.policy_batches_stream()) + self._policy_collect_model.init(shapes.signature(policy_batch)) + self._policy_eval_model = tl.Accelerate( + policy_model(mode="eval"), n_devices=1 + ) # Not collecting stats + self._policy_eval_model.init(shapes.signature(policy_batch)) + + @property + def policy_loss(self): + """Policy loss.""" + return NotImplementedError + + @property + def policy_metrics(self): + return {"policy_loss": self.policy_loss} + + def policy_batches_stream(self): + """Use self.task to create inputs to the policy model.""" + return NotImplementedError + + def policy(self, trajectory, temperature=1.0): + """Chooses an action to play after a trajectory.""" + model = self._policy_collect_model + if temperature != 1.0: # When evaluating (t != 1.0), don't collect stats + model = self._policy_eval_model + model.state = self._policy_collect_model.state + model.replicate_weights(self._policy_trainer.model_weights) + tr_slice = trajectory.suffix(self._max_slice_length) + trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) + # Add batch dimension to trajectory_np and run the model. + pred = model(trajectory_np.observation[None, ...]) + # Pick element 0 from the batch (the only one), last (current) timestep. + pred = pred[0, -1, :] + sample = self._policy_dist.sample(pred, temperature=temperature) + result = (sample, pred) + if fastmath.is_backend(fastmath.Backend.JAX): + result = fastmath.nested_map(lambda x: x.copy(), result) + return result + + def train_epoch(self): + """Trains RL for one epoch.""" + # When restoring, calculate how many evals are remaining. + n_evals = remaining_evals( + self._policy_trainer.step, + self._epoch, + self._policy_train_steps_per_epoch, + self._policy_evals_per_epoch, + ) + for _ in range(n_evals): + self._policy_trainer.train_epoch( + self._policy_train_steps_per_epoch // self._policy_evals_per_epoch, + self._policy_eval_steps, + ) + + def close(self): + self._policy_trainer.close() + super().close() + + +def remaining_evals(cur_step, epoch, train_steps_per_epoch, evals_per_epoch): + """Helper function to calculate remaining evaluations for a trainers. + + Args: + cur_step: current step of the supervised trainers + epoch: current epoch of the RL trainers + train_steps_per_epoch: supervised trainers steps per RL epoch + evals_per_epoch: supervised trainers evals per RL epoch + + Returns: + number of remaining evals to do this epoch + + Raises: + ValueError if the provided numbers indicate a step mismatch + """ + if epoch < 1: + raise ValueError("Epoch must be at least 1, got %d" % epoch) + prev_steps = (epoch - 1) * train_steps_per_epoch + done_steps_this_epoch = cur_step - prev_steps + if done_steps_this_epoch < 0: + raise ValueError( + "Current step (%d) < previously done steps (%d)." % (cur_step, prev_steps) + ) + train_steps_per_eval = train_steps_per_epoch // evals_per_epoch + if done_steps_this_epoch % train_steps_per_eval != 0: + raise ValueError( + "Done steps (%d) must divide train steps per eval (%d)." + % (done_steps_this_epoch, train_steps_per_eval) + ) + return evals_per_epoch - (done_steps_this_epoch // train_steps_per_eval) + + +class LoopPolicyAgent(Agent): + """Base class for policy-only Agents based on Loop.""" + + def __init__( + self, + task, + model_fn, + value_fn, + weight_fn, + n_replay_epochs, + n_train_steps_per_epoch, + advantage_normalization, + optimizer=adam.Adam, + lr_schedule=lr.multifactor, + batch_size=64, + network_eval_at=None, + n_eval_batches=1, + max_slice_length=1, + trajectory_stream_preprocessing_fn=None, + **kwargs, + ): + """Initializes LoopPolicyAgent. + + Args: + task: Instance of trax.reinforcement.task.RLTask. + model_fn: Function (policy_distribution, mode) -> policy_model. + value_fn: Function TimeStepBatch -> array (batch_size, seq_len) + calculating the baseline for advantage calculation. + weight_fn: Function float -> float to apply to advantages when calculating + policy loss. + n_replay_epochs: Number of last epochs to take into the replay buffer; + only makes sense for off-policy algorithms. + n_train_steps_per_epoch: Number of steps to train the policy network for + in each epoch. + advantage_normalization: Whether to normalize the advantages before + passing them to weight_fn. + optimizer: Optimizer for network training. + lr_schedule: Learning rate schedule for network training. + batch_size: Batch size for network training. + network_eval_at: Function step -> bool indicating the training steps, when + network evaluation should be performed. + n_eval_batches: Number of batches to run during network evaluation. + max_slice_length: The length of trajectory slices to run the network on. + trajectory_stream_preprocessing_fn: Function to apply to the trajectory + stream before batching. Can be used e.g. to filter trajectories. + **kwargs: Keyword arguments passed to the superclass. + """ + self._n_train_steps_per_epoch = n_train_steps_per_epoch + super().__init__(task, **kwargs) + + task.set_n_replay_epochs(n_replay_epochs) + self._max_slice_length = max_slice_length + trajectory_batch_stream = task.trajectory_batch_stream( + batch_size, + epochs=[-(ep + 1) for ep in range(n_replay_epochs)], + max_slice_length=self._max_slice_length, + sample_trajectories_uniformly=True, + trajectory_stream_preprocessing_fn=trajectory_stream_preprocessing_fn, + ) + self._policy_dist = distributions.create_distribution(task.action_space) + train_task = policy_tasks.PolicyTrainTask( + trajectory_batch_stream, + optimizer(), + lr_schedule(), + self._policy_dist, + # Without a value network it doesn't make a lot of sense to use + # a better advantage estimator than MC. + advantage_estimator=advantages.monte_carlo(task.gamma, margin=0), + advantage_normalization=advantage_normalization, + value_fn=value_fn, + weight_fn=weight_fn, + ) + eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches) + model_fn = functools.partial( + model_fn, + policy_distribution=self._policy_dist, + ) + + if self._output_dir is not None: + policy_output_dir = os.path.join(self._output_dir, "policy") + else: + policy_output_dir = None + # Checkpoint every epoch. + checkpoint_at = lambda step: step % n_train_steps_per_epoch == 0 + self._loop = supervised.training.Loop( + model=model_fn(mode="train"), + tasks=[train_task], + eval_model=model_fn(mode="eval"), + eval_tasks=[eval_task], + output_dir=policy_output_dir, + eval_at=network_eval_at, + checkpoint_at=checkpoint_at, + ) + self._collect_model = model_fn(mode="collect") + self._collect_model.init(shapes.signature(train_task.sample_batch)) + + # Validate the restored checkpoints. + # TODO(pkozakowski): Move this to the base class once all Agents use Loop. + if self._loop.step != self._epoch * self._n_train_steps_per_epoch: + raise ValueError( + "The number of Loop steps must equal the number of Agent epochs " + "times the number of steps per epoch, got {}, {} and {}.".format( + self._loop.step, self._epoch, self._n_train_steps_per_epoch + ) + ) + + @property + def loop(self): + """Loop exposed for testing.""" + return self._loop + + def train_epoch(self): + """Trains RL for one epoch.""" + # Copy policy state accumulated during data collection to the trainers. + self._loop.update_weights_and_state(state=self._collect_model.state) + # Train for the specified number of steps. + self._loop.run(n_steps=self._n_train_steps_per_epoch) + + +class PolicyGradient(LoopPolicyAgent): + """Trains a policy model using policy gradient on the given RLTask.""" + + def __init__(self, task, model_fn, **kwargs): + """Initializes PolicyGradient. + + Args: + task: Instance of trax.reinforcement.task.RLTask. + model_fn: Function (policy_distribution, mode) -> policy_model. + **kwargs: Keyword arguments passed to the superclass. + """ + super().__init__( + task, + model_fn, + # We're on-policy, so we can only use data from the last epoch. + n_replay_epochs=1, + # Each gradient computation needs a new data sample, so we do 1 step + # per epoch. + n_train_steps_per_epoch=1, + # Very simple baseline: mean return across trajectories. + value_fn=self._value_fn, + # Weights are just advantages. + weight_fn=(lambda x: x), + # Normalize advantages, because this makes optimization nicer. + advantage_normalization=True, + **kwargs, + ) + + def policy(self, trajectory, temperature=1.0): + """Policy function that samples from the trained network.""" + tr_slice = trajectory.suffix(self._max_slice_length) + trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) + return network_policy( + collect_model=self._collect_model, + policy_distribution=self._policy_dist, + loop=self.loop, + trajectory_np=trajectory_np, + temperature=temperature, + ) + + @staticmethod + def _value_fn(trajectory_batch): + # Estimate the value of every state as the mean return across trajectories + # and timesteps in a batch. + value = np.mean(trajectory_batch.return_) + return np.broadcast_to(value, trajectory_batch.return_.shape) + + +@gin.configurable +def sharpened_network_policy(temperature, temperature_multiplier=1.0, **kwargs): + """Expert function that runs a policy network with lower temperature. + + Args: + temperature: Temperature passed from the Agent. + temperature_multiplier: Multiplier to apply to the temperature to "sharpen" + the policy distribution. Should be <= 1, but this is not a requirement. + **kwargs: Keyword arguments passed to network_policy. + + Returns: + Pair (action, dist_inputs) where action is the action taken and dist_inputs + is the parameters of the policy distribution, that will later be used for + training. + """ + return network_policy(temperature=(temperature_multiplier * temperature), **kwargs) + + +class ExpertIteration(LoopPolicyAgent): + """Trains a policy model using expert iteration with a given expert.""" + + def __init__( + self, + task, + model_fn, + expert_policy_fn=sharpened_network_policy, + quantile=0.9, + n_replay_epochs=10, + n_train_steps_per_epoch=1000, + filter_buffer_size=256, + **kwargs, + ): + """Initializes ExpertIteration. + + Args: + task: Instance of trax.reinforcement.task.RLTask. + model_fn: Function (policy_distribution, mode) -> policy_model. + expert_policy_fn: Function of the same signature as `network_policy`, to + be used as an expert. The policy will be trained to mimic the expert on + the "solved" trajectories. + quantile: Quantile of best trajectories to be marked as "solved". They + will be used to train the policy. + n_replay_epochs: Number of last epochs to include in the replay buffer. + n_train_steps_per_epoch: Number of policy training steps to run in each + epoch. + filter_buffer_size: Number of trajectories in the trajectory filter + buffer, used to select the best trajectories based on the quantile. + **kwargs: Keyword arguments passed to the superclass. + """ + self._expert_policy_fn = expert_policy_fn + self._quantile = quantile + self._filter_buffer_size = filter_buffer_size + super().__init__( + task, + model_fn, + # Don't use a baseline - it's not useful in our weights. + value_fn=(lambda batch: jnp.zeros_like(batch.return_)), + # Don't weight trajectories - the training signal is provided by + # filtering trajectories. + weight_fn=jnp.ones_like, + # Filter trajectories based on the quantile. + trajectory_stream_preprocessing_fn=self._filter_trajectories, + # Advantage normalization is a no-op here. + advantage_normalization=False, + n_replay_epochs=n_replay_epochs, + n_train_steps_per_epoch=n_train_steps_per_epoch, + **kwargs, + ) + + def policy(self, trajectory, temperature=1.0): + """Policy function that runs the expert.""" + tr_slice = trajectory.suffix(self._max_slice_length) + trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) + return self._expert_policy_fn( + collect_model=self._collect_model, + policy_distribution=self._policy_dist, + loop=self.loop, + trajectory_np=trajectory_np, + temperature=temperature, + ) + + def _filter_trajectories(self, trajectory_stream): + """Filter trajectories based on the quantile.""" + + def trajectory_return(trajectory): + return trajectory.timesteps[0].return_ + + trajectory_buffer = [] + for trajectory in trajectory_stream: + trajectory_buffer.append(trajectory) + if len(trajectory_buffer) == self._filter_buffer_size: + n_best = int((1 - self._quantile) * self._filter_buffer_size) or 1 + trajectory_buffer.sort(key=trajectory_return, reverse=True) + yield from trajectory_buffer[:n_best] + trajectory_buffer.clear() + + +def network_policy( + collect_model, + policy_distribution, + loop, + trajectory_np, + head_index=0, + temperature=1.0, +): + """Policy function powered by a neural network. + + Used to implement Agent.policy() in policy-based agents. + + Args: + collect_model: the model used for collecting trajectories + policy_distribution: an instance of trax.reinforcement.distributions.Distribution + loop: trax.supervised.training.Loop used to train the policy network + trajectory_np: an instance of trax.reinforcement.task.TimeStepBatch + head_index: index of the policy head a multihead model. + temperature: temperature used to sample from the policy (default=1.0) + + Returns: + a pair (action, dist_inputs) where action is the action taken and + dist_inputs is the parameters of the policy distribution, that will later + be used for training. + """ + if temperature == 1.0: + model = collect_model + else: + # When evaluating (t != 1.0), use the evaluation model instead of the + # collection model - some models accumulate normalization statistics + # during data collection, and we don't want to do it in eval to avoid data + # leakage. + model = loop.eval_model + model.state = collect_model.state + # Copying weights from loop.model should work, because the raw model's + # weights should be updated automatically during training, but it doesn't. + # TODO(pkozakowski): Debug. + acc = loop._trainer_per_task[ + 0 + ].accelerated_model_with_loss # pylint: disable=protected-access + model.weights = acc._unreplicate(acc.weights[0]) # pylint: disable=protected-access + # Add batch dimension to trajectory_np and run the model. + pred = model(trajectory_np.observation[None, ...]) + if isinstance(pred, (tuple, list)): + # For multihead models, extract the policy head output. + pred = pred[head_index] + assert pred.shape == ( + 1, + trajectory_np.observation.shape[0], + policy_distribution.n_inputs, + ) + # Pick element 0 from the batch (the only one), last (current) timestep. + pred = pred[0, -1, :] + sample = policy_distribution.sample(pred, temperature=temperature) + result = (sample, pred) + if fastmath.is_backend(fastmath.Backend.JAX): + # The result is composed of mutable numpy arrays. We copy them to avoid + # accidental modification. + result = fastmath.nested_map(lambda x: x.copy(), result) + return result + + +class ValueAgent(Agent): + """Trainer that uses a deep learning model for value function. + + Compute the loss using variants of the Bellman equation. + """ + + def __init__( + self, + task, + value_body=None, + value_optimizer=None, + value_lr_schedule=lr.multifactor, + value_batch_size=64, + value_train_steps_per_epoch=500, + value_evals_per_epoch=1, + value_eval_steps=1, + exploration_rate=functools.partial( + lr.multifactor, + factors="constant * decay_every", + constant=1.0, # pylint: disable=redefined-outer-name + decay_factor=0.99, + steps_per_decay=1, + minimum=0.1, + ), + n_eval_episodes=0, + only_eval=False, + n_replay_epochs=1, + max_slice_length=1, + sync_freq=1000, + scale_value_targets=True, + output_dir=None, + **kwargs, + ): + """Configures the value trainers. + + Args: + task: RLTask instance, which defines the environment to train on. + value_body: Trax layer, representing the body of the value model. + functions and eval functions (a.k.a. metrics) are considered to be + outside the core model, taking core model output and data labels as + their two inputs. + value_optimizer: the optimizer to use to train the policy model. + value_lr_schedule: learning rate schedule to use to train the policy. + value_batch_size: batch size used to train the policy model. + value_train_steps_per_epoch: how long to train policy in each RL epoch. + value_evals_per_epoch: number of policy trainers evaluations per RL epoch + - only affects metric reporting. + value_eval_steps: number of policy trainers steps per evaluation - only + affects metric reporting. + exploration_rate: exploration rate schedule - used in the policy method. + n_eval_episodes: number of episodes to play with policy at + temperature 0 in each epoch -- used for evaluation only + only_eval: If set to True, then trajectories are collected only for + for evaluation purposes, but they are not recorded. + n_replay_epochs: Number of last epochs to take into the replay buffer; + only makes sense for off-policy algorithms. + max_slice_length: the maximum length of trajectory slices to use; it is + the second dimenions of the value network output: + (batch, max_slice_length, number of actions) + Higher max_slice_length implies that the network has to predict more + values into the future. + sync_freq: frequency when to synchronize the target + network with the trained network. This is necessary for training the + network on bootstrapped targets, e.g. using n-step returns. + scale_value_targets: If `True`, scale value function targets by + `1 / (1 - gamma)`. We are trying to fix the problem with very large + returns in some games in a way which does not introduce an additional + hyperparameters. + output_dir: Path telling where to save outputs (evals and checkpoints). + **kwargs: arguments for the superclass RLTrainer. + """ + super(ValueAgent, self).__init__( + task, n_eval_episodes=n_eval_episodes, output_dir=output_dir, **kwargs + ) + self._value_batch_size = value_batch_size + self._value_train_steps_per_epoch = value_train_steps_per_epoch + self._value_evals_per_epoch = value_evals_per_epoch + self._value_eval_steps = value_eval_steps + self._only_eval = only_eval + self._max_slice_length = max_slice_length + self._policy_dist = distributions.create_distribution(task.action_space) + self._n_replay_epochs = n_replay_epochs + + self._exploration_rate = exploration_rate() + self._sync_at = lambda step: step % sync_freq == 0 + + if scale_value_targets: + self._value_network_scale = 1 / (1 - self._task.gamma) + else: + self._value_network_scale = 1 + + value_model = functools.partial( + models.Quality, body=value_body, n_actions=self.task.action_space.n + ) + + self._value_eval_model = value_model(mode="eval") + self._value_eval_model.init(self._value_model_signature) + self._value_eval_jit = tl.jit_forward( + self._value_eval_model.pure_fn, fastmath.local_device_count(), do_mean=False + ) + + # Inputs to the value model are produced by self._values_batches_stream. + self._inputs = data.inputs.Inputs( + train_stream=lambda _: self.value_batches_stream() + ) + + # This is the value Trainer that will be used to train the value model. + # * inputs to the trainers come from self.value_batches_stream + # * outputs, targets and weights are passed to self.value_loss + self._value_trainer = supervised.Trainer( + model=value_model, + optimizer=value_optimizer, + lr_schedule=value_lr_schedule(), + loss_fn=self.value_loss, + inputs=self._inputs, + output_dir=output_dir, + metrics={ + "value_loss": self.value_loss, + "value_mean": self.value_mean, + "returns_mean": self.returns_mean, + }, + ) + value_batch = next(self.value_batches_stream()) + self._eval_model = tl.Accelerate(value_model(mode="collect"), n_devices=1) + self._eval_model.init(shapes.signature(value_batch)) + + @property + def _value_model_signature(self): + obs_sig = shapes.signature(self._task.observation_space) + target_sig = mask_sig = shapes.ShapeDtype( + shape=(1, 1, self._task.action_space), + ) + inputs_sig = obs_sig.replace(shape=(1, 1) + obs_sig.shape) + return (inputs_sig, target_sig, mask_sig) + + def value_batches_stream(self): + """Use self.task to create inputs to the policy model.""" + raise NotImplementedError + + def policy(self, trajectory, temperature=1): + """Chooses an action to play after a trajectory.""" + raise NotImplementedError + + def train_epoch(self): + """Trains RL for one epoch.""" + # Update the target value network. + self._value_eval_model.weights = self._value_trainer.model_weights + self._value_eval_model.state = self._value_trainer.model_state + + # When restoring, calculate how many evals are remaining. + n_evals = remaining_evals( + self._value_trainer.step, + self._epoch, + self._value_train_steps_per_epoch, + self._value_evals_per_epoch, + ) + for _ in range(n_evals): + self._value_trainer.train_epoch( + self._value_train_steps_per_epoch // self._value_evals_per_epoch, + self._value_eval_steps, + ) + value_metrics = dict( + {"exploration_rate": self._exploration_rate(self._epoch)} + ) + self._value_trainer.log_metrics( + value_metrics, self._value_trainer._train_sw, "dqn" + ) # pylint: disable=protected-access + # Update the target value network. + # TODO(henrykm) a bit tricky if sync_at does not coincide with epochs + if self._sync_at(self._value_trainer.step): + self._value_eval_model.weights = self._value_trainer.model_weights + self._value_eval_model.state = self._value_trainer.model_state + + def close(self): + self._value_trainer.close() + super().close() + + @property + def value_mean(self): + """The mean value of actions selected by the behavioral policy.""" + raise NotImplementedError + + @property + def returns_mean(self): + """The mean value of actions selected by the behavioral policy.""" + + def f(values, index_max, returns, mask): + del values, index_max + return jnp.sum(returns) / jnp.sum(mask) + + return tl.Fn("ReturnsMean", f) + + +class DQN(ValueAgent): + r"""Trains a value model using DQN on the given RLTask. + + Notice that the algorithm and the parameters signficantly diverge from + the original DQN paper. In particular we have separated learning and data + collection. + + The Bellman loss is computed in the value_loss method. The formula takes + the state-action values tensors Q and n-step returns R: + + .. math:: + L(s,a) = Q(s,a) - R(s,a) + + where R is computed in value_batches_stream. In the simplest case of the + 1-step returns we are getting + + .. math:: + L(s,a) = Q(s,a) - r(s,a) - gamma * \max_{a'} Q'(s',a') + + where s' is the state reached after taking action a in state s, Q' is + the target network, gamma is the discount factor and the maximum is taken + with respect to all actions avaliable in the state s'. The tensor Q' is + updated using the sync_freq parameter. + + In code the maximum is visible in the policy method where we take + sample = jnp.argmax(values). The epsilon-greedy policy is taking a random + move with probability epsilon and oterhwise in state s it takes the + action argmax_a Q(s,a). + """ + + def __init__( + self, + task, + advantage_estimator=advantages.monte_carlo, + max_slice_length=1, + smoothl1loss=True, + double_dqn=False, + **kwargs, + ): + self._max_slice_length = max_slice_length + self._margin = max_slice_length - 1 + # Our default choice of learning targets for DQN are n-step targets + # implemented in the method td_k. We set the slice used for computation + # of td_k to max_slice_length and we set the "margin" in td_k + # to self._max_slice_length-1; in turn it implies that the shape of the + # returned tensor of n-step targets is + # values[:, :-(self.margin)] = values[:, :1] + self._advantage_estimator = advantage_estimator( + gamma=task.gamma, margin=self._margin + ) + self._smoothl1loss = smoothl1loss + self._double_dqn = double_dqn + super(DQN, self).__init__( + task=task, max_slice_length=max_slice_length, **kwargs + ) + + @property + def value_loss(self): + """Value loss computed using smooth L1 loss or L2 loss.""" + + def f(values, actions, returns, mask): + ind_0, ind_1 = np.indices(actions.shape) + # We calculate length using the shape of returns + # and adequatly remove a superflous slice of values. + # An analogous operation is done in value_batches_stream. + length = returns.shape[1] + values = values[:, :length, :] + selected_values = values[ind_0, ind_1, actions] + shapes.assert_same_shape(selected_values, returns) + shapes.assert_same_shape(selected_values, mask) + if self._smoothl1loss: + return tl.SmoothL1Loss().forward((selected_values, returns, mask)) + else: + return tl.L2Loss().forward((selected_values, returns, mask)) + + return tl.Fn("ValueLoss", f) + + @property + def _replay_epochs(self): + return [-(ep + 1) for ep in range(self._n_replay_epochs)] + + def value_batches_stream(self): + """Use the RLTask self._task to create inputs to the value model.""" + max_slice_length = self._max_slice_length + min_slice_length = 1 + for np_trajectory in self._task.trajectory_batch_stream( + self._value_batch_size, + max_slice_length=max_slice_length, + min_slice_length=min_slice_length, + margin=0, + epochs=self._replay_epochs, + ): + values_target = self._run_value_model( + np_trajectory.observation, use_eval_model=True + ) + if self._double_dqn: + values = self._run_value_model( + np_trajectory.observation, use_eval_model=False + ) + index_max = np.argmax(values, axis=-1) + ind_0, ind_1 = np.indices(index_max.shape) + values_max = values_target[ind_0, ind_1, index_max] + else: + values_max = np.array(jnp.max(values_target, axis=-1)) + + # The advantage_estimator returns + # gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards + # - values_max(s_i) + # hence we have to add values_max(s_i) in order to get n-step returns: + # gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards + # Notice, that in DQN the tensor values_max[:, :-self._margin] + # is the same as values_max[:, :-1]. + n_step_returns = values_max[:, : -self._margin] + self._advantage_estimator( + rewards=np_trajectory.reward, + returns=np_trajectory.return_, + values=values_max, + dones=np_trajectory.done, + discount_mask=np_trajectory.env_info.discount_mask, + ) + + length = n_step_returns.shape[1] + target_returns = n_step_returns[:, :length] + inputs = np_trajectory.observation[:, :length, :] + + yield ( + # Inputs are observations + # (batch, length, obs) + inputs, + # the max indices will be needed to compute the loss + np_trajectory.action[:, :length], # index_max, + # Targets: computed returns. + # target_returns, we expect here shapes such as + # (batch, length, num_actions) + target_returns / self._value_network_scale, + # TODO(henrykm): mask has the shape (batch, max_slice_length) + # that is it misses the action dimension; the preferred format + # would be np_trajectory.mask[:, :length, :] but for now we pass: + np.ones(shape=target_returns.shape), + ) + + def policy(self, trajectory, temperature=1): + """Chooses an action to play after a trajectory.""" + tr_slice = trajectory.suffix(self._max_slice_length) + trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) + # Add batch dimension to trajectory_np and run the model. + obs = trajectory_np.observation[None, ...] + values = self._run_value_model(obs, use_eval_model=False) + # We insisit that values and observations have the shape + # (batch, length, ...), where the length is the number of subsequent + # observations on a given trajectory + assert values.shape[:1] == obs.shape[:1] + # We select the last element in the batch and the value + # related to the last (current) observation + values = values[0, -1, :] + # temperature == 0 is used in another place in order to trigger eval + if ( + np.random.random_sample() < self._exploration_rate(self._epoch) + and temperature == 1 + ): + sample = np.array(self.task.action_space.sample()) + else: + # this is our way of doing the argmax + sample = jnp.argmax(values) + result = (sample, values) + if fastmath.backend_name() == "jax": + result = fastmath.nested_map(lambda x: x.copy(), result) + return result + + def _run_value_model(self, obs, use_eval_model=True): + """Runs value model.""" + n_devices = fastmath.local_device_count() + if use_eval_model: + weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) + state = tl.for_n_devices(self._value_eval_model.state, n_devices) + rng = self._value_eval_model.rng + else: + # TODO(henrykm): this strangely looking solution address the problem that + # value_batches_stream calls _run_value_model _once_ before + # the trainers is initialized. + try: + weights = tl.for_n_devices(self._value_trainer.model_weights, n_devices) + state = tl.for_n_devices(self._value_trainer.model_state, n_devices) + rng = self._value_trainer._rng # pylint: disable=protected-access + except AttributeError: + weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) + state = tl.for_n_devices(self._value_eval_model.state, n_devices) + rng = self._value_eval_model.rng + # TODO(henrykm): the line below fails on TPU with the error + # ValueError: Number of devices (8) does not evenly divide batch size (1). + obs_batch = obs.shape[0] + if n_devices > obs_batch: + obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0) + values, _ = self._value_eval_jit(obs, weights, state, rng) + values = values[:obs_batch] + values *= self._value_network_scale + return values + + @property + def value_mean(self): + """The mean value of actions selected by the behavioral policy.""" + + def f(values, actions, returns, mask): + ind_0, ind_1 = np.indices(actions.shape) + # We calculate length using the shape of returns + # and adequatly remove a superflous slice of values. + # An analogous operation is done in value_batches_stream. + length = returns.shape[1] + values = values[:, :length, :] + selected_values = values[ind_0, ind_1, actions] + shapes.assert_same_shape(selected_values, returns) + shapes.assert_same_shape(selected_values, mask) + return jnp.sum(selected_values) / jnp.sum(mask) + + return tl.Fn("ValueMean", f) diff --git a/trax/learning/reinforcement/value_tasks.py b/trax/learning/reinforcement/value_tasks.py new file mode 100644 index 000000000..aa6ea81bb --- /dev/null +++ b/trax/learning/reinforcement/value_tasks.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Value network training tasks.""" + +import copy + +import numpy as np + +from trax import layers as tl +from trax.fastmath import numpy as jnp +from trax.learning.supervised import training + + +class ValueTrainTask(training.TrainTask): + """Task for value training.""" + + def __init__( + self, + trajectory_batch_stream, + optimizer, + lr_schedule, + advantage_estimator, + model, + target_model=None, + target_scale=1.0, + sync_at=(lambda step: step % 100 == 0), + loss_layer=None, + head_selector=(), + ): + """Initializes ValueTrainTask. + + Args: + trajectory_batch_stream: Generator of trax.reinforcement.task.TimeStepBatch. + optimizer: Optimizer for network training. + lr_schedule: Learning rate schedule for network training. + advantage_estimator: Function + (rewards, returns, values, dones) -> advantages, created by one of the + functions from trax.reinforcement.advantages. + model: Model being trained, used to synchronize weights of the target + model. + target_model: Model for calculating TD targets. If `None`, use `model`. + target_scale: Multiplier for the targets. Useful for rescaling the targets + to a more reasonable range for model training. + sync_at: Function step -> bool, indicating when to synchronize the target + network with the trained network. This is necessary for training the + network on bootstrapped targets, e.g. using TD-k. + loss_layer: The value loss layer. The default is L2 loss. + head_selector: Layer to apply to the network output to select the value + head. Only needed in multitask training. + """ + self._trajectory_batch_stream = trajectory_batch_stream + self._advantage_estimator = advantage_estimator + self._target_scale = target_scale + + self._synced = False + + def sync_also_on_initial_batch(step): + return sync_at(step) or not self._synced + + self._sync_at = sync_also_on_initial_batch + + self._head_selector = head_selector + + def attach_head(model): + return tl.Serial(model, self._head_selector) + + self._train_model = attach_head(model) + if target_model is None: + target_model = model + # TODO(pkozakowski): Use target_model.clone() once it's implemented. + self._target_model = attach_head(copy.deepcopy(target_model)) + + # Count the steps, so we know when to synchronize the target network. + self._step = 0 + + def labeled_data(): + for trajectory_batch in self._trajectory_batch_stream: + self._step += 1 + yield self.value_batch(trajectory_batch) + + sample_batch = self.value_batch(next(trajectory_batch_stream), shape_only=True) + if loss_layer is None: + loss_layer = tl.L2Loss() + loss_layer = tl.Serial(head_selector, loss_layer) + super().__init__( + labeled_data(), + loss_layer, + optimizer, + sample_batch=sample_batch, + lr_schedule=lr_schedule, + loss_name="value_loss", + ) + + @property + def trajectory_batch_stream(self): + return self._trajectory_batch_stream + + def _sync_target_model(self): + self._target_model.weights = self._train_model.weights + self._target_model.state = self._train_model.state + self._synced = True + + def value_batch(self, trajectory_batch, shape_only=False): + """Computes a value training batch based on a trajectory batch. + + Args: + trajectory_batch: trax.reinforcement.task.TimeStepBatch with a batch of trajectory + slices. Elements should have shape (batch_size, seq_len, ...). + shape_only: Whether to return dummy zero arrays of correct shape. Useful + for initializing models. + + Returns: + Triple (observations, targets, weights), where targets are the target + values for network training and weights are used for masking in loss + computation. Shapes: + - observations: (batch_size, seq_len) + observation_shape + - targets: (batch_size, seq_len) + - weights: (batch_size, seq_len) + """ + if self._sync_at(self._step) and not shape_only: + self._sync_target_model() + + (batch_size, seq_len) = trajectory_batch.observation.shape[:2] + assert trajectory_batch.action.shape[:2] == (batch_size, seq_len) + assert trajectory_batch.mask.shape == (batch_size, seq_len) + # Compute the value from the target network. + values = np.array(self.value(trajectory_batch, shape_only=shape_only)) + assert values.shape == (batch_size, seq_len) + # Compute the advantages - the TD errors of the target network. + advantages = self._advantage_estimator( + rewards=trajectory_batch.reward, + returns=trajectory_batch.return_, + dones=trajectory_batch.done, + values=values, + discount_mask=trajectory_batch.env_info.discount_mask, + ) + adv_seq_len = advantages.shape[1] + # The advantage sequence should be shorter by the margin. For more details, + # see the comment in policy_tasks.PolicyTrainTask.policy_batch. + assert adv_seq_len <= seq_len + assert advantages.shape == (batch_size, adv_seq_len) + # Compute the targets based on the target values and their TD errors. The + # network gives perfect predictions when targets == values, so the + # advantages are zero. + targets = (values[:, :adv_seq_len] + advantages) * self._target_scale + # Trim observations and the mask to match the target length. + observations = trajectory_batch.observation[:, :adv_seq_len] + mask = trajectory_batch.mask[:, :adv_seq_len] + # Add a singleton depth dimension to the targets and the mask. + targets = targets[:, :, None] + mask = mask[:, :, None] + return (observations, targets, mask) + + def value(self, trajectory_batch, shape_only=False): + """Computes values of states in a given batch of trajectory slices. + + Can be passed as value_fn to PolicyTrainTask to implement a critic baseline + for advantage calculation. + + Args: + trajectory_batch: Batch of trajectory slices to compute values for. + shape_only: Whether to return dummy zero arrays of correct shape. Useful + for initializing models. + + Returns: + Array of values of all states in `trajectory_batch`. + """ + if shape_only: + # The target model hasn't been initialized yet, and we are asked for the + # initial, sample batch. Only shape matters here, so just return zeros. + return np.zeros(trajectory_batch.observation.shape[:2]) + + if not self._synced: + self._sync_target_model() + + values = self._target_model(trajectory_batch.observation) + # Squeeze the singleton depth axis. + return np.squeeze(values, axis=-1) / self._target_scale + + +class ValueEvalTask(training.EvalTask): + """Task for value evaluation.""" + + def __init__(self, train_task, n_eval_batches=1, head_selector=()): + """Initializes ValueEvalTask. + + Args: + train_task: ValueTrainTask used to train the policy network. + n_eval_batches: Number of batches per evaluation. + head_selector: Layer to apply to the network output to select the value + head. Only needed in multitask training. + """ + labeled_data = map(train_task.value_batch, train_task.trajectory_batch_stream) + metrics = [tl.L2Loss(), self.l1_loss] + metric_names = ["value_l2", "value_l1"] + # Select the appropriate head for evaluation. + metrics = [tl.Serial(head_selector, metric) for metric in metrics] + super().__init__( + labeled_data, + metrics, + sample_batch=train_task.sample_batch, + metric_names=metric_names, + n_eval_batches=n_eval_batches, + ) + + @property + def l1_loss(self): + def loss(values, targets, weights): + return jnp.sum(jnp.abs(values - targets) * weights) / jnp.sum(weights) + + return tl.Fn("L1Loss", loss) diff --git a/trax/tf_numpy/numpy_impl/__init__.py b/trax/learning/supervised/__init__.py similarity index 93% rename from trax/tf_numpy/numpy_impl/__init__.py rename to trax/learning/supervised/__init__.py index 4d1d79a2d..dcd4e3919 100644 --- a/trax/tf_numpy/numpy_impl/__init__.py +++ b/trax/learning/supervised/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""NumPy API. Deprecated.""" +"""Supervised learning imports in Trax.""" diff --git a/trax/learning/supervised/callbacks.py b/trax/learning/supervised/callbacks.py new file mode 100644 index 000000000..3fe8d03f0 --- /dev/null +++ b/trax/learning/supervised/callbacks.py @@ -0,0 +1,249 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loop callbacks. + +Callbacks can be used to customize the behavior of `supervised.training.Loop` +to accomodate a variety of use-cases. + +Examples include: + - custom evaluation schemes + - logging metrics to external servers + - sending model checkpoints to external servers + - updating the target network in RL algorithms and other non-stationary + problems +""" + +import collections +import os + +import gin +import numpy as np + +from trax import layers as tl +from trax.learning.reinforcement import serialization_utils +from trax.learning.supervised import decoding +from trax.utils import jaxboard, shapes + + +class TrainingStepCallback: + """Callback triggered before and after a training step.""" + + def __init__(self, loop): + """Initializes the callback with a `supervised.training.Loop` instance.""" + self._loop = loop + + def call_at(self, step): + """Returns whether the callback should be called at a given step.""" + raise NotImplementedError + + def on_step_begin(self, step): + """Called by Loop before training steps, when call_at returned True.""" + raise NotImplementedError + + def on_step_end(self, step): + """Called by Loop after training steps, when call_at returned True.""" + raise NotImplementedError + + +@gin.configurable +class SerializedModelEvaluation(TrainingStepCallback): + """Evaluates serialized sequence prediction models. + + Example: time series prediction. We can serialize a time series into + a sequence of discrete tokens and model this sequence using an autoregressive + sequence model, such as Transformer - see + `trax.reinforcement.serialization_utils.SerializedModel`. Then we can use this callback + to evaluate long-horizon predictions of such a model. + """ + + def __init__( + self, + loop, + model=None, + eval_at=1000, + eval_task=None, + context_lengths=(1,), + horizon_lengths=(1,), + n_steps=1, + accelerate_model=True, + ): + """Initializes SerializedModelEvaluation. + + Args: + loop: Instance of `trax.supervised.training.Loop` or `None`. Can be set to + `None` for testing - in such a case, `model` and `eval_task` must be + provided. + model: Instance of `trax.reinforcement.serialization_utils.SerializedModel`. Not + required if `loop` is provided. + eval_at: When to evaluate. Either int (every how many steps to evaluate), + or a list of ints (step numbers), or a function int -> bool (step + predicate). + eval_task: Instance of `trax.supervised.training.EvalTask` with the + evaluation data, or None. If not provided, the task will be taken from + `loop`. + context_lengths: List of lengths of the context sequence fed into the + model before starting prediction. + horizon_lengths: List of lengths of the predicted sequence. + n_steps: Number of batches to run evaluation for. + accelerate_model (bool): Whether to wrap the model in `tl.Accelerate`. + """ + super().__init__(loop) + + if model is None: + model = loop.model + + observation_serializer = model.observation_serializer + action_serializer = model.action_serializer + + predict_model = model.make_predict_model() + if accelerate_model: + predict_model = tl.Accelerate(predict_model) + self._predict_model = predict_model + self._obs_serializer = observation_serializer + self._act_serializer = action_serializer + + if isinstance(eval_at, int): + self._eval_at = lambda step: step % eval_at == 1 + elif hasattr(eval_at, "__in__"): + self._eval_at = lambda step: step in eval_at + elif callable(eval_at): + self._eval_at = eval_at + else: + raise TypeError(f"Unsupported type for eval_at: {type(eval_at)}.") + + if eval_task is None: + if len(loop.eval_tasks) != 1: + raise ValueError( + "If eval_task is not provided, the number of eval_tasks registered " + "in Loop must be exactly 1." + ) + eval_task = loop.eval_tasks[0] + self._eval_task = eval_task + + self._context_lengths = list(sorted(context_lengths)) + self._horizon_lengths = list(sorted(horizon_lengths)) + self._n_steps = n_steps + + self._batch_size = eval_task.sample_batch[0].shape[0] + (_, self._init_state) = predict_model.init( + shapes.ShapeDtype((self._batch_size, 1), dtype=np.int32) + ) + + @property + def predict_model(self): + return self._predict_model + + def call_at(self, step): + return self._eval_at(step) + + def on_step_begin(self, step): + pass + + def on_step_end(self, step): + summary_writer = jaxboard.SummaryWriter( + os.path.join(self._loop.output_dir, "srl_eval") + ) + try: + weights = self._loop.eval_model.seq_model_weights + metrics = self.evaluate(weights) + self._loop.log_summary(metrics, summary_writer, "", "srl_eval") + finally: + summary_writer.close() + + def evaluate(self, weights): + """Evaluates the model and returns the metrics.""" + self._predict_model.weights = weights + + metrics = collections.defaultdict(list) + for _ in range(self._n_steps): + batch = self._eval_task.next_batch() + step_metrics = self._evaluate_batch(batch) + for key, value in step_metrics.items(): + metrics[key].append(value) + + metrics = {k: np.array(v) for (k, v) in metrics.items()} + + def metric_name(context, horizon): + return f"pred_error/context_{context}/horizon_{horizon}" + + return { + metric_name(context, horizon): np.sum(errors) / (np.sum(errors != 0) + 1e-6) + for ((context, horizon), errors) in metrics.items() + } + + def _evaluate_batch(self, batch): + """Performs evaluation on a single batch.""" + (obs, act, _, mask) = batch + obs_repr = serialization_utils.Serialize(self._obs_serializer)(obs) + act_repr = serialization_utils.Serialize(self._act_serializer)(act) + + errors = {} + last_context = 0 + last_state = self._init_state + last_start_id = 0 + for context in self._context_lengths: + self._predict_model.state = last_state + start_id = last_start_id + + if context > last_context: + context_seq = serialization_utils.Interleave()( + ( + obs_repr[:, last_context:context], + act_repr[:, last_context:context], + ) + ) + consume_sequence(self._predict_model, start_id, context_seq[:, :-1]) + last_start_id = start_id = context_seq[:, -1:] + last_state = self._predict_model.state + last_context = context + + for timestep in range(max(self._horizon_lengths)): + pred_repr = decoding.autoregressive_sample( + self._predict_model, + start_id=start_id, + eos_id=-1, + batch_size=self._batch_size, + max_length=self._obs_serializer.representation_length, + accelerate=False, + ) + horizon = timestep + 1 + if horizon in self._horizon_lengths: + pred = self._obs_serializer.deserialize(pred_repr) + error = self._calculate_error(pred, obs[:, context + timestep]) + errors[context, horizon] = error * mask[:, context + timestep] + + start_id = pred_repr[:, -1:] + consume_sequence( + self._predict_model, start_id, act_repr[:, context + timestep, :-1] + ) + start_id = act_repr[:, context + timestep, -1:] + + return errors + + def _calculate_error(self, prediction, ground_truth): + return (prediction - ground_truth) ** 2 + + +def consume_sequence(model, start_id, sequence): + decoding.autoregressive_sample( + model, + start_id=start_id, + eos_id=-1, + inputs=sequence, + batch_size=sequence.shape[0], + max_length=1, + accelerate=False, + ) diff --git a/trax/learning/supervised/decoding.py b/trax/learning/supervised/decoding.py new file mode 100644 index 000000000..be476a63f --- /dev/null +++ b/trax/learning/supervised/decoding.py @@ -0,0 +1,299 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decoding with Trax models.""" + +import numpy as np + +from trax import fastmath +from trax import layers as tl + + +def autoregressive_sample_stream( + model, + inputs=None, + batch_size=1, + temperature=1.0, + start_id=0, + accelerate=True, + eval_mode=False, + eval_min_length=1, +): + """Yields samples from `model`, in autoregressive language model fashion. + + This function uses `model` to generate outputs one position at a time, with + access to inputs for the current position and all preceding positions. The + new output becomes the next position's input, and further calls to + `autoregressive_sample_stream` repeat the process for successive positions + indefinitely. + + Inputs and outputs always come in batches, even if size 1. If `inputs` is + present, it must have shape (`batch_size`, inputs_sequence_length), and each + output in the stream has shape (`batch_size`, 1). + + Args: + model: A layer object (subclass of `trax.layers.Layer`) created in + `'predict'` mode and initialized from trained weights. The model + must have a structure that allows it to run as an autoregressive + one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), + except if `eval_mode` is set -- any model can be sampled then, + but the sampling process may be much slower. + inputs: Sequence of symbols the model sees as input the first time it + generates an output. If None, the model generates the first output + based on just the start symbol. + batch_size: Number of sequences to generate in parallel as a batch. + temperature: Parameter that controls the sharpness of the softmax that + feeds the sampling process. Values range from 0.0 (all probability mass + goes to one candidate; like an argmax) to positive infinity (all + candidates have equal probability). + start_id: Integer representing the start symbol for the autoregressive + process, or array of shape (`batch_size`, 1) of such integers. + accelerate: If True, create an accelerated version of `model` and use it + for generating outputs. + eval_mode: If True, assume the model is created in `eval` mode and sample + by collecting all previous outputs and passing the whole tensor. + eval_min_length: If set, the minimum length to pad to in eval mode. + + Yields: + Tensor of integers with shape (`batch_size`, 1), representing the batch of + outputs for the next position in the stream. + """ + if inputs is not None and inputs.shape[0] != batch_size: + raise ValueError( + f"Inputs batch size ({inputs.shape[0]}) does not match " + f"batch_size arg ({batch_size}." + ) + + fast_model = tl.Accelerate(model) if accelerate else model + if np.isscalar(start_id): + start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) + else: + start_symbol = start_id + if model.n_in == 1 and inputs is not None: + current_symbols = np.concatenate([start_symbol, inputs], axis=1) + else: + current_symbols = start_symbol + + if eval_mode: + # no start symbol needed in eval mode + current_symbols = current_symbols[:, 1:] + + while True: + # Pad inputs to power-of-2 length if needed. + if eval_mode: + # one extra symbol as an initial one will be added + l = max(eval_min_length, current_symbols.shape[1] + 1) + pad_len = int(2 ** np.ceil(np.log2(l))) - current_symbols.shape[1] + unpadded_symbols = current_symbols + current_symbols = np.pad( + current_symbols, [[0, 0], [0, pad_len]], mode="constant" + ) + last_index = -pad_len # no -1 as the starting one will be added + else: + last_index = -1 + # Run the model. + if model.n_in > 1 and inputs is not None: + logits = fast_model((inputs, current_symbols))[0] + else: + logits = fast_model(current_symbols) + logits = tl.log_softmax(logits[:, last_index, :]) + sample = tl.logsoftmax_sample(logits, temperature=temperature) + yield sample + if eval_mode: + current_symbols = np.concatenate( + [unpadded_symbols, sample[:, None]], axis=1 + ) + else: + # NOTE: Because the model is autoregressive and in 'predict' mode, its + # history is cached in the model state and the next input is the single + # symbol just sampled. + current_symbols = sample[:, None] + + +def autoregressive_sample( + model, + inputs=None, + batch_size=1, + temperature=1.0, + start_id=0, + eos_id=1, + max_length=100, + accelerate=True, + eval_mode=False, + eval_min_length=1, +): + """Returns a batch of sequences created by autoregressive sampling. + + This function uses `model` to generate outputs one position at a time, with + access to inputs for the current position and all preceding positions. The + new output becomes the next position's input, and this loop repeats until + either the model outputs the `eos_id` value or the output sequence reaches + `max_length` items. + + Args: + model: A layer object (subclass of `trax.layers.Layer`) created in + `'predict'` mode and initialized from trained weights. The model + must have a structure that allows it to run as autoregressive + one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), + except if `eval_mode` is set -- any model can be sampled then, + but the sampling process may be much slower. + inputs: Sequence of symbols the model sees as input the first time it + generates an output. If None, the model must generate the first output + with no input to guide it. + batch_size: Number of sequences to generate in parallel as a batch. + temperature: Parameter that controls the sharpness of the softmax that + feeds the sampling process. Values range from 0.0 (all probability mass + goes to one candidate; like an argmax) to positive infinity (all + candidates have equal probability). + start_id: The start symbol (ID/integer) for the autoregressive process, + or array of shape (`batch_size`, 1) of such integers. + eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive + process. + max_length: Maximum length for generated sequences. + accelerate: If True, create an accelerated version of `model` and use it + for generating outputs. + eval_mode: If True, assume the model is created in `eval` mode and sample + by collecting all previous outputs and passing the whole tensor. + eval_min_length: If set, the minimum length to pad to in eval mode. + + Returns: + Tensor of integers with shape (`batch_size`, output_length) representing + a batch of output sequences. output_length is the maximum length of the + output sequences, where each sequence can be no longer than `max_length`. + """ + result = [] + eos_seen = [] + counter = 0 + for sample in autoregressive_sample_stream( + model, + inputs, + batch_size=batch_size, + temperature=temperature, + start_id=start_id, + accelerate=accelerate, + eval_mode=eval_mode, + eval_min_length=eval_min_length, + ): + sample = sample[:, None] + result.append(sample) + counter += 1 + if counter >= max_length: + return np.concatenate(result, axis=1) + # Check at which batch positions have we already encountered EOS. + for j in range(batch_size): + if int(sample[j, 0]) == eos_id: + eos_seen.append(j) + # If EOS has been seen on all positions, stop. + if all([j in eos_seen for j in range(batch_size)]): + return np.concatenate(result, axis=1) + return np.concatenate(result, axis=1) + + +def beam_search( + model, + inputs=None, + batch_size=1, + n_beams=2, + start_id=0, + eos_id=1, + max_length=100, + length_penalty=1.0, + accelerate=True, +): + """Returns a batch of n_beams-sequences created by beam search. + + This function uses `model` to generate outputs one position at a time, with + access to inputs for the current position and all preceding positions. The + new output becomes the next position's input, and this loop repeats until + either the model outputs the `eos_id` value or the output sequence reaches + `max_length` items -- but keeping n_beams top beams. + + Args: + model: A layer object (subclass of `trax.layers.Layer`) created in + `'predict'` mode and initialized from trained weights. The model + must have a structure that allows it to run as autoregressive + one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`). + inputs: Sequence of symbols the model sees as input the first time it + generates an output. If None, the model must generate the first output + with no input to guide it. + batch_size: Number of sequences to generate in parallel as a batch. + n_beams: How many beams to consider at the same time. + start_id: The start symbol (ID/integer) for the autoregressive process, + or array of shape (`batch_size`, 1) of such integers. + eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive + process. + max_length: Maximum length for generated sequences. + length_penalty: Factor alpha in calculating the length penalty for beams. + accelerate: If True, create an accelerated version of `model` and use it + for generating outputs. + + Returns: + Tensor of integers with shape (`batch_size`, n_beams, output_length) with + a batch of output sequences. output_length is the maximum length of the + output sequences, where each sequence can be no longer than `max_length`. + """ + del eos_id, length_penalty # TODO(lukaszkaiser): add length penalty, eos + assert batch_size == 1, "Batch size > 1 not supported yet" + if inputs is not None and inputs.shape[0] != batch_size: + raise ValueError( + f"Inputs batch size ({inputs.shape[0]}) does not match " + f"batch_size arg ({batch_size}." + ) + + fast_model = tl.Accelerate(model) if accelerate else model + if np.isscalar(start_id): + start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) + else: + start_symbol = start_id + if model.n_in == 1 and inputs is not None: + current_symbols = np.concatenate([start_symbol, inputs], axis=1) + else: + current_symbols = start_symbol + + beams = [current_symbols for _ in range(n_beams)] + results = [([], 0.0) for _ in range(n_beams)] + states = [fast_model.state for _ in range(n_beams)] + top_k = [None] * n_beams + counter = 0 + while counter < max_length: + counter += 1 + # Run the model on all beams, collect states and top_k for each beam. + for beam_id in range(n_beams if counter > 1 else 1): + fast_model.state = states[beam_id] + if model.n_in > 1 and inputs is not None: + logits = fast_model((inputs, beams[beam_id]))[0] + else: + logits = fast_model(beams[beam_id]) + logits = tl.log_softmax(logits[:, -1, :]) + states[beam_id] = fast_model.state + top_k[beam_id] = fastmath.top_k(logits, k=n_beams) + + # Select new beams. + cur_values = [] # will hold triples (sum-of-logprobs, beam-id, symbol) + for beam_id in range(n_beams if counter > 1 else 1): + for k in range(n_beams): + values, symbols = top_k[beam_id] + value, symbol = values[:, k], symbols[:, k] + cur_values.append((results[beam_id][1] + value, beam_id, symbol)) + cur_values.sort(key=lambda x: -x[0][0]) # x[0][0] as batch_size=1 + # Collect top beams to the new states and results. + new_results, new_states, new_beams = [], [], [] + for value, beam_id, symbol in cur_values[:n_beams]: + new_results.append((results[beam_id][0] + [symbol], value)) + new_states.append(states[beam_id]) # copy? + new_beams.append(symbol[:, None]) + results, states, beams = new_results, new_states, new_beams + + return [(np.stack(r, axis=-1), v) for (r, v) in results] diff --git a/trax/learning/supervised/history.py b/trax/learning/supervised/history.py new file mode 100644 index 000000000..5f97833ee --- /dev/null +++ b/trax/learning/supervised/history.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax history.""" +import collections +import copy + +import six + +from absl import logging + + +class History: + """History of metrics. + + History contains the metrics recorded during training and evaluation. + Save data with history.append and get a sequence of data by calling + history.get. + + For example: + history.append('train', 'metrics/accuracy', 1, 0.04) + history.append('train', 'metrics/accuracy', 1000, 0.31) + history.get('train', 'metrics/accuracy') + # returns [(1, 0.04), (1000, 0.31)] + """ + + def __init__(self): + # Structure is + # values = { + # 'mode1': { + # 'metric1': [val1, val2], + # ... + # }, + # 'mode2': ... + # } + self._values = {} + + @classmethod + def from_dict(cls, json_object): + """Constructs a `History` from a Python dictionary of parameters.""" + history = History() + for key, value in six.iteritems(json_object): + history.__dict__[key] = value + return history + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def append(self, mode, metric, step, value): + """Append (step, value) pair to history for the given mode and metric.""" + if mode not in self._values: + self._values[mode] = collections.defaultdict(list) + self._values[mode][metric].append((step, value)) + + def get(self, mode, metric): + """Get the history for the given metric and mode.""" + if mode not in self._values: + logging.info("Metric %s not found for mode %s", metric, mode) + return [] + return list(self._values[mode][metric]) + + @property + def modes(self): + """Current tracked modes.""" + return sorted(list(self._values.keys())) + + def metrics_for_mode(self, mode): + """Metrics available for a given mode.""" + if mode not in self._values: + logging.info("Mode %s not found", mode) + return [] + return sorted(list(self._values[mode].keys())) + + def __str__(self): + return str(self._values) diff --git a/trax/learning/supervised/lr_schedules.py b/trax/learning/supervised/lr_schedules.py new file mode 100644 index 000000000..4838df7fa --- /dev/null +++ b/trax/learning/supervised/lr_schedules.py @@ -0,0 +1,233 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Learning rate (LR) schedules. + +In Trax a learning rate schedule is a function: +:math:`\text{step} \mapsto \text{learning_rate}`. +This module provides helpers for constructing such functions. For example:: + + constant(0.001) + +returns a function that always returns `0.001`. +""" + +import math + +import gin + +from trax.fastmath import numpy as jnp + + +@gin.configurable +def constant(value): + """Returns an LR schedule that is constant from time (step) 1 to infinity.""" + return _BodyAndTail(value, body_start=1) + + +@gin.configurable +def warmup(n_warmup_steps, max_value): + """Returns an LR schedule with linear warm-up followed by constant value. + + Args: + n_warmup_steps: Number of steps during which the learning rate rises on + a line connecting (0, 0) and (n_warmup_steps, max_value). + max_value: Value for learning rate after warm-up has finished. + """ + return _BodyAndTail(max_value, body_start=n_warmup_steps + 1) + + +@gin.configurable +def warmup_and_rsqrt_decay(n_warmup_steps, max_value): + """Returns an LR schedule with warm-up + reciprocal square root decay.""" + return _BodyAndTail(max_value, tail_start=n_warmup_steps + 1, tail_fn=_rsqrt) + + +@gin.configurable +def multifactor( + factors="constant * linear_warmup * rsqrt_decay", + constant=0.1, # pylint: disable=redefined-outer-name + warmup_steps=400, + decay_factor=0.5, + steps_per_decay=20000, + steps_per_cycle=100000, + second_constant=0.01, + second_constant_step=10000, + minimum=0, +): + """Factor-based learning rate schedule. + + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * decay_every: Every k steps decay the learning rate by decay_factor. + * cosine_deay: Cyclic cosine decay, uses steps_per_cycle parameter. + * two_constants: constant until second_constant_step, then switch to + second_constant. + + Args: + factors: a string with factors separated by '*' that defines the schedule. + constant: float, the starting constant for the learning rate schedule. + warmup_steps: how many steps to warm up for in the warmup schedule. + decay_factor: The amount to decay the learning rate by. + steps_per_decay: How often to decay the learning rate. + steps_per_cycle: Steps per cycle when using cosine decay. + second_constant: float, the second constant for the learning rate schedule. + second_constant_step: the step when the second_constant is triggered. + minimum: if the computed rate is below the minimum, then return the minimum. + + Returns: + a function learning_rate(step): float -> {'learning_rate': float}, the + step-dependent lr. + """ + factors = [n.strip() for n in factors.split("*")] + + def learning_rate(step): + """Step to learning rate function.""" + ret = 1.0 + for name in factors: + if name == "constant": + ret *= constant + elif name == "two_constants": + if step < second_constant_step: + ret *= constant + else: + ret *= second_constant + elif name == "linear_warmup": + ret *= jnp.minimum(1.0, step / warmup_steps) + elif name == "rsqrt_decay": + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "rsqrt_normalized_decay": + ret *= jnp.sqrt(warmup_steps) + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "decay_every": + ret *= decay_factor ** (step // steps_per_decay) + elif name == "cosine_decay": + progress = jnp.maximum( + 0.0, (step - warmup_steps) / float(steps_per_cycle) + ) + ret *= 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))) + else: + raise ValueError("Unknown factor %s." % name) + # TODO(henrykm): return float(jnp.max(minimum, ret)) would be + # better but causes TypeError: 'numpy.float64' object cannot + # be interpreted as an integer + if ret <= minimum: + return minimum + return ret + + return learning_rate + + +class _BodyAndTail: + """Defines a curve over time as a linear ramp + constant body + curvy tail. + + The body is a span of constant learning rate, and can be the entire curve. + The warm-up, if present, is based on the line connecting points (0, 0) and + (body_start, body_value). The tail, if defined, is a function from time to + learning rate that is used for all training steps from tail_start on. + """ + + def __init__(self, body_value, body_start=None, tail_start=None, tail_fn=None): + """Specifies a body-and-tail time curve. + + Args: + body_value: Constant learning rate for the body of the curve (after + warm-up and before tail). Also is the reference (maximum) value for + calculating warm-up values and tail values. + body_start: Training step number at which the body starts. If None, takes + its value from tail_start, which amounts to there being no body. All + steps from 1 to body_start - 1 are computed using a linear warm-up. + tail_start: Training step number at which the tail starts. If None, the + body value remains until the end of training. + tail_fn: Function returning a floating point learning rate, given inputs: + - step_number (absolute step number from the start of training) + - tail_start (step number at which the tail starts) + - body_value (value relative to which the tail should be computed) + """ + if body_start is None and tail_start is None: + raise ValueError("Both body start and tail start are None.") + if tail_start is not None and tail_fn is None: + raise ValueError( + f"Tail start has value ({tail_start}) but tail_fn is None." + ) + if body_start is None: + body_start = tail_start if tail_start is not None else 1 + + self._body_value = body_value + self._body_start = body_start + self._tail_start = tail_start + self._tail_fn = tail_fn + + def __call__(self, step_number): + """Returns the learning rate for the given step number.""" + if step_number < self._body_start: + return (step_number / self._body_start) * self._body_value + elif self._tail_start is not None and step_number >= self._tail_start: + return self._tail_fn(step_number, self._tail_start, self._body_value) + else: + return self._body_value + + +def _rsqrt(step_number, tail_start, body_value): + """Computes a tail using a scaled reciprocal square root of step number. + + Args: + step_number: Absolute step number from the start of training. + tail_start: Step number at which the tail of the curve starts. + body_value: Value relative to which the tail should be computed. + + Returns: + A learning rate value that falls as the reciprocal square root of the step + number, scaled so that it joins smoothly with the body of a BodyAndTail + instance. + """ + return body_value * (math.sqrt(tail_start) / math.sqrt(step_number)) + + +class _CosineSawtoothTail: + """Cosine-sawtooth-shaped tail that simulates warm restarts. + + Creates a cyclic learning rate curve; each cycle is half of a cosine, falling + from maximum value to minimum value. For motivation and further details, see + Loshchilov & Hutter (2017) [https://arxiv.org/abs/1608.03983]. + """ + + def __init__(self, steps_per_cycle, min_value=1e-5): + """Configures the periodic behavior of this learning rate function. + + Args: + steps_per_cycle: Number of training steps per sawtooth cycle. The + learning rate will be highest at the start of each cycle, and lowest + at the end. + min_value: Minimum value, reached at the end of each cycle. + """ + self._steps_per_cycle = steps_per_cycle + self._min_value = min_value + + def __call__(self, step_number, tail_start, body_value): + """Returns the learning rate for the given step number, when in the tail. + + Args: + step_number: Absolute step number from the start of training. + tail_start: Step number at which the tail of the curve starts. + body_value: Value relative to which the tail should be computed. + """ + max_value = body_value + min_value = self._min_value + position_in_cycle = ((step_number - tail_start) / self._steps_per_cycle) % 1.0 + theta = math.pi * position_in_cycle + return min_value + (max_value - min_value) * 0.5 * (1 + math.cos(theta)) diff --git a/trax/learning/supervised/pretrain_finetune.py b/trax/learning/supervised/pretrain_finetune.py new file mode 100644 index 000000000..890736b92 --- /dev/null +++ b/trax/learning/supervised/pretrain_finetune.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""data processing for BERT. + +For now, this file only supports fine-tuning bert-base-uncased on GLUE. + +TODO(afrozm): Move this into data/ +""" +import functools + +import gin +import numpy as onp +import tensorflow_datasets as tfds + +from trax.data.preprocessing.inputs import Inputs + + +def _tfds_stream( + n_devices, + dataset_name, + split, + batch_size, + data_dir, + shuffle_files, + shuffle_buffer_size, + batch_shuffle_size, + preprocess_fun, + repeat=True, +): + """Streams batches of examples from tfds, with pure-python preprocessing.""" + # TODO(piotrekp1): delete if switched to data_streams + if batch_size % n_devices != 0: + raise ValueError( + f"Batch size ({batch_size}) not divisible" + " by number of devices ({n_devices})" + ) + ds = tfds.load( + name=dataset_name, split=split, data_dir=data_dir, shuffle_files=shuffle_files + ) + if repeat: + ds = ds.repeat() + if shuffle_buffer_size is not None: + ds = ds.shuffle(shuffle_buffer_size) + ds = ds.batch(batch_size) + if batch_shuffle_size is not None: + ds = ds.shuffle(batch_shuffle_size) + + for batch in tfds.as_numpy(ds): + if preprocess_fun is not None: + yield preprocess_fun(batch) + else: + yield batch + + +@gin.configurable +def tfds_inputs( + dataset_name, + preprocess_fun, + batch_size, + eval_batch_size=None, + data_dir=None, + train_split=tfds.Split.TRAIN, + eval_split=tfds.Split.VALIDATION, + shuffle_buffer_size=1024, + batch_shuffle_size=128, +): + """Tensorflow Datasets input pipeline, with pure-python preprocessing.""" + if eval_batch_size is None: + eval_batch_size = batch_size + return Inputs( + train_stream=functools.partial( + _tfds_stream, + dataset_name=dataset_name, + split=train_split, + batch_size=batch_size, + data_dir=data_dir, + shuffle_files=True, + shuffle_buffer_size=shuffle_buffer_size, + batch_shuffle_size=batch_shuffle_size, + preprocess_fun=preprocess_fun, + ), + eval_stream=functools.partial( + _tfds_stream, + dataset_name=dataset_name, + split=eval_split, + batch_size=eval_batch_size, + data_dir=data_dir, + shuffle_files=False, + shuffle_buffer_size=None, + batch_shuffle_size=None, + preprocess_fun=preprocess_fun, + ), + ) + + +@gin.configurable +def bert_tokenizer(vocab_path=None): + """Constructs a BERT tokenizer.""" + # This import is from https://github.com/google-research/bert which is not + # listed as a dependency in trax. + # TODO(piotrekp1): using SubwordTextEncoder instead after fixing the + # differences + from bert.tokenization.bert_tokenization import ( + FullTokenizer, # pylint: disable=g-import-not-at-top + ) + + if vocab_path is None: + raise ValueError("vocab_path is required to construct the BERT tokenizer.") + tokenizer = FullTokenizer(vocab_path, do_lower_case=True) + return tokenizer + + +def bert_preprocess(batch, tokenizer, key_a, key_b=None, max_len=128): + """Tokenize and convert text to model inputs in a BERT format.""" + batch_size = batch["idx"].shape[0] + input_ids = onp.zeros((batch_size, max_len), dtype=onp.int32) + type_ids = onp.zeros((batch_size, max_len), dtype=onp.int32) + for i in range(batch_size): + sentence_a = batch[key_a][i] + tokens_a = ( + [101] + + tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence_a)) + + [102] + ) + + if key_b is not None: + sentence_b = batch[key_b][i] + tokens_b = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(sentence_b) + ) + [102] + else: + tokens_b = [] + + ex_input_ids = (tokens_a + tokens_b)[:max_len] + ex_type_ids = ([0] * len(tokens_a) + [1] * len(tokens_b))[:max_len] + + input_ids[i, : len(ex_input_ids)] = ex_input_ids + type_ids[i, : len(ex_type_ids)] = ex_type_ids + return input_ids, type_ids, input_ids > 0, batch["label"], onp.ones(batch_size) + + +@gin.configurable +def glue_inputs( + dataset_name=gin.REQUIRED, + batch_size=16, + eval_batch_size=None, + data_dir=None, + max_len=128, + tokenizer=bert_tokenizer, +): + """Input pipeline for fine-tuning BERT on GLUE tasks.""" + if callable(tokenizer): # If we pass a function, e.g., through gin, call it. + tokenizer = bert_tokenizer() + + eval_split = tfds.Split.VALIDATION + if dataset_name == "glue/mnli": + eval_split = "validation_matched" + # TODO(kitaev): Support diagnostic dataset (AX) + + keys_lookup = { + "glue/cola": ("sentence", None), + "glue/sst2": ("sentence", None), + "glue/mrpc": ("sentence1", "sentence2"), + "glue/qqp": ("question1", "question2"), + "glue/stsb": ("sentence1", "sentence2"), + "glue/mnli": ("premise", "hypothesis"), # TODO(kitaev): swap the two? + "glue/qnli": ("question", "sentence"), # TODO(kitaev) swap the two? + "glue/rte": ("sentence1", "sentence2"), + "glue/wnli": ("sentence1", "sentence2"), + } + + key_a, key_b = keys_lookup[dataset_name] + + preprocess_fn = functools.partial( + bert_preprocess, tokenizer=tokenizer, key_a=key_a, key_b=key_b, max_len=max_len + ) + return tfds_inputs( # TODO(piotrekp1): use data_streams instead + dataset_name=dataset_name, + preprocess_fun=preprocess_fn, + batch_size=batch_size, + eval_batch_size=eval_batch_size, + data_dir=data_dir, + train_split=tfds.Split.TRAIN, + eval_split=eval_split, + ) + + +# TODO(piotrekp1): add glue evaluation diff --git a/trax/learning/supervised/trainer_lib.py b/trax/learning/supervised/trainer_lib.py new file mode 100644 index 000000000..1de42b551 --- /dev/null +++ b/trax/learning/supervised/trainer_lib.py @@ -0,0 +1,1030 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Original API for supervised learning/training in Trax. + +Trax authors expect that the `supervised.training` module (under development) +will replace `trainer_lib`. +""" + +import collections +import functools +import itertools +import os +import sys +import time + +import gin +import jax +import tensorflow.compat.v2 as tf + +from absl import logging + +from trax import fastmath +from trax import layers as tl +from trax import optimizers as trax_opt +from trax.data.preprocessing import inputs as trax_inputs +from trax.fastmath import numpy as np +from trax.fastmath import random as jax_random +from trax.layers import base +from trax.learning.supervised import history as trax_history +from trax.learning.supervised import lr_schedules as lr +from trax.learning.supervised import training +from trax.utils import jaxboard +from trax.utils.shapes import ShapeDtype + +# TODO(afrozm): Maybe flatten everything from OptState into TrainerState. +TrainerState = collections.namedtuple( + "_TrainerState", + [ + "step", # Current training step number. + "opt_state", # OptState. + "history", # trax.history.History. + "model_state", # Auxilliary state of the model. + ], +) + + +OptState = collections.namedtuple( + "_OptState", + [ + "weights", # Model weights. + "slots", # Per-parameter optimizer state, e.g. gradient moments. + "opt_params", # Optimizer (hyper)parameters, e.g. learning rate, momentum. + ], +) + + +_DEFAULT_METRICS = { + "loss": tl.WeightedCategoryCrossEntropy(), + "accuracy": tl.WeightedCategoryAccuracy(), + "sequence_accuracy": tl.MaskedSequenceAccuracy(), + "neg_log_perplexity": tl.Serial(tl.WeightedCategoryCrossEntropy(), tl.Negate()), + "weights_per_batch_per_core": tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()), +} + + +NamedStream = collections.namedtuple("NamedStream", ["name", "stream"]) + + +@gin.configurable +def named_stream(name=gin.REQUIRED, stream=gin.REQUIRED): + return NamedStream(name=name, stream=stream) + + +class Trainer: + """Trax trainers. + + A trainers allows to make training steps, train for full epochs, + save the training state and access evaluation data. + """ + + def __init__( + self, + model, + loss_fn, + optimizer, + lr_schedule, + inputs, + output_dir=None, + random_seed=None, + n_devices=None, + checkpoints_at=None, + should_save_checkpoints=True, + should_write_summaries=True, + metrics=None, + checkpoint_highest=None, + checkpoint_lowest=None, + init_checkpoint=None, + ): + self._is_chief, _, self._n_devices, rng = training.init_host_and_devices( + n_devices, random_seed + ) + self._should_save_checkpoints = should_save_checkpoints and self._is_chief + self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] + self._should_write_summaries = should_write_summaries + if not output_dir: + self._should_save_checkpoints = False + self._should_write_summaries = False + self._checkpoint_highest = checkpoint_highest + self._checkpoint_lowest = checkpoint_lowest + self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS + # Inputs is either an Inputs instance or a function that returns it. + self._inputs = inputs + if callable(inputs): # If we pass a function, e.g., through gin, call it. + self._inputs = inputs() + # Initialize the learning rate to a dummy value. It will be set in reset(). + opt = optimizer(learning_rate=0.0) + + # Setup the model. + model_train = model(mode="train") + model_predict_eval = model(mode="eval") + # Should work for fine-tuning of T5. + if init_checkpoint: + model_train.init_from_file(init_checkpoint, weights_only=True) + model_predict_eval.init_from_file(init_checkpoint, weights_only=True) + self._model_with_loss = tl.Serial(model_train, loss_fn) + + # Setup state. + rng, init_rng = jax_random.split(rng) + self._rngs = np.stack(jax_random.split(rng, self._n_devices)) + shapes, dtypes = self._inputs.example_shape_dtype + input_signature = tuple(ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) + + def new_opt_state_and_model_state(rng): + """Returns optimizer and model states suitable for training a model.""" + weights, state = self._model_with_loss.init(input_signature, rng=rng) + (slots, opt_params) = opt.tree_init(weights) + return (OptState(weights, slots, opt_params), state) + + if fastmath.is_backend(fastmath.Backend.JAX): + # JIT parameter initialization to avoid memory fragmentation + new_opt_state_and_model_state = fastmath.jit(new_opt_state_and_model_state) + self._new_opt_state_and_model_state = lambda: new_opt_state_and_model_state( + init_rng + ) + + # Arrange and initialize metrics layers. + self._metrics = list(sorted(self._metrics_dict.keys())) + metrics_layers = [self._metrics_dict[m] for m in self._metrics] + metrics_in_parallel = tl.Branch(*metrics_layers) + metrics_in_parallel.rng = init_rng + example_signature = tuple( + ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype) + ) + model_predict_eval.init(example_signature) + self._input_signature = example_signature + output_signature = model_predict_eval.output_signature(example_signature) + m_weights, m_state = metrics_in_parallel.init(output_signature) + self._metrics_weights = self._for_n_devices(m_weights) + self._metrics_state = self._for_n_devices(m_state) + + # Jit model_predict and update so they're fast. + self._jit_eval = _jit_predict_fn( + model_predict_eval, metrics_in_parallel, self._n_devices + ) + self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) + + self._model_train = model_train + self._model_predict_eval = model_predict_eval + self._loss_fn = loss_fn + self._lr_schedule = lr_schedule + + # Those fields will be set in reset(). + self._output_dir = None + self._train_sw = None + self._eval_sw = None + self._history = None + self._opt_state = None + self._step = None + self._model_state = None + self.reset(output_dir) + + @property + def n_devices(self): + return self._n_devices + + @property + def step(self): + return self._step + + @property + def model_weights(self): + # Currently we need to pick [0] as we ignore loss weights (empty). + weights = self._opt_state.weights[0] + if self.n_devices > 1: + unreplicate = lambda x: x[0] + weights = fastmath.nested_map(unreplicate, weights) + return weights + + @model_weights.setter + def model_weights(self, weights): + new_model_weights = self._for_n_devices(weights) + if isinstance(self._opt_state.weights, list): + self._opt_state.weights[0] = new_model_weights + else: # weights are a tuple, need to re-create + new_weights = [new_model_weights] + list(self._opt_state.weights[1:]) + self._opt_state = self._opt_state._replace(weights=new_weights) + + @property + def model_state(self): + # Currently we need to pick [0] as we ignore loss state (empty). + state = self._model_state[0] + if self.n_devices > 1: + unreplicate = lambda x: x[0] + state = fastmath.nested_map(unreplicate, state) + return state + + @model_state.setter + def model_state(self, state): + new_model_state = self._for_n_devices(state) + if isinstance(self._model_state, list): + self._model_state[0] = new_model_state + else: # weights are a tuple, need to re-create + self._model_state = [new_model_state] + list(self._model_state[1:]) + + @property + def state(self): + return TrainerState( + opt_state=self._opt_state, + step=self._step, + history=self._history, + model_state=self._model_state, + ) + + @property + def learning_rate(self): + with fastmath.use_backend(fastmath.Backend.NUMPY): + return self._lr_schedule(self._step) + + def reset(self, output_dir, init_checkpoint=None): + """Reset the model parameters. + + Restores the parameters from the given output_dir if a checkpoint exists, + otherwise randomly initializes them. + + Does not re-jit the model. + + Args: + output_dir: Output directory. + init_checkpoint: Initial checkpoint (default $output_dir/model.pkl.gz) + """ + self.close() + self._output_dir = output_dir + if output_dir is not None: + tf.io.gfile.makedirs(output_dir) + else: + assert not self._should_save_checkpoints + assert not self._should_write_summaries + + # Create summary writers and history. + if self._should_write_summaries: + self._train_sw = jaxboard.SummaryWriter( + os.path.join(output_dir, "train"), enable=self._is_chief + ) + self._eval_sw = jaxboard.SummaryWriter( + os.path.join(output_dir, "eval"), enable=self._is_chief + ) + + # Reset the train and eval streams. + self._train_stream = _repeat_stream(self._inputs.train_stream, self._n_devices) + # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval + # set by adding a padding and stopping the stream when too large. + self._eval_stream = _repeat_stream(self._inputs.eval_stream, self._n_devices) + self._train_eval_stream = _repeat_stream( + self._inputs.train_eval_stream, self._n_devices + ) + + # Restore the training state. + if output_dir is not None: + state = load_trainer_state( + output_dir, self._model_with_loss, init_checkpoint + ) + else: + state = TrainerState( + step=None, + opt_state=None, + history=trax_history.History(), + model_state=None, + ) + self._step = state.step or 0 + history = state.history + self._history = history + if state.opt_state: + opt_state = state.opt_state + model_state = state.model_state + else: + opt_state, model_state = self._new_opt_state_and_model_state() + model_state = self._for_n_devices(model_state) + self._opt_state = OptState(*self._for_n_devices(opt_state)) + self._model_state = model_state + if not state.opt_state and self._should_save_checkpoints: + self.save_state(keep=False) + + def train_epoch(self, n_steps, n_eval_steps): + """Runs `n_steps` of training, with periodic logging, saving, and evals.""" + # TODO(jonni): Clarify how this method relates to the stricter notion of + # epoch (training for as many steps as needed for a full pass through the + # training data). + print() # Add visual separator in logs for start of training epoch. + start_time = time.time() + + for _ in range(n_steps): + batch = next(self._train_stream) + if self.n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible. + batch = _reshape_by_device(batch, self.n_devices) + self.train_step(batch) + if self._should_save_now(): + self.save_state(keep=True) + if self._should_log_now(): + self._train_sw.scalar("training/learning_rate", self.learning_rate) + + # At end of n_steps, do bookkeeping, run evals, and save state. + elapsed_time = time.time() - start_time + self.log_step("Ran %d train steps in %0.2f secs" % (n_steps, elapsed_time)) + if self._train_sw and n_steps > 1: + self._train_sw.scalar( + "training/steps per second", n_steps / elapsed_time, step=self._step + ) + self._train_sw.flush() + self.evaluate(n_eval_steps) + if self._eval_sw: + self._eval_sw.flush() + if self._should_save_checkpoints: + self.save_state(keep=False) + if self._should_save_checkpoints and self._current_step_is_best(high=True): + self.save_state(keep=False, prefix="highest_" + self._checkpoint_highest) + if self._should_save_checkpoints and self._current_step_is_best(high=False): + self.save_state(keep=False, prefix="lowest_" + self._checkpoint_lowest) + + def train_step(self, batch): + """Run one training step and update self._opt_state.""" + # Calculate the current optimizer parameters. + opt_param_updates = self._for_n_devices( + {"learning_rate": np.array(self.learning_rate)} + ) + opt_state = self._opt_state + opt_state.opt_params.update(opt_param_updates) + + # Run the update. + weights, slots, opt_params = opt_state + (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( + (weights, slots), + self._step, + opt_params, + batch, + self._model_state, + self._rngs, + ) + self._opt_state = opt_state._replace(weights=weights, slots=slots) + if self._should_log_now(): + for name, value in stat.items(): + # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here + # with a device put array error complaining that it should be an array. + # On multiple devices, take the mean. + scalar_value = np.mean(np.array(value)) + self._train_sw.scalar("training/" + name, scalar_value, step=self._step) + self._step += 1 + + def evaluate(self, n_eval_steps): + """Evaluate the model and log metrics.""" + _, rng = jax_random.split(self._rngs[0]) + # TODO(lukaszkaiser): both model state and parameters by default include + # the loss layer. Currently, we access the pure-model parameters by just + # indexing, [0] here. But we should make it more explicit in a better API. + weights = (self._opt_state.weights[0], self._metrics_weights) + state = (self._model_state[0], self._metrics_state) + self.log_step("Evaluation") + train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps) + train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state, rng) + self.log_metrics(train_metrics, self._train_sw, "train") + eval_slice = itertools.islice(self._eval_stream, n_eval_steps) + eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng) + self.log_metrics(eval_metrics, self._eval_sw, "eval") + self.log_step("Finished evaluation") + + # Save the learning rate in history. + self._history.append( + "train", "training/learning_rate", self._step, self.learning_rate + ) + + def evaluation_round(self, inputs_stream, weights, state, rng): + """Evaluate. + + Args: + inputs_stream: Iterable of inputs to evaluate on. + weights: Weights for each f in eval_fns. + state: State for each f in eval_fns. + rng: Single-use random number generator (JAX PRNG key). + + Returns: + Tuple of `(metrics, state)`. `metrics` is a dict from metric name to + metric value averaged over the number of inputs, and `state` is the end + state returned by this trainers's `predict_fn`. + """ + metrics = collections.defaultdict(float) + count = 0 + for inp in inputs_stream: + count += 1 + rng, subrng = jax_random.split(rng) + metric_values, _ = self._jit_eval(inp, weights, state, subrng) + try: + metric_values = list(metric_values) + except (TypeError, IndexError): + metric_values = [float(metric_values)] + for m, v in zip(self._metrics, metric_values): + metrics[m] += v + return {m: v / count for (m, v) in metrics.items()}, state + + def save_gin(self): + """ "Saves the operative gin config, only if it is the chief.""" + if not self._is_chief: + return + assert self._output_dir is not None + config_path = os.path.join(self._output_dir, "config.gin") + config_str = gin.operative_config_str() + with tf.io.gfile.GFile(config_path, "w") as f: + f.write(config_str) + sw = self._train_sw + if sw: + sw.text("gin_config", jaxboard.markdownify_operative_config_str(config_str)) + + def _save_state_dict(self, trainer_state_dict, weights_file): + training.pickle_to_file(trainer_state_dict, weights_file, gzip=True) + log("Model saved to %s" % weights_file, stdout=False) + + def save_state(self, keep, prefix="model"): + """Save trainers state given a possibly replicated opt_state.""" + opt_state = self._opt_state + if self.n_devices > 1: + first_replica = lambda x: x[0] + opt_state = OptState(*fastmath.nested_map(first_replica, opt_state)) + # This line, while optional, allows JAX to transfer arrays from the device + # to the host in parallel, which is particularly important for cloud TPU. + if fastmath.is_backend(fastmath.Backend.JAX): + opt_state = jax.device_get(opt_state) + step, history, model_state = self._step, self._history, self._model_state + output_dir = self._output_dir + + weights_file = os.path.join(output_dir, prefix + ".pkl.gz") + + # This dict will be stored as the model. + trainer_state_dict = make_trainer_state_dict( + step, opt_state, history, model_state, self._input_signature + ) + self._save_state_dict(trainer_state_dict, weights_file) + + if keep: + weights_file = os.path.join(output_dir, "{}_{}.pkl.gz".format(prefix, step)) + self._save_state_dict(trainer_state_dict, weights_file) + + def save_computation_graphs(self): + """Dump computation graphs to files.""" + if self.n_devices != 1: + return # TODO(lukaszkaiser): make this work with more devices. + batch = next(self._train_stream) + output_dir = self._output_dir + if self.n_devices > 1: + batch = _reshape_by_device(batch, self.n_devices) + weights = self._opt_state.weights[0] + forward_computation = ( + jax.jit(self._model_predict_eval) + .lower( + batch, weights=weights, state=self._model_state[0], rng=self._rngs[0] + ) + .compiler_ir(dialect="hlo") + ) + with tf.io.gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: + f.write(forward_computation.as_hlo_text()) + with tf.io.gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: + f.write(forward_computation.as_hlo_dot_graph()) + + def log_step(self, step_message): + log("Step % 6d: %s" % (self.step, step_message)) + + def log_metrics(self, metrics, summ_writer, log_prefix): + """Log metrics to summary writer and history.""" + history = self._history + rjust_len = max([0] + [len(name) for name in metrics]) + for name, value in metrics.items(): + self.log_step( + "%s %s | % .8f" % (log_prefix.ljust(5), name.rjust(rjust_len), value) + ) + full_name = "metrics/" + name + if history: + history.append(log_prefix, full_name, self.step, value) + if summ_writer: + summ_writer.scalar(full_name, value, self.step) + + def print_n_weights(self): + """Prints the total count of trainable weights.""" + opt_state = self._opt_state + sizes = _sizes(opt_state.weights) + if self.n_devices > 1: + unreplicate = lambda x: x[0] + single_weights = fastmath.nested_map(unreplicate, opt_state.weights) + sizes = _sizes(single_weights) + total_size = _nested_reduce(sum, sizes) + self.log_step("Total number of trainable weights: %d" % total_size) + + def _should_save_now(self): + return self._should_save_checkpoints and self._step in self._checkpoints_at + + def _current_step_is_best(self, high): + """Is the current step the best (highest if high, else lowest).""" + metric = self._checkpoint_highest if high else self._checkpoint_lowest + if metric is None: + return False + # History is a list of pairs (step, value). + history = self._history.get("eval", "metrics/" + metric) + sequence = [float(i[1]) for i in history] # Just the values. + best = max(sequence) if high else min(sequence) # Best value. + last_is_best = float(history[-1][1]) == best # Is last the best? + cur_step = history[-1][0] == self._step # Is last the current step? + return cur_step and last_is_best + + def _should_log_now(self): + return self._train_sw is not None and (self._step == 1 or self._step % 10 == 0) + + def _for_n_devices(self, x): + """Replicates/broadcasts `x` for n devices if `self.n_devices > 1`.""" + return tl.for_n_devices(x, self.n_devices) # pylint: disable=protected-access + + def close(self): + if self._train_sw is not None: + self._train_sw.close() + self._train_sw = None + if self._eval_sw is not None: + self._eval_sw.close() + self._eval_sw = None + + +@gin.configurable(denylist=["output_dir"]) +def train( + output_dir, + model=gin.REQUIRED, + loss_fn=tl.WeightedCategoryCrossEntropy(), + inputs=trax_inputs.batcher, + optimizer=trax_opt.Adafactor, + lr_schedule_fn=lr.multifactor, + trainer_class=Trainer, + steps=1000, + checkpoints_at=None, + permanent_checkpoints_at=None, + eval_steps=10, + eval_frequency=100, + permanent_checkpoint_frequency=None, + random_seed=None, + save_graphs=True, + metrics=None, + checkpoint_highest=None, + checkpoint_lowest=None, + use_loop=True, + loss_chunk_size=0, + use_memory_efficient_trainer=False, + adasum=False, + init_checkpoint=None, + callbacks=None, + n_weights_shards=1, + additional_train_tasks=None, + additional_eval_tasks=None, + additional_eval_streams=None, +): + """Train the model on the inputs. + + Args: + output_dir: Directory where to put the logs and checkpoints. + model: The model to train as a callable returning 2 callables, an init_fn + and apply_fn. + loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, + rng -> loss. + inputs: callable returning trax.inputs.Inputs. + optimizer: The optimizer (see optimizers/base.py for signature). + lr_schedule_fn: A learning rate schedule function, that when called returns + a function from step to learning rate (a float). + trainer_class: The trainers class to use. + steps: int, total number of training steps. + checkpoints_at: list of integers. Save a checkpoint for each training step + in the list. + permanent_checkpoints_at: list of integers. Save a permanent checkpoint for + each training step in the list. + eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. + eval_frequency: int, how often to run evaluation (every eval_frequency + steps). If None or 0, eval disabled. + permanent_checkpoint_frequency: int, how often to save permanent checkpoints + (every permanent_checkpoint_frequency steps). + random_seed: the random seed to use; time/os dependent if None (default). + save_graphs: bool, if True, save computation graph to file. + metrics: optionally override the default metrics dictionary. + checkpoint_highest: save the checkpoint highest at this metric. + checkpoint_lowest: save the checkpoint lowest at this metric. + use_loop: whether to use training.Loop instead of Trainer. + loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. + use_memory_efficient_trainer: whether to use memory-efficient trainers. + adasum: if True, use adaptive summation for multi-device gradients. + init_checkpoint: a checkpoint for fine tuning. + callbacks: a list of callbacks to call during training. + n_weights_shards: shard weights into this many devices. + additional_train_tasks: additional tasks which should be performed during + training. + additional_eval_tasks: additional tasks which should be performed during + evaluation. + additional_eval_streams: List[NamedStream], additional data streams that + should be used during evaluation. Can be provided independently of + additional_eval_tasks. + + Returns: + trax.TrainerState or training.Loop if use_loop is True + """ + base.N_WEIGHTS_SHARDS = n_weights_shards + if ( + permanent_checkpoint_frequency is not None + and permanent_checkpoints_at is not None + ): + raise ValueError( + 'Only one of ["permanent_checkpoint_frequency", ' + '"permanent_checkpoints_at"] should be set.' + ) + if use_loop: + n_devices = num_devices() or fastmath.local_device_count() + + # Prepare the training task. + # Inputs is either an Inputs instance or a function that returns it. + if callable(inputs): # If we pass a function, e.g., through gin, call it. + inputs = inputs() + opt = optimizer if use_memory_efficient_trainer else optimizer() + train_task = training.TrainTask( + inputs.train_stream(n_devices), + loss_layer=loss_fn, + optimizer=opt, + lr_schedule=lr_schedule_fn(), + n_steps_per_checkpoint=eval_frequency, + n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency, + ) + + if additional_train_tasks is None: + additional_train_tasks = [] + + # Prepare the evaluation. + metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS + names, metrics = zip(*metrics_dict.items()) + eval_task = training.EvalTask( + inputs.eval_stream(n_devices), + metrics, + metric_names=names, + n_eval_batches=eval_steps, + ) + + if additional_eval_tasks is None: + additional_eval_tasks = [] + + additional_eval_tasks_from_streams = [] + if additional_eval_streams is not None: + for stream in additional_eval_streams: + additional_eval_tasks_from_streams.append( + training.EvalTask( + stream.stream, + metrics, + metric_names=names, + n_eval_batches=eval_steps, + export_prefix=stream.name, + ) + ) + + # Prepare the training loop. + checkpoint_at = None + if checkpoints_at is not None: + checkpoint_at = lambda step: step in checkpoints_at + permanent_checkpoint_at = None + if permanent_checkpoints_at is not None: + permanent_checkpoint_at = lambda step: step in permanent_checkpoints_at + + # Setup the model. + model_train = model(mode="train") + model_predict_eval = model(mode="eval") + if init_checkpoint: + model_train.init_from_file(init_checkpoint, weights_only=True) + model_predict_eval.init_from_file(init_checkpoint, weights_only=True) + loop = training.Loop( + model_train, + [train_task] + additional_train_tasks, + eval_model=model_predict_eval, + eval_tasks=[eval_task] + + additional_eval_tasks + + additional_eval_tasks_from_streams, + output_dir=output_dir, + checkpoint_at=checkpoint_at, + checkpoint_low_metric=checkpoint_lowest, + checkpoint_high_metric=checkpoint_highest, + permanent_checkpoint_at=permanent_checkpoint_at, + n_devices=n_devices, + loss_chunk_size=loss_chunk_size, + use_memory_efficient_trainer=use_memory_efficient_trainer, + adasum=adasum, + random_seed=random_seed, + callbacks=callbacks, + ) + + steps_to_go = steps - loop.step + if steps_to_go <= 0: + log("Stop training, already reached the total training steps %d" % steps) + return loop + + # Train and return the loop. + loop.run(steps_to_go) + return loop + + n_devices = num_devices() + trainer = trainer_class( + model, + loss_fn, + optimizer, + lr_schedule_fn(), + inputs, + output_dir, + random_seed=random_seed, + n_devices=n_devices, + checkpoints_at=checkpoints_at, + metrics=metrics, + checkpoint_lowest=checkpoint_lowest, + checkpoint_highest=checkpoint_highest, + init_checkpoint=init_checkpoint, + ) + + epoch_steps = [steps] # Only training if eval_frequency is 0 or None + if eval_frequency and eval_steps > 0: + epoch_steps = itertools.chain( + [1, eval_frequency - 1], # first epoch only 1 step + itertools.repeat(eval_frequency), + ) + trainer.log_step("Starting training using %d devices" % trainer.n_devices) + trainer.print_n_weights() + + try: + for epoch_steps in epochs(steps, trainer.step, epoch_steps): + trainer.train_epoch(epoch_steps, eval_steps) + + # Bookkeeping we do at the first step + if trainer.step == 1: + # Save computation graph (single-device only for now) + if save_graphs and fastmath.is_backend(fastmath.Backend.JAX): + trainer.save_computation_graphs() + + # Save Gin config + trainer.save_gin() + + trainer.log_step("Training done") + except Exception as e: + raise e + finally: + trainer.close() + return trainer.state + + +@gin.configurable +def num_devices(value=None): + """Returns how many devices to use (if None, default, use all available).""" + return value + + +@gin.configurable +def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): + """Returns a (JIT-compiled) function that computes updates for one step.""" + model_and_loss = tl.Serial(predict_fn, loss_fn) + + # Gradients are always wrt. the first argument, so putting weights first. + def model_and_loss_call(weights, batch, state, rng): + res = model_and_loss(batch, weights=weights, state=state, rng=rng) + return res, model_and_loss.state + + if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. + + def single_update(weights_and_slots, i, opt_params, batch, state, rng): + weights, slots = weights_and_slots + rng, subrng = jax_random.split(rng[0]) + grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) + grads, state = grad_fn(weights, batch, state, rng) + new_weights, new_slots, stats = optimizer.tree_update( + i, grads, weights, slots, opt_params + ) + return (new_weights, new_slots), stats, state, [subrng] + + if jit: + # TODO(lukaszkaiser): donate_argnums=(0,) when XLA supports it on GPU + return fastmath.jit(single_update) + else: + return single_update + + # Else, for n_devices > 1: + @functools.partial(fastmath.pmap, axis_name="batch") # donate_argnums=(0,)) + def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): + """This is a multi-device version of the update function above.""" + # We assume all tensors have the first dimension = n_devices. + weights, slots = weights_and_slots + rng, subrng = jax_random.split(rng) + grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) + grads, state = grad_fn(weights, batch, state, rng) + # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just + # the number of devices on this host machine, however psum goes over all + # devices of all hosts (ex: a TPU pod) and we need to be averaging over all + # of them. + # + # Collect all gradients. + grads = fastmath.psum(grads, "batch") + n_devices_total = fastmath.psum(np.array(1.0), "batch") + # Average across hosts. + grads = jax.tree_util.tree_map(lambda g: g / n_devices_total, grads) + + new_weights, new_slots, stats = optimizer.tree_update( + i, grads, weights, slots, opt_params + ) + return (new_weights, new_slots), stats, state, subrng + + def update(weights_and_slots, i, opt_params, batch, state, rng): + return mapped_update( + weights_and_slots, np.repeat(i, n_devices), opt_params, batch, state, rng + ) + + return update + + +@gin.configurable +def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): + """Returns a JIT-compiled predict function (unless jit=False).""" + model = tl.Serial(model_predict, metric_fn) + if not jit: + return model.pure_fn + + return tl.jit_forward(model.pure_fn, n_devices) + + +@gin.configurable +def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True): + """Returns a (JIT-compiled) function that computes the loss for one step.""" + if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. + + def single_compute_loss(opt_state, batch, state, rng): + rng, subrng = jax_random.split(rng[0]) + loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) + return loss_val, state, [subrng] + + return fastmath.jit(single_compute_loss) if jit else single_compute_loss + + # Else, for n_devices > 1: + @functools.partial(fastmath.pmap, axis_name="batch") + def mapped_compute_loss(opt_state, batch, state, rng): + """This is a multi-device version of the update function above.""" + # We assume all tensors have the first dimension = n_devices. + rng, subrng = jax_random.split(rng) + loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) + return loss_val, state, subrng + + def compute_loss(opt_state, batch, state, rng): + return mapped_compute_loss( + opt_state, _reshape_by_device(batch, n_devices), state, rng + ) + + return compute_loss + + +def log(s, stdout=True): + logging.info(s) + if stdout: + print(s) + sys.stdout.flush() + + +def epochs(total_steps, steps_to_skip, epoch_steps): + """Generates the number of steps in each epoch before reaching total_steps. + + Args: + total_steps: int, total number of steps. + steps_to_skip: int, number of steps to skip because of a restart. + epoch_steps: iterable of int, numbers of steps in each epoch. + + Yields: + epoch_steps: int, number of steps in this epoch + """ + steps_to_go = total_steps - steps_to_skip + epoch_steps = iter(epoch_steps) + + # Remove the desired number of steps from the stream. + for steps_this_epoch in epoch_steps: + if steps_this_epoch > steps_to_skip: + # Put back the number of steps left in the unfinished epoch. + epoch_steps = itertools.chain( + [steps_this_epoch - steps_to_skip], epoch_steps + ) + if steps_this_epoch >= steps_to_skip: + break + steps_to_skip -= steps_this_epoch + + # Yield the remaining steps per epoch up to total_steps. + for steps_this_epoch in epoch_steps: + steps_this_epoch = min(steps_this_epoch, steps_to_go) + yield steps_this_epoch + steps_to_go -= steps_this_epoch + if steps_to_go == 0: + break + + +def make_trainer_state_dict(step, opt_state, history, model_state, input_signature): + """Creates a trainers state dictionary to save to disk. + + Args: + step: int, a step number + opt_state: OptState namedtuple + history: `trax.history.History`, the history object. + model_state: A nested structure of the model state. + input_signature: signature of model inputs. + + Returns: + A dictionary with the fields of TrainerState and OptState flattened. + """ + flat_weights, flat_state = tl.flatten_weights_and_state( + opt_state.weights, model_state + ) + return { + "step": step, + "flat_weights": flat_weights, + "slots": opt_state.slots, + "opt_params": opt_state.opt_params, + "history": history, + "flat_state": flat_state, + "input_signature": input_signature, + "version_timestamp": "Jun-18-2020", # To update in the future if needed. + } + + +def trainer_state_from_dict(trainer_state_dict, model): + """Given the trainers state dictionary, returns `TrainerState`.""" + # TODO(afrozm): This becomes simpler if OptState is flattened into + # TrainerState. + step = trainer_state_dict["step"] + history = trainer_state_dict["history"] + input_signature = trainer_state_dict["input_signature"] + weights_and_state_sig = model.weights_and_state_signature(input_signature) + weights, model_state = tl.unflatten_weights_and_state( + trainer_state_dict["flat_weights"], + trainer_state_dict["flat_state"], + weights_and_state_sig, + ) + opt_state = OptState( + weights=weights, + slots=trainer_state_dict["slots"], + opt_params=trainer_state_dict["opt_params"], + ) + return TrainerState( + step=step, + opt_state=OptState(*opt_state), + history=history, + model_state=model_state, + ) + + +def load_trainer_state(output_dir, model, weights_file=None): + """Returns a TrainerState instance loaded from the given `output_dir`.""" + if weights_file is None: + weights_file = os.path.join(output_dir, "model.pkl.gz") + if not tf.io.gfile.exists(weights_file): + return TrainerState( + step=None, + opt_state=None, + history=trax_history.History(), + model_state=None, + ) + elif not tf.io.gfile.exists(weights_file): + raise ValueError("File not found: %s" % weights_file) + + trainer_state_dict = training.unpickle_from_file(weights_file, gzip=True) + trainer_state = trainer_state_from_dict(trainer_state_dict, model) + log("Model loaded from %s at step %d" % (weights_file, trainer_state.step)) + logging.debug("From loaded model : history = %s", trainer_state.history) + return trainer_state + + +def _reshape_by_device(x, n_devices): + """Reshapes possibly nested x into a shape (n_devices, ...).""" + return tl.reshape_by_device(x, n_devices) # pylint: disable=protected-access + + +def _nested_reduce(f, x): + """Fold the function f to the nested structure x (dicts, tuples, lists).""" + if isinstance(x, list): + return f([_nested_reduce(f, y) for y in x]) + if isinstance(x, tuple): + return f([_nested_reduce(f, y) for y in x]) + if isinstance(x, dict): + return f([_nested_reduce(f, v) for (_, v) in x.items()]) + return x + + +def _sizes(x): + """Get a structure of sizes for a structure of nested arrays.""" + + def size(x): + try: + return x.size + except Exception: # pylint: disable=broad-except + return 0 + + return fastmath.nested_map(size, x) + + +def _repeat_stream(stream, n_devices): + """Repeat a stream indefinitely.""" + while True: + for example in stream(n_devices): + yield example diff --git a/trax/learning/supervised/training.py b/trax/learning/supervised/training.py new file mode 100644 index 000000000..e3f808dde --- /dev/null +++ b/trax/learning/supervised/training.py @@ -0,0 +1,1486 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simplified API (under development) for supervised learning/training in Trax. + +This module will eventually replace :py:class:`trainer_lib.Trainer`. + +Key classes: + + - :py:class:`Loop`: Core training loop for an n-step training session, + starting from random initialization. + + - :py:class:`TrainTask`: Labeled data + feedback mechanism (loss function w/ + optimizer) for modifying a model's weights. + + - :py:class:`Optimizer`: How to compute model weight updates using + loss-derived gradients. May contain state ("slots", 1-1 with model weights) + that accumulates across training steps. (This class is defined in the + :py:class:`trax.optimizers`.) + + - :py:class:`EvalTask`: How and when to measure model performance as a + function of training step number. +""" +import collections +import contextlib +import functools +import gzip as gzip_lib +import os +import pickle +import random +import sys +import time + +import gin +import jax +import numpy as np +import psutil +import tensorflow as tf + +from absl import logging + +from trax import fastmath +from trax import layers as tl +from trax.data.preprocessing import inputs +from trax.fastmath import numpy as jnp +from trax.fastmath import random as jax_random +from trax.layers import base +from trax.learning.supervised import history as trax_history +from trax.trainers import base as trainer +from trax.utils import jaxboard, shapes + +_Evaluator = collections.namedtuple("_Evaluator", ["weights", "state", "metrics_fn"]) + + +class Loop: + """Loop that can run for a given number of steps to train a supervised model. + + Can train the model on multiple tasks by interleaving updates according to the + ``which_task`` argument. + + The typical supervised training process randomly initializes a model and + updates its weights via feedback (loss-derived gradients) from a training + task, by looping through batches of labeled data. A training loop can also + be configured to run periodic evals and save intermediate checkpoints. + + For speed, the implementation takes advantage of JAX's composable function + transformations (specifically, ``jit`` and ``grad``). It creates JIT-compiled + pure functions derived from variants of the core model; schematically: + + - training variant: `jit(grad(pure_function(model+loss)))` + - evals variant: `jit(pure_function(model+evals))` + + In training or during evals, these variants are called with explicit + arguments for all relevant input data, model weights/state, optimizer slots, + and random number seeds: + + - batch: labeled data + - model weights/state: trainable weights and input-related state (e.g., as + used by batch norm) + - optimizer slots: weights in the optimizer that evolve during the training + process + - random number seeds: JAX PRNG keys that enable high-quality, distributed, + repeatable generation of pseudo-random numbers + """ + + def __init__( + self, + model, + tasks, + eval_model=None, + eval_tasks=None, + output_dir=None, + checkpoint_at=None, + checkpoint_low_metric=None, + checkpoint_high_metric=None, + permanent_checkpoint_at=None, + eval_at=None, + which_task=None, + n_devices=None, + random_seed=None, + loss_chunk_size=0, + use_memory_efficient_trainer=False, + adasum=False, + callbacks=None, + ): + """Configures a training ``Loop``, including a random initialization. + + Args: + model: Trax layer, representing the core model to be trained. Loss + functions and eval functions (a.k.a. metrics) are considered to be + outside the core model, taking core model output and data labels as + their two inputs. + tasks: List of :py:class:`TrainTask` instances, which define the training + data, loss function, and optimizer to be used in respective tasks in + this training loop. It can also be a single :py:class:`TrainTask` + instance which is treated in the same way as a singleton list. + eval_model: Optional Trax layer, representing model used for evaluation, + e.g., with dropout turned off. If ``None``, the training model (model) + will be used. + eval_tasks: List of :py:class:`EvalTask` instances which define how to + evaluate the model: which validation data to use and which metrics to + report. Evaluation on each of the tasks and will run and be reported + separately which allows to score a model on different subtasks. This + argument can also be ``None``, in which case no evals will be run, or + a single :py:class:`EvalTask`, which wil be treated in the same way + as a singleton list. + output_dir: Path telling where to save outputs (evals and checkpoints). + Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are + ``None``. + checkpoint_at: Function (integer --> boolean) telling, for step n, whether + that step should have its checkpoint saved. If ``None``, the default + is periodic checkpointing at ``task.n_steps_per_checkpoint``. + checkpoint_low_metric: Name of metric, or None. The metric name must + be one of the metric names from the evals in ``eval_tasks``. At + checkpoint times determined by ``checkpoint_at``, a separate + specially named checkpoint will be saved (overwriting any previous + version) if the designated metric reaches a value less than or equal + to any previous recorded low value. No such checkpoint is saved if + arg value is `None`. + checkpoint_high_metric: Name of metric, or None. The metric name must + be one of the metric names from the evals in ``eval_tasks``. At + checkpoint times determined by ``checkpoint_at``, a separate + specially named checkpoint will be saved (overwriting any previous + version) if the designated metric reaches a value greater than or + equal to any previous recorded high value. No such checkpoint is + saved if arg value is `None`. + permanent_checkpoint_at: Function (integer --> boolean) telling, + for step n, whether that step should have its checkpoint saved + permanently. If ``None``, the default is periodic checkpointing at + ``task.n_steps_per_permanent_checkpoint``. + eval_at: Function (integer --> boolean) that says, for training step n, + whether that step should run evals. If ``None``, run evals on the + first step and on every N'th step, as determined by the first + training task. + which_task: Function (integer --> integer) indicating which task should be + used at which training step. Can be set to ``None`` in single-task + training. + n_devices: integer or ``None``, the number of devices for this + computation. + random_seed: the random seed to use; time/os dependent if ``None`` + (default). + loss_chunk_size: int, if > 0 use chunks of this size to make loss + computation more more memory-efficient. + use_memory_efficient_trainer: whether to use a special memory-efficient + trainers; if set to 2, the memory efficiency if very aggressive + adasum: if True, use adaptive summation for multi-device gradients + callbacks: List of subclasses of StepCallback to call on training + steps. + """ + ( + self._is_chief, + self._n_hosts, + self._n_devices, + self._rng, + ) = init_host_and_devices(n_devices, random_seed) + if use_memory_efficient_trainer: + self._rng = tl.on_cpu(self._rng) + + # Handle single task case without lists too. + if not isinstance(tasks, (list, tuple)): + tasks = [tasks] + + if not tasks: + raise ValueError("Must provide at least one training task.") + if eval_tasks is None: + eval_tasks = [] + eval_at = _never + else: + if not isinstance(eval_tasks, (list, tuple)): + eval_tasks = [eval_tasks] + + self._tasks = tasks + self._model = model + self._eval_model = eval_model or model + + self._use_memory_efficient_trainer = use_memory_efficient_trainer + self._loss_chunk_size = loss_chunk_size + self._adasum = adasum + # TODO(lukaszkaiser): can we have different eval models and save memory? + if use_memory_efficient_trainer: + assert len(tasks) == 1, "only single task supported for now" + self._eval_model = model + + default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint) + permanent_default_at = _at_step_1_and_every_nth_step( + tasks[0].n_steps_per_permanent_checkpoint + ) + if output_dir is not None: + self._output_dir = os.path.expanduser(output_dir) + tf.io.gfile.makedirs(self._output_dir) + inputs.load_data_counters(self._output_dir) + else: + self._output_dir = None + + # Prepare training components. + self._step = 0 + self._history = trax_history.History() + self._checkpoint_at = checkpoint_at or default_at + self._checkpoint_low_metric = checkpoint_low_metric + self._checkpoint_high_metric = checkpoint_high_metric + self._permanent_checkpoint_at = permanent_checkpoint_at or permanent_default_at + if which_task is None: + # If which task is not passed, then we permute tasks one by one. + # If len(tasks) = 1, then which_task is a constant function equal to 0. + which_task = lambda n: n % len(tasks) + self._which_task = which_task + + # Initialize using the given random seed. + # NOTE: If random_seed is None then self._rng will be different on + # different hosts, leading to different weights on the different hosts. + self._batch_signature = shapes.signature(tasks[0].sample_batch) + self._model.rng = self.new_rng() + # In the memory-efficient case, we initialize in init_trainer. + if not use_memory_efficient_trainer: + if _is_uninitialized(self._model): + self._model.init(self._batch_signature) + self._eval_model.rng = self.new_rng() + if _is_uninitialized(self._eval_model): + self._eval_model.init(self._batch_signature) + + # To handle the above case (i.e. random_seed = None), we psum the weights + # and state and average them. + # NOTE: This adds time (how much?) so we prefer not to do it if it is + # unnecessary, i.e. random_seed was set. + # NOTE: Averaging the weights across devices can screw up the initial weight + # statistics. + # TODO(pkozakowski): Broadcast from one of the devices instead? + if ( + random_seed is None + and self._n_hosts > 1 + and not use_memory_efficient_trainer + ): + logging.info("Syncing weights/state across %d hosts.", self._n_hosts) + # Do self._sync_weights_and_state_across_hosts() but layer-by-layer + # to save memory. + blocks, last_layer = trainer.extract_reversible_blocks([self._model]) + all_layers = [] + for std_layer, rev_layers in blocks: + all_layers.append(tl.Serial(std_layer)) + all_layers.extend(rev_layers) + all_layers.append(last_layer) + for layer in all_layers: + weights_and_state = (layer.weights, layer.state) + if not _is_empty(weights_and_state): + layer.weights, layer.state = tl.on_cpu( + self._unreplicate( + _make_weights_and_state_same_across_hosts( + self._for_n_devices(weights_and_state) + ) + ) + ) + + # Create the optimizer for the training loss function. + self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) + + # Sync layers weights/state in memory effcient trainers layers. + if random_seed is None and self._n_hosts > 1 and use_memory_efficient_trainer: + logging.info("Syncing layers across %d hosts.", self._n_hosts) + for layer in self._trainer_per_task[0].all_layers: + weights_and_state = (layer.weights, layer.state) + if not _is_empty(weights_and_state): + layer.weights, layer.state = tl.on_cpu( + self._unreplicate( + _make_weights_and_state_same_across_hosts( + self._for_n_devices(weights_and_state) + ) + ) + ) + + # Load checkpoint if it exists. + self.load_checkpoint() + + # Prepare eval components. + self._eval_at = eval_at or default_at + self._eval_tasks = eval_tasks + loss_names = [task.loss_name for task in self._tasks] + metric_names = [ + name # pylint: disable=g-complex-comprehension + for eval_task in self._eval_tasks + for name in eval_task.metric_names + ] + self._rjust_len = max(map(len, loss_names + metric_names)) + self._evaluator_per_task = tuple( + self._init_evaluator(eval_task) for eval_task in self._eval_tasks + ) + + if self._output_dir is None: + _log("Will not write evaluation metrics, because output_dir is None.") + + def task_output_dir(task_index, task_list): + if self._output_dir is not None: + if len(task_list) < 2: + output_dir = self._output_dir + else: + output_dir = os.path.join( + self._output_dir, + task_list[task_index].export_prefix or str(task_index), + ) + tf.io.gfile.makedirs(output_dir) + return output_dir + else: + return None + + self._output_dir_per_eval_task = [ + task_output_dir(i, eval_tasks) for i in range(len(eval_tasks)) + ] + self._output_dir_per_train_task = [ + task_output_dir(i, tasks) for i in range(len(tasks)) + ] + + callbacks = callbacks or [] + self._callbacks = [callback_class(self) for callback_class in callbacks] + + def _init_trainer(self, task): + """Initializes the per-task trainers.""" + # Build the per-task model, sharing weights with other tasks. + if not self._use_memory_efficient_trainer: + model_in_training = _model_with_ends( + self._model, [task.loss_layer], shapes.signature(task.sample_batch) + ) + if base.N_WEIGHTS_SHARDS > 1: + sharded_weights = fastmath.nested_map( + lambda x: x[0], tl.shard(model_in_training.weights) + ) + task.optimizer.tree_init(sharded_weights) + else: + task.optimizer.tree_init(model_in_training.weights) + return trainer.Trainer( + model_in_training, task.optimizer, adasum=self._adasum + ) + # In the memory-efficient path, we initialize the model here. + blocks, loss_layer = trainer.extract_reversible_blocks( + [self._model, task.loss_layer], loss_chunk_size=self._loss_chunk_size + ) + rng = self._model.rng + sig = shapes.signature(task.sample_batch) + trainer.init_reversible_blocks(blocks, loss_layer, sig, rng) + # TODO(lukaszkaiser): here optimizer is a function, revisit this. + return trainer.ReversibleSerialTrainer( + blocks, + loss_layer, + task.optimizer, + free_accelerators_on_step=(self._use_memory_efficient_trainer == 2), + adasum=self._adasum, + ) + + def _init_evaluator(self, eval_task): + """Initializes the per-task evaluator.""" + model_with_metrics = _model_with_metrics(self._eval_model, eval_task) + if self._use_memory_efficient_trainer: + return _Evaluator( + weights=tl.on_cpu(model_with_metrics.weights[1]), + state=tl.on_cpu(model_with_metrics.state[1]), + metrics_fn=_accelerate_model_with_metrics(model_with_metrics, 0), + ) + else: + return _Evaluator( + # Replicate the eval part of weights and state. + weights=self._for_n_devices(model_with_metrics.weights[1]), + state=self._for_n_devices(model_with_metrics.state[1]), + metrics_fn=_accelerate_model_with_metrics( + model_with_metrics, self.n_devices + ), + ) + + def _sync_weights_and_state_across_hosts(self): + """Sync weights and state across all the hosts in the computation.""" + + if logging.vlog_is_on(1): + logging.debug( + "Input training weights shape: %s", + fastmath.nested_map(lambda x: x.shape, self._model.weights), + ) + logging.debug("Input training weights: %s", self._model.weights) + logging.debug("Input training state: %s", self._model.state) + logging.debug("Input eval weights: %s", self._eval_model.weights) + logging.debug("Input eval state: %s", self._eval_model.state) + + ( + self._model.weights, + self._model.state, + self._eval_model.weights, + self._eval_model.state, + ) = self._unreplicate( + _make_weights_and_state_same_across_hosts( + self._for_n_devices( + ( + self._model.weights, + self._model.state, + self._eval_model.weights, + self._eval_model.state, + ) + ) + ) + ) + + if logging.vlog_is_on(1): + logging.debug( + "Output training weights shape: %s", + fastmath.nested_map(lambda x: x.shape, self._model.weights), + ) + logging.debug("Output training weights: %s", self._model.weights) + logging.debug("Output training state: %s", self._model.state) + logging.debug("Output eval weights: %s", self._eval_model.weights) + logging.debug("Output eval state: %s", self._eval_model.state) + + def run(self, n_steps=1): + """Runs this training loop for n steps. + + Optionally runs evals and saves checkpoints at specified points. + + Args: + n_steps: Stop training after completing n steps. + """ + with self._open_summary_writers() as ( + train_summary_writers, + eval_summary_writers, + ): + process = psutil.Process(os.getpid()) + loss_acc, step_acc = 0.0, 0 + start_time = time.time() + optimizer_metrics_acc = collections.defaultdict(float) + for i in range(n_steps): + prev_task_index = self._which_task(self._step) + self._step += 1 + task_index = self._which_task(self._step) + task_changed = task_index != prev_task_index + + if task_changed: + loss_acc, step_acc = 0.0, 0 + + loss, optimizer_metrics = self._run_one_step(task_index, task_changed) + + # optimizer_metrics and loss are replicated on self.n_devices, a few + # metrics are replicated (ex: gradients_l2, weights_l2) - i.e. they are + # the same across devices, whereas some (ex: loss) aren't because they + # are different on different devices (due to different data). + # Taking the average does the correct thing in both the cases. + # + # NOTE: Only the weights and gradients are synced across the hosts. This + # implies the loss here is averaged from this hosts' devices and not + # across all hosts. + optimizer_metrics, loss = fastmath.nested_map( + functools.partial(tl.mean, self._n_devices), + (optimizer_metrics, loss), + ) + + loss_acc += loss + # Log loss every 50 steps, every step in memory-efficient trainers. + if self._step % 50 == 0 or self._use_memory_efficient_trainer: + self._log_step("Loss: %.4f" % loss, stdout=False) + step_acc += 1 + for metric_name, value in optimizer_metrics.items(): + optimizer_metrics_acc[metric_name] += value + + # TODO(yuwenyan): Finds a way to log the last round eval step in + # history. + # + # Right now, the last round eval log is missing in history since the + # checkpoint is saved before it. However sometimes the eval step will + # fail for some reasons, and it's not acceptable to loose the whole + # checkpoint in this case. Stays with the old way for now, and fixes it + # when the checkpoint format is changed to storing weights separately + # from a small file with history and other data. + if self._checkpoint_at(self.step): + self.save_checkpoint("model") + if self._permanent_checkpoint_at(self.step): + self.save_checkpoint(f"model_{self.step}") + if self._eval_at(self.step): + logging.info( + "cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + elapsed_time = time.time() - start_time + self._log_training_progress( + task=self._tasks[task_index], + total_loss=loss_acc, + n_steps=step_acc, + elapsed_time=elapsed_time, + optimizer_metrics=optimizer_metrics_acc, + summary_writer=train_summary_writers[task_index], + ) + self.run_evals(eval_summary_writers) + loss_acc, step_acc = 0.0, 0 + start_time = time.time() + optimizer_metrics_acc = collections.defaultdict(float) + + # For the current step, after all evals are run and recorded in the + # event history, check if we need to save special checkpoints because + # of a new low metric value or a new high metric value. + if self._checkpoint_at(self.step): + if self._checkpoint_low_metric is not None and self._at_lowest(): + self.save_checkpoint(f"lowest_{self._checkpoint_low_metric}") + if self._checkpoint_high_metric is not None and self._at_highest(): + self.save_checkpoint(f"highest_{self._checkpoint_high_metric}") + + # Store the final values back into their respective objects, for testing + # or other inspection/use. + # + # We keep the standard model weights/state unreplicated and + # tl.Accelerate(model) will carry the replicated weights/state. + # TODO(afrozm): Try to use tl.Accelerate(model) everywhere in the Loop. + self._eval_model.weights = self._model.weights + + def _at_lowest(self): + low_items = self.history.get("eval", f"metrics/{self._checkpoint_low_metric}") + vals = [float(obj[1]) for obj in low_items] + return vals[-1] == min(vals) + + def _at_highest(self): + high_items = self.history.get("eval", f"metrics/{self._checkpoint_high_metric}") + vals = [float(obj[1]) for obj in high_items] + return vals[-1] == max(vals) + + @property + def step(self): + """Returns current step number in this training session.""" + return self._step + + @property + def history(self): + """Returns history in this training session.""" + return self._history + + @property + def n_devices(self): + """Returns the number of devices to be used in this computation.""" + return self._n_devices + + @property + def is_chief(self): + """Returns true if this Loop is the chief.""" + return self._is_chief + + @property + def model(self): + """Returns the model that is training.""" + return self._model + + @property + def tasks(self): + """Returns the training tasks.""" + return self._tasks + + @property + def eval_model(self): + """Returns the model used for evaluation.""" + return self._eval_model + + @property + def eval_tasks(self): + """Returns the evaluation tasks.""" + return self._eval_tasks + + @property + def output_dir(self): + """Returns the output directory.""" + return self._output_dir + + def new_rng(self): + """Returns a new single-use random number generator (JAX PRNG key).""" + self._rng, rng = fastmath.random.split(self._rng) + if self._use_memory_efficient_trainer: + self._rng = tl.on_cpu(self._rng) + rng = tl.on_cpu(rng) + return rng + + def _for_n_devices(self, x): + """Replicates/broadcasts ``x`` for n devices if ``self.n_devicess > 1``.""" + return tl.for_n_devices(x, self.n_devices) + + def _unreplicate(self, x): + if self.n_devices == 1: + return x + + unreplicate_fn = lambda x: x[0] + return fastmath.nested_map(unreplicate_fn, x) + + def _reshape_by_device(self, x): + if self.n_devices == 1: + return x + return tl.reshape_by_device(x, self.n_devices) + + def update_weights_and_state(self, weights=None, state=None): + """Updates the weights and state of the trained model. + + Sends this data both to the singleton model accessible via Loop.model + and to the replicated model on the accelerator. + + Useful when the weights or state are modified outside of training, e.g. + during data collection in RL agents. + + Args: + weights: Model weights or ``None``. If ``None``, don't set. + state: Model state or ``None``. If ``None``, don't set. + """ + for trainer in self._trainer_per_task: + acc_model_with_loss = trainer.accelerated_model_with_loss + if weights is not None: + self._model.weights = weights + acc_model_with_loss.replicate_weights(trainer.model_with_loss.weights) + if state is not None: + self._model.state = state + acc_model_with_loss.replicate_state(trainer.model_with_loss.state) + + def _run_one_step(self, task_index, task_changed): + """Updates model weights/state and optimizer slots by running one step. + + Args: + task_index (int): Index of the task to train on. + task_changed (bool): Whether the state has changed since the last step. + + Returns: + Tuple (loss, stats) with loss value from one step + of training and stats, the current optimizer statistics. + """ + step = self.step + for callback in self._callbacks: + if callback.call_at(step): + callback.on_step_begin(step) + + learning_rate = self._tasks[task_index].learning_rate(step) + batch = self._tasks[task_index].next_batch() + rng = self.new_rng() + trainer = self._trainer_per_task[task_index] + if task_changed: + # Re-replicate weights and state to synchronize them between tasks. + self.update_weights_and_state(self._model.weights, self._model.state) + + (loss, stats) = trainer.one_step( + batch, rng, step=step, learning_rate=learning_rate + ) + + for callback in self._callbacks: + if callback.call_at(step): + callback.on_step_end(step) + + return (loss, stats) + + def _log_training_progress( + self, task, total_loss, n_steps, elapsed_time, optimizer_metrics, summary_writer + ): + """Logs training related metrics. + + Logs: + * current learning rate, + * steps per second, + * average training loss, + * average metrics returned from the optimizer + to the provided summary writer. Training loss is also logged to stdout. + + Args: + task: Current training task. + total_loss: Total training loss accumulated over n_steps training steps. + n_steps: Number of steps over which the metrics were accumulated. + elapsed_time: Time of execution of n_steps training steps. + optimizer_metrics: Dict from optimizer metric name to metric values. + summary_writer: Jaxboard summary writer for saving provided metrics. + """ + # only here do avoid potential divide-by-0 + n_steps = max(1, n_steps) + _log("") # Separator for visibility on terminals. + if self.step == 1: + self._log_n_weights() + self._log_step("Ran %d train steps in %0.2f secs" % (n_steps, elapsed_time)) + self.log_summary( + {task.loss_name: total_loss / float(n_steps)}, + summary_writer, + "metrics/", + "train", + ) + if self.step == 1: + self._save_gin(summary_writer) + train_parameters = { + "learning_rate": task.learning_rate(self.step), + "steps per second": n_steps / elapsed_time, + } + # Average optimizer_metrics over n_steps. + optimizer_metrics = {k: v / n_steps for k, v in optimizer_metrics.items()} + train_parameters.update(optimizer_metrics) + self.log_summary( + train_parameters, summary_writer, "training/", "train", stdout=False + ) + + def _save_gin(self, summary_writer): + """ "Saves the operative gin config.""" + if not self.is_chief or self._output_dir is None: + return + config_path = os.path.join(self._output_dir, "config.gin") + config_str = gin.operative_config_str() + with tf.io.gfile.GFile(config_path, "w") as f: + f.write(config_str) + if summary_writer is not None: + summary_writer.text( + "gin_config", jaxboard.markdownify_operative_config_str(config_str) + ) + + def _log_n_weights(self): + """ "Logs the number of weights in the training model.""" + + def _size(x): + try: + return x.size + except Exception: # pylint: disable=broad-except + return 0 + + sizes = fastmath.nested_map(_size, self._model.weights) + total_size = sum(fastmath.tree_flatten(sizes)) + total_size *= base.N_WEIGHTS_SHARDS + self._log_step("Total number of trainable weights: %d" % total_size) + + # TODO(afrozm): Fix multi-host evals, right now the reported numbers in the + # summary writer are only from the chief and not averaged across hosts. + def run_evals(self, summary_writers=None): + """Runs and records evals for this training session. + + Args: + summary_writers: List of per-task Jaxboard summary writers to log metrics. + """ + if summary_writers is None: + summary_writers = (None,) * len(self._eval_tasks) + + self._eval_model.weights = self._model.weights + self._eval_model.state = self._model.state + + def recursively_look_for_printable_states(state): + if isinstance(state, (tuple, list)): + for substate in state: + for item in recursively_look_for_printable_states(substate): + yield item + if isinstance(state, dict): + for key, value in state.items(): + if isinstance(key, str) and key.startswith("summary_"): + for device_id, device_value in enumerate(value): + yield ( + "device{}/{}".format(device_id, key[len("summary_") :]), + device_value, + ) + + # The most recently trained weights are in this trainers, use those for eval. + cur_train_task_index = self._which_task(self._step) + trainer = self._trainer_per_task[cur_train_task_index] + + for eval_task_index in range(len(self._eval_tasks)): + eval_task = self._eval_tasks[eval_task_index] + evaluator = self._evaluator_per_task[eval_task_index] + if eval_task is None: + continue + + # Extract the actual model weights and state, excluding the loss layer. + if self._use_memory_efficient_trainer: + model_weights, model_state = self._model.weights, self._model.state + else: + model_weights = trainer.accelerated_model_with_loss.weights[0] + model_state = trainer.accelerated_model_with_loss.state[0] + + # evaluator.{weights,state} are already replicated. + metrics_weights = (model_weights, evaluator.weights) + metrics_state = (model_state, evaluator.state) + + n_batches = eval_task.n_eval_batches + n_metrics = len(eval_task.metrics) + sums = np.zeros((n_metrics,)) + for _ in range(n_batches): + rng = self.new_rng() + batch = eval_task.next_batch() + metric_values, _ = evaluator.metrics_fn( + batch, metrics_weights, metrics_state, rng + ) + sums += metric_values + averages = sums / n_batches + all_metrics = dict(zip(eval_task.metric_names, averages)) + summary_writer = summary_writers[eval_task_index] + self.log_summary(all_metrics, summary_writer, "metrics/", "eval") + summary_metrics = dict(recursively_look_for_printable_states(model_state)) + self.log_summary(summary_metrics, summary_writer, "summary_", "eval") + + def log_summary( + self, values, summary_writer, value_prefix, log_prefix, stdout=True + ): + """Logs and saves provided metrics. + + Args: + values: Dict from metric name to metric value. + summary_writer: Jaxboard summary writer. + value_prefix: String appended in front of summary_writer entries. + log_prefix: String appended in front of logs. + stdout: Boolean saying if logs should be logged to stdout as well. + """ + history = self._history + should_write_summaries = self.is_chief and summary_writer is not None + for name, value in values.items(): + full_name = value_prefix + name + s = tuple(jnp.shape(value)) + if not s: + self._log_step( + "%s %s | % .8f" + % (log_prefix.ljust(5), name.rjust(self._rjust_len), value), + stdout=stdout, + ) + if should_write_summaries: + summary_writer.scalar(full_name, value, self.step) + else: + if should_write_summaries: + summary_writer.image(full_name, value, self.step) + if history: + history.append(log_prefix, full_name, self.step, value) + if should_write_summaries: + summary_writer.flush() + + def _log_step(self, msg, stdout=True): + """Logs message, labeled with the current training step number.""" + _log("Step % 6d: %s" % (self.step, msg), stdout=stdout) + + def save_checkpoint(self, basename): + """Saves checkpoint (multiple files) to disk for the current training step. + + Saving a checkpoint will overwrite any previous checkpoint saved with the + same ``basename``. Use differing ``basename`` values to save multiple + checkpoints or multiple copies of the same checkpoint. + + Args: + basename: Basename for saving a checkpoint. Full file paths for the saved + checkpoint will combine the output dir, basename, and relevant file + extensions (e.g., `.weights.npy.gz`). + """ + if self._output_dir is None: + _log("Did not save checkpoint as output_dir is None") + return + + inputs.save_data_counters(self._output_dir) + if not self.is_chief: + _log("Did not save checkpoint as we are not chief.") + return + + dir_and_basename = os.path.join(self._output_dir, basename) + pkl_file = dir_and_basename + ".pkl.gz" + + _log("Saving checkpoint to %s" % pkl_file, stdout=False) + weights = self._model.weights + if base.N_WEIGHTS_SHARDS > 1: + weights = self._trainer_per_task[0].accelerated_model_with_loss.weights + weights = tl.unshard(weights) + state = self._model.state + compresslevel = 0 if self._use_memory_efficient_trainer else 2 + # Serialize optimizer slots. + for i, trainer in enumerate(self._trainer_per_task): + flat_slots = _flatten_and_remove_empty(trainer.slots) + tl.np_to_file( + self._to_bits(flat_slots), + f"{dir_and_basename}.opt_slots{i}.npy.gz", + compresslevel=compresslevel, + ) + # We only need the input signature for the body, not for the loss layers. + # That part is the same across tasks - take it from the first one. + input_signature = self._batch_signature[: self._model.n_in] + flat_weights, flat_state = tl.flatten_weights_and_state(weights, state) + _, flat_eval_state = tl.flatten_weights_and_state( + weights, self._eval_model.state + ) + tl.np_to_file( + self._to_bits(flat_weights), + f"{dir_and_basename}.weights.npy.gz", + compresslevel=compresslevel, + ) + d = { + "step": self.step, + "flat_weights": compresslevel, # for compatibility with older format + "flat_state": flat_state, + "flat_eval_state": flat_eval_state, + "history": self._history.to_dict(), + "slots_per_task": compresslevel, # for compatibility with older format + "input_signature": input_signature, + "version_timestamp": "Mar-10-2021", # To update in the future if needed. + } + pickle_to_file(d, pkl_file, gzip=True) + _log("Checkpoint saved in %s" % pkl_file, stdout=False) + + def _to_bits(self, weights): + """Converts a list of weights to bit-cast weights and their types.""" + # This is currently needed to pickle bfloat16 arrays from JAX. + # TODO(lukaszkaiser): remove once it is not needed (the following unit test + # checks it: training_test/test_restores_step_bfloat16). + if not fastmath.is_backend(fastmath.Backend.JAX): + return weights + bits = [] + for w in weights: + if w.dtype == jnp.bfloat16: + converted = jax.lax.bitcast_convert_type(w, np.uint16) + bits.append(np.asarray(converted.astype(np.uint16))) + else: # for non-bfloat16 weights, be compatible with earlier checkpoints + bits.append(np.asarray(w)) + return bits + + def _from_bits(self, bits): + """Converts a list of bit-cast weights back to weights.""" + # This is the reverse of _to_bits, see above for explanation. + if not fastmath.is_backend(fastmath.Backend.JAX): + return bits + weights = [] + for b in bits: + if b.dtype == np.uint16: # currently all uint16 are bfloat16s + w = jax.lax.bitcast_convert_type(b, jnp.bfloat16) + weights.append(np.asarray(w)) + else: + weights.append(b) + return weights + + def load_checkpoint(self, directory=None, filename=None): + """Loads model weights and step from a checkpoint on disk. + + Args: + directory: Directory with the checkpoint (self._output_dir by default). + filename: Checkpoint file name (model.pkl.gz by default). + """ + directory = directory or self._output_dir + if directory is None: + _log("Not loading as both directory and output_dir are None.", stdout=False) + return + filename = filename or "model" + path = os.path.join(directory, filename) + pkl_path = path + ".pkl.gz" + if not tf.io.gfile.exists(pkl_path): + _log( + f"Not loading as checkpoint file does not exist: {pkl_path}", + stdout=False, + ) + return + _log("Loading checkpoint from %s" % pkl_path, stdout=False) + d = unpickle_from_file(pkl_path, gzip=True) + # Weights are stored in a separate non-pickled file in the new checkpoint + # format. We support loading old checkpoints with this hack. + # TODO(lukaszkaiser): remove the hack when not needed any more. + if isinstance(d["flat_weights"], int): + weights = tl.np_from_file( + path + ".weights.npy.gz", compresslevel=d["flat_weights"] + ) + d["flat_weights"] = weights + else: + d["flat_weights"] = d["flat_weights"] + # The same holds for optimizer slots. + if "slots" in d: # Old checkpoints had just 'slots' for one task. + if len(self._tasks) != 1: + raise ValueError( + "Can't load a single-task checkpoint into a multitask Loop." + ) + d["slots_per_task"] = [d["slots"]] + # Read from separate files if optimizer slots are in them. + if "slots_per_task" in d and isinstance(d["slots_per_task"], int): + compresslevel = d["slots_per_task"] + d["slots_per_task"] = [] + for i in range(len(self._trainer_per_task)): + slots = tl.np_from_file( + path + f".opt_slots{i}.npy.gz", compresslevel=compresslevel + ) + d["slots_per_task"].append(slots) + for trainer, slots in zip(self._trainer_per_task, d["slots_per_task"]): + matched_flat_slots = _match_by_shape( + self._to_bits(_flatten_and_remove_empty(trainer.slots)), + _flatten_and_remove_empty(slots), + ) + matched_slots, _ = fastmath.tree_unflatten( + self._from_bits(matched_flat_slots), + trainer.slots, + copy_from_tree=[None, ()], + ) + trainer.slots = matched_slots + self._step = d["step"] + self._history = trax_history.History.from_dict(d["history"]) + # This is self._model.init_from_file but optimized to not re-read. + input_signature = d["input_signature"] + weights_and_state_sig = self._model.weights_and_state_signature(input_signature) + flat_init_weights, flat_init_state = tl.flatten_weights_and_state( + self._model.weights, self._model.state + ) + if len(d["flat_weights"]) < len(flat_init_weights): + _log("Checkpoint has less weights than the model, loading first ones.") + matched_weights = _match_by_shape( + self._to_bits(flat_init_weights), d["flat_weights"] + ) + matched_weights = self._from_bits(matched_weights) + try: + restored_state = True + matched_state = _match_by_shape( + self._to_bits(flat_init_state), d["flat_state"] + ) + matched_state = self._from_bits(matched_state) + weights, state = tl.unflatten_weights_and_state( + matched_weights, matched_state, weights_and_state_sig + ) + self._model.state = state + except IndexError: + _log("Failed loading model state from checkpoint, loading weights only.") + restored_state = False + weights, _ = tl.unflatten_weights_and_state( + matched_weights, (), weights_and_state_sig, weights_only=True + ) + self._model.weights = weights + self._eval_model.weights = self._model.weights + # Restore eval model state; note: it's not always the same as train state. + if restored_state: + if "flat_eval_state" in d: + flat_eval_state = d["flat_eval_state"] + else: # It wasn't saved in old checkpoints; remove this branch once done. + flat_eval_state = d["flat_state"] + _, eval_state = tl.unflatten_weights_and_state( + matched_weights, flat_eval_state, weights_and_state_sig + ) + self._eval_model.state = eval_state + _log("Checkpoint loaded from %s" % pkl_path, stdout=False) + + @contextlib.contextmanager + def _open_summary_writers(self): + """Opens the Jaxboard summary writers wrapped by context manager. + + Yields: + A pair (train_summary_writers, eval_summary_writers) of lists of + Jaxboard summary writers wrapped in a GeneratorContextManager object. + Elements of the lists correspond to the training and evaluation task + directories created during initialization. If there was no output_dir + provided, yields lists of Nones with the appropriate length. + """ + if self._output_dir is not None: + _log(f"Metrics will be written in {self._output_dir}.", stdout=False) + train_writers = [ + jaxboard.SummaryWriter(os.path.join(output_dir, "train")) + for output_dir in self._output_dir_per_train_task + ] + eval_writers = [ + jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) + for output_dir in self._output_dir_per_eval_task + ] + try: + yield (train_writers, eval_writers) + finally: + for writer in train_writers + eval_writers: + writer.close() + _log(f"Metrics were written in {self._output_dir}", stdout=False) + else: + yield ([None] * len(self._tasks), [None] * len(self._eval_tasks)) + + +def _model_with_ends(model, end_layers, batch_signature): + """Returns a model+ends layer built on an already initialized model. + + Ends can be loss or metric layers. + + Args: + model: Layer with initialized weights and state. + end_layers: List of end layers. + batch_signature: Signature of the model input batch. + + Returns: + An initialized, combined model+ends layer, preserving the initialization + of ``model``. + """ + # TODO(jonni): Redo this function as part of an initialization refactor? + metrics_layer = tl.Branch(*end_layers) + metrics_input_signature = model.output_signature(batch_signature) + _, _ = metrics_layer.init(metrics_input_signature) + + model_with_metrics = tl.Serial(model, metrics_layer) + return model_with_metrics + + +def _model_with_metrics(model, eval_task): + """Returns a model+metrics layer built on an already initialized model. + + Args: + model: Layer with initialized weights and state. + eval_task: :py:class:`EvalTask` instance. + + Returns: + An initialized, combined model+metrics layer, preserving the initialization + of ``model``. + """ + return _model_with_ends( + model, eval_task.metrics, shapes.signature(eval_task.sample_batch) + ) + + +@gin.configurable +class TrainTask: + """A supervised task (labeled data + feedback mechanism) for training.""" + + def __init__( + self, + labeled_data, + loss_layer, + optimizer, + lr_schedule=None, + n_steps_per_checkpoint=100, + n_steps_per_permanent_checkpoint=None, + loss_name=None, + sample_batch=None, + export_prefix=None, + ): + r"""Configures a training task. + + Args: + labeled_data: Iterator of batches of labeled data tuples. Each tuple has + 1+ data (input value) tensors followed by 1 label (target value) + tensor. All tensors are NumPy ndarrays or their JAX counterparts. + loss_layer: Layer that computes a scalar value (the "loss") by comparing + model output :math:`\hat{y}=f(x)` to the target :math:`y`. + optimizer: Optimizer object that computes model weight updates from + loss-function gradients. + lr_schedule: Learning rate schedule, a function step -> learning_rate. + n_steps_per_checkpoint: How many steps to run between checkpoints. + n_steps_per_permanent_checkpoint: How many steps to run between permanent + checkpoints. + loss_name: Name for the loss metric. + sample_batch: Optional sample batch for model initialization. If not + provided, it will be taken from ``labeled_data``. + export_prefix: Optional task name to be used as prefix for exporting + metrics during training in Loop. + """ + self._export_prefix = export_prefix + self._labeled_data = labeled_data + self._loss_layer = loss_layer + self._optimizer = optimizer + self._lr_schedule = lr_schedule + self._sample_batch = sample_batch or next(labeled_data) + self._n_steps_per_checkpoint = n_steps_per_checkpoint + self._n_steps_per_permanent_checkpoint = n_steps_per_permanent_checkpoint + self._loss_name = loss_name or self._loss_layer.name + + @property + def labeled_data(self): + return self._labeled_data + + @property + def sample_batch(self): + return self._sample_batch + + def next_batch(self): + """Returns one batch of labeled data: a tuple of input(s) plus label.""" + return next(self._labeled_data) + + @property + def export_prefix(self): + return self._export_prefix + + @property + def loss_layer(self): + return self._loss_layer + + @property + def loss_name(self): + return self._loss_name + + @property + def n_steps_per_checkpoint(self): + return self._n_steps_per_checkpoint + + @property + def n_steps_per_permanent_checkpoint(self): + return self._n_steps_per_permanent_checkpoint + + @property + def optimizer(self): + return self._optimizer + + def learning_rate(self, step): + """Return the learning rate for the given step.""" + if self._lr_schedule is not None: + with fastmath.use_backend(fastmath.Backend.NUMPY): + return self._lr_schedule(step) + opt = self._optimizer + if callable(opt): # when optimizer is a function, like Adam, not Adam() + opt = opt() + params = opt._init_opt_params # pylint: disable=protected-access + return params["learning_rate"] + + +@gin.configurable +class EvalTask: + """Labeled data plus scalar functions for (periodically) measuring a model. + + An eval task specifies how (``labeled_data`` + ``metrics``) and with what + precision (``n_eval_batches``) to measure a model as it is training. + The variance of each scalar output is reduced by measuring over multiple + (``n_eval_batches``) batches and reporting the average from those + measurements. + """ + + def __init__( + self, + labeled_data, + metrics, + metric_names=None, + n_eval_batches=1, + sample_batch=None, + export_prefix=None, + ): + r"""Configures an eval task: named metrics run with a given data source. + + Args: + labeled_data: Iterator of batches of labeled data tuples. Each tuple has + 1+ data tensors (NumPy ndarrays) followed by 1 label (target value) + tensor. + metrics: List of layers; each computes a scalar value per batch by + comparing model output :math:`\hat{y}=f(x)` to the target :math:`y`. + metric_names: List of names, one for each item in ``metrics``, in matching + order, to be used when recording/reporting eval output. If ``None``, + generate default names using layer names from metrics. + n_eval_batches: Integer N that specifies how many eval batches to run; + the output is then the average of the outputs from the N batches. + sample_batch: Optional sample batch for model initialization. If not + provided, it will be taken from ``labeled_data``. + export_prefix: Optional task name to be used as prefix for exporting + metrics during evaluation in Loop. + """ + self._export_prefix = export_prefix + self._labeled_data = labeled_data + self._metrics = metrics + self._metric_names = metric_names or self._default_names() + self._n_eval_batches = n_eval_batches # pylint: disable=invalid-name + + self._sample_batch = sample_batch or next(labeled_data) + self._check_init_values() + + @property + def labeled_data(self): + return self._labeled_data + + @property + def sample_batch(self): + return self._sample_batch + + def next_batch(self): + """Returns one batch of labeled data: a tuple of input(s) plus label.""" + return next(self._labeled_data) + + @property + def export_prefix(self): + return self._export_prefix + + @property + def metrics(self): + return self._metrics + + @property + def metric_names(self): + return self._metric_names + + @property + def n_eval_batches(self): + return self._n_eval_batches + + def _default_names(self): + return [m.name for m in self._metrics] + + def _check_init_values(self): + if len(self._metrics) != len(self._metric_names): + raise ValueError( + f"Number of metrics ({len(self._metrics)}) does not equal " + f"number of metric names ({len(self._metric_names)})." + ) + + +def _never(*args): + """Returns False for all step numbers.""" + del args + return False + + +def _at_step_1_and_every_nth_step(period): + """A function that's true at 1 and n when n % period == 0.""" + if period is None: + return lambda step_n: False + + def _at_1_and_periodically(step_n): + return (step_n == 1) or (step_n > 0 and (step_n % period == 0)) + + return _at_1_and_periodically + + +def _log(s, stdout=True): + logging.info(s) + if stdout: + print(s) + sys.stdout.flush() + + +def pickle_to_file(obj, file_path, gzip=False): + """Pickle obj to file_path with gzipping and failure protection.""" + # Pickle to tmp file and overwrite to prevent writing partial files. + tmp_file_path = file_path + "._tmp_" + with tf.io.gfile.GFile(tmp_file_path, "wb") as f: + if not gzip: + pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) + else: + with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: + pickle.dump(obj, gzipf, protocol=pickle.HIGHEST_PROTOCOL) + # Moving a file is much less error-prone than pickling large files. + tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) + + +def unpickle_from_file(file_path, gzip=False): + """Unpickle obj from file_path with gzipping.""" + with tf.io.gfile.GFile(file_path, "rb") as f: + if not gzip: + obj = pickle.load(f) + else: + with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: + obj = pickle.load(gzipf) + return obj + + +def _init_random_number_generators(seed=None): + """Initializes random generators for Python, NumPy, TensorFlow, and JAX.""" + # Seed Python random (None as seed is okay), then use it to seed the others. + random.seed(seed) + if seed is None: + seed = random.randint(0, 2**31 - 1) + logging.info("using seed %d", seed) + np.random.seed(seed) + tf.random.set_seed(seed) + return jax_random.get_prng(seed) + + +def init_host_and_devices(n_devices=None, random_seed=None): + """Initializes host and device attributes for this trainers. + + Args: + n_devices: Number of devices this trainers will use. If ``None``, get the + number from the backend. + random_seed: Random seed as the starting point for all random numbers used + by the trainers. If ``None``, calculate one from system time and host id. + + Returns: + is_chief: True if this trainers has special chief responsibilities. + host_count: Number of hosts in this computation. + n_devices: The passed in value of n_devices or a computed default (for this + host). + random_seed: The passed in value of random_seed or a computed default. + """ + if fastmath.is_backend(fastmath.Backend.JAX): + host_id = jax.process_index() + host_count = jax.host_count() + else: + host_id = 0 + host_count = 1 + is_chief = host_id == 0 + + logging.info( + "Initializing hosts and devices: host_id %d, host_count %d, " "is_chief %d", + host_id, + host_count, + is_chief, + ) + + device_count = fastmath.local_device_count() + n_devices = n_devices or device_count + # TODO(lukaszkaiser): remove this restriction when possible. + if n_devices != device_count and fastmath.is_backend(fastmath.Backend.JAX): + raise ValueError( + "JAX cannot work yet with n_devices != all devices: " + "%d != %d" % (n_devices, device_count) + ) + + if random_seed is None and host_count > 1: + random_seed = int(1e6 * (host_id + time.time())) % 2**31 + return ( + is_chief, + host_count, + n_devices, + _init_random_number_generators(random_seed), + ) + + +def _accelerate_model_with_metrics( + model_with_metrics, n_devices, accelerate=True, do_mean=True +): + if not accelerate: + return model_with_metrics.pure_fn + + return tl.jit_forward(model_with_metrics.pure_fn, n_devices, do_mean=do_mean) + + +@functools.partial(fastmath.pmap, axis_name="devices", donate_argnums=(0,)) +def _make_weights_and_state_same_across_hosts(weights_and_state): + """Makes train and eval model's weights and state the same across hosts.""" + + # We assume that weights_and_state have been already replicated, i.e the + # leading axis is self._n_devices + + # This is the total number of devices across all hosts. + n_devices_total = fastmath.psum(jnp.array(1.0), "devices").astype(jnp.int32) + + # We average the weights and state across all devices. + # We also make sure we don't change the type of the weights and state. + return fastmath.nested_map( + lambda x: (fastmath.psum(x, "devices") / n_devices_total).astype(x.dtype), + weights_and_state, + ) + + +def _is_empty(x): + if isinstance(x, (list, tuple)): + return all(_is_empty(y) for y in x) + else: + return x is None + + +def _is_uninitialized(model): + """Checks whether no weights in the model have been initialized.""" + if not _is_empty(model.weights): + return False + return all(_is_uninitialized(l) for l in model.sublayers) + + +def _match_by_shape(full, partial): + """Puts partial into full matching by shape.""" + partial_idx = 0 + res = [] + for w in full: + if partial_idx >= len(partial): + res.append(w) # read everything from parial list, just fill + elif w is None and partial[partial_idx] is None: # both Nones, move on + res.append(None) + partial_idx += 1 + elif w is None or partial[partial_idx] is None: # one None but not both + res.append(w) + elif w.shape == partial[partial_idx].shape: + res.append(partial[partial_idx]) + partial_idx += 1 + else: + res.append(w) + if partial_idx < len(partial): + _log("Did not manage to match shapes in model for all checkpoint weights.") + for w in partial[:partial_idx]: + _log(" Inserted tensor of shape %s" % str(w.shape)) + for i, w in enumerate(partial[partial_idx:]): + _log(" Not inserted tensor of shape %s" % str(w.shape)) + model_weight_shape = str(full[i + partial_idx].shape) + _log(" Tensor in that place has shape: %s" % model_weight_shape) + raise IndexError + return res + + +def _flatten_and_remove_empty(x): + try: + flat = fastmath.tree_flatten(x)[0] + # First try with the safer type check approach + return [ + f for f in flat if f is not None and not (isinstance(f, tuple) and len(f) == 0) + ] + except (TypeError, AttributeError, IndexError): + flat = fastmath.tree_flatten(x) + return [ + f for f in flat if f is not None and f != () + ] # pylint: disable=literal-comparison diff --git a/trax/models/Attention_Visualization_in_Trax.ipynb b/trax/models/Attention_Visualization_in_Trax.ipynb deleted file mode 100644 index 040b6d676..000000000 --- a/trax/models/Attention_Visualization_in_Trax.ipynb +++ /dev/null @@ -1,1601 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "7yuytuIllsv1" - }, - "source": [ - "# Attention Visualization in Trax\n", - "\n", - "For more information see the [tenso2tensor](https://trax-ml.readthedocs.io/en/latest/) visualization colab. All js tools are taken from the tensor2tensor version along with attention processing methods. The \"viz\" mode for a Trax model used in this colab [was added to Trax](https://github.com/google/trax/commit/e9a171379ef206a3e351b67cef91fe40bf37589c) with the attention visualization in mind. The colab re-uses some parts of the [Intro to Trax](https://github.com/google/trax/blob/master/trax/intro.ipynb) colab.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BIl27504La0G" - }, - "source": [ - "**General Setup**\n", - "\n", - "Execute the following few cells (once) before running of visualization codes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": {}, - "colab_type": "code", - "id": "oILRLCWN_16u" - }, - "outputs": [], - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", - "import json\n", - "import numpy as np\n", - "import os\n", - "import IPython.display as display\n", - "import gin" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "both", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 466 - }, - "colab_type": "code", - "id": "vlGjGoGMTt-D", - "outputId": "28f4556b-caef-47a1-bddd-7f51ecc064d8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[K |████████████████████████████████| 368kB 2.8MB/s \n", - "\u001b[K |████████████████████████████████| 1.5MB 13.0MB/s \n", - "\u001b[K |████████████████████████████████| 2.6MB 20.1MB/s \n", - "\u001b[K |████████████████████████████████| 163kB 33.1MB/s \n", - "\u001b[K |████████████████████████████████| 194kB 19.4MB/s \n", - "\u001b[K |████████████████████████████████| 983kB 30.6MB/s \n", - "\u001b[K |████████████████████████████████| 655kB 56.6MB/s \n", - "\u001b[K |████████████████████████████████| 81kB 11.7MB/s \n", - "\u001b[K |████████████████████████████████| 5.3MB 45.0MB/s \n", - "\u001b[K |████████████████████████████████| 368kB 57.1MB/s \n", - "\u001b[K |████████████████████████████████| 307kB 55.8MB/s \n", - "\u001b[K |████████████████████████████████| 358kB 58.6MB/s \n", - "\u001b[K |████████████████████████████████| 1.1MB 59.0MB/s \n", - "\u001b[K |████████████████████████████████| 3.5MB 58.4MB/s \n", - "\u001b[K |████████████████████████████████| 778kB 59.4MB/s \n", - "\u001b[K |████████████████████████████████| 51kB 8.7MB/s \n", - "\u001b[K |████████████████████████████████| 51kB 8.6MB/s \n", - "\u001b[K |████████████████████████████████| 235kB 54.2MB/s \n", - "\u001b[K |████████████████████████████████| 3.0MB 62.4MB/s \n", - "\u001b[K |████████████████████████████████| 890kB 58.2MB/s \n", - "\u001b[?25h Building wheel for bz2file (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for pypng (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[31mERROR: kfac 0.2.2 has requirement tensorflow-probability==0.8, but you'll have tensorflow-probability 0.7.0 which is incompatible.\u001b[0m\n", - "INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 \n" - ] - } - ], - "source": [ - "#@title\n", - "# Import Trax\n", - "\n", - "!pip install -q -U trax\n", - "import trax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "VCBjVMrZRS6q" - }, - "outputs": [], - "source": [ - "#@title Some cool tooling for attention (make sure that you run the cell)\n", - "def resize(att_mat, max_length=None):\n", - " \"\"\"Normalize attention matrices and reshape as necessary.\"\"\"\n", - " for i, att in enumerate(att_mat):\n", - " # Add extra batch dim for viz code to work.\n", - " if att.ndim == 3:\n", - " att = np.expand_dims(att, axis=0)\n", - " if max_length is not None:\n", - " # Sum across different attention values for each token.\n", - " att = att[:, :, :max_length, :max_length]\n", - " row_sums = np.sum(att, axis=2)\n", - " # Normalize\n", - " att /= row_sums[:, :, np.newaxis]\n", - " att_mat[i] = att\n", - " return att_mat\n", - "\n", - "\n", - "def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):\n", - " \"\"\"Compute representation of the attention ready for the d3 visualization.\n", - "\n", - " Args:\n", - " inp_text: list of strings, words to be displayed on the left of the vis\n", - " out_text: list of strings, words to be displayed on the right of the vis\n", - " enc_atts: numpy array, encoder self-attentions\n", - " [num_layers, batch_size, num_heads, enc_length, enc_length]\n", - " dec_atts: numpy array, decoder self-attentions\n", - " [num_layers, batch_size, num_heads, dec_length, dec_length]\n", - " encdec_atts: numpy array, encoder-decoder attentions\n", - " [num_layers, batch_size, num_heads, dec_length, enc_length]\n", - "\n", - " Returns:\n", - " Dictionary of attention representations with the structure:\n", - " {\n", - " 'all': Representations for showing all attentions at the same time.\n", - " 'inp_inp': Representations for showing encoder self-attentions\n", - " 'inp_out': Representations for showing encoder-decoder attentions\n", - " 'out_out': Representations for showing decoder self-attentions\n", - " }\n", - " and each sub-dictionary has structure:\n", - " {\n", - " 'att': list of inter attentions matrices, one for each attention head\n", - " 'top_text': list of strings, words to be displayed on the left of the vis\n", - " 'bot_text': list of strings, words to be displayed on the right of the vis\n", - " }\n", - " \"\"\"\n", - " def get_full_attention(layer):\n", - " \"\"\"Get the full input+output - input+output attentions.\"\"\"\n", - " enc_att = enc_atts[layer][0]\n", - " dec_att = dec_atts[layer][0]\n", - " encdec_att = encdec_atts[layer][0]\n", - " enc_att = np.transpose(enc_att, [0, 2, 1])\n", - " dec_att = np.transpose(dec_att, [0, 2, 1])\n", - " encdec_att = np.transpose(encdec_att, [0, 2, 1])\n", - " # [heads, query_length, memory_length]\n", - " enc_length = enc_att.shape[1]\n", - " dec_length = dec_att.shape[1]\n", - " num_heads = enc_att.shape[0]\n", - " first = np.concatenate([enc_att, encdec_att], axis=2)\n", - " second = np.concatenate(\n", - " [np.zeros((num_heads, dec_length, enc_length)), dec_att], axis=2)\n", - " full_att = np.concatenate([first, second], axis=1)\n", - " return [ha.T.tolist() for ha in full_att]\n", - "\n", - " def get_inp_inp_attention(layer):\n", - " att = np.transpose(enc_atts[layer][0], (0, 2, 1))\n", - " return [ha.T.tolist() for ha in att]\n", - "\n", - " def get_out_inp_attention(layer):\n", - " att = np.transpose(encdec_atts[layer][0], (0, 2, 1))\n", - " return [ha.T.tolist() for ha in att]\n", - "\n", - " def get_out_out_attention(layer):\n", - " att = np.transpose(dec_atts[layer][0], (0, 2, 1))\n", - " return [ha.T.tolist() for ha in att]\n", - "\n", - " def get_attentions(get_attention_fn):\n", - " num_layers = len(enc_atts)\n", - " return [get_attention_fn(i) for i in range(num_layers)]\n", - "\n", - " attentions = {\n", - " 'all': {\n", - " 'att': get_attentions(get_full_attention),\n", - " 'top_text': inp_text + out_text,\n", - " 'bot_text': inp_text + out_text,\n", - " },\n", - " 'inp_inp': {\n", - " 'att': get_attentions(get_inp_inp_attention),\n", - " 'top_text': inp_text,\n", - " 'bot_text': inp_text,\n", - " },\n", - " 'inp_out': {\n", - " 'att': get_attentions(get_out_inp_attention),\n", - " 'top_text': inp_text,\n", - " 'bot_text': out_text,\n", - " },\n", - " 'out_out': {\n", - " 'att': get_attentions(get_out_out_attention),\n", - " 'top_text': out_text,\n", - " 'bot_text': out_text,\n", - " },\n", - " }\n", - "\n", - " return attentions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "47lzWIH5THcw" - }, - "outputs": [], - "source": [ - "#@title Some cool HTML and js stuff (make sure that you run the cell)\n", - "vis_html = \"\"\"\n", - " \u003cspan style=\"user-select:none\"\u003e\n", - " Layer: \u003cselect id=\"layer\"\u003e\u003c/select\u003e\n", - " Attention: \u003cselect id=\"att_type\"\u003e\n", - " \u003coption value=\"all\"\u003eAll\u003c/option\u003e\n", - " \u003coption value=\"inp_inp\"\u003eInput - Input\u003c/option\u003e\n", - " \u003coption value=\"inp_out\"\u003eInput - Output\u003c/option\u003e\n", - " \u003coption value=\"out_out\"\u003eOutput - Output\u003c/option\u003e\n", - " \u003c/select\u003e\n", - " \u003c/span\u003e\n", - " \u003cdiv id='vis'\u003e\u003c/div\u003e\n", - "\"\"\"\n", - "def call_html():\n", - " import IPython\n", - " display.display(display.HTML('''\n", - " \u003cscript src=\"/static/components/requirejs/require.js\"\u003e\u003c/script\u003e\n", - " \u003cscript\u003e\n", - " requirejs.config({\n", - " paths: {\n", - " base: '/static/base',\n", - " \"d3\": \"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min\",\n", - " jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n", - " },\n", - " });\n", - " \u003c/script\u003e\n", - " '''))\n", - "vis_js = \"\"\"\n", - "/**\n", - " * @fileoverview Transformer Visualization D3 javascript code.\n", - " */\n", - "\n", - "requirejs(['jquery', 'd3'],\n", - "function($, d3) {\n", - "\n", - "var attention = window.attention;\n", - "\n", - "const TEXT_SIZE = 15;\n", - "const BOXWIDTH = TEXT_SIZE * 8;\n", - "const BOXHEIGHT = TEXT_SIZE * 1.5;\n", - "const WIDTH = 2000;\n", - "const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;\n", - "const MATRIX_WIDTH = 150;\n", - "const head_colours = d3.scale.category10();\n", - "const CHECKBOX_SIZE = 20;\n", - "\n", - "function lighten(colour) {\n", - " var c = d3.hsl(colour);\n", - " var increment = (1 - c.l) * 0.6;\n", - " c.l += increment;\n", - " c.s -= increment;\n", - " return c;\n", - "}\n", - "\n", - "function transpose(mat) {\n", - " return mat[0].map(function(col, i) {\n", - " return mat.map(function(row) {\n", - " return row[i];\n", - " });\n", - " });\n", - "}\n", - "\n", - "function zip(a, b) {\n", - " return a.map(function (e, i) {\n", - " return [e, b[i]];\n", - " });\n", - "}\n", - "\n", - "\n", - "function renderVis(id, top_text, bot_text, attention_heads, config) {\n", - " $(id).empty();\n", - " var svg = d3.select(id)\n", - " .append('svg')\n", - " .attr(\"width\", WIDTH)\n", - " .attr(\"height\", HEIGHT);\n", - "\n", - " var att_data = [];\n", - " for (var i=0; i \u003c attention_heads.length; i++) {\n", - " var att_trans = transpose(attention_heads[i]);\n", - " att_data.push(zip(attention_heads[i], att_trans));\n", - " }\n", - "\n", - " renderText(svg, top_text, true, att_data, 0);\n", - " renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);\n", - "\n", - " renderAttentionHighlights(svg, att_data);\n", - "\n", - " svg.append(\"g\").classed(\"attention_heads\", true);\n", - "\n", - " renderAttention(svg, attention_heads);\n", - "\n", - " draw_checkboxes(config, 0, svg, attention_heads);\n", - "}\n", - "\n", - "\n", - "function renderText(svg, text, is_top, att_data, left_pos) {\n", - " var id = is_top ? \"top\" : \"bottom\";\n", - " var textContainer = svg.append(\"svg:g\")\n", - " .attr(\"id\", id);\n", - "\n", - " textContainer.append(\"g\").classed(\"attention_boxes\", true)\n", - " .selectAll(\"g\")\n", - " .data(att_data)\n", - " .enter()\n", - " .append(\"g\")\n", - " .selectAll(\"rect\")\n", - " .data(function(d) {return d;})\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"x\", function(d, i, j) {\n", - " return left_pos + box_offset(j);\n", - " })\n", - " .attr(\"y\", function(d, i) {\n", - " return (+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .attr(\"fill\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .style(\"opacity\", 0.0);\n", - "\n", - "\n", - " var tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n", - " .data(text)\n", - " .enter()\n", - " .append(\"g\");\n", - "\n", - " tokenContainer.append(\"rect\")\n", - " .classed(\"background\", true)\n", - " .style(\"opacity\", 0.0)\n", - " .attr(\"fill\", \"lightgray\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH)\n", - " .attr(\"height\", BOXHEIGHT);\n", - "\n", - " var theText = tokenContainer.append(\"text\")\n", - " .text(function(d) { return d; })\n", - " .attr(\"font-size\", TEXT_SIZE + \"px\")\n", - " .style(\"cursor\", \"default\")\n", - " .style(\"-webkit-user-select\", \"none\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " });\n", - "\n", - " if (is_top) {\n", - " theText.style(\"text-anchor\", \"end\")\n", - " .attr(\"dx\", BOXWIDTH - TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " } else {\n", - " theText.style(\"text-anchor\", \"start\")\n", - " .attr(\"dx\", + TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " }\n", - "\n", - " tokenContainer.on(\"mouseover\", function(d, index) {\n", - " textContainer.selectAll(\".background\")\n", - " .style(\"opacity\", function(d, i) {\n", - " return i == index ? 1.0 : 0.0;\n", - " });\n", - "\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"none\");\n", - "\n", - " svg.selectAll(\".line_heads\") // To get the nesting to work.\n", - " .selectAll(\".att_lines\")\n", - " .attr(\"stroke-opacity\", function(d) {\n", - " return 1.0;\n", - " })\n", - " .attr(\"y1\", function(d, i) {\n", - " if (is_top) {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", function(d, i) {\n", - " if (is_top) {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .attr(\"stroke-opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j]) {\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " });\n", - "\n", - "\n", - " function updateAttentionBoxes() {\n", - " var id = is_top ? \"bottom\" : \"top\";\n", - " var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;\n", - " svg.select(\"#\" + id)\n", - " .selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .attr(\"x\", function(d, i, j) { return the_left_pos + box_offset(j); })\n", - " .attr(\"y\", function(d, i) { return (i+1) * BOXHEIGHT; })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .style(\"opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j])\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " else\n", - " return 0.0;\n", - "\n", - " });\n", - " }\n", - "\n", - " updateAttentionBoxes();\n", - " });\n", - "\n", - " textContainer.on(\"mouseleave\", function() {\n", - " d3.select(this).selectAll(\".background\")\n", - " .style(\"opacity\", 0.0);\n", - "\n", - " svg.selectAll(\".att_lines\").attr(\"stroke-opacity\", 0.0);\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"inline\");\n", - " svg.selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .style(\"opacity\", 0.0);\n", - " });\n", - "}\n", - "\n", - "function renderAttentionHighlights(svg, attention) {\n", - " var line_container = svg.append(\"g\");\n", - " line_container.selectAll(\"g\")\n", - " .data(attention)\n", - " .enter()\n", - " .append(\"g\")\n", - " .classed(\"line_heads\", true)\n", - " .selectAll(\"line\")\n", - " .data(function(d){return d;})\n", - " .enter()\n", - " .append(\"line\").classed(\"att_lines\", true);\n", - "}\n", - "\n", - "function renderAttention(svg, attention_heads) {\n", - " var line_container = svg.selectAll(\".attention_heads\");\n", - " line_container.html(null);\n", - " for(var h=0; h\u003cattention_heads.length; h++) {\n", - " for(var a=0; a\u003cattention_heads[h].length; a++) {\n", - " for(var s=0; s\u003cattention_heads[h][a].length; s++) {\n", - " line_container.append(\"line\")\n", - " .attr(\"y1\", (s+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", (a+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", head_colours(h))\n", - " .attr(\"stroke-opacity\", function() {\n", - " if (config.head_vis[h]) {\n", - " return attention_heads[h][a][s]/active_heads();\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " }());\n", - " }\n", - " }\n", - " }\n", - "}\n", - "\n", - "// Checkboxes\n", - "function box_offset(i) {\n", - " var num_head_above = config.head_vis.reduce(\n", - " function(acc, val, cur) {return val \u0026\u0026 cur \u003c i ? acc + 1: acc;}, 0);\n", - " return num_head_above*(BOXWIDTH / active_heads());\n", - "}\n", - "\n", - "function active_heads() {\n", - " return config.head_vis.reduce(function(acc, val) {\n", - " return val ? acc + 1: acc;\n", - " }, 0);\n", - "}\n", - "\n", - "function draw_checkboxes(config, top, svg, attention_heads) {\n", - " var checkboxContainer = svg.append(\"g\");\n", - " var checkbox = checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"fill\", function(d, i) {\n", - " return head_colours(i);\n", - " })\n", - " .attr(\"x\", function(d, i) {\n", - " return (i+1) * CHECKBOX_SIZE;\n", - " })\n", - " .attr(\"y\", top)\n", - " .attr(\"width\", CHECKBOX_SIZE)\n", - " .attr(\"height\", CHECKBOX_SIZE);\n", - "\n", - " function update_checkboxes() {\n", - " checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .attr(\"fill\", function(d, i) {\n", - " var head_colour = head_colours(i);\n", - " var colour = d ? head_colour : lighten(head_colour);\n", - " return colour;\n", - " });\n", - " }\n", - "\n", - " update_checkboxes();\n", - "\n", - " checkbox.on(\"click\", function(d, i) {\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) return;\n", - " config.head_vis[i] = !config.head_vis[i];\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "\n", - " checkbox.on(\"dblclick\", function(d, i) {\n", - " // If we double click on the only active head then reset\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) {\n", - " config.head_vis = new Array(config.num_heads).fill(true);\n", - " } else {\n", - " config.head_vis = new Array(config.num_heads).fill(false);\n", - " config.head_vis[i] = true;\n", - " }\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "}\n", - "\n", - "var config = {\n", - " layer: 0,\n", - " att_type: 'all',\n", - "};\n", - "\n", - "function visualize() {\n", - " var num_heads = attention['all']['att'][0].length;\n", - " config.head_vis = new Array(num_heads).fill(true);\n", - " config.num_heads = num_heads;\n", - " config.attention = attention;\n", - "\n", - " render();\n", - "}\n", - "\n", - "function render() {\n", - " var conf = config.attention[config.att_type];\n", - "\n", - " var top_text = conf.top_text;\n", - " var bot_text = conf.bot_text;\n", - " var attention = conf.att[config.layer];\n", - "\n", - " $(\"#vis svg\").empty();\n", - " renderVis(\"#vis\", top_text, bot_text, attention, config);\n", - "}\n", - "\n", - "$(\"#layer\").empty();\n", - "for(var i=0; i\u003c6; i++) {\n", - " $(\"#layer\").append($(\"\u003coption /\u003e\").val(i).text(i));\n", - "}\n", - "\n", - "$(\"#layer\").on('change', function(e) {\n", - " config.layer = +e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"#att_type\").on('change', function(e) {\n", - " config.att_type = e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"button\").on('click', visualize);\n", - "\n", - "visualize();\n", - "\n", - "});\n", - "\"\"\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-LQ89rFFsEdk" - }, - "source": [ - "## 1. Run a pre-trained Transformer\n", - "\n", - "* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)\n", - "* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)\n", - "* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)\n", - "* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)\n", - "* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "id": "djTiSLcaNFGa", - "outputId": "b5ad2955-5e1d-47aa-97bb-5d72a25ed76d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Es ist schÃļn, heute neue Dinge zu lernen!\n" - ] - } - ], - "source": [ - "# Create a Transformer model.\n", - "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", - "model = trax.models.Transformer(\n", - " input_vocab_size=33300,\n", - " d_model=512, d_ff=2048,\n", - " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", - " max_len=2048, mode='predict')\n", - "\n", - "# Initialize using pre-trained weights.\n", - "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", - " weights_only=True)\n", - "\n", - "# Tokenize a sentence.\n", - "sentence = 'It is nice to learn new things today!'\n", - "tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword'))[0]\n", - "\n", - "# Decode from the Transformer.\n", - "tokenized = tokenized[None, :] # Add batch dimension.\n", - "tokenized_translation = trax.supervised.decoding.autoregressive_sample(\n", - " model, tokenized, temperature=0.0) # Higher temperature: more diverse results.\n", - "\n", - "# De-tokenize,\n", - "tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.\n", - "translation = trax.data.detokenize(tokenized_translation,\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword')\n", - "print(translation)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "id": "pWDPwZfSJeD3", - "outputId": "050d40bf-f28d-49ea-b69a-af2886cf92a4" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([[ 118, 16, 1902, 9, 3197, 141, 1059, 420, 207]]),\n", - " array([ 168, 24, 9358, 2, 352, 367, 2427, 18, 3580, 207]))" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenized, tokenized_translation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Lu6URNjbXIHv" - }, - "source": [ - "## 2. Prepare the tokens for visualization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "kqNWMpNdMg9z" - }, - "outputs": [], - "source": [ - "def decode(single_token):\n", - " return trax.data.detokenize(single_token,\n", - " vocab_dir='gs://trax-ml/vocabs/',\n", - " vocab_file='ende_32k.subword')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "H2fbJB_BMeRw" - }, - "outputs": [], - "source": [ - "def get_tokens_str(integers):\n", - " token_strs = []\n", - " for i in range(integers.shape[1]):\n", - " token_strs.append(decode(integers[:,i]))\n", - " return token_strs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "YkNT8rbgKM5-" - }, - "outputs": [], - "source": [ - "tokenized_translation_with_start = np.array([0]+list(tokenized_translation), dtype=np.int64)\n", - "tokenized_translation_with_start = tokenized_translation_with_start[np.newaxis, ...]\n", - "tokenized_translation = np.array(tokenized_translation, dtype=np.int64)\n", - "tokenized_translation = tokenized_translation[np.newaxis, ...]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "r-FVdSZPKQhs" - }, - "outputs": [], - "source": [ - "tokenized_str = get_tokens_str(tokenized)\n", - "tokenized_translation_str = get_tokens_str(tokenized_translation_with_start)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 223 - }, - "colab_type": "code", - "id": "Cy7edKBuKash", - "outputId": "c1e00dbe-f467-48df-eaaf-579f68ef788f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(['It', 'is', 'nice', 'to', 'learn', 'new', 'things', 'today', '!'],\n", - " ['\u003cpad\u003e',\n", - " 'Es',\n", - " 'ist',\n", - " 'schÃļn',\n", - " ', ',\n", - " 'heute',\n", - " 'neue',\n", - " 'Dinge',\n", - " 'zu',\n", - " 'lernen',\n", - " '!'])" - ] - }, - "execution_count": 11, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenized_str, tokenized_translation_str" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "1XxJSqAsOTBe" - }, - "outputs": [], - "source": [ - "max_len = max(tokenized.shape[1], tokenized_translation.shape[1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Qju-9pPHOV6G" - }, - "outputs": [], - "source": [ - "tokenized_translation_pad = np.zeros((1,max_len), dtype=np.int64)\n", - "tokenized_translation_pad[:,:tokenized_translation.shape[1]] = tokenized_translation\n", - "\n", - "tokenized_pad = np.zeros((1,max_len), dtype=np.int64)\n", - "tokenized_pad[:,:tokenized.shape[1]] = tokenized" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "id": "zGxBSk0gOfYi", - "outputId": "d83328fa-eec8-4631-d2b6-4fffc3f0b933" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((1, 10), (1, 10))" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenized_translation_pad.shape, tokenized_pad.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "WqvjmRaCXign" - }, - "source": [ - "## 3. Create the same pre-trained model in the \"viz\" mode." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Qb2F4Pj_OLMZ" - }, - "outputs": [], - "source": [ - "# Create a Transformer model in the \"viz\" mode\n", - "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n", - "model_viz = trax.models.Transformer(\n", - " input_vocab_size=33300,\n", - " d_model=512, d_ff=2048,\n", - " n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n", - " max_len=2048, mode='viz')\n", - "\n", - "# Initialize using pre-trained weights.\n", - "model_viz.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n", - " weights_only=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "AxcrAfprO0rD" - }, - "outputs": [], - "source": [ - "# We run the viz model because later we want to inspect its state\n", - "_ = model_viz((tokenized_pad, tokenized_translation_pad))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "lVCYSQSuXw6f" - }, - "source": [ - "## 4. Find the attention weights (aka dots)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "dsGuqdgnO2Lf" - }, - "outputs": [], - "source": [ - "attention_weights = []\n", - "def attention_sublayers(layer):\n", - " if 'Attention' in layer.name:\n", - " print(\"Found layer {}\".format(layer.name))\n", - " attention_weights.append(layer.state)\n", - " if layer.sublayers:\n", - " for sublayer in layer.sublayers:\n", - " attention_sublayers(sublayer)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 326 - }, - "colab_type": "code", - "id": "FA3ba2-DO5l4", - "outputId": "f66756b1-fa86-4582-bd04-9b464ae132eb" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n", - "Found layer DotProductCausalAttention\n", - "Found layer PureAttention\n" - ] - } - ], - "source": [ - "attention_sublayers(model_viz)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "id": "q36-o98QO7HC", - "outputId": "445fe1ce-f1fa-484a-9db4-b37f56915d7c" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "18" - ] - }, - "execution_count": 19, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "len(attention_weights)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "LahOE6q6PB1B" - }, - "outputs": [], - "source": [ - "# Manually identification of layers would be difficult, hence we rely on attention_sublayers function\n", - "enc_atts = attention_weights[:6]\n", - "dec_atts = attention_weights[6::2] # these are the DotProductCausalAttention layers\n", - "encdec_atts = attention_weights[7::2] # these are the PureAttention layers starting from the 6th layer on\n", - "\n", - "# Here we use a number of python utils inherited from tensor2tensor\n", - "enc_atts_res = resize(enc_atts)\n", - "dec_atts_res = resize(dec_atts)\n", - "encdec_atts_res = resize(encdec_atts)\n", - "attention_dict = _get_attention(tokenized_str, tokenized_translation_str, enc_atts_res, dec_atts_res, encdec_atts_res)\n", - "attention_json = json.dumps(attention_dict)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1DgBBfg-X6-d" - }, - "source": [ - "## 5. Display attention" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "resources": { - "http://localhost:8080/static/components/requirejs/require.js": { - "data": "LyoqIHZpbTogZXQ6dHM9NDpzdz00OnN0cz00CiAqIEBsaWNlbnNlIFJlcXVpcmVKUyAyLjEuMjIgQ29weXJpZ2h0IChjKSAyMDEwLTIwMTUsIFRoZSBEb2pvIEZvdW5kYXRpb24gQWxsIFJpZ2h0cyBSZXNlcnZlZC4KICogQXZhaWxhYmxlIHZpYSB0aGUgTUlUIG9yIG5ldyBCU0QgbGljZW5zZS4KICogc2VlOiBodHRwOi8vZ2l0aHViLmNvbS9qcmJ1cmtlL3JlcXVpcmVqcyBmb3IgZGV0YWlscwogKi8KLy9Ob3QgdXNpbmcgc3RyaWN0OiB1bmV2ZW4gc3RyaWN0IHN1cHBvcnQgaW4gYnJvd3NlcnMsICMzOTIsIGFuZCBjYXVzZXMKLy9wcm9ibGVtcyB3aXRoIHJlcXVpcmVqcy5leGVjKCkvdHJhbnNwaWxlciBwbHVnaW5zIHRoYXQgbWF5IG5vdCBiZSBzdHJpY3QuCi8qanNsaW50IHJlZ2V4cDogdHJ1ZSwgbm9tZW46IHRydWUsIHNsb3BweTogdHJ1ZSAqLwovKmdsb2JhbCB3aW5kb3csIG5hdmlnYXRvciwgZG9jdW1lbnQsIGltcG9ydFNjcmlwdHMsIHNldFRpbWVvdXQsIG9wZXJhICovCgp2YXIgcmVxdWlyZWpzLCByZXF1aXJlLCBkZWZpbmU7CihmdW5jdGlvbiAoZ2xvYmFsKSB7CiAgICB2YXIgcmVxLCBzLCBoZWFkLCBiYXNlRWxlbWVudCwgZGF0YU1haW4sIHNyYywKICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCwgY3VycmVudGx5QWRkaW5nU2NyaXB0LCBtYWluU2NyaXB0LCBzdWJQYXRoLAogICAgICAgIHZlcnNpb24gPSAnMi4xLjIyJywKICAgICAgICBjb21tZW50UmVnRXhwID0gLyhcL1wqKFtcc1xTXSo/KVwqXC98KFteOl18XilcL1wvKC4qKSQpL21nLAogICAgICAgIGNqc1JlcXVpcmVSZWdFeHAgPSAvW14uXVxzKnJlcXVpcmVccypcKFxzKlsiJ10oW14nIlxzXSspWyInXVxzKlwpL2csCiAgICAgICAganNTdWZmaXhSZWdFeHAgPSAvXC5qcyQvLAogICAgICAgIGN1cnJEaXJSZWdFeHAgPSAvXlwuXC8vLAogICAgICAgIG9wID0gT2JqZWN0LnByb3RvdHlwZSwKICAgICAgICBvc3RyaW5nID0gb3AudG9TdHJpbmcsCiAgICAgICAgaGFzT3duID0gb3AuaGFzT3duUHJvcGVydHksCiAgICAgICAgYXAgPSBBcnJheS5wcm90b3R5cGUsCiAgICAgICAgaXNCcm93c2VyID0gISEodHlwZW9mIHdpbmRvdyAhPT0gJ3VuZGVmaW5lZCcgJiYgdHlwZW9mIG5hdmlnYXRvciAhPT0gJ3VuZGVmaW5lZCcgJiYgd2luZG93LmRvY3VtZW50KSwKICAgICAgICBpc1dlYldvcmtlciA9ICFpc0Jyb3dzZXIgJiYgdHlwZW9mIGltcG9ydFNjcmlwdHMgIT09ICd1bmRlZmluZWQnLAogICAgICAgIC8vUFMzIGluZGljYXRlcyBsb2FkZWQgYW5kIGNvbXBsZXRlLCBidXQgbmVlZCB0byB3YWl0IGZvciBjb21wbGV0ZQogICAgICAgIC8vc3BlY2lmaWNhbGx5LiBTZXF1ZW5jZSBpcyAnbG9hZGluZycsICdsb2FkZWQnLCBleGVjdXRpb24sCiAgICAgICAgLy8gdGhlbiAnY29tcGxldGUnLiBUaGUgVUEgY2hlY2sgaXMgdW5mb3J0dW5hdGUsIGJ1dCBub3Qgc3VyZSBob3cKICAgICAgICAvL3RvIGZlYXR1cmUgdGVzdCB3L28gY2F1c2luZyBwZXJmIGlzc3Vlcy4KICAgICAgICByZWFkeVJlZ0V4cCA9IGlzQnJvd3NlciAmJiBuYXZpZ2F0b3IucGxhdGZvcm0gPT09ICdQTEFZU1RBVElPTiAzJyA/CiAgICAgICAgICAgICAgICAgICAgICAvXmNvbXBsZXRlJC8gOiAvXihjb21wbGV0ZXxsb2FkZWQpJC8sCiAgICAgICAgZGVmQ29udGV4dE5hbWUgPSAnXycsCiAgICAgICAgLy9PaCB0aGUgdHJhZ2VkeSwgZGV0ZWN0aW5nIG9wZXJhLiBTZWUgdGhlIHVzYWdlIG9mIGlzT3BlcmEgZm9yIHJlYXNvbi4KICAgICAgICBpc09wZXJhID0gdHlwZW9mIG9wZXJhICE9PSAndW5kZWZpbmVkJyAmJiBvcGVyYS50b1N0cmluZygpID09PSAnW29iamVjdCBPcGVyYV0nLAogICAgICAgIGNvbnRleHRzID0ge30sCiAgICAgICAgY2ZnID0ge30sCiAgICAgICAgZ2xvYmFsRGVmUXVldWUgPSBbXSwKICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwoKICAgIGZ1bmN0aW9uIGlzRnVuY3Rpb24oaXQpIHsKICAgICAgICByZXR1cm4gb3N0cmluZy5jYWxsKGl0KSA9PT0gJ1tvYmplY3QgRnVuY3Rpb25dJzsKICAgIH0KCiAgICBmdW5jdGlvbiBpc0FycmF5KGl0KSB7CiAgICAgICAgcmV0dXJuIG9zdHJpbmcuY2FsbChpdCkgPT09ICdbb2JqZWN0IEFycmF5XSc7CiAgICB9CgogICAgLyoqCiAgICAgKiBIZWxwZXIgZnVuY3Rpb24gZm9yIGl0ZXJhdGluZyBvdmVyIGFuIGFycmF5LiBJZiB0aGUgZnVuYyByZXR1cm5zCiAgICAgKiBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoKGFyeSwgZnVuYykgewogICAgICAgIGlmIChhcnkpIHsKICAgICAgICAgICAgdmFyIGk7CiAgICAgICAgICAgIGZvciAoaSA9IDA7IGkgPCBhcnkubGVuZ3RoOyBpICs9IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICAvKioKICAgICAqIEhlbHBlciBmdW5jdGlvbiBmb3IgaXRlcmF0aW5nIG92ZXIgYW4gYXJyYXkgYmFja3dhcmRzLiBJZiB0aGUgZnVuYwogICAgICogcmV0dXJucyBhIHRydWUgdmFsdWUsIGl0IHdpbGwgYnJlYWsgb3V0IG9mIHRoZSBsb29wLgogICAgICovCiAgICBmdW5jdGlvbiBlYWNoUmV2ZXJzZShhcnksIGZ1bmMpIHsKICAgICAgICBpZiAoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpOwogICAgICAgICAgICBmb3IgKGkgPSBhcnkubGVuZ3RoIC0gMTsgaSA+IC0xOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgIGlmIChhcnlbaV0gJiYgZnVuYyhhcnlbaV0sIGksIGFyeSkpIHsKICAgICAgICAgICAgICAgICAgICBicmVhazsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQogICAgICAgIH0KICAgIH0KCiAgICBmdW5jdGlvbiBoYXNQcm9wKG9iaiwgcHJvcCkgewogICAgICAgIHJldHVybiBoYXNPd24uY2FsbChvYmosIHByb3ApOwogICAgfQoKICAgIGZ1bmN0aW9uIGdldE93bihvYmosIHByb3ApIHsKICAgICAgICByZXR1cm4gaGFzUHJvcChvYmosIHByb3ApICYmIG9ialtwcm9wXTsKICAgIH0KCiAgICAvKioKICAgICAqIEN5Y2xlcyBvdmVyIHByb3BlcnRpZXMgaW4gYW4gb2JqZWN0IGFuZCBjYWxscyBhIGZ1bmN0aW9uIGZvciBlYWNoCiAgICAgKiBwcm9wZXJ0eSB2YWx1ZS4gSWYgdGhlIGZ1bmN0aW9uIHJldHVybnMgYSB0cnV0aHkgdmFsdWUsIHRoZW4gdGhlCiAgICAgKiBpdGVyYXRpb24gaXMgc3RvcHBlZC4KICAgICAqLwogICAgZnVuY3Rpb24gZWFjaFByb3Aob2JqLCBmdW5jKSB7CiAgICAgICAgdmFyIHByb3A7CiAgICAgICAgZm9yIChwcm9wIGluIG9iaikgewogICAgICAgICAgICBpZiAoaGFzUHJvcChvYmosIHByb3ApKSB7CiAgICAgICAgICAgICAgICBpZiAoZnVuYyhvYmpbcHJvcF0sIHByb3ApKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBTaW1wbGUgZnVuY3Rpb24gdG8gbWl4IGluIHByb3BlcnRpZXMgZnJvbSBzb3VyY2UgaW50byB0YXJnZXQsCiAgICAgKiBidXQgb25seSBpZiB0YXJnZXQgZG9lcyBub3QgYWxyZWFkeSBoYXZlIGEgcHJvcGVydHkgb2YgdGhlIHNhbWUgbmFtZS4KICAgICAqLwogICAgZnVuY3Rpb24gbWl4aW4odGFyZ2V0LCBzb3VyY2UsIGZvcmNlLCBkZWVwU3RyaW5nTWl4aW4pIHsKICAgICAgICBpZiAoc291cmNlKSB7CiAgICAgICAgICAgIGVhY2hQcm9wKHNvdXJjZSwgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICBpZiAoZm9yY2UgfHwgIWhhc1Byb3AodGFyZ2V0LCBwcm9wKSkgewogICAgICAgICAgICAgICAgICAgIGlmIChkZWVwU3RyaW5nTWl4aW4gJiYgdHlwZW9mIHZhbHVlID09PSAnb2JqZWN0JyAmJiB2YWx1ZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAhaXNBcnJheSh2YWx1ZSkgJiYgIWlzRnVuY3Rpb24odmFsdWUpICYmCiAgICAgICAgICAgICAgICAgICAgICAgICEodmFsdWUgaW5zdGFuY2VvZiBSZWdFeHApKSB7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIXRhcmdldFtwcm9wXSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0ge307CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgbWl4aW4odGFyZ2V0W3Byb3BdLCB2YWx1ZSwgZm9yY2UsIGRlZXBTdHJpbmdNaXhpbik7CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgdGFyZ2V0W3Byb3BdID0gdmFsdWU7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9KTsKICAgICAgICB9CiAgICAgICAgcmV0dXJuIHRhcmdldDsKICAgIH0KCiAgICAvL1NpbWlsYXIgdG8gRnVuY3Rpb24ucHJvdG90eXBlLmJpbmQsIGJ1dCB0aGUgJ3RoaXMnIG9iamVjdCBpcyBzcGVjaWZpZWQKICAgIC8vZmlyc3QsIHNpbmNlIGl0IGlzIGVhc2llciB0byByZWFkL2ZpZ3VyZSBvdXQgd2hhdCAndGhpcycgd2lsbCBiZS4KICAgIGZ1bmN0aW9uIGJpbmQob2JqLCBmbikgewogICAgICAgIHJldHVybiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgIHJldHVybiBmbi5hcHBseShvYmosIGFyZ3VtZW50cyk7CiAgICAgICAgfTsKICAgIH0KCiAgICBmdW5jdGlvbiBzY3JpcHRzKCkgewogICAgICAgIHJldHVybiBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnc2NyaXB0Jyk7CiAgICB9CgogICAgZnVuY3Rpb24gZGVmYXVsdE9uRXJyb3IoZXJyKSB7CiAgICAgICAgdGhyb3cgZXJyOwogICAgfQoKICAgIC8vQWxsb3cgZ2V0dGluZyBhIGdsb2JhbCB0aGF0IGlzIGV4cHJlc3NlZCBpbgogICAgLy9kb3Qgbm90YXRpb24sIGxpa2UgJ2EuYi5jJy4KICAgIGZ1bmN0aW9uIGdldEdsb2JhbCh2YWx1ZSkgewogICAgICAgIGlmICghdmFsdWUpIHsKICAgICAgICAgICAgcmV0dXJuIHZhbHVlOwogICAgICAgIH0KICAgICAgICB2YXIgZyA9IGdsb2JhbDsKICAgICAgICBlYWNoKHZhbHVlLnNwbGl0KCcuJyksIGZ1bmN0aW9uIChwYXJ0KSB7CiAgICAgICAgICAgIGcgPSBnW3BhcnRdOwogICAgICAgIH0pOwogICAgICAgIHJldHVybiBnOwogICAgfQoKICAgIC8qKgogICAgICogQ29uc3RydWN0cyBhbiBlcnJvciB3aXRoIGEgcG9pbnRlciB0byBhbiBVUkwgd2l0aCBtb3JlIGluZm9ybWF0aW9uLgogICAgICogQHBhcmFtIHtTdHJpbmd9IGlkIHRoZSBlcnJvciBJRCB0aGF0IG1hcHMgdG8gYW4gSUQgb24gYSB3ZWIgcGFnZS4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSBtZXNzYWdlIGh1bWFuIHJlYWRhYmxlIGVycm9yLgogICAgICogQHBhcmFtIHtFcnJvcn0gW2Vycl0gdGhlIG9yaWdpbmFsIGVycm9yLCBpZiB0aGVyZSBpcyBvbmUuCiAgICAgKgogICAgICogQHJldHVybnMge0Vycm9yfQogICAgICovCiAgICBmdW5jdGlvbiBtYWtlRXJyb3IoaWQsIG1zZywgZXJyLCByZXF1aXJlTW9kdWxlcykgewogICAgICAgIHZhciBlID0gbmV3IEVycm9yKG1zZyArICdcbmh0dHA6Ly9yZXF1aXJlanMub3JnL2RvY3MvZXJyb3JzLmh0bWwjJyArIGlkKTsKICAgICAgICBlLnJlcXVpcmVUeXBlID0gaWQ7CiAgICAgICAgZS5yZXF1aXJlTW9kdWxlcyA9IHJlcXVpcmVNb2R1bGVzOwogICAgICAgIGlmIChlcnIpIHsKICAgICAgICAgICAgZS5vcmlnaW5hbEVycm9yID0gZXJyOwogICAgICAgIH0KICAgICAgICByZXR1cm4gZTsKICAgIH0KCiAgICBpZiAodHlwZW9mIGRlZmluZSAhPT0gJ3VuZGVmaW5lZCcpIHsKICAgICAgICAvL0lmIGEgZGVmaW5lIGlzIGFscmVhZHkgaW4gcGxheSB2aWEgYW5vdGhlciBBTUQgbG9hZGVyLAogICAgICAgIC8vZG8gbm90IG92ZXJ3cml0ZS4KICAgICAgICByZXR1cm47CiAgICB9CgogICAgaWYgKHR5cGVvZiByZXF1aXJlanMgIT09ICd1bmRlZmluZWQnKSB7CiAgICAgICAgaWYgKGlzRnVuY3Rpb24ocmVxdWlyZWpzKSkgewogICAgICAgICAgICAvL0RvIG5vdCBvdmVyd3JpdGUgYW4gZXhpc3RpbmcgcmVxdWlyZWpzIGluc3RhbmNlLgogICAgICAgICAgICByZXR1cm47CiAgICAgICAgfQogICAgICAgIGNmZyA9IHJlcXVpcmVqczsKICAgICAgICByZXF1aXJlanMgPSB1bmRlZmluZWQ7CiAgICB9CgogICAgLy9BbGxvdyBmb3IgYSByZXF1aXJlIGNvbmZpZyBvYmplY3QKICAgIGlmICh0eXBlb2YgcmVxdWlyZSAhPT0gJ3VuZGVmaW5lZCcgJiYgIWlzRnVuY3Rpb24ocmVxdWlyZSkpIHsKICAgICAgICAvL2Fzc3VtZSBpdCBpcyBhIGNvbmZpZyBvYmplY3QuCiAgICAgICAgY2ZnID0gcmVxdWlyZTsKICAgICAgICByZXF1aXJlID0gdW5kZWZpbmVkOwogICAgfQoKICAgIGZ1bmN0aW9uIG5ld0NvbnRleHQoY29udGV4dE5hbWUpIHsKICAgICAgICB2YXIgaW5DaGVja0xvYWRlZCwgTW9kdWxlLCBjb250ZXh0LCBoYW5kbGVycywKICAgICAgICAgICAgY2hlY2tMb2FkZWRUaW1lb3V0SWQsCiAgICAgICAgICAgIGNvbmZpZyA9IHsKICAgICAgICAgICAgICAgIC8vRGVmYXVsdHMuIERvIG5vdCBzZXQgYSBkZWZhdWx0IGZvciBtYXAKICAgICAgICAgICAgICAgIC8vY29uZmlnIHRvIHNwZWVkIHVwIG5vcm1hbGl6ZSgpLCB3aGljaAogICAgICAgICAgICAgICAgLy93aWxsIHJ1biBmYXN0ZXIgaWYgdGhlcmUgaXMgbm8gZGVmYXVsdC4KICAgICAgICAgICAgICAgIHdhaXRTZWNvbmRzOiA3LAogICAgICAgICAgICAgICAgYmFzZVVybDogJy4vJywKICAgICAgICAgICAgICAgIHBhdGhzOiB7fSwKICAgICAgICAgICAgICAgIGJ1bmRsZXM6IHt9LAogICAgICAgICAgICAgICAgcGtnczoge30sCiAgICAgICAgICAgICAgICBzaGltOiB7fSwKICAgICAgICAgICAgICAgIGNvbmZpZzoge30KICAgICAgICAgICAgfSwKICAgICAgICAgICAgcmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgLy9yZWdpc3RyeSBvZiBqdXN0IGVuYWJsZWQgbW9kdWxlcywgdG8gc3BlZWQKICAgICAgICAgICAgLy9jeWNsZSBicmVha2luZyBjb2RlIHdoZW4gbG90cyBvZiBtb2R1bGVzCiAgICAgICAgICAgIC8vYXJlIHJlZ2lzdGVyZWQsIGJ1dCBub3QgYWN0aXZhdGVkLgogICAgICAgICAgICBlbmFibGVkUmVnaXN0cnkgPSB7fSwKICAgICAgICAgICAgdW5kZWZFdmVudHMgPSB7fSwKICAgICAgICAgICAgZGVmUXVldWUgPSBbXSwKICAgICAgICAgICAgZGVmaW5lZCA9IHt9LAogICAgICAgICAgICB1cmxGZXRjaGVkID0ge30sCiAgICAgICAgICAgIGJ1bmRsZXNNYXAgPSB7fSwKICAgICAgICAgICAgcmVxdWlyZUNvdW50ZXIgPSAxLAogICAgICAgICAgICB1bm5vcm1hbGl6ZWRDb3VudGVyID0gMTsKCiAgICAgICAgLyoqCiAgICAgICAgICogVHJpbXMgdGhlIC4gYW5kIC4uIGZyb20gYW4gYXJyYXkgb2YgcGF0aCBzZWdtZW50cy4KICAgICAgICAgKiBJdCB3aWxsIGtlZXAgYSBsZWFkaW5nIHBhdGggc2VnbWVudCBpZiBhIC4uIHdpbGwgYmVjb21lCiAgICAgICAgICogdGhlIGZpcnN0IHBhdGggc2VnbWVudCwgdG8gaGVscCB3aXRoIG1vZHVsZSBuYW1lIGxvb2t1cHMsCiAgICAgICAgICogd2hpY2ggYWN0IGxpa2UgcGF0aHMsIGJ1dCBjYW4gYmUgcmVtYXBwZWQuIEJ1dCB0aGUgZW5kIHJlc3VsdCwKICAgICAgICAgKiBhbGwgcGF0aHMgdGhhdCB1c2UgdGhpcyBmdW5jdGlvbiBzaG91bGQgbG9vayBub3JtYWxpemVkLgogICAgICAgICAqIE5PVEU6IHRoaXMgbWV0aG9kIE1PRElGSUVTIHRoZSBpbnB1dCBhcnJheS4KICAgICAgICAgKiBAcGFyYW0ge0FycmF5fSBhcnkgdGhlIGFycmF5IG9mIHBhdGggc2VnbWVudHMuCiAgICAgICAgICovCiAgICAgICAgZnVuY3Rpb24gdHJpbURvdHMoYXJ5KSB7CiAgICAgICAgICAgIHZhciBpLCBwYXJ0OwogICAgICAgICAgICBmb3IgKGkgPSAwOyBpIDwgYXJ5Lmxlbmd0aDsgaSsrKSB7CiAgICAgICAgICAgICAgICBwYXJ0ID0gYXJ5W2ldOwogICAgICAgICAgICAgICAgaWYgKHBhcnQgPT09ICcuJykgewogICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSwgMSk7CiAgICAgICAgICAgICAgICAgICAgaSAtPSAxOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmIChwYXJ0ID09PSAnLi4nKSB7CiAgICAgICAgICAgICAgICAgICAgLy8gSWYgYXQgdGhlIHN0YXJ0LCBvciBwcmV2aW91cyB2YWx1ZSBpcyBzdGlsbCAuLiwKICAgICAgICAgICAgICAgICAgICAvLyBrZWVwIHRoZW0gc28gdGhhdCB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGggaXQgbWF5CiAgICAgICAgICAgICAgICAgICAgLy8gc3RpbGwgd29yayB3aGVuIGNvbnZlcnRlZCB0byBhIHBhdGgsIGV2ZW4gdGhvdWdoCiAgICAgICAgICAgICAgICAgICAgLy8gYXMgYW4gSUQgaXQgaXMgbGVzcyB0aGFuIGlkZWFsLiBJbiBsYXJnZXIgcG9pbnQKICAgICAgICAgICAgICAgICAgICAvLyByZWxlYXNlcywgbWF5IGJlIGJldHRlciB0byBqdXN0IGtpY2sgb3V0IGFuIGVycm9yLgogICAgICAgICAgICAgICAgICAgIGlmIChpID09PSAwIHx8IChpID09PSAxICYmIGFyeVsyXSA9PT0gJy4uJykgfHwgYXJ5W2kgLSAxXSA9PT0gJy4uJykgewogICAgICAgICAgICAgICAgICAgICAgICBjb250aW51ZTsKICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKGkgPiAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGFyeS5zcGxpY2UoaSAtIDEsIDIpOwogICAgICAgICAgICAgICAgICAgICAgICBpIC09IDI7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBHaXZlbiBhIHJlbGF0aXZlIG1vZHVsZSBuYW1lLCBsaWtlIC4vc29tZXRoaW5nLCBub3JtYWxpemUgaXQgdG8KICAgICAgICAgKiBhIHJlYWwgbmFtZSB0aGF0IGNhbiBiZSBtYXBwZWQgdG8gYSBwYXRoLgogICAgICAgICAqIEBwYXJhbSB7U3RyaW5nfSBuYW1lIHRoZSByZWxhdGl2ZSBuYW1lCiAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IGJhc2VOYW1lIGEgcmVhbCBuYW1lIHRoYXQgdGhlIG5hbWUgYXJnIGlzIHJlbGF0aXZlCiAgICAgICAgICogdG8uCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBhcHBseU1hcCBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgdmFsdWUuIFNob3VsZAogICAgICAgICAqIG9ubHkgYmUgZG9uZSBpZiB0aGlzIG5vcm1hbGl6YXRpb24gaXMgZm9yIGEgZGVwZW5kZW5jeSBJRC4KICAgICAgICAgKiBAcmV0dXJucyB7U3RyaW5nfSBub3JtYWxpemVkIG5hbWUKICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBub3JtYWxpemUobmFtZSwgYmFzZU5hbWUsIGFwcGx5TWFwKSB7CiAgICAgICAgICAgIHZhciBwa2dNYWluLCBtYXBWYWx1ZSwgbmFtZVBhcnRzLCBpLCBqLCBuYW1lU2VnbWVudCwgbGFzdEluZGV4LAogICAgICAgICAgICAgICAgZm91bmRNYXAsIGZvdW5kSSwgZm91bmRTdGFyTWFwLCBzdGFySSwgbm9ybWFsaXplZEJhc2VQYXJ0cywKICAgICAgICAgICAgICAgIGJhc2VQYXJ0cyA9IChiYXNlTmFtZSAmJiBiYXNlTmFtZS5zcGxpdCgnLycpKSwKICAgICAgICAgICAgICAgIG1hcCA9IGNvbmZpZy5tYXAsCiAgICAgICAgICAgICAgICBzdGFyTWFwID0gbWFwICYmIG1hcFsnKiddOwoKICAgICAgICAgICAgLy9BZGp1c3QgYW55IHJlbGF0aXZlIHBhdGhzLgogICAgICAgICAgICBpZiAobmFtZSkgewogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3BsaXQoJy8nKTsKICAgICAgICAgICAgICAgIGxhc3RJbmRleCA9IG5hbWUubGVuZ3RoIC0gMTsKCiAgICAgICAgICAgICAgICAvLyBJZiB3YW50aW5nIG5vZGUgSUQgY29tcGF0aWJpbGl0eSwgc3RyaXAgLmpzIGZyb20gZW5kCiAgICAgICAgICAgICAgICAvLyBvZiBJRHMuIEhhdmUgdG8gZG8gdGhpcyBoZXJlLCBhbmQgbm90IGluIG5hbWVUb1VybAogICAgICAgICAgICAgICAgLy8gYmVjYXVzZSBub2RlIGFsbG93cyBlaXRoZXIgLmpzIG9yIG5vbiAuanMgdG8gbWFwCiAgICAgICAgICAgICAgICAvLyB0byBzYW1lIGZpbGUuCiAgICAgICAgICAgICAgICBpZiAoY29uZmlnLm5vZGVJZENvbXBhdCAmJiBqc1N1ZmZpeFJlZ0V4cC50ZXN0KG5hbWVbbGFzdEluZGV4XSkpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lW2xhc3RJbmRleF0gPSBuYW1lW2xhc3RJbmRleF0ucmVwbGFjZShqc1N1ZmZpeFJlZ0V4cCwgJycpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vIFN0YXJ0cyB3aXRoIGEgJy4nIHNvIG5lZWQgdGhlIGJhc2VOYW1lCiAgICAgICAgICAgICAgICBpZiAobmFtZVswXS5jaGFyQXQoMCkgPT09ICcuJyAmJiBiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAvL0NvbnZlcnQgYmFzZU5hbWUgdG8gYXJyYXksIGFuZCBsb3Agb2ZmIHRoZSBsYXN0IHBhcnQsCiAgICAgICAgICAgICAgICAgICAgLy9zbyB0aGF0IC4gbWF0Y2hlcyB0aGF0ICdkaXJlY3RvcnknIGFuZCBub3QgbmFtZSBvZiB0aGUgYmFzZU5hbWUncwogICAgICAgICAgICAgICAgICAgIC8vbW9kdWxlLiBGb3IgaW5zdGFuY2UsIGJhc2VOYW1lIG9mICdvbmUvdHdvL3RocmVlJywgbWFwcyB0bwogICAgICAgICAgICAgICAgICAgIC8vJ29uZS90d28vdGhyZWUuanMnLCBidXQgd2Ugd2FudCB0aGUgZGlyZWN0b3J5LCAnb25lL3R3bycgZm9yCiAgICAgICAgICAgICAgICAgICAgLy90aGlzIG5vcm1hbGl6YXRpb24uCiAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZEJhc2VQYXJ0cyA9IGJhc2VQYXJ0cy5zbGljZSgwLCBiYXNlUGFydHMubGVuZ3RoIC0gMSk7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vcm1hbGl6ZWRCYXNlUGFydHMuY29uY2F0KG5hbWUpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHRyaW1Eb3RzKG5hbWUpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuam9pbignLycpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL0FwcGx5IG1hcCBjb25maWcgaWYgYXZhaWxhYmxlLgogICAgICAgICAgICBpZiAoYXBwbHlNYXAgJiYgbWFwICYmIChiYXNlUGFydHMgfHwgc3Rhck1hcCkpIHsKICAgICAgICAgICAgICAgIG5hbWVQYXJ0cyA9IG5hbWUuc3BsaXQoJy8nKTsKCiAgICAgICAgICAgICAgICBvdXRlckxvb3A6IGZvciAoaSA9IG5hbWVQYXJ0cy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICBuYW1lU2VnbWVudCA9IG5hbWVQYXJ0cy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgIGlmIChiYXNlUGFydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9GaW5kIHRoZSBsb25nZXN0IGJhc2VOYW1lIHNlZ21lbnQgbWF0Y2ggaW4gdGhlIGNvbmZpZy4KICAgICAgICAgICAgICAgICAgICAgICAgLy9TbywgZG8gam9pbnMgb24gdGhlIGJpZ2dlc3QgdG8gc21hbGxlc3QgbGVuZ3RocyBvZiBiYXNlUGFydHMuCiAgICAgICAgICAgICAgICAgICAgICAgIGZvciAoaiA9IGJhc2VQYXJ0cy5sZW5ndGg7IGogPiAwOyBqIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1hcFZhbHVlID0gZ2V0T3duKG1hcCwgYmFzZVBhcnRzLnNsaWNlKDAsIGopLmpvaW4oJy8nKSk7CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iYXNlTmFtZSBzZWdtZW50IGhhcyBjb25maWcsIGZpbmQgaWYgaXQgaGFzIG9uZSBmb3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1hcFZhbHVlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbWFwVmFsdWUgPSBnZXRPd24obWFwVmFsdWUsIG5hbWVTZWdtZW50KTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAobWFwVmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXRjaCwgdXBkYXRlIG5hbWUgdG8gdGhlIG5ldyB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRNYXAgPSBtYXBWYWx1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZm91bmRJID0gaTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWsgb3V0ZXJMb29wOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9DaGVjayBmb3IgYSBzdGFyIG1hcCBtYXRjaCwgYnV0IGp1c3QgaG9sZCBvbiB0byBpdCwKICAgICAgICAgICAgICAgICAgICAvL2lmIHRoZXJlIGlzIGEgc2hvcnRlciBzZWdtZW50IG1hdGNoIGxhdGVyIGluIGEgbWF0Y2hpbmcKICAgICAgICAgICAgICAgICAgICAvL2NvbmZpZywgdGhlbiBmYXZvciBvdmVyIHRoaXMgc3RhciBtYXAuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFmb3VuZFN0YXJNYXAgJiYgc3Rhck1hcCAmJiBnZXRPd24oc3Rhck1hcCwgbmFtZVNlZ21lbnQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGZvdW5kU3Rhck1hcCA9IGdldE93bihzdGFyTWFwLCBuYW1lU2VnbWVudCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHN0YXJJID0gaTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFmb3VuZE1hcCAmJiBmb3VuZFN0YXJNYXApIHsKICAgICAgICAgICAgICAgICAgICBmb3VuZE1hcCA9IGZvdW5kU3Rhck1hcDsKICAgICAgICAgICAgICAgICAgICBmb3VuZEkgPSBzdGFySTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBpZiAoZm91bmRNYXApIHsKICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMuc3BsaWNlKDAsIGZvdW5kSSwgZm91bmRNYXApOwogICAgICAgICAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHMuam9pbignLycpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICAvLyBJZiB0aGUgbmFtZSBwb2ludHMgdG8gYSBwYWNrYWdlJ3MgbmFtZSwgdXNlCiAgICAgICAgICAgIC8vIHRoZSBwYWNrYWdlIG1haW4gaW5zdGVhZC4KICAgICAgICAgICAgcGtnTWFpbiA9IGdldE93bihjb25maWcucGtncywgbmFtZSk7CgogICAgICAgICAgICByZXR1cm4gcGtnTWFpbiA/IHBrZ01haW4gOiBuYW1lOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gcmVtb3ZlU2NyaXB0KG5hbWUpIHsKICAgICAgICAgICAgaWYgKGlzQnJvd3NlcikgewogICAgICAgICAgICAgICAgZWFjaChzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHROb2RlKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKHNjcmlwdE5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKSA9PT0gbmFtZSAmJgogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2NyaXB0Tm9kZS5nZXRBdHRyaWJ1dGUoJ2RhdGEtcmVxdWlyZWNvbnRleHQnKSA9PT0gY29udGV4dC5jb250ZXh0TmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzY3JpcHROb2RlLnBhcmVudE5vZGUucmVtb3ZlQ2hpbGQoc2NyaXB0Tm9kZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBoYXNQYXRoRmFsbGJhY2soaWQpIHsKICAgICAgICAgICAgdmFyIHBhdGhDb25maWcgPSBnZXRPd24oY29uZmlnLnBhdGhzLCBpZCk7CiAgICAgICAgICAgIGlmIChwYXRoQ29uZmlnICYmIGlzQXJyYXkocGF0aENvbmZpZykgJiYgcGF0aENvbmZpZy5sZW5ndGggPiAxKSB7CiAgICAgICAgICAgICAgICAvL1BvcCBvZmYgdGhlIGZpcnN0IGFycmF5IHZhbHVlLCBzaW5jZSBpdCBmYWlsZWQsIGFuZAogICAgICAgICAgICAgICAgLy9yZXRyeQogICAgICAgICAgICAgICAgcGF0aENvbmZpZy5zaGlmdCgpOwogICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlLnVuZGVmKGlkKTsKCiAgICAgICAgICAgICAgICAvL0N1c3RvbSByZXF1aXJlIHRoYXQgZG9lcyBub3QgZG8gbWFwIHRyYW5zbGF0aW9uLCBzaW5jZQogICAgICAgICAgICAgICAgLy9JRCBpcyAiYWJzb2x1dGUiLCBhbHJlYWR5IG1hcHBlZC9yZXNvbHZlZC4KICAgICAgICAgICAgICAgIGNvbnRleHQubWFrZVJlcXVpcmUobnVsbCwgewogICAgICAgICAgICAgICAgICAgIHNraXBNYXA6IHRydWUKICAgICAgICAgICAgICAgIH0pKFtpZF0pOwoKICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvL1R1cm5zIGEgcGx1Z2luIXJlc291cmNlIHRvIFtwbHVnaW4sIHJlc291cmNlXQogICAgICAgIC8vd2l0aCB0aGUgcGx1Z2luIGJlaW5nIHVuZGVmaW5lZCBpZiB0aGUgbmFtZQogICAgICAgIC8vZGlkIG5vdCBoYXZlIGEgcGx1Z2luIHByZWZpeC4KICAgICAgICBmdW5jdGlvbiBzcGxpdFByZWZpeChuYW1lKSB7CiAgICAgICAgICAgIHZhciBwcmVmaXgsCiAgICAgICAgICAgICAgICBpbmRleCA9IG5hbWUgPyBuYW1lLmluZGV4T2YoJyEnKSA6IC0xOwogICAgICAgICAgICBpZiAoaW5kZXggPiAtMSkgewogICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZS5zdWJzdHJpbmcoMCwgaW5kZXgpOwogICAgICAgICAgICAgICAgbmFtZSA9IG5hbWUuc3Vic3RyaW5nKGluZGV4ICsgMSwgbmFtZS5sZW5ndGgpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIHJldHVybiBbcHJlZml4LCBuYW1lXTsKICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIENyZWF0ZXMgYSBtb2R1bGUgbWFwcGluZyB0aGF0IGluY2x1ZGVzIHBsdWdpbiBwcmVmaXgsIG1vZHVsZQogICAgICAgICAqIG5hbWUsIGFuZCBwYXRoLiBJZiBwYXJlbnRNb2R1bGVNYXAgaXMgcHJvdmlkZWQgaXQgd2lsbAogICAgICAgICAqIGFsc28gbm9ybWFsaXplIHRoZSBuYW1lIHZpYSByZXF1aXJlLm5vcm1hbGl6ZSgpCiAgICAgICAgICoKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gbmFtZSB0aGUgbW9kdWxlIG5hbWUKICAgICAgICAgKiBAcGFyYW0ge1N0cmluZ30gW3BhcmVudE1vZHVsZU1hcF0gcGFyZW50IG1vZHVsZSBtYXAKICAgICAgICAgKiBmb3IgdGhlIG1vZHVsZSBuYW1lLCB1c2VkIHRvIHJlc29sdmUgcmVsYXRpdmUgbmFtZXMuCiAgICAgICAgICogQHBhcmFtIHtCb29sZWFufSBpc05vcm1hbGl6ZWQ6IGlzIHRoZSBJRCBhbHJlYWR5IG5vcm1hbGl6ZWQuCiAgICAgICAgICogVGhpcyBpcyB0cnVlIGlmIHRoaXMgY2FsbCBpcyBkb25lIGZvciBhIGRlZmluZSgpIG1vZHVsZSBJRC4KICAgICAgICAgKiBAcGFyYW0ge0Jvb2xlYW59IGFwcGx5TWFwOiBhcHBseSB0aGUgbWFwIGNvbmZpZyB0byB0aGUgSUQuCiAgICAgICAgICogU2hvdWxkIG9ubHkgYmUgdHJ1ZSBpZiB0aGlzIG1hcCBpcyBmb3IgYSBkZXBlbmRlbmN5LgogICAgICAgICAqCiAgICAgICAgICogQHJldHVybnMge09iamVjdH0KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiBtYWtlTW9kdWxlTWFwKG5hbWUsIHBhcmVudE1vZHVsZU1hcCwgaXNOb3JtYWxpemVkLCBhcHBseU1hcCkgewogICAgICAgICAgICB2YXIgdXJsLCBwbHVnaW5Nb2R1bGUsIHN1ZmZpeCwgbmFtZVBhcnRzLAogICAgICAgICAgICAgICAgcHJlZml4ID0gbnVsbCwKICAgICAgICAgICAgICAgIHBhcmVudE5hbWUgPSBwYXJlbnRNb2R1bGVNYXAgPyBwYXJlbnRNb2R1bGVNYXAubmFtZSA6IG51bGwsCiAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWUgPSBuYW1lLAogICAgICAgICAgICAgICAgaXNEZWZpbmUgPSB0cnVlLAogICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUgPSAnJzsKCiAgICAgICAgICAgIC8vSWYgbm8gbmFtZSwgdGhlbiBpdCBtZWFucyBpdCBpcyBhIHJlcXVpcmUgY2FsbCwgZ2VuZXJhdGUgYW4KICAgICAgICAgICAgLy9pbnRlcm5hbCBuYW1lLgogICAgICAgICAgICBpZiAoIW5hbWUpIHsKICAgICAgICAgICAgICAgIGlzRGVmaW5lID0gZmFsc2U7CiAgICAgICAgICAgICAgICBuYW1lID0gJ19AcicgKyAocmVxdWlyZUNvdW50ZXIgKz0gMSk7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIG5hbWVQYXJ0cyA9IHNwbGl0UHJlZml4KG5hbWUpOwogICAgICAgICAgICBwcmVmaXggPSBuYW1lUGFydHNbMF07CiAgICAgICAgICAgIG5hbWUgPSBuYW1lUGFydHNbMV07CgogICAgICAgICAgICBpZiAocHJlZml4KSB7CiAgICAgICAgICAgICAgICBwcmVmaXggPSBub3JtYWxpemUocHJlZml4LCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICBwbHVnaW5Nb2R1bGUgPSBnZXRPd24oZGVmaW5lZCwgcHJlZml4KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9BY2NvdW50IGZvciByZWxhdGl2ZSBwYXRocyBpZiB0aGVyZSBpcyBhIGJhc2UgbmFtZS4KICAgICAgICAgICAgaWYgKG5hbWUpIHsKICAgICAgICAgICAgICAgIGlmIChwcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICBpZiAocGx1Z2luTW9kdWxlICYmIHBsdWdpbk1vZHVsZS5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9QbHVnaW4gaXMgbG9hZGVkLCB1c2UgaXRzIG5vcm1hbGl6ZSBtZXRob2QuCiAgICAgICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gcGx1Z2luTW9kdWxlLm5vcm1hbGl6ZShuYW1lLCBmdW5jdGlvbiAobmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIG5lc3RlZCBwbHVnaW4gcmVmZXJlbmNlcywgdGhlbiBkbyBub3QgdHJ5IHRvCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIG5vcm1hbGl6ZSwgYXMgaXQgd2lsbCBub3Qgbm9ybWFsaXplIGNvcnJlY3RseS4gVGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvLyBwbGFjZXMgYSByZXN0cmljdGlvbiBvbiByZXNvdXJjZUlkcywgYW5kIHRoZSBsb25nZXIKICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGVybSBzb2x1dGlvbiBpcyBub3QgdG8gbm9ybWFsaXplIHVudGlsIHBsdWdpbnMgYXJlCiAgICAgICAgICAgICAgICAgICAgICAgIC8vIGxvYWRlZCBhbmQgYWxsIG5vcm1hbGl6YXRpb25zIHRvIGFsbG93IGZvciBhc3luYwogICAgICAgICAgICAgICAgICAgICAgICAvLyBsb2FkaW5nIG9mIGEgbG9hZGVyIHBsdWdpbi4gQnV0IGZvciBub3csIGZpeGVzIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAvLyBjb21tb24gdXNlcy4gRGV0YWlscyBpbiAjMTEzMQogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5hbWUuaW5kZXhPZignIScpID09PSAtMSA/CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplKG5hbWUsIHBhcmVudE5hbWUsIGFwcGx5TWFwKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgbmFtZTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIC8vQSByZWd1bGFyIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTmFtZSA9IG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCBhcHBseU1hcCk7CgogICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplZCBuYW1lIG1heSBiZSBhIHBsdWdpbiBJRCBkdWUgdG8gbWFwIGNvbmZpZwogICAgICAgICAgICAgICAgICAgIC8vYXBwbGljYXRpb24gaW4gbm9ybWFsaXplLiBUaGUgbWFwIGNvbmZpZyB2YWx1ZXMgbXVzdAogICAgICAgICAgICAgICAgICAgIC8vYWxyZWFkeSBiZSBub3JtYWxpemVkLCBzbyBkbyBub3QgbmVlZCB0byByZWRvIHRoYXQgcGFydC4KICAgICAgICAgICAgICAgICAgICBuYW1lUGFydHMgPSBzcGxpdFByZWZpeChub3JtYWxpemVkTmFtZSk7CiAgICAgICAgICAgICAgICAgICAgcHJlZml4ID0gbmFtZVBhcnRzWzBdOwogICAgICAgICAgICAgICAgICAgIG5vcm1hbGl6ZWROYW1lID0gbmFtZVBhcnRzWzFdOwogICAgICAgICAgICAgICAgICAgIGlzTm9ybWFsaXplZCA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIHVybCA9IGNvbnRleHQubmFtZVRvVXJsKG5vcm1hbGl6ZWROYW1lKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiB0aGUgaWQgaXMgYSBwbHVnaW4gaWQgdGhhdCBjYW5ub3QgYmUgZGV0ZXJtaW5lZCBpZiBpdCBuZWVkcwogICAgICAgICAgICAvL25vcm1hbGl6YXRpb24sIHN0YW1wIGl0IHdpdGggYSB1bmlxdWUgSUQgc28gdHdvIG1hdGNoaW5nIHJlbGF0aXZlCiAgICAgICAgICAgIC8vaWRzIHRoYXQgbWF5IGNvbmZsaWN0IGNhbiBiZSBzZXBhcmF0ZS4KICAgICAgICAgICAgc3VmZml4ID0gcHJlZml4ICYmICFwbHVnaW5Nb2R1bGUgJiYgIWlzTm9ybWFsaXplZCA/CiAgICAgICAgICAgICAgICAgICAgICdfdW5ub3JtYWxpemVkJyArICh1bm5vcm1hbGl6ZWRDb3VudGVyICs9IDEpIDoKICAgICAgICAgICAgICAgICAgICAgJyc7CgogICAgICAgICAgICByZXR1cm4gewogICAgICAgICAgICAgICAgcHJlZml4OiBwcmVmaXgsCiAgICAgICAgICAgICAgICBuYW1lOiBub3JtYWxpemVkTmFtZSwKICAgICAgICAgICAgICAgIHBhcmVudE1hcDogcGFyZW50TW9kdWxlTWFwLAogICAgICAgICAgICAgICAgdW5ub3JtYWxpemVkOiAhIXN1ZmZpeCwKICAgICAgICAgICAgICAgIHVybDogdXJsLAogICAgICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBvcmlnaW5hbE5hbWUsCiAgICAgICAgICAgICAgICBpc0RlZmluZTogaXNEZWZpbmUsCiAgICAgICAgICAgICAgICBpZDogKHByZWZpeCA/CiAgICAgICAgICAgICAgICAgICAgICAgIHByZWZpeCArICchJyArIG5vcm1hbGl6ZWROYW1lIDoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE5hbWUpICsgc3VmZml4CiAgICAgICAgICAgIH07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBnZXRNb2R1bGUoZGVwTWFwKSB7CiAgICAgICAgICAgIHZhciBpZCA9IGRlcE1hcC5pZCwKICAgICAgICAgICAgICAgIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwoKICAgICAgICAgICAgaWYgKCFtb2QpIHsKICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXSA9IG5ldyBjb250ZXh0Lk1vZHVsZShkZXBNYXApOwogICAgICAgICAgICB9CgogICAgICAgICAgICByZXR1cm4gbW9kOwogICAgICAgIH0KCiAgICAgICAgZnVuY3Rpb24gb24oZGVwTWFwLCBuYW1lLCBmbikgewogICAgICAgICAgICB2YXIgaWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIGlkKTsKCiAgICAgICAgICAgIGlmIChoYXNQcm9wKGRlZmluZWQsIGlkKSAmJgogICAgICAgICAgICAgICAgICAgICghbW9kIHx8IG1vZC5kZWZpbmVFbWl0Q29tcGxldGUpKSB7CiAgICAgICAgICAgICAgICBpZiAobmFtZSA9PT0gJ2RlZmluZWQnKSB7CiAgICAgICAgICAgICAgICAgICAgZm4oZGVmaW5lZFtpZF0pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgbW9kID0gZ2V0TW9kdWxlKGRlcE1hcCk7CiAgICAgICAgICAgICAgICBpZiAobW9kLmVycm9yICYmIG5hbWUgPT09ICdlcnJvcicpIHsKICAgICAgICAgICAgICAgICAgICBmbihtb2QuZXJyb3IpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICBtb2Qub24obmFtZSwgZm4pOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBvbkVycm9yKGVyciwgZXJyYmFjaykgewogICAgICAgICAgICB2YXIgaWRzID0gZXJyLnJlcXVpcmVNb2R1bGVzLAogICAgICAgICAgICAgICAgbm90aWZpZWQgPSBmYWxzZTsKCiAgICAgICAgICAgIGlmIChlcnJiYWNrKSB7CiAgICAgICAgICAgICAgICBlcnJiYWNrKGVycik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBlYWNoKGlkcywgZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgaWQpOwogICAgICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9TZXQgZXJyb3Igb24gbW9kdWxlLCBzbyBpdCBza2lwcyB0aW1lb3V0IGNoZWNrcy4KICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVycm9yID0gZXJyOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm90aWZpZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIGlmICghbm90aWZpZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICAvKioKICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdG8gdHJhbnNmZXIgZ2xvYmFsUXVldWUgaXRlbXMgdG8gdGhpcyBjb250ZXh0J3MKICAgICAgICAgKiBkZWZRdWV1ZS4KICAgICAgICAgKi8KICAgICAgICBmdW5jdGlvbiB0YWtlR2xvYmFsUXVldWUoKSB7CiAgICAgICAgICAgIC8vUHVzaCBhbGwgdGhlIGdsb2JhbERlZlF1ZXVlIGl0ZW1zIGludG8gdGhlIGNvbnRleHQncyBkZWZRdWV1ZQogICAgICAgICAgICBpZiAoZ2xvYmFsRGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICBlYWNoKGdsb2JhbERlZlF1ZXVlLCBmdW5jdGlvbihxdWV1ZUl0ZW0pIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQgPSBxdWV1ZUl0ZW1bMF07CiAgICAgICAgICAgICAgICAgICAgaWYgKHR5cGVvZiBpZCA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcFtpZF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICBkZWZRdWV1ZS5wdXNoKHF1ZXVlSXRlbSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGdsb2JhbERlZlF1ZXVlID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGhhbmRsZXJzID0gewogICAgICAgICAgICAncmVxdWlyZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QucmVxdWlyZSkgewogICAgICAgICAgICAgICAgICAgIHJldHVybiBtb2QucmVxdWlyZTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChtb2QucmVxdWlyZSA9IGNvbnRleHQubWFrZVJlcXVpcmUobW9kLm1hcCkpOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAogICAgICAgICAgICAnZXhwb3J0cyc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIG1vZC51c2luZ0V4cG9ydHMgPSB0cnVlOwogICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICBpZiAobW9kLmV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChkZWZpbmVkW21vZC5tYXAuaWRdID0gbW9kLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLmV4cG9ydHMgPSBkZWZpbmVkW21vZC5tYXAuaWRdID0ge30pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKICAgICAgICAgICAgJ21vZHVsZSc6IGZ1bmN0aW9uIChtb2QpIHsKICAgICAgICAgICAgICAgIGlmIChtb2QubW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG1vZC5tb2R1bGU7CiAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgIHJldHVybiAobW9kLm1vZHVsZSA9IHsKICAgICAgICAgICAgICAgICAgICAgICAgaWQ6IG1vZC5tYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIHVyaTogbW9kLm1hcC51cmwsCiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZzogZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGdldE93bihjb25maWcuY29uZmlnLCBtb2QubWFwLmlkKSB8fCB7fTsKICAgICAgICAgICAgICAgICAgICAgICAgfSwKICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0czogbW9kLmV4cG9ydHMgfHwgKG1vZC5leHBvcnRzID0ge30pCiAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjbGVhblJlZ2lzdHJ5KGlkKSB7CiAgICAgICAgICAgIC8vQ2xlYW4gdXAgbWFjaGluZXJ5IHVzZWQgZm9yIHdhaXRpbmcgbW9kdWxlcy4KICAgICAgICAgICAgZGVsZXRlIHJlZ2lzdHJ5W2lkXTsKICAgICAgICAgICAgZGVsZXRlIGVuYWJsZWRSZWdpc3RyeVtpZF07CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBicmVha0N5Y2xlKG1vZCwgdHJhY2VkLCBwcm9jZXNzZWQpIHsKICAgICAgICAgICAgdmFyIGlkID0gbW9kLm1hcC5pZDsKCiAgICAgICAgICAgIGlmIChtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgIG1vZC5lbWl0KCdlcnJvcicsIG1vZC5lcnJvcik7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICB0cmFjZWRbaWRdID0gdHJ1ZTsKICAgICAgICAgICAgICAgIGVhY2gobW9kLmRlcE1hcHMsIGZ1bmN0aW9uIChkZXBNYXAsIGkpIHsKICAgICAgICAgICAgICAgICAgICB2YXIgZGVwSWQgPSBkZXBNYXAuaWQsCiAgICAgICAgICAgICAgICAgICAgICAgIGRlcCA9IGdldE93bihyZWdpc3RyeSwgZGVwSWQpOwoKICAgICAgICAgICAgICAgICAgICAvL09ubHkgZm9yY2UgdGhpbmdzIHRoYXQgaGF2ZSBub3QgY29tcGxldGVkCiAgICAgICAgICAgICAgICAgICAgLy9iZWluZyBkZWZpbmVkLCBzbyBzdGlsbCBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAgICAgLy9hbmQgb25seSBpZiBpdCBoYXMgbm90IGJlZW4gbWF0Y2hlZCB1cAogICAgICAgICAgICAgICAgICAgIC8vaW4gdGhlIG1vZHVsZSBhbHJlYWR5LgogICAgICAgICAgICAgICAgICAgIGlmIChkZXAgJiYgIW1vZC5kZXBNYXRjaGVkW2ldICYmICFwcm9jZXNzZWRbZGVwSWRdKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChnZXRPd24odHJhY2VkLCBkZXBJZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG1vZC5kZWZpbmVEZXAoaSwgZGVmaW5lZFtkZXBJZF0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kLmNoZWNrKCk7IC8vcGFzcyBmYWxzZT8KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGJyZWFrQ3ljbGUoZGVwLCB0cmFjZWQsIHByb2Nlc3NlZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIHByb2Nlc3NlZFtpZF0gPSB0cnVlOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiBjaGVja0xvYWRlZCgpIHsKICAgICAgICAgICAgdmFyIGVyciwgdXNpbmdQYXRoRmFsbGJhY2ssCiAgICAgICAgICAgICAgICB3YWl0SW50ZXJ2YWwgPSBjb25maWcud2FpdFNlY29uZHMgKiAxMDAwLAogICAgICAgICAgICAgICAgLy9JdCBpcyBwb3NzaWJsZSB0byBkaXNhYmxlIHRoZSB3YWl0IGludGVydmFsIGJ5IHVzaW5nIHdhaXRTZWNvbmRzIG9mIDAuCiAgICAgICAgICAgICAgICBleHBpcmVkID0gd2FpdEludGVydmFsICYmIChjb250ZXh0LnN0YXJ0VGltZSArIHdhaXRJbnRlcnZhbCkgPCBuZXcgRGF0ZSgpLmdldFRpbWUoKSwKICAgICAgICAgICAgICAgIG5vTG9hZHMgPSBbXSwKICAgICAgICAgICAgICAgIHJlcUNhbGxzID0gW10sCiAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSBmYWxzZSwKICAgICAgICAgICAgICAgIG5lZWRDeWNsZUNoZWNrID0gdHJ1ZTsKCiAgICAgICAgICAgIC8vRG8gbm90IGJvdGhlciBpZiB0aGlzIGNhbGwgd2FzIGEgcmVzdWx0IG9mIGEgY3ljbGUgYnJlYWsuCiAgICAgICAgICAgIGlmIChpbkNoZWNrTG9hZGVkKSB7CiAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIGluQ2hlY2tMb2FkZWQgPSB0cnVlOwoKICAgICAgICAgICAgLy9GaWd1cmUgb3V0IHRoZSBzdGF0ZSBvZiBhbGwgdGhlIG1vZHVsZXMuCiAgICAgICAgICAgIGVhY2hQcm9wKGVuYWJsZWRSZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgdmFyIG1hcCA9IG1vZC5tYXAsCiAgICAgICAgICAgICAgICAgICAgbW9kSWQgPSBtYXAuaWQ7CgogICAgICAgICAgICAgICAgLy9Ta2lwIHRoaW5ncyB0aGF0IGFyZSBub3QgZW5hYmxlZCBvciBpbiBlcnJvciBzdGF0ZS4KICAgICAgICAgICAgICAgIGlmICghbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtYXAuaXNEZWZpbmUpIHsKICAgICAgICAgICAgICAgICAgICByZXFDYWxscy5wdXNoKG1vZCk7CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgaWYgKCFtb2QuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIHRoZSBtb2R1bGUgc2hvdWxkIGJlIGV4ZWN1dGVkLCBhbmQgaXQgaGFzIG5vdAogICAgICAgICAgICAgICAgICAgIC8vYmVlbiBpbml0ZWQgYW5kIHRpbWUgaXMgdXAsIHJlbWVtYmVyIGl0LgogICAgICAgICAgICAgICAgICAgIGlmICghbW9kLmluaXRlZCAmJiBleHBpcmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChoYXNQYXRoRmFsbGJhY2sobW9kSWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2luZ1BhdGhGYWxsYmFjayA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgbm9Mb2Fkcy5wdXNoKG1vZElkKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChtb2RJZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCFtb2QuaW5pdGVkICYmIG1vZC5mZXRjaGVkICYmIG1hcC5pc0RlZmluZSkgewogICAgICAgICAgICAgICAgICAgICAgICBzdGlsbExvYWRpbmcgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAoIW1hcC5wcmVmaXgpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vTm8gcmVhc29uIHRvIGtlZXAgbG9va2luZyBmb3IgdW5maW5pc2hlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9sb2FkaW5nLiBJZiB0aGUgb25seSBzdGlsbExvYWRpbmcgaXMgYQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9wbHVnaW4gcmVzb3VyY2UgdGhvdWdoLCBrZWVwIGdvaW5nLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9iZWNhdXNlIGl0IG1heSBiZSB0aGF0IGEgcGx1Z2luIHJlc291cmNlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL2lzIHdhaXRpbmcgb24gYSBub24tcGx1Z2luIGN5Y2xlLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIChuZWVkQ3ljbGVDaGVjayA9IGZhbHNlKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSk7CgogICAgICAgICAgICBpZiAoZXhwaXJlZCAmJiBub0xvYWRzLmxlbmd0aCkgewogICAgICAgICAgICAgICAgLy9JZiB3YWl0IHRpbWUgZXhwaXJlZCwgdGhyb3cgZXJyb3Igb2YgdW5sb2FkZWQgbW9kdWxlcy4KICAgICAgICAgICAgICAgIGVyciA9IG1ha2VFcnJvcigndGltZW91dCcsICdMb2FkIHRpbWVvdXQgZm9yIG1vZHVsZXM6ICcgKyBub0xvYWRzLCBudWxsLCBub0xvYWRzKTsKICAgICAgICAgICAgICAgIGVyci5jb250ZXh0TmFtZSA9IGNvbnRleHQuY29udGV4dE5hbWU7CiAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihlcnIpOwogICAgICAgICAgICB9CgogICAgICAgICAgICAvL05vdCBleHBpcmVkLCBjaGVjayBmb3IgYSBjeWNsZS4KICAgICAgICAgICAgaWYgKG5lZWRDeWNsZUNoZWNrKSB7CiAgICAgICAgICAgICAgICBlYWNoKHJlcUNhbGxzLCBmdW5jdGlvbiAobW9kKSB7CiAgICAgICAgICAgICAgICAgICAgYnJlYWtDeWNsZShtb2QsIHt9LCB7fSk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgfQoKICAgICAgICAgICAgLy9JZiBzdGlsbCB3YWl0aW5nIG9uIGxvYWRzLCBhbmQgdGhlIHdhaXRpbmcgbG9hZCBpcyBzb21ldGhpbmcKICAgICAgICAgICAgLy9vdGhlciB0aGFuIGEgcGx1Z2luIHJlc291cmNlLCBvciB0aGVyZSBhcmUgc3RpbGwgb3V0c3RhbmRpbmcKICAgICAgICAgICAgLy9zY3JpcHRzLCB0aGVuIGp1c3QgdHJ5IGJhY2sgbGF0ZXIuCiAgICAgICAgICAgIGlmICgoIWV4cGlyZWQgfHwgdXNpbmdQYXRoRmFsbGJhY2spICYmIHN0aWxsTG9hZGluZykgewogICAgICAgICAgICAgICAgLy9Tb21ldGhpbmcgaXMgc3RpbGwgd2FpdGluZyB0byBsb2FkLiBXYWl0IGZvciBpdCwgYnV0IG9ubHkKICAgICAgICAgICAgICAgIC8vaWYgYSB0aW1lb3V0IGlzIG5vdCBhbHJlYWR5IGluIGVmZmVjdC4KICAgICAgICAgICAgICAgIGlmICgoaXNCcm93c2VyIHx8IGlzV2ViV29ya2VyKSAmJiAhY2hlY2tMb2FkZWRUaW1lb3V0SWQpIHsKICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IHNldFRpbWVvdXQoZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICBjaGVja0xvYWRlZFRpbWVvdXRJZCA9IDA7CiAgICAgICAgICAgICAgICAgICAgICAgIGNoZWNrTG9hZGVkKCk7CiAgICAgICAgICAgICAgICAgICAgfSwgNTApOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9CgogICAgICAgICAgICBpbkNoZWNrTG9hZGVkID0gZmFsc2U7CiAgICAgICAgfQoKICAgICAgICBNb2R1bGUgPSBmdW5jdGlvbiAobWFwKSB7CiAgICAgICAgICAgIHRoaXMuZXZlbnRzID0gZ2V0T3duKHVuZGVmRXZlbnRzLCBtYXAuaWQpIHx8IHt9OwogICAgICAgICAgICB0aGlzLm1hcCA9IG1hcDsKICAgICAgICAgICAgdGhpcy5zaGltID0gZ2V0T3duKGNvbmZpZy5zaGltLCBtYXAuaWQpOwogICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHMgPSBbXTsKICAgICAgICAgICAgdGhpcy5kZXBNYXBzID0gW107CiAgICAgICAgICAgIHRoaXMuZGVwTWF0Y2hlZCA9IFtdOwogICAgICAgICAgICB0aGlzLnBsdWdpbk1hcHMgPSB7fTsKICAgICAgICAgICAgdGhpcy5kZXBDb3VudCA9IDA7CgogICAgICAgICAgICAvKiB0aGlzLmV4cG9ydHMgdGhpcy5mYWN0b3J5CiAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcyA9IFtdLAogICAgICAgICAgICAgICB0aGlzLmVuYWJsZWQsIHRoaXMuZmV0Y2hlZAogICAgICAgICAgICAqLwogICAgICAgIH07CgogICAgICAgIE1vZHVsZS5wcm90b3R5cGUgPSB7CiAgICAgICAgICAgIGluaXQ6IGZ1bmN0aW9uIChkZXBNYXBzLCBmYWN0b3J5LCBlcnJiYWNrLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIG5vdCBkbyBtb3JlIGluaXRzIGlmIGFscmVhZHkgZG9uZS4gQ2FuIGhhcHBlbiBpZiB0aGVyZQogICAgICAgICAgICAgICAgLy9hcmUgbXVsdGlwbGUgZGVmaW5lIGNhbGxzIGZvciB0aGUgc2FtZSBtb2R1bGUuIFRoYXQgaXMgbm90CiAgICAgICAgICAgICAgICAvL2Egbm9ybWFsLCBjb21tb24gY2FzZSwgYnV0IGl0IGlzIGFsc28gbm90IHVuZXhwZWN0ZWQuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgdGhpcy5mYWN0b3J5ID0gZmFjdG9yeTsKCiAgICAgICAgICAgICAgICBpZiAoZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgIC8vUmVnaXN0ZXIgZm9yIGVycm9ycyBvbiB0aGlzIG1vZHVsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLm9uKCdlcnJvcicsIGVycmJhY2spOwogICAgICAgICAgICAgICAgfSBlbHNlIGlmICh0aGlzLmV2ZW50cy5lcnJvcikgewogICAgICAgICAgICAgICAgICAgIC8vSWYgbm8gZXJyYmFjayBhbHJlYWR5LCBidXQgdGhlcmUgYXJlIGVycm9yIGxpc3RlbmVycwogICAgICAgICAgICAgICAgICAgIC8vb24gdGhpcyBtb2R1bGUsIHNldCB1cCBhbiBlcnJiYWNrIHRvIHBhc3MgdG8gdGhlIGRlcHMuCiAgICAgICAgICAgICAgICAgICAgZXJyYmFjayA9IGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0RvIGEgY29weSBvZiB0aGUgZGVwZW5kZW5jeSBhcnJheSwgc28gdGhhdAogICAgICAgICAgICAgICAgLy9zb3VyY2UgaW5wdXRzIGFyZSBub3QgbW9kaWZpZWQuIEZvciBleGFtcGxlCiAgICAgICAgICAgICAgICAvLyJzaGltIiBkZXBzIGFyZSBwYXNzZWQgaW4gaGVyZSBkaXJlY3RseSwgYW5kCiAgICAgICAgICAgICAgICAvL2RvaW5nIGEgZGlyZWN0IG1vZGlmaWNhdGlvbiBvZiB0aGUgZGVwTWFwcyBhcnJheQogICAgICAgICAgICAgICAgLy93b3VsZCBhZmZlY3QgdGhhdCBjb25maWcuCiAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMgPSBkZXBNYXBzICYmIGRlcE1hcHMuc2xpY2UoMCk7CgogICAgICAgICAgICAgICAgdGhpcy5lcnJiYWNrID0gZXJyYmFjazsKCiAgICAgICAgICAgICAgICAvL0luZGljYXRlIHRoaXMgbW9kdWxlIGhhcyBiZSBpbml0aWFsaXplZAogICAgICAgICAgICAgICAgdGhpcy5pbml0ZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIHRoaXMuaWdub3JlID0gb3B0aW9ucy5pZ25vcmU7CgogICAgICAgICAgICAgICAgLy9Db3VsZCBoYXZlIG9wdGlvbiB0byBpbml0IHRoaXMgbW9kdWxlIGluIGVuYWJsZWQgbW9kZSwKICAgICAgICAgICAgICAgIC8vb3IgY291bGQgaGF2ZSBiZWVuIHByZXZpb3VzbHkgbWFya2VkIGFzIGVuYWJsZWQuIEhvd2V2ZXIsCiAgICAgICAgICAgICAgICAvL3RoZSBkZXBlbmRlbmNpZXMgYXJlIG5vdCBrbm93biB1bnRpbCBpbml0IGlzIGNhbGxlZC4gU28KICAgICAgICAgICAgICAgIC8vaWYgZW5hYmxlZCBwcmV2aW91c2x5LCBub3cgdHJpZ2dlciBkZXBlbmRlbmNpZXMgYXMgZW5hYmxlZC4KICAgICAgICAgICAgICAgIGlmIChvcHRpb25zLmVuYWJsZWQgfHwgdGhpcy5lbmFibGVkKSB7CiAgICAgICAgICAgICAgICAgICAgLy9FbmFibGUgdGhpcyBtb2R1bGUgYW5kIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAvL1dpbGwgY2FsbCB0aGlzLmNoZWNrKCkKICAgICAgICAgICAgICAgICAgICB0aGlzLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmNoZWNrKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBkZWZpbmVEZXA6IGZ1bmN0aW9uIChpLCBkZXBFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAvL0JlY2F1c2Ugb2YgY3ljbGVzLCBkZWZpbmVkIGNhbGxiYWNrIGZvciBhIGdpdmVuCiAgICAgICAgICAgICAgICAvL2V4cG9ydCBjYW4gYmUgY2FsbGVkIG1vcmUgdGhhbiBvbmNlLgogICAgICAgICAgICAgICAgaWYgKCF0aGlzLmRlcE1hdGNoZWRbaV0pIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hdGNoZWRbaV0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwQ291bnQgLT0gMTsKICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcEV4cG9ydHNbaV0gPSBkZXBFeHBvcnRzOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZmV0Y2g6IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIGlmICh0aGlzLmZldGNoZWQpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB0aGlzLmZldGNoZWQgPSB0cnVlOwoKICAgICAgICAgICAgICAgIGNvbnRleHQuc3RhcnRUaW1lID0gKG5ldyBEYXRlKCkpLmdldFRpbWUoKTsKCiAgICAgICAgICAgICAgICB2YXIgbWFwID0gdGhpcy5tYXA7CgogICAgICAgICAgICAgICAgLy9JZiB0aGUgbWFuYWdlciBpcyBmb3IgYSBwbHVnaW4gbWFuYWdlZCByZXNvdXJjZSwKICAgICAgICAgICAgICAgIC8vYXNrIHRoZSBwbHVnaW4gdG8gbG9hZCBpdCBub3cuCiAgICAgICAgICAgICAgICBpZiAodGhpcy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5tYWtlUmVxdWlyZSh0aGlzLm1hcCwgewogICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVCdWlsZENhbGxiYWNrOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgfSkodGhpcy5zaGltLmRlcHMgfHwgW10sIGJpbmQodGhpcywgZnVuY3Rpb24gKCkgewogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgfSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL1JlZ3VsYXIgZGVwZW5kZW5jeS4KICAgICAgICAgICAgICAgICAgICByZXR1cm4gbWFwLnByZWZpeCA/IHRoaXMuY2FsbFBsdWdpbigpIDogdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICBsb2FkOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICB2YXIgdXJsID0gdGhpcy5tYXAudXJsOwoKICAgICAgICAgICAgICAgIC8vUmVndWxhciBkZXBlbmRlbmN5LgogICAgICAgICAgICAgICAgaWYgKCF1cmxGZXRjaGVkW3VybF0pIHsKICAgICAgICAgICAgICAgICAgICB1cmxGZXRjaGVkW3VybF0gPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIGNvbnRleHQubG9hZCh0aGlzLm1hcC5pZCwgdXJsKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBDaGVja3MgaWYgdGhlIG1vZHVsZSBpcyByZWFkeSB0byBkZWZpbmUgaXRzZWxmLCBhbmQgaWYgc28sCiAgICAgICAgICAgICAqIGRlZmluZSBpdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNoZWNrOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBpZiAoIXRoaXMuZW5hYmxlZCB8fCB0aGlzLmVuYWJsaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHZhciBlcnIsIGNqc01vZHVsZSwKICAgICAgICAgICAgICAgICAgICBpZCA9IHRoaXMubWFwLmlkLAogICAgICAgICAgICAgICAgICAgIGRlcEV4cG9ydHMgPSB0aGlzLmRlcEV4cG9ydHMsCiAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0cywKICAgICAgICAgICAgICAgICAgICBmYWN0b3J5ID0gdGhpcy5mYWN0b3J5OwoKICAgICAgICAgICAgICAgIGlmICghdGhpcy5pbml0ZWQpIHsKICAgICAgICAgICAgICAgICAgICAvLyBPbmx5IGZldGNoIGlmIG5vdCBhbHJlYWR5IGluIHRoZSBkZWZRdWV1ZS4KICAgICAgICAgICAgICAgICAgICBpZiAoIWhhc1Byb3AoY29udGV4dC5kZWZRdWV1ZU1hcCwgaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZmV0Y2goKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgdGhpcy5lcnJvcik7CiAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKCF0aGlzLmRlZmluaW5nKSB7CiAgICAgICAgICAgICAgICAgICAgLy9UaGUgZmFjdG9yeSBjb3VsZCB0cmlnZ2VyIGFub3RoZXIgcmVxdWlyZSBjYWxsCiAgICAgICAgICAgICAgICAgICAgLy90aGF0IHdvdWxkIHJlc3VsdCBpbiBjaGVja2luZyB0aGlzIG1vZHVsZSB0bwogICAgICAgICAgICAgICAgICAgIC8vZGVmaW5lIGl0c2VsZiBhZ2Fpbi4gSWYgYWxyZWFkeSBpbiB0aGUgcHJvY2VzcwogICAgICAgICAgICAgICAgICAgIC8vb2YgZG9pbmcgdGhhdCwgc2tpcCB0aGlzIHdvcmsuCiAgICAgICAgICAgICAgICAgICAgdGhpcy5kZWZpbmluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlcENvdW50IDwgMSAmJiAhdGhpcy5kZWZpbmVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChpc0Z1bmN0aW9uKGZhY3RvcnkpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0cnkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjb250ZXh0LmV4ZWNDYihpZCwgZmFjdG9yeSwgZGVwRXhwb3J0cywgZXhwb3J0cyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXJyID0gZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBGYXZvciByZXR1cm4gdmFsdWUgb3ZlciBleHBvcnRzLiBJZiBub2RlL2NqcyBpbiBwbGF5LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdGhlbiB3aWxsIG5vdCBoYXZlIGEgcmV0dXJuIHZhbHVlIGFueXdheS4gRmF2b3IKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG1vZHVsZS5leHBvcnRzIGFzc2lnbm1lbnQgb3ZlciBleHBvcnRzIG9iamVjdC4KICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLm1hcC5pc0RlZmluZSAmJiBleHBvcnRzID09PSB1bmRlZmluZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjanNNb2R1bGUgPSB0aGlzLm1vZHVsZTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAoY2pzTW9kdWxlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBjanNNb2R1bGUuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMudXNpbmdFeHBvcnRzKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vZXhwb3J0cyBhbHJlYWR5IHNldCB0aGUgZGVmaW5lZCB2YWx1ZS4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZXhwb3J0cyA9IHRoaXMuZXhwb3J0czsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIElmIHRoZXJlIGlzIGFuIGVycm9yIGxpc3RlbmVyLCBmYXZvciBwYXNzaW5nCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdG8gdGhhdCBpbnN0ZWFkIG9mIHRocm93aW5nIGFuIGVycm9yLiBIb3dldmVyLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIG9ubHkgZG8gaXQgZm9yIGRlZmluZSgpJ2QgIG1vZHVsZXMuIHJlcXVpcmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBlcnJiYWNrcyBzaG91bGQgbm90IGJlIGNhbGxlZCBmb3IgZmFpbHVyZXMgaW4KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyB0aGVpciBjYWxsYmFja3MgKCM2OTkpLiBIb3dldmVyIGlmIGEgZ2xvYmFsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gb25FcnJvciBpcyBzZXQsIHVzZSB0aGF0LgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodGhpcy5ldmVudHMuZXJyb3IgJiYgdGhpcy5tYXAuaXNEZWZpbmUpIHx8CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5vbkVycm9yICE9PSBkZWZhdWx0T25FcnJvcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1hcCA9IHRoaXMubWFwOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZU1vZHVsZXMgPSB0aGlzLm1hcC5pc0RlZmluZSA/IFt0aGlzLm1hcC5pZF0gOiBudWxsOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlcnIucmVxdWlyZVR5cGUgPSB0aGlzLm1hcC5pc0RlZmluZSA/ICdkZWZpbmUnIDogJ3JlcXVpcmUnOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcigodGhpcy5lcnJvciA9IGVycikpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAodHlwZW9mIGNvbnNvbGUgIT09ICd1bmRlZmluZWQnICYmCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjb25zb2xlLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIExvZyB0aGUgZXJyb3IgZm9yIGRlYnVnZ2luZy4gSWYgcHJvbWlzZXMgY291bGQgYmUKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gdXNlZCwgdGhpcyB3b3VsZCBiZSBkaWZmZXJlbnQsIGJ1dCBtYWtpbmcgZG8uCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbnNvbGUuZXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBEbyBub3Qgd2FudCB0byBjb21wbGV0ZWx5IGxvc2UgdGhlIGVycm9yLiBXaGlsZSB0aGlzCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vIHdpbGwgbWVzcyB1cCBwcm9jZXNzaW5nIGFuZCBsZWFkIHRvIHNpbWlsYXIgcmVzdWx0cwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBhcyBidWcgMTQ0MCwgaXQgYXQgbGVhc3Qgc3VyZmFjZXMgdGhlIGVycm9yLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXEub25FcnJvcihlcnIpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIGxpdGVyYWwgdmFsdWUKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGV4cG9ydHMgPSBmYWN0b3J5OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmV4cG9ydHMgPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMubWFwLmlzRGVmaW5lICYmICF0aGlzLmlnbm9yZSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmaW5lZFtpZF0gPSBleHBvcnRzOwoKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEub25SZXNvdXJjZUxvYWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YXIgcmVzTG9hZE1hcHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXNMb2FkTWFwcy5wdXNoKGRlcE1hcC5ub3JtYWxpemVkTWFwIHx8IGRlcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmVxLm9uUmVzb3VyY2VMb2FkKGNvbnRleHQsIHRoaXMubWFwLCByZXNMb2FkTWFwcyk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQ2xlYW4gdXAKICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgLy9GaW5pc2hlZCB0aGUgZGVmaW5lIHN0YWdlLiBBbGxvdyBjYWxsaW5nIGNoZWNrIGFnYWluCiAgICAgICAgICAgICAgICAgICAgLy90byBhbGxvdyBkZWZpbmUgbm90aWZpY2F0aW9ucyBiZWxvdyBpbiB0aGUgY2FzZSBvZiBhCiAgICAgICAgICAgICAgICAgICAgLy9jeWNsZS4KICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluaW5nID0gZmFsc2U7CgogICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLmRlZmluZWQgJiYgIXRoaXMuZGVmaW5lRW1pdHRlZCkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXR0ZWQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2RlZmluZWQnLCB0aGlzLmV4cG9ydHMpOwogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZUVtaXRDb21wbGV0ZSA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGNhbGxQbHVnaW46IGZ1bmN0aW9uICgpIHsKICAgICAgICAgICAgICAgIHZhciBtYXAgPSB0aGlzLm1hcCwKICAgICAgICAgICAgICAgICAgICBpZCA9IG1hcC5pZCwKICAgICAgICAgICAgICAgICAgICAvL01hcCBhbHJlYWR5IG5vcm1hbGl6ZWQgdGhlIHByZWZpeC4KICAgICAgICAgICAgICAgICAgICBwbHVnaW5NYXAgPSBtYWtlTW9kdWxlTWFwKG1hcC5wcmVmaXgpOwoKICAgICAgICAgICAgICAgIC8vTWFyayB0aGlzIGFzIGEgZGVwZW5kZW5jeSBmb3IgdGhpcyBwbHVnaW4sIHNvIGl0CiAgICAgICAgICAgICAgICAvL2NhbiBiZSB0cmFjZWQgZm9yIGN5Y2xlcy4KICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwcy5wdXNoKHBsdWdpbk1hcCk7CgogICAgICAgICAgICAgICAgb24ocGx1Z2luTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbikgewogICAgICAgICAgICAgICAgICAgIHZhciBsb2FkLCBub3JtYWxpemVkTWFwLCBub3JtYWxpemVkTW9kLAogICAgICAgICAgICAgICAgICAgICAgICBidW5kbGVJZCA9IGdldE93bihidW5kbGVzTWFwLCB0aGlzLm1hcC5pZCksCiAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSB0aGlzLm1hcC5uYW1lLAogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnROYW1lID0gdGhpcy5tYXAucGFyZW50TWFwID8gdGhpcy5tYXAucGFyZW50TWFwLm5hbWUgOiBudWxsLAogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKG1hcC5wYXJlbnRNYXAsIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZUJ1aWxkQ2FsbGJhY2s6IHRydWUKICAgICAgICAgICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgICAgIC8vSWYgY3VycmVudCBtYXAgaXMgbm90IG5vcm1hbGl6ZWQsIHdhaXQgZm9yIHRoYXQKICAgICAgICAgICAgICAgICAgICAvL25vcm1hbGl6ZWQgbmFtZSB0byBsb2FkIGluc3RlYWQgb2YgY29udGludWluZy4KICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5tYXAudW5ub3JtYWxpemVkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIHRoZSBJRCBpZiB0aGUgcGx1Z2luIGFsbG93cyBpdC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHBsdWdpbi5ub3JtYWxpemUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIG5hbWUgPSBwbHVnaW4ubm9ybWFsaXplKG5hbWUsIGZ1bmN0aW9uIChuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG5vcm1hbGl6ZShuYW1lLCBwYXJlbnROYW1lLCB0cnVlKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pIHx8ICcnOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL3ByZWZpeCBhbmQgbmFtZSBzaG91bGQgYWxyZWFkeSBiZSBub3JtYWxpemVkLCBubyBuZWVkCiAgICAgICAgICAgICAgICAgICAgICAgIC8vZm9yIGFwcGx5aW5nIG1hcCBjb25maWcgYWdhaW4gZWl0aGVyLgogICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTWFwID0gbWFrZU1vZHVsZU1hcChtYXAucHJlZml4ICsgJyEnICsgbmFtZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5tYXAucGFyZW50TWFwKTsKICAgICAgICAgICAgICAgICAgICAgICAgb24obm9ybWFsaXplZE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICdkZWZpbmVkJywgYmluZCh0aGlzLCBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLm1hcC5ub3JtYWxpemVkTWFwID0gbm9ybWFsaXplZE1hcDsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGVuYWJsZWQ6IHRydWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlnbm9yZTogdHJ1ZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgbm9ybWFsaXplZE1vZCA9IGdldE93bihyZWdpc3RyeSwgbm9ybWFsaXplZE1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChub3JtYWxpemVkTW9kKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL01hcmsgdGhpcyBhcyBhIGRlcGVuZGVuY3kgZm9yIHRoaXMgcGx1Z2luLCBzbyBpdAogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9jYW4gYmUgdHJhY2VkIGZvciBjeWNsZXMuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlcE1hcHMucHVzaChub3JtYWxpemVkTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpZiAodGhpcy5ldmVudHMuZXJyb3IpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLm9uKCdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24gKGVycikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmVtaXQoJ2Vycm9yJywgZXJyKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBub3JtYWxpemVkTW9kLmVuYWJsZSgpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0lmIGEgcGF0aHMgY29uZmlnLCB0aGVuIGp1c3QgbG9hZCB0aGF0IGZpbGUgaW5zdGVhZCB0bwogICAgICAgICAgICAgICAgICAgIC8vcmVzb2x2ZSB0aGUgcGx1Z2luLCBhcyBpdCBpcyBidWlsdCBpbnRvIHRoYXQgcGF0aHMgbGF5ZXIuCiAgICAgICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMubWFwLnVybCA9IGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkKTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5sb2FkKCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGxvYWQgPSBiaW5kKHRoaXMsIGZ1bmN0aW9uICh2YWx1ZSkgewogICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmluaXQoW10sIGZ1bmN0aW9uICgpIHsgcmV0dXJuIHZhbHVlOyB9LCBudWxsLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICBsb2FkLmVycm9yID0gYmluZCh0aGlzLCBmdW5jdGlvbiAoZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuaW5pdGVkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lcnJvciA9IGVycjsKICAgICAgICAgICAgICAgICAgICAgICAgZXJyLnJlcXVpcmVNb2R1bGVzID0gW2lkXTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vUmVtb3ZlIHRlbXAgdW5ub3JtYWxpemVkIG1vZHVsZXMgZm9yIHRoaXMgbW9kdWxlLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NpbmNlIHRoZXkgd2lsbCBuZXZlciBiZSByZXNvbHZlZCBvdGhlcndpc2Ugbm93LgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5tYXAuaWQuaW5kZXhPZihpZCArICdfdW5ub3JtYWxpemVkJykgPT09IDApIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBjbGVhblJlZ2lzdHJ5KG1vZC5tYXAuaWQpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgICAgIG9uRXJyb3IoZXJyKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9BbGxvdyBwbHVnaW5zIHRvIGxvYWQgb3RoZXIgY29kZSB3aXRob3V0IGhhdmluZyB0byBrbm93IHRoZQogICAgICAgICAgICAgICAgICAgIC8vY29udGV4dCBvciBob3cgdG8gJ2NvbXBsZXRlJyB0aGUgbG9hZC4KICAgICAgICAgICAgICAgICAgICBsb2FkLmZyb21UZXh0ID0gYmluZCh0aGlzLCBmdW5jdGlvbiAodGV4dCwgdGV4dEFsdCkgewogICAgICAgICAgICAgICAgICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtb2R1bGVOYW1lID0gbWFwLm5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVNYXAgPSBtYWtlTW9kdWxlTWFwKG1vZHVsZU5hbWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaGFzSW50ZXJhY3RpdmUgPSB1c2VJbnRlcmFjdGl2ZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vQXMgb2YgMi4xLjAsIHN1cHBvcnQganVzdCBwYXNzaW5nIHRoZSB0ZXh0LCB0byByZWluZm9yY2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9mcm9tVGV4dCBvbmx5IGJlaW5nIGNhbGxlZCBvbmNlIHBlciByZXNvdXJjZS4gU3RpbGwKICAgICAgICAgICAgICAgICAgICAgICAgLy9zdXBwb3J0IG9sZCBzdHlsZSBvZiBwYXNzaW5nIG1vZHVsZU5hbWUgYnV0IGRpc2NhcmQKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGF0IG1vZHVsZU5hbWUgaW4gZmF2b3Igb2YgdGhlIGludGVybmFsIHJlZi4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRleHRBbHQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRleHQgPSB0ZXh0QWx0OwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1R1cm4gb2ZmIGludGVyYWN0aXZlIHNjcmlwdCBtYXRjaGluZyBmb3IgSUUgZm9yIGFueSBkZWZpbmUKICAgICAgICAgICAgICAgICAgICAgICAgLy9jYWxscyBpbiB0aGUgdGV4dCwgdGhlbiB0dXJuIGl0IGJhY2sgb24gYXQgdGhlIGVuZC4KICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc0ludGVyYWN0aXZlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB1c2VJbnRlcmFjdGl2ZSA9IGZhbHNlOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAvL1ByaW1lIHRoZSBzeXN0ZW0gYnkgY3JlYXRpbmcgYSBtb2R1bGUgaW5zdGFuY2UgZm9yCiAgICAgICAgICAgICAgICAgICAgICAgIC8vaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGdldE1vZHVsZShtb2R1bGVNYXApOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9UcmFuc2ZlciBhbnkgY29uZmlnIHRvIHRoaXMgb3RoZXIgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzUHJvcChjb25maWcuY29uZmlnLCBpZCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5jb25maWdbbW9kdWxlTmFtZV0gPSBjb25maWcuY29uZmlnW2lkXTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdHJ5IHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlcS5leGVjKHRleHQpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ2Zyb210ZXh0ZXZhbCcsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdmcm9tVGV4dCBldmFsIGZvciAnICsgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICcgZmFpbGVkOiAnICsgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW2lkXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFzSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9NYXJrIHRoaXMgYXMgYSBkZXBlbmRlbmN5IGZvciB0aGUgcGx1Z2luCiAgICAgICAgICAgICAgICAgICAgICAgIC8vcmVzb3VyY2UKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBNYXBzLnB1c2gobW9kdWxlTWFwKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3VwcG9ydCBhbm9ueW1vdXMgbW9kdWxlcy4KICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQobW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgdGhlIHZhbHVlIG9mIHRoYXQgbW9kdWxlIHRvIHRoZSB2YWx1ZSBmb3IgdGhpcwogICAgICAgICAgICAgICAgICAgICAgICAvL3Jlc291cmNlIElELgogICAgICAgICAgICAgICAgICAgICAgICBsb2NhbFJlcXVpcmUoW21vZHVsZU5hbWVdLCBsb2FkKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgLy9Vc2UgcGFyZW50TmFtZSBoZXJlIHNpbmNlIHRoZSBwbHVnaW4ncyBuYW1lIGlzIG5vdCByZWxpYWJsZSwKICAgICAgICAgICAgICAgICAgICAvL2NvdWxkIGJlIHNvbWUgd2VpcmQgc3RyaW5nIHdpdGggbm8gcGF0aCB0aGF0IGFjdHVhbGx5IHdhbnRzIHRvCiAgICAgICAgICAgICAgICAgICAgLy9yZWZlcmVuY2UgdGhlIHBhcmVudE5hbWUncyBwYXRoLgogICAgICAgICAgICAgICAgICAgIHBsdWdpbi5sb2FkKG1hcC5uYW1lLCBsb2NhbFJlcXVpcmUsIGxvYWQsIGNvbmZpZyk7CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgIHRoaXMucGx1Z2luTWFwc1twbHVnaW5NYXAuaWRdID0gcGx1Z2luTWFwOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgZW5hYmxlOiBmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICBlbmFibGVkUmVnaXN0cnlbdGhpcy5tYXAuaWRdID0gdGhpczsKICAgICAgICAgICAgICAgIHRoaXMuZW5hYmxlZCA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9TZXQgZmxhZyBtZW50aW9uaW5nIHRoYXQgdGhlIG1vZHVsZSBpcyBlbmFibGluZywKICAgICAgICAgICAgICAgIC8vc28gdGhhdCBpbW1lZGlhdGUgY2FsbHMgdG8gdGhlIGRlZmluZWQgY2FsbGJhY2tzCiAgICAgICAgICAgICAgICAvL2ZvciBkZXBlbmRlbmNpZXMgZG8gbm90IHRyaWdnZXIgaW5hZHZlcnRlbnQgbG9hZAogICAgICAgICAgICAgICAgLy93aXRoIHRoZSBkZXBDb3VudCBzdGlsbCBiZWluZyB6ZXJvLgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IHRydWU7CgogICAgICAgICAgICAgICAgLy9FbmFibGUgZWFjaCBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoKHRoaXMuZGVwTWFwcywgYmluZCh0aGlzLCBmdW5jdGlvbiAoZGVwTWFwLCBpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIGlkLCBtb2QsIGhhbmRsZXI7CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwTWFwID09PSAnc3RyaW5nJykgewogICAgICAgICAgICAgICAgICAgICAgICAvL0RlcGVuZGVuY3kgbmVlZHMgdG8gYmUgY29udmVydGVkIHRvIGEgZGVwTWFwCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYW5kIHdpcmVkIHVwIHRvIHRoaXMgbW9kdWxlLgogICAgICAgICAgICAgICAgICAgICAgICBkZXBNYXAgPSBtYWtlTW9kdWxlTWFwKGRlcE1hcCwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAodGhpcy5tYXAuaXNEZWZpbmUgPyB0aGlzLm1hcCA6IHRoaXMubWFwLnBhcmVudE1hcCksCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZmFsc2UsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIXRoaXMuc2tpcE1hcCk7CiAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuZGVwTWFwc1tpXSA9IGRlcE1hcDsKCiAgICAgICAgICAgICAgICAgICAgICAgIGhhbmRsZXIgPSBnZXRPd24oaGFuZGxlcnMsIGRlcE1hcC5pZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaGFuZGxlcikgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBFeHBvcnRzW2ldID0gaGFuZGxlcih0aGlzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5kZXBDb3VudCArPSAxOwoKICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZGVmaW5lZCcsIGJpbmQodGhpcywgZnVuY3Rpb24gKGRlcEV4cG9ydHMpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh0aGlzLnVuZGVmZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm47CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB0aGlzLmRlZmluZURlcChpLCBkZXBFeHBvcnRzKTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgICAgICAgICAgICAgfSkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKHRoaXMuZXJyYmFjaykgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgb24oZGVwTWFwLCAnZXJyb3InLCBiaW5kKHRoaXMsIHRoaXMuZXJyYmFjaykpOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgaWYgKHRoaXMuZXZlbnRzLmVycm9yKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvLyBObyBkaXJlY3QgZXJyYmFjayBvbiB0aGlzIG1vZHVsZSwgYnV0IHNvbWV0aGluZwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gZWxzZSBpcyBsaXN0ZW5pbmcgZm9yIGVycm9ycywgc28gYmUgc3VyZSB0bwogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy8gcHJvcGFnYXRlIHRoZSBlcnJvciBjb3JyZWN0bHkuCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBvbihkZXBNYXAsICdlcnJvcicsIGJpbmQodGhpcywgZnVuY3Rpb24oZXJyKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgdGhpcy5lbWl0KCdlcnJvcicsIGVycik7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9KSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlkID0gZGVwTWFwLmlkOwogICAgICAgICAgICAgICAgICAgIG1vZCA9IHJlZ2lzdHJ5W2lkXTsKCiAgICAgICAgICAgICAgICAgICAgLy9Ta2lwIHNwZWNpYWwgbW9kdWxlcyBsaWtlICdyZXF1aXJlJywgJ2V4cG9ydHMnLCAnbW9kdWxlJwogICAgICAgICAgICAgICAgICAgIC8vQWxzbywgZG9uJ3QgY2FsbCBlbmFibGUgaWYgaXQgaXMgYWxyZWFkeSBlbmFibGVkLAogICAgICAgICAgICAgICAgICAgIC8vaW1wb3J0YW50IGluIGNpcmN1bGFyIGRlcGVuZGVuY3kgY2FzZXMuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGhhbmRsZXJzLCBpZCkgJiYgbW9kICYmICFtb2QuZW5hYmxlZCkgewogICAgICAgICAgICAgICAgICAgICAgICBjb250ZXh0LmVuYWJsZShkZXBNYXAsIHRoaXMpOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pKTsKCiAgICAgICAgICAgICAgICAvL0VuYWJsZSBlYWNoIHBsdWdpbiB0aGF0IGlzIHVzZWQgaW4KICAgICAgICAgICAgICAgIC8vYSBkZXBlbmRlbmN5CiAgICAgICAgICAgICAgICBlYWNoUHJvcCh0aGlzLnBsdWdpbk1hcHMsIGJpbmQodGhpcywgZnVuY3Rpb24gKHBsdWdpbk1hcCkgewogICAgICAgICAgICAgICAgICAgIHZhciBtb2QgPSBnZXRPd24ocmVnaXN0cnksIHBsdWdpbk1hcC5pZCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCAmJiAhbW9kLmVuYWJsZWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dC5lbmFibGUocGx1Z2luTWFwLCB0aGlzKTsKICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICB9KSk7CgogICAgICAgICAgICAgICAgdGhpcy5lbmFibGluZyA9IGZhbHNlOwoKICAgICAgICAgICAgICAgIHRoaXMuY2hlY2soKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG9uOiBmdW5jdGlvbiAobmFtZSwgY2IpIHsKICAgICAgICAgICAgICAgIHZhciBjYnMgPSB0aGlzLmV2ZW50c1tuYW1lXTsKICAgICAgICAgICAgICAgIGlmICghY2JzKSB7CiAgICAgICAgICAgICAgICAgICAgY2JzID0gdGhpcy5ldmVudHNbbmFtZV0gPSBbXTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNicy5wdXNoKGNiKTsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIGVtaXQ6IGZ1bmN0aW9uIChuYW1lLCBldnQpIHsKICAgICAgICAgICAgICAgIGVhY2godGhpcy5ldmVudHNbbmFtZV0sIGZ1bmN0aW9uIChjYikgewogICAgICAgICAgICAgICAgICAgIGNiKGV2dCk7CiAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIGlmIChuYW1lID09PSAnZXJyb3InKSB7CiAgICAgICAgICAgICAgICAgICAgLy9Ob3cgdGhhdCB0aGUgZXJyb3IgaGFuZGxlciB3YXMgdHJpZ2dlcmVkLCByZW1vdmUKICAgICAgICAgICAgICAgICAgICAvL3RoZSBsaXN0ZW5lcnMsIHNpbmNlIHRoaXMgYnJva2VuIE1vZHVsZSBpbnN0YW5jZQogICAgICAgICAgICAgICAgICAgIC8vY2FuIHN0YXkgYXJvdW5kIGZvciBhIHdoaWxlIGluIHRoZSByZWdpc3RyeS4KICAgICAgICAgICAgICAgICAgICBkZWxldGUgdGhpcy5ldmVudHNbbmFtZV07CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBmdW5jdGlvbiBjYWxsR2V0TW9kdWxlKGFyZ3MpIHsKICAgICAgICAgICAgLy9Ta2lwIG1vZHVsZXMgYWxyZWFkeSBkZWZpbmVkLgogICAgICAgICAgICBpZiAoIWhhc1Byb3AoZGVmaW5lZCwgYXJnc1swXSkpIHsKICAgICAgICAgICAgICAgIGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKGFyZ3NbMF0sIG51bGwsIHRydWUpKS5pbml0KGFyZ3NbMV0sIGFyZ3NbMl0pOwogICAgICAgICAgICB9CiAgICAgICAgfQoKICAgICAgICBmdW5jdGlvbiByZW1vdmVMaXN0ZW5lcihub2RlLCBmdW5jLCBuYW1lLCBpZU5hbWUpIHsKICAgICAgICAgICAgLy9GYXZvciBkZXRhY2hFdmVudCBiZWNhdXNlIG9mIElFOQogICAgICAgICAgICAvL2lzc3VlLCBzZWUgYXR0YWNoRXZlbnQvYWRkRXZlbnRMaXN0ZW5lciBjb21tZW50IGVsc2V3aGVyZQogICAgICAgICAgICAvL2luIHRoaXMgZmlsZS4KICAgICAgICAgICAgaWYgKG5vZGUuZGV0YWNoRXZlbnQgJiYgIWlzT3BlcmEpIHsKICAgICAgICAgICAgICAgIC8vUHJvYmFibHkgSUUuIElmIG5vdCBpdCB3aWxsIHRocm93IGFuIGVycm9yLCB3aGljaCB3aWxsIGJlCiAgICAgICAgICAgICAgICAvL3VzZWZ1bCB0byBrbm93LgogICAgICAgICAgICAgICAgaWYgKGllTmFtZSkgewogICAgICAgICAgICAgICAgICAgIG5vZGUuZGV0YWNoRXZlbnQoaWVOYW1lLCBmdW5jKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUucmVtb3ZlRXZlbnRMaXN0ZW5lcihuYW1lLCBmdW5jLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8qKgogICAgICAgICAqIEdpdmVuIGFuIGV2ZW50IGZyb20gYSBzY3JpcHQgbm9kZSwgZ2V0IHRoZSByZXF1aXJlanMgaW5mbyBmcm9tIGl0LAogICAgICAgICAqIGFuZCB0aGVuIHJlbW92ZXMgdGhlIGV2ZW50IGxpc3RlbmVycyBvbiB0aGUgbm9kZS4KICAgICAgICAgKiBAcGFyYW0ge0V2ZW50fSBldnQKICAgICAgICAgKiBAcmV0dXJucyB7T2JqZWN0fQogICAgICAgICAqLwogICAgICAgIGZ1bmN0aW9uIGdldFNjcmlwdERhdGEoZXZ0KSB7CiAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgLy9hbGwgb2xkIGJyb3dzZXJzIHdpbGwgYmUgc3VwcG9ydGVkLCBidXQgdGhpcyBvbmUgd2FzIGVhc3kgZW5vdWdoCiAgICAgICAgICAgIC8vdG8gc3VwcG9ydCBhbmQgc3RpbGwgbWFrZXMgc2Vuc2UuCiAgICAgICAgICAgIHZhciBub2RlID0gZXZ0LmN1cnJlbnRUYXJnZXQgfHwgZXZ0LnNyY0VsZW1lbnQ7CgogICAgICAgICAgICAvL1JlbW92ZSB0aGUgbGlzdGVuZXJzIG9uY2UgaGVyZS4KICAgICAgICAgICAgcmVtb3ZlTGlzdGVuZXIobm9kZSwgY29udGV4dC5vblNjcmlwdExvYWQsICdsb2FkJywgJ29ucmVhZHlzdGF0ZWNoYW5nZScpOwogICAgICAgICAgICByZW1vdmVMaXN0ZW5lcihub2RlLCBjb250ZXh0Lm9uU2NyaXB0RXJyb3IsICdlcnJvcicpOwoKICAgICAgICAgICAgcmV0dXJuIHsKICAgICAgICAgICAgICAgIG5vZGU6IG5vZGUsCiAgICAgICAgICAgICAgICBpZDogbm9kZSAmJiBub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJykKICAgICAgICAgICAgfTsKICAgICAgICB9CgogICAgICAgIGZ1bmN0aW9uIGludGFrZURlZmluZXMoKSB7CiAgICAgICAgICAgIHZhciBhcmdzOwoKICAgICAgICAgICAgLy9BbnkgZGVmaW5lZCBtb2R1bGVzIGluIHRoZSBnbG9iYWwgcXVldWUsIGludGFrZSB0aGVtIG5vdy4KICAgICAgICAgICAgdGFrZUdsb2JhbFF1ZXVlKCk7CgogICAgICAgICAgICAvL01ha2Ugc3VyZSBhbnkgcmVtYWluaW5nIGRlZlF1ZXVlIGl0ZW1zIGdldCBwcm9wZXJseSBwcm9jZXNzZWQuCiAgICAgICAgICAgIHdoaWxlIChkZWZRdWV1ZS5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGFyZ3MgPSBkZWZRdWV1ZS5zaGlmdCgpOwogICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ21pc21hdGNoJywgJ01pc21hdGNoZWQgYW5vbnltb3VzIGRlZmluZSgpIG1vZHVsZTogJyArCiAgICAgICAgICAgICAgICAgICAgICAgIGFyZ3NbYXJncy5sZW5ndGggLSAxXSkpOwogICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAvL2FyZ3MgYXJlIGlkLCBkZXBzLCBmYWN0b3J5LiBTaG91bGQgYmUgbm9ybWFsaXplZCBieSB0aGUKICAgICAgICAgICAgICAgICAgICAvL2RlZmluZSgpIGZ1bmN0aW9uLgogICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoYXJncyk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZU1hcCA9IHt9OwogICAgICAgIH0KCiAgICAgICAgY29udGV4dCA9IHsKICAgICAgICAgICAgY29uZmlnOiBjb25maWcsCiAgICAgICAgICAgIGNvbnRleHROYW1lOiBjb250ZXh0TmFtZSwKICAgICAgICAgICAgcmVnaXN0cnk6IHJlZ2lzdHJ5LAogICAgICAgICAgICBkZWZpbmVkOiBkZWZpbmVkLAogICAgICAgICAgICB1cmxGZXRjaGVkOiB1cmxGZXRjaGVkLAogICAgICAgICAgICBkZWZRdWV1ZTogZGVmUXVldWUsCiAgICAgICAgICAgIGRlZlF1ZXVlTWFwOiB7fSwKICAgICAgICAgICAgTW9kdWxlOiBNb2R1bGUsCiAgICAgICAgICAgIG1ha2VNb2R1bGVNYXA6IG1ha2VNb2R1bGVNYXAsCiAgICAgICAgICAgIG5leHRUaWNrOiByZXEubmV4dFRpY2ssCiAgICAgICAgICAgIG9uRXJyb3I6IG9uRXJyb3IsCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogU2V0IGEgY29uZmlndXJhdGlvbiBmb3IgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAqIEBwYXJhbSB7T2JqZWN0fSBjZmcgY29uZmlnIG9iamVjdCB0byBpbnRlZ3JhdGUuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBjb25maWd1cmU6IGZ1bmN0aW9uIChjZmcpIHsKICAgICAgICAgICAgICAgIC8vTWFrZSBzdXJlIHRoZSBiYXNlVXJsIGVuZHMgaW4gYSBzbGFzaC4KICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIGlmIChjZmcuYmFzZVVybC5jaGFyQXQoY2ZnLmJhc2VVcmwubGVuZ3RoIC0gMSkgIT09ICcvJykgewogICAgICAgICAgICAgICAgICAgICAgICBjZmcuYmFzZVVybCArPSAnLyc7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vU2F2ZSBvZmYgdGhlIHBhdGhzIHNpbmNlIHRoZXkgcmVxdWlyZSBzcGVjaWFsIHByb2Nlc3NpbmcsCiAgICAgICAgICAgICAgICAvL3RoZXkgYXJlIGFkZGl0aXZlLgogICAgICAgICAgICAgICAgdmFyIHNoaW0gPSBjb25maWcuc2hpbSwKICAgICAgICAgICAgICAgICAgICBvYmpzID0gewogICAgICAgICAgICAgICAgICAgICAgICBwYXRoczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlczogdHJ1ZSwKICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnOiB0cnVlLAogICAgICAgICAgICAgICAgICAgICAgICBtYXA6IHRydWUKICAgICAgICAgICAgICAgICAgICB9OwoKICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZywgZnVuY3Rpb24gKHZhbHVlLCBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKG9ianNbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFjb25maWdbcHJvcF0pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHt9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIG1peGluKGNvbmZpZ1twcm9wXSwgdmFsdWUsIHRydWUsIHRydWUpOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZ1twcm9wXSA9IHZhbHVlOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgIC8vUmV2ZXJzZSBtYXAgdGhlIGJ1bmRsZXMKICAgICAgICAgICAgICAgIGlmIChjZmcuYnVuZGxlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2hQcm9wKGNmZy5idW5kbGVzLCBmdW5jdGlvbiAodmFsdWUsIHByb3ApIHsKICAgICAgICAgICAgICAgICAgICAgICAgZWFjaCh2YWx1ZSwgZnVuY3Rpb24gKHYpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmICh2ICE9PSBwcm9wKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgYnVuZGxlc01hcFt2XSA9IHByb3A7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgICAgIH0pOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vTWVyZ2Ugc2hpbQogICAgICAgICAgICAgICAgaWYgKGNmZy5zaGltKSB7CiAgICAgICAgICAgICAgICAgICAgZWFjaFByb3AoY2ZnLnNoaW0sIGZ1bmN0aW9uICh2YWx1ZSwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9Ob3JtYWxpemUgdGhlIHN0cnVjdHVyZQogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaXNBcnJheSh2YWx1ZSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIHZhbHVlID0gewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGRlcHM6IHZhbHVlCiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgICAgIGlmICgodmFsdWUuZXhwb3J0cyB8fCB2YWx1ZS5pbml0KSAmJiAhdmFsdWUuZXhwb3J0c0ZuKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB2YWx1ZS5leHBvcnRzRm4gPSBjb250ZXh0Lm1ha2VTaGltRXhwb3J0cyh2YWx1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgc2hpbVtpZF0gPSB2YWx1ZTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICBjb25maWcuc2hpbSA9IHNoaW07CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgLy9BZGp1c3QgcGFja2FnZXMgaWYgbmVjZXNzYXJ5LgogICAgICAgICAgICAgICAgaWYgKGNmZy5wYWNrYWdlcykgewogICAgICAgICAgICAgICAgICAgIGVhY2goY2ZnLnBhY2thZ2VzLCBmdW5jdGlvbiAocGtnT2JqKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBsb2NhdGlvbiwgbmFtZTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHBrZ09iaiA9IHR5cGVvZiBwa2dPYmogPT09ICdzdHJpbmcnID8ge25hbWU6IHBrZ09ian0gOiBwa2dPYmo7CgogICAgICAgICAgICAgICAgICAgICAgICBuYW1lID0gcGtnT2JqLm5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIGxvY2F0aW9uID0gcGtnT2JqLmxvY2F0aW9uOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAobG9jYXRpb24pIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIGNvbmZpZy5wYXRoc1tuYW1lXSA9IHBrZ09iai5sb2NhdGlvbjsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TYXZlIHBvaW50ZXIgdG8gbWFpbiBtb2R1bGUgSUQgZm9yIHBrZyBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICAvL1JlbW92ZSBsZWFkaW5nIGRvdCBpbiBtYWluLCBzbyBtYWluIHBhdGhzIGFyZSBub3JtYWxpemVkLAogICAgICAgICAgICAgICAgICAgICAgICAvL2FuZCByZW1vdmUgYW55IHRyYWlsaW5nIC5qcywgc2luY2UgZGlmZmVyZW50IHBhY2thZ2UKICAgICAgICAgICAgICAgICAgICAgICAgLy9lbnZzIGhhdmUgZGlmZmVyZW50IGNvbnZlbnRpb25zOiBzb21lIHVzZSBhIG1vZHVsZSBuYW1lLAogICAgICAgICAgICAgICAgICAgICAgICAvL3NvbWUgdXNlIGEgZmlsZSBuYW1lLgogICAgICAgICAgICAgICAgICAgICAgICBjb25maWcucGtnc1tuYW1lXSA9IHBrZ09iai5uYW1lICsgJy8nICsgKHBrZ09iai5tYWluIHx8ICdtYWluJykKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIC5yZXBsYWNlKGN1cnJEaXJSZWdFeHAsICcnKQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgLnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL0lmIHRoZXJlIGFyZSBhbnkgIndhaXRpbmcgdG8gZXhlY3V0ZSIgbW9kdWxlcyBpbiB0aGUgcmVnaXN0cnksCiAgICAgICAgICAgICAgICAvL3VwZGF0ZSB0aGUgbWFwcyBmb3IgdGhlbSwgc2luY2UgdGhlaXIgaW5mbywgbGlrZSBVUkxzIHRvIGxvYWQsCiAgICAgICAgICAgICAgICAvL21heSBoYXZlIGNoYW5nZWQuCiAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24gKG1vZCwgaWQpIHsKICAgICAgICAgICAgICAgICAgICAvL0lmIG1vZHVsZSBhbHJlYWR5IGhhcyBpbml0IGNhbGxlZCwgc2luY2UgaXQgaXMgdG9vCiAgICAgICAgICAgICAgICAgICAgLy9sYXRlIHRvIG1vZGlmeSB0aGVtLCBhbmQgaWdub3JlIHVubm9ybWFsaXplZCBvbmVzCiAgICAgICAgICAgICAgICAgICAgLy9zaW5jZSB0aGV5IGFyZSB0cmFuc2llbnQuCiAgICAgICAgICAgICAgICAgICAgaWYgKCFtb2QuaW5pdGVkICYmICFtb2QubWFwLnVubm9ybWFsaXplZCkgewogICAgICAgICAgICAgICAgICAgICAgICBtb2QubWFwID0gbWFrZU1vZHVsZU1hcChpZCwgbnVsbCwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9JZiBhIGRlcHMgYXJyYXkgb3IgYSBjb25maWcgY2FsbGJhY2sgaXMgc3BlY2lmaWVkLCB0aGVuIGNhbGwKICAgICAgICAgICAgICAgIC8vcmVxdWlyZSB3aXRoIHRob3NlIGFyZ3MuIFRoaXMgaXMgdXNlZnVsIHdoZW4gcmVxdWlyZSBpcyBkZWZpbmVkIGFzIGEKICAgICAgICAgICAgICAgIC8vY29uZmlnIG9iamVjdCBiZWZvcmUgcmVxdWlyZS5qcyBpcyBsb2FkZWQuCiAgICAgICAgICAgICAgICBpZiAoY2ZnLmRlcHMgfHwgY2ZnLmNhbGxiYWNrKSB7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5yZXF1aXJlKGNmZy5kZXBzIHx8IFtdLCBjZmcuY2FsbGJhY2spOwogICAgICAgICAgICAgICAgfQogICAgICAgICAgICB9LAoKICAgICAgICAgICAgbWFrZVNoaW1FeHBvcnRzOiBmdW5jdGlvbiAodmFsdWUpIHsKICAgICAgICAgICAgICAgIGZ1bmN0aW9uIGZuKCkgewogICAgICAgICAgICAgICAgICAgIHZhciByZXQ7CiAgICAgICAgICAgICAgICAgICAgaWYgKHZhbHVlLmluaXQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0ID0gdmFsdWUuaW5pdC5hcHBseShnbG9iYWwsIGFyZ3VtZW50cyk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgIHJldHVybiByZXQgfHwgKHZhbHVlLmV4cG9ydHMgJiYgZ2V0R2xvYmFsKHZhbHVlLmV4cG9ydHMpKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIHJldHVybiBmbjsKICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIG1ha2VSZXF1aXJlOiBmdW5jdGlvbiAocmVsTWFwLCBvcHRpb25zKSB7CiAgICAgICAgICAgICAgICBvcHRpb25zID0gb3B0aW9ucyB8fCB7fTsKCiAgICAgICAgICAgICAgICBmdW5jdGlvbiBsb2NhbFJlcXVpcmUoZGVwcywgY2FsbGJhY2ssIGVycmJhY2spIHsKICAgICAgICAgICAgICAgICAgICB2YXIgaWQsIG1hcCwgcmVxdWlyZU1vZDsKCiAgICAgICAgICAgICAgICAgICAgaWYgKG9wdGlvbnMuZW5hYmxlQnVpbGRDYWxsYmFjayAmJiBjYWxsYmFjayAmJiBpc0Z1bmN0aW9uKGNhbGxiYWNrKSkgewogICAgICAgICAgICAgICAgICAgICAgICBjYWxsYmFjay5fX3JlcXVpcmVKc0J1aWxkID0gdHJ1ZTsKICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIGlmICh0eXBlb2YgZGVwcyA9PT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzRnVuY3Rpb24oY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0ludmFsaWQgY2FsbAogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdyZXF1aXJlYXJncycsICdJbnZhbGlkIHJlcXVpcmUgY2FsbCcpLCBlcnJiYWNrKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiByZXF1aXJlfGV4cG9ydHN8bW9kdWxlIGFyZSByZXF1ZXN0ZWQsIGdldCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy92YWx1ZSBmb3IgdGhlbSBmcm9tIHRoZSBzcGVjaWFsIGhhbmRsZXJzLiBDYXZlYXQ6CiAgICAgICAgICAgICAgICAgICAgICAgIC8vdGhpcyBvbmx5IHdvcmtzIHdoaWxlIG1vZHVsZSBpcyBiZWluZyBkZWZpbmVkLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAocmVsTWFwICYmIGhhc1Byb3AoaGFuZGxlcnMsIGRlcHMpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gaGFuZGxlcnNbZGVwc10ocmVnaXN0cnlbcmVsTWFwLmlkXSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vU3luY2hyb25vdXMgYWNjZXNzIHRvIG9uZSBtb2R1bGUuIElmIHJlcXVpcmUuZ2V0IGlzCiAgICAgICAgICAgICAgICAgICAgICAgIC8vYXZhaWxhYmxlIChhcyBpbiB0aGUgTm9kZSBhZGFwdGVyKSwgcHJlZmVyIHRoYXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChyZXEuZ2V0KSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gcmVxLmdldChjb250ZXh0LCBkZXBzLCByZWxNYXAsIGxvY2FsUmVxdWlyZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIC8vTm9ybWFsaXplIG1vZHVsZSBuYW1lLCBpZiBpdCBjb250YWlucyAuIG9yIC4uCiAgICAgICAgICAgICAgICAgICAgICAgIG1hcCA9IG1ha2VNb2R1bGVNYXAoZGVwcywgcmVsTWFwLCBmYWxzZSwgdHJ1ZSk7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFwLmlkOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKCFoYXNQcm9wKGRlZmluZWQsIGlkKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub3Rsb2FkZWQnLCAnTW9kdWxlIG5hbWUgIicgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgaWQgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIgaGFzIG5vdCBiZWVuIGxvYWRlZCB5ZXQgZm9yIGNvbnRleHQ6ICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29udGV4dE5hbWUgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHJlbE1hcCA/ICcnIDogJy4gVXNlIHJlcXVpcmUoW10pJykpKTsKICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICByZXR1cm4gZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgICAgICAvL0dyYWIgZGVmaW5lcyB3YWl0aW5nIGluIHRoZSBnbG9iYWwgcXVldWUuCiAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAvL01hcmsgYWxsIHRoZSBkZXBlbmRlbmNpZXMgYXMgbmVlZGluZyB0byBiZSBsb2FkZWQuCiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5uZXh0VGljayhmdW5jdGlvbiAoKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vU29tZSBkZWZpbmVzIGNvdWxkIGhhdmUgYmVlbiBhZGRlZCBzaW5jZSB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9yZXF1aXJlIGNhbGwsIGNvbGxlY3QgdGhlbS4KICAgICAgICAgICAgICAgICAgICAgICAgaW50YWtlRGVmaW5lcygpOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZCA9IGdldE1vZHVsZShtYWtlTW9kdWxlTWFwKG51bGwsIHJlbE1hcCkpOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9TdG9yZSBpZiBtYXAgY29uZmlnIHNob3VsZCBiZSBhcHBsaWVkIHRvIHRoaXMgcmVxdWlyZQogICAgICAgICAgICAgICAgICAgICAgICAvL2NhbGwgZm9yIGRlcGVuZGVuY2llcy4KICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5za2lwTWFwID0gb3B0aW9ucy5za2lwTWFwOwoKICAgICAgICAgICAgICAgICAgICAgICAgcmVxdWlyZU1vZC5pbml0KGRlcHMsIGNhbGxiYWNrLCBlcnJiYWNrLCB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlbmFibGVkOiB0cnVlCiAgICAgICAgICAgICAgICAgICAgICAgIH0pOwoKICAgICAgICAgICAgICAgICAgICAgICAgY2hlY2tMb2FkZWQoKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGxvY2FsUmVxdWlyZTsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBtaXhpbihsb2NhbFJlcXVpcmUsIHsKICAgICAgICAgICAgICAgICAgICBpc0Jyb3dzZXI6IGlzQnJvd3NlciwKCiAgICAgICAgICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAgICAgICAgICogQ29udmVydHMgYSBtb2R1bGUgbmFtZSArIC5leHRlbnNpb24gaW50byBhbiBVUkwgcGF0aC4KICAgICAgICAgICAgICAgICAgICAgKiAqUmVxdWlyZXMqIHRoZSB1c2Ugb2YgYSBtb2R1bGUgbmFtZS4gSXQgZG9lcyBub3Qgc3VwcG9ydCB1c2luZwogICAgICAgICAgICAgICAgICAgICAqIHBsYWluIFVSTHMgbGlrZSBuYW1lVG9VcmwuCiAgICAgICAgICAgICAgICAgICAgICovCiAgICAgICAgICAgICAgICAgICAgdG9Vcmw6IGZ1bmN0aW9uIChtb2R1bGVOYW1lUGx1c0V4dCkgewogICAgICAgICAgICAgICAgICAgICAgICB2YXIgZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaW5kZXggPSBtb2R1bGVOYW1lUGx1c0V4dC5sYXN0SW5kZXhPZignLicpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgc2VnbWVudCA9IG1vZHVsZU5hbWVQbHVzRXh0LnNwbGl0KCcvJylbMF0sCiAgICAgICAgICAgICAgICAgICAgICAgICAgICBpc1JlbGF0aXZlID0gc2VnbWVudCA9PT0gJy4nIHx8IHNlZ21lbnQgPT09ICcuLic7CgogICAgICAgICAgICAgICAgICAgICAgICAvL0hhdmUgYSBmaWxlIGV4dGVuc2lvbiBhbGlhcywgYW5kIGl0IGlzIG5vdCB0aGUKICAgICAgICAgICAgICAgICAgICAgICAgLy9kb3RzIGZyb20gYSByZWxhdGl2ZSBwYXRoLgogICAgICAgICAgICAgICAgICAgICAgICBpZiAoaW5kZXggIT09IC0xICYmICghaXNSZWxhdGl2ZSB8fCBpbmRleCA+IDEpKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBleHQgPSBtb2R1bGVOYW1lUGx1c0V4dC5zdWJzdHJpbmcoaW5kZXgsIG1vZHVsZU5hbWVQbHVzRXh0Lmxlbmd0aCk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lUGx1c0V4dCA9IG1vZHVsZU5hbWVQbHVzRXh0LnN1YnN0cmluZygwLCBpbmRleCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBjb250ZXh0Lm5hbWVUb1VybChub3JtYWxpemUobW9kdWxlTmFtZVBsdXNFeHQsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJlbE1hcCAmJiByZWxNYXAuaWQsIHRydWUpLCBleHQsICB0cnVlKTsKICAgICAgICAgICAgICAgICAgICB9LAoKICAgICAgICAgICAgICAgICAgICBkZWZpbmVkOiBmdW5jdGlvbiAoaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGhhc1Byb3AoZGVmaW5lZCwgbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQpOwogICAgICAgICAgICAgICAgICAgIH0sCgogICAgICAgICAgICAgICAgICAgIHNwZWNpZmllZDogZnVuY3Rpb24gKGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgIGlkID0gbWFrZU1vZHVsZU1hcChpZCwgcmVsTWFwLCBmYWxzZSwgdHJ1ZSkuaWQ7CiAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiBoYXNQcm9wKGRlZmluZWQsIGlkKSB8fCBoYXNQcm9wKHJlZ2lzdHJ5LCBpZCk7CiAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgfSk7CgogICAgICAgICAgICAgICAgLy9Pbmx5IGFsbG93IHVuZGVmIG9uIHRvcCBsZXZlbCByZXF1aXJlIGNhbGxzCiAgICAgICAgICAgICAgICBpZiAoIXJlbE1hcCkgewogICAgICAgICAgICAgICAgICAgIGxvY2FsUmVxdWlyZS51bmRlZiA9IGZ1bmN0aW9uIChpZCkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0JpbmQgYW55IHdhaXRpbmcgZGVmaW5lKCkgY2FsbHMgdG8gdGhpcyBjb250ZXh0LAogICAgICAgICAgICAgICAgICAgICAgICAvL2ZpeCBmb3IgIzQwOAogICAgICAgICAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICAgICAgICAgIHZhciBtYXAgPSBtYWtlTW9kdWxlTWFwKGlkLCByZWxNYXAsIHRydWUpLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgbW9kID0gZ2V0T3duKHJlZ2lzdHJ5LCBpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBtb2QudW5kZWZlZCA9IHRydWU7CiAgICAgICAgICAgICAgICAgICAgICAgIHJlbW92ZVNjcmlwdChpZCk7CgogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgZGVmaW5lZFtpZF07CiAgICAgICAgICAgICAgICAgICAgICAgIGRlbGV0ZSB1cmxGZXRjaGVkW21hcC51cmxdOwogICAgICAgICAgICAgICAgICAgICAgICBkZWxldGUgdW5kZWZFdmVudHNbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgLy9DbGVhbiBxdWV1ZWQgZGVmaW5lcyB0b28uIEdvIGJhY2t3YXJkcwogICAgICAgICAgICAgICAgICAgICAgICAvL2luIGFycmF5IHNvIHRoYXQgdGhlIHNwbGljZXMgZG8gbm90CiAgICAgICAgICAgICAgICAgICAgICAgIC8vbWVzcyB1cCB0aGUgaXRlcmF0aW9uLgogICAgICAgICAgICAgICAgICAgICAgICBlYWNoUmV2ZXJzZShkZWZRdWV1ZSwgZnVuY3Rpb24oYXJncywgaSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IGlkKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgZGVmUXVldWUuc3BsaWNlKGksIDEpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICAgICAgZGVsZXRlIGNvbnRleHQuZGVmUXVldWVNYXBbaWRdOwoKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9Ib2xkIG9uIHRvIGxpc3RlbmVycyBpbiBjYXNlIHRoZQogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9tb2R1bGUgd2lsbCBiZSBhdHRlbXB0ZWQgdG8gYmUgcmVsb2FkZWQKICAgICAgICAgICAgICAgICAgICAgICAgICAgIC8vdXNpbmcgYSBkaWZmZXJlbnQgY29uZmlnLgogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKG1vZC5ldmVudHMuZGVmaW5lZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHVuZGVmRXZlbnRzW2lkXSA9IG1vZC5ldmVudHM7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgICAgICAgICAgY2xlYW5SZWdpc3RyeShpZCk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9OwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBsb2NhbFJlcXVpcmU7CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGVkIHRvIGVuYWJsZSBhIG1vZHVsZSBpZiBpdCBpcyBzdGlsbCBpbiB0aGUgcmVnaXN0cnkKICAgICAgICAgICAgICogYXdhaXRpbmcgZW5hYmxlbWVudC4gQSBzZWNvbmQgYXJnLCBwYXJlbnQsIHRoZSBwYXJlbnQgbW9kdWxlLAogICAgICAgICAgICAgKiBpcyBwYXNzZWQgaW4gZm9yIGNvbnRleHQsIHdoZW4gdGhpcyBtZXRob2QgaXMgb3ZlcnJpZGRlbiBieQogICAgICAgICAgICAgKiB0aGUgb3B0aW1pemVyLiBOb3Qgc2hvd24gaGVyZSB0byBrZWVwIGNvZGUgY29tcGFjdC4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGVuYWJsZTogZnVuY3Rpb24gKGRlcE1hcCkgewogICAgICAgICAgICAgICAgdmFyIG1vZCA9IGdldE93bihyZWdpc3RyeSwgZGVwTWFwLmlkKTsKICAgICAgICAgICAgICAgIGlmIChtb2QpIHsKICAgICAgICAgICAgICAgICAgICBnZXRNb2R1bGUoZGVwTWFwKS5lbmFibGUoKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgfSwKCiAgICAgICAgICAgIC8qKgogICAgICAgICAgICAgKiBJbnRlcm5hbCBtZXRob2QgdXNlZCBieSBlbnZpcm9ubWVudCBhZGFwdGVycyB0byBjb21wbGV0ZSBhIGxvYWQgZXZlbnQuCiAgICAgICAgICAgICAqIEEgbG9hZCBldmVudCBjb3VsZCBiZSBhIHNjcmlwdCBsb2FkIG9yIGp1c3QgYSBsb2FkIHBhc3MgZnJvbSBhIHN5bmNocm9ub3VzCiAgICAgICAgICAgICAqIGxvYWQgY2FsbC4KICAgICAgICAgICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSB0byBwb3RlbnRpYWxseSBjb21wbGV0ZS4KICAgICAgICAgICAgICovCiAgICAgICAgICAgIGNvbXBsZXRlTG9hZDogZnVuY3Rpb24gKG1vZHVsZU5hbWUpIHsKICAgICAgICAgICAgICAgIHZhciBmb3VuZCwgYXJncywgbW9kLAogICAgICAgICAgICAgICAgICAgIHNoaW0gPSBnZXRPd24oY29uZmlnLnNoaW0sIG1vZHVsZU5hbWUpIHx8IHt9LAogICAgICAgICAgICAgICAgICAgIHNoRXhwb3J0cyA9IHNoaW0uZXhwb3J0czsKCiAgICAgICAgICAgICAgICB0YWtlR2xvYmFsUXVldWUoKTsKCiAgICAgICAgICAgICAgICB3aGlsZSAoZGVmUXVldWUubGVuZ3RoKSB7CiAgICAgICAgICAgICAgICAgICAgYXJncyA9IGRlZlF1ZXVlLnNoaWZ0KCk7CiAgICAgICAgICAgICAgICAgICAgaWYgKGFyZ3NbMF0gPT09IG51bGwpIHsKICAgICAgICAgICAgICAgICAgICAgICAgYXJnc1swXSA9IG1vZHVsZU5hbWU7CiAgICAgICAgICAgICAgICAgICAgICAgIC8vSWYgYWxyZWFkeSBmb3VuZCBhbiBhbm9ueW1vdXMgbW9kdWxlIGFuZCBib3VuZCBpdAogICAgICAgICAgICAgICAgICAgICAgICAvL3RvIHRoaXMgbmFtZSwgdGhlbiB0aGlzIGlzIHNvbWUgb3RoZXIgYW5vbiBtb2R1bGUKICAgICAgICAgICAgICAgICAgICAgICAgLy93YWl0aW5nIGZvciBpdHMgY29tcGxldGVMb2FkIHRvIGZpcmUuCiAgICAgICAgICAgICAgICAgICAgICAgIGlmIChmb3VuZCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0gZWxzZSBpZiAoYXJnc1swXSA9PT0gbW9kdWxlTmFtZSkgewogICAgICAgICAgICAgICAgICAgICAgICAvL0ZvdW5kIG1hdGNoaW5nIGRlZmluZSBjYWxsIGZvciB0aGlzIHNjcmlwdCEKICAgICAgICAgICAgICAgICAgICAgICAgZm91bmQgPSB0cnVlOwogICAgICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAgICAgY2FsbEdldE1vZHVsZShhcmdzKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQuZGVmUXVldWVNYXAgPSB7fTsKCiAgICAgICAgICAgICAgICAvL0RvIHRoaXMgYWZ0ZXIgdGhlIGN5Y2xlIG9mIGNhbGxHZXRNb2R1bGUgaW4gY2FzZSB0aGUgcmVzdWx0CiAgICAgICAgICAgICAgICAvL29mIHRob3NlIGNhbGxzL2luaXQgY2FsbHMgY2hhbmdlcyB0aGUgcmVnaXN0cnkuCiAgICAgICAgICAgICAgICBtb2QgPSBnZXRPd24ocmVnaXN0cnksIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmICghZm91bmQgJiYgIWhhc1Byb3AoZGVmaW5lZCwgbW9kdWxlTmFtZSkgJiYgbW9kICYmICFtb2QuaW5pdGVkKSB7CiAgICAgICAgICAgICAgICAgICAgaWYgKGNvbmZpZy5lbmZvcmNlRGVmaW5lICYmICghc2hFeHBvcnRzIHx8ICFnZXRHbG9iYWwoc2hFeHBvcnRzKSkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGhhc1BhdGhGYWxsYmFjayhtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuOwogICAgICAgICAgICAgICAgICAgICAgICB9IGVsc2UgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgcmV0dXJuIG9uRXJyb3IobWFrZUVycm9yKCdub2RlZmluZScsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdObyBkZWZpbmUgY2FsbCBmb3IgJyArIG1vZHVsZU5hbWUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIG51bGwsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIFttb2R1bGVOYW1lXSkpOwogICAgICAgICAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgICAgICAgICAgLy9BIHNjcmlwdCB0aGF0IGRvZXMgbm90IGNhbGwgZGVmaW5lKCksIHNvIGp1c3Qgc2ltdWxhdGUKICAgICAgICAgICAgICAgICAgICAgICAgLy90aGUgY2FsbCBmb3IgaXQuCiAgICAgICAgICAgICAgICAgICAgICAgIGNhbGxHZXRNb2R1bGUoW21vZHVsZU5hbWUsIChzaGltLmRlcHMgfHwgW10pLCBzaGltLmV4cG9ydHNGbl0pOwogICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICBjaGVja0xvYWRlZCgpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIENvbnZlcnRzIGEgbW9kdWxlIG5hbWUgdG8gYSBmaWxlIHBhdGguIFN1cHBvcnRzIGNhc2VzIHdoZXJlCiAgICAgICAgICAgICAqIG1vZHVsZU5hbWUgbWF5IGFjdHVhbGx5IGJlIGp1c3QgYW4gVVJMLgogICAgICAgICAgICAgKiBOb3RlIHRoYXQgaXQgKipkb2VzIG5vdCoqIGNhbGwgbm9ybWFsaXplIG9uIHRoZSBtb2R1bGVOYW1lLAogICAgICAgICAgICAgKiBpdCBpcyBhc3N1bWVkIHRvIGhhdmUgYWxyZWFkeSBiZWVuIG5vcm1hbGl6ZWQuIFRoaXMgaXMgYW4KICAgICAgICAgICAgICogaW50ZXJuYWwgQVBJLCBub3QgYSBwdWJsaWMgb25lLiBVc2UgdG9VcmwgZm9yIHRoZSBwdWJsaWMgQVBJLgogICAgICAgICAgICAgKi8KICAgICAgICAgICAgbmFtZVRvVXJsOiBmdW5jdGlvbiAobW9kdWxlTmFtZSwgZXh0LCBza2lwRXh0KSB7CiAgICAgICAgICAgICAgICB2YXIgcGF0aHMsIHN5bXMsIGksIHBhcmVudE1vZHVsZSwgdXJsLAogICAgICAgICAgICAgICAgICAgIHBhcmVudFBhdGgsIGJ1bmRsZUlkLAogICAgICAgICAgICAgICAgICAgIHBrZ01haW4gPSBnZXRPd24oY29uZmlnLnBrZ3MsIG1vZHVsZU5hbWUpOwoKICAgICAgICAgICAgICAgIGlmIChwa2dNYWluKSB7CiAgICAgICAgICAgICAgICAgICAgbW9kdWxlTmFtZSA9IHBrZ01haW47CiAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgYnVuZGxlSWQgPSBnZXRPd24oYnVuZGxlc01hcCwgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAgICAgaWYgKGJ1bmRsZUlkKSB7CiAgICAgICAgICAgICAgICAgICAgcmV0dXJuIGNvbnRleHQubmFtZVRvVXJsKGJ1bmRsZUlkLCBleHQsIHNraXBFeHQpOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIC8vSWYgYSBjb2xvbiBpcyBpbiB0aGUgVVJMLCBpdCBpbmRpY2F0ZXMgYSBwcm90b2NvbCBpcyB1c2VkIGFuZCBpdCBpcyBqdXN0CiAgICAgICAgICAgICAgICAvL2FuIFVSTCB0byBhIGZpbGUsIG9yIGlmIGl0IHN0YXJ0cyB3aXRoIGEgc2xhc2gsIGNvbnRhaW5zIGEgcXVlcnkgYXJnIChpLmUuID8pCiAgICAgICAgICAgICAgICAvL29yIGVuZHMgd2l0aCAuanMsIHRoZW4gYXNzdW1lIHRoZSB1c2VyIG1lYW50IHRvIHVzZSBhbiB1cmwgYW5kIG5vdCBhIG1vZHVsZSBpZC4KICAgICAgICAgICAgICAgIC8vVGhlIHNsYXNoIGlzIGltcG9ydGFudCBmb3IgcHJvdG9jb2wtbGVzcyBVUkxzIGFzIHdlbGwgYXMgZnVsbCBwYXRocy4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtb2R1bGVOYW1lKSkgewogICAgICAgICAgICAgICAgICAgIC8vSnVzdCBhIHBsYWluIHBhdGgsIG5vdCBtb2R1bGUgbmFtZSBsb29rdXAsIHNvIGp1c3QgcmV0dXJuIGl0LgogICAgICAgICAgICAgICAgICAgIC8vQWRkIGV4dGVuc2lvbiBpZiBpdCBpcyBpbmNsdWRlZC4gVGhpcyBpcyBhIGJpdCB3b25reSwgb25seSBub24tLmpzIHRoaW5ncyBwYXNzCiAgICAgICAgICAgICAgICAgICAgLy9hbiBleHRlbnNpb24sIHRoaXMgbWV0aG9kIHByb2JhYmx5IG5lZWRzIHRvIGJlIHJld29ya2VkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IG1vZHVsZU5hbWUgKyAoZXh0IHx8ICcnKTsKICAgICAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICAgICAgLy9BIG1vZHVsZSB0aGF0IG5lZWRzIHRvIGJlIGNvbnZlcnRlZCB0byBhIHBhdGguCiAgICAgICAgICAgICAgICAgICAgcGF0aHMgPSBjb25maWcucGF0aHM7CgogICAgICAgICAgICAgICAgICAgIHN5bXMgPSBtb2R1bGVOYW1lLnNwbGl0KCcvJyk7CiAgICAgICAgICAgICAgICAgICAgLy9Gb3IgZWFjaCBtb2R1bGUgbmFtZSBzZWdtZW50LCBzZWUgaWYgdGhlcmUgaXMgYSBwYXRoCiAgICAgICAgICAgICAgICAgICAgLy9yZWdpc3RlcmVkIGZvciBpdC4gU3RhcnQgd2l0aCBtb3N0IHNwZWNpZmljIG5hbWUKICAgICAgICAgICAgICAgICAgICAvL2FuZCB3b3JrIHVwIGZyb20gaXQuCiAgICAgICAgICAgICAgICAgICAgZm9yIChpID0gc3ltcy5sZW5ndGg7IGkgPiAwOyBpIC09IDEpIHsKICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50TW9kdWxlID0gc3ltcy5zbGljZSgwLCBpKS5qb2luKCcvJyk7CgogICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gZ2V0T3duKHBhdGhzLCBwYXJlbnRNb2R1bGUpOwogICAgICAgICAgICAgICAgICAgICAgICBpZiAocGFyZW50UGF0aCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgLy9JZiBhbiBhcnJheSwgaXQgbWVhbnMgdGhlcmUgYXJlIGEgZmV3IGNob2ljZXMsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAvL0Nob29zZSB0aGUgb25lIHRoYXQgaXMgZGVzaXJlZAogICAgICAgICAgICAgICAgICAgICAgICAgICAgaWYgKGlzQXJyYXkocGFyZW50UGF0aCkpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBwYXJlbnRQYXRoID0gcGFyZW50UGF0aFswXTsKICAgICAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICAgICAgICAgIHN5bXMuc3BsaWNlKDAsIGksIHBhcmVudFBhdGgpOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgYnJlYWs7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9CgogICAgICAgICAgICAgICAgICAgIC8vSm9pbiB0aGUgcGF0aCBwYXJ0cyB0b2dldGhlciwgdGhlbiBmaWd1cmUgb3V0IGlmIGJhc2VVcmwgaXMgbmVlZGVkLgogICAgICAgICAgICAgICAgICAgIHVybCA9IHN5bXMuam9pbignLycpOwogICAgICAgICAgICAgICAgICAgIHVybCArPSAoZXh0IHx8ICgvXmRhdGFcOnxcPy8udGVzdCh1cmwpIHx8IHNraXBFeHQgPyAnJyA6ICcuanMnKSk7CiAgICAgICAgICAgICAgICAgICAgdXJsID0gKHVybC5jaGFyQXQoMCkgPT09ICcvJyB8fCB1cmwubWF0Y2goL15bXHdcK1wuXC1dKzovKSA/ICcnIDogY29uZmlnLmJhc2VVcmwpICsgdXJsOwogICAgICAgICAgICAgICAgfQoKICAgICAgICAgICAgICAgIHJldHVybiBjb25maWcudXJsQXJncyA/IHVybCArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAoKHVybC5pbmRleE9mKCc/JykgPT09IC0xID8gJz8nIDogJyYnKSArCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgY29uZmlnLnVybEFyZ3MpIDogdXJsOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLy9EZWxlZ2F0ZXMgdG8gcmVxLmxvYWQuIEJyb2tlbiBvdXQgYXMgYSBzZXBhcmF0ZSBmdW5jdGlvbiB0bwogICAgICAgICAgICAvL2FsbG93IG92ZXJyaWRpbmcgaW4gdGhlIG9wdGltaXplci4KICAgICAgICAgICAgbG9hZDogZnVuY3Rpb24gKGlkLCB1cmwpIHsKICAgICAgICAgICAgICAgIHJlcS5sb2FkKGNvbnRleHQsIGlkLCB1cmwpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIEV4ZWN1dGVzIGEgbW9kdWxlIGNhbGxiYWNrIGZ1bmN0aW9uLiBCcm9rZW4gb3V0IGFzIGEgc2VwYXJhdGUgZnVuY3Rpb24KICAgICAgICAgICAgICogc29sZWx5IHRvIGFsbG93IHRoZSBidWlsZCBzeXN0ZW0gdG8gc2VxdWVuY2UgdGhlIGZpbGVzIGluIHRoZSBidWlsdAogICAgICAgICAgICAgKiBsYXllciBpbiB0aGUgcmlnaHQgc2VxdWVuY2UuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwcml2YXRlCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBleGVjQ2I6IGZ1bmN0aW9uIChuYW1lLCBjYWxsYmFjaywgYXJncywgZXhwb3J0cykgewogICAgICAgICAgICAgICAgcmV0dXJuIGNhbGxiYWNrLmFwcGx5KGV4cG9ydHMsIGFyZ3MpOwogICAgICAgICAgICB9LAoKICAgICAgICAgICAgLyoqCiAgICAgICAgICAgICAqIGNhbGxiYWNrIGZvciBzY3JpcHQgbG9hZHMsIHVzZWQgdG8gY2hlY2sgc3RhdHVzIG9mIGxvYWRpbmcuCiAgICAgICAgICAgICAqCiAgICAgICAgICAgICAqIEBwYXJhbSB7RXZlbnR9IGV2dCB0aGUgZXZlbnQgZnJvbSB0aGUgYnJvd3NlciBmb3IgdGhlIHNjcmlwdAogICAgICAgICAgICAgKiB0aGF0IHdhcyBsb2FkZWQuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdExvYWQ6IGZ1bmN0aW9uIChldnQpIHsKICAgICAgICAgICAgICAgIC8vVXNpbmcgY3VycmVudFRhcmdldCBpbnN0ZWFkIG9mIHRhcmdldCBmb3IgRmlyZWZveCAyLjAncyBzYWtlLiBOb3QKICAgICAgICAgICAgICAgIC8vYWxsIG9sZCBicm93c2VycyB3aWxsIGJlIHN1cHBvcnRlZCwgYnV0IHRoaXMgb25lIHdhcyBlYXN5IGVub3VnaAogICAgICAgICAgICAgICAgLy90byBzdXBwb3J0IGFuZCBzdGlsbCBtYWtlcyBzZW5zZS4KICAgICAgICAgICAgICAgIGlmIChldnQudHlwZSA9PT0gJ2xvYWQnIHx8CiAgICAgICAgICAgICAgICAgICAgICAgIChyZWFkeVJlZ0V4cC50ZXN0KChldnQuY3VycmVudFRhcmdldCB8fCBldnQuc3JjRWxlbWVudCkucmVhZHlTdGF0ZSkpKSB7CiAgICAgICAgICAgICAgICAgICAgLy9SZXNldCBpbnRlcmFjdGl2ZSBzY3JpcHQgc28gYSBzY3JpcHQgbm9kZSBpcyBub3QgaGVsZCBvbnRvIGZvcgogICAgICAgICAgICAgICAgICAgIC8vdG8gbG9uZy4KICAgICAgICAgICAgICAgICAgICBpbnRlcmFjdGl2ZVNjcmlwdCA9IG51bGw7CgogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvdXQgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZSBhbmQgdGhlIGNvbnRleHQuCiAgICAgICAgICAgICAgICAgICAgdmFyIGRhdGEgPSBnZXRTY3JpcHREYXRhKGV2dCk7CiAgICAgICAgICAgICAgICAgICAgY29udGV4dC5jb21wbGV0ZUxvYWQoZGF0YS5pZCk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0sCgogICAgICAgICAgICAvKioKICAgICAgICAgICAgICogQ2FsbGJhY2sgZm9yIHNjcmlwdCBlcnJvcnMuCiAgICAgICAgICAgICAqLwogICAgICAgICAgICBvblNjcmlwdEVycm9yOiBmdW5jdGlvbiAoZXZ0KSB7CiAgICAgICAgICAgICAgICB2YXIgZGF0YSA9IGdldFNjcmlwdERhdGEoZXZ0KTsKICAgICAgICAgICAgICAgIGlmICghaGFzUGF0aEZhbGxiYWNrKGRhdGEuaWQpKSB7CiAgICAgICAgICAgICAgICAgICAgdmFyIHBhcmVudHMgPSBbXTsKICAgICAgICAgICAgICAgICAgICBlYWNoUHJvcChyZWdpc3RyeSwgZnVuY3Rpb24odmFsdWUsIGtleSkgewogICAgICAgICAgICAgICAgICAgICAgICBpZiAoa2V5LmluZGV4T2YoJ19AcicpICE9PSAwKSB7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICBlYWNoKHZhbHVlLmRlcE1hcHMsIGZ1bmN0aW9uKGRlcE1hcCkgewogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGlmIChkZXBNYXAuaWQgPT09IGRhdGEuaWQpIHsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgcGFyZW50cy5wdXNoKGtleSk7CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfQogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHJldHVybiB0cnVlOwogICAgICAgICAgICAgICAgICAgICAgICAgICAgfSk7CiAgICAgICAgICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgICAgICB9KTsKICAgICAgICAgICAgICAgICAgICByZXR1cm4gb25FcnJvcihtYWtlRXJyb3IoJ3NjcmlwdGVycm9yJywgJ1NjcmlwdCBlcnJvciBmb3IgIicgKyBkYXRhLmlkICsKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgKHBhcmVudHMubGVuZ3RoID8KICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJyIsIG5lZWRlZCBieTogJyArIHBhcmVudHMuam9pbignLCAnKSA6CiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICciJyksIGV2dCwgW2RhdGEuaWRdKSk7CiAgICAgICAgICAgICAgICB9CiAgICAgICAgICAgIH0KICAgICAgICB9OwoKICAgICAgICBjb250ZXh0LnJlcXVpcmUgPSBjb250ZXh0Lm1ha2VSZXF1aXJlKCk7CiAgICAgICAgcmV0dXJuIGNvbnRleHQ7CiAgICB9CgogICAgLyoqCiAgICAgKiBNYWluIGVudHJ5IHBvaW50LgogICAgICoKICAgICAqIElmIHRoZSBvbmx5IGFyZ3VtZW50IHRvIHJlcXVpcmUgaXMgYSBzdHJpbmcsIHRoZW4gdGhlIG1vZHVsZSB0aGF0CiAgICAgKiBpcyByZXByZXNlbnRlZCBieSB0aGF0IHN0cmluZyBpcyBmZXRjaGVkIGZvciB0aGUgYXBwcm9wcmlhdGUgY29udGV4dC4KICAgICAqCiAgICAgKiBJZiB0aGUgZmlyc3QgYXJndW1lbnQgaXMgYW4gYXJyYXksIHRoZW4gaXQgd2lsbCBiZSB0cmVhdGVkIGFzIGFuIGFycmF5CiAgICAgKiBvZiBkZXBlbmRlbmN5IHN0cmluZyBuYW1lcyB0byBmZXRjaC4gQW4gb3B0aW9uYWwgZnVuY3Rpb24gY2FsbGJhY2sgY2FuCiAgICAgKiBiZSBzcGVjaWZpZWQgdG8gZXhlY3V0ZSB3aGVuIGFsbCBvZiB0aG9zZSBkZXBlbmRlbmNpZXMgYXJlIGF2YWlsYWJsZS4KICAgICAqCiAgICAgKiBNYWtlIGEgbG9jYWwgcmVxIHZhcmlhYmxlIHRvIGhlbHAgQ2FqYSBjb21wbGlhbmNlIChpdCBhc3N1bWVzIHRoaW5ncwogICAgICogb24gYSByZXF1aXJlIHRoYXQgYXJlIG5vdCBzdGFuZGFyZGl6ZWQpLCBhbmQgdG8gZ2l2ZSBhIHNob3J0CiAgICAgKiBuYW1lIGZvciBtaW5pZmljYXRpb24vbG9jYWwgc2NvcGUgdXNlLgogICAgICovCiAgICByZXEgPSByZXF1aXJlanMgPSBmdW5jdGlvbiAoZGVwcywgY2FsbGJhY2ssIGVycmJhY2ssIG9wdGlvbmFsKSB7CgogICAgICAgIC8vRmluZCB0aGUgcmlnaHQgY29udGV4dCwgdXNlIGRlZmF1bHQKICAgICAgICB2YXIgY29udGV4dCwgY29uZmlnLAogICAgICAgICAgICBjb250ZXh0TmFtZSA9IGRlZkNvbnRleHROYW1lOwoKICAgICAgICAvLyBEZXRlcm1pbmUgaWYgaGF2ZSBjb25maWcgb2JqZWN0IGluIHRoZSBjYWxsLgogICAgICAgIGlmICghaXNBcnJheShkZXBzKSAmJiB0eXBlb2YgZGVwcyAhPT0gJ3N0cmluZycpIHsKICAgICAgICAgICAgLy8gZGVwcyBpcyBhIGNvbmZpZyBvYmplY3QKICAgICAgICAgICAgY29uZmlnID0gZGVwczsKICAgICAgICAgICAgaWYgKGlzQXJyYXkoY2FsbGJhY2spKSB7CiAgICAgICAgICAgICAgICAvLyBBZGp1c3QgYXJncyBpZiB0aGVyZSBhcmUgZGVwZW5kZW5jaWVzCiAgICAgICAgICAgICAgICBkZXBzID0gY2FsbGJhY2s7CiAgICAgICAgICAgICAgICBjYWxsYmFjayA9IGVycmJhY2s7CiAgICAgICAgICAgICAgICBlcnJiYWNrID0gb3B0aW9uYWw7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBkZXBzID0gW107CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIGlmIChjb25maWcgJiYgY29uZmlnLmNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dE5hbWUgPSBjb25maWcuY29udGV4dDsKICAgICAgICB9CgogICAgICAgIGNvbnRleHQgPSBnZXRPd24oY29udGV4dHMsIGNvbnRleHROYW1lKTsKICAgICAgICBpZiAoIWNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dCA9IGNvbnRleHRzW2NvbnRleHROYW1lXSA9IHJlcS5zLm5ld0NvbnRleHQoY29udGV4dE5hbWUpOwogICAgICAgIH0KCiAgICAgICAgaWYgKGNvbmZpZykgewogICAgICAgICAgICBjb250ZXh0LmNvbmZpZ3VyZShjb25maWcpOwogICAgICAgIH0KCiAgICAgICAgcmV0dXJuIGNvbnRleHQucmVxdWlyZShkZXBzLCBjYWxsYmFjaywgZXJyYmFjayk7CiAgICB9OwoKICAgIC8qKgogICAgICogU3VwcG9ydCByZXF1aXJlLmNvbmZpZygpIHRvIG1ha2UgaXQgZWFzaWVyIHRvIGNvb3BlcmF0ZSB3aXRoIG90aGVyCiAgICAgKiBBTUQgbG9hZGVycyBvbiBnbG9iYWxseSBhZ3JlZWQgbmFtZXMuCiAgICAgKi8KICAgIHJlcS5jb25maWcgPSBmdW5jdGlvbiAoY29uZmlnKSB7CiAgICAgICAgcmV0dXJuIHJlcShjb25maWcpOwogICAgfTsKCiAgICAvKioKICAgICAqIEV4ZWN1dGUgc29tZXRoaW5nIGFmdGVyIHRoZSBjdXJyZW50IHRpY2sKICAgICAqIG9mIHRoZSBldmVudCBsb29wLiBPdmVycmlkZSBmb3Igb3RoZXIgZW52cwogICAgICogdGhhdCBoYXZlIGEgYmV0dGVyIHNvbHV0aW9uIHRoYW4gc2V0VGltZW91dC4KICAgICAqIEBwYXJhbSAge0Z1bmN0aW9ufSBmbiBmdW5jdGlvbiB0byBleGVjdXRlIGxhdGVyLgogICAgICovCiAgICByZXEubmV4dFRpY2sgPSB0eXBlb2Ygc2V0VGltZW91dCAhPT0gJ3VuZGVmaW5lZCcgPyBmdW5jdGlvbiAoZm4pIHsKICAgICAgICBzZXRUaW1lb3V0KGZuLCA0KTsKICAgIH0gOiBmdW5jdGlvbiAoZm4pIHsgZm4oKTsgfTsKCiAgICAvKioKICAgICAqIEV4cG9ydCByZXF1aXJlIGFzIGEgZ2xvYmFsLCBidXQgb25seSBpZiBpdCBkb2VzIG5vdCBhbHJlYWR5IGV4aXN0LgogICAgICovCiAgICBpZiAoIXJlcXVpcmUpIHsKICAgICAgICByZXF1aXJlID0gcmVxOwogICAgfQoKICAgIHJlcS52ZXJzaW9uID0gdmVyc2lvbjsKCiAgICAvL1VzZWQgdG8gZmlsdGVyIG91dCBkZXBlbmRlbmNpZXMgdGhhdCBhcmUgYWxyZWFkeSBwYXRocy4KICAgIHJlcS5qc0V4dFJlZ0V4cCA9IC9eXC98OnxcP3xcLmpzJC87CiAgICByZXEuaXNCcm93c2VyID0gaXNCcm93c2VyOwogICAgcyA9IHJlcS5zID0gewogICAgICAgIGNvbnRleHRzOiBjb250ZXh0cywKICAgICAgICBuZXdDb250ZXh0OiBuZXdDb250ZXh0CiAgICB9OwoKICAgIC8vQ3JlYXRlIGRlZmF1bHQgY29udGV4dC4KICAgIHJlcSh7fSk7CgogICAgLy9FeHBvcnRzIHNvbWUgY29udGV4dC1zZW5zaXRpdmUgbWV0aG9kcyBvbiBnbG9iYWwgcmVxdWlyZS4KICAgIGVhY2goWwogICAgICAgICd0b1VybCcsCiAgICAgICAgJ3VuZGVmJywKICAgICAgICAnZGVmaW5lZCcsCiAgICAgICAgJ3NwZWNpZmllZCcKICAgIF0sIGZ1bmN0aW9uIChwcm9wKSB7CiAgICAgICAgLy9SZWZlcmVuY2UgZnJvbSBjb250ZXh0cyBpbnN0ZWFkIG9mIGVhcmx5IGJpbmRpbmcgdG8gZGVmYXVsdCBjb250ZXh0LAogICAgICAgIC8vc28gdGhhdCBkdXJpbmcgYnVpbGRzLCB0aGUgbGF0ZXN0IGluc3RhbmNlIG9mIHRoZSBkZWZhdWx0IGNvbnRleHQKICAgICAgICAvL3dpdGggaXRzIGNvbmZpZyBnZXRzIHVzZWQuCiAgICAgICAgcmVxW3Byb3BdID0gZnVuY3Rpb24gKCkgewogICAgICAgICAgICB2YXIgY3R4ID0gY29udGV4dHNbZGVmQ29udGV4dE5hbWVdOwogICAgICAgICAgICByZXR1cm4gY3R4LnJlcXVpcmVbcHJvcF0uYXBwbHkoY3R4LCBhcmd1bWVudHMpOwogICAgICAgIH07CiAgICB9KTsKCiAgICBpZiAoaXNCcm93c2VyKSB7CiAgICAgICAgaGVhZCA9IHMuaGVhZCA9IGRvY3VtZW50LmdldEVsZW1lbnRzQnlUYWdOYW1lKCdoZWFkJylbMF07CiAgICAgICAgLy9JZiBCQVNFIHRhZyBpcyBpbiBwbGF5LCB1c2luZyBhcHBlbmRDaGlsZCBpcyBhIHByb2JsZW0gZm9yIElFNi4KICAgICAgICAvL1doZW4gdGhhdCBicm93c2VyIGRpZXMsIHRoaXMgY2FuIGJlIHJlbW92ZWQuIERldGFpbHMgaW4gdGhpcyBqUXVlcnkgYnVnOgogICAgICAgIC8vaHR0cDovL2Rldi5qcXVlcnkuY29tL3RpY2tldC8yNzA5CiAgICAgICAgYmFzZUVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50c0J5VGFnTmFtZSgnYmFzZScpWzBdOwogICAgICAgIGlmIChiYXNlRWxlbWVudCkgewogICAgICAgICAgICBoZWFkID0gcy5oZWFkID0gYmFzZUVsZW1lbnQucGFyZW50Tm9kZTsKICAgICAgICB9CiAgICB9CgogICAgLyoqCiAgICAgKiBBbnkgZXJyb3JzIHRoYXQgcmVxdWlyZSBleHBsaWNpdGx5IGdlbmVyYXRlcyB3aWxsIGJlIHBhc3NlZCB0byB0aGlzCiAgICAgKiBmdW5jdGlvbi4gSW50ZXJjZXB0L292ZXJyaWRlIGl0IGlmIHlvdSB3YW50IGN1c3RvbSBlcnJvciBoYW5kbGluZy4KICAgICAqIEBwYXJhbSB7RXJyb3J9IGVyciB0aGUgZXJyb3Igb2JqZWN0LgogICAgICovCiAgICByZXEub25FcnJvciA9IGRlZmF1bHRPbkVycm9yOwoKICAgIC8qKgogICAgICogQ3JlYXRlcyB0aGUgbm9kZSBmb3IgdGhlIGxvYWQgY29tbWFuZC4gT25seSB1c2VkIGluIGJyb3dzZXIgZW52cy4KICAgICAqLwogICAgcmVxLmNyZWF0ZU5vZGUgPSBmdW5jdGlvbiAoY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgbm9kZSA9IGNvbmZpZy54aHRtbCA/CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50TlMoJ2h0dHA6Ly93d3cudzMub3JnLzE5OTkveGh0bWwnLCAnaHRtbDpzY3JpcHQnKSA6CiAgICAgICAgICAgICAgICBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdzY3JpcHQnKTsKICAgICAgICBub2RlLnR5cGUgPSBjb25maWcuc2NyaXB0VHlwZSB8fCAndGV4dC9qYXZhc2NyaXB0JzsKICAgICAgICBub2RlLmNoYXJzZXQgPSAndXRmLTgnOwogICAgICAgIG5vZGUuYXN5bmMgPSB0cnVlOwogICAgICAgIHJldHVybiBub2RlOwogICAgfTsKCiAgICAvKioKICAgICAqIERvZXMgdGhlIHJlcXVlc3QgdG8gbG9hZCBhIG1vZHVsZSBmb3IgdGhlIGJyb3dzZXIgY2FzZS4KICAgICAqIE1ha2UgdGhpcyBhIHNlcGFyYXRlIGZ1bmN0aW9uIHRvIGFsbG93IG90aGVyIGVudmlyb25tZW50cwogICAgICogdG8gb3ZlcnJpZGUgaXQuCiAgICAgKgogICAgICogQHBhcmFtIHtPYmplY3R9IGNvbnRleHQgdGhlIHJlcXVpcmUgY29udGV4dCB0byBmaW5kIHN0YXRlLgogICAgICogQHBhcmFtIHtTdHJpbmd9IG1vZHVsZU5hbWUgdGhlIG5hbWUgb2YgdGhlIG1vZHVsZS4KICAgICAqIEBwYXJhbSB7T2JqZWN0fSB1cmwgdGhlIFVSTCB0byB0aGUgbW9kdWxlLgogICAgICovCiAgICByZXEubG9hZCA9IGZ1bmN0aW9uIChjb250ZXh0LCBtb2R1bGVOYW1lLCB1cmwpIHsKICAgICAgICB2YXIgY29uZmlnID0gKGNvbnRleHQgJiYgY29udGV4dC5jb25maWcpIHx8IHt9LAogICAgICAgICAgICBub2RlOwogICAgICAgIGlmIChpc0Jyb3dzZXIpIHsKICAgICAgICAgICAgLy9JbiB0aGUgYnJvd3NlciBzbyB1c2UgYSBzY3JpcHQgdGFnCiAgICAgICAgICAgIG5vZGUgPSByZXEuY3JlYXRlTm9kZShjb25maWcsIG1vZHVsZU5hbWUsIHVybCk7CiAgICAgICAgICAgIGlmIChjb25maWcub25Ob2RlQ3JlYXRlZCkgewogICAgICAgICAgICAgICAgY29uZmlnLm9uTm9kZUNyZWF0ZWQobm9kZSwgY29uZmlnLCBtb2R1bGVOYW1lLCB1cmwpOwogICAgICAgICAgICB9CgogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcsIGNvbnRleHQuY29udGV4dE5hbWUpOwogICAgICAgICAgICBub2RlLnNldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlbW9kdWxlJywgbW9kdWxlTmFtZSk7CgogICAgICAgICAgICAvL1NldCB1cCBsb2FkIGxpc3RlbmVyLiBUZXN0IGF0dGFjaEV2ZW50IGZpcnN0IGJlY2F1c2UgSUU5IGhhcwogICAgICAgICAgICAvL2Egc3VidGxlIGlzc3VlIGluIGl0cyBhZGRFdmVudExpc3RlbmVyIGFuZCBzY3JpcHQgb25sb2FkIGZpcmluZ3MKICAgICAgICAgICAgLy90aGF0IGRvIG5vdCBtYXRjaCB0aGUgYmVoYXZpb3Igb2YgYWxsIG90aGVyIGJyb3dzZXJzIHdpdGgKICAgICAgICAgICAgLy9hZGRFdmVudExpc3RlbmVyIHN1cHBvcnQsIHdoaWNoIGZpcmUgdGhlIG9ubG9hZCBldmVudCBmb3IgYQogICAgICAgICAgICAvL3NjcmlwdCByaWdodCBhZnRlciB0aGUgc2NyaXB0IGV4ZWN1dGlvbi4gU2VlOgogICAgICAgICAgICAvL2h0dHBzOi8vY29ubmVjdC5taWNyb3NvZnQuY29tL0lFL2ZlZWRiYWNrL2RldGFpbHMvNjQ4MDU3L3NjcmlwdC1vbmxvYWQtZXZlbnQtaXMtbm90LWZpcmVkLWltbWVkaWF0ZWx5LWFmdGVyLXNjcmlwdC1leGVjdXRpb24KICAgICAgICAgICAgLy9VTkZPUlRVTkFURUxZIE9wZXJhIGltcGxlbWVudHMgYXR0YWNoRXZlbnQgYnV0IGRvZXMgbm90IGZvbGxvdyB0aGUgc2NyaXB0CiAgICAgICAgICAgIC8vc2NyaXB0IGV4ZWN1dGlvbiBtb2RlLgogICAgICAgICAgICBpZiAobm9kZS5hdHRhY2hFdmVudCAmJgogICAgICAgICAgICAgICAgICAgIC8vQ2hlY2sgaWYgbm9kZS5hdHRhY2hFdmVudCBpcyBhcnRpZmljaWFsbHkgYWRkZWQgYnkgY3VzdG9tIHNjcmlwdCBvcgogICAgICAgICAgICAgICAgICAgIC8vbmF0aXZlbHkgc3VwcG9ydGVkIGJ5IGJyb3dzZXIKICAgICAgICAgICAgICAgICAgICAvL3JlYWQgaHR0cHM6Ly9naXRodWIuY29tL2pyYnVya2UvcmVxdWlyZWpzL2lzc3Vlcy8xODcKICAgICAgICAgICAgICAgICAgICAvL2lmIHdlIGNhbiBOT1QgZmluZCBbbmF0aXZlIGNvZGVdIHRoZW4gaXQgbXVzdCBOT1QgbmF0aXZlbHkgc3VwcG9ydGVkLgogICAgICAgICAgICAgICAgICAgIC8vaW4gSUU4LCBub2RlLmF0dGFjaEV2ZW50IGRvZXMgbm90IGhhdmUgdG9TdHJpbmcoKQogICAgICAgICAgICAgICAgICAgIC8vTm90ZSB0aGUgdGVzdCBmb3IgIltuYXRpdmUgY29kZSIgd2l0aCBubyBjbG9zaW5nIGJyYWNlLCBzZWU6CiAgICAgICAgICAgICAgICAgICAgLy9odHRwczovL2dpdGh1Yi5jb20vanJidXJrZS9yZXF1aXJlanMvaXNzdWVzLzI3MwogICAgICAgICAgICAgICAgICAgICEobm9kZS5hdHRhY2hFdmVudC50b1N0cmluZyAmJiBub2RlLmF0dGFjaEV2ZW50LnRvU3RyaW5nKCkuaW5kZXhPZignW25hdGl2ZSBjb2RlJykgPCAwKSAmJgogICAgICAgICAgICAgICAgICAgICFpc09wZXJhKSB7CiAgICAgICAgICAgICAgICAvL1Byb2JhYmx5IElFLiBJRSAoYXQgbGVhc3QgNi04KSBkbyBub3QgZmlyZQogICAgICAgICAgICAgICAgLy9zY3JpcHQgb25sb2FkIHJpZ2h0IGFmdGVyIGV4ZWN1dGluZyB0aGUgc2NyaXB0LCBzbwogICAgICAgICAgICAgICAgLy93ZSBjYW5ub3QgdGllIHRoZSBhbm9ueW1vdXMgZGVmaW5lIGNhbGwgdG8gYSBuYW1lLgogICAgICAgICAgICAgICAgLy9Ib3dldmVyLCBJRSByZXBvcnRzIHRoZSBzY3JpcHQgYXMgYmVpbmcgaW4gJ2ludGVyYWN0aXZlJwogICAgICAgICAgICAgICAgLy9yZWFkeVN0YXRlIGF0IHRoZSB0aW1lIG9mIHRoZSBkZWZpbmUgY2FsbC4KICAgICAgICAgICAgICAgIHVzZUludGVyYWN0aXZlID0gdHJ1ZTsKCiAgICAgICAgICAgICAgICBub2RlLmF0dGFjaEV2ZW50KCdvbnJlYWR5c3RhdGVjaGFuZ2UnLCBjb250ZXh0Lm9uU2NyaXB0TG9hZCk7CiAgICAgICAgICAgICAgICAvL0l0IHdvdWxkIGJlIGdyZWF0IHRvIGFkZCBhbiBlcnJvciBoYW5kbGVyIGhlcmUgdG8gY2F0Y2gKICAgICAgICAgICAgICAgIC8vNDA0cyBpbiBJRTkrLiBIb3dldmVyLCBvbnJlYWR5c3RhdGVjaGFuZ2Ugd2lsbCBmaXJlIGJlZm9yZQogICAgICAgICAgICAgICAgLy90aGUgZXJyb3IgaGFuZGxlciwgc28gdGhhdCBkb2VzIG5vdCBoZWxwLiBJZiBhZGRFdmVudExpc3RlbmVyCiAgICAgICAgICAgICAgICAvL2lzIHVzZWQsIHRoZW4gSUUgd2lsbCBmaXJlIGVycm9yIGJlZm9yZSBsb2FkLCBidXQgd2UgY2Fubm90CiAgICAgICAgICAgICAgICAvL3VzZSB0aGF0IHBhdGh3YXkgZ2l2ZW4gdGhlIGNvbm5lY3QubWljcm9zb2Z0LmNvbSBpc3N1ZQogICAgICAgICAgICAgICAgLy9tZW50aW9uZWQgYWJvdmUgYWJvdXQgbm90IGRvaW5nIHRoZSAnc2NyaXB0IGV4ZWN1dGUsCiAgICAgICAgICAgICAgICAvL3RoZW4gZmlyZSB0aGUgc2NyaXB0IGxvYWQgZXZlbnQgbGlzdGVuZXIgYmVmb3JlIGV4ZWN1dGUKICAgICAgICAgICAgICAgIC8vbmV4dCBzY3JpcHQnIHRoYXQgb3RoZXIgYnJvd3NlcnMgZG8uCiAgICAgICAgICAgICAgICAvL0Jlc3QgaG9wZTogSUUxMCBmaXhlcyB0aGUgaXNzdWVzLAogICAgICAgICAgICAgICAgLy9hbmQgdGhlbiBkZXN0cm95cyBhbGwgaW5zdGFsbHMgb2YgSUUgNi05LgogICAgICAgICAgICAgICAgLy9ub2RlLmF0dGFjaEV2ZW50KCdvbmVycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yKTsKICAgICAgICAgICAgfSBlbHNlIHsKICAgICAgICAgICAgICAgIG5vZGUuYWRkRXZlbnRMaXN0ZW5lcignbG9hZCcsIGNvbnRleHQub25TY3JpcHRMb2FkLCBmYWxzZSk7CiAgICAgICAgICAgICAgICBub2RlLmFkZEV2ZW50TGlzdGVuZXIoJ2Vycm9yJywgY29udGV4dC5vblNjcmlwdEVycm9yLCBmYWxzZSk7CiAgICAgICAgICAgIH0KICAgICAgICAgICAgbm9kZS5zcmMgPSB1cmw7CgogICAgICAgICAgICAvL0ZvciBzb21lIGNhY2hlIGNhc2VzIGluIElFIDYtOCwgdGhlIHNjcmlwdCBleGVjdXRlcyBiZWZvcmUgdGhlIGVuZAogICAgICAgICAgICAvL29mIHRoZSBhcHBlbmRDaGlsZCBleGVjdXRpb24sIHNvIHRvIHRpZSBhbiBhbm9ueW1vdXMgZGVmaW5lCiAgICAgICAgICAgIC8vY2FsbCB0byB0aGUgbW9kdWxlIG5hbWUgKHdoaWNoIGlzIHN0b3JlZCBvbiB0aGUgbm9kZSksIGhvbGQgb24KICAgICAgICAgICAgLy90byBhIHJlZmVyZW5jZSB0byB0aGlzIG5vZGUsIGJ1dCBjbGVhciBhZnRlciB0aGUgRE9NIGluc2VydGlvbi4KICAgICAgICAgICAgY3VycmVudGx5QWRkaW5nU2NyaXB0ID0gbm9kZTsKICAgICAgICAgICAgaWYgKGJhc2VFbGVtZW50KSB7CiAgICAgICAgICAgICAgICBoZWFkLmluc2VydEJlZm9yZShub2RlLCBiYXNlRWxlbWVudCk7CiAgICAgICAgICAgIH0gZWxzZSB7CiAgICAgICAgICAgICAgICBoZWFkLmFwcGVuZENoaWxkKG5vZGUpOwogICAgICAgICAgICB9CiAgICAgICAgICAgIGN1cnJlbnRseUFkZGluZ1NjcmlwdCA9IG51bGw7CgogICAgICAgICAgICByZXR1cm4gbm9kZTsKICAgICAgICB9IGVsc2UgaWYgKGlzV2ViV29ya2VyKSB7CiAgICAgICAgICAgIHRyeSB7CiAgICAgICAgICAgICAgICAvL0luIGEgd2ViIHdvcmtlciwgdXNlIGltcG9ydFNjcmlwdHMuIFRoaXMgaXMgbm90IGEgdmVyeQogICAgICAgICAgICAgICAgLy9lZmZpY2llbnQgdXNlIG9mIGltcG9ydFNjcmlwdHMsIGltcG9ydFNjcmlwdHMgd2lsbCBibG9jayB1bnRpbAogICAgICAgICAgICAgICAgLy9pdHMgc2NyaXB0IGlzIGRvd25sb2FkZWQgYW5kIGV2YWx1YXRlZC4gSG93ZXZlciwgaWYgd2ViIHdvcmtlcnMKICAgICAgICAgICAgICAgIC8vYXJlIGluIHBsYXksIHRoZSBleHBlY3RhdGlvbiBpcyB0aGF0IGEgYnVpbGQgaGFzIGJlZW4gZG9uZSBzbwogICAgICAgICAgICAgICAgLy90aGF0IG9ubHkgb25lIHNjcmlwdCBuZWVkcyB0byBiZSBsb2FkZWQgYW55d2F5LiBUaGlzIG1heSBuZWVkCiAgICAgICAgICAgICAgICAvL3RvIGJlIHJlZXZhbHVhdGVkIGlmIG90aGVyIHVzZSBjYXNlcyBiZWNvbWUgY29tbW9uLgogICAgICAgICAgICAgICAgaW1wb3J0U2NyaXB0cyh1cmwpOwoKICAgICAgICAgICAgICAgIC8vQWNjb3VudCBmb3IgYW5vbnltb3VzIG1vZHVsZXMKICAgICAgICAgICAgICAgIGNvbnRleHQuY29tcGxldGVMb2FkKG1vZHVsZU5hbWUpOwogICAgICAgICAgICB9IGNhdGNoIChlKSB7CiAgICAgICAgICAgICAgICBjb250ZXh0Lm9uRXJyb3IobWFrZUVycm9yKCdpbXBvcnRzY3JpcHRzJywKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnaW1wb3J0U2NyaXB0cyBmYWlsZWQgZm9yICcgKwogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBtb2R1bGVOYW1lICsgJyBhdCAnICsgdXJsLAogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIGUsCiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgW21vZHVsZU5hbWVdKSk7CiAgICAgICAgICAgIH0KICAgICAgICB9CiAgICB9OwoKICAgIGZ1bmN0aW9uIGdldEludGVyYWN0aXZlU2NyaXB0KCkgewogICAgICAgIGlmIChpbnRlcmFjdGl2ZVNjcmlwdCAmJiBpbnRlcmFjdGl2ZVNjcmlwdC5yZWFkeVN0YXRlID09PSAnaW50ZXJhY3RpdmUnKSB7CiAgICAgICAgICAgIHJldHVybiBpbnRlcmFjdGl2ZVNjcmlwdDsKICAgICAgICB9CgogICAgICAgIGVhY2hSZXZlcnNlKHNjcmlwdHMoKSwgZnVuY3Rpb24gKHNjcmlwdCkgewogICAgICAgICAgICBpZiAoc2NyaXB0LnJlYWR5U3RhdGUgPT09ICdpbnRlcmFjdGl2ZScpIHsKICAgICAgICAgICAgICAgIHJldHVybiAoaW50ZXJhY3RpdmVTY3JpcHQgPSBzY3JpcHQpOwogICAgICAgICAgICB9CiAgICAgICAgfSk7CiAgICAgICAgcmV0dXJuIGludGVyYWN0aXZlU2NyaXB0OwogICAgfQoKICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gc2NyaXB0IGF0dHJpYnV0ZSwgd2hpY2ggY291bGQgYWxzbyBhZGp1c3QgdGhlIGJhc2VVcmwuCiAgICBpZiAoaXNCcm93c2VyICYmICFjZmcuc2tpcERhdGFNYWluKSB7CiAgICAgICAgLy9GaWd1cmUgb3V0IGJhc2VVcmwuIEdldCBpdCBmcm9tIHRoZSBzY3JpcHQgdGFnIHdpdGggcmVxdWlyZS5qcyBpbiBpdC4KICAgICAgICBlYWNoUmV2ZXJzZShzY3JpcHRzKCksIGZ1bmN0aW9uIChzY3JpcHQpIHsKICAgICAgICAgICAgLy9TZXQgdGhlICdoZWFkJyB3aGVyZSB3ZSBjYW4gYXBwZW5kIGNoaWxkcmVuIGJ5CiAgICAgICAgICAgIC8vdXNpbmcgdGhlIHNjcmlwdCdzIHBhcmVudC4KICAgICAgICAgICAgaWYgKCFoZWFkKSB7CiAgICAgICAgICAgICAgICBoZWFkID0gc2NyaXB0LnBhcmVudE5vZGU7CiAgICAgICAgICAgIH0KCiAgICAgICAgICAgIC8vTG9vayBmb3IgYSBkYXRhLW1haW4gYXR0cmlidXRlIHRvIHNldCBtYWluIHNjcmlwdCBmb3IgdGhlIHBhZ2UKICAgICAgICAgICAgLy90byBsb2FkLiBJZiBpdCBpcyB0aGVyZSwgdGhlIHBhdGggdG8gZGF0YSBtYWluIGJlY29tZXMgdGhlCiAgICAgICAgICAgIC8vYmFzZVVybCwgaWYgaXQgaXMgbm90IGFscmVhZHkgc2V0LgogICAgICAgICAgICBkYXRhTWFpbiA9IHNjcmlwdC5nZXRBdHRyaWJ1dGUoJ2RhdGEtbWFpbicpOwogICAgICAgICAgICBpZiAoZGF0YU1haW4pIHsKICAgICAgICAgICAgICAgIC8vUHJlc2VydmUgZGF0YU1haW4gaW4gY2FzZSBpdCBpcyBhIHBhdGggKGkuZS4gY29udGFpbnMgJz8nKQogICAgICAgICAgICAgICAgbWFpblNjcmlwdCA9IGRhdGFNYWluOwoKICAgICAgICAgICAgICAgIC8vU2V0IGZpbmFsIGJhc2VVcmwgaWYgdGhlcmUgaXMgbm90IGFscmVhZHkgYW4gZXhwbGljaXQgb25lLgogICAgICAgICAgICAgICAgaWYgKCFjZmcuYmFzZVVybCkgewogICAgICAgICAgICAgICAgICAgIC8vUHVsbCBvZmYgdGhlIGRpcmVjdG9yeSBvZiBkYXRhLW1haW4gZm9yIHVzZSBhcyB0aGUKICAgICAgICAgICAgICAgICAgICAvL2Jhc2VVcmwuCiAgICAgICAgICAgICAgICAgICAgc3JjID0gbWFpblNjcmlwdC5zcGxpdCgnLycpOwogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBzcmMucG9wKCk7CiAgICAgICAgICAgICAgICAgICAgc3ViUGF0aCA9IHNyYy5sZW5ndGggPyBzcmMuam9pbignLycpICArICcvJyA6ICcuLyc7CgogICAgICAgICAgICAgICAgICAgIGNmZy5iYXNlVXJsID0gc3ViUGF0aDsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1N0cmlwIG9mZiBhbnkgdHJhaWxpbmcgLmpzIHNpbmNlIG1haW5TY3JpcHQgaXMgbm93CiAgICAgICAgICAgICAgICAvL2xpa2UgYSBtb2R1bGUgbmFtZS4KICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBtYWluU2NyaXB0LnJlcGxhY2UoanNTdWZmaXhSZWdFeHAsICcnKTsKCiAgICAgICAgICAgICAgICAvL0lmIG1haW5TY3JpcHQgaXMgc3RpbGwgYSBwYXRoLCBmYWxsIGJhY2sgdG8gZGF0YU1haW4KICAgICAgICAgICAgICAgIGlmIChyZXEuanNFeHRSZWdFeHAudGVzdChtYWluU2NyaXB0KSkgewogICAgICAgICAgICAgICAgICAgIG1haW5TY3JpcHQgPSBkYXRhTWFpbjsKICAgICAgICAgICAgICAgIH0KCiAgICAgICAgICAgICAgICAvL1B1dCB0aGUgZGF0YS1tYWluIHNjcmlwdCBpbiB0aGUgZmlsZXMgdG8gbG9hZC4KICAgICAgICAgICAgICAgIGNmZy5kZXBzID0gY2ZnLmRlcHMgPyBjZmcuZGVwcy5jb25jYXQobWFpblNjcmlwdCkgOiBbbWFpblNjcmlwdF07CgogICAgICAgICAgICAgICAgcmV0dXJuIHRydWU7CiAgICAgICAgICAgIH0KICAgICAgICB9KTsKICAgIH0KCiAgICAvKioKICAgICAqIFRoZSBmdW5jdGlvbiB0aGF0IGhhbmRsZXMgZGVmaW5pdGlvbnMgb2YgbW9kdWxlcy4gRGlmZmVycyBmcm9tCiAgICAgKiByZXF1aXJlKCkgaW4gdGhhdCBhIHN0cmluZyBmb3IgdGhlIG1vZHVsZSBzaG91bGQgYmUgdGhlIGZpcnN0IGFyZ3VtZW50LAogICAgICogYW5kIHRoZSBmdW5jdGlvbiB0byBleGVjdXRlIGFmdGVyIGRlcGVuZGVuY2llcyBhcmUgbG9hZGVkIHNob3VsZAogICAgICogcmV0dXJuIGEgdmFsdWUgdG8gZGVmaW5lIHRoZSBtb2R1bGUgY29ycmVzcG9uZGluZyB0byB0aGUgZmlyc3QgYXJndW1lbnQncwogICAgICogbmFtZS4KICAgICAqLwogICAgZGVmaW5lID0gZnVuY3Rpb24gKG5hbWUsIGRlcHMsIGNhbGxiYWNrKSB7CiAgICAgICAgdmFyIG5vZGUsIGNvbnRleHQ7CgogICAgICAgIC8vQWxsb3cgZm9yIGFub255bW91cyBtb2R1bGVzCiAgICAgICAgaWYgKHR5cGVvZiBuYW1lICE9PSAnc3RyaW5nJykgewogICAgICAgICAgICAvL0FkanVzdCBhcmdzIGFwcHJvcHJpYXRlbHkKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbmFtZTsKICAgICAgICAgICAgbmFtZSA9IG51bGw7CiAgICAgICAgfQoKICAgICAgICAvL1RoaXMgbW9kdWxlIG1heSBub3QgaGF2ZSBkZXBlbmRlbmNpZXMKICAgICAgICBpZiAoIWlzQXJyYXkoZGVwcykpIHsKICAgICAgICAgICAgY2FsbGJhY2sgPSBkZXBzOwogICAgICAgICAgICBkZXBzID0gbnVsbDsKICAgICAgICB9CgogICAgICAgIC8vSWYgbm8gbmFtZSwgYW5kIGNhbGxiYWNrIGlzIGEgZnVuY3Rpb24sIHRoZW4gZmlndXJlIG91dCBpZiBpdCBhCiAgICAgICAgLy9Db21tb25KUyB0aGluZyB3aXRoIGRlcGVuZGVuY2llcy4KICAgICAgICBpZiAoIWRlcHMgJiYgaXNGdW5jdGlvbihjYWxsYmFjaykpIHsKICAgICAgICAgICAgZGVwcyA9IFtdOwogICAgICAgICAgICAvL1JlbW92ZSBjb21tZW50cyBmcm9tIHRoZSBjYWxsYmFjayBzdHJpbmcsCiAgICAgICAgICAgIC8vbG9vayBmb3IgcmVxdWlyZSBjYWxscywgYW5kIHB1bGwgdGhlbSBpbnRvIHRoZSBkZXBlbmRlbmNpZXMsCiAgICAgICAgICAgIC8vYnV0IG9ubHkgaWYgdGhlcmUgYXJlIGZ1bmN0aW9uIGFyZ3MuCiAgICAgICAgICAgIGlmIChjYWxsYmFjay5sZW5ndGgpIHsKICAgICAgICAgICAgICAgIGNhbGxiYWNrCiAgICAgICAgICAgICAgICAgICAgLnRvU3RyaW5nKCkKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjb21tZW50UmVnRXhwLCAnJykKICAgICAgICAgICAgICAgICAgICAucmVwbGFjZShjanNSZXF1aXJlUmVnRXhwLCBmdW5jdGlvbiAobWF0Y2gsIGRlcCkgewogICAgICAgICAgICAgICAgICAgICAgICBkZXBzLnB1c2goZGVwKTsKICAgICAgICAgICAgICAgICAgICB9KTsKCiAgICAgICAgICAgICAgICAvL01heSBiZSBhIENvbW1vbkpTIHRoaW5nIGV2ZW4gd2l0aG91dCByZXF1aXJlIGNhbGxzLCBidXQgc3RpbGwKICAgICAgICAgICAgICAgIC8vY291bGQgdXNlIGV4cG9ydHMsIGFuZCBtb2R1bGUuIEF2b2lkIGRvaW5nIGV4cG9ydHMgYW5kIG1vZHVsZQogICAgICAgICAgICAgICAgLy93b3JrIHRob3VnaCBpZiBpdCBqdXN0IG5lZWRzIHJlcXVpcmUuCiAgICAgICAgICAgICAgICAvL1JFUVVJUkVTIHRoZSBmdW5jdGlvbiB0byBleHBlY3QgdGhlIENvbW1vbkpTIHZhcmlhYmxlcyBpbiB0aGUKICAgICAgICAgICAgICAgIC8vb3JkZXIgbGlzdGVkIGJlbG93LgogICAgICAgICAgICAgICAgZGVwcyA9IChjYWxsYmFjay5sZW5ndGggPT09IDEgPyBbJ3JlcXVpcmUnXSA6IFsncmVxdWlyZScsICdleHBvcnRzJywgJ21vZHVsZSddKS5jb25jYXQoZGVwcyk7CiAgICAgICAgICAgIH0KICAgICAgICB9CgogICAgICAgIC8vSWYgaW4gSUUgNi04IGFuZCBoaXQgYW4gYW5vbnltb3VzIGRlZmluZSgpIGNhbGwsIGRvIHRoZSBpbnRlcmFjdGl2ZQogICAgICAgIC8vd29yay4KICAgICAgICBpZiAodXNlSW50ZXJhY3RpdmUpIHsKICAgICAgICAgICAgbm9kZSA9IGN1cnJlbnRseUFkZGluZ1NjcmlwdCB8fCBnZXRJbnRlcmFjdGl2ZVNjcmlwdCgpOwogICAgICAgICAgICBpZiAobm9kZSkgewogICAgICAgICAgICAgICAgaWYgKCFuYW1lKSB7CiAgICAgICAgICAgICAgICAgICAgbmFtZSA9IG5vZGUuZ2V0QXR0cmlidXRlKCdkYXRhLXJlcXVpcmVtb2R1bGUnKTsKICAgICAgICAgICAgICAgIH0KICAgICAgICAgICAgICAgIGNvbnRleHQgPSBjb250ZXh0c1tub2RlLmdldEF0dHJpYnV0ZSgnZGF0YS1yZXF1aXJlY29udGV4dCcpXTsKICAgICAgICAgICAgfQogICAgICAgIH0KCiAgICAgICAgLy9BbHdheXMgc2F2ZSBvZmYgZXZhbHVhdGluZyB0aGUgZGVmIGNhbGwgdW50aWwgdGhlIHNjcmlwdCBvbmxvYWQgaGFuZGxlci4KICAgICAgICAvL1RoaXMgYWxsb3dzIG11bHRpcGxlIG1vZHVsZXMgdG8gYmUgaW4gYSBmaWxlIHdpdGhvdXQgcHJlbWF0dXJlbHkKICAgICAgICAvL3RyYWNpbmcgZGVwZW5kZW5jaWVzLCBhbmQgYWxsb3dzIGZvciBhbm9ueW1vdXMgbW9kdWxlIHN1cHBvcnQsCiAgICAgICAgLy93aGVyZSB0aGUgbW9kdWxlIG5hbWUgaXMgbm90IGtub3duIHVudGlsIHRoZSBzY3JpcHQgb25sb2FkIGV2ZW50CiAgICAgICAgLy9vY2N1cnMuIElmIG5vIGNvbnRleHQsIHVzZSB0aGUgZ2xvYmFsIHF1ZXVlLCBhbmQgZ2V0IGl0IHByb2Nlc3NlZAogICAgICAgIC8vaW4gdGhlIG9uc2NyaXB0IGxvYWQgY2FsbGJhY2suCiAgICAgICAgaWYgKGNvbnRleHQpIHsKICAgICAgICAgICAgY29udGV4dC5kZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgICAgICBjb250ZXh0LmRlZlF1ZXVlTWFwW25hbWVdID0gdHJ1ZTsKICAgICAgICB9IGVsc2UgewogICAgICAgICAgICBnbG9iYWxEZWZRdWV1ZS5wdXNoKFtuYW1lLCBkZXBzLCBjYWxsYmFja10pOwogICAgICAgIH0KICAgIH07CgogICAgZGVmaW5lLmFtZCA9IHsKICAgICAgICBqUXVlcnk6IHRydWUKICAgIH07CgogICAgLyoqCiAgICAgKiBFeGVjdXRlcyB0aGUgdGV4dC4gTm9ybWFsbHkganVzdCB1c2VzIGV2YWwsIGJ1dCBjYW4gYmUgbW9kaWZpZWQKICAgICAqIHRvIHVzZSBhIGJldHRlciwgZW52aXJvbm1lbnQtc3BlY2lmaWMgY2FsbC4gT25seSB1c2VkIGZvciB0cmFuc3BpbGluZwogICAgICogbG9hZGVyIHBsdWdpbnMsIG5vdCBmb3IgcGxhaW4gSlMgbW9kdWxlcy4KICAgICAqIEBwYXJhbSB7U3RyaW5nfSB0ZXh0IHRoZSB0ZXh0IHRvIGV4ZWN1dGUvZXZhbHVhdGUuCiAgICAgKi8KICAgIHJlcS5leGVjID0gZnVuY3Rpb24gKHRleHQpIHsKICAgICAgICAvKmpzbGludCBldmlsOiB0cnVlICovCiAgICAgICAgcmV0dXJuIGV2YWwodGV4dCk7CiAgICB9OwoKICAgIC8vU2V0IHVwIHdpdGggY29uZmlnIGluZm8uCiAgICByZXEoY2ZnKTsKfSh0aGlzKSk7Cg==", - "headers": [ - [ - "content-type", - "application/javascript" - ] - ], - "ok": true, - "status": 200, - "status_text": "" - } - } - }, - "colab_type": "code", - "id": "k0j5zzpAPSFn", - "outputId": "cb5b1d88-054b-413e-d303-428e63bce694" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \u003cscript src=\"/static/components/requirejs/require.js\"\u003e\u003c/script\u003e\n", - " \u003cscript\u003e\n", - " requirejs.config({\n", - " paths: {\n", - " base: '/static/base',\n", - " \"d3\": \"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min\",\n", - " jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n", - " },\n", - " });\n", - " \u003c/script\u003e\n", - " " - ], - "text/plain": [ - "\u003cIPython.core.display.HTML object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \u003cspan style=\"user-select:none\"\u003e\n", - " Layer: \u003cselect id=\"layer\"\u003e\u003c/select\u003e\n", - " Attention: \u003cselect id=\"att_type\"\u003e\n", - " \u003coption value=\"all\"\u003eAll\u003c/option\u003e\n", - " \u003coption value=\"inp_inp\"\u003eInput - Input\u003c/option\u003e\n", - " \u003coption value=\"inp_out\"\u003eInput - Output\u003c/option\u003e\n", - " \u003coption value=\"out_out\"\u003eOutput - Output\u003c/option\u003e\n", - " \u003c/select\u003e\n", - " \u003c/span\u003e\n", - " \u003cdiv id='vis'\u003e\u003c/div\u003e\n" - ], - "text/plain": [ - "\u003cIPython.core.display.HTML object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window.attention = {\"all\": {\"att\": [[[[0.05334341153502464, 0.025828205049037933, 0.062369391322135925, 0.043252814561128616, 0.4045393764972687, 0.06697215139865875, 0.09001608937978745, 0.14983074367046356, 0.10384786874055862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11816457659006119, 0.03106253407895565, 0.01979171112179756, 0.16624291241168976, 0.3321376442909241, 0.020051123574376106, 0.08730963617563248, 0.18211135268211365, 0.04312858730554581, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05936884880065918, 0.02174757793545723, 0.016160180792212486, 0.010601435787975788, 0.43925121426582336, 0.03876951336860657, 0.19815810024738312, 0.07065817713737488, 0.14528508484363556, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15478025376796722, 0.16446512937545776, 0.0578744001686573, 0.21637752652168274, 0.03835854306817055, 0.09130414575338364, 0.11191156506538391, 0.08360221982002258, 0.08132638782262802, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2183060646057129, 0.1704275906085968, 0.0827711746096611, 0.1202380359172821, 0.05203341320157051, 0.05958092212677002, 0.12280035018920898, 0.09366822242736816, 0.08017415553331375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05084313824772835, 0.026207493618130684, 0.13631564378738403, 0.012270472943782806, 0.16236551105976105, 0.02548854425549507, 0.03909383341670036, 0.03172134608030319, 0.5156941413879395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03615221381187439, 0.04799472168087959, 0.04255519434809685, 0.04762651398777962, 0.5117892622947693, 0.016304347664117813, 0.005770198069512844, 0.10897397249937057, 0.18283340334892273, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03243544325232506, 0.025252558290958405, 0.11733424663543701, 0.0250592939555645, 0.20289097726345062, 0.08240236341953278, 0.18285907804965973, 0.011341268196702003, 0.3204246759414673, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22355543076992035, 0.1260528564453125, 0.03741241991519928, 0.16813479363918304, 0.09858733415603638, 0.035831648856401443, 0.16361697018146515, 0.07236126810312271, 0.07444748282432556, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08996112644672394, 0.0921943336725235, 0.22672457993030548, 0.12702998518943787, 0.05907799303531647, 0.10712798684835434, 0.16789256036281586, 0.055181413888931274, 0.07481010258197784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9198169708251953, 0.0801829993724823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9412446618080139, 0.05875528231263161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8846490979194641, 0.10308036208152771, 0.012270578183233738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7461972832679749, 0.18569768965244293, 0.06810508668422699, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9307316541671753, 0.03309628367424011, 0.027538668364286423, 0.008633385412395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4299372434616089, 0.16845084726810455, 0.2029547393321991, 0.19865721464157104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9335180521011353, 0.020782457664608955, 0.008113296702504158, 0.029529055580496788, 0.008057110011577606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5215166807174683, 0.16121163964271545, 0.19463112950325012, 0.09347883611917496, 0.029161658138036728, 0.0, 0.0, 0.0, 0.0, 0.0], [0.923790454864502, 0.01269624661654234, 0.004588128533214331, 0.020286502316594124, 0.018672045320272446, 0.019966628402471542, 0.0, 0.0, 0.0, 0.0, 0.26405569911003113, 0.04358615726232529, 0.10687251389026642, 0.1710020899772644, 0.4105237126350403, 0.0039598336443305016, 0.0, 0.0, 0.0, 0.0], [0.5214514136314392, 0.051599469035863876, 0.007387364283204079, 0.04305899888277054, 0.0632161945104599, 0.07775087654590607, 0.2355356514453888, 0.0, 0.0, 0.0, 0.29189321398735046, 0.19170531630516052, 0.11295431852340698, 0.08274418860673904, 0.12850242853164673, 0.09739833325147629, 0.09480219334363937, 0.0, 0.0, 0.0], [0.9122877717018127, 0.007671441417187452, 0.0012418286642059684, 0.005250561982393265, 0.001960531808435917, 0.032091617584228516, 0.03012256510555744, 0.009373520500957966, 0.0, 0.0, 0.3496137857437134, 0.03085259348154068, 0.0195528082549572, 0.45414459705352783, 0.09152030944824219, 0.008845902979373932, 0.02992299199104309, 0.01554702315479517, 0.0, 0.0], [0.012450892478227615, 0.0001350480888504535, 0.0001820741599658504, 0.0018266986589878798, 0.00022605709091294557, 0.0032795630395412445, 0.005876350682228804, 0.012136856094002724, 0.9638864398002625, 0.0, 0.4675538241863251, 0.03941410034894943, 0.05400091037154198, 0.17985978722572327, 0.20104949176311493, 0.030323797836899757, 0.010615098290145397, 0.015154700726270676, 0.002028239192441106, 0.0], [0.907938539981842, 0.003707215888425708, 0.003004483412951231, 0.0008324749651364982, 0.0015859504928812385, 0.008079104125499725, 0.010460118763148785, 0.005838368553668261, 0.038938846439123154, 0.019614921882748604, 0.053565241396427155, 0.029699191451072693, 0.0156599972397089, 0.016939852386713028, 0.04015244543552399, 0.21933501958847046, 0.1449035257101059, 0.4037321209907532, 0.019583676010370255, 0.056428998708724976]], [[0.040477100759744644, 0.20988762378692627, 0.4869004786014557, 0.03505674749612808, 0.0558856800198555, 0.025423096492886543, 0.12231241166591644, 0.007062799762934446, 0.016993943601846695, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8996549844741821, 0.02599872276186943, 0.049097247421741486, 0.0040262676775455475, 0.0039152717217803, 0.0049644638784229755, 0.010553319938480854, 0.001352570834569633, 0.0004369009402580559, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.33065715432167053, 0.2687782049179077, 0.03312753140926361, 0.22958999872207642, 0.01851547136902809, 0.046473052352666855, 0.053183481097221375, 0.007113412953913212, 0.012561764568090439, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1589452475309372, 0.47470128536224365, 0.12878550589084625, 0.14158962666988373, 0.04442765936255455, 0.022274963557720184, 0.013780632056295872, 0.0024951419327408075, 0.012999956496059895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2559169828891754, 0.033451542258262634, 0.15095548331737518, 0.024318046867847443, 0.10824166238307953, 0.03234097361564636, 0.36475417017936707, 0.012823408469557762, 0.017197895795106888, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.021462664008140564, 0.010474847629666328, 0.007213775999844074, 0.02227940410375595, 0.21737068891525269, 0.4960675537586212, 0.014628118835389614, 0.20502059161663055, 0.005482145119458437, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06734316051006317, 0.09532227367162704, 0.1127309575676918, 0.009542002342641354, 0.0678786113858223, 0.12933993339538574, 0.03809814900159836, 0.44453269243240356, 0.035212237387895584, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10458365827798843, 0.02846018597483635, 0.029760979115962982, 0.014774680137634277, 0.022077379748225212, 0.1553817093372345, 0.3539015054702759, 0.19523507356643677, 0.09582491964101791, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.021077070385217667, 0.010932122357189655, 0.05088815093040466, 0.028641115874052048, 0.0881260335445404, 0.12014731019735336, 0.3900885581970215, 0.09544514119625092, 0.1946544349193573, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02552945166826248, 0.05594164505600929, 0.045791901648044586, 0.093170166015625, 0.03584437444806099, 0.0969511866569519, 0.18585819005966187, 0.17433671653270721, 0.28657644987106323, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4050312936306, 0.5949686765670776, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5249735116958618, 0.4750264883041382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2333158701658249, 0.39531010389328003, 0.37137407064437866, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3563348054885864, 0.5701623558998108, 0.07350286096334457, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.52278733253479, 0.11893566697835922, 0.28584957122802734, 0.07242746651172638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3398579955101013, 0.23167477548122406, 0.1957632154226303, 0.23270410299301147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23179638385772705, 0.09258762001991272, 0.103512242436409, 0.19472002983093262, 0.37738385796546936, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4351256191730499, 0.09737284481525421, 0.08845506608486176, 0.06574707478284836, 0.31329941749572754, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3839746117591858, 0.05338669568300247, 0.09416119009256363, 0.09689370542764664, 0.24871283769607544, 0.12287086993455887, 0.0, 0.0, 0.0, 0.0, 0.360861599445343, 0.02136792428791523, 0.005633710417896509, 0.009215844795107841, 0.15762653946876526, 0.4452943205833435, 0.0, 0.0, 0.0, 0.0], [0.5838866233825684, 0.02439245954155922, 0.042716383934020996, 0.03342103213071823, 0.08018141984939575, 0.15234005451202393, 0.08306187391281128, 0.0, 0.0, 0.0, 0.009015758521854877, 0.0013937305193394423, 0.00017763266805559397, 0.00016997012426145375, 0.010879353620111942, 0.0024589570239186287, 0.9759047627449036, 0.0, 0.0, 0.0], [0.639571487903595, 0.016348807141184807, 0.038869310170412064, 0.02800355665385723, 0.0377902127802372, 0.0529697984457016, 0.07620508968830109, 0.11024164408445358, 0.0, 0.0, 0.014776602387428284, 0.0001805058855097741, 1.6896785382414237e-05, 0.0003442507586441934, 0.006220621056854725, 0.0012393802171573043, 0.9433164596557617, 0.033905431628227234, 0.0, 0.0], [0.5836893320083618, 0.011862898245453835, 0.02550557814538479, 0.009363977238535881, 0.0196645837277174, 0.018125057220458984, 0.07040998339653015, 0.2077602595090866, 0.053618304431438446, 0.0, 0.005810329224914312, 0.002043980173766613, 0.0003433740057516843, 0.001522325212135911, 0.0030212807469069958, 0.00817712489515543, 0.5456522107124329, 0.10564129799604416, 0.32778817415237427, 0.0], [0.49946048855781555, 0.04904361814260483, 0.04135226085782051, 0.015084759332239628, 0.018269173800945282, 0.020069265738129616, 0.05080949887633324, 0.09452320635318756, 0.06869905441999435, 0.14268863201141357, 0.3754594326019287, 0.030579065904021263, 0.028458155691623688, 0.035943739116191864, 0.28040432929992676, 0.0202159583568573, 0.0396210215985775, 0.05075624957680702, 0.13473623991012573, 0.0038258912973105907]], [[0.18220090866088867, 0.25508272647857666, 0.2721964120864868, 0.04886331781744957, 0.010257811285555363, 0.07344724237918854, 0.08866558223962784, 0.037977367639541626, 0.0313086174428463, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5722172260284424, 0.09567929804325104, 0.1448327898979187, 0.033306267112493515, 0.0031244128476828337, 0.020944159477949142, 0.012691132724285126, 0.061001092195510864, 0.05620381608605385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.049244701862335205, 0.5266616344451904, 0.27518483996391296, 0.09334208071231842, 0.005858665332198143, 0.005467486567795277, 0.02565312758088112, 0.005746132228523493, 0.012841282412409782, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13445906341075897, 0.13356590270996094, 0.6041688919067383, 0.01878039538860321, 0.06342840194702148, 0.03677675500512123, 0.008389262482523918, 0.0002739423362072557, 0.00015757972141727805, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03273050859570503, 0.0697193592786789, 0.19719526171684265, 0.41500693559646606, 0.13721567392349243, 0.05743291601538658, 0.06517775356769562, 0.010865128599107265, 0.014656689018011093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031571000814437866, 0.014337136410176754, 0.06860436499118805, 0.09357307106256485, 0.10011686384677887, 0.07827721536159515, 0.5866308212280273, 0.011440092697739601, 0.015449290163815022, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006158333271741867, 0.001533387927338481, 0.05427416041493416, 0.005477452650666237, 0.02694696933031082, 0.8134917616844177, 0.02643686905503273, 0.050265438854694366, 0.015415593050420284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008847472257912159, 0.0066053420305252075, 0.036443497985601425, 0.021455924957990646, 0.019254589453339577, 0.11543811857700348, 0.1138116791844368, 0.20307059586048126, 0.4750728905200958, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017603449523448944, 0.008448019623756409, 0.004260394722223282, 0.006066101603209972, 0.013470137491822243, 0.01876576989889145, 0.16350960731506348, 0.1980665624141693, 0.5698099732398987, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10490093380212784, 0.014168650843203068, 0.0247807614505291, 0.018330294638872147, 0.009348674677312374, 0.02287398651242256, 0.032268356531858444, 0.10571902245283127, 0.6676092147827148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9956012964248657, 0.00439875153824687, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9630448818206787, 0.036955028772354126, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8920916318893433, 0.017498359084129333, 0.09041006118059158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8940342664718628, 0.015322646126151085, 0.09064316004514694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8103601336479187, 0.011479738168418407, 0.14884205162525177, 0.029318034648895264, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4866876006126404, 0.028273453935980797, 0.4569007158279419, 0.028138065710663795, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9073429107666016, 0.017702236771583557, 0.0008831396116875112, 0.017153160646557808, 0.05691858008503914, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7252220511436462, 0.10817205905914307, 0.07890959084033966, 0.017715180292725563, 0.06998112797737122, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7007134556770325, 0.00013011474220547825, 0.0017889889422804117, 0.00429273396730423, 0.20973503589630127, 0.08333952724933624, 0.0, 0.0, 0.0, 0.0, 0.8598019480705261, 0.012843498960137367, 0.014502299018204212, 0.004056263715028763, 0.10580158233642578, 0.0029942472465336323, 0.0, 0.0, 0.0, 0.0], [0.8020992279052734, 0.0005838978104293346, 0.0002877263759728521, 0.000665249943267554, 0.00924165453761816, 0.10947777330875397, 0.07764454185962677, 0.0, 0.0, 0.0, 0.8686293363571167, 0.024889284744858742, 0.013860221020877361, 0.00703870365396142, 0.07120370119810104, 0.003939351066946983, 0.010439489968121052, 0.0, 0.0, 0.0], [0.936653733253479, 0.00026242269086651504, 0.0004762547614518553, 0.000683068297803402, 0.0005867508007213473, 0.008624686859548092, 0.044821251183748245, 0.00789186917245388, 0.0, 0.0, 0.8572709560394287, 0.018014011904597282, 0.008267350494861603, 0.0022140766959637403, 0.1038530021905899, 0.004275611136108637, 0.0009780752006918192, 0.005126776173710823, 0.0, 0.0], [0.638530433177948, 0.00012756754586007446, 2.6267471184837632e-05, 0.035790614783763885, 0.00038457714254036546, 0.0026843701489269733, 0.0740678533911705, 0.21536435186862946, 0.03302408382296562, 0.0, 0.35013046860694885, 0.0037752145435661077, 0.0071558705531060696, 0.01608894392848015, 0.6097922325134277, 0.002463925164192915, 0.0005387101555243134, 0.005540961865335703, 0.004513624589890242, 0.0], [0.9069857597351074, 0.0010905838571488857, 0.0003166680980939418, 0.0021527763456106186, 0.00019805191550403833, 0.0004849489778280258, 0.025774035602808, 0.02642407827079296, 0.01662513054907322, 0.01994791068136692, 0.1888049989938736, 0.12293454259634018, 0.5947631597518921, 0.009457849897444248, 0.07291270792484283, 0.008950368501245975, 0.0004109511792194098, 0.000914009811822325, 0.0006959570455364883, 0.00015547229850199074]], [[0.2071455419063568, 0.637531578540802, 0.06835082173347473, 0.011966697871685028, 0.0017193991225212812, 0.04911382868885994, 0.009478496387600899, 0.008040529675781727, 0.00665308628231287, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07411027699708939, 0.15093472599983215, 0.2656005620956421, 0.05758262053132057, 0.05194409564137459, 0.23625947535037994, 0.019166678190231323, 0.04010465368628502, 0.10429693013429642, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1540999412536621, 0.10598444193601608, 0.22474077343940735, 0.32441702485084534, 0.1116243302822113, 0.054135363548994064, 0.008848286233842373, 0.004088098648935556, 0.012061581946909428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019440434873104095, 0.00560638727620244, 0.0035774046555161476, 0.0888679027557373, 0.7120485901832581, 0.14891275763511658, 0.011600993573665619, 0.008666431531310081, 0.0012791723711416125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08580154180526733, 0.02444172091782093, 0.08060747385025024, 0.05198557302355766, 0.2700504660606384, 0.34216371178627014, 0.11280739307403564, 0.006445358972996473, 0.02569655328989029, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0424385629594326, 0.029667967930436134, 0.006252861116081476, 0.020168066024780273, 0.03000665083527565, 0.2812231779098511, 0.49279165267944336, 0.09351769089698792, 0.003933228086680174, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006467411294579506, 0.0076894015073776245, 0.008325580507516861, 0.0010907554533332586, 0.01040297094732523, 0.19462232291698456, 0.013263629749417305, 0.24681615829467773, 0.5113216042518616, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.028696376830339432, 0.014982450753450394, 0.011884906329214573, 0.0011242942418903112, 0.01692844182252884, 0.12885364890098572, 0.028225399553775787, 0.6451764106750488, 0.12412811070680618, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16117365658283234, 0.06794824451208115, 0.06173194944858551, 0.00451233983039856, 0.05306624248623848, 0.0510348416864872, 0.04402391240000725, 0.12432018667459488, 0.4321887195110321, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1690559983253479, 0.043453093618154526, 0.036818861961364746, 0.017293656244874, 0.11775903403759003, 0.07970321178436279, 0.043801818042993546, 0.06849095970392227, 0.4236232340335846, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9964158535003662, 0.0035840808413922787, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.91131192445755, 0.08868805319070816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.603236198425293, 0.29069802165031433, 0.10606581717729568, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.786292314529419, 0.09286607056856155, 0.1208416074514389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7401933073997498, 0.005742713809013367, 0.18690980970859528, 0.06715414673089981, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1722075194120407, 0.10747934877872467, 0.1462225317955017, 0.5740904808044434, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9087624549865723, 0.0078224902972579, 0.003505129599943757, 0.0673881471157074, 0.012521738186478615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1893281787633896, 0.1733204573392868, 0.06838839501142502, 0.47577211260795593, 0.09319086372852325, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7394620180130005, 0.0234938096255064, 0.009907918982207775, 0.01616108976304531, 0.1237591803073883, 0.08721596747636795, 0.0, 0.0, 0.0, 0.0, 0.08935888856649399, 0.012517428956925869, 0.017112966626882553, 0.08479276299476624, 0.7640082240104675, 0.03220977261662483, 0.0, 0.0, 0.0, 0.0], [0.9526587724685669, 0.007287254091352224, 0.0013716809917241335, 0.0023222684394568205, 0.007607423700392246, 0.009167732670903206, 0.01958492584526539, 0.0, 0.0, 0.0, 0.824190616607666, 0.008810147643089294, 0.002143737394362688, 0.002297793049365282, 0.11996792256832123, 0.005709697026759386, 0.036880046129226685, 0.0, 0.0, 0.0], [0.9270981550216675, 0.004809631034731865, 0.0030887839384377003, 0.005205564666539431, 0.018441975116729736, 0.006030889227986336, 0.03003735840320587, 0.0052877976559102535, 0.0, 0.0, 0.1513449102640152, 0.015725232660770416, 0.02784004621207714, 0.01800909824669361, 0.6534391641616821, 0.016422629356384277, 0.09054289758205414, 0.026676079258322716, 0.0, 0.0], [0.603268563747406, 0.009098237380385399, 0.00021995518181938678, 0.07179546356201172, 0.0017328117974102497, 0.01055157370865345, 0.020978767424821854, 0.2736198902130127, 0.008734744042158127, 0.0, 0.1625923067331314, 0.016224535182118416, 0.06514906883239746, 0.003223034320399165, 0.6737184524536133, 0.014129054732620716, 0.036937959492206573, 0.023035621270537376, 0.004990031942725182, 0.0], [0.6497007608413696, 0.0906025841832161, 0.0100435521453619, 0.007925360463559628, 0.013416239991784096, 0.0018666544929146767, 0.02140365168452263, 0.08128199726343155, 0.04188578948378563, 0.08187359571456909, 0.06836045533418655, 0.01236770860850811, 0.008784784935414791, 0.014186863787472248, 0.09790214896202087, 0.046204064041376114, 0.1703491061925888, 0.1878211945295334, 0.0703599750995636, 0.32366377115249634]], [[0.03085354156792164, 0.12322185933589935, 0.13651973009109497, 0.050716523081064224, 0.2999139726161957, 0.09802427887916565, 0.06620478630065918, 0.0782310962677002, 0.11631430685520172, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06789751350879669, 0.058182138949632645, 0.3129631578922272, 0.04353875666856766, 0.09142065048217773, 0.10271093249320984, 0.026392055675387383, 0.09630800783634186, 0.2005866914987564, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07152411341667175, 0.3454192876815796, 0.11299439519643784, 0.18012462556362152, 0.07151429355144501, 0.052652161568403244, 0.0567985400557518, 0.09459780901670456, 0.014374655671417713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10420235246419907, 0.21845531463623047, 0.19832336902618408, 0.022119704633951187, 0.13572701811790466, 0.07722532749176025, 0.0508468933403492, 0.045597679913043976, 0.14750221371650696, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07030870020389557, 0.10706955939531326, 0.02791348285973072, 0.02260597050189972, 0.12725059688091278, 0.07336997240781784, 0.26662203669548035, 0.16957008838653564, 0.13528966903686523, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05156806856393814, 0.04327721148729324, 0.07664787024259567, 0.06931594759225845, 0.1889398992061615, 0.09515503793954849, 0.07227510958909988, 0.2641449272632599, 0.13867592811584473, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02184019424021244, 0.11184182018041611, 0.36672860383987427, 0.013787303119897842, 0.07600502669811249, 0.0389828234910965, 0.040494974702596664, 0.12485849112272263, 0.20546066761016846, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013738485053181648, 0.05187288299202919, 0.03463537245988846, 0.03627979755401611, 0.048659998923540115, 0.02440205216407776, 0.07256433367729187, 0.024731382727622986, 0.6931155323982239, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02671198360621929, 0.4013687074184418, 0.01132842618972063, 0.14022575318813324, 0.026275552809238434, 0.08107840269804001, 0.04189194366335869, 0.25432130694389343, 0.0167979933321476, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14228780567646027, 0.07866450399160385, 0.08390624076128006, 0.09396661072969437, 0.087954580783844, 0.14498625695705414, 0.13517630100250244, 0.1169552430510521, 0.11610251665115356, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9857779741287231, 0.014221975579857826, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.961704432964325, 0.038295578211545944, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9197340607643127, 0.07413885742425919, 0.0061270855367183685, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.37462106347084045, 0.2157517969608307, 0.40962719917297363, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8673564195632935, 0.016403868794441223, 0.1017053872346878, 0.014534366317093372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.48521965742111206, 0.031020229682326317, 0.3760664165019989, 0.10769358277320862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.044595908373594284, 0.010755550116300583, 0.002565854461863637, 0.9345642328262329, 0.007518457714468241, 0.0, 0.0, 0.0, 0.0, 0.0, 0.914044201374054, 0.004715718794614077, 0.006151301320642233, 0.005079128313809633, 0.07000966370105743, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4605148434638977, 0.007289387751370668, 0.009601963683962822, 0.08598940074443817, 0.4091304838657379, 0.027473902329802513, 0.0, 0.0, 0.0, 0.0, 0.060511741787195206, 0.006127620115876198, 0.00728148128837347, 0.013585635460913181, 0.9084653854370117, 0.004028240218758583, 0.0, 0.0, 0.0, 0.0], [0.8714936971664429, 0.002528996206820011, 0.0021269593853503466, 0.0052809687331318855, 0.02593054249882698, 0.07010670751333237, 0.022532090544700623, 0.0, 0.0, 0.0, 0.23348243534564972, 0.03748093172907829, 0.055222347378730774, 0.014132470823824406, 0.27614685893058777, 0.017582375556230545, 0.3659524619579315, 0.0, 0.0, 0.0], [0.507957398891449, 0.003823956474661827, 0.004157013725489378, 0.018131878226995468, 0.06916838884353638, 0.047881923615932465, 0.2798653542995453, 0.06901402771472931, 0.0, 0.0, 0.06461911648511887, 0.003781915409490466, 0.002705940278246999, 0.016099220141768456, 0.8774597644805908, 0.012668337672948837, 0.0088069261983037, 0.013858767226338387, 0.0, 0.0], [0.4575899839401245, 0.005646431352943182, 0.0004441867640707642, 0.03129462152719498, 0.014414624311029911, 0.0058625745587050915, 0.09207130968570709, 0.34311652183532715, 0.04955975338816643, 0.0, 0.05451222136616707, 0.014412143267691135, 0.00208102585747838, 0.011283651925623417, 0.02552390843629837, 0.02239326573908329, 0.031104939058423042, 0.20777365565299988, 0.630915105342865, 0.0], [0.8105311393737793, 0.0010255038505420089, 0.0001402802881784737, 0.0005781117943115532, 0.00122542935423553, 0.000594198820181191, 0.02804729714989662, 0.01081023644655943, 0.13665232062339783, 0.010395429097115993, 0.5451503992080688, 0.014764615334570408, 0.2503703534603119, 0.037022024393081665, 0.0935375839471817, 0.022694993764162064, 0.0037449353840202093, 0.0053339023143053055, 0.007315538357943296, 0.020065704360604286]], [[0.02165721170604229, 0.018354326486587524, 0.6383510828018188, 0.042513273656368256, 0.10956817120313644, 0.10717540234327316, 0.030344119295477867, 0.015826348215341568, 0.01621006615459919, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4647374749183655, 0.07284841686487198, 0.28081396222114563, 0.014013433828949928, 0.03169411048293114, 0.02214456908404827, 0.058711059391498566, 0.036629818379879, 0.01840737834572792, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07372704148292542, 0.12858736515045166, 0.4501189887523651, 0.054217785596847534, 0.07096204906702042, 0.05748127028346062, 0.06541819125413895, 0.04703349620103836, 0.05245373025536537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04684445261955261, 0.019098779186606407, 0.008431704714894295, 0.0010175607167184353, 0.9129327535629272, 0.004866998642683029, 0.006678053177893162, 8.096762758214027e-05, 4.903498847852461e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08239725232124329, 0.02813413366675377, 0.16611848771572113, 0.1532817929983139, 0.07408729940652847, 0.10856874287128448, 0.047752734273672104, 0.02563621662557125, 0.31402355432510376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17959792912006378, 0.02262653037905693, 0.10724494606256485, 0.022216446697711945, 0.1862414926290512, 0.14705143868923187, 0.15912717580795288, 0.15293282270431519, 0.02296125516295433, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.038375359028577805, 0.0038853511214256287, 0.06201936677098274, 0.005828780122101307, 0.22059503197669983, 0.36631014943122864, 0.020396992564201355, 0.20976856350898743, 0.07282061129808426, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014258276671171188, 0.005652762018144131, 0.025611618533730507, 0.15294744074344635, 0.06760217249393463, 0.2498260736465454, 0.1669282466173172, 0.2265811711549759, 0.09059228003025055, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15833799540996552, 0.1228356659412384, 0.10147804021835327, 0.0284584891051054, 0.27955442667007446, 0.06763719022274017, 0.08874277770519257, 0.1152903363108635, 0.037665050476789474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09844867885112762, 0.0919492095708847, 0.028445947915315628, 0.03726689890027046, 0.035665158182382584, 0.06817072629928589, 0.29930955171585083, 0.09819743037223816, 0.2425464242696762, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8512031435966492, 0.14879685640335083, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9904667735099792, 0.009533224627375603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10041537135839462, 0.8953256011009216, 0.0042589944787323475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9818503260612488, 0.007338901981711388, 0.010810752399265766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6295948624610901, 0.2121732085943222, 0.10306572169065475, 0.055166181176900864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9738979935646057, 0.007647394668310881, 0.015154722146689892, 0.0032999368850141764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9503376483917236, 0.007425909396260977, 0.0019253676291555166, 0.025024304166436195, 0.015286784619092941, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6611008644104004, 0.04138284549117088, 0.1119912639260292, 0.0262944046407938, 0.15923058986663818, 0.0, 0.0, 0.0, 0.0, 0.0], [0.24298420548439026, 0.06981680542230606, 0.030552756041288376, 0.020666545256972313, 0.46177101135253906, 0.1742086559534073, 0.0, 0.0, 0.0, 0.0, 0.9380988478660583, 0.005562208592891693, 0.01078465860337019, 0.004562946502119303, 0.033130958676338196, 0.007860423997044563, 0.0, 0.0, 0.0, 0.0], [0.8132306933403015, 0.003601218806579709, 0.01019350253045559, 0.009439423680305481, 0.040081463754177094, 0.07570415735244751, 0.04774952307343483, 0.0, 0.0, 0.0, 0.9377894997596741, 0.003691342193633318, 0.002771170577034354, 0.0017416415503248572, 0.04246653988957405, 0.002464305842295289, 0.009075501933693886, 0.0, 0.0, 0.0], [0.6454712152481079, 0.006356438156217337, 0.006696825381368399, 0.0020169378258287907, 0.11416922509670258, 0.11139311641454697, 0.07912010699510574, 0.03477614000439644, 0.0, 0.0, 0.9083399176597595, 0.005597027484327555, 0.02609928511083126, 0.005710097029805183, 0.017865832895040512, 0.0029857312329113483, 0.002900469582527876, 0.030501706525683403, 0.0, 0.0], [0.22032444179058075, 0.0006508066435344517, 0.006827942095696926, 0.028858821839094162, 0.0022757677361369133, 0.006474251858890057, 0.09447979182004929, 0.6212162375450134, 0.018891895189881325, 0.0, 0.8338009119033813, 0.00436164066195488, 0.006190306507050991, 0.0008050849428400397, 0.015337309800088406, 0.00863864365965128, 0.010715007781982422, 0.1143304780125618, 0.005820483900606632, 0.0], [0.03250038996338844, 0.0005526043241843581, 2.807211239996832e-05, 0.00014761221245862544, 0.00482193985953927, 7.781770545989275e-05, 0.00014718669990543276, 0.0008632297394797206, 0.959712028503418, 0.0011490467004477978, 0.9085996747016907, 0.00676243519410491, 0.02013525180518627, 0.009278967045247555, 0.02104269526898861, 0.009343095123767853, 0.0009470531367696822, 0.0018253516172990203, 0.003784958738833666, 0.018280424177646637]], [[0.02519470639526844, 0.006357265170663595, 0.14269335567951202, 0.023629529401659966, 0.3124701976776123, 0.13565225899219513, 0.2595662772655487, 0.07959114015102386, 0.014845297671854496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04550129547715187, 0.011541971005499363, 0.1165909469127655, 0.02512240968644619, 0.01843150518834591, 0.05711649730801582, 0.44489097595214844, 0.033205363899469376, 0.24759893119335175, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13528011739253998, 0.06777236610651016, 0.14429129660129547, 0.04697401076555252, 0.1738385707139969, 0.014099549502134323, 0.38417065143585205, 0.01158357597887516, 0.02199004776775837, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.21356959640979767, 0.1638900637626648, 0.10595463216304779, 0.06925727427005768, 0.167257159948349, 0.04259340837597847, 0.10967854410409927, 0.03570139408111572, 0.09209771454334259, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20140984654426575, 0.04755665361881256, 0.15174560248851776, 0.11619894206523895, 0.21928974986076355, 0.07600340992212296, 0.05828682705760002, 0.10010629147291183, 0.029402663931250572, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.024259669706225395, 0.02116699516773224, 0.21201731264591217, 0.019622934982180595, 0.4893963038921356, 0.021304504945874214, 0.16948339343070984, 0.022949064150452614, 0.01979990489780903, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022248759865760803, 0.01183647196739912, 0.0633181631565094, 0.029095010831952095, 0.07090882211923599, 0.4614315629005432, 0.020150773227214813, 0.18720205128192902, 0.1338084638118744, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003461656626313925, 0.01603432185947895, 0.009874427691102028, 0.014947548508644104, 0.2953553795814514, 0.3502987027168274, 0.08878874033689499, 0.036094941198825836, 0.18514421582221985, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005101516842842102, 0.022985950112342834, 0.007523353211581707, 0.026773063465952873, 0.01009095273911953, 0.014858697541058064, 0.15149906277656555, 0.028601571917533875, 0.7325656414031982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12995873391628265, 0.07769863307476044, 0.02032659947872162, 0.13720010221004486, 0.011713794432580471, 0.054615918546915054, 0.23920413851737976, 0.13190706074237823, 0.19737498462200165, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9700191020965576, 0.029980869963765144, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.972051739692688, 0.027948210015892982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7072298526763916, 0.2173422873020172, 0.07542789727449417, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7552067041397095, 0.17251533269882202, 0.0722779706120491, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5017270445823669, 0.10517530888319016, 0.32087045907974243, 0.07222715020179749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6455309987068176, 0.23265127837657928, 0.10187581926584244, 0.01994187943637371, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39005738496780396, 0.2261916995048523, 0.1838584840297699, 0.10916081070899963, 0.09073163568973541, 0.0, 0.0, 0.0, 0.0, 0.0, 0.470674991607666, 0.26442891359329224, 0.14268451929092407, 0.03363766148686409, 0.08857394009828568, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11122927069664001, 0.04386316239833832, 0.023478534072637558, 0.07375308126211166, 0.5692906379699707, 0.17838534712791443, 0.0, 0.0, 0.0, 0.0, 0.6457618474960327, 0.011289404705166817, 0.008832731284201145, 0.01570025272667408, 0.2588561475276947, 0.059559762477874756, 0.0, 0.0, 0.0, 0.0], [0.16762810945510864, 0.030268238857388496, 0.015392551198601723, 0.05242612585425377, 0.21519990265369415, 0.34948840737342834, 0.16959665715694427, 0.0, 0.0, 0.0, 0.4916176497936249, 0.07200384140014648, 0.0701020285487175, 0.019148536026477814, 0.0833231583237648, 0.12199999392032623, 0.14180481433868408, 0.0, 0.0, 0.0], [0.15348000824451447, 0.03554287180304527, 0.008979924954473972, 0.07115276902914047, 0.08698276430368423, 0.24143245816230774, 0.28553345799446106, 0.11689584702253342, 0.0, 0.0, 0.11119699478149414, 0.002801541704684496, 0.0021932011004537344, 0.0016493132570758462, 0.06827285885810852, 0.22499483823776245, 0.5049597024917603, 0.08393163233995438, 0.0, 0.0], [0.09456975758075714, 0.010759694501757622, 0.0067994119599461555, 0.01042863354086876, 0.05627141892910004, 0.11228546500205994, 0.14361944794654846, 0.3204572796821594, 0.2448090761899948, 0.0, 0.13208742439746857, 0.0035411729477345943, 0.0015305017586797476, 0.002489483682438731, 0.06612236052751541, 0.213859423995018, 0.5324232578277588, 0.03503565117716789, 0.012910734862089157, 0.0], [0.057867951691150665, 0.02229062095284462, 0.016399098560214043, 0.02521427348256111, 0.047808028757572174, 0.03428687900304794, 0.05170976370573044, 0.19979508221149445, 0.41991233825683594, 0.12471600621938705, 0.20209012925624847, 0.05223073810338974, 0.03088257648050785, 0.036374326795339584, 0.014660456217825413, 0.03045688569545746, 0.03597142919898033, 0.16862399876117706, 0.022359324619174004, 0.40635016560554504]], [[0.21207179129123688, 0.11920439451932907, 0.4251355528831482, 0.014464439824223518, 0.20776884257793427, 0.01428140513598919, 0.0027938869316130877, 0.001743048895150423, 0.002536489861086011, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.046175818890333176, 0.026793524622917175, 0.8552185297012329, 0.04517081379890442, 0.010388500988483429, 0.004191457759588957, 0.0036751439329236746, 0.0013485046802088618, 0.007037981878966093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013186579570174217, 0.020899420604109764, 0.6900137662887573, 0.0480119027197361, 0.15360434353351593, 0.02344118244946003, 0.03952033817768097, 0.0038994532078504562, 0.007422822527587414, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006273405160754919, 0.00015674144378863275, 0.000751359446439892, 0.00447711581364274, 0.9859057664871216, 0.002212332095950842, 0.00014360185014083982, 4.957199053023942e-05, 2.9913859179941937e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.001047183177433908, 0.0003636489564087242, 0.009283728897571564, 0.016805388033390045, 0.42387446761131287, 0.4776095747947693, 0.06253702938556671, 0.005590841174125671, 0.002888289513066411, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0018647151300683618, 0.0002549054042901844, 2.6050107408082113e-05, 2.586200753285084e-05, 0.0024472770746797323, 0.006814199965447187, 0.9776560664176941, 0.010138182900846004, 0.000773087958805263, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.047241877764463425, 0.006076885852962732, 0.04534892365336418, 0.00081661093281582, 0.087706059217453, 0.41394293308258057, 0.21876952052116394, 0.17005810141563416, 0.0100388890132308, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0019138919888064265, 0.006189406383782625, 0.010115097276866436, 8.508542669005692e-05, 0.008424345403909683, 0.003492203773930669, 0.13495568931102753, 0.4890870749950409, 0.34573695063591003, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.016032341867685318, 0.005025702994316816, 0.009520799852907658, 0.0008855267078615725, 0.026489384472370148, 0.0020503124687820673, 0.032939448952674866, 0.09461060166358948, 0.8124459385871887, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25683313608169556, 0.02960006147623062, 0.11211041361093521, 0.09736908972263336, 0.17546677589416504, 0.032068025320768356, 0.017857572063803673, 0.025635067373514175, 0.25305992364883423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9535994529724121, 0.04640045389533043, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9218347668647766, 0.0781652107834816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8665578961372375, 0.09402694553136826, 0.03941517323255539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4189925193786621, 0.4865715503692627, 0.09443587809801102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8201385140419006, 0.07587680220603943, 0.05075912922620773, 0.053225547075271606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.48251789808273315, 0.34758540987968445, 0.13321316242218018, 0.036683470010757446, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6245242953300476, 0.093341164290905, 0.11281723529100418, 0.1092497780919075, 0.06006752699613571, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8504839539527893, 0.033341050148010254, 0.053517427295446396, 0.012789242900907993, 0.049868300557136536, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5755861401557922, 0.0864969864487648, 0.10001320391893387, 0.12654373049736023, 0.06871193647384644, 0.04264802858233452, 0.0, 0.0, 0.0, 0.0, 0.4515743553638458, 0.03267433121800423, 0.019386781379580498, 0.024256065487861633, 0.17900733649730682, 0.29310107231140137, 0.0, 0.0, 0.0, 0.0], [0.6500274538993835, 0.06470640748739243, 0.047299426048994064, 0.08855419605970383, 0.06197808310389519, 0.04487667977809906, 0.04255769029259682, 0.0, 0.0, 0.0, 0.5910289883613586, 0.0027754076290875673, 0.004533650353550911, 0.0023315453436225653, 0.08002334088087082, 0.06913208961486816, 0.2501751184463501, 0.0, 0.0, 0.0], [0.5771223902702332, 0.0491044707596302, 0.09411156177520752, 0.06903567165136337, 0.04109871760010719, 0.06523709744215012, 0.06637011468410492, 0.03792000934481621, 0.0, 0.0, 0.1626552939414978, 0.0011573631782084703, 0.00017211545491591096, 0.0007665579323656857, 0.03241841867566109, 0.34369325637817383, 0.2890424132347107, 0.17009468376636505, 0.0, 0.0], [0.4695849120616913, 0.017787985503673553, 0.06290572881698608, 0.06516575813293457, 0.09894091635942459, 0.03647425398230553, 0.051347069442272186, 0.08907806128263474, 0.10871540009975433, 0.0, 0.10835989564657211, 0.0007107920246198773, 0.00030798258376307786, 0.005807099863886833, 0.04662986099720001, 0.1659584492444992, 0.3522194027900696, 0.30094781517982483, 0.019058646634221077, 0.0], [0.18501408398151398, 0.040740884840488434, 0.10466982424259186, 0.07660976052284241, 0.17033715546131134, 0.05819392204284668, 0.0898737907409668, 0.09184892475605011, 0.10470453649759293, 0.0780070349574089, 0.5449283123016357, 0.01310307253152132, 0.008020865730941296, 0.006764447782188654, 0.16009773313999176, 0.06950337439775467, 0.0024397175293415785, 0.014089844189584255, 0.013654321432113647, 0.1673980951309204]]], [[[0.10487863421440125, 0.7106320858001709, 0.1635318249464035, 0.011256101541221142, 0.0012767312582582235, 0.00310636218637228, 0.0013001860352233052, 0.0012553841806948185, 0.002762428717687726, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.021650908514857292, 0.0030605364590883255, 0.6595932245254517, 0.2987315356731415, 0.012945608235895634, 0.0028472936246544123, 7.557096250820905e-05, 0.00029089683084748685, 0.0008047237643040717, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014272261410951614, 0.040512338280677795, 0.8595607280731201, 0.038314104080200195, 0.037397123873233795, 0.006795509252697229, 0.001303989440202713, 0.001011757180094719, 0.0008321924251504242, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031783342361450195, 0.007319662719964981, 0.7663278579711914, 0.0010118860518559813, 0.1672297865152359, 0.02513650804758072, 0.000853335193824023, 0.0002817189379129559, 5.600590884569101e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002136597875505686, 0.00037253598566167057, 0.07588302344083786, 0.2252500057220459, 0.33551687002182007, 0.35751965641975403, 0.0027331046294420958, 0.00018122239271178842, 0.0004068210837431252, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0004353485128376633, 0.0003557991876732558, 0.0003262429090682417, 0.003819868667051196, 0.33603885769844055, 0.2681770920753479, 0.3838857412338257, 0.0068349516950547695, 0.00012614508159458637, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [6.71677480568178e-05, 3.9912600186653435e-05, 0.00047830803669057786, 5.937727837590501e-05, 0.0014537296956405044, 0.6413838863372803, 0.29047340154647827, 0.06565171480178833, 0.0003929881495423615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00047039391938596964, 0.0007891620043665171, 0.0007817292353138328, 0.0010076714679598808, 0.00965806283056736, 0.003733346238732338, 0.35330116748809814, 0.5722718238830566, 0.05798657611012459, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006178696174174547, 0.009340841323137283, 0.0005589249776676297, 0.005146770738065243, 0.0033258567564189434, 0.0016933922888711095, 0.06414961069822311, 0.3291752338409424, 0.5804308652877808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006624103523790836, 0.001978900283575058, 0.0081730792298913, 0.0030846702866256237, 0.0018904987955465913, 0.0014340116176754236, 0.005187559872865677, 0.029854312539100647, 0.9417726993560791, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10875418037176132, 0.15107707679271698, 0.07560893893241882, 0.11182637512683868, 0.051575273275375366, 0.1800614595413208, 0.13901139795780182, 0.11257244646549225, 0.06951297074556351, 0.0, 0.03246883675456047, 0.020431363955140114, 0.06294436007738113, 0.08282972872257233, 0.047490958124399185, 0.03976213559508324, 0.01868664100766182, 0.5054241418838501, 0.18996170163154602, 0.0], [0.04530828073620796, 0.11530135571956635, 0.03132164478302002, 0.12301183491945267, 0.01339547149837017, 0.009322633035480976, 0.0069213854148983955, 0.181557297706604, 0.47386014461517334, 0.0, 0.0334412157535553, 0.45350977778434753, 0.23828978836536407, 0.07703227549791336, 0.02545342594385147, 0.019935714080929756, 0.007961008697748184, 0.08864670246839523, 0.05572996661067009, 0.0], [0.08671615272760391, 0.21926835179328918, 0.11249969899654388, 0.05250205472111702, 0.044286634773015976, 0.006910341326147318, 0.004434189759194851, 0.00961831770837307, 0.4637643098831177, 0.0, 0.008816813118755817, 0.009350132197141647, 0.09488566964864731, 0.022458655759692192, 0.001578008639626205, 0.01768183708190918, 0.0012928039068356156, 0.7889453768730164, 0.05499071627855301, 0.0], [0.016148164868354797, 0.08668603748083115, 0.1414848268032074, 0.024200299754738808, 0.018711188808083534, 0.02537006139755249, 0.017450006678700447, 0.039331331849098206, 0.6306182146072388, 0.0, 0.0037117439787834883, 0.00603569345548749, 0.019362367689609528, 0.06632085889577866, 0.02251342497766018, 0.048607613891363144, 0.00711278198286891, 0.7890322804450989, 0.03730323165655136, 0.0], [0.024489276111125946, 0.03301851078867912, 0.03003605268895626, 0.03562680631875992, 0.06981870532035828, 0.022592445835471153, 0.025447512045502663, 0.03545365110039711, 0.7235170006752014, 0.0, 0.0017165049212053418, 0.0031809706706553698, 0.00569736585021019, 0.027958940714597702, 0.001130971242673695, 0.006313299294561148, 0.004051794297993183, 0.9312260150909424, 0.018723946064710617, 0.0], [0.05760658532381058, 0.08793947100639343, 0.053903114050626755, 0.0679689273238182, 0.007038408424705267, 0.007889931090176105, 0.010035911574959755, 0.019540006294846535, 0.6880777478218079, 0.0, 0.0028915719594806433, 0.007050157990306616, 0.004614752251654863, 0.0017270235111936927, 0.0016248916508629918, 0.06901240348815918, 0.005150379613041878, 0.13293159008026123, 0.7749972939491272, 0.0], [0.045610494911670685, 0.042210742831230164, 0.14248158037662506, 0.03233090415596962, 0.03048519603908062, 0.011738738045096397, 0.014284060336649418, 0.006383211817592382, 0.6744750738143921, 0.0, 0.005032604560256004, 0.005055313929915428, 0.0030569147784262896, 0.0010687477188184857, 0.012304573319852352, 0.013984610326588154, 0.3489484190940857, 0.012370014563202858, 0.5981789827346802, 0.0], [0.096277616918087, 0.030696624889969826, 0.10220203548669815, 0.04915016517043114, 0.047845132648944855, 0.05814794450998306, 0.06954183429479599, 0.028650736436247826, 0.5174878835678101, 0.0, 0.0019784842152148485, 0.009333183988928795, 0.005381024908274412, 0.0002465381403453648, 0.0013898308388888836, 0.005461550783365965, 0.0012134313583374023, 0.001065099611878395, 0.9739308953285217, 0.0], [0.009306053631007671, 0.02153283730149269, 0.009718294255435467, 0.005953253246843815, 0.011703923344612122, 0.017902903258800507, 0.011090915650129318, 0.01645584963262081, 0.8963360786437988, 0.0, 0.005657540168613195, 0.006781480740755796, 0.00696007814258337, 0.0009338636882603168, 0.02429838851094246, 0.03842600807547569, 0.00286328443326056, 0.03579647094011307, 0.8782829642295837, 0.0], [0.009895006194710732, 0.026821313425898552, 0.16079027950763702, 0.01761648990213871, 0.01726638339459896, 0.08361288905143738, 0.039622098207473755, 0.14411716163158417, 0.5002583861351013, 0.0, 0.007395321968942881, 0.012293249368667603, 0.006963892374187708, 0.00022730379714630544, 0.0005401583621278405, 0.005707587581127882, 0.0028992195148020983, 0.0027063635643571615, 0.9612669944763184, 0.0]], [[0.17277710139751434, 0.13871003687381744, 0.020699918270111084, 0.04190761595964432, 0.17760643362998962, 0.1702892780303955, 0.16168300807476044, 0.10000763088464737, 0.01631900854408741, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9987638592720032, 0.0011447033612057567, 1.5495901607209817e-05, 2.3805538096333123e-10, 1.1166920899086108e-07, 4.81009180930414e-07, 2.3257289285538718e-05, 3.4320622944505885e-05, 1.812833215808496e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029870687052607536, 0.9668734669685364, 0.0031853404361754656, 3.7420595617732033e-06, 1.0481591772304455e-07, 4.711453893690987e-09, 4.051101996083162e-07, 1.359390239485947e-06, 6.518688314827159e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.9839180569979362e-05, 0.0008244949858635664, 0.9990562796592712, 6.778111855965108e-05, 2.14482715819031e-05, 5.3428358959273226e-11, 7.202954205309808e-11, 7.697720239008277e-11, 1.422941551254553e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [9.680035873316228e-05, 4.205659934086725e-05, 0.0021876851096749306, 0.9926192164421082, 0.0050464412197470665, 7.330636890401365e-06, 4.7689670878980905e-08, 8.238330573284713e-10, 9.979119397485192e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.136659183335723e-06, 6.750806136324172e-08, 8.17252839624416e-06, 0.008817464113235474, 0.9640147089958191, 0.027066770941019058, 8.771067950874567e-05, 3.571775764044105e-09, 3.5257423647294672e-09, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.115869043947896e-07, 1.0059281407848175e-08, 1.3136859422502312e-07, 9.641905052149013e-08, 0.001335342414677143, 0.9957214593887329, 0.0029362423811107874, 7.136273325158982e-06, 1.1521567699901425e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.561131961760111e-06, 2.727877870256634e-07, 8.369554507225985e-07, 1.214864764342849e-09, 4.873449597653234e-06, 0.024909861385822296, 0.9680997133255005, 0.006879042834043503, 0.00010210835171164945, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00021467455371748656, 9.040503209689632e-05, 3.369562909938395e-05, 1.9265097961351785e-08, 9.727973520057276e-07, 2.4095537810353562e-05, 0.0040859803557395935, 0.8618475794792175, 0.1337023377418518, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.289768872287823e-06, 6.284429400693625e-05, 0.0001214230724144727, 2.809870807141124e-07, 1.092972157223926e-09, 1.0671180605825725e-09, 1.2438744079190656e-06, 0.024907555431127548, 0.9749038219451904, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0543275885283947, 0.01742306910455227, 0.05347726121544838, 0.18824619054794312, 0.09003543108701706, 0.08433128148317337, 0.1953076422214508, 0.206686869263649, 0.11016455292701721, 0.0, 0.02470340207219124, 0.02512546442449093, 0.11353036016225815, 0.35132649540901184, 0.20412008464336395, 0.027150044217705727, 0.015305055305361748, 0.05760098248720169, 0.1811380535364151, 0.0], [0.00859006680548191, 0.02184058353304863, 0.02418440766632557, 0.03131486475467682, 0.03273439407348633, 0.06774082779884338, 0.1731010377407074, 0.09275981038808823, 0.5477339029312134, 0.0, 0.009894105605781078, 0.02192404493689537, 0.3007009029388428, 0.13983333110809326, 0.03682582825422287, 0.08908118307590485, 0.27657952904701233, 0.026430398225784302, 0.09873086214065552, 0.0], [0.02145911566913128, 0.046526145190000534, 0.014734850265085697, 0.026213468983769417, 0.04904777929186821, 0.08567024767398834, 0.13810616731643677, 0.03392839804291725, 0.5843138694763184, 0.0, 0.011459765024483204, 0.044317521154880524, 0.5289616584777832, 0.19549138844013214, 0.03426412120461464, 0.017797794193029404, 0.030613277107477188, 0.0163635965436697, 0.12073105573654175, 0.0], [0.019245177507400513, 0.01515401341021061, 0.027409562841057777, 0.0068243746645748615, 0.07997982203960419, 0.0921224057674408, 0.04510754346847534, 0.04373685643076897, 0.670420229434967, 0.0, 0.011578483507037163, 0.0029169816989451647, 0.00455811433494091, 0.01625976897776127, 0.018393559381365776, 0.11749742925167084, 0.32938554883003235, 0.41049671173095703, 0.08891336619853973, 0.0], [0.04381020739674568, 0.06711422652006149, 0.07609888166189194, 0.021496189758181572, 0.05042967572808266, 0.15614424645900726, 0.11071597784757614, 0.14296749234199524, 0.3312230408191681, 0.0, 0.0033444140572100878, 0.0011373214656487107, 0.0019445078214630485, 0.02781311236321926, 0.0049105980433523655, 0.05221953243017197, 0.09222303330898285, 0.3644186854362488, 0.45198866724967957, 0.0], [0.04100082442164421, 0.030313873663544655, 0.032653506845235825, 0.0695231482386589, 0.12672685086727142, 0.12515434622764587, 0.08855390548706055, 0.05835743993520737, 0.4277162253856659, 0.0, 0.002199131529778242, 0.0006913270917721093, 0.002652444876730442, 0.017487458884716034, 0.18746966123580933, 0.39171290397644043, 0.26989367604255676, 0.017002178356051445, 0.11089123785495758, 0.0], [0.14112897217273712, 0.06592341512441635, 0.06986766308546066, 0.06311382353305817, 0.12678426504135132, 0.04950721934437752, 0.08025017380714417, 0.03467738255858421, 0.36874714493751526, 0.0, 0.01051913108676672, 0.003755246289074421, 0.0008555634994991124, 0.002675057854503393, 0.0025919810868799686, 0.02418649010360241, 0.018060903996229172, 0.003447937313467264, 0.9339075684547424, 0.0], [0.02841436117887497, 0.022568009793758392, 0.014519155025482178, 0.019271234050393105, 0.018120555207133293, 0.036434635519981384, 0.014109926298260689, 0.24622198939323425, 0.6003400683403015, 0.0, 0.029951948672533035, 0.006547479424625635, 0.030934682115912437, 0.0036260345950722694, 0.1420958936214447, 0.19529034197330475, 0.1491098254919052, 0.009723717346787453, 0.43272000551223755, 0.0], [0.05730762332677841, 0.07724729180335999, 0.030861826613545418, 0.04063780978322029, 0.08539344370365143, 0.029541905969381332, 0.02964094467461109, 0.028206804767251015, 0.6211622953414917, 0.0, 0.017757408320903778, 0.006832967512309551, 0.028906390070915222, 0.00921954121440649, 0.054915353655815125, 0.028632348403334618, 0.03646676614880562, 0.01978384144604206, 0.7974854707717896, 0.0], [0.20915710926055908, 0.193747878074646, 0.11181499063968658, 0.07680925726890564, 0.04479793831706047, 0.03787367418408394, 0.04819086939096451, 0.11330965161323547, 0.1642986238002777, 0.0, 0.06588920205831528, 0.05552517622709274, 0.18546447157859802, 0.007839588448405266, 0.020484987646341324, 0.01699826307594776, 0.01947665773332119, 0.017759086564183235, 0.6105626821517944, 0.0]], [[0.058097392320632935, 0.00935883168131113, 0.04822169989347458, 0.0048278868198394775, 0.191309854388237, 0.28154584765434265, 0.09391050785779953, 0.24126385152339935, 0.07146408408880234, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10414423793554306, 0.027566324919462204, 0.021727869287133217, 0.033647697418928146, 0.026882247999310493, 0.17782779037952423, 0.05685214698314667, 0.45095938444137573, 0.10039239376783371, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.44215551018714905, 0.049670565873384476, 0.014098896645009518, 0.029011834412813187, 0.01834075152873993, 0.1358453929424286, 0.04072042554616928, 0.2330295443534851, 0.03712712228298187, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10425814986228943, 0.06979154050350189, 0.036334071308374405, 0.028995294123888016, 0.015532439574599266, 0.1330128014087677, 0.063407763838768, 0.23157192766666412, 0.3170958459377289, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3384562134742737, 0.055937401950359344, 0.038792647421360016, 0.00819220207631588, 0.03063569962978363, 0.09386011958122253, 0.07227522879838943, 0.30926018953323364, 0.05259038880467415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3519401550292969, 0.1823827177286148, 0.06509842723608017, 0.030452275648713112, 0.08377533406019211, 0.09469012171030045, 0.04247477278113365, 0.11751312017440796, 0.03167306259274483, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3634622097015381, 0.14048337936401367, 0.08374395966529846, 0.038946691900491714, 0.03473563492298126, 0.06442954391241074, 0.019375532865524292, 0.22685663402080536, 0.027966352179646492, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.18070067465305328, 0.04645215719938278, 0.0992647334933281, 0.005799622740596533, 0.47514480352401733, 0.12094692885875702, 0.030788421630859375, 0.025236092507839203, 0.015666494145989418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5453059673309326, 0.10054859519004822, 0.01722547970712185, 0.06704734265804291, 0.007780902087688446, 0.07263857871294022, 0.022086072713136673, 0.1394840031862259, 0.027883058413863182, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15028028190135956, 0.17163224518299103, 0.06043723225593567, 0.10140684247016907, 0.10512865334749222, 0.06778015196323395, 0.06512691080570221, 0.23085294663906097, 0.04735487326979637, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.038908280432224274, 0.07760688662528992, 0.062413811683654785, 0.0023113787174224854, 0.0021746077109128237, 0.015095214359462261, 0.003646473865956068, 0.038165315985679626, 0.759678065776825, 0.0, 0.14391662180423737, 0.11156481504440308, 0.4162432849407196, 0.07845085859298706, 0.04067624360322952, 0.016916701570153236, 0.012291320599615574, 0.10670017451047897, 0.07323983311653137, 0.0], [0.015742339193820953, 0.029524141922593117, 0.0550379604101181, 0.16926467418670654, 0.035933610051870346, 0.03279981389641762, 0.03188418969511986, 0.5383173227310181, 0.09149592369794846, 0.0, 0.0171683169901371, 0.03512553498148918, 0.4936983287334442, 0.18945446610450745, 0.020571058616042137, 0.011469473131000996, 0.04002959281206131, 0.08968089520931244, 0.10280223935842514, 0.0], [0.022741766646504402, 0.013864121399819851, 0.06161126494407654, 0.06985131651163101, 0.03954875469207764, 0.02864447981119156, 0.036658816039562225, 0.05774570629000664, 0.6693336963653564, 0.0, 0.2093620002269745, 0.11281707882881165, 0.25891542434692383, 0.14515942335128784, 0.0042000748217105865, 0.006485591176897287, 0.005525505635887384, 0.14364667236804962, 0.11388827115297318, 0.0], [0.06077639013528824, 0.053226571530103683, 0.05544588342308998, 0.08368532359600067, 0.04779139161109924, 0.028960514813661575, 0.03463221713900566, 0.42419588565826416, 0.21128588914871216, 0.0, 0.0109701631590724, 0.0007525839027948678, 0.011503712274134159, 0.03920656442642212, 0.2449047565460205, 0.048431187868118286, 0.12996943295001984, 0.4081973731517792, 0.10606419295072556, 0.0], [0.03320460394024849, 0.07872876524925232, 0.0791814923286438, 0.008506255224347115, 0.010383618995547295, 0.021636927500367165, 0.009444555267691612, 0.026183925569057465, 0.7327298521995544, 0.0, 0.004995591007173061, 0.0001893905719043687, 0.0009439413552172482, 0.03207648918032646, 0.08267047256231308, 0.015983520075678825, 0.02033340558409691, 0.8191123604774475, 0.023694908246397972, 0.0], [0.14095324277877808, 0.17195045948028564, 0.04960065335035324, 0.02801741287112236, 0.02789357118308544, 0.0246508177369833, 0.027228642255067825, 0.008449538610875607, 0.521255612373352, 0.0, 0.0022357299458235502, 0.000793653482105583, 0.0010144039988517761, 0.2958794832229614, 0.3394852876663208, 0.07495945692062378, 0.06856833398342133, 0.06118563562631607, 0.15587811172008514, 0.0], [0.01678302139043808, 0.02193976752460003, 0.13912786543369293, 0.05168221518397331, 0.06239692494273186, 0.008615943603217602, 0.037501659244298935, 0.02482585795223713, 0.6371266841888428, 0.0, 0.0020441634114831686, 0.00032311712857335806, 0.0006899640429764986, 0.03996479511260986, 0.38782593607902527, 0.05503879860043526, 0.24750953912734985, 0.004524962045252323, 0.26207876205444336, 0.0], [0.03396642208099365, 0.07778684049844742, 0.18657010793685913, 0.11281172931194305, 0.019890569150447845, 0.012303605675697327, 0.0494060292840004, 0.11448060721158981, 0.39278414845466614, 0.0, 0.0012333561899140477, 0.0002747838443610817, 0.0023864947725087404, 0.10253860056400299, 0.4721597135066986, 0.04103615880012512, 0.03782818093895912, 0.026908699423074722, 0.31563398241996765, 0.0], [0.02684134803712368, 0.03310805931687355, 0.163743257522583, 0.014529252424836159, 0.10077258199453354, 0.044357266277074814, 0.04152251034975052, 0.10173188894987106, 0.4733937382698059, 0.0, 0.004791810177266598, 0.0015037101693451405, 0.004669447895139456, 0.38809871673583984, 0.13379721343517303, 0.024320820346474648, 0.03647102415561676, 0.013309511356055737, 0.3930378258228302, 0.0], [0.01862592063844204, 0.022009190171957016, 0.028925148770213127, 0.006837732624262571, 0.006956242956221104, 0.010202805511653423, 0.015325144864618778, 0.11640346795320511, 0.7747144103050232, 0.0, 0.00849083997309208, 0.003579143201932311, 0.0033037925604730844, 0.006032468285411596, 0.017621049657464027, 0.0234503336250782, 0.018282314762473106, 0.02657976746559143, 0.8926602602005005, 0.0]], [[0.11086989939212799, 0.14517885446548462, 0.17419463396072388, 0.060936953872442245, 0.08783368766307831, 0.11005676537752151, 0.03251044824719429, 0.07983692735433578, 0.19858187437057495, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16660544276237488, 0.29352903366088867, 0.1008867621421814, 0.023942291736602783, 0.15022507309913635, 0.06581585109233856, 0.02344084158539772, 0.05208655819296837, 0.12346797436475754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1683349758386612, 0.22478938102722168, 0.06976605206727982, 0.1032773107290268, 0.16255290806293488, 0.08890064060688019, 0.03925151377916336, 0.023706944659352303, 0.11942004412412643, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.19914905726909637, 0.1368866264820099, 0.178489089012146, 0.11241752654314041, 0.06187256798148155, 0.0768556222319603, 0.01627686619758606, 0.07274915277957916, 0.14530348777770996, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08000901341438293, 0.20181676745414734, 0.21235129237174988, 0.05340588092803955, 0.12758778035640717, 0.11278047412633896, 0.06906574964523315, 0.08596791326999664, 0.05701539292931557, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14153669774532318, 0.10432923585176468, 0.09881750494241714, 0.08603313565254211, 0.10391980409622192, 0.06189347058534622, 0.06772381067276001, 0.08503933250904083, 0.25070688128471375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06525713205337524, 0.07869093865156174, 0.11366366595029831, 0.044226594269275665, 0.05455174669623375, 0.23646420240402222, 0.09933798015117645, 0.1198185384273529, 0.1879890412092209, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09450254589319229, 0.027017319574952126, 0.06480545550584793, 0.10929621011018753, 0.11382008343935013, 0.17441418766975403, 0.11898359656333923, 0.06495486199855804, 0.23220552504062653, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07681684195995331, 0.0671391412615776, 0.0905177965760231, 0.06064317002892494, 0.06652072072029114, 0.09855856746435165, 0.07360702753067017, 0.13956283032894135, 0.3266339898109436, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12179998308420181, 0.07977079600095749, 0.08405954390764236, 0.1456507444381714, 0.14551174640655518, 0.07862778753042221, 0.09882251918315887, 0.14300917088985443, 0.1027478501200676, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0830092504620552, 0.0839436799287796, 0.10106679797172546, 0.11154499650001526, 0.045070260763168335, 0.1284436285495758, 0.1161414161324501, 0.19574469327926636, 0.1350351870059967, 0.0, 0.8417463898658752, 0.05951714888215065, 0.012198105454444885, 0.03180553764104843, 0.02919766865670681, 0.0096508814021945, 0.003031272441148758, 0.0009100366733036935, 0.011942943558096886, 0.0], [0.0006529411766678095, 0.0018492193194106221, 0.018439743667840958, 0.004895282443612814, 0.0036929987836629152, 0.05041775107383728, 0.03271673619747162, 0.4425412714481354, 0.4447941780090332, 0.0, 0.00569154741242528, 0.979739785194397, 0.012030904181301594, 0.0001143000990850851, 9.368032624479383e-05, 0.0008171445806510746, 0.00012590458209160715, 0.0005024938145652413, 0.0008843241375871003, 0.0], [0.015919672325253487, 0.02172437310218811, 0.013682822696864605, 0.028371846303343773, 0.017258556559681892, 0.014516759663820267, 0.033475372940301895, 0.45419326424598694, 0.40085726976394653, 0.0, 0.005223963409662247, 0.005622355733066797, 0.9848889708518982, 0.002582893241196871, 0.0003334738139528781, 0.0005618981667794287, 3.256636409787461e-05, 0.00024550766102038324, 0.0005086653982289135, 0.0], [0.006064589135348797, 0.006147248670458794, 0.06902536749839783, 0.011021673679351807, 0.0062199062667787075, 0.17622654139995575, 0.00982236210256815, 0.46262383460998535, 0.25284844636917114, 0.0, 0.0032260464504361153, 0.007557107135653496, 0.0651315227150917, 0.6094849109649658, 0.008782745338976383, 0.2748804986476898, 0.015592943876981735, 0.008143502287566662, 0.007200630847364664, 0.0], [0.018328940495848656, 0.034908927977085114, 0.027539005503058434, 0.04494883120059967, 0.03695090860128403, 0.18224696815013885, 0.04204700142145157, 0.09570277482271194, 0.5173265337944031, 0.0, 0.01683628372848034, 0.0020552987698465586, 0.00783018209040165, 0.008005303330719471, 0.0011927365558221936, 0.9284406900405884, 0.03478293865919113, 0.00030738895293325186, 0.0005490221083164215, 0.0], [0.06838149577379227, 0.025893883779644966, 0.06412170827388763, 0.11039282381534576, 0.12848982214927673, 0.09953469038009644, 0.09056522697210312, 0.12723064422607422, 0.28538966178894043, 0.0, 0.0004254023951943964, 7.111614831956103e-05, 0.0008891545585356653, 1.880968193290755e-05, 6.570573896169662e-05, 0.9941434860229492, 0.0025632327888160944, 9.733852493809536e-06, 0.0018130606040358543, 0.0], [0.07893572002649307, 0.0734885111451149, 0.06503137946128845, 0.04291535168886185, 0.08502060174942017, 0.04846649244427681, 0.07035838067531586, 0.14812934398651123, 0.38765427470207214, 0.0, 7.936867405078374e-06, 1.8136512153432705e-05, 4.5569290705316234e-06, 1.071940641850233e-05, 3.808495648627286e-06, 0.0008168917265720665, 0.9974388480186462, 1.4373016711033415e-05, 0.0016848900122568011, 0.0], [0.007445929106324911, 0.004103729501366615, 0.05411284416913986, 0.006074799690395594, 0.07146289199590683, 0.5494692921638489, 0.05009504780173302, 0.058794084936380386, 0.1984413117170334, 0.0, 0.0014213839313015342, 0.003971228376030922, 0.008488249033689499, 2.0282970581320114e-05, 8.774230809649453e-05, 0.030342059209942818, 0.010436602868139744, 0.013138609007000923, 0.9320940375328064, 0.0], [0.0037151367869228125, 0.005083263851702213, 0.02171880006790161, 0.01245985459536314, 0.012914983555674553, 0.14437292516231537, 0.026943473145365715, 0.17420484125614166, 0.5985866785049438, 0.0, 9.058997966349125e-05, 0.0009022729936987162, 0.0017266678623855114, 1.3629892237077001e-05, 0.000727150880265981, 0.002379553159698844, 0.0010508937994018197, 0.012508089654147625, 0.9806011319160461, 0.0], [0.02579679898917675, 0.0645768865942955, 0.03225725144147873, 0.044467855244874954, 0.04297630116343498, 0.06060377135872841, 0.030930038541555405, 0.03278812766075134, 0.6656030416488647, 0.0, 0.0003429521748330444, 0.001905322540551424, 0.0005013775080442429, 1.1471392099338118e-05, 0.00017356597527395934, 0.0029742273036390543, 0.003938945475965738, 0.028075864538550377, 0.9620763063430786, 0.0]], [[0.0261031873524189, 0.9575563073158264, 0.006272038444876671, 0.0037288309540599585, 0.0038619006518274546, 0.0007324732141569257, 0.0005133527447469532, 0.0003637235495261848, 0.0008679544553160667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02134888991713524, 0.08473973721265793, 0.6753177642822266, 0.028721673414111137, 0.14432094991207123, 0.027568204328417778, 0.0057298606261610985, 0.004451636224985123, 0.007801060564815998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03883299231529236, 0.030284319072961807, 0.5620493292808533, 0.09062989801168442, 0.17362907528877258, 0.08253934979438782, 0.010801085270941257, 0.00978847872465849, 0.0014453904004767537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002180949319154024, 0.003013473702594638, 0.16569769382476807, 0.008050205186009407, 0.7580646276473999, 0.061441101133823395, 0.001020166208036244, 0.0001067533012246713, 0.0004249440098647028, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004150479566305876, 0.00034606645931489766, 0.3802972435951233, 0.06855826079845428, 0.29045602679252625, 0.1767650991678238, 0.06603583693504333, 0.0014808314153924584, 0.011909942142665386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006170187145471573, 0.0012396957026794553, 0.0354800671339035, 0.0032299698796123266, 0.03240001201629639, 0.5543311238288879, 0.30418315529823303, 0.051339369267225266, 0.01162647269666195, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0035115755163133144, 0.0011483307462185621, 0.017956364899873734, 0.003783614607527852, 0.030611976981163025, 0.3673596978187561, 0.20627115666866302, 0.3506667912006378, 0.01869054324924946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0021685126703232527, 0.0006909942603670061, 0.010240452364087105, 0.01958688348531723, 0.004634156823158264, 0.11485372483730316, 0.04815557599067688, 0.7050773501396179, 0.0945921242237091, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.049201104789972305, 0.02397306263446808, 0.02337191067636013, 0.31066185235977173, 0.06433572620153427, 0.12544430792331696, 0.0786852017045021, 0.25179895758628845, 0.07252778857946396, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.010841209441423416, 0.0041772774420678616, 0.01548130251467228, 0.036074474453926086, 0.033387064933776855, 0.08192819356918335, 0.04784044623374939, 0.10195028781890869, 0.668319821357727, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13460709154605865, 0.15298102796077728, 0.06546170264482498, 0.14220191538333893, 0.11837887763977051, 0.09888823330402374, 0.10630416870117188, 0.08867054432630539, 0.09250646829605103, 0.0, 0.23634016513824463, 0.09021607041358948, 0.12040459364652634, 0.01354933436959982, 0.0019137230701744556, 0.009001325815916061, 0.028688833117485046, 0.2612648904323578, 0.23862121999263763, 0.0], [0.9316296577453613, 0.016095036640763283, 0.0020372711587697268, 0.0019596514757722616, 2.8437656510504894e-05, 6.708989531034604e-05, 0.0004955903859809041, 3.0113247703411616e-05, 0.047657083719968796, 0.0, 0.2307557761669159, 0.2812652289867401, 0.30346915125846863, 0.05031246319413185, 0.006193350534886122, 0.01668362505733967, 0.012607063166797161, 0.07951408624649048, 0.019199388101696968, 0.0], [0.043201129883527756, 0.9419298768043518, 0.0003410913050174713, 0.003313146298751235, 7.506452675443143e-06, 1.9570916265365668e-05, 2.5470235414104536e-05, 2.1080213628010824e-05, 0.011141069233417511, 0.0, 0.29960742592811584, 0.20819564163684845, 0.27825382351875305, 0.007396433036774397, 0.0007608149899169803, 0.0260151494294405, 0.012685009278357029, 0.12934625148773193, 0.03773954138159752, 0.0], [3.7581870856229216e-05, 0.00022979748609941453, 0.9982534646987915, 8.70372386998497e-05, 5.87535805607331e-06, 2.5239218302886002e-05, 6.597588708245894e-06, 2.193619138779468e-06, 0.001352491439320147, 0.0, 0.035675279796123505, 0.035874202847480774, 0.007117687724530697, 0.018771182745695114, 0.010206644423305988, 0.06527784466743469, 0.03775254264473915, 0.7770709991455078, 0.012253628112375736, 0.0], [0.0019612079486250877, 0.011641290038824081, 0.010358362458646297, 0.8346317410469055, 0.00641160923987627, 0.0007435380248352885, 0.0018172020791098475, 7.255822129081935e-05, 0.1323624849319458, 0.0, 0.012017791159451008, 0.0028583300299942493, 0.0024127706419676542, 0.002610970288515091, 0.001820205245167017, 0.04092223569750786, 0.016621166840195656, 0.9115477800369263, 0.009188669733703136, 0.0], [4.077299308846705e-05, 0.00016088274423964322, 3.1180113637674367e-06, 5.9685276937671006e-05, 6.661444786004722e-06, 0.0006764131248928607, 5.4107837058836594e-05, 0.9797272086143494, 0.01927126571536064, 0.0, 0.03447290509939194, 0.013388306833803654, 0.08488336205482483, 0.015237652696669102, 0.19176845252513885, 0.3472833037376404, 0.10885429382324219, 0.192628413438797, 0.011483324691653252, 0.0], [2.7792530090664513e-06, 1.1777839063142892e-05, 1.0386434951215051e-05, 0.0006807934259995818, 0.00028749846387654543, 0.9563493728637695, 2.4335316993528977e-05, 0.001297356327995658, 0.041335828602313995, 0.0, 0.0005363536183722317, 0.0001964608090929687, 0.0017719777533784509, 0.003164003835991025, 0.27662715315818787, 0.05286016687750816, 0.648875892162323, 0.007890382781624794, 0.00807751715183258, 0.0], [0.00033864984288811684, 0.00016234541544690728, 0.00011107163300039247, 7.639558316441253e-05, 9.851753566181287e-05, 0.00046863980242051184, 0.9855522513389587, 0.00012009339843643829, 0.013071970082819462, 0.0, 0.001257028547115624, 0.00020761204359587282, 0.0024441492278128862, 0.003374723019078374, 0.9062062501907349, 0.0712839737534523, 0.0032159662805497646, 0.009974849410355091, 0.0020355340093374252, 0.0], [0.001446103909984231, 0.0026176422834396362, 0.0005430445889942348, 0.5833504796028137, 0.08298782259225845, 0.01277364045381546, 0.008405186235904694, 0.028461067005991936, 0.2794148921966553, 0.0, 0.0008205634076148272, 0.00019305139721836895, 0.002098840195685625, 0.004588909447193146, 0.9688709378242493, 0.01628950424492359, 0.0038415545132011175, 0.0016231476329267025, 0.0016735766548663378, 0.0], [8.301706202473724e-07, 1.612889263924444e-06, 3.859615389956161e-06, 0.0015496612759307027, 0.9884966611862183, 0.0003321043332107365, 1.1829011782538146e-05, 3.7258676002238644e-06, 0.00959983840584755, 0.0, 0.03610469028353691, 0.046298399567604065, 0.04650943726301193, 0.02111651562154293, 0.06683006882667542, 0.37146270275115967, 0.174205482006073, 0.15773150324821472, 0.07974111288785934, 0.0]], [[0.005738695617765188, 0.0068999892100691795, 0.4274883270263672, 0.08288666605949402, 0.1445126235485077, 0.04382907599210739, 0.10957401990890503, 0.05347184091806412, 0.1255987584590912, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0025263649877160788, 0.00471830926835537, 0.13454590737819672, 0.4177793860435486, 0.28839975595474243, 0.029358303174376488, 0.017654288560152054, 0.0047735795378685, 0.10024390369653702, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009192855097353458, 0.007133236154913902, 0.03149157017469406, 0.1856081485748291, 0.5691666603088379, 0.07386670261621475, 0.029819192364811897, 0.03683711960911751, 0.05688462406396866, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00297820963896811, 0.0015070328954607248, 0.0025649494491517544, 0.0011051844339817762, 0.04088710993528366, 0.1953955888748169, 0.34000417590141296, 0.3367410898208618, 0.07881659269332886, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003951869439333677, 0.009354526177048683, 0.007010620087385178, 0.0025927696842700243, 0.09962604194879532, 0.10909298062324524, 0.4455967843532562, 0.15358439087867737, 0.16918975114822388, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0038829154800623655, 0.0036434896755963564, 0.006399825215339661, 0.000760377966798842, 0.010139851830899715, 0.038725122809410095, 0.10014155507087708, 0.48370444774627686, 0.35260239243507385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.001297087874263525, 0.0014563009608536959, 0.013839880004525185, 0.0004286184557713568, 0.012207024730741978, 0.028704902157187462, 0.046600911766290665, 0.26406532526016235, 0.6313998103141785, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0033481158316135406, 0.0038099782541394234, 0.0031049775425344706, 0.00033546099439263344, 0.0031272985506802797, 0.008788534440100193, 0.021183660253882408, 0.12157405912876129, 0.8347280025482178, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3364367187023163, 0.17456969618797302, 0.051038213074207306, 0.006790165323764086, 0.024106895551085472, 0.0694134384393692, 0.02184627763926983, 0.061508405953645706, 0.25429028272628784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10536088049411774, 0.07750789821147919, 0.0850178673863411, 0.08725376427173615, 0.2586125433444977, 0.16756391525268555, 0.054291605949401855, 0.030132828280329704, 0.13425879180431366, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03624086081981659, 0.008591840974986553, 0.01890810765326023, 0.010947922244668007, 0.5211313366889954, 0.04890615865588188, 0.13394898176193237, 0.08554741740226746, 0.13577744364738464, 0.0, 0.03425053879618645, 0.026130978018045425, 0.3080751299858093, 0.027706336230039597, 0.12989944219589233, 0.29902005195617676, 0.0305496696382761, 0.03879137709736824, 0.1055762991309166, 0.0], [0.09101090580224991, 0.15663929283618927, 0.2008313536643982, 0.13744188845157623, 0.16349081695079803, 0.01479706447571516, 0.04576689749956131, 0.05515507981181145, 0.1348666250705719, 0.0, 0.004509713500738144, 0.02305547706782818, 0.939035952091217, 0.006188178434967995, 0.020785806700587273, 0.00040150884888134897, 0.00018676061881706119, 0.00013036451127845794, 0.005706076975911856, 0.0], [0.10898119956254959, 0.19741322100162506, 0.12774543464183807, 0.07097428292036057, 0.033309608697891235, 0.016726871952414513, 0.019306309521198273, 0.09155051410198212, 0.3339925706386566, 0.0, 0.0005241778562776744, 0.009561678394675255, 0.988527774810791, 2.2495760276797228e-05, 4.7274414100684226e-05, 0.00013538387429434806, 4.543165232462343e-06, 6.27172994427383e-05, 0.001113483915105462, 0.0], [0.051247891038656235, 0.06952031701803207, 0.3243081271648407, 0.04820195212960243, 0.05462171137332916, 0.04280935227870941, 0.03801479935646057, 0.07710513472557068, 0.2941707372665405, 0.0, 0.06551901996135712, 0.0800878182053566, 0.06342226266860962, 0.00974376779049635, 0.5160938501358032, 0.02204274758696556, 0.004013149533420801, 0.0735243633389473, 0.1655530482530594, 0.0], [0.22540897130966187, 0.04426601901650429, 0.13483746349811554, 0.09052211791276932, 0.036632657051086426, 0.06078784167766571, 0.09962243586778641, 0.04597063735127449, 0.2619517743587494, 0.0, 0.0013552415184676647, 0.0004213388019707054, 0.002606122987344861, 0.0010090378345921636, 0.24638326466083527, 0.6568374633789062, 0.01604411192238331, 0.04806208983063698, 0.027281243354082108, 0.0], [0.08315062522888184, 0.10649015009403229, 0.15254046022891998, 0.0728936716914177, 0.10388997197151184, 0.04998103529214859, 0.0675109326839447, 0.17524446547031403, 0.18829864263534546, 0.0, 0.0002145337639376521, 0.00018796027870848775, 0.0008407118148170412, 0.0029629908967763186, 0.28427600860595703, 0.6725634336471558, 0.023870857432484627, 0.00339014851488173, 0.011693413369357586, 0.0], [0.09407053142786026, 0.04335644096136093, 0.04757237061858177, 0.023308007046580315, 0.14141318202018738, 0.017728488892316818, 0.02331509254872799, 0.07266414165496826, 0.5365718007087708, 0.0, 0.0009873382514342666, 0.0005485343281179667, 6.628077971981838e-05, 0.0029302756302058697, 0.23183174431324005, 0.05256076529622078, 0.5701138377189636, 0.005792138632386923, 0.13516920804977417, 0.0], [0.08477651327848434, 0.026448125019669533, 0.013684368692338467, 0.1331702470779419, 0.16824185848236084, 0.007634431589394808, 0.025501158088445663, 0.035930439829826355, 0.5046128630638123, 0.0, 2.471696279826574e-05, 2.0868348656222224e-05, 4.437468305695802e-05, 0.002024284563958645, 0.9655042886734009, 0.024176988750696182, 0.001284845289774239, 0.00018083618488162756, 0.006738840136677027, 0.0], [0.03296202793717384, 0.01823815330862999, 0.025750160217285156, 0.08325016498565674, 0.1596710979938507, 0.010502922348678112, 0.01792057603597641, 0.05097610503435135, 0.6007286906242371, 0.0, 0.0007289832574315369, 7.746354822302237e-05, 0.00018428664770908654, 0.014176051132380962, 0.9112405180931091, 0.013280178420245647, 0.003417921019718051, 0.02014165185391903, 0.03675319626927376, 0.0], [0.04370357468724251, 0.02250431850552559, 0.016271278262138367, 0.019842427223920822, 0.12028838694095612, 0.03933797404170036, 0.043740611523389816, 0.08045370131731033, 0.6138576865196228, 0.0, 0.00874137319624424, 0.03438721224665642, 0.17507928609848022, 0.007159235887229443, 0.0029199302662163973, 0.023628318682312965, 0.007933209650218487, 0.004559694789350033, 0.7355918884277344, 0.0]], [[0.034539882093667984, 0.0018589550163596869, 0.9604092836380005, 1.3120608855388127e-05, 2.1815638319822028e-05, 0.00012517283903434873, 8.019943197723478e-05, 0.0021589084062725306, 0.0007928607519716024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.048832912914804e-07, 1.7815009414334781e-06, 0.9998455047607422, 0.0001518452918389812, 4.1070780554264275e-08, 2.7954746156799715e-11, 9.231376947582692e-12, 9.901777175969073e-09, 2.5545642756696907e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [6.695767496012195e-08, 2.089915795977504e-07, 0.005368041805922985, 0.9945066571235657, 0.0001248170156031847, 2.304766155702964e-09, 2.762512718579302e-10, 3.973758211373024e-09, 9.372820954922645e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.018761014413675e-13, 1.4841802622529476e-16, 4.663825770023777e-09, 3.820862737313746e-09, 0.9999942183494568, 4.988648925063899e-06, 4.967477167452938e-13, 1.416252587396787e-16, 2.1775358895380023e-16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.666895758731471e-09, 7.292542437975502e-12, 2.898993545219497e-11, 4.2817244194637283e-10, 0.00027504604076966643, 0.9995728731155396, 0.00015239788626786321, 1.9082661839586734e-10, 2.232514032581706e-13, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.7137297136926577e-10, 5.3312285142048665e-12, 2.2368220760327594e-14, 4.904942142678549e-17, 8.726878775178193e-09, 0.004644036293029785, 0.9953435659408569, 1.324965796811739e-05, 6.982896899598856e-12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.877224735189145e-10, 1.5497924055196677e-09, 6.021576987036426e-11, 8.955144165463396e-19, 1.7180077889825118e-13, 6.163505759104737e-07, 0.001256544259376824, 0.9987285733222961, 1.4209075743565336e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.25698863434809e-08, 7.313030323530256e-07, 1.412931510458293e-06, 1.1662047555981733e-16, 8.495708612521816e-14, 1.1933978653379251e-13, 1.3303619539328793e-07, 0.01294001005589962, 0.9870572686195374, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.6884889646462398e-06, 2.6281904865754768e-05, 0.001122217159718275, 6.101166945882142e-06, 4.424501298672112e-08, 5.172042264953158e-13, 5.508820136168602e-11, 5.942968346062116e-05, 0.9987838268280029, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.288114359951578e-05, 6.015944563841913e-06, 0.004432132933288813, 0.025997335091233253, 0.000731422973331064, 6.87844434188456e-11, 8.199346692057408e-13, 7.098316245901515e-08, 0.9687905311584473, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1783323585987091, 0.3813028037548065, 0.2072289139032364, 0.06766574084758759, 0.053963109850883484, 0.030795719474554062, 0.023536406457424164, 0.03921645134687424, 0.01795845478773117, 0.0, 0.01947755739092827, 0.007096209097653627, 0.03225293010473251, 0.0123430285602808, 0.10373923927545547, 0.44083938002586365, 0.04899014160037041, 0.25500863790512085, 0.08025286346673965, 0.0], [0.8837893009185791, 0.07202983647584915, 0.03646722435951233, 0.0004511935112532228, 0.0007272462244145572, 0.0008432198665104806, 0.0031319037079811096, 0.0004143840924371034, 0.0021455709356814623, 0.0, 0.018974049016833305, 0.05092930048704147, 0.38670486211776733, 0.05532746762037277, 0.02096201851963997, 0.23439037799835205, 0.029592081904411316, 0.06233520433306694, 0.1407845914363861, 0.0], [0.3973897695541382, 0.14911939203739166, 0.3486334979534149, 0.012645252980291843, 0.00675938231870532, 0.00483374297618866, 0.010028100572526455, 0.012036854401230812, 0.058554183691740036, 0.0, 0.009641589596867561, 0.009545106440782547, 0.19981582462787628, 0.009672220796346664, 0.003704657079651952, 0.04582780599594116, 0.006998295895755291, 0.5789687037467957, 0.13582585752010345, 0.0], [0.005409032106399536, 0.005906772334128618, 0.13379110395908356, 0.15247586369514465, 0.06559418141841888, 0.15356750786304474, 0.04085409641265869, 0.029147597029805183, 0.41325387358665466, 0.0, 0.00450306897982955, 0.0034239809028804302, 0.012258612550795078, 0.005700208712369204, 0.04511384665966034, 0.4419432282447815, 0.12840862572193146, 0.13075105845928192, 0.22789721190929413, 0.0], [0.0013326199259608984, 0.0014979635598137975, 0.011986319907009602, 0.7730216383934021, 0.06901827454566956, 0.05895080044865608, 0.016383536159992218, 0.015771687030792236, 0.052037257701158524, 0.0, 0.00048664878704585135, 0.00010348611976951361, 0.0010980216320604086, 0.0006185582024045289, 0.028226494789123535, 0.37447214126586914, 0.09456676244735718, 0.48241522908210754, 0.018012629821896553, 0.0], [0.0012038598069921136, 0.0033955213148146868, 0.025528373196721077, 0.03136582672595978, 0.10901585966348648, 0.3851255178451538, 0.0182026457041502, 0.13982580602169037, 0.2863365411758423, 0.0, 8.0467427324038e-05, 3.9275117160286754e-05, 0.00016763176245149225, 0.00013412459520623088, 0.009092556312680244, 0.7851189374923706, 0.16675172746181488, 0.0029041438829153776, 0.03571125119924545, 0.0], [0.008065885864198208, 0.004362722393125296, 0.06363680213689804, 0.023311397060751915, 0.06106392294168472, 0.1357712298631668, 0.03965916484594345, 0.06073852628469467, 0.6033903956413269, 0.0, 0.0007275060634128749, 0.00015159584290813655, 0.00037383963353931904, 0.0005468691233545542, 0.01837681420147419, 0.03491391986608505, 0.7517433166503906, 0.00028147027478553355, 0.19288486242294312, 0.0], [0.0003142715140711516, 0.0005578870768658817, 0.0015481057344004512, 0.0887022390961647, 0.06383900344371796, 0.2639910578727722, 0.049384135752916336, 0.12241825461387634, 0.40924492478370667, 0.0, 0.0005560970166698098, 0.0002987806510645896, 0.0021934551186859608, 0.00023410467838402838, 0.023030919954180717, 0.05263887345790863, 0.01838914304971695, 0.0007265828317031264, 0.9019319415092468, 0.0], [0.0003916181158274412, 0.0003099135938100517, 0.0024421222042292356, 0.016801349818706512, 0.18835966289043427, 0.025843605399131775, 0.08458039909601212, 0.20884136855602264, 0.4724300503730774, 0.0, 0.007445591501891613, 0.0020796440076082945, 0.012208829633891582, 0.001590645289979875, 0.09274771064519882, 0.017371611669659615, 0.04761578515172005, 0.004260089714080095, 0.8146799802780151, 0.0], [5.865378989255987e-05, 7.253760122694075e-05, 0.0007906460668891668, 0.025103986263275146, 0.0753612071275711, 0.04038592055439949, 0.011871143244206905, 0.05808362737298012, 0.7882723212242126, 0.0, 0.014990360476076603, 0.004210897721350193, 0.002848376054316759, 0.0006518716691061854, 0.0007818753365427256, 0.0019951288122683764, 0.0036728696431964636, 0.0004030312702525407, 0.9704453349113464, 0.0]], [[0.02526121959090233, 0.9527671933174133, 0.014345486648380756, 0.0014051493490114808, 0.003839265089482069, 0.00014350644778460264, 0.0006356940139085054, 0.00025237957015633583, 0.0013501241337507963, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004122408106923103, 0.023777475580573082, 0.9002965688705444, 0.0682864859700203, 0.0017659803852438927, 0.0001271881628781557, 0.00011044178245356306, 0.0001890352723421529, 0.0013242338318377733, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.841444650897756e-05, 0.0002895947836805135, 0.06307922303676605, 0.9069769978523254, 0.028407124802470207, 0.000558151863515377, 0.00022284295118879527, 0.00018588549573905766, 0.00019132612214889377, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.889026179924258e-06, 3.9712713260087185e-06, 0.001210480579175055, 0.003201226470991969, 0.8290116786956787, 0.16640713810920715, 0.00015829727635718882, 4.0429063119518105e-06, 9.256136763724498e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.000399262469727546, 5.1438626542221755e-05, 0.0001944842515513301, 0.0007700449787080288, 0.4879837930202484, 0.4847603738307953, 0.025640420615673065, 0.00018376839580014348, 1.6383723050239496e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [4.30414620495867e-05, 1.017293288896326e-05, 8.407413588429336e-06, 5.451946094581217e-07, 0.000544070964679122, 0.021075371652841568, 0.9573339819908142, 0.0208626389503479, 0.00012169074034318328, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00043880229350179434, 0.0004488519043661654, 0.000600603292696178, 1.4583132212919736e-07, 3.6701523640658706e-05, 0.010162030346691608, 0.37363454699516296, 0.559087336063385, 0.0555914081633091, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0010709260823205113, 0.0006920771556906402, 0.0016655249055474997, 0.00010216240480076522, 1.0821948308148421e-05, 2.6151516067329794e-05, 0.01446994487196207, 0.2987785339355469, 0.6831837296485901, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0002485924051143229, 0.00016839140153024346, 0.019545644521713257, 0.016785046085715294, 0.005671702325344086, 0.00014030851889401674, 0.001185068627819419, 0.04272715002298355, 0.9135279655456543, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0039028520695865154, 0.0008621322922408581, 0.02400260791182518, 0.35541704297065735, 0.048350416123867035, 0.00013779231812804937, 0.00015075977717060596, 0.0015127401566132903, 0.5656636953353882, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01597539149224758, 0.027860743924975395, 0.08824922889471054, 0.011547067202627659, 0.02896539680659771, 0.03845160827040672, 0.011409634724259377, 0.043791815638542175, 0.7337491512298584, 0.0, 0.21779413521289825, 0.08220235258340836, 0.04201545566320419, 0.07069981843233109, 0.041075702756643295, 0.13784317672252655, 0.1975526064634323, 0.04344295710325241, 0.16737376153469086, 0.0], [0.0371943861246109, 0.014876782894134521, 0.02253115549683571, 0.10164438933134079, 0.029471710324287415, 0.040005166083574295, 0.020577073097229004, 0.07326765358448029, 0.6604316830635071, 0.0, 0.23605762422084808, 0.07441659271717072, 0.04143041744828224, 0.05435749515891075, 0.0077708023600280285, 0.0960790365934372, 0.4399828016757965, 0.006641789805144072, 0.04326343908905983, 0.0], [0.06676606088876724, 0.1320837438106537, 0.02368331328034401, 0.09289334714412689, 0.06407851725816727, 0.007657648529857397, 0.014540987089276314, 0.018603011965751648, 0.5796933174133301, 0.0, 0.06337786465883255, 0.03357791155576706, 0.03929098695516586, 0.5017232298851013, 0.0066258725710213184, 0.009236367419362068, 0.1690734624862671, 0.0422079935669899, 0.13488635420799255, 0.0], [0.029496638104319572, 0.013616771437227726, 0.030488401651382446, 0.021259615197777748, 0.13049498200416565, 0.06418323516845703, 0.050123173743486404, 0.1609034240245819, 0.4994336664676666, 0.0, 0.006272959988564253, 0.0007428607787005603, 0.0011506476439535618, 0.007357995491474867, 0.0006080326274968684, 0.05679970234632492, 0.8685706257820129, 0.03271445259451866, 0.025782890617847443, 0.0], [0.010230573825538158, 0.015954630449414253, 0.007779641076922417, 0.018425902351737022, 0.021085364744067192, 0.0588817335665226, 0.013979516923427582, 0.0252523310482502, 0.828410267829895, 0.0, 0.041861388832330704, 0.004794578067958355, 0.0024879220873117447, 0.015253551304340363, 0.0005973980878479779, 0.08281483501195908, 0.814189076423645, 0.006639576051384211, 0.03136153519153595, 0.0], [0.02648993395268917, 0.0214377511292696, 0.03494586795568466, 0.05471349507570267, 0.09140968322753906, 0.04952282831072807, 0.05564551055431366, 0.11169540882110596, 0.5541394948959351, 0.0, 0.010862020775675774, 0.0008270516409538686, 0.00023008826246950775, 0.006298262160271406, 0.0022151959128677845, 0.09469958394765854, 0.8416994214057922, 0.0006256845663301647, 0.04254243150353432, 0.0], [0.03231878578662872, 0.018621357157826424, 0.05183127149939537, 0.03979233279824257, 0.13804322481155396, 0.03567919135093689, 0.047386858612298965, 0.13114488124847412, 0.505182147026062, 0.0, 0.00024508681963197887, 3.835038296529092e-05, 2.0304802092141472e-05, 0.00012946058996021748, 0.0003255259362049401, 0.0026247953064739704, 0.9805192947387695, 0.00014136231038719416, 0.01595580205321312, 0.0], [0.04592716693878174, 0.010993612930178642, 0.01772226020693779, 0.05332585424184799, 0.15264220535755157, 0.22139224410057068, 0.048004403710365295, 0.12396018952131271, 0.3260320723056793, 0.0, 0.001919803791679442, 0.0005674636922776699, 0.0002780239738058299, 0.0008655164856463671, 0.0013816945720463991, 0.010561172850430012, 0.05357982590794563, 0.0009362901910208166, 0.9299100637435913, 0.0], [0.03168570622801781, 0.026294516399502754, 0.025469979271292686, 0.03026771917939186, 0.058515094220638275, 0.13361068069934845, 0.026259208098053932, 0.0612059161067009, 0.6066910624504089, 0.0, 0.00319756381213665, 0.0005108749028295279, 0.00043022894533351064, 0.005312783177942038, 0.005197612568736076, 0.008492776192724705, 0.05858352780342102, 0.01401757076382637, 0.9042569398880005, 0.0], [0.07492455840110779, 0.06428299844264984, 0.07022737711668015, 0.0507473424077034, 0.0447908453643322, 0.060839906334877014, 0.14463475346565247, 0.054812539368867874, 0.4347396492958069, 0.0, 0.00021474930690601468, 0.0004951281007379293, 0.00032367443782277405, 0.0001866286911536008, 6.129321263870224e-05, 0.00016246296581812203, 0.0016925180098041892, 0.000427676277467981, 0.996435821056366, 0.0]]], [[[0.09929531812667847, 0.3125585615634918, 0.26699960231781006, 0.036189958453178406, 0.01689508929848671, 0.05626463145017624, 0.014853590168058872, 0.021625356748700142, 0.17531771957874298, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6598999500274658, 0.04883529245853424, 0.24573534727096558, 0.008949915878474712, 0.008034803904592991, 0.0058951652608811855, 0.001835338887758553, 0.0024289200082421303, 0.018385181203484535, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.28377673029899597, 0.4307016134262085, 0.19275489449501038, 0.05968217924237251, 0.007509235758334398, 0.00627214927226305, 0.0010254314402118325, 0.0010938378982245922, 0.017183959484100342, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00751571636646986, 0.01881357654929161, 0.9318985342979431, 0.014481762424111366, 0.02105659246444702, 0.0032304797787219286, 0.00013498679618351161, 2.4857494281604886e-05, 0.0028432777617126703, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08691340684890747, 0.01259385235607624, 0.21131311357021332, 0.15839329361915588, 0.3931293189525604, 0.10845079272985458, 0.004768806044012308, 0.0032348930835723877, 0.021202562376856804, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029192518442869186, 0.06438057869672775, 0.033022571355104446, 0.04279496520757675, 0.6011855006217957, 0.17385539412498474, 0.03754284232854843, 0.006468524225056171, 0.011557108722627163, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006125382613390684, 0.006982659921050072, 0.004575703293085098, 0.0037440320011228323, 0.36007580161094666, 0.5409486889839172, 0.0626324936747551, 0.00843171589076519, 0.006483553443104029, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0017123871948570013, 0.017555760219693184, 0.012620777823030949, 0.00947127677500248, 0.08178496360778809, 0.2538650631904602, 0.19189175963401794, 0.255443274974823, 0.17565478384494781, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02615528553724289, 0.002552631078287959, 0.01957615464925766, 0.021708596497774124, 0.008856788277626038, 0.021813882514834404, 0.052812058478593826, 0.19690369069576263, 0.6496209502220154, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004899451043456793, 0.005663626827299595, 0.012920243665575981, 0.007757777348160744, 0.014441648498177528, 0.021742597222328186, 0.05050418898463249, 0.35952994227409363, 0.5225404500961304, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9642227292060852, 0.035777393728494644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9262088537216187, 0.07379112392663956, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9523521065711975, 0.027811188250780106, 0.019836684688925743, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2983383536338806, 0.576672375202179, 0.12498921155929565, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.849480152130127, 0.03536543622612953, 0.019422976300120354, 0.09573143720626831, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3100782334804535, 0.1274886280298233, 0.5286650061607361, 0.033768050372600555, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.741925060749054, 0.05566684901714325, 0.024736514315009117, 0.08595114946365356, 0.09172046929597855, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3118414282798767, 0.11087317764759064, 0.12077098339796066, 0.10916762799024582, 0.34734681248664856, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6503966450691223, 0.0582728385925293, 0.0236701387912035, 0.0691222995519638, 0.0758395791053772, 0.12269847840070724, 0.0, 0.0, 0.0, 0.0, 0.1361667662858963, 0.0034004957415163517, 0.00320720998570323, 0.0056303562596440315, 0.013746269047260284, 0.8378488421440125, 0.0, 0.0, 0.0, 0.0], [0.4914315342903137, 0.11739180237054825, 0.02309434488415718, 0.07889512181282043, 0.05101678892970085, 0.12367808818817139, 0.11449223756790161, 0.0, 0.0, 0.0, 0.9168469905853271, 0.009582683444023132, 0.002923850901424885, 0.009140468202531338, 0.0233402531594038, 0.01968987099826336, 0.01847577467560768, 0.0, 0.0, 0.0], [0.4262734055519104, 0.07066749036312103, 0.024391667917370796, 0.04879573732614517, 0.051445234566926956, 0.1276569813489914, 0.11843930184841156, 0.13233007490634918, 0.0, 0.0, 0.4528708755970001, 0.012551077641546726, 0.013286955654621124, 0.003301329677924514, 0.024005549028515816, 0.0439622700214386, 0.03865182027220726, 0.41137006878852844, 0.0, 0.0], [0.589878499507904, 0.026613032445311546, 0.020459800958633423, 0.028271155431866646, 0.03679497539997101, 0.07860217243432999, 0.08500825613737106, 0.09285575151443481, 0.04151623696088791, 0.0, 0.06380993872880936, 0.0008893097401596606, 0.0011801879154518247, 0.0013187900185585022, 0.0034512828569859266, 0.0014297974994406104, 0.0023058890365064144, 0.041651248931884766, 0.8839635848999023, 0.0], [0.2743179202079773, 0.06089583784341812, 0.03565794974565506, 0.044920988380908966, 0.03933599591255188, 0.18495218455791473, 0.09192009270191193, 0.13160176575183868, 0.04121606424450874, 0.09518115967512131, 0.5330018997192383, 0.012773798778653145, 0.01854255609214306, 0.022641947492957115, 0.1288023591041565, 0.01178218238055706, 0.020595960319042206, 0.08756020665168762, 0.09921147674322128, 0.06508753448724747]], [[0.8470081686973572, 0.043761640787124634, 0.000660977209918201, 0.00018918802379630506, 0.01478277612477541, 0.00942840613424778, 0.06798462569713593, 0.011217072606086731, 0.004967056680470705, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9998846054077148, 9.298400982515886e-05, 7.557733283647394e-08, 4.2952964861113496e-13, 4.9295836510032665e-12, 3.2098330660090824e-09, 5.042555585532682e-06, 1.7450745872338302e-05, 2.33268380611662e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.118646625604015e-05, 0.9999122619628906, 6.629392737522721e-05, 1.312590147684034e-09, 2.7011800782239526e-11, 6.488713510726871e-14, 1.250517189799183e-10, 3.650779589747799e-08, 2.9122876554765753e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.1949000816580124e-11, 3.2456850362905243e-07, 1.0, 3.0732459777027543e-07, 4.943382370115046e-10, 1.2582140899967535e-17, 7.485076299292317e-18, 2.998638596002183e-14, 1.3861908843004755e-10, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.382360668271247e-10, 8.056646905174603e-09, 0.00035429277340881526, 0.9995232820510864, 0.00012279135989956558, 1.6631793720023325e-09, 1.8857353897253244e-14, 9.284229879032505e-15, 1.8321206097376974e-12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.614902194392648e-12, 3.5818106835540375e-13, 4.029543365646759e-09, 3.1193526410788763e-06, 0.9959417581558228, 0.004055640660226345, 2.0883923923520342e-08, 1.5150488692381933e-14, 1.8145465705242968e-17, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.3006167283734502e-12, 4.150501252094593e-15, 2.9068709245239077e-12, 2.726213081238188e-13, 1.0724114645199734e-06, 0.9999104142189026, 8.954491204349324e-05, 3.77386955019432e-10, 8.537545242676776e-16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.656632632941808e-10, 2.8593680201360883e-10, 4.910126749635424e-10, 3.37084723469553e-15, 1.3075121541028523e-10, 0.0003027402563020587, 0.999218225479126, 0.00047932929010130465, 1.4258912273135138e-08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0133464911632473e-07, 1.7307414168499236e-07, 2.3342326471720298e-07, 4.688030020606748e-13, 1.5028331227032177e-12, 5.3876938466146385e-09, 0.00158107269089669, 0.994592010974884, 0.0038271904923021793, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.33300490037891e-10, 1.2628836998374027e-07, 1.2948551102454076e-06, 3.169647599943204e-10, 1.5141217069741288e-14, 8.21656009561151e-15, 2.347289251858342e-09, 0.0025180077645927668, 0.9974797964096069, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9842625260353088, 0.015737490728497505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9422653913497925, 0.057734500616788864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8382691144943237, 0.11647694557905197, 0.04525385797023773, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.37070432305336, 0.2449311465024948, 0.3843645751476288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4638526439666748, 0.1585947573184967, 0.3189436197280884, 0.0586090050637722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5423898100852966, 0.11884469538927078, 0.1850128471851349, 0.15375272929668427, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2375488132238388, 0.07284080982208252, 0.20766110718250275, 0.3110494017601013, 0.1708998829126358, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7452426552772522, 0.024770371615886688, 0.025099167600274086, 0.014617366716265678, 0.19027042388916016, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20615516602993011, 0.03705071657896042, 0.05929475650191307, 0.08692343533039093, 0.5564662218093872, 0.05410974845290184, 0.0, 0.0, 0.0, 0.0, 0.4940005838871002, 0.026306116953492165, 0.014163044281303883, 0.022562485188245773, 0.43185216188430786, 0.011115492321550846, 0.0, 0.0, 0.0, 0.0], [0.31913095712661743, 0.011343744583427906, 0.01675090566277504, 0.013238506391644478, 0.06746862828731537, 0.3789318799972534, 0.19313538074493408, 0.0, 0.0, 0.0, 0.8323472142219543, 0.005361876450479031, 0.001218354911543429, 0.0017811520956456661, 0.06672050058841705, 0.0179598405957222, 0.07461105287075043, 0.0, 0.0, 0.0], [0.4113273322582245, 0.003934106323868036, 0.003564919577911496, 0.005882325116544962, 0.018547017127275467, 0.18534934520721436, 0.3216978907585144, 0.04969710111618042, 0.0, 0.0, 0.5900163650512695, 0.0016051119891926646, 0.00041884748497977853, 0.002425695303827524, 0.09076588600873947, 0.005809221416711807, 0.03928956016898155, 0.2696692943572998, 0.0, 0.0], [0.07648876309394836, 0.0013769177021458745, 0.001890459912829101, 0.006597061175853014, 0.007926206104457378, 0.013261871412396431, 0.15683594346046448, 0.7190074324607849, 0.016615279018878937, 0.0, 0.14191001653671265, 0.0026981914415955544, 0.000433926354162395, 0.0025318085681647062, 0.0752185806632042, 0.041030533611774445, 0.10226735472679138, 0.6134982705116272, 0.020411266013979912, 0.0], [0.08104224503040314, 0.00045554721145890653, 0.00038501128437928855, 0.0009405335295014083, 0.005597654264420271, 0.0034990713465958834, 0.009850292466580868, 0.0463707260787487, 0.7366765141487122, 0.11518235504627228, 0.9951959252357483, 0.000172812317032367, 0.0011272057890892029, 0.0002565488684922457, 0.001650187186896801, 0.0010172545444220304, 3.585639569791965e-05, 0.00030177918961271644, 2.7251116989646107e-05, 0.00021514984837267548]], [[0.011770328506827354, 0.014021093025803566, 0.10656744986772537, 0.04667313024401665, 0.13704808056354523, 0.04681243374943733, 0.08347266167402267, 0.3310377299785614, 0.22259721159934998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009583584032952785, 0.010384900495409966, 0.09424954652786255, 0.09874095767736435, 0.2214881330728531, 0.08727390319108963, 0.09998933970928192, 0.16299772262573242, 0.21529172360897064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.040493443608284, 0.05296378955245018, 0.12471148371696472, 0.04822944849729538, 0.2201310694217682, 0.13458549976348877, 0.16853223741054535, 0.12866733968257904, 0.08168572932481766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014574799686670303, 0.015747353434562683, 0.011357909068465233, 0.008449763990938663, 0.024292636662721634, 0.06141809746623039, 0.10683716088533401, 0.6414783596992493, 0.1158437430858612, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0041047134436666965, 0.010159346275031567, 0.006441198755055666, 0.009530052542686462, 0.061682768166065216, 0.07391326874494553, 0.3019707202911377, 0.45178085565567017, 0.08041701465845108, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013634801842272282, 0.03774101287126541, 0.015713637694716454, 0.01436087116599083, 0.06650711596012115, 0.06899012625217438, 0.1819150745868683, 0.376579225063324, 0.2245580554008484, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03166442736983299, 0.07015468180179596, 0.1104653850197792, 0.016236137598752975, 0.18190902471542358, 0.08141329884529114, 0.15690769255161285, 0.22899281978607178, 0.12225660681724548, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10994787514209747, 0.08447018265724182, 0.05270976573228836, 0.013435273431241512, 0.06919412314891815, 0.04981343820691109, 0.24833135306835175, 0.2721446752548218, 0.09995320439338684, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39435869455337524, 0.21061576902866364, 0.1085209921002388, 0.004411425907164812, 0.06908565759658813, 0.04562678933143616, 0.02559957653284073, 0.06842028349637985, 0.0733608528971672, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2682938873767853, 0.18270419538021088, 0.12741044163703918, 0.03156330808997154, 0.10574271529912949, 0.0955348014831543, 0.052997197955846786, 0.0821281224489212, 0.05362524837255478, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9800853133201599, 0.019914645701646805, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9959792494773865, 0.004020644351840019, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9159882068634033, 0.02969631738960743, 0.05431551858782768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8763805031776428, 0.06819441169500351, 0.05542506277561188, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6467475295066833, 0.08892705291509628, 0.19796258211135864, 0.06636285036802292, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6675543785095215, 0.035431310534477234, 0.2554236948490143, 0.04159051924943924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9833061099052429, 0.004010406322777271, 0.004914217162877321, 0.0015858567785471678, 0.006183335091918707, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8250302076339722, 0.013232334516942501, 0.10887149721384048, 0.016031241044402122, 0.03683457896113396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9524497389793396, 0.0022862900514155626, 0.000848656112793833, 0.00408557103946805, 0.028177350759506226, 0.012152665294706821, 0.0, 0.0, 0.0, 0.0, 0.14042839407920837, 0.005938003305345774, 0.04128086566925049, 0.01834655925631523, 0.7866368293762207, 0.007369248662143946, 0.0, 0.0, 0.0, 0.0], [0.1907505989074707, 0.026542214676737785, 0.01945381611585617, 0.029287727549672127, 0.057166602462530136, 0.11766232550144196, 0.5591367483139038, 0.0, 0.0, 0.0, 0.3567042350769043, 0.0165000781416893, 0.015264611691236496, 0.010309864766895771, 0.38396307826042175, 0.025359012186527252, 0.1918991357088089, 0.0, 0.0, 0.0], [0.4022328555583954, 0.017193131148815155, 0.01565318927168846, 0.01915702596306801, 0.01739031821489334, 0.16459040343761444, 0.18205313384532928, 0.18172988295555115, 0.0, 0.0, 0.03735272213816643, 0.0005555232055485249, 0.0009066119673661888, 0.003488750196993351, 0.4253699481487274, 0.039391178637742996, 0.3313658535480499, 0.1615692675113678, 0.0, 0.0], [0.9652498960494995, 0.0010482663055881858, 0.0012260396033525467, 0.0009098293376155198, 0.0013901795027777553, 0.0028189055155962706, 0.007343438919633627, 0.018731823191046715, 0.0012814495712518692, 0.0, 0.0020103107672184706, 0.0002689870889298618, 0.0004340466111898422, 0.0009705349220894277, 0.03535917028784752, 0.014057940803468227, 0.07802704721689224, 0.8683921694755554, 0.0004796571738552302, 0.0], [0.18471455574035645, 0.018054824322462082, 0.08812589198350906, 0.00762907462194562, 0.018057269975543022, 0.05247756093740463, 0.03497685119509697, 0.5025416612625122, 0.052323222160339355, 0.04109897091984749, 0.21001528203487396, 0.008917403407394886, 0.08127831667661667, 0.6020672917366028, 0.0504239983856678, 0.01106872595846653, 0.002271559089422226, 0.009885885752737522, 0.013363776728510857, 0.010707534849643707]], [[8.027511648833752e-05, 0.0010475717717781663, 0.9977908730506897, 0.0002747455728240311, 0.000536168459802866, 9.231048170477152e-05, 0.00010586588905425742, 1.1979215742030647e-05, 5.969347330392338e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00012679747305810452, 5.715776205761358e-05, 0.922791600227356, 0.07177212089300156, 0.002934361109510064, 0.0005548547487705946, 0.001313770073466003, 2.2278460164670832e-05, 0.0004267726035322994, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0063565499149262905, 0.0009426671313121915, 0.23976103961467743, 0.6402719020843506, 0.019077658653259277, 0.04590805247426033, 0.0423574335873127, 0.00055616011377424, 0.0047685266472399235, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00012164804502390325, 1.1780298336816486e-05, 0.0001827587402658537, 0.00020120454428251833, 0.9978508353233337, 0.0014421044616028666, 6.411068170564249e-05, 4.628768147085793e-05, 7.896547322161496e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03763079643249512, 0.00208932813256979, 0.0006042887107469141, 0.5138440728187561, 0.19755180180072784, 0.029773280024528503, 0.15554653108119965, 0.015671545639634132, 0.0472884401679039, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.8805592339485884e-05, 1.2464041901694145e-05, 9.030352521222085e-05, 1.7544094589538872e-05, 0.0006991567788645625, 0.039246365427970886, 0.9305517077445984, 0.02403487078845501, 0.005308609921485186, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003011370776221156, 0.005974559113383293, 0.003425326431170106, 0.001937237335368991, 0.01794668287038803, 0.06517820060253143, 0.25853174924850464, 0.28359606862068176, 0.3603990077972412, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0019687232561409473, 0.0019828693475574255, 0.0009621239732950926, 0.0017320939805358648, 0.008526722900569439, 0.012685983441770077, 0.060781437903642654, 0.38653799891471863, 0.524821937084198, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06319467723369598, 0.3812802731990814, 0.07775641977787018, 0.0546053946018219, 0.0410320870578289, 0.010218034498393536, 0.022281788289546967, 0.04868403077125549, 0.30094724893569946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06465335935354233, 0.0841824859380722, 0.028003698214888573, 0.01470992248505354, 0.013160775415599346, 0.006258893292397261, 0.003528257366269827, 0.022525515407323837, 0.7629771828651428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9911633133888245, 0.008836665190756321, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8274853825569153, 0.1725146621465683, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9641951322555542, 0.023474374786019325, 0.012330451980233192, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.39722761511802673, 0.5465205311775208, 0.05625181272625923, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6152319312095642, 0.28041696548461914, 0.04906271770596504, 0.05528838559985161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7089572548866272, 0.12511004507541656, 0.08669630438089371, 0.0792364850640297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6057276725769043, 0.1235719844698906, 0.06170117110013962, 0.11151555925607681, 0.0974835753440857, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9339975714683533, 0.013466393575072289, 0.00928713008761406, 0.00507207540795207, 0.03817704692482948, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6386814713478088, 0.07927443087100983, 0.06004401296377182, 0.06398510187864304, 0.06341437995433807, 0.09460049122571945, 0.0, 0.0, 0.0, 0.0, 0.7470325231552124, 0.0030789184384047985, 0.0006101431790739298, 0.009402818977832794, 0.23476918041706085, 0.005106179974973202, 0.0, 0.0, 0.0, 0.0], [0.13321073353290558, 0.0565485954284668, 0.20425985753536224, 0.10307760536670685, 0.17957380414009094, 0.26328328251838684, 0.06004612147808075, 0.0, 0.0, 0.0, 0.21711143851280212, 0.003716376842930913, 0.00037448908551596105, 0.0019620254170149565, 0.018900232389569283, 0.009617134928703308, 0.7483181953430176, 0.0, 0.0, 0.0], [0.19694660604000092, 0.027736904099583626, 0.05790374055504799, 0.10621010512113571, 0.15510229766368866, 0.2214440256357193, 0.18680275976657867, 0.04785352945327759, 0.0, 0.0, 0.010075456462800503, 5.468959716381505e-05, 5.17756825502147e-06, 5.762913860962726e-05, 0.0005752856959588826, 0.0004235330270603299, 0.004707484506070614, 0.9841007590293884, 0.0, 0.0], [0.08537944406270981, 0.033881768584251404, 0.03968465328216553, 0.08240006119012833, 0.15350975096225739, 0.23219235241413116, 0.22240297496318817, 0.11620921641588211, 0.034339725971221924, 0.0, 0.0014721885090693831, 9.766960283741355e-05, 9.390318155055866e-06, 9.01468301890418e-05, 0.00026504675042815506, 0.0001477079640608281, 0.0007441531051881611, 0.9970147013664246, 0.00015886487381067127, 0.0], [0.06051333248615265, 0.012086840346455574, 0.028373999521136284, 0.07542525231838226, 0.10199770331382751, 0.15039192140102386, 0.20426926016807556, 0.16016273200511932, 0.06537677347660065, 0.14140206575393677, 0.9506397247314453, 0.010028047487139702, 0.0004243685398250818, 0.012790095992386341, 0.006212451495230198, 0.0008045415161177516, 0.0008908100426197052, 0.0004145564162172377, 0.0002187698701163754, 0.01757662557065487]], [[0.00496841873973608, 0.010829150676727295, 0.03283568099141121, 0.009884797036647797, 0.047239795327186584, 0.06476759165525436, 0.11417313665151596, 0.6207002401351929, 0.09460126608610153, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014457895420491695, 0.06253711134195328, 0.10527490824460983, 0.051058270037174225, 0.04873393103480339, 0.058862265199422836, 0.13390113413333893, 0.44425415992736816, 0.0809202790260315, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09337731450796127, 0.22848238050937653, 0.11594945937395096, 0.04185759648680687, 0.012283656746149063, 0.1264774352312088, 0.19395124912261963, 0.16978387534618378, 0.017837027087807655, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7125841975212097, 0.21987739205360413, 0.020619483664631844, 0.02881826087832451, 0.009833384305238724, 0.004124533850699663, 0.0008098671096377075, 0.0004809961246792227, 0.0028517041355371475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029080189764499664, 0.33611080050468445, 0.12628716230392456, 0.0817737877368927, 0.1908877044916153, 0.0943109318614006, 0.05712011829018593, 0.06781000643968582, 0.016619542613625526, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07309448719024658, 0.07739713788032532, 0.0567743182182312, 0.03291132301092148, 0.16455504298210144, 0.1779973953962326, 0.2714528441429138, 0.13868720829486847, 0.007130389101803303, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2111189365386963, 0.06559138745069504, 0.041267942637205124, 0.009358389303088188, 0.20342323184013367, 0.1869427114725113, 0.19775718450546265, 0.07797932624816895, 0.006560905836522579, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08770362287759781, 0.12808790802955627, 0.023038268089294434, 0.17453545331954956, 0.09798892587423325, 0.11677049100399017, 0.09396524727344513, 0.26174578070640564, 0.01616443321108818, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.35409674048423767, 0.0420590415596962, 0.00930203776806593, 0.3349112272262573, 0.03967892378568649, 0.15319538116455078, 0.022175630554556847, 0.0432865284383297, 0.0012946304632350802, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10030248761177063, 0.08145220577716827, 0.053510215133428574, 0.08076464384794235, 0.07446140050888062, 0.13495147228240967, 0.2503055930137634, 0.17467214167118073, 0.04957977309823036, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5400503277778625, 0.4599496126174927, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9158000946044922, 0.0841999277472496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04321815073490143, 0.9357689023017883, 0.02101275697350502, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9424960017204285, 0.02535107545554638, 0.032153017818927765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48035699129104614, 0.12913382053375244, 0.27151036262512207, 0.11899882555007935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22060541808605194, 0.18997374176979065, 0.08500542491674423, 0.5044154524803162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6920371055603027, 0.019891848787665367, 0.1885785609483719, 0.06273186951875687, 0.036760613322257996, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7531844973564148, 0.02070058509707451, 0.008920542895793915, 0.016695866361260414, 0.20049844682216644, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8527964949607849, 0.08059625327587128, 0.0037265238352119923, 0.008582950569689274, 0.042790722101926804, 0.01150701567530632, 0.0, 0.0, 0.0, 0.0, 0.759453296661377, 0.0056156679056584835, 0.008695651777088642, 0.014426307752728462, 0.16163751482963562, 0.05017174035310745, 0.0, 0.0, 0.0, 0.0], [0.900881826877594, 0.012710069306194782, 0.000794807099737227, 0.00424413476139307, 0.02110898308455944, 0.01962616853415966, 0.04063420742750168, 0.0, 0.0, 0.0, 0.2527230679988861, 0.0006535803549923003, 0.00037003192119300365, 0.00041730765951797366, 0.057080648839473724, 0.06757333129644394, 0.6211821436882019, 0.0, 0.0, 0.0], [0.713775098323822, 0.003081131726503372, 0.000918463512789458, 0.009338468313217163, 0.013423318043351173, 0.019161174073815346, 0.10174864530563354, 0.13855360448360443, 0.0, 0.0, 0.6996693015098572, 0.00526623846963048, 0.003115275641903281, 0.001864676014520228, 0.019210346043109894, 0.022201303392648697, 0.16487717628479004, 0.08379579335451126, 0.0, 0.0], [0.4800099730491638, 0.0009553784620948136, 0.00013007478264626116, 0.020002998411655426, 0.0032414987217634916, 0.002101779682561755, 0.028948260471224785, 0.46123453974723816, 0.0033754503820091486, 0.0, 0.01643717661499977, 0.001304203411564231, 0.00015219511988107115, 8.364384120795876e-05, 0.0027460975106805563, 0.005807426758110523, 0.02910688892006874, 0.054244525730609894, 0.8901176452636719, 0.0], [0.7501513361930847, 0.019767694175243378, 0.0020619838032871485, 0.0038300605956465006, 0.0023455689661204815, 0.023803891614079475, 0.011456847190856934, 0.045016106218099594, 0.08813992142677307, 0.05342674255371094, 0.03737838938832283, 0.0008823095704428852, 0.00013810240488965064, 0.0003819032572209835, 0.0009168537217192352, 0.017434338107705116, 0.0524771511554718, 0.5634113550186157, 0.05003770440816879, 0.27694204449653625]], [[0.140123188495636, 0.010056160390377045, 0.0845566838979721, 0.03108036518096924, 0.16015855967998505, 0.30321791768074036, 0.04101235046982765, 0.0719088688492775, 0.1578858345746994, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6134085655212402, 0.1547522246837616, 0.03818102553486824, 0.001013039844110608, 0.013297338038682938, 0.008754062466323376, 0.005134810693562031, 0.0324203222990036, 0.13303862512111664, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6891250014305115, 0.17779399454593658, 0.09809523820877075, 0.006996517535299063, 0.007719202898442745, 0.0016296659596264362, 0.010662317276000977, 0.004304768517613411, 0.0036729834973812103, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04376668110489845, 0.09640005975961685, 0.8100467324256897, 0.018579678609967232, 0.017539000138640404, 0.0008903089328669012, 0.0009985471842810512, 0.003613307373598218, 0.008165487088263035, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03085213713347912, 0.025543441995978355, 0.6937543153762817, 0.17392684519290924, 0.03124413825571537, 0.02177071012556553, 0.007475809659808874, 0.003389933379366994, 0.012042560614645481, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.020024498924613, 0.002941351616755128, 0.05481509119272232, 0.183584526181221, 0.4182366132736206, 0.25923243165016174, 0.05362166836857796, 0.0045484029687941074, 0.002995501272380352, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006091661751270294, 0.0012010806240141392, 0.008193010464310646, 0.009258490055799484, 0.15450483560562134, 0.7388086915016174, 0.06675267219543457, 0.01373466569930315, 0.0014547830214723945, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0014694302808493376, 0.0017220929730683565, 0.005703628528863192, 0.0032696493435651064, 0.01713697426021099, 0.49356934428215027, 0.3729664385318756, 0.05505490303039551, 0.04910748079419136, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0052343131974339485, 0.004969605710357428, 0.005609327927231789, 0.0007064095698297024, 0.005421568639576435, 0.045942794531583786, 0.22256441414356232, 0.43683722615242004, 0.27271413803100586, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011939328163862228, 0.019054703414440155, 0.010745645500719547, 0.006908759940415621, 0.009522099047899246, 0.006889646407216787, 0.12289831787347794, 0.2292226105928421, 0.5828191637992859, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03494315221905708, 0.965056836605072, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9822245836257935, 0.017775410786271095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.020348060876131058, 0.8944171071052551, 0.08523476868867874, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9783667922019958, 0.004186260513961315, 0.01744689606130123, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015979396412149072, 0.6347042918205261, 0.09008561074733734, 0.27361196279525757, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8277915120124817, 0.0035995396319776773, 0.1268300712108612, 0.04177885130047798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01025437843054533, 0.17247439920902252, 0.3664330542087555, 0.4087805449962616, 0.04205762594938278, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9593387246131897, 0.001320014358498156, 0.002763292985036969, 0.002305841539055109, 0.03427214175462723, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012186901643872261, 0.3028968572616577, 0.12117700278759003, 0.3522109389305115, 0.06255244463682175, 0.14897578954696655, 0.0, 0.0, 0.0, 0.0, 0.5380056500434875, 0.00011044789425795898, 0.001150083844549954, 0.002725756261497736, 0.45681822299957275, 0.0011898496886715293, 0.0, 0.0, 0.0, 0.0], [0.010822800919413567, 0.2333739995956421, 0.11113002151250839, 0.15861180424690247, 0.11286703497171402, 0.2766783833503723, 0.0965159684419632, 0.0, 0.0, 0.0, 0.16147758066654205, 0.001678255619481206, 0.004225697834044695, 0.012547606602311134, 0.4120558202266693, 0.030565770342946053, 0.37744930386543274, 0.0, 0.0, 0.0], [0.00965114776045084, 0.19982098042964935, 0.054301097989082336, 0.13056904077529907, 0.03828747197985649, 0.4827912747859955, 0.05511533096432686, 0.029463520273566246, 0.0, 0.0, 0.07655133306980133, 0.00011485892173368484, 0.0004792730906046927, 0.0037317569367587566, 0.9091346859931946, 0.005207230802625418, 0.003226343309506774, 0.0015543886693194509, 0.0, 0.0], [0.014548483304679394, 0.07520423084497452, 0.1090526208281517, 0.14237697422504425, 0.030428709462285042, 0.5021095275878906, 0.026151562109589577, 0.04390878602862358, 0.05621904134750366, 0.0, 0.0006837816908955574, 6.692374881822616e-05, 3.2170661143027246e-05, 0.017242103815078735, 0.9703013896942139, 0.0009919245494529605, 0.00010187587758991867, 0.00012404048175085336, 0.01045528706163168, 0.0], [0.000422637298470363, 0.17123113572597504, 0.04347287863492966, 0.10408183932304382, 0.013075248338282108, 0.5476951003074646, 0.020964276045560837, 0.019243689253926277, 0.0612923838198185, 0.018520813435316086, 0.8681296706199646, 0.004244405776262283, 0.0034055972937494516, 0.0032342004124075174, 0.11890427023172379, 0.00032322408515028656, 1.7166490579256788e-05, 8.356601756531745e-05, 0.00016651467012707144, 0.0014914675848558545]], [[0.0014003654941916466, 0.00935011450201273, 0.8996742963790894, 0.029868578538298607, 0.05752851441502571, 0.0008847691351547837, 0.0005429417942650616, 0.0004143548430874944, 0.00033632174017839134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0005502321291714907, 0.003854800947010517, 0.8475468754768372, 0.06876953691244125, 0.07909266650676727, 5.498397149494849e-05, 2.1647396351909265e-05, 6.648269391007489e-06, 0.00010276718239765614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0025599629152566195, 0.010113149881362915, 0.21385346353054047, 0.26065483689308167, 0.44287386536598206, 0.0458405464887619, 0.013329384848475456, 0.0076821851544082165, 0.0030928871128708124, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0002600199659354985, 3.3608048397582024e-05, 0.0020931970793753862, 0.007768034934997559, 0.9780486822128296, 0.011327453888952732, 0.00041993538616225123, 4.125805935473181e-05, 8.07127889856929e-06, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0010751935187727213, 0.00017567894246894866, 0.004301255568861961, 0.0010412797564640641, 0.012584774754941463, 0.5903621912002563, 0.36841556429862976, 0.021853862330317497, 0.00019013854034710675, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00036065353197045624, 0.00041391997365280986, 0.00018344201089348644, 1.21664334074012e-05, 0.0008204621262848377, 0.02300320193171501, 0.7380199432373047, 0.23411831259727478, 0.0030676021706312895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0007766868220642209, 0.00179819215554744, 0.0031821478623896837, 1.569229607412126e-05, 0.001023828866891563, 0.004582487046718597, 0.04412461444735527, 0.8326310515403748, 0.11186514794826508, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002560202032327652, 0.0021961459424346685, 0.0012966376962140203, 3.874531466863118e-05, 0.00012789985339622945, 0.00017348439723718911, 0.06046983227133751, 0.07663179188966751, 0.856505274772644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05078713223338127, 0.09524610638618469, 0.03648101165890694, 0.050540339201688766, 0.009611092507839203, 0.0027538249269127846, 0.009690326638519764, 0.015156174078583717, 0.7297340035438538, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017420543357729912, 0.009016300551593304, 0.008660875260829926, 0.04713813588023186, 0.042011067271232605, 0.003162879729643464, 0.00040178498602472246, 0.005153133533895016, 0.8670352697372437, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9947329163551331, 0.005267037078738213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9673911333084106, 0.032608743757009506, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7284466028213501, 0.21829284727573395, 0.05326057970523834, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8945506811141968, 0.048047225922346115, 0.05740200728178024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7024527192115784, 0.0454108789563179, 0.10381712764501572, 0.14831924438476562, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8226539492607117, 0.025171183049678802, 0.033602889627218246, 0.1185719221830368, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2374107390642166, 0.04589728266000748, 0.2683154046535492, 0.3902822434902191, 0.0580943301320076, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7488189339637756, 0.022310951724648476, 0.03220387548208237, 0.05049983412027359, 0.14616648852825165, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7228419780731201, 0.007619804237037897, 0.013993922621011734, 0.04429992660880089, 0.020430808886885643, 0.19081364572048187, 0.0, 0.0, 0.0, 0.0, 0.5947939157485962, 0.009725339710712433, 0.01194794476032257, 0.06678443402051926, 0.22137242555618286, 0.09537594765424728, 0.0, 0.0, 0.0, 0.0], [0.4783930778503418, 0.005506142508238554, 0.008406496606767178, 0.012424511834979057, 0.04335693642497063, 0.17542317509651184, 0.27648961544036865, 0.0, 0.0, 0.0, 0.5493549704551697, 0.010730843059718609, 0.013811847195029259, 0.01375968661159277, 0.13386781513690948, 0.031593821942806244, 0.2468811273574829, 0.0, 0.0, 0.0], [0.056768160313367844, 0.001066300319507718, 0.0015203694347292185, 0.004650356248021126, 0.004999558907002211, 0.17368057370185852, 0.7387632131576538, 0.018551528453826904, 0.0, 0.0, 0.44999176263809204, 0.0022518665064126253, 0.007128801662474871, 0.06941325962543488, 0.11436374485492706, 0.06527625769376755, 0.25339174270629883, 0.038182370364665985, 0.0, 0.0], [0.14709600806236267, 0.007261540275067091, 0.001291902968659997, 0.012605146504938602, 0.005232691299170256, 0.08098926395177841, 0.5304067134857178, 0.207069993019104, 0.00804678164422512, 0.0, 0.6273319125175476, 0.0019851899705827236, 0.014608433470129967, 0.053566914051771164, 0.10037831962108612, 0.05395424738526344, 0.09709113836288452, 0.020020073279738426, 0.031063806265592575, 0.0], [0.15080930292606354, 0.014301316812634468, 0.002821019385010004, 0.02008463814854622, 0.004475536290556192, 0.05297520384192467, 0.27036672830581665, 0.407105028629303, 0.007729486562311649, 0.06933178007602692, 0.13732852041721344, 0.005784862674772739, 0.011142567731440067, 0.3659982979297638, 0.03412118926644325, 0.191008523106575, 0.02493627928197384, 0.01782877929508686, 0.005097466055303812, 0.2067534178495407]], [[0.22553573548793793, 0.2680850327014923, 0.019470686092972755, 0.14175784587860107, 0.053468361496925354, 0.02777918614447117, 0.05628729239106178, 0.04874898120760918, 0.15886712074279785, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.28905513882637024, 0.12247822433710098, 0.046002231538295746, 0.1958596557378769, 0.10771062225103378, 0.06661061197519302, 0.07628067582845688, 0.02713944762945175, 0.06886337697505951, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04905243590474129, 0.05268532782793045, 0.11285670101642609, 0.09091109782457352, 0.24185867607593536, 0.20752739906311035, 0.04222555831074715, 0.05885446071624756, 0.14402832090854645, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06971512734889984, 0.14066818356513977, 0.05942149832844734, 0.21028849482536316, 0.10966084897518158, 0.08002462983131409, 0.10722756385803223, 0.1377343237400055, 0.08525940030813217, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1429702192544937, 0.26978883147239685, 0.12360350787639618, 0.05825580656528473, 0.022957824170589447, 0.2193503975868225, 0.0713224932551384, 0.06461618840694427, 0.02713468112051487, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07554306834936142, 0.051579318940639496, 0.2103901356458664, 0.03246254473924637, 0.12347473949193954, 0.20594589412212372, 0.10415074229240417, 0.14436782896518707, 0.05208563804626465, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10752540081739426, 0.08459899574518204, 0.07340764254331589, 0.019914846867322922, 0.048802055418491364, 0.2628321945667267, 0.23049965500831604, 0.11754198372364044, 0.05487721040844917, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.054300110787153244, 0.03522595763206482, 0.19028180837631226, 0.11526520550251007, 0.043804410845041275, 0.1941872388124466, 0.12765192985534668, 0.19942660629749298, 0.03985673561692238, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13462598621845245, 0.09648311138153076, 0.08205218613147736, 0.241444393992424, 0.024601474404335022, 0.03336581960320473, 0.09252338856458664, 0.0673752948641777, 0.22752824425697327, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1438782811164856, 0.15257491171360016, 0.11015111207962036, 0.2259429395198822, 0.11582648009061813, 0.06522659957408905, 0.06865230947732925, 0.07465960830450058, 0.04308782145380974, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9945669174194336, 0.005433134268969297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9590145349502563, 0.0409853532910347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9554939270019531, 0.02177131362259388, 0.0227347444742918, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13186156749725342, 0.7104970812797546, 0.15764127671718597, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.19059398770332336, 0.7459079623222351, 0.05105874687433243, 0.012439398095011711, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1307007521390915, 0.4791290760040283, 0.2198515087366104, 0.1703186184167862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.062025006860494614, 0.7277394533157349, 0.13110491633415222, 0.028790757060050964, 0.050339892506599426, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25735223293304443, 0.03605807572603226, 0.08834479749202728, 0.21978884935379028, 0.398455947637558, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7678350806236267, 0.007377212401479483, 0.020054306834936142, 0.11815592646598816, 0.07254840433597565, 0.014029012061655521, 0.0, 0.0, 0.0, 0.0, 0.014754761941730976, 0.016280202195048332, 0.010505245067179203, 0.26496851444244385, 0.6780229210853577, 0.015468388795852661, 0.0, 0.0, 0.0, 0.0], [0.8187481760978699, 0.009394909255206585, 0.015446240082383156, 0.012167787179350853, 0.10175905376672745, 0.02721206098794937, 0.01527167297899723, 0.0, 0.0, 0.0, 0.0561433881521225, 0.00821017101407051, 0.013592599891126156, 0.04250938817858696, 0.20505541563034058, 0.637790322303772, 0.03669866546988487, 0.0, 0.0, 0.0], [0.7012083530426025, 0.12151088565587997, 0.03808446228504181, 0.01883355714380741, 0.0837249755859375, 0.006598148960620165, 0.006499246694147587, 0.023540453985333443, 0.0, 0.0, 0.02288638986647129, 0.0031705975998193026, 0.0010986417764797807, 0.1258203089237213, 0.13997967541217804, 0.6275703310966492, 0.004779829643666744, 0.07469423860311508, 0.0, 0.0], [0.5152325630187988, 0.054241329431533813, 0.17093418538570404, 0.020541386678814888, 0.17657014727592468, 0.012641755864024162, 0.01802964322268963, 0.023539982736110687, 0.008269038051366806, 0.0, 0.04480466619133949, 0.007826470769941807, 0.0012622721260413527, 0.18829701840877533, 0.1579897105693817, 0.4087865948677063, 0.0030938636045902967, 0.17715193331241608, 0.010787548497319221, 0.0], [0.9131196141242981, 0.0010915634920820594, 0.006193474866449833, 0.006082434672862291, 0.03542511910200119, 0.006826554890722036, 0.0028478680178523064, 0.004068343434482813, 0.014553201384842396, 0.009791722521185875, 0.2647387683391571, 0.0023117128293961287, 0.5836825370788574, 0.022214042022824287, 0.05302866920828819, 0.05609899014234543, 0.0002153095556423068, 0.0012429821072146297, 0.012765316292643547, 0.0037017168942838907]]], [[[0.008583037182688713, 0.007665919605642557, 0.023932937532663345, 0.013663848862051964, 0.00724611384794116, 0.01780843734741211, 0.04220886155962944, 0.035630952566862106, 0.8432599306106567, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005249040201306343, 0.006725347600877285, 0.022601336240768433, 0.004061485640704632, 0.003380684182047844, 0.05792760103940964, 0.08571713417768478, 0.017759306356310844, 0.796578049659729, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014741344377398491, 0.08626628667116165, 0.11416944116353989, 0.06755448132753372, 0.010767532512545586, 0.037519536912441254, 0.13943251967430115, 0.03284287825226784, 0.4967060387134552, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8946033120155334, 0.07520093768835068, 0.007621173746883869, 0.004705401603132486, 0.005715447012335062, 0.0016736779361963272, 0.0011882666731253266, 0.0005322583019733429, 0.008759708143770695, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17331360280513763, 0.32618802785873413, 0.1865183413028717, 0.12219864875078201, 0.08427056670188904, 0.017049826681613922, 0.027256622910499573, 0.011689829640090466, 0.05151442065834999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.024287043139338493, 0.22289688885211945, 0.2742122411727905, 0.1883603185415268, 0.1339159905910492, 0.04209006950259209, 0.04496186599135399, 0.03600992262363434, 0.033265650272369385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01142946071922779, 0.05564042925834656, 0.055694323033094406, 0.5140662789344788, 0.1435396671295166, 0.038738954812288284, 0.06230159476399422, 0.07060025632381439, 0.047988954931497574, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03956271708011627, 0.0978141501545906, 0.053332336246967316, 0.4993227422237396, 0.15091775357723236, 0.05724353715777397, 0.05616844817996025, 0.014285729266703129, 0.03135249391198158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04081583395600319, 0.017569201067090034, 0.031049959361553192, 0.07860688865184784, 0.1978374421596527, 0.3013133406639099, 0.2561938464641571, 0.010236106812953949, 0.06637723743915558, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005346705671399832, 0.017637349665164948, 0.01670711860060692, 0.027819450944662094, 0.014111858792603016, 0.15744496881961823, 0.29349666833877563, 0.10989060997962952, 0.357545405626297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16448259353637695, 0.17219680547714233, 0.09987642616033554, 0.09012344479560852, 0.06534503400325775, 0.08456553518772125, 0.06690192222595215, 0.08019057661294937, 0.17631761729717255, 0.0, 0.18620921671390533, 0.0449230894446373, 0.15743261575698853, 0.0027164025232195854, 0.000954743183683604, 0.10880818217992783, 0.004260051064193249, 0.4840531051158905, 0.010642877779901028, 0.0], [0.49537378549575806, 0.03979916125535965, 0.09498286247253418, 0.0017974335933104157, 0.028368383646011353, 0.0015277893980965018, 0.014851069077849388, 0.0003722719266079366, 0.3229270279407501, 0.0, 0.10068266838788986, 0.8361198902130127, 0.05278307944536209, 0.003077939385548234, 0.0006954235723242164, 0.001363753923214972, 0.00026539582177065313, 0.004202431067824364, 0.0008096573874354362, 0.0], [0.0031106590759009123, 0.8318147659301758, 0.0329316072165966, 0.00014872441533952951, 0.000739947019610554, 0.0009879706194624305, 0.0012947155628353357, 0.00040531408740207553, 0.128566175699234, 0.0, 0.012129311449825764, 0.01155073568224907, 0.9600933194160461, 8.282387716462836e-05, 1.0725593710958492e-05, 0.0005505315493792295, 8.825069380691275e-05, 0.015057343989610672, 0.00043726651347242296, 0.0], [3.727031798916869e-05, 0.00033458907273598015, 0.9051278829574585, 0.014809494838118553, 0.0013665216974914074, 0.0009820980485528708, 0.0004274636448826641, 0.0006300737150013447, 0.07628484070301056, 0.0, 8.100323611870408e-05, 0.0004598332743626088, 0.004657193087041378, 0.000634010590147227, 0.00027469659107737243, 0.005632649641484022, 0.000647437758743763, 0.9867796301841736, 0.0008332319557666779, 0.0], [2.789895370369777e-05, 7.413508137688041e-05, 0.00011113573418697342, 0.9593441486358643, 0.023210706189274788, 0.00043970797560177743, 0.00011651179374894127, 0.0001221746060764417, 0.016553271561861038, 0.0, 0.00010327257041353732, 8.895192149793729e-05, 0.0004001102061010897, 3.5898548958357424e-05, 8.903054549591616e-06, 0.002168947132304311, 0.0003314291825518012, 0.9968016743659973, 6.082480831537396e-05, 0.0], [5.518151283467887e-06, 4.040239218738861e-06, 4.706911568064243e-06, 0.0001475349417887628, 0.0011833186727017164, 0.007331210654228926, 0.0003812467912212014, 0.7072276473045349, 0.28371480107307434, 0.0, 0.0006819640402682126, 0.0025551444850862026, 0.029635878279805183, 0.0007182788685895503, 0.0009121407056227326, 0.9391846656799316, 0.0023257755674421787, 0.020892569795250893, 0.0030933902598917484, 0.0], [2.1062598989374237e-06, 1.0153020184588968e-06, 9.153064297606761e-07, 2.3557351596537046e-05, 0.0019158869981765747, 0.9726926684379578, 0.0003360892878845334, 0.008161749690771103, 0.01686590164899826, 0.0, 0.0006610184791497886, 0.004029686562716961, 0.03350083529949188, 0.0028945906087756157, 0.06891647726297379, 0.0361749529838562, 0.6805889010429382, 0.0015104033518582582, 0.17172299325466156, 0.0], [1.876308124337811e-05, 3.1762643629917875e-05, 7.612020908709383e-06, 4.369785983726615e-06, 0.00035698129795491695, 0.006292039528489113, 0.9372867941856384, 0.0028216273058205843, 0.0531802624464035, 0.0, 0.00011510718468343839, 0.00041600633994676173, 0.007651225198060274, 0.0003919293521903455, 0.048794399946928024, 0.12390702962875366, 0.005600529722869396, 0.0008058404200710356, 0.8123176097869873, 0.0], [0.00017082327394746244, 0.0008267413941211998, 0.0010992212919518352, 0.016357675194740295, 0.03317699581384659, 0.013446258381009102, 0.022417983040213585, 0.0993492603302002, 0.813154935836792, 0.0, 0.0003188557457178831, 0.0017433647299185395, 0.0013032852439209819, 0.008202485740184784, 0.26753997802734375, 0.1699969321489334, 0.02015369012951851, 0.026912324130535126, 0.5038290619850159, 0.0], [2.095436911986326e-06, 1.0510404990782263e-06, 8.745904779061675e-06, 9.465758921578526e-05, 0.9096792936325073, 0.004888555034995079, 0.00019891942793037742, 0.00012723646068479866, 0.08499950170516968, 0.0, 0.020566454157233238, 0.12752646207809448, 0.13235142827033997, 8.515831723343581e-05, 0.0007726486655883491, 0.005525102838873863, 0.002064254367724061, 0.0015006973408162594, 0.7096077799797058, 0.0]], [[0.14326919615268707, 0.06937730312347412, 0.4621289074420929, 0.06899607926607132, 0.20691490173339844, 0.03204977884888649, 0.010433961637318134, 0.001572124194353819, 0.005257652141153812, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7372201681137085, 0.03819188475608826, 0.19263039529323578, 0.00509582320228219, 0.014029700309038162, 0.004338367842137814, 0.0016640998655930161, 0.0023727945517748594, 0.004456941969692707, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6392468810081482, 0.09436309337615967, 0.23124097287654877, 0.009032140485942364, 0.016629014164209366, 0.004053707234561443, 0.0011662752367556095, 0.0013368013314902782, 0.0029307324439287186, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15959776937961578, 0.060010410845279694, 0.6323540210723877, 0.04208587482571602, 0.09941276162862778, 0.001314919558353722, 0.0003186642425134778, 0.00045829309965483844, 0.004447522107511759, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06331828236579895, 0.03697410970926285, 0.6882537603378296, 0.04094800353050232, 0.1500014215707779, 0.014815385453402996, 0.0006663103122264147, 0.0014023728435859084, 0.0036205528303980827, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02740752510726452, 0.007235638331621885, 0.2575177550315857, 0.2825733423233032, 0.26921361684799194, 0.13694509863853455, 0.012512636370956898, 0.00419765617698431, 0.0023968773894011974, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026527998968958855, 0.0014296816661953926, 0.0034867397043854, 0.11850380897521973, 0.15826237201690674, 0.4342584013938904, 0.21162042021751404, 0.04376554489135742, 0.0021449460182338953, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0008783259545452893, 0.0010965524706989527, 0.006981557235121727, 0.007060014642775059, 0.27200379967689514, 0.45634904503822327, 0.1935150921344757, 0.03130912408232689, 0.030806703492999077, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012816469185054302, 0.004784241784363985, 0.007290879264473915, 0.0027244724333286285, 0.0388973169028759, 0.12052476406097412, 0.3920805752277374, 0.10759556293487549, 0.3132855296134949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0021361028775572777, 0.003133963793516159, 0.003311034757643938, 0.0013810866512358189, 0.004479007329791784, 0.007041627541184425, 0.09507600963115692, 0.5596640706062317, 0.32377713918685913, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09510962665081024, 0.13984361290931702, 0.01835908181965351, 0.05623754486441612, 0.05484192445874214, 0.02751241996884346, 0.023350151255726814, 0.02046714909374714, 0.5642784833908081, 0.0, 0.08830718696117401, 0.003260435536503792, 0.007942354306578636, 0.007197668310254812, 0.023230358958244324, 0.6884769797325134, 0.13524922728538513, 0.013760159723460674, 0.03257569298148155, 0.0], [0.32246580719947815, 0.12212380021810532, 0.0033711090218275785, 0.41883695125579834, 0.0010050723794847727, 0.00026374190929345787, 0.00840060692280531, 0.0003199145139660686, 0.12321317940950394, 0.0, 0.01410764642059803, 0.011476421728730202, 0.655226469039917, 0.029443562030792236, 0.17404575645923615, 0.04738258570432663, 0.035108331590890884, 0.004049936309456825, 0.02915901131927967, 0.0], [0.1343918889760971, 0.42756012082099915, 0.03016146458685398, 0.27197346091270447, 0.0008738918695598841, 0.00041738885920494795, 0.0011337834876030684, 0.0017680631717666984, 0.13172008097171783, 0.0, 0.006112441886216402, 0.010383019223809242, 0.9739192724227905, 0.0017695348942652345, 0.0007649966282770038, 0.001380802714265883, 0.0003705607377924025, 0.00034036929719150066, 0.004958811681717634, 0.0], [4.970032023265958e-05, 0.0002945268643088639, 0.9929893612861633, 0.006102537736296654, 1.304412307945313e-06, 7.552243459940655e-06, 2.0433815279830014e-06, 1.4308750905911438e-05, 0.0005390164442360401, 0.0, 0.025388794019818306, 0.006199578754603863, 0.10192698240280151, 0.0023500584065914154, 0.009979050606489182, 0.5388055443763733, 0.29305511713027954, 0.002850176068022847, 0.0194447822868824, 0.0], [0.0006735534407198429, 0.0037932321429252625, 0.014864870347082615, 0.9520841240882874, 0.0031083461362868547, 0.0014454165939241648, 0.000881638377904892, 0.00042032121564261615, 0.02272843010723591, 0.0, 0.0011180925648659468, 3.349311737110838e-05, 0.00020844468963332474, 0.00016400347521994263, 0.001158660277724266, 0.5398337244987488, 0.4514371454715729, 0.00012239665375091136, 0.005924074444919825, 0.0], [1.054488166118972e-06, 5.819076250190847e-06, 3.686256491164386e-07, 5.7184315664926544e-05, 1.600286668690387e-05, 0.0002979082928504795, 5.8259040088159963e-05, 0.997514009475708, 0.0020495890639722347, 0.0, 4.934398384648375e-05, 6.905893883413228e-07, 5.809057256556116e-06, 1.44853029269143e-05, 0.0013859024038538337, 0.62599116563797, 0.3719564974308014, 0.0002632574178278446, 0.00033293903106823564, 0.0], [1.2081607110303594e-06, 1.8248301785206422e-06, 3.5412674037615943e-07, 0.00017610432405490428, 0.0004308871575631201, 0.9919483065605164, 0.001251595327630639, 0.004008213523775339, 0.002181792864575982, 0.0, 1.8935834305011667e-05, 5.593590231001144e-06, 9.02482042874908e-06, 4.666295353672467e-05, 0.00140501803252846, 0.0024830379988998175, 0.9939435124397278, 0.00030495785176754, 0.0017833412857726216, 0.0], [1.3394396773946937e-06, 1.858925656961219e-06, 8.99223309147601e-08, 5.498410246218555e-06, 4.1167979361489415e-05, 0.003499603597447276, 0.9961592555046082, 8.322765097545926e-06, 0.0002831367892213166, 0.0, 0.00015082204481586814, 9.979225069400854e-06, 0.00013493606820702553, 0.0006857623811811209, 0.9507938623428345, 0.013522839173674583, 0.004887807182967663, 0.001293701701797545, 0.028520429506897926, 0.0], [0.0011697824811562896, 0.00207342766225338, 0.0001985222043003887, 0.24218614399433136, 0.2580603361129761, 0.03422079235315323, 0.3017951250076294, 0.0700761154294014, 0.09021952003240585, 0.0, 0.00021830093464814126, 1.1190621080459096e-05, 0.0010014179861173034, 0.0016852812841534615, 0.9693949818611145, 0.003066261066123843, 0.002616706071421504, 0.006246546749025583, 0.015759343281388283, 0.0], [4.897859540164973e-08, 1.9182496657776937e-07, 1.6890984966266842e-07, 0.00012898082786705345, 0.9986647963523865, 0.0003688811557367444, 8.465539576718584e-05, 1.2611121746886056e-05, 0.0007397857843898237, 0.0, 0.033513687551021576, 0.047761499881744385, 0.1371326446533203, 0.027179328724741936, 0.07905351370573044, 0.04665757715702057, 0.017991477623581886, 0.0258343443274498, 0.5848759412765503, 0.0]], [[0.001748488168232143, 0.011698327027261257, 0.047558922320604324, 0.7770814299583435, 0.15215088427066803, 0.0056790816597640514, 0.0010312696686014533, 0.0011229184456169605, 0.0019287114264443517, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.000820137036498636, 0.0007328591891564429, 0.012266330420970917, 0.94822758436203, 0.02221596986055374, 0.006038068328052759, 0.0018012026557698846, 0.002194090047851205, 0.0057037402875721455, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0017187671037390828, 0.0012595502194017172, 0.00971528235822916, 0.8996129631996155, 0.03184645250439644, 0.026646586135029793, 0.01671759784221649, 0.005960865877568722, 0.006522092968225479, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.010048117488622665, 0.003920346032828093, 0.01464000903069973, 0.028398782014846802, 0.047600653022527695, 0.6803404688835144, 0.07394693046808243, 0.046145662665367126, 0.09495888650417328, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0020061242394149303, 0.0010488562984392047, 0.0021137045696377754, 0.03403143212199211, 0.040159616619348526, 0.4656003415584564, 0.16990402340888977, 0.16164875030517578, 0.12348736822605133, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0023888982832431793, 0.0010238748509436846, 0.0031129145063459873, 0.00400560162961483, 0.005227341782301664, 0.050918273627758026, 0.28773385286331177, 0.5181463956832886, 0.12744267284870148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0057381619699299335, 0.0037375285755842924, 0.006655727047473192, 0.0010085925459861755, 0.005980721674859524, 0.02943945676088333, 0.05893365666270256, 0.6100658774375916, 0.2784405052661896, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003593636676669121, 0.0024473541416227818, 0.002264569513499737, 0.00914584007114172, 0.0013253247598186135, 0.010908454656600952, 0.07958614826202393, 0.12585432827472687, 0.7648744583129883, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031058229506015778, 0.02174283377826214, 0.012145284563302994, 0.010826506651937962, 0.01352943666279316, 0.021966811269521713, 0.055832888931035995, 0.11603516340255737, 0.7168627977371216, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20383700728416443, 0.06762446463108063, 0.042199794203042984, 0.021983252838253975, 0.11625738441944122, 0.013579235412180424, 0.025292381644248962, 0.08914806693792343, 0.4200783669948578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008738831616938114, 0.010689073242247105, 0.010104849003255367, 0.025418052449822426, 0.008787600323557854, 0.018541773781180382, 0.01414045225828886, 0.009587875567376614, 0.8939914107322693, 0.0, 0.3675236701965332, 0.22013956308364868, 0.3048599064350128, 0.045011524111032486, 0.013697491027414799, 0.012050136923789978, 0.009531261399388313, 0.0020223394967615604, 0.025163909420371056, 0.0], [0.050771377980709076, 0.08173098415136337, 0.03076810948550701, 0.6816214919090271, 0.04326915368437767, 0.0030209666583687067, 0.006032166071236134, 0.007633579429239035, 0.09515213221311569, 0.0, 0.013416368514299393, 0.7244334816932678, 0.22923606634140015, 0.004823721945285797, 0.0007022434147074819, 0.0012150612892583013, 0.001360778696835041, 0.00021415007358882576, 0.024598030373454094, 0.0], [0.04749365150928497, 0.07148067653179169, 0.018722670152783394, 0.5845115184783936, 0.03816590458154678, 0.003933309111744165, 0.006466464139521122, 0.021205652505159378, 0.20802012085914612, 0.0, 0.03640636429190636, 0.024720389395952225, 0.8944843411445618, 0.0018058173591271043, 0.00014742508938070387, 0.002046161564067006, 0.0012721297098323703, 0.0010774562833830714, 0.0380399152636528, 0.0], [0.021572547033429146, 0.11727327853441238, 0.03622674569487572, 0.4274545907974243, 0.05620160698890686, 0.01161592174321413, 0.010393376462161541, 0.014363090507686138, 0.30489882826805115, 0.0, 0.032080236822366714, 0.02157183177769184, 0.017530914396047592, 0.21374234557151794, 0.5176447033882141, 0.021586988121271133, 0.06124785542488098, 0.004810539539903402, 0.10978466272354126, 0.0], [0.015270093455910683, 0.10013995319604874, 0.006727923639118671, 0.19538360834121704, 0.1119888573884964, 0.027630485594272614, 0.0700199231505394, 0.01868581771850586, 0.4541531801223755, 0.0, 0.16469916701316833, 0.0144515885040164, 0.007452514488250017, 0.029052020981907845, 0.2643658220767975, 0.1970161497592926, 0.2818319797515869, 0.016781603917479515, 0.024349281564354897, 0.0], [0.00540963327512145, 0.07916348427534103, 0.01957465149462223, 0.49324244260787964, 0.10871188342571259, 0.02422497235238552, 0.008650544099509716, 0.16292543709278107, 0.0980970561504364, 0.0, 0.025996195152401924, 0.005627068690955639, 0.007119623012840748, 0.004898787476122379, 0.5349600911140442, 0.05678911507129669, 0.3094601333141327, 0.008422048762440681, 0.04672713205218315, 0.0], [0.027941647917032242, 0.005471521522849798, 0.006384703796356916, 0.03924928605556488, 0.22657036781311035, 0.21837352216243744, 0.3372570872306824, 0.05897291377186775, 0.07977905124425888, 0.0, 0.004280757624655962, 0.0006373892538249493, 9.946383943315595e-05, 0.00030879577388986945, 0.02805289998650551, 0.008433223702013493, 0.9252934455871582, 0.001439885818399489, 0.03145414590835571, 0.0], [0.009049936197698116, 0.005020579323172569, 0.014692768454551697, 0.15799382328987122, 0.4401932656764984, 0.1766415536403656, 0.03136269003152847, 0.12063619494438171, 0.044409021735191345, 0.0, 0.04426492750644684, 0.0032368048559874296, 0.0014763016952201724, 0.0021763627883046865, 0.5636131763458252, 0.010265699587762356, 0.08146306872367859, 0.003517861943691969, 0.289985716342926, 0.0], [0.0007816475699655712, 0.0003147682291455567, 0.0032215022947639227, 0.4467180669307709, 0.3918246924877167, 0.00227341428399086, 0.004370422102510929, 0.14414219558238983, 0.006353371310979128, 0.0, 0.012160537764430046, 0.00020874926121905446, 0.0005602578166872263, 0.0007960868533700705, 0.9389106035232544, 0.005963308271020651, 0.005384649150073528, 0.0009963578777387738, 0.035019390285015106, 0.0], [0.0005489268223755062, 0.016601460054516792, 0.01341363787651062, 0.2753817141056061, 0.13981539011001587, 0.04711242765188217, 0.08167178928852081, 0.11951272189617157, 0.30594193935394287, 0.0, 0.006462599150836468, 0.006167746149003506, 0.00141435069963336, 0.00035615835804492235, 0.0002947094908449799, 0.002378113567829132, 0.011835698038339615, 0.0024426754098385572, 0.968647837638855, 0.0]], [[0.022736268118023872, 0.02286626398563385, 0.14116300642490387, 0.13108347356319427, 0.23994718492031097, 0.1924150437116623, 0.01816762052476406, 0.04976898059248924, 0.18185211718082428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05882957577705383, 0.028569074347615242, 0.23305171728134155, 0.053790394216775894, 0.18451730906963348, 0.2002667486667633, 0.015585620887577534, 0.052768219262361526, 0.17262138426303864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09136874228715897, 0.08459936082363129, 0.05023255571722984, 0.21660202741622925, 0.1335863471031189, 0.10654665529727936, 0.02717875875532627, 0.06888726353645325, 0.22099831700325012, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04131297022104263, 0.05848437175154686, 0.3077566921710968, 0.040097035467624664, 0.16343727707862854, 0.11984208226203918, 0.06441103667020798, 0.0850440189242363, 0.11961443722248077, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06447532773017883, 0.05503746494650841, 0.11529060453176498, 0.13719302415847778, 0.0843825414776802, 0.22279226779937744, 0.11870565265417099, 0.05292103812098503, 0.14920207858085632, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.061820220202207565, 0.03663187846541405, 0.08412205427885056, 0.386857271194458, 0.1083698719739914, 0.1462787538766861, 0.03903358429670334, 0.026668915525078773, 0.11021733283996582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08746915310621262, 0.025642354041337967, 0.16437062621116638, 0.19346435368061066, 0.10867251455783844, 0.12237238138914108, 0.06722743809223175, 0.0922309011220932, 0.13855047523975372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10294228792190552, 0.07313423603773117, 0.18607352674007416, 0.09769721329212189, 0.1089077964425087, 0.26933327317237854, 0.06555335968732834, 0.061070602387189865, 0.03528755530714989, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12094805389642715, 0.14730192720890045, 0.09877816587686539, 0.21085986495018005, 0.06241541728377342, 0.22994481027126312, 0.04595630243420601, 0.04531335458159447, 0.0384821854531765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11032164841890335, 0.07897982746362686, 0.08231978863477707, 0.2677886188030243, 0.1231643408536911, 0.0929633229970932, 0.08270144462585449, 0.06097007542848587, 0.10079105943441391, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11438923329114914, 0.12380287796258926, 0.23573537170886993, 0.19010169804096222, 0.15611350536346436, 0.031749427318573, 0.02482231892645359, 0.05017237365245819, 0.07311322540044785, 0.0, 0.013161101378500462, 0.01350532379001379, 0.39494189620018005, 0.007352527230978012, 0.12711142003536224, 0.14605116844177246, 0.03487401455640793, 0.15623201429843903, 0.10677067190408707, 0.0], [0.002549531403928995, 0.03178577870130539, 0.17347589135169983, 0.2232668697834015, 0.49775105714797974, 0.018238944932818413, 0.005651220679283142, 0.03368452191352844, 0.013595964759588242, 0.0, 0.021876059472560883, 0.4906902313232422, 0.4596463143825531, 0.004091671667993069, 0.004464378114789724, 0.001156727666966617, 0.000353646173607558, 0.000146497564855963, 0.017574656754732132, 0.0], [0.0032994491048157215, 0.026504727080464363, 0.41210347414016724, 0.24245016276836395, 0.18897436559200287, 0.012874660082161427, 0.006452939473092556, 0.10089367628097534, 0.00644671730697155, 0.0, 0.005734701175242662, 0.026843877509236336, 0.9321272969245911, 0.00021884289162699133, 0.00045866103027947247, 0.0010309598874300718, 0.00017261962057091296, 0.003054215107113123, 0.030358724296092987, 0.0], [0.002998506650328636, 0.048583757132291794, 0.28224417567253113, 0.0846971943974495, 0.013445784337818623, 0.02188579924404621, 0.017656570300459862, 0.5155076384544373, 0.012980557046830654, 0.0, 0.0482722632586956, 0.14050070941448212, 0.4546079635620117, 0.0072937230579555035, 0.023873258382081985, 0.09857403486967087, 0.0516686774790287, 0.11766187101602554, 0.05754747614264488, 0.0], [0.004188622813671827, 0.028234833851456642, 0.022820167243480682, 0.058492597192525864, 0.19205521047115326, 0.08343320339918137, 0.07119973003864288, 0.4843534827232361, 0.0552222914993763, 0.0, 0.0020078516099601984, 0.002228439087048173, 0.111594557762146, 0.0033910104539245367, 0.08423032611608505, 0.17691271007061005, 0.14758752286434174, 0.4346924424171448, 0.037355244159698486, 0.0], [0.0038351663388311863, 0.015353971160948277, 0.01755588687956333, 0.06245748698711395, 0.1218588799238205, 0.07207991182804108, 0.02867230959236622, 0.5455195903778076, 0.13266700506210327, 0.0, 0.0008274781284853816, 0.0016531302826479077, 0.047970183193683624, 0.0006053023971617222, 0.22220103442668915, 0.6234129071235657, 0.05364101752638817, 0.012585645541548729, 0.03710317984223366, 0.0], [0.004144841339439154, 0.0048835063353180885, 0.0035110898315906525, 0.06276324391365051, 0.04069552943110466, 0.3603023290634155, 0.1472603678703308, 0.2116946280002594, 0.16474448144435883, 0.0, 2.7583497285377234e-05, 1.1631378583842888e-05, 4.4259006244828925e-05, 0.0006730516324751079, 0.599366307258606, 0.006597205530852079, 0.3886081576347351, 0.0003169252013321966, 0.004354946780949831, 0.0], [0.024624889716506004, 0.016127971932291985, 0.0073340879753232, 0.023849278688430786, 0.042295511811971664, 0.5078635215759277, 0.2884303331375122, 0.011452756822109222, 0.07802165299654007, 0.0, 2.752073669398669e-06, 2.0648456029448425e-06, 8.536147106497083e-06, 6.34281532256864e-05, 0.9992840886116028, 0.00028667543665505946, 7.951273437356576e-05, 3.5721727726922836e-06, 0.00026920961681753397, 0.0], [0.00880166981369257, 0.002673782641068101, 0.001370548619888723, 0.0061265453696250916, 0.02490534819662571, 0.2073771357536316, 0.3818575143814087, 0.1663341522216797, 0.20055335760116577, 0.0, 3.3996084312093444e-06, 2.1497796751646092e-06, 7.304265182028757e-06, 0.00018760550301522017, 0.99969482421875, 2.4790026145637967e-05, 3.4293629141757265e-05, 6.942725121916737e-06, 3.892222957802005e-05, 0.0], [0.012253189459443092, 0.02221212349832058, 0.002282155444845557, 0.10455729067325592, 0.4111727774143219, 0.08308815956115723, 0.045707643032073975, 0.03711223974823952, 0.2816142141819, 0.0, 0.0005689842510037124, 0.002939490834251046, 0.019829533994197845, 0.0003717679646797478, 0.01646142266690731, 0.011912180110812187, 0.001234701368957758, 0.0013870754046365619, 0.945294976234436, 0.0]], [[0.008687321096658707, 0.012162125669419765, 0.02774685248732567, 0.0013578477082774043, 0.052177976816892624, 0.027187975123524666, 0.05590689554810524, 0.020962538197636604, 0.7938104867935181, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.005042325239628553, 0.015503124333918095, 0.010042164474725723, 0.0008876739302650094, 0.011308688670396805, 0.010491759516298771, 0.03130592033267021, 0.04934320226311684, 0.8660751581192017, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013016406446695328, 0.03886239603161812, 0.027493299916386604, 0.029101338237524033, 0.009947741404175758, 0.00769558921456337, 0.035501737147569656, 0.023772817105054855, 0.8146085143089294, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.018851714208722115, 0.05105733126401901, 0.8005384206771851, 0.01116525661200285, 0.09583853930234909, 0.0015093896072357893, 0.005055624525994062, 0.0006665397086180747, 0.015317671000957489, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01609102450311184, 0.023716216906905174, 0.5135837197303772, 0.10603100061416626, 0.26668840646743774, 0.019648341462016106, 0.01755940169095993, 0.01368130836635828, 0.023000601679086685, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01718730293214321, 0.02692273259162903, 0.05480796471238136, 0.010818017646670341, 0.7150712013244629, 0.0585104264318943, 0.04717297852039337, 0.030360547825694084, 0.039148781448602676, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006439396180212498, 0.012697076424956322, 0.014188298024237156, 0.000897688849363476, 0.7481768727302551, 0.15047557651996613, 0.03333613649010658, 0.01207506563514471, 0.021714046597480774, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009459104388952255, 0.022298788651823997, 0.013802104629576206, 0.011955137364566326, 0.03879927098751068, 0.1585427075624466, 0.07075291126966476, 0.329448938369751, 0.3449409306049347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04810584336519241, 0.017975708469748497, 0.025123968720436096, 0.023182567209005356, 0.020010611042380333, 0.04571577161550522, 0.1801854819059372, 0.06764508783817291, 0.5720548629760742, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026153914630413055, 0.0356404148042202, 0.10573611408472061, 0.06201518699526787, 0.06006328761577606, 0.09286139905452728, 0.2927103638648987, 0.20419549942016602, 0.12062377482652664, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5821239352226257, 0.14550858736038208, 0.031251534819602966, 0.030760297551751137, 0.02147754468023777, 0.013665237464010715, 0.009087015874683857, 0.01557532325387001, 0.15055041015148163, 0.0, 0.00632825493812561, 0.011520092375576496, 0.08263711631298065, 0.006356080062687397, 0.022936103865504265, 0.03108564019203186, 0.013897407799959183, 0.697504997253418, 0.12773430347442627, 0.0], [0.12817564606666565, 0.33913177251815796, 0.07241326570510864, 0.41213902831077576, 0.0326012559235096, 0.0031606394331902266, 0.0006341012776829302, 0.007317711599171162, 0.0044263736344873905, 0.0, 0.008715116418898106, 0.015272715128958225, 0.10463730990886688, 0.08011683076620102, 0.13045108318328857, 0.05373600497841835, 0.015578814782202244, 0.4212273955345154, 0.1702648103237152, 0.0], [0.08047150820493698, 0.06199575960636139, 0.5555182099342346, 0.2858560383319855, 0.008700164034962654, 0.003758196486160159, 0.001155794132500887, 0.0007424709619954228, 0.0018020549323409796, 0.0, 0.004959889687597752, 0.007777809165418148, 0.14492008090019226, 0.02459821291267872, 0.014704479835927486, 0.016136664897203445, 0.008129375986754894, 0.7319321036338806, 0.0468413271009922, 0.0], [0.010044030845165253, 0.018482256680727005, 0.6269924640655518, 0.32439544796943665, 0.01023165788501501, 0.007641270756721497, 0.0008933563949540257, 0.0010311403311789036, 0.00028844154439866543, 0.0, 0.005315575283020735, 0.0021190166007727385, 0.007080279756337404, 0.006970370654016733, 0.010002117604017258, 0.007610250264406204, 0.004703941754996777, 0.8570073246955872, 0.09919113665819168, 0.0], [0.0007911038701422513, 0.0008549468475393951, 0.015090622939169407, 0.8270009160041809, 0.11969847232103348, 0.032614268362522125, 0.0024233118165284395, 0.0011481117689982057, 0.0003779604157898575, 0.0, 0.0016317280242219567, 0.0005414763581939042, 0.004523266106843948, 0.0019645043648779392, 0.010821727104485035, 0.008883371017873287, 0.00927714817225933, 0.920802652835846, 0.041554201394319534, 0.0], [0.017773190513253212, 0.008623103611171246, 0.0020072387997061014, 0.08177924901247025, 0.13816505670547485, 0.6801413297653198, 0.02186667174100876, 0.024107687175273895, 0.025536518543958664, 0.0, 0.002020488725975156, 0.0007793906843289733, 0.022791940718889236, 0.005821499973535538, 0.1932065784931183, 0.30031588673591614, 0.08197023719549179, 0.12508654594421387, 0.2680076062679291, 0.0], [0.000318053673254326, 5.6540200603194535e-05, 1.071194674295839e-05, 0.0009494975674897432, 0.0034297029487788677, 0.032661326229572296, 0.9588278532028198, 0.003185966284945607, 0.0005602877936325967, 0.0, 0.007396090775728226, 0.0032474161125719547, 0.00692824088037014, 0.007240207865834236, 0.42384257912635803, 0.04473983123898506, 0.013007782399654388, 0.007779541425406933, 0.4858182966709137, 0.0], [0.0017862697131931782, 0.0002347631088923663, 2.1297884813975543e-05, 0.0004797980946023017, 0.0018031852087005973, 0.024247879162430763, 0.45456385612487793, 0.5099425911903381, 0.006920217536389828, 0.0, 0.0026900237426161766, 0.0007204422145150602, 0.005861051380634308, 0.003422616282477975, 0.46744993329048157, 0.10402297228574753, 0.05837857723236084, 0.0177029799669981, 0.3397515118122101, 0.0], [0.0006541880429722369, 0.0009561541373841465, 7.73017163737677e-05, 0.00942671112716198, 0.04198922589421272, 0.04971348121762276, 0.32961171865463257, 0.4513629972934723, 0.11620841920375824, 0.0, 0.005906206555664539, 0.002057044068351388, 0.0031123505905270576, 0.008901549503207207, 0.43650564551353455, 0.08504725992679596, 0.0923796221613884, 0.009556618519127369, 0.3565336763858795, 0.0], [0.017209511250257492, 0.004475452937185764, 3.128392927465029e-05, 0.00047953161993063986, 0.00448839133605361, 0.03360708802938461, 0.11509764194488525, 0.5398797988891602, 0.2847314178943634, 0.0, 0.013360978104174137, 0.04520300775766373, 0.09048072248697281, 0.012179902754724026, 0.030064363032579422, 0.023480970412492752, 0.008669134229421616, 0.03746046498417854, 0.7391002178192139, 0.0]], [[0.02415475994348526, 0.0027711745351552963, 0.003856832394376397, 0.0957413911819458, 0.02159286104142666, 0.03336814045906067, 0.009564127773046494, 0.03954486921429634, 0.7694058418273926, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9052021503448486, 0.02053658291697502, 0.0014916026266291738, 0.00022646080469712615, 4.7710393118904904e-05, 0.000383042759494856, 0.014123834669589996, 0.0205638837069273, 0.03742456063628197, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37607336044311523, 0.6030705571174622, 0.0068079219199717045, 0.0036466827150434256, 9.876023250399157e-05, 2.0246809071977623e-05, 0.0007042856304906309, 0.002560489112511277, 0.007017510011792183, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [5.0091031880583614e-05, 0.00024915943504311144, 0.9895205497741699, 0.006273698527365923, 0.0016484790248796344, 4.1711446101544425e-05, 7.522702958340233e-07, 1.2660359971050639e-05, 0.002202932955697179, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [8.009441080503166e-05, 9.311464236816391e-05, 0.006593613885343075, 0.9913647770881653, 0.0018261962104588747, 1.6436462829005904e-05, 8.038865075832291e-07, 1.0318336762793479e-06, 2.3524326024926268e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [3.1561212381348014e-05, 1.8178753862230224e-06, 0.00011904581333510578, 0.027105441316962242, 0.8800897598266602, 0.09253741800785065, 0.00010895416926359758, 5.953493655397324e-06, 1.9602707368449046e-07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.7160528553716858e-09, 1.4191656530493368e-11, 3.274841375855431e-08, 2.1219284462858923e-07, 1.9925082597183064e-05, 0.9999751448631287, 3.130498271275428e-06, 1.9788064946624218e-06, 3.1215499074477293e-09, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.2861962204624433e-05, 5.737682045037218e-07, 2.0471109110076213e-06, 1.0477544492459856e-05, 6.581651632586727e-06, 0.02534269355237484, 0.16125597059726715, 0.5878354907035828, 0.22553342580795288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0009172551217488945, 7.270056084962562e-05, 2.2026280930731446e-05, 4.6261970965133514e-06, 4.921669642499182e-06, 4.060195351485163e-05, 0.027831047773361206, 0.33271971344947815, 0.6383873224258423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.3075091374048498e-05, 6.147480598883703e-05, 4.768987855641171e-05, 2.045959490715177e-06, 1.1152823553572944e-08, 3.07468525306831e-07, 0.0007055726600810885, 0.02803119830787182, 0.9711382985115051, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20143046975135803, 0.41116827726364136, 0.09215858578681946, 0.10672477632761002, 0.06125285103917122, 0.017610367387533188, 0.01457523088902235, 0.02514597773551941, 0.06993352621793747, 0.0, 0.023652182891964912, 0.008639940991997719, 0.08203616738319397, 0.035750582814216614, 0.050224509090185165, 0.3533262312412262, 0.03081362321972847, 0.28302860260009766, 0.1325281411409378, 0.0], [0.026864346116781235, 0.037146128714084625, 0.08411292731761932, 0.02904331497848034, 0.0955604761838913, 0.05886658653616905, 0.08584483712911606, 0.4076027572154999, 0.17495866119861603, 0.0, 0.016670020297169685, 0.1283574253320694, 0.836423397064209, 0.0042742472141981125, 0.0022883012425154448, 0.00297459471039474, 0.00022807312780059874, 0.0012588471872732043, 0.007524838205426931, 0.0], [0.073190838098526, 0.07998740673065186, 0.05594569817185402, 0.03243006020784378, 0.10037493705749512, 0.13878461718559265, 0.15250830352306366, 0.25721096992492676, 0.10956726223230362, 0.0, 0.031559381633996964, 0.02045642025768757, 0.8176267743110657, 0.006169404834508896, 0.0014412011951208115, 0.0069603933952748775, 0.0010916722239926457, 0.011522608809173107, 0.10317197442054749, 0.0], [0.0438627265393734, 0.04628896340727806, 0.4038660526275635, 0.005475929472595453, 0.03436022624373436, 0.11165640503168106, 0.02260321006178856, 0.28233063220977783, 0.04955587536096573, 0.0, 0.004598122555762529, 0.004610949195921421, 0.01865001954138279, 0.020574036985635757, 0.0137012405321002, 0.7973257303237915, 0.01646837778389454, 0.023596635088324547, 0.1004747673869133, 0.0], [0.2377929538488388, 0.08882997930049896, 0.12371516227722168, 0.08651548624038696, 0.015416872687637806, 0.04211122542619705, 0.16403844952583313, 0.11833071708679199, 0.12324906885623932, 0.0, 0.0005213705007918179, 0.00018707667186390609, 0.0016978917410597205, 0.019619440659880638, 0.009308884851634502, 0.8590161800384521, 0.024511896073818207, 0.06970686465501785, 0.015430280938744545, 0.0], [0.023254310712218285, 0.0034057339653372765, 0.036038532853126526, 0.009054891765117645, 0.0329253226518631, 0.05284882336854935, 0.15671837329864502, 0.6067742109298706, 0.07897992432117462, 0.0, 0.0001481063081882894, 2.072651477647014e-05, 0.00035672096419148147, 0.00033358228392899036, 0.00040588833508081734, 0.9861487746238708, 0.00651955883949995, 0.00443643843755126, 0.0016300288261845708, 0.0], [0.015282228589057922, 0.008608018048107624, 0.08339564502239227, 0.032651614397764206, 0.21303850412368774, 0.22661514580249786, 0.21832069754600525, 0.1323210895061493, 0.06976725161075592, 0.0, 0.0010996124474331737, 0.0011850595474243164, 0.0075045316480100155, 0.004539311397820711, 0.05570072680711746, 0.18870605528354645, 0.23963898420333862, 0.013960372656583786, 0.487665593624115, 0.0], [0.019424932077527046, 0.008587736636400223, 0.014951083809137344, 0.01159222237765789, 0.2890152633190155, 0.2543036639690399, 0.2561561167240143, 0.0882645845413208, 0.05770434811711311, 0.0, 0.0003884119214490056, 0.0004658032557927072, 0.028157439082860947, 0.0002352961164433509, 0.1278570294380188, 0.08260466903448105, 0.02582997828722, 0.022790132090449333, 0.7116712927818298, 0.0], [0.020595766603946686, 0.015824340283870697, 0.008689227513968945, 0.03796549141407013, 0.3004503846168518, 0.16956602036952972, 0.10506420582532883, 0.05004280060529709, 0.2918018400669098, 0.0, 0.0015414542285725474, 0.0007310948567464948, 0.010464987717568874, 0.0012846259633079171, 0.45206302404403687, 0.029316790401935577, 0.04706822335720062, 0.018986493349075317, 0.4385431706905365, 0.0], [0.18154361844062805, 0.0977708026766777, 0.20556335151195526, 0.05251142755150795, 0.13640889525413513, 0.06629360467195511, 0.06030320003628731, 0.08172836154699326, 0.11787670105695724, 0.0, 0.0005072542116977274, 0.0011837932979688048, 0.01220926083624363, 8.532252832083032e-05, 0.0018606879748404026, 0.010199862532317638, 0.0016309961210936308, 0.010775143280625343, 0.9615475535392761, 0.0]], [[0.060361556708812714, 0.015829458832740784, 0.05784451961517334, 0.3351474404335022, 0.06477320939302444, 0.04427827522158623, 0.09356044977903366, 0.03362266346812248, 0.2945823669433594, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.051239900290966034, 0.0459107868373394, 0.10656695812940598, 0.4080160856246948, 0.16381530463695526, 0.044977184385061264, 0.05972094088792801, 0.009804679080843925, 0.10994797199964523, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019088272005319595, 0.05349855497479439, 0.4389742910861969, 0.022328443825244904, 0.03395729511976242, 0.20592069625854492, 0.007582489866763353, 0.08437496423721313, 0.13427504897117615, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03275543451309204, 0.01311502419412136, 0.038520246744155884, 0.47789818048477173, 0.04586595296859741, 0.01380465179681778, 0.03337283805012703, 0.07212045043706894, 0.27254730463027954, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04071904346346855, 0.043366871774196625, 0.1190471276640892, 0.18268215656280518, 0.2763146162033081, 0.029253922402858734, 0.017268449068069458, 0.0670313611626625, 0.22431644797325134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04853136092424393, 0.0034203159157186747, 0.17822766304016113, 0.005087696481496096, 0.02670232392847538, 0.5734196305274963, 0.06478680670261383, 0.04684215411543846, 0.05298209935426712, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.016102498397231102, 0.0006646174006164074, 0.00315408268943429, 0.003398373955860734, 0.01210782676935196, 0.07864897698163986, 0.743419349193573, 0.023116787895560265, 0.11938738822937012, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0031801864970475435, 0.0032259617000818253, 0.027063841000199318, 0.0018325509736314416, 0.006064774002879858, 0.017839375883340836, 0.05006564408540726, 0.8002738952636719, 0.0904538482427597, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02500138245522976, 0.016465606167912483, 0.02692888118326664, 0.01824249140918255, 0.047875918447971344, 0.06556686758995056, 0.15585453808307648, 0.21941381692886353, 0.42465049028396606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07641319185495377, 0.017753547057509422, 0.039497166872024536, 0.014236720278859138, 0.03872253745794296, 0.1210501492023468, 0.17305448651313782, 0.2333979308605194, 0.28587427735328674, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07673492282629013, 0.03585591912269592, 0.0804624855518341, 0.05707075819373131, 0.16190174221992493, 0.1288135051727295, 0.1235240250825882, 0.06807681918144226, 0.2675597667694092, 0.0, 0.29744189977645874, 0.04770943149924278, 0.09888078272342682, 0.19768767058849335, 0.048243775963783264, 0.12058595567941666, 0.05976371467113495, 0.03847452625632286, 0.09121233224868774, 0.0], [0.005086997989565134, 0.014635499566793442, 0.013461720198392868, 0.6349815726280212, 0.14714521169662476, 0.015218403190374374, 0.01605474203824997, 0.018318237736821175, 0.1350976973772049, 0.0, 0.04126456007361412, 0.6604095697402954, 0.028894882649183273, 0.20104490220546722, 0.0014044500421732664, 0.0009343607816845179, 0.00244489056058228, 0.007453228812664747, 0.05614929273724556, 0.0], [0.03515003249049187, 0.049813926219940186, 0.04029693454504013, 0.4151618778705597, 0.24873343110084534, 0.009437951259315014, 0.008381601423025131, 0.020832136273384094, 0.17219208180904388, 0.0, 0.008357543498277664, 0.0022072584833949804, 0.9876156449317932, 8.841200906317681e-05, 1.4883004041621462e-05, 0.00011741811613319442, 2.7020510970032774e-05, 0.00016062626673374325, 0.001411277218721807, 0.0], [0.06722414493560791, 0.13528113067150116, 0.06224377825856209, 0.18915168941020966, 0.17580503225326538, 0.07229694724082947, 0.012536793015897274, 0.09137610346078873, 0.19408434629440308, 0.0, 0.06216944754123688, 0.48559242486953735, 0.042546145617961884, 0.034007471054792404, 0.047574639320373535, 0.12490913271903992, 0.07922931015491486, 0.013364763930439949, 0.11060672253370285, 0.0], [0.09099949151277542, 0.09548961371183395, 0.04829362779855728, 0.1739831268787384, 0.06667517125606537, 0.05157051607966423, 0.05465595796704292, 0.06177656352519989, 0.3565560579299927, 0.0, 0.05222959443926811, 0.025416702032089233, 0.02865077182650566, 0.17457211017608643, 0.03144511207938194, 0.3907364010810852, 0.19607771933078766, 0.05274118855595589, 0.04813018813729286, 0.0], [0.09822985529899597, 0.05441536381840706, 0.039150238037109375, 0.06369251012802124, 0.05292840674519539, 0.050128646194934845, 0.044398434460163116, 0.04042055085301399, 0.5566359758377075, 0.0, 0.0037726862356066704, 0.0031579534988850355, 0.0029440780635923147, 0.0017320584738627076, 0.060473062098026276, 0.761774480342865, 0.1523173600435257, 0.0058823637664318085, 0.007945872843265533, 0.0], [0.012019939720630646, 0.0076602306216955185, 0.02716030552983284, 0.03984800726175308, 0.09776019304990768, 0.05175628885626793, 0.08536165207624435, 0.0944109782576561, 0.5840223431587219, 0.0, 0.0020738786552101374, 0.0012752892216667533, 0.0004058163322042674, 0.020963717252016068, 0.39340031147003174, 0.012434415519237518, 0.4783190190792084, 0.011497312225401402, 0.0796302929520607, 0.0], [0.036716632544994354, 0.021969007328152657, 0.010507079772651196, 0.012404722161591053, 0.040125522762537, 0.010736462660133839, 0.018730206415057182, 0.030387653037905693, 0.8184227347373962, 0.0, 5.31752230017446e-05, 1.4492364243778866e-05, 7.312332309084013e-05, 0.0023682843893766403, 0.9866323471069336, 0.0009243910317309201, 0.0011850211303681135, 0.0017622504383325577, 0.0069872229360044, 0.0], [0.04769879952073097, 0.19333122670650482, 0.02803504839539528, 0.016029207035899162, 0.11119306832551956, 0.03845509514212608, 0.011404097080230713, 0.0836206004023552, 0.4702327847480774, 0.0, 4.074166645295918e-05, 1.823456841520965e-05, 0.0001418270985595882, 0.007263784296810627, 0.9604514241218567, 0.0001852070417953655, 0.00034164052340202034, 0.0018497714772820473, 0.029707150533795357, 0.0], [0.05245642364025116, 0.013315027579665184, 0.012056763283908367, 0.004825723823159933, 0.015483945608139038, 0.032884638756513596, 0.027794960886240005, 0.07057305425405502, 0.7706093788146973, 0.0, 0.0133396340534091, 0.03136875480413437, 0.6319980621337891, 0.0033722908701747656, 0.04728742688894272, 0.03541773557662964, 0.009523973800241947, 0.03100484237074852, 0.1966874897480011, 0.0]], [[0.15564993023872375, 0.3264511823654175, 0.08247561007738113, 0.04047680273652077, 0.04636594280600548, 0.03705644607543945, 0.05653020739555359, 0.08808662742376328, 0.16690711677074432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6047166585922241, 0.08402378112077713, 0.11650887131690979, 0.004807815421372652, 0.02726476825773716, 0.0609126091003418, 0.02905944734811783, 0.012920884415507317, 0.059785205870866776, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5938906669616699, 0.07300958037376404, 0.08890929818153381, 0.008111076429486275, 0.04038470610976219, 0.07353192567825317, 0.03085281327366829, 0.08706387132406235, 0.004246041644364595, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2591831088066101, 0.17658700048923492, 0.44177621603012085, 0.01689036749303341, 0.0653892457485199, 0.01502177957445383, 0.02055797167122364, 0.0024378441739827394, 0.0021566858049482107, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.33400091528892517, 0.03927909955382347, 0.27614372968673706, 0.009977479465305805, 0.12025652825832367, 0.1713484674692154, 0.04292818158864975, 0.004225345328450203, 0.00184013566467911, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06147114187479019, 0.019044799730181694, 0.059415291994810104, 0.05198045074939728, 0.12181691080331802, 0.419679194688797, 0.1140735000371933, 0.14551687240600586, 0.00700181070715189, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006845483556389809, 0.002091927919536829, 0.01196279563009739, 0.014390786178410053, 0.02692629024386406, 0.8455513715744019, 0.07174734026193619, 0.017689114436507225, 0.0027949714567512274, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.00039940490387380123, 0.00013551976007875055, 0.020663700997829437, 0.008696838282048702, 0.021915050223469734, 0.1381293535232544, 0.0347108468413353, 0.7650054097175598, 0.010343861766159534, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02615724503993988, 0.0051858089864254, 0.038734134286642075, 0.021585455164313316, 0.19684533774852753, 0.17548950016498566, 0.1665634661912918, 0.2796759307384491, 0.08976294845342636, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.043001022189855576, 0.016749290749430656, 0.04958483204245567, 0.06659381091594696, 0.0702962800860405, 0.27735820412635803, 0.14212922751903534, 0.20686522126197815, 0.12742231786251068, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05745904520153999, 0.06613133102655411, 0.11319872736930847, 0.031750500202178955, 0.0641264021396637, 0.07090476900339127, 0.053613319993019104, 0.1108509749174118, 0.4319649040699005, 0.0, 0.03367111459374428, 0.018932543694972992, 0.09506545215845108, 0.04718795791268349, 0.028798582032322884, 0.33658939599990845, 0.02586139366030693, 0.29842811822891235, 0.11546547710895538, 0.0], [0.12783250212669373, 0.16847258806228638, 0.08126984536647797, 0.10575822740793228, 0.03301985561847687, 0.2111520618200302, 0.10687874257564545, 0.06316707283258438, 0.10244929045438766, 0.0, 0.006203038617968559, 0.0906001627445221, 0.6977949738502502, 0.018352899700403214, 0.06787873804569244, 0.04403599724173546, 0.001631368650123477, 0.024296771734952927, 0.049206044524908066, 0.0], [0.1413263976573944, 0.38601601123809814, 0.16798537969589233, 0.14611834287643433, 0.015951359644532204, 0.042198505252599716, 0.016183707863092422, 0.06246974319219589, 0.021750787273049355, 0.0, 0.006243667099624872, 0.010453532449901104, 0.7879610657691956, 0.004093538969755173, 0.0008473669877275825, 0.027760563418269157, 0.0003080451278947294, 0.14831961691379547, 0.014012438245117664, 0.0], [0.020376645028591156, 0.008152640424668789, 0.04579228535294533, 0.022974595427513123, 0.007921000011265278, 0.11700868606567383, 0.010826223529875278, 0.7216546535491943, 0.04529344290494919, 0.0, 0.004387176129966974, 0.023410169407725334, 0.17247918248176575, 0.03958609700202942, 0.023799436166882515, 0.43659475445747375, 0.014754846692085266, 0.2318120151758194, 0.05317622795701027, 0.0], [0.04728184640407562, 0.041129130870103836, 0.12847241759300232, 0.038289085030555725, 0.07389654964208603, 0.11478690057992935, 0.04442784935235977, 0.41169247031211853, 0.1000237911939621, 0.0, 0.0020952164195477962, 0.0024118656292557716, 0.028229335322976112, 0.007075420115143061, 0.019164882600307465, 0.5397294163703918, 0.034580815583467484, 0.3465326428413391, 0.020180128514766693, 0.0], [0.016180921345949173, 0.005130380857735872, 0.21081623435020447, 0.00797765702009201, 0.04691680520772934, 0.052309177815914154, 0.2947923243045807, 0.34133997559547424, 0.02453651838004589, 0.0, 0.00020744462381117046, 0.00036016973899677396, 0.004934145137667656, 0.0004664760490413755, 0.008187839761376381, 0.9661812782287598, 0.009987047873437405, 0.003882928751409054, 0.005792597308754921, 0.0], [0.006579844746738672, 0.001606129459105432, 0.206822007894516, 0.017204096540808678, 0.13898226618766785, 0.09910376369953156, 0.4235020577907562, 0.05497713387012482, 0.051222700625658035, 0.0, 3.4081476769642904e-05, 1.7181657312903553e-05, 5.4824478866066784e-05, 0.00045897584641352296, 0.0043338024988770485, 0.001544477418065071, 0.9909620881080627, 2.356152981519699e-05, 0.0025708049070090055, 0.0], [0.00896216370165348, 0.0023249718360602856, 0.0226416178047657, 0.05458173528313637, 0.07694459706544876, 0.29436299204826355, 0.36870595812797546, 0.12525610625743866, 0.046219732612371445, 0.0, 0.0001047314508468844, 0.0001599654060555622, 0.001310097286477685, 0.001540280063636601, 0.833267331123352, 0.044754061847925186, 0.0028599577490240335, 0.0006454077665694058, 0.11535807698965073, 0.0], [0.027829669415950775, 0.014619122259318829, 0.014550572261214256, 0.048137370496988297, 0.15001901984214783, 0.11716196686029434, 0.34159788489341736, 0.1513865739107132, 0.13469791412353516, 0.0, 8.819431968731806e-05, 6.364465662045404e-05, 0.00022057128080632538, 0.001112746773287654, 0.9560981392860413, 0.003599100047722459, 0.0002217600413132459, 0.0006697923527099192, 0.03792598471045494, 0.0], [0.0014273751294240355, 0.003807784290984273, 0.3760293126106262, 0.002253596903756261, 0.11343870311975479, 0.12883712351322174, 0.04242479428648949, 0.28902071714401245, 0.042760640382766724, 0.0, 0.0018130787648260593, 0.022020958364009857, 0.12822051346302032, 0.0005810249131172895, 0.03168048337101936, 0.014293116517364979, 0.002500524278730154, 0.0212943647056818, 0.7775959372520447, 0.0]]], [[[0.13086311519145966, 0.049477167427539825, 0.10100015252828598, 0.03843620419502258, 0.27287009358406067, 0.20078831911087036, 0.16546384990215302, 0.03368193656206131, 0.007419050205498934, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1137659102678299, 0.11250672489404678, 0.21935509145259857, 0.09974226355552673, 0.22245454788208008, 0.11022598296403885, 0.0977952778339386, 0.010162456892430782, 0.013991687446832657, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09118296205997467, 0.0991944894194603, 0.31555840373039246, 0.16625922918319702, 0.1399575173854828, 0.0926588773727417, 0.021735703572630882, 0.056496523320674896, 0.016956249251961708, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.35773080587387085, 0.19870112836360931, 0.026073846966028214, 0.07347559928894043, 0.09251826256513596, 0.0859094187617302, 0.06421677768230438, 0.06334269791841507, 0.0380314365029335, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02230222336947918, 0.0210218857973814, 0.024334343150258064, 0.36442241072654724, 0.2750929892063141, 0.13295342028141022, 0.06824173033237457, 0.0036951478105038404, 0.0879359245300293, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.018942566588521004, 0.011805560439825058, 0.04696377366781235, 0.09440026432275772, 0.39890599250793457, 0.17608429491519928, 0.10613365471363068, 0.10454639047384262, 0.04221746698021889, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0475851334631443, 0.008668179623782635, 0.011950161308050156, 0.0786907747387886, 0.09432563930749893, 0.07653870433568954, 0.4287588894367218, 0.13403372466564178, 0.1194487139582634, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008243327029049397, 0.006908380892127752, 0.04044030234217644, 0.08380357921123505, 0.1593569815158844, 0.1858288198709488, 0.0890916958451271, 0.40247857570648193, 0.02384827472269535, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09753390401601791, 0.04787491634488106, 0.10570236295461655, 0.09989321976900101, 0.07242950052022934, 0.16000299155712128, 0.13195638358592987, 0.12870465219020844, 0.15590202808380127, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3338638246059418, 0.05386793985962868, 0.15485166013240814, 0.05483235418796539, 0.052468191832304, 0.12754301726818085, 0.13515245914459229, 0.06475869566202164, 0.022661946713924408, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9917634725570679, 0.008236419409513474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6252409815788269, 0.3747589886188507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.711856484413147, 0.20838035643100739, 0.07976315170526505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8520486354827881, 0.010580658912658691, 0.13737063109874725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6327172517776489, 0.1227935329079628, 0.21565596759319305, 0.028833283111453056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05910082906484604, 0.011589597910642624, 0.877491295337677, 0.051818281412124634, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3586137592792511, 0.038762304931879044, 0.08015953004360199, 0.4233120083808899, 0.09915236383676529, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3626183867454529, 0.026959313079714775, 0.07612177729606628, 0.13077552616596222, 0.4035249352455139, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7095601558685303, 0.03453405201435089, 0.02220289036631584, 0.009008818306028843, 0.201883926987648, 0.022810086607933044, 0.0, 0.0, 0.0, 0.0, 0.21979263424873352, 0.001410112832672894, 0.007092535495758057, 0.13166557252407074, 0.626970648765564, 0.013068560510873795, 0.0, 0.0, 0.0, 0.0], [0.5828825831413269, 0.02795644849538803, 0.054448600858449936, 0.01975347101688385, 0.11504233628511429, 0.08908692002296448, 0.11082970350980759, 0.0, 0.0, 0.0, 0.08148042857646942, 0.001490423921495676, 0.004908325150609016, 0.01383854728192091, 0.7959722876548767, 0.05201547220349312, 0.05029459297657013, 0.0, 0.0, 0.0], [0.4315364956855774, 0.020537925884127617, 0.01659376546740532, 0.014654956758022308, 0.13063199818134308, 0.27319464087486267, 0.08869150280952454, 0.024158723652362823, 0.0, 0.0, 0.03934427723288536, 5.908778257435188e-05, 0.00014962907880544662, 0.005592166446149349, 0.7025003433227539, 0.1675100177526474, 0.03920353576540947, 0.04564077779650688, 0.0, 0.0], [0.26020547747612, 0.014821716584265232, 0.01224969606846571, 0.0724530965089798, 0.10939211398363113, 0.19152909517288208, 0.10495918244123459, 0.1680101454257965, 0.06637949496507645, 0.0, 0.4660189151763916, 0.00034756408422254026, 9.701005183160305e-05, 0.008154522627592087, 0.08121690154075623, 0.15592943131923676, 0.11426379531621933, 0.17044323682785034, 0.0035288764629513025, 0.0], [0.6687084436416626, 0.04345089942216873, 0.009689688682556152, 0.0018685735994949937, 0.0738394483923912, 0.12735962867736816, 0.025320274755358696, 0.026545442640781403, 0.020931225270032883, 0.0022863498888909817, 0.3707294762134552, 0.0020887483842670918, 0.23984688520431519, 0.07748916745185852, 0.18109895288944244, 0.03584783151745796, 0.005205830093473196, 0.005058187525719404, 0.0050886403769254684, 0.0775463655591011]], [[0.011833908967673779, 0.03545977920293808, 0.03510122373700142, 0.06200635805726051, 0.09438431262969971, 0.06055876612663269, 0.053256530314683914, 0.30701303482055664, 0.3403860926628113, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03663749620318413, 0.06511621922254562, 0.05716057866811752, 0.07533077895641327, 0.10846659541130066, 0.037432827055454254, 0.04480022192001343, 0.18166707456111908, 0.39338818192481995, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06557667255401611, 0.03966936469078064, 0.008358842693269253, 0.06794404983520508, 0.05668830871582031, 0.02720261737704277, 0.07913517951965332, 0.20437636971473694, 0.45104852318763733, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.044038429856300354, 0.07477934658527374, 0.10143070667982101, 0.16204005479812622, 0.06265459954738617, 0.10170722752809525, 0.08676454424858093, 0.0699862688779831, 0.2965989410877228, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06005045771598816, 0.046840403228998184, 0.06629239022731781, 0.04125581681728363, 0.007815167307853699, 0.20412082970142365, 0.1083299070596695, 0.04942404478788376, 0.41587093472480774, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03666035085916519, 0.028792625293135643, 0.06887229532003403, 0.18481910228729248, 0.15058831870555878, 0.048441674560308456, 0.0780390277504921, 0.13469383120536804, 0.26909276843070984, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03408746421337128, 0.026394939050078392, 0.05409233644604683, 0.06951043754816055, 0.1446777582168579, 0.09970070421695709, 0.05472328141331673, 0.16119606792926788, 0.35561704635620117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12936006486415863, 0.04621516913175583, 0.10149524360895157, 0.14774896204471588, 0.45855623483657837, 0.033130910247564316, 0.031401973217725754, 0.02012830227613449, 0.031963150948286057, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1214270144701004, 0.04088712856173515, 0.05250505730509758, 0.07924661785364151, 0.05337269604206085, 0.10527284443378448, 0.08820997178554535, 0.17732012271881104, 0.28175854682922363, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13074854016304016, 0.06475767493247986, 0.07325490564107895, 0.0625966489315033, 0.14061231911182404, 0.07830052822828293, 0.12438739091157913, 0.21453101933002472, 0.11081094294786453, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9482711553573608, 0.051728855818510056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15256483852863312, 0.8474349975585938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8711318373680115, 0.04994085431098938, 0.07892734557390213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08618302643299103, 0.30268052220344543, 0.6111364364624023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7221198678016663, 0.040686361491680145, 0.06532222777605057, 0.17187155783176422, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6251113414764404, 0.14608541131019592, 0.21724094450473785, 0.011562197469174862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5948007702827454, 0.036634139716625214, 0.02264709398150444, 0.035541336983442307, 0.3103766441345215, 0.0, 0.0, 0.0, 0.0, 0.0, 0.31851068139076233, 0.11805614084005356, 0.02926168404519558, 0.0854775682091713, 0.44869405031204224, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6650473475456238, 0.01644211634993553, 0.019737746566534042, 0.0375308021903038, 0.10231779515743256, 0.15892422199249268, 0.0, 0.0, 0.0, 0.0, 0.23099647462368011, 0.015003926120698452, 0.0028121687937527895, 0.025386620312929153, 0.5829272270202637, 0.14287345111370087, 0.0, 0.0, 0.0, 0.0], [0.36675524711608887, 0.04118315875530243, 0.02765432558953762, 0.03228116035461426, 0.11875578761100769, 0.12892943620681763, 0.2844408452510834, 0.0, 0.0, 0.0, 0.2648485600948334, 0.01456066407263279, 0.008421574719250202, 0.01653379574418068, 0.25845009088516235, 0.35933130979537964, 0.07785411924123764, 0.0, 0.0, 0.0], [0.19659309089183807, 0.015950728207826614, 0.02453998662531376, 0.039237309247255325, 0.037656329572200775, 0.34599894285202026, 0.23759640753269196, 0.10242718458175659, 0.0, 0.0, 0.21031156182289124, 0.00652333116158843, 0.005756322760134935, 0.019128819927573204, 0.2526819407939911, 0.49096593260765076, 0.008809886872768402, 0.00582215515896678, 0.0, 0.0], [0.3881740868091583, 0.012267092242836952, 0.01897304505109787, 0.013982790522277355, 0.030991200357675552, 0.10819684714078903, 0.20157809555530548, 0.14642520248889923, 0.07941170781850815, 0.0, 0.11555754393339157, 0.00475481478497386, 0.0013921409845352173, 0.045808907598257065, 0.29882168769836426, 0.3024459183216095, 0.0483231395483017, 0.18265680968761444, 0.0002390409354120493, 0.0], [0.11410266160964966, 0.03479800745844841, 0.043540675193071365, 0.021180409938097, 0.03197954222559929, 0.2248576581478119, 0.12852585315704346, 0.2089216560125351, 0.039846520870923996, 0.1522471308708191, 0.8451279401779175, 0.021679740399122238, 0.035543736070394516, 0.005811640061438084, 0.04445958510041237, 0.018052000552415848, 0.0015424924204126, 0.013668404892086983, 0.012673787772655487, 0.0014405279653146863]], [[0.0022766904439777136, 0.00227623013779521, 0.027263110503554344, 0.7988243699073792, 0.12335250526666641, 0.012830986641347408, 0.008179515600204468, 0.004631126299500465, 0.020365260541439056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022365765646100044, 0.0197063609957695, 0.08540411293506622, 0.7100865840911865, 0.10288897156715393, 0.023861246183514595, 0.009303209371864796, 0.012690575793385506, 0.013693095184862614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.023093748837709427, 0.013999207876622677, 0.09048538655042648, 0.10519850999116898, 0.12126202881336212, 0.34847554564476013, 0.057331401854753494, 0.0919070839881897, 0.14824725687503815, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03627682104706764, 0.0323517769575119, 0.06003699079155922, 0.04609783738851547, 0.3189731240272522, 0.3202785551548004, 0.06900984793901443, 0.021341597661376, 0.0956336110830307, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026664189994335175, 0.018690558150410652, 0.01473171729594469, 0.003785684471949935, 0.012891196645796299, 0.6301508545875549, 0.1024516150355339, 0.10377107560634613, 0.08686315268278122, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.010066811926662922, 0.005272349342703819, 0.019913937896490097, 0.005584465805441141, 0.0479762889444828, 0.06466472148895264, 0.2978198528289795, 0.22872935235500336, 0.31997203826904297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.054553788155317307, 0.011876759119331837, 0.005296430550515652, 0.008171333000063896, 0.17499762773513794, 0.29638832807540894, 0.22286026179790497, 0.017016055062413216, 0.20883934199810028, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03061697818338871, 0.020777547731995583, 0.27117541432380676, 0.010558649897575378, 0.16651615500450134, 0.3011224865913391, 0.026109976693987846, 0.048922766000032425, 0.12420005351305008, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16545239090919495, 0.03877135366201401, 0.007565324194729328, 0.015141250565648079, 0.03747279569506645, 0.3241279125213623, 0.26990416646003723, 0.043362975120544434, 0.09820175170898438, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22949647903442383, 0.0972394198179245, 0.02905140444636345, 0.03182214871048927, 0.025490015745162964, 0.08278947323560715, 0.15009135007858276, 0.031098822131752968, 0.3229208290576935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.993086576461792, 0.0069133141078054905, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9927853345870972, 0.007214863318949938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9852874875068665, 0.011381878517568111, 0.0033306065015494823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011021426878869534, 0.007158290129154921, 0.9818204641342163, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4834398031234741, 0.011301998049020767, 0.48758530616760254, 0.017672834917902946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007071706000715494, 0.026167649775743484, 0.19316613674163818, 0.773594319820404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9851425886154175, 0.0010397545993328094, 0.00470126885920763, 0.0012236799811944366, 0.007892588153481483, 0.0, 0.0, 0.0, 0.0, 0.0, 0.320003479719162, 0.03976304829120636, 0.22334550321102142, 0.24320250749588013, 0.17368540167808533, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6588926315307617, 0.005506628658622503, 0.021607331931591034, 0.010738613083958626, 0.07747143507003784, 0.2257833182811737, 0.0, 0.0, 0.0, 0.0, 0.10932182520627975, 0.001151762087829411, 0.007792286574840546, 0.18981949985027313, 0.6517421007156372, 0.04017229378223419, 0.0, 0.0, 0.0, 0.0], [0.13557791709899902, 0.018924091011285782, 0.02187344618141651, 0.015362304635345936, 0.11512601375579834, 0.14739760756492615, 0.5457385182380676, 0.0, 0.0, 0.0, 0.02538878843188286, 0.005211540497839451, 0.03069700486958027, 0.13252338767051697, 0.4279623329639435, 0.0899164006114006, 0.28830063343048096, 0.0, 0.0, 0.0], [0.38992705941200256, 0.021535715088248253, 0.005403842777013779, 0.0032997699454426765, 0.4358868896961212, 0.06306594610214233, 0.03204012289643288, 0.04884066432714462, 0.0, 0.0, 0.010537173599004745, 0.0007831656257621944, 0.0007035965682007372, 0.015162549912929535, 0.9050821661949158, 0.05248205363750458, 0.01132790744304657, 0.00392116466537118, 0.0, 0.0], [0.81478351354599, 0.022238636389374733, 0.0008386021945625544, 0.01924033649265766, 0.06109088659286499, 0.020853841677308083, 0.014834966510534286, 0.028932424262166023, 0.017186695709824562, 0.0, 0.005222301464527845, 0.003575690556317568, 0.0029950442258268595, 0.00018454395467415452, 0.0012630765559151769, 0.01364975143224001, 0.09376595914363861, 0.853415846824646, 0.02592780999839306, 0.0], [0.011323019862174988, 0.004743177909404039, 0.004908193834125996, 0.04389021545648575, 0.9175272583961487, 0.008399821817874908, 0.00010120288789039478, 0.0007724545430392027, 0.001946530188433826, 0.006388010922819376, 0.14979584515094757, 0.0004723063320852816, 0.4970340430736542, 0.03214645013213158, 0.022075939923524857, 0.006538126152008772, 0.0013381451135501266, 0.0030305178370326757, 0.0008045822032727301, 0.28676414489746094]], [[0.023217031732201576, 0.015444980934262276, 0.33269768953323364, 0.4809305965900421, 0.08491171896457672, 0.027504485100507736, 0.007655052933841944, 0.015150148421525955, 0.012488299049437046, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003814368275925517, 0.0054845609702169895, 0.005400203168392181, 0.34217125177383423, 0.010647634975612164, 0.00044525362318381667, 0.00011972449283348396, 0.00042839962407015264, 0.6314883828163147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013448912650346756, 0.01028169970959425, 0.4982297718524933, 0.3182436525821686, 0.01780710555613041, 0.024587348103523254, 0.0009282209794037044, 0.11607228964567184, 0.0004009671974927187, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0027270291466265917, 0.01338754128664732, 0.019254636019468307, 0.11856623739004135, 0.0025901400949805975, 0.0012062221067026258, 0.0006161375786177814, 0.0012282256502658129, 0.8404240608215332, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.802536098693963e-05, 0.0005015733768232167, 2.3977232558536343e-05, 0.00012258262722752988, 0.00013862864580005407, 1.9367420463822782e-05, 1.2695372788584791e-05, 2.8395381377777085e-05, 0.9991349577903748, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.045823611319065094, 0.0060311248525977135, 0.11489683389663696, 0.011397628113627434, 0.14236140251159668, 0.31853923201560974, 0.18707275390625, 0.16781283915042877, 0.006064609158784151, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.031908370554447174, 0.0013231962220743299, 0.03774190694093704, 0.014869065955281258, 0.08836144208908081, 0.662682056427002, 0.1095389723777771, 0.05017231032252312, 0.0034025281202048063, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0061959377489984035, 0.012075785547494888, 0.28881579637527466, 0.0719127431511879, 0.08756363391876221, 0.0848873034119606, 0.027471251785755157, 0.404219388961792, 0.016858302056789398, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0946543961763382, 0.0623893216252327, 0.18748056888580322, 0.1788652539253235, 0.03208017721772194, 0.1587594598531723, 0.05469479411840439, 0.17047303915023804, 0.06060296297073364, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019481608644127846, 0.068674735724926, 0.13537795841693878, 0.2137300968170166, 0.031131863594055176, 0.02376358024775982, 0.030956387519836426, 0.04989796131849289, 0.4269856810569763, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9621535539627075, 0.037846412509679794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9691458940505981, 0.03085414692759514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5398231148719788, 0.4385344386100769, 0.021642372012138367, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9338735938072205, 0.02144204080104828, 0.04468445107340813, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6502059698104858, 0.16868625581264496, 0.04876677691936493, 0.13234086334705353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4091326594352722, 0.1788463294506073, 0.3530478775501251, 0.058973249047994614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5965072512626648, 0.06637387722730637, 0.1054789125919342, 0.1866345852613449, 0.04500538855791092, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8083640336990356, 0.0245783980935812, 0.02959858626127243, 0.02002020739018917, 0.11743883788585663, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3253602683544159, 0.03396952152252197, 0.02178867906332016, 0.07780158519744873, 0.04822142422199249, 0.49285849928855896, 0.0, 0.0, 0.0, 0.0, 0.6256738901138306, 0.03313886746764183, 0.03255102410912514, 0.015011090785264969, 0.27659764885902405, 0.017027597874403, 0.0, 0.0, 0.0, 0.0], [0.2524598240852356, 0.04065639525651932, 0.06012948602437973, 0.022925280034542084, 0.0371418297290802, 0.17370767891407013, 0.41297948360443115, 0.0, 0.0, 0.0, 0.2970131039619446, 0.01776941865682602, 0.015323061496019363, 0.014444534666836262, 0.2387886643409729, 0.36828577518463135, 0.048375438898801804, 0.0, 0.0, 0.0], [0.03411499038338661, 0.003937003668397665, 0.005961195565760136, 0.01710909977555275, 0.011033114977180958, 0.7081340551376343, 0.13750500977039337, 0.08220544457435608, 0.0, 0.0, 0.16347570717334747, 0.01386126596480608, 0.012116431258618832, 0.006670618429780006, 0.5951986312866211, 0.1577492356300354, 0.024585027247667313, 0.02634291537106037, 0.0, 0.0], [0.42400264739990234, 0.02131979539990425, 0.017963027581572533, 0.01083337515592575, 0.019156770780682564, 0.14712399244308472, 0.1343262642621994, 0.19853995740413666, 0.02673417516052723, 0.0, 0.1568753868341446, 0.002166055142879486, 0.0014692704426124692, 0.009539359249174595, 0.7249224781990051, 0.0696585550904274, 0.02269914373755455, 0.010646837763488293, 0.0020231890957802534, 0.0], [0.010900852270424366, 0.01643177680671215, 0.007438827771693468, 0.037741534411907196, 0.0038807683158665895, 0.513563871383667, 0.17121337354183197, 0.14364023506641388, 0.04466766491532326, 0.050521109253168106, 0.6687246561050415, 0.003988182172179222, 0.00992897991091013, 0.00877397134900093, 0.07160260528326035, 0.14080072939395905, 0.01739262230694294, 0.04941429942846298, 0.01782085746526718, 0.011553076095879078]], [[0.00896595511585474, 0.001820763573050499, 0.0036846648436039686, 0.8942996859550476, 0.002699120668694377, 0.0018430916825309396, 0.00023619653075002134, 0.0008667120710015297, 0.08558366447687149, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011139868758618832, 0.00517098605632782, 0.03486357256770134, 0.92783522605896, 0.010794212110340595, 0.0029791113920509815, 0.0008399260113947093, 0.0003134821599815041, 0.006063643377274275, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07888396829366684, 0.0272236131131649, 0.0322146937251091, 0.791079044342041, 0.03133838623762131, 0.009372375905513763, 0.002263500588014722, 0.0005359782953746617, 0.02708848938345909, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008838528767228127, 0.0009813528740778565, 0.014693140052258968, 0.00012726498243864626, 0.013269715011119843, 0.06431703269481659, 0.0039668334648013115, 0.8607616424560547, 0.0330444760620594, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.028727378696203232, 0.001701394678093493, 0.0009593431605026126, 0.0036824517883360386, 0.009683175943791866, 0.2589351236820221, 0.040837112814188004, 0.01649528741836548, 0.6389787197113037, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009239337407052517, 0.0011580593418329954, 0.0009623299702070653, 0.000996780814602971, 0.00493139773607254, 0.04319336265325546, 0.859686553478241, 0.012395362369716167, 0.06743697822093964, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.024199873208999634, 0.007249501068145037, 0.02041051909327507, 0.008800184354186058, 0.02760438062250614, 0.1116553395986557, 0.030366744846105576, 0.03851965814828873, 0.7311937808990479, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06881897896528244, 0.21671976149082184, 0.02303808182477951, 0.0017656114650890231, 0.09897635877132416, 0.04207116737961769, 0.012660021893680096, 0.25307658314704895, 0.2828734517097473, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09324429929256439, 0.059572815895080566, 0.021969754248857498, 0.008625463582575321, 0.022502752020955086, 0.07016356289386749, 0.033860694617033005, 0.03514377400279045, 0.6549169421195984, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04541633278131485, 0.01696496643126011, 0.003866765182465315, 0.00941139180213213, 0.006640681531280279, 0.024550199508666992, 0.009012367576360703, 0.009869653731584549, 0.8742677569389343, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4730486273765564, 0.5269513726234436, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.497504860162735, 0.502495288848877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39858773350715637, 0.07930062711238861, 0.5221116542816162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.028444888070225716, 0.01678420603275299, 0.9547709822654724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5825604200363159, 0.08404675871133804, 0.15067298710346222, 0.182719886302948, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02853180095553398, 0.022399114444851875, 0.7835201025009155, 0.1655489057302475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29498350620269775, 0.03899451717734337, 0.00506106112152338, 0.006130008026957512, 0.6548308730125427, 0.0, 0.0, 0.0, 0.0, 0.0, 0.023048963397741318, 0.055082567036151886, 0.3371332883834839, 0.25099456310272217, 0.33374062180519104, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13055028021335602, 0.007264712825417519, 0.014658198691904545, 0.03852052241563797, 0.6908979415893555, 0.11810839176177979, 0.0, 0.0, 0.0, 0.0, 0.013693265616893768, 0.057373203337192535, 0.02566814236342907, 0.11711565405130386, 0.13761301338672638, 0.6485366225242615, 0.0, 0.0, 0.0, 0.0], [0.6701509952545166, 0.016114505007863045, 0.009837295860052109, 0.013812566176056862, 0.10121432691812515, 0.04637172445654869, 0.14249859750270844, 0.0, 0.0, 0.0, 0.5831283926963806, 0.0857725590467453, 0.06227085366845131, 0.03169894590973854, 0.06183577701449394, 0.01752074435353279, 0.15777261555194855, 0.0, 0.0, 0.0], [0.15980258584022522, 0.02680308185517788, 0.03885137289762497, 0.01341771800071001, 0.16442187130451202, 0.12716332077980042, 0.3698134124279022, 0.09972671419382095, 0.0, 0.0, 0.0033312023151665926, 0.003545752028003335, 0.0018331086030229926, 0.05265560373663902, 0.047756411135196686, 0.045255228877067566, 0.20667387545108795, 0.6389486193656921, 0.0, 0.0], [0.5671898722648621, 0.0029452391900122166, 0.0006932761170901358, 0.0009682640084065497, 0.008882325142621994, 0.018135691061615944, 0.19489231705665588, 0.1878870278596878, 0.01840599626302719, 0.0, 0.02047032117843628, 0.03542931377887726, 0.01270933635532856, 0.46998995542526245, 0.035482652485370636, 0.015606570988893509, 0.1128709465265274, 0.03180817514657974, 0.26563259959220886, 0.0], [0.10793960839509964, 0.02733222208917141, 0.05983218923211098, 0.007959540002048016, 0.012123869732022285, 0.0992540642619133, 0.031409986317157745, 0.1074245497584343, 0.5389924645423889, 0.007731476798653603, 0.027955254539847374, 0.024354776367545128, 0.4609973132610321, 0.07958999276161194, 0.34062448143959045, 0.0068156360648572445, 0.000798556546214968, 0.0009541919571347535, 0.00023223790049087256, 0.05767740309238434]], [[0.007143852766603231, 0.26111796498298645, 0.053768061101436615, 0.022731401026248932, 0.014146089553833008, 0.012985849753022194, 0.007359612733125687, 0.0043042986653745174, 0.6164429783821106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6188079118728638, 0.11004422605037689, 0.07541824132204056, 0.010463211685419083, 0.003863272722810507, 0.016659650951623917, 0.028880171477794647, 0.010046081617474556, 0.1258174628019333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1673126220703125, 0.6515482068061829, 0.016748156398534775, 0.042502570897340775, 0.016912223771214485, 0.011716129258275032, 0.04548521339893341, 0.0008787817787379026, 0.04689598083496094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6625117659568787, 0.049922335892915726, 0.2738172709941864, 0.004228150937706232, 0.0033112652599811554, 0.001177642960101366, 0.0005330604617483914, 0.00011132613872177899, 0.0043872324749827385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0040039438754320145, 0.00112480903044343, 0.04353015124797821, 0.9313303232192993, 0.010056668892502785, 0.0007567661814391613, 0.0006773694767616689, 0.00016374654660467058, 0.008356312289834023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015825676964595914, 0.001574154943227768, 0.001225732616148889, 0.27774307131767273, 0.47191065549850464, 0.041899941861629486, 0.10331469774246216, 0.0047262245789170265, 0.0960230901837349, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0012044805334880948, 0.001744594657793641, 0.010911357589066029, 0.035235531628131866, 0.12406003475189209, 0.49639585614204407, 0.02129644714295864, 0.07618547230958939, 0.23296628892421722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006665610242635012, 0.008957373909652233, 0.0028928713873028755, 0.7268922924995422, 0.10707614570856094, 0.01201178040355444, 0.013845101930201054, 0.022992080077528954, 0.09866661578416824, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04134861007332802, 0.0526767373085022, 0.04131396487355232, 0.023087071254849434, 0.04077164828777313, 0.027765633538365364, 0.05679082125425339, 0.025407245382666588, 0.6908383369445801, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0021411015186458826, 0.012145284563302994, 0.008635377511382103, 0.004571457393467426, 0.009789393283426762, 0.022923681885004044, 0.019266795367002487, 0.15913596749305725, 0.7613908052444458, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9029706120491028, 0.09702935069799423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.909331738948822, 0.09066825360059738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0243820920586586, 0.026593990623950958, 0.9490237236022949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3865880072116852, 0.017979737371206284, 0.5954321622848511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002445725491270423, 0.01137782447040081, 0.2685152590274811, 0.7176609635353088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5815560817718506, 0.15706834197044373, 0.052335821092128754, 0.2090395838022232, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013372139073908329, 0.017163371667265892, 0.023703746497631073, 0.029362313449382782, 0.916398286819458, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7268465161323547, 0.0363004133105278, 0.07873083651065826, 0.06576839834451675, 0.09235385805368423, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009483089670538902, 0.0015323974657803774, 0.016186771914362907, 0.02369842305779457, 0.15252061188220978, 0.7965786457061768, 0.0, 0.0, 0.0, 0.0, 0.6550694704055786, 0.019533857703208923, 0.042362816631793976, 0.07321250438690186, 0.06519921869039536, 0.14462217688560486, 0.0, 0.0, 0.0, 0.0], [0.9718639850616455, 0.0004310244112275541, 0.00011954548244830221, 0.007853196933865547, 0.005200029816478491, 0.0034086843952536583, 0.011123435571789742, 0.0, 0.0, 0.0, 0.48597243428230286, 0.05253118649125099, 0.06572883576154709, 0.06831242144107819, 0.06681334227323532, 0.09225586801767349, 0.168385848402977, 0.0, 0.0, 0.0], [0.24338993430137634, 0.0009381886338815093, 0.001691973302513361, 0.004991883412003517, 0.06480661034584045, 0.02667633630335331, 0.5911706686019897, 0.06633439660072327, 0.0, 0.0, 0.2283225953578949, 0.01085133571177721, 0.0076954541727900505, 0.03403906524181366, 0.05505141243338585, 0.11318682134151459, 0.23008716106414795, 0.3207661509513855, 0.0, 0.0], [0.9644694328308105, 0.00020931981271132827, 0.00022034233552403748, 0.001116775325499475, 0.0005140798166394234, 0.011200232431292534, 0.006607241928577423, 0.015303434804081917, 0.00035933865001425147, 0.0, 0.31019407510757446, 0.01576145552098751, 0.006604246329516172, 0.1025082990527153, 0.11805430799722672, 0.0999068170785904, 0.17944715917110443, 0.09494999051094055, 0.07257375121116638, 0.0], [6.446504994528368e-05, 5.223282641964033e-05, 4.761212403536774e-05, 0.0026887860149145126, 0.9879595041275024, 8.169181819539517e-05, 3.4432316169841215e-05, 0.00022215544595383108, 0.008540215902030468, 0.0003085293574258685, 0.028495613485574722, 0.00728303287178278, 0.028978589922189713, 0.21746259927749634, 0.0312367994338274, 0.01134485937654972, 0.002138715935871005, 0.0005697175511159003, 0.00012198994954815134, 0.6723678112030029]], [[0.1481553614139557, 0.14691436290740967, 0.5575758218765259, 0.02441403828561306, 0.058879025280475616, 0.011832842603325844, 0.01016098354011774, 0.015505112707614899, 0.026562504470348358, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20737145841121674, 0.2658809721469879, 0.4251604974269867, 0.03998560830950737, 0.012661930173635483, 0.003662273520603776, 0.0006891252705827355, 0.004390099551528692, 0.040197838097810745, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09591562300920486, 0.13717111945152283, 0.23219715058803558, 0.020156029611825943, 0.031411558389663696, 0.04842779412865639, 0.003137993859127164, 0.03623202070593834, 0.3953508734703064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.039530280977487564, 0.0770052894949913, 0.4637334942817688, 0.3284752666950226, 0.018390586599707603, 0.021701356396079063, 0.0038800504989922047, 0.01712900958955288, 0.03015456721186638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11945435404777527, 0.064958855509758, 0.07506249845027924, 0.03312050923705101, 0.045947931706905365, 0.21168209612369537, 0.1585550606250763, 0.15941208600997925, 0.13180643320083618, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.027811916545033455, 0.005367752630263567, 0.022701909765601158, 0.02928026206791401, 0.042085714638233185, 0.23124846816062927, 0.3448639512062073, 0.17164644598960876, 0.12499356269836426, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04207267239689827, 0.0036673955619335175, 0.013221505098044872, 0.04823020473122597, 0.018784234300255775, 0.15617071092128754, 0.5762590169906616, 0.08817830681800842, 0.053416069597005844, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.016374358907341957, 0.01844160258769989, 0.0564517118036747, 0.008724928833544254, 0.031119121238589287, 0.08068697899580002, 0.028958600014448166, 0.08762799203395844, 0.6716147661209106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01953587494790554, 0.025442129001021385, 0.033712055534124374, 0.054231878370046616, 0.046861547976732254, 0.038379911333322525, 0.03105914779007435, 0.027592265978455544, 0.723185122013092, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13706038892269135, 0.05296454578638077, 0.06056801974773407, 0.10271193832159042, 0.10989244282245636, 0.11971112340688705, 0.10623226314783096, 0.11503037810325623, 0.19582894444465637, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9361864924430847, 0.06381344050168991, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9973775148391724, 0.0026223897002637386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8988155722618103, 0.08033642917871475, 0.020848000422120094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8086934685707092, 0.08078567683696747, 0.11052089184522629, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9107179045677185, 0.0414138063788414, 0.01669401116669178, 0.031174303963780403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14967022836208344, 0.05171789228916168, 0.3914002478122711, 0.40721163153648376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7451518774032593, 0.09319417923688889, 0.038068220019340515, 0.07867664098739624, 0.04490913450717926, 0.0, 0.0, 0.0, 0.0, 0.0, 0.40780311822891235, 0.04434635117650032, 0.05232110992074013, 0.3448564112186432, 0.15067294239997864, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7448095083236694, 0.053781699389219284, 0.02206255868077278, 0.051568105816841125, 0.050503022968769073, 0.07727508991956711, 0.0, 0.0, 0.0, 0.0, 0.3417491614818573, 0.023165758699178696, 0.008621969260275364, 0.03819064050912857, 0.566249430179596, 0.022023199126124382, 0.0, 0.0, 0.0, 0.0], [0.49530187249183655, 0.08432565629482269, 0.024537190794944763, 0.03536847233772278, 0.04101351276040077, 0.2942921817302704, 0.025161173194646835, 0.0, 0.0, 0.0, 0.13283461332321167, 0.0027981544844806194, 0.001892031985335052, 0.057958006858825684, 0.4807162284851074, 0.22431829571723938, 0.09948258846998215, 0.0, 0.0, 0.0], [0.33749768137931824, 0.0470278225839138, 0.025994539260864258, 0.11184448003768921, 0.035708073526620865, 0.3288814127445221, 0.052594561129808426, 0.06045151129364967, 0.0, 0.0, 0.5702553391456604, 0.005225116387009621, 0.0014312443090602756, 0.028526127338409424, 0.15899939835071564, 0.05284468084573746, 0.022491520270705223, 0.16022635996341705, 0.0, 0.0], [0.5827996730804443, 0.037185750901699066, 0.025691334158182144, 0.040444474667310715, 0.032313525676727295, 0.15237869322299957, 0.02532070316374302, 0.06300554424524307, 0.040860243141651154, 0.0, 0.005209033377468586, 6.901475717313588e-05, 5.760595013271086e-05, 0.006149875931441784, 0.006613760255277157, 0.010193211026489735, 0.013639912940561771, 0.9578513503074646, 0.00021631908020935953, 0.0], [0.17712005972862244, 0.06858222186565399, 0.023361776024103165, 0.06553570926189423, 0.015878353267908096, 0.40178799629211426, 0.03335757926106453, 0.09457883983850479, 0.06679747253656387, 0.05300001800060272, 0.5839820504188538, 0.007275882177054882, 0.03890826180577278, 0.2169828861951828, 0.02285575494170189, 0.0033320344518870115, 0.0027764069382101297, 0.032896872609853745, 0.038299210369586945, 0.05269058048725128]], [[0.1995955854654312, 0.1365700364112854, 0.06844333559274673, 0.10430964082479477, 0.06450515240430832, 0.046256136149168015, 0.1181989535689354, 0.11867640167474747, 0.14344464242458344, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11168574541807175, 0.19221879541873932, 0.09424428641796112, 0.10450402647256851, 0.06917304545640945, 0.0600862056016922, 0.14199501276016235, 0.09375911951065063, 0.1323338896036148, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1609925925731659, 0.1396498680114746, 0.177944153547287, 0.03334498405456543, 0.016808874905109406, 0.10536731034517288, 0.1187783032655716, 0.03365077078342438, 0.2134632021188736, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1381153017282486, 0.04422784969210625, 0.04791303351521492, 0.16848880052566528, 0.14531251788139343, 0.08485772460699081, 0.03650972992181778, 0.08906612545251846, 0.24550898373126984, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05420248210430145, 0.02658572979271412, 0.05446610227227211, 0.10749125480651855, 0.22097598016262054, 0.16638338565826416, 0.0331658273935318, 0.035041358321905136, 0.30168798565864563, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13837963342666626, 0.02727973647415638, 0.1299334168434143, 0.0896206796169281, 0.11551950126886368, 0.13963927328586578, 0.07841819524765015, 0.02172034978866577, 0.25948914885520935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08648376911878586, 0.022083481773734093, 0.023758457973599434, 0.19388236105442047, 0.1724909394979477, 0.02776450663805008, 0.04756799340248108, 0.03839871287345886, 0.38756975531578064, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06476524472236633, 0.0030945511534810066, 0.016289785504341125, 0.013512126170098782, 0.007217712234705687, 0.047962453216314316, 0.03755675256252289, 0.02134101279079914, 0.788260281085968, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06007339805364609, 0.038077425211668015, 0.01732070930302143, 0.04335314407944679, 0.09849875420331955, 0.07123422622680664, 0.12536978721618652, 0.17402620613574982, 0.37204641103744507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.31619662046432495, 0.10629935562610626, 0.051193755120038986, 0.08206456899642944, 0.08056272566318512, 0.05727463215589523, 0.13476009666919708, 0.03839832916855812, 0.13325001299381256, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5274816155433655, 0.47251835465431213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.050016310065984726, 0.9499835968017578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0672292709350586, 0.8769893646240234, 0.05578138306736946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1501082479953766, 0.3363426625728607, 0.5135491490364075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13304242491722107, 0.3016340434551239, 0.1132093071937561, 0.45211419463157654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3278971016407013, 0.3615517318248749, 0.08450257778167725, 0.22604861855506897, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3050541281700134, 0.12257003039121628, 0.15977424383163452, 0.12758392095565796, 0.2850176990032196, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25667694211006165, 0.40780505537986755, 0.15422186255455017, 0.10097997635602951, 0.08031607419252396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22356735169887543, 0.03928647190332413, 0.007754397578537464, 0.009327426552772522, 0.12143179029226303, 0.5986325740814209, 0.0, 0.0, 0.0, 0.0, 0.06381947547197342, 0.026328660547733307, 0.009785240516066551, 0.001955200219526887, 0.6504424810409546, 0.24766895174980164, 0.0, 0.0, 0.0, 0.0], [0.21462960541248322, 0.06222677230834961, 0.03770677372813225, 0.020617984235286713, 0.1298619657754898, 0.16450734436511993, 0.3704494833946228, 0.0, 0.0, 0.0, 0.20312389731407166, 0.04875075817108154, 0.01619674637913704, 0.028123438358306885, 0.08579143136739731, 0.20417368412017822, 0.41383999586105347, 0.0, 0.0, 0.0], [0.03640613704919815, 0.010346643626689911, 0.00673291739076376, 0.007102567236870527, 0.047351155430078506, 0.07502260059118271, 0.1735789030790329, 0.6434589624404907, 0.0, 0.0, 0.0028631098102778196, 0.00043423930765129626, 0.00021322182146832347, 0.0004598038794938475, 0.009248310700058937, 0.010348621755838394, 0.14874719083309174, 0.8276853561401367, 0.0, 0.0], [0.10655857622623444, 0.005401281639933586, 0.008467022329568863, 0.004935698118060827, 0.02920999936759472, 0.06761414557695389, 0.11367721855640411, 0.6410130262374878, 0.02312297187745571, 0.0, 0.0018436884274706244, 0.00014077842934057117, 0.00019285999587737024, 0.0006260477821342647, 0.010675687342882156, 0.007219970691949129, 0.05410425364971161, 0.9211149215698242, 0.0040815602988004684, 0.0], [0.022663118317723274, 0.01638328656554222, 0.016234278678894043, 0.013438239693641663, 0.13762539625167847, 0.1316443532705307, 0.10126813501119614, 0.1815005987882614, 0.10885223746299744, 0.27039045095443726, 0.0996592566370964, 0.00012745348794851452, 0.0005518114776350558, 0.00026576913660392165, 0.00016320105351042002, 9.30886235437356e-05, 0.00024295282491948456, 0.000212193961488083, 0.0026789114344865084, 0.8960053324699402]]], [[[0.0077307759784162045, 0.013184988871216774, 0.016869038343429565, 0.013336911797523499, 0.01304439827799797, 0.013718237169086933, 0.0296618789434433, 0.02448520064353943, 0.8679684996604919, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1374177634716034, 0.018056754022836685, 0.029763542115688324, 0.004862301517277956, 0.00231130956672132, 0.006278112530708313, 0.012106452137231827, 0.033879589289426804, 0.755324125289917, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1539316475391388, 0.23461903631687164, 0.033691998571157455, 0.026462335139513016, 0.0030949951615184546, 0.0038835303857922554, 0.009438932873308659, 0.0025479686446487904, 0.5323294997215271, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29215019941329956, 0.05790534242987633, 0.18934868276119232, 0.018473153933882713, 0.002999690594151616, 0.004652327857911587, 0.010374259203672409, 0.0072145056910812855, 0.4168816804885864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16539201140403748, 0.05819307267665863, 0.12084146589040756, 0.1738077849149704, 0.004504370968788862, 0.006831282749772072, 0.02180996537208557, 0.012287246994674206, 0.4363327622413635, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017602156847715378, 0.026805447414517403, 0.07671570032835007, 0.5152483582496643, 0.21202509105205536, 0.041201505810022354, 0.02207496576011181, 0.00952092744410038, 0.07880578190088272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.017750855535268784, 0.01654047518968582, 0.07482129335403442, 0.23223723471164703, 0.3542158007621765, 0.16141267120838165, 0.02249749004840851, 0.044686269015073776, 0.07583795487880707, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01755565032362938, 0.008823209442198277, 0.013083218596875668, 0.5061533451080322, 0.02344801276922226, 0.0075739468447864056, 0.07187878340482712, 0.029167035594582558, 0.3223167061805725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07037408649921417, 0.05453738570213318, 0.05508268624544144, 0.02769530564546585, 0.038050826638936996, 0.20446287095546722, 0.19980187714099884, 0.19835616648197174, 0.15163862705230713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.003375247586518526, 0.008793321438133717, 0.001001630094833672, 0.002094975672662258, 0.0032946632709354162, 0.01792662777006626, 0.10988471657037735, 0.21093924343585968, 0.6426896452903748, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5484673976898193, 0.11615661531686783, 0.018765881657600403, 0.05580288916826248, 0.007166780531406403, 0.0032917018979787827, 0.014802427962422371, 0.009165346622467041, 0.2263808399438858, 0.0, 0.11433855444192886, 0.04686390608549118, 0.05050662159919739, 0.01692698895931244, 0.014815520495176315, 0.0768260732293129, 0.017432983964681625, 0.6101956367492676, 0.052093639969825745, 0.0], [0.05529153719544411, 0.9114729762077332, 0.02160518430173397, 0.004567866213619709, 0.0020856212358921766, 0.002110318513587117, 7.9427394666709e-05, 0.0003330546023789793, 0.002454147208482027, 0.0, 0.38972416520118713, 0.22944767773151398, 0.07904881238937378, 0.052096493542194366, 0.007011168636381626, 0.006784902419894934, 0.0014207185013219714, 0.13426420092582703, 0.10020176321268082, 0.0], [0.05249117314815521, 0.7437799572944641, 0.20206855237483978, 0.000259216787526384, 5.300459815771319e-05, 0.0010736108524724841, 1.3093422239762731e-05, 5.472580232890323e-05, 0.00020689083612523973, 0.0, 0.11186019331216812, 0.017331605777144432, 0.04338392615318298, 0.012996343895792961, 0.0037596735637634993, 0.04186789691448212, 0.010188347660005093, 0.40130615234375, 0.35730576515197754, 0.0], [0.00011124516458949074, 0.00014391898002941161, 0.002859711181372404, 0.9545366168022156, 0.03663090988993645, 0.0007743826135993004, 9.755761857377365e-05, 0.004461521748453379, 0.00038416660390794277, 0.0, 0.10591340810060501, 0.20223784446716309, 0.2886512279510498, 0.06966178864240646, 0.009836114011704922, 0.0229511596262455, 0.008547846227884293, 0.14636646211147308, 0.14583423733711243, 0.0], [0.00743946572765708, 0.0006693374016322196, 0.030706975609064102, 0.7693524360656738, 0.13636630773544312, 0.010656076483428478, 0.00020547783060465008, 0.04430907592177391, 0.0002948205510620028, 0.0, 0.07172418385744095, 0.2936038672924042, 0.03526498004794121, 0.13891401886940002, 0.06139945611357689, 0.03925776481628418, 0.05349786579608917, 0.1478489190340042, 0.15848881006240845, 0.0], [0.34355977177619934, 0.05352164804935455, 0.28276264667510986, 0.0013746530748903751, 0.09131773561239243, 0.14369648694992065, 0.012657100334763527, 0.0018211203860118985, 0.06928855180740356, 0.0, 0.06896000355482101, 0.07670743763446808, 0.04746192321181297, 0.0077508557587862015, 0.0064402432180941105, 0.0690828412771225, 0.07239065319299698, 0.3534849286079407, 0.2977212965488434, 0.0], [0.00012065855844411999, 1.9836104911519215e-05, 1.09456195787061e-05, 1.777518991730176e-05, 5.799566861242056e-05, 0.00024955562548711896, 0.9993764758110046, 2.622428155518719e-06, 0.00014374956663232297, 0.0, 0.07854402810335159, 0.05315924435853958, 0.006970829796046019, 0.01197806466370821, 0.15678070485591888, 0.059328123927116394, 0.0844779834151268, 0.058106619864702225, 0.4906543493270874, 0.0], [0.07937772572040558, 0.06896578520536423, 0.046331409364938736, 0.0006753505440428853, 0.10991709679365158, 0.07862550020217896, 0.015291865915060043, 0.024935398250818253, 0.575879693031311, 0.0, 0.03681005910038948, 0.0140389958396554, 0.007565703243017197, 0.004180380143225193, 0.03550711274147034, 0.08768045157194138, 0.04672156274318695, 0.05331620201468468, 0.7141793966293335, 0.0], [0.0013021372724324465, 0.0014324801741167903, 0.001721968175843358, 0.0011953436769545078, 0.025524066761136055, 0.0017154657980427146, 0.004751213360577822, 0.06856247782707214, 0.8937948942184448, 0.0, 0.011276278644800186, 0.0043571279384195805, 0.0015699869254603982, 0.009309794753789902, 0.5466312766075134, 0.06633520126342773, 0.012565904296934605, 0.036605555564165115, 0.3113488256931305, 0.0], [0.19237911701202393, 0.08248087018728256, 0.005975060164928436, 0.0005637910799123347, 0.0032617340330034494, 0.08159960061311722, 0.10672623664140701, 0.06415636837482452, 0.4628572165966034, 0.0, 0.06600549072027206, 0.025541657581925392, 0.15132947266101837, 0.052793603390455246, 0.07684693485498428, 0.05682613328099251, 0.01590665802359581, 0.05306769907474518, 0.501682460308075, 0.0]], [[0.07041527330875397, 0.15367093682289124, 0.3963199257850647, 0.03077671490609646, 0.0928598940372467, 0.04086732864379883, 0.018142100423574448, 0.012120239436626434, 0.18482762575149536, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0367966964840889, 0.022482391446828842, 0.35830163955688477, 0.02875097654759884, 0.03547174483537674, 0.026731541380286217, 0.005365677177906036, 0.038472291082143784, 0.4476269483566284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.035722482949495316, 0.013633755035698414, 0.09877835214138031, 0.0896211713552475, 0.16777706146240234, 0.10725134611129761, 0.05053357034921646, 0.10712091624736786, 0.32956135272979736, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.026592494919896126, 0.002020884770900011, 0.010739283636212349, 0.015951883047819138, 0.18538028001785278, 0.16766443848609924, 0.03731367364525795, 0.3853055238723755, 0.16903170943260193, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0038862666115164757, 0.0004870722477789968, 0.0013956124894320965, 0.001421120367012918, 0.013834443874657154, 0.18579153716564178, 0.18213258683681488, 0.5258954763412476, 0.08515587449073792, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002724156714975834, 0.004415807779878378, 0.003638928523287177, 0.00862019695341587, 0.010569852776825428, 0.18068262934684753, 0.2256886065006256, 0.3616458773612976, 0.20201392471790314, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004798010922968388, 0.006960091646760702, 0.005558326840400696, 0.015252271667122841, 0.010058294981718063, 0.16163121163845062, 0.19844789803028107, 0.089196115732193, 0.508097767829895, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006148195825517178, 0.009931370615959167, 0.0022139709908515215, 0.003481896361336112, 0.00199966412037611, 0.011451391503214836, 0.018514955416321754, 0.10389390587806702, 0.8423647284507751, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1474374532699585, 0.1116698756814003, 0.10040155798196793, 0.07832593470811844, 0.051569730043411255, 0.103182852268219, 0.0987909808754921, 0.06264416873455048, 0.24597744643688202, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0029639359563589096, 0.0008504446013830602, 0.002215511864051223, 0.0016108372947201133, 0.001786046545021236, 0.003435377962887287, 0.000923731888178736, 0.009203500114381313, 0.9770104885101318, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02611132338643074, 0.03929625079035759, 0.13154497742652893, 0.0362381711602211, 0.41634607315063477, 0.056842975318431854, 0.07852181792259216, 0.04581860452890396, 0.16927982866764069, 0.0, 0.02032626047730446, 0.020203230902552605, 0.6020291447639465, 0.030009541660547256, 0.018654465675354004, 0.18802858889102936, 0.05143868178129196, 0.02303317002952099, 0.04627683013677597, 0.0], [0.02255222387611866, 0.10361888259649277, 0.8158259987831116, 0.004235901869833469, 0.006766538135707378, 0.000986771541647613, 0.0032120062969624996, 0.001885716337710619, 0.04091576859354973, 0.0, 0.0019176346249878407, 0.016540443524718285, 0.9535443782806396, 0.023291945457458496, 0.00010982116509694606, 0.0009689715225249529, 0.00013850984396412969, 0.00020587266772054136, 0.0032822038047015667, 0.0], [0.017373425886034966, 0.13537107408046722, 0.8077713847160339, 0.0038886782713234425, 0.000276644917903468, 0.0006381930434145033, 0.00045153097016736865, 0.00025059780455194414, 0.0339784249663353, 0.0, 0.0021921033039689064, 0.01942657120525837, 0.9639623165130615, 0.007353355176746845, 5.936318120802753e-05, 0.0048356493934988976, 6.20722203166224e-05, 0.00038220148417167366, 0.0017266407376155257, 0.0], [0.005007221829146147, 0.01780957728624344, 0.01267488207668066, 0.04065018519759178, 0.30516332387924194, 0.0026367236860096455, 0.0019572318997234106, 0.004150545224547386, 0.6099502444267273, 0.0, 0.009460263885557652, 0.051493529230356216, 0.3948231041431427, 0.009338779374957085, 0.006950095761567354, 0.48816555738449097, 0.01867900975048542, 0.0028053568676114082, 0.018284155055880547, 0.0], [0.00225767120718956, 0.0031865497585386038, 0.001125291339121759, 0.016497144475579262, 0.8971690535545349, 0.00343481102026999, 0.003961168695241213, 0.006191920954734087, 0.0661763921380043, 0.0, 0.003424277761951089, 0.009652957320213318, 0.10005363076925278, 0.026136387139558792, 0.09182560443878174, 0.6324647068977356, 0.07658552378416061, 0.005459657870233059, 0.05439731851220131, 0.0], [0.009628614410758018, 0.010054895654320717, 0.001336919842287898, 0.0704738199710846, 0.6877674460411072, 0.03301373869180679, 0.05187760666012764, 0.005273953080177307, 0.1305730938911438, 0.0, 0.0026005429681390524, 0.0029943317640572786, 0.26352012157440186, 0.02426978573203087, 0.05801504850387573, 0.49633610248565674, 0.11849544942378998, 0.009708826430141926, 0.02405967190861702, 0.0], [0.004103087354451418, 0.0010862533235922456, 0.0006940921302884817, 0.005870609078556299, 0.43826234340667725, 0.030803751200437546, 0.2956492602825165, 0.002342070685699582, 0.2211885154247284, 0.0, 0.0027453943621367216, 0.0021119171287864447, 0.0030521987937390804, 0.09308812767267227, 0.28145554661750793, 0.015254919417202473, 0.530491828918457, 0.007408978417515755, 0.06439103186130524, 0.0], [0.0007035931921564043, 0.0015657383482903242, 0.0003329406026750803, 0.025085464119911194, 0.8715798258781433, 0.006046876311302185, 0.002586639951914549, 0.00011169366916874424, 0.09198720753192902, 0.0, 0.0012973180273547769, 0.002199073787778616, 0.0031004296615719795, 0.024488963186740875, 0.8535729050636292, 0.016068320721387863, 0.029179612174630165, 0.012250186875462532, 0.05784311145544052, 0.0], [0.0021814475767314434, 0.0018482444575056434, 0.02461252734065056, 0.02290530502796173, 0.17733190953731537, 0.007551506161689758, 0.026218494400382042, 0.1859409213066101, 0.5514096021652222, 0.0, 0.0010010729311034083, 0.0008253253763541579, 0.0028483583591878414, 0.028342707082629204, 0.860925018787384, 0.0038871155120432377, 0.006998666562139988, 0.01413769368082285, 0.08103384077548981, 0.0], [0.002560347318649292, 0.0069580040872097015, 0.0021583843044936657, 0.002428637584671378, 0.010794135741889477, 0.002866419730708003, 0.010929176583886147, 0.004671781323850155, 0.9566330909729004, 0.0, 0.007383578456938267, 0.056256651878356934, 0.5807297825813293, 0.01667044125497341, 0.03810223564505577, 0.07880110293626785, 0.009197888895869255, 0.12926581501960754, 0.08359251171350479, 0.0]], [[0.022387586534023285, 0.045972827821969986, 0.05835629999637604, 0.22869053483009338, 0.010770916007459164, 0.006216464098542929, 0.018148910254240036, 0.006308646872639656, 0.6031478047370911, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0035997454542666674, 0.002674269489943981, 0.016009783372282982, 0.05554450675845146, 0.0013587778666988015, 0.0032801039051264524, 0.00560772093012929, 0.00799081102013588, 0.9039342403411865, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0018151046242564917, 0.001049908110871911, 0.0005912692868150771, 0.005136367864906788, 0.0005621784366667271, 0.00844560656696558, 0.017937110736966133, 0.008342047221958637, 0.9561205506324768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.004429707303643227, 0.005516116041690111, 0.003033371875062585, 0.012963998131453991, 0.0034379358403384686, 0.003276604227721691, 0.0140963364392519, 0.005416945554316044, 0.9478288888931274, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013230006210505962, 0.011804360896348953, 0.009972047992050648, 0.004975683055818081, 0.008386109955608845, 0.18977868556976318, 0.1806434541940689, 0.03204761818051338, 0.549161970615387, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.006766628473997116, 0.008349079638719559, 0.003925195895135403, 0.0006033667596057057, 0.006175691727548838, 0.2236345112323761, 0.03405819088220596, 0.07976362109184265, 0.6367236971855164, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03860252723097801, 0.01646261475980282, 0.02104821614921093, 0.0021387943997979164, 0.005319601856172085, 0.2400989532470703, 0.03188503161072731, 0.005558657925575972, 0.6388856768608093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0738457515835762, 0.018826894462108612, 0.0069308048114180565, 0.0074225678108632565, 0.004789229016751051, 0.046955253928899765, 0.11907684803009033, 0.18744726479053497, 0.5347052812576294, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09479796141386032, 0.11939200013875961, 0.0752992108464241, 0.061374519020318985, 0.08638977259397507, 0.12459041178226471, 0.16023214161396027, 0.0879756435751915, 0.1899482160806656, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.027537798509001732, 0.06296242028474808, 0.014751194976270199, 0.0011882808757945895, 0.016387099400162697, 0.15830224752426147, 0.03707461059093475, 0.028470970690250397, 0.6533253788948059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10015174746513367, 0.09576243162155151, 0.03824060782790184, 0.020538825541734695, 0.027732321992516518, 0.017240623012185097, 0.011470633558928967, 0.4632270634174347, 0.22563566267490387, 0.0, 0.023356424644589424, 0.012650059536099434, 0.05017145350575447, 0.05590398982167244, 0.05159280076622963, 0.01602507382631302, 0.014807065948843956, 0.654244601726532, 0.12124844640493393, 0.0], [0.018642960116267204, 0.04822106286883354, 0.0140958521515131, 0.13092586398124695, 0.031955283135175705, 0.05195324495434761, 0.024238048121333122, 0.35591286420822144, 0.32405486702919006, 0.0, 0.030835414305329323, 0.04180247709155083, 0.029645785689353943, 0.20071062445640564, 0.010328685864806175, 0.03208288922905922, 0.026780622079968452, 0.457701712846756, 0.17011170089244843, 0.0], [0.04842180013656616, 0.06019889563322067, 0.016854524612426758, 0.166769877076149, 0.016003064811229706, 0.024013301357626915, 0.03686072677373886, 0.44016721844673157, 0.1907106339931488, 0.0, 0.02856343612074852, 0.03459611535072327, 0.15441730618476868, 0.04662194848060608, 0.0013040672056376934, 0.017847269773483276, 0.02464178577065468, 0.5969575643539429, 0.09505032747983932, 0.0], [0.0036558371502906084, 0.0033082098234444857, 0.0009605223312973976, 0.017093589529395103, 0.019939379766583443, 0.08280718326568604, 0.031923551112413406, 0.703184187412262, 0.1371275782585144, 0.0, 0.03350958973169327, 0.02514287829399109, 0.027676144614815712, 0.11052078753709793, 0.15496152639389038, 0.08862635493278503, 0.027723105624318123, 0.24766287207603455, 0.28417670726776123, 0.0], [0.0018930931109935045, 0.002881440566852689, 0.00019882648484781384, 0.00406575808301568, 0.0021070034708827734, 0.011610278859734535, 0.0074381148442626, 0.9341073036193848, 0.03569793701171875, 0.0, 0.014355039224028587, 0.005383878946304321, 0.002517768880352378, 0.09422861039638519, 0.06622537225484848, 0.046315327286720276, 0.08473969250917435, 0.4999735355377197, 0.18626095354557037, 0.0], [0.00691588269546628, 0.02838265150785446, 0.015397720038890839, 0.031874921172857285, 0.04765379801392555, 0.22744230926036835, 0.06624653190374374, 0.10724947601556778, 0.4688366651535034, 0.0, 0.002460801973938942, 0.0016284199664369226, 0.005857668351382017, 0.006880565080791712, 0.7626023292541504, 0.025456121191382408, 0.021016357466578484, 0.06090177595615387, 0.11319592595100403, 0.0], [0.009144916199147701, 0.012914983555674553, 0.0114166010171175, 0.010616455227136612, 0.03852293640375137, 0.11398687958717346, 0.23996756970882416, 0.03855413943529129, 0.5248754620552063, 0.0, 0.007633878383785486, 0.002682786202058196, 0.0008938225219026208, 0.006808742880821228, 0.17231638729572296, 0.049100711941719055, 0.32851701974868774, 0.0061601921916007996, 0.4258863925933838, 0.0], [0.0165528766810894, 0.08396174013614655, 0.03695421293377876, 0.012792840600013733, 0.05054211989045143, 0.004681664984673262, 0.006349458359181881, 0.0059485542587935925, 0.7822163701057434, 0.0, 0.003303236560896039, 0.0015338786179199815, 0.0017581325955688953, 0.0052335225045681, 0.24177710711956024, 0.09136255830526352, 0.06603478640317917, 0.0047843558713793755, 0.5842124223709106, 0.0], [0.01941962167620659, 0.0720844566822052, 0.06703408807516098, 0.0024893011432141066, 0.09017500281333923, 0.01547347940504551, 0.011082785204052925, 0.036743972450494766, 0.6854971051216125, 0.0, 0.011694137938320637, 0.0015430846251547337, 0.00043408613419160247, 0.005433904007077217, 0.03723231703042984, 0.1666216105222702, 0.04878358170390129, 0.024785596877336502, 0.7034717798233032, 0.0], [0.010559813119471073, 0.7681021094322205, 0.01782229356467724, 0.0007385257631540298, 0.000383153063012287, 0.00014055910287424922, 0.00037340103881433606, 0.000453647633548826, 0.2014264017343521, 0.0, 0.028515880927443504, 0.0183264147490263, 0.011487613432109356, 0.03205259144306183, 0.06179385632276535, 0.041277043521404266, 0.014015565626323223, 0.06198226660490036, 0.7305486798286438, 0.0]], [[0.044569190591573715, 0.00917287077754736, 0.004391324240714312, 0.8386606574058533, 0.06130588799715042, 0.003870139131322503, 0.007488539442420006, 0.028126200661063194, 0.002415221417322755, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08009635657072067, 0.016815535724163055, 0.012093844823539257, 0.1592065542936325, 0.5643750429153442, 0.02920410968363285, 0.0919446051120758, 0.036902546882629395, 0.009361499920487404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08603464066982269, 0.01933746039867401, 0.05900268629193306, 0.2806539237499237, 0.22094620764255524, 0.08643656224012375, 0.026435989886522293, 0.1974046230316162, 0.023747902363538742, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011109462939202785, 0.008883769623935223, 0.006091873627156019, 0.9400036931037903, 0.020445559173822403, 0.0056496066972613335, 0.0019461432239040732, 0.005268549080938101, 0.000601345207542181, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.12766019999980927, 0.021774157881736755, 0.08726640790700912, 0.0718328207731247, 0.053083695471286774, 0.3031027019023895, 0.06321869790554047, 0.2611844837665558, 0.010876962915062904, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06407223641872406, 0.11627303808927536, 0.1807759404182434, 0.0054795523174107075, 0.026687098667025566, 0.09637009352445602, 0.052303463220596313, 0.4456423819065094, 0.012396130710840225, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09725438803434372, 0.15948796272277832, 0.05173082649707794, 0.01153761800378561, 0.0721999853849411, 0.059252724051475525, 0.11923323571681976, 0.05380275845527649, 0.3755004107952118, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0974307581782341, 0.23218762874603271, 0.06967660784721375, 0.031012043356895447, 0.04906507954001427, 0.31767621636390686, 0.08231117576360703, 0.07159094512462616, 0.04904941841959953, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16388627886772156, 0.17526376247406006, 0.07081529498100281, 0.17886894941329956, 0.07944575697183609, 0.07640470564365387, 0.0757102444767952, 0.04333823174238205, 0.13626690208911896, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09265941381454468, 0.11633585393428802, 0.04908691346645355, 0.0062498715706169605, 0.07016508281230927, 0.012818480841815472, 0.0484321266412735, 0.015437646768987179, 0.5888146162033081, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15394526720046997, 0.03155883774161339, 0.013263775035738945, 0.6156436800956726, 0.07578981667757034, 0.016770629212260246, 0.041555847972631454, 0.008898256346583366, 0.042573899030685425, 0.0, 0.10071786493062973, 0.017111245542764664, 0.07246935367584229, 0.01480931881815195, 0.14864948391914368, 0.20273517072200775, 0.054981958121061325, 0.25890761613845825, 0.12961813807487488, 0.0], [0.0076131573878228664, 0.0034139170311391354, 0.013692762702703476, 0.9489110708236694, 0.0023491643369197845, 0.004504398908466101, 0.018228650093078613, 6.88576401444152e-05, 0.0012180121848359704, 0.0, 0.049787674099206924, 0.02108882926404476, 0.20989678800106049, 0.006962155923247337, 0.21569682657718658, 0.1622857302427292, 0.016771212220191956, 0.2403237521648407, 0.07718709856271744, 0.0], [0.002593724289909005, 0.0010450187837705016, 0.004842442460358143, 0.9895163178443909, 0.0010734394891187549, 6.863682210678235e-05, 0.00040535334846936166, 0.00029785462538711727, 0.00015763564442750067, 0.0, 0.024618864059448242, 0.010488898493349552, 0.4834355115890503, 0.015693388879299164, 0.07393413037061691, 0.055557433515787125, 0.007495412603020668, 0.27800077199935913, 0.05077548325061798, 0.0], [0.007949098013341427, 0.0007930149440653622, 0.0010613474296405911, 0.913150429725647, 0.017281265929341316, 0.01414033118635416, 0.03622613847255707, 0.0036093932576477528, 0.0057889497838914394, 0.0, 0.006324393209069967, 0.0006586945382878184, 0.02188086323440075, 0.003439908614382148, 0.055277179926633835, 0.5423230528831482, 0.1656835526227951, 0.12264314293861389, 0.08176910877227783, 0.0], [0.002646723063662648, 0.0005160199943929911, 0.0002488561731297523, 0.0025351925287395716, 0.0016247049206867814, 0.09429789334535599, 0.8856176137924194, 0.005861275363713503, 0.006651633884757757, 0.0, 0.008246216922998428, 0.000647101376671344, 0.018551276996731758, 0.0031310885678976774, 0.04379039630293846, 0.34376823902130127, 0.2999532222747803, 0.13205647468566895, 0.1498558670282364, 0.0], [0.010305220261216164, 0.0041244118474423885, 0.0009454450337216258, 0.011387528851628304, 0.006450551562011242, 0.09920497238636017, 0.7582080960273743, 0.0005519646219909191, 0.1088215708732605, 0.0, 0.001800144906155765, 0.00032634654780849814, 0.02560480497777462, 0.0014933178899809718, 0.04328969866037369, 0.48067817091941833, 0.22867664694786072, 0.008819987997412682, 0.20931090414524078, 0.0], [0.02700764685869217, 0.004230276681482792, 0.0004602614790201187, 0.0022337904665619135, 0.001628970610909164, 0.01760227419435978, 0.5739604234695435, 0.0034094173461198807, 0.3694668710231781, 0.0, 0.0023069612216204405, 0.0018136217258870602, 0.006447605788707733, 0.005140945315361023, 0.046570103615522385, 0.045606330037117004, 0.3236173987388611, 0.014286459423601627, 0.5542104840278625, 0.0], [0.008519366383552551, 0.005846879445016384, 0.00031929058604873717, 0.00022687541786581278, 0.0001488836423959583, 0.0012441301951184869, 0.007195098325610161, 0.000364138453733176, 0.9761351943016052, 0.0, 0.0025225167628377676, 0.000774701707996428, 0.0168449804186821, 0.0014132045907899737, 0.1692919135093689, 0.21547472476959229, 0.19468647241592407, 0.00621472392231226, 0.392776757478714, 0.0], [0.10429845005273819, 0.062129467725753784, 0.0009245545952580869, 0.00015166438242886215, 0.00031537580071017146, 0.00040291156619787216, 0.006900690030306578, 0.009933815337717533, 0.8149431943893433, 0.0, 0.006893941201269627, 0.0026040272787213326, 0.036687299609184265, 0.0016275923699140549, 0.13132861256599426, 0.15552441775798798, 0.23651301860809326, 0.023025648668408394, 0.4057953953742981, 0.0], [0.021318454295396805, 0.024646490812301636, 0.0006273255567066371, 3.4892458643298596e-05, 0.0002248335222247988, 0.001184952794574201, 0.005942351184785366, 0.01648845337331295, 0.9295321702957153, 0.0, 0.004257934633642435, 0.008543262258172035, 0.05716743320226669, 0.0024442216381430626, 0.027526315301656723, 0.08828678727149963, 0.025276461616158485, 0.2843557894229889, 0.5021417737007141, 0.0]], [[0.005850312765687704, 0.017421673983335495, 0.004798548296093941, 0.008814580738544464, 0.00403921864926815, 0.015260725282132626, 0.03377071022987366, 0.009620469063520432, 0.9004237651824951, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013436811044812202, 0.03272867575287819, 0.005969575606286526, 0.02213078737258911, 0.008325905539095402, 0.015314633026719093, 0.027177294716238976, 0.017041552811861038, 0.8578747510910034, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012305106967687607, 0.03383316844701767, 0.010593314655125141, 0.027156231924891472, 0.00306991720572114, 0.004844812210649252, 0.018964877352118492, 0.05307865887880325, 0.8361539244651794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01255274098366499, 0.03600494936108589, 0.010369472205638885, 0.019021298736333847, 0.0032906190026551485, 0.0037067385856062174, 0.017627976834774017, 0.01037716306746006, 0.8870489597320557, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0020879805088043213, 0.013872731477022171, 0.0018324662232771516, 0.006437606178224087, 0.013170951046049595, 0.011930068954825401, 0.0030771365854889154, 0.018353432416915894, 0.9292376041412354, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022865016013383865, 0.0674654096364975, 0.00996339786797762, 0.01914660632610321, 0.014956261031329632, 0.026097828522324562, 0.018910687416791916, 0.06562207639217377, 0.7549726963043213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.019515078514814377, 0.07521340996026993, 0.03206341341137886, 0.0070005785673856735, 0.0066195218823850155, 0.03877842426300049, 0.01228683814406395, 0.032381508499383926, 0.7761411666870117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0022571254521608353, 0.00383052253164351, 0.0012509305961430073, 0.005982697941362858, 0.001252268673852086, 0.0028570422437042, 0.00556317949667573, 0.7337145805358887, 0.24329175055027008, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.132036030292511, 0.1683972030878067, 0.13758207857608795, 0.14189518988132477, 0.03147142380475998, 0.047566916793584824, 0.07834812998771667, 0.12177446484565735, 0.1409287303686142, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.038695693016052246, 0.06063547730445862, 0.020152689889073372, 0.1101006418466568, 0.021127784624695778, 0.02848564088344574, 0.03705665469169617, 0.108894944190979, 0.5748504996299744, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.18210574984550476, 0.33179324865341187, 0.17356830835342407, 0.09634877741336823, 0.13200198113918304, 0.013823754154145718, 0.003925282042473555, 0.0030049949418753386, 0.06342781335115433, 0.0, 0.2613511085510254, 0.024888625368475914, 0.11462423205375671, 0.021279124543070793, 0.1065509021282196, 0.23139707744121552, 0.07117345929145813, 0.09822205454111099, 0.07051338255405426, 0.0], [0.0020119324326515198, 0.018535200506448746, 0.9658629894256592, 0.007450602483004332, 0.004517300054430962, 0.0010996937053278089, 0.00011890578753082082, 8.778373739914969e-05, 0.0003156243183184415, 0.0, 0.08973123878240585, 0.07256940752267838, 0.3644520342350006, 0.09313907474279404, 0.10501276701688766, 0.026235496625304222, 0.035534195601940155, 0.05646198242902756, 0.15686386823654175, 0.0], [0.00025491390260867774, 0.010184208862483501, 0.989091694355011, 0.0003856455150526017, 4.125349732930772e-05, 1.630889528314583e-05, 4.754766450787429e-06, 9.06635705177905e-06, 1.2339231943769846e-05, 0.0, 0.12105944007635117, 0.03531542792916298, 0.18099160492420197, 0.04576702043414116, 0.03264385089278221, 0.04934798926115036, 0.0072426870465278625, 0.2739674150943756, 0.25366437435150146, 0.0], [0.0005263744969852269, 0.0021082928869873285, 0.03339724615216255, 0.9491472840309143, 0.010056117549538612, 0.0008323733345605433, 0.00041247360059060156, 0.003171485150232911, 0.0003482940956018865, 0.0, 0.049101557582616806, 0.02436317875981331, 0.1119280532002449, 0.019082490354776382, 0.23333144187927246, 0.12024182081222534, 0.09606382250785828, 0.03866123780608177, 0.3072265088558197, 0.0], [0.004013302735984325, 0.003607808379456401, 0.10117900371551514, 0.3848154544830322, 0.4549750089645386, 0.02184862084686756, 0.015023248270154, 0.013938427902758121, 0.0005991325015202165, 0.0, 0.011761177331209183, 0.004259902983903885, 0.019396282732486725, 0.010304590687155724, 0.5410462021827698, 0.1548439860343933, 0.1577453315258026, 0.022628072649240494, 0.07801424711942673, 0.0], [0.006630904506891966, 0.001555793103761971, 0.01566290855407715, 0.005377574823796749, 0.0545264296233654, 0.7578195929527283, 0.15542279183864594, 0.00011175184772582725, 0.002892365213483572, 0.0, 0.011012338101863861, 0.006456742994487286, 0.03514476120471954, 0.01111147552728653, 0.3646441400051117, 0.06045660004019737, 0.22725869715213776, 0.030072104185819626, 0.2538430392742157, 0.0], [0.0013706677127629519, 0.0003565592342056334, 0.0006504033226519823, 0.0008717189775779843, 0.023110924288630486, 0.16852477192878723, 0.8020843863487244, 0.0004564319388009608, 0.0025740356650203466, 0.0, 0.0009554855059832335, 0.0010365764610469341, 0.000539954868145287, 0.013481645844876766, 0.6702913641929626, 0.013201623223721981, 0.06565960496664047, 0.008186675608158112, 0.2266470193862915, 0.0], [0.0005740747437812388, 0.0018384596332907677, 0.015691960230469704, 0.0004515495093073696, 0.04004881531000137, 0.8668573498725891, 0.03566786274313927, 0.01278533972799778, 0.0260846596211195, 0.0, 0.007978711277246475, 0.0019918852485716343, 0.0007363414042629302, 0.010062554851174355, 0.10717969387769699, 0.01258536335080862, 0.08278501033782959, 0.02946571074426174, 0.7472147941589355, 0.0], [0.0008355869795195758, 0.003608973463997245, 0.04490630701184273, 0.009341607801616192, 0.007649072911590338, 0.10034366697072983, 0.06446904689073563, 0.7009655237197876, 0.06788014620542526, 0.0, 0.014369996264576912, 0.00412968173623085, 0.002898097038269043, 0.0381503589451313, 0.28382056951522827, 0.03412872180342674, 0.2624143660068512, 0.04523473232984543, 0.3148534893989563, 0.0], [0.00805425550788641, 0.039243537932634354, 0.05003930628299713, 0.0007152591715566814, 0.00863983016461134, 0.4756345748901367, 0.24407540261745453, 0.1204291433095932, 0.053168926388025284, 0.0, 0.0025127469561994076, 0.0030011499766260386, 0.0036209137178957462, 0.0006047216593287885, 0.01094596553593874, 0.0023283734917640686, 0.003409643191844225, 0.009625249542295933, 0.9639512896537781, 0.0]], [[0.0842718631029129, 0.33930304646492004, 0.1421334594488144, 0.18528752028942108, 0.05815916135907173, 0.022830937057733536, 0.01860896497964859, 0.009871570393443108, 0.13953347504138947, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1295168399810791, 0.08884051442146301, 0.06592670828104019, 0.2686370015144348, 0.02522267960011959, 0.03633918985724449, 0.021549394354224205, 0.051057688891887665, 0.31290990114212036, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20177747309207916, 0.039902716875076294, 0.053595658391714096, 0.09988140314817429, 0.01657777465879917, 0.07154539972543716, 0.024320384487509727, 0.12353017926216125, 0.36886897683143616, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20979411900043488, 0.30265647172927856, 0.20804975926876068, 0.007371237967163324, 0.0033807901199907064, 0.02442527562379837, 0.017248263582587242, 0.022337088361382484, 0.2047368586063385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20739784836769104, 0.05559583380818367, 0.12535981833934784, 0.009768493473529816, 0.015522800385951996, 0.024528708308935165, 0.03864477947354317, 0.08712086826562881, 0.4360608160495758, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.054182324558496475, 0.0038395673036575317, 0.019914912059903145, 0.014234894886612892, 0.012772555463016033, 0.019022708758711815, 0.04023807495832443, 0.36886361241340637, 0.46693122386932373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10943998396396637, 0.007069440558552742, 0.019821595400571823, 0.012627309188246727, 0.016869045794010162, 0.05302179232239723, 0.05124732851982117, 0.14304620027542114, 0.5868573188781738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23190192878246307, 0.060554418712854385, 0.05880776792764664, 0.00438718032091856, 0.00454165181145072, 0.15464532375335693, 0.09585105627775192, 0.02281157113611698, 0.3664989471435547, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13208958506584167, 0.15704363584518433, 0.07176639884710312, 0.08554346114397049, 0.0733223557472229, 0.0956358015537262, 0.07472448796033859, 0.09573546797037125, 0.21413877606391907, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.28021925687789917, 0.045415956526994705, 0.04812552034854889, 0.00880114920437336, 0.012029618956148624, 0.04001859948039055, 0.0577121265232563, 0.02487611398100853, 0.4828015863895416, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7289455533027649, 0.07586073875427246, 0.09869885444641113, 0.029881592839956284, 0.013988524675369263, 0.006547953933477402, 0.011115974746644497, 0.01961168274283409, 0.01534893549978733, 0.0, 0.023785226047039032, 0.024275904521346092, 0.6168470978736877, 0.01581703871488571, 0.026939542964100838, 0.1783975064754486, 0.04853774979710579, 0.02762567065656185, 0.03777410835027695, 0.0], [0.0068419137969613075, 0.9909167289733887, 0.0013876587618142366, 0.00011136479588458315, 6.257302447920665e-05, 6.40047510387376e-05, 1.7212767488672398e-05, 0.00025805848417803645, 0.0003405954339541495, 0.0, 0.006065902300179005, 0.04932599142193794, 0.8176359534263611, 0.018976736813783646, 0.008159944787621498, 0.011068272404372692, 0.010428683832287788, 0.014124251902103424, 0.06421414017677307, 0.0], [0.0458948090672493, 0.09657034277915955, 0.8496794104576111, 0.005955036263912916, 0.00011324687511660159, 0.000537818530574441, 0.00024974963162094355, 6.34682146483101e-05, 0.0009358474053442478, 0.0, 0.0037236923817545176, 0.007844064384698868, 0.9502744674682617, 0.0048003061674535275, 0.00022506865207105875, 0.004834793973714113, 0.0015490480000153184, 0.0026021157391369343, 0.024146683514118195, 0.0], [0.0026473035104572773, 0.007075308356434107, 0.0509142205119133, 0.9043685793876648, 0.01687121018767357, 0.0027590824756771326, 0.0040096985176205635, 0.004973408300429583, 0.006381142418831587, 0.0, 0.0007564057596027851, 0.0017460802337154746, 0.15768493711948395, 0.004074132069945335, 0.015430302359163761, 0.7368869781494141, 0.028010869398713112, 0.013945921324193478, 0.04146439954638481, 0.0], [0.00491793267428875, 0.0006714555202051997, 0.0047717769630253315, 0.09624139964580536, 0.8609471917152405, 0.020327016711235046, 0.008984015323221684, 0.0013932499568909407, 0.0017457373905926943, 0.0, 0.0008445779676549137, 0.0015138997696340084, 0.17073306441307068, 0.0074179465882480145, 0.08121992647647858, 0.5853323936462402, 0.09402737021446228, 0.024092217907309532, 0.034818582236766815, 0.0], [0.003878936870023608, 0.005480882711708546, 0.00011314810399198905, 0.0003485401102807373, 0.006120527163147926, 0.0029893070459365845, 0.0006264422554522753, 0.004414959345012903, 0.9760271906852722, 0.0, 0.0006014688406139612, 0.0016882645431905985, 0.16094569861888885, 0.003698966233059764, 0.034668561071157455, 0.5876308679580688, 0.09562253206968307, 0.05209798738360405, 0.06304588913917542, 0.0], [2.4277944248751737e-05, 4.029595402244013e-06, 3.533453991622082e-07, 0.0002488488098606467, 2.0782925275852904e-05, 0.0004858894390054047, 0.9990906119346619, 4.08113919547759e-05, 8.422240352956578e-05, 0.0, 0.00015283364336937666, 0.0004516944463830441, 0.003205003682523966, 0.0049727726727724075, 0.10853080451488495, 0.03262018784880638, 0.6125266551971436, 0.005719948559999466, 0.2318202555179596, 0.0], [0.0008877408690750599, 0.00558891985565424, 8.855570922605693e-05, 8.779557902016677e-06, 0.00010457105963723734, 0.00017662049503996968, 0.002778601599857211, 0.09916532039642334, 0.8912010192871094, 0.0, 9.933842375176027e-05, 0.00020211786613799632, 0.0037883264012634754, 0.0051808832213282585, 0.6936825513839722, 0.10089477151632309, 0.023457802832126617, 0.011726793833076954, 0.16096755862236023, 0.0], [0.0001188771057059057, 0.0008492054184898734, 9.383026917930692e-05, 5.974515715934103e-06, 0.002050562761723995, 8.90250812517479e-05, 8.933644130593166e-05, 0.002921103034168482, 0.9937818646430969, 0.0, 0.0002526468597352505, 0.0010056017199531198, 0.003837066935375333, 0.034950658679008484, 0.5882559418678284, 0.029549231752753258, 0.030938459560275078, 0.01461110170930624, 0.2965993583202362, 0.0], [0.013618292286992073, 0.005976812914013863, 3.6079491110285744e-05, 4.805085336556658e-05, 0.00010178168304264545, 0.03545643016695976, 0.04239484667778015, 0.35667654871940613, 0.5456912517547607, 0.0, 0.0029390468262135983, 0.005815382581204176, 0.06488344818353653, 0.008705642074346542, 0.010130577720701694, 0.012970774434506893, 0.019612692296504974, 0.007819950580596924, 0.8671225309371948, 0.0]], [[0.057871319353580475, 0.09845025092363358, 0.03600643575191498, 0.06401734054088593, 0.07263048738241196, 0.014885936863720417, 0.07473781704902649, 0.1193607747554779, 0.4620397090911865, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04736582189798355, 0.06951819360256195, 0.039210546761751175, 0.040616557002067566, 0.05645532160997391, 0.01900673843920231, 0.063181072473526, 0.23291724920272827, 0.43172842264175415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05375257506966591, 0.04091374948620796, 0.01263821218162775, 0.04125160351395607, 0.014244006015360355, 0.012229752726852894, 0.029117466881871223, 0.07314542680978775, 0.7227071523666382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04062311723828316, 0.2312910407781601, 0.20085060596466064, 0.03848586603999138, 0.04763459786772728, 0.013425372540950775, 0.027237186208367348, 0.03882591798901558, 0.3616262376308441, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04558584839105606, 0.04037311673164368, 0.043737076222896576, 0.02740027941763401, 0.005366531666368246, 0.014126299880445004, 0.07268305867910385, 0.014923120848834515, 0.7358046770095825, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04308110475540161, 0.037618398666381836, 0.054927192628383636, 0.045146394520998, 0.02157701551914215, 0.014024189673364162, 0.03546718508005142, 0.04130468890070915, 0.7068538665771484, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.049693867564201355, 0.040712278336286545, 0.011129319667816162, 0.08677691221237183, 0.24132831394672394, 0.028864668682217598, 0.04710082337260246, 0.028962818905711174, 0.46543097496032715, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0256647989153862, 0.026453843340277672, 0.1064542606472969, 0.07867259532213211, 0.03285365179181099, 0.056291256099939346, 0.026517342776060104, 0.014768523164093494, 0.6323237419128418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.07950135320425034, 0.04906205087900162, 0.09099037200212479, 0.10450085997581482, 0.06846266984939575, 0.21755923330783844, 0.14818403124809265, 0.14456483721733093, 0.0971745029091835, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0029191188514232635, 0.0026039828080683947, 0.000987313687801361, 0.027727283537387848, 0.007311245426535606, 0.0033244043588638306, 0.023969389498233795, 0.00596341909840703, 0.9251939058303833, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5123088955879211, 0.04897892847657204, 0.0054785809479653835, 0.037157464772462845, 0.033040400594472885, 0.0287709329277277, 0.020658228546380997, 0.005767439026385546, 0.3078390061855316, 0.0, 0.027314670383930206, 0.02143898233771324, 0.1116434708237648, 0.006578116212040186, 0.20446842908859253, 0.3867157995700836, 0.054494183510541916, 0.09778231382369995, 0.08956411480903625, 0.0], [0.02116605080664158, 0.45384567975997925, 0.5185156464576721, 0.002682638820260763, 0.000782136688940227, 0.0011598592391237617, 9.265153494197875e-05, 0.0013774167746305466, 0.00037764458102174103, 0.0, 0.09418193250894547, 0.7071846127510071, 0.05323847755789757, 0.0077135805040597916, 0.01789833791553974, 0.010848474688827991, 0.0020562252029776573, 0.01705808937549591, 0.08982021361589432, 0.0], [0.06292606890201569, 0.18043853342533112, 0.7547035217285156, 0.0011577572440728545, 6.746899998688605e-06, 0.00012623713701032102, 8.539375994587317e-05, 0.0005039210664108396, 5.15973697474692e-05, 0.0, 0.005751691292971373, 0.01031999196857214, 0.8884198069572449, 0.00210022390820086, 0.0066058398224413395, 0.019834432750940323, 0.002143828198313713, 0.02793465554714203, 0.036889322102069855, 0.0], [0.0030907110776752234, 0.0035500964149832726, 0.4530283808708191, 0.5221376419067383, 0.009557867422699928, 0.008033420890569687, 0.0001996932114707306, 0.0003745300055015832, 2.7415442673373036e-05, 0.0, 0.027659546583890915, 0.03931494802236557, 0.10616040229797363, 0.011142275296151638, 0.1017894372344017, 0.30847156047821045, 0.12201698124408722, 0.05519269034266472, 0.22825226187705994, 0.0], [0.00426989421248436, 0.0004077394842170179, 0.010691642761230469, 0.016011668369174004, 0.5530933737754822, 0.12423280626535416, 0.0053755901753902435, 0.28551235795021057, 0.00040478314622305334, 0.0, 0.037639543414115906, 0.062483835965394974, 0.050776157528162, 0.012697378173470497, 0.27911704778671265, 0.19993652403354645, 0.14870049059391022, 0.10304640233516693, 0.10560261458158493, 0.0], [0.00216855201870203, 0.009595326147973537, 0.007803121581673622, 0.04625817388296127, 0.24702508747577667, 0.2669595181941986, 0.024053409695625305, 0.24639348685741425, 0.14974308013916016, 0.0, 0.007995839230716228, 0.008397839032113552, 0.03270075097680092, 0.004312656354159117, 0.03775893524289131, 0.3733556568622589, 0.3424486219882965, 0.012857009656727314, 0.18017242848873138, 0.0], [1.953603486981592e-06, 5.726202516598278e-07, 5.0084551617146644e-08, 6.896097602293594e-06, 0.0001788837107596919, 0.0027895711828023195, 0.9969833493232727, 1.1296948287053965e-05, 2.7187752493773587e-05, 0.0, 0.012142053805291653, 0.007298614829778671, 0.016982076689600945, 0.02473442070186138, 0.08738671243190765, 0.033574704080820084, 0.27830857038497925, 0.033199213445186615, 0.5063735842704773, 0.0], [0.00041725789196789265, 0.0019265476148575544, 6.523763295263052e-05, 0.00018337361689191312, 0.011946662329137325, 0.04555974155664444, 0.15744170546531677, 0.025624049827456474, 0.7568355202674866, 0.0, 0.02729739435017109, 0.05135440081357956, 0.03332214429974556, 0.02499799057841301, 0.11955489963293076, 0.020848069339990616, 0.017926985397934914, 0.01858661323785782, 0.6861116290092468, 0.0], [0.0004706757317762822, 0.0027324894908815622, 0.0007427418022416532, 0.00934627279639244, 0.17134670913219452, 0.030644211918115616, 0.08413954824209213, 0.2513456642627716, 0.4492316246032715, 0.0, 0.018110578879714012, 0.011406980454921722, 0.0018257799092680216, 0.025524618104100227, 0.3885835111141205, 0.010744227096438408, 0.008441396057605743, 0.003679890651255846, 0.5316829681396484, 0.0], [0.0002758087939582765, 0.0016877831658348441, 2.4452297111565713e-06, 0.0004533462051767856, 0.001545731327496469, 0.008134560659527779, 0.010873721912503242, 0.026235109195113182, 0.950791597366333, 0.0, 0.02325628325343132, 0.013795747421681881, 0.0823512151837349, 0.0021813653875142336, 0.03511650115251541, 0.0814405307173729, 0.02589382231235504, 0.14330172538757324, 0.5926627516746521, 0.0]], [[0.3116825819015503, 0.20666195452213287, 0.1363646388053894, 0.07141851633787155, 0.029045483097434044, 0.04730900749564171, 0.037391580641269684, 0.10128220915794373, 0.05884409323334694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.16883184015750885, 0.1785622239112854, 0.23318006098270416, 0.05258537083864212, 0.10740725696086884, 0.09185276180505753, 0.022670285776257515, 0.08943870663642883, 0.05547139048576355, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13844002783298492, 0.030681077390909195, 0.12107612937688828, 0.007168712094426155, 0.05214103311300278, 0.10045275092124939, 0.006991118658334017, 0.2955043315887451, 0.2475447952747345, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.17306901514530182, 0.19963578879833221, 0.148520827293396, 0.2046978771686554, 0.06720028817653656, 0.019652126356959343, 0.05792365223169327, 0.07665500044822693, 0.05264541134238243, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15115797519683838, 0.06555280834436417, 0.02683498151600361, 0.20794028043746948, 0.17434173822402954, 0.11980342864990234, 0.04239796847105026, 0.03961418569087982, 0.1723567098379135, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.022970011457800865, 0.003322059055790305, 0.03333919122815132, 0.03161727264523506, 0.3007935583591461, 0.20675496757030487, 0.037206847220659256, 0.30415406823158264, 0.059842076152563095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.029802093282341957, 0.006723356898874044, 0.00844306219369173, 0.09911961853504181, 0.2257867157459259, 0.22737178206443787, 0.28318148851394653, 0.05687837302684784, 0.0626935064792633, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.008192392997443676, 0.002076511736959219, 0.010627061128616333, 0.01573592983186245, 0.01893553137779236, 0.042316026985645294, 0.02403445728123188, 0.868257999420166, 0.009823988191783428, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08033955097198486, 0.18780502676963806, 0.02817567251622677, 0.041370414197444916, 0.02824225462973118, 0.038844767957925797, 0.11732563376426697, 0.06162348762154579, 0.4162730574607849, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0563269779086113, 0.007248359732329845, 0.008029782213270664, 0.003040844574570656, 0.007221699226647615, 0.01730128563940525, 0.050128430128097534, 0.3587413430213928, 0.491961270570755, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23423711955547333, 0.3770143389701843, 0.036408282816410065, 0.05249679461121559, 0.12246986478567123, 0.07398416101932526, 0.040104977786540985, 0.05500312149524689, 0.008280999027192593, 0.0, 0.011912941001355648, 0.006341524887830019, 0.1334817260503769, 0.017931688576936722, 0.005569889210164547, 0.7441595792770386, 0.0258712787181139, 0.034265827387571335, 0.020465616136789322, 0.0], [0.032638370990753174, 0.0975130945444107, 0.07180608063936234, 0.1075734794139862, 0.025216424837708473, 0.0218534916639328, 0.0376754030585289, 0.19710108637809753, 0.40862247347831726, 0.0, 0.03743305802345276, 0.05229289084672928, 0.3549361228942871, 0.028500670567154884, 0.01974724419414997, 0.29288655519485474, 0.08050932735204697, 0.06582070142030716, 0.06787342578172684, 0.0], [0.06665100902318954, 0.01976630836725235, 0.13041609525680542, 0.1772802174091339, 0.07561768591403961, 0.0061133550480008125, 0.022857116535305977, 0.3516288995742798, 0.14966930449008942, 0.0, 0.025082573294639587, 0.10057684034109116, 0.7856844663619995, 0.01178921852260828, 0.0010154875926673412, 0.02595749869942665, 0.008632739074528217, 0.006036050152033567, 0.0352250337600708, 0.0], [0.007724895142018795, 0.007534464355558157, 0.020593322813510895, 0.32147932052612305, 0.059838466346263885, 0.017819387838244438, 0.13181470334529877, 0.15524767339229584, 0.27794769406318665, 0.0, 0.010372502729296684, 0.023954369127750397, 0.18692812323570251, 0.03930393233895302, 0.004741673823446035, 0.46527597308158875, 0.1267295777797699, 0.048278260976076126, 0.09441567957401276, 0.0], [0.0028419073205441236, 0.0006648481939919293, 0.0018234169110655785, 0.01609039306640625, 0.0009005893371067941, 0.09726841002702713, 0.11035522073507309, 0.6978457570075989, 0.07220931351184845, 0.0, 0.0031391805969178677, 0.006868112366646528, 0.1369999200105667, 0.013019833713769913, 0.008593270555138588, 0.6626507639884949, 0.07777946442365646, 0.052107103168964386, 0.03884238004684448, 0.0], [0.0027223427314311266, 0.011157790198922157, 0.0052745467983186245, 0.00438346853479743, 0.010011360049247742, 0.38358205556869507, 0.2294219732284546, 0.057655833661556244, 0.29579076170921326, 0.0, 0.005276903510093689, 0.026354510337114334, 0.056238383054733276, 0.03191604092717171, 0.025259410962462425, 0.3898610472679138, 0.2790180444717407, 0.028249284252524376, 0.1578262895345688, 0.0], [0.0007371494430117309, 0.00023907626746222377, 8.450529276160523e-05, 0.002850313438102603, 0.002168564358726144, 0.02191324159502983, 0.9514637589454651, 0.0002505093871150166, 0.02029278129339218, 0.0, 0.001124523114413023, 0.0035109275486320257, 0.0021898215636610985, 0.043262600898742676, 0.0267842635512352, 0.03029855713248253, 0.5000982284545898, 0.0056237452663481236, 0.38710734248161316, 0.0], [0.004062887746840715, 0.007024500984698534, 0.007399421185255051, 0.011675640009343624, 0.0680280476808548, 0.02557964250445366, 0.07043837755918503, 0.01946563646197319, 0.7863259315490723, 0.0, 0.00020160828717052937, 0.0004381221951916814, 0.010444902814924717, 0.010398894548416138, 0.10143585503101349, 0.23107938468456268, 0.09920267760753632, 0.019364865496754646, 0.5274338126182556, 0.0], [0.0007200208492577076, 0.0024253681767731905, 0.029840486124157906, 0.00014908696175552905, 0.11268167197704315, 0.03336171433329582, 0.007834793999791145, 0.08127990365028381, 0.7317068576812744, 0.0, 0.00019610628078226, 0.00041535915806889534, 0.009462974965572357, 0.005905983969569206, 0.29870083928108215, 0.24092966318130493, 0.11132201552391052, 0.05168075114488602, 0.2813863754272461, 0.0], [0.008707311004400253, 0.08642490208148956, 0.11372587829828262, 0.004973042756319046, 0.13256150484085083, 0.16557356715202332, 0.0817113071680069, 0.006879410706460476, 0.3994430899620056, 0.0, 0.004135515075176954, 0.014277483336627483, 0.15738952159881592, 0.003396045882254839, 0.01641785353422165, 0.07296160608530045, 0.02388397790491581, 0.024013018235564232, 0.683525025844574, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\", \"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"], \"bot_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\", \"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}, \"inp_inp\": {\"att\": [[[[0.05334341153502464, 0.025828205049037933, 0.062369391322135925, 0.043252814561128616, 0.4045393764972687, 0.06697215139865875, 0.09001608937978745, 0.14983074367046356, 0.10384786874055862, 0.0], [0.11816457659006119, 0.03106253407895565, 0.01979171112179756, 0.16624291241168976, 0.3321376442909241, 0.020051123574376106, 0.08730963617563248, 0.18211135268211365, 0.04312858730554581, 0.0], [0.05936884880065918, 0.02174757793545723, 0.016160180792212486, 0.010601435787975788, 0.43925121426582336, 0.03876951336860657, 0.19815810024738312, 0.07065817713737488, 0.14528508484363556, 0.0], [0.15478025376796722, 0.16446512937545776, 0.0578744001686573, 0.21637752652168274, 0.03835854306817055, 0.09130414575338364, 0.11191156506538391, 0.08360221982002258, 0.08132638782262802, 0.0], [0.2183060646057129, 0.1704275906085968, 0.0827711746096611, 0.1202380359172821, 0.05203341320157051, 0.05958092212677002, 0.12280035018920898, 0.09366822242736816, 0.08017415553331375, 0.0], [0.05084313824772835, 0.026207493618130684, 0.13631564378738403, 0.012270472943782806, 0.16236551105976105, 0.02548854425549507, 0.03909383341670036, 0.03172134608030319, 0.5156941413879395, 0.0], [0.03615221381187439, 0.04799472168087959, 0.04255519434809685, 0.04762651398777962, 0.5117892622947693, 0.016304347664117813, 0.005770198069512844, 0.10897397249937057, 0.18283340334892273, 0.0], [0.03243544325232506, 0.025252558290958405, 0.11733424663543701, 0.0250592939555645, 0.20289097726345062, 0.08240236341953278, 0.18285907804965973, 0.011341268196702003, 0.3204246759414673, 0.0], [0.22355543076992035, 0.1260528564453125, 0.03741241991519928, 0.16813479363918304, 0.09858733415603638, 0.035831648856401443, 0.16361697018146515, 0.07236126810312271, 0.07444748282432556, 0.0], [0.08996112644672394, 0.0921943336725235, 0.22672457993030548, 0.12702998518943787, 0.05907799303531647, 0.10712798684835434, 0.16789256036281586, 0.055181413888931274, 0.07481010258197784, 0.0]], [[0.040477100759744644, 0.20988762378692627, 0.4869004786014557, 0.03505674749612808, 0.0558856800198555, 0.025423096492886543, 0.12231241166591644, 0.007062799762934446, 0.016993943601846695, 0.0], [0.8996549844741821, 0.02599872276186943, 0.049097247421741486, 0.0040262676775455475, 0.0039152717217803, 0.0049644638784229755, 0.010553319938480854, 0.001352570834569633, 0.0004369009402580559, 0.0], [0.33065715432167053, 0.2687782049179077, 0.03312753140926361, 0.22958999872207642, 0.01851547136902809, 0.046473052352666855, 0.053183481097221375, 0.007113412953913212, 0.012561764568090439, 0.0], [0.1589452475309372, 0.47470128536224365, 0.12878550589084625, 0.14158962666988373, 0.04442765936255455, 0.022274963557720184, 0.013780632056295872, 0.0024951419327408075, 0.012999956496059895, 0.0], [0.2559169828891754, 0.033451542258262634, 0.15095548331737518, 0.024318046867847443, 0.10824166238307953, 0.03234097361564636, 0.36475417017936707, 0.012823408469557762, 0.017197895795106888, 0.0], [0.021462664008140564, 0.010474847629666328, 0.007213775999844074, 0.02227940410375595, 0.21737068891525269, 0.4960675537586212, 0.014628118835389614, 0.20502059161663055, 0.005482145119458437, 0.0], [0.06734316051006317, 0.09532227367162704, 0.1127309575676918, 0.009542002342641354, 0.0678786113858223, 0.12933993339538574, 0.03809814900159836, 0.44453269243240356, 0.035212237387895584, 0.0], [0.10458365827798843, 0.02846018597483635, 0.029760979115962982, 0.014774680137634277, 0.022077379748225212, 0.1553817093372345, 0.3539015054702759, 0.19523507356643677, 0.09582491964101791, 0.0], [0.021077070385217667, 0.010932122357189655, 0.05088815093040466, 0.028641115874052048, 0.0881260335445404, 0.12014731019735336, 0.3900885581970215, 0.09544514119625092, 0.1946544349193573, 0.0], [0.02552945166826248, 0.05594164505600929, 0.045791901648044586, 0.093170166015625, 0.03584437444806099, 0.0969511866569519, 0.18585819005966187, 0.17433671653270721, 0.28657644987106323, 0.0]], [[0.18220090866088867, 0.25508272647857666, 0.2721964120864868, 0.04886331781744957, 0.010257811285555363, 0.07344724237918854, 0.08866558223962784, 0.037977367639541626, 0.0313086174428463, 0.0], [0.5722172260284424, 0.09567929804325104, 0.1448327898979187, 0.033306267112493515, 0.0031244128476828337, 0.020944159477949142, 0.012691132724285126, 0.061001092195510864, 0.05620381608605385, 0.0], [0.049244701862335205, 0.5266616344451904, 0.27518483996391296, 0.09334208071231842, 0.005858665332198143, 0.005467486567795277, 0.02565312758088112, 0.005746132228523493, 0.012841282412409782, 0.0], [0.13445906341075897, 0.13356590270996094, 0.6041688919067383, 0.01878039538860321, 0.06342840194702148, 0.03677675500512123, 0.008389262482523918, 0.0002739423362072557, 0.00015757972141727805, 0.0], [0.03273050859570503, 0.0697193592786789, 0.19719526171684265, 0.41500693559646606, 0.13721567392349243, 0.05743291601538658, 0.06517775356769562, 0.010865128599107265, 0.014656689018011093, 0.0], [0.031571000814437866, 0.014337136410176754, 0.06860436499118805, 0.09357307106256485, 0.10011686384677887, 0.07827721536159515, 0.5866308212280273, 0.011440092697739601, 0.015449290163815022, 0.0], [0.006158333271741867, 0.001533387927338481, 0.05427416041493416, 0.005477452650666237, 0.02694696933031082, 0.8134917616844177, 0.02643686905503273, 0.050265438854694366, 0.015415593050420284, 0.0], [0.008847472257912159, 0.0066053420305252075, 0.036443497985601425, 0.021455924957990646, 0.019254589453339577, 0.11543811857700348, 0.1138116791844368, 0.20307059586048126, 0.4750728905200958, 0.0], [0.017603449523448944, 0.008448019623756409, 0.004260394722223282, 0.006066101603209972, 0.013470137491822243, 0.01876576989889145, 0.16350960731506348, 0.1980665624141693, 0.5698099732398987, 0.0], [0.10490093380212784, 0.014168650843203068, 0.0247807614505291, 0.018330294638872147, 0.009348674677312374, 0.02287398651242256, 0.032268356531858444, 0.10571902245283127, 0.6676092147827148, 0.0]], [[0.2071455419063568, 0.637531578540802, 0.06835082173347473, 0.011966697871685028, 0.0017193991225212812, 0.04911382868885994, 0.009478496387600899, 0.008040529675781727, 0.00665308628231287, 0.0], [0.07411027699708939, 0.15093472599983215, 0.2656005620956421, 0.05758262053132057, 0.05194409564137459, 0.23625947535037994, 0.019166678190231323, 0.04010465368628502, 0.10429693013429642, 0.0], [0.1540999412536621, 0.10598444193601608, 0.22474077343940735, 0.32441702485084534, 0.1116243302822113, 0.054135363548994064, 0.008848286233842373, 0.004088098648935556, 0.012061581946909428, 0.0], [0.019440434873104095, 0.00560638727620244, 0.0035774046555161476, 0.0888679027557373, 0.7120485901832581, 0.14891275763511658, 0.011600993573665619, 0.008666431531310081, 0.0012791723711416125, 0.0], [0.08580154180526733, 0.02444172091782093, 0.08060747385025024, 0.05198557302355766, 0.2700504660606384, 0.34216371178627014, 0.11280739307403564, 0.006445358972996473, 0.02569655328989029, 0.0], [0.0424385629594326, 0.029667967930436134, 0.006252861116081476, 0.020168066024780273, 0.03000665083527565, 0.2812231779098511, 0.49279165267944336, 0.09351769089698792, 0.003933228086680174, 0.0], [0.006467411294579506, 0.0076894015073776245, 0.008325580507516861, 0.0010907554533332586, 0.01040297094732523, 0.19462232291698456, 0.013263629749417305, 0.24681615829467773, 0.5113216042518616, 0.0], [0.028696376830339432, 0.014982450753450394, 0.011884906329214573, 0.0011242942418903112, 0.01692844182252884, 0.12885364890098572, 0.028225399553775787, 0.6451764106750488, 0.12412811070680618, 0.0], [0.16117365658283234, 0.06794824451208115, 0.06173194944858551, 0.00451233983039856, 0.05306624248623848, 0.0510348416864872, 0.04402391240000725, 0.12432018667459488, 0.4321887195110321, 0.0], [0.1690559983253479, 0.043453093618154526, 0.036818861961364746, 0.017293656244874, 0.11775903403759003, 0.07970321178436279, 0.043801818042993546, 0.06849095970392227, 0.4236232340335846, 0.0]], [[0.03085354156792164, 0.12322185933589935, 0.13651973009109497, 0.050716523081064224, 0.2999139726161957, 0.09802427887916565, 0.06620478630065918, 0.0782310962677002, 0.11631430685520172, 0.0], [0.06789751350879669, 0.058182138949632645, 0.3129631578922272, 0.04353875666856766, 0.09142065048217773, 0.10271093249320984, 0.026392055675387383, 0.09630800783634186, 0.2005866914987564, 0.0], [0.07152411341667175, 0.3454192876815796, 0.11299439519643784, 0.18012462556362152, 0.07151429355144501, 0.052652161568403244, 0.0567985400557518, 0.09459780901670456, 0.014374655671417713, 0.0], [0.10420235246419907, 0.21845531463623047, 0.19832336902618408, 0.022119704633951187, 0.13572701811790466, 0.07722532749176025, 0.0508468933403492, 0.045597679913043976, 0.14750221371650696, 0.0], [0.07030870020389557, 0.10706955939531326, 0.02791348285973072, 0.02260597050189972, 0.12725059688091278, 0.07336997240781784, 0.26662203669548035, 0.16957008838653564, 0.13528966903686523, 0.0], [0.05156806856393814, 0.04327721148729324, 0.07664787024259567, 0.06931594759225845, 0.1889398992061615, 0.09515503793954849, 0.07227510958909988, 0.2641449272632599, 0.13867592811584473, 0.0], [0.02184019424021244, 0.11184182018041611, 0.36672860383987427, 0.013787303119897842, 0.07600502669811249, 0.0389828234910965, 0.040494974702596664, 0.12485849112272263, 0.20546066761016846, 0.0], [0.013738485053181648, 0.05187288299202919, 0.03463537245988846, 0.03627979755401611, 0.048659998923540115, 0.02440205216407776, 0.07256433367729187, 0.024731382727622986, 0.6931155323982239, 0.0], [0.02671198360621929, 0.4013687074184418, 0.01132842618972063, 0.14022575318813324, 0.026275552809238434, 0.08107840269804001, 0.04189194366335869, 0.25432130694389343, 0.0167979933321476, 0.0], [0.14228780567646027, 0.07866450399160385, 0.08390624076128006, 0.09396661072969437, 0.087954580783844, 0.14498625695705414, 0.13517630100250244, 0.1169552430510521, 0.11610251665115356, 0.0]], [[0.02165721170604229, 0.018354326486587524, 0.6383510828018188, 0.042513273656368256, 0.10956817120313644, 0.10717540234327316, 0.030344119295477867, 0.015826348215341568, 0.01621006615459919, 0.0], [0.4647374749183655, 0.07284841686487198, 0.28081396222114563, 0.014013433828949928, 0.03169411048293114, 0.02214456908404827, 0.058711059391498566, 0.036629818379879, 0.01840737834572792, 0.0], [0.07372704148292542, 0.12858736515045166, 0.4501189887523651, 0.054217785596847534, 0.07096204906702042, 0.05748127028346062, 0.06541819125413895, 0.04703349620103836, 0.05245373025536537, 0.0], [0.04684445261955261, 0.019098779186606407, 0.008431704714894295, 0.0010175607167184353, 0.9129327535629272, 0.004866998642683029, 0.006678053177893162, 8.096762758214027e-05, 4.903498847852461e-05, 0.0], [0.08239725232124329, 0.02813413366675377, 0.16611848771572113, 0.1532817929983139, 0.07408729940652847, 0.10856874287128448, 0.047752734273672104, 0.02563621662557125, 0.31402355432510376, 0.0], [0.17959792912006378, 0.02262653037905693, 0.10724494606256485, 0.022216446697711945, 0.1862414926290512, 0.14705143868923187, 0.15912717580795288, 0.15293282270431519, 0.02296125516295433, 0.0], [0.038375359028577805, 0.0038853511214256287, 0.06201936677098274, 0.005828780122101307, 0.22059503197669983, 0.36631014943122864, 0.020396992564201355, 0.20976856350898743, 0.07282061129808426, 0.0], [0.014258276671171188, 0.005652762018144131, 0.025611618533730507, 0.15294744074344635, 0.06760217249393463, 0.2498260736465454, 0.1669282466173172, 0.2265811711549759, 0.09059228003025055, 0.0], [0.15833799540996552, 0.1228356659412384, 0.10147804021835327, 0.0284584891051054, 0.27955442667007446, 0.06763719022274017, 0.08874277770519257, 0.1152903363108635, 0.037665050476789474, 0.0], [0.09844867885112762, 0.0919492095708847, 0.028445947915315628, 0.03726689890027046, 0.035665158182382584, 0.06817072629928589, 0.29930955171585083, 0.09819743037223816, 0.2425464242696762, 0.0]], [[0.02519470639526844, 0.006357265170663595, 0.14269335567951202, 0.023629529401659966, 0.3124701976776123, 0.13565225899219513, 0.2595662772655487, 0.07959114015102386, 0.014845297671854496, 0.0], [0.04550129547715187, 0.011541971005499363, 0.1165909469127655, 0.02512240968644619, 0.01843150518834591, 0.05711649730801582, 0.44489097595214844, 0.033205363899469376, 0.24759893119335175, 0.0], [0.13528011739253998, 0.06777236610651016, 0.14429129660129547, 0.04697401076555252, 0.1738385707139969, 0.014099549502134323, 0.38417065143585205, 0.01158357597887516, 0.02199004776775837, 0.0], [0.21356959640979767, 0.1638900637626648, 0.10595463216304779, 0.06925727427005768, 0.167257159948349, 0.04259340837597847, 0.10967854410409927, 0.03570139408111572, 0.09209771454334259, 0.0], [0.20140984654426575, 0.04755665361881256, 0.15174560248851776, 0.11619894206523895, 0.21928974986076355, 0.07600340992212296, 0.05828682705760002, 0.10010629147291183, 0.029402663931250572, 0.0], [0.024259669706225395, 0.02116699516773224, 0.21201731264591217, 0.019622934982180595, 0.4893963038921356, 0.021304504945874214, 0.16948339343070984, 0.022949064150452614, 0.01979990489780903, 0.0], [0.022248759865760803, 0.01183647196739912, 0.0633181631565094, 0.029095010831952095, 0.07090882211923599, 0.4614315629005432, 0.020150773227214813, 0.18720205128192902, 0.1338084638118744, 0.0], [0.003461656626313925, 0.01603432185947895, 0.009874427691102028, 0.014947548508644104, 0.2953553795814514, 0.3502987027168274, 0.08878874033689499, 0.036094941198825836, 0.18514421582221985, 0.0], [0.005101516842842102, 0.022985950112342834, 0.007523353211581707, 0.026773063465952873, 0.01009095273911953, 0.014858697541058064, 0.15149906277656555, 0.028601571917533875, 0.7325656414031982, 0.0], [0.12995873391628265, 0.07769863307476044, 0.02032659947872162, 0.13720010221004486, 0.011713794432580471, 0.054615918546915054, 0.23920413851737976, 0.13190706074237823, 0.19737498462200165, 0.0]], [[0.21207179129123688, 0.11920439451932907, 0.4251355528831482, 0.014464439824223518, 0.20776884257793427, 0.01428140513598919, 0.0027938869316130877, 0.001743048895150423, 0.002536489861086011, 0.0], [0.046175818890333176, 0.026793524622917175, 0.8552185297012329, 0.04517081379890442, 0.010388500988483429, 0.004191457759588957, 0.0036751439329236746, 0.0013485046802088618, 0.007037981878966093, 0.0], [0.013186579570174217, 0.020899420604109764, 0.6900137662887573, 0.0480119027197361, 0.15360434353351593, 0.02344118244946003, 0.03952033817768097, 0.0038994532078504562, 0.007422822527587414, 0.0], [0.006273405160754919, 0.00015674144378863275, 0.000751359446439892, 0.00447711581364274, 0.9859057664871216, 0.002212332095950842, 0.00014360185014083982, 4.957199053023942e-05, 2.9913859179941937e-05, 0.0], [0.001047183177433908, 0.0003636489564087242, 0.009283728897571564, 0.016805388033390045, 0.42387446761131287, 0.4776095747947693, 0.06253702938556671, 0.005590841174125671, 0.002888289513066411, 0.0], [0.0018647151300683618, 0.0002549054042901844, 2.6050107408082113e-05, 2.586200753285084e-05, 0.0024472770746797323, 0.006814199965447187, 0.9776560664176941, 0.010138182900846004, 0.000773087958805263, 0.0], [0.047241877764463425, 0.006076885852962732, 0.04534892365336418, 0.00081661093281582, 0.087706059217453, 0.41394293308258057, 0.21876952052116394, 0.17005810141563416, 0.0100388890132308, 0.0], [0.0019138919888064265, 0.006189406383782625, 0.010115097276866436, 8.508542669005692e-05, 0.008424345403909683, 0.003492203773930669, 0.13495568931102753, 0.4890870749950409, 0.34573695063591003, 0.0], [0.016032341867685318, 0.005025702994316816, 0.009520799852907658, 0.0008855267078615725, 0.026489384472370148, 0.0020503124687820673, 0.032939448952674866, 0.09461060166358948, 0.8124459385871887, 0.0], [0.25683313608169556, 0.02960006147623062, 0.11211041361093521, 0.09736908972263336, 0.17546677589416504, 0.032068025320768356, 0.017857572063803673, 0.025635067373514175, 0.25305992364883423, 0.0]]], [[[0.10487863421440125, 0.7106320858001709, 0.1635318249464035, 0.011256101541221142, 0.0012767312582582235, 0.00310636218637228, 0.0013001860352233052, 0.0012553841806948185, 0.002762428717687726, 0.0], [0.021650908514857292, 0.0030605364590883255, 0.6595932245254517, 0.2987315356731415, 0.012945608235895634, 0.0028472936246544123, 7.557096250820905e-05, 0.00029089683084748685, 0.0008047237643040717, 0.0], [0.014272261410951614, 0.040512338280677795, 0.8595607280731201, 0.038314104080200195, 0.037397123873233795, 0.006795509252697229, 0.001303989440202713, 0.001011757180094719, 0.0008321924251504242, 0.0], [0.031783342361450195, 0.007319662719964981, 0.7663278579711914, 0.0010118860518559813, 0.1672297865152359, 0.02513650804758072, 0.000853335193824023, 0.0002817189379129559, 5.600590884569101e-05, 0.0], [0.002136597875505686, 0.00037253598566167057, 0.07588302344083786, 0.2252500057220459, 0.33551687002182007, 0.35751965641975403, 0.0027331046294420958, 0.00018122239271178842, 0.0004068210837431252, 0.0], [0.0004353485128376633, 0.0003557991876732558, 0.0003262429090682417, 0.003819868667051196, 0.33603885769844055, 0.2681770920753479, 0.3838857412338257, 0.0068349516950547695, 0.00012614508159458637, 0.0], [6.71677480568178e-05, 3.9912600186653435e-05, 0.00047830803669057786, 5.937727837590501e-05, 0.0014537296956405044, 0.6413838863372803, 0.29047340154647827, 0.06565171480178833, 0.0003929881495423615, 0.0], [0.00047039391938596964, 0.0007891620043665171, 0.0007817292353138328, 0.0010076714679598808, 0.00965806283056736, 0.003733346238732338, 0.35330116748809814, 0.5722718238830566, 0.05798657611012459, 0.0], [0.006178696174174547, 0.009340841323137283, 0.0005589249776676297, 0.005146770738065243, 0.0033258567564189434, 0.0016933922888711095, 0.06414961069822311, 0.3291752338409424, 0.5804308652877808, 0.0], [0.006624103523790836, 0.001978900283575058, 0.0081730792298913, 0.0030846702866256237, 0.0018904987955465913, 0.0014340116176754236, 0.005187559872865677, 0.029854312539100647, 0.9417726993560791, 0.0]], [[0.17277710139751434, 0.13871003687381744, 0.020699918270111084, 0.04190761595964432, 0.17760643362998962, 0.1702892780303955, 0.16168300807476044, 0.10000763088464737, 0.01631900854408741, 0.0], [0.9987638592720032, 0.0011447033612057567, 1.5495901607209817e-05, 2.3805538096333123e-10, 1.1166920899086108e-07, 4.81009180930414e-07, 2.3257289285538718e-05, 3.4320622944505885e-05, 1.812833215808496e-05, 0.0], [0.029870687052607536, 0.9668734669685364, 0.0031853404361754656, 3.7420595617732033e-06, 1.0481591772304455e-07, 4.711453893690987e-09, 4.051101996083162e-07, 1.359390239485947e-06, 6.518688314827159e-05, 0.0], [2.9839180569979362e-05, 0.0008244949858635664, 0.9990562796592712, 6.778111855965108e-05, 2.14482715819031e-05, 5.3428358959273226e-11, 7.202954205309808e-11, 7.697720239008277e-11, 1.422941551254553e-07, 0.0], [9.680035873316228e-05, 4.205659934086725e-05, 0.0021876851096749306, 0.9926192164421082, 0.0050464412197470665, 7.330636890401365e-06, 4.7689670878980905e-08, 8.238330573284713e-10, 9.979119397485192e-08, 0.0], [5.136659183335723e-06, 6.750806136324172e-08, 8.17252839624416e-06, 0.008817464113235474, 0.9640147089958191, 0.027066770941019058, 8.771067950874567e-05, 3.571775764044105e-09, 3.5257423647294672e-09, 0.0], [5.115869043947896e-07, 1.0059281407848175e-08, 1.3136859422502312e-07, 9.641905052149013e-08, 0.001335342414677143, 0.9957214593887329, 0.0029362423811107874, 7.136273325158982e-06, 1.1521567699901425e-08, 0.0], [3.561131961760111e-06, 2.727877870256634e-07, 8.369554507225985e-07, 1.214864764342849e-09, 4.873449597653234e-06, 0.024909861385822296, 0.9680997133255005, 0.006879042834043503, 0.00010210835171164945, 0.0], [0.00021467455371748656, 9.040503209689632e-05, 3.369562909938395e-05, 1.9265097961351785e-08, 9.727973520057276e-07, 2.4095537810353562e-05, 0.0040859803557395935, 0.8618475794792175, 0.1337023377418518, 0.0], [2.289768872287823e-06, 6.284429400693625e-05, 0.0001214230724144727, 2.809870807141124e-07, 1.092972157223926e-09, 1.0671180605825725e-09, 1.2438744079190656e-06, 0.024907555431127548, 0.9749038219451904, 0.0]], [[0.058097392320632935, 0.00935883168131113, 0.04822169989347458, 0.0048278868198394775, 0.191309854388237, 0.28154584765434265, 0.09391050785779953, 0.24126385152339935, 0.07146408408880234, 0.0], [0.10414423793554306, 0.027566324919462204, 0.021727869287133217, 0.033647697418928146, 0.026882247999310493, 0.17782779037952423, 0.05685214698314667, 0.45095938444137573, 0.10039239376783371, 0.0], [0.44215551018714905, 0.049670565873384476, 0.014098896645009518, 0.029011834412813187, 0.01834075152873993, 0.1358453929424286, 0.04072042554616928, 0.2330295443534851, 0.03712712228298187, 0.0], [0.10425814986228943, 0.06979154050350189, 0.036334071308374405, 0.028995294123888016, 0.015532439574599266, 0.1330128014087677, 0.063407763838768, 0.23157192766666412, 0.3170958459377289, 0.0], [0.3384562134742737, 0.055937401950359344, 0.038792647421360016, 0.00819220207631588, 0.03063569962978363, 0.09386011958122253, 0.07227522879838943, 0.30926018953323364, 0.05259038880467415, 0.0], [0.3519401550292969, 0.1823827177286148, 0.06509842723608017, 0.030452275648713112, 0.08377533406019211, 0.09469012171030045, 0.04247477278113365, 0.11751312017440796, 0.03167306259274483, 0.0], [0.3634622097015381, 0.14048337936401367, 0.08374395966529846, 0.038946691900491714, 0.03473563492298126, 0.06442954391241074, 0.019375532865524292, 0.22685663402080536, 0.027966352179646492, 0.0], [0.18070067465305328, 0.04645215719938278, 0.0992647334933281, 0.005799622740596533, 0.47514480352401733, 0.12094692885875702, 0.030788421630859375, 0.025236092507839203, 0.015666494145989418, 0.0], [0.5453059673309326, 0.10054859519004822, 0.01722547970712185, 0.06704734265804291, 0.007780902087688446, 0.07263857871294022, 0.022086072713136673, 0.1394840031862259, 0.027883058413863182, 0.0], [0.15028028190135956, 0.17163224518299103, 0.06043723225593567, 0.10140684247016907, 0.10512865334749222, 0.06778015196323395, 0.06512691080570221, 0.23085294663906097, 0.04735487326979637, 0.0]], [[0.11086989939212799, 0.14517885446548462, 0.17419463396072388, 0.060936953872442245, 0.08783368766307831, 0.11005676537752151, 0.03251044824719429, 0.07983692735433578, 0.19858187437057495, 0.0], [0.16660544276237488, 0.29352903366088867, 0.1008867621421814, 0.023942291736602783, 0.15022507309913635, 0.06581585109233856, 0.02344084158539772, 0.05208655819296837, 0.12346797436475754, 0.0], [0.1683349758386612, 0.22478938102722168, 0.06976605206727982, 0.1032773107290268, 0.16255290806293488, 0.08890064060688019, 0.03925151377916336, 0.023706944659352303, 0.11942004412412643, 0.0], [0.19914905726909637, 0.1368866264820099, 0.178489089012146, 0.11241752654314041, 0.06187256798148155, 0.0768556222319603, 0.01627686619758606, 0.07274915277957916, 0.14530348777770996, 0.0], [0.08000901341438293, 0.20181676745414734, 0.21235129237174988, 0.05340588092803955, 0.12758778035640717, 0.11278047412633896, 0.06906574964523315, 0.08596791326999664, 0.05701539292931557, 0.0], [0.14153669774532318, 0.10432923585176468, 0.09881750494241714, 0.08603313565254211, 0.10391980409622192, 0.06189347058534622, 0.06772381067276001, 0.08503933250904083, 0.25070688128471375, 0.0], [0.06525713205337524, 0.07869093865156174, 0.11366366595029831, 0.044226594269275665, 0.05455174669623375, 0.23646420240402222, 0.09933798015117645, 0.1198185384273529, 0.1879890412092209, 0.0], [0.09450254589319229, 0.027017319574952126, 0.06480545550584793, 0.10929621011018753, 0.11382008343935013, 0.17441418766975403, 0.11898359656333923, 0.06495486199855804, 0.23220552504062653, 0.0], [0.07681684195995331, 0.0671391412615776, 0.0905177965760231, 0.06064317002892494, 0.06652072072029114, 0.09855856746435165, 0.07360702753067017, 0.13956283032894135, 0.3266339898109436, 0.0], [0.12179998308420181, 0.07977079600095749, 0.08405954390764236, 0.1456507444381714, 0.14551174640655518, 0.07862778753042221, 0.09882251918315887, 0.14300917088985443, 0.1027478501200676, 0.0]], [[0.0261031873524189, 0.9575563073158264, 0.006272038444876671, 0.0037288309540599585, 0.0038619006518274546, 0.0007324732141569257, 0.0005133527447469532, 0.0003637235495261848, 0.0008679544553160667, 0.0], [0.02134888991713524, 0.08473973721265793, 0.6753177642822266, 0.028721673414111137, 0.14432094991207123, 0.027568204328417778, 0.0057298606261610985, 0.004451636224985123, 0.007801060564815998, 0.0], [0.03883299231529236, 0.030284319072961807, 0.5620493292808533, 0.09062989801168442, 0.17362907528877258, 0.08253934979438782, 0.010801085270941257, 0.00978847872465849, 0.0014453904004767537, 0.0], [0.002180949319154024, 0.003013473702594638, 0.16569769382476807, 0.008050205186009407, 0.7580646276473999, 0.061441101133823395, 0.001020166208036244, 0.0001067533012246713, 0.0004249440098647028, 0.0], [0.004150479566305876, 0.00034606645931489766, 0.3802972435951233, 0.06855826079845428, 0.29045602679252625, 0.1767650991678238, 0.06603583693504333, 0.0014808314153924584, 0.011909942142665386, 0.0], [0.006170187145471573, 0.0012396957026794553, 0.0354800671339035, 0.0032299698796123266, 0.03240001201629639, 0.5543311238288879, 0.30418315529823303, 0.051339369267225266, 0.01162647269666195, 0.0], [0.0035115755163133144, 0.0011483307462185621, 0.017956364899873734, 0.003783614607527852, 0.030611976981163025, 0.3673596978187561, 0.20627115666866302, 0.3506667912006378, 0.01869054324924946, 0.0], [0.0021685126703232527, 0.0006909942603670061, 0.010240452364087105, 0.01958688348531723, 0.004634156823158264, 0.11485372483730316, 0.04815557599067688, 0.7050773501396179, 0.0945921242237091, 0.0], [0.049201104789972305, 0.02397306263446808, 0.02337191067636013, 0.31066185235977173, 0.06433572620153427, 0.12544430792331696, 0.0786852017045021, 0.25179895758628845, 0.07252778857946396, 0.0], [0.010841209441423416, 0.0041772774420678616, 0.01548130251467228, 0.036074474453926086, 0.033387064933776855, 0.08192819356918335, 0.04784044623374939, 0.10195028781890869, 0.668319821357727, 0.0]], [[0.005738695617765188, 0.0068999892100691795, 0.4274883270263672, 0.08288666605949402, 0.1445126235485077, 0.04382907599210739, 0.10957401990890503, 0.05347184091806412, 0.1255987584590912, 0.0], [0.0025263649877160788, 0.00471830926835537, 0.13454590737819672, 0.4177793860435486, 0.28839975595474243, 0.029358303174376488, 0.017654288560152054, 0.0047735795378685, 0.10024390369653702, 0.0], [0.009192855097353458, 0.007133236154913902, 0.03149157017469406, 0.1856081485748291, 0.5691666603088379, 0.07386670261621475, 0.029819192364811897, 0.03683711960911751, 0.05688462406396866, 0.0], [0.00297820963896811, 0.0015070328954607248, 0.0025649494491517544, 0.0011051844339817762, 0.04088710993528366, 0.1953955888748169, 0.34000417590141296, 0.3367410898208618, 0.07881659269332886, 0.0], [0.003951869439333677, 0.009354526177048683, 0.007010620087385178, 0.0025927696842700243, 0.09962604194879532, 0.10909298062324524, 0.4455967843532562, 0.15358439087867737, 0.16918975114822388, 0.0], [0.0038829154800623655, 0.0036434896755963564, 0.006399825215339661, 0.000760377966798842, 0.010139851830899715, 0.038725122809410095, 0.10014155507087708, 0.48370444774627686, 0.35260239243507385, 0.0], [0.001297087874263525, 0.0014563009608536959, 0.013839880004525185, 0.0004286184557713568, 0.012207024730741978, 0.028704902157187462, 0.046600911766290665, 0.26406532526016235, 0.6313998103141785, 0.0], [0.0033481158316135406, 0.0038099782541394234, 0.0031049775425344706, 0.00033546099439263344, 0.0031272985506802797, 0.008788534440100193, 0.021183660253882408, 0.12157405912876129, 0.8347280025482178, 0.0], [0.3364367187023163, 0.17456969618797302, 0.051038213074207306, 0.006790165323764086, 0.024106895551085472, 0.0694134384393692, 0.02184627763926983, 0.061508405953645706, 0.25429028272628784, 0.0], [0.10536088049411774, 0.07750789821147919, 0.0850178673863411, 0.08725376427173615, 0.2586125433444977, 0.16756391525268555, 0.054291605949401855, 0.030132828280329704, 0.13425879180431366, 0.0]], [[0.034539882093667984, 0.0018589550163596869, 0.9604092836380005, 1.3120608855388127e-05, 2.1815638319822028e-05, 0.00012517283903434873, 8.019943197723478e-05, 0.0021589084062725306, 0.0007928607519716024, 0.0], [7.048832912914804e-07, 1.7815009414334781e-06, 0.9998455047607422, 0.0001518452918389812, 4.1070780554264275e-08, 2.7954746156799715e-11, 9.231376947582692e-12, 9.901777175969073e-09, 2.5545642756696907e-07, 0.0], [6.695767496012195e-08, 2.089915795977504e-07, 0.005368041805922985, 0.9945066571235657, 0.0001248170156031847, 2.304766155702964e-09, 2.762512718579302e-10, 3.973758211373024e-09, 9.372820954922645e-07, 0.0], [5.018761014413675e-13, 1.4841802622529476e-16, 4.663825770023777e-09, 3.820862737313746e-09, 0.9999942183494568, 4.988648925063899e-06, 4.967477167452938e-13, 1.416252587396787e-16, 2.1775358895380023e-16, 0.0], [4.666895758731471e-09, 7.292542437975502e-12, 2.898993545219497e-11, 4.2817244194637283e-10, 0.00027504604076966643, 0.9995728731155396, 0.00015239788626786321, 1.9082661839586734e-10, 2.232514032581706e-13, 0.0], [1.7137297136926577e-10, 5.3312285142048665e-12, 2.2368220760327594e-14, 4.904942142678549e-17, 8.726878775178193e-09, 0.004644036293029785, 0.9953435659408569, 1.324965796811739e-05, 6.982896899598856e-12, 0.0], [4.877224735189145e-10, 1.5497924055196677e-09, 6.021576987036426e-11, 8.955144165463396e-19, 1.7180077889825118e-13, 6.163505759104737e-07, 0.001256544259376824, 0.9987285733222961, 1.4209075743565336e-05, 0.0], [3.25698863434809e-08, 7.313030323530256e-07, 1.412931510458293e-06, 1.1662047555981733e-16, 8.495708612521816e-14, 1.1933978653379251e-13, 1.3303619539328793e-07, 0.01294001005589962, 0.9870572686195374, 0.0], [1.6884889646462398e-06, 2.6281904865754768e-05, 0.001122217159718275, 6.101166945882142e-06, 4.424501298672112e-08, 5.172042264953158e-13, 5.508820136168602e-11, 5.942968346062116e-05, 0.9987838268280029, 0.0], [4.288114359951578e-05, 6.015944563841913e-06, 0.004432132933288813, 0.025997335091233253, 0.000731422973331064, 6.87844434188456e-11, 8.199346692057408e-13, 7.098316245901515e-08, 0.9687905311584473, 0.0]], [[0.02526121959090233, 0.9527671933174133, 0.014345486648380756, 0.0014051493490114808, 0.003839265089482069, 0.00014350644778460264, 0.0006356940139085054, 0.00025237957015633583, 0.0013501241337507963, 0.0], [0.004122408106923103, 0.023777475580573082, 0.9002965688705444, 0.0682864859700203, 0.0017659803852438927, 0.0001271881628781557, 0.00011044178245356306, 0.0001890352723421529, 0.0013242338318377733, 0.0], [8.841444650897756e-05, 0.0002895947836805135, 0.06307922303676605, 0.9069769978523254, 0.028407124802470207, 0.000558151863515377, 0.00022284295118879527, 0.00018588549573905766, 0.00019132612214889377, 0.0], [1.889026179924258e-06, 3.9712713260087185e-06, 0.001210480579175055, 0.003201226470991969, 0.8290116786956787, 0.16640713810920715, 0.00015829727635718882, 4.0429063119518105e-06, 9.256136763724498e-07, 0.0], [0.000399262469727546, 5.1438626542221755e-05, 0.0001944842515513301, 0.0007700449787080288, 0.4879837930202484, 0.4847603738307953, 0.025640420615673065, 0.00018376839580014348, 1.6383723050239496e-05, 0.0], [4.30414620495867e-05, 1.017293288896326e-05, 8.407413588429336e-06, 5.451946094581217e-07, 0.000544070964679122, 0.021075371652841568, 0.9573339819908142, 0.0208626389503479, 0.00012169074034318328, 0.0], [0.00043880229350179434, 0.0004488519043661654, 0.000600603292696178, 1.4583132212919736e-07, 3.6701523640658706e-05, 0.010162030346691608, 0.37363454699516296, 0.559087336063385, 0.0555914081633091, 0.0], [0.0010709260823205113, 0.0006920771556906402, 0.0016655249055474997, 0.00010216240480076522, 1.0821948308148421e-05, 2.6151516067329794e-05, 0.01446994487196207, 0.2987785339355469, 0.6831837296485901, 0.0], [0.0002485924051143229, 0.00016839140153024346, 0.019545644521713257, 0.016785046085715294, 0.005671702325344086, 0.00014030851889401674, 0.001185068627819419, 0.04272715002298355, 0.9135279655456543, 0.0], [0.0039028520695865154, 0.0008621322922408581, 0.02400260791182518, 0.35541704297065735, 0.048350416123867035, 0.00013779231812804937, 0.00015075977717060596, 0.0015127401566132903, 0.5656636953353882, 0.0]]], [[[0.09929531812667847, 0.3125585615634918, 0.26699960231781006, 0.036189958453178406, 0.01689508929848671, 0.05626463145017624, 0.014853590168058872, 0.021625356748700142, 0.17531771957874298, 0.0], [0.6598999500274658, 0.04883529245853424, 0.24573534727096558, 0.008949915878474712, 0.008034803904592991, 0.0058951652608811855, 0.001835338887758553, 0.0024289200082421303, 0.018385181203484535, 0.0], [0.28377673029899597, 0.4307016134262085, 0.19275489449501038, 0.05968217924237251, 0.007509235758334398, 0.00627214927226305, 0.0010254314402118325, 0.0010938378982245922, 0.017183959484100342, 0.0], [0.00751571636646986, 0.01881357654929161, 0.9318985342979431, 0.014481762424111366, 0.02105659246444702, 0.0032304797787219286, 0.00013498679618351161, 2.4857494281604886e-05, 0.0028432777617126703, 0.0], [0.08691340684890747, 0.01259385235607624, 0.21131311357021332, 0.15839329361915588, 0.3931293189525604, 0.10845079272985458, 0.004768806044012308, 0.0032348930835723877, 0.021202562376856804, 0.0], [0.029192518442869186, 0.06438057869672775, 0.033022571355104446, 0.04279496520757675, 0.6011855006217957, 0.17385539412498474, 0.03754284232854843, 0.006468524225056171, 0.011557108722627163, 0.0], [0.006125382613390684, 0.006982659921050072, 0.004575703293085098, 0.0037440320011228323, 0.36007580161094666, 0.5409486889839172, 0.0626324936747551, 0.00843171589076519, 0.006483553443104029, 0.0], [0.0017123871948570013, 0.017555760219693184, 0.012620777823030949, 0.00947127677500248, 0.08178496360778809, 0.2538650631904602, 0.19189175963401794, 0.255443274974823, 0.17565478384494781, 0.0], [0.02615528553724289, 0.002552631078287959, 0.01957615464925766, 0.021708596497774124, 0.008856788277626038, 0.021813882514834404, 0.052812058478593826, 0.19690369069576263, 0.6496209502220154, 0.0], [0.004899451043456793, 0.005663626827299595, 0.012920243665575981, 0.007757777348160744, 0.014441648498177528, 0.021742597222328186, 0.05050418898463249, 0.35952994227409363, 0.5225404500961304, 0.0]], [[0.8470081686973572, 0.043761640787124634, 0.000660977209918201, 0.00018918802379630506, 0.01478277612477541, 0.00942840613424778, 0.06798462569713593, 0.011217072606086731, 0.004967056680470705, 0.0], [0.9998846054077148, 9.298400982515886e-05, 7.557733283647394e-08, 4.2952964861113496e-13, 4.9295836510032665e-12, 3.2098330660090824e-09, 5.042555585532682e-06, 1.7450745872338302e-05, 2.33268380611662e-07, 0.0], [2.118646625604015e-05, 0.9999122619628906, 6.629392737522721e-05, 1.312590147684034e-09, 2.7011800782239526e-11, 6.488713510726871e-14, 1.250517189799183e-10, 3.650779589747799e-08, 2.9122876554765753e-08, 0.0], [1.1949000816580124e-11, 3.2456850362905243e-07, 1.0, 3.0732459777027543e-07, 4.943382370115046e-10, 1.2582140899967535e-17, 7.485076299292317e-18, 2.998638596002183e-14, 1.3861908843004755e-10, 0.0], [5.382360668271247e-10, 8.056646905174603e-09, 0.00035429277340881526, 0.9995232820510864, 0.00012279135989956558, 1.6631793720023325e-09, 1.8857353897253244e-14, 9.284229879032505e-15, 1.8321206097376974e-12, 0.0], [8.614902194392648e-12, 3.5818106835540375e-13, 4.029543365646759e-09, 3.1193526410788763e-06, 0.9959417581558228, 0.004055640660226345, 2.0883923923520342e-08, 1.5150488692381933e-14, 1.8145465705242968e-17, 0.0], [2.3006167283734502e-12, 4.150501252094593e-15, 2.9068709245239077e-12, 2.726213081238188e-13, 1.0724114645199734e-06, 0.9999104142189026, 8.954491204349324e-05, 3.77386955019432e-10, 8.537545242676776e-16, 0.0], [8.656632632941808e-10, 2.8593680201360883e-10, 4.910126749635424e-10, 3.37084723469553e-15, 1.3075121541028523e-10, 0.0003027402563020587, 0.999218225479126, 0.00047932929010130465, 1.4258912273135138e-08, 0.0], [1.0133464911632473e-07, 1.7307414168499236e-07, 2.3342326471720298e-07, 4.688030020606748e-13, 1.5028331227032177e-12, 5.3876938466146385e-09, 0.00158107269089669, 0.994592010974884, 0.0038271904923021793, 0.0], [2.33300490037891e-10, 1.2628836998374027e-07, 1.2948551102454076e-06, 3.169647599943204e-10, 1.5141217069741288e-14, 8.21656009561151e-15, 2.347289251858342e-09, 0.0025180077645927668, 0.9974797964096069, 0.0]], [[0.011770328506827354, 0.014021093025803566, 0.10656744986772537, 0.04667313024401665, 0.13704808056354523, 0.04681243374943733, 0.08347266167402267, 0.3310377299785614, 0.22259721159934998, 0.0], [0.009583584032952785, 0.010384900495409966, 0.09424954652786255, 0.09874095767736435, 0.2214881330728531, 0.08727390319108963, 0.09998933970928192, 0.16299772262573242, 0.21529172360897064, 0.0], [0.040493443608284, 0.05296378955245018, 0.12471148371696472, 0.04822944849729538, 0.2201310694217682, 0.13458549976348877, 0.16853223741054535, 0.12866733968257904, 0.08168572932481766, 0.0], [0.014574799686670303, 0.015747353434562683, 0.011357909068465233, 0.008449763990938663, 0.024292636662721634, 0.06141809746623039, 0.10683716088533401, 0.6414783596992493, 0.1158437430858612, 0.0], [0.0041047134436666965, 0.010159346275031567, 0.006441198755055666, 0.009530052542686462, 0.061682768166065216, 0.07391326874494553, 0.3019707202911377, 0.45178085565567017, 0.08041701465845108, 0.0], [0.013634801842272282, 0.03774101287126541, 0.015713637694716454, 0.01436087116599083, 0.06650711596012115, 0.06899012625217438, 0.1819150745868683, 0.376579225063324, 0.2245580554008484, 0.0], [0.03166442736983299, 0.07015468180179596, 0.1104653850197792, 0.016236137598752975, 0.18190902471542358, 0.08141329884529114, 0.15690769255161285, 0.22899281978607178, 0.12225660681724548, 0.0], [0.10994787514209747, 0.08447018265724182, 0.05270976573228836, 0.013435273431241512, 0.06919412314891815, 0.04981343820691109, 0.24833135306835175, 0.2721446752548218, 0.09995320439338684, 0.0], [0.39435869455337524, 0.21061576902866364, 0.1085209921002388, 0.004411425907164812, 0.06908565759658813, 0.04562678933143616, 0.02559957653284073, 0.06842028349637985, 0.0733608528971672, 0.0], [0.2682938873767853, 0.18270419538021088, 0.12741044163703918, 0.03156330808997154, 0.10574271529912949, 0.0955348014831543, 0.052997197955846786, 0.0821281224489212, 0.05362524837255478, 0.0]], [[8.027511648833752e-05, 0.0010475717717781663, 0.9977908730506897, 0.0002747455728240311, 0.000536168459802866, 9.231048170477152e-05, 0.00010586588905425742, 1.1979215742030647e-05, 5.969347330392338e-05, 0.0], [0.00012679747305810452, 5.715776205761358e-05, 0.922791600227356, 0.07177212089300156, 0.002934361109510064, 0.0005548547487705946, 0.001313770073466003, 2.2278460164670832e-05, 0.0004267726035322994, 0.0], [0.0063565499149262905, 0.0009426671313121915, 0.23976103961467743, 0.6402719020843506, 0.019077658653259277, 0.04590805247426033, 0.0423574335873127, 0.00055616011377424, 0.0047685266472399235, 0.0], [0.00012164804502390325, 1.1780298336816486e-05, 0.0001827587402658537, 0.00020120454428251833, 0.9978508353233337, 0.0014421044616028666, 6.411068170564249e-05, 4.628768147085793e-05, 7.896547322161496e-05, 0.0], [0.03763079643249512, 0.00208932813256979, 0.0006042887107469141, 0.5138440728187561, 0.19755180180072784, 0.029773280024528503, 0.15554653108119965, 0.015671545639634132, 0.0472884401679039, 0.0], [3.8805592339485884e-05, 1.2464041901694145e-05, 9.030352521222085e-05, 1.7544094589538872e-05, 0.0006991567788645625, 0.039246365427970886, 0.9305517077445984, 0.02403487078845501, 0.005308609921485186, 0.0], [0.003011370776221156, 0.005974559113383293, 0.003425326431170106, 0.001937237335368991, 0.01794668287038803, 0.06517820060253143, 0.25853174924850464, 0.28359606862068176, 0.3603990077972412, 0.0], [0.0019687232561409473, 0.0019828693475574255, 0.0009621239732950926, 0.0017320939805358648, 0.008526722900569439, 0.012685983441770077, 0.060781437903642654, 0.38653799891471863, 0.524821937084198, 0.0], [0.06319467723369598, 0.3812802731990814, 0.07775641977787018, 0.0546053946018219, 0.0410320870578289, 0.010218034498393536, 0.022281788289546967, 0.04868403077125549, 0.30094724893569946, 0.0], [0.06465335935354233, 0.0841824859380722, 0.028003698214888573, 0.01470992248505354, 0.013160775415599346, 0.006258893292397261, 0.003528257366269827, 0.022525515407323837, 0.7629771828651428, 0.0]], [[0.00496841873973608, 0.010829150676727295, 0.03283568099141121, 0.009884797036647797, 0.047239795327186584, 0.06476759165525436, 0.11417313665151596, 0.6207002401351929, 0.09460126608610153, 0.0], [0.014457895420491695, 0.06253711134195328, 0.10527490824460983, 0.051058270037174225, 0.04873393103480339, 0.058862265199422836, 0.13390113413333893, 0.44425415992736816, 0.0809202790260315, 0.0], [0.09337731450796127, 0.22848238050937653, 0.11594945937395096, 0.04185759648680687, 0.012283656746149063, 0.1264774352312088, 0.19395124912261963, 0.16978387534618378, 0.017837027087807655, 0.0], [0.7125841975212097, 0.21987739205360413, 0.020619483664631844, 0.02881826087832451, 0.009833384305238724, 0.004124533850699663, 0.0008098671096377075, 0.0004809961246792227, 0.0028517041355371475, 0.0], [0.029080189764499664, 0.33611080050468445, 0.12628716230392456, 0.0817737877368927, 0.1908877044916153, 0.0943109318614006, 0.05712011829018593, 0.06781000643968582, 0.016619542613625526, 0.0], [0.07309448719024658, 0.07739713788032532, 0.0567743182182312, 0.03291132301092148, 0.16455504298210144, 0.1779973953962326, 0.2714528441429138, 0.13868720829486847, 0.007130389101803303, 0.0], [0.2111189365386963, 0.06559138745069504, 0.041267942637205124, 0.009358389303088188, 0.20342323184013367, 0.1869427114725113, 0.19775718450546265, 0.07797932624816895, 0.006560905836522579, 0.0], [0.08770362287759781, 0.12808790802955627, 0.023038268089294434, 0.17453545331954956, 0.09798892587423325, 0.11677049100399017, 0.09396524727344513, 0.26174578070640564, 0.01616443321108818, 0.0], [0.35409674048423767, 0.0420590415596962, 0.00930203776806593, 0.3349112272262573, 0.03967892378568649, 0.15319538116455078, 0.022175630554556847, 0.0432865284383297, 0.0012946304632350802, 0.0], [0.10030248761177063, 0.08145220577716827, 0.053510215133428574, 0.08076464384794235, 0.07446140050888062, 0.13495147228240967, 0.2503055930137634, 0.17467214167118073, 0.04957977309823036, 0.0]], [[0.140123188495636, 0.010056160390377045, 0.0845566838979721, 0.03108036518096924, 0.16015855967998505, 0.30321791768074036, 0.04101235046982765, 0.0719088688492775, 0.1578858345746994, 0.0], [0.6134085655212402, 0.1547522246837616, 0.03818102553486824, 0.001013039844110608, 0.013297338038682938, 0.008754062466323376, 0.005134810693562031, 0.0324203222990036, 0.13303862512111664, 0.0], [0.6891250014305115, 0.17779399454593658, 0.09809523820877075, 0.006996517535299063, 0.007719202898442745, 0.0016296659596264362, 0.010662317276000977, 0.004304768517613411, 0.0036729834973812103, 0.0], [0.04376668110489845, 0.09640005975961685, 0.8100467324256897, 0.018579678609967232, 0.017539000138640404, 0.0008903089328669012, 0.0009985471842810512, 0.003613307373598218, 0.008165487088263035, 0.0], [0.03085213713347912, 0.025543441995978355, 0.6937543153762817, 0.17392684519290924, 0.03124413825571537, 0.02177071012556553, 0.007475809659808874, 0.003389933379366994, 0.012042560614645481, 0.0], [0.020024498924613, 0.002941351616755128, 0.05481509119272232, 0.183584526181221, 0.4182366132736206, 0.25923243165016174, 0.05362166836857796, 0.0045484029687941074, 0.002995501272380352, 0.0], [0.006091661751270294, 0.0012010806240141392, 0.008193010464310646, 0.009258490055799484, 0.15450483560562134, 0.7388086915016174, 0.06675267219543457, 0.01373466569930315, 0.0014547830214723945, 0.0], [0.0014694302808493376, 0.0017220929730683565, 0.005703628528863192, 0.0032696493435651064, 0.01713697426021099, 0.49356934428215027, 0.3729664385318756, 0.05505490303039551, 0.04910748079419136, 0.0], [0.0052343131974339485, 0.004969605710357428, 0.005609327927231789, 0.0007064095698297024, 0.005421568639576435, 0.045942794531583786, 0.22256441414356232, 0.43683722615242004, 0.27271413803100586, 0.0], [0.011939328163862228, 0.019054703414440155, 0.010745645500719547, 0.006908759940415621, 0.009522099047899246, 0.006889646407216787, 0.12289831787347794, 0.2292226105928421, 0.5828191637992859, 0.0]], [[0.0014003654941916466, 0.00935011450201273, 0.8996742963790894, 0.029868578538298607, 0.05752851441502571, 0.0008847691351547837, 0.0005429417942650616, 0.0004143548430874944, 0.00033632174017839134, 0.0], [0.0005502321291714907, 0.003854800947010517, 0.8475468754768372, 0.06876953691244125, 0.07909266650676727, 5.498397149494849e-05, 2.1647396351909265e-05, 6.648269391007489e-06, 0.00010276718239765614, 0.0], [0.0025599629152566195, 0.010113149881362915, 0.21385346353054047, 0.26065483689308167, 0.44287386536598206, 0.0458405464887619, 0.013329384848475456, 0.0076821851544082165, 0.0030928871128708124, 0.0], [0.0002600199659354985, 3.3608048397582024e-05, 0.0020931970793753862, 0.007768034934997559, 0.9780486822128296, 0.011327453888952732, 0.00041993538616225123, 4.125805935473181e-05, 8.07127889856929e-06, 0.0], [0.0010751935187727213, 0.00017567894246894866, 0.004301255568861961, 0.0010412797564640641, 0.012584774754941463, 0.5903621912002563, 0.36841556429862976, 0.021853862330317497, 0.00019013854034710675, 0.0], [0.00036065353197045624, 0.00041391997365280986, 0.00018344201089348644, 1.21664334074012e-05, 0.0008204621262848377, 0.02300320193171501, 0.7380199432373047, 0.23411831259727478, 0.0030676021706312895, 0.0], [0.0007766868220642209, 0.00179819215554744, 0.0031821478623896837, 1.569229607412126e-05, 0.001023828866891563, 0.004582487046718597, 0.04412461444735527, 0.8326310515403748, 0.11186514794826508, 0.0], [0.002560202032327652, 0.0021961459424346685, 0.0012966376962140203, 3.874531466863118e-05, 0.00012789985339622945, 0.00017348439723718911, 0.06046983227133751, 0.07663179188966751, 0.856505274772644, 0.0], [0.05078713223338127, 0.09524610638618469, 0.03648101165890694, 0.050540339201688766, 0.009611092507839203, 0.0027538249269127846, 0.009690326638519764, 0.015156174078583717, 0.7297340035438538, 0.0], [0.017420543357729912, 0.009016300551593304, 0.008660875260829926, 0.04713813588023186, 0.042011067271232605, 0.003162879729643464, 0.00040178498602472246, 0.005153133533895016, 0.8670352697372437, 0.0]], [[0.22553573548793793, 0.2680850327014923, 0.019470686092972755, 0.14175784587860107, 0.053468361496925354, 0.02777918614447117, 0.05628729239106178, 0.04874898120760918, 0.15886712074279785, 0.0], [0.28905513882637024, 0.12247822433710098, 0.046002231538295746, 0.1958596557378769, 0.10771062225103378, 0.06661061197519302, 0.07628067582845688, 0.02713944762945175, 0.06886337697505951, 0.0], [0.04905243590474129, 0.05268532782793045, 0.11285670101642609, 0.09091109782457352, 0.24185867607593536, 0.20752739906311035, 0.04222555831074715, 0.05885446071624756, 0.14402832090854645, 0.0], [0.06971512734889984, 0.14066818356513977, 0.05942149832844734, 0.21028849482536316, 0.10966084897518158, 0.08002462983131409, 0.10722756385803223, 0.1377343237400055, 0.08525940030813217, 0.0], [0.1429702192544937, 0.26978883147239685, 0.12360350787639618, 0.05825580656528473, 0.022957824170589447, 0.2193503975868225, 0.0713224932551384, 0.06461618840694427, 0.02713468112051487, 0.0], [0.07554306834936142, 0.051579318940639496, 0.2103901356458664, 0.03246254473924637, 0.12347473949193954, 0.20594589412212372, 0.10415074229240417, 0.14436782896518707, 0.05208563804626465, 0.0], [0.10752540081739426, 0.08459899574518204, 0.07340764254331589, 0.019914846867322922, 0.048802055418491364, 0.2628321945667267, 0.23049965500831604, 0.11754198372364044, 0.05487721040844917, 0.0], [0.054300110787153244, 0.03522595763206482, 0.19028180837631226, 0.11526520550251007, 0.043804410845041275, 0.1941872388124466, 0.12765192985534668, 0.19942660629749298, 0.03985673561692238, 0.0], [0.13462598621845245, 0.09648311138153076, 0.08205218613147736, 0.241444393992424, 0.024601474404335022, 0.03336581960320473, 0.09252338856458664, 0.0673752948641777, 0.22752824425697327, 0.0], [0.1438782811164856, 0.15257491171360016, 0.11015111207962036, 0.2259429395198822, 0.11582648009061813, 0.06522659957408905, 0.06865230947732925, 0.07465960830450058, 0.04308782145380974, 0.0]]], [[[0.008583037182688713, 0.007665919605642557, 0.023932937532663345, 0.013663848862051964, 0.00724611384794116, 0.01780843734741211, 0.04220886155962944, 0.035630952566862106, 0.8432599306106567, 0.0], [0.005249040201306343, 0.006725347600877285, 0.022601336240768433, 0.004061485640704632, 0.003380684182047844, 0.05792760103940964, 0.08571713417768478, 0.017759306356310844, 0.796578049659729, 0.0], [0.014741344377398491, 0.08626628667116165, 0.11416944116353989, 0.06755448132753372, 0.010767532512545586, 0.037519536912441254, 0.13943251967430115, 0.03284287825226784, 0.4967060387134552, 0.0], [0.8946033120155334, 0.07520093768835068, 0.007621173746883869, 0.004705401603132486, 0.005715447012335062, 0.0016736779361963272, 0.0011882666731253266, 0.0005322583019733429, 0.008759708143770695, 0.0], [0.17331360280513763, 0.32618802785873413, 0.1865183413028717, 0.12219864875078201, 0.08427056670188904, 0.017049826681613922, 0.027256622910499573, 0.011689829640090466, 0.05151442065834999, 0.0], [0.024287043139338493, 0.22289688885211945, 0.2742122411727905, 0.1883603185415268, 0.1339159905910492, 0.04209006950259209, 0.04496186599135399, 0.03600992262363434, 0.033265650272369385, 0.0], [0.01142946071922779, 0.05564042925834656, 0.055694323033094406, 0.5140662789344788, 0.1435396671295166, 0.038738954812288284, 0.06230159476399422, 0.07060025632381439, 0.047988954931497574, 0.0], [0.03956271708011627, 0.0978141501545906, 0.053332336246967316, 0.4993227422237396, 0.15091775357723236, 0.05724353715777397, 0.05616844817996025, 0.014285729266703129, 0.03135249391198158, 0.0], [0.04081583395600319, 0.017569201067090034, 0.031049959361553192, 0.07860688865184784, 0.1978374421596527, 0.3013133406639099, 0.2561938464641571, 0.010236106812953949, 0.06637723743915558, 0.0], [0.005346705671399832, 0.017637349665164948, 0.01670711860060692, 0.027819450944662094, 0.014111858792603016, 0.15744496881961823, 0.29349666833877563, 0.10989060997962952, 0.357545405626297, 0.0]], [[0.14326919615268707, 0.06937730312347412, 0.4621289074420929, 0.06899607926607132, 0.20691490173339844, 0.03204977884888649, 0.010433961637318134, 0.001572124194353819, 0.005257652141153812, 0.0], [0.7372201681137085, 0.03819188475608826, 0.19263039529323578, 0.00509582320228219, 0.014029700309038162, 0.004338367842137814, 0.0016640998655930161, 0.0023727945517748594, 0.004456941969692707, 0.0], [0.6392468810081482, 0.09436309337615967, 0.23124097287654877, 0.009032140485942364, 0.016629014164209366, 0.004053707234561443, 0.0011662752367556095, 0.0013368013314902782, 0.0029307324439287186, 0.0], [0.15959776937961578, 0.060010410845279694, 0.6323540210723877, 0.04208587482571602, 0.09941276162862778, 0.001314919558353722, 0.0003186642425134778, 0.00045829309965483844, 0.004447522107511759, 0.0], [0.06331828236579895, 0.03697410970926285, 0.6882537603378296, 0.04094800353050232, 0.1500014215707779, 0.014815385453402996, 0.0006663103122264147, 0.0014023728435859084, 0.0036205528303980827, 0.0], [0.02740752510726452, 0.007235638331621885, 0.2575177550315857, 0.2825733423233032, 0.26921361684799194, 0.13694509863853455, 0.012512636370956898, 0.00419765617698431, 0.0023968773894011974, 0.0], [0.026527998968958855, 0.0014296816661953926, 0.0034867397043854, 0.11850380897521973, 0.15826237201690674, 0.4342584013938904, 0.21162042021751404, 0.04376554489135742, 0.0021449460182338953, 0.0], [0.0008783259545452893, 0.0010965524706989527, 0.006981557235121727, 0.007060014642775059, 0.27200379967689514, 0.45634904503822327, 0.1935150921344757, 0.03130912408232689, 0.030806703492999077, 0.0], [0.012816469185054302, 0.004784241784363985, 0.007290879264473915, 0.0027244724333286285, 0.0388973169028759, 0.12052476406097412, 0.3920805752277374, 0.10759556293487549, 0.3132855296134949, 0.0], [0.0021361028775572777, 0.003133963793516159, 0.003311034757643938, 0.0013810866512358189, 0.004479007329791784, 0.007041627541184425, 0.09507600963115692, 0.5596640706062317, 0.32377713918685913, 0.0]], [[0.001748488168232143, 0.011698327027261257, 0.047558922320604324, 0.7770814299583435, 0.15215088427066803, 0.0056790816597640514, 0.0010312696686014533, 0.0011229184456169605, 0.0019287114264443517, 0.0], [0.000820137036498636, 0.0007328591891564429, 0.012266330420970917, 0.94822758436203, 0.02221596986055374, 0.006038068328052759, 0.0018012026557698846, 0.002194090047851205, 0.0057037402875721455, 0.0], [0.0017187671037390828, 0.0012595502194017172, 0.00971528235822916, 0.8996129631996155, 0.03184645250439644, 0.026646586135029793, 0.01671759784221649, 0.005960865877568722, 0.006522092968225479, 0.0], [0.010048117488622665, 0.003920346032828093, 0.01464000903069973, 0.028398782014846802, 0.047600653022527695, 0.6803404688835144, 0.07394693046808243, 0.046145662665367126, 0.09495888650417328, 0.0], [0.0020061242394149303, 0.0010488562984392047, 0.0021137045696377754, 0.03403143212199211, 0.040159616619348526, 0.4656003415584564, 0.16990402340888977, 0.16164875030517578, 0.12348736822605133, 0.0], [0.0023888982832431793, 0.0010238748509436846, 0.0031129145063459873, 0.00400560162961483, 0.005227341782301664, 0.050918273627758026, 0.28773385286331177, 0.5181463956832886, 0.12744267284870148, 0.0], [0.0057381619699299335, 0.0037375285755842924, 0.006655727047473192, 0.0010085925459861755, 0.005980721674859524, 0.02943945676088333, 0.05893365666270256, 0.6100658774375916, 0.2784405052661896, 0.0], [0.003593636676669121, 0.0024473541416227818, 0.002264569513499737, 0.00914584007114172, 0.0013253247598186135, 0.010908454656600952, 0.07958614826202393, 0.12585432827472687, 0.7648744583129883, 0.0], [0.031058229506015778, 0.02174283377826214, 0.012145284563302994, 0.010826506651937962, 0.01352943666279316, 0.021966811269521713, 0.055832888931035995, 0.11603516340255737, 0.7168627977371216, 0.0], [0.20383700728416443, 0.06762446463108063, 0.042199794203042984, 0.021983252838253975, 0.11625738441944122, 0.013579235412180424, 0.025292381644248962, 0.08914806693792343, 0.4200783669948578, 0.0]], [[0.022736268118023872, 0.02286626398563385, 0.14116300642490387, 0.13108347356319427, 0.23994718492031097, 0.1924150437116623, 0.01816762052476406, 0.04976898059248924, 0.18185211718082428, 0.0], [0.05882957577705383, 0.028569074347615242, 0.23305171728134155, 0.053790394216775894, 0.18451730906963348, 0.2002667486667633, 0.015585620887577534, 0.052768219262361526, 0.17262138426303864, 0.0], [0.09136874228715897, 0.08459936082363129, 0.05023255571722984, 0.21660202741622925, 0.1335863471031189, 0.10654665529727936, 0.02717875875532627, 0.06888726353645325, 0.22099831700325012, 0.0], [0.04131297022104263, 0.05848437175154686, 0.3077566921710968, 0.040097035467624664, 0.16343727707862854, 0.11984208226203918, 0.06441103667020798, 0.0850440189242363, 0.11961443722248077, 0.0], [0.06447532773017883, 0.05503746494650841, 0.11529060453176498, 0.13719302415847778, 0.0843825414776802, 0.22279226779937744, 0.11870565265417099, 0.05292103812098503, 0.14920207858085632, 0.0], [0.061820220202207565, 0.03663187846541405, 0.08412205427885056, 0.386857271194458, 0.1083698719739914, 0.1462787538766861, 0.03903358429670334, 0.026668915525078773, 0.11021733283996582, 0.0], [0.08746915310621262, 0.025642354041337967, 0.16437062621116638, 0.19346435368061066, 0.10867251455783844, 0.12237238138914108, 0.06722743809223175, 0.0922309011220932, 0.13855047523975372, 0.0], [0.10294228792190552, 0.07313423603773117, 0.18607352674007416, 0.09769721329212189, 0.1089077964425087, 0.26933327317237854, 0.06555335968732834, 0.061070602387189865, 0.03528755530714989, 0.0], [0.12094805389642715, 0.14730192720890045, 0.09877816587686539, 0.21085986495018005, 0.06241541728377342, 0.22994481027126312, 0.04595630243420601, 0.04531335458159447, 0.0384821854531765, 0.0], [0.11032164841890335, 0.07897982746362686, 0.08231978863477707, 0.2677886188030243, 0.1231643408536911, 0.0929633229970932, 0.08270144462585449, 0.06097007542848587, 0.10079105943441391, 0.0]], [[0.008687321096658707, 0.012162125669419765, 0.02774685248732567, 0.0013578477082774043, 0.052177976816892624, 0.027187975123524666, 0.05590689554810524, 0.020962538197636604, 0.7938104867935181, 0.0], [0.005042325239628553, 0.015503124333918095, 0.010042164474725723, 0.0008876739302650094, 0.011308688670396805, 0.010491759516298771, 0.03130592033267021, 0.04934320226311684, 0.8660751581192017, 0.0], [0.013016406446695328, 0.03886239603161812, 0.027493299916386604, 0.029101338237524033, 0.009947741404175758, 0.00769558921456337, 0.035501737147569656, 0.023772817105054855, 0.8146085143089294, 0.0], [0.018851714208722115, 0.05105733126401901, 0.8005384206771851, 0.01116525661200285, 0.09583853930234909, 0.0015093896072357893, 0.005055624525994062, 0.0006665397086180747, 0.015317671000957489, 0.0], [0.01609102450311184, 0.023716216906905174, 0.5135837197303772, 0.10603100061416626, 0.26668840646743774, 0.019648341462016106, 0.01755940169095993, 0.01368130836635828, 0.023000601679086685, 0.0], [0.01718730293214321, 0.02692273259162903, 0.05480796471238136, 0.010818017646670341, 0.7150712013244629, 0.0585104264318943, 0.04717297852039337, 0.030360547825694084, 0.039148781448602676, 0.0], [0.006439396180212498, 0.012697076424956322, 0.014188298024237156, 0.000897688849363476, 0.7481768727302551, 0.15047557651996613, 0.03333613649010658, 0.01207506563514471, 0.021714046597480774, 0.0], [0.009459104388952255, 0.022298788651823997, 0.013802104629576206, 0.011955137364566326, 0.03879927098751068, 0.1585427075624466, 0.07075291126966476, 0.329448938369751, 0.3449409306049347, 0.0], [0.04810584336519241, 0.017975708469748497, 0.025123968720436096, 0.023182567209005356, 0.020010611042380333, 0.04571577161550522, 0.1801854819059372, 0.06764508783817291, 0.5720548629760742, 0.0], [0.026153914630413055, 0.0356404148042202, 0.10573611408472061, 0.06201518699526787, 0.06006328761577606, 0.09286139905452728, 0.2927103638648987, 0.20419549942016602, 0.12062377482652664, 0.0]], [[0.02415475994348526, 0.0027711745351552963, 0.003856832394376397, 0.0957413911819458, 0.02159286104142666, 0.03336814045906067, 0.009564127773046494, 0.03954486921429634, 0.7694058418273926, 0.0], [0.9052021503448486, 0.02053658291697502, 0.0014916026266291738, 0.00022646080469712615, 4.7710393118904904e-05, 0.000383042759494856, 0.014123834669589996, 0.0205638837069273, 0.03742456063628197, 0.0], [0.37607336044311523, 0.6030705571174622, 0.0068079219199717045, 0.0036466827150434256, 9.876023250399157e-05, 2.0246809071977623e-05, 0.0007042856304906309, 0.002560489112511277, 0.007017510011792183, 0.0], [5.0091031880583614e-05, 0.00024915943504311144, 0.9895205497741699, 0.006273698527365923, 0.0016484790248796344, 4.1711446101544425e-05, 7.522702958340233e-07, 1.2660359971050639e-05, 0.002202932955697179, 0.0], [8.009441080503166e-05, 9.311464236816391e-05, 0.006593613885343075, 0.9913647770881653, 0.0018261962104588747, 1.6436462829005904e-05, 8.038865075832291e-07, 1.0318336762793479e-06, 2.3524326024926268e-05, 0.0], [3.1561212381348014e-05, 1.8178753862230224e-06, 0.00011904581333510578, 0.027105441316962242, 0.8800897598266602, 0.09253741800785065, 0.00010895416926359758, 5.953493655397324e-06, 1.9602707368449046e-07, 0.0], [1.7160528553716858e-09, 1.4191656530493368e-11, 3.274841375855431e-08, 2.1219284462858923e-07, 1.9925082597183064e-05, 0.9999751448631287, 3.130498271275428e-06, 1.9788064946624218e-06, 3.1215499074477293e-09, 0.0], [1.2861962204624433e-05, 5.737682045037218e-07, 2.0471109110076213e-06, 1.0477544492459856e-05, 6.581651632586727e-06, 0.02534269355237484, 0.16125597059726715, 0.5878354907035828, 0.22553342580795288, 0.0], [0.0009172551217488945, 7.270056084962562e-05, 2.2026280930731446e-05, 4.6261970965133514e-06, 4.921669642499182e-06, 4.060195351485163e-05, 0.027831047773361206, 0.33271971344947815, 0.6383873224258423, 0.0], [1.3075091374048498e-05, 6.147480598883703e-05, 4.768987855641171e-05, 2.045959490715177e-06, 1.1152823553572944e-08, 3.07468525306831e-07, 0.0007055726600810885, 0.02803119830787182, 0.9711382985115051, 0.0]], [[0.060361556708812714, 0.015829458832740784, 0.05784451961517334, 0.3351474404335022, 0.06477320939302444, 0.04427827522158623, 0.09356044977903366, 0.03362266346812248, 0.2945823669433594, 0.0], [0.051239900290966034, 0.0459107868373394, 0.10656695812940598, 0.4080160856246948, 0.16381530463695526, 0.044977184385061264, 0.05972094088792801, 0.009804679080843925, 0.10994797199964523, 0.0], [0.019088272005319595, 0.05349855497479439, 0.4389742910861969, 0.022328443825244904, 0.03395729511976242, 0.20592069625854492, 0.007582489866763353, 0.08437496423721313, 0.13427504897117615, 0.0], [0.03275543451309204, 0.01311502419412136, 0.038520246744155884, 0.47789818048477173, 0.04586595296859741, 0.01380465179681778, 0.03337283805012703, 0.07212045043706894, 0.27254730463027954, 0.0], [0.04071904346346855, 0.043366871774196625, 0.1190471276640892, 0.18268215656280518, 0.2763146162033081, 0.029253922402858734, 0.017268449068069458, 0.0670313611626625, 0.22431644797325134, 0.0], [0.04853136092424393, 0.0034203159157186747, 0.17822766304016113, 0.005087696481496096, 0.02670232392847538, 0.5734196305274963, 0.06478680670261383, 0.04684215411543846, 0.05298209935426712, 0.0], [0.016102498397231102, 0.0006646174006164074, 0.00315408268943429, 0.003398373955860734, 0.01210782676935196, 0.07864897698163986, 0.743419349193573, 0.023116787895560265, 0.11938738822937012, 0.0], [0.0031801864970475435, 0.0032259617000818253, 0.027063841000199318, 0.0018325509736314416, 0.006064774002879858, 0.017839375883340836, 0.05006564408540726, 0.8002738952636719, 0.0904538482427597, 0.0], [0.02500138245522976, 0.016465606167912483, 0.02692888118326664, 0.01824249140918255, 0.047875918447971344, 0.06556686758995056, 0.15585453808307648, 0.21941381692886353, 0.42465049028396606, 0.0], [0.07641319185495377, 0.017753547057509422, 0.039497166872024536, 0.014236720278859138, 0.03872253745794296, 0.1210501492023468, 0.17305448651313782, 0.2333979308605194, 0.28587427735328674, 0.0]], [[0.15564993023872375, 0.3264511823654175, 0.08247561007738113, 0.04047680273652077, 0.04636594280600548, 0.03705644607543945, 0.05653020739555359, 0.08808662742376328, 0.16690711677074432, 0.0], [0.6047166585922241, 0.08402378112077713, 0.11650887131690979, 0.004807815421372652, 0.02726476825773716, 0.0609126091003418, 0.02905944734811783, 0.012920884415507317, 0.059785205870866776, 0.0], [0.5938906669616699, 0.07300958037376404, 0.08890929818153381, 0.008111076429486275, 0.04038470610976219, 0.07353192567825317, 0.03085281327366829, 0.08706387132406235, 0.004246041644364595, 0.0], [0.2591831088066101, 0.17658700048923492, 0.44177621603012085, 0.01689036749303341, 0.0653892457485199, 0.01502177957445383, 0.02055797167122364, 0.0024378441739827394, 0.0021566858049482107, 0.0], [0.33400091528892517, 0.03927909955382347, 0.27614372968673706, 0.009977479465305805, 0.12025652825832367, 0.1713484674692154, 0.04292818158864975, 0.004225345328450203, 0.00184013566467911, 0.0], [0.06147114187479019, 0.019044799730181694, 0.059415291994810104, 0.05198045074939728, 0.12181691080331802, 0.419679194688797, 0.1140735000371933, 0.14551687240600586, 0.00700181070715189, 0.0], [0.006845483556389809, 0.002091927919536829, 0.01196279563009739, 0.014390786178410053, 0.02692629024386406, 0.8455513715744019, 0.07174734026193619, 0.017689114436507225, 0.0027949714567512274, 0.0], [0.00039940490387380123, 0.00013551976007875055, 0.020663700997829437, 0.008696838282048702, 0.021915050223469734, 0.1381293535232544, 0.0347108468413353, 0.7650054097175598, 0.010343861766159534, 0.0], [0.02615724503993988, 0.0051858089864254, 0.038734134286642075, 0.021585455164313316, 0.19684533774852753, 0.17548950016498566, 0.1665634661912918, 0.2796759307384491, 0.08976294845342636, 0.0], [0.043001022189855576, 0.016749290749430656, 0.04958483204245567, 0.06659381091594696, 0.0702962800860405, 0.27735820412635803, 0.14212922751903534, 0.20686522126197815, 0.12742231786251068, 0.0]]], [[[0.13086311519145966, 0.049477167427539825, 0.10100015252828598, 0.03843620419502258, 0.27287009358406067, 0.20078831911087036, 0.16546384990215302, 0.03368193656206131, 0.007419050205498934, 0.0], [0.1137659102678299, 0.11250672489404678, 0.21935509145259857, 0.09974226355552673, 0.22245454788208008, 0.11022598296403885, 0.0977952778339386, 0.010162456892430782, 0.013991687446832657, 0.0], [0.09118296205997467, 0.0991944894194603, 0.31555840373039246, 0.16625922918319702, 0.1399575173854828, 0.0926588773727417, 0.021735703572630882, 0.056496523320674896, 0.016956249251961708, 0.0], [0.35773080587387085, 0.19870112836360931, 0.026073846966028214, 0.07347559928894043, 0.09251826256513596, 0.0859094187617302, 0.06421677768230438, 0.06334269791841507, 0.0380314365029335, 0.0], [0.02230222336947918, 0.0210218857973814, 0.024334343150258064, 0.36442241072654724, 0.2750929892063141, 0.13295342028141022, 0.06824173033237457, 0.0036951478105038404, 0.0879359245300293, 0.0], [0.018942566588521004, 0.011805560439825058, 0.04696377366781235, 0.09440026432275772, 0.39890599250793457, 0.17608429491519928, 0.10613365471363068, 0.10454639047384262, 0.04221746698021889, 0.0], [0.0475851334631443, 0.008668179623782635, 0.011950161308050156, 0.0786907747387886, 0.09432563930749893, 0.07653870433568954, 0.4287588894367218, 0.13403372466564178, 0.1194487139582634, 0.0], [0.008243327029049397, 0.006908380892127752, 0.04044030234217644, 0.08380357921123505, 0.1593569815158844, 0.1858288198709488, 0.0890916958451271, 0.40247857570648193, 0.02384827472269535, 0.0], [0.09753390401601791, 0.04787491634488106, 0.10570236295461655, 0.09989321976900101, 0.07242950052022934, 0.16000299155712128, 0.13195638358592987, 0.12870465219020844, 0.15590202808380127, 0.0], [0.3338638246059418, 0.05386793985962868, 0.15485166013240814, 0.05483235418796539, 0.052468191832304, 0.12754301726818085, 0.13515245914459229, 0.06475869566202164, 0.022661946713924408, 0.0]], [[0.011833908967673779, 0.03545977920293808, 0.03510122373700142, 0.06200635805726051, 0.09438431262969971, 0.06055876612663269, 0.053256530314683914, 0.30701303482055664, 0.3403860926628113, 0.0], [0.03663749620318413, 0.06511621922254562, 0.05716057866811752, 0.07533077895641327, 0.10846659541130066, 0.037432827055454254, 0.04480022192001343, 0.18166707456111908, 0.39338818192481995, 0.0], [0.06557667255401611, 0.03966936469078064, 0.008358842693269253, 0.06794404983520508, 0.05668830871582031, 0.02720261737704277, 0.07913517951965332, 0.20437636971473694, 0.45104852318763733, 0.0], [0.044038429856300354, 0.07477934658527374, 0.10143070667982101, 0.16204005479812622, 0.06265459954738617, 0.10170722752809525, 0.08676454424858093, 0.0699862688779831, 0.2965989410877228, 0.0], [0.06005045771598816, 0.046840403228998184, 0.06629239022731781, 0.04125581681728363, 0.007815167307853699, 0.20412082970142365, 0.1083299070596695, 0.04942404478788376, 0.41587093472480774, 0.0], [0.03666035085916519, 0.028792625293135643, 0.06887229532003403, 0.18481910228729248, 0.15058831870555878, 0.048441674560308456, 0.0780390277504921, 0.13469383120536804, 0.26909276843070984, 0.0], [0.03408746421337128, 0.026394939050078392, 0.05409233644604683, 0.06951043754816055, 0.1446777582168579, 0.09970070421695709, 0.05472328141331673, 0.16119606792926788, 0.35561704635620117, 0.0], [0.12936006486415863, 0.04621516913175583, 0.10149524360895157, 0.14774896204471588, 0.45855623483657837, 0.033130910247564316, 0.031401973217725754, 0.02012830227613449, 0.031963150948286057, 0.0], [0.1214270144701004, 0.04088712856173515, 0.05250505730509758, 0.07924661785364151, 0.05337269604206085, 0.10527284443378448, 0.08820997178554535, 0.17732012271881104, 0.28175854682922363, 0.0], [0.13074854016304016, 0.06475767493247986, 0.07325490564107895, 0.0625966489315033, 0.14061231911182404, 0.07830052822828293, 0.12438739091157913, 0.21453101933002472, 0.11081094294786453, 0.0]], [[0.0022766904439777136, 0.00227623013779521, 0.027263110503554344, 0.7988243699073792, 0.12335250526666641, 0.012830986641347408, 0.008179515600204468, 0.004631126299500465, 0.020365260541439056, 0.0], [0.022365765646100044, 0.0197063609957695, 0.08540411293506622, 0.7100865840911865, 0.10288897156715393, 0.023861246183514595, 0.009303209371864796, 0.012690575793385506, 0.013693095184862614, 0.0], [0.023093748837709427, 0.013999207876622677, 0.09048538655042648, 0.10519850999116898, 0.12126202881336212, 0.34847554564476013, 0.057331401854753494, 0.0919070839881897, 0.14824725687503815, 0.0], [0.03627682104706764, 0.0323517769575119, 0.06003699079155922, 0.04609783738851547, 0.3189731240272522, 0.3202785551548004, 0.06900984793901443, 0.021341597661376, 0.0956336110830307, 0.0], [0.026664189994335175, 0.018690558150410652, 0.01473171729594469, 0.003785684471949935, 0.012891196645796299, 0.6301508545875549, 0.1024516150355339, 0.10377107560634613, 0.08686315268278122, 0.0], [0.010066811926662922, 0.005272349342703819, 0.019913937896490097, 0.005584465805441141, 0.0479762889444828, 0.06466472148895264, 0.2978198528289795, 0.22872935235500336, 0.31997203826904297, 0.0], [0.054553788155317307, 0.011876759119331837, 0.005296430550515652, 0.008171333000063896, 0.17499762773513794, 0.29638832807540894, 0.22286026179790497, 0.017016055062413216, 0.20883934199810028, 0.0], [0.03061697818338871, 0.020777547731995583, 0.27117541432380676, 0.010558649897575378, 0.16651615500450134, 0.3011224865913391, 0.026109976693987846, 0.048922766000032425, 0.12420005351305008, 0.0], [0.16545239090919495, 0.03877135366201401, 0.007565324194729328, 0.015141250565648079, 0.03747279569506645, 0.3241279125213623, 0.26990416646003723, 0.043362975120544434, 0.09820175170898438, 0.0], [0.22949647903442383, 0.0972394198179245, 0.02905140444636345, 0.03182214871048927, 0.025490015745162964, 0.08278947323560715, 0.15009135007858276, 0.031098822131752968, 0.3229208290576935, 0.0]], [[0.023217031732201576, 0.015444980934262276, 0.33269768953323364, 0.4809305965900421, 0.08491171896457672, 0.027504485100507736, 0.007655052933841944, 0.015150148421525955, 0.012488299049437046, 0.0], [0.003814368275925517, 0.0054845609702169895, 0.005400203168392181, 0.34217125177383423, 0.010647634975612164, 0.00044525362318381667, 0.00011972449283348396, 0.00042839962407015264, 0.6314883828163147, 0.0], [0.013448912650346756, 0.01028169970959425, 0.4982297718524933, 0.3182436525821686, 0.01780710555613041, 0.024587348103523254, 0.0009282209794037044, 0.11607228964567184, 0.0004009671974927187, 0.0], [0.0027270291466265917, 0.01338754128664732, 0.019254636019468307, 0.11856623739004135, 0.0025901400949805975, 0.0012062221067026258, 0.0006161375786177814, 0.0012282256502658129, 0.8404240608215332, 0.0], [1.802536098693963e-05, 0.0005015733768232167, 2.3977232558536343e-05, 0.00012258262722752988, 0.00013862864580005407, 1.9367420463822782e-05, 1.2695372788584791e-05, 2.8395381377777085e-05, 0.9991349577903748, 0.0], [0.045823611319065094, 0.0060311248525977135, 0.11489683389663696, 0.011397628113627434, 0.14236140251159668, 0.31853923201560974, 0.18707275390625, 0.16781283915042877, 0.006064609158784151, 0.0], [0.031908370554447174, 0.0013231962220743299, 0.03774190694093704, 0.014869065955281258, 0.08836144208908081, 0.662682056427002, 0.1095389723777771, 0.05017231032252312, 0.0034025281202048063, 0.0], [0.0061959377489984035, 0.012075785547494888, 0.28881579637527466, 0.0719127431511879, 0.08756363391876221, 0.0848873034119606, 0.027471251785755157, 0.404219388961792, 0.016858302056789398, 0.0], [0.0946543961763382, 0.0623893216252327, 0.18748056888580322, 0.1788652539253235, 0.03208017721772194, 0.1587594598531723, 0.05469479411840439, 0.17047303915023804, 0.06060296297073364, 0.0], [0.019481608644127846, 0.068674735724926, 0.13537795841693878, 0.2137300968170166, 0.031131863594055176, 0.02376358024775982, 0.030956387519836426, 0.04989796131849289, 0.4269856810569763, 0.0]], [[0.00896595511585474, 0.001820763573050499, 0.0036846648436039686, 0.8942996859550476, 0.002699120668694377, 0.0018430916825309396, 0.00023619653075002134, 0.0008667120710015297, 0.08558366447687149, 0.0], [0.011139868758618832, 0.00517098605632782, 0.03486357256770134, 0.92783522605896, 0.010794212110340595, 0.0029791113920509815, 0.0008399260113947093, 0.0003134821599815041, 0.006063643377274275, 0.0], [0.07888396829366684, 0.0272236131131649, 0.0322146937251091, 0.791079044342041, 0.03133838623762131, 0.009372375905513763, 0.002263500588014722, 0.0005359782953746617, 0.02708848938345909, 0.0], [0.008838528767228127, 0.0009813528740778565, 0.014693140052258968, 0.00012726498243864626, 0.013269715011119843, 0.06431703269481659, 0.0039668334648013115, 0.8607616424560547, 0.0330444760620594, 0.0], [0.028727378696203232, 0.001701394678093493, 0.0009593431605026126, 0.0036824517883360386, 0.009683175943791866, 0.2589351236820221, 0.040837112814188004, 0.01649528741836548, 0.6389787197113037, 0.0], [0.009239337407052517, 0.0011580593418329954, 0.0009623299702070653, 0.000996780814602971, 0.00493139773607254, 0.04319336265325546, 0.859686553478241, 0.012395362369716167, 0.06743697822093964, 0.0], [0.024199873208999634, 0.007249501068145037, 0.02041051909327507, 0.008800184354186058, 0.02760438062250614, 0.1116553395986557, 0.030366744846105576, 0.03851965814828873, 0.7311937808990479, 0.0], [0.06881897896528244, 0.21671976149082184, 0.02303808182477951, 0.0017656114650890231, 0.09897635877132416, 0.04207116737961769, 0.012660021893680096, 0.25307658314704895, 0.2828734517097473, 0.0], [0.09324429929256439, 0.059572815895080566, 0.021969754248857498, 0.008625463582575321, 0.022502752020955086, 0.07016356289386749, 0.033860694617033005, 0.03514377400279045, 0.6549169421195984, 0.0], [0.04541633278131485, 0.01696496643126011, 0.003866765182465315, 0.00941139180213213, 0.006640681531280279, 0.024550199508666992, 0.009012367576360703, 0.009869653731584549, 0.8742677569389343, 0.0]], [[0.007143852766603231, 0.26111796498298645, 0.053768061101436615, 0.022731401026248932, 0.014146089553833008, 0.012985849753022194, 0.007359612733125687, 0.0043042986653745174, 0.6164429783821106, 0.0], [0.6188079118728638, 0.11004422605037689, 0.07541824132204056, 0.010463211685419083, 0.003863272722810507, 0.016659650951623917, 0.028880171477794647, 0.010046081617474556, 0.1258174628019333, 0.0], [0.1673126220703125, 0.6515482068061829, 0.016748156398534775, 0.042502570897340775, 0.016912223771214485, 0.011716129258275032, 0.04548521339893341, 0.0008787817787379026, 0.04689598083496094, 0.0], [0.6625117659568787, 0.049922335892915726, 0.2738172709941864, 0.004228150937706232, 0.0033112652599811554, 0.001177642960101366, 0.0005330604617483914, 0.00011132613872177899, 0.0043872324749827385, 0.0], [0.0040039438754320145, 0.00112480903044343, 0.04353015124797821, 0.9313303232192993, 0.010056668892502785, 0.0007567661814391613, 0.0006773694767616689, 0.00016374654660467058, 0.008356312289834023, 0.0], [0.0015825676964595914, 0.001574154943227768, 0.001225732616148889, 0.27774307131767273, 0.47191065549850464, 0.041899941861629486, 0.10331469774246216, 0.0047262245789170265, 0.0960230901837349, 0.0], [0.0012044805334880948, 0.001744594657793641, 0.010911357589066029, 0.035235531628131866, 0.12406003475189209, 0.49639585614204407, 0.02129644714295864, 0.07618547230958939, 0.23296628892421722, 0.0], [0.006665610242635012, 0.008957373909652233, 0.0028928713873028755, 0.7268922924995422, 0.10707614570856094, 0.01201178040355444, 0.013845101930201054, 0.022992080077528954, 0.09866661578416824, 0.0], [0.04134861007332802, 0.0526767373085022, 0.04131396487355232, 0.023087071254849434, 0.04077164828777313, 0.027765633538365364, 0.05679082125425339, 0.025407245382666588, 0.6908383369445801, 0.0], [0.0021411015186458826, 0.012145284563302994, 0.008635377511382103, 0.004571457393467426, 0.009789393283426762, 0.022923681885004044, 0.019266795367002487, 0.15913596749305725, 0.7613908052444458, 0.0]], [[0.1481553614139557, 0.14691436290740967, 0.5575758218765259, 0.02441403828561306, 0.058879025280475616, 0.011832842603325844, 0.01016098354011774, 0.015505112707614899, 0.026562504470348358, 0.0], [0.20737145841121674, 0.2658809721469879, 0.4251604974269867, 0.03998560830950737, 0.012661930173635483, 0.003662273520603776, 0.0006891252705827355, 0.004390099551528692, 0.040197838097810745, 0.0], [0.09591562300920486, 0.13717111945152283, 0.23219715058803558, 0.020156029611825943, 0.031411558389663696, 0.04842779412865639, 0.003137993859127164, 0.03623202070593834, 0.3953508734703064, 0.0], [0.039530280977487564, 0.0770052894949913, 0.4637334942817688, 0.3284752666950226, 0.018390586599707603, 0.021701356396079063, 0.0038800504989922047, 0.01712900958955288, 0.03015456721186638, 0.0], [0.11945435404777527, 0.064958855509758, 0.07506249845027924, 0.03312050923705101, 0.045947931706905365, 0.21168209612369537, 0.1585550606250763, 0.15941208600997925, 0.13180643320083618, 0.0], [0.027811916545033455, 0.005367752630263567, 0.022701909765601158, 0.02928026206791401, 0.042085714638233185, 0.23124846816062927, 0.3448639512062073, 0.17164644598960876, 0.12499356269836426, 0.0], [0.04207267239689827, 0.0036673955619335175, 0.013221505098044872, 0.04823020473122597, 0.018784234300255775, 0.15617071092128754, 0.5762590169906616, 0.08817830681800842, 0.053416069597005844, 0.0], [0.016374358907341957, 0.01844160258769989, 0.0564517118036747, 0.008724928833544254, 0.031119121238589287, 0.08068697899580002, 0.028958600014448166, 0.08762799203395844, 0.6716147661209106, 0.0], [0.01953587494790554, 0.025442129001021385, 0.033712055534124374, 0.054231878370046616, 0.046861547976732254, 0.038379911333322525, 0.03105914779007435, 0.027592265978455544, 0.723185122013092, 0.0], [0.13706038892269135, 0.05296454578638077, 0.06056801974773407, 0.10271193832159042, 0.10989244282245636, 0.11971112340688705, 0.10623226314783096, 0.11503037810325623, 0.19582894444465637, 0.0]], [[0.1995955854654312, 0.1365700364112854, 0.06844333559274673, 0.10430964082479477, 0.06450515240430832, 0.046256136149168015, 0.1181989535689354, 0.11867640167474747, 0.14344464242458344, 0.0], [0.11168574541807175, 0.19221879541873932, 0.09424428641796112, 0.10450402647256851, 0.06917304545640945, 0.0600862056016922, 0.14199501276016235, 0.09375911951065063, 0.1323338896036148, 0.0], [0.1609925925731659, 0.1396498680114746, 0.177944153547287, 0.03334498405456543, 0.016808874905109406, 0.10536731034517288, 0.1187783032655716, 0.03365077078342438, 0.2134632021188736, 0.0], [0.1381153017282486, 0.04422784969210625, 0.04791303351521492, 0.16848880052566528, 0.14531251788139343, 0.08485772460699081, 0.03650972992181778, 0.08906612545251846, 0.24550898373126984, 0.0], [0.05420248210430145, 0.02658572979271412, 0.05446610227227211, 0.10749125480651855, 0.22097598016262054, 0.16638338565826416, 0.0331658273935318, 0.035041358321905136, 0.30168798565864563, 0.0], [0.13837963342666626, 0.02727973647415638, 0.1299334168434143, 0.0896206796169281, 0.11551950126886368, 0.13963927328586578, 0.07841819524765015, 0.02172034978866577, 0.25948914885520935, 0.0], [0.08648376911878586, 0.022083481773734093, 0.023758457973599434, 0.19388236105442047, 0.1724909394979477, 0.02776450663805008, 0.04756799340248108, 0.03839871287345886, 0.38756975531578064, 0.0], [0.06476524472236633, 0.0030945511534810066, 0.016289785504341125, 0.013512126170098782, 0.007217712234705687, 0.047962453216314316, 0.03755675256252289, 0.02134101279079914, 0.788260281085968, 0.0], [0.06007339805364609, 0.038077425211668015, 0.01732070930302143, 0.04335314407944679, 0.09849875420331955, 0.07123422622680664, 0.12536978721618652, 0.17402620613574982, 0.37204641103744507, 0.0], [0.31619662046432495, 0.10629935562610626, 0.051193755120038986, 0.08206456899642944, 0.08056272566318512, 0.05727463215589523, 0.13476009666919708, 0.03839832916855812, 0.13325001299381256, 0.0]]], [[[0.0077307759784162045, 0.013184988871216774, 0.016869038343429565, 0.013336911797523499, 0.01304439827799797, 0.013718237169086933, 0.0296618789434433, 0.02448520064353943, 0.8679684996604919, 0.0], [0.1374177634716034, 0.018056754022836685, 0.029763542115688324, 0.004862301517277956, 0.00231130956672132, 0.006278112530708313, 0.012106452137231827, 0.033879589289426804, 0.755324125289917, 0.0], [0.1539316475391388, 0.23461903631687164, 0.033691998571157455, 0.026462335139513016, 0.0030949951615184546, 0.0038835303857922554, 0.009438932873308659, 0.0025479686446487904, 0.5323294997215271, 0.0], [0.29215019941329956, 0.05790534242987633, 0.18934868276119232, 0.018473153933882713, 0.002999690594151616, 0.004652327857911587, 0.010374259203672409, 0.0072145056910812855, 0.4168816804885864, 0.0], [0.16539201140403748, 0.05819307267665863, 0.12084146589040756, 0.1738077849149704, 0.004504370968788862, 0.006831282749772072, 0.02180996537208557, 0.012287246994674206, 0.4363327622413635, 0.0], [0.017602156847715378, 0.026805447414517403, 0.07671570032835007, 0.5152483582496643, 0.21202509105205536, 0.041201505810022354, 0.02207496576011181, 0.00952092744410038, 0.07880578190088272, 0.0], [0.017750855535268784, 0.01654047518968582, 0.07482129335403442, 0.23223723471164703, 0.3542158007621765, 0.16141267120838165, 0.02249749004840851, 0.044686269015073776, 0.07583795487880707, 0.0], [0.01755565032362938, 0.008823209442198277, 0.013083218596875668, 0.5061533451080322, 0.02344801276922226, 0.0075739468447864056, 0.07187878340482712, 0.029167035594582558, 0.3223167061805725, 0.0], [0.07037408649921417, 0.05453738570213318, 0.05508268624544144, 0.02769530564546585, 0.038050826638936996, 0.20446287095546722, 0.19980187714099884, 0.19835616648197174, 0.15163862705230713, 0.0], [0.003375247586518526, 0.008793321438133717, 0.001001630094833672, 0.002094975672662258, 0.0032946632709354162, 0.01792662777006626, 0.10988471657037735, 0.21093924343585968, 0.6426896452903748, 0.0]], [[0.07041527330875397, 0.15367093682289124, 0.3963199257850647, 0.03077671490609646, 0.0928598940372467, 0.04086732864379883, 0.018142100423574448, 0.012120239436626434, 0.18482762575149536, 0.0], [0.0367966964840889, 0.022482391446828842, 0.35830163955688477, 0.02875097654759884, 0.03547174483537674, 0.026731541380286217, 0.005365677177906036, 0.038472291082143784, 0.4476269483566284, 0.0], [0.035722482949495316, 0.013633755035698414, 0.09877835214138031, 0.0896211713552475, 0.16777706146240234, 0.10725134611129761, 0.05053357034921646, 0.10712091624736786, 0.32956135272979736, 0.0], [0.026592494919896126, 0.002020884770900011, 0.010739283636212349, 0.015951883047819138, 0.18538028001785278, 0.16766443848609924, 0.03731367364525795, 0.3853055238723755, 0.16903170943260193, 0.0], [0.0038862666115164757, 0.0004870722477789968, 0.0013956124894320965, 0.001421120367012918, 0.013834443874657154, 0.18579153716564178, 0.18213258683681488, 0.5258954763412476, 0.08515587449073792, 0.0], [0.002724156714975834, 0.004415807779878378, 0.003638928523287177, 0.00862019695341587, 0.010569852776825428, 0.18068262934684753, 0.2256886065006256, 0.3616458773612976, 0.20201392471790314, 0.0], [0.004798010922968388, 0.006960091646760702, 0.005558326840400696, 0.015252271667122841, 0.010058294981718063, 0.16163121163845062, 0.19844789803028107, 0.089196115732193, 0.508097767829895, 0.0], [0.006148195825517178, 0.009931370615959167, 0.0022139709908515215, 0.003481896361336112, 0.00199966412037611, 0.011451391503214836, 0.018514955416321754, 0.10389390587806702, 0.8423647284507751, 0.0], [0.1474374532699585, 0.1116698756814003, 0.10040155798196793, 0.07832593470811844, 0.051569730043411255, 0.103182852268219, 0.0987909808754921, 0.06264416873455048, 0.24597744643688202, 0.0], [0.0029639359563589096, 0.0008504446013830602, 0.002215511864051223, 0.0016108372947201133, 0.001786046545021236, 0.003435377962887287, 0.000923731888178736, 0.009203500114381313, 0.9770104885101318, 0.0]], [[0.022387586534023285, 0.045972827821969986, 0.05835629999637604, 0.22869053483009338, 0.010770916007459164, 0.006216464098542929, 0.018148910254240036, 0.006308646872639656, 0.6031478047370911, 0.0], [0.0035997454542666674, 0.002674269489943981, 0.016009783372282982, 0.05554450675845146, 0.0013587778666988015, 0.0032801039051264524, 0.00560772093012929, 0.00799081102013588, 0.9039342403411865, 0.0], [0.0018151046242564917, 0.001049908110871911, 0.0005912692868150771, 0.005136367864906788, 0.0005621784366667271, 0.00844560656696558, 0.017937110736966133, 0.008342047221958637, 0.9561205506324768, 0.0], [0.004429707303643227, 0.005516116041690111, 0.003033371875062585, 0.012963998131453991, 0.0034379358403384686, 0.003276604227721691, 0.0140963364392519, 0.005416945554316044, 0.9478288888931274, 0.0], [0.013230006210505962, 0.011804360896348953, 0.009972047992050648, 0.004975683055818081, 0.008386109955608845, 0.18977868556976318, 0.1806434541940689, 0.03204761818051338, 0.549161970615387, 0.0], [0.006766628473997116, 0.008349079638719559, 0.003925195895135403, 0.0006033667596057057, 0.006175691727548838, 0.2236345112323761, 0.03405819088220596, 0.07976362109184265, 0.6367236971855164, 0.0], [0.03860252723097801, 0.01646261475980282, 0.02104821614921093, 0.0021387943997979164, 0.005319601856172085, 0.2400989532470703, 0.03188503161072731, 0.005558657925575972, 0.6388856768608093, 0.0], [0.0738457515835762, 0.018826894462108612, 0.0069308048114180565, 0.0074225678108632565, 0.004789229016751051, 0.046955253928899765, 0.11907684803009033, 0.18744726479053497, 0.5347052812576294, 0.0], [0.09479796141386032, 0.11939200013875961, 0.0752992108464241, 0.061374519020318985, 0.08638977259397507, 0.12459041178226471, 0.16023214161396027, 0.0879756435751915, 0.1899482160806656, 0.0], [0.027537798509001732, 0.06296242028474808, 0.014751194976270199, 0.0011882808757945895, 0.016387099400162697, 0.15830224752426147, 0.03707461059093475, 0.028470970690250397, 0.6533253788948059, 0.0]], [[0.044569190591573715, 0.00917287077754736, 0.004391324240714312, 0.8386606574058533, 0.06130588799715042, 0.003870139131322503, 0.007488539442420006, 0.028126200661063194, 0.002415221417322755, 0.0], [0.08009635657072067, 0.016815535724163055, 0.012093844823539257, 0.1592065542936325, 0.5643750429153442, 0.02920410968363285, 0.0919446051120758, 0.036902546882629395, 0.009361499920487404, 0.0], [0.08603464066982269, 0.01933746039867401, 0.05900268629193306, 0.2806539237499237, 0.22094620764255524, 0.08643656224012375, 0.026435989886522293, 0.1974046230316162, 0.023747902363538742, 0.0], [0.011109462939202785, 0.008883769623935223, 0.006091873627156019, 0.9400036931037903, 0.020445559173822403, 0.0056496066972613335, 0.0019461432239040732, 0.005268549080938101, 0.000601345207542181, 0.0], [0.12766019999980927, 0.021774157881736755, 0.08726640790700912, 0.0718328207731247, 0.053083695471286774, 0.3031027019023895, 0.06321869790554047, 0.2611844837665558, 0.010876962915062904, 0.0], [0.06407223641872406, 0.11627303808927536, 0.1807759404182434, 0.0054795523174107075, 0.026687098667025566, 0.09637009352445602, 0.052303463220596313, 0.4456423819065094, 0.012396130710840225, 0.0], [0.09725438803434372, 0.15948796272277832, 0.05173082649707794, 0.01153761800378561, 0.0721999853849411, 0.059252724051475525, 0.11923323571681976, 0.05380275845527649, 0.3755004107952118, 0.0], [0.0974307581782341, 0.23218762874603271, 0.06967660784721375, 0.031012043356895447, 0.04906507954001427, 0.31767621636390686, 0.08231117576360703, 0.07159094512462616, 0.04904941841959953, 0.0], [0.16388627886772156, 0.17526376247406006, 0.07081529498100281, 0.17886894941329956, 0.07944575697183609, 0.07640470564365387, 0.0757102444767952, 0.04333823174238205, 0.13626690208911896, 0.0], [0.09265941381454468, 0.11633585393428802, 0.04908691346645355, 0.0062498715706169605, 0.07016508281230927, 0.012818480841815472, 0.0484321266412735, 0.015437646768987179, 0.5888146162033081, 0.0]], [[0.005850312765687704, 0.017421673983335495, 0.004798548296093941, 0.008814580738544464, 0.00403921864926815, 0.015260725282132626, 0.03377071022987366, 0.009620469063520432, 0.9004237651824951, 0.0], [0.013436811044812202, 0.03272867575287819, 0.005969575606286526, 0.02213078737258911, 0.008325905539095402, 0.015314633026719093, 0.027177294716238976, 0.017041552811861038, 0.8578747510910034, 0.0], [0.012305106967687607, 0.03383316844701767, 0.010593314655125141, 0.027156231924891472, 0.00306991720572114, 0.004844812210649252, 0.018964877352118492, 0.05307865887880325, 0.8361539244651794, 0.0], [0.01255274098366499, 0.03600494936108589, 0.010369472205638885, 0.019021298736333847, 0.0032906190026551485, 0.0037067385856062174, 0.017627976834774017, 0.01037716306746006, 0.8870489597320557, 0.0], [0.0020879805088043213, 0.013872731477022171, 0.0018324662232771516, 0.006437606178224087, 0.013170951046049595, 0.011930068954825401, 0.0030771365854889154, 0.018353432416915894, 0.9292376041412354, 0.0], [0.022865016013383865, 0.0674654096364975, 0.00996339786797762, 0.01914660632610321, 0.014956261031329632, 0.026097828522324562, 0.018910687416791916, 0.06562207639217377, 0.7549726963043213, 0.0], [0.019515078514814377, 0.07521340996026993, 0.03206341341137886, 0.0070005785673856735, 0.0066195218823850155, 0.03877842426300049, 0.01228683814406395, 0.032381508499383926, 0.7761411666870117, 0.0], [0.0022571254521608353, 0.00383052253164351, 0.0012509305961430073, 0.005982697941362858, 0.001252268673852086, 0.0028570422437042, 0.00556317949667573, 0.7337145805358887, 0.24329175055027008, 0.0], [0.132036030292511, 0.1683972030878067, 0.13758207857608795, 0.14189518988132477, 0.03147142380475998, 0.047566916793584824, 0.07834812998771667, 0.12177446484565735, 0.1409287303686142, 0.0], [0.038695693016052246, 0.06063547730445862, 0.020152689889073372, 0.1101006418466568, 0.021127784624695778, 0.02848564088344574, 0.03705665469169617, 0.108894944190979, 0.5748504996299744, 0.0]], [[0.0842718631029129, 0.33930304646492004, 0.1421334594488144, 0.18528752028942108, 0.05815916135907173, 0.022830937057733536, 0.01860896497964859, 0.009871570393443108, 0.13953347504138947, 0.0], [0.1295168399810791, 0.08884051442146301, 0.06592670828104019, 0.2686370015144348, 0.02522267960011959, 0.03633918985724449, 0.021549394354224205, 0.051057688891887665, 0.31290990114212036, 0.0], [0.20177747309207916, 0.039902716875076294, 0.053595658391714096, 0.09988140314817429, 0.01657777465879917, 0.07154539972543716, 0.024320384487509727, 0.12353017926216125, 0.36886897683143616, 0.0], [0.20979411900043488, 0.30265647172927856, 0.20804975926876068, 0.007371237967163324, 0.0033807901199907064, 0.02442527562379837, 0.017248263582587242, 0.022337088361382484, 0.2047368586063385, 0.0], [0.20739784836769104, 0.05559583380818367, 0.12535981833934784, 0.009768493473529816, 0.015522800385951996, 0.024528708308935165, 0.03864477947354317, 0.08712086826562881, 0.4360608160495758, 0.0], [0.054182324558496475, 0.0038395673036575317, 0.019914912059903145, 0.014234894886612892, 0.012772555463016033, 0.019022708758711815, 0.04023807495832443, 0.36886361241340637, 0.46693122386932373, 0.0], [0.10943998396396637, 0.007069440558552742, 0.019821595400571823, 0.012627309188246727, 0.016869045794010162, 0.05302179232239723, 0.05124732851982117, 0.14304620027542114, 0.5868573188781738, 0.0], [0.23190192878246307, 0.060554418712854385, 0.05880776792764664, 0.00438718032091856, 0.00454165181145072, 0.15464532375335693, 0.09585105627775192, 0.02281157113611698, 0.3664989471435547, 0.0], [0.13208958506584167, 0.15704363584518433, 0.07176639884710312, 0.08554346114397049, 0.0733223557472229, 0.0956358015537262, 0.07472448796033859, 0.09573546797037125, 0.21413877606391907, 0.0], [0.28021925687789917, 0.045415956526994705, 0.04812552034854889, 0.00880114920437336, 0.012029618956148624, 0.04001859948039055, 0.0577121265232563, 0.02487611398100853, 0.4828015863895416, 0.0]], [[0.057871319353580475, 0.09845025092363358, 0.03600643575191498, 0.06401734054088593, 0.07263048738241196, 0.014885936863720417, 0.07473781704902649, 0.1193607747554779, 0.4620397090911865, 0.0], [0.04736582189798355, 0.06951819360256195, 0.039210546761751175, 0.040616557002067566, 0.05645532160997391, 0.01900673843920231, 0.063181072473526, 0.23291724920272827, 0.43172842264175415, 0.0], [0.05375257506966591, 0.04091374948620796, 0.01263821218162775, 0.04125160351395607, 0.014244006015360355, 0.012229752726852894, 0.029117466881871223, 0.07314542680978775, 0.7227071523666382, 0.0], [0.04062311723828316, 0.2312910407781601, 0.20085060596466064, 0.03848586603999138, 0.04763459786772728, 0.013425372540950775, 0.027237186208367348, 0.03882591798901558, 0.3616262376308441, 0.0], [0.04558584839105606, 0.04037311673164368, 0.043737076222896576, 0.02740027941763401, 0.005366531666368246, 0.014126299880445004, 0.07268305867910385, 0.014923120848834515, 0.7358046770095825, 0.0], [0.04308110475540161, 0.037618398666381836, 0.054927192628383636, 0.045146394520998, 0.02157701551914215, 0.014024189673364162, 0.03546718508005142, 0.04130468890070915, 0.7068538665771484, 0.0], [0.049693867564201355, 0.040712278336286545, 0.011129319667816162, 0.08677691221237183, 0.24132831394672394, 0.028864668682217598, 0.04710082337260246, 0.028962818905711174, 0.46543097496032715, 0.0], [0.0256647989153862, 0.026453843340277672, 0.1064542606472969, 0.07867259532213211, 0.03285365179181099, 0.056291256099939346, 0.026517342776060104, 0.014768523164093494, 0.6323237419128418, 0.0], [0.07950135320425034, 0.04906205087900162, 0.09099037200212479, 0.10450085997581482, 0.06846266984939575, 0.21755923330783844, 0.14818403124809265, 0.14456483721733093, 0.0971745029091835, 0.0], [0.0029191188514232635, 0.0026039828080683947, 0.000987313687801361, 0.027727283537387848, 0.007311245426535606, 0.0033244043588638306, 0.023969389498233795, 0.00596341909840703, 0.9251939058303833, 0.0]], [[0.3116825819015503, 0.20666195452213287, 0.1363646388053894, 0.07141851633787155, 0.029045483097434044, 0.04730900749564171, 0.037391580641269684, 0.10128220915794373, 0.05884409323334694, 0.0], [0.16883184015750885, 0.1785622239112854, 0.23318006098270416, 0.05258537083864212, 0.10740725696086884, 0.09185276180505753, 0.022670285776257515, 0.08943870663642883, 0.05547139048576355, 0.0], [0.13844002783298492, 0.030681077390909195, 0.12107612937688828, 0.007168712094426155, 0.05214103311300278, 0.10045275092124939, 0.006991118658334017, 0.2955043315887451, 0.2475447952747345, 0.0], [0.17306901514530182, 0.19963578879833221, 0.148520827293396, 0.2046978771686554, 0.06720028817653656, 0.019652126356959343, 0.05792365223169327, 0.07665500044822693, 0.05264541134238243, 0.0], [0.15115797519683838, 0.06555280834436417, 0.02683498151600361, 0.20794028043746948, 0.17434173822402954, 0.11980342864990234, 0.04239796847105026, 0.03961418569087982, 0.1723567098379135, 0.0], [0.022970011457800865, 0.003322059055790305, 0.03333919122815132, 0.03161727264523506, 0.3007935583591461, 0.20675496757030487, 0.037206847220659256, 0.30415406823158264, 0.059842076152563095, 0.0], [0.029802093282341957, 0.006723356898874044, 0.00844306219369173, 0.09911961853504181, 0.2257867157459259, 0.22737178206443787, 0.28318148851394653, 0.05687837302684784, 0.0626935064792633, 0.0], [0.008192392997443676, 0.002076511736959219, 0.010627061128616333, 0.01573592983186245, 0.01893553137779236, 0.042316026985645294, 0.02403445728123188, 0.868257999420166, 0.009823988191783428, 0.0], [0.08033955097198486, 0.18780502676963806, 0.02817567251622677, 0.041370414197444916, 0.02824225462973118, 0.038844767957925797, 0.11732563376426697, 0.06162348762154579, 0.4162730574607849, 0.0], [0.0563269779086113, 0.007248359732329845, 0.008029782213270664, 0.003040844574570656, 0.007221699226647615, 0.01730128563940525, 0.050128430128097534, 0.3587413430213928, 0.491961270570755, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"], \"bot_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"]}, \"inp_out\": {\"att\": [[[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9198169708251953, 0.0801829993724823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8846490979194641, 0.10308036208152771, 0.012270578183233738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9307316541671753, 0.03309628367424011, 0.027538668364286423, 0.008633385412395, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9335180521011353, 0.020782457664608955, 0.008113296702504158, 0.029529055580496788, 0.008057110011577606, 0.0, 0.0, 0.0, 0.0, 0.0], [0.923790454864502, 0.01269624661654234, 0.004588128533214331, 0.020286502316594124, 0.018672045320272446, 0.019966628402471542, 0.0, 0.0, 0.0, 0.0], [0.5214514136314392, 0.051599469035863876, 0.007387364283204079, 0.04305899888277054, 0.0632161945104599, 0.07775087654590607, 0.2355356514453888, 0.0, 0.0, 0.0], [0.9122877717018127, 0.007671441417187452, 0.0012418286642059684, 0.005250561982393265, 0.001960531808435917, 0.032091617584228516, 0.03012256510555744, 0.009373520500957966, 0.0, 0.0], [0.012450892478227615, 0.0001350480888504535, 0.0001820741599658504, 0.0018266986589878798, 0.00022605709091294557, 0.0032795630395412445, 0.005876350682228804, 0.012136856094002724, 0.9638864398002625, 0.0], [0.907938539981842, 0.003707215888425708, 0.003004483412951231, 0.0008324749651364982, 0.0015859504928812385, 0.008079104125499725, 0.010460118763148785, 0.005838368553668261, 0.038938846439123154, 0.019614921882748604]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4050312936306, 0.5949686765670776, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2333158701658249, 0.39531010389328003, 0.37137407064437866, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.52278733253479, 0.11893566697835922, 0.28584957122802734, 0.07242746651172638, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23179638385772705, 0.09258762001991272, 0.103512242436409, 0.19472002983093262, 0.37738385796546936, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3839746117591858, 0.05338669568300247, 0.09416119009256363, 0.09689370542764664, 0.24871283769607544, 0.12287086993455887, 0.0, 0.0, 0.0, 0.0], [0.5838866233825684, 0.02439245954155922, 0.042716383934020996, 0.03342103213071823, 0.08018141984939575, 0.15234005451202393, 0.08306187391281128, 0.0, 0.0, 0.0], [0.639571487903595, 0.016348807141184807, 0.038869310170412064, 0.02800355665385723, 0.0377902127802372, 0.0529697984457016, 0.07620508968830109, 0.11024164408445358, 0.0, 0.0], [0.5836893320083618, 0.011862898245453835, 0.02550557814538479, 0.009363977238535881, 0.0196645837277174, 0.018125057220458984, 0.07040998339653015, 0.2077602595090866, 0.053618304431438446, 0.0], [0.49946048855781555, 0.04904361814260483, 0.04135226085782051, 0.015084759332239628, 0.018269173800945282, 0.020069265738129616, 0.05080949887633324, 0.09452320635318756, 0.06869905441999435, 0.14268863201141357]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9956012964248657, 0.00439875153824687, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8920916318893433, 0.017498359084129333, 0.09041006118059158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8103601336479187, 0.011479738168418407, 0.14884205162525177, 0.029318034648895264, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9073429107666016, 0.017702236771583557, 0.0008831396116875112, 0.017153160646557808, 0.05691858008503914, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7007134556770325, 0.00013011474220547825, 0.0017889889422804117, 0.00429273396730423, 0.20973503589630127, 0.08333952724933624, 0.0, 0.0, 0.0, 0.0], [0.8020992279052734, 0.0005838978104293346, 0.0002877263759728521, 0.000665249943267554, 0.00924165453761816, 0.10947777330875397, 0.07764454185962677, 0.0, 0.0, 0.0], [0.936653733253479, 0.00026242269086651504, 0.0004762547614518553, 0.000683068297803402, 0.0005867508007213473, 0.008624686859548092, 0.044821251183748245, 0.00789186917245388, 0.0, 0.0], [0.638530433177948, 0.00012756754586007446, 2.6267471184837632e-05, 0.035790614783763885, 0.00038457714254036546, 0.0026843701489269733, 0.0740678533911705, 0.21536435186862946, 0.03302408382296562, 0.0], [0.9069857597351074, 0.0010905838571488857, 0.0003166680980939418, 0.0021527763456106186, 0.00019805191550403833, 0.0004849489778280258, 0.025774035602808, 0.02642407827079296, 0.01662513054907322, 0.01994791068136692]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9964158535003662, 0.0035840808413922787, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.603236198425293, 0.29069802165031433, 0.10606581717729568, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7401933073997498, 0.005742713809013367, 0.18690980970859528, 0.06715414673089981, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9087624549865723, 0.0078224902972579, 0.003505129599943757, 0.0673881471157074, 0.012521738186478615, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7394620180130005, 0.0234938096255064, 0.009907918982207775, 0.01616108976304531, 0.1237591803073883, 0.08721596747636795, 0.0, 0.0, 0.0, 0.0], [0.9526587724685669, 0.007287254091352224, 0.0013716809917241335, 0.0023222684394568205, 0.007607423700392246, 0.009167732670903206, 0.01958492584526539, 0.0, 0.0, 0.0], [0.9270981550216675, 0.004809631034731865, 0.0030887839384377003, 0.005205564666539431, 0.018441975116729736, 0.006030889227986336, 0.03003735840320587, 0.0052877976559102535, 0.0, 0.0], [0.603268563747406, 0.009098237380385399, 0.00021995518181938678, 0.07179546356201172, 0.0017328117974102497, 0.01055157370865345, 0.020978767424821854, 0.2736198902130127, 0.008734744042158127, 0.0], [0.6497007608413696, 0.0906025841832161, 0.0100435521453619, 0.007925360463559628, 0.013416239991784096, 0.0018666544929146767, 0.02140365168452263, 0.08128199726343155, 0.04188578948378563, 0.08187359571456909]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9857779741287231, 0.014221975579857826, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9197340607643127, 0.07413885742425919, 0.0061270855367183685, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8673564195632935, 0.016403868794441223, 0.1017053872346878, 0.014534366317093372, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.044595908373594284, 0.010755550116300583, 0.002565854461863637, 0.9345642328262329, 0.007518457714468241, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4605148434638977, 0.007289387751370668, 0.009601963683962822, 0.08598940074443817, 0.4091304838657379, 0.027473902329802513, 0.0, 0.0, 0.0, 0.0], [0.8714936971664429, 0.002528996206820011, 0.0021269593853503466, 0.0052809687331318855, 0.02593054249882698, 0.07010670751333237, 0.022532090544700623, 0.0, 0.0, 0.0], [0.507957398891449, 0.003823956474661827, 0.004157013725489378, 0.018131878226995468, 0.06916838884353638, 0.047881923615932465, 0.2798653542995453, 0.06901402771472931, 0.0, 0.0], [0.4575899839401245, 0.005646431352943182, 0.0004441867640707642, 0.03129462152719498, 0.014414624311029911, 0.0058625745587050915, 0.09207130968570709, 0.34311652183532715, 0.04955975338816643, 0.0], [0.8105311393737793, 0.0010255038505420089, 0.0001402802881784737, 0.0005781117943115532, 0.00122542935423553, 0.000594198820181191, 0.02804729714989662, 0.01081023644655943, 0.13665232062339783, 0.010395429097115993]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8512031435966492, 0.14879685640335083, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10041537135839462, 0.8953256011009216, 0.0042589944787323475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6295948624610901, 0.2121732085943222, 0.10306572169065475, 0.055166181176900864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9503376483917236, 0.007425909396260977, 0.0019253676291555166, 0.025024304166436195, 0.015286784619092941, 0.0, 0.0, 0.0, 0.0, 0.0], [0.24298420548439026, 0.06981680542230606, 0.030552756041288376, 0.020666545256972313, 0.46177101135253906, 0.1742086559534073, 0.0, 0.0, 0.0, 0.0], [0.8132306933403015, 0.003601218806579709, 0.01019350253045559, 0.009439423680305481, 0.040081463754177094, 0.07570415735244751, 0.04774952307343483, 0.0, 0.0, 0.0], [0.6454712152481079, 0.006356438156217337, 0.006696825381368399, 0.0020169378258287907, 0.11416922509670258, 0.11139311641454697, 0.07912010699510574, 0.03477614000439644, 0.0, 0.0], [0.22032444179058075, 0.0006508066435344517, 0.006827942095696926, 0.028858821839094162, 0.0022757677361369133, 0.006474251858890057, 0.09447979182004929, 0.6212162375450134, 0.018891895189881325, 0.0], [0.03250038996338844, 0.0005526043241843581, 2.807211239996832e-05, 0.00014761221245862544, 0.00482193985953927, 7.781770545989275e-05, 0.00014718669990543276, 0.0008632297394797206, 0.959712028503418, 0.0011490467004477978]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9700191020965576, 0.029980869963765144, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7072298526763916, 0.2173422873020172, 0.07542789727449417, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5017270445823669, 0.10517530888319016, 0.32087045907974243, 0.07222715020179749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39005738496780396, 0.2261916995048523, 0.1838584840297699, 0.10916081070899963, 0.09073163568973541, 0.0, 0.0, 0.0, 0.0, 0.0], [0.11122927069664001, 0.04386316239833832, 0.023478534072637558, 0.07375308126211166, 0.5692906379699707, 0.17838534712791443, 0.0, 0.0, 0.0, 0.0], [0.16762810945510864, 0.030268238857388496, 0.015392551198601723, 0.05242612585425377, 0.21519990265369415, 0.34948840737342834, 0.16959665715694427, 0.0, 0.0, 0.0], [0.15348000824451447, 0.03554287180304527, 0.008979924954473972, 0.07115276902914047, 0.08698276430368423, 0.24143245816230774, 0.28553345799446106, 0.11689584702253342, 0.0, 0.0], [0.09456975758075714, 0.010759694501757622, 0.0067994119599461555, 0.01042863354086876, 0.05627141892910004, 0.11228546500205994, 0.14361944794654846, 0.3204572796821594, 0.2448090761899948, 0.0], [0.057867951691150665, 0.02229062095284462, 0.016399098560214043, 0.02521427348256111, 0.047808028757572174, 0.03428687900304794, 0.05170976370573044, 0.19979508221149445, 0.41991233825683594, 0.12471600621938705]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9535994529724121, 0.04640045389533043, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8665578961372375, 0.09402694553136826, 0.03941517323255539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8201385140419006, 0.07587680220603943, 0.05075912922620773, 0.053225547075271606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6245242953300476, 0.093341164290905, 0.11281723529100418, 0.1092497780919075, 0.06006752699613571, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5755861401557922, 0.0864969864487648, 0.10001320391893387, 0.12654373049736023, 0.06871193647384644, 0.04264802858233452, 0.0, 0.0, 0.0, 0.0], [0.6500274538993835, 0.06470640748739243, 0.047299426048994064, 0.08855419605970383, 0.06197808310389519, 0.04487667977809906, 0.04255769029259682, 0.0, 0.0, 0.0], [0.5771223902702332, 0.0491044707596302, 0.09411156177520752, 0.06903567165136337, 0.04109871760010719, 0.06523709744215012, 0.06637011468410492, 0.03792000934481621, 0.0, 0.0], [0.4695849120616913, 0.017787985503673553, 0.06290572881698608, 0.06516575813293457, 0.09894091635942459, 0.03647425398230553, 0.051347069442272186, 0.08907806128263474, 0.10871540009975433, 0.0], [0.18501408398151398, 0.040740884840488434, 0.10466982424259186, 0.07660976052284241, 0.17033715546131134, 0.05819392204284668, 0.0898737907409668, 0.09184892475605011, 0.10470453649759293, 0.0780070349574089]]], [[[0.10875418037176132, 0.15107707679271698, 0.07560893893241882, 0.11182637512683868, 0.051575273275375366, 0.1800614595413208, 0.13901139795780182, 0.11257244646549225, 0.06951297074556351, 0.0], [0.04530828073620796, 0.11530135571956635, 0.03132164478302002, 0.12301183491945267, 0.01339547149837017, 0.009322633035480976, 0.0069213854148983955, 0.181557297706604, 0.47386014461517334, 0.0], [0.08671615272760391, 0.21926835179328918, 0.11249969899654388, 0.05250205472111702, 0.044286634773015976, 0.006910341326147318, 0.004434189759194851, 0.00961831770837307, 0.4637643098831177, 0.0], [0.016148164868354797, 0.08668603748083115, 0.1414848268032074, 0.024200299754738808, 0.018711188808083534, 0.02537006139755249, 0.017450006678700447, 0.039331331849098206, 0.6306182146072388, 0.0], [0.024489276111125946, 0.03301851078867912, 0.03003605268895626, 0.03562680631875992, 0.06981870532035828, 0.022592445835471153, 0.025447512045502663, 0.03545365110039711, 0.7235170006752014, 0.0], [0.05760658532381058, 0.08793947100639343, 0.053903114050626755, 0.0679689273238182, 0.007038408424705267, 0.007889931090176105, 0.010035911574959755, 0.019540006294846535, 0.6880777478218079, 0.0], [0.045610494911670685, 0.042210742831230164, 0.14248158037662506, 0.03233090415596962, 0.03048519603908062, 0.011738738045096397, 0.014284060336649418, 0.006383211817592382, 0.6744750738143921, 0.0], [0.096277616918087, 0.030696624889969826, 0.10220203548669815, 0.04915016517043114, 0.047845132648944855, 0.05814794450998306, 0.06954183429479599, 0.028650736436247826, 0.5174878835678101, 0.0], [0.009306053631007671, 0.02153283730149269, 0.009718294255435467, 0.005953253246843815, 0.011703923344612122, 0.017902903258800507, 0.011090915650129318, 0.01645584963262081, 0.8963360786437988, 0.0], [0.009895006194710732, 0.026821313425898552, 0.16079027950763702, 0.01761648990213871, 0.01726638339459896, 0.08361288905143738, 0.039622098207473755, 0.14411716163158417, 0.5002583861351013, 0.0]], [[0.0543275885283947, 0.01742306910455227, 0.05347726121544838, 0.18824619054794312, 0.09003543108701706, 0.08433128148317337, 0.1953076422214508, 0.206686869263649, 0.11016455292701721, 0.0], [0.00859006680548191, 0.02184058353304863, 0.02418440766632557, 0.03131486475467682, 0.03273439407348633, 0.06774082779884338, 0.1731010377407074, 0.09275981038808823, 0.5477339029312134, 0.0], [0.02145911566913128, 0.046526145190000534, 0.014734850265085697, 0.026213468983769417, 0.04904777929186821, 0.08567024767398834, 0.13810616731643677, 0.03392839804291725, 0.5843138694763184, 0.0], [0.019245177507400513, 0.01515401341021061, 0.027409562841057777, 0.0068243746645748615, 0.07997982203960419, 0.0921224057674408, 0.04510754346847534, 0.04373685643076897, 0.670420229434967, 0.0], [0.04381020739674568, 0.06711422652006149, 0.07609888166189194, 0.021496189758181572, 0.05042967572808266, 0.15614424645900726, 0.11071597784757614, 0.14296749234199524, 0.3312230408191681, 0.0], [0.04100082442164421, 0.030313873663544655, 0.032653506845235825, 0.0695231482386589, 0.12672685086727142, 0.12515434622764587, 0.08855390548706055, 0.05835743993520737, 0.4277162253856659, 0.0], [0.14112897217273712, 0.06592341512441635, 0.06986766308546066, 0.06311382353305817, 0.12678426504135132, 0.04950721934437752, 0.08025017380714417, 0.03467738255858421, 0.36874714493751526, 0.0], [0.02841436117887497, 0.022568009793758392, 0.014519155025482178, 0.019271234050393105, 0.018120555207133293, 0.036434635519981384, 0.014109926298260689, 0.24622198939323425, 0.6003400683403015, 0.0], [0.05730762332677841, 0.07724729180335999, 0.030861826613545418, 0.04063780978322029, 0.08539344370365143, 0.029541905969381332, 0.02964094467461109, 0.028206804767251015, 0.6211622953414917, 0.0], [0.20915710926055908, 0.193747878074646, 0.11181499063968658, 0.07680925726890564, 0.04479793831706047, 0.03787367418408394, 0.04819086939096451, 0.11330965161323547, 0.1642986238002777, 0.0]], [[0.038908280432224274, 0.07760688662528992, 0.062413811683654785, 0.0023113787174224854, 0.0021746077109128237, 0.015095214359462261, 0.003646473865956068, 0.038165315985679626, 0.759678065776825, 0.0], [0.015742339193820953, 0.029524141922593117, 0.0550379604101181, 0.16926467418670654, 0.035933610051870346, 0.03279981389641762, 0.03188418969511986, 0.5383173227310181, 0.09149592369794846, 0.0], [0.022741766646504402, 0.013864121399819851, 0.06161126494407654, 0.06985131651163101, 0.03954875469207764, 0.02864447981119156, 0.036658816039562225, 0.05774570629000664, 0.6693336963653564, 0.0], [0.06077639013528824, 0.053226571530103683, 0.05544588342308998, 0.08368532359600067, 0.04779139161109924, 0.028960514813661575, 0.03463221713900566, 0.42419588565826416, 0.21128588914871216, 0.0], [0.03320460394024849, 0.07872876524925232, 0.0791814923286438, 0.008506255224347115, 0.010383618995547295, 0.021636927500367165, 0.009444555267691612, 0.026183925569057465, 0.7327298521995544, 0.0], [0.14095324277877808, 0.17195045948028564, 0.04960065335035324, 0.02801741287112236, 0.02789357118308544, 0.0246508177369833, 0.027228642255067825, 0.008449538610875607, 0.521255612373352, 0.0], [0.01678302139043808, 0.02193976752460003, 0.13912786543369293, 0.05168221518397331, 0.06239692494273186, 0.008615943603217602, 0.037501659244298935, 0.02482585795223713, 0.6371266841888428, 0.0], [0.03396642208099365, 0.07778684049844742, 0.18657010793685913, 0.11281172931194305, 0.019890569150447845, 0.012303605675697327, 0.0494060292840004, 0.11448060721158981, 0.39278414845466614, 0.0], [0.02684134803712368, 0.03310805931687355, 0.163743257522583, 0.014529252424836159, 0.10077258199453354, 0.044357266277074814, 0.04152251034975052, 0.10173188894987106, 0.4733937382698059, 0.0], [0.01862592063844204, 0.022009190171957016, 0.028925148770213127, 0.006837732624262571, 0.006956242956221104, 0.010202805511653423, 0.015325144864618778, 0.11640346795320511, 0.7747144103050232, 0.0]], [[0.0830092504620552, 0.0839436799287796, 0.10106679797172546, 0.11154499650001526, 0.045070260763168335, 0.1284436285495758, 0.1161414161324501, 0.19574469327926636, 0.1350351870059967, 0.0], [0.0006529411766678095, 0.0018492193194106221, 0.018439743667840958, 0.004895282443612814, 0.0036929987836629152, 0.05041775107383728, 0.03271673619747162, 0.4425412714481354, 0.4447941780090332, 0.0], [0.015919672325253487, 0.02172437310218811, 0.013682822696864605, 0.028371846303343773, 0.017258556559681892, 0.014516759663820267, 0.033475372940301895, 0.45419326424598694, 0.40085726976394653, 0.0], [0.006064589135348797, 0.006147248670458794, 0.06902536749839783, 0.011021673679351807, 0.0062199062667787075, 0.17622654139995575, 0.00982236210256815, 0.46262383460998535, 0.25284844636917114, 0.0], [0.018328940495848656, 0.034908927977085114, 0.027539005503058434, 0.04494883120059967, 0.03695090860128403, 0.18224696815013885, 0.04204700142145157, 0.09570277482271194, 0.5173265337944031, 0.0], [0.06838149577379227, 0.025893883779644966, 0.06412170827388763, 0.11039282381534576, 0.12848982214927673, 0.09953469038009644, 0.09056522697210312, 0.12723064422607422, 0.28538966178894043, 0.0], [0.07893572002649307, 0.0734885111451149, 0.06503137946128845, 0.04291535168886185, 0.08502060174942017, 0.04846649244427681, 0.07035838067531586, 0.14812934398651123, 0.38765427470207214, 0.0], [0.007445929106324911, 0.004103729501366615, 0.05411284416913986, 0.006074799690395594, 0.07146289199590683, 0.5494692921638489, 0.05009504780173302, 0.058794084936380386, 0.1984413117170334, 0.0], [0.0037151367869228125, 0.005083263851702213, 0.02171880006790161, 0.01245985459536314, 0.012914983555674553, 0.14437292516231537, 0.026943473145365715, 0.17420484125614166, 0.5985866785049438, 0.0], [0.02579679898917675, 0.0645768865942955, 0.03225725144147873, 0.044467855244874954, 0.04297630116343498, 0.06060377135872841, 0.030930038541555405, 0.03278812766075134, 0.6656030416488647, 0.0]], [[0.13460709154605865, 0.15298102796077728, 0.06546170264482498, 0.14220191538333893, 0.11837887763977051, 0.09888823330402374, 0.10630416870117188, 0.08867054432630539, 0.09250646829605103, 0.0], [0.9316296577453613, 0.016095036640763283, 0.0020372711587697268, 0.0019596514757722616, 2.8437656510504894e-05, 6.708989531034604e-05, 0.0004955903859809041, 3.0113247703411616e-05, 0.047657083719968796, 0.0], [0.043201129883527756, 0.9419298768043518, 0.0003410913050174713, 0.003313146298751235, 7.506452675443143e-06, 1.9570916265365668e-05, 2.5470235414104536e-05, 2.1080213628010824e-05, 0.011141069233417511, 0.0], [3.7581870856229216e-05, 0.00022979748609941453, 0.9982534646987915, 8.70372386998497e-05, 5.87535805607331e-06, 2.5239218302886002e-05, 6.597588708245894e-06, 2.193619138779468e-06, 0.001352491439320147, 0.0], [0.0019612079486250877, 0.011641290038824081, 0.010358362458646297, 0.8346317410469055, 0.00641160923987627, 0.0007435380248352885, 0.0018172020791098475, 7.255822129081935e-05, 0.1323624849319458, 0.0], [4.077299308846705e-05, 0.00016088274423964322, 3.1180113637674367e-06, 5.9685276937671006e-05, 6.661444786004722e-06, 0.0006764131248928607, 5.4107837058836594e-05, 0.9797272086143494, 0.01927126571536064, 0.0], [2.7792530090664513e-06, 1.1777839063142892e-05, 1.0386434951215051e-05, 0.0006807934259995818, 0.00028749846387654543, 0.9563493728637695, 2.4335316993528977e-05, 0.001297356327995658, 0.041335828602313995, 0.0], [0.00033864984288811684, 0.00016234541544690728, 0.00011107163300039247, 7.639558316441253e-05, 9.851753566181287e-05, 0.00046863980242051184, 0.9855522513389587, 0.00012009339843643829, 0.013071970082819462, 0.0], [0.001446103909984231, 0.0026176422834396362, 0.0005430445889942348, 0.5833504796028137, 0.08298782259225845, 0.01277364045381546, 0.008405186235904694, 0.028461067005991936, 0.2794148921966553, 0.0], [8.301706202473724e-07, 1.612889263924444e-06, 3.859615389956161e-06, 0.0015496612759307027, 0.9884966611862183, 0.0003321043332107365, 1.1829011782538146e-05, 3.7258676002238644e-06, 0.00959983840584755, 0.0]], [[0.03624086081981659, 0.008591840974986553, 0.01890810765326023, 0.010947922244668007, 0.5211313366889954, 0.04890615865588188, 0.13394898176193237, 0.08554741740226746, 0.13577744364738464, 0.0], [0.09101090580224991, 0.15663929283618927, 0.2008313536643982, 0.13744188845157623, 0.16349081695079803, 0.01479706447571516, 0.04576689749956131, 0.05515507981181145, 0.1348666250705719, 0.0], [0.10898119956254959, 0.19741322100162506, 0.12774543464183807, 0.07097428292036057, 0.033309608697891235, 0.016726871952414513, 0.019306309521198273, 0.09155051410198212, 0.3339925706386566, 0.0], [0.051247891038656235, 0.06952031701803207, 0.3243081271648407, 0.04820195212960243, 0.05462171137332916, 0.04280935227870941, 0.03801479935646057, 0.07710513472557068, 0.2941707372665405, 0.0], [0.22540897130966187, 0.04426601901650429, 0.13483746349811554, 0.09052211791276932, 0.036632657051086426, 0.06078784167766571, 0.09962243586778641, 0.04597063735127449, 0.2619517743587494, 0.0], [0.08315062522888184, 0.10649015009403229, 0.15254046022891998, 0.0728936716914177, 0.10388997197151184, 0.04998103529214859, 0.0675109326839447, 0.17524446547031403, 0.18829864263534546, 0.0], [0.09407053142786026, 0.04335644096136093, 0.04757237061858177, 0.023308007046580315, 0.14141318202018738, 0.017728488892316818, 0.02331509254872799, 0.07266414165496826, 0.5365718007087708, 0.0], [0.08477651327848434, 0.026448125019669533, 0.013684368692338467, 0.1331702470779419, 0.16824185848236084, 0.007634431589394808, 0.025501158088445663, 0.035930439829826355, 0.5046128630638123, 0.0], [0.03296202793717384, 0.01823815330862999, 0.025750160217285156, 0.08325016498565674, 0.1596710979938507, 0.010502922348678112, 0.01792057603597641, 0.05097610503435135, 0.6007286906242371, 0.0], [0.04370357468724251, 0.02250431850552559, 0.016271278262138367, 0.019842427223920822, 0.12028838694095612, 0.03933797404170036, 0.043740611523389816, 0.08045370131731033, 0.6138576865196228, 0.0]], [[0.1783323585987091, 0.3813028037548065, 0.2072289139032364, 0.06766574084758759, 0.053963109850883484, 0.030795719474554062, 0.023536406457424164, 0.03921645134687424, 0.01795845478773117, 0.0], [0.8837893009185791, 0.07202983647584915, 0.03646722435951233, 0.0004511935112532228, 0.0007272462244145572, 0.0008432198665104806, 0.0031319037079811096, 0.0004143840924371034, 0.0021455709356814623, 0.0], [0.3973897695541382, 0.14911939203739166, 0.3486334979534149, 0.012645252980291843, 0.00675938231870532, 0.00483374297618866, 0.010028100572526455, 0.012036854401230812, 0.058554183691740036, 0.0], [0.005409032106399536, 0.005906772334128618, 0.13379110395908356, 0.15247586369514465, 0.06559418141841888, 0.15356750786304474, 0.04085409641265869, 0.029147597029805183, 0.41325387358665466, 0.0], [0.0013326199259608984, 0.0014979635598137975, 0.011986319907009602, 0.7730216383934021, 0.06901827454566956, 0.05895080044865608, 0.016383536159992218, 0.015771687030792236, 0.052037257701158524, 0.0], [0.0012038598069921136, 0.0033955213148146868, 0.025528373196721077, 0.03136582672595978, 0.10901585966348648, 0.3851255178451538, 0.0182026457041502, 0.13982580602169037, 0.2863365411758423, 0.0], [0.008065885864198208, 0.004362722393125296, 0.06363680213689804, 0.023311397060751915, 0.06106392294168472, 0.1357712298631668, 0.03965916484594345, 0.06073852628469467, 0.6033903956413269, 0.0], [0.0003142715140711516, 0.0005578870768658817, 0.0015481057344004512, 0.0887022390961647, 0.06383900344371796, 0.2639910578727722, 0.049384135752916336, 0.12241825461387634, 0.40924492478370667, 0.0], [0.0003916181158274412, 0.0003099135938100517, 0.0024421222042292356, 0.016801349818706512, 0.18835966289043427, 0.025843605399131775, 0.08458039909601212, 0.20884136855602264, 0.4724300503730774, 0.0], [5.865378989255987e-05, 7.253760122694075e-05, 0.0007906460668891668, 0.025103986263275146, 0.0753612071275711, 0.04038592055439949, 0.011871143244206905, 0.05808362737298012, 0.7882723212242126, 0.0]], [[0.01597539149224758, 0.027860743924975395, 0.08824922889471054, 0.011547067202627659, 0.02896539680659771, 0.03845160827040672, 0.011409634724259377, 0.043791815638542175, 0.7337491512298584, 0.0], [0.0371943861246109, 0.014876782894134521, 0.02253115549683571, 0.10164438933134079, 0.029471710324287415, 0.040005166083574295, 0.020577073097229004, 0.07326765358448029, 0.6604316830635071, 0.0], [0.06676606088876724, 0.1320837438106537, 0.02368331328034401, 0.09289334714412689, 0.06407851725816727, 0.007657648529857397, 0.014540987089276314, 0.018603011965751648, 0.5796933174133301, 0.0], [0.029496638104319572, 0.013616771437227726, 0.030488401651382446, 0.021259615197777748, 0.13049498200416565, 0.06418323516845703, 0.050123173743486404, 0.1609034240245819, 0.4994336664676666, 0.0], [0.010230573825538158, 0.015954630449414253, 0.007779641076922417, 0.018425902351737022, 0.021085364744067192, 0.0588817335665226, 0.013979516923427582, 0.0252523310482502, 0.828410267829895, 0.0], [0.02648993395268917, 0.0214377511292696, 0.03494586795568466, 0.05471349507570267, 0.09140968322753906, 0.04952282831072807, 0.05564551055431366, 0.11169540882110596, 0.5541394948959351, 0.0], [0.03231878578662872, 0.018621357157826424, 0.05183127149939537, 0.03979233279824257, 0.13804322481155396, 0.03567919135093689, 0.047386858612298965, 0.13114488124847412, 0.505182147026062, 0.0], [0.04592716693878174, 0.010993612930178642, 0.01772226020693779, 0.05332585424184799, 0.15264220535755157, 0.22139224410057068, 0.048004403710365295, 0.12396018952131271, 0.3260320723056793, 0.0], [0.03168570622801781, 0.026294516399502754, 0.025469979271292686, 0.03026771917939186, 0.058515094220638275, 0.13361068069934845, 0.026259208098053932, 0.0612059161067009, 0.6066910624504089, 0.0], [0.07492455840110779, 0.06428299844264984, 0.07022737711668015, 0.0507473424077034, 0.0447908453643322, 0.060839906334877014, 0.14463475346565247, 0.054812539368867874, 0.4347396492958069, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9642227292060852, 0.035777393728494644, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9523521065711975, 0.027811188250780106, 0.019836684688925743, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.849480152130127, 0.03536543622612953, 0.019422976300120354, 0.09573143720626831, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.741925060749054, 0.05566684901714325, 0.024736514315009117, 0.08595114946365356, 0.09172046929597855, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6503966450691223, 0.0582728385925293, 0.0236701387912035, 0.0691222995519638, 0.0758395791053772, 0.12269847840070724, 0.0, 0.0, 0.0, 0.0], [0.4914315342903137, 0.11739180237054825, 0.02309434488415718, 0.07889512181282043, 0.05101678892970085, 0.12367808818817139, 0.11449223756790161, 0.0, 0.0, 0.0], [0.4262734055519104, 0.07066749036312103, 0.024391667917370796, 0.04879573732614517, 0.051445234566926956, 0.1276569813489914, 0.11843930184841156, 0.13233007490634918, 0.0, 0.0], [0.589878499507904, 0.026613032445311546, 0.020459800958633423, 0.028271155431866646, 0.03679497539997101, 0.07860217243432999, 0.08500825613737106, 0.09285575151443481, 0.04151623696088791, 0.0], [0.2743179202079773, 0.06089583784341812, 0.03565794974565506, 0.044920988380908966, 0.03933599591255188, 0.18495218455791473, 0.09192009270191193, 0.13160176575183868, 0.04121606424450874, 0.09518115967512131]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9842625260353088, 0.015737490728497505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8382691144943237, 0.11647694557905197, 0.04525385797023773, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4638526439666748, 0.1585947573184967, 0.3189436197280884, 0.0586090050637722, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2375488132238388, 0.07284080982208252, 0.20766110718250275, 0.3110494017601013, 0.1708998829126358, 0.0, 0.0, 0.0, 0.0, 0.0], [0.20615516602993011, 0.03705071657896042, 0.05929475650191307, 0.08692343533039093, 0.5564662218093872, 0.05410974845290184, 0.0, 0.0, 0.0, 0.0], [0.31913095712661743, 0.011343744583427906, 0.01675090566277504, 0.013238506391644478, 0.06746862828731537, 0.3789318799972534, 0.19313538074493408, 0.0, 0.0, 0.0], [0.4113273322582245, 0.003934106323868036, 0.003564919577911496, 0.005882325116544962, 0.018547017127275467, 0.18534934520721436, 0.3216978907585144, 0.04969710111618042, 0.0, 0.0], [0.07648876309394836, 0.0013769177021458745, 0.001890459912829101, 0.006597061175853014, 0.007926206104457378, 0.013261871412396431, 0.15683594346046448, 0.7190074324607849, 0.016615279018878937, 0.0], [0.08104224503040314, 0.00045554721145890653, 0.00038501128437928855, 0.0009405335295014083, 0.005597654264420271, 0.0034990713465958834, 0.009850292466580868, 0.0463707260787487, 0.7366765141487122, 0.11518235504627228]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9800853133201599, 0.019914645701646805, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9159882068634033, 0.02969631738960743, 0.05431551858782768, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6467475295066833, 0.08892705291509628, 0.19796258211135864, 0.06636285036802292, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9833061099052429, 0.004010406322777271, 0.004914217162877321, 0.0015858567785471678, 0.006183335091918707, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9524497389793396, 0.0022862900514155626, 0.000848656112793833, 0.00408557103946805, 0.028177350759506226, 0.012152665294706821, 0.0, 0.0, 0.0, 0.0], [0.1907505989074707, 0.026542214676737785, 0.01945381611585617, 0.029287727549672127, 0.057166602462530136, 0.11766232550144196, 0.5591367483139038, 0.0, 0.0, 0.0], [0.4022328555583954, 0.017193131148815155, 0.01565318927168846, 0.01915702596306801, 0.01739031821489334, 0.16459040343761444, 0.18205313384532928, 0.18172988295555115, 0.0, 0.0], [0.9652498960494995, 0.0010482663055881858, 0.0012260396033525467, 0.0009098293376155198, 0.0013901795027777553, 0.0028189055155962706, 0.007343438919633627, 0.018731823191046715, 0.0012814495712518692, 0.0], [0.18471455574035645, 0.018054824322462082, 0.08812589198350906, 0.00762907462194562, 0.018057269975543022, 0.05247756093740463, 0.03497685119509697, 0.5025416612625122, 0.052323222160339355, 0.04109897091984749]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9911633133888245, 0.008836665190756321, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9641951322555542, 0.023474374786019325, 0.012330451980233192, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6152319312095642, 0.28041696548461914, 0.04906271770596504, 0.05528838559985161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6057276725769043, 0.1235719844698906, 0.06170117110013962, 0.11151555925607681, 0.0974835753440857, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6386814713478088, 0.07927443087100983, 0.06004401296377182, 0.06398510187864304, 0.06341437995433807, 0.09460049122571945, 0.0, 0.0, 0.0, 0.0], [0.13321073353290558, 0.0565485954284668, 0.20425985753536224, 0.10307760536670685, 0.17957380414009094, 0.26328328251838684, 0.06004612147808075, 0.0, 0.0, 0.0], [0.19694660604000092, 0.027736904099583626, 0.05790374055504799, 0.10621010512113571, 0.15510229766368866, 0.2214440256357193, 0.18680275976657867, 0.04785352945327759, 0.0, 0.0], [0.08537944406270981, 0.033881768584251404, 0.03968465328216553, 0.08240006119012833, 0.15350975096225739, 0.23219235241413116, 0.22240297496318817, 0.11620921641588211, 0.034339725971221924, 0.0], [0.06051333248615265, 0.012086840346455574, 0.028373999521136284, 0.07542525231838226, 0.10199770331382751, 0.15039192140102386, 0.20426926016807556, 0.16016273200511932, 0.06537677347660065, 0.14140206575393677]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5400503277778625, 0.4599496126174927, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04321815073490143, 0.9357689023017883, 0.02101275697350502, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48035699129104614, 0.12913382053375244, 0.27151036262512207, 0.11899882555007935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6920371055603027, 0.019891848787665367, 0.1885785609483719, 0.06273186951875687, 0.036760613322257996, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8527964949607849, 0.08059625327587128, 0.0037265238352119923, 0.008582950569689274, 0.042790722101926804, 0.01150701567530632, 0.0, 0.0, 0.0, 0.0], [0.900881826877594, 0.012710069306194782, 0.000794807099737227, 0.00424413476139307, 0.02110898308455944, 0.01962616853415966, 0.04063420742750168, 0.0, 0.0, 0.0], [0.713775098323822, 0.003081131726503372, 0.000918463512789458, 0.009338468313217163, 0.013423318043351173, 0.019161174073815346, 0.10174864530563354, 0.13855360448360443, 0.0, 0.0], [0.4800099730491638, 0.0009553784620948136, 0.00013007478264626116, 0.020002998411655426, 0.0032414987217634916, 0.002101779682561755, 0.028948260471224785, 0.46123453974723816, 0.0033754503820091486, 0.0], [0.7501513361930847, 0.019767694175243378, 0.0020619838032871485, 0.0038300605956465006, 0.0023455689661204815, 0.023803891614079475, 0.011456847190856934, 0.045016106218099594, 0.08813992142677307, 0.05342674255371094]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.03494315221905708, 0.965056836605072, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.020348060876131058, 0.8944171071052551, 0.08523476868867874, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0015979396412149072, 0.6347042918205261, 0.09008561074733734, 0.27361196279525757, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.01025437843054533, 0.17247439920902252, 0.3664330542087555, 0.4087805449962616, 0.04205762594938278, 0.0, 0.0, 0.0, 0.0, 0.0], [0.012186901643872261, 0.3028968572616577, 0.12117700278759003, 0.3522109389305115, 0.06255244463682175, 0.14897578954696655, 0.0, 0.0, 0.0, 0.0], [0.010822800919413567, 0.2333739995956421, 0.11113002151250839, 0.15861180424690247, 0.11286703497171402, 0.2766783833503723, 0.0965159684419632, 0.0, 0.0, 0.0], [0.00965114776045084, 0.19982098042964935, 0.054301097989082336, 0.13056904077529907, 0.03828747197985649, 0.4827912747859955, 0.05511533096432686, 0.029463520273566246, 0.0, 0.0], [0.014548483304679394, 0.07520423084497452, 0.1090526208281517, 0.14237697422504425, 0.030428709462285042, 0.5021095275878906, 0.026151562109589577, 0.04390878602862358, 0.05621904134750366, 0.0], [0.000422637298470363, 0.17123113572597504, 0.04347287863492966, 0.10408183932304382, 0.013075248338282108, 0.5476951003074646, 0.020964276045560837, 0.019243689253926277, 0.0612923838198185, 0.018520813435316086]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9947329163551331, 0.005267037078738213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7284466028213501, 0.21829284727573395, 0.05326057970523834, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7024527192115784, 0.0454108789563179, 0.10381712764501572, 0.14831924438476562, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2374107390642166, 0.04589728266000748, 0.2683154046535492, 0.3902822434902191, 0.0580943301320076, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7228419780731201, 0.007619804237037897, 0.013993922621011734, 0.04429992660880089, 0.020430808886885643, 0.19081364572048187, 0.0, 0.0, 0.0, 0.0], [0.4783930778503418, 0.005506142508238554, 0.008406496606767178, 0.012424511834979057, 0.04335693642497063, 0.17542317509651184, 0.27648961544036865, 0.0, 0.0, 0.0], [0.056768160313367844, 0.001066300319507718, 0.0015203694347292185, 0.004650356248021126, 0.004999558907002211, 0.17368057370185852, 0.7387632131576538, 0.018551528453826904, 0.0, 0.0], [0.14709600806236267, 0.007261540275067091, 0.001291902968659997, 0.012605146504938602, 0.005232691299170256, 0.08098926395177841, 0.5304067134857178, 0.207069993019104, 0.00804678164422512, 0.0], [0.15080930292606354, 0.014301316812634468, 0.002821019385010004, 0.02008463814854622, 0.004475536290556192, 0.05297520384192467, 0.27036672830581665, 0.407105028629303, 0.007729486562311649, 0.06933178007602692]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9945669174194336, 0.005433134268969297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9554939270019531, 0.02177131362259388, 0.0227347444742918, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.19059398770332336, 0.7459079623222351, 0.05105874687433243, 0.012439398095011711, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.062025006860494614, 0.7277394533157349, 0.13110491633415222, 0.028790757060050964, 0.050339892506599426, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7678350806236267, 0.007377212401479483, 0.020054306834936142, 0.11815592646598816, 0.07254840433597565, 0.014029012061655521, 0.0, 0.0, 0.0, 0.0], [0.8187481760978699, 0.009394909255206585, 0.015446240082383156, 0.012167787179350853, 0.10175905376672745, 0.02721206098794937, 0.01527167297899723, 0.0, 0.0, 0.0], [0.7012083530426025, 0.12151088565587997, 0.03808446228504181, 0.01883355714380741, 0.0837249755859375, 0.006598148960620165, 0.006499246694147587, 0.023540453985333443, 0.0, 0.0], [0.5152325630187988, 0.054241329431533813, 0.17093418538570404, 0.020541386678814888, 0.17657014727592468, 0.012641755864024162, 0.01802964322268963, 0.023539982736110687, 0.008269038051366806, 0.0], [0.9131196141242981, 0.0010915634920820594, 0.006193474866449833, 0.006082434672862291, 0.03542511910200119, 0.006826554890722036, 0.0028478680178523064, 0.004068343434482813, 0.014553201384842396, 0.009791722521185875]]], [[[0.16448259353637695, 0.17219680547714233, 0.09987642616033554, 0.09012344479560852, 0.06534503400325775, 0.08456553518772125, 0.06690192222595215, 0.08019057661294937, 0.17631761729717255, 0.0], [0.49537378549575806, 0.03979916125535965, 0.09498286247253418, 0.0017974335933104157, 0.028368383646011353, 0.0015277893980965018, 0.014851069077849388, 0.0003722719266079366, 0.3229270279407501, 0.0], [0.0031106590759009123, 0.8318147659301758, 0.0329316072165966, 0.00014872441533952951, 0.000739947019610554, 0.0009879706194624305, 0.0012947155628353357, 0.00040531408740207553, 0.128566175699234, 0.0], [3.727031798916869e-05, 0.00033458907273598015, 0.9051278829574585, 0.014809494838118553, 0.0013665216974914074, 0.0009820980485528708, 0.0004274636448826641, 0.0006300737150013447, 0.07628484070301056, 0.0], [2.789895370369777e-05, 7.413508137688041e-05, 0.00011113573418697342, 0.9593441486358643, 0.023210706189274788, 0.00043970797560177743, 0.00011651179374894127, 0.0001221746060764417, 0.016553271561861038, 0.0], [5.518151283467887e-06, 4.040239218738861e-06, 4.706911568064243e-06, 0.0001475349417887628, 0.0011833186727017164, 0.007331210654228926, 0.0003812467912212014, 0.7072276473045349, 0.28371480107307434, 0.0], [2.1062598989374237e-06, 1.0153020184588968e-06, 9.153064297606761e-07, 2.3557351596537046e-05, 0.0019158869981765747, 0.9726926684379578, 0.0003360892878845334, 0.008161749690771103, 0.01686590164899826, 0.0], [1.876308124337811e-05, 3.1762643629917875e-05, 7.612020908709383e-06, 4.369785983726615e-06, 0.00035698129795491695, 0.006292039528489113, 0.9372867941856384, 0.0028216273058205843, 0.0531802624464035, 0.0], [0.00017082327394746244, 0.0008267413941211998, 0.0010992212919518352, 0.016357675194740295, 0.03317699581384659, 0.013446258381009102, 0.022417983040213585, 0.0993492603302002, 0.813154935836792, 0.0], [2.095436911986326e-06, 1.0510404990782263e-06, 8.745904779061675e-06, 9.465758921578526e-05, 0.9096792936325073, 0.004888555034995079, 0.00019891942793037742, 0.00012723646068479866, 0.08499950170516968, 0.0]], [[0.09510962665081024, 0.13984361290931702, 0.01835908181965351, 0.05623754486441612, 0.05484192445874214, 0.02751241996884346, 0.023350151255726814, 0.02046714909374714, 0.5642784833908081, 0.0], [0.32246580719947815, 0.12212380021810532, 0.0033711090218275785, 0.41883695125579834, 0.0010050723794847727, 0.00026374190929345787, 0.00840060692280531, 0.0003199145139660686, 0.12321317940950394, 0.0], [0.1343918889760971, 0.42756012082099915, 0.03016146458685398, 0.27197346091270447, 0.0008738918695598841, 0.00041738885920494795, 0.0011337834876030684, 0.0017680631717666984, 0.13172008097171783, 0.0], [4.970032023265958e-05, 0.0002945268643088639, 0.9929893612861633, 0.006102537736296654, 1.304412307945313e-06, 7.552243459940655e-06, 2.0433815279830014e-06, 1.4308750905911438e-05, 0.0005390164442360401, 0.0], [0.0006735534407198429, 0.0037932321429252625, 0.014864870347082615, 0.9520841240882874, 0.0031083461362868547, 0.0014454165939241648, 0.000881638377904892, 0.00042032121564261615, 0.02272843010723591, 0.0], [1.054488166118972e-06, 5.819076250190847e-06, 3.686256491164386e-07, 5.7184315664926544e-05, 1.600286668690387e-05, 0.0002979082928504795, 5.8259040088159963e-05, 0.997514009475708, 0.0020495890639722347, 0.0], [1.2081607110303594e-06, 1.8248301785206422e-06, 3.5412674037615943e-07, 0.00017610432405490428, 0.0004308871575631201, 0.9919483065605164, 0.001251595327630639, 0.004008213523775339, 0.002181792864575982, 0.0], [1.3394396773946937e-06, 1.858925656961219e-06, 8.99223309147601e-08, 5.498410246218555e-06, 4.1167979361489415e-05, 0.003499603597447276, 0.9961592555046082, 8.322765097545926e-06, 0.0002831367892213166, 0.0], [0.0011697824811562896, 0.00207342766225338, 0.0001985222043003887, 0.24218614399433136, 0.2580603361129761, 0.03422079235315323, 0.3017951250076294, 0.0700761154294014, 0.09021952003240585, 0.0], [4.897859540164973e-08, 1.9182496657776937e-07, 1.6890984966266842e-07, 0.00012898082786705345, 0.9986647963523865, 0.0003688811557367444, 8.465539576718584e-05, 1.2611121746886056e-05, 0.0007397857843898237, 0.0]], [[0.008738831616938114, 0.010689073242247105, 0.010104849003255367, 0.025418052449822426, 0.008787600323557854, 0.018541773781180382, 0.01414045225828886, 0.009587875567376614, 0.8939914107322693, 0.0], [0.050771377980709076, 0.08173098415136337, 0.03076810948550701, 0.6816214919090271, 0.04326915368437767, 0.0030209666583687067, 0.006032166071236134, 0.007633579429239035, 0.09515213221311569, 0.0], [0.04749365150928497, 0.07148067653179169, 0.018722670152783394, 0.5845115184783936, 0.03816590458154678, 0.003933309111744165, 0.006466464139521122, 0.021205652505159378, 0.20802012085914612, 0.0], [0.021572547033429146, 0.11727327853441238, 0.03622674569487572, 0.4274545907974243, 0.05620160698890686, 0.01161592174321413, 0.010393376462161541, 0.014363090507686138, 0.30489882826805115, 0.0], [0.015270093455910683, 0.10013995319604874, 0.006727923639118671, 0.19538360834121704, 0.1119888573884964, 0.027630485594272614, 0.0700199231505394, 0.01868581771850586, 0.4541531801223755, 0.0], [0.00540963327512145, 0.07916348427534103, 0.01957465149462223, 0.49324244260787964, 0.10871188342571259, 0.02422497235238552, 0.008650544099509716, 0.16292543709278107, 0.0980970561504364, 0.0], [0.027941647917032242, 0.005471521522849798, 0.006384703796356916, 0.03924928605556488, 0.22657036781311035, 0.21837352216243744, 0.3372570872306824, 0.05897291377186775, 0.07977905124425888, 0.0], [0.009049936197698116, 0.005020579323172569, 0.014692768454551697, 0.15799382328987122, 0.4401932656764984, 0.1766415536403656, 0.03136269003152847, 0.12063619494438171, 0.044409021735191345, 0.0], [0.0007816475699655712, 0.0003147682291455567, 0.0032215022947639227, 0.4467180669307709, 0.3918246924877167, 0.00227341428399086, 0.004370422102510929, 0.14414219558238983, 0.006353371310979128, 0.0], [0.0005489268223755062, 0.016601460054516792, 0.01341363787651062, 0.2753817141056061, 0.13981539011001587, 0.04711242765188217, 0.08167178928852081, 0.11951272189617157, 0.30594193935394287, 0.0]], [[0.11438923329114914, 0.12380287796258926, 0.23573537170886993, 0.19010169804096222, 0.15611350536346436, 0.031749427318573, 0.02482231892645359, 0.05017237365245819, 0.07311322540044785, 0.0], [0.002549531403928995, 0.03178577870130539, 0.17347589135169983, 0.2232668697834015, 0.49775105714797974, 0.018238944932818413, 0.005651220679283142, 0.03368452191352844, 0.013595964759588242, 0.0], [0.0032994491048157215, 0.026504727080464363, 0.41210347414016724, 0.24245016276836395, 0.18897436559200287, 0.012874660082161427, 0.006452939473092556, 0.10089367628097534, 0.00644671730697155, 0.0], [0.002998506650328636, 0.048583757132291794, 0.28224417567253113, 0.0846971943974495, 0.013445784337818623, 0.02188579924404621, 0.017656570300459862, 0.5155076384544373, 0.012980557046830654, 0.0], [0.004188622813671827, 0.028234833851456642, 0.022820167243480682, 0.058492597192525864, 0.19205521047115326, 0.08343320339918137, 0.07119973003864288, 0.4843534827232361, 0.0552222914993763, 0.0], [0.0038351663388311863, 0.015353971160948277, 0.01755588687956333, 0.06245748698711395, 0.1218588799238205, 0.07207991182804108, 0.02867230959236622, 0.5455195903778076, 0.13266700506210327, 0.0], [0.004144841339439154, 0.0048835063353180885, 0.0035110898315906525, 0.06276324391365051, 0.04069552943110466, 0.3603023290634155, 0.1472603678703308, 0.2116946280002594, 0.16474448144435883, 0.0], [0.024624889716506004, 0.016127971932291985, 0.0073340879753232, 0.023849278688430786, 0.042295511811971664, 0.5078635215759277, 0.2884303331375122, 0.011452756822109222, 0.07802165299654007, 0.0], [0.00880166981369257, 0.002673782641068101, 0.001370548619888723, 0.0061265453696250916, 0.02490534819662571, 0.2073771357536316, 0.3818575143814087, 0.1663341522216797, 0.20055335760116577, 0.0], [0.012253189459443092, 0.02221212349832058, 0.002282155444845557, 0.10455729067325592, 0.4111727774143219, 0.08308815956115723, 0.045707643032073975, 0.03711223974823952, 0.2816142141819, 0.0]], [[0.5821239352226257, 0.14550858736038208, 0.031251534819602966, 0.030760297551751137, 0.02147754468023777, 0.013665237464010715, 0.009087015874683857, 0.01557532325387001, 0.15055041015148163, 0.0], [0.12817564606666565, 0.33913177251815796, 0.07241326570510864, 0.41213902831077576, 0.0326012559235096, 0.0031606394331902266, 0.0006341012776829302, 0.007317711599171162, 0.0044263736344873905, 0.0], [0.08047150820493698, 0.06199575960636139, 0.5555182099342346, 0.2858560383319855, 0.008700164034962654, 0.003758196486160159, 0.001155794132500887, 0.0007424709619954228, 0.0018020549323409796, 0.0], [0.010044030845165253, 0.018482256680727005, 0.6269924640655518, 0.32439544796943665, 0.01023165788501501, 0.007641270756721497, 0.0008933563949540257, 0.0010311403311789036, 0.00028844154439866543, 0.0], [0.0007911038701422513, 0.0008549468475393951, 0.015090622939169407, 0.8270009160041809, 0.11969847232103348, 0.032614268362522125, 0.0024233118165284395, 0.0011481117689982057, 0.0003779604157898575, 0.0], [0.017773190513253212, 0.008623103611171246, 0.0020072387997061014, 0.08177924901247025, 0.13816505670547485, 0.6801413297653198, 0.02186667174100876, 0.024107687175273895, 0.025536518543958664, 0.0], [0.000318053673254326, 5.6540200603194535e-05, 1.071194674295839e-05, 0.0009494975674897432, 0.0034297029487788677, 0.032661326229572296, 0.9588278532028198, 0.003185966284945607, 0.0005602877936325967, 0.0], [0.0017862697131931782, 0.0002347631088923663, 2.1297884813975543e-05, 0.0004797980946023017, 0.0018031852087005973, 0.024247879162430763, 0.45456385612487793, 0.5099425911903381, 0.006920217536389828, 0.0], [0.0006541880429722369, 0.0009561541373841465, 7.73017163737677e-05, 0.00942671112716198, 0.04198922589421272, 0.04971348121762276, 0.32961171865463257, 0.4513629972934723, 0.11620841920375824, 0.0], [0.017209511250257492, 0.004475452937185764, 3.128392927465029e-05, 0.00047953161993063986, 0.00448839133605361, 0.03360708802938461, 0.11509764194488525, 0.5398797988891602, 0.2847314178943634, 0.0]], [[0.20143046975135803, 0.41116827726364136, 0.09215858578681946, 0.10672477632761002, 0.06125285103917122, 0.017610367387533188, 0.01457523088902235, 0.02514597773551941, 0.06993352621793747, 0.0], [0.026864346116781235, 0.037146128714084625, 0.08411292731761932, 0.02904331497848034, 0.0955604761838913, 0.05886658653616905, 0.08584483712911606, 0.4076027572154999, 0.17495866119861603, 0.0], [0.073190838098526, 0.07998740673065186, 0.05594569817185402, 0.03243006020784378, 0.10037493705749512, 0.13878461718559265, 0.15250830352306366, 0.25721096992492676, 0.10956726223230362, 0.0], [0.0438627265393734, 0.04628896340727806, 0.4038660526275635, 0.005475929472595453, 0.03436022624373436, 0.11165640503168106, 0.02260321006178856, 0.28233063220977783, 0.04955587536096573, 0.0], [0.2377929538488388, 0.08882997930049896, 0.12371516227722168, 0.08651548624038696, 0.015416872687637806, 0.04211122542619705, 0.16403844952583313, 0.11833071708679199, 0.12324906885623932, 0.0], [0.023254310712218285, 0.0034057339653372765, 0.036038532853126526, 0.009054891765117645, 0.0329253226518631, 0.05284882336854935, 0.15671837329864502, 0.6067742109298706, 0.07897992432117462, 0.0], [0.015282228589057922, 0.008608018048107624, 0.08339564502239227, 0.032651614397764206, 0.21303850412368774, 0.22661514580249786, 0.21832069754600525, 0.1323210895061493, 0.06976725161075592, 0.0], [0.019424932077527046, 0.008587736636400223, 0.014951083809137344, 0.01159222237765789, 0.2890152633190155, 0.2543036639690399, 0.2561561167240143, 0.0882645845413208, 0.05770434811711311, 0.0], [0.020595766603946686, 0.015824340283870697, 0.008689227513968945, 0.03796549141407013, 0.3004503846168518, 0.16956602036952972, 0.10506420582532883, 0.05004280060529709, 0.2918018400669098, 0.0], [0.18154361844062805, 0.0977708026766777, 0.20556335151195526, 0.05251142755150795, 0.13640889525413513, 0.06629360467195511, 0.06030320003628731, 0.08172836154699326, 0.11787670105695724, 0.0]], [[0.07673492282629013, 0.03585591912269592, 0.0804624855518341, 0.05707075819373131, 0.16190174221992493, 0.1288135051727295, 0.1235240250825882, 0.06807681918144226, 0.2675597667694092, 0.0], [0.005086997989565134, 0.014635499566793442, 0.013461720198392868, 0.6349815726280212, 0.14714521169662476, 0.015218403190374374, 0.01605474203824997, 0.018318237736821175, 0.1350976973772049, 0.0], [0.03515003249049187, 0.049813926219940186, 0.04029693454504013, 0.4151618778705597, 0.24873343110084534, 0.009437951259315014, 0.008381601423025131, 0.020832136273384094, 0.17219208180904388, 0.0], [0.06722414493560791, 0.13528113067150116, 0.06224377825856209, 0.18915168941020966, 0.17580503225326538, 0.07229694724082947, 0.012536793015897274, 0.09137610346078873, 0.19408434629440308, 0.0], [0.09099949151277542, 0.09548961371183395, 0.04829362779855728, 0.1739831268787384, 0.06667517125606537, 0.05157051607966423, 0.05465595796704292, 0.06177656352519989, 0.3565560579299927, 0.0], [0.09822985529899597, 0.05441536381840706, 0.039150238037109375, 0.06369251012802124, 0.05292840674519539, 0.050128646194934845, 0.044398434460163116, 0.04042055085301399, 0.5566359758377075, 0.0], [0.012019939720630646, 0.0076602306216955185, 0.02716030552983284, 0.03984800726175308, 0.09776019304990768, 0.05175628885626793, 0.08536165207624435, 0.0944109782576561, 0.5840223431587219, 0.0], [0.036716632544994354, 0.021969007328152657, 0.010507079772651196, 0.012404722161591053, 0.040125522762537, 0.010736462660133839, 0.018730206415057182, 0.030387653037905693, 0.8184227347373962, 0.0], [0.04769879952073097, 0.19333122670650482, 0.02803504839539528, 0.016029207035899162, 0.11119306832551956, 0.03845509514212608, 0.011404097080230713, 0.0836206004023552, 0.4702327847480774, 0.0], [0.05245642364025116, 0.013315027579665184, 0.012056763283908367, 0.004825723823159933, 0.015483945608139038, 0.032884638756513596, 0.027794960886240005, 0.07057305425405502, 0.7706093788146973, 0.0]], [[0.05745904520153999, 0.06613133102655411, 0.11319872736930847, 0.031750500202178955, 0.0641264021396637, 0.07090476900339127, 0.053613319993019104, 0.1108509749174118, 0.4319649040699005, 0.0], [0.12783250212669373, 0.16847258806228638, 0.08126984536647797, 0.10575822740793228, 0.03301985561847687, 0.2111520618200302, 0.10687874257564545, 0.06316707283258438, 0.10244929045438766, 0.0], [0.1413263976573944, 0.38601601123809814, 0.16798537969589233, 0.14611834287643433, 0.015951359644532204, 0.042198505252599716, 0.016183707863092422, 0.06246974319219589, 0.021750787273049355, 0.0], [0.020376645028591156, 0.008152640424668789, 0.04579228535294533, 0.022974595427513123, 0.007921000011265278, 0.11700868606567383, 0.010826223529875278, 0.7216546535491943, 0.04529344290494919, 0.0], [0.04728184640407562, 0.041129130870103836, 0.12847241759300232, 0.038289085030555725, 0.07389654964208603, 0.11478690057992935, 0.04442784935235977, 0.41169247031211853, 0.1000237911939621, 0.0], [0.016180921345949173, 0.005130380857735872, 0.21081623435020447, 0.00797765702009201, 0.04691680520772934, 0.052309177815914154, 0.2947923243045807, 0.34133997559547424, 0.02453651838004589, 0.0], [0.006579844746738672, 0.001606129459105432, 0.206822007894516, 0.017204096540808678, 0.13898226618766785, 0.09910376369953156, 0.4235020577907562, 0.05497713387012482, 0.051222700625658035, 0.0], [0.00896216370165348, 0.0023249718360602856, 0.0226416178047657, 0.05458173528313637, 0.07694459706544876, 0.29436299204826355, 0.36870595812797546, 0.12525610625743866, 0.046219732612371445, 0.0], [0.027829669415950775, 0.014619122259318829, 0.014550572261214256, 0.048137370496988297, 0.15001901984214783, 0.11716196686029434, 0.34159788489341736, 0.1513865739107132, 0.13469791412353516, 0.0], [0.0014273751294240355, 0.003807784290984273, 0.3760293126106262, 0.002253596903756261, 0.11343870311975479, 0.12883712351322174, 0.04242479428648949, 0.28902071714401245, 0.042760640382766724, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9917634725570679, 0.008236419409513474, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.711856484413147, 0.20838035643100739, 0.07976315170526505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6327172517776489, 0.1227935329079628, 0.21565596759319305, 0.028833283111453056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3586137592792511, 0.038762304931879044, 0.08015953004360199, 0.4233120083808899, 0.09915236383676529, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7095601558685303, 0.03453405201435089, 0.02220289036631584, 0.009008818306028843, 0.201883926987648, 0.022810086607933044, 0.0, 0.0, 0.0, 0.0], [0.5828825831413269, 0.02795644849538803, 0.054448600858449936, 0.01975347101688385, 0.11504233628511429, 0.08908692002296448, 0.11082970350980759, 0.0, 0.0, 0.0], [0.4315364956855774, 0.020537925884127617, 0.01659376546740532, 0.014654956758022308, 0.13063199818134308, 0.27319464087486267, 0.08869150280952454, 0.024158723652362823, 0.0, 0.0], [0.26020547747612, 0.014821716584265232, 0.01224969606846571, 0.0724530965089798, 0.10939211398363113, 0.19152909517288208, 0.10495918244123459, 0.1680101454257965, 0.06637949496507645, 0.0], [0.6687084436416626, 0.04345089942216873, 0.009689688682556152, 0.0018685735994949937, 0.0738394483923912, 0.12735962867736816, 0.025320274755358696, 0.026545442640781403, 0.020931225270032883, 0.0022863498888909817]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9482711553573608, 0.051728855818510056, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8711318373680115, 0.04994085431098938, 0.07892734557390213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7221198678016663, 0.040686361491680145, 0.06532222777605057, 0.17187155783176422, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5948007702827454, 0.036634139716625214, 0.02264709398150444, 0.035541336983442307, 0.3103766441345215, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6650473475456238, 0.01644211634993553, 0.019737746566534042, 0.0375308021903038, 0.10231779515743256, 0.15892422199249268, 0.0, 0.0, 0.0, 0.0], [0.36675524711608887, 0.04118315875530243, 0.02765432558953762, 0.03228116035461426, 0.11875578761100769, 0.12892943620681763, 0.2844408452510834, 0.0, 0.0, 0.0], [0.19659309089183807, 0.015950728207826614, 0.02453998662531376, 0.039237309247255325, 0.037656329572200775, 0.34599894285202026, 0.23759640753269196, 0.10242718458175659, 0.0, 0.0], [0.3881740868091583, 0.012267092242836952, 0.01897304505109787, 0.013982790522277355, 0.030991200357675552, 0.10819684714078903, 0.20157809555530548, 0.14642520248889923, 0.07941170781850815, 0.0], [0.11410266160964966, 0.03479800745844841, 0.043540675193071365, 0.021180409938097, 0.03197954222559929, 0.2248576581478119, 0.12852585315704346, 0.2089216560125351, 0.039846520870923996, 0.1522471308708191]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.993086576461792, 0.0069133141078054905, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9852874875068665, 0.011381878517568111, 0.0033306065015494823, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4834398031234741, 0.011301998049020767, 0.48758530616760254, 0.017672834917902946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9851425886154175, 0.0010397545993328094, 0.00470126885920763, 0.0012236799811944366, 0.007892588153481483, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6588926315307617, 0.005506628658622503, 0.021607331931591034, 0.010738613083958626, 0.07747143507003784, 0.2257833182811737, 0.0, 0.0, 0.0, 0.0], [0.13557791709899902, 0.018924091011285782, 0.02187344618141651, 0.015362304635345936, 0.11512601375579834, 0.14739760756492615, 0.5457385182380676, 0.0, 0.0, 0.0], [0.38992705941200256, 0.021535715088248253, 0.005403842777013779, 0.0032997699454426765, 0.4358868896961212, 0.06306594610214233, 0.03204012289643288, 0.04884066432714462, 0.0, 0.0], [0.81478351354599, 0.022238636389374733, 0.0008386021945625544, 0.01924033649265766, 0.06109088659286499, 0.020853841677308083, 0.014834966510534286, 0.028932424262166023, 0.017186695709824562, 0.0], [0.011323019862174988, 0.004743177909404039, 0.004908193834125996, 0.04389021545648575, 0.9175272583961487, 0.008399821817874908, 0.00010120288789039478, 0.0007724545430392027, 0.001946530188433826, 0.006388010922819376]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9621535539627075, 0.037846412509679794, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5398231148719788, 0.4385344386100769, 0.021642372012138367, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6502059698104858, 0.16868625581264496, 0.04876677691936493, 0.13234086334705353, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5965072512626648, 0.06637387722730637, 0.1054789125919342, 0.1866345852613449, 0.04500538855791092, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3253602683544159, 0.03396952152252197, 0.02178867906332016, 0.07780158519744873, 0.04822142422199249, 0.49285849928855896, 0.0, 0.0, 0.0, 0.0], [0.2524598240852356, 0.04065639525651932, 0.06012948602437973, 0.022925280034542084, 0.0371418297290802, 0.17370767891407013, 0.41297948360443115, 0.0, 0.0, 0.0], [0.03411499038338661, 0.003937003668397665, 0.005961195565760136, 0.01710909977555275, 0.011033114977180958, 0.7081340551376343, 0.13750500977039337, 0.08220544457435608, 0.0, 0.0], [0.42400264739990234, 0.02131979539990425, 0.017963027581572533, 0.01083337515592575, 0.019156770780682564, 0.14712399244308472, 0.1343262642621994, 0.19853995740413666, 0.02673417516052723, 0.0], [0.010900852270424366, 0.01643177680671215, 0.007438827771693468, 0.037741534411907196, 0.0038807683158665895, 0.513563871383667, 0.17121337354183197, 0.14364023506641388, 0.04466766491532326, 0.050521109253168106]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4730486273765564, 0.5269513726234436, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39858773350715637, 0.07930062711238861, 0.5221116542816162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5825604200363159, 0.08404675871133804, 0.15067298710346222, 0.182719886302948, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.29498350620269775, 0.03899451717734337, 0.00506106112152338, 0.006130008026957512, 0.6548308730125427, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13055028021335602, 0.007264712825417519, 0.014658198691904545, 0.03852052241563797, 0.6908979415893555, 0.11810839176177979, 0.0, 0.0, 0.0, 0.0], [0.6701509952545166, 0.016114505007863045, 0.009837295860052109, 0.013812566176056862, 0.10121432691812515, 0.04637172445654869, 0.14249859750270844, 0.0, 0.0, 0.0], [0.15980258584022522, 0.02680308185517788, 0.03885137289762497, 0.01341771800071001, 0.16442187130451202, 0.12716332077980042, 0.3698134124279022, 0.09972671419382095, 0.0, 0.0], [0.5671898722648621, 0.0029452391900122166, 0.0006932761170901358, 0.0009682640084065497, 0.008882325142621994, 0.018135691061615944, 0.19489231705665588, 0.1878870278596878, 0.01840599626302719, 0.0], [0.10793960839509964, 0.02733222208917141, 0.05983218923211098, 0.007959540002048016, 0.012123869732022285, 0.0992540642619133, 0.031409986317157745, 0.1074245497584343, 0.5389924645423889, 0.007731476798653603]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9029706120491028, 0.09702935069799423, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0243820920586586, 0.026593990623950958, 0.9490237236022949, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.002445725491270423, 0.01137782447040081, 0.2685152590274811, 0.7176609635353088, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013372139073908329, 0.017163371667265892, 0.023703746497631073, 0.029362313449382782, 0.916398286819458, 0.0, 0.0, 0.0, 0.0, 0.0], [0.009483089670538902, 0.0015323974657803774, 0.016186771914362907, 0.02369842305779457, 0.15252061188220978, 0.7965786457061768, 0.0, 0.0, 0.0, 0.0], [0.9718639850616455, 0.0004310244112275541, 0.00011954548244830221, 0.007853196933865547, 0.005200029816478491, 0.0034086843952536583, 0.011123435571789742, 0.0, 0.0, 0.0], [0.24338993430137634, 0.0009381886338815093, 0.001691973302513361, 0.004991883412003517, 0.06480661034584045, 0.02667633630335331, 0.5911706686019897, 0.06633439660072327, 0.0, 0.0], [0.9644694328308105, 0.00020931981271132827, 0.00022034233552403748, 0.001116775325499475, 0.0005140798166394234, 0.011200232431292534, 0.006607241928577423, 0.015303434804081917, 0.00035933865001425147, 0.0], [6.446504994528368e-05, 5.223282641964033e-05, 4.761212403536774e-05, 0.0026887860149145126, 0.9879595041275024, 8.169181819539517e-05, 3.4432316169841215e-05, 0.00022215544595383108, 0.008540215902030468, 0.0003085293574258685]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9361864924430847, 0.06381344050168991, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8988155722618103, 0.08033642917871475, 0.020848000422120094, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9107179045677185, 0.0414138063788414, 0.01669401116669178, 0.031174303963780403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7451518774032593, 0.09319417923688889, 0.038068220019340515, 0.07867664098739624, 0.04490913450717926, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7448095083236694, 0.053781699389219284, 0.02206255868077278, 0.051568105816841125, 0.050503022968769073, 0.07727508991956711, 0.0, 0.0, 0.0, 0.0], [0.49530187249183655, 0.08432565629482269, 0.024537190794944763, 0.03536847233772278, 0.04101351276040077, 0.2942921817302704, 0.025161173194646835, 0.0, 0.0, 0.0], [0.33749768137931824, 0.0470278225839138, 0.025994539260864258, 0.11184448003768921, 0.035708073526620865, 0.3288814127445221, 0.052594561129808426, 0.06045151129364967, 0.0, 0.0], [0.5827996730804443, 0.037185750901699066, 0.025691334158182144, 0.040444474667310715, 0.032313525676727295, 0.15237869322299957, 0.02532070316374302, 0.06300554424524307, 0.040860243141651154, 0.0], [0.17712005972862244, 0.06858222186565399, 0.023361776024103165, 0.06553570926189423, 0.015878353267908096, 0.40178799629211426, 0.03335757926106453, 0.09457883983850479, 0.06679747253656387, 0.05300001800060272]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5274816155433655, 0.47251835465431213, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0672292709350586, 0.8769893646240234, 0.05578138306736946, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13304242491722107, 0.3016340434551239, 0.1132093071937561, 0.45211419463157654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3050541281700134, 0.12257003039121628, 0.15977424383163452, 0.12758392095565796, 0.2850176990032196, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22356735169887543, 0.03928647190332413, 0.007754397578537464, 0.009327426552772522, 0.12143179029226303, 0.5986325740814209, 0.0, 0.0, 0.0, 0.0], [0.21462960541248322, 0.06222677230834961, 0.03770677372813225, 0.020617984235286713, 0.1298619657754898, 0.16450734436511993, 0.3704494833946228, 0.0, 0.0, 0.0], [0.03640613704919815, 0.010346643626689911, 0.00673291739076376, 0.007102567236870527, 0.047351155430078506, 0.07502260059118271, 0.1735789030790329, 0.6434589624404907, 0.0, 0.0], [0.10655857622623444, 0.005401281639933586, 0.008467022329568863, 0.004935698118060827, 0.02920999936759472, 0.06761414557695389, 0.11367721855640411, 0.6410130262374878, 0.02312297187745571, 0.0], [0.022663118317723274, 0.01638328656554222, 0.016234278678894043, 0.013438239693641663, 0.13762539625167847, 0.1316443532705307, 0.10126813501119614, 0.1815005987882614, 0.10885223746299744, 0.27039045095443726]]], [[[0.5484673976898193, 0.11615661531686783, 0.018765881657600403, 0.05580288916826248, 0.007166780531406403, 0.0032917018979787827, 0.014802427962422371, 0.009165346622467041, 0.2263808399438858, 0.0], [0.05529153719544411, 0.9114729762077332, 0.02160518430173397, 0.004567866213619709, 0.0020856212358921766, 0.002110318513587117, 7.9427394666709e-05, 0.0003330546023789793, 0.002454147208482027, 0.0], [0.05249117314815521, 0.7437799572944641, 0.20206855237483978, 0.000259216787526384, 5.300459815771319e-05, 0.0010736108524724841, 1.3093422239762731e-05, 5.472580232890323e-05, 0.00020689083612523973, 0.0], [0.00011124516458949074, 0.00014391898002941161, 0.002859711181372404, 0.9545366168022156, 0.03663090988993645, 0.0007743826135993004, 9.755761857377365e-05, 0.004461521748453379, 0.00038416660390794277, 0.0], [0.00743946572765708, 0.0006693374016322196, 0.030706975609064102, 0.7693524360656738, 0.13636630773544312, 0.010656076483428478, 0.00020547783060465008, 0.04430907592177391, 0.0002948205510620028, 0.0], [0.34355977177619934, 0.05352164804935455, 0.28276264667510986, 0.0013746530748903751, 0.09131773561239243, 0.14369648694992065, 0.012657100334763527, 0.0018211203860118985, 0.06928855180740356, 0.0], [0.00012065855844411999, 1.9836104911519215e-05, 1.09456195787061e-05, 1.777518991730176e-05, 5.799566861242056e-05, 0.00024955562548711896, 0.9993764758110046, 2.622428155518719e-06, 0.00014374956663232297, 0.0], [0.07937772572040558, 0.06896578520536423, 0.046331409364938736, 0.0006753505440428853, 0.10991709679365158, 0.07862550020217896, 0.015291865915060043, 0.024935398250818253, 0.575879693031311, 0.0], [0.0013021372724324465, 0.0014324801741167903, 0.001721968175843358, 0.0011953436769545078, 0.025524066761136055, 0.0017154657980427146, 0.004751213360577822, 0.06856247782707214, 0.8937948942184448, 0.0], [0.19237911701202393, 0.08248087018728256, 0.005975060164928436, 0.0005637910799123347, 0.0032617340330034494, 0.08159960061311722, 0.10672623664140701, 0.06415636837482452, 0.4628572165966034, 0.0]], [[0.02611132338643074, 0.03929625079035759, 0.13154497742652893, 0.0362381711602211, 0.41634607315063477, 0.056842975318431854, 0.07852181792259216, 0.04581860452890396, 0.16927982866764069, 0.0], [0.02255222387611866, 0.10361888259649277, 0.8158259987831116, 0.004235901869833469, 0.006766538135707378, 0.000986771541647613, 0.0032120062969624996, 0.001885716337710619, 0.04091576859354973, 0.0], [0.017373425886034966, 0.13537107408046722, 0.8077713847160339, 0.0038886782713234425, 0.000276644917903468, 0.0006381930434145033, 0.00045153097016736865, 0.00025059780455194414, 0.0339784249663353, 0.0], [0.005007221829146147, 0.01780957728624344, 0.01267488207668066, 0.04065018519759178, 0.30516332387924194, 0.0026367236860096455, 0.0019572318997234106, 0.004150545224547386, 0.6099502444267273, 0.0], [0.00225767120718956, 0.0031865497585386038, 0.001125291339121759, 0.016497144475579262, 0.8971690535545349, 0.00343481102026999, 0.003961168695241213, 0.006191920954734087, 0.0661763921380043, 0.0], [0.009628614410758018, 0.010054895654320717, 0.001336919842287898, 0.0704738199710846, 0.6877674460411072, 0.03301373869180679, 0.05187760666012764, 0.005273953080177307, 0.1305730938911438, 0.0], [0.004103087354451418, 0.0010862533235922456, 0.0006940921302884817, 0.005870609078556299, 0.43826234340667725, 0.030803751200437546, 0.2956492602825165, 0.002342070685699582, 0.2211885154247284, 0.0], [0.0007035931921564043, 0.0015657383482903242, 0.0003329406026750803, 0.025085464119911194, 0.8715798258781433, 0.006046876311302185, 0.002586639951914549, 0.00011169366916874424, 0.09198720753192902, 0.0], [0.0021814475767314434, 0.0018482444575056434, 0.02461252734065056, 0.02290530502796173, 0.17733190953731537, 0.007551506161689758, 0.026218494400382042, 0.1859409213066101, 0.5514096021652222, 0.0], [0.002560347318649292, 0.0069580040872097015, 0.0021583843044936657, 0.002428637584671378, 0.010794135741889477, 0.002866419730708003, 0.010929176583886147, 0.004671781323850155, 0.9566330909729004, 0.0]], [[0.10015174746513367, 0.09576243162155151, 0.03824060782790184, 0.020538825541734695, 0.027732321992516518, 0.017240623012185097, 0.011470633558928967, 0.4632270634174347, 0.22563566267490387, 0.0], [0.018642960116267204, 0.04822106286883354, 0.0140958521515131, 0.13092586398124695, 0.031955283135175705, 0.05195324495434761, 0.024238048121333122, 0.35591286420822144, 0.32405486702919006, 0.0], [0.04842180013656616, 0.06019889563322067, 0.016854524612426758, 0.166769877076149, 0.016003064811229706, 0.024013301357626915, 0.03686072677373886, 0.44016721844673157, 0.1907106339931488, 0.0], [0.0036558371502906084, 0.0033082098234444857, 0.0009605223312973976, 0.017093589529395103, 0.019939379766583443, 0.08280718326568604, 0.031923551112413406, 0.703184187412262, 0.1371275782585144, 0.0], [0.0018930931109935045, 0.002881440566852689, 0.00019882648484781384, 0.00406575808301568, 0.0021070034708827734, 0.011610278859734535, 0.0074381148442626, 0.9341073036193848, 0.03569793701171875, 0.0], [0.00691588269546628, 0.02838265150785446, 0.015397720038890839, 0.031874921172857285, 0.04765379801392555, 0.22744230926036835, 0.06624653190374374, 0.10724947601556778, 0.4688366651535034, 0.0], [0.009144916199147701, 0.012914983555674553, 0.0114166010171175, 0.010616455227136612, 0.03852293640375137, 0.11398687958717346, 0.23996756970882416, 0.03855413943529129, 0.5248754620552063, 0.0], [0.0165528766810894, 0.08396174013614655, 0.03695421293377876, 0.012792840600013733, 0.05054211989045143, 0.004681664984673262, 0.006349458359181881, 0.0059485542587935925, 0.7822163701057434, 0.0], [0.01941962167620659, 0.0720844566822052, 0.06703408807516098, 0.0024893011432141066, 0.09017500281333923, 0.01547347940504551, 0.011082785204052925, 0.036743972450494766, 0.6854971051216125, 0.0], [0.010559813119471073, 0.7681021094322205, 0.01782229356467724, 0.0007385257631540298, 0.000383153063012287, 0.00014055910287424922, 0.00037340103881433606, 0.000453647633548826, 0.2014264017343521, 0.0]], [[0.15394526720046997, 0.03155883774161339, 0.013263775035738945, 0.6156436800956726, 0.07578981667757034, 0.016770629212260246, 0.041555847972631454, 0.008898256346583366, 0.042573899030685425, 0.0], [0.0076131573878228664, 0.0034139170311391354, 0.013692762702703476, 0.9489110708236694, 0.0023491643369197845, 0.004504398908466101, 0.018228650093078613, 6.88576401444152e-05, 0.0012180121848359704, 0.0], [0.002593724289909005, 0.0010450187837705016, 0.004842442460358143, 0.9895163178443909, 0.0010734394891187549, 6.863682210678235e-05, 0.00040535334846936166, 0.00029785462538711727, 0.00015763564442750067, 0.0], [0.007949098013341427, 0.0007930149440653622, 0.0010613474296405911, 0.913150429725647, 0.017281265929341316, 0.01414033118635416, 0.03622613847255707, 0.0036093932576477528, 0.0057889497838914394, 0.0], [0.002646723063662648, 0.0005160199943929911, 0.0002488561731297523, 0.0025351925287395716, 0.0016247049206867814, 0.09429789334535599, 0.8856176137924194, 0.005861275363713503, 0.006651633884757757, 0.0], [0.010305220261216164, 0.0041244118474423885, 0.0009454450337216258, 0.011387528851628304, 0.006450551562011242, 0.09920497238636017, 0.7582080960273743, 0.0005519646219909191, 0.1088215708732605, 0.0], [0.02700764685869217, 0.004230276681482792, 0.0004602614790201187, 0.0022337904665619135, 0.001628970610909164, 0.01760227419435978, 0.5739604234695435, 0.0034094173461198807, 0.3694668710231781, 0.0], [0.008519366383552551, 0.005846879445016384, 0.00031929058604873717, 0.00022687541786581278, 0.0001488836423959583, 0.0012441301951184869, 0.007195098325610161, 0.000364138453733176, 0.9761351943016052, 0.0], [0.10429845005273819, 0.062129467725753784, 0.0009245545952580869, 0.00015166438242886215, 0.00031537580071017146, 0.00040291156619787216, 0.006900690030306578, 0.009933815337717533, 0.8149431943893433, 0.0], [0.021318454295396805, 0.024646490812301636, 0.0006273255567066371, 3.4892458643298596e-05, 0.0002248335222247988, 0.001184952794574201, 0.005942351184785366, 0.01648845337331295, 0.9295321702957153, 0.0]], [[0.18210574984550476, 0.33179324865341187, 0.17356830835342407, 0.09634877741336823, 0.13200198113918304, 0.013823754154145718, 0.003925282042473555, 0.0030049949418753386, 0.06342781335115433, 0.0], [0.0020119324326515198, 0.018535200506448746, 0.9658629894256592, 0.007450602483004332, 0.004517300054430962, 0.0010996937053278089, 0.00011890578753082082, 8.778373739914969e-05, 0.0003156243183184415, 0.0], [0.00025491390260867774, 0.010184208862483501, 0.989091694355011, 0.0003856455150526017, 4.125349732930772e-05, 1.630889528314583e-05, 4.754766450787429e-06, 9.06635705177905e-06, 1.2339231943769846e-05, 0.0], [0.0005263744969852269, 0.0021082928869873285, 0.03339724615216255, 0.9491472840309143, 0.010056117549538612, 0.0008323733345605433, 0.00041247360059060156, 0.003171485150232911, 0.0003482940956018865, 0.0], [0.004013302735984325, 0.003607808379456401, 0.10117900371551514, 0.3848154544830322, 0.4549750089645386, 0.02184862084686756, 0.015023248270154, 0.013938427902758121, 0.0005991325015202165, 0.0], [0.006630904506891966, 0.001555793103761971, 0.01566290855407715, 0.005377574823796749, 0.0545264296233654, 0.7578195929527283, 0.15542279183864594, 0.00011175184772582725, 0.002892365213483572, 0.0], [0.0013706677127629519, 0.0003565592342056334, 0.0006504033226519823, 0.0008717189775779843, 0.023110924288630486, 0.16852477192878723, 0.8020843863487244, 0.0004564319388009608, 0.0025740356650203466, 0.0], [0.0005740747437812388, 0.0018384596332907677, 0.015691960230469704, 0.0004515495093073696, 0.04004881531000137, 0.8668573498725891, 0.03566786274313927, 0.01278533972799778, 0.0260846596211195, 0.0], [0.0008355869795195758, 0.003608973463997245, 0.04490630701184273, 0.009341607801616192, 0.007649072911590338, 0.10034366697072983, 0.06446904689073563, 0.7009655237197876, 0.06788014620542526, 0.0], [0.00805425550788641, 0.039243537932634354, 0.05003930628299713, 0.0007152591715566814, 0.00863983016461134, 0.4756345748901367, 0.24407540261745453, 0.1204291433095932, 0.053168926388025284, 0.0]], [[0.7289455533027649, 0.07586073875427246, 0.09869885444641113, 0.029881592839956284, 0.013988524675369263, 0.006547953933477402, 0.011115974746644497, 0.01961168274283409, 0.01534893549978733, 0.0], [0.0068419137969613075, 0.9909167289733887, 0.0013876587618142366, 0.00011136479588458315, 6.257302447920665e-05, 6.40047510387376e-05, 1.7212767488672398e-05, 0.00025805848417803645, 0.0003405954339541495, 0.0], [0.0458948090672493, 0.09657034277915955, 0.8496794104576111, 0.005955036263912916, 0.00011324687511660159, 0.000537818530574441, 0.00024974963162094355, 6.34682146483101e-05, 0.0009358474053442478, 0.0], [0.0026473035104572773, 0.007075308356434107, 0.0509142205119133, 0.9043685793876648, 0.01687121018767357, 0.0027590824756771326, 0.0040096985176205635, 0.004973408300429583, 0.006381142418831587, 0.0], [0.00491793267428875, 0.0006714555202051997, 0.0047717769630253315, 0.09624139964580536, 0.8609471917152405, 0.020327016711235046, 0.008984015323221684, 0.0013932499568909407, 0.0017457373905926943, 0.0], [0.003878936870023608, 0.005480882711708546, 0.00011314810399198905, 0.0003485401102807373, 0.006120527163147926, 0.0029893070459365845, 0.0006264422554522753, 0.004414959345012903, 0.9760271906852722, 0.0], [2.4277944248751737e-05, 4.029595402244013e-06, 3.533453991622082e-07, 0.0002488488098606467, 2.0782925275852904e-05, 0.0004858894390054047, 0.9990906119346619, 4.08113919547759e-05, 8.422240352956578e-05, 0.0], [0.0008877408690750599, 0.00558891985565424, 8.855570922605693e-05, 8.779557902016677e-06, 0.00010457105963723734, 0.00017662049503996968, 0.002778601599857211, 0.09916532039642334, 0.8912010192871094, 0.0], [0.0001188771057059057, 0.0008492054184898734, 9.383026917930692e-05, 5.974515715934103e-06, 0.002050562761723995, 8.90250812517479e-05, 8.933644130593166e-05, 0.002921103034168482, 0.9937818646430969, 0.0], [0.013618292286992073, 0.005976812914013863, 3.6079491110285744e-05, 4.805085336556658e-05, 0.00010178168304264545, 0.03545643016695976, 0.04239484667778015, 0.35667654871940613, 0.5456912517547607, 0.0]], [[0.5123088955879211, 0.04897892847657204, 0.0054785809479653835, 0.037157464772462845, 0.033040400594472885, 0.0287709329277277, 0.020658228546380997, 0.005767439026385546, 0.3078390061855316, 0.0], [0.02116605080664158, 0.45384567975997925, 0.5185156464576721, 0.002682638820260763, 0.000782136688940227, 0.0011598592391237617, 9.265153494197875e-05, 0.0013774167746305466, 0.00037764458102174103, 0.0], [0.06292606890201569, 0.18043853342533112, 0.7547035217285156, 0.0011577572440728545, 6.746899998688605e-06, 0.00012623713701032102, 8.539375994587317e-05, 0.0005039210664108396, 5.15973697474692e-05, 0.0], [0.0030907110776752234, 0.0035500964149832726, 0.4530283808708191, 0.5221376419067383, 0.009557867422699928, 0.008033420890569687, 0.0001996932114707306, 0.0003745300055015832, 2.7415442673373036e-05, 0.0], [0.00426989421248436, 0.0004077394842170179, 0.010691642761230469, 0.016011668369174004, 0.5530933737754822, 0.12423280626535416, 0.0053755901753902435, 0.28551235795021057, 0.00040478314622305334, 0.0], [0.00216855201870203, 0.009595326147973537, 0.007803121581673622, 0.04625817388296127, 0.24702508747577667, 0.2669595181941986, 0.024053409695625305, 0.24639348685741425, 0.14974308013916016, 0.0], [1.953603486981592e-06, 5.726202516598278e-07, 5.0084551617146644e-08, 6.896097602293594e-06, 0.0001788837107596919, 0.0027895711828023195, 0.9969833493232727, 1.1296948287053965e-05, 2.7187752493773587e-05, 0.0], [0.00041725789196789265, 0.0019265476148575544, 6.523763295263052e-05, 0.00018337361689191312, 0.011946662329137325, 0.04555974155664444, 0.15744170546531677, 0.025624049827456474, 0.7568355202674866, 0.0], [0.0004706757317762822, 0.0027324894908815622, 0.0007427418022416532, 0.00934627279639244, 0.17134670913219452, 0.030644211918115616, 0.08413954824209213, 0.2513456642627716, 0.4492316246032715, 0.0], [0.0002758087939582765, 0.0016877831658348441, 2.4452297111565713e-06, 0.0004533462051767856, 0.001545731327496469, 0.008134560659527779, 0.010873721912503242, 0.026235109195113182, 0.950791597366333, 0.0]], [[0.23423711955547333, 0.3770143389701843, 0.036408282816410065, 0.05249679461121559, 0.12246986478567123, 0.07398416101932526, 0.040104977786540985, 0.05500312149524689, 0.008280999027192593, 0.0], [0.032638370990753174, 0.0975130945444107, 0.07180608063936234, 0.1075734794139862, 0.025216424837708473, 0.0218534916639328, 0.0376754030585289, 0.19710108637809753, 0.40862247347831726, 0.0], [0.06665100902318954, 0.01976630836725235, 0.13041609525680542, 0.1772802174091339, 0.07561768591403961, 0.0061133550480008125, 0.022857116535305977, 0.3516288995742798, 0.14966930449008942, 0.0], [0.007724895142018795, 0.007534464355558157, 0.020593322813510895, 0.32147932052612305, 0.059838466346263885, 0.017819387838244438, 0.13181470334529877, 0.15524767339229584, 0.27794769406318665, 0.0], [0.0028419073205441236, 0.0006648481939919293, 0.0018234169110655785, 0.01609039306640625, 0.0009005893371067941, 0.09726841002702713, 0.11035522073507309, 0.6978457570075989, 0.07220931351184845, 0.0], [0.0027223427314311266, 0.011157790198922157, 0.0052745467983186245, 0.00438346853479743, 0.010011360049247742, 0.38358205556869507, 0.2294219732284546, 0.057655833661556244, 0.29579076170921326, 0.0], [0.0007371494430117309, 0.00023907626746222377, 8.450529276160523e-05, 0.002850313438102603, 0.002168564358726144, 0.02191324159502983, 0.9514637589454651, 0.0002505093871150166, 0.02029278129339218, 0.0], [0.004062887746840715, 0.007024500984698534, 0.007399421185255051, 0.011675640009343624, 0.0680280476808548, 0.02557964250445366, 0.07043837755918503, 0.01946563646197319, 0.7863259315490723, 0.0], [0.0007200208492577076, 0.0024253681767731905, 0.029840486124157906, 0.00014908696175552905, 0.11268167197704315, 0.03336171433329582, 0.007834793999791145, 0.08127990365028381, 0.7317068576812744, 0.0], [0.008707311004400253, 0.08642490208148956, 0.11372587829828262, 0.004973042756319046, 0.13256150484085083, 0.16557356715202332, 0.0817113071680069, 0.006879410706460476, 0.3994430899620056, 0.0]]]], \"top_text\": [\"It\", \"is\", \"nice\", \"to\", \"learn\", \"new\", \"things\", \"today\", \"!\"], \"bot_text\": [\"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}, \"out_out\": {\"att\": [[[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9412446618080139, 0.05875528231263161, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7461972832679749, 0.18569768965244293, 0.06810508668422699, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4299372434616089, 0.16845084726810455, 0.2029547393321991, 0.19865721464157104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5215166807174683, 0.16121163964271545, 0.19463112950325012, 0.09347883611917496, 0.029161658138036728, 0.0, 0.0, 0.0, 0.0, 0.0], [0.26405569911003113, 0.04358615726232529, 0.10687251389026642, 0.1710020899772644, 0.4105237126350403, 0.0039598336443305016, 0.0, 0.0, 0.0, 0.0], [0.29189321398735046, 0.19170531630516052, 0.11295431852340698, 0.08274418860673904, 0.12850242853164673, 0.09739833325147629, 0.09480219334363937, 0.0, 0.0, 0.0], [0.3496137857437134, 0.03085259348154068, 0.0195528082549572, 0.45414459705352783, 0.09152030944824219, 0.008845902979373932, 0.02992299199104309, 0.01554702315479517, 0.0, 0.0], [0.4675538241863251, 0.03941410034894943, 0.05400091037154198, 0.17985978722572327, 0.20104949176311493, 0.030323797836899757, 0.010615098290145397, 0.015154700726270676, 0.002028239192441106, 0.0], [0.053565241396427155, 0.029699191451072693, 0.0156599972397089, 0.016939852386713028, 0.04015244543552399, 0.21933501958847046, 0.1449035257101059, 0.4037321209907532, 0.019583676010370255, 0.056428998708724976]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5249735116958618, 0.4750264883041382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3563348054885864, 0.5701623558998108, 0.07350286096334457, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3398579955101013, 0.23167477548122406, 0.1957632154226303, 0.23270410299301147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4351256191730499, 0.09737284481525421, 0.08845506608486176, 0.06574707478284836, 0.31329941749572754, 0.0, 0.0, 0.0, 0.0, 0.0], [0.360861599445343, 0.02136792428791523, 0.005633710417896509, 0.009215844795107841, 0.15762653946876526, 0.4452943205833435, 0.0, 0.0, 0.0, 0.0], [0.009015758521854877, 0.0013937305193394423, 0.00017763266805559397, 0.00016997012426145375, 0.010879353620111942, 0.0024589570239186287, 0.9759047627449036, 0.0, 0.0, 0.0], [0.014776602387428284, 0.0001805058855097741, 1.6896785382414237e-05, 0.0003442507586441934, 0.006220621056854725, 0.0012393802171573043, 0.9433164596557617, 0.033905431628227234, 0.0, 0.0], [0.005810329224914312, 0.002043980173766613, 0.0003433740057516843, 0.001522325212135911, 0.0030212807469069958, 0.00817712489515543, 0.5456522107124329, 0.10564129799604416, 0.32778817415237427, 0.0], [0.3754594326019287, 0.030579065904021263, 0.028458155691623688, 0.035943739116191864, 0.28040432929992676, 0.0202159583568573, 0.0396210215985775, 0.05075624957680702, 0.13473623991012573, 0.0038258912973105907]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9630448818206787, 0.036955028772354126, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8940342664718628, 0.015322646126151085, 0.09064316004514694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4866876006126404, 0.028273453935980797, 0.4569007158279419, 0.028138065710663795, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7252220511436462, 0.10817205905914307, 0.07890959084033966, 0.017715180292725563, 0.06998112797737122, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8598019480705261, 0.012843498960137367, 0.014502299018204212, 0.004056263715028763, 0.10580158233642578, 0.0029942472465336323, 0.0, 0.0, 0.0, 0.0], [0.8686293363571167, 0.024889284744858742, 0.013860221020877361, 0.00703870365396142, 0.07120370119810104, 0.003939351066946983, 0.010439489968121052, 0.0, 0.0, 0.0], [0.8572709560394287, 0.018014011904597282, 0.008267350494861603, 0.0022140766959637403, 0.1038530021905899, 0.004275611136108637, 0.0009780752006918192, 0.005126776173710823, 0.0, 0.0], [0.35013046860694885, 0.0037752145435661077, 0.0071558705531060696, 0.01608894392848015, 0.6097922325134277, 0.002463925164192915, 0.0005387101555243134, 0.005540961865335703, 0.004513624589890242, 0.0], [0.1888049989938736, 0.12293454259634018, 0.5947631597518921, 0.009457849897444248, 0.07291270792484283, 0.008950368501245975, 0.0004109511792194098, 0.000914009811822325, 0.0006959570455364883, 0.00015547229850199074]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.91131192445755, 0.08868805319070816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.786292314529419, 0.09286607056856155, 0.1208416074514389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1722075194120407, 0.10747934877872467, 0.1462225317955017, 0.5740904808044434, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1893281787633896, 0.1733204573392868, 0.06838839501142502, 0.47577211260795593, 0.09319086372852325, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08935888856649399, 0.012517428956925869, 0.017112966626882553, 0.08479276299476624, 0.7640082240104675, 0.03220977261662483, 0.0, 0.0, 0.0, 0.0], [0.824190616607666, 0.008810147643089294, 0.002143737394362688, 0.002297793049365282, 0.11996792256832123, 0.005709697026759386, 0.036880046129226685, 0.0, 0.0, 0.0], [0.1513449102640152, 0.015725232660770416, 0.02784004621207714, 0.01800909824669361, 0.6534391641616821, 0.016422629356384277, 0.09054289758205414, 0.026676079258322716, 0.0, 0.0], [0.1625923067331314, 0.016224535182118416, 0.06514906883239746, 0.003223034320399165, 0.6737184524536133, 0.014129054732620716, 0.036937959492206573, 0.023035621270537376, 0.004990031942725182, 0.0], [0.06836045533418655, 0.01236770860850811, 0.008784784935414791, 0.014186863787472248, 0.09790214896202087, 0.046204064041376114, 0.1703491061925888, 0.1878211945295334, 0.0703599750995636, 0.32366377115249634]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.961704432964325, 0.038295578211545944, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37462106347084045, 0.2157517969608307, 0.40962719917297363, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48521965742111206, 0.031020229682326317, 0.3760664165019989, 0.10769358277320862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.914044201374054, 0.004715718794614077, 0.006151301320642233, 0.005079128313809633, 0.07000966370105743, 0.0, 0.0, 0.0, 0.0, 0.0], [0.060511741787195206, 0.006127620115876198, 0.00728148128837347, 0.013585635460913181, 0.9084653854370117, 0.004028240218758583, 0.0, 0.0, 0.0, 0.0], [0.23348243534564972, 0.03748093172907829, 0.055222347378730774, 0.014132470823824406, 0.27614685893058777, 0.017582375556230545, 0.3659524619579315, 0.0, 0.0, 0.0], [0.06461911648511887, 0.003781915409490466, 0.002705940278246999, 0.016099220141768456, 0.8774597644805908, 0.012668337672948837, 0.0088069261983037, 0.013858767226338387, 0.0, 0.0], [0.05451222136616707, 0.014412143267691135, 0.00208102585747838, 0.011283651925623417, 0.02552390843629837, 0.02239326573908329, 0.031104939058423042, 0.20777365565299988, 0.630915105342865, 0.0], [0.5451503992080688, 0.014764615334570408, 0.2503703534603119, 0.037022024393081665, 0.0935375839471817, 0.022694993764162064, 0.0037449353840202093, 0.0053339023143053055, 0.007315538357943296, 0.020065704360604286]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9904667735099792, 0.009533224627375603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9818503260612488, 0.007338901981711388, 0.010810752399265766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9738979935646057, 0.007647394668310881, 0.015154722146689892, 0.0032999368850141764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6611008644104004, 0.04138284549117088, 0.1119912639260292, 0.0262944046407938, 0.15923058986663818, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9380988478660583, 0.005562208592891693, 0.01078465860337019, 0.004562946502119303, 0.033130958676338196, 0.007860423997044563, 0.0, 0.0, 0.0, 0.0], [0.9377894997596741, 0.003691342193633318, 0.002771170577034354, 0.0017416415503248572, 0.04246653988957405, 0.002464305842295289, 0.009075501933693886, 0.0, 0.0, 0.0], [0.9083399176597595, 0.005597027484327555, 0.02609928511083126, 0.005710097029805183, 0.017865832895040512, 0.0029857312329113483, 0.002900469582527876, 0.030501706525683403, 0.0, 0.0], [0.8338009119033813, 0.00436164066195488, 0.006190306507050991, 0.0008050849428400397, 0.015337309800088406, 0.00863864365965128, 0.010715007781982422, 0.1143304780125618, 0.005820483900606632, 0.0], [0.9085996747016907, 0.00676243519410491, 0.02013525180518627, 0.009278967045247555, 0.02104269526898861, 0.009343095123767853, 0.0009470531367696822, 0.0018253516172990203, 0.003784958738833666, 0.018280424177646637]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.972051739692688, 0.027948210015892982, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7552067041397095, 0.17251533269882202, 0.0722779706120491, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6455309987068176, 0.23265127837657928, 0.10187581926584244, 0.01994187943637371, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.470674991607666, 0.26442891359329224, 0.14268451929092407, 0.03363766148686409, 0.08857394009828568, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6457618474960327, 0.011289404705166817, 0.008832731284201145, 0.01570025272667408, 0.2588561475276947, 0.059559762477874756, 0.0, 0.0, 0.0, 0.0], [0.4916176497936249, 0.07200384140014648, 0.0701020285487175, 0.019148536026477814, 0.0833231583237648, 0.12199999392032623, 0.14180481433868408, 0.0, 0.0, 0.0], [0.11119699478149414, 0.002801541704684496, 0.0021932011004537344, 0.0016493132570758462, 0.06827285885810852, 0.22499483823776245, 0.5049597024917603, 0.08393163233995438, 0.0, 0.0], [0.13208742439746857, 0.0035411729477345943, 0.0015305017586797476, 0.002489483682438731, 0.06612236052751541, 0.213859423995018, 0.5324232578277588, 0.03503565117716789, 0.012910734862089157, 0.0], [0.20209012925624847, 0.05223073810338974, 0.03088257648050785, 0.036374326795339584, 0.014660456217825413, 0.03045688569545746, 0.03597142919898033, 0.16862399876117706, 0.022359324619174004, 0.40635016560554504]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9218347668647766, 0.0781652107834816, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4189925193786621, 0.4865715503692627, 0.09443587809801102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.48251789808273315, 0.34758540987968445, 0.13321316242218018, 0.036683470010757446, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8504839539527893, 0.033341050148010254, 0.053517427295446396, 0.012789242900907993, 0.049868300557136536, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4515743553638458, 0.03267433121800423, 0.019386781379580498, 0.024256065487861633, 0.17900733649730682, 0.29310107231140137, 0.0, 0.0, 0.0, 0.0], [0.5910289883613586, 0.0027754076290875673, 0.004533650353550911, 0.0023315453436225653, 0.08002334088087082, 0.06913208961486816, 0.2501751184463501, 0.0, 0.0, 0.0], [0.1626552939414978, 0.0011573631782084703, 0.00017211545491591096, 0.0007665579323656857, 0.03241841867566109, 0.34369325637817383, 0.2890424132347107, 0.17009468376636505, 0.0, 0.0], [0.10835989564657211, 0.0007107920246198773, 0.00030798258376307786, 0.005807099863886833, 0.04662986099720001, 0.1659584492444992, 0.3522194027900696, 0.30094781517982483, 0.019058646634221077, 0.0], [0.5449283123016357, 0.01310307253152132, 0.008020865730941296, 0.006764447782188654, 0.16009773313999176, 0.06950337439775467, 0.0024397175293415785, 0.014089844189584255, 0.013654321432113647, 0.1673980951309204]]], [[[0.03246883675456047, 0.020431363955140114, 0.06294436007738113, 0.08282972872257233, 0.047490958124399185, 0.03976213559508324, 0.01868664100766182, 0.5054241418838501, 0.18996170163154602, 0.0], [0.0334412157535553, 0.45350977778434753, 0.23828978836536407, 0.07703227549791336, 0.02545342594385147, 0.019935714080929756, 0.007961008697748184, 0.08864670246839523, 0.05572996661067009, 0.0], [0.008816813118755817, 0.009350132197141647, 0.09488566964864731, 0.022458655759692192, 0.001578008639626205, 0.01768183708190918, 0.0012928039068356156, 0.7889453768730164, 0.05499071627855301, 0.0], [0.0037117439787834883, 0.00603569345548749, 0.019362367689609528, 0.06632085889577866, 0.02251342497766018, 0.048607613891363144, 0.00711278198286891, 0.7890322804450989, 0.03730323165655136, 0.0], [0.0017165049212053418, 0.0031809706706553698, 0.00569736585021019, 0.027958940714597702, 0.001130971242673695, 0.006313299294561148, 0.004051794297993183, 0.9312260150909424, 0.018723946064710617, 0.0], [0.0028915719594806433, 0.007050157990306616, 0.004614752251654863, 0.0017270235111936927, 0.0016248916508629918, 0.06901240348815918, 0.005150379613041878, 0.13293159008026123, 0.7749972939491272, 0.0], [0.005032604560256004, 0.005055313929915428, 0.0030569147784262896, 0.0010687477188184857, 0.012304573319852352, 0.013984610326588154, 0.3489484190940857, 0.012370014563202858, 0.5981789827346802, 0.0], [0.0019784842152148485, 0.009333183988928795, 0.005381024908274412, 0.0002465381403453648, 0.0013898308388888836, 0.005461550783365965, 0.0012134313583374023, 0.001065099611878395, 0.9739308953285217, 0.0], [0.005657540168613195, 0.006781480740755796, 0.00696007814258337, 0.0009338636882603168, 0.02429838851094246, 0.03842600807547569, 0.00286328443326056, 0.03579647094011307, 0.8782829642295837, 0.0], [0.007395321968942881, 0.012293249368667603, 0.006963892374187708, 0.00022730379714630544, 0.0005401583621278405, 0.005707587581127882, 0.0028992195148020983, 0.0027063635643571615, 0.9612669944763184, 0.0]], [[0.02470340207219124, 0.02512546442449093, 0.11353036016225815, 0.35132649540901184, 0.20412008464336395, 0.027150044217705727, 0.015305055305361748, 0.05760098248720169, 0.1811380535364151, 0.0], [0.009894105605781078, 0.02192404493689537, 0.3007009029388428, 0.13983333110809326, 0.03682582825422287, 0.08908118307590485, 0.27657952904701233, 0.026430398225784302, 0.09873086214065552, 0.0], [0.011459765024483204, 0.044317521154880524, 0.5289616584777832, 0.19549138844013214, 0.03426412120461464, 0.017797794193029404, 0.030613277107477188, 0.0163635965436697, 0.12073105573654175, 0.0], [0.011578483507037163, 0.0029169816989451647, 0.00455811433494091, 0.01625976897776127, 0.018393559381365776, 0.11749742925167084, 0.32938554883003235, 0.41049671173095703, 0.08891336619853973, 0.0], [0.0033444140572100878, 0.0011373214656487107, 0.0019445078214630485, 0.02781311236321926, 0.0049105980433523655, 0.05221953243017197, 0.09222303330898285, 0.3644186854362488, 0.45198866724967957, 0.0], [0.002199131529778242, 0.0006913270917721093, 0.002652444876730442, 0.017487458884716034, 0.18746966123580933, 0.39171290397644043, 0.26989367604255676, 0.017002178356051445, 0.11089123785495758, 0.0], [0.01051913108676672, 0.003755246289074421, 0.0008555634994991124, 0.002675057854503393, 0.0025919810868799686, 0.02418649010360241, 0.018060903996229172, 0.003447937313467264, 0.9339075684547424, 0.0], [0.029951948672533035, 0.006547479424625635, 0.030934682115912437, 0.0036260345950722694, 0.1420958936214447, 0.19529034197330475, 0.1491098254919052, 0.009723717346787453, 0.43272000551223755, 0.0], [0.017757408320903778, 0.006832967512309551, 0.028906390070915222, 0.00921954121440649, 0.054915353655815125, 0.028632348403334618, 0.03646676614880562, 0.01978384144604206, 0.7974854707717896, 0.0], [0.06588920205831528, 0.05552517622709274, 0.18546447157859802, 0.007839588448405266, 0.020484987646341324, 0.01699826307594776, 0.01947665773332119, 0.017759086564183235, 0.6105626821517944, 0.0]], [[0.14391662180423737, 0.11156481504440308, 0.4162432849407196, 0.07845085859298706, 0.04067624360322952, 0.016916701570153236, 0.012291320599615574, 0.10670017451047897, 0.07323983311653137, 0.0], [0.0171683169901371, 0.03512553498148918, 0.4936983287334442, 0.18945446610450745, 0.020571058616042137, 0.011469473131000996, 0.04002959281206131, 0.08968089520931244, 0.10280223935842514, 0.0], [0.2093620002269745, 0.11281707882881165, 0.25891542434692383, 0.14515942335128784, 0.0042000748217105865, 0.006485591176897287, 0.005525505635887384, 0.14364667236804962, 0.11388827115297318, 0.0], [0.0109701631590724, 0.0007525839027948678, 0.011503712274134159, 0.03920656442642212, 0.2449047565460205, 0.048431187868118286, 0.12996943295001984, 0.4081973731517792, 0.10606419295072556, 0.0], [0.004995591007173061, 0.0001893905719043687, 0.0009439413552172482, 0.03207648918032646, 0.08267047256231308, 0.015983520075678825, 0.02033340558409691, 0.8191123604774475, 0.023694908246397972, 0.0], [0.0022357299458235502, 0.000793653482105583, 0.0010144039988517761, 0.2958794832229614, 0.3394852876663208, 0.07495945692062378, 0.06856833398342133, 0.06118563562631607, 0.15587811172008514, 0.0], [0.0020441634114831686, 0.00032311712857335806, 0.0006899640429764986, 0.03996479511260986, 0.38782593607902527, 0.05503879860043526, 0.24750953912734985, 0.004524962045252323, 0.26207876205444336, 0.0], [0.0012333561899140477, 0.0002747838443610817, 0.0023864947725087404, 0.10253860056400299, 0.4721597135066986, 0.04103615880012512, 0.03782818093895912, 0.026908699423074722, 0.31563398241996765, 0.0], [0.004791810177266598, 0.0015037101693451405, 0.004669447895139456, 0.38809871673583984, 0.13379721343517303, 0.024320820346474648, 0.03647102415561676, 0.013309511356055737, 0.3930378258228302, 0.0], [0.00849083997309208, 0.003579143201932311, 0.0033037925604730844, 0.006032468285411596, 0.017621049657464027, 0.0234503336250782, 0.018282314762473106, 0.02657976746559143, 0.8926602602005005, 0.0]], [[0.8417463898658752, 0.05951714888215065, 0.012198105454444885, 0.03180553764104843, 0.02919766865670681, 0.0096508814021945, 0.003031272441148758, 0.0009100366733036935, 0.011942943558096886, 0.0], [0.00569154741242528, 0.979739785194397, 0.012030904181301594, 0.0001143000990850851, 9.368032624479383e-05, 0.0008171445806510746, 0.00012590458209160715, 0.0005024938145652413, 0.0008843241375871003, 0.0], [0.005223963409662247, 0.005622355733066797, 0.9848889708518982, 0.002582893241196871, 0.0003334738139528781, 0.0005618981667794287, 3.256636409787461e-05, 0.00024550766102038324, 0.0005086653982289135, 0.0], [0.0032260464504361153, 0.007557107135653496, 0.0651315227150917, 0.6094849109649658, 0.008782745338976383, 0.2748804986476898, 0.015592943876981735, 0.008143502287566662, 0.007200630847364664, 0.0], [0.01683628372848034, 0.0020552987698465586, 0.00783018209040165, 0.008005303330719471, 0.0011927365558221936, 0.9284406900405884, 0.03478293865919113, 0.00030738895293325186, 0.0005490221083164215, 0.0], [0.0004254023951943964, 7.111614831956103e-05, 0.0008891545585356653, 1.880968193290755e-05, 6.570573896169662e-05, 0.9941434860229492, 0.0025632327888160944, 9.733852493809536e-06, 0.0018130606040358543, 0.0], [7.936867405078374e-06, 1.8136512153432705e-05, 4.5569290705316234e-06, 1.071940641850233e-05, 3.808495648627286e-06, 0.0008168917265720665, 0.9974388480186462, 1.4373016711033415e-05, 0.0016848900122568011, 0.0], [0.0014213839313015342, 0.003971228376030922, 0.008488249033689499, 2.0282970581320114e-05, 8.774230809649453e-05, 0.030342059209942818, 0.010436602868139744, 0.013138609007000923, 0.9320940375328064, 0.0], [9.058997966349125e-05, 0.0009022729936987162, 0.0017266678623855114, 1.3629892237077001e-05, 0.000727150880265981, 0.002379553159698844, 0.0010508937994018197, 0.012508089654147625, 0.9806011319160461, 0.0], [0.0003429521748330444, 0.001905322540551424, 0.0005013775080442429, 1.1471392099338118e-05, 0.00017356597527395934, 0.0029742273036390543, 0.003938945475965738, 0.028075864538550377, 0.9620763063430786, 0.0]], [[0.23634016513824463, 0.09021607041358948, 0.12040459364652634, 0.01354933436959982, 0.0019137230701744556, 0.009001325815916061, 0.028688833117485046, 0.2612648904323578, 0.23862121999263763, 0.0], [0.2307557761669159, 0.2812652289867401, 0.30346915125846863, 0.05031246319413185, 0.006193350534886122, 0.01668362505733967, 0.012607063166797161, 0.07951408624649048, 0.019199388101696968, 0.0], [0.29960742592811584, 0.20819564163684845, 0.27825382351875305, 0.007396433036774397, 0.0007608149899169803, 0.0260151494294405, 0.012685009278357029, 0.12934625148773193, 0.03773954138159752, 0.0], [0.035675279796123505, 0.035874202847480774, 0.007117687724530697, 0.018771182745695114, 0.010206644423305988, 0.06527784466743469, 0.03775254264473915, 0.7770709991455078, 0.012253628112375736, 0.0], [0.012017791159451008, 0.0028583300299942493, 0.0024127706419676542, 0.002610970288515091, 0.001820205245167017, 0.04092223569750786, 0.016621166840195656, 0.9115477800369263, 0.009188669733703136, 0.0], [0.03447290509939194, 0.013388306833803654, 0.08488336205482483, 0.015237652696669102, 0.19176845252513885, 0.3472833037376404, 0.10885429382324219, 0.192628413438797, 0.011483324691653252, 0.0], [0.0005363536183722317, 0.0001964608090929687, 0.0017719777533784509, 0.003164003835991025, 0.27662715315818787, 0.05286016687750816, 0.648875892162323, 0.007890382781624794, 0.00807751715183258, 0.0], [0.001257028547115624, 0.00020761204359587282, 0.0024441492278128862, 0.003374723019078374, 0.9062062501907349, 0.0712839737534523, 0.0032159662805497646, 0.009974849410355091, 0.0020355340093374252, 0.0], [0.0008205634076148272, 0.00019305139721836895, 0.002098840195685625, 0.004588909447193146, 0.9688709378242493, 0.01628950424492359, 0.0038415545132011175, 0.0016231476329267025, 0.0016735766548663378, 0.0], [0.03610469028353691, 0.046298399567604065, 0.04650943726301193, 0.02111651562154293, 0.06683006882667542, 0.37146270275115967, 0.174205482006073, 0.15773150324821472, 0.07974111288785934, 0.0]], [[0.03425053879618645, 0.026130978018045425, 0.3080751299858093, 0.027706336230039597, 0.12989944219589233, 0.29902005195617676, 0.0305496696382761, 0.03879137709736824, 0.1055762991309166, 0.0], [0.004509713500738144, 0.02305547706782818, 0.939035952091217, 0.006188178434967995, 0.020785806700587273, 0.00040150884888134897, 0.00018676061881706119, 0.00013036451127845794, 0.005706076975911856, 0.0], [0.0005241778562776744, 0.009561678394675255, 0.988527774810791, 2.2495760276797228e-05, 4.7274414100684226e-05, 0.00013538387429434806, 4.543165232462343e-06, 6.27172994427383e-05, 0.001113483915105462, 0.0], [0.06551901996135712, 0.0800878182053566, 0.06342226266860962, 0.00974376779049635, 0.5160938501358032, 0.02204274758696556, 0.004013149533420801, 0.0735243633389473, 0.1655530482530594, 0.0], [0.0013552415184676647, 0.0004213388019707054, 0.002606122987344861, 0.0010090378345921636, 0.24638326466083527, 0.6568374633789062, 0.01604411192238331, 0.04806208983063698, 0.027281243354082108, 0.0], [0.0002145337639376521, 0.00018796027870848775, 0.0008407118148170412, 0.0029629908967763186, 0.28427600860595703, 0.6725634336471558, 0.023870857432484627, 0.00339014851488173, 0.011693413369357586, 0.0], [0.0009873382514342666, 0.0005485343281179667, 6.628077971981838e-05, 0.0029302756302058697, 0.23183174431324005, 0.05256076529622078, 0.5701138377189636, 0.005792138632386923, 0.13516920804977417, 0.0], [2.471696279826574e-05, 2.0868348656222224e-05, 4.437468305695802e-05, 0.002024284563958645, 0.9655042886734009, 0.024176988750696182, 0.001284845289774239, 0.00018083618488162756, 0.006738840136677027, 0.0], [0.0007289832574315369, 7.746354822302237e-05, 0.00018428664770908654, 0.014176051132380962, 0.9112405180931091, 0.013280178420245647, 0.003417921019718051, 0.02014165185391903, 0.03675319626927376, 0.0], [0.00874137319624424, 0.03438721224665642, 0.17507928609848022, 0.007159235887229443, 0.0029199302662163973, 0.023628318682312965, 0.007933209650218487, 0.004559694789350033, 0.7355918884277344, 0.0]], [[0.01947755739092827, 0.007096209097653627, 0.03225293010473251, 0.0123430285602808, 0.10373923927545547, 0.44083938002586365, 0.04899014160037041, 0.25500863790512085, 0.08025286346673965, 0.0], [0.018974049016833305, 0.05092930048704147, 0.38670486211776733, 0.05532746762037277, 0.02096201851963997, 0.23439037799835205, 0.029592081904411316, 0.06233520433306694, 0.1407845914363861, 0.0], [0.009641589596867561, 0.009545106440782547, 0.19981582462787628, 0.009672220796346664, 0.003704657079651952, 0.04582780599594116, 0.006998295895755291, 0.5789687037467957, 0.13582585752010345, 0.0], [0.00450306897982955, 0.0034239809028804302, 0.012258612550795078, 0.005700208712369204, 0.04511384665966034, 0.4419432282447815, 0.12840862572193146, 0.13075105845928192, 0.22789721190929413, 0.0], [0.00048664878704585135, 0.00010348611976951361, 0.0010980216320604086, 0.0006185582024045289, 0.028226494789123535, 0.37447214126586914, 0.09456676244735718, 0.48241522908210754, 0.018012629821896553, 0.0], [8.0467427324038e-05, 3.9275117160286754e-05, 0.00016763176245149225, 0.00013412459520623088, 0.009092556312680244, 0.7851189374923706, 0.16675172746181488, 0.0029041438829153776, 0.03571125119924545, 0.0], [0.0007275060634128749, 0.00015159584290813655, 0.00037383963353931904, 0.0005468691233545542, 0.01837681420147419, 0.03491391986608505, 0.7517433166503906, 0.00028147027478553355, 0.19288486242294312, 0.0], [0.0005560970166698098, 0.0002987806510645896, 0.0021934551186859608, 0.00023410467838402838, 0.023030919954180717, 0.05263887345790863, 0.01838914304971695, 0.0007265828317031264, 0.9019319415092468, 0.0], [0.007445591501891613, 0.0020796440076082945, 0.012208829633891582, 0.001590645289979875, 0.09274771064519882, 0.017371611669659615, 0.04761578515172005, 0.004260089714080095, 0.8146799802780151, 0.0], [0.014990360476076603, 0.004210897721350193, 0.002848376054316759, 0.0006518716691061854, 0.0007818753365427256, 0.0019951288122683764, 0.0036728696431964636, 0.0004030312702525407, 0.9704453349113464, 0.0]], [[0.21779413521289825, 0.08220235258340836, 0.04201545566320419, 0.07069981843233109, 0.041075702756643295, 0.13784317672252655, 0.1975526064634323, 0.04344295710325241, 0.16737376153469086, 0.0], [0.23605762422084808, 0.07441659271717072, 0.04143041744828224, 0.05435749515891075, 0.0077708023600280285, 0.0960790365934372, 0.4399828016757965, 0.006641789805144072, 0.04326343908905983, 0.0], [0.06337786465883255, 0.03357791155576706, 0.03929098695516586, 0.5017232298851013, 0.0066258725710213184, 0.009236367419362068, 0.1690734624862671, 0.0422079935669899, 0.13488635420799255, 0.0], [0.006272959988564253, 0.0007428607787005603, 0.0011506476439535618, 0.007357995491474867, 0.0006080326274968684, 0.05679970234632492, 0.8685706257820129, 0.03271445259451866, 0.025782890617847443, 0.0], [0.041861388832330704, 0.004794578067958355, 0.0024879220873117447, 0.015253551304340363, 0.0005973980878479779, 0.08281483501195908, 0.814189076423645, 0.006639576051384211, 0.03136153519153595, 0.0], [0.010862020775675774, 0.0008270516409538686, 0.00023008826246950775, 0.006298262160271406, 0.0022151959128677845, 0.09469958394765854, 0.8416994214057922, 0.0006256845663301647, 0.04254243150353432, 0.0], [0.00024508681963197887, 3.835038296529092e-05, 2.0304802092141472e-05, 0.00012946058996021748, 0.0003255259362049401, 0.0026247953064739704, 0.9805192947387695, 0.00014136231038719416, 0.01595580205321312, 0.0], [0.001919803791679442, 0.0005674636922776699, 0.0002780239738058299, 0.0008655164856463671, 0.0013816945720463991, 0.010561172850430012, 0.05357982590794563, 0.0009362901910208166, 0.9299100637435913, 0.0], [0.00319756381213665, 0.0005108749028295279, 0.00043022894533351064, 0.005312783177942038, 0.005197612568736076, 0.008492776192724705, 0.05858352780342102, 0.01401757076382637, 0.9042569398880005, 0.0], [0.00021474930690601468, 0.0004951281007379293, 0.00032367443782277405, 0.0001866286911536008, 6.129321263870224e-05, 0.00016246296581812203, 0.0016925180098041892, 0.000427676277467981, 0.996435821056366, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9262088537216187, 0.07379112392663956, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.2983383536338806, 0.576672375202179, 0.12498921155929565, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3100782334804535, 0.1274886280298233, 0.5286650061607361, 0.033768050372600555, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3118414282798767, 0.11087317764759064, 0.12077098339796066, 0.10916762799024582, 0.34734681248664856, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1361667662858963, 0.0034004957415163517, 0.00320720998570323, 0.0056303562596440315, 0.013746269047260284, 0.8378488421440125, 0.0, 0.0, 0.0, 0.0], [0.9168469905853271, 0.009582683444023132, 0.002923850901424885, 0.009140468202531338, 0.0233402531594038, 0.01968987099826336, 0.01847577467560768, 0.0, 0.0, 0.0], [0.4528708755970001, 0.012551077641546726, 0.013286955654621124, 0.003301329677924514, 0.024005549028515816, 0.0439622700214386, 0.03865182027220726, 0.41137006878852844, 0.0, 0.0], [0.06380993872880936, 0.0008893097401596606, 0.0011801879154518247, 0.0013187900185585022, 0.0034512828569859266, 0.0014297974994406104, 0.0023058890365064144, 0.041651248931884766, 0.8839635848999023, 0.0], [0.5330018997192383, 0.012773798778653145, 0.01854255609214306, 0.022641947492957115, 0.1288023591041565, 0.01178218238055706, 0.020595960319042206, 0.08756020665168762, 0.09921147674322128, 0.06508753448724747]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9422653913497925, 0.057734500616788864, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.37070432305336, 0.2449311465024948, 0.3843645751476288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5423898100852966, 0.11884469538927078, 0.1850128471851349, 0.15375272929668427, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7452426552772522, 0.024770371615886688, 0.025099167600274086, 0.014617366716265678, 0.19027042388916016, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4940005838871002, 0.026306116953492165, 0.014163044281303883, 0.022562485188245773, 0.43185216188430786, 0.011115492321550846, 0.0, 0.0, 0.0, 0.0], [0.8323472142219543, 0.005361876450479031, 0.001218354911543429, 0.0017811520956456661, 0.06672050058841705, 0.0179598405957222, 0.07461105287075043, 0.0, 0.0, 0.0], [0.5900163650512695, 0.0016051119891926646, 0.00041884748497977853, 0.002425695303827524, 0.09076588600873947, 0.005809221416711807, 0.03928956016898155, 0.2696692943572998, 0.0, 0.0], [0.14191001653671265, 0.0026981914415955544, 0.000433926354162395, 0.0025318085681647062, 0.0752185806632042, 0.041030533611774445, 0.10226735472679138, 0.6134982705116272, 0.020411266013979912, 0.0], [0.9951959252357483, 0.000172812317032367, 0.0011272057890892029, 0.0002565488684922457, 0.001650187186896801, 0.0010172545444220304, 3.585639569791965e-05, 0.00030177918961271644, 2.7251116989646107e-05, 0.00021514984837267548]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9959792494773865, 0.004020644351840019, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8763805031776428, 0.06819441169500351, 0.05542506277561188, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6675543785095215, 0.035431310534477234, 0.2554236948490143, 0.04159051924943924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8250302076339722, 0.013232334516942501, 0.10887149721384048, 0.016031241044402122, 0.03683457896113396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14042839407920837, 0.005938003305345774, 0.04128086566925049, 0.01834655925631523, 0.7866368293762207, 0.007369248662143946, 0.0, 0.0, 0.0, 0.0], [0.3567042350769043, 0.0165000781416893, 0.015264611691236496, 0.010309864766895771, 0.38396307826042175, 0.025359012186527252, 0.1918991357088089, 0.0, 0.0, 0.0], [0.03735272213816643, 0.0005555232055485249, 0.0009066119673661888, 0.003488750196993351, 0.4253699481487274, 0.039391178637742996, 0.3313658535480499, 0.1615692675113678, 0.0, 0.0], [0.0020103107672184706, 0.0002689870889298618, 0.0004340466111898422, 0.0009705349220894277, 0.03535917028784752, 0.014057940803468227, 0.07802704721689224, 0.8683921694755554, 0.0004796571738552302, 0.0], [0.21001528203487396, 0.008917403407394886, 0.08127831667661667, 0.6020672917366028, 0.0504239983856678, 0.01106872595846653, 0.002271559089422226, 0.009885885752737522, 0.013363776728510857, 0.010707534849643707]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8274853825569153, 0.1725146621465683, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.39722761511802673, 0.5465205311775208, 0.05625181272625923, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7089572548866272, 0.12511004507541656, 0.08669630438089371, 0.0792364850640297, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9339975714683533, 0.013466393575072289, 0.00928713008761406, 0.00507207540795207, 0.03817704692482948, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7470325231552124, 0.0030789184384047985, 0.0006101431790739298, 0.009402818977832794, 0.23476918041706085, 0.005106179974973202, 0.0, 0.0, 0.0, 0.0], [0.21711143851280212, 0.003716376842930913, 0.00037448908551596105, 0.0019620254170149565, 0.018900232389569283, 0.009617134928703308, 0.7483181953430176, 0.0, 0.0, 0.0], [0.010075456462800503, 5.468959716381505e-05, 5.17756825502147e-06, 5.762913860962726e-05, 0.0005752856959588826, 0.0004235330270603299, 0.004707484506070614, 0.9841007590293884, 0.0, 0.0], [0.0014721885090693831, 9.766960283741355e-05, 9.390318155055866e-06, 9.01468301890418e-05, 0.00026504675042815506, 0.0001477079640608281, 0.0007441531051881611, 0.9970147013664246, 0.00015886487381067127, 0.0], [0.9506397247314453, 0.010028047487139702, 0.0004243685398250818, 0.012790095992386341, 0.006212451495230198, 0.0008045415161177516, 0.0008908100426197052, 0.0004145564162172377, 0.0002187698701163754, 0.01757662557065487]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9158000946044922, 0.0841999277472496, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9424960017204285, 0.02535107545554638, 0.032153017818927765, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.22060541808605194, 0.18997374176979065, 0.08500542491674423, 0.5044154524803162, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7531844973564148, 0.02070058509707451, 0.008920542895793915, 0.016695866361260414, 0.20049844682216644, 0.0, 0.0, 0.0, 0.0, 0.0], [0.759453296661377, 0.0056156679056584835, 0.008695651777088642, 0.014426307752728462, 0.16163751482963562, 0.05017174035310745, 0.0, 0.0, 0.0, 0.0], [0.2527230679988861, 0.0006535803549923003, 0.00037003192119300365, 0.00041730765951797366, 0.057080648839473724, 0.06757333129644394, 0.6211821436882019, 0.0, 0.0, 0.0], [0.6996693015098572, 0.00526623846963048, 0.003115275641903281, 0.001864676014520228, 0.019210346043109894, 0.022201303392648697, 0.16487717628479004, 0.08379579335451126, 0.0, 0.0], [0.01643717661499977, 0.001304203411564231, 0.00015219511988107115, 8.364384120795876e-05, 0.0027460975106805563, 0.005807426758110523, 0.02910688892006874, 0.054244525730609894, 0.8901176452636719, 0.0], [0.03737838938832283, 0.0008823095704428852, 0.00013810240488965064, 0.0003819032572209835, 0.0009168537217192352, 0.017434338107705116, 0.0524771511554718, 0.5634113550186157, 0.05003770440816879, 0.27694204449653625]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9822245836257935, 0.017775410786271095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9783667922019958, 0.004186260513961315, 0.01744689606130123, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8277915120124817, 0.0035995396319776773, 0.1268300712108612, 0.04177885130047798, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9593387246131897, 0.001320014358498156, 0.002763292985036969, 0.002305841539055109, 0.03427214175462723, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5380056500434875, 0.00011044789425795898, 0.001150083844549954, 0.002725756261497736, 0.45681822299957275, 0.0011898496886715293, 0.0, 0.0, 0.0, 0.0], [0.16147758066654205, 0.001678255619481206, 0.004225697834044695, 0.012547606602311134, 0.4120558202266693, 0.030565770342946053, 0.37744930386543274, 0.0, 0.0, 0.0], [0.07655133306980133, 0.00011485892173368484, 0.0004792730906046927, 0.0037317569367587566, 0.9091346859931946, 0.005207230802625418, 0.003226343309506774, 0.0015543886693194509, 0.0, 0.0], [0.0006837816908955574, 6.692374881822616e-05, 3.2170661143027246e-05, 0.017242103815078735, 0.9703013896942139, 0.0009919245494529605, 0.00010187587758991867, 0.00012404048175085336, 0.01045528706163168, 0.0], [0.8681296706199646, 0.004244405776262283, 0.0034055972937494516, 0.0032342004124075174, 0.11890427023172379, 0.00032322408515028656, 1.7166490579256788e-05, 8.356601756531745e-05, 0.00016651467012707144, 0.0014914675848558545]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9673911333084106, 0.032608743757009506, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8945506811141968, 0.048047225922346115, 0.05740200728178024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8226539492607117, 0.025171183049678802, 0.033602889627218246, 0.1185719221830368, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7488189339637756, 0.022310951724648476, 0.03220387548208237, 0.05049983412027359, 0.14616648852825165, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5947939157485962, 0.009725339710712433, 0.01194794476032257, 0.06678443402051926, 0.22137242555618286, 0.09537594765424728, 0.0, 0.0, 0.0, 0.0], [0.5493549704551697, 0.010730843059718609, 0.013811847195029259, 0.01375968661159277, 0.13386781513690948, 0.031593821942806244, 0.2468811273574829, 0.0, 0.0, 0.0], [0.44999176263809204, 0.0022518665064126253, 0.007128801662474871, 0.06941325962543488, 0.11436374485492706, 0.06527625769376755, 0.25339174270629883, 0.038182370364665985, 0.0, 0.0], [0.6273319125175476, 0.0019851899705827236, 0.014608433470129967, 0.053566914051771164, 0.10037831962108612, 0.05395424738526344, 0.09709113836288452, 0.020020073279738426, 0.031063806265592575, 0.0], [0.13732852041721344, 0.005784862674772739, 0.011142567731440067, 0.3659982979297638, 0.03412118926644325, 0.191008523106575, 0.02493627928197384, 0.01782877929508686, 0.005097466055303812, 0.2067534178495407]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9590145349502563, 0.0409853532910347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.13186156749725342, 0.7104970812797546, 0.15764127671718597, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1307007521390915, 0.4791290760040283, 0.2198515087366104, 0.1703186184167862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25735223293304443, 0.03605807572603226, 0.08834479749202728, 0.21978884935379028, 0.398455947637558, 0.0, 0.0, 0.0, 0.0, 0.0], [0.014754761941730976, 0.016280202195048332, 0.010505245067179203, 0.26496851444244385, 0.6780229210853577, 0.015468388795852661, 0.0, 0.0, 0.0, 0.0], [0.0561433881521225, 0.00821017101407051, 0.013592599891126156, 0.04250938817858696, 0.20505541563034058, 0.637790322303772, 0.03669866546988487, 0.0, 0.0, 0.0], [0.02288638986647129, 0.0031705975998193026, 0.0010986417764797807, 0.1258203089237213, 0.13997967541217804, 0.6275703310966492, 0.004779829643666744, 0.07469423860311508, 0.0, 0.0], [0.04480466619133949, 0.007826470769941807, 0.0012622721260413527, 0.18829701840877533, 0.1579897105693817, 0.4087865948677063, 0.0030938636045902967, 0.17715193331241608, 0.010787548497319221, 0.0], [0.2647387683391571, 0.0023117128293961287, 0.5836825370788574, 0.022214042022824287, 0.05302866920828819, 0.05609899014234543, 0.0002153095556423068, 0.0012429821072146297, 0.012765316292643547, 0.0037017168942838907]]], [[[0.18620921671390533, 0.0449230894446373, 0.15743261575698853, 0.0027164025232195854, 0.000954743183683604, 0.10880818217992783, 0.004260051064193249, 0.4840531051158905, 0.010642877779901028, 0.0], [0.10068266838788986, 0.8361198902130127, 0.05278307944536209, 0.003077939385548234, 0.0006954235723242164, 0.001363753923214972, 0.00026539582177065313, 0.004202431067824364, 0.0008096573874354362, 0.0], [0.012129311449825764, 0.01155073568224907, 0.9600933194160461, 8.282387716462836e-05, 1.0725593710958492e-05, 0.0005505315493792295, 8.825069380691275e-05, 0.015057343989610672, 0.00043726651347242296, 0.0], [8.100323611870408e-05, 0.0004598332743626088, 0.004657193087041378, 0.000634010590147227, 0.00027469659107737243, 0.005632649641484022, 0.000647437758743763, 0.9867796301841736, 0.0008332319557666779, 0.0], [0.00010327257041353732, 8.895192149793729e-05, 0.0004001102061010897, 3.5898548958357424e-05, 8.903054549591616e-06, 0.002168947132304311, 0.0003314291825518012, 0.9968016743659973, 6.082480831537396e-05, 0.0], [0.0006819640402682126, 0.0025551444850862026, 0.029635878279805183, 0.0007182788685895503, 0.0009121407056227326, 0.9391846656799316, 0.0023257755674421787, 0.020892569795250893, 0.0030933902598917484, 0.0], [0.0006610184791497886, 0.004029686562716961, 0.03350083529949188, 0.0028945906087756157, 0.06891647726297379, 0.0361749529838562, 0.6805889010429382, 0.0015104033518582582, 0.17172299325466156, 0.0], [0.00011510718468343839, 0.00041600633994676173, 0.007651225198060274, 0.0003919293521903455, 0.048794399946928024, 0.12390702962875366, 0.005600529722869396, 0.0008058404200710356, 0.8123176097869873, 0.0], [0.0003188557457178831, 0.0017433647299185395, 0.0013032852439209819, 0.008202485740184784, 0.26753997802734375, 0.1699969321489334, 0.02015369012951851, 0.026912324130535126, 0.5038290619850159, 0.0], [0.020566454157233238, 0.12752646207809448, 0.13235142827033997, 8.515831723343581e-05, 0.0007726486655883491, 0.005525102838873863, 0.002064254367724061, 0.0015006973408162594, 0.7096077799797058, 0.0]], [[0.08830718696117401, 0.003260435536503792, 0.007942354306578636, 0.007197668310254812, 0.023230358958244324, 0.6884769797325134, 0.13524922728538513, 0.013760159723460674, 0.03257569298148155, 0.0], [0.01410764642059803, 0.011476421728730202, 0.655226469039917, 0.029443562030792236, 0.17404575645923615, 0.04738258570432663, 0.035108331590890884, 0.004049936309456825, 0.02915901131927967, 0.0], [0.006112441886216402, 0.010383019223809242, 0.9739192724227905, 0.0017695348942652345, 0.0007649966282770038, 0.001380802714265883, 0.0003705607377924025, 0.00034036929719150066, 0.004958811681717634, 0.0], [0.025388794019818306, 0.006199578754603863, 0.10192698240280151, 0.0023500584065914154, 0.009979050606489182, 0.5388055443763733, 0.29305511713027954, 0.002850176068022847, 0.0194447822868824, 0.0], [0.0011180925648659468, 3.349311737110838e-05, 0.00020844468963332474, 0.00016400347521994263, 0.001158660277724266, 0.5398337244987488, 0.4514371454715729, 0.00012239665375091136, 0.005924074444919825, 0.0], [4.934398384648375e-05, 6.905893883413228e-07, 5.809057256556116e-06, 1.44853029269143e-05, 0.0013859024038538337, 0.62599116563797, 0.3719564974308014, 0.0002632574178278446, 0.00033293903106823564, 0.0], [1.8935834305011667e-05, 5.593590231001144e-06, 9.02482042874908e-06, 4.666295353672467e-05, 0.00140501803252846, 0.0024830379988998175, 0.9939435124397278, 0.00030495785176754, 0.0017833412857726216, 0.0], [0.00015082204481586814, 9.979225069400854e-06, 0.00013493606820702553, 0.0006857623811811209, 0.9507938623428345, 0.013522839173674583, 0.004887807182967663, 0.001293701701797545, 0.028520429506897926, 0.0], [0.00021830093464814126, 1.1190621080459096e-05, 0.0010014179861173034, 0.0016852812841534615, 0.9693949818611145, 0.003066261066123843, 0.002616706071421504, 0.006246546749025583, 0.015759343281388283, 0.0], [0.033513687551021576, 0.047761499881744385, 0.1371326446533203, 0.027179328724741936, 0.07905351370573044, 0.04665757715702057, 0.017991477623581886, 0.0258343443274498, 0.5848759412765503, 0.0]], [[0.3675236701965332, 0.22013956308364868, 0.3048599064350128, 0.045011524111032486, 0.013697491027414799, 0.012050136923789978, 0.009531261399388313, 0.0020223394967615604, 0.025163909420371056, 0.0], [0.013416368514299393, 0.7244334816932678, 0.22923606634140015, 0.004823721945285797, 0.0007022434147074819, 0.0012150612892583013, 0.001360778696835041, 0.00021415007358882576, 0.024598030373454094, 0.0], [0.03640636429190636, 0.024720389395952225, 0.8944843411445618, 0.0018058173591271043, 0.00014742508938070387, 0.002046161564067006, 0.0012721297098323703, 0.0010774562833830714, 0.0380399152636528, 0.0], [0.032080236822366714, 0.02157183177769184, 0.017530914396047592, 0.21374234557151794, 0.5176447033882141, 0.021586988121271133, 0.06124785542488098, 0.004810539539903402, 0.10978466272354126, 0.0], [0.16469916701316833, 0.0144515885040164, 0.007452514488250017, 0.029052020981907845, 0.2643658220767975, 0.1970161497592926, 0.2818319797515869, 0.016781603917479515, 0.024349281564354897, 0.0], [0.025996195152401924, 0.005627068690955639, 0.007119623012840748, 0.004898787476122379, 0.5349600911140442, 0.05678911507129669, 0.3094601333141327, 0.008422048762440681, 0.04672713205218315, 0.0], [0.004280757624655962, 0.0006373892538249493, 9.946383943315595e-05, 0.00030879577388986945, 0.02805289998650551, 0.008433223702013493, 0.9252934455871582, 0.001439885818399489, 0.03145414590835571, 0.0], [0.04426492750644684, 0.0032368048559874296, 0.0014763016952201724, 0.0021763627883046865, 0.5636131763458252, 0.010265699587762356, 0.08146306872367859, 0.003517861943691969, 0.289985716342926, 0.0], [0.012160537764430046, 0.00020874926121905446, 0.0005602578166872263, 0.0007960868533700705, 0.9389106035232544, 0.005963308271020651, 0.005384649150073528, 0.0009963578777387738, 0.035019390285015106, 0.0], [0.006462599150836468, 0.006167746149003506, 0.00141435069963336, 0.00035615835804492235, 0.0002947094908449799, 0.002378113567829132, 0.011835698038339615, 0.0024426754098385572, 0.968647837638855, 0.0]], [[0.013161101378500462, 0.01350532379001379, 0.39494189620018005, 0.007352527230978012, 0.12711142003536224, 0.14605116844177246, 0.03487401455640793, 0.15623201429843903, 0.10677067190408707, 0.0], [0.021876059472560883, 0.4906902313232422, 0.4596463143825531, 0.004091671667993069, 0.004464378114789724, 0.001156727666966617, 0.000353646173607558, 0.000146497564855963, 0.017574656754732132, 0.0], [0.005734701175242662, 0.026843877509236336, 0.9321272969245911, 0.00021884289162699133, 0.00045866103027947247, 0.0010309598874300718, 0.00017261962057091296, 0.003054215107113123, 0.030358724296092987, 0.0], [0.0482722632586956, 0.14050070941448212, 0.4546079635620117, 0.0072937230579555035, 0.023873258382081985, 0.09857403486967087, 0.0516686774790287, 0.11766187101602554, 0.05754747614264488, 0.0], [0.0020078516099601984, 0.002228439087048173, 0.111594557762146, 0.0033910104539245367, 0.08423032611608505, 0.17691271007061005, 0.14758752286434174, 0.4346924424171448, 0.037355244159698486, 0.0], [0.0008274781284853816, 0.0016531302826479077, 0.047970183193683624, 0.0006053023971617222, 0.22220103442668915, 0.6234129071235657, 0.05364101752638817, 0.012585645541548729, 0.03710317984223366, 0.0], [2.7583497285377234e-05, 1.1631378583842888e-05, 4.4259006244828925e-05, 0.0006730516324751079, 0.599366307258606, 0.006597205530852079, 0.3886081576347351, 0.0003169252013321966, 0.004354946780949831, 0.0], [2.752073669398669e-06, 2.0648456029448425e-06, 8.536147106497083e-06, 6.34281532256864e-05, 0.9992840886116028, 0.00028667543665505946, 7.951273437356576e-05, 3.5721727726922836e-06, 0.00026920961681753397, 0.0], [3.3996084312093444e-06, 2.1497796751646092e-06, 7.304265182028757e-06, 0.00018760550301522017, 0.99969482421875, 2.4790026145637967e-05, 3.4293629141757265e-05, 6.942725121916737e-06, 3.892222957802005e-05, 0.0], [0.0005689842510037124, 0.002939490834251046, 0.019829533994197845, 0.0003717679646797478, 0.01646142266690731, 0.011912180110812187, 0.001234701368957758, 0.0013870754046365619, 0.945294976234436, 0.0]], [[0.00632825493812561, 0.011520092375576496, 0.08263711631298065, 0.006356080062687397, 0.022936103865504265, 0.03108564019203186, 0.013897407799959183, 0.697504997253418, 0.12773430347442627, 0.0], [0.008715116418898106, 0.015272715128958225, 0.10463730990886688, 0.08011683076620102, 0.13045108318328857, 0.05373600497841835, 0.015578814782202244, 0.4212273955345154, 0.1702648103237152, 0.0], [0.004959889687597752, 0.007777809165418148, 0.14492008090019226, 0.02459821291267872, 0.014704479835927486, 0.016136664897203445, 0.008129375986754894, 0.7319321036338806, 0.0468413271009922, 0.0], [0.005315575283020735, 0.0021190166007727385, 0.007080279756337404, 0.006970370654016733, 0.010002117604017258, 0.007610250264406204, 0.004703941754996777, 0.8570073246955872, 0.09919113665819168, 0.0], [0.0016317280242219567, 0.0005414763581939042, 0.004523266106843948, 0.0019645043648779392, 0.010821727104485035, 0.008883371017873287, 0.00927714817225933, 0.920802652835846, 0.041554201394319534, 0.0], [0.002020488725975156, 0.0007793906843289733, 0.022791940718889236, 0.005821499973535538, 0.1932065784931183, 0.30031588673591614, 0.08197023719549179, 0.12508654594421387, 0.2680076062679291, 0.0], [0.007396090775728226, 0.0032474161125719547, 0.00692824088037014, 0.007240207865834236, 0.42384257912635803, 0.04473983123898506, 0.013007782399654388, 0.007779541425406933, 0.4858182966709137, 0.0], [0.0026900237426161766, 0.0007204422145150602, 0.005861051380634308, 0.003422616282477975, 0.46744993329048157, 0.10402297228574753, 0.05837857723236084, 0.0177029799669981, 0.3397515118122101, 0.0], [0.005906206555664539, 0.002057044068351388, 0.0031123505905270576, 0.008901549503207207, 0.43650564551353455, 0.08504725992679596, 0.0923796221613884, 0.009556618519127369, 0.3565336763858795, 0.0], [0.013360978104174137, 0.04520300775766373, 0.09048072248697281, 0.012179902754724026, 0.030064363032579422, 0.023480970412492752, 0.008669134229421616, 0.03746046498417854, 0.7391002178192139, 0.0]], [[0.023652182891964912, 0.008639940991997719, 0.08203616738319397, 0.035750582814216614, 0.050224509090185165, 0.3533262312412262, 0.03081362321972847, 0.28302860260009766, 0.1325281411409378, 0.0], [0.016670020297169685, 0.1283574253320694, 0.836423397064209, 0.0042742472141981125, 0.0022883012425154448, 0.00297459471039474, 0.00022807312780059874, 0.0012588471872732043, 0.007524838205426931, 0.0], [0.031559381633996964, 0.02045642025768757, 0.8176267743110657, 0.006169404834508896, 0.0014412011951208115, 0.0069603933952748775, 0.0010916722239926457, 0.011522608809173107, 0.10317197442054749, 0.0], [0.004598122555762529, 0.004610949195921421, 0.01865001954138279, 0.020574036985635757, 0.0137012405321002, 0.7973257303237915, 0.01646837778389454, 0.023596635088324547, 0.1004747673869133, 0.0], [0.0005213705007918179, 0.00018707667186390609, 0.0016978917410597205, 0.019619440659880638, 0.009308884851634502, 0.8590161800384521, 0.024511896073818207, 0.06970686465501785, 0.015430280938744545, 0.0], [0.0001481063081882894, 2.072651477647014e-05, 0.00035672096419148147, 0.00033358228392899036, 0.00040588833508081734, 0.9861487746238708, 0.00651955883949995, 0.00443643843755126, 0.0016300288261845708, 0.0], [0.0010996124474331737, 0.0011850595474243164, 0.0075045316480100155, 0.004539311397820711, 0.05570072680711746, 0.18870605528354645, 0.23963898420333862, 0.013960372656583786, 0.487665593624115, 0.0], [0.0003884119214490056, 0.0004658032557927072, 0.028157439082860947, 0.0002352961164433509, 0.1278570294380188, 0.08260466903448105, 0.02582997828722, 0.022790132090449333, 0.7116712927818298, 0.0], [0.0015414542285725474, 0.0007310948567464948, 0.010464987717568874, 0.0012846259633079171, 0.45206302404403687, 0.029316790401935577, 0.04706822335720062, 0.018986493349075317, 0.4385431706905365, 0.0], [0.0005072542116977274, 0.0011837932979688048, 0.01220926083624363, 8.532252832083032e-05, 0.0018606879748404026, 0.010199862532317638, 0.0016309961210936308, 0.010775143280625343, 0.9615475535392761, 0.0]], [[0.29744189977645874, 0.04770943149924278, 0.09888078272342682, 0.19768767058849335, 0.048243775963783264, 0.12058595567941666, 0.05976371467113495, 0.03847452625632286, 0.09121233224868774, 0.0], [0.04126456007361412, 0.6604095697402954, 0.028894882649183273, 0.20104490220546722, 0.0014044500421732664, 0.0009343607816845179, 0.00244489056058228, 0.007453228812664747, 0.05614929273724556, 0.0], [0.008357543498277664, 0.0022072584833949804, 0.9876156449317932, 8.841200906317681e-05, 1.4883004041621462e-05, 0.00011741811613319442, 2.7020510970032774e-05, 0.00016062626673374325, 0.001411277218721807, 0.0], [0.06216944754123688, 0.48559242486953735, 0.042546145617961884, 0.034007471054792404, 0.047574639320373535, 0.12490913271903992, 0.07922931015491486, 0.013364763930439949, 0.11060672253370285, 0.0], [0.05222959443926811, 0.025416702032089233, 0.02865077182650566, 0.17457211017608643, 0.03144511207938194, 0.3907364010810852, 0.19607771933078766, 0.05274118855595589, 0.04813018813729286, 0.0], [0.0037726862356066704, 0.0031579534988850355, 0.0029440780635923147, 0.0017320584738627076, 0.060473062098026276, 0.761774480342865, 0.1523173600435257, 0.0058823637664318085, 0.007945872843265533, 0.0], [0.0020738786552101374, 0.0012752892216667533, 0.0004058163322042674, 0.020963717252016068, 0.39340031147003174, 0.012434415519237518, 0.4783190190792084, 0.011497312225401402, 0.0796302929520607, 0.0], [5.31752230017446e-05, 1.4492364243778866e-05, 7.312332309084013e-05, 0.0023682843893766403, 0.9866323471069336, 0.0009243910317309201, 0.0011850211303681135, 0.0017622504383325577, 0.0069872229360044, 0.0], [4.074166645295918e-05, 1.823456841520965e-05, 0.0001418270985595882, 0.007263784296810627, 0.9604514241218567, 0.0001852070417953655, 0.00034164052340202034, 0.0018497714772820473, 0.029707150533795357, 0.0], [0.0133396340534091, 0.03136875480413437, 0.6319980621337891, 0.0033722908701747656, 0.04728742688894272, 0.03541773557662964, 0.009523973800241947, 0.03100484237074852, 0.1966874897480011, 0.0]], [[0.03367111459374428, 0.018932543694972992, 0.09506545215845108, 0.04718795791268349, 0.028798582032322884, 0.33658939599990845, 0.02586139366030693, 0.29842811822891235, 0.11546547710895538, 0.0], [0.006203038617968559, 0.0906001627445221, 0.6977949738502502, 0.018352899700403214, 0.06787873804569244, 0.04403599724173546, 0.001631368650123477, 0.024296771734952927, 0.049206044524908066, 0.0], [0.006243667099624872, 0.010453532449901104, 0.7879610657691956, 0.004093538969755173, 0.0008473669877275825, 0.027760563418269157, 0.0003080451278947294, 0.14831961691379547, 0.014012438245117664, 0.0], [0.004387176129966974, 0.023410169407725334, 0.17247918248176575, 0.03958609700202942, 0.023799436166882515, 0.43659475445747375, 0.014754846692085266, 0.2318120151758194, 0.05317622795701027, 0.0], [0.0020952164195477962, 0.0024118656292557716, 0.028229335322976112, 0.007075420115143061, 0.019164882600307465, 0.5397294163703918, 0.034580815583467484, 0.3465326428413391, 0.020180128514766693, 0.0], [0.00020744462381117046, 0.00036016973899677396, 0.004934145137667656, 0.0004664760490413755, 0.008187839761376381, 0.9661812782287598, 0.009987047873437405, 0.003882928751409054, 0.005792597308754921, 0.0], [3.4081476769642904e-05, 1.7181657312903553e-05, 5.4824478866066784e-05, 0.00045897584641352296, 0.0043338024988770485, 0.001544477418065071, 0.9909620881080627, 2.356152981519699e-05, 0.0025708049070090055, 0.0], [0.0001047314508468844, 0.0001599654060555622, 0.001310097286477685, 0.001540280063636601, 0.833267331123352, 0.044754061847925186, 0.0028599577490240335, 0.0006454077665694058, 0.11535807698965073, 0.0], [8.819431968731806e-05, 6.364465662045404e-05, 0.00022057128080632538, 0.001112746773287654, 0.9560981392860413, 0.003599100047722459, 0.0002217600413132459, 0.0006697923527099192, 0.03792598471045494, 0.0], [0.0018130787648260593, 0.022020958364009857, 0.12822051346302032, 0.0005810249131172895, 0.03168048337101936, 0.014293116517364979, 0.002500524278730154, 0.0212943647056818, 0.7775959372520447, 0.0]]], [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6252409815788269, 0.3747589886188507, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8520486354827881, 0.010580658912658691, 0.13737063109874725, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.05910082906484604, 0.011589597910642624, 0.877491295337677, 0.051818281412124634, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3626183867454529, 0.026959313079714775, 0.07612177729606628, 0.13077552616596222, 0.4035249352455139, 0.0, 0.0, 0.0, 0.0, 0.0], [0.21979263424873352, 0.001410112832672894, 0.007092535495758057, 0.13166557252407074, 0.626970648765564, 0.013068560510873795, 0.0, 0.0, 0.0, 0.0], [0.08148042857646942, 0.001490423921495676, 0.004908325150609016, 0.01383854728192091, 0.7959722876548767, 0.05201547220349312, 0.05029459297657013, 0.0, 0.0, 0.0], [0.03934427723288536, 5.908778257435188e-05, 0.00014962907880544662, 0.005592166446149349, 0.7025003433227539, 0.1675100177526474, 0.03920353576540947, 0.04564077779650688, 0.0, 0.0], [0.4660189151763916, 0.00034756408422254026, 9.701005183160305e-05, 0.008154522627592087, 0.08121690154075623, 0.15592943131923676, 0.11426379531621933, 0.17044323682785034, 0.0035288764629513025, 0.0], [0.3707294762134552, 0.0020887483842670918, 0.23984688520431519, 0.07748916745185852, 0.18109895288944244, 0.03584783151745796, 0.005205830093473196, 0.005058187525719404, 0.0050886403769254684, 0.0775463655591011]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.15256483852863312, 0.8474349975585938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.08618302643299103, 0.30268052220344543, 0.6111364364624023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6251113414764404, 0.14608541131019592, 0.21724094450473785, 0.011562197469174862, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.31851068139076233, 0.11805614084005356, 0.02926168404519558, 0.0854775682091713, 0.44869405031204224, 0.0, 0.0, 0.0, 0.0, 0.0], [0.23099647462368011, 0.015003926120698452, 0.0028121687937527895, 0.025386620312929153, 0.5829272270202637, 0.14287345111370087, 0.0, 0.0, 0.0, 0.0], [0.2648485600948334, 0.01456066407263279, 0.008421574719250202, 0.01653379574418068, 0.25845009088516235, 0.35933130979537964, 0.07785411924123764, 0.0, 0.0, 0.0], [0.21031156182289124, 0.00652333116158843, 0.005756322760134935, 0.019128819927573204, 0.2526819407939911, 0.49096593260765076, 0.008809886872768402, 0.00582215515896678, 0.0, 0.0], [0.11555754393339157, 0.00475481478497386, 0.0013921409845352173, 0.045808907598257065, 0.29882168769836426, 0.3024459183216095, 0.0483231395483017, 0.18265680968761444, 0.0002390409354120493, 0.0], [0.8451279401779175, 0.021679740399122238, 0.035543736070394516, 0.005811640061438084, 0.04445958510041237, 0.018052000552415848, 0.0015424924204126, 0.013668404892086983, 0.012673787772655487, 0.0014405279653146863]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9927853345870972, 0.007214863318949938, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.011021426878869534, 0.007158290129154921, 0.9818204641342163, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.007071706000715494, 0.026167649775743484, 0.19316613674163818, 0.773594319820404, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.320003479719162, 0.03976304829120636, 0.22334550321102142, 0.24320250749588013, 0.17368540167808533, 0.0, 0.0, 0.0, 0.0, 0.0], [0.10932182520627975, 0.001151762087829411, 0.007792286574840546, 0.18981949985027313, 0.6517421007156372, 0.04017229378223419, 0.0, 0.0, 0.0, 0.0], [0.02538878843188286, 0.005211540497839451, 0.03069700486958027, 0.13252338767051697, 0.4279623329639435, 0.0899164006114006, 0.28830063343048096, 0.0, 0.0, 0.0], [0.010537173599004745, 0.0007831656257621944, 0.0007035965682007372, 0.015162549912929535, 0.9050821661949158, 0.05248205363750458, 0.01132790744304657, 0.00392116466537118, 0.0, 0.0], [0.005222301464527845, 0.003575690556317568, 0.0029950442258268595, 0.00018454395467415452, 0.0012630765559151769, 0.01364975143224001, 0.09376595914363861, 0.853415846824646, 0.02592780999839306, 0.0], [0.14979584515094757, 0.0004723063320852816, 0.4970340430736542, 0.03214645013213158, 0.022075939923524857, 0.006538126152008772, 0.0013381451135501266, 0.0030305178370326757, 0.0008045822032727301, 0.28676414489746094]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9691458940505981, 0.03085414692759514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9338735938072205, 0.02144204080104828, 0.04468445107340813, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.4091326594352722, 0.1788463294506073, 0.3530478775501251, 0.058973249047994614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8083640336990356, 0.0245783980935812, 0.02959858626127243, 0.02002020739018917, 0.11743883788585663, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6256738901138306, 0.03313886746764183, 0.03255102410912514, 0.015011090785264969, 0.27659764885902405, 0.017027597874403, 0.0, 0.0, 0.0, 0.0], [0.2970131039619446, 0.01776941865682602, 0.015323061496019363, 0.014444534666836262, 0.2387886643409729, 0.36828577518463135, 0.048375438898801804, 0.0, 0.0, 0.0], [0.16347570717334747, 0.01386126596480608, 0.012116431258618832, 0.006670618429780006, 0.5951986312866211, 0.1577492356300354, 0.024585027247667313, 0.02634291537106037, 0.0, 0.0], [0.1568753868341446, 0.002166055142879486, 0.0014692704426124692, 0.009539359249174595, 0.7249224781990051, 0.0696585550904274, 0.02269914373755455, 0.010646837763488293, 0.0020231890957802534, 0.0], [0.6687246561050415, 0.003988182172179222, 0.00992897991091013, 0.00877397134900093, 0.07160260528326035, 0.14080072939395905, 0.01739262230694294, 0.04941429942846298, 0.01782085746526718, 0.011553076095879078]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.497504860162735, 0.502495288848877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.028444888070225716, 0.01678420603275299, 0.9547709822654724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.02853180095553398, 0.022399114444851875, 0.7835201025009155, 0.1655489057302475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.023048963397741318, 0.055082567036151886, 0.3371332883834839, 0.25099456310272217, 0.33374062180519104, 0.0, 0.0, 0.0, 0.0, 0.0], [0.013693265616893768, 0.057373203337192535, 0.02566814236342907, 0.11711565405130386, 0.13761301338672638, 0.6485366225242615, 0.0, 0.0, 0.0, 0.0], [0.5831283926963806, 0.0857725590467453, 0.06227085366845131, 0.03169894590973854, 0.06183577701449394, 0.01752074435353279, 0.15777261555194855, 0.0, 0.0, 0.0], [0.0033312023151665926, 0.003545752028003335, 0.0018331086030229926, 0.05265560373663902, 0.047756411135196686, 0.045255228877067566, 0.20667387545108795, 0.6389486193656921, 0.0, 0.0], [0.02047032117843628, 0.03542931377887726, 0.01270933635532856, 0.46998995542526245, 0.035482652485370636, 0.015606570988893509, 0.1128709465265274, 0.03180817514657974, 0.26563259959220886, 0.0], [0.027955254539847374, 0.024354776367545128, 0.4609973132610321, 0.07958999276161194, 0.34062448143959045, 0.0068156360648572445, 0.000798556546214968, 0.0009541919571347535, 0.00023223790049087256, 0.05767740309238434]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.909331738948822, 0.09066825360059738, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3865880072116852, 0.017979737371206284, 0.5954321622848511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5815560817718506, 0.15706834197044373, 0.052335821092128754, 0.2090395838022232, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7268465161323547, 0.0363004133105278, 0.07873083651065826, 0.06576839834451675, 0.09235385805368423, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6550694704055786, 0.019533857703208923, 0.042362816631793976, 0.07321250438690186, 0.06519921869039536, 0.14462217688560486, 0.0, 0.0, 0.0, 0.0], [0.48597243428230286, 0.05253118649125099, 0.06572883576154709, 0.06831242144107819, 0.06681334227323532, 0.09225586801767349, 0.168385848402977, 0.0, 0.0, 0.0], [0.2283225953578949, 0.01085133571177721, 0.0076954541727900505, 0.03403906524181366, 0.05505141243338585, 0.11318682134151459, 0.23008716106414795, 0.3207661509513855, 0.0, 0.0], [0.31019407510757446, 0.01576145552098751, 0.006604246329516172, 0.1025082990527153, 0.11805430799722672, 0.0999068170785904, 0.17944715917110443, 0.09494999051094055, 0.07257375121116638, 0.0], [0.028495613485574722, 0.00728303287178278, 0.028978589922189713, 0.21746259927749634, 0.0312367994338274, 0.01134485937654972, 0.002138715935871005, 0.0005697175511159003, 0.00012198994954815134, 0.6723678112030029]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.9973775148391724, 0.0026223897002637386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.8086934685707092, 0.08078567683696747, 0.11052089184522629, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.14967022836208344, 0.05171789228916168, 0.3914002478122711, 0.40721163153648376, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.40780311822891235, 0.04434635117650032, 0.05232110992074013, 0.3448564112186432, 0.15067294239997864, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3417491614818573, 0.023165758699178696, 0.008621969260275364, 0.03819064050912857, 0.566249430179596, 0.022023199126124382, 0.0, 0.0, 0.0, 0.0], [0.13283461332321167, 0.0027981544844806194, 0.001892031985335052, 0.057958006858825684, 0.4807162284851074, 0.22431829571723938, 0.09948258846998215, 0.0, 0.0, 0.0], [0.5702553391456604, 0.005225116387009621, 0.0014312443090602756, 0.028526127338409424, 0.15899939835071564, 0.05284468084573746, 0.022491520270705223, 0.16022635996341705, 0.0, 0.0], [0.005209033377468586, 6.901475717313588e-05, 5.760595013271086e-05, 0.006149875931441784, 0.006613760255277157, 0.010193211026489735, 0.013639912940561771, 0.9578513503074646, 0.00021631908020935953, 0.0], [0.5839820504188538, 0.007275882177054882, 0.03890826180577278, 0.2169828861951828, 0.02285575494170189, 0.0033320344518870115, 0.0027764069382101297, 0.032896872609853745, 0.038299210369586945, 0.05269058048725128]], [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.050016310065984726, 0.9499835968017578, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1501082479953766, 0.3363426625728607, 0.5135491490364075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3278971016407013, 0.3615517318248749, 0.08450257778167725, 0.22604861855506897, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.25667694211006165, 0.40780505537986755, 0.15422186255455017, 0.10097997635602951, 0.08031607419252396, 0.0, 0.0, 0.0, 0.0, 0.0], [0.06381947547197342, 0.026328660547733307, 0.009785240516066551, 0.001955200219526887, 0.6504424810409546, 0.24766895174980164, 0.0, 0.0, 0.0, 0.0], [0.20312389731407166, 0.04875075817108154, 0.01619674637913704, 0.028123438358306885, 0.08579143136739731, 0.20417368412017822, 0.41383999586105347, 0.0, 0.0, 0.0], [0.0028631098102778196, 0.00043423930765129626, 0.00021322182146832347, 0.0004598038794938475, 0.009248310700058937, 0.010348621755838394, 0.14874719083309174, 0.8276853561401367, 0.0, 0.0], [0.0018436884274706244, 0.00014077842934057117, 0.00019285999587737024, 0.0006260477821342647, 0.010675687342882156, 0.007219970691949129, 0.05410425364971161, 0.9211149215698242, 0.0040815602988004684, 0.0], [0.0996592566370964, 0.00012745348794851452, 0.0005518114776350558, 0.00026576913660392165, 0.00016320105351042002, 9.30886235437356e-05, 0.00024295282491948456, 0.000212193961488083, 0.0026789114344865084, 0.8960053324699402]]], [[[0.11433855444192886, 0.04686390608549118, 0.05050662159919739, 0.01692698895931244, 0.014815520495176315, 0.0768260732293129, 0.017432983964681625, 0.6101956367492676, 0.052093639969825745, 0.0], [0.38972416520118713, 0.22944767773151398, 0.07904881238937378, 0.052096493542194366, 0.007011168636381626, 0.006784902419894934, 0.0014207185013219714, 0.13426420092582703, 0.10020176321268082, 0.0], [0.11186019331216812, 0.017331605777144432, 0.04338392615318298, 0.012996343895792961, 0.0037596735637634993, 0.04186789691448212, 0.010188347660005093, 0.40130615234375, 0.35730576515197754, 0.0], [0.10591340810060501, 0.20223784446716309, 0.2886512279510498, 0.06966178864240646, 0.009836114011704922, 0.0229511596262455, 0.008547846227884293, 0.14636646211147308, 0.14583423733711243, 0.0], [0.07172418385744095, 0.2936038672924042, 0.03526498004794121, 0.13891401886940002, 0.06139945611357689, 0.03925776481628418, 0.05349786579608917, 0.1478489190340042, 0.15848881006240845, 0.0], [0.06896000355482101, 0.07670743763446808, 0.04746192321181297, 0.0077508557587862015, 0.0064402432180941105, 0.0690828412771225, 0.07239065319299698, 0.3534849286079407, 0.2977212965488434, 0.0], [0.07854402810335159, 0.05315924435853958, 0.006970829796046019, 0.01197806466370821, 0.15678070485591888, 0.059328123927116394, 0.0844779834151268, 0.058106619864702225, 0.4906543493270874, 0.0], [0.03681005910038948, 0.0140389958396554, 0.007565703243017197, 0.004180380143225193, 0.03550711274147034, 0.08768045157194138, 0.04672156274318695, 0.05331620201468468, 0.7141793966293335, 0.0], [0.011276278644800186, 0.0043571279384195805, 0.0015699869254603982, 0.009309794753789902, 0.5466312766075134, 0.06633520126342773, 0.012565904296934605, 0.036605555564165115, 0.3113488256931305, 0.0], [0.06600549072027206, 0.025541657581925392, 0.15132947266101837, 0.052793603390455246, 0.07684693485498428, 0.05682613328099251, 0.01590665802359581, 0.05306769907474518, 0.501682460308075, 0.0]], [[0.02032626047730446, 0.020203230902552605, 0.6020291447639465, 0.030009541660547256, 0.018654465675354004, 0.18802858889102936, 0.05143868178129196, 0.02303317002952099, 0.04627683013677597, 0.0], [0.0019176346249878407, 0.016540443524718285, 0.9535443782806396, 0.023291945457458496, 0.00010982116509694606, 0.0009689715225249529, 0.00013850984396412969, 0.00020587266772054136, 0.0032822038047015667, 0.0], [0.0021921033039689064, 0.01942657120525837, 0.9639623165130615, 0.007353355176746845, 5.936318120802753e-05, 0.0048356493934988976, 6.20722203166224e-05, 0.00038220148417167366, 0.0017266407376155257, 0.0], [0.009460263885557652, 0.051493529230356216, 0.3948231041431427, 0.009338779374957085, 0.006950095761567354, 0.48816555738449097, 0.01867900975048542, 0.0028053568676114082, 0.018284155055880547, 0.0], [0.003424277761951089, 0.009652957320213318, 0.10005363076925278, 0.026136387139558792, 0.09182560443878174, 0.6324647068977356, 0.07658552378416061, 0.005459657870233059, 0.05439731851220131, 0.0], [0.0026005429681390524, 0.0029943317640572786, 0.26352012157440186, 0.02426978573203087, 0.05801504850387573, 0.49633610248565674, 0.11849544942378998, 0.009708826430141926, 0.02405967190861702, 0.0], [0.0027453943621367216, 0.0021119171287864447, 0.0030521987937390804, 0.09308812767267227, 0.28145554661750793, 0.015254919417202473, 0.530491828918457, 0.007408978417515755, 0.06439103186130524, 0.0], [0.0012973180273547769, 0.002199073787778616, 0.0031004296615719795, 0.024488963186740875, 0.8535729050636292, 0.016068320721387863, 0.029179612174630165, 0.012250186875462532, 0.05784311145544052, 0.0], [0.0010010729311034083, 0.0008253253763541579, 0.0028483583591878414, 0.028342707082629204, 0.860925018787384, 0.0038871155120432377, 0.006998666562139988, 0.01413769368082285, 0.08103384077548981, 0.0], [0.007383578456938267, 0.056256651878356934, 0.5807297825813293, 0.01667044125497341, 0.03810223564505577, 0.07880110293626785, 0.009197888895869255, 0.12926581501960754, 0.08359251171350479, 0.0]], [[0.023356424644589424, 0.012650059536099434, 0.05017145350575447, 0.05590398982167244, 0.05159280076622963, 0.01602507382631302, 0.014807065948843956, 0.654244601726532, 0.12124844640493393, 0.0], [0.030835414305329323, 0.04180247709155083, 0.029645785689353943, 0.20071062445640564, 0.010328685864806175, 0.03208288922905922, 0.026780622079968452, 0.457701712846756, 0.17011170089244843, 0.0], [0.02856343612074852, 0.03459611535072327, 0.15441730618476868, 0.04662194848060608, 0.0013040672056376934, 0.017847269773483276, 0.02464178577065468, 0.5969575643539429, 0.09505032747983932, 0.0], [0.03350958973169327, 0.02514287829399109, 0.027676144614815712, 0.11052078753709793, 0.15496152639389038, 0.08862635493278503, 0.027723105624318123, 0.24766287207603455, 0.28417670726776123, 0.0], [0.014355039224028587, 0.005383878946304321, 0.002517768880352378, 0.09422861039638519, 0.06622537225484848, 0.046315327286720276, 0.08473969250917435, 0.4999735355377197, 0.18626095354557037, 0.0], [0.002460801973938942, 0.0016284199664369226, 0.005857668351382017, 0.006880565080791712, 0.7626023292541504, 0.025456121191382408, 0.021016357466578484, 0.06090177595615387, 0.11319592595100403, 0.0], [0.007633878383785486, 0.002682786202058196, 0.0008938225219026208, 0.006808742880821228, 0.17231638729572296, 0.049100711941719055, 0.32851701974868774, 0.0061601921916007996, 0.4258863925933838, 0.0], [0.003303236560896039, 0.0015338786179199815, 0.0017581325955688953, 0.0052335225045681, 0.24177710711956024, 0.09136255830526352, 0.06603478640317917, 0.0047843558713793755, 0.5842124223709106, 0.0], [0.011694137938320637, 0.0015430846251547337, 0.00043408613419160247, 0.005433904007077217, 0.03723231703042984, 0.1666216105222702, 0.04878358170390129, 0.024785596877336502, 0.7034717798233032, 0.0], [0.028515880927443504, 0.0183264147490263, 0.011487613432109356, 0.03205259144306183, 0.06179385632276535, 0.041277043521404266, 0.014015565626323223, 0.06198226660490036, 0.7305486798286438, 0.0]], [[0.10071786493062973, 0.017111245542764664, 0.07246935367584229, 0.01480931881815195, 0.14864948391914368, 0.20273517072200775, 0.054981958121061325, 0.25890761613845825, 0.12961813807487488, 0.0], [0.049787674099206924, 0.02108882926404476, 0.20989678800106049, 0.006962155923247337, 0.21569682657718658, 0.1622857302427292, 0.016771212220191956, 0.2403237521648407, 0.07718709856271744, 0.0], [0.024618864059448242, 0.010488898493349552, 0.4834355115890503, 0.015693388879299164, 0.07393413037061691, 0.055557433515787125, 0.007495412603020668, 0.27800077199935913, 0.05077548325061798, 0.0], [0.006324393209069967, 0.0006586945382878184, 0.02188086323440075, 0.003439908614382148, 0.055277179926633835, 0.5423230528831482, 0.1656835526227951, 0.12264314293861389, 0.08176910877227783, 0.0], [0.008246216922998428, 0.000647101376671344, 0.018551276996731758, 0.0031310885678976774, 0.04379039630293846, 0.34376823902130127, 0.2999532222747803, 0.13205647468566895, 0.1498558670282364, 0.0], [0.001800144906155765, 0.00032634654780849814, 0.02560480497777462, 0.0014933178899809718, 0.04328969866037369, 0.48067817091941833, 0.22867664694786072, 0.008819987997412682, 0.20931090414524078, 0.0], [0.0023069612216204405, 0.0018136217258870602, 0.006447605788707733, 0.005140945315361023, 0.046570103615522385, 0.045606330037117004, 0.3236173987388611, 0.014286459423601627, 0.5542104840278625, 0.0], [0.0025225167628377676, 0.000774701707996428, 0.0168449804186821, 0.0014132045907899737, 0.1692919135093689, 0.21547472476959229, 0.19468647241592407, 0.00621472392231226, 0.392776757478714, 0.0], [0.006893941201269627, 0.0026040272787213326, 0.036687299609184265, 0.0016275923699140549, 0.13132861256599426, 0.15552441775798798, 0.23651301860809326, 0.023025648668408394, 0.4057953953742981, 0.0], [0.004257934633642435, 0.008543262258172035, 0.05716743320226669, 0.0024442216381430626, 0.027526315301656723, 0.08828678727149963, 0.025276461616158485, 0.2843557894229889, 0.5021417737007141, 0.0]], [[0.2613511085510254, 0.024888625368475914, 0.11462423205375671, 0.021279124543070793, 0.1065509021282196, 0.23139707744121552, 0.07117345929145813, 0.09822205454111099, 0.07051338255405426, 0.0], [0.08973123878240585, 0.07256940752267838, 0.3644520342350006, 0.09313907474279404, 0.10501276701688766, 0.026235496625304222, 0.035534195601940155, 0.05646198242902756, 0.15686386823654175, 0.0], [0.12105944007635117, 0.03531542792916298, 0.18099160492420197, 0.04576702043414116, 0.03264385089278221, 0.04934798926115036, 0.0072426870465278625, 0.2739674150943756, 0.25366437435150146, 0.0], [0.049101557582616806, 0.02436317875981331, 0.1119280532002449, 0.019082490354776382, 0.23333144187927246, 0.12024182081222534, 0.09606382250785828, 0.03866123780608177, 0.3072265088558197, 0.0], [0.011761177331209183, 0.004259902983903885, 0.019396282732486725, 0.010304590687155724, 0.5410462021827698, 0.1548439860343933, 0.1577453315258026, 0.022628072649240494, 0.07801424711942673, 0.0], [0.011012338101863861, 0.006456742994487286, 0.03514476120471954, 0.01111147552728653, 0.3646441400051117, 0.06045660004019737, 0.22725869715213776, 0.030072104185819626, 0.2538430392742157, 0.0], [0.0009554855059832335, 0.0010365764610469341, 0.000539954868145287, 0.013481645844876766, 0.6702913641929626, 0.013201623223721981, 0.06565960496664047, 0.008186675608158112, 0.2266470193862915, 0.0], [0.007978711277246475, 0.0019918852485716343, 0.0007363414042629302, 0.010062554851174355, 0.10717969387769699, 0.01258536335080862, 0.08278501033782959, 0.02946571074426174, 0.7472147941589355, 0.0], [0.014369996264576912, 0.00412968173623085, 0.002898097038269043, 0.0381503589451313, 0.28382056951522827, 0.03412872180342674, 0.2624143660068512, 0.04523473232984543, 0.3148534893989563, 0.0], [0.0025127469561994076, 0.0030011499766260386, 0.0036209137178957462, 0.0006047216593287885, 0.01094596553593874, 0.0023283734917640686, 0.003409643191844225, 0.009625249542295933, 0.9639512896537781, 0.0]], [[0.023785226047039032, 0.024275904521346092, 0.6168470978736877, 0.01581703871488571, 0.026939542964100838, 0.1783975064754486, 0.04853774979710579, 0.02762567065656185, 0.03777410835027695, 0.0], [0.006065902300179005, 0.04932599142193794, 0.8176359534263611, 0.018976736813783646, 0.008159944787621498, 0.011068272404372692, 0.010428683832287788, 0.014124251902103424, 0.06421414017677307, 0.0], [0.0037236923817545176, 0.007844064384698868, 0.9502744674682617, 0.0048003061674535275, 0.00022506865207105875, 0.004834793973714113, 0.0015490480000153184, 0.0026021157391369343, 0.024146683514118195, 0.0], [0.0007564057596027851, 0.0017460802337154746, 0.15768493711948395, 0.004074132069945335, 0.015430302359163761, 0.7368869781494141, 0.028010869398713112, 0.013945921324193478, 0.04146439954638481, 0.0], [0.0008445779676549137, 0.0015138997696340084, 0.17073306441307068, 0.0074179465882480145, 0.08121992647647858, 0.5853323936462402, 0.09402737021446228, 0.024092217907309532, 0.034818582236766815, 0.0], [0.0006014688406139612, 0.0016882645431905985, 0.16094569861888885, 0.003698966233059764, 0.034668561071157455, 0.5876308679580688, 0.09562253206968307, 0.05209798738360405, 0.06304588913917542, 0.0], [0.00015283364336937666, 0.0004516944463830441, 0.003205003682523966, 0.0049727726727724075, 0.10853080451488495, 0.03262018784880638, 0.6125266551971436, 0.005719948559999466, 0.2318202555179596, 0.0], [9.933842375176027e-05, 0.00020211786613799632, 0.0037883264012634754, 0.0051808832213282585, 0.6936825513839722, 0.10089477151632309, 0.023457802832126617, 0.011726793833076954, 0.16096755862236023, 0.0], [0.0002526468597352505, 0.0010056017199531198, 0.003837066935375333, 0.034950658679008484, 0.5882559418678284, 0.029549231752753258, 0.030938459560275078, 0.01461110170930624, 0.2965993583202362, 0.0], [0.0029390468262135983, 0.005815382581204176, 0.06488344818353653, 0.008705642074346542, 0.010130577720701694, 0.012970774434506893, 0.019612692296504974, 0.007819950580596924, 0.8671225309371948, 0.0]], [[0.027314670383930206, 0.02143898233771324, 0.1116434708237648, 0.006578116212040186, 0.20446842908859253, 0.3867157995700836, 0.054494183510541916, 0.09778231382369995, 0.08956411480903625, 0.0], [0.09418193250894547, 0.7071846127510071, 0.05323847755789757, 0.0077135805040597916, 0.01789833791553974, 0.010848474688827991, 0.0020562252029776573, 0.01705808937549591, 0.08982021361589432, 0.0], [0.005751691292971373, 0.01031999196857214, 0.8884198069572449, 0.00210022390820086, 0.0066058398224413395, 0.019834432750940323, 0.002143828198313713, 0.02793465554714203, 0.036889322102069855, 0.0], [0.027659546583890915, 0.03931494802236557, 0.10616040229797363, 0.011142275296151638, 0.1017894372344017, 0.30847156047821045, 0.12201698124408722, 0.05519269034266472, 0.22825226187705994, 0.0], [0.037639543414115906, 0.062483835965394974, 0.050776157528162, 0.012697378173470497, 0.27911704778671265, 0.19993652403354645, 0.14870049059391022, 0.10304640233516693, 0.10560261458158493, 0.0], [0.007995839230716228, 0.008397839032113552, 0.03270075097680092, 0.004312656354159117, 0.03775893524289131, 0.3733556568622589, 0.3424486219882965, 0.012857009656727314, 0.18017242848873138, 0.0], [0.012142053805291653, 0.007298614829778671, 0.016982076689600945, 0.02473442070186138, 0.08738671243190765, 0.033574704080820084, 0.27830857038497925, 0.033199213445186615, 0.5063735842704773, 0.0], [0.02729739435017109, 0.05135440081357956, 0.03332214429974556, 0.02499799057841301, 0.11955489963293076, 0.020848069339990616, 0.017926985397934914, 0.01858661323785782, 0.6861116290092468, 0.0], [0.018110578879714012, 0.011406980454921722, 0.0018257799092680216, 0.025524618104100227, 0.3885835111141205, 0.010744227096438408, 0.008441396057605743, 0.003679890651255846, 0.5316829681396484, 0.0], [0.02325628325343132, 0.013795747421681881, 0.0823512151837349, 0.0021813653875142336, 0.03511650115251541, 0.0814405307173729, 0.02589382231235504, 0.14330172538757324, 0.5926627516746521, 0.0]], [[0.011912941001355648, 0.006341524887830019, 0.1334817260503769, 0.017931688576936722, 0.005569889210164547, 0.7441595792770386, 0.0258712787181139, 0.034265827387571335, 0.020465616136789322, 0.0], [0.03743305802345276, 0.05229289084672928, 0.3549361228942871, 0.028500670567154884, 0.01974724419414997, 0.29288655519485474, 0.08050932735204697, 0.06582070142030716, 0.06787342578172684, 0.0], [0.025082573294639587, 0.10057684034109116, 0.7856844663619995, 0.01178921852260828, 0.0010154875926673412, 0.02595749869942665, 0.008632739074528217, 0.006036050152033567, 0.0352250337600708, 0.0], [0.010372502729296684, 0.023954369127750397, 0.18692812323570251, 0.03930393233895302, 0.004741673823446035, 0.46527597308158875, 0.1267295777797699, 0.048278260976076126, 0.09441567957401276, 0.0], [0.0031391805969178677, 0.006868112366646528, 0.1369999200105667, 0.013019833713769913, 0.008593270555138588, 0.6626507639884949, 0.07777946442365646, 0.052107103168964386, 0.03884238004684448, 0.0], [0.005276903510093689, 0.026354510337114334, 0.056238383054733276, 0.03191604092717171, 0.025259410962462425, 0.3898610472679138, 0.2790180444717407, 0.028249284252524376, 0.1578262895345688, 0.0], [0.001124523114413023, 0.0035109275486320257, 0.0021898215636610985, 0.043262600898742676, 0.0267842635512352, 0.03029855713248253, 0.5000982284545898, 0.0056237452663481236, 0.38710734248161316, 0.0], [0.00020160828717052937, 0.0004381221951916814, 0.010444902814924717, 0.010398894548416138, 0.10143585503101349, 0.23107938468456268, 0.09920267760753632, 0.019364865496754646, 0.5274338126182556, 0.0], [0.00019610628078226, 0.00041535915806889534, 0.009462974965572357, 0.005905983969569206, 0.29870083928108215, 0.24092966318130493, 0.11132201552391052, 0.05168075114488602, 0.2813863754272461, 0.0], [0.004135515075176954, 0.014277483336627483, 0.15738952159881592, 0.003396045882254839, 0.01641785353422165, 0.07296160608530045, 0.02388397790491581, 0.024013018235564232, 0.683525025844574, 0.0]]]], \"top_text\": [\"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"], \"bot_text\": [\"\u003cpad\u003e\", \"Es\", \"ist\", \"sch\\u00f6n\", \", \", \"heute\", \"neue\", \"Dinge\", \"zu\", \"lernen\", \"!\"]}}" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "\n", - "/**\n", - " * @fileoverview Transformer Visualization D3 javascript code.\n", - " */\n", - "\n", - "requirejs(['jquery', 'd3'],\n", - "function($, d3) {\n", - "\n", - "var attention = window.attention;\n", - "\n", - "const TEXT_SIZE = 15;\n", - "const BOXWIDTH = TEXT_SIZE * 8;\n", - "const BOXHEIGHT = TEXT_SIZE * 1.5;\n", - "const WIDTH = 2000;\n", - "const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;\n", - "const MATRIX_WIDTH = 150;\n", - "const head_colours = d3.scale.category10();\n", - "const CHECKBOX_SIZE = 20;\n", - "\n", - "function lighten(colour) {\n", - " var c = d3.hsl(colour);\n", - " var increment = (1 - c.l) * 0.6;\n", - " c.l += increment;\n", - " c.s -= increment;\n", - " return c;\n", - "}\n", - "\n", - "function transpose(mat) {\n", - " return mat[0].map(function(col, i) {\n", - " return mat.map(function(row) {\n", - " return row[i];\n", - " });\n", - " });\n", - "}\n", - "\n", - "function zip(a, b) {\n", - " return a.map(function (e, i) {\n", - " return [e, b[i]];\n", - " });\n", - "}\n", - "\n", - "\n", - "function renderVis(id, top_text, bot_text, attention_heads, config) {\n", - " $(id).empty();\n", - " var svg = d3.select(id)\n", - " .append('svg')\n", - " .attr(\"width\", WIDTH)\n", - " .attr(\"height\", HEIGHT);\n", - "\n", - " var att_data = [];\n", - " for (var i=0; i \u003c attention_heads.length; i++) {\n", - " var att_trans = transpose(attention_heads[i]);\n", - " att_data.push(zip(attention_heads[i], att_trans));\n", - " }\n", - "\n", - " renderText(svg, top_text, true, att_data, 0);\n", - " renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);\n", - "\n", - " renderAttentionHighlights(svg, att_data);\n", - "\n", - " svg.append(\"g\").classed(\"attention_heads\", true);\n", - "\n", - " renderAttention(svg, attention_heads);\n", - "\n", - " draw_checkboxes(config, 0, svg, attention_heads);\n", - "}\n", - "\n", - "\n", - "function renderText(svg, text, is_top, att_data, left_pos) {\n", - " var id = is_top ? \"top\" : \"bottom\";\n", - " var textContainer = svg.append(\"svg:g\")\n", - " .attr(\"id\", id);\n", - "\n", - " textContainer.append(\"g\").classed(\"attention_boxes\", true)\n", - " .selectAll(\"g\")\n", - " .data(att_data)\n", - " .enter()\n", - " .append(\"g\")\n", - " .selectAll(\"rect\")\n", - " .data(function(d) {return d;})\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"x\", function(d, i, j) {\n", - " return left_pos + box_offset(j);\n", - " })\n", - " .attr(\"y\", function(d, i) {\n", - " return (+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .attr(\"fill\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .style(\"opacity\", 0.0);\n", - "\n", - "\n", - " var tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n", - " .data(text)\n", - " .enter()\n", - " .append(\"g\");\n", - "\n", - " tokenContainer.append(\"rect\")\n", - " .classed(\"background\", true)\n", - " .style(\"opacity\", 0.0)\n", - " .attr(\"fill\", \"lightgray\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " })\n", - " .attr(\"width\", BOXWIDTH)\n", - " .attr(\"height\", BOXHEIGHT);\n", - "\n", - " var theText = tokenContainer.append(\"text\")\n", - " .text(function(d) { return d; })\n", - " .attr(\"font-size\", TEXT_SIZE + \"px\")\n", - " .style(\"cursor\", \"default\")\n", - " .style(\"-webkit-user-select\", \"none\")\n", - " .attr(\"x\", left_pos)\n", - " .attr(\"y\", function(d, i) {\n", - " return (i+1) * BOXHEIGHT;\n", - " });\n", - "\n", - " if (is_top) {\n", - " theText.style(\"text-anchor\", \"end\")\n", - " .attr(\"dx\", BOXWIDTH - TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " } else {\n", - " theText.style(\"text-anchor\", \"start\")\n", - " .attr(\"dx\", + TEXT_SIZE)\n", - " .attr(\"dy\", TEXT_SIZE);\n", - " }\n", - "\n", - " tokenContainer.on(\"mouseover\", function(d, index) {\n", - " textContainer.selectAll(\".background\")\n", - " .style(\"opacity\", function(d, i) {\n", - " return i == index ? 1.0 : 0.0;\n", - " });\n", - "\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"none\");\n", - "\n", - " svg.selectAll(\".line_heads\") // To get the nesting to work.\n", - " .selectAll(\".att_lines\")\n", - " .attr(\"stroke-opacity\", function(d) {\n", - " return 1.0;\n", - " })\n", - " .attr(\"y1\", function(d, i) {\n", - " if (is_top) {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", function(d, i) {\n", - " if (is_top) {\n", - " return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " } else {\n", - " return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);\n", - " }\n", - " })\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", function(d, i, j) {\n", - " return head_colours(j);\n", - " })\n", - " .attr(\"stroke-opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j]) {\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " });\n", - "\n", - "\n", - " function updateAttentionBoxes() {\n", - " var id = is_top ? \"bottom\" : \"top\";\n", - " var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;\n", - " svg.select(\"#\" + id)\n", - " .selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .attr(\"x\", function(d, i, j) { return the_left_pos + box_offset(j); })\n", - " .attr(\"y\", function(d, i) { return (i+1) * BOXHEIGHT; })\n", - " .attr(\"width\", BOXWIDTH/active_heads())\n", - " .attr(\"height\", function() { return BOXHEIGHT; })\n", - " .style(\"opacity\", function(d, i, j) {\n", - " if (is_top) {d = d[0];} else {d = d[1];}\n", - " if (config.head_vis[j])\n", - " if (d) {\n", - " return d[index];\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " else\n", - " return 0.0;\n", - "\n", - " });\n", - " }\n", - "\n", - " updateAttentionBoxes();\n", - " });\n", - "\n", - " textContainer.on(\"mouseleave\", function() {\n", - " d3.select(this).selectAll(\".background\")\n", - " .style(\"opacity\", 0.0);\n", - "\n", - " svg.selectAll(\".att_lines\").attr(\"stroke-opacity\", 0.0);\n", - " svg.selectAll(\".attention_heads\").style(\"display\", \"inline\");\n", - " svg.selectAll(\".attention_boxes\")\n", - " .selectAll(\"g\")\n", - " .selectAll(\"rect\")\n", - " .style(\"opacity\", 0.0);\n", - " });\n", - "}\n", - "\n", - "function renderAttentionHighlights(svg, attention) {\n", - " var line_container = svg.append(\"g\");\n", - " line_container.selectAll(\"g\")\n", - " .data(attention)\n", - " .enter()\n", - " .append(\"g\")\n", - " .classed(\"line_heads\", true)\n", - " .selectAll(\"line\")\n", - " .data(function(d){return d;})\n", - " .enter()\n", - " .append(\"line\").classed(\"att_lines\", true);\n", - "}\n", - "\n", - "function renderAttention(svg, attention_heads) {\n", - " var line_container = svg.selectAll(\".attention_heads\");\n", - " line_container.html(null);\n", - " for(var h=0; h\u003cattention_heads.length; h++) {\n", - " for(var a=0; a\u003cattention_heads[h].length; a++) {\n", - " for(var s=0; s\u003cattention_heads[h][a].length; s++) {\n", - " line_container.append(\"line\")\n", - " .attr(\"y1\", (s+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x1\", BOXWIDTH)\n", - " .attr(\"y2\", (a+1) * BOXHEIGHT + (BOXHEIGHT/2))\n", - " .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n", - " .attr(\"stroke-width\", 2)\n", - " .attr(\"stroke\", head_colours(h))\n", - " .attr(\"stroke-opacity\", function() {\n", - " if (config.head_vis[h]) {\n", - " return attention_heads[h][a][s]/active_heads();\n", - " } else {\n", - " return 0.0;\n", - " }\n", - " }());\n", - " }\n", - " }\n", - " }\n", - "}\n", - "\n", - "// Checkboxes\n", - "function box_offset(i) {\n", - " var num_head_above = config.head_vis.reduce(\n", - " function(acc, val, cur) {return val \u0026\u0026 cur \u003c i ? acc + 1: acc;}, 0);\n", - " return num_head_above*(BOXWIDTH / active_heads());\n", - "}\n", - "\n", - "function active_heads() {\n", - " return config.head_vis.reduce(function(acc, val) {\n", - " return val ? acc + 1: acc;\n", - " }, 0);\n", - "}\n", - "\n", - "function draw_checkboxes(config, top, svg, attention_heads) {\n", - " var checkboxContainer = svg.append(\"g\");\n", - " var checkbox = checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .enter()\n", - " .append(\"rect\")\n", - " .attr(\"fill\", function(d, i) {\n", - " return head_colours(i);\n", - " })\n", - " .attr(\"x\", function(d, i) {\n", - " return (i+1) * CHECKBOX_SIZE;\n", - " })\n", - " .attr(\"y\", top)\n", - " .attr(\"width\", CHECKBOX_SIZE)\n", - " .attr(\"height\", CHECKBOX_SIZE);\n", - "\n", - " function update_checkboxes() {\n", - " checkboxContainer.selectAll(\"rect\")\n", - " .data(config.head_vis)\n", - " .attr(\"fill\", function(d, i) {\n", - " var head_colour = head_colours(i);\n", - " var colour = d ? head_colour : lighten(head_colour);\n", - " return colour;\n", - " });\n", - " }\n", - "\n", - " update_checkboxes();\n", - "\n", - " checkbox.on(\"click\", function(d, i) {\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) return;\n", - " config.head_vis[i] = !config.head_vis[i];\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "\n", - " checkbox.on(\"dblclick\", function(d, i) {\n", - " // If we double click on the only active head then reset\n", - " if (config.head_vis[i] \u0026\u0026 active_heads() == 1) {\n", - " config.head_vis = new Array(config.num_heads).fill(true);\n", - " } else {\n", - " config.head_vis = new Array(config.num_heads).fill(false);\n", - " config.head_vis[i] = true;\n", - " }\n", - " update_checkboxes();\n", - " renderAttention(svg, attention_heads);\n", - " });\n", - "}\n", - "\n", - "var config = {\n", - " layer: 0,\n", - " att_type: 'all',\n", - "};\n", - "\n", - "function visualize() {\n", - " var num_heads = attention['all']['att'][0].length;\n", - " config.head_vis = new Array(num_heads).fill(true);\n", - " config.num_heads = num_heads;\n", - " config.attention = attention;\n", - "\n", - " render();\n", - "}\n", - "\n", - "function render() {\n", - " var conf = config.attention[config.att_type];\n", - "\n", - " var top_text = conf.top_text;\n", - " var bot_text = conf.bot_text;\n", - " var attention = conf.att[config.layer];\n", - "\n", - " $(\"#vis svg\").empty();\n", - " renderVis(\"#vis\", top_text, bot_text, attention, config);\n", - "}\n", - "\n", - "$(\"#layer\").empty();\n", - "for(var i=0; i\u003c6; i++) {\n", - " $(\"#layer\").append($(\"\u003coption /\u003e\").val(i).text(i));\n", - "}\n", - "\n", - "$(\"#layer\").on('change', function(e) {\n", - " config.layer = +e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"#att_type\").on('change', function(e) {\n", - " config.att_type = e.currentTarget.value;\n", - " render();\n", - "});\n", - "\n", - "$(\"button\").on('click', visualize);\n", - "\n", - "visualize();\n", - "\n", - "});\n" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript object\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "call_html()\n", - "display.display(display.HTML(vis_html))\n", - "display.display(display.Javascript('window.attention = %s' % attention_json))\n", - "display.display(display.Javascript(vis_js))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "lydjSs3hgDVF" - }, - "outputs": [], - "source": [ - "" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "Attention_Visualization_in_Trax.ipynb", - "provenance": [ - { - "file_id": "1bJu3Qx37FY9UpHqVMyXCTNb64v4Iw_v7", - "timestamp": 1598692842045 - } - ], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/models/__init__.py b/trax/models/__init__.py index f827f1d94..d098a34bb 100644 --- a/trax/models/__init__.py +++ b/trax/models/__init__.py @@ -16,28 +16,24 @@ """Models defined in trax.""" import gin -from trax.models import atari_cnn -from trax.models import mlp -from trax.models import neural_gpu -from trax.models import resnet -from trax.models import rl -from trax.models import rnn -from trax.models import transformer +from trax.models import atari_cnn, gnn, mlp, neural_gpu, resnet, rl, rnn, transformer from trax.models.reformer import reformer -from trax.models.research import bert -from trax.models.research import configurable_transformer -from trax.models.research import hourglass -from trax.models.research import layerdrop_transformer -from trax.models.research import rezero -from trax.models.research import rse -from trax.models.research import terraformer -from trax.models.research import transformer2 +from trax.models.research import ( + bert, + configurable_transformer, + hourglass, + layerdrop_transformer, + rezero, + rse, + terraformer, + transformer2, +) # Ginify def model_configure(*args, **kwargs): - kwargs['module'] = 'trax.models' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.models" + return gin.external_configurable(*args, **kwargs) # pylint: disable=invalid-name @@ -49,37 +45,34 @@ def model_configure(*args, **kwargs): BERTRegressionHead = model_configure(bert.BERTRegressionHead) ConfigurableTerraformer = model_configure(terraformer.ConfigurableTerraformer) ConfigurableTransformer = model_configure( - configurable_transformer.ConfigurableTransformer) + configurable_transformer.ConfigurableTransformer +) ConfigurableTransformerEncoder = model_configure( - configurable_transformer.ConfigurableTransformerEncoder) + configurable_transformer.ConfigurableTransformerEncoder +) ConfigurableTransformerLM = model_configure( - configurable_transformer.ConfigurableTransformerLM) + configurable_transformer.ConfigurableTransformerLM +) MLP = model_configure(mlp.MLP) NeuralGPU = model_configure(neural_gpu.NeuralGPU) Reformer = model_configure(reformer.Reformer) ReformerLM = model_configure(reformer.ReformerLM) ReformerShortenLM = model_configure(reformer.ReformerShortenLM) Resnet50 = model_configure(resnet.Resnet50) -ReZeroTransformer = model_configure( - rezero.ReZeroTransformer) -ReZeroTransformerDecoder = model_configure( - rezero.ReZeroTransformerDecoder) -ReZeroTransformerEncoder = model_configure( - rezero.ReZeroTransformerEncoder) -ReZeroTransformerLM = model_configure( - rezero.ReZeroTransformerLM) -SkippingTransformerLM = model_configure( - layerdrop_transformer.SkippingTransformerLM) -LayerDropTransformerLM = model_configure( - layerdrop_transformer.LayerDropTransformerLM) +ReZeroTransformer = model_configure(rezero.ReZeroTransformer) +ReZeroTransformerDecoder = model_configure(rezero.ReZeroTransformerDecoder) +ReZeroTransformerEncoder = model_configure(rezero.ReZeroTransformerEncoder) +ReZeroTransformerLM = model_configure(rezero.ReZeroTransformerLM) +SkippingTransformerLM = model_configure(layerdrop_transformer.SkippingTransformerLM) +LayerDropTransformerLM = model_configure(layerdrop_transformer.LayerDropTransformerLM) EveryOtherLayerDropTransformerLM = model_configure( - layerdrop_transformer.EveryOtherLayerDropTransformerLM) + layerdrop_transformer.EveryOtherLayerDropTransformerLM +) Transformer = model_configure(transformer.Transformer) TransformerDecoder = model_configure(transformer.TransformerDecoder) TransformerEncoder = model_configure(transformer.TransformerEncoder) TransformerLM = model_configure(transformer.TransformerLM) -Transformer2 = model_configure( - transformer2.Transformer2) +Transformer2 = model_configure(transformer2.Transformer2) WideResnet = model_configure(resnet.WideResnet) Policy = model_configure(rl.Policy) PolicyAndValue = model_configure(rl.PolicyAndValue) @@ -90,3 +83,6 @@ def model_configure(*args, **kwargs): LSTMSeq2SeqAttn = model_configure(rnn.LSTMSeq2SeqAttn) ResidualShuffleExchange = model_configure(rse.ResidualShuffleExchange) HourglassLM = model_configure(hourglass.HourglassLM) +GraphConvNet = model_configure(gnn.GraphConvNet) +GraphAttentionNet = model_configure(gnn.GraphAttentionNet) +GraphEdgeNet = model_configure(gnn.GraphEdgeNet) diff --git a/trax/models/atari_cnn.py b/trax/models/atari_cnn.py index 99464d527..2cb15e960 100644 --- a/trax/models/atari_cnn.py +++ b/trax/models/atari_cnn.py @@ -19,83 +19,102 @@ def _FrameStack(n_frames): - """Stacks successive game frames along their last dimension.""" - # Input shape: (B, T, ..., C). - # Output shape: (B, T, ..., C * n_frames). - assert n_frames >= 1 - if n_frames == 1: - return [] # No-op; just let the data flow through. - return [ - # Create copies of input sequence, shift right by [0, ..., n_frames - 1] - # frames, and concatenate along the channel dimension. - tl.Branch(*map(_shift_right, range(n_frames))), - tl.Concatenate(n_items=n_frames, axis=-1) - ] + """Stacks successive game frames along their last dimension.""" + # Input shape: (B, T, ..., C). + # Output shape: (B, T, ..., C * n_frames). + assert n_frames >= 1 + if n_frames == 1: + return [] # No-op; just let the data flow through. + return [ + # Create copies of input sequence, shift right by [0, ..., n_frames - 1] + # frames, and concatenate along the channel dimension. + tl.Branch(*map(_shift_right, range(n_frames))), + tl.Concatenate(n_items=n_frames, axis=-1), + ] def _BytesToFloats(): - """Layer that converts unsigned bytes to floats.""" - return tl.Fn('BytesToFloats', lambda x: x / 255.0) - - -def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): - """An Atari CNN.""" - del mode - - # TODO(jonni): Include link to paper? - # Input shape: (B, T, H, W, C) - # Output shape: (B, T, output_size) - return tl.Serial( - _BytesToFloats(), - _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) - tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), - tl.Relu(), - tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), - tl.Relu(), - tl.Flatten(n_axes_to_keep=2), # B, T and rest. - tl.Dense(output_size), - tl.Relu(), - ) - - -def AtariCnnBody(n_frames=4, hidden_sizes=(32, 64, 64), - output_size=512, mode='train', - kernel_initializer=None, padding='VALID'): - """An Atari CNN.""" - del mode - - # TODO(jonni): Include link to paper? - # Input shape: (B, T, H, W, C) - # Output shape: (B, T, output_size) - return tl.Serial( - _BytesToFloats(), - _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) - tl.Conv(hidden_sizes[0], (8, 8), (4, 4), padding=padding, - kernel_initializer=kernel_initializer), - tl.Relu(), - tl.Conv(hidden_sizes[1], (4, 4), (2, 2), padding=padding, - kernel_initializer=kernel_initializer), - tl.Relu(), - tl.Conv(hidden_sizes[2], (3, 3), (1, 1), padding=padding, - kernel_initializer=kernel_initializer), - tl.Relu(), - tl.Flatten(n_axes_to_keep=2), # B, T and rest. - tl.Dense(output_size), - tl.Relu(), - ) - - -def FrameStackMLP(n_frames=4, hidden_sizes=(64,), output_size=64, - mode='train'): - """MLP operating on a fixed number of last frames.""" - del mode - - return tl.Serial( - _FrameStack(n_frames=n_frames), - [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes], - tl.Dense(output_size), - ) + """Layer that converts unsigned bytes to floats.""" + return tl.Fn("BytesToFloats", lambda x: x / 255.0) + + +def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode="train"): + """An Atari CNN.""" + del mode + + # TODO(jonni): Include link to paper? + # Input shape: (B, T, H, W, C) + # Output shape: (B, T, output_size) + return tl.Serial( + _BytesToFloats(), + _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) + tl.Conv(hidden_sizes[0], (5, 5), (2, 2), "SAME"), + tl.Relu(), + tl.Conv(hidden_sizes[1], (5, 5), (2, 2), "SAME"), + tl.Relu(), + tl.Flatten(n_axes_to_keep=2), # B, T and rest. + tl.Dense(output_size), + tl.Relu(), + ) + + +def AtariCnnBody( + n_frames=4, + hidden_sizes=(32, 64, 64), + output_size=512, + mode="train", + kernel_initializer=None, + padding="VALID", +): + """An Atari CNN.""" + del mode + + # TODO(jonni): Include link to paper? + # Input shape: (B, T, H, W, C) + # Output shape: (B, T, output_size) + return tl.Serial( + _BytesToFloats(), + _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) + tl.Conv( + hidden_sizes[0], + (8, 8), + (4, 4), + padding=padding, + kernel_initializer=kernel_initializer, + ), + tl.Relu(), + tl.Conv( + hidden_sizes[1], + (4, 4), + (2, 2), + padding=padding, + kernel_initializer=kernel_initializer, + ), + tl.Relu(), + tl.Conv( + hidden_sizes[2], + (3, 3), + (1, 1), + padding=padding, + kernel_initializer=kernel_initializer, + ), + tl.Relu(), + tl.Flatten(n_axes_to_keep=2), # B, T and rest. + tl.Dense(output_size), + tl.Relu(), + ) + + +def FrameStackMLP(n_frames=4, hidden_sizes=(64,), output_size=64, mode="train"): + """MLP operating on a fixed number of last frames.""" + del mode + + return tl.Serial( + _FrameStack(n_frames=n_frames), + [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes], + tl.Dense(output_size), + ) def _shift_right(n): # pylint: disable=invalid-name - return [tl.ShiftRight()] * n + return [tl.ShiftRight()] * n diff --git a/trax/models/atari_cnn_test.py b/trax/models/atari_cnn_test.py deleted file mode 100644 index fe3ded66d..000000000 --- a/trax/models/atari_cnn_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.models.atari_cnn.""" - -import functools -import operator as op -import numpy as np -from tensorflow import test -from trax.models import atari_cnn -from trax.shapes import ShapeDtype - - -class AtariCnnTest(test.TestCase): - - def test_computes(self): - hidden_size = (4, 4) - output_size = 6 - model = atari_cnn.AtariCnn( - hidden_sizes=hidden_size, output_size=output_size) - B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name - input_signature = ShapeDtype((1, 1) + OBS) - _, _ = model.init(input_signature) - x = np.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( - B, T + 1, *OBS) - y = model(x) - self.assertEqual((B, T + 1, output_size), y.shape) - - -class FrameStackMLPTest(test.TestCase): - - def test_computes(self): - hidden_size = (4, 4) - output_size = 6 - model = atari_cnn.FrameStackMLP( - hidden_sizes=hidden_size, output_size=output_size) - B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name - input_signature = ShapeDtype((1, 1, OBS)) - _, _ = model.init(input_signature) - x = np.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) - y = model(x) - self.assertEqual((B, T + 1, output_size), y.shape) - - -if __name__ == '__main__': - test.main() diff --git a/trax/models/gnn.py b/trax/models/gnn.py new file mode 100644 index 000000000..a4f680323 --- /dev/null +++ b/trax/models/gnn.py @@ -0,0 +1,168 @@ +# coding=utf-8 +"""Simple Graph Neural Network models for Trax. + +This module provides minimal building blocks for graph neural networks with +basic functionality like adjacency normalization, optional self-loops and a +light-weight graph attention layer. +""" + +from jax import nn + +from trax import layers as tl +from trax.fastmath import numpy as jnp + + +def normalize_adjacency(adj, add_self_loops=True, eps=1e-8): + """Returns normalized ``adj`` applying ``D^-1/2 (A + I) D^-1/2``. + + Args: + adj: ``(..., N, N)`` adjacency matrices, optionally batched. + add_self_loops: Whether to add identity connections before normalizing. + eps: Small constant for numerical stability. + + Returns: + Normalized adjacency matrices with the same shape as ``adj``. + """ + if add_self_loops: + eye = jnp.eye(adj.shape[-1]) + eye = jnp.broadcast_to(eye, adj.shape) + adj = adj + eye + deg = jnp.sum(adj, axis=-1) + inv_sqrt_deg = 1.0 / jnp.sqrt(deg + eps) + norm = adj * inv_sqrt_deg[..., None] * inv_sqrt_deg[..., None, :] + return norm + + +def GraphConv(out_dim, activation=tl.Relu, add_self_loops=True): + """Returns a graph convolution layer using normalized adjacency. + + The layer expects inputs ``(node_features, adjacency_matrix)`` and + returns ``(new_features, adjacency_matrix)`` so that multiple graph + convolution layers can be chained. + + Args: + out_dim: Size of the output node representation. + activation: Activation layer constructor applied after the dense step. + + Returns: + A :class:`~trax.layers.Serial` layer implementing graph convolution. + """ + + def _conv(f, a): + a_norm = normalize_adjacency(a, add_self_loops=add_self_loops) + return jnp.matmul(a_norm, f) + + return tl.Serial( + tl.Branch( + tl.Serial( + tl.Fn("Aggregate", _conv, n_out=1), + tl.Dense(out_dim), + activation(), + ), + tl.Select([1]), # Pass adjacency unchanged. + ) + ) + + +def GraphConvNet(hidden_sizes=(16, 2), activation=tl.Relu): + """Baseline graph neural network built from :func:`GraphConv` layers.""" + layers = [] + for size in hidden_sizes[:-1]: + layers.append(GraphConv(size, activation=activation)) + layers.append(GraphConv(hidden_sizes[-1], activation=tl.Serial)) + return tl.Serial(*layers) + + +def GraphAttentionConv(out_dim, num_heads=1, activation=tl.Relu): + """Graph convolution with attention akin to GAT.""" + + def _attention(q, k, v, a): + q = q.reshape((q.shape[0], q.shape[1], num_heads, out_dim)) + k = k.reshape((k.shape[0], k.shape[1], num_heads, out_dim)) + v = v.reshape((v.shape[0], v.shape[1], num_heads, out_dim)) + logits = jnp.einsum("bnhd,bmhd->bhnm", q, k) / jnp.sqrt(out_dim) + mask = (a > 0).astype(jnp.float32) + logits = logits - 1e9 * (1.0 - mask[:, None, :, :]) + attn = nn.softmax(logits, axis=-1) + out = jnp.einsum("bhnm,bmhd->bnhd", attn, v) + out = out.reshape((out.shape[0], out.shape[1], num_heads * out_dim)) + return out + + return tl.Serial( + tl.Branch( + tl.Serial( + tl.Select([0, 0, 0, 1]), + tl.Parallel( + tl.Dense(out_dim * num_heads), + tl.Dense(out_dim * num_heads), + tl.Dense(out_dim * num_heads), + None, + ), + tl.Fn("GAT", _attention, n_out=1), + tl.Dense(out_dim), + activation(), + ), + tl.Select([1]), + ) + ) + + +def GraphAttentionNet(hidden_sizes=(16, 2), activation=tl.Relu, num_heads=1): + """Stack of :func:`GraphAttentionConv` layers for small graphs.""" + layers = [] + for size in hidden_sizes[:-1]: + layers.append( + GraphAttentionConv(size, num_heads=num_heads, activation=activation) + ) + layers.append( + GraphAttentionConv(hidden_sizes[-1], num_heads=num_heads, activation=tl.Serial) + ) + return tl.Serial(*layers) + + +def GraphEdgeConv(node_out_dim, edge_out_dim, activation=tl.Relu, add_self_loops=True): + """Graph layer updating both node and edge features.""" + + def _prep(nodes, edges, adj): + adj_norm = normalize_adjacency(adj, add_self_loops=add_self_loops) + n_i = nodes[:, :, None, :] + n_j = nodes[:, None, :, :] + n_i = jnp.broadcast_to(n_i, edges.shape[:-1] + (nodes.shape[-1],)) + n_j = jnp.broadcast_to(n_j, edges.shape[:-1] + (nodes.shape[-1],)) + msg = jnp.concatenate([n_i, n_j, edges], axis=-1) + agg = jnp.einsum("bij,bijd->bid", adj_norm, msg) + node_in = jnp.concatenate([nodes, agg], axis=-1) + return node_in, msg, adj + + return tl.Serial( + tl.Fn("Prepare", _prep, n_out=3), + tl.Parallel(tl.Dense(node_out_dim), tl.Dense(edge_out_dim), None), + tl.Parallel(activation(), activation(), None), + ) + + +def GraphEdgeNet( + node_sizes=(16, 2), edge_sizes=(4, 2), activation=tl.Relu, add_self_loops=True +): + """Stack of :func:`GraphEdgeConv` layers with edge updates.""" + if len(node_sizes) != len(edge_sizes): + raise ValueError("node_sizes and edge_sizes must match length") + layers = [] + for n_size, e_size in zip(node_sizes[:-1], edge_sizes[:-1]): + layers.append( + GraphEdgeConv( + n_size, + e_size, + activation=activation, + add_self_loops=add_self_loops, + ) + ) + layers.append( + GraphEdgeConv( + node_sizes[-1], + edge_sizes[-1], + activation=tl.Serial, + add_self_loops=add_self_loops, + ) + ) + return tl.Serial(*layers) diff --git a/trax/models/reformer/image_generation.ipynb b/trax/models/reformer/image_generation.ipynb deleted file mode 100644 index 626a99cae..000000000 --- a/trax/models/reformer/image_generation.ipynb +++ /dev/null @@ -1,414 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Reformer: Image Generation", - "provenance": [], - "collapsed_sections": [ - "udDs_biH0n5U" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "udDs_biH0n5U", - "colab_type": "text" - }, - "source": [ - "#### Copyright 2020 Google LLC." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WPY-OyyM0pSs", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "psnUF-8c02o_", - "colab_type": "text" - }, - "source": [ - "# Reformer: Image Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/image_generation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1lnRd_IoERdk", - "colab_type": "text" - }, - "source": [ - "This notebook was designed to run on TPU.\n", - "\n", - "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8PluCmWbZIpJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install JAX. This custom build raises the TPU timeout threshold, because the\n", - "# default limit of 2 minutes is too short for sampling very long sequences.\n", - "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", - "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", - "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", - "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", - "\n", - "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", - "import requests\n", - "import os\n", - "if 'TPU_DRIVER_MODE' not in globals():\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", - " resp = requests.post(url)\n", - " TPU_DRIVER_MODE = 1\n", - "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "from jax.config import config\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "print(config.FLAGS.jax_backend_target)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yiPdBenoZwH6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", - "\n", - "from tensorflow.compat.v1.io.gfile import GFile\n", - "import gin\n", - "import os\n", - "import jax\n", - "import trax\n", - "from trax.models.beam_search import Search\n", - "from trax.supervised import inputs\n", - "\n", - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "from scipy.special import softmax" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yyxRk75iaAap", - "colab_type": "code", - "colab": {} - }, - "source": [ - "%matplotlib inline\n", - "from matplotlib import pyplot as plt" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FQ89jHCYfhpg" - }, - "source": [ - "## Load example data and model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qBvuw2h85WXE", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Normally we train on the full imagenet64 training set, which is quite large so\n", - "# we won't be loading it from this notebook. Instead, let's just load a few PNG\n", - "# images to use in our data pipeline.\n", - "DATA = []\n", - "for i in range(8):\n", - " img = plt.imread(GFile('gs://trax-ml/reformer/img{}.png'.format(i), 'rb'))\n", - " # Convert from RGBA floating-point to RGB integer representation.\n", - " img = np.asarray(img[:, :, :3] * 255, dtype=np.int32)\n", - " DATA.append(img)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "oBZh0Q2UEiaB", - "colab_type": "code", - "outputId": "d5adcac0-6f76-4c56-e6ef-74becaca87be", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 130 - } - }, - "source": [ - "# We can examine one of the images to make sure we've loaded it correctly.\n", - "plt.figure(figsize=(1.5, 1.5))\n", - "plt.axis('off')\n", - "plt.imshow(DATA[0])" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 5 - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAF8AAABfCAYAAACOTBv1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO29eaxk2X3f9znn3KXq1v72pdfp6Vk4\nC3tmyBluoklKpixSlk0pcbzJQhIkkaXAiWUpSOTEie0IiOEkBmwkCBAHhg1FSGzSlkjKDCVRpLgP\nZzjDmeF0z0z36/3129+rV3vde885+eN3q17T0jQ9DQQtIH2ABrrq3brL7/zOb/n+vud3lfee++Pe\nDH2vb+D/z+O+8O/huC/8ezjuC/8ejvvCv4cjuNMf/4+/9Re9wwHgnEZpj3MK4yRCyrVHe4V3coz3\n4HAYZ7HFvGqlOH/+Im988xWSigFAeYdTClNcxwIKj0eR2+JLL/+0kfMoLV8EPiBTclA7s6RDeODX\nPwtAZTlm41aXkw8sYm9clntkQDL4PsumDUCvvU8zewNIcPkYgDTz9IYZaVexsyXPEgaas6djMiu/\nG3RLHB7mHF6LKPcrcq5RjjKKB463ADA5ZB6+8p03+NDTpwCoJyF/+/98Xr1j4Vtl8E5+pxTgFRpP\n8RXay39cEa56QKGwKkBNJkQ5bJ4TBIDyhbAVyikwvjgPODTGO1Ivwg60xzqPLc4TBeCsJjOWsJih\nuoXmv/gXtMMhAMN2Sm25wnD9IsuJnLvX26MZ3CLubcr95BuMDkZcvbHF7uZAvhs/wGzyDGdOP0Zr\nVp5tc32DnUspYeMlAPYP1mlv1bi53uZULRKZaMVBd8wfvHQLgDj0rMyXeO7xFZK4UDQTvq1875ud\nezjuqPmB9+SFlis8hRIWy0A033uHLj5775jkbNN1pjRZnlFYjeKiigyPniZ4Bo9j7DRm8kvlMcqQ\nOdFyoxUOR+QduSg6x/7Zb3HDtinbKgBltUV9Y4t6Y4Z6+n0A0rZmPILdXgpAKQh48flNnnv4V/n5\nH38GgPm5OplJuHZzm+u3RIvPnDnF5vYup1Y/AcC1Nz7Dh37iU1y9dInf/eK/AmD7cBtGKXYogun0\nNWk6IptJ6I7lu3oluzvh57cJ0XtV2OWj75xy8mFih9ByjPfTCVIefGpxSmw7gHeTvxezqRzeycFa\nTa4HWnkCM/lCodCk1x0P/eN/BMCVwFH3muXahkzQ/qs0yOm+uYaviK1uGkdS2sLq7wIQph/nv/6r\nf4/V+ZNcu/o1AP7mf/s5fvEX/iK7Gylzhd2pNpt465hZmAHgodM/z/rFN3juxz7CtWuvifDOH5IR\nk/VFwN1xSr/f59Z6B63E7Ayz/t0JXwFq4nDRKOVxqImpBpQIuvhsvMK6XFbE5BCtyNOcQIN1E8Eq\ncuuJYrl87hxGabz15MVq0AZy5wkKr+xx0LHM/Dd/i/yxh+SYzZs08gvosTx8KT9AVc8ys7JP3O/I\nHY5eJ7F/gp/46G/Ic+SQlGP2Ll/BeLn+uSdO4dIRTz55ju+/dQmA2TDi+LFVgqgIHDwsnz5OHkac\nfebDAFx68wXqcRVdEQWZVR5nZxinOWkmqpal47sTfiFSeQjlUV6jsfjCKYoThomtcYDXCuudOFTA\nKMhthlIQFEqcA9oolJ9ESWKTPBAUBzmnQHnc5Fo+wycJJz/0LJ2OCLYZ3KLVvUqpslJM2AkWVYc0\nfYXlkWj+2XzAuHeF73z12wAsLZ7kyuWL1GoVlhdOA/CpP/soPqnzjedfZHHhGACDfofXXnyen/0P\n/wMAtq9fZn17l+6Vq+hcTFh/aOiPB0zwsdAYSqEmCkNKsTha6ytvK9n7Dvcejh+q+ZMwUikxMUop\nXBEyKjTgsEc+Eu8UGo2f2m5PlmcozXSqIw8aP/UBaIVRkkMExbnHTkyP1sXquAFP/M+/xltlg+ls\nAdDK9ghqyygt2rU8/h3ePdjg0VKXqHB0eZrz5sENLp8Xmx+VyqycOkZmHSqR+HxtZ5evffUL/Kmf\n+tPkA/HmX/nil/jYJ3+My1evAPC3/+avknvNw+96FxUt97jQ8OggYmJ3rfUMh552b0xWRCeTYOSu\nhK8mkQy+iNPVbcvF45VDTeJ9BQaFxaPcZIIUfpjKxEyOm7jtyXJF4RUY5cTcMIluQNtcHuJnPsXN\nH3uOK9fXWRrK71qt4/TJWfaHADxCxLl6H22H5L3i/sc5c5Ut3rv0IAAf/uk/iwlCxsMeX/29LwPQ\nnKvziU/8BLtbGyS1BID/9Jd+gd/67Of5L//aXwNg9vgCS4sn2b5ykbnSjtxTnmOUIVAiRm0slYqh\nisarifmchIjvUPhOia0H8A6MEoFMbJxxFtC44kLGWazyGCt2XSYPxi4j9RBPbLxSZNZL4oWsLu81\nTmsmWuScRQUB2TU55uFf+HmujFIeW1omUiKg9OB1FnWVObsOwJOsobMe6cCQ9cUut/cq/N5bi1wu\nvQLA73znl1mYn2fY7tJNRwA8+4EP8MwHP8gHVj9CurcNwE/+1CepJiXOPfu03EA+5NRixHh3DWdl\nzaqghPMON9FGb0B5PBr9b4Tlf9S4b/Pv4bhzkoWaLhsBFo7sP4AzQaEFogkOCU8z5THFzDtvSccD\nkuAozs8dxFozwY0AXK5QgcYVq0Mph9rJqPwXvwJA7dy7SK68SKOySHtX7PIxdYYkfY1nh2sAlPw2\ndqQhs1NT0EsjvnztOH/nf5HznDpzCus8SoEJi4jEOrTN6e9u8iu/8ssAtJIaOSFnapJ0NQPFW5df\npJwYFOJjPB7t7TRqQ2kU4J0jKBIW696+UnhH4XtAKTP97HyRYN3mTJUWkyQCU1jnUMpNbZ73Dj8C\nHwVoJk4IlJGUDMQHBKFnaC2mCC197jBPvpvV/+wXAfj25i3y9pDWeMQwbopgwwEPre9yxv42IFBR\nThN8D1uEg7+/NsfH//zPcuLBUwCM05RARzjvyYbDqRC8tfzK3/hVPvcb/zcASw+e4sMf/hgXXrwO\nQDI/wOsyJmCaG3qX4ZxCqeK5AgUeBqOcJBLDO1t5e+Ny5yRLqamgrRMYQcx9YfPw4BR+ik/maAzW\nH4FvzgssYa2/zchJLjDJZjWKwEseYIoHSbdg4X/7O7StaM78wT7naHJzdZ6VfXF4ycEOy/oaRYBC\nIzLYPCNwKU6LVn/7fIc/96eb06hDEeByh/MWClQzKpf4n/7+3+dLX/gsKw8/DEC5UuPmzavMnX4c\ngK3rXyQbDrAGOkXeVE7m0cGRWQ8kFaRSDklzkcm1g7uNdiZJFBAojXUOoxSTONLi8UodZcFOS/Tj\nQE2QTu/RFlykpriN8x6fa8KgMEReMfZOVs4N+e7c//D36K4ssteTsKWysUtt6QS0D9nfFKd4OtxD\nlxVqAp8oSznqieMO5cuFxohKkpBnoolZmmO0IR2PSALRhm989Wv8r3/312iePIErnHCne0iiF3ng\nKUngWmf/CuW4TKNeZ2dTENIL3/wNRuMIncu5U6VRoSYOA6KSCD+I3l689x3uPRw/BNtR0zVlXSr4\nCx6vJniHwgo6Lz/w4ApzMvEzucvILSTakbmjwkiExvsj6MKiCDOL/VMfA2D9Q8+Rr7RYvXJVTt1s\nsjcYkAURp5uCwy/tvsAp/30CNS/3Y/bETzlHruoAvLg5h/30Z3j47LsAaM7P0e+PMHbMMJOb/K9+\n+ZdoHFshIMBq0WITKcaDIV//g28A8N6PfoSPPvEMLz7/LQ5vXQCgXK8RmQRd2E+nFc46RsNsaq6d\nv0uzI1anALp0UAhUkEs5QEzFtJKFBa9wePQkAsodGrFUxSon9QqtISvu0GuBivub8PD/+Gvy3fiA\nE99fp312DoDLVy8SLZ0mHd4kOhQj3/RDgrCO82IabP51jDboCqhDuadGvcJL332Jv/HXxXF/8Ec+\njDIljh9b5dqVNwC4dfMyQZwQGD8NFBSek48eZ3FOJrZVCvjcP/1HmPwaSdgQ4UVV0LpIGgEnGW2S\nmCnSO5HfOxa+vw32VV6jdA5YnJvYc4EcpuGod3ibg/PkxepIxzmDQzAlRxSa6aRmgpPKhWyAu5HS\n+Gf/hH0lWr39ygXU4x9gqyP2NQkStjsd7CjmrVQmZHZ8nGbnJrOVr8v9lMv4NMX2mab3T54J+dZa\nxN6+AG2//dnPkeUjtnd2prHv7EyTU0shEUp8GlApa5ZbfdKewNWvXtjCK8PqwhLuKNyTRL2IfjRS\n3XKFXORh7zLa8d7jVZGruhDwAjNPYnjtMCYiKpfk4oFBB2WiJKFaEwEd7G/y8f+4ypULF7jxhsC1\ncQjKmClG5Dsp0S/95zy6tMjogmSrs9XjKDKWOhK1DCslgjwlz9rEsTjTNg121AlS+6aclxyvPWSg\njTzaRx8f8fyVEO8LiFcHuCigXmmRu8JT+5hTczmRzbGTe/I53b1rAjAh8XoQG7x3KHWE6lrv0BPX\nqTzDUcpSyzBM5bthOsn1//C473Dv4fghoabBhDEAppRgggAdRuigLH/WHoXBFlplrccEHu8DTCin\n9qbM6cfOcfrRJ+jvS7b45X/1WQYH+5hQ5r63D3/moz/C9nyJN8Xq8J4PvZ/zl7/NfGkRgD0/oDzy\n+Laj48TmX2s7Hqts0S5LLO663yMuGyIDg26h6eOI5QXH9r6UGo23jK3HGTUNP8slQ57nGHWbffaA\nNtMCvs0dldji3BiFxI/OaLTzR6bAK3QQcH1zTFIU8FfmkrsTfn35QUlFEbvmvJXkpBjWOtBqavMM\nKdlAk9shm/2u3KDNKUUVUueYhOO9zj6BiXAF+LX4V/99vvPUA3QurlFZEef52sYGcS/moC02N7CW\nuBJjqxnNkRRTZs8sUt02dGJxiuw1aLBLWFbUyjJBy6MxT52q86/b8jkKK+h+Sqg12YTeUqqg1Bjl\nBZEF0GgCHGmBlztrGWdD9vdyjBYNiZKYUmDQgZhGrUNCownrCXlRtruyeZcO11qLK8IojcdZi7Vj\nUitCw2lwYHRRDvQ5SivKcYh1ReLhIXdjXArpKCvOC0GQkkqiiv+Zv8S181doLc4xzGVVNXb36Xzn\n2/T3DwA4e+YsG7ZOTERpRlbDYTbgm7P/Hu8ZSUFbVR9nvrKD1Y5SLOHgI6e2+f52ix/9UA2A9f2U\nV75rCSOYK4sGzwYdYlJAHUHovqg754XwvCcMAsplBb7AhHLoeVAF7B0HniiIQXnCovZsjtCZPzTu\n2/x7OO6o+YP+cPp/rQ2B1linCEPRTo0jzVLGmRQzlI8wYUQ6Gk39QKgTgrhEmvemKyQuVbAbfar/\n/X8nJz//XcLHniJVM5T6Eu3s7dwiWqmjHjgJwG5jCZPtk4YV9iNRp9JoDCj20+MAPDvfQ6ffQY0H\nZKFEYGF5j0899TJr7WcBeP1YTKmmOdgfsxLKCjZX98gxKO3QhQ11WmF0wMRY6rBENYnQpQCjZMUY\nZcBP0Ray3NEdDX/gu+AO+MKdIeXAoApq3sHWJuVGgyCI8Lnc4Mh60CFjLxfQTpH2utTiEma63sbY\nPEahyDPBTfprfUo/9xfQD0kB+8yZIZeCOq3OW9hMHjZspozcLP0dEdDmy6/TeO8DVDo3sPtC59Bl\nRWheYa6yC0C+8wUy06SaDHCDIskzitJsyEOzki+cGW3xZPkMaTrim1fkvm84h4kMJkiICmGpwBDq\ngL6Tey77EXG5KJYUyKvFSwI5wb+MolqO8TiUE9FOWAzvWPhprhn3xebur62x+vRzpGnKqLCLoXZo\nH1IORMs8llK5iVEBFBpjswxPiveW7p4IqfTX/yNqn/hxGscEF99MQwYuxdiArMDY48iQbbfxuayq\n5LGTxMaS12bo2uK7/pCazqkENwE46DdYXszZS+eZj8ShdHstSjom7Ug221iBBxe/x1B/mFdfvQqA\nQWMAYxS6iNJ0WBLuqBfhG6Uko0dPgUSLx2s1FT5eSbkUmCSnd2AL3ln4h70e/bY86DhqEEYxzmpU\n4d3DwACOvED1tMtxNifNh2QFFBxHISooQ9Zm8MRzctyf+AjxMGSzJFFKMNilsn4d62Hu5Am5+MZr\n+GRMqkXLoyTkYGdMFEK9qNENjeayOsFuXyb1TLfGI0GfWrZBQ2rjxIsOZxyRFlOZ5ylYg+q+QJhL\nIuiM4EsedVTM8RLJ5dkEXgGtFdbpKRSuvEPfFszkCJzgUUd81jvI977DvYfjjpqflOt853vfA6Bq\nSjyMRRkwrqBWjyF3ClOogjcx1mY4DFHhFLPcYvtDvq5CnvtLPwfAIWOO54pLh4LVl3yf8twA26lS\nLRhie84wCldoHJewMk23mKtCevU1RolUslr+IpkrMYjkmFcrs1w+zHhP7GmNBNVcMtch9ISzslrT\n60OM9lQjw1ZbTIpyIcobKNgXIPUI4/QUgggMBWXRC90RpujtBLdU0zKfnxCypxDKOxb+Gzc3WN8v\n6Nglx2CsiEsRuRInqHOI9JHDGac5oAmjEJfJBFnn+M7738/euaf59YN9AD4wv8C4v89MqUhOQo0e\ntFg/UKg14eQEgwrdEDQyQXNhD/SIuBFQNSKQzD9AMF7HjiQLP6Y9zXLE74zeR70rcf5C5xIqSEHm\nAm0g23WMzYATDcl6d7oxQRgThjHKFA5XCaVxkuGWjEZ5gybFTUqrhamawO7eiRly6ojRqt1dAms3\ndnrs9OTBNtZ3+dSPOlym8boAloISqIy8EHSII3Cefn/MtRlJq7daq7jHz+HeWuPRZKIOnkvZkOG2\nZK8zpsTC7CKPrGxz+Xe/CUDr7DniGpTCglfvcvygj9kbE8fih3quykJ+QHcok/Hs6SF5ZZlu+xhf\nqT0BwPdKfxkXlliKRWHOlL/GA+PfRKURplhlXml0EGB8gFbhVLCpzadFf6U93jtQ+kibvSroR0fE\nAqc8yukpNWDiQ/6ocd/m38PxQzT/kPmmLMPz+7DXSWk1ckxRMNfjDpn1pCXR8n6rwjYD2t1Dbg0F\n/1h99t1cOf8qS9mABx//KADf2LrFUgBLi6sArAcxmUs5mRp+9Kf/PABfu/ga49xSGYldHt+6SPvC\nTZ4LDlheEU3vJE0eXC2xPnpAzpNU2LgJ26niMJXVMVpoYULLsPIIAK+WHyVY+nf5udf/AZ2XvyPP\noRO8AmscalIQ9prM5tNsySiH0x68meJoaIVyfppQOS+lE40XwA2mhLJ3LvxbXRYKvvpBJ+PSlZu8\n772PkhbbctaXZthujjksWADd7as4X6WjqyycOiO/23yLeVMi+cCH+Obr50WQzRbxKGR/LODb8Xib\nNKpQXl3lextX5bnoEW91p0I8Vq8QzVT5Sn6c+UVRiLnqPDeXHqJXOEU7PGR8RhMrSxKLQnQGPWpp\nzvUNKbrXW7OUlOI3z/4VVr56Ua7V0Wh0kbGqiVyxmZ9WqZQuODKe6fYm7cTm3wZqisA52p/A28v+\nh6CaIdMMN6qGnFhokKcpX3zsLADGDYn31okLwr6ZO015XCNqzTGel/i8PNylfPxdjLfXqV+Vhw1V\nyPGzdfq6gIbdAn1dQu9eZy4pqv7xPLulPmpWoIM03aR0pkL1Sk7FiELoqIkfHOIKfr4qNYj8EGMM\nfiAKUUnm2d3dZaEqSKhPIUtTuuUmN899XM7z+S+hUDivMbcZ4pG1mCnKqdAukOJVcYxTRVI1JQ4f\nlQ2nROE7yPfO8EKs0VqWfaV3QFKqMAxiakUE0N7eJjCLjPblwQbff5m9Rx+m2tHoumSvo8VTDLIR\npfXrtA5Ei+vPvYduNKS7K5pfCQ7IMweNJd54Qbbz+JMtllaWCdMCvlU1Bm7A4hPzuIHs9qiYLn7Y\nx0VSUx2T4rTBWodJJOsedIfM4qmMpYy4HyU0SyGmv44992MA2M98Flepo7yfbgD0GrLMipMFlDbk\nHiEFTzZwFBK30xRXTfclTN3sHXpb3He493DcUfOz4YBwVRKYxf/kk3z+yWPsZT2a1wQnaV68yTAs\nkRYAWfz4Y8zbnHQ5Io1F87i5waOrc6SDEVfPyHaeua0N0uEGw/ISAPrkAstpTvviGtFpccKDSoVR\n74BRYTQjbYjLIVWV4aoSn6u4ycAobGUBAJulKJtRDkP2C/JTvrtFuLLC7ljw/Pl8n3Sk6FeWWSo2\njeyXa8KYw0GxVcjhSa0jLBJI5RQGNwXVQLYqOXdkYhRCOvC3cVBvP/4dCb9WqxKcEPsaxHXmtnss\nvnWd/p6Yi3h2lujMaezrwmccPHaMwbFjqDyiWxIsJZqJWV/fZen0aT5s3wJg1OmyveXQPUmEfLXO\ndmrpzSzTjMXHJDnkPqBZFaGFY41xJTZtwFIsDndvOECX68ThxAhDbiO8tjQKa5vPNaA7JE0kX9mt\nL2Nu7VHVkFkBDX1UQpOjVHxkv60iTy1xgc5aJRs/UFN/i/KglMYUpsUqfjDlhaMS4zsV/vb6Fvam\noIPReJs/87Of4iqazceEz+jrdRpXXuHwOSE6NecVw+0+1g9Z6Iitbr+5R/PkAg+GPXqHYndt1sCV\n21NK3cXc0lxappZ5jBcnPDPq09El9ouKWDkqUXcZy9mQdsEIyOeWcSNHfyQPfywO0GmfnazEwUgm\nZDaMsblnqTjvMW154UsvE/30jxMuS/hZXj6DX7+Ei/wUKslzR+Yd5cIDey9O1/rbmDheFZPlp8eg\nJp8mvJ37SdYfy3HnGq4xjArayYPvepxxFBGfmefcoIiArr7Etfd9gsaizO7++VvYSo3l1XkSL3H1\nMOgzaB1jbf8mVSNw8aX9DnpmBhufA6C1sMxcbxNVbdHrTgrWcn1yCSMrYZnD1FIKI8bFuk96bfoq\nIayIadpQHpdoWtmQaFAU8GsxvdzQmRPf1c0Ve/GIoHOJWlf8i4/LgEQ6qgibc+dRqKM9YVMO3m1w\ngpbI53a6pBeC/pQuqe4Q6N9R+ItPP4hqSxi5WK8TDm4xPKjxWksepPTsaWZblv3LYpqau9vw2GlU\n3dO3kuRUT1fJ1tbo1xN6+wKS7Y0dw3SGxpL4k4Vxn3G5SXXQ5cGK2PN+rvEuJI4L5ttgwIEN0eWY\ngn1NyQSUAArbHVFFdTP6lIiKPKNuLIclTb8rZsfbLnNnKxzOPETckGMqvS0iZ3HkWCcnH+cWr/yU\nWm4oGnS4I2FOMJ7ppkEv/1dKTTF/9/aFrDsL/+mVKsvPSlXic40neKU2Q0yflYJ/k7814HquODUW\nLT/2yZ/gYKaFjmKGm1cBWM1GZLVZRkNH51BWzCOVCv0TZ2kUe2WbaYdtX8VVEyb7fkvVFicV1Atm\nwM64z/FQkZYr6KEoxLLR7I/HDJ2ELW1lOWlGbGSarNihuAssNg3HuwLQrQUNFpIHGISO+UC0unv5\nJr6mcC6fopp5Jpu5j1obUBjzfMpY05ZiP9YPQglCp7xDavtvI/z0mSf43KKEh/F+m+buLZLEc1jM\n+G6Ws6hTlv4d2SisGoaNcYljt26QZ/IQ6+Um+QjKgz5JS3D4w/kVtjv7JDW5vFERD+Qjrm4dcLMI\nIxtJhUGqOCzg47g0Q9V3yYY9fFUqUONsQKDH1PsCVadJnQNVIlcWWyCvjWxEd2i4ZZbloWoxm1rR\nCebYX5MVWxqvQfVh2TlZmAtJnCy6CD29mtACdUGaBK8Ft7fTXhTgnCALfspvut915I/luKPmv9TO\nCC7+gRxYKTN6eIWg22WwXYRfKwtUdjpcjiWmDza20Yc77Ngyh3VJsip7HWpJBe89dlaSKhM6ymPL\nWipaMWMCwhBqWtEcil8IUk1l1MfGshI65RLlYJmIjL2BHLNYqdAfj8niyQY1RVaq0KLHQSoqvNSs\nkhvFYVOuHd66TN5cJsoGuOfl2XQ4J30enEcVxOAsywt8nuLcmsBDpuQTFHvSUNO4XzaJC/Lp3aTg\ncpfsBWZalMsioFKtQd4/JLpwjfTUowCcuXmdlZOrRF5ArKBaY7NSpdvfJXhdGMm1dz9Ca5xzo7HI\nXiSJzrn2derecGMkmMxNDccTQ2YzKBflv1LMqFRlPpGJ3tzdpDcYMAyr1Iud471uDxUHLJcKjlA/\n48rBmNmlEp3vS1H90tIT7O3sYiUgotJqsJXH1OszJBvCEbImIPQe78AW6eooy2TTw1SyhcNV4Ow0\nlBGaSGFAPEo2ens//d1kA8g7Fv5MluOrRdTS30Ot3eDazpCTg9cBmH/30yx+4GFu9iSSaA/GMOxQ\nXtuk8ZyEkWzvsD5S7C1XeeSa/K5SC8mTGk9bcYIXXYMbOxYbxcRGNG9kDLrvuVrEuk3nqJfFHYeN\nwsGu71M9fpor+1IRO39rB3vmNGFnh9GsrMberVvUg5yDDcnCcxtRmlOUD/sMnv+/AEhmTh7R1QvB\nWusJOQoVZUfOJNmSMdmT42/ffaKcAHQTX3E/yfrjOe4MrM3N0EolWTGf/l0WZyCbOUd1VTD2+Jl3\n8aby5J2CAXywi9rvkz7+EFcvijb2nzzD2a1rvG/vdbpV4emocsB+rgiLDRTLeweYHPZ0SLNolVLL\nU7TqkRoBzWqNBczhHldR1NdkM0QyO8+t3S3WhwIxN2cquJ0txlkP3RIb3wgNajzgsSVJ8C6+fINy\ndIB5/XWuS3mByqkBS7M1yvWjBnu580SRm2LDmiMG8+3FcadzVNHswLkAlMNhyJW77ci7EL4dD6i8\nJQjmzKkGW2ffT32oOFiQZb83PGC2nVK5LvZ1YDT9uUUaL3yP8H3CjWy4MeWyJjZldvpy82utBcrj\nAde6EvcvRRW8GRPYHkMnIWoQheg44eDKFQA6K/P0rCHt7DGzIILdc2P2PdSLsDIoRSSHB7BynO2i\nGliqBNSbi1y7KPc4zCEeG07/P/+cF8vFc3ZHDHo51caYWk2e7fqNXeZn65Tn40Ia0uDD3VY8QaUo\nb44g+4LNYL0/2vJ0tzvQZ9sDDgNxgPaBj9IbDll76glmFuSuq7t7tF9ap3FCNHpnt81iZOGRM+wX\n3fXy3OBKLcL2iNlCsP7KNVZnI0aZMAp6yQJV3ed4SdMbykrbzeukozYLzSLD1AOSxhytSs4rI3Hw\nw7hE4By1omVkK8+4UK0y4z2HRTFlM8uor/fpdguWm+vxjJ6nVeqyelLue2OzTeYSBt2MdlsmqTMY\ns9HeZe2m3PPqXMLyXImkHFuFgA8AAA7TSURBVEwZCd4Lb2ZScHFMGkIddT67a82PUvAFNXDv5DE2\nkjorowGlL79aCKhMOj+L78vFK6dXKe9epufniS6Kg3v6RImNNKZ3fZ/qKbncvB2QVGepdUTQpaRC\nrBS622Zc3O3YdcjzEenyKRGGtZyNyjA4pN8X4TeCEuXITCHmvOdolsoc1qq0ilU22uwx7GbTaCe3\nLVZvdVg89gB5KsvjPU8+wc5hj/Nv3GR/Q4TvrCMOK7iiicT6wYD17UMq1YCVlijkTLOECW4jTVmF\n1x7rHUdNiN7erd53uPdw3FHzt06dpjMr4Fd964DTVrH1jecZTeiC73k/s+Mcl8jqCK5dovLQInZk\nCbV4rsbcPOUbWxyszOJ7ggkFMytcdGVqWjRvr9NmPlKouMlhYYoqWUqjkjAeyOeRj9hKUuilHBYF\njvfPNOmMumwVzex0r0s8u8RONIP69vMArJaXua4zcoHgyMclzvTf4PqNm5giOTRhxKljxzl94ji7\nbVmN59+6ySuvXiCw8rtyHJOakNHAcaknmFR+0zLbqLJcEMSatRBw0qiy0Pj8DqSpO2M7RtPcl2Wo\nN9Y5vHyZaK6CP/6YTIhSDPM+1V1BFZsPHGMjConckBOrckPXtlLo7DPyJepFUT10OcHhLvWCRt6P\nG2weDrF+SK1eUP+WZti7eovasizfza0N6uESb7S7nC0il26e09VQLwouB60mb3YtJ8cHXK+K0sz0\n+5iyolN0Vjq5vofb/z2GvQOa85OMO8R78M4zW5f7/sizD/O+9zzIjetCX3z1/BX2d7ZJoioTvGZk\nPQeHfXb3Cs5poJmdr7HQLFOOi/Lj3XI14zTHviaJ0ZUAji20qOsydlYclb25RjPukS5KyJj2PKXZ\nEnkfhiW5aG3jRcauykFnxPqqUPjifEQjH+IK+nWSHeArdUoaRonACbFO6JiY8Vg0/+FGwog+fZdC\nkQjtpCmnyxV2vEy+6Q3YHVXIzl8iOSHZc7djyUoRUUlAvTPBW2xcuYTWRzvntfJFkyg13QCYeWk3\n+eBp2aD30OkT7B4ccP7SDd54/c1CeIYoiBnriZbDre091jc09UT80Px87e6Ev/3iC1On8ODWPkmp\nTudjT+NelSU9v9xkuLTMY1uyVBs31vjm4Bid62ucft+TAPRHDlsKaR6bRRU7C7U2uCjnVtGua5xb\nas2I0miIHohgX965TmlukUFB+WiEit3RCB0GbGSiaatpnW3XoVGYoes2ZvbaHj5JaBdtWXo2JSrP\n0hnLMadKu+x2tonD2tTjOTSRNnjnpECC7LJRRYEcwKmcVqvOh599kg++R/o4nF9b5/z5y2xvi3WI\n4wRtyozHnvZIlGb3xt7byve+w72H446aby5coBXJ8jl8+Al6T72b0td/n+GymJ2tlSVObHW5PF/s\n8Pjffx375z5CZWWFziQPacyztdnl9NIyUUe27wz1LPNBjC9KlINqndF4TGxzMi3EqjTQDJWjVmDl\nN7yj7y0HI095UZbyMNZ093foFTa4eyNjMa2xX4OsKyfv1iMCp3lsIBo4XruIVgFRqUpcoKFhEEm3\nFK0whYO0SssWoCnXUqopzjtU0RLgiUdO8tSjJ6Y2/9W3rvD6y28SmpCFgr0xuluHm7z4KtmflEbO\nwbvfzfDLn4dHTlOfFbvcczmDvUOyVJZd/4PnqM3NUentcLMjNjfvj6lGMUp7xrkIKauUGQWKdlFw\n2Oz2KUeGUQZJSRZjpV6hZ9y0++Z+WGJ3b8Cwt8d4QxRif2mVVe/ovS6VtNJggf5ii8P2LUZVEVC/\nnBAN4diBZOrp61/GRHXKlTqlAjQMAiM9/b3CTfr1e4/CQrHZz3uL1j9oKJzzZECriPs/9v6n+ZFn\nH+etq+u89LKgup1bu3cn/N7cKvFHPgDA4NZ5Sk8+RrWqOVgSbEf969+h/+gjZBuCTkbvfZK4f5Vm\nM6K9d00e3kPQWmYvzdGJTIga9Hk5mMUMRWiVTHPy5CJ+0GW/4OC4KCYa9YgKcKVaSvjulVcIZ+v4\nolYQH94imatx4YpEJKWlId3DPmHF4iIRbITmPQsBM9+9CoDFEEclojhChyJY6x2hNqA0pphsoQ1G\nxUsTZEe69zlOKbw96mxulZqSpqzNMDrgsQeP8eRDIqNbe4dvK9/7Nv8ejjtqfvVD76OrBTH0SYm8\nHjHeOkTPS3Sjv/ES2bNP0esU6KRJiZKInp4jjCXxMaZOHA/ZPwhZrYnGxPN1rrx5DWxBZHr8XWzd\nuk6yuMqw6FlQH2UMcku1qPu+9dobuASyZoVSkdCX6yX6FzbIC7TxgCHdSkhSbhIV52lozanf/Me4\n81+Qh4pqhKUKQTkhKDZle2+lV8JtVSqtZWvrtFOUUgTa4JzHmMnvJlx8OcRrieudYxolLTdbdyf8\nwXILWyzDWqgY5DmkKZXrYlIWZ2e42j8kiIpKUjqgmcxwbXtIXLAAqM1QiXdx5YgqEv9pEzEkl8YZ\nQCOuc33s0HFMrYCH8+4eG8kMvatS5O5lfcZJDZMsYBE2XLNcZudqG1/8ZhDUqA0d1u/hTkg38Pde\n+zbx938LXxIFCeOYWr1JqVw+4lhajTEBjnQqEu89yvjpq0nwORaDVky3Remiu+Lt6NmkH8/ET9s7\nsJTvLPzFhKzYVZhf3ufhfI/umUfZuSJRy+Hx4/ir2wwWC45MnNDOHd3UkxY9iW/s7dKYy0kqMcMJ\njz+HQTdlYUG04nBtnXB2Dpdm6MKetqqKfH+DjZHYzEE5glqdmc+/wtonJc6ONofYPYUKJRGq7Wvy\nlSHJ8gLvtzJpK1/4JxDPYoouV0mjRVxOUJhpn0+MJFzGG9l9AtOXNUxJTzoi955A+yPat5cVMtk0\nh5IcRpLAH8T+37HwjQ7onRct18tNtj/9DdJnnqVVvOhlpdNh71vfYvTTPwPA1tIit8YDZlSX/YI/\nGdkRo7RGkOSoWELErd6ITm9IswgZ3/znn2b+F38OM3SMJhyYpMXG1ov0Z0Rje5WEmS9c52p4ktk1\nyWjDzNLWEd2KhIy2PqDZmqHpco594Z/KM9gBPkxozgo7rbawgjEB1uXT9pNaadBidoy/DY38gWK5\nI/TgpGR+2yFHxCqlFBZHYJh2a8G9vfDvO9x7OO4MrO3s00qLziMvb7Nz7ifxlw9ZWpPexbvnX2f0\nF/4y5aI43X7oFOVuD6/BFVxJpXNcskJSrmCKrGpjc4M532dcnLt2fIa9L1+k9MgcutD0vYtXGdZK\njCO5xfqr22xvz+CfydDLckz2rbc4CA15Sa7VOtGiVlJ8/NXfJ7z2glw/apA050iK85pAg/PShbDQ\n8kwpAqewKPSkkK7k9SR2ykyQ1sUa92/YeI26zeOqoh29spPtoXcp/OSwDd+SZOH6yU/wyImA65tt\n+q8XLwL42CdYtwNay2Lz+84y53N6xDAsYm+lsSomMgZfgGRB1iOvlqkVgu0MctbrA07EY0Y3hLM/\n3yrT9iWSwnF3X9AMziaUV0vEa5If7Kyvc3h6lYViMpqjER9Z+yrBN/8lLhZ/UirXmJlfJi4yTuXA\nGVPsIhTJBFqyW+kTOpGqRRk9BfGMDvAuxSk9LREqr7Aqm86FUcUuROWZNmW4W+rI+FvrdE7+lJwj\nsbwxSjm2cwFTNJoKEk9crzGeExg4thloj8tTTLExelyts3vYx820MDviqHPtaAzG7F0UfL9barGd\nGObDhJlZ+d2eN6hmg8OvSOq+MVOhfgaiUcDWV0SrD04vs1KPOBaJ5p9746s0X/pdXNwgjCQRay2u\nEJWr05BRa1A4jIG8YLEqL90HtTIE0x7/Hu3dtEu097k4U9RUm713hATYIkiQFzxQvGWj+OEdSFP3\nbf49HHcONY/9CK2qRDb9JCeIUpL1AYunBMvYaMyRd9r4R4XBFo4HWG1I+23yYtkN21B2W2AzGgWc\n0GnNEf3275EuCNV8o7XISj2SuLsA2w4rCfkNw+4NOU/lqRwqi7Sfv0S0UtATFxJWF+DMC/K+q/k3\nv4gLqygdMrMoBZdya17YZkVbGqc9RkeyTd9MXl1RRI+G6VshvDcoE0x9AF7hfC7U79sp4dofYT7K\n4b3CFf35OTr0nQu/p0ccluUsse8yHrVgY4vqqaKvQi9lbrlEu3iIik1RNidPx6RFA6T9gePxsEsc\nKIbzMmmbL1zjzPY2W3PCHC6vVKklmlI2YLMuoN1hO6bzpTb6QTEpdu4YnZduUF+pMLNV9OB5+Sbl\ntuLExS+LDKMKQRAxt3qSWuFghTF8FB56K2ZHaY0rNl54kCTLMaX+TfdDT6gf3smRzsuGaJjSBaeH\nFK2Qpq+wgj8Exv1bC9+WR2g/4a3UYbdP2e3ST+TBgryLXl3BOXECNks5GFsIPIOunDqql7BVTe4D\ntrfEns9/7it0HjnG4YoIv5aU8NWQYalMuyMONvuXe3ByDKdE0OZCm0HT4bpQuSDR1rlFzfLVF8Qx\nAjouM3/8DM3GHOnEUXqN0RpfcPFDpfCqeKFMsVNQGgG6H7TVTppc+CJeN55pows38RXFb6cvp9Hg\nbdGmeNIA/G57KatyCiNZ4q40ItZdopoiVZLU6MVZulGAKXrcZ5llPExRukS/yPoWKwE2H5P7KvzD\nz8gNrZ7k+uoS0bysoFYZwlqF1w7LLH1GSnTDuRi3HFO6KTD0brbLM43jLK2/wjCQ5Mxc+jRmBlwR\nEbUWVqg3i92Tk67ZKJz3hMXqdBZQHusySa4AcsiNx0xeLTiRrM+ZbNawThpyO+cmc4Z30tQ1nxQm\n8sLZHvX6mvbm+aPGfYd7D8ed4YVI44puf9m4RxQPqZZDLgvLj3huARuG6K7Y5V5uMQwZdDVUC6qG\nSSF37D7/Bk8/9ScBiFZ3eGNukbAkS7NWjtjqRAS/v84QMTNu8RDXn8VclwTu/Y0RC1/6Kje7mudi\nKYysPvWTjHYusCRlAmqzyyiMNNi+zQ57xbRjlDGa3DpCZaYvk7HFNk/nme6lst5hNLjb3o6cadAF\nlXwyrM+mWi70cYdz7o62fjKUv5M7vj/+Px33zc49HPeFfw/HfeHfw3Ff+Pdw3Bf+PRz3hX8Px/8L\nmha/p4Qii9cAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "VXjtCPxl3I82", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# We'll be using a pre-trained 12-layer Reformer model.\n", - "# First, load the config (which sets all needed hyperparameters).\n", - "!gsutil cp gs://trax-ml/reformer/imgnet64/config.gin ./config.gin\n", - "gin.parse_config_file('./config.gin')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "NhiTshPPbvLY", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Now we construct a ReformerLM instance and load the pre-trained weights.\n", - "# The 'predict' mode configures the model to accept single tokens at a time,\n", - "# instead of feeding in a complete image all at once.\n", - "model_infer = trax.models.ReformerLM(mode='predict')\n", - "model_infer.init_from_file(\n", - " 'gs://trax-ml/reformer/imgnet64/model.pkl', weights_only=True)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zY3hpgnI5Rgn", - "colab_type": "text" - }, - "source": [ - "## Sample from the model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PnzRPCzFqIVi", - "colab_type": "text" - }, - "source": [ - "Now we're ready to sample from the pre-trained Reformer model. Unlike during training, sampling processes the images one pixel and channel value at a time. The TPU colab runtime has 8 cores so we can sample 8 images in parallel." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "W9ZetV91PujO", - "colab_type": "code", - "colab": {} - }, - "source": [ - "sampling_decoder = Search(\n", - " trax.models.ReformerLM,\n", - " model_infer.weights,\n", - " temperature=1.0,\n", - " max_decode_len=32*64*3,\n", - " )" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HOLawc5dB7QV", - "colab_type": "text" - }, - "source": [ - "Sampling is an inherently serial process and will take up to 9 minutes to run. A good chunk of that time will be spent on JIT-compiling the code, though, so the code cell below will finish faster when re-run for a second time." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "We9Jj9Rap3cB", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 214 - }, - "outputId": "10b6142b-11f1-414d-9b63-353f721a6a82" - }, - "source": [ - "flat_prompt = []\n", - "for i, img in enumerate(DATA[:trax.fastmath.device_count()]):\n", - " img = img.reshape((-1, 64, 3))[:32, :, :]\n", - " flat_prompt.append(img.reshape((-1,)))\n", - "prompt = np.stack(flat_prompt, 0)\n", - "\n", - "print(\"Prompt:\")\n", - "plt.figure(figsize=(10, 10*8))\n", - "for i in range(prompt.shape[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off')\n", - " plt.imshow(prompt[i].reshape((-1, 64, 3)), aspect='equal')\n", - "plt.show()\n", - "\n", - "seqs, scores = sampling_decoder.decode(targets_prefix=prompt, batch_size=8)\n", - "\n", - "print(\"Sampled completions:\")\n", - "plt.figure(figsize=(10, 10*8))\n", - "for i in range(prompt.shape[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off')\n", - " plt.imshow(seqs[i, -1].reshape((-1, 64, 3)), aspect='equal')\n", - "\n", - "plt.figure(figsize=(10, 10*8))\n", - "for i in range(prompt.shape[0]):\n", - " plt.subplot(1, 8, i+1)\n", - " plt.axis('off')\n", - " img = jnp.concatenate([prompt[i], seqs[i, -1]], -1)\n", - " plt.imshow(img.reshape((-1, 64, 3)), aspect='equal')" - ], - "execution_count": 12, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Prompt:\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eZRlV3Xm+dvn3OGNMUdkRM6DlKkh\nUxISCBCDEGCEwWAwduPC5bmqGozbhbtNuQu77Ha5DMvluapcVe62jSdsbFw2BmMMhUACoQFSQkNK\nqVTOGZkZkTG+ePMdztn9x30RKcal8Fq9ei1WfGtpKfIN97633z7nfOfb3z6iqmxhC1vYwha2sIUt\nfDvD/P/9AbawhS1sYQtb2MIW/r/GFuHZwha2sIUtbGEL3/bYIjxb2MIWtrCFLWzh2x5bhGcLW9jC\nFrawhS1822OL8GxhC1vYwha2sIVve2wRni1sYQtb2MIWtvBtj+BbPfkHv/AO9Xi8N4hRvBesV3Kj\niFNUweOx3uEwGBGefvokzzzwOJWqRdTjRbCAAwRFEfIcUDC2uC4ogQZk4mhkjrQH+//sY1RnYuYu\nt9izfxtu9gw5XQ4k99NurDCSPQNU8HlCmintXkbaEhaveMLAcO2+mMw16LZKrK3lrJ2PcMsxYoX9\nu0axOWQK937pGV5+616GKiG/9KGH5fkG7s4T5zRAMCJ4MTgRqkbIEJwJcMYxjKVhhNBYBEHE0MUR\nilAFVIQcsGIQASNCLEIPQECKMGEwZCgWg4jHqgcxRCheDF6VFDAIXkFVEVFyBVElFRDvEQUPJN6D\ngFclQvBAnueUMOTeY0VQVYwqHgjUc8/ebc87NgAPzGVqg5A/+/Vf4/Dhmzl89+tI+w7nhMwLKDgP\n6hWv4L3gvKIevC+uYQffHyAQAVM8r8og98BpkVVGipgZo4gIYoq8M0YRAyIQWUGMYkSxBsTIIO4e\nwRfX0Kv3lEEceM7jKlLksQ5uCIDyih3h847PK1/zcm2srHDo4AEaK8sM1Wp0W022TY5w+4tupNts\nUqtU2b1zL6+5624ee+RRTKycnr3IleUVQjUcO/0sf/OxewjDEcJSwr5DY4yMVhFRvDPkuafZalIp\n1VhZbdHr9jES4r0DySmVQ+JSSL1WoRxFbNteI6yME9gqWepBHMZk9JOEytDt7Nl1F8uNZzl79l4C\nGmT9RSqVnUTxfrq+yo1HXk2302Zp7gFWlo8xs+s13HnX23njCyfQIljPCz/7m2/T1cUuzliWly9y\n5tnTRBVL6pW5C02WOvCHH/hPXLN/ih95z/dDUieoZAR5lUo9pi3zlMeUzqqgudLtClEAzimBEWrb\nIEuUsckIu3SISydXEZ/SaLaZmBzmrte+iIPXHGL3rkNIXOIrxz/N7/7WnxD1lNf90Cv57u/6fpaW\nT/Inf/aXHL3/MkFJCMqKZDBhyrTbKaUZJfWe0SnD6iKM7ILOGrg2jE1BJOD6wlpLOX3UPe/Y/MgP\n7tMrjRSLYsXQbHbo93p4p0QlIYosy3mXP7r9Vdxyx6043UGWXKY7VOFif436noMENka9J0vXSPoN\n2mtLfPz+x1hb6/HI4xcgVH7+p9+JsauMDteIKxWsGhBDril/9/cP8ME/vQ8JTDGn4QjjiDe8+S7G\nhsuExlMKDHgHRlBvMXiyLEMUFEHFk+eet976aiIT8lu/90GG+6dp2WEe65R49omzmAguLj7/vAH4\nnff/hCZJk9Ba4iiimCGELPVcWrhML03odLu0kg69tAveERhBtUkt7FKNAypRndhGxVyAFvODxqS5\n0s8cqYIJY0rlCuVomDiuUYpiRALyzJHnOVnuaLd7NNbWmL18kU4/oZ8N5jh1pE5p9RMEC2LBWKJA\nCEyItUJgLNYKobUIivcg3oAWc5IVQ2CF+x85+bzjs7BwVkUVsASBLeYBoFQdx4iQ9BuIhARhBICx\nFhEBVYwJsdZSzHdXbyny1bf/2n8D9JMOF44/SHN+juHhIXYduYNyffKrrgOg6nFPfhgrLdj1chSD\nX75I84mHWF1p0bI1/Pg0k3sPEg2Pgjii0ghhqY6YEFDCqII1IemTH6F8y9u/YWy+JeFxYlFfLAqo\nYFC8gFHBqUehSHkJEO9R8bg8JwgAURyCeAGrGAWPwarHm2Jxc94TBeCdIbOOMHcMORj5yEdohD16\njZT6TJXepZPMVJR2e5m4cwbN5+iv9jk3e4Wl+S6a7Ge8chsH9t3I6DjMX5pj8VRKOPwoK6uXaFyp\nc/FSgz2Vbay2Eu579DJxqGyfLPHiw9upxBax4bcKxddDlVwFa8AKhN6DsYQUREO9YdUW5CN3DmMs\nRjJiLLkqXYEICpIxWFQjDLl6YmMQ9WQYUgrSGKqAOEIvZCLFxKHQQwm0mERVlQwIEDJVrAoBQugc\noYGeGJx3uAEJNUDbe8oiqFj66hiSggjmKGZAUS1+c7GhIBHqMn7wZ9/LJ/7v3+PMb/9H7n7Xewkj\nQTKHV4M1oCp4D06VwIP6gsRIkVwF11AAxYjgBt+TAQ+x6jdeWJCXgjyKASMFEbIDYmOtYkxxnXWi\ngxQEqBgdxfeFb0x2RAoCWbyieM/6KNgMZrZNUauUCAPDgf17aawuE8UBQWjodruEYYXdu/bz2te+\nmiceO0ovaWF8xNNPPcULX3oHcxcuEYchB6/Zx6lT85SqFovFuZRKtYzLDc73qFYN3ncJwh71YUu3\nreRZhjHQ76UEYYDgqIZC1usRhE2iMCSKI5J+Rp7lBMYSRiGVkRGCkQOk+UXaSydo5ys0W1dwzYyw\nPMajj36YWsWztjJPEG9j14FDjEyObjpvvvPF/5JXvez1fPSe/4d/+6v/irEZeOub3sM9X/gfXFpo\n8iv/6heY2FbmX//s9/PyW76X1HdoN5dZWVkgCVeoVz3Li9BdFeJIKJcgS2FoSrCBkuVCdVjodhNq\nQZPxaUskNZqdFlNTYwRRCZtcIbhyAr/jjST9nNwJsVW6rQWePvEZnjo2S7NZ/OpRCHsOxAgJrbMp\n7bZjeqJMP+2Rpp5e2zAFpA6cQq8BrTa0VhWXbi42KoYoFEIxtHsOYy3eeZxTDBHOQQgcW1rmlgtr\ndCZCTkmbZt5mrdMmfOZJGs0mIzXDQ48dx6qn0/d0Us/liy1cpvSSJqn2uXCmx+//8Qd55UsOs7jU\nZdtMjbtediPX7BmnXotIk5Q4hJmd0xx5wXVsmyozXI6AjMwZRAXvDVme4xHCUoRzHnWeUikiTRxP\nLz6LX/LYzjxjwxHPXF6iPLKf0bGA5Ua+6dwBS5p4JBKseESKzY/3xaYwd0ruldwZvA9JcynmCC1h\n1BGYiDiI8RKBCM47AgyIRTQgMAYjJeJSjXK5TlwqE8cxcRRhjMU5T5Z5kn6KupgsC4jiHr3eGiGO\nVD24AK8O5xy5E1QUsZDmBms81lishcga4qiYkQTBeIfFbsxtTjdXnOl3VpGsRxCVkVKVMKqAGFRd\n8f3EYIxBZHBdLe6rBeVCB/crOI0MXqLfkORcfVwwYrGBJev3cKUIzZPBvFoIHcUqNBADZu5g/qkv\nMD0REtXHycZqLO3NmcuPkTnDiAgu62KlTpbngw1ugLUW9Q5BwRh6k7dR/iZx+JaEJ1AlFzZY5vo3\nNlqwdFXP+rmFRRwMWZ4xEG2KhZdCKQCL4km8wRjFiiXzDmsEjydST96DnX/yd8y6BmVXoyxXGJq7\nwtDwGEPpMdKGYWkppRQEHH14nhcfeh/vvPs2JieGyGyF8xcXuHD5MgcO7GV+YYm9O97A+Wf+By//\nzrdy7tQpPvTnvwf9FNfzNDuGNO2TjVVoJZ6havY8U6eAxaCmIIGhV7AGdQ4VgxhHIAa8p2/sgNR4\nYlUSFDWQekENZKIEKKEKGY4RI6Qe0sHSG4vBKwieQBWHEnpIjMOpYHBYAlItJtQizRSvHotskKmy\nQoAnFiFTUDyOgiAEIoDDqZJq8VkThFAVCySb22gVUEA8Pk148zv/Vz7xB3/M7777+/nn/9d/YWRy\nrCDIWmSOGrAq5AP1xvgB6diQWgbXG8REtFClioFjNl4iAipXiY4xYEWK/xvF2mIgGqMY3EAVKvap\nxQGc+tybbXyGDUqjxfWLnVbxaURls3yHMDSMjw7RbjZpra5yzbW7GanvoVIp022vMT5UY2pqis9/\n/l4uzZ7l5iNHCOIhvudtb2P/wYOcePIZtu/ezuT0Tn7nmf+KS2tkSaFiuFzp9xN6HSXPoiIPZBhj\nhHpdCUNLEASICLnL6fUSdu/dQz9PWFlrUomHMFjwDpelSFyl12+z1rxCaUiZnBwhzIYIzTCNZhMC\nx/btB1lenePcyUeIokle/Yo3cvsdd1Apb75iPjE5CcDdr/hnfOJzf0Z1bIq/+uhvs7IEK5fg4aNf\nYeYa2HH99Rx54S3cfOBGvuPOtzC/dILdd1zHCw6NMhY42r0WYUVwfkBKRVi6INQnIYhgbUlYaZ4j\nTcD3wZZKjI4PEWVL7F66h+0pPH3yfh454YhKCh1otFY5fuox7rvvDHNXILBCXFb27t1LJcp4fP4i\nYeToNxK+70ffyT987I9oxQm9ppDlIB66C4oGggQwdc3mYhObgH6vQ8crViDtZTjnAIP3inolzWF6\nz3Wc3zPOQytPMb84z/U7DvGr/+1DnL8CIxW4sgAzY/Dmu1/C9FSZrzx5gTRLSbKMeq1Gc+UKu3dt\n54ljOU8ce4xaDaIIPvH3jzI9BjfsHWNqxw7GxuqEwyVqlRHCIKKXZsUGAkcUGILQUy6FqBp6/QQj\nFmMDjDGUSxHzrRWaF+cYkmVCKREFyrZaj5tfvZ0vHG1sOndUIXOOIA/wAVgtlHdFcD4A8VhTIrBC\nLx8oJ1Ko4i73OBfiXIxKiA7UDYfgvEUkwpoIE8QEpkwcVqlX6tg4xiJ4VQILgQFrIrwL6KdCHFQJ\nbILmCcZCF4+mildL7hWPwXiwBnJjgOI6mYVMPZEtFHYrxZxtjUG0EBA2A5d06J07hs8SSqMzDO++\nntLQJLgUNWXwKd5niDEIHh2sISBYG6JhQbaKeVI2iI4qz/n76mZx8IuACFFUIjDKMD3i3iwaOLzL\nwSXgE8T3IBqhs7ZKOD7D0MxBAIK0RCwpExXLmgtRn6IuxQQRtFdJxKLqKVlL6JNCtY8qlLX9TePw\nLQlPzvoOWzb2vgJ48QOqN9AAim9dLA6pw0tRwlI/WBnwIJ4ifsUKYkQJbLFCCYb0gufg7/9nzgae\nITXM1OewK08wTE7rxGm02mDEeuKhRwjT1/Hz7/pVdkzu4fy5L/Bzv/hx3v0T72BpLmVidJzayAjq\nPGNTYxzc904unXyGF7/2VXzuk39ERkzWyWglKZ1Oh8uXmhix9LLO5jLIKwOVEW8c1hfBTEULyRlH\ngiVWV6glIqgaAlG8g9RCmYJ8RFLEKhKlj0V9MQIsgOYYAlI8BiFW6BplGKFnoeYNoq7QG0RRNWSq\nhGoQih1KIkpjQCYsHqNCn6IkaRBy78kHyoUTqCMY9agY1HtiNjm6YJ0BA0reT/muf/HDNAn4m3//\nL/nJP/goSd8Xk/TgvkYHMQLUy4DfFO9fV3kKYUdxqli9en0FzAbZWf9PBmWrAfkRxRgw4gpFR77u\nYw7IT/HA+uB97jfX5z6gz6FGmzytPEv6lKKQ0AovfNGtOE0olQz1epWpsTFGh8pkeZulpQUmJ6dZ\nWmpSHYLFtRYnTp9m/sIc23ZO8rKXvojHH38lTzz5OK3VDplaWs2U+bllmmtd6vUhxBQEq16vEsWW\nUjnCmJDARHgjBD6lakrUaqOsrC7SS0NEPL2sS+b6DMczTG0bpSRtTHuNwK8xMVGhVhonigKaPYjK\nNcbtLi6eOcvkyLUcvu42to2Vn8NYnz8azRUAem6Nmw69ifFto1yav0KQLjMfPMPnH/o4/doz7Jm6\nnsnhYYKgBMD52ePsGq9wzYE7eOALn2BifIhm2sQjRCUIQs/k3kIGXDyvNOYUa4SwBJ1MKfmQKFni\nVv8su0eh5wzX1Jb5jgguTBuSC0qvmXC606G1BqEVAhHmF5RLZ9vcfffr+PI/fJgoiuiser7y5b8n\nyfqUx0ACiJ1lbcnRawn1Mcv+w4oZdpuLTTshywtFMs0dYSz0+x4xhszlhEFISWDkroMslC3TtX2I\nVx768hmWlqEMuD7cdM0kb3nDzVxeWeXhJy9z+tQynTRHA0Op1+PYo49w5xsmeNubb+b++x5n+zRM\nj8fM7Jxi7/7d1IbriLVAQGN1lU7epFItI2IJpFBROg6cyykFnnolZKgOvZ6jn3isqZJnSikMeel1\nN/DA7CJRxXPy1FnS8xfYdV3Inl3jm84dURA/UIc9hXqCIKoIIZEYiCy5F5IsLxQNA4EarAiBDYEY\nlQART2gKwmxtDARgS1gbEcUVKpUKpbiENSEeQXSgMCBYG4CHNMkohQFl6/HpMjYeIsibeNcprAeA\nMSFhVMcaAwYUQ2AAo0VZ3hYleQuE1hCJxYgF8y2X7q+DFUMcVsgTjzaXSRbOE1WHEZ8gYQmftsnX\nVrBxCc0SXNLEtxcLhXdiL/HMTdgwHgS2UGXWoSobis6GMk5BfATI8WT9FuOjq5h+A6ovx+LQs3+M\nzJ2G2WMsm1tZu/77GL/2JrzrIjbE+x4SVxg68ALKTlmbPUOeZqRpl6nsDG5lhVwisuoOOvEEoSi1\nbI5S/yTsvOUbxuFbRq34Cr5YakXxyNWFZlAycT4vFB8AI+RpTmAG/gwVcqdEcTDwhhjUKWqKQRHY\ngrXSdIz9u18gv/EgZv4iw/lxTJJRyleR2rWMbV8h7jSR/lO87VUn8DlUyjHLZ85iNeCWI3vxaZ+b\nbrqFY8+eYjyM2LVzB0FkMAoz+3aRhxHeeobiGqYqjIvi3RhJmpNmjixNNpVAqkroDVlRsSMb1HtF\nhY56Im8R4xGESDyBt/TFgzEYhMgrLYFQPE6FUChYv3pUit0JeIxCb1DKKgd5wfTV0336ONnKCjI9\nAwcP4t2gRIjBU3h4rBaqWnnwm63gCWTglRGw6wv7QOXwWpTnmqqEYojUFGTtn8R3riY+oqS9lB/4\nFz/Ax+pD/OEHfocf+OmfIgjB5W6QT+uqCYVXhwGP0Ks7CVXBIAW50XWFZTBxFMEvatxGB3K2bpCf\novadb3h9njswN5Jdr5anNnYtzzH1bPAinkOE5Op7ni/q9SGeevJxDt9wPb1+j+XVOcbG6gzV69xz\nz2d4/d13USmX8F5ZXV1jrdFkaHSSRrvFth3bictlvHM89OB93H77YTK6PH3iGSZH9uHzPlOTo2zf\nPkUUW1RT+v0uxuSEEoMEDNfHmJyYZmJigjKelfPPsrh8npGDhzBxjU7SLsqexhT+kDCm05onbZ0l\n6V+kVq/hNcJITESXIExotRuMT+/jhiOvZGpmqvBN6SZrNsCH/+I/c/+zf8H9932QyanX8Nhf38Py\nEtx551u5fOE0nSShvdxlz80Hef13vIOHj36MhfYsf/PpD/NDP/Ie/u6vf59zC7B/f5vhcUOnBd0V\nz75rx1hbXuPMcce2qYgX3LaDY4+dBQM37NlGvpTyum3Pcv12sIEwFEHXGO44DH95ytATw003v5Iv\nPnQfXiGOBJ9CyQjHj11idfGjJPmg7J97cmkws2uEfteycqFHc7ZLlgr1aagOZ7TWIMg3p4AFxlCK\nQ/IsQY3QaafkDmqRkKmQZV2u2b+L3KZkPaVSG2Fk2zif+9PPsLIExsLth7fx+te+gNnFBdY6Kc7b\nwnqAILmSmBLnn13g/pGHeMNrDjJcabJn7w7GxkcwQbHtjYISWR4iGEpBk3arg08dRjxhYMnynNx7\nwijCRAFLSYIVpVYTdkzX6XegkaSMlGfoN1a47cUv4+BNh5HD9/HiFx1k1+4xqpV407mTu3wwHh3q\ni9g6zfHYYtwbQ4QhDi1RWPhjLCCDclFgI2wwKGeJIQgsOCUwESqGIAzBRsSliFKpRGgDrISFCqRF\necggGGvwVUNZQ7a9YD8vuPY2hoeGERE67SbHnzlGrAFJ2uXy2ae5cOnLXG6epxTUsSYq5nijlIzF\nDDYNoREqAYTGFD5G2RxZro/vQSpD+PYa2ZXHyVZncTuuRVwHWxqHXgN99mO4zKFZH8lTAgNBXEE1\nw49dg9hgUOrSQgGDDaJztczFxnoCppg/JaCfZ6S+Rrw0i0anwIaQT9BniBWzl/bYfkZ23EBgLUm/\nRRiE+CwhKpcxQZnI56yhuDyhtTDLtnP/BVP1hHu+h3J1CPIu2liE2YfpzS9SeeE3jsPzoIm2WDzU\nYHComoGnpzC0qin8POIFK5C7DBEIpFCIjBVE/UbJoChJGBDFq0E0QysV9rz8dprNJiPBZUZb5yhV\nt2PsbrZJkzR9nJl+g2vzLl/6/ENMb9vD2TMnqderzEzt461vuR6tDPHFh4+ybWon3U6TJ48+zA/+\n+I+xcOEMlxaWaJ09R6dn6STdgqxYSyk0RGFIKQ5xWt1UAhk8ESAY8gGhSMmJgSE19HDEanE4RAyZ\n97jAIt5jxKAihWFYBY+wosqwKVQcBHJ1GDVk4hlWoSfQVYNbXCD/4IdI8hx1jr6D0vV7Cd/+vUW9\nWXMsQqqQSk6IxXhwKA4IVXFGMN7jKRSeSAvZX6VQl0pi6Kpi1JGr4/nbca9C1yUQ1geBkvQS3vz2\nN/H60Yinjj7Ie/7jb7L7wHbSJB3kBaiaQtPRor6rqhtkR9UTitkQVHSwpzIwIDtFOa8wLlPUeKXw\nPhmKEtfX15x1oDIONEy5SnCuDl5Fv/odX32FTSo8Z86eZ3hsnAPXXsNjXznKjYcPMj05Tr/X48A1\n17DW6GNo0u526bba3HnXnRx/5gJHjtzEDTcfpt1pk/dzLv7t3xLamJHhUXbu3FPsBAPHxHiFzBdK\nbBSUqFbqGK2g/QBUGK/NsLbQpeRWKFcTaqNVxqerXGo26WlGHhpCE2B6XTrLlzgvI1RrQzQXTqL9\nBp3WCJQquFwxzrFy+SStbo/h8euwtSkoGdIMOs3Ne78+9rGPI4/CUABh5QJhJaRSdXzxwb9lrQlZ\nBlme8aVjH+UrP3sf1x7Zw2/8wTs5PZewd+dOLjYWqATC3HllZBhsKDiFcT0MXCbPTxHGnubaMkEF\nDkzv5kC9xu27n+ZFhwRjCzobVyHJCo/Oj96g/PFRw/33f5YL59uQmw0/kA0hLBnmllaZimJSVfpN\nT70yzdhMyqVjU5y9dJTMCcN7hMqkZ6Q8wsRUicRvTlWOrdDredaz1TtPLQ7JfeFXwcA/+7434H0J\nl7fITEoUl/mNX3w3n7//KY4fn+XIjbtpt/usLvdYaXXp9yFJc3RQYs69cK5jWXvyAlF1nCMvvAEr\nFqeF0dhIQDftEscVJBiiUo2olaC3do6+nSbNAnKXEZiQfjeh3+kX446A1aWMuaDJTYe2U45KdPtL\nJGvwkpftZmTvAq8bqXDq7FFOPLFCp5fyI+9776bi00syHBnOF17FXME7h8ejxhfzgxOcdZiwULvx\nHjHFAi6RwQRm3cWHwxbPBRawSBiAGGwYIkYwYjBBUCgc3qJkBPEQmasT9RJ21CpMTV5LYJWZoTrT\n06MElRqvetGL6Pf6zJ48h9x4G0tnbubUlcs8PPsQp1efpmojxIRExhdz3KCaIKGgpIXvcZM2g6zX\nQOceQa+cQNOUcOIwvttA0gZ+9BpMYCkFOUYTlB4SerABGii+t4DvrSCl6iAy6zo4Az4gz1F5WK9z\nbdSExBjyqMwpP0Z5OWaqto2oNoy5/mWsjj7F8aVPMT4xw/ahOr2F45QroxiJi6YmIGsuQX+NvN/G\npS2Glx7D7/8Z9Oy9SDdERncj9WFUFmlluzjHCt9Y33kehMcPDEjrRiQvimBwMlgbvGAw6KC+l+VZ\nYaswEGnB8RyAEawUnV7eK8aCMR6dhSO/+Ss8W7bY5hVGs2WC+gxiqswkn+bm7hzXl1pE1Yw8zfnH\nxx8hKpXZvncnmfNIZZTTi0t84fOf5PVvfhN5t8e9n7qHV7/xtZw5d5Zf+rn3kavh0A03MDWsRf0P\nxTml19NCJva+WBg3AdFBbIwnY6AiYIouKBQD9FyONwF171FriNRvlFf6XgmM0lchMxAZIfPFjtEP\nSoh9yTc6qzKU2DuW/vBPcb0uGlhCl6CBof/UCfS3f5fof/sJMjGoFSK1pCgl5+hLYV6uqOBUyRgo\nTYOF2g9q1Tk58WByC1CSgWel+09QePS5g0LWhRAl7SX827/9Ij/9lttZecdJfuD//DW+83tfTdLP\nBsTGb0w4Xt2A6FwlPmw8v3719T/9hhmZAfEpLDqFV0fka8nOern1OR6dr7ruIO+/QTlvXeFZn3M2\nmTq43FEervGlL32Zaw/soVyq0G536LQ6jI2M0en2WFpa4Jr9e4hsgPeOiW2jfOGLn+P85dPcetut\ndBo5tXiE2dnL7Jvex/6d1/Kl4w+Re0O3m+ICxUYBuQ8pmTqVaJhyyVApl8lzR95f48C+a0m7c0zW\nh1hZWWH/vp08/MQTRLUyEnoqJQNJjysXHsXaGMlWca5L1OtRG5skCMvktgauTdbvE8Vr+KTF0twa\n5VKVC+cWNxcYoBzXSPpt1rxw7NhJKpUqoQ1ZazQIawbj4PzpBc6eWQAHDz72ZbIMwlB45thFjAjT\n+4S4orRXBd+BZsMQVSpErQgxcHk2p1xrcsu1B9gzPMYL7CNcPw3HTxuOHCp+zwuzymhNEAO7JuDH\nXuT489N9WghBDmkOJiw6DHvdQj3BC1Fg6TRzuk3hlptv5eP3fRQfCNG4MHEttNcgI6WTZTTbmyM8\nq82EgS8Ar0I1DnF5gnphaCjgv37gl1nud1lpLhKFAaJQr40zFo9wcNciQ7Ual+cXSZOMIIpoNBqs\nrLQL9d7nlOpjDNUCwmwJ7cVonhKKpRRHhe8LoVoZpp8WHY82HGIlXyXPLWFlO4EpkeYeY1zR4Wn9\nc0zVBs0M3W7OvfefZu/OYcLIcfedt3Bl+fP84x+u8v7/9iRvvHOIiYk6eWlk07nTyZuopoRI4a3z\nWjRpkOFJQJTceDJSnEnAesTmg5gqmfFkJkfFAA5jDQaLNQKDTSpGCwIFRSkLQARvhdVGTOv4LPHi\nWWKXQXWCZXeUtLVM8oLbyTyJux8AACAASURBVA4dZMdttxFFEZU4IZU1Tn3p7xlrX2afH+KWQy9j\ntnOEv5v9IivZJQiiwjSihauGEEQC0KLEtRm45gqSlfHhTqRSxffbZE/dh9GMYMftXFW2czAgPgWf\nI07Q7gq+eQWpTyPGDjpVZcNKgCrWJ4V4oZ6iG8yjPsPkCVWrDFVqrDU7rLbn2f2y7ybLu5QqNUra\nZZttsTR7nIWhKsNj04iYjVnXBhFOuviki01biBHq6UNcuhRzubudPY0y5RqYWo6WZljzSxw/+g/c\n8tZ3f8M4fEvC44vfGe+LNl7PwDDlc4wYUAcC1rHhU+l3OqQeYik6ujIHQVCUfTwGsYqIhyDAnc84\n9LlPcbweM2ICovo+0tUusamyw53gbb1ZquVl0q6l30xpLFdpNDs8+MX7mZqcpNdo0Ur73H7HHbzz\np9/Dzh0zpMsLvP+X38cTx75Iq51A3mPvzBhrxz7FUC3EIAS2mPgqNaVCUCggsrndaCkvGsXFGCoG\ncuMJJKA7KFEFGIwRSi5HrVDTnMAZEhFScRgxWCdYHJEvSoZOhASLNVB2GW52Frk4S2thkaTX5ekv\nH4PhiL/+Nz9ItSqQ9sgTg0Zlfu/TX+SpRx5n+EW3kmaORAqSpAPVJMHjFZKBh8eipDhQoasQeI8b\nkIyuKqFCSR2JGhLd/E7duYHQuU76pVDDFM9LXnyER1ccT504zw++9NV8+NerVKZu4z3v/w9cd3gn\nSVIMnI2Sm/qNPjG/7hreoCjF0FgvDqx7edbv/XVlKF0nOrJxBRlc5mpn1oaEtDGhrctK8tx6ln7N\nv58nev0enY5F85RWq02r3ebMyZPMTG/jRbfeSm91lb1791Kt1QnEsHBlkWp9Gwf2HSQqlzn6pSfo\npp4Lq4tM7dhOlqa4LGVyKKSVCDkG5zwhIbsnd3H93v20Vhc5d+YSy4tnuf7GQxzYvoPG4mni6ihz\nC02yNKCM4/vvfjtO4DMP/E8uLM1SiiJEHZ1mk4nJfQT9jHf987dz/0Of5czCHFe6XerRHly7z2rz\nGGe6TVoXHyZzhosLFzYXGCAvtQm7kA9J4XVzPVprPXodg3aETtszXDGEAvE49JeFLPPEk4LrQ3NV\nSdaE0XGL9Y5oTNk7Bq3lPgE1ggC+743fxU37d3Hl2Ge5e/rLdHuGl/9CMMiiQdcfAYVGXSyG33Mz\nfOC7PZ8/JnzgIciN4DLFBoKNio1cMdcZ4qrhoU+f5PhXTuIzqO9RahNC1ofKiEFdn7mLRYlqM1hu\n9LGhZWK4TmOtwVriEC1y9v3/7idp9ZZYvbTCqTNPYwKQ3HHp3CJDY7sZHa+j3jO9bYR7v/Q0H/rI\nKSKB0Ql41csPkC2dZnpnmd2H9hCVbqTTyymHIc12n9wZSpU6qpa1lS55lmCtYE2bnBgTlokmxsBn\nRJphdFBuUcWrKwh2ltFrKbMXGqSp47aJ3Ry5ZogrrQan5keZ3nct3/n6PmtLF+ilOZeba5vOnYa/\nQGgElTKRBIg1eONwPicNu3j1pC4liTpkkmJLDMrbFu+hazJSYwiMEJqAvvSohEU3pWpClzYQ4Oij\nPqFEldZ8iUtPzFLRNttCxx5SRkpKPQ7AXGa1K+R1xc1+kXi6jj9xP9WDN3Hpyw+wduJJJpqPEpse\nh6ZiOs1HucYIR0bKLMtr+E/p3yHY4ngRlL61RAPCs1noZ36VvJdB5rCFExJbqSGjE7jlC5goBNeF\nrIPkffBpsbnUDtge+dkHkPoMZmSmaDgp5HPCc78L4/uRkVsK9//8n0BrDlmeRc7cC9ktBNk09WQE\nt/OF7LrlDlRTrFHU9ajMHGHnnUNMd9eQzgru9JO05h2loEt1YoZxey1ahV51hnzmBuLWIn1/mF3l\nM+z2BpdEpKc+giYp3nnq849y+w2v/KZx+NZdWgje+8FCddWI5G0ALtuYGjJRrAevjjTpUgkK03Lu\nITaG9eXK51Kc3yCKLGZU/817qd9yA5WzRxmubqOx1GOnHKCSPsntvdOUdAHXN5A5AglopxE//lPv\nZe+BvThflC1sGOKcx7icztI8733vzzBaqZMTcqB+mZFAePbMUcoVizKGUYcMzrERQL0nGLTJbwbO\nQ2gU8QMPigckJxZwgUVxqC8Yuc09BAYjEKnDiNCRHKuGmkAsDgc4DPrscdLHHmeumREd2I/dt5fy\n6Cjn/uaj/Pr7XsG+6e3Ebhkny9htbcJ2Fdcf4V1vO8zPHa8zjGFNlEwUgyUUJcGQ+ZxU1826Bbkx\naogGpcl8IPd6CrITeA9iQR32n9Cl5ZFCZh+YBov+Ax20PXrSNOGGg3v4yFce4Kfe8mZ04RF+8Z0/\nzo/9zC9x91teWpS51snJwCc1KHIBXM3HIilZV2QK21/x4qKIWviqdL0N8mvPjuCrrDtf13HwNY6d\njeuu31a02BhsBp12l5HhOnEcc/SRR6jXa1x36DqSpM/c8gLnnpnlhutu4sSzpxidGOLkxUtcf0ON\ndgKXzpwmimJOnLuIBCGEder1GlF5hD0Hr+fpZ76M5AlhEBNplRPHZ1mba/DiW68hrgpRFrHSaCAN\nz/DIEJHJqYxUuHRllSwLePTxx9m5axevf9lryZIO9x59kDPnjlGvTbJ9ch8uafL3n/mfHLnuANt2\nHOBjn7qXtWyV8alJgjDg0uwFOs02Ya1Onvc2FxjgJ3/85/iDP/wVMjy5M6Qr0FqBfg/qkcEbBa+U\nx4WoDv0GBJGQOsi6yraRmNVewtKi5+ANu1AajFSmSJIO1juMQnt1jsuzGeeXhUcTeOkh5XM/78mc\nx3nIcgiilHYP+m3FqVArw2pLeHq+MORGAxM8gDqhVldK4pB2hAi4VGhdFib31zj8kiGayUVWlwTv\nheFJ8KmwcmVzG4lqNaJeiVluNMnynDgKyPIc8TBe3cZf/e3HGRsdpz6yg9FqlcWFecamhTRt4PMO\nYXWE8/MJ23fv4Vd++SChVYLCSEcv2UcQRajLWGv2CIOAJE+JwxjnApJejtcUAUIRkl4OEZTE00kd\nfaXYyDrHwmKLdidhba1DnilXFlaZX1xgYmKCH/ruF3LzdeM8e/4yc1dgredotHMudU6wY+848bVT\nHH/yBOKbm84dDRt4E5Ebh7FRodp7wfkUifuI5ohLCW0forRYP8SCFJYLrw4VwRVyME4go0MUCNZG\nRRec9sjo4JyjsTDM5ccvUcmbDNcCJvMlKr6HqMXYiNRbqtUaUaVKEOeMjhny3hrp6hwxPXYPreJ2\nhpDmmDBnanuJvJ9i3RVK3R43l1/MUf9wcVwHkOERm21a3QEIywEmL9qQrCiYgniT9snnnoHaMEE7\nwWZtEAd5BurwPiUHXHIB05jHDG1DTbF2ikJnbYxa60Ek/St8OFWUA+xhkt4+svh21vpXWKVKoz6K\nSbuEjfP4bTvIvaNUGiEIQmz7HP7kZ2HlNOXlL4HmRCMH4MAoGrZg4mVUKjvZ4+rkEpG3r8Az/0iy\nNsRCA8rXHUIqU1CdIt99J6W4/k3j8C0JT2FhKMLrN3a7VxcDEcF5j4hHpfBYaB80CjCDg9zEDg7E\nQwlCpeccJvPYm25mx79+Nw/NXyZv9BhN+vTiEdphl4OXljjgPoFVyBkBbePylM+enuBd1+wlSdOi\ny0S1OD8EUOd47//xPj7+53/J9DV7eeUrX83xoxeoTHZRU8YGoD7De0HEYwIBhW4/pxLljFc3t9sy\neNQ5gkF8FENOjqKEqnhjSKyhKxkGQ8UbclOcLeERylgCgU6vS2AtLjBUTj7NylyXymvfQDhSpkxA\neuo4Nzz4ed7y5sM02wmPPTBHffkKLuzxwncM8+QDC7QuLWMnxnjrzX0+eTyjYg29VhvmF2m22oxO\njJG/4hVYG2IVnFe6CKH4DVOy8UpPlKr3WKAjQh3FqmI3W7MBnDOD2n1hkPYKVrXoPpCi5JVmKft2\nTfL7n/o0P/HW7yFbeoY//62f4YHPvon3/of/nUrNkmXuOQyDDcKyrsqs/yHr9Gbgu9H1s3X0ardG\n8UPpoBT1nHKWDMjR13pxnqP0XH3sOQ8Pyo2bRafbJwhiev02d9xxByh88pOf4uDBA/STPjt3T7G8\nusKFS3Ps2LuP7sUlHnn6aZaXlqlVh1hud6nUaxgb0Fhr4LyjVCpjzAhZf5TcJWR5n8uNRegFLFxZ\nIvMJteowmDKLyytUywH9fod2pUOaZ3QzhXyeLAfNuuzf9hImZyZ5yY//L1y48nJq47v40Ic/w/zi\nGhfPneIz9zzC9h27GZnczcXZM7AstLodFpcXeNn1+9Ewprf6zdtDvxne9a5foJcu8/7f+e+ghToc\n1wqC22znjIwI07vK9LMeQVUwcfEj5GueajDEW77vbv7qwx9hYQV8NsubX/8uJkd28plP/QNR7Bmb\nggcef4TkwCG0b/jo5RmMznHXrYINIEvgLz4Bb3od7K8qrZaQO1hZho88oNx7wTBiHWkg+BxqdQUj\njM9AvjJo46UocflMIerw6EM9RnYJ4oROE4bGYGib0lzb3LiqhpbZSw0kKM6PsWKwJiBzGYmBN735\nO9i74wCffeheHjtxhivzy1ycnefUqXniGF54807GxiepVYdYutJhdHyUWhhgSyVqZYN6Jc/61KwF\nB1nu6fcs6ruI8UQo3VRRJ2QuIcky1la7NJp9et2UTqdPu5NgAktgDEFgyDPPxOQQ73jzazmwu8Yz\nZ+f463vOk2WGaqXBgV0ziE/JndBPPUnmuOXGAzz2xMObzp0g7oCkeJPgpFR0mSI4n4HpF2e1+Jwo\nTLDqEDxiAkRzvOY4Xyh6InYwPYQ4EVYiB1QYs+OU0iG6aU63W2HlsVMESYt6JWIqaxD7DoFJ6SYh\n+IxqPWZy5za8c5CuUa948rhCb+4MfvEk/bP3EEUlslaDuGYwPqY+PEllaDvl+ct8x/IOqL6YU/Zh\n3OAcJ2sKr+xm4bME8rQ4hiR0hCEFiU3apCfvJyuNEl2+ggZ9rCnOUFMjQEyuPbxeQi4+hR3fhalP\nUoxIRW/4bq6s3kze7jGx7whhuUK/12D26S/Q4ArB5LVcevooQWyoJSlrj34Ggjp2ejfUU9Sn2KyJ\n685jQ0cc5RgFpyV86wCm7vB6E6STuLU5TBhRSto05Q7avslaLWZo7hPYsT3kzRJ+LUFWFuHWt3/j\nHPlWQZLB7toNPC7FerDefmeBHIPFaVG+8lqcbeCcrjtJceu+FYRACzOzuwJT//3f03DK5OoKtzDC\nxR2TbF9ZpLK6yIw5T7cHw5HF5RmBT/Em5KGnm7xbBCHA5x6vDvKEqFziN37t17jnkx9j+6FDlKt1\nLl48x8S+w1y58CmyXhdnIbOThdQrEFAoG9VySJpbzq9u0gTmcqyxiM9I1OA1IRu04ds8pxSEoJa+\nFRIrdNXjcyUTKFkh8IKSI3lCqdvlykc/xdqBg4zffB3Gd0gfP0k2MUrj0UepTu4lKXfp9Rw325DH\nGj32v7LC4pMTyK7LXHi6y0zJ0D03z3LrBWT3foGs3aZ0x21sv+kmzn/pUQKKLoKKConxZFq0qZat\noZU5nEDklUwhFqGsHq+FPynwm+sIAMhzIQgKRUcGhnVvC7M0dr0CXJzwPDVe44P/+A/8zA//KAun\nH8E88Tf88Bsf45d+5wMcuW0fSb8/IBqF2rJOUgZJunE2BBsHBepXqS9mvTQmMugu+JrfevD2r/L4\n6HPcO+vP6XPI/nOe26Q4iJiQZ06cJrBKq9lieKjGXXe9muXlRcIwwpuUo08eZXR0mi8+9BgmCCHv\ncHb2IgcPXk/fOcqlKnv27OHBBx8sTPhhxMLCFbZNbefipbOosSROWF5e5saD13PszEUunfoKo2MR\n+/eMUysZfO4Io2UmJicYHhtltFymUq4TxhUW584xEmxDG/b/5ey9oyy77jrfz94n31w5dFV1Tmp1\nS2pFS7JlgbOxMDAegzFjDx7SwGOB4b034DfADDCD15Ae89YwYwawCTYDtmVjjCzZVrCCZcVudc6p\ncrh184l77/fHudWSGdtQ7LW6+96qWn1v7bvPPr/9/X0DbhLzpc9/jiQWuL5FZXQIY8CpBjx35Bkm\nRye5fPUKC4sL3Hnv/QwMb2e52cTxRzY3McCP/utbefnocUZquVzcDQRWYOg2BFoZpAupFWMywfoy\nue2AgW4PFC1u2vc6ih+weMNd72L7zA6Myvh/P/bLHLihypkzq8gs37mOnb/E1uFJIuXwl8cG8Kx1\n7rlF4npw2374vU8J7tojuOcmQ7sFXzyqeXgBfGkjLVhTijQVOfnVzlutymgsacj6NycUdOLceLR9\nCooVKJUhiqDbAGlvbuFcXmwwWK0RRRFJnJCJnBtjWS6s1tk6NsbymZd46IkXqflT7Jrcx90HRxkY\nGKBSDZCWxPYEWHDm7BXOn7/MK0dP0+y2mJoYZWZqEM+1SWJoNFq5SaUlSJKYNDFESW72lqaKNO3L\npS0Hv2Dh+h6VWok0VSQqI0kyhioB99++hbHA4rnzczx/UhEEDu1Oyu03jDG3mlFvhFgWOBYEjqDZ\ny9Aln9/8N+/a9NqR7jpCuGjhkGIhRb7La61BhhiVIIzGAdz+IUlriRYJ0uTimo3WtSVdMh2DSUjM\nDg7KOxEL6zzx9DMUtuzFt5oM0kbIDBGF9HRC0TEUXcFASYEf0Isyav46zYVZmusd5MJV/OFJotnz\nrDz/BcT6CkFNYjuS9UuGsbEm4wVFefIgSTTIRHeN+9xp4tKNzMXH8fSriQX/cAv7x4aOkrwbYbKc\nLuAJHEuRRG1M/QpZeBpaa1B0cYpF5PBeilt3k6zNk107DVFIcvEZ5NAk7vbbEX4FhMEi5Yk/+RXG\nBiapDxZxtx6gvOUmqsMzeN1F0qOfYWk2xd5VxAqGsB2XzsJV0gwGRmaQTglZHseyNcL2uV5a+Cnr\n0++gsPNNFGpjqM4i1s5BVp78c0i3sTQ5jZmA8oUHEW2XxtHzJK0eNb+N4dsLkL4zabmP5ttCovqR\nAxiBwvSJbhIjNLlhbe6pIhVoV1w3YzKZxLFzrkhsNEIIbv6tj9KeHGOt06G4sEp5fAYaTeqLy2x3\n1pCBQKSAUARuBy0lOCmj1YgszUiTvNhI4oiCLXn6a0/y3379N6ltnUEnEa12k4IcY8ctkwzs/lcE\nXkC1UuFrn/1dothFZhmJkAhH4jk2rm9hu5tbQKlRpJkhFgKZxsRKIywLSxpiCYExiMzBKziI1NCy\nNyhuuWwVwPciOHmECy9eZurdb8snu91Btzus/dnfUDl0gIEbdjO6ZYr6/Es4pTnOrmsO76xSGyug\n0wMMJiGHDoesL8L05F4G/uZlBraOMOdOo184zjVjU/ngDxLHKY6CTFqgcyfnBJErbcjbBJJc2ZNo\nQaoVgZA4/XbUZoeAPppm+kqp/IaFJfr+QnlFYgClNOWCxe9/8s/5Pz/0E5x+9kvMTGs+/IEf4f3/\n9hd4349/H2maIoXCthy0NjiuTZoY0gwKBYFKM5TW3+yO3C96jDCv8nCu83pe3THMN8NF/V/gVY6P\n6ZsAXUeDNtpar616NjHeeP+beeyxr+TITFAiSTVXr1zj0qVLTM9M4DoT7D9wkIXlOo3uOqVShayT\nsHPHLhrrTcqlCkEQ0Gq1GBkZoVQqYVmSTq+BUlneWnUcysUiUanH0eOvcOvNdzAzvYNW/RJGdNDK\nRRgbzwuo1+skImGwuI04TZGOZnB4iFK1xqW5y8yv97gyt04kfBrtBoduvplemrIwd4WDB2e4eH6d\n1XqdHbv3sG//PnqdBEv4OHLzbrlrcpaZGwZZ6NVxQ0Hcy5We5SGwOwKvCFdPaAYLgso02MOCXh1i\nJYhTzece/yi1kSX+9itNPvAvfo5jp75Eu/0CfrAHSxaw3VyJZLkJS+FlHDVMhs9/e6aITY+7bxUc\n2A9DLxl+5VOCv98Pj57WLBUmObTH5ciRywSuS9FoMmGIeoKgYEibMDomWI5DtACTgbYgSfPuAEIQ\n9cArQGdF0KhrvGBz62Z0aDBXeCKRlo3KFGncodOGbSNb8cduwBm+i9//d+9FZBFJ2EJFLVTUQaUh\nQtpI28eybW687SDOvXdhFcqcml3k0a89wy//yu/S6sHP//hhCoUqvdjQabcBCKMYneZ7nrQlhcDO\nuYFxkneKpSZKFc1OyOSgz523TiEswRNHLrDDH8HziyiadLop46MFxoerzC6v0+6ElAKbRjvFCIUl\nDZ2uxtsyuOm1I7xuzjmBvnK2X7/YeVyNkBu+vv0hc3NYaTbuszagUcYgSbGMQcpR7k7eQnzlJKcv\nL1D2QWZt9hYF3WaDRmJRCBxMpujpmHXVo1yUDPsGWSpw+dhLrC9kGNfHLM9jeT6NMy8xf26OSuDg\n2xpv0MIfgTiUdNbqeNUVytWA7Ow8Y6ng7cEBni3VOJ0+hWPMN+1d/9ShuiF2/1CoyVCuyT2LYoXv\nDTNw6EbsqYMkta0Iu8CKM4jQLSrUKd1pMO3TdF58hujkVxB+GWfqRoQd0Fo+y65dhxkvV7jy8MeJ\nmj0OH3gLfqnCwtFPUeudZfeBtzDf7YAcxqqMIoanSY1GxXWwHfBrCK+I6K7ka8mA2zxL49Hf5Suf\n+iPuOLwX2VnC9Dp0FpfAL7HcTji77nDnjVtx2h7h3EK+T/s2VtH/tvPwj/jw9HuZOsk9dDAYIZFG\n5D69BnQfwdEGMp2SKShITdrP33KRGJNL2xUCJ1XM3Xsn2eQAWy5dxtRqrPV6pLbL9lqP8dXn2WaO\nY4sRpLWW32S0JhMVXlgcZunaCrWRYbrdCEvFhKnhl37xw1SnJrGxUTLDcgVxL+SpJ57m9vvfyP0H\nb+WFb3ydoFLGtQpIKdBSoJUmCtP+KX1zi8hLc+WV0imFTGMbTYShoEAHHj2jiU2KneayPJSFkjmy\n40pDz1Gox5+ktZIy9cB3kfV6SCkoKoMsFNjyI+/l1Cf+EnPkRe564C3Mzy1zy8ECO+8ex2o2WXsu\nZdU7R2st4sa9Du6tVX7v9x+hZxw6toNz33ayn/hRSsMDmF6EEZJM5kRQQ4o0eVKVFJpBrZg1OYLn\nmtwd2zWSjtEUtMayNs+S0zpHcfKsMfqyWXG9SEaIPh8n/6OUwLcFf/AXf8J/+Ll/xxOf/iOmt2/l\nbz/+UV545ll+7bc/QmmowlNPvIAXlHj2S5/hlSMnsFHsO3Qfb/q+tzIxOcLwRIUszlBCI64bX268\nTn9N58QeXsWZ8u/8w/EPfXo2ECZe819qxHXb9X/q+Jfvex+f+cLnmR4dIk4UUdhjdbXO6Ng4WZag\ntMsrx89SqlTZMjXJ6lqdanmAWm2AbdMl2q0OR08co1KpMDo6ilKKTqdDuThKpVJG6WkuXDoJKqFU\n8il4RU6dPEmxVGD39mFs45P2eti4dHs9lEmpWlUSragUPRRwdX6RXq9Lmikyu0Y7Br/iMz4yRdTW\nZMZCRZIkTqjX17jr7ruYmtlCt7NKvdWkm2bIf8a6oRKRNiWmC5YLMoHuOiSRwWSCbtsQ90AX8riG\nzipUhmGoCKorMd4SozMwIq4QnniaYaNxgnHi2KVSrTDWBS18VldiFpY1BVap+MP0lMtvPZzy/9gJ\ntx+S/PDb4P6b8hbXi13wSzaNuINXBhXmPLfhKvgBIHIfrTCE5mpuaJghcMowtgWWZzW9UFAuCjwB\nqTJYDhQqm5uaNDMImRJGMUnSZX4u4/bbbubnf+YXYHg3q/OzxJ010qiDVimoNG9NaA06xXJ8LNvv\nt3JyUYVjOUwXivzkO9/AT7//PTx99By//Z/+C0dfeIztN04yMFhEWBau5+AXHQpoXNelFyekaUaW\nmnz9hTEjNZd3fNdOVtbbPPLSFbR2saVkVnTYPlDFCiWVQYtb9o7y9LHLWKKCo23SLMWVklDl6iff\nd/nkw0/zln+7ufkR9gYKyzfT7jYQXItvPsRgkHbfMNcYDBnGiPwwnWUYA4OlG3HPXaDVW8cSMZPT\nM7zhhnGm1AUeXMk9cTrtDoHdxpZtWjYEDiw3JKmVsjaXcXXVY9SNGNURZAn1xUW6IYxXcqWwE2r8\nwDA6rHFLNmmvTTcx9DprlAcd9tgNJoMD/F1gcz56ikxluX3JJkbWSZEWCFvm5O2Wzt3Uw4TRH31P\nLr1XMS6LoDPsrmCVKSLLpVJ/ieLap3FqNxKev0B66lGEX8Ye3YVXnKAwNkGaxpRueRve+Haay2fp\nNRy8HffD6I0UgiHG68uo1SWGx8dxp4cQRpOlMXG3gddaxs402do1TApoiI3Dte4kd73p7UwWVzCX\nzpB11hibGSbrrOCHBr19O421BrXWGp5tMK6XAzDRP9NpeaOBIKXdh+1lH07T/QWSIzea3E5cZzon\np4rcYjsxua1/isDIPD6iuwhTQY+Z43M0dg9z8fI53PHtJOEsbjOkZkJsp4I2k6jsKSxpIYsgmppq\npcgv/PxPc8/r34CwfKantnDl0mnmZy9iewVsy1wPd9y6f5qx4REGfJsvfOK/YmVXKBbHQObGfOhc\nSl4oWKC/tfz4Ow3fKJJE4WeGyOTWXW5+9oIoRdl59ISlDLaUCK2xTd5SqWlNcOEi9YWImXsPY4Ux\nQa2A0gK13sFxJE7ZJ884sXjslXPs0gkPf3KZn/reEpcaHaarRcJTZ3B9h+MvQmp1WVprUfED1sOE\nQuE2pseGaUUxoQAPg0kNtp1/Hpkx+edmoNkPhRUijwxJ9cbnn5tEtjff0ULpfLXQ5/HkeTC5e+iG\nLcF1TEXkCKE2BplF/Ic/+C3+6/AAl57/Mu1eSP3q03zgDW/g7R/8Sb7+0F+QdNeYmCoRrXYpFhVf\n+cIzPP/VX6U6NMhNh9/Dv/ypn2ZgbIJMZRj1D1g2140W/8En/i0iIv53c8IcNtrgJRnAaLHpltbd\n97+ehx/5Eu9+1zswqkSmBUFpkChW7N4zTZIo6o02rU4MWGzftZsoiplfWqHbvUyhUODee+/hscce\no1qt0Gq2KBQKZCpDhKWqqQAAIABJREFUSM3pE+eYmpokzSKE5yKUhSccGu0Ox45d49ZbDqLsNXpR\nGxUnWJ6m3ogoBC0Gh4ZxHY8kgXonYc+2LTzx3DFKlRpeUCSKQ4xIyMKI+mqd2dl5ur0WtmNYb8yT\nZjHtRos0TQiCwuYmBlicjUkjg12UtGZzwnJtFIoDgm4j71HWJgzRiqHXNThFSDqSmf0GXTY01uDF\n5+DAnpM4ye+wd/oD6I6HPV7DEgEyCGjWNa02BIFAZLDWXaVWqNIIC/z3F0oEXp1tM4LxUc3XzwpU\nMUeGAtdmZHCQpfkWjhaoLvRSg2ML7Cr0mnnkjJACrRReSTBaEkzs9ZlbTiiPGGRP0OxAGguaq5tb\nOLnDcka90WSkVuNzf/3H3H7PWwivvMz6tdMYnZFH+rlIKUHaaCEwKkNYNpYb4NgOeZme+4MJSxK1\nV4h6dazZoxyuDPDg3/x3nnrpNL/x678AysV2c/GD61goDa12jyzTdKKEVitm/7Yqt+4ao9EO+cyT\n53EsB60dbAuGBooEBYtCUGbHRIDthywu9+h2JdWSTZJpqkWPRtwmNRmu7aC15Kkvn9/02nGc/BJV\n5jXdbXiNS/urIoXrKInJkZ8NYY7EQRiP6eBWqnKMkhikJ0KMEujiIO+69xA31y4TdT3a9Q6JsPGt\nKM+RlEUaWUa3rtk2nLEW5mrMdmzTWoq5Ierhiohr5y8zNVKikYLnKtxYMbso6PQUW2Y8LBnTXu/S\n7sZUBiIqhYxqOePd8m4e9gc50/1bhNoceqrjFGVz3Yk6DmOkUDiDO9HNP4fgrpwwp7qgG7iiyFTz\nzyFU6DWbbDkjPPP36KGdMP8K2cguZGUcrzhALXBwB4coz+xEjuxCSpssS1F+BacyQLF7hS0HdiPk\nVozngj2OjiM61hS2NUdQmkSPfwT32J+RPPmHGBtMllJMVyiVJMm55+HyMUQwgtE9rOoIDjAwuYtg\n7jjBwjKWV0VbNiopoJNvH0vynUnLQgAaYSRCZoDqxwH01VtGY1QG2pAJSRJn+UXv676TZc4sl4g8\nNO1aQvXP/pQzR08hbrybpdYiBbvAcquFijzOJsMMxdPUWrMMFZ/CBAEmSVBdSLXm0E6Hb1xo8MW/\n/QJpFrG8sgIKhgZrbBt3cBFYQlAMJBMDXZLOAq+cWsIIiy2j4/mNSYDRfWiz73nz2kymf+rIlMbJ\nDCpTWJaFFuBoQywMrlbYyhAZQ+zYCJXiSotYaDwEa7ZAHL1AcOdN+FGGX/BodnuozJBlCYVMkkYJ\nez74Q5QCn+6jXyGaXSbsdfnGpQWmtu/lyU7K8M33cuH8JdTyJQq+RCYJsZDgOax+8itYL13Cfuub\nCccmKSpFz8qjIkKtrwe5KmNIEGTCkKFxtek7Z+cZaJHJ40M3O7TKJ1sYcst0ARtQoOlbJYsNwnHf\nqVsYk/s8JSE/+6v/N7/+cy1Ky2fIsJkYS3jswY8xVJL4JR9PFBgbtBidsLhth6DjFGnHGa3uc/zO\nR57lwIH7uOvN72Rix0Ecz8qzW+gXMfnLXmcfi421fv0o2F//bKBAr70xvfpM9zfOzar2syTjpkM3\n8/jjT/Hhn/0ZTp04RtF1KFcqzM43KQY227Zvo9FqMzo6jmPbtJI2ruMyOjrK6uoqH//4J3jHO97B\nxQsXOHTTTX1PlTUajSb79u+jXl+kXCmRdRqUC0Ookk3QbRP2DM88d56De3fi2QHrnYvUSgVGh4cJ\n3CE6LY0xLTzPZXF+lrWVJVq9DNcLcG2LKALPDWist2g2WriOh1OrEicd4gxcx8FxXWzH+RYmj//4\n6KwKem2BtCEMDTqB+rxAWgbHE5Rr4HYgcwUH7hxnrR0yv9DA8uBNB29m0m8icMGTuEM1vvTkQxx/\ntsG+uyaYHA9YvBISJfm1P71NkjUtFhczItFkbHyGnlvir85V+f7oEg3b4X+eSjmwYxw7LlEREOsW\ntpOhsh6uEGRdQ2Z5RInG9lMsB1SiQYCnYHXJUAo0diborYATgBvAcMmg083NT6OVsLi0zI+9/4P8\n+1/6CMn6GuvHH0VrhbTsvrAky1WDKleROU6G5WhEvEqhMEPcOIURAVpnSLuCMS2EXUHKAhpBr71G\nfORL3LllN3/914/w0z/zk7Ray6jApxvmflhpmpJkGq0V3//GbVyer/P8xTppZnAshyzVeL5NqWDl\nxY7rsNxpYGOzd7TMkXOr+K6NJSRxlmB5Dq0oxHMcfMvh5aOzXNi8wA/PMf34nA3UJv+6Ma/y7K6b\nv7NRAOWPFKbP+Uuw9CghC5R1Fa1XifytNGrw5u0DOM2z+COGuLVCxesxt6SxfEFB2KjSIAXZYWzU\nJU0bVEs2bgEaHc2pDkg3IM00rguBm1JvGXZO24Q9TZJlXF0tIoIaI6UZEl9geJrF+SWmbjyEcCyG\nVZvXewdY1SeIkjObmhuTphgtURqEyXDGdxHsex3Zxa8gD3wOTQ8ZXcE0T8D6LIy/l8xZZP3FL+Je\nO0L32BGujDzAuLVKMWyhls6Rbb0NHJve818kvPA8Ztth7Ps+iDMwCtIi7nWoNU9Q3Gahk0UwLUxc\nQDiTdOIZEq9Iof4NdPJV6Cwg0imEO4aJl/AE7Hnjd9M0NnLfAzC+D91aIVpdxrgDsGOc8V13EBQz\nrOazyNokxvJxfEFw6/d823n4zgWPMRiRgc7j1wUgNGip8UsDSNvKbZ8LBUrlYdbri7zlx0tcOnWK\na6fP4zkgLAstDKaV4H7459g/PkajUUSQMt5yCIs+dpaQpQ08L6VBlRUxQ6LO4JFhpIEUpGVz/40R\nz150SKWNdm0qxQEynYLx2Dac4aoMJQzGZLTXroC0UNpge3nImBQyJ1GTOz2HUcL4gEWYSMJkcxVz\nlihkppFSIHQ/yoCcD5M5FiiFbcAyCqnz26QlLBIyCklGmGWUmy3U6AArUYjKDEblKFCcZaQSTLvN\n6c9+Frm4TMPz8ITkq6fXGNXXqJKRzF5m5417+dTLLTqL1+hkGVUswjglkxaXj5/FfuEEI7/yf9Eb\nGsakikxu3MI1icnzgIzO878E0DMKY2y00fhmg56+eS2STnM5urbyOAghN2DjDdtxcrSk3yq9zqsx\nAqUFWS/lJz7yH/nUv/8xPCvlwnLCzbcc5sKZo+hMsbLeplD0uHSuzfFuyt6tKcMDKWnLp1S0iVYe\nJ158kRcv7WJk5gPsOnAApVOuR0W8tkWFyBf2BoJ5/W/Rr39eG5OR/17KACb379Cb1KVblkOqNFMz\nW/nTv/wk/+tTf8Vvf/SjCGFjpKAYVFhYWCTJMmZnF5icnqFUKtNqt+gu9ti5cwff8z0PcPXKNW6/\n43U0m00s6TA0NIRlSxYX5/EDj1phJM/jscs41Qlumj7M6PgkX33kQS6ffIQdU0M4fgUhBUXHZd+e\nvZQrFRYWFsiyBMd1SXGIsxQ7cOh02mzbup1Gq82pk2e5dOkqxWKZ2+/aS5J2sB0faQUUCjbdbrcf\nbLm50esBmaEXQq8rUImhMgRxS5A2DCIFrSCzNUtLLVY6IT0FY4Vb+eAPf5z04sMk9QsgLK5dnWPf\nAKztMATCxnE8wgikC2FTgFKkSDxPstbUTE141Ea2Uxv0eV7sY2zLNsKv/CFZO6Adxaw3I1bW29TX\nQ2yhqBTyZOt2JyROwfGh6OWOs0rlN9ChLZDGGZOD0OmB4+avjxSb5sadOL/M//jPv8L7P/ABuvOX\nSZMULAejDVIKLCvD98ro5gI6i7Ecl3D+AlnrLGm0xrV5C8tqULzhR1CXH8fyixSqGbLgI6SFXdmJ\n9EsYb4jeyhW8Xos//ZNP8GM/9sM02x2SLEEbQZgm7Joqs2XI49Nfv8KA7yGkRmW5atFxbDzPwvUk\nUdgjCm186bB/psCLZ+aJE5sw7OKOV0iVoRfFBAWHLFFkqeHRR84wvcl2H9APoqbPydkg2tEvgnLa\nQp4a8M2HGAv6Bqe5kMWRHTy1m0XaoGfo1HvcMjXChYsXuanSobWm8DyL+dU1WnGR8S2TmFKZkekJ\n9uwZYTw6S9yAbn2ZXkGgB1POXYbM8VhaXqVlHLraMDpYIE57LLUFEzvGmNw2hDO0n/LkboYcm0sn\nztJqrmBSje61aDQSyr7ihvIBjpnNFTxCgJYuwmRYts/gu38JESRcPfkc2W+9iYHdu1h8+Rjulmmi\nSx28LV2SHePUVx1WO1to7PsQY3KV6toVUqcM7SVUr4Uol0mlj+cOIObOorttTG0YCztvnyqNaT8P\ng/swwetBljFLn6AcNKgsfwp6VyG4m/RCl97poySxTa3QL1CvvER5h4Oz9jLi0qOojoVILTICNJI0\nbaOuHcdvNSDqIKwM8a/+I1cmPsj2bzMP/whp2cJyPCy/gGXbSMdF2gFC5g6jShks22CMjeXYGCtg\n+4Gb2b7/IN36PI89+Lf01utYjqRTh++9//Usj/hMHz7IyYvPMuKPsWZ6BJHBNDQtHXKloTlQXKIR\n3IhuH8ELLFwLem1FGrvYTgHLKGKV30yzNCPwLbIsy/0FIF/o0kJpjco0RU+hdQyWnydxCw1GIG2b\nq4sxhYJhcnhz8LubxigjSRQ4dr/UEQK0QlsCpOgTuaFvuYhRGsdkpEISHLqR9ZePsPjYCo4lUb5D\n4LikWYatNYnRhLPzFKKQUEA3jFjPUt542ySPPPQoE9UiExMTXFtcJol6hG1DsWgRGYPMUoTlEmUZ\nluWjjr2CfOP9GJEnZFv9WIu4vyn4wiIxitTkQXoCjRaGngEHk6sXNjmyLOfsCJkjghs2OAj6LdAN\np1LTd+vcaG/lhWmqDGEieeA976N1+WG8Y9e4sNzED4pcW2+yZXwU23ZZaK4yv7DKfD1iZHiQg1sz\nZsZtOoOav394mYIzxz2lGja7URv8HfpFTL9F9c005tfG5L7mUui//dxMDYyRKNO3fd8k/0v0XzPV\nKUGpzPt+4ieZOXQrP/+hDzFZc4njiGqtComkXM7RNaUUnuexuLjI4tIi27buZO/e/cRRyuzVOdqd\nFnfceZh2p8ny8iJ79h7I1UNphgxG2XXgfu69+y5qI2N0ki5Xz3yd5aUmxpIUiy5Ga+YXL7C3to/D\ntx7khedeZs/eAzx35BhK2JBl+L5HsVhkvdnh9OkzedSAjICMJO1iux5aGYzQeJ636cgNyDk7rVZO\nLBaJwRcQLhmENCglWZ0z2EXDjpESh2rDnIln6QjNA2/4XtbWlzl1ZZnVM2fpdFJGPMPA2AC7vSEK\nlSFaPUmxCt0OlMqaoVKVlbRHXMxImyBsi8rAAOPjNXZunQYtKXhw5Pw52msJWQyuDSUBy12BY2kK\ngUXg5ER/nRmMk2/WxYJAJjmIWAjytHFPwEDZptPNaLfACjY3N7/+iz/D+z/4o/TmrmBMntmkTN6e\ncizonHiS6Mpj6Oo9MHIDlb2HaJw4weUXzpJhs/tdv0znyjHqYpyqW6O+tsz6YoilVwiqRaR7lOL4\nHopTt2CVpsmiFva1l/it3/kYh2+/i127J5mUDveObmNCFKi0KuzfNcYXFs5wrWvwbPAch8CXuI6k\nGyY4QuKg6YqI9U4TnQjedhOsegf4+nNXqPgucZzvCOWSx3PPXKYOlP4Z6CD0jyobSD79h330diOX\n6nrkQf8n9PUWtwHjk5pJMrmfJN1D2PQYKZ/HdWzKQZF2awkha8w3uiQioDI4QmXHreybqrFrqspA\nQeG2G1w6fgURd/CFy85Jn2vzLTKVMTpawxocZD5RXF7rsnUQDt0gGdwyQFTcSm1kkOpQSq+5zq0/\n9OM8/+n/idKKpLVGMRjhyLNPcOit7+K8M7CpebEDH+GWcT2J0AY5PEV07BN4l57j4vv+EEe2Wb31\njVSGJtj/vv1cfPIhku4K8wtz1LbNMPfI5xi67200mwOU4nWyXhMTtpFpAqUK7vgYOoowhTKyb/ro\n+D5JO6R7soXv/BrWfR+Dmbchpj4C3Xmyyy/S+epDiOghlt0HiEcOw4RmYOGTICE79TTyypNQqiBF\nDaNjrF4PVIiQCeELX8AxIU4RjM4Q4zfg1J9k6/xvwlujbz0P32mSKhO7QFiYvqmg7jtoKpUvUIuE\ntCfJVMhit41WGb5bJNG5nLnTqmNbLrqbMPZT/5rnbtlB69wFXHsBr+Ox3ljAVgqv6KFKKbWoxdDO\nMUrLFi1vBNaqVFnFCQTlIGQiilEyxHWKyG6CIyUpEukXESJG9BVkEomNJlF5Lz1OQ+prGUFJ4dsW\n0nbylFpL4lQKZEpzaXFzm3OQZmhhIWwboTSpyFtDSIGnAWPIXAsXQ6IV0lg5kTZLMGGGZ6UEtx4k\n0wYrVqRJiIpSbAkiU0THjhMurdFKYvYd3M/C/DWSMOXLX/4GtjAkMsNPe6hQYZa7VPxRGtE1Bioe\nvTTFMhk6yxGNxQf/nvKRk5Te/4PooWGUys2rJILYCKTIUNoQidx3JyAPXbVN3vPdvNYG0BqjrW86\nZV13Q95AVsyrmxEbHSUEQmpsAQMFQ8NxmZ4aZWU9Znb5DK6dMTk+glIJa/U2a6urbN15C2nSIola\nXFwpsdBssW3XIIM7PbxSjVb4DU6+9PvsvPnD3+y5c12m3i92+jla4lsVPf2fM7rvvWvyYE2tc5+W\nzY18o7UtC2NAZRl33naYF4+8xL13HETiMb+wQG1wiFKpiCPzvJowjNi7dy+tZotut0u1UmNubo79\nN+zHsiRz81eo1Wq8+c3v4OSps1RKZaqFLQyNTUK4yje+/nne8tYf4ND+XXyxWKXb6tBrNhkbnmZ1\naZY4aTAzUaVbrTC9dydLVy8zu7jG9NbdtDttxsaqtMMm6/U6cSfDKfgYWxD2UoJSCQsFogcCfK+E\n42xS+kjO/SqVBWli0AVBFoPRkrGRMrOLTZCGqgP33n4nd04Ncqixj7X1Ot0XnuDvjjzOqvIQXU06\ne5kdd+/lsVN13GKFpHmR6Yl9FKTPehJRMD4q9rB1hmVpZKxIe7nAIE0iVJZSKlWxHJskTBiSIMuC\nNCjgRT3WIggTQ6eXYTtQCERe5GORaY0rBUEA0ZIBFyInJz7HoUPcy/A82Kwv44d/8ReJlhcQ0uqj\nyhIp81if5ac+TyyHuTY7gli4SO9yTONv/j+C6gjISSZHB2md/iwXT5xB2F9CpYZCIMAqMThwN62V\nqwSsosI50ladytab8IZ2E61dZmzLjfz2r/4CZ/76c8xUx9m2dZyBQpXBbXsZtlu8d8+9/PHnv8B/\neegse8aHMDolS8EWgjQzZCbl5t1DHLu8jhGS9dlldt93gBPHfIxKAQehYGmpwb4Du7Go8cgLm0Mw\nALLrBqcCrV/l7+SXrugrLTeUSvSx5j4HtX/tp+zGNe9F9aZpr67QXl/lzl2DFOMVQq1RtqRnyjx/\nsQFD27nntjvZsaXC5KCN62ouXjjHi6+cZ4ctGa2OIrMGjhezZwxGRstMbN/F7TftobmygLe6yr5b\nt3Ph0hxfvKi4Uj9CVrnGdx8+yA+8YRIrs6lsPQw0yLSPl6wRWCH2+jK3jHz7ts23GtJ1cQYncEsO\navYqILCHdyNdwZnH/oLWhXPcNewigjJnP9fD1hnutoOM7Zim6GruPriTVZVyZiXlNi/C9FrosIVU\nKaLTQPR6GJUinZwjZtIUnaasVbaRTRQZDA9T+tM/hC0PsRoLaDeovP6nkB/8UZbTkJcef5BpN+b1\n6Sfh8Nsxpx6C0EWoDLs0ggjKqGtXMUmYG0YGLrayMLIMdPN9Pe5iwh1Y06//tvPwHQsepRRa9gMB\nlEKpmEQloCUWNpnJEFIQeE5u320g0zE6gSRKUQpsOyFZAfMDP8yVk5cYGBumsFqn9dyzdOvr7N65\nmwVVwcPFHxyjmfZ4Zui93BY9iCjdyEhxBSU1vneKfduW+e57p5irJxx9UeG4MBy4DNktPJJ8CQuB\nMQotQGV5I9exbYJAoDLoGBAqw7MNru2BMDiWwNokTSVWCmVrrFQTk7srKyHR2oDOsHwXKSDQGgtD\npFMSAYctiw+M+fz2pTbNOMZRCmUErnTJAolsNpn/6uOEjSYYxT33vY4Xv/EcrTjBM5qh0m66WRs3\n0SytJSSZTVnU6BkbO07RcUSv18VyXSKjsQ2EwiI8f4mlX/vPbHvPA8jX30s3M0hL4BhDqKGHRmuB\nNKBFnmYv0bm8fvMH9TyHxUjQfT5wv/axBH05url+CruuCt/4VwiUlkS9DF8kCG+cm28uMlXTfPnZ\na7xy8hJYFq1mg2K5hK0XCQIb7dgYZeFbis5Kl93jRYLpkCiUKHU0L3CMfs0maF5D2elTq/tffzVH\nK/+u6bMeNWB07j1ljCDTgkz974jQdx79k2a/uLIFYDSW7fCeD/0sD//lH2ESQ2O9zfhYAc/1cH2f\nK1euAAbP84jCiMmJIkNDQywuLrKyuoJScd/JWjI2MkxjbY1tO3ax3qiTNVdZadVZuLyM0BZbp0d5\n5cg5MDbdTsrM1lFGywHx3DyyNsFXH/kar7vlBh649zAnTl1lS2mQwwdv5cS5YzTWlhFRD+E7KGXw\nfQ/HBs8T4Ci0yhEZKTe/cIQAy4NeAmEjt7kQSpOmhjTJ5znRhsdf+CqZ+S5Ov3CaW6eH6KQpwwdf\nR8ktU47mCMYzXphtc34xIosbjI4NsGPGB+2RNiLwq5w7k9BOuygj8YqABMsSmH7r27IlYxM1ulfr\njBegg6AtJZYUuMJcp3xpRa46ycDyIcugFymyroB1gSjAUqSZGfeIPEOWGkRd4Jc3OT/dEKNFzoHT\n+cWjwxXoKepmkt7sRVy5TjcxxN119t1yCE9kNKIS1y6fZKQaMrrjAGcuXSVsdDm05y6ihVO02rOE\nvQQTlxmv7ECeO4FIPo1M78UZvYd09hXuvfvNtD72O2w7NIFU52m3qyRnFykevAe9IvjQu+5n19A4\nv/HIESpuQBInKC1IEs34kM9yo4tB4NqS09cKbFtaJU1iSgWPTOXWJqVije9+99088NaE4+/bfMGj\ndR4UuzE2AEa9AfVsXMv9FnQeGtJvpwtQYgyZfT/N+Rpr155FpyG2MQh2o6MO0sRMbt9NdVCy3lih\nvd5gbXGOPeM2vUgyMrqV85bP+eV1vrbW49CAYJ/UbIszdr/hEKPjW7DKgxy8700sHX2Uhc41MmHz\n7Nw0jy/ELNj7uGt6io/88cPo5G7uPbSFsLGMt3uClZMnkYNVRmuGUmAzxcSm5kZIB9VbJV5eRWRA\ndwV77CCVN76XnVEJSyaUrhwjTSwaa22StEtpp4teWCTZthVnbAoZa2bKhqSZkHWbyO46RmXYaUjW\naaMsDx2FKGFwvICwsUDQvMzCy49xvDPAzjf/LDN79hItr+LKCCHaLL3wRZJzL7N6qou1Y5roru+i\n+IYPYyrjtP/mT6kNekjHwmQppt1GxAk4BpEaHBGRaa9PT4C2v435pUHs41e54cC3nofvWPD0uvkR\nRMrcOVNpgeMESDTdXgNhXCzHJYkilFE4soDt+SRZB0vauffCQpfSb/wanHwR58AtJGKQtZXncScr\niB1bWa2OY6V1EqdI3bXwoxgQ1JNp7hjpIJPnEHGP1PFxgjXeuSPhxJSHX5as12MmnQTr8hoZFkJq\npNFoKbCkjSZFOj6lgov0bYoyyIEFkbuItqPw+nN7k0Y8wrIwlotIY2whMdImUSmlJEPVyigBxUwR\nofEMOMKilcS8PV3jufmIt4yOoaKUTy22KBtNKjW2Fpz9zOewpMBWCa7j8uxjT+FJKGUZQkpa4RVs\newilHAIryHOOhEBni2gjaLa67N6xlfVml2StTiYTHMfGMYKu6zP76b+jfOIE1R95PzgBYT8IzzJ5\nMSK1JhV5H9xGkKLAbJ60bNIYbBtpSay++YV4DZryavxcPv/SbJCHTZ9IDpWKTVoawFo/wfriOTAR\n9+4S/ND3vZcrCw0++9BLPHvkFKrZpug5jNTKVALNYr3LpdkWS6seb1wdZOedkqFwDiEcjInzvC24\nzsfZcH4W11Ef+o9f+wvl6g+trbylpQWZyr2FMrXJltY3xVfk0n1p5Zv197zn/ezeMsbv/adfx7UF\na2t1CuUySml2bN/BkaNHuOGGGxgaHqDba/H8C99gdHSE6aktHDv+Cnv37qVUKqIyh6TTYG1pjpGx\nMVpRj+bSKrVgmFKpTNlzqPglYtdjdbVDc3CAsZEC0oOlpTmanRZHT55ifHSUgSEXXItzFy+zNL/I\n2+6/G99oHvrSlynVRgjDLkZK/GIBy5Z0wi62FeC6zqbXjS0NcSjoroOt82Kwi2H2WgNcgfRBuDA8\nOUqhGrB/tMSWisPUjgHC5AoDIqDipxyb17xwdpVuq4OlYtyaRWpydcpIFdrNdQYqI2jTxfETYkUu\n01U6j5kxeZxA4FQpVlpU0MSJodNNKbh5JLLoL5FUgaMFgWvIVJb7/AiBbYFwDLEGT0LWTug2XEzR\nUCkIkk12bdK4h7Ry3piwDapxkdbsIisnvk6WhKwuzTNx8Hu58NSDlIqScK3OWlbALktGR8oIfwer\niwssLzQ4sLvG6nKdxkKDGw/fTBB2abZjenaN6vBhrl48zoy8xlBlgaw9xPjewwRbYLn5LNLycAKf\nSrHAhTNP0WltZdu+Q9x38w5OX1ziwfPzOW9Pg7Thhu1VXr7QxnMspGWzrgJWzy3g+xGDA1UanYQ4\n7VEtVNg6HJBUA5557KObXjtK8Wo7q1/kbBxFrj/vPzH9NrQRG+IJUNk76V2waDeOUQwkzW4IwjAg\nu7R6TXZOegwP+5hyQJjEaL/M0OAgp6+u8+4ffBu//Is/wyePD/Mb/8c7+cZfPcza+lVGthbYM1ll\n+o7XMXjTzTTbGrdYpjp7mnjLFo7Wx3HGNeeOnSWY7jA6qLih1KI1vp2lhXlkb56CXaXTu8bgcAmv\nFODRY+fI5kw9tbSgF5KuNbGziObf/QH+rgOYOMY0OtSGy7RPRZjuAuWBIVrrXZzlOUbveSelvbeC\nSKiefhwxZ5EmM8m1AAAgAElEQVQlCh32MN0mMkvyRIVM0TMJ+uRnsFnDymLsTgtPOSTpAI25qyR/\n9WtY0yFbChXQEqd5mR3+FhKzTmfHLuzRCs8/fZ77y/8DvXiFRsMlbMdMiUV0nJKuxxApTLWAiTNE\n1kTFDmYgP1BX559grdPDf//vfdt5+M5ZWraFEIr1pUWCahXbdjGZJlKG2LhILUg6bcqej2VZQIzK\nPASCLI3oXujif+CHkHu2s3NnyHm7wkDrLLKWEOkhuisJiy+foHr7Doqta6j6IDIQONZRhourZCsP\nkVo1SoUeupdzY/bsXGRntMShYCdJEvHMJZdrWmO5FpZdwLVdhG3hSJuujghMhBdYIEy/3ZUvftsS\nlAIPg0ZomyTdZF9CJQRCENkCLSRKGrwwRagMtxeC7RD3ughpkdXKpDrGXW+yUNQMl32KYZP5C3OU\nijWCLKTjVrHCNiXLJk6iPIW528EohUCSKY0vLDITkcXXkF6RtUSiMk1HdZFCosm4+dBBlldX2Ld7\nG2EUU19fJzOQWLkTNp6PuLRA62Mfx/+x9+M7BZI+3BtJhUee75Uog9VPvvU2GayaD43Q+YxvtIsc\nwbdIpc9xlA1VuDZ5XpZtaYigXT8Piy/jF4coVxxWOqtkCyeZDAb4yR+8g/vu2Mpn/+5Z5tc7LLR7\ndFPJntECwoZYZ3RMTK8+TOop4m4Xr+Bi+pyk/H1tqPQ2vmBek70lrr9HTe4hpA19N/E8fy1Rfb7S\nJhCe16qXrrtEAxjN2ECZbQ+8m6999XGOfv0ZZoanSLKMgcEC9Xqdm266iSzLuHbtKnEccc89r+OF\n519gy5YJ7rjjDlqtFuPj45w7O4slbNq9GL8TMrNzB5XqGLYbsLK8yOpqHc8tksYZ0nVZaba49tI8\nh7ZWUWKWhhGcOX6OCxcfZmpqK29569vpLs2Rdtt4MuF73nYvt9y0jz/7X5+j3QqpVEevI4QITRh1\n8b3Nt7R6sSBuQVAAoww2AtnNzUyrFWi2DHt3TfADb/p+1s++zMzUMCNTVUzJJewlRHg88/IlHvn8\nizRiOHCgmitKOxkiU1TKNqkF44UiW3YNcuWiYnTQZjFdJ+z0GVxGoozCKI20bISt8YTEVRrXcxBS\nI2TOm3Hy+whpCtWCRAtJN8soOrkhodaAnZsnKmmoOA6ptskihbXJ6ckVsnnUy/rxr3Dt0b+Abe+g\npwZpr5yhMlDh9Nc+y8j4JMKGuUaCKAZYq8t0u5qyfInMGuemQzfSSiWe6jF24z3MXTiGxyVq276L\nTKUc///be9MgS6+zzvN3zrve/ebNfa2sylpUKklV2mXZ8m7wAtjGTTdg0zTGHQ66GWCiGSCIYWa6\niQbCDZhmbGgGQ8PYdNtsxrslS7YWa9+lUlWpstas3DPvfu973+2cMx/eLMnD2B6nvzWR/4iMyg8Z\nlZknz/ue5zzPf/nmVznxw++n13iMYWGh0hQpDI0BOM0p/OIwtfIY/bU2l6MrbHZeZLVXR4sh3jKf\n4zMnTZY/ZcOxQ0OcutJiuFwkShS24+JVbfJmFNXfoCHOUq2MEIiIr3zjEre+boWxqTJ1Zdi3y72j\nDaBfjX+F7BNtzCshq2bnZSP4llEXhkH8DqKzw6j+GXzfp765iYgl+bEJtpttJobLFH1FeXSS9mCb\nwlAZlta48tLj/Mjbb+U//eeP8WDrAExYTIWr3PeZ3+Y1b34/W/2A0uR15GcOIfffSqm+TWt9m9zk\nQQr1LbwrHV48s87RuRqdrVP85Z8/yY/987fRbnRZPPdN0o01ykM3ot0hlJPHuCG1MgSD3cW26P4A\nWShBbQzVqtP75hfpPfgZbL/I7E/+NvaBg5RvehtpfR3bshgXKbrTRC09zub2JSxP4538PLEq4SQp\nKuhDp45MEmRnCT24SHlqlMLm3Vj2EMLNY/kunWZAKyhz5503UK2fpxg8jJVuZpy52IIZH2kkw9MT\nrG6uM5E2SB/8W5SAkm3h+CXCrR4Ig/Y8hOcjbBDVGiIIobuNUa8eHu3I4qu//AF+8e8vftt1+K4F\nT5xKon6TxvnzTN90O3EcEwqBIzU528eg8HNVLGEDCSpJMMQYo+jWt/H/539N6Z0/SGWmwHrsEOgY\nS9nI3DDJZguTtskf24dnKdJSja5qk+8PKMmUgr1Ms19hcjylHo8y6m7R7Q0he2eoTMHB8ecYyNfz\nwguXsJBYZO1o6dhIx8cgSE2YuUOLLAAVNEaKVze9uTpc0Fi7vIyaRCFMhGtZBLZFrDSOLbETQ9rp\n4wCRJbBERNqG94+XuZRYnA0SfiANKVmCK2Gfg0OjnNEWOdew9vmvE0chaZpgCXClJBWZ7N82EqEU\ntjAUXIcwDnAtm16aZpwqW6CV4PSpM4TdgHOnLyJzHtIYvEGE9jWRSolNyk+8841sFIr8zV9+mpkP\nfRBXKTQCXwsSkaJU1nvRZK3z4Ptg8UiRjRev8nYssfOCkZkl/VVJOlc7KwACLCF2ih5BqANUf5VU\nVqi6FsnGCqWyhxkM0L0eoRZMOy6//ct/SL0R84nPPs3JS+ucNobhsuTokEvaDmmlfZRbwlXPA7fv\nDK8y8rS4Ogbl6s/Dq1fCq50YTWaYaMQrxU6qyD5SQZx8H6Obne9/9btqBBaaip3SjzX//F/8BE/c\n9yDtdpdG0CEKI5rNJlubmxw/cQLHsYmjiM2NdfbNz2HbNoPBgJGREer1On6+RLPRot/qceHKGo0o\nJu9XcHMJjp9D2x5GOGjdR+kUxysxNXENOatPPwiQqcXKUgfXOsyJY7cStQbEnYCFuUmaW2sM14ao\nlnPcdOJ6nnv5AkkiGM4NY+yIVtLcIa1+e+Lgd0OqAQt6q+BUsiJhdFQwGIBKIZ8acirk0//3HzE3\ncoB6I6S4VOIH3/MOBqbP2OwCc4HFG+cWeawZcf0Nh6jXA/JKIyxJTdVYXelRqLkYy6FQKlG0JFVt\n0U9ShNFE4YCgNyDIBfiyiKtzIAYU8x6TRpAMDCUbUm0yW4y8oBEZEi0wWpEkgGUI+2BssCqCgqsp\nGAfX82gHBaIwYKSyu0vWVSPYuHWFxfvuoRfWkIuPg51nfXUVP2czun+BxOTpbS3hVSaZGonZ3hLU\n1+pURkcQGOqXnqE6fpgwCigFl4gq06StBLF2N5E8ysF3fZDWxgXy0zdjVIyVq7JxeYVHvgFH7nCp\n9WMWSi32LSxQGn4jF88/wcW4x4MPPsGNx6a5oZrj2UaX2aE8/X5CFAMixJYurpWwvyaYOPBuPqQ1\n975wH/c+/Q8kHZ/nXg750Q98gk/9l5/m2qO793CSAvRO4P0rPD2TPb+v+BNf5RSSFT+egIG8i/jC\ntcTdVaQNjlGQCkq1Ca7ZP8eJQz7DyWUaWxv41TxXlpvMTFcIHjjPwnUT/OWn/poHkmPcYJ9ieRka\n2wEVO6U2P83cTVOM3/Im8kfuBO0jvBGq88OYiUn6jTa1+jOMTMyyvrHNVtfiXT/5HkaHi8xd/jy5\n1Kbfh9baCsLNIb08Q47HoLPOYODtam10vw+OjfBymHwB02ogUx+VJOjnH0CGF0lPPoSDJFg8CX6F\nROdJCiXi2xfID1ZwREIUbmPiBvmZm7D3HaBcjnDe/+/QeOjkCk6aolaeRp99FtKUIdHjYGUM0biI\nTJvEgcEvgSz4GDePGnRQXpVCocL0O3+E+Yd/B70NRgrKeYUs2zjDI1jVoSwaQ0p0p4NqNHBVjF/d\n+QVTw3YXngkdfuD2a7/jOnzXgidRkgefPU3RyjHv2hlR2WhSc1UI5aKShETFWJZFmoLU8M005fbf\n+SOmiJhNBefaffykzv74AmpQJFIdKFUZ3neION7AaifEl17EzVcxZpGz2mfJfQ1G30rxQsIt3vPc\nUqwwMbeEkysTLnWwpKYkH6OxXcEI5xWVlDAGtEIKizCKcR2y7ocRWEZmxly8MgLnqgR5t0QVE4WE\nqUVi2YSOpmpZhEKgCz5OlIBOkdIgLYtxV/O1bpfy+gYfPjaHrWIa568wnfepBS3eduNx/v1vfZRg\nbQsv1rgqQduZedPsvjk2t7Yp5iU6CAnDAaFS5D0bkhhShS8E6SAkjUJsUSJNY3zPoRf0yFsWxrah\nn2K7NoQxn/rC1/iZn/hRDpdrXP7Tv2DoX70f20CsDakx9CVIBUiBpaH/fcjSbSmwZfav3CnAs8nW\nzu2Kq/wZnRUg/6hBIqTGtotstWosP/13mG5AbbTK0kqbkg9uoYSRgs0AWr2Y3nabD7x+hvSN8/S7\nA+4/vcnlnqKbas7dN+D0mRy/tPAEdvk1GKMwr9wD/7FVu3hFMWYgC1Q0WYGjNahUEClDmkIUZx9h\ntDuBsTLgpRA7kma/y1iuhJCQ2hb1pMNQrsyhE9czefwQ9ZWzjPkloqCDSQfMz02SRn2aWwHDwyPo\nRFEq5sk5Hto3tNttkiSl0WiQhCEH5heoVUcYHhrBKRbo9ftcvnyZ17zubSTJgAfu+VvefNtRThyZ\nYW39Eq2NlM4gxiv53PnaOzE7YxllOXSTlHOLp7ju4BTjY4e5tN6iF4Q0NhoM+gnNeo/b77iV4sx+\nUsvLRkN8eVdrM2TD5DWC1rjhwvMQ6qywEEKQbGgmFxzOdZr8+Ps+wMzEGCfv/Tz19YRKYQJDj6Cn\nOHLkBAdufobN57eZHB/Gtnyiep+861CoTbBgVxgeLtIOFZb2uP+5S5y/1OP4TfsZxDEFz8HxLFwv\nTzE/zEZXs2kVOTDiMFMrsr1VZ63T4Q13nmBkpMY9X/06FV/QCwy+a5FzFZV9Bm90Z+8XoRNKNpMB\nedEnHABSsL62q6XBLfmEGyssfvaPaTS2GTSXqczsp728ROzPYeiQ3z6Dqt3E/iPjbC2eZr3Vp7jv\nVibKb8b2LTprG5QPHWK9GVJLXyZe+DFKzYuMXVOjZd6HiAasLS8z4vTpnv1rvNd9lLQ0w1uvfRf+\nUJnlc8tMFAucWs8z7m3x+kPnGB0b57pClc6Ra5kfjnnn/HmCmmBjq0+oDbbjUHQcXnvUxRf7mUkO\nc+6xh5gvVTmeFHnvW34VYzs88JY6v/Z//TG/97uf4qUVRRT85q7W52qUDDvFjNHZBYWMypmNAtl5\nrkUmUQ+Tt6DPHoR4FSmKlMujlHM+agCuFZOTmpe//MeMH7qB0WoZW3Q4POuw/tBDXO67/OWXzvK2\n41V+qpzy8mqZ181YfOHZgE+/7xfIlUq84YP/CyulfTiDMvlmD4sUJ5dDFsfY/zO/wdiLD1O4527y\negnfHUetfpVcQRKM3kj7pYfZrFncf+9JOlpxZP8Ydm6E08sR/a2n+Q40lW+PNEa3mgjbQbgOlGvo\nIMYqlQlOPoa7dZKzz10i7EpGRgsIEyBJyR0twcVTVN98C+V3fwi/t4FrF3GiM6TuAolqI8rDJGkX\nW2sGWiALo3jH309yehnnof/K9CAm2P8WVD8mnrwT6idx4g62E0JZ4LiT+ME3WfnTv6IlYGxYolPI\nl8CyW4h2E1oXszuoAUuBWwIxdRtieBZZzBG4R6jk9vFOr0axfuk7LsN3LXjOLK+x0lCUfU0QCTzf\nJRUxMs0OryhOAYnjOugkQmnNE695DfUTN/GpZoM7R8eI+g1qvoN0JDIYYqUpcHoFug5Ieow4PZAh\nXsWmaCUk5gB2tIIKPWakoZpzuSe8g3L3NGOdc2DHSAuSbU1kBcxVimx1PWzHw3E8hOUiREYkVVrj\nWxJhLCTxK+m5CIHZyVDRO6MLqXd3aBmjiJXARmCJlCYKmYLvSvpGoS2JTZZM3u72yXs2p/B4bqvD\njYT0Uk1sSa69boEvfeFzbF5eo2JZJDuOWUalpCqlP+gzUS2ztL6BrVOCNCEnLAZpSm5Heq9llt5e\nKRRIohjfkkijMa5LmCbIJMG3HERqSI2m3mrzax//c8YqNYqzU0QXXiaYPUReCpKsIUaAIaezl4bk\n++lgZG8arbN2smRnXHWVSMhV/sw/+r9NlgRtuw5P3XMf//X3f5Z88QAnDo7y2btPc+O+Ks+swdn6\nJpHJukGDCF5/qMrbKzkGVo5mvcMd1+/nwZMr0GgyVHEplUK2t15g+mBmsojJUof1Tnvb7BCVd36E\nbPhpTCY93/HbUUqQaINKIU6znKRECc7e98fwhp/7ntdG6xSwUUA/AVXUGJ2glEfULdMrwgMPPM/L\np5aYGR2lP0golwzT09P0ej2KxQI5v4jv57h8+TLz8/O0u23CdEC5VGF2dh+piqjMTOBaNmE0oB/0\nGS1XQBlmp6bxPJ+zi6vYMmVra43CjQe47YbDLJ5e4tRaEyuXY6veZGp6Et/3CQcDVtfWcOMGMu1T\nHqpge1UsRBZmqTRh0geTUixWmdp/DcVScdf7xnMEUShB+tiVPvGmQA0gVwR3RGCXUsIYwqBBvRXi\nFyyGKwpXajbWzhEKh/NJSqup6ZWLPPj8Mq12wHF9EefUBsrMceK2m1g4ssDTzz/N9UfHWW0HLG/2\nXvHUwnGIowilUqTMJLYryy3inoW10WXEUyQh3PPoc8wetChfB811kEFm7CYlkBf0bIh6BrsPUShI\nY+jbmbDBpALx3U1B/j9ofOP3KN/xYaLmWXJCERbHGEQ5xo7fRP+pb2CP3sx6VMXrdFhZewk5cRtD\nfht78gauPPs0YzfcTnV7kWT1JFV3H0Gax5z6IpE9zpWz2xR4lGoux9iJ96HXDPnpHyV36C7+3Q/d\nwjIw1e0gc5KGdOhowUBYHExybC9tMTczSrezwqmNFB/N1kaIEiCU4boDFe66YxrVy1PeOkTa2ma4\nOkKaDChVqjR6A6o1l7smJnj0Y7/Pv/j13yY2W7veO5hvFRpctQR5lbwsZHaBwQjQWeC1ak+ikoC0\nG2DlHKyoy+VmkygWXH/oAG86NsRKc5LhoSrlikvS2CZOEkrT+2gEl6nhsbY5YF8+5NDsPEOrL9CW\no6ixOfYfPsrLwRAzJUmqEuI46xp32hvYpJTHJ3HnjjJ/0wb9c5Lxa/ejencwWHqebstDKouxEjx8\nxTA2nKPRSnHsPnGjgeft7rwSRmGCIFsjvwCOjTtzAH/fUZL6Iqqzhp0v4Pf6WYGYJMRhhDh7lqn3\nHGJy5F709uPk3EOYdAgsgaufzS6u0UWceAWwMNYkJArd3KLTrZGzhjFhG178OuVcE9fTyNo8Ymgc\nEoXpt9GNLkk3wrVkNlVI9CthzVlnJfNxE8ZAAPKud2G97YPI/gW29Bw96wB4Y1x1edPzU99xHb7r\nI3dlq8dWz2NtZZv3vkWjE4mREmH7pGkPB42tDf1+xOVano2hafR1J9Bnz3M0b0AbziUDBptr1Cyf\nseFxrpnapHkJvBL4zjpGp5igj1WP8Lw2PV1kLG3SHSTctn9AWpik25rh/tL1POd/gDFPspB7iAPR\nPyBiF8uVWb6XbWMZGykyf/FYpVmyrDQZZ0NkvikZoS073LQwCC3RZJlgu0GapLg2CClxkqzAshKN\niCWp46IlGGnwpCHRknYvJI/h7s0Ox6crTI4ZLNfHtgzVQo6ZoSJho0diVKaIixVawr6ZKosnVxDR\nAK0lvpTINCXVCca2iRKNY0EiHZJeF69SQESCVBt0kmIrg29LemkKQmPZLmWl8KQgGrQIzoZ4h2ax\npw7QJZNZy6vDHiMYGDK58S6hjcoyfHYIgtoIpDGZb83VQ0Fk/kVCvKqXECLriq0s9/jNn/8ZJsfm\n2AhivvnSFTSSL59r48ss6TtvZ3+7Sglebsac/foFBv0B0xWfsYt1brgmzze3DP2OYCKR5KoRcS/A\nLtpozI7zdlaQ/eMWkzEZeTobZYFWAqVAp4JEQRxnvjBLp5/kwoN/AHzvBQ8q+32VEQSxIpaSxHg8\n8fhTXFo8T7lU494v/z0VN6JYGKbVDSnk88RxFtRYr2+jUkmhUMR1XRr1BoevOUy9vcH6+jpKK4Zr\nY/iOS7fT4cTxBVzb5vCRBZ58+in6vQB0QhSHhKki0YKtVpswVDilGtotkQD5HfdnCYSDgPX1TfaN\nFIiVphe0KFg5bCEplUoYadHptwh6HQbdDpFKmZqa2fW+SWKF5+XwPIcrsg95wAEVglMyDA17zEzP\ncHC0xMblp3ndNQeJgpAnv/4Z7nniDGOTM/R6EanRBL2AMErpR4q75iFut1naavL1R77Ihz7wdrY7\nCQ0RUq+3iQIQRqJUgtaaKLaIopQ01UxMjdJpdYn6Csc2WHmL6hSYcUGAImpIoqbBc7KRujGK7qZB\nbwmMBZbe4Y8kgjQF42YBurtV7X/uk3/Oe5wK02/+t5z+0n8jabUxdsqFJ+5Gjt2BpZpsrHQYyvXI\nHX4N3eWzlPcf4MIjXydUksZ9/wmvdoh0EFGdP0ihd540MlRnh7myepraXT9Jb/UCrS98hGvf+n6O\n/uS/4Td+4g08/BIMAVpItJYM0gQ3EqzFLb5xOuLo1BgbGwMO75tj8dSTiAIoIxGW5g3H53nTm/fR\nafZpni8xNuqQ5PIIpSFfJUoiiiOjxNEA33Vw2wF/+7/+Mu/8zY/teu+onUwJbdjh7PBKhyfr7rxa\nDkmZkZxNZ0ASe6SWi0bSbG7hS8MtN9zEwbEyfnARJ19C6i469gjrCaOvew3ysecoc5kYQ6oSNhpN\n5mYLXCoeodkc8Ka3vpX9199IbPm0m3W++dwicn2TtcUlnti6hO16/OzP/hiv++F3U7nt7VQPH6Px\n0tOUJm5AmyGs3hNsrGsOTgmc1FAtKgaRJkxiiBTV0bFdrY0IA0g0JtVopRFejrS+SHh+Dfe6SYxt\nUau6BM0eBVJ6xRyJl2JGKlQne+jqD2HiKPNS0C3QIagWBgfcBczw25GNryGXH0eceYYrZ29lo2/w\nZm5mNrkHVwzQAxfLCbHnr8W67V9hWltEX/0D9NIiQkmkY6FTAylQyBSzWfj0zi8RgfX+D8PCKNtd\nxXL8OsomRlbHyaUDnPBFhj0Btg/c+G3X4f+n4GkzWnU51YB6J2aokmJhIaMOAydHf6jAJgGtbpvV\nQcD0bce5eOoFJpKAg9e9iYc3VpmwYWJ8mhXbI9Ex+2KLvp8QpYpCGBKtLtI6vcztdpPJqYROvsrB\naZ+V8AAr+QJry7AZC9pxm3BsiH7hBl7IHcWe+DF++qWP0nn2CaTMZ7JCSyNEAkaSqHTnFq/R0oCx\nECK7wZkdIpshI8hKnZlP7QqpBjSRVKQ6q0hLRjDQKRYSV8lM2aMNOVKIIrRwaHkez6yscXyogrZS\nLpw6R6ffp7XdIictkkGMQGAdmGNeDPiV/3Ajd38q4uOfaJD3IG/biFSw78gBls6v8SNvu4mtZg/L\ntnjwqRewjSDRGpkmpFGCkIJ+rIgB47j4RpOEMcp2sI0msQ3x0hIpKV5q4YksNE9j00fjf59p6Vor\nlNZIrTFaZpESgBLgGPHKtSsbcclMLUHmGFtvRPzrH7yLoZrFaN7i5e2EsgcTNZfGekitUqSEYqAV\n22GCJSwcqXAcQdyVBEKythWSyxtkztATKeW8prfu4Ezb2N/K1cGwY7ywsx/EjrFgJo03JuPqJCp7\noUYqI6imSlJfX6fz0G+i7V0mOzuaTmJoDwx/9PGP8bP/8ucII83f/M2f0ms+iyfHCLYblCoxkdsl\nND3qTY9iySOIBhQLZYTILP6PHz/OCy+8mAXhBgPm52ZZXDyPjU1uYpKR4SGWr1zk9ttu4cxLz+EJ\nxVpzkzBR+J7L2PAoqeVxaWmdg7Pj4JZYrbfIlUdIkxhJSnOzQ6c/YNBoMXtsnnp9idX1DsUgT7k6\nxL5DR1jfWKYTJAThgInxMbZXN2lu1He9b3IeWCLH1PQUcWR45vEWri2J+oa8C2kSMTkxQ9+a5Eoz\nz3p9k35vwOKFdcIg5NKZc0yNF7NRbtolUTAuNc0eLLgDtDFsrNbpY1MbG2P/sZt56ZHHuXCujtaK\noNVCjo5hCUMcB5RLBYwReK5LHMbk8jZewaabJnQagkEvS0P3EvBklk3nuhadvkIpUFJgGYNtC0hA\nkTn7RglYu/Th+dqiS+FPPsrtP/ReTvzMr/D0332WzoWvkqvdQL3bx4pbzFU6aGVTiC9i77+esHOS\n4ZmbaK0tYo2/g81Tj1KpVolXHmQrPsD8TIW1M+eI7BFO//ffpVyDO3/+E4yM2/zqu17LUjvPagK2\nI3FtCdIiNTYqVGgPLm006EYhE4mkXV8hTCFqtrkch5iggbrlGK1ul6WXA7aXBPkISmqAlStDv0dx\nfJzm2hqzo1Wifg9KRTzH5c9+4YO73jv6qibd7MS+GF7x1Mo4rTty9J0vSxMfTBFFgvGqCAGd+joj\nM/PcfGSKo5OC9Re7/MPnvsIv/uJP092+gj06zeDCJfqlA3R5ngNlUEKy3e6QWA36kWTmujfS9qfp\nlfdxcTuh8bkv8Z61bRYcifBs3j12E18tbvB//O4n+dztN1MYnsDMXstIbYTNL/8FqRboqM/immQg\nBE+tSA7MxBTsmJdfrnP89huJvlV//70gSrLLpjLoKEZaDsH6FsGLlxixQ3K37MNdXWUwP0J/CHQl\nR0n2cU2HzskGo+VnkPtfC7kFdBgi/DzGriIL4xjpIRrnSHrXU//sx1k5C8Htd3KmOOAu7zxGg+XJ\n7OVvQJ35MvadH0YcfDOW/+dkd/HMC2mzLZgcA+FKSFQmMtIgjUFMjcBkm03zK3j1T+NM3Iqqn2Sh\n8WdQXgDPxqg2pE0E7/y2y/DdC57VLmNDwzQ7CecuLnPHrUeJU8XKRI3FYofu5iW0KdKRRcbmF2iu\nn2XU8snf+ToeeekUUXUIL3RoRF1mvU1it0BuehrZuIK30aUdt5kpF3BrRe5PZxkddxkpjrI8cZie\nTlCDNtGCxBOKvJenE/Rorm1SHhrGF4J/OPQvmXpwEdmRSCSWsHZGIqCS7AgTUl5lrWVjnh255FVz\nXM0OOXWX9Y6TJiQobGPwhGFgBJ4E26SkQuKZENspoJAII7FVyqBQwFjwpPKpbG7R6XXoRCF3P/IC\nqcm8KM0FG44AABPMSURBVGw7o6+Ovf1tNP74Uyze20DEilo+h2NnstqC7TAxVGH29VNYfhGjesSu\nxVCxSHu7jmsJEgO+bRGhkdqQA+I0JTUKIS3yQjBQYJHQO7PIcJJgSOlZgqK2cE2MBjwtvq+0dEeC\nMjobaSmzo8LKRluplNgiGykakR0KAp3J4i2H//KR38FN10jdMssDxcGZPBVpqGAoeBbNfp/UsgmV\nJhqklEdhEGryfp6D15WQwrC0btH3bCQOo0nMl7/Z5cRtCb6dZLcSXi249I7nM4AiuxkqnWVkKZ2R\nwdPUEKdZsZMkknY7YOvz/xP1Tpvdmjj1U4f1Zp/zF86zeulF/vITf8iRhVs4MjdDODZJqiVPXfwK\na/UWtUKOfM5hY6vO8nqP6687BsYlGvRxXZcXXniBUqlEs9nEEhbbG1uMj46wb98MQT9gZGIKWwou\nnj/PeG2G6YX9bKxeIUwVaRjjqATPcXEsl62NDVZ6K+SLeTzbpjg6Qtjv4GLRaDa5684T+HaK4zr0\nI8kzjz7NyPgEI2MTdMM2U/kjTEwusLJ8ERNHDILdKUkA3CIEgwGdbh2SjCyolQHbULEcupcMp8Ul\n7r3wAO/4oXdREgl//6V7GOQcFuZHCBqao/MjbDQDPMvGSQTlqM9qdZKV6RvJNS4A8NADT3Lw0D5y\n5RW6rU0KRUhVil8ooI1GyATfy1MbGeX8mfOkIdgCVKpRSpO0oZVknEUwWDLbK9I2eJZNEibEcdaK\nTwCZGEyaxUroQCBSg9ylUOKei5rrKx7ii59l8um7ueMDH2Gt+WbOfPmvKEhBp/kihbnbKIiEeP5t\nbH/jExgRUx1+iJx3gCAIqE5M009dLCvH2NQQ3d4WcaipTU5x3Qf+lP0Hp7j7kx/n8//9yxSmDpKr\nwfalcwxbGq0zr64oifFcjzDR5IShGQ+wty/RtTXb9TZBbYIfee8bqVU9Xjx5gf1zmpce38YxEe2i\nR9Dv45Q0s0NFoq1lKpbF9uIKhZwkWDG4E4ep1HYnu4bMh+cVns7VV9ZVry2xc3EzIosfSi0kknay\niNbjhN2QYt6nWBvjwPgI46WU7uWzeKLPD7zzbZh0wEtnznPXyCi52ijnzj/GwWGHVGk2+ploRtsd\n/JFDtLsBuptQbMWcf/Zlrl3Zor12lpX9k9Sm5hi7eYKfLp3g3LlP8NRjT/KGd7072xiWTengDaw/\n9HmkFIzkFf+w6DJWzsb2g8I083Mt/FKJQbu9q7UxqcZo80rRYxxN6nms2wG1zS2swQRxwaFQaDEk\nGoiBjesU8BzF08tlXvi9e5m5+Tyj0+OUDh0j1gpfpWxdXCdfzHPvpz8Ga3DgPb/K+msOs3zucQ6v\nP8a19gu0XBfbM5hB1q1R2xD96c8gp29H1TfABpEIXG3oDAxKC0Sssk5P5iSZndfl/Vjri0T69/Gq\nNzK3/FuU821Mcg66xwANzgykne+4Dt/dadkBIRRu0WFurEIax9x97BCWHlBuLWON7CcXlXCHRohG\na+QG2+RmryXaXKF8aRFHOMweKtOXAy7rMfrSR24vUfFG2fb7iOFZ4ngdf6FA8WJKwRpGulVM0EZH\nCcKv4JoBlmVhgohCfhQ7bGNiSOKYbq7K8okfQH7xvuzgMjueL0CoFNaOKFpqO+tayB1llria3bSj\nknmFwPy9wzKKRNtkqcMGT0NsGWg30ZM5/HaArnhoKYgNuGmKMSlOanFmfQvvyjlGfI8LW1tsrG0j\nMSS2jQkVtbe8nmhphbt+ZJZ9N8Tc93CHfqoYsi1EqlBC8NBjJ7nm2PV0o3WGKxUunV7Hsh1QWdq5\nsEAmKfZO98SyJcJxEEmMIQsKFcrgCYuw3sseBGHIxRrfKAbC4CJJjUKa3fd4HGmQZL4mSMPVkD5p\nRHbFRe4Y02ksbZDSAguee3GTZ7/0F1SHKiAlUz4cr/k8fLnDhjYcHnNYbCe0+hHzwyWK01USA11v\nwOyMxC5pOudi3jTnkpMwd42HZ3zoWTz8wAbH39qkzTivyjV2eDvwiv9PuhOyp3eKHaUzrk6SGtJE\n0O6EXP7cL1PMSUQgcWWyq7XpRRYPP/g4Lz13PyUn4vzpe6n5EMcNrrvpJ9GWy1MPvIAwqwzlXHK+\nzeqldWZnphgdnUBgsRrFSCEJwxDHcajVali2j1KKXC6H7/nk/Cwd+5FvPsi73vVOfNfjG9+4jzte\ncxt9JXn0kUepC3BJCQd9WnGL3PA0hw4e4MyZJWpDw3ieS6fdwRAwPFYkbreZ3TfN2tYGlaEi9eY2\nlm+hkew7eIwDB+/iltdaPPiNTxLHXRYXd8fMrW9JXKtPu90DA56URF0DMczfci333/c8b3jLtahG\nm5uPv4ZqET73R/cQkXDtkSnOv7TN5MQQzSDBtSRxamOrAdr1sfJjFHOXmByGpx8/xYXHT/HFv/sK\nt94wjGNDmiToJCTodWk2NBurj/Hyiy+ju2A7WUyK1IZgoBEpuBJild2XjAItBIosYkIrIM1UWkKD\nic2OVHHnc3YURbuAtgxfu6AoHPKJlGb9Iz/P5MHD3PG+f8PmWouVh/usrb5IcXaB9tOfozJ3iOZ2\nnV6UUHNbjI7naW7lcPOzJIMO2yvrzO2v8Iaf+0XioMujn/skH/uFL0IR8mNzREGTM40+A8Aam6W5\ndoV2lDAKFHIGY0uMVFQdh0a3halM8csf/zC337FAqiNSLH5c3sX25hZ/9Tt/wMKBPN3BANlsUg17\ndOMew0UwsSZOFIWxGYJ2ndWTz1CaXeDg7pYnU1Mq+H9fZjLl56tfI3ZGXAqjDaWqha3zNJOQ+ZlZ\niq7gxusWEJ0rrC8vUSumXLlwliNzN/Hgc5e55ZZjmDTk4kaMXyiyttQi52Zd9/FRD2lrakM+uWCJ\nUN3M2nNPc52VMPmmt2A7CXHOpbUW4qUD3nn8BL4AJQSWUqgIHN+nPH+QpSceYSBtHluXvPeaEOWN\nUUYgx8axPAdp7U7FliaZzcLO3T8jIPoOlguxaaMHdcZHXXy/hp0fxVgOwraQpIxPVXjQuo2ZcItv\n/sZfc8sb4cXNEmvWKPc+doEffuM8J37hYTatDi9deJHS8klYXqJdm6HfeQFLJuDYJH2JTkAzhBlE\n9P/+K1gjFmJ0HJGEWI02YzmZuXSH2Tmtkmz8KAQEpkJue4T5IxpT05jacYx3A0TboLsgdm4Q8jur\nir+7D48nkTKk0GuS9wsMbI+S1rQ2Nwl7eYKTz1I/eoRiRyLLBcLxeYIkxF9ZYqjZpnz7LXTdAd3t\nLgW7SZpoqEyw1l9hYmoSJw5wRYlAB4xfP4oO+hSsLmbQR7sVImK0tFBKY+V9gu6ACd2i4eap+g5W\nfwV14q2ov/s8ulDOsqu0wEhIEoUxOvPB2Znpyp1qX+14n4hvcdx89UrwvSFRCkskaOmQhhFKSmSs\nEFrhx4okjUmUztJxpSA1JrNg1QZdr/PYs4tUXYvtdh+hFBpDmiqEBeO33cilP/kES/tcvvH5iPNn\n++TtLE/LE2CkpGAbTp85y6/9/E9xcb3JS089RRSn2I4kNlkxE+kUS9oIJCqFtokoXR0txQNSHJBg\nWxZCKXLaEEqBnSryFoRGI4Dw+1BpSa0xtkLplDQVYNyMuGwbLCV2Zu074aHSYKGxhctn/vA/UPIT\njJFMiZRHL/ZZ7kXcPp/HuJJuKCmGCQU35S3H5+imkshyefT0MhubitsnLDZnDEHPUKmlfPWBdRJl\nSLG487YRnrr/bo6962eJo+hb/vYCvRMIqo3cuSmaTH6eGpSCNIY4EcRKcPKL/5FicJmOk6MfhFTK\nu0s6vPsrn+XUc0/Sqy9TKRWpb6+wfOVFcr4HJsHLV5nfvw/bu0Kh4GCFhutvWODsmcvUt5tIS5Cm\nKUZDLpcnCAasr6+TL7jEccShQ4dYWV3hwP4F2p0OBxYWCPoBKlQsLMzzpS9/kYVrb8B2PYRbZGV1\nhUG5hPQdiLo88cxpDh69Bte3iOOIfr9PGA3w8z7byz1i4VFvdoiVhV+osNFcR3pyJzSyzOT0HG94\n04+yuXmFR7761K7WptsylIYFnhT4ecPELFw+mz2b+/dNcj/PM+IX6I+OsPXy8+A5FIZBFzx69QGN\nVsD2SpvqUImw02eqWqTd3ebRe0+RL9ToNTtUgbojcIvQbBqSZo+yAwjJ8HAVbUks22VmZpLNlU1C\nso6x5wusSGMShVQCP68ZpAIiiD3I+wZHClKdHSyYnUJfQnaygY4z/qB4RVG0CxjFswMYX5FUXcNU\nJUfEEhc++ksced07OP6+D3KdU0AnmoESmDRgbblJHEQMun3SnMORY0P45RFwLHKeJOj3+ZOffyeb\nm5C4oMsOQvqsNerYVkrt8Ov43z/wWkqlCkUH1q6c55Fnn+Vrjz7K4QJYvsXSep/5W67hj/7Pd5DP\nQat9EcsuISwX2fDxPIfYHpAMeqytrjDmOERRSqsxIOgqbFcSSUHr4oDtdkSn06cQGO7c5fIkO/YQ\nxoC4mmVI9mxfdXhHGCzhYrCQTg7FOs0tm/kDd6Kkw4lDY1w/YTi/GCOsPJWxCo2Vz6D0zQQdQGk2\nVmICcpSKRTbtJrEwaJ3i54q4hTJTE6O858ffzYV4kmfIEVTKbKeSm47O0+r0SYwmjmKGpsZZunKZ\n9splisNjCMvDGj+Mu3Ka3uKj/NkZi1JO0ejDzZUC1ZFhet0BrWabaq2yq7VRsXolQV5pIIwxjoPt\nW+BIfHWapG9h/ApEVrZ3tYI05GDyEnFyiK/Yt+H+0q/QyklWtjtshpqf+uBRXj77Er3Ok0ysv8jE\npRdpRy6XBoZJbfO5C/O8b/YS2ovRqUW4BXHaxPEsmPUQjsDoLtZolThJ8aMIIxVktFtQO7xl4EI4\nw/SlmNOf/DRv+N+mMGkLZB3hDmXPZHEcY3x0+zR8B4rTd5elDwKc6XHGP/wuvnjDDPWkR/XyGaqL\nyzSPHcW77hijKiWedIk9H5bXODo9QhyEXFo4zMjGGvFgjUFuArlvjMk4pbV4HjUzS9hrEiJwpYWX\ncyiKBF0sIrwqgSVQhTFUEiNUQs5xaMQh6fYG28MTjKYN4lDQL0wyUYBGroRtwEaDyQipsdI4Mitq\nLDTGZDwNfTXqgEyaaHYOc7PLLoYgxU5gAPhGoFSmFEljhTRpNvcPu3jSxSCxbRu706NvewjLod0L\naIssHLGXRlRzHkSa4h23svqpv0J3Oyw+q3npyXVynoN2JR6ZRFpqjbRt8rrH7//hn+P6HiXfYZAk\n6DiBVGG5DjYCV0BqVOaZpC08x6MXhUhj4ckE21j0jQJliNIEXwoSY3C0IW8sukZjyd13eLIcLysr\naHTWfVNYWAhSmXV6rG/psliWyyP3P4WJEgKj+aHpHEXLpudIbjxUJohSigJOrbUp5QvcPD/BTE3y\n9IUul1shQXvAQEoeeKhPvmCYqjn0tcfMfJ4Xrwja9Q7ttYC//fR/Zt8Nb8Id34c2eif8Myt0jLna\n2QGtJKkyJCqrU6MEkliwdvob+J172R7koN+mVnSZGNpdCuQ9X/kojnZQA4EZmeWa6+9g9eIZOp0N\nvviFjzE0doROa5Hh8Rz9NEIZC2lp/JzP1ladYsnHsi08x6XdblOr1VBKUSjkCfp9tra26PUGXLmy\nhBSGQwsHeOmlkxy/7jj5fJ75+TkajTrVkTFSJ08UwOmLq1y6skW3FXL98RvQqQKR0mw26fUDktjQ\nrA8IQ0GjV6dcGSVMJKOT07RWLhAnbeKwydbmBabnphmtHkNHhV3vG2TGlUotQRIKvDxUyoLYGC4/\n/SgAW4/fT+ponn3oy+yPLQ7ddoKT6+cJ15q0t1tsr0uGJ4ZptTv4UQCJoTYCGo1XyJEvSmSkSSwL\ncor2DoFdpSmDfhevNIYUCs93qVRKzM14NIII6Wf7wVhQdh0m8zatQYAxAt/OChnflkTKoEJ2pKxk\nNuY7nYVX+PHiajfie8dbpx1e7sP9rZgfrNmsdRLakcWR0RwvP/IVzj78FYQE14NyZZbc8CTV8gjS\nzWPyEA0GrD+5Tmv7Chvrm3Q6ECkIgcC16Qw0lbxmqxux1IxhaIa3X3OM9vIFClMTDO+fY+jIYU4c\nPcJ73/wmfuk//hZxX3HTjQf4yL+/haC5iFFDKJFjoLs4TgGZuJx/4utYuSEubbYISwkhglreIYl6\npOkA23cZqhap1xv0RI52J8FrNHe9dVIld9h4ZAcmAil3GDwW2NImRaOVAqMQRjOIUo4ffg/7a+Pc\ncf1+Vl94gs/f3WRfvsv+2SEKQ1X6Goo5m9m5YVR5P8+fbpIYg+VoHBfaQUpxfJJer8NQJeDydsRT\nV+DM9hX+2fwUYb9NkR52muAV8rhpSsSAQ/v20Q0u0106hxSG0tgsJlfFnbmeoaNv4tTf3cMPTGRF\n5fbmRVYWR5k6PE398jrVodKu1iaUDl4Uv2q1EaWABTmJTDOer8wp0A1ktjxZN9IC40A/mWCudY5i\nOWbf2hL54RLGLyAvNrnLbGHu/RxWQWNkREU7iJnb8GoFXvjCc/yzWRAqRY+OY+k+xUIpc2ken0W6\nFlpJjEpxaxInXEHmQEciC3PcgRbgNs5y+bU/zbHbP8RXPvLrvP1Dt7P2zKOU5g6QqxbZuudvUSZP\nX0kOHf/1b7sO4vtJNN7DHvawhz3sYQ97+B8J348AZw972MMe9rCHPezhfyjsFTx72MMe9rCHPezh\nnzz2Cp497GEPe9jDHvbwTx57Bc8e9rCHPexhD3v4J4+9gmcPe9jDHvawhz38k8dewbOHPexhD3vY\nwx7+yeP/AaEKzhrzt/cfAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "Sampled completions:\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAAsCAYAAABhRmIoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8eYyl2Xne9zvnfPvdqu6tvaur1+np\n6ekZkkMOyeHIpESKpmRttmQkRrRAlgIoieMEiuVYjuNEDgwhAWTYiZ0/EgOKbUTRFkmwSVGkqXCT\nOBLF2dee7um9a7/79m1nyR9fzUh/iAOWgCAAUQ9QqOq6BfR33/ue8z7v8z7nCOccJzjBCU5wghOc\n4ATfzpD/fz/ACU5wghOc4AQnOMH/1zghPCc4wQlOcIITnODbHieE5wQnOMEJTnCCE3zb44TwnOAE\nJzjBCU5wgm97nBCeE5zgBCc4wQlO8G2PE8JzghOc4AQnOMEJvu3hvduLP/sHv+3K1gIdL2T3+ptM\nhwPipSYzL2ZfKNovvURSq+MuXMCLWyTNOo/FQ76YNlB5xEX/FodFSRCtQCEo9oeI9SYyihh3ezSb\nCotPljriwEfNumzIEdolZPE6o8GIWjRExg2GNsHdegP9yOO423tsPfows/4tYqMZf/5zNF5/jbBW\nRymwzuP6vT6hZ2gmCmHBOoETFoHCCMCCEQ6cBQcOyb/43WviWw3cz/6zX3SlUHRWVtntDphaiJ1g\nNpmRr6wgR2NMLaKe1DClZlxqikyz0Aq489JLZM9eI0sntKKQUApKB7k2bP3wX2X26d9lMB6hpMA5\ngS9hqKEuJdo5AiWx1mKVj5MegQJjNGmhmZc5i3HMLM2Q1mCdQAOxFMwsFIDyfTypKJ1ECij9mOX/\n8Rew1uFhcdainCTEUeCYCsm17//4txwbgM/+wW0nRYCWERqPUnjgh0gl8UKBUOBL8DxBEMKd195g\nvn+DL/7el7jzJ7/Gpy4uc2ua01mug81pL9bYP5yxPTTMrGUw1mTWkijFSqy43S8hrhGFivWlJmWR\nc+vWIWdPBbS2Ev74jyb8lfcs4gLYevQH+b6/9Q8oyhKHwNrqy1je+bnUDmuh0FAUkKaSLLc89798\nksxO8KOIjdVl0nTMfh7zq7/y77713PnH/6UrZ5JOc53dwztM84w4cNx843mEbNJeO0M6OKC9lGGk\nYjzIKbKShcY6d27v0Om0aC906B30cM7SWlig02kjlePatTe4fPkRrBVkaUoc+Tz1oSeZjsZ09w/Z\nOLVOs9Xg2ZdfZ/P8Zb70lS9QTnd48fnbaCfYWNlkpbOIF0LYsOjSMOumzA3U5YzlZo1xlqNlQpIs\nUhhLsrzEQfcOzfoSteYFzl+4ikkNaTrgf/6f/h7OuW85Np0t5QLPIRwoITAGZhOBlzjOJpbnb8CP\nXQa3GrK2GDLv13jsB/8aX3z+ee78yasc9KY8/eRpOg2fN1+/S6uV0OsX3N3NOXP1PIPxkGy/z7Up\ndGqCpUByruPzxmGGizt8/498gloSE4c+y0trvPKNr6OvvYDzBC/embHy6EW8JCDUGXce3Of5myMi\nI2n4oJQiDn3KsuTBnkYocD5IHAiBtSAcVZvpHE4LnLbfet68r+ZKLK8PDTfGhr+0IIl9R+grrp7y\n8D2J50vqkY8QJX5QQyiPIpvjXE5WeJSuQbT0ELZImVOj6Xb56lsx1+69wWAyI/NqXOvOAHh4fZlz\nnVWW2m1On7vIYj3hcDhhc2ONtaaPCCI+98bX+C9+ep10vku93mBuHLV4jVRrEApPBcymQ/73f/SH\n3L05IlaQJDFFmpMwpxkHGCHJ8hQ/iuilGaawRGHA79zoHWvP+al/rJx0DgcoJZAShBAIHJ60OEf1\nJUQVf1vjseBTPNb6ILvb93jpwYxb/ZLvvxRzcdWjLsY89JEP8cv/9P/kOz94mms3e7znqaf4/Fdf\n5plX98mmA8amxtKlpzh9+T2M/t1/Tm31w2x+73/M2uWrPD24S30v485en+79OxxMd/mV117jf/ux\nH6eIIAlq2IuOvrMsXnqMlbMPI1SImm+z8+u/xM/+D79GMxEYpfmOyw02l9c5+8hZhv0ptU6bR//u\nMfacjzXcjxpNiMM4gfUCtOeRDnM2WzNOfw+YGshZSLX4a5CVoA1eNuB6+WFU5zzPXtvDFimnQkOc\nzbj48CP0tg+pj79Os5khY5hOY/ZaHyDwSn7zV/6Yn/8R2P7u/5WVx5/E3/kitrTIm19CXvlO5Pnv\nROx9Edd/Bfugy/6vfpHmmkDkILy3P0OH5xxv1j9E8fOfYTXy+b2//d1sbD/Lpb/yg2ycW6L/5j6z\nnQe0ojr5/Qds/cqdPzc270p4dkrwXnmFrL1GY6mDvvgw01qN+WjG6d/7Hbh0EW91ncOlVR6ZD7mz\ntM5z9TXkXo/B4S0O2mdYi/YY+DFReZ+eLCmmddaDPqdWWvRNm0k+I7Y99GSEkiV9tYj06nh5j3AR\n7FyQqhqB1YSPv4/uNCNTPqNySM2WTMMYc2oT8/LzmNIH4WOMwViNJyVgEXhoz6GswDmHdnD0EhYQ\nVLznONiZCTwcw9ku1g8RWcYUg9CW3FhCY/A0jDKHCgLSWYkMPPanOe210+zoF6kJhTSWkXEIa9EO\nls19rvzlLfrbGZPJnOHhmMm8xA2mWBlilSS1FicqcoLNSa0kweJ0SegkeZZjrSO31QqXEnKnsA6s\n1SjPw1oDOGQQYJyhnM8RnkcuqvclhGXiwDgInTlecACsoxQOi8EicEKidYlHABoCIXEe9A4Oef6z\nv86zX/63nH/4Efbv3aW90uTAwsV2jYG1RPjcuDWg8DzONh1JEtBPJBNjGY4t49LiPAPaMpxlHO6P\nePRKk+aKT2d1Ba9ueP+ThkmrQEn4zd/6VX7gP/05rPNwDsyfITzGCIxxGC0ojaPUUJSCorTsvfUc\nSTPHFRFeVEM4RxgkJPP58XJnO+dn/uZ/S6Pe4BvP/hZf+Pf/NyZs0WisY7Ipdjaj152hfB+VlKT5\nDGk99vd3abdb1Gs1+r0BSnlk2Rzf9+l2e6xvLPPkkx+k3+9XuTMYcPr0Br1eH19Izp+/yOvXXuH8\n+XNYq8FoonqLUvcJ4zrbN+/Q6SzTXKpjTMbG2ir3HtxHKcmp1U36+29Q8VaftbVTzLKSzZVl3rq7\nB6VkPh0h5H32d+ssLy4RBsfPG3lUjCIpWF2BQQpe7FAW3jgQ1BcE11PHyl7BRl1x7VaX4kufZX9/\nRnulxZnzC1x87AqD7Xvo0uAHBonFiwSHOw8YFyUCSS1xqAicdVitwYHWmhuvXOPUmdNMJ2M87xaT\nbpdEW5TyyQ/mNJ+q01lbwrMFOgq43XseNQVFVVytE1Q7CiAcnhIIWeWWOCq4wgGuIv3HyptCsCAl\n5xNHagzPjx2P1sEvLS8/MJxZtPgyI15dIUzajGZjgppCFxLtb1BrzVlYvMJ0+zXqa1exd79KufAE\nH7qyyNWzS7xw4x6TUnNmOeWFN+7jEMy1ZWc4ovfaG2idc2p5ndt3b+E5w+b6En//v/oOutNXCeOY\n3EIUJuTGsNxqMy8KkAs06xtc/cg2r3z189iVFlmWYpDYWkRaCNLpBJQilJq7I0jHOaut4ti5Y43D\niYooKwnKk1TdrMRYBzikEAQKPJlwlR+g7T3OZz7/DC/sTGnUY65sdRAixsar1JMednqXqF7DOcuV\nh1d54Y++QX1hncDcwPkxFy8/QXzuPRA0udH4UR7veLzvY3+Jdjlj5e4IS8HZ5Qg7XmRtIeGTH/kI\nWTpDGoXVBQvJKqmdUc5nSClBGPBjChew1obxzFGLIM9zXFxn0C+RQUw5TY+XOwPHOJJ0pMGK6u49\n58A6h7UCjEMu/xh84uO4a59GZLvIT/0iYu93cb/9T+i+9YDTjRZi3KeeDrnw8aeYDubM9w5xwx5z\n12JRZCglUNLDTgeUzZjpHNDghjv4X/pH2OEYpn2Ev4/Yew79hX+ACGsQBAgX4VLACKyoyImrPjYs\nQABy2qcgQS6fZ7OVsvPqTaSxXL+zxxNrHTwtERff803j8K6Ex2vFJLKNnox5KwNSgevts9Q/wDxy\nAd1ZJ22tMB9MGI/GnFt33OuO2Wg0mbZWkLN73G+t4q+sU7w1oLFZ4yD3GFNnWLRIbMpyeYByM9Jo\nCW/WJUzqWGexnodLLdNwmViERA2PaZriZYbIU+xlPssrl+gM7rFf9yjzjDBpghPkxlA6iy8lOIVF\n4FuHRWGdrd70USAR4JxAHKuXAM8YhHX4zpAVJc4YjFJ46QxtLZ6xeLZkbnJqJqBZTpkSEjqL12qR\nTjIC30N6EqUdxloWI583P/cKX7rXJYp9lHMIUTWEYeBjrEWWBmT1XpAWg8ETioEzSGdASqwBJRxS\nQOYcpRDgDM4apFVIY3BKYZ3BGUvgCfSgD+1lfGGQQiIExIA2gD0mG6RaSDiDddUzGuFw1lEiEH7E\n/t4Dvvabv8y9l79ELUhpRiGzwW1GszENzyKNz/1hip8oCimZpo5unnL24ZhYggwUk14OXkiZQ+gL\naouWNJOM+xF37qR8/HLCS3cnrOqE5vmYnbcM585LHv+o4htf/jqPf+yj5IXFmCOVx4C2R2THOooS\nSi0oNRhr0Nd/C3yBzQSNRoulZsib93eIk+h4uSN9wqROnkNRFmg7oMwdSyun2HnrFVoNhy8l87mg\nFiuajQbTQUpZFDTqkjwv3tkJGs0m0+kUpRTj0ZgoipBSMplMAYFUksXFBW5cu87mxiYXL5wjKwqm\nszlFkZPOUzJTcPr0Eh/96AfJdMpsNkQKD1NaLm1dYEfucfdgl2YtZmllkYP+mCyf4Xke48kBYZyg\nXUoQl+wdvsFg1OOR7/tJnGkdO2+EhMiXtFsgfEctBD2CYiJIRw4ix58MAQTPvjnnzBbMend4/+Yq\n1+8NWdp4mEYtYq+wGAdOwFwqljo1droj+imcqUl2tKGpYNyDsdEUU9BNQ6kdYb3FQqfF5Yce5Stf\n+AI7t97g3NlFZBxw59YdXnrzVdpJg9kwpxhBcCQUIzW+UlhrAYc8IjqU4CRIV205iKOtxx7v0tfn\nM8FFpVnzYcmTFL7jK334QNPhK81bA8lqLSDd6bMQHKD8mASJ1ZJGzTIYrxCnt1Fqgf6D10kal+g+\nuEYWnsF5Po88cobbb1zjY+9/gu/72NM06g1u3N+j25/woYfWWG6vsrSwyHw65a37d/HrdbTXR1iD\nI8balFbjEsL3GI53iPyYVqPDPCv4nh/4GP/yFz+Pl89QcRPp+wwmc5w1CN/D5iW9WUZROtLCkhbH\nJ8tlUeUP0iEAaw1CgpQSQaWKCwlZ6bBmzmuzGfdeeoa1jTN43ZssNWs0Qh+nZ1x+7BEWliXXP/cr\nvO9Ch4P9Lu1mjSCpsdc7QHoeH//xn+PGNKG7d8iHPvgQT33s/SwFmnmvx/msixMOnY4JQ5+rVxZA\nxOyMuwjP4QmJsILxzSmL711koDUiUFBa8CTh+kPMclhbEGy0fRqtFvPhHs1WE5cXKHU8N4o3E+ig\nWhNGCIRzYG3VDBsJU4PwCkQzwkbL8Ce/g+8+Xu2JRcRwbFmeFzx9WtPsbLLwvd/B9c99EfPaLrVm\nwEjXsWYfz0EgS+qRpcgmjKbQK1dpvfZvwLuPzMCU4DzIRqB8ULUZBBY8UwkQDpzinb4BK6rP1BQY\nISnwiX2L0zEv/vGz3C0afOLDF9GDEflMkyrLyjeLw7sFKRARZbCEOCtZtB7m7g1WkhB97jSeiNhx\nJd3ZhHP7t1lb2mBsBCqo03vzdTaXT1FEgqC1gLm/Q2oiTNnEL4boaR3mz6LOnoVeTiYVyegW0itJ\npyX1UOL8VdyCYp7DkmcwWUo3BTtLiRIf8pR4lHHgtVg+fZ6D0idxVVEtykq+FFJWvEa8vXiqDuCI\n4GJwOCv4i1w2HWRTKqXe4juHsaIiL3mO0I5Ya+xoRGvBx+gU3e1SeD61zhLkGdZTRIGPsBZjNdIJ\nsjSjJiEMHZ40KCvRHniiKsCZtSSeJDWCSFmsUkgtKI0mkBLpBNoYAk9WBVuIo/FX9b4NEiUs1hqk\nscRRRKZL8DzcwSFhu43SjlIZpKnyrXQw4/iERwqJEwLloHDV5mOFxQGvvvIi//LvfpKHz29ybmOF\nL7y8S+B7nFpus7GYsCljClPw8Pk2X7nWIzWafm/OqCi4dl/QqimmxsMpR5RojBG0lxzWOJYuJtSa\nUE4tr90q+P73tHl1W7P3jZILTwUEZQRFj9t3+jz6UYE2AmuqwqRNNcbKtUNr0CVoXS3Qspzh8tfx\nvCZrLY9Op8b9B9u0mnUGw+N1WwEln//sr9KIz3M47BJGglI6SiyDYcHiYo4fGAb9MeOsYLWzSJGX\nLLQWwcHy8jIHe4cYbRkNh9QbDaIoQmvDzs4u0+kEZ+07v2s2GmxsrJOmKZ7nc/f6DZyFPCuIw4jx\nzHDh/FmC0IfSEEdLxFHCUtIgFD7zVotukRF6isj3iaMQv17DCNBOE9QCMuswYsj6aZhODrl/7wVO\nb106dt5QVhLIrBTsHzi8QDAfwcGhBR/W12D3DoDDAz7yXR9gdWMZr5wx1w8oyqJSD5w4UlQqIusL\nTaig9MCzlo6BdV/QuGx54sm/xNLdAc+8cJP2+lmc57HUWSBJEmpxxGBQ8sgVD08XfNcHLjANYpaW\nFnn55Ru8st1FSAc5RHHVOMW+B2isoepWBAjnsAaEEjhTjbiEOl6XNcoLrgmFNiV1HIlwnArhpbFj\nMxecb9p3lOqthYBsmrOYd/GUR6lnRP4hon4VVU/wJhHz0T1qrQUmD95kafMM+e6bPHG6xfz2v2Xt\nwodRRZsPX9qkFj3MdNAlu30LubFJLH0uLS4SNRLeeO4uZy5ptPBpJB0sKbZQtFunUQHsHtwiipq0\n1pf5D/72p/i1f/55lsKC6SRHOIvyQE8ztHBIYxEypBbkDPN3LU1/LrQDoatiajQoK1DSIoRFqSoP\njAFtBDp3zPSnUa2/QW+yx+bmJsXsECMWUXENaXKyIRTGJwpKJqWmGbapJRp7MOG7f+q/R61eYOOw\ny9XLZykNZGmBjB2XdBd/0KW0GqksQjq8qEF/NMCWBbYAGUU4Z1EzQ2Yknq8QKJwnwSREZy4BAqct\nN7vw9HqIsCXp3LBxqo2eT44VmyAvsEZhVNWMKjgi5lAWAjcD/Y3fwLv3NaTeRrbBDAADCI/Uxhz2\nC7bQSKfJvvw5tva/zL37jmTTg1qHLIegDn5UUC/GzLIhY8/nCzdzfqy+j8kr1u954LRD1kCUAqdB\nihThHGVZrQn39ujlqEbjAFMghEAZzRoPWFur89Cj51i5tIKYDylEwnxyH93ofNM4vGtW7QmPc1vr\nDLsj6uMhrG9SKMGgO2ceZaRZxoKFpSc+zIGWzOox+VtvcbETMwg0QbRBfXmN3s5nmQVNlucTimyA\n8PaJF9p416+RyAdYU+C1HsWPIozXwMUxByhqFLTsISJeoSwNVvgUCwm+nSKlwpqUIIuYJ010vYFz\nDmU9LHOwDk9YnJMIIZBHC0ICpbB4TuCUxGlbkaBjKjx7uzt4nk8UxSwGAbk1BDIi15a4KLA6R2Oh\nTMFYRJkTpxmHMoQ8RZUFJvLxHFghwDpyY6mVDmkMnpRMTUHgFBoolcBHYLQlFA7rPDAFUnmkxqKM\nxRmDJ2FSGgJVCTPWWRCSwhoCJIFSZM5S4jBFiRKS0mqYDBGFoRSW0AqsAqstoYOYYwaHP+2unHU4\nazDOoQEI+bV/+AkefughDiczXv76NYIwYDH2WN1Y4fa1Nzn0FB9Yi+mOCmazgnbk0bWOehQyKufM\nJx5XLzRJxAJW+Xh1x/2Ro1abM91P2VhfxJQa1xZ84eY2W2tNzgWKojdnr6tY22hz7dnf5BN/469h\n9dvenaog5RpKDVpXSpnWYHRJNuojXclBL+PK1jL37r5FHK9z/c59WguLx8ude28SBw3SJGKUpczn\nJatnN8lnmnHe42AEh70eDp/INemmc06vLnP7wQ5In9ObZxBCMptPaLcXKYqCOI44ODhgaanDma1z\nJEkNaw3bO/d56ZWXMYVmc2Mdaz3CMKTlJQhXKUlxkFAUJXsHu6xtnKLb7VILc5IFwU5/iA0jeuMJ\nj37gPfTvv4kR0Ns/QIYRXhySiQHj7JCtMzWybEhrUfLya7/H7dt3jp03pXFkJfi5IPKrfc4TAnCg\nIEiq7qQdC37gPZDeeJFnxw9x8ZHHEI0ZgRH4UmHw0Q6cFExmJe2kUlhyAzMB7RC0hDSDpNVmqkd4\nnmQyPsQJjZn3yeYFO9v75D7cGzkmBq69+TK1lVWSSNEbTMCANUdWQCdwzpKLP9NBvf2jEAjlcM5V\nREeAE8frtN4fw9dnhrul4pQoCQTEQtDxYKzhlYHj4QbUAsHB2OCQFKUj9DWjWQlC0xp8nU67AX6L\nnJh5d5d6u045vIGTAePZhNbGkzTOPoGxmjI3TAb7WCs4ffkKvhM0FWgZ4JzlzV+9zuiHGlx+r0E1\nOpS6IEnWKIoZ8+mQKOnQH/Xpp47/8Ee/g3/9zz/PbJ5V4xTlo3KDcQYcyCCgTEtQqlIxjwmrqxop\nJSCPVGZbkVBj3jbwVD49J0Aqj3k25+LmBQb9HoEf0+31+KFPPMF0XvDmH/wBl8+uY4qU/P6MuNak\n3tDcydZobJ7HWLh66SzDucZqw/pqAz+dMP2jZ8jOnSNI6sz2BoTWUboxs9kEa330kcJS83yENvgm\npH35PK7q0kEImlsXWA5KhC84GEnevH2fpz/4PpTUGA3J8jfTMP58hEd5qgEhHc4eeZqOJgLOgVBg\nt3cQAVjfQ1mDKx3kCpcERK2Y+e0MPylp2B7T+AmS915DDAb4RYOhaNBUE6QFX6ZE5ZS9VNJIAAPO\nOETNIc89gvG2CL/rp3F6ivmtn4HDEiczdHb0PByZ3Y78tkhQLifCYYRg0+1hpx3OX1gjKrfJZzXm\nGKxzJFtr3zQO76qLNVROdmebZnZAUuzjBY6uisgCB5M5siyxrToHu0N2NIwmU9RoTGOhzeLCArLZ\nYnawy3KrSWv1FK7pky6dI1rdJLYRtTNX2GtdZrD4KKmAnlpk6vnMywyRTZjPC2r1BvN5wTRZoxkK\nammJXjmD8GoYr0naCAgbq/ibW5hSgzJkReVBQSgQAke1ESEExjlUZSOsJGVn35mtHweNsM5y3MC3\nDqM1aMM8S8lMCrpkmuVYY5C5RszmKCFY1I5iMmfuwE9ifAHWWAIL0lqUktSXa3zPDz9B6iRPf+Qq\n61srLJ/q4FmH7yzWWAoNnrOQO4wxhAL8o/FceaTMOOOqPJEC4zTCWFJnGFtDqTVGG5wHKRbyAntw\ngLEGZao2SBXVuKswoP8CIy1jNMZYjDE4ZzHWUWiDlRKZLHPrcMRsntOqJazFitI5lO9R6JL5vEQC\ns1lJPRD0pilLiUJFAY+eTXjvxRYLYcDeIOX5nZLhqOTymWVQFr/tOLw3w2sGrJx2LK41mdShecZj\n+1XNC8/uIkpwvT9iOsgrcmMc5ZFB2WgwWlAUjqKw6LKkLKcwO+D63X2SWkhaZuQ6ZDRPQTm2dx4c\nL3cWamTzEuHqJLUmToTMs4KsLHnk6iZREjLLhggyrM4ZDecoFVPkBolCIphOp9TrdYQQzOdz0jTl\nzJktlpeXWVhosbf3gOef+zr7O3tsbZ0hiWK6B4cg4ODggFocMZ+MmQz7GAPb2/dZX9tEIHHO0UwS\nnn35ZQ6zKff2t2ksLJBpyWgwYTrOmExSjJFY6yNUxtraKkm4yHQYMBvFuNJwuHPj2HkjBCAFTjoM\njtnUcdA/WpwlOA1nT9f5gY+uEcZ1dp/VqL0ujcYiixunkL6HEg5jDcKCwENJR70WISIf3zpyHOsN\nWK7B3MIzf3KNfD4iTQvOX7jEysoilx66zKVLl7l0+QIbMawDRQjOSpzyefbFa/zBH9+ATGBL8H1B\nEAjCwH9HZXHWvTPCQgKeqDpmqIjQMZfVeiB4r2+5lxnulpKZgUA6IgGTUjDT8OIAdmeOnYljd2LZ\nmTr6M0tpHd2Z4u7A8frdCbdu32XUu8usSGnECXLlaRY2rnL+8U9g+rfovfYV9m7cw4wHDLdvQTqn\n39tlPLzH3b3bjG68xGT7De5sv8x/83e+xuvPjwFBHNRwpkDIAGMkRelhtKMoUmZK81M/90nSsaa0\nljwvybVBKYV2itk8Bxypc6TH33Kq8b86MiQfpYxzVaG3BkojeNsj7knJ6DBB+QGNJGRx7SzSD0g6\na7z2ym1ckRK3mgjpoV3AQj0m1YrSenzfT/w07eUlgiRBeAFJ6LO50WahFhHUY+4ub3DwzPPIWYYX\nBWht6QpNZh02jAgbDYRUCCEIlI83mBGuXKpmcghcMSMfz2h6cL0nSXxDEARk8znNRh1dZkcezG8d\nYdVTU9rqezUQrxoJqwXOgK/qqPoa1q5QjhOK2xKzDZQCfIUnLCx3cLUY01pBDHbRLkEqCERGqRJc\nCVhB6GlM7pGXsBkbpF/VIrHchKd+CPnIFYwNcMUYKcrqeY68O1a/3TxUD+pcpbxrGTE00Jul7F/4\nSaRnufDDT7P1C79O8P7vJE81ur7IxkrwTePwroQn6R1Q69/FHfYxnUtMMkPZ61FkOYP/5/fZDD1q\n05ICTT7sE+/vs7p1ltvE3Hch6yrD4nij68OdXcjGFM4h0hKnJsyzQ7TWOAL6QYuJGVPD4kufhVrE\nYmwJshzhCaaTDCMUpefQE00966K1JtEK8hxqDaSqEjwvDL6oTkeIyiGIQyCtrUyRAK6SmaUTCCeQ\nx2Q8RpeM8wzlLOO0xBaaUmuKvEQXOTadM3npVYb3b5OPRpRBzJ3JFD2fI2YTPOHIS402lsKWKCnB\nOq5d32atnaDHObosEQJOnVoiKzWZsVgsvnAo5ZOZkqK0mFyTWY0sNViNZy2lFUgHWIcpNQpBTcqq\nGACRcOg0JTz6+8mDXZQxSG3RWpMX+p15jirLY8UGwDiDtpbCGkptMaXGWMe9Gy+zslDD6JISSJSl\nFviks4w7D+6hneSJrQbrrRDrCZTvc/5Ui0Ip2oHH51/IePnalG9cG6C9AjEfcKqhIR3z1NmzPLG1\nxaX1JRaTgmTTZ2MdWqOcO5TSVi8AACAASURBVLdnpFmJcwX1SHF2I2Ha28MejTxMWZHFvBQUBWCg\nKA1FMaEopph0By8AV2TsHIxZ9AX7BwcESrDYbBwvNqqk1m4TNRTNhQ5Qo9SWIp8zGN9j8/Qpzp29\nCsKnlkSk84w7O9toC5cuPcz6+jqe56GUotlsUhQFBwcHCFFtEqPRGK2PcmdzkxvXbzKbp9y4fZev\nP/sci51lgtBHSIPEEYd10lnG7vY29XqDs+fOsru/R4HhYDrG+ArtCnZ3brG4ukijtUhRlAwGXW7d\nvE4nruEZj927QxK/QzNaphEsEVv/2HlTWIEXgC1gOBHMcoHWR4bTmsBoOHd1k2L5Ib6x73ixgDv7\nGbdvbnP/YMb6xQvgDIPRsCL/ErCW8STDc45SwMyBkIJTNUFDSW7cukfkCTzP47mvfplXv/ECd29v\no3WJ0Ya0hF3pqHUUZ69c4aELD1FoAylHZudKSTUGCmvR+qgYvV10jxRkAchY4N4+rWWOOdLS0A4k\nD0WOoRE8sJKxgUDBSugwtipo18fQzaGfQ28O2xPH/aFgmlcK5jy3DDPFKPdo1D10mRG5A6LmOvv3\n7pCsXGEyBpGOmRweoPMZe/s3Odh5nb3D++hScmM052sv3OWlssku8PJrJaQeg8kOxmp0kVMPlkmz\nnCCqs1BrUotbfPJ7P8A+oLVFWINxglQLSm2wuErpzkqWT586du4YK7C46hQWR6qtqwq7fdssfvTB\nOGeJl2D1dAMcxJ5lvb1Aux5z7sIqN/ZTLl55hOFoxMHeAc3IIYo+fZ3QtXVu3uuR1GPAsbhYJw49\nAl/hKcl7P/HdfGXUIx0PkKZgWCuZqxykV40z/YAwjLBS4XkSlR4deghjUA6nDSoI6CyHeM7gnODR\nC2cJdEFeljQaTfJcHys2wlYWjrdVLmdd9SzOoU01/pU1gdp6CP/9HyD81Pfi/fWfgI98Py6fcv1w\niPIEdnWd0cZH0Ebi0jH+d/40niuIVYH1E8qS6qBMNaAjErCUaIw8mjakDvvl38L8/i/j/s1fRXz6\n50FValV1NubIoez+1MLD0XoRRlZWCecI1x6jP+whpg/43f/uH/JPP/MqeXuJRt1jcLf3TePwriMt\nP2qiT50m2x1QZI7Y5chmiJooGv/JTzKzIeNhnyJKEN1DlLXslSWeJ7mwtcFh95C8n7OxtoXwh7j+\ngDOj1yjb5yjGKdHiGmvlmDxIkEDdSWZxC2VKUpez5AJ2jMHu9Wk0WxTZlKXOGomcMvJblLnCTjSR\nf4A+s4l761UK7ZhlBfVAgZQ4V42sXKVxYmUVUInF8mdmWe54JjAfh8MhtSGRkhCBtI7SCYbdA+x8\nRlJPKPYPsFKRtRewozFu1Idai0YtJp+lSClAhcxnGdLB3/yxD3HlQ1f4pX91Ceeg3x/yG//6mUrg\nM64yGnuSeZ4RSonCopUAXTIXVEZnq6vOBktxdCJYOktWGjxPoaQgdQJPWKQFbTPcXpegyPDLEi0U\nTio8HFMH87+Ah2eWFhgshbNoq9AICgcPXvkK97oTlushYVhJ10MPzm018Vsxo70h55uC3jQnzw2X\nliJmpWOtFnJ6JSaIchCgAstkrlhfS9gbpmSHU7KDCBlFLK4J8kyTBBEMAtqJR52SSx8K+Z7mEjcP\nuzRW2uTzlCyvTmFpDWUJpnSUZUFRjtDFlLLIMCZntv8WG4sd+qOMcWYolWBro83t7R7t2vHk9/Za\nwivXPkMSfA1rfNIsx4kuGxstGuEmWiUki8tsBDFCp3zH049zMOpz4dIZbt+6xu3b91hZ7lCWJXEc\nkyQJnU6HNM24efMmQRCglGF9Y429/QM6nSVWlxbxA8lbb92g1Wyxs3Mfa+HM6S3ubO+ytrJBaUv6\nvR6NWszpzVNIP0AEIfN0RqcWoJxAq5i9/UPanU20mHHx4U263T7bD0b4KmFjfZ3nXnyJjspY7hxP\ndgco5w4tBYV1DLqg/swOZXHMZpDUE7r9MUMv4m48Ix9MuJDlXHjoAouLbUzvFm+8uUNSQm4Eh3NJ\nmpcsJ5LvuxihZEhLRuxd38evh3zkiYcqJbS8w+Pvv0qy0OHq5UtVZ1nmxCF88PFHKYIHGCXoj8bc\nuX5YzQfU20fOBaWxkObMCgdBVUQqtZV3ZHl3pMTao+nFcfDixBBKS2LhlDCkDvapxmmnJSwHVWuX\nGrg/h9IKFgNHKGGQW3COQAo8CaHnCJTj1sASqwNqUZfo+psIHApDqxnS3QOb5WydWibWc4bGIy1m\nPJgqnnmQ82IP/uu/9UH+xY+sEQYNrJcSeAnOzYjjFtoG1Bw0mw9hLWTZgHpnkd//6t/hwx/9J5wN\nASkptCUQIIVinmqe+sufoh3Ex84dnEMgEOKI6DiB4E/JJkffrRNHR9en7Jn/g66IudD5Z6yc8ol7\n13n+5bs8/f6LBJ5gNhjhshl70Xn0lZ/AhTWuvb7HYO+Qb/zRc9iiZOviOdbX13jysTUatUrh+qFf\n+iU++xu/zZPLBjspyYoQz6uUB5fnWGOphQEqCoiaMRYD2QybTpFRggwNl68+SX/yh1zbdxxs3+P0\nuXNEgUdpSvJZdqzQJFRXjRjkO2THVvwSbR3OKTATGH8VBkcemgLEPGZwqDicl1gn+Pozb/HcquUX\nL2Ts5DHLb3waHS3gFzlSCNIZyLrDi1MWOh6rS46thZTbh0u8tfAU39X9NGI4QQYSV1Z+ZBcUoKu1\nbo54nDl6DQHSCZxwCJeB0QgkxDG7H/z73PiNf8XtXo/3Lizy2gtdnn5ojT/+zB/yH/3cnx+HdyU8\nunQUqcFRIsZDBsvLDNIxC50mtvToDw9ZKKuFVydnob2IDhSdmgMzIW+uIRghBndIXIkfQB+PNTVm\nXG+QzjLqUQtPgZERraRBNx2hw0XkaAfRWafTWmSmdyjjBkmjSek36alV6tMHTL2cTstH2w0WtoaV\nxG40tjR4UXUkUSIQFhwGIxXKvC2Xy2qUdSTumWNuPtZVxmTnwJOWHIGSglIIojBgPjZ4rho5+b5H\nMZ/jrKNmDSqbM5nMCZUkL0tkAb6sxjnlbMbowW3iRkwUBZxeT+h0PO5dc2jfoiwYU90fg6s64sCT\nRzK6wViLLxVaOJwFicIJjRWu6nhsdUIg8T1AkOsS4Rwi8rFaExiDwVJ6GoHAtxL/uDszYHVJ7hSF\nq8YLGjAostEe7zvfoT+Y4ZCUiURPCwokDz+yTv/2Hp9/a8x71xqkFlYaC+xPJjy03mCnP+FD5xcY\n93MGJkNLzXQ+p9NU9EvLg1lOMC/ZGfosNUL6gSHyPYw09A4c+Z0xq+dTvEZIVDgMIeUR0anELEep\nC4piRFlMKMuMMp8dHeGe4HzBre6cK5fOotIx03SCJxWD7Hidup63aTVSth/cY6HZYtgf4wenyeaQ\npwW7e3dYaLZITUo99sjyEUtLMXkxRJuMdKrZPLXBfD5nNBqRpinT6RTfVyRJjWazyXQ6ZDrN6LSX\naDVbNBoNOp0WS51Fbt68RSdqM5tl3L23zWg8JR/3WF1fqY6/mpLpaIDwInw/JNKamh8i8RlNp2Rl\nynZ3SFyTbJ1ZIlnw2fTWGPQ0K6fOUbt7G5OXPOgfHjtvPFV5XZSiuiurgHekkqP7kbwoIh1P0Lq6\nV6JAMOj2eWUyZmVpka2OwvMlOrfgLG0PhKdYWYpoJqCNZDrN8CKIWzEPDscEQhD4AcPBFC0SprOU\nTjshqdcpDPQebHPn+i4mHVMIj52dyqgubHWM0lPQCiQEkqnWR3ftVD46AeDzzghLUB0f+9ZvJ6rQ\n1ZaakqTW4VlHjeq/yaxjiiAW1ZCi7oE4Uo8GhSOQUPMECkmoHOFRkSsU+AZyzzEtDYlfqWmZkUST\nEicUTgumdw+JPMnu1HJtCK8P/lRd+J5PnmE6nRG0HUakKCtJU8N4NqReO4UxAdlsn1RXqtp83mN9\nfYG/9599F7/+y18lLwzaOjyv8gaZCxdZatXoHXzzLv2b4W2y886/RfW7txV+QXXq1UF1WMWCkgIh\nUl65/zOsrP4CW63H8Lweo7njlLXM0zHTGYw++OPEXkTkSU5vdDjc7tJqtVleb9NZXmDz1DIWSxT4\nZDNH5gyf/Os/yPQz/xezucVJgSkNKgwwZUkAqFoIvoVhisCC0bhSQyRQUUT70qPsf+YPuTUIuTwX\n1Ho9WotL4DcIwuOdDHVH5Oao9T9i3Eem/rdDJj0wBhtJCCoPpydT9rJFclvxgf5ck9QbKDdmPi9o\n1VqUxRApQpQtMdWQAWksyvOohSmy0eKc12VRP0PXnWbF3kccLQynDU7Zqkbb6mQjR55Wd1Sdofqs\n/LyPcoJ67ONqEYcPXqOQCfPte3g+PLbeIssF7eVv7ql8V8Izmw+pr9ZxrQ5Zy8ft7pH7glrnNAMf\nhPIR9Ro2y6lduspwNuWUnBL4IeloSiAK9NxQtNcgnSHzObQ2GThFLZ5jsh6ZW8SFEsZDekZTqyVk\nkz1UrUaZ5lDkTIVPYCR5mZJLHzubUfohXlxjpFNiQuYrp4niuBpv4fCVOLpjpwqZO9qBjKjGWMoe\nmYWrke+Rq+cYMIYSkBiyHKJAkmpJ4AzTLEWUmshKMinQRUnkhcxCiT+2dEddgkKTxAGhlBSi8org\nHP/+c7f57GfewqaOXEFZWuJEonxx5N+ySCfwqS4RdM6RGVuNxIxjLgTG2Io5G4uU1dhMW4snvcpT\nIx2mrI6tW2PxfR8yTTabITyvujAssxgFwsl3VKJjhUdbtHOUxlC66j4fqQR+GDNFksqAjZqPnDpq\ngY8fSFpBjBAB3VlBUAuop45cwlyX9IYls9zh9aasLERcrS3yYFjghIcpFC/3xrQWNMK35POc0HN0\nx5a7U0tRahaaCZcuR4yBXtfSeRisk0dkx1GWlrLI0OWUPJ9QFnN0kR0pPJpsPmcyz1hq1JjPJyxE\nHosu5sKy5ebu6FixmXUHzCZDOott5tMxtVpE/v9y9ibBliX3ed8vM8945zcPNXX1iO4GQACEg5RM\nB2UNDMumltLKG9vhrVb2xhtFaGHvHN544/DWEbLDCzFkmwxZIimTJgmCAwg0Gg10dVXX9F698c5n\nyNGLPPdVk3R38OFEVNd79Ya+N0+ezC+///d9/8qRyhErLalXM7YGCu/nQEmRl9TNirqxFHnO4rrB\nWsN0OuXw8BAp5U0pa7GYo5Rie3uHsixxFvqDIZPxmCRRfPbZCz799BFvPHyT1WpNmqW0TUN/0EcI\nwdXlJYN+yfHREWcXU9q6RSnFyfMXTCaHTFdLimFCYTzW1kjlSXOYTht2946ompYkTXEWVtXtTqEA\nKhGUvUic8Fefya5EIVSCDx7rJMMhpErw+7/1pyyBr7094td/7QNGvQLdaedQ0C9SBB5jPcYa1p2w\n/+lFxfj+GJGAlIG8lzMc5PR7JUY3tE1D28DJySlNs+CnjxacrwAbc3c2V79UGCkopML7jvHrSmfQ\nnZgdf5lquOWSE4Jg7QItkIdoaOsJGCC49mCEYNw5W0YqkAhBZSNbNmsDiYTCCwoZSCXk0QlM5iCR\ngmUTUAicCCxbcMHTGkHlBSdLODeScb/Ht9/eR3vLZK9H055TlCVJ3idR0HpLlvQo831kMib3axCO\nVGY07ZThYMLp9Jr/9D//D/hf/qffwQcZbeMhcJ1k/OK77/P24QQRbhlSRMxw6vTKQHh9fwI3YOcm\ni2TzdRFdunkaWCz/ex7bD3hw8E+Z3Btx/fT3WNeBk2/91wyEJE0TpFJkueLDX3yfLJPMF52sYWU4\n7OdoY+kP+7RVhQ2OK5mTBkviYxZc0AZpAxQ5vq4hpNirCnF9SpomUcjeVAirmaQtx/dKntaGQZHQ\nGIH3BiladHM71j2ETt4hvqCjDwJJPBgHK8BaxN4AuZVDliKkgsJxVnmyoodQijYIcAJFy7oGefwu\n7gf/GqGiucCLCCZjlIlnoEAIh0MyVleo4GltTpG0iOMDyAp49jQeDqwALwgdAgthA1hjWbhOBmgJ\nhVSoJOHyL/6CK7NkUC/YbST9411cI9l/ePSl4/DVtvQ7D2iCJDl9SZjsIoaK+3pBpme0155JmZKG\nnONBn6xasVVfwe4RYvuA0n/C2UKS7WbUs4qenpEGSbZ6jt17E7925K1lYB/TZN+kLjxpb4KqZ4gs\nwyQDWmPIyjGpmpNf/ozVes1s/+tsT59TH72DCbAlIUklKj+kGh0Snn5CIiBT6gbMROAT/6s69Og7\n54cghobdspxO5hxeCaSxOJXhrEd6SyslxXqNUhInLMfjLV5eXyGbml5esG6usIsF4yLduO5QQmBF\nXGBEqqh0S5kLshCLmd5bPNFGaEO3KMn4uiWB1rqoE0gSUmdwUtI55uNER+AINN6RJwrv40naeg8q\nPgFCeMTlJcnhYXQ4bcbO25/DowWN1ZiQYJxDO4kDlJdkecHjV9coIVimfY4nBS9rx9wGfu/ffY+D\ncZ/z2ZKV9WR5hvExZ+jyesrB7oDFSuCswXrD1jDl0YuK66XHt4aTp2C9I8sVVJ7do4REOgY9z86g\nprI9SikY7ikWly1BSLQOWBMwpka3S7RZYZoVWte4tsUaHYXUi1OkC1iRkPiU3ckus8U15yvYG2/f\nbu4oR9A1QYwo8x3miyUP3zgmUxlFvk1TfcZykXJ8dMhiOqNpNCYEpFRYZyl7JScnJxweHrJarUiS\nhDRNSRLFgwdvcHV1xXqdUtcNvd4Ao1t+/NGPGI16vDo9ZXtrQtO2TGczWu3Z298lV5GJHI6HGK05\nP7tge2ePk4sZB3fvsDMecX5+zmjc5+TVI7a2jiD0efHsFTZtOD58i8urhlfnJ7TakogCqW6fpeI7\nnr0xm1n3BVTgQdfgjEO3Dm0twzGoWpLuK5bnlnI4YrHQLOsofLfW4oRiXObYUBMQpJnCB4P14POU\nv/MPfoXRoMfTk/+V7cNjsjQjUZI8H1DkOW99+D7buwmjq5+SFimLumKVbDQQMdrgYu7Ym8iYI5Qo\nmtp3vEL0TgTXuWBs3GhFAsltWWUfEAScFKyFoCVgApQExggqDxfAXhq1LIWKDHftQHtwITDXgbWE\nXAnySEhReKJ9G5AiCn5dEMy05+Va0FhBlkGeQVWvODlfsVrAP3j7ASoIiiKlbedYN4zaMEoSSpp2\nTggW51NMEGyPJ9SNYWd0F2ta5hbKqDSgNo69d97izTu7nK8a9sZfLjz9ssv5QLCRLUJJ5E2sLEix\nId1kpxsPsUwiI+BRqcDZBXX1Rzz6/I/4H5/+XX6lPyTzBclgTK/Xw0M8PIaU1niKImU0KPDGkkgw\nrSEfDDDGdnloLjplfRJFKNaQZj1EIiF4nHe4VUvzyROGIolltt4QVEqorjn5/DF3797FfvIps7XE\nhRaVp9imohzeTjdINxKbw70KEPBIKbEenIewVARdwfPV68fOwdNmh51JQQhgnEU7Dd5TVUCz7IJ7\nYwnXdCGuXntUkVDKFuzqhgkdlVNMk8fnwUsIZdTC+S++yPh52NBxImqQEhGihMN7qus5yf13ufj9\nf8kv7xSUWU5bRbdf2f/yufOVgMdVNbme4e5sk0jJWGzh/Ijq4pLDvUMuL1dM8hY1SMlbw3D/kDKp\nuLo4Y+V2GY80Zr1g1C4xyRhlpoitO+QukPYLZsO36FUz5npOnibkrmKVlbTWUsqU0fJz2vaSbHxI\nur2FGO3gFxV2vMNWWXJ18RwjMnQ/wdZr5Ptvsvjoe+RFL05kIopVCAIWGeKU76qAUQV+M7C329Zd\n8OQIXKJIgkMKQZAC5T0i+FhaaxoutGZ7e4vp1TVpr0DagLCWoldEIOIdMkBBQiUMynuc1TiVYAQU\nQlJ5Q9I9sEIQ8x1QBBFhXColIThWTkda12hQyU25LoSADKCkRCBIpMAaS8CiUJAqggdXLbH+ACfj\nQpkEiZbh5woedNZinMZ6ifUSFwIuSFTRIxEpk0Ly+HyKfHCf7V6FWLb4Xo/aet45GPPTkzVlafnk\n3COTwLLMCTUc7ARqArVTXJxafvLCUbWeug0MS4VDYLTHJZJ1JdgbSIyXCFWyuLLIPGeUJaQubkfO\nOoxpMabC2Rara6xp0U2NMxprDdZYnFmSyQmuvSTbznnx6oy7ByOm84bFLfVfLmhkgMvLKf3BDqC4\nvj5htSqYjLfYGm8xvVogUQx6JVfn13gpSdKcshgQjKYse8znc6SUlGXJfD7n4GD/xq21Wq346KMf\n8c7b77FcrijznF7ZQ0nFaDTger5iuVxgnGQwnrBYLJBSMBqPSaVkd2eHsj+gXLZIIblaTEFYtnfG\nrOshw16GFAn90ZDr9QWLZYP38OaDuzx//DN8EBTF7RdlqWKZ1ru/SoIIktDZagEfHASHUqAkpIMU\nzh0/e/qKrVFOpaFwIKXicun5hQcKLwoEmjRVXFk4SGDdWn7jN34XZ+HxowXf+90/ZLA1JkMzHA75\n9LPPaVYzRvuHZJOSb779BmfXH7O6Mp0GJ77Cfi5IVTxQqeQLDEJnx4UO9GxYh9svOd3WLTrtRUBI\nWAHWw7YM9AKYIHjVwF4OiqjXGSCoBWjfBfIFWNlAIyCVgr6Phy7RhZUKoLKB0wpM8AQFqza+BiWh\n0oI1AaGg198hSzXGgleaQuQkacl8fY4VgiIZsTY1xlYEPyBVLdYOmC+vWRIrfdZaevsH3L9/n8mw\nz3TuWITbC96de+22FT7EsEfJTXgrIgKf1EfjEUHg4CbnJQhwNr6vuftt/pUHVx3wHwtP3TakSU6a\np9SNZTjIAAmuJUsVqYrrf9vaKO4l4OdTamNQSYo0AZGkeDyZkmChzApUdY3b3UFt3cHPniNVSgge\nkffYv3PA1XSOc/Bi5vj2lsc0GpUXrKufw8a2YU39jS44Otp8wLWe0EKYA04g0w4Ilh4nJHnicc5j\ntYluN6dQ0rJ68jGq61og8QQUxjjSNI6v31BuKpZ/gwOl2jjmL06B08jqdILqjTsrdG7RzXwMHrwR\ngGTVGjJTMdk/5tXjK+QHB6hU0SwbirJgffXlyfdfuVIHIZjkS9JqzrBZI2YXtEmfUBaYtCDvO+ph\nStbf5s5uhq+X6AX0lSJtLsiaJaZeI1XCTtpgsj4qz8jQNHIASE4MbPX7IAV52sPowK7UeFtR10vS\ntsKHnEruIHtHhMufMtAWMV8w7o8ZiTW+LFCFx3/zA66nS4pMAQERQidY9vguhDACoI1zYoMibl3Q\nIrgoeo5pzpKl952lF4Ru8QbKBHqJoqljqSjxHkcgSyW60RjraYwF49DORooxkQyShJZ4gqxDGyeC\n8zEGnIASMkK44PB4GmcJPran8C6eWggBGWJPIikFUkBwFms0tdUEEeiVA1SWdjVci18usd7gnScx\nBmNrgulCRm551a2l1pq61bRtg9Ytuq0pxwf0tyTXreNga8TbOzLGCGQZj5+dsyawbi2zeUNpDdPZ\nmucnC3RjuVi2nJwbPnvZ8i9+d8m/+3jNqmlBeD54d0i/p/AWVCpx1tHzgY+errlaaD56vKJpNcNB\nQnVZ41YOpTKc11hbY/Ua3S6xek1brzG6xdjouNNGU+uG0bhgNMqY9DOUCiAMeb9g75b6yhfnZxiv\nmE3XEAJGtySJZ70+xxjF9niPLO2RJyOadUCQ8fjJC9Zrw7qqKMuSsiyZTqdsbW3hu5DBk5MTvPc8\nfvyEpmn5+te/gbGWPM9J0wzvPe+99y77+3tMJhOc91xdX/HRRz+iyHOGwyHVehUddNbig+fg4AAf\nPMtqDsKxWk+ZbA9ZracxNM5YVmuL1prRcMhHH/0pu7sj0jSh17u98DQVnXAxhL/yTHafN/GUHUIE\ntst5TAPfzNDtcUHZyxgPs5jo6qPbrpdv5nkEJmnockiEoj8s6Q8LigLuP7zH2++9y4MHb/Lwza/x\n3rtv0eiKpRXoVeDPf/KC6+tOwxLipiAiodCdCzpK+QtHySBFZFKlYKNkDi7mPt3q6s4e7sZdE8FL\nJaDqmOKeCAyTwFUjaDpGSEooFRRSUG4AYiecrm1gZgJTE1homOvo7jprBEEqiizDWCjyPt/+4EMO\ntoZsjXaAIe+9MYmljGSEZ02RDZm37ub+9ESKcwbdLhHK0Zg52uUs1+fsjCb83b99n2vvUbngrXfe\n48HuBIQg70/YK/PbTp14P72IQas+RAdm1GrjQszfsd3YbZQ9oRtHF6LLyzlF00jWRtBWUOydoY3H\n+0Cvl6MU5FmCEhKCi+G2wSOVREhFkiqkSkjzgubJ5wihEEHirYkOIx8IrYZOmiDyHulej2A1Mklj\nmK3VWONx9KMVvS+Y5ApnLNfXs2jGcbd0aX0BXAc6PU8HnEOIImEpAnIQkPtlbGPUxDLgLChyJQne\n4xBcXFyzXDfIvI+cv0SIJM7Jjo6xrhv31kTgI6Ij2nvwbSxbdXREVFZ1z6XvdFWv72F8na5DZ8Fb\nnPfs5Ak7x3dpFnMqAUJ62spyfTGHTBK+omfLVzI82+s5vpeSDnfplSn1fEaoK8a5ZtbWZIMx272U\nfn1FI+NJuswzqvUpAxNwIcVnI7RoWbYVaZqROU2QBWNzSuZ6WJaYfItJmTGTfZQ9IzRLfL6Hv/N1\nivEIu6gwwbEKkuLf+/ssm5pRO6VgwEKnzNsWuc7ZPXoTefQGmTBd4GBcX4KINVwrYn1adguTDZ1W\nQHRhVbe4QvBRRIlDADke6Sw6KKRQZMFQJSnWeFTwjNKU6+Ua3TTYRmOSCFAyAa33N/Sd1iZOBOdR\nwZGrmLLsfWwR4XxACo8LggKBEZC4gMwVofXI4NDOoxTYTrmUipQklTTG4EMgEwopJU5rXAggFQHJ\nlmmx2iCCwBA1ATmaW/Q3vLn2BoLpsmKuA41TMc5cOpJygrCBDx4eU60q+mXC3Hj2Mvjue4e4IOkP\nclpt2e4pJqOCi0VD3VoGw4SgYG09w96SXMJwOMT4QOYSBpnDFy1FkpKVkidXNe/fy5ivLFILRJ7z\n4qxmMko43hL8xR//bazHFwAAIABJREFU78zPTvnm3/svMLrGmLoDZw1Wa4zROGOw1vLOW/dYvjzl\n3TfuoGuDlJrpyjCUt5882mWY9ZrDw2MW8znDfkZoLMpGC2qW9vBO4hzMF0v29rb57sF3cVbSNoZ+\nNmQ2m3N4eMh6vSbPc6y1DAbDmKw8GsVTUoDj42Nmsxn3DncYDPsYrVmvVhRFDgSyLEPWDUIIBv0B\nx8fHKBFYzKacn58j8wFSeIaDIUWhqM2aPE/oDfq4EEhVwv7WFnu7uzSN42h/RLVumIyO0XV763nT\ntDGYT4buMHKDtUO3qNI11Q0Y45hVIFOL6+bo9cWKz7KXNJUmEWDxiASqpsFah8oEMlHY1mNzaKqK\nhw/fZDLu8+MffMb9N+6ztTOh3++jpIxlZusYDHqgYLmq0KaTwH4h7CV0S3jsUMdrqUg8e930BBI+\n/lvwX/j5v+HlO6Fp3BzitNvc5ysRtTrbKjJhozxQ28iWDVTEWmUSreuJjPZ1E+J6omMlgLYDQcZF\nR5wQlmkF3/2FX6BICzJamizhfp5yfJxTMmc43MOLFYkqma2fU+TbOK8p8h5KBJKkoHEBY1O2xgfM\nlpdsj+8jZeCXf/UDfvsPnjHav8vRnXsc7Y6YriqGvV5sXnnLy28scC4KwpN4wr2hCf2G894Me3dP\nVJQ/dmMZf4cPUdw+vfD837//n/HBO/+cb3/wDcp+jzxTGOsRQtA6TS+T9DJJbRyNCxQEVCJoXz5h\nKDK8N92G3TFsRoMqsM6Qb02QZSDYGpRCJJJQB4RKGB8fk0nBu4egHRFwJRkqFZh6fbvBCdzolXyA\npKO0RPdv9TKQLWNIo0gUYdmgVJxbRqQoXPzYB9ZVS23i5ipWJwRShJIdIM9xrLBWkPgY/EsHgNGB\nkIiYf9dJKW4eIfeFrgdfuD9RxxMz5YTTyM5lvDae33n0gncPBzf3zdmYZWe+4oD+lYCnaZ8zOv4l\nDoc9ripNPT5EXl3iR8eEZUXpA4PRDn7Qwy9eYfoTliHQ2zmifvkZl65kq5TMQ0kvLfHNFK0tl16y\nn24zHPe5vhT0tEMl4O0FvTKnrj2lX5KaBCUd5wQGaQ9bVTTVNT3tCNs72JCyGjjyVlMkNZXNSN56\nH3nyAwQKEUS0oSMQXpJ0iD50fUykCPggEUFsJOx/88v5ri+XxIZ4AnBBUihoQ6B1gTQY0u7hWUtF\nv0gpRXRUJHmCCpGOTl1sHJorifIBLyTGa+7d2eXR8zOyLn1TESiUoDWeVAlM8GgTA9aSEAt3KSCV\nxIdOx9QlHdfaRIpRCGwwYBOEgCyJJ9/gHOuqIqsbGilwIdLNrVSkP4eI59GjH9FUCwbDCY3pY9U2\nQuYkQrNYW8rjwP0HIzyK/UHJujUom/JqtqRMBFu5ZLLVZ+4bPnx7wr/9k1OqtuXD93MSAou54o3j\nAeXdPWRZcf50QaMsJo3Wy9XCRLF0I/ng4ZCs5/jDH62YrxKSN1Ns22f14vs8vf4+78x/HS8Vpq0x\nbY01Da1uI9hpW5z1TFfXLKqGiWmx2nN5teR4Z8RZfcnDw71bjY3M+qyvZhg9RQFF0kNXNQ8f3mVu\nLWU5QgjJarVi/3CH9XqJcgWvTq8piwFJP2O1WvHw4UO01jx+/Jg0Tfnwww+5vr7GWstqtUJrTZpm\nPLh3j8ViCcKyXq9pmobecNI5CGPpoGnamIkVAtYZVssVdd2Ahb29bRbLC1bLhryXMJr0MUays7PP\n6emUUsD16QVae3qjXaSH5Tr+vtte1sT+Ov7G6f+F5zIRYAMegQ8B0/VGyvOcg8M9Ts9eIIJgOBrg\ntUYHg/SCXEqKtKAya6wP8fnoMlush7OXr/j4oylNA7/7W7/DeHePr3/4kP29PT7//AVXFyse/x+/\nz+dnmxcSefbNqTlRsUyMlAjAWA+24+I3LE7noNq0mmDDMN/i2mwGoSuXed+tbV3vqAUxqXovmmwY\npnGDmprI7PRU1OxI4t+hK4EtiUDIdcxHogQ2wKKF73z4Ndp1xaNHf8EbWxn/1d87YtLvo/oFv/W8\nQsgabxuEKnCuxnmBdoZhknM6uyZVQ7a23mC2eIn0liwb0tanCJHxn/zaN/lv/7vf4vD4Dr1M8eNP\nn3G8U3I2u8Do27VrgRguKLqEZYmITTK7ZrSbWoYg2p29j99PBwDw3UzbyEZEzDFzTmLsT/iTj/8x\n947/iLxq2d2dkKcpxhrSJEWHOI96uULKBONjWxFnHYIcb+J7US7KFwgBKWQEaBLE/WNU0SOsmwg4\nVIISUO7cRaUZwQVSqej1B9jWUDWBejG71dhsJByv8cTruedCFAzbijhPwxIlwBhB0gbqAMNuD1pV\nmt5bE4y+wnROZ+8FwcW2JsaKyOo4CDesZ1diJO6zQcQIlyDETWRD8LHFxM1hoYsOEB68iAd35y2J\nh8Z5Vuua65ef8639Mc5b2tYhhSeEwPr6y0taX+3S2nqP5mxN1h/TsqRYOabLBaK/x9o6jkoBraNa\nLuh//TvY1YrxckpzvaZMTQxpu/RkZYmxNZoE0Xju+mdI28fqFpFper6l9ttMlqfMekOGSWBJghcV\nue5TZgl6ren1C4Ztg88GBG0w6wsmMmo41NYBfbtkuT1CvYi31IvwWpMfy+mobtDtTVG383LdUoeB\nCHhrYjR3gMQLrHSIoHABslTRekXdtvQSQVFXhDyDJCXBk0iFs5ZMglYCbWI/LSGjRkc4z08evaQs\n0o5+9ogQ6+5BdLEem/ekwLk4czwy1u4TsNZFrZyU0c1EQAZH8CICzBBorEEh8N7RrCpSY8B3zUdl\ntHTfmv4CHv3Jv0ZJuLsz4vR6TbL9HWT/Dpcn3+fBwTbzlWWvL/nR00vOrtbc3xsxyAyvtMYLyaVO\nOVk0ZCl89njG3qhAKcn0zPB87tguBOsAW7uBxWXGwds7+HbNbNbi1rCYOUYEViLhyWnAmoaMAuUa\nvvPOMWYxZSIu8fNAszrHiQLnWrRu0E2L0S1WdwyP0eA016s15SynyDKClGRFn1oVnFe3y+G5f+9r\nPDx6wN2DHT776ccsZufs7Ix5/OinFHsPCblk0B+ilIXgKfKcV+dTIOCd5fLykrfffpv5fE6SJOzt\n7aGUYjqdorXGe8/Ozi7eO05PT7l355j79++zXs+RsuHOnbtczhYopUhC1DEdHh6ytTUhBI+xlslk\nDCrBkjIajVkuL8mLjLxMMdZzeTkjzXLKPGOSDfn82ROslVxPz9jeHrKYVlxd3d5arDIR++w48de1\nY90a3TQt61WDsQaBoNUt1/Oa/hCKMuHtt+7yuK15edZgutgG7UEbj5SBtvVkWULdxpP3eGtMUAJ4\nzvvfeIdiOOYXvvVNtrf3yZXn8z//HipkQB11Ox0bsNkhkzxuqmkSn31tOyZGBEgEwhHTo0PoDlfx\n1HrbMjqvyYrYCLPT3dxE8csYqmi9YFcEBiIm7OZpzOa50hH0FB3jEwgxgTeJombTsWTaQ9MEvvH2\nm4DkR48+AeC//NX77I9KZGIYb6W88azh7OwlB7vbGKHwvmBQ3keKlFpb9idvkhYlrVkjRU7T1iya\ncwbFiOASdnfG/MP/6F2sHTMqMmrpyIoRtYKJvPXoEEKXpC86b27oSiRdIGcXhdSxLdyAXlwHBG7y\nIrtdoyNvrYU8hX/z+/8Nf/sX/xk7O+OooRISmQDe44PEWocPDumj/KBtPalwBCUQxsWmzkmCsCZq\nJFOJKEvEG/eB6FRDBIJMECHyUTt7R4wGnzLs9Th5ds4v/tIvYZsVvdFtG/PGktIGbEftavc+O1bP\nmmgLj76eWPRrKtAyapSa2tDagHexBrjp1sEm70gpWgO95DWIVCIyPJtQHWE9QYnY6y6J94JNTlVX\nfgw3qKe7V93T4r1FSIH2kmZ2xfUP/4zsb73FslpQDgXj/W2sdnj/5czyVwKeUE4Y7k1weUkpSox+\nRZoXzC+ec3D3AfXsKVMUxc4Rl9MF6eoSVfZI21c0o/fZbRc8tTMKuYIkwS/P6DlN/tb7XBhBcfmc\nkZXoYkQVHGuZs93bR+WW7ZCyqBfMxC6ZfoVU0T4c8n28NMyXDYM8w6kxSlyw1pp20VA0K7xUSNEZ\n2kIAIZFhgxIlUgQIEkHU3VgCStzW5kfnAJMk0uGcJ0VhhCBYg5eKHI+XgdZ6ijyjtZYkSUgQaG3R\n3tLL0ljbJWBswKQB73yMw09ltMv7eFIJCHLi4dGEqCEaoFg6fZPcqkVMmdYmkISAlLJLMHWIEHN6\nkKprYCijWl9JVBBIbZC2RlpBKpPYc0xIWnl7iif4wGzVUirJ5bMTJlenvJqukDLwYuX4+oMJ52c1\nD7f7bE9SDvKCy+s1eV9xeVHx4bsjXJ5ileJlvebb7+/y2aNr3jkc8eCNPkVjeOks26M+qmoYbEmq\nVUKWNFyKaz64O4TKILUjAK8uClRWczAZMKs8zUyw0j+jv0h59K/+Odvf+SfY5gqb7iFkjnOCtm2w\nRmONZrFecDDssT/oYYTku197wGfPX0EiyEa3o99DCPzxH/8hH2WKrWGfyXDAYn7NYFDwyaMnfP39\nbyCkJM97nQNGY82UPMvJs4LJzjbOOebzOU3T8O1vf5vPPvsM5xzvv/8+T548QQg6IJRwenpKu15w\ndLTP0dERWmtm0yl5XoD03Lt3l+vra87OTjk4OKCp16wWU1aVYbx7yHw+o8hzrIs7gtWGoiyYzRbs\n7dzj1atL5vOKujbs7N8FlaKSFP1zdLxWEkaFoN2IFW++Im4qXLHzddyWvI+J/EWScHiwx4cf3sPp\nhpdnS0wAKwPSe/JUIaXCuQbtFKvG0U+gaWEw6JEVUW8UgMnWsHO3abwNKCk5mAh0UvL8UYxXFkKg\nVMytCZ2GyBpNIlKsj44sIQVBv65uiU2Nvdt1btu/z7MBWpsfDJ2GSNy4MgUCTeBMCFoX2JMCJaFP\noFSdRV0LchXzeZSAnhJkMqB9/LoXgopAnuf8yY8/BiQHpWJvpBC5JMmGUU8YDGlimdcLtvbuULuE\ndX1Gmgwp8y20WSETSSIzdrYGWO1JBCRFybyqqN2Sf/j33+M3/01CmQe+++GbfPb5K0iGZKPbu7R8\niCF1LgAuhsAmaXjdksnHeRS/3LH9fqPz2ST1xLU2rraiK4FJWh1ozG/zB3+aI+U/Y3tryGAQy1tS\ninjo9HFejXKBNw2+MbgkwXfGmVhLNBGoEkXnvH2IPNiL9kOVEIKLbIhzyGxEWhTsDPucXK1JESxW\nC44OdxHhduxpxy9G3BE2IuH4FR86wXfXSFncsJddynkISBHdwD4IqtWS3r5CiuiENETwKIHaSSbh\n9f6olL8pHYZOTxYssdRLBJk3zGV4/dAHH/PyQkfJxS85TPAkBLSzzC/AaEeOpF5rtg8ThAg3tvb/\nv+srAY/Pe+hByaJfYmbPwUicbhhmJVwvMIMxKu+T6ilH9ZLry2vYuofYOkKkgmubcbAtMStFtTT0\nkoLhYExVSwbzp/T2jzHrcyRRDDX2M3rXF5h7f4vpxXMmRZ/1/AUhz0iCJ5EB7ypma420FSuRkrg5\nSX+XZHlKkymy5Qr8ZrK6eOe8v5nISoAVArWp1QpP4sOts4R98KigaJ2JOQ4hZvwoF9tXaGvZEQqZ\npMyMpW1bhlnOMhUsuhqqUoqcCHATIQgitnXwIsQgJikIQSCFxwaPJD7MCiBElsqGGDQYOkidAhZJ\nKgKVdyTWgYwCbktE8HGee5SI2hofHMoZKmsp1wakohEOJSRWKZS6fSaG94J+pliuFuwfjPnmO/sM\nJts8fnnBb/7bP+Xz85pvvrXLvKqZLRqKHYVNISNB5Rl1Y3h5taLIUlIJP/t8yr39PmtjmD6b0u/n\nyJ2UJ58+x6uMoh1TrWts1fLW4YTKN1TaYrWi1wuMj3L2R2OG2oDRPHpVM+wrBgcS73JO/uw3CE6j\n8iHrumXn4bdZmX50r9mWrWEf5zxaBEzreDI/RwvFcNAnSW8XAuZtTaYUEmLQX78kS3us5heYNlBV\nNffvv0FVrZhdXyElvPvO1zg5OSXLcuq65uXLl2xtxbLUz372M+7du0fbtjx//pzxeIwQ8OTJE7wP\n3Dk+5s7hLk275gc/eMRka4KQEu8dWhvG4y1yFdjd2QIBRVGQyG0QK7xzNHVNr19gq4oXL16wczBi\nPJ5weTEnz0sW66eUgz4ycxhv2dk74POnp8wXt8snAqibwGIl8HZj6v5iQb9jIEyDbrt2JxbI4OBw\nm7PLOZ988oxhEZ8bgKbxrK1HEMizFOdapIRUKVrnURL+/Ps/IcnjUlgva5787Bn7kx36vYLHjx7H\nhoSDnHcORyxX10irub5s8d2eU6SQpcSQUxnLKqEDQqJjbUOIIIiNNunnuIJ/rXmIJa3XFZvIPIWb\nwEPrAlMhaB3shkCpujVHQC4DrRfUNq6JAwVp19yUIGg6Zk3ebByef/TdNzEu6Q5WLYI+3/yFXYq8\nJB8e0jQtk3KbJC0h6WFMi1SSJB1Rtw3GrCnSPo1cYkxLIjPqVvGtb93lf/uXP0XLBzx5foIWPYaD\nkqQ3uPX4eN8543y4ie3wTtyUAENn+tiIzb2LDJAnxJZDN1Ot2xvobOsibsAyBOaL3+T//N1P+fe/\n+z/wjfe/jreBNJMMctkdVEVcV6slSkqMjS16pAgEZ5GtQWYpMnTRKHvbiDwn1FWcK0HQ2fnI+kOC\nzBiOt9nt9GveaB598pSvffjGzzmJeB1ZcoObBdZwoyn7wjRC3di5wGsPIrBqzE33edH9UikkLgiM\nVx17BsF31QHrI3vWAUxBHE827kW/AVzx4xsG1cdviSaczeuQJCEgRMI1MFtqjkpIEglK8uj7P+Jr\nf+eXvvTtf7VLqwL/4nOaJyfYa01YvaKXZbjRiHYwJMxr8tWKEsmqv8twqGgnAyotqJ+ecmg1sveA\ndPE544Fja+uIfDigNZoLsUNYVWi1Qyszes6QDnZY9e+xOn2EkDlLmdKb7ECzhiyWXCrrKPWMO9u7\n7G7vE3p76HSAUSN2dye0rmNDhO9uRqw9+iCQHoyIziVPIHamChASNgFHf+N5EwLGGRIvyIVAJQrV\nCaekCKRKssRFC6mLjoHWWyb9Iv4sEYnWNhaSlfcoAso58hDPIcIHBB7T0YZJiHofQ2ReVCculAGE\n2li/QXpHY7tOmEoSiH1qEhFTk1V3mtHORBstEuc9rtYIpxHdz3prUdYS9O17aS0WU4JwzJYrrs9m\n/NkPn/LxRz8j9Uu+/uYR33jriDwXPLs0eCd4Y3vIpFdyd6fkeJyRZoGLpaWvACmQSH74dIrTnt1S\nMEqgJ/skeY/jnQLjaoKqKMeBpxc1L55ahMooR5DniiLXhFBzJWvOg+LgjSGWlKoCL3NIFdoLrLME\nAs8++j2WL79Ps/gcW1+RqRSVZtSrlvVyTZqUeATKW67Wt9MbPHv0MaNeyf7ePv3BEJWkIFICJWUx\n4OLikjwvefbsOd5HKvnk5Iy7dx9gjSPNUqSSFB0rIaXk5cuXlGXJ7u4uWmt6vR5JknB8fESaJKgk\nIYTAnTvHLJcrnjx5Qln2uHv3Hp8/fUYIgaura+7cucvu7i4+wO7eHt4HrqdXtG1L07S8+867VKsa\n76BposDe+IplPYPEUzUVP/74x1hneeutN289b1wd+z017i/zO/BaZudNwAV3s3hJCdVqyfnZCbqd\ncbG8pm0sxgNE0YpDxEyezrljbKBMYuPPd792j6998DYAw8mEO/fvcPfeXQ6P7/P2u29RrywhSZCJ\nYDZfMppEu/3Glp4k8ZUmSpInCdZ0WTsbNAJx5daho4K6P7fVDQa6DfH15rT5eNOt/aZzu4+5NMsQ\neBng0gqMj6UsJaFUgX4ChQzMbWDtopBZSEiVRAIfPXp0879+ezcjeEvwPrILIlD0DpAqQ0lJ1b7C\nhiXeLxFuQZEP8N7jzBpnZgz6e6RlSZJsQQhoM6UoJty/MwKvWa9XpPmYfpmivOdqdnuwTKftsp0m\nxLm49gYXx8KHeBCLVvSurNXpljYMBHyhrCK4ccYGIXAheqyFe8Tv/cGv8+zZc/Iio9GOurYMhzmp\nUljvaK/OSZSKgNQ6WDegLd5ZZJIipSA9mMRIEOOiw1d0zEta4IVECMH24RHbo5IyVzRaMt7e5s13\nHrJarW49Nq+3uE1z1ddlVWs3LrWobXLd+BgbS4LWe+ym86izgEQbhyPpuq1HQGm9imxa9ztEJyCO\n4YbEykKkKm/m7WbObkqzIUTw4zf3oGNznbUIAjp4mtYyArJg4r1xgfVc8+Z3vslq9uWC7q8EPJl+\nhU4EbvWCNmlpJmOMUPSLLYIINFqTppZK5WQiYT18E6VbQlsjegV+ssvs9Dm93XdZhRQpA/P5mtyc\ncTdfkuyOsUWBG2+TNa/AJvT9iqAKiiTFt4aqdfQGfWamT1ZKEmFI+rtMRR/tNam5ppg/QWYpl1aR\n9wrCTd5mJMGF39SsJTLE2uWNhCdIEP72gMc7UiFAOFrvaLRGG4sIniRI0hDQITIwibQQAsJ5GmtZ\ne0cKKDyFFFjvMCEQRDxBOSlIhMBIMCEunpbY+DDWowXGRYFpQWdk8Y5URCu6J24CqYxiMhHi6VPF\nEjFKyk47KQjeda/ZI5ynXS7BW5SLGiVlLdLfHvBcmhVpIrmqK/qjgixPOF2uWM0FgwyadsUPny44\n3Ck4OpiwNoaPn81IgV98cw/XeL715oBKBqarlp1Rj1/7zj36vZS7R2Men9esZyt2J0OmVbRFl3mf\ntglIJXjjwZisdMgkofUK1c8ZZ57rc49wKUk54PBgi8EwYV1dc3D3Q6wQrNYxcHBQCKSeU198zsWz\nj5g1KwqV0ViJUZL+IKUs+lBresVXEqV/7bImakGurq7o9wbUjcYYh/cZw/4QIQQ//OEPOTw8Yr2u\naFoLQXF5ec14soNznqPDI2azGcfHx+zsbOO9Z7FYMJ1OWSwWrNdrdnd3mU6nbJoJHR0fRlHfekWW\n5VxfX/P8+XP6/T7jrS3eePgw6oCMwTjL9fWUdbXm4PAgltPyDOc8TWPI0oKD/UPWq4q8t01vOOK9\nDz5ApYKj40OyXNA0t3SSAIRusQ3AX30mN3u8CNAtzpt1fLFY0StGvPfOAf/hL78TSzMBtPPUOgr5\nlZSxi7oQGBcFviRQ9odsbcfwyLKfc3Cwi1LRDQICbeJi3C8yEHD26hqIDieAxsYDBEpivCNTrzeX\nzYYiNkFaX3xbt0w7jSfnL8TtbzbmjVYldsp87XzpdEbawZWHV16wdvH7BVEDnkliGi7QeFjaQGsg\nTwSViRTWYb+kUAJJA97hQhLLHNJTDg5jxEU+Jk16nbZFQdAcHX4XVI/tyZsYfcZqcUmrK7K0T54f\n0Mv7aHHIew8NptH0x0NsSEFJesXtbenWhhshsvUB5zsxtt+wOVHjIzZjttGzhE3pamOC22hGvjj4\nAReihd06iZTwf/3OP+XsbIrWhsaD1hbvHUmSsf6LP8HJTnDp4s4dulBNCch+huz3EPu70C6I0d0S\nZBobNmcZIQR0W2FFj1nleDWdMZ0uEViEul1OkejcfBsrfvw7HiqCDzgLQQe8DTcg0bvIvAgRDwu1\nsQREFzsQ4yFwG3AZEEhMNwd9iG9JSl5bF70gOHEDNm/m7V9idjZMZrixqEN8HQSHCIFcSer1kh6R\nh2tah8pSVKIQSvDq0cmXjsNXAh7bHzCoZsjJDrLM2PNrXJLTXM0xTUXrFZUfkVQXSL9g4M7R/TFS\nKQrzAy5PnnFYtqSzM+5kmsG4j0wzDvfukfZ2kNkA5mvSqoZigiwz5skIFzzBGco+BD3l0hTsDgKN\nGOFlhuj1WYcU7TJEOYJ8zEh5RldXmLNLhIoqqC/E7OCUwKkNsOlq6BtqE9GZy//mlwoBEzwy+JtA\nwFQKrAgY72i9J8OTSk8uEgrho6Auwi6cC2gfotCto6ad0zgCZQjUwZF5GV1ezsXFy3mMd6Q3AWIO\nI6KYy/n4INsQOwGnIS6AInhcAOM9XgRciIK6ICRJl02kiAhetga9XFI4B9ogvac1FmFu77bZ2uoz\n2BozHI0oxwV7O2O++/4DnlxcYVygrT39NIlx7SLw6fka4wLf+3zKTy+WrFXGJ4/nnJ03fOuNfU6X\nK2QSeO/uLmmesTfM2BaOV1cLVBop1YuziuGwZHs7xYYW4VOUEJQjhQ2KpetxfL/HpKxIK4NAUqox\nGZIffv//Qa8arNN4o2l0XDCraonwNUlW8PnLUzIMpZJcXS0YpYHhKIvHoFuNTQQow0HsdC6FQIgE\nJWM68t07d+j3e6RpRn84ZP/wEGMDz56d4JxnvV5jjCHLMtI05fT0FePxmCzLODk5YW9vj6urK169\neoVSCU+ePOHq6prZdM719ZTt7R2EFCRJitaaEALLxYJXr14xn885efESgaAsS7I8o6oqjI2hlufn\n57z77gfM5yuW8zWrVcP5VcXu/h2sN1hfMd7qs1hekme3L4WCoOhBkf/153HTJsnZqEVTgmjxJgIa\nKSXrynJ+folz8SQvBdQaqnUbGZAQN8E0lViAABfnF3z608hm/PDPfswf/b/f4+Mf/5hPf/oT/vzP\nfkRtIVjPso6Na+vVDb8ORNNA/FXx35zgJroqtpbo9Bqd1iZsdtbbDs+m7Xp4vUH4zSax0T188XNE\n1zE8RLeWDbwKgrmPTjchIuhJRMwMGyrBKBFk6nXgG8C9UUqOIRUpQUR2QvsUXbfd4VLgXEWiBhAU\nQeQ0RrNaPUEJyenFY0CQl9uMBnt4kbIzPiBLeyzWFb/yq2/x7OVnXJ0+Y5RahtKCuX1bEtEx+f5m\nHOLnzkWg6B03QDmE2K9Nbu5Jp9mJPx+63/F6PP0GJBHXVeclzvw5//O/+Mc8fX5KpaNLdzwo8PNL\n9OkpiBTXGrwx0bHlPd7FTVumCrHVjwdOY0FG/YlvKpx1YA0hJAwnE0bDHkc7Q4a9HuNBgXUB095u\nfMINmiDqlIMHfqZZAAAgAElEQVS4eZ4iewLOCrwVOBM/9y72sgvEJa7RFilidIq1vmPDo7Vqk33U\n+AgUA5uDS/hL8xU27rjXZarwBcC5GXPvYmnSuw6QOfDWYkPgsjEs5xW7dLqsEFiuWozRWOOYXn85\n+/WVR9PEe4zPEa1kK9XM7IT+AFaLFf9fe2cOY9mRpecvIu767tsy82Vm7WSR7G72Oj0aoTE9gkae\nLNkCZAiQMDJly5YlT7YMWXLkyhEgDATMoCFDmhn1MtMryWazKllVub3Mt7+7xCYj7n1ZZJOcTnpq\n5AESWZXLy3fjxo3445z//0+W9ujdn6BVRloMSPMcnVfE2RC1+Bir/pBjfcH+cMIsH6J9wqpeMTjc\nY2YLfLQim1+xQTLwFT5OcCIi0xafKMT2OZV+kyQXCHNJWR+hBhnm5Bl1b5+0OSPKPVodIOKMjavQ\n5QIhQmpQtMSsgOI9kRXomyJXuNGdQ6S/PYfHOE9sg5cPOCIZTnxKBBKX8FBah3A2HLBbByzhHMoZ\nrI8DilXB/8JYSyYlwjtCwlAS4dDeE/lQioukak8sDiOCwMwS5JOO0HYC4ZEu9NuRsltXg6LDuKBq\nkSiMCP23HJIah7Qaaw1isaEebHEqIbKB7xTsMm8XRwdH/OY3J0SpYDpbMpttGPVS8lzx/HLNHz4Y\nUxjN/b2U80XJ6WzLvKwZZBF1aYkiwfFeAd7ys+dTjvZ7/PmPT/mz779BXRuuNg2DImF7uWIw7DMs\nQNiEsoF+njM+yrm+XlMZSyE8qIr1NsaZCF/U9I4kH/38kiLqE6uYPNU03pNJIFWkiefjizWpinES\ndOlQUYxwgahXa4tznvmiRg7HtxscL+n3C4SXaNOEViHekecFqdtydnbG/v6Ys7Mz6rpmsVjibPAc\nquuGKI53Ke3Oh0dKSVVVHB0dcXV1xdOnT/nVr37FYDBkOBxSbreMhj0eP37Mh7/5Ddvtltl8jooz\nmrrm+PgQYwxXV1N6RUGepVgL66rBOUuaFvTygrSX8+GHH4GPePLmO5y9uiAfHqKyAYgSFW05+fhv\nOTocslndTr3WxWoeTm5d7Mik7TSUsjVw84AFqyBKIlRVEcWSvBeRZYpmHUpYAEophLVEKpBInbE4\nGQDR4dERaZYA/5c//v4fkfR7fPVrX6NXDBDW8MO/+F9467mcLVuZ+SfZRbJda6z1of+iaxMHHRH2\ntY0F1/HzOmLo7x6esEG717gWu9d9/Yd8i42g/SMdQJOstaeSMJKhLU+xcyK+aayZScEoDgeotYEn\nkz6xCHwnSQkuxdsty3XNPVNjRUUapxi7wgqJb5bEcRZUjmKGiipEPMGYBTio6iW4hihOiKKC737r\nmH/7g//KX7aXEEcJeVHwH//Tf77dAH2qZOOdD/xI2WZX2i7qOwJ5u/kHvkoYMNEByTZjJgShy31b\nXwm4R+AaT6RA+J/zg7/+DxD/e/aHbzAyjuriFUhFow1GV0jnAvjUGunasY4jeOcN/HwGaRbeum5C\n2wknsLoiGubko32yNMOYmvGwoDGaB+MRLz5e3W5odvNCtNLw8EWPbDk07f7FzXUigiGhpfW9cqGF\nkRXB+LDRpi0XOpwL1ZPG3GRtnISuR5F3hPYSrS9S4LfRytJ96xXQgsuOQL1TRIb7IVpK0ONezkc4\nFMF2QdcWOYxoSk1vLIm/IOP+hYCnt7mgzB9QvXpGEjUkB2+x1bB/fJ9q/orh4QGrRpF6iynn5PGA\nl1fn7NUrXNPA0VM2xuOufsEiOmYgPGex5qj5CN/fw0rJIE+5rqb0kERx4FMoVyPSx8R7PcxyRS+O\nMXbJ9tlzhpMjit4B9eIZZZSQrF+gYk+9jrhal6RnJ8j9NNw26RFOYIVESIcQEmGDjZGTYaEKGTPR\n9l353cM5h5Wmtdh2NEaQqojSWJwNJawMSSNEILMhcFKQpBmT0ZB6u8XJkC73KnBrgqOkIzCKAkx2\ngJMSrKHxoU+WEMFUy7YlLIMLaquW1ehdIC93c8kKggpLhKnsRSv0cwInDEooMBbblJSbNVxP6fWH\nmKAtCE3kbhl+vmG51TzOhrz5tXu899ErTl7NSLxgr8gY9hJELUilIEcwHuQMUsV8U3FdNmRxzCBW\nnC5LDoY9/vidY64WS/7bj0/ZG0cMJilnqy3f+eZX6EnN2XJK1EvJnGC2XrNYzYAIZz2lSoijGCEa\nZCRwNubsrKR/HNOsNqyuFXhDESVczrYYC3vDgsN+TpZHlKVmvt1wOBiRRBpHjPKSJI1ZG8k7e7cj\nWB72Mz786CMev/lVlssFcZqyXC5pnGdvb8TZ+SlFkZEkCbPZjPv372GNYD6fs1yuONwfB8fssiSO\nY4bDIZvNBmMMe3t7DAZ9Xrx40RJbLU1dMxgc4b1nvV7xne98m0obLi6vmM8XHB4eEUUxh5MJ+3tj\nTk9PKXoFF9NZ2LSlZLPdgFdEaY80HZGmBVUJ+5NHfPA3f81+P6ZJNoyGKVlqKTdrvoS4D4Ej60uc\nFtStff6OadC+njbBzbfW7BbPzXzJRsOLV3AxC9kda9mZFfbSmMVmG4zhXGiP0E8lSSI5e/kKorAU\nPn/+krRXcDAYEcUxJ8+et414PV//2kOmrzynH84/ATJctyCLAKCi1siu61jeNRAVkiBPt6/tzbcI\n59yOjN3F6+8jgKyb/3Qwp/WzDXxFwt5yiedSQBLBN9KWACBAekiVp4gEeSxg4bjeWkqjcInBGpCq\ngUbz8sWSR9trNME7ylHTmC3Hk29R1hLjIvrFkNVqynp+ikwKirggiwpqVyHFMUl8xf6jJ7z7Rwf8\n8IdXIATGala39JkJF9Z6uxDWOteVSVrXZdEq5KQMjTSVCuoq15Fsvce27Q92GR0PkQw/61SwS9BN\n+Du1FqHEdfnn/Pf/8ef8xQ++xz/7k3/H93/5l+joXjCdrSym0SRxDJstqj8gynKKf/o9GBaB0JwV\neCK88ki9DW86ycFDXVeM7t8n/iDD+hohI6qyYr2+HeDpUo67DuTt9dEeHLR1mErio/BN5UTgG/uw\nj2hr2XiNEpKrdYWpPVaLXSbIWYnbCGrrsHXoV2YjiQqE2TCm5kYJJwjPlNOBd+bMDbBxbWmySwo5\n2u8nAZDW1lLXdRD8GIsR4DSIWKErzfjg81vafOHRPRscYOorxuMerijYOEOeZERGMMj32DQ19fKU\nsi7ZxGNeXF6iBBwPYlSWwmZJNMjY7n+N/t6EOK4Ybl7hlCTeLIL5XjMlywpsNMaVK7xe0rdL8s05\n5WlNGqfkyRjnItK9Q8q6YbGaoZ/8Ab2DN0gO3qR4+A9ZOU/y7ARl2g1dgHCKTmLoRTBJEiJMVrhp\nO4F3re/A7x7WWxQShcO70KOqthbrHco6GgSRCEaAuQBhPU4bBIZG12xaYGKcCwThcAYIKNlbEgJ/\nRwqPsSGVGON3lviyPa147wLfB79zZA5lqu6E191kt1v0jDVBlhmwO03rQSq1xtcNotLUdYU2DVrX\neHP7k/r5csv+0ZDptuRnP/mQPo48jXk8Loil5Ocv59TOURrNlbPkWcLDyYC3jwc02tLgWVUNX3uw\nT5Ipfvnskrp2PBjF5ElC5Wr0KCZLPWeLFddTzXpVYmwJtiFPU7x1DPsxHk1VO4SPcdZRl4ZYSFIk\ncZTQm6TEqeR6WTHa67E3HhApKBuNd5ZGa9I4Zb4quaoiNqahJx2rTYUSnnJ9O4Ll+cUV+0f3mF5d\nIpVivliwLUuu5teslzMePnrI4yeP2Ww2fO2rXyHv9Vitl/R6GdvthnK7YbVacXx8zGq14vr6mtFo\nxIMHD8jznKqq8Z5QEmvdkwN/Nnjs/OhHPwYheeedt0mShHJbBpn7conWmgcPHlBWVQA7IoB0FUU0\nRmOsZ7GseX5yznyx5b33PsBVS/qppZ9H7I/3cUYSqxz3JTKDRIKmhKaBT8OCHam0baHyejbDWMNw\nkPDGkwO+8nRCEon2+Qi/U9U6lDcIi7sk8Ia194wP9nn06AEADx8f8/Sdxzx68oivfvWrfP1bXwcB\nWlvqxqI+wx9GyGBwsUvZdGen7kH07R9rU/ti93DeLjpiaPfSvq23vP5599GWGHYn5LaEILrsRfte\nGi04bRu1ClrQIyCWUCjYL+BHJzMqF8bKE9YfJeCNp5PQMy0pEHiSKEZFQzbbBVV1FqwJ6jW9Yh/t\nFXl2gIzH9Hr76EaH8ojesKkF//pf/Gm4/RGoqCOE3y5cR1Jur60rR4W2ES1nxbcTprsF3e7fKog6\n8ncYs278ROt7JogVxInosALCw6aWVCVsFn/Nrz/8AddXK2rjgjrJmKCAM5ZExERRHNbmcR9fN8g4\nbRuLa/AGohRMg3cWJyRxOqReLRjvH1CkSXAZl57jR/duOTqvI2Cxu2xoidk+0Cxcy8nxLkjUnQ1z\npzaOsi3dW2txjaU9n+NaLo61Htv+23Vy847L5ggcHk/bu+zmo5vXHdHZOdG+Tlue7F7LhV1MO89g\nb0xFsF/p5r8SHhVLyurzOadfOK1kVfNANcQ+QcqcB0WEd7ClptQ15caS1zVxr2C9WsCwx0A6PuaI\n2bZi7VIq1Sf3Jb5aY4tH1HtfIUkTmsGY5XSOTA/wzRa/eUkvUuSiB/sPaUZDjvg1pGMsMdZbyAYM\nc0UalejnH7JuUiqZMRU56tFj4hcvUIlAydBYE0G7ELVgBx/cN9s7Ebx5ZJue/u2F7AunjxNUxgRk\n7CzGORwW7wwuiii8ZWks2jqstzjpkVHwY+gVPXzLOA/ydkHSnkQkDuk81jki4Ul9yMYkxlGKUPpQ\nwmK9JcUSdyuD8GgczodSS+U8ot0UhGPn6iy9bxuR+iDbdJ7Eh5Odcx5XN0G91ZjwYVp6/S2jX6R8\n/fiIP/3OY7733cc82O8xGWVstWd/lFKMM96/WPBXH1wym254/+SCHz9fYKzk7YMel+cznjyY8ORg\nwCSJEJGisZ4oVVgL9yZjnuwPWJWOJRtc3uDWjuuZpawiXr5oyNOENEkQTmG0pXGeyjRUtUVIwXxR\nUaQFOZLZwpKnMbX2JCrYAGjrWS8bhoOMnnMcHQ4ZxZrYKl4t1tRVw6QYcXZ1u9Nov+iFurh3zJYL\naqMRkUJGEfPlHO8d77//ASpSzOZz3n/vPcpyy+HhIb084969ezx58oTVasVkMmEymSCEYDabYW0w\nEVwul/R6PZbLJefn53jvGY2G5FnG/fv32Gy3vHz5ktFwiJQCbTTb7ZbNdst6vW4dmUuUksRRRN1o\nmkZzcvIx8/kcrTUff/yC4WDI8fGIk+e/Zr1acX52RZr2SNM+St1Org+A86iI0PH6U9FxDkS7YXVn\nFGchTSSHBwPuHY3oJVkoG3kwVjDMIUpT4kiFBobhxIDWnohgFpckgQS6WFRs1iXWGrQO2THnoKoM\nVRUyhO27uHlj3oUyhQoL+o6L3BkH4UP5vP36bjO/Nejx7Sm4XcPaDWy3pH3WR0uU7T524Kgt2Xjn\nuTTwIgjudu14IGwgvVix0A2zVYM1Fd5YgoQCvvH1+2yrC5J8xLq6wtgmlPbchkH/MXGckiQD9gb3\nSOIBziqECmtULx1T19dkSZ+yrPiT778d7mVLzOaW6zHctIdwzt2AunaDdU60/ned1w67Pk3Oth2+\nO7VQe/2+LUcGQNKOswygLEk8KiYoSAVsS0ndwIuf/hfWySQcspt6VyITTY1QMgDK+0PEYAhZD5Ks\nBTseGYd2Lw4LaR+BQ6QSW9comTKfT0myHuvVmji6nVACbsajUwd2VDLvOxVbB3Jk4PPYwOWRUtAY\nS6WDglU72wJLj+kyPG2fMk/gjHkbvI66FI1v+2t1nBzXtpu4Aedt/yz8LhtkO+J5m3ELwMfSU5LG\ne06AygafNWODItpq+4WtH78YR6cp2+yArashyWhMhHcVtvbM1ltiVSLvP2a9bcgzQTp9xUrHTBJN\nlMaMfv4/mb44QZZzuHqGbirs1Xvo/IAMh7QzembDSMYUaY+Lyduo4/uopkblD5APv0MkNY0y5MrT\nMyt0sYdXA6KsoTn/CV4vUKc/Q5YLxE//DhlFrTnZjbNmR0gLh7MAcQSiBUVtku+W+XfTUs3jzu/B\nW4wFhSR2Gm09qXdYZzBe4DWtE6dHZAm4Vj3mw+ZaiVA/td4FPx8haQjKiQiPVhB5R0zosSSFxIrg\nYGRbcpdsT3fmNRWady74HHlPQgA+ljARjfekIii2jLNBrm41Zl2y2WwRukLbBqFvy3CC1abhVy9f\n4pIe/YMR/fuHXGy3vPdqSrWu2S5K3j3a47tvHHK432OYJ0hXc7IqGQwK/tE3H3N+teTkes3RaIi3\njtJY1pVlqjWnesOmJ9HJkr1Bwl6/z+i4TzGISWTKYKRIC0WjG+IkwhqJ14EnNRpmVBbyPGZdr5hu\ntvT7Cuk9faXQ2uC8Ylyk5IOUdQX7x4eMe5KPztf0Rwn9oo+1FR++OiNYht5mbEqkVIzHe8RJwmg8\nZnI4Ybw35pvf/g6Xl5f0ej3KssQ5x3A4ZH9/n5OPT3j8+DEgOD8/YzAYYFtCuzGGPM+ZTqecnp7y\n9OlTkiRhb2+PouhhTFBdxXHgjvWynEcPH7Fer4njmCSOqaqK87NzmqYhSRIePX7M9OqK9WpFvyiw\n1qK1ppdljEcjrNFMpxccjMcMih5NbTg6vI8nRhBTlbcHyqiwMRnz23Oua8XgCMT7DgA5H1pSXE7X\nPHt2wfnZjG2lcQQJdpYEg7KyttSdR5f1LTcDZtNrPvrwGQCzq2tmV0ENc352ysmzk1CecjDqZ+Rp\nd6/9628MpYIIQMpuEb/JMkDnxeNveAuvnX5/5+gyF7+laLnFS+3AUJct8xjtmRp41YjdVQkgEdBT\njqMEfvqqRGuNcR4nIqyF5mTBd0cDrK4ZZhMGvQmJKhAMUQp0U3J6/WvOph8yGh4hlWCzuqaslwjh\nSKIcIRUqGTE+GPHwUbHL3Hn/Oe//CyIcXEPWoCPR0krPOyAVOCcCZ9qms+6mJGm9D9fXZsREV3rx\nfocdRTg/E8eCOAUZudb5RLBYQ09+hTgbgANb1eE2W4NwDiEF0lnibzwCGSOcw2/XgTIQp2A0ZnWN\nUK1CzWqkSFBpjzxLONgfc3H2ijhWlOvPb5/wufe9A3rc4GUIGSzrOkDSgc72a11pyXq07UYZjLbB\n4sEEno/TnbIrjLUNhYuwt7Z2Cfhwdg6E5OCo/gmSvQ8jbl13v3gt6wPtxo7FE7eCG63dbr1YXq1b\nQcDnPw1fCHhWS1g0EhdnZEkoFSW6wixfcv/eAWNpiZWh3hg2l2eIZsmmWXJux4wPjjD/5J+jzBRl\nNc34IbGpSNMj4nqG1hVFcZ9KeLyPsVIzuvgVs+WKRiq82bLxfayPSZsNMs7QIsVePqeqNenem0wG\nI+JygRk9ZiAjVtOr19QFqn0I/K5BqBXBk8OJm6fJO4GVis84UH7x/GlXHd0aFjggtgbhLHUTFFHa\nOhKhsMYipUM6ENogjaYRAWR4OkJzm1Vomfzah0yPFC50UW9JX85bEu8xzoXNrpXtKedDH67u6fU+\nSOIJpxcA4TVVy/dRbauKxvtgh46jdjpI9psG2YRMT2YczZfI8Lx5b0JiPKJp0JXjarbhw5dz3nx0\nSB1JvvnWfYZpRG08S+M4ngzJs5RUJMzKip++mmFc6D2jpCXPUpSSTLcV+4c5w9EA7zXWGeIoZIDW\nqxJvPP39iMlen/OLNVmRYrUnU5ZBP8ZLh7GOSGk2VYNTkCtBHKWIKEZFggcH/eCuXGrWlcYIzS9+\n84ym8oz7Oe+9OAdvGOURvVQyu17eamzW2y11XWOtZW9vj6Zpdv+vq4r1es3JyQlJHKOiiOPj41Cz\nThKqquSnP/07jDFIKanrmul0ilKK6XTK/v7+DgQVRUEcx4xGI46ODsmyjOvrK1arFSfPTzg/vyDL\nMpRSXFxMkVLy7rvv4r2nqkp+8fOfs91sQQimF1M2my0PHjwKi8tyhXOOKIqx2pLlfZrGsy0dL04u\nmF7O6eW9W88bZDhRquh1WnCI7uTmndtt9kC7SDqcs6w2Fca6mzkfBZl1FqvAw9CByeLaRdUamBxO\nODgM/dAmh3s8fuMBx/cOefj4CV//5rth87Jh3vhPLZmCNjPbWvIjQumjbT20Mwa8kea2QEN+qQTP\n7nc+TVjuNo1dZgM+tZm8dsJ/7fUgvCer4bTxvGhEew3hW5GAgww+vCrRLqY2ZrcTzZZLDocZDyWk\n+UOm84+xdsOmWaCtQKohaZQjRUJZrUBAv39EGhfk6aQ1lJMoGUN8zL/5V8EwznF7QvfNmPiWLP7J\n0t3OVbnNpGtL27E+ZBA6ECm5ATqdfH83lu1nKUO/tzj2xIlARR4Rhfl0sliG/k9Vg9QGYTSqapC2\nJUQ7T/b4Md668CJR6yhtGrxzRHGOSHPCg6CI8xwlFcvliqZ2DEcTVJKQF7d7ttrOVLtMoG8vZqcs\ndK35X5tpcdYHMNNmapwPZrpBrm9xrWLYm/ZZsiJYAXCTpbMu9DQLcvQwx5xtQU8LfnwLfujKYC70\n49oBL9uSmYN/IwIYxBHLzTY4O/tAmm6MxWiDaTSm+fzWEl8IeJrmCmE1cdqDeJ/YWXySocbHlJdT\nqmiAc5ZYXVGMxkRPvknBhnp2wXI2Q16+JE4G9Mf3mMQW4dYkqUQm++hSU8cFfnCEzCUqHqPimHj5\niroyRFlOZM8ReokcHTJfNyh9TdI/RtqS+fUaV0yoZUE2e8a6rsCGUkVYeIIyaZfCDMWt1pMnZHYC\nx0UQAf6WfAPtglLLOYe2FuUctXdoG8wMpQ5ZFWM0CZ5GOGIZZOJFJIlVoElr64KiopWaRwKylmwa\ni+B0WRH4RlqEJqNWQGeVWDuL96GRqceTEtpFGEJX+NpbhA+GYY0TCBdKKdpZmlD03p1gIkLbAGMM\n2mhc1WCdubXsGuDjkxfEzvLyfMHZ1YJYRDw9ntAYzziPqRrNyXTD//7glMTAtx8ccjRMOB4K+nHC\nu/f3eDTO2QrPj0+v8QJW64ZvPJwQ5Z5NWVFtt1Sl5Wph2G7KQLKznmGSYMWWg72M9cIgnKQ/zEky\nSSQiitQirSCyBfE2QwlBGsHeIKPoRaysIU9ihq3k3FaGWEbMVhVFntCLUiovqFqFwDq+HYdntLfH\ner1hdn3Nq5cv+Qd/+N0d4Lm4uODdd9/l6dOnrNZr+kVB0zRorTm+d4+6rjk8OmJ/f5/tdstyueTR\no0esViu+8Y2vE0URSZJQVRVXV1dcXV3xy1/8gtVqxXq9DgDFWg4ODnDOkmUZy+WCwWDAeDRmPl/s\nMkPj8R5HR4cMBwO01hRFn9Vy1fbsquj3eyyXoadWudWoqEev2OPp02/gnGI2vx0QBPB16Omjm89I\ngbT/tdZ++ogKhFLY4aRgfy8hSlqDOR88U9I0IlKKPFOBryYgVtDLJFIJ+v2ifW2D0cE93VqLbjTW\nC2wD68U2kDBff7+wU4tDeJ7dblMJGQYhAAVCifCAd9ya244Nr/1O+49PAIMOBP09FaHPyp6EsgJM\ntees4/S0wCeL4MXlgum6Ai9p6uBSGA9zOLvmXm9FT09Jkj693pC90RtIqbi4/Dv62QH90RFlvUZ4\nzeXVS2SUMl+dI71DmyqASa350z/59u46vkyGx3c3ozWXdG0WG8ROsm9tS0xuS1y7zAa729LWA0J2\nUMiwae+AAiHDIxVEbZYnigVKCqIEzuwH1NUGtlXgbDa6Fe6HDLzcK5C9HkgZJOlKgnN4GYxBfRzt\nuHN4UGmBNQ1Zr49KU1abFeWm5oOPTm83NogbAChaANjyZ/CB7tDokPXyrYzfdmDQB0NW01ocWOva\nA0BLLjYiGBcaF36/BSq2bXTQtazYlcss0AKaTxoPBrC1K5MZ2r/XyuRNuMnzKuxRMVC2hpll4/FK\nEqUJsfr8vfwLd3mVpozHB8RCBMXO3oD1/ALlIyJvkLZCN4rh+B7IHKoVvcnbDPoJPkkxQjC0axbW\nEQ3fwMoxa7nHxXxOfviQ9bpBWE3lYpSoaRqJGgzwMiIzG66qGJMOOb+cMUxLbDHBqZh4f0LPXXCy\nsLi8h+ntIV98RNQuRh6QrlNggRCSzne560nlnQw4vi1peWFvNYGwEDmLcIZYeIzxRM5DWxpqhCDC\ntUDDIozBGxcclq3b+TIIJW68GVzYsGvnSRCB/4MPUjof5OkJoJ3FWIPz3fdcK8Vn57ysXCAldxbm\ngrb1BUEaH3lBZj3O6dAjqc0MeatxRiMaA42lKSv8FyDmz4u3H004nAxZrbdcTa8RZYl2no3WLLaW\neakRieKNeyPiLGK2XeC9ZL7VPDjqcbFYsawN0jRsF5rpuuT+pCBKPWnhOdjzCGXYbkuqckmSe5IE\n9lPJq+mSQb6PlCll3bDVTXCS1R7dwKvLksnxiPEoQg08w35EZRyrTYnxsFpWZP2UTdmQFxl5HFP0\nEi4WW3oKJqMeWRRxMbPsTybs9W8nS5+t1hjnODg4oCgKfvKTvyWJE5pGc3x8zOXFBZv1mjeePOHk\n5ITT01MePLjPfD4LHAljMMYipWS73TKdTrl//z4/+tGPWa1WHBwc7Jx++/0+WZaFBSKOybKM7bZE\nKkW/P2A+n7O/f4CUiuVqRa+XM5td45yjqkqGwwHb7YZiMMAYy8XlFBVF5HmPPE+RUlD0+xTDISpJ\nmBw+5tXpFU/efIdHj5/eet6IlusiPqPEfINvbvgVYRMSiChms9ah2asLJeGwiYXTYiRDixbXgnsc\nRFJSNo7ZdMbpyzMAnj9/xdnpJR8/P+HlyUf88me/wrbPTlB7/vaSKQinTEQg88o2FSOEuFEPuJBR\n8DbsOB2341Zj033uwJL/5MeuHPXpofOf/NzxNj4dvt3AXtWeuQmbviT4i0368PEcrGvaEoOh1vCT\nH7yC5yc8UhXWbJhe/pI0KhBScbD/JnEyYLM1KCmYraYIpanWJ0SRwDqDVD2sL2mc4lvf+irDvfT2\nSLAN9xhZ1UwAAAUjSURBVPrmGW5HS9K+KTHuwJAL/K4AgF4rdXU/R8c9F+29bJtluq60JZAqNGGO\nU1CxJ4qhVPBqeYXTJqyjpl3bfVh3kzcPYdjHV+XNjRDBjiQIbDpAbEHFCDz9yX3wJUWeYustkfL0\ne7fr3xeux7ccp84rpy03dSW9djyMuyEPGxtSBdb5nZmjdxZtQwVDtyDJttkgZz3WSLQWWCODt5Dx\nbTanuz8tX8qIXcPWbtz9a2A1lB5boLQrkbUtyRoT7EFM+JvbxqLiCNM4+oPPz359MeBxjrWOkEmK\n2syppxt8ccB2dk0zOqCO++jVlNx71nVJ1TK9dXGMJ2IbFWwl9MyWbXnN1q/J9BV721Ocrcn7MbJa\ncTAaEg/v0RsPyOOMYeEox28y0guEsBR+Rd1/BKIhtQ3GpER7b3BfnJG3KLJ6dhomoWBn8hW7YKrn\nCQuR2HlSgFBBoRTtOD236+wsZEC52oTSUmQtjQ1GhJU1OFOFmmzbOVdYT+kM3jqSKKKDV8I4VDv5\nKhd6ETkXskNRJNsJ5oJplXfUzoaMFCFbhAugKN4pCsLq571DtmRt57ufCSee8PuBK4QP91k4G1xo\nnSNyGudCc1O0xX4Gn+Lvi2a+Yr0oKRRU1nFytWQ2nzFAoVTDs8s1e2nE9x7ts9pqau05GKRYLzm7\n3pD0UoaZJJaKw37GW/sFk37GS2+JWq6SijNEbNk7GGJshiVlriXjfsJ6s6ZpSrJMYLXA1LBee+JI\nMBnn2KZCiwisZLBXMM5TojwjSyXaCUzVkAtJ1Ri086QqZdSLWJQWKRQphidHMdeLJWwvbzU2UX/C\nVguWywVNU3NwcIC1ltFoyPn5eehiHimMMRRFwf7+/o6QnCQJSZIwHA5ZLBYcHh7y1ltv8ezZMx4+\nfMi9e/faklSFEAJrLYNBH601m82WsqwYjUY0TcX19TVlWXF1dY2UAucc17MZh4dHWGsZj8cM+n3K\nbUnRy0OPrSjm4YMH5FnGcrlkMOiT5ikOx3A45q/+6v9QFEPef/8jfv7TX9x63uyIlJ8x5bp93BnH\nJzgwHqw29LKIzUZzPatuXHJdMJgzhMOENiHz21WYIgmTowlvvv0EgDefPuLBowc8efqEN956hz/4\no2/v/pYQnjj+DDJ1q4AJqqj237bNmgQP1PBzUgQg1yGJLzM2HWj5FNj5RPrHfcbXXvv8Otj5dJmr\nM3l7WQnq9n1HEgYxvH8NtXFY3YRNTmtq2efX769wJ88ZX00ZDd9gU53jnKbcznk5/QDvSmqzppf1\nSbIDBqN3SJIU6zOiqCBSY4yp8DLmz/7lP27H9EuMT6D5YttSlNspfTpD1puN1bSIMWz2olVc3Qxa\n1wncA1K2QmrflW/C60gliOOQ2YnjUKFKYvjh+Qd4bfGN2WVqpFJIKcnefhiaN+sKbw2uafAqDkpY\n3UCU4p0BROiY3tToasto/4gIwfG9h8Rpn6Ojv7nt4LzOc3+t5BkeONe6JduWsGw7kLHLfgUhjfOh\nW7ox4VkyBpyRGAPGORrrW26UwLTs/a6U5doS4utOzrtMTpfV6cCXuWlv4TtOT1toSKTE6AYJNCZk\neLT11JXGaAM8/9xhEP6zoP5d3MVd3MVd3MVd3MXvUXzJvr13cRd3cRd3cRd3cRf//8Qd4LmLu7iL\nu7iLu7iL3/u4Azx3cRd3cRd3cRd38Xsfd4DnLu7iLu7iLu7iLn7v4w7w3MVd3MVd3MVd3MXvfdwB\nnru4i7u4i7u4i7v4vY//B8JgNN16YgdXAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAABJCAYAAAA5f/zBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOy8ebBlV3Xm+Vt7n+GObx5zHqRMDakB\nCQSIQQgwwmAwGLtx4fJcVQ3G7cLdptyFXXa7XIZwea4qV5W7beMJGxuXzWCMoRBIIDRASmhIZSqV\nc77MfJlvvO/OZ9h79R/nvpdiDD1HdHSE430RCr28wzn3rrv23t/+1re2qCpb2MIWtrCFLWxhC/+U\nYf7//gBb2MIWtrCFLWxhC/9fY4vwbGELW9jCFrawhX/y2CI8W9jCFrawhS1s4Z88tgjPFrawhS1s\nYQtb+CePLcKzhS1sYQtb2MIW/skj+HZP/sEvvEM9Hu8NYhTvBeuV3CjiFFXweKx3OAxGhKNHT/DM\ng09QqVpEPV4ECzhAUBQhzwEFY4vrghJoQCaORuZIe7Dvzz5OdTZm/lKL3fumcXOnyemyP3mAdmOF\nkewZoILPE9JMafcy0paweMUTBoZr98ZkrkG3VWJtLWftXIRbjhEr7Ns5is0hU7jvy8/w8tv2MFQJ\n+aUPPSLPN3B3HT+rAYIRwYvBiVA1QobgTIAzjmEsDSOExiIIIoYujlCEKqAi5IAVgwgYEWIRegAC\nUoQJgyFDsRhEPFY9iCFC8WLwqqSAQfAKqoqIkiuIKqmAeI8oeCDxHgS8KhGCB/I8p4Qh9x4rgqpi\nVPFAoJ5790w/79gAPDifqQ1C/uzXf41Dh27h0D2vI+07nBMyL6DgPKhXvIL3gvOKevC+uIYdfH+A\nQARM8bwqg9wDp0VWGSliZowiIogp8s4YRQyIQGQFMYoRxRoQI4O4ewRfXEOv3lMGceA5j6tIkcc6\nuCEAyiu2h887Pq98zcu1sbLCwQP7aawsM1Sr0W01mZ4c4Y4X3Ui32aRWqbJrxx5ec/c9PP7oY5hY\nOTV3gSvLK4RqOHLqWf7m4/cShiOEpYS9B8cYGa0ionhnyHNPs9WkUqqxstqi1+1jJMR7B5JTKofE\npZB6rUI5ipjeViOsjBPYKlnqQRzGZPSThMrQHezeeTfLjWc5c+Y+Ahpk/UUqlR1E8T66vsqNN72a\nbqfN0vyDrCwfYXbna7jr7rfzxhdOoEWwnhd+9jffpquLXZyxLC9f4PSzp4gqltQr8+ebLHXgDz/w\nn7hm3xQ/8p7vh6ROUMkI8iqVekxbLlMeUzqrguZKtytEATinBEaoTUOWKGOTEXbpIBdPrCI+pdFs\nMzE5zN2vfREHrjnIrp0HkbjEV499ht/9rT8h6imv+6FX8t3f9f0sLZ/gT/7sLzn8wCWCkhCUFclg\nwpRpt1NKs0rqPaNThtVFGNkJnTVwbRibgkjA9YW1lnLqsHvesfmRH9yrVxopFsWKodns0O/18E6J\nSkIUWZbzLn90x6u49c7bcLqdLLlEd6jChf4a9d0HCGyMek+WrpH0G7TXlvjEA4+zttbj0SfOQ6j8\n/E+/E2NXGR2uEVcqWDUghlxTPvZ3D/LBP70fCUwxp+EI44g3vPluxobLhMZTCgx4B0ZQbzF4sixD\nFBRBxZPnnrfe9moiE/Jbv/dBhvunaNlhHu+UePbJM5gILiw+/7wB+J33/4QmSZPQWuIoopghhCz1\nXFy4RC9N6HS7tJIOvbQL3hEYQbVJLexSjQMqUZ3YRsVcgBbzg8akudLPHKmCCWNK5QrlaJg4rlGK\nYkQC8syR5zlZ7mi3ezTW1pi7dIFOP6GfDeY4daROafUTBAtiwViiQAhMiLVCYCzWCqG1CIr3IN6A\nFnOSFUNghQcePfG847OwcEZFFbAEgS3mAaBUHceIkPQbiIQEYQSAsRYRAVWMCbHWUsx3V28p8rW3\n//p/A/STDuePPUTz8jzDw0PsvOlOyvXJr7kOgKrHPfVhrLRg58tRDH75As0nH2Z1pUXL1vDjM0zu\nOUA0PAriiEojhKU6YkJACaMK1oSkT32E8q1v/6ax+baEx4lFfbEooIJB8QJGBacehSLlJUC8R8Xj\n8pwgAERxCOIFrGIUPAarHm+Kxc15TxSAd4bMOsLcMeRg5CMfoRH26DVS6rNVehdPMFtR2u1l4s5p\nNJ+nv9rn7NwVli530WQf45Xb2b/3RkbH4fLFeRZPpoTDj7GyepHGlToXLjbYXZlmtZVw/2OXiENl\n22SJFx/aRiW2iA2/XSi+EarkKlgDViD0HowlpCAa6g2rtiAfuXMYYzGSEWPJVekKRFCQjMGiGmHI\n1RMbg6gnw5BSkMZQBcQReiETKSYOhR5KoMUkqqpkQICQqWJVCBBC5wgN9MTgvMMNSKgB2t5TFkHF\n0lfHkBREMEcxA4pq8ZuLDQWJUJfxgz/7Xj75f/8ep3/7P3LPu95LGAmSObwarAFVwXtwqgQe1Bck\nRorkKriGAihGBDf4ngx4iFW/8cKCvBTkUQwYKYiQHRAbaxVjiuusEx2kIEDF6Ci+L3xzsiNSEMji\nFcV71kfBZjA7PUWtUiIMDPv37aGxukwUBwShodvtEoYVdu3cx2tf+2qefPwwvaSF8RFHn36aF770\nTubPXyQOQw5cs5eTJy9TqlosFudSKtUyLjc436NaNXjfJQh71Ict3baSZxnGQL+XEoQBgqMaClmv\nRxA2icKQKI5I+hl5lhMYSxiFVEZGCEb2k+YXaC8dp52v0GxdwTUzwvIYjz32YWoVz9rKZYJ4mp37\nDzIyObrpvPnOF/9LXvWy1/PRe/8f/u2v/ivGZuGtb3oP937xf3Bxocmv/KtfYGK6zL/+2e/n5bd+\nL6nv0G4us7KyQBKuUK96lhehuyrEkVAuQZbC0JRgAyXLheqw0O0m1IIm4zOWSGo0Oy2mpsYIohI2\nuUJw5Th++xtJ+jm5E2KrdFsLHD3+WZ4+MkezWfzqUQi798cICa0zKe22Y2aiTD/tkaaeXtswBaQO\nnEKvAa02tFYVl24uNiqGKBRCMbR7DmMt3nmcUwwRzkEIHFla5tbza3QmQk5Km2beZq3TJnzmKRrN\nJiM1w8OPH8Oqp9P3dFLPpQstXKb0kiap9jl/usfv//EHeeVLDrG41GV6tsbdL7uRa3aPU69FpElK\nHMLsjhluesF1TE+VGS5HQEbmDKKC94Ysz/EIYSnCOY86T6kUkSaOo4vP4pc8tnOZseGIZy4tUR7Z\nx+hYwHIj33TugCVNPBIJVjwixebH+2JTmDsl90ruDN6HpLkUc4SWMOoITEQcxHiJQATnHQEGxCIa\nEBiDkRJxqUa5XCculYnjmDiKMMbinCfLPEk/RV1MlgVEcY9eb40QR6oeXIBXh3OO3AkqilhIc4M1\nHmss1kJkDXFUzEiCYLzDYjfmNqebK870O6tI1iOIykipShhVQAyqrvh+YjDGIDK4rhb31YJyoYP7\nFZxGBi/Rb0pyrj4uGLHYwJL1e7hShObJYF4thI5iFRqIAbN3cvnpLzIzERLVx8nGaiztyZnPj5A5\nw4gILutipU6W54MNboC1FvUOQcEYepO3U/4Wcfi2hCdQJRc2WOb6NzZasHRVz/oxPkUcDFmeMRBt\nioWXQikAi+JJvMEYxYol8w5rBI8nUk/egx1/8jHmXIOyq1GWKwzNX2FoeIyh9Ahpw7C0lFIKAg4/\ncpkXH3wf77zndiYnhshshXMXFjh/6RL79+/h8sISe7a/gXPP/A9e/p1v5ezJk3zoz38P+imu52l2\nDGnaJxur0Eo8Q9XseaZOAYtBTUECQ69gDeocKgYxjkAMeE/f2AGp8cSqJChqIPWCGshECVBCFTIc\nI0ZIPaSDpTcWg1cQPIEqDiX0kBiHU8HgsASkWkyoRZopXj0W2SBTZYUATyxCpqB4HAVBCEQAh1Ml\n1eKzJgihKhZINrfRKqCAeHya8OZ3/q988g/+mN999/fzz/+v/8LI5FhBkLXIHDVgVcgH6o3xA9Kx\nIbUMrjeIiWihShUDx2y8RARUrhIdY8CKFP83irXFQDRGMbiBKlTsU4vzqPS5N9v4DBuURovrFzut\n4tOIymb5DmFoGB8dot1s0lpd5ZprdzFS302lUqbbXmN8qMbU1BRf+MJ9XJw7wy033UQQD/E9b3sb\n+w4c4PhTz7Bt1zYmZ3bwO8/8V1xaI0sKFcPlSr+f0OsoeRYVeSDDGCPU60oYWoIgQETIXU6vl7Br\nz276ecLKWpNKPITBgne4LEXiKr1+m7XmFUpDyuTkCGE2RGiGaTSbEDi2bTvA8uo8Z088ShRN8upX\nvJE77ryTSnnzFfOJyUkA7nnFP+OTn/8zqmNT/NVHf5uVJVi5CI8c/iqz18D266/nphfeyi37b+Q7\n7noLl5eOs+vO63jBwVHGAke71yKsCM4PSKkIS+eF+iQEEawtCSvNs6QJ+D7YUonR8SGibIldS/ey\nLYWjJx7g0eOOqKTQgUZrlWMnH+f++08zfwUCK8RlZc+ePVSijCcuXyCMHP1Gwvf96Dv5+4//Ea04\nodcUshzEQ3dB0UCQAKau2VxsYhPQ73XoeMUKpL0M5xxg8F5Rr6Q5zOy+jnO7x3l45WkuL17m+u0H\n+dX/9iHOXYGRClxZgNkxePM9L2FmqsxXnzpPmqUkWUa9VqO5coVdO7fx5JGcJ488Tq0GUQSf/LvH\nmBmDG/aMMbV9O2NjdcLhErXKCGEQ0UuzYgOBIwoMQegpl0JUDb1+ghGLsQHGGMqliMutFZoX5hmS\nZUIpEQXKdK3HLa/exhcPNzadO6qQOUeQB/gArBbKuyI4H4B4rCkRWKGXD5QTKVRxl3ucC3EuRiVE\nB+qGQ3DeIhJhTYQJYgJTJg6r1Ct1bBxjEbwqgYXAgDUR3gX0UyEOqgQ2QfMEY6GLR1PFqyX3isdg\nPFgDuTFAcZ3MQqaeyBYKu5VizrbGIFoICJuBSzr0zh7BZwml0VmGd11PaWgSXIqaMvgU7zPEGASP\nDtYQEKwN0bAgW8U8KRtER5Xn/H11szj4RUCEKCoRGGWYHnFvDg0c3uXgEvAJ4nsQjdBZWyUcn2Vo\n9gAAQVoilpSJimXNhahPUZdiggjaqyRiUfWUrCX0SaHaRxXK2v6Wcfi2hCdnfYctG3tfAbz4AdUb\naADFty4Wh9ThpShhqR+sDHgQTxG/YgUxogS2WKEEQ3rec+D3/zNnAs+QGmbr89iVJxkmp3X8FFpt\nMGI98dCjhOnr+Pl3/SrbJ3dz7uwX+blf/ATv/ol3sDSfMjE6Tm1kBHWesakxDux9JxdPPMOLX/sq\nPv+pPyIjJutktJKUTqfDpYtNjFh6WWdzGeSVgcqINw7ri2CmooXkjCPBEqsr1BIRVA2BKN5BaqFM\nQT4iKWIVidLHor4YARZAcwwBKR6DECt0jTKM0LNQ8wZRV+gNoqgaMlVCNQjFDiURpTEgExaPUaFP\nUZI0CLn35APlwgnUEYx6VAzqPTGbHF2wzoABJe+nfNe/+GGaBPzNv/+X/OQffJSk74tJenBfo4MY\nAeplwG+K96+rPIWwozhVrF69vgJmg+ys/yeDstWA/IhiDBhxhaIj3/AxB+SneGB98D73m+tzH9Dn\nUKNNHt6ZJX1KUUhohRe+6DacJpRKhnq9ytTYGKNDZbK8zdLSApOTMywtNakOweJai+OnTnH5/DzT\nOyZ52UtfxBNPvJInn3qC1mqHTC2tZsrl+WWaa13q9SHEFASrXq8SxZZSOcKYkMBEeCMEPqVqStRq\no6ysLtJLQ0Q8vaxL5voMx7NMTY9SkjamvUbg15iYqFArjRNFAc0eROUa43YnF06fYXLkWg5ddzvT\nY+XnMNbnj0ZzBYCeW+Pmg29ifHqUi5evEKTLXA6e4QsPf4J+7Rl2T13P5PAwQVAC4NzcMXaOV7hm\n/508+MVPMjE+RDNt4hGiEgShZ3JPIQMunlMa84o1QliCTqaUfEiULHGbf5Zdo9Bzhmtqy3xHBOdn\nDMl5pddMONXp0FqD0AqBCJcXlItn2txzz+v4yt9/mCiK6Kx6vvqVvyPJ+pTHQAKInWVtydFrCfUx\ny75Dihl2m4tNOyHLC0UyzR1hLPT7HjGGzOWEQUhJYOTuAyyULTO1vYhXHv7KaZaWoQy4Ptx8zSRv\necMtXFpZ5ZGnLnHq5DKdNEcDQ6nX48hjj3LXGyZ425tv4YH7n2DbDMyMx8zumGLPvl3UhuuItUBA\nY3WVTt6kUi0jYgmkUFE6DpzLKQWeeiVkqA69nqOfeKypkmdKKQx56XU38ODcIlHFc+LkGdJz59l5\nXcjuneObzh1RED9Qhz2FeoIgqgghkRiILLkXkiwvFA0DgRqsCIENgRiVABFPaArCbG0MBGBLWBsR\nxRUqlQqluIQ1IR5BdKAwIFgbgIc0ySiFAWXr8ekyNh4iyJt41ymsB4AxIWFUxxoDBhRDYACjRVne\nFiV5C4TWEInFiAXzbZfub4AVQxxWyBOPNpdJFs4RVYcRnyBhCZ+2yddWsHEJzRJc0sS3FwuFd2IP\n8ezN2DAeBLZQZdahKhuKzoYyTkF8BMjxZP0W46OrmH4Dqi/H4tAzf4zMn4K5Iyyb21i7/vsYv/Zm\nvOsiNsT7HhJXGNr/AspOWZs7TZ5mpGmXqew0bmWFXCKy6nY68QShKLVsnlL/BOy49ZvG4dtGrfgK\nvlhqRfHI1YVmUDJxPi8UHwAj5GlOYAb+DBVyp0RxMPCGGNQpaopBEdiCtdJ0jP27XyC/8QDm8gWG\n82OYJKOUryK1axnbtkLcaSL9p3nbq47jc6iUY5ZPn8FqwK037cGnfW6++VaOPHuS8TBi547tBJHB\nKMzu3UkeRnjrGYprmKowLop3YyRpTpo5sjTZVAKpKqE3ZEXFjmxQ7xUVOuqJvEWMRxAi8QTe0hcP\nxmAQIq+0BELxOBVCoWD96lEpdifgMQq9QSmrHOQF01dP9+gxspUVZGYWDhzAu0GJEIOn8PBYLVS1\n8uA3W8ETyMArI2DXF/aByuG1KM81VQnFEKkpyNo/iu9cTXxESXspP/AvfoCP14f4ww/8Dj/w0z9F\nEILL3SCf1lUTCq8OAx6hV3cSqoJBCnKj6wrLYOIogl/UuI0O5GzdID9F7Tvf8Po8d2BuJLteLU9t\n7FqeY+rZ4EU8hwjJ1fc8X9TrQzz91BMcuuF6ev0ey6vzjI3VGarXuffez/L6e+6mUi7hvbK6usZa\no8nQ6CSNdovp7duIy2W8czz80P3cccchMrocPf4MkyN78XmfqclRtm2bIootqin9fhdjckKJQQKG\n62NMTswwMTFBGc/KuWdZXD7HyIGDmLhGJ2kXZU9jCn9IGNNpXSZtnSHpX6BWr+E1wkhMRJcgTGi1\nG4zP7OWGm17J1OxU4ZvSTdZsgA//xX/mgWf/ggfu/yCTU6/h8b++l+UluOuut3Lp/Ck6SUJ7ucvu\nWw7w+u94B48c/jgL7Tn+5jMf5od+5D187K9/n7MLsG9fm+FxQ6cF3RXP3mvHWFte4/Qxx/RUxAtu\n386Rx8+AgRt2T5Mvpbxu+lmu3wY2EIYi6BrDnYfgL08aemK4+ZZX8qWH78crxJHgUygZ4diRi6wu\nfpQkH5T9c08uDWZ3jtDvWlbO92jOdclSoT4D1eGM1hoE+eYUsMAYSnFIniWoETrtlNxBLRIyFbKs\nyzX7dpLblKynVGojjEyP8/k//SwrS2As3HFomte/9gXMLS6w1klx3hbWAwTJlcSUOPfsAg+MPMwb\nXnOA4UqT3Xu2MzY+ggmKbW8UlMjyEMFQCpq0Wx186jDiCQNLlufk3hNGESYKWEoSrCi1mrB9pk6/\nA40kZaQ8S7+xwu0vfhkHbj6EHLqfF7/oADt3jVGtxJvOndzlg/HoUF/E1mmOxxbj3hgiDHFoicLC\nH2MBGZSLAhthg0E5SwxBYMEpgYlQMQRhCDYiLkWUSiVCG2AlLFQgLcpDBsFYg68ayhoy/YJ9vODa\n2xkeGkZE6LSbHHvmCLEGJGmXS2eOcv7iV7jUPEcpqGNNVMzxRikZixlsGkIjVAIIjSl8jLI5slwf\n341UhvDtNbIrT5CtzuG2X4u4DrY0Dr0G+uzHcZlDsz6SpwQGgriCaoYfuwaxwaDUpYUCBhtE52qZ\ni431BEwxf0pAP89IfY14aQ6NToINIZ+gzxArZg/tsX2MbL+BwFqSfoswCPFZQlQuY4Iykc9ZQ3F5\nQmthjumz/wVT9YS7v4dydQjyLtpYhLlH6F1epPLCbx6H50ETbbF4qMHgUDUDT09haFVT+HnEC1Yg\ndxkiEEihEBkriPqNkkFRkjAgileDaIZWKux++R00m01GgkuMts5Sqm7D2F1MS5M0fYLZfoNr8y5f\n/sLDzEzv5szpE9TrVWan9vLWt1yPVob40iOHmZ7aQbfT5KnDj/CDP/5jLJw/zcWFJVpnztLpWTpJ\ntyAr1lIKDVEYUopDnFY3lUAGTwQIhnxAKFJyYmBIDT0csVocDhFD5j0usIj3GDGoSGEYVsEjrKgy\nbAoVB4FcHUYNmXiGVegJdNXgFhfIP/ghkjxHnaPvoHT9HsK3f29Rb9Yci5AqpJITYjEeHIoDQlWc\nEYz3eAqFJ9JC9lcp1KWSGLqqGHXk6nj+dtyr0HUJhPVBoCS9hDe//U28fjTi6cMP8Z7/+Jvs2r+N\nNEkHeQGqptB0tKjvquoG2VH1hGI2BBUd7KkMDMhOUc4rjMsUNV4pvE+GosT1jTVnHaiMAw1TrhKc\nq4NX0a99x9deYZMKz+kz5xgeG2f/tdfw+FcPc+OhA8xMjtPv9dh/zTWsNfoYmrS7XbqtNnfdfRfH\nnjnPTTfdzA23HKLdaZP3cy787d8S2piR4VF27Nhd7AQDx8R4hcwXSmwUlKhW6hitoP0AVBivzbK2\n0KXkVihXE2qjVcZnqlxsNulpRh4aQhNgel06yxc5JyNUa0M0F06g/Qad1giUKrhcMc6xcukErW6P\n4fHrsLUpKBnSDDrNzXu/Pv7xTyCPwVAAYeU8YSWkUnV86aG/Za0JWQZZnvHlIx/lqz97P9fetJvf\n+IN3cmo+Yc+OHVxoLFAJhPlzysgw2FBwCuN6CLhEnp8kjD3NtWWCCuyf2cX+eo07dh3lRQcFYws6\nG1chyQqPzo/eoPzxYcMDD3yO8+fakJsNP5ANISwZ5pdWmYpiUlX6TU+9MsPYbMrFI1OcuXiYzAnD\nu4XKpGekPMLEVInEb05Vjq3Q63nWs9U7Ty0OyX3hV8HAP/u+N+B9CZe3yExKFJf5jV98N1944GmO\nHZvjpht30W73WV3usdLq0u9DkubooMSce+Fsx7L21Hmi6jg3vfAGrFicFkZjIwHdtEscV5BgiEo1\nolaC3tpZ+naGNAvIXUZgQvrdhH6nX4w7AlaXMuaDJjcf3EY5KtHtL5GswUtetouRPQu8bqTCyTOH\nOf7kCp1eyo+8772bik8vyXBkOF94FXMF7xwejxpfzA9OcNZhwkLtxnvEFAu4RAYTmHUXHw5bPBdY\nwCJhAGKwYYgYwYjBBEGhcHiLkhHEQ2SuTtRL2F6rMDV5LYFVZofqzMyMElRqvOpFL6Lf6zN34ixy\n4+0snb6Fk1cu8cjcw5xaPUrVRogJiYwv5rhBNUFCQUkL3+MmbQZZr4HOP4peOY6mKeHEIXy3gaQN\n/Og1mMBSCnKMJig9JPRgAzRQfG8B31tBStVBZNZ1cAZ8QJ6j8rBe59qoCYkx5FGZk36M8nLMVG2a\nqDaMuf5lrI4+zbGlTzM+Mcu2oTq9hWOUK6MYiYumJiBrLkF/jbzfxqUthpcex+/7GfTMfUg3REZ3\nIfVhVBZpZTs5ywrfXN95HoTHDwxI60YkL4pgcDJYG7xgMOigvpflWWGrMBBpwfEcgBGsFJ1e3ivG\ngjEenYObfvNXeLZssc0rjGbLBPVZxFSZTT7DLd15ri+1iKoZeZrzD088SlQqs23PDjLnkcoopxaX\n+OIXPsXr3/wm8m6P+z59L69+42s5ffYMv/Rz7yNXw8EbbmBqWIv6H4pzSq+nhUzsfbEwbgKig9gY\nT8ZARcAUXVAoBui5HG8C6t6j1hCp3yiv9L0SGKWvQmYgMkLmix2jH5QQ+5JvdFZlKLF3LP3hn+J6\nXTSwhC5BA0P/6ePob/8u0f/2E2RiUCtEaklRSs7Rl8K8XFHBqZIxUJoGC7Uf1KpzcuLB5BagJAPP\nSvcfofDocweFrAshStpL+Ld/+yV++i13sPKOE/zA//lrfOf3vpqknw2Ijd+YcLy6AdG5SnzYeH79\n6ut/+g0zMgPiU1h0Cq+OyNeTnfVy63M8Ol9z3UHef5Ny3rrCsz7nbDJ1cLmjPFzjy1/+Ctfu3025\nVKHd7tBpdRgbGaPT7bG0tMA1+3YT2QDvHRPTo3zxS5/n3KVT3Hb7bXQaObV4hLm5S+yd2cu+Hdfy\n5WMPk3tDt5viAsVGAbkPKZk6lWiYcslQKZfJc0feX2P/3mtJu/NM1odYWVlh394dPPLkk0S1MhJ6\nKiUDSY8r5x/D2hjJVnGuS9TrURubJAjL5LYGrk3W7xPFa/ikxdL8GuVSlfNnFzcXGKAc10j6bda8\ncOTICSqVKqENWWs0CGsG4+DcqQXOnF4ABw89/hWyDMJQeObIBYwIM3uFuKK0VwXfgWbDEFUqRK0I\nMXBpLqdca3LrtfvZPTzGC+yjXD8Dx04ZbjpY/J7n55TRmiAGdk7Aj73I8een+rQQghzSHExYdBj2\nuoV6gheiwNJp5nSbwq233MYn7v8oPhCicWHiWmivQUZKJ8totjdHeFabCQNfAF6Fahzi8gT1wtBQ\nwH/9wC+z3O+y0lwkCgNEoV4bZywe4cDORYZqNS5dXiRNMoIootFosLLSLtR7n1OqjzFUCwizJbQX\no3lKKJZSHBW+L4RqZZh+WnQ82nCIlXyVPLeElW0EpkSae4xxRYen9c8xVRs0M3S7Ofc9cIo9O4YJ\nI8c9d93KleUv8A9/uMr7/9tTvPGuISYm6uSlkU3nTidvopoSIoW3zmvRpEGGJwFRcuPJSHEmAesR\nmw9iqmTGk5kcFQM4jDUYLNYIDDapGC0IFBSlLAARvBVWGzGtY3PEi2eIXQbVCZbdYdLWMskL7iA7\neIDtt99OFEVU4oRU1jj55VQFs2QAACAASURBVL9jrH2JvX6IWw++jLnOTXxs7kusZBchiArTiBau\nGkIQCUCLEtdm4JorSFbGhzuQShXfb5M9fT9GM4Ltd3BV2c7BgPgUfI44Qbsr+OYVpD6DGDvoVJUN\nKwGqWJ8U4oV6im4wj/oMkydUrTJUqbHW7LDavsyul303Wd6lVKlR0i7TtsXS3DEWhqoMj80gYjZm\nXRtEOOniky42bSFGqKcPc/FizKXuNnY3ypRrYGo5WpplzS9x7PDfc+tb3/1N4/BtCY8vfme8L9p4\nPQPDlM8xYkAdCFjHhk+l3+mQeoil6OjKHARBUfbxGMQqIh6CAHcu4+DnP82xesyICYjqe0lXu8Sm\nynZ3nLf15qiWl0m7ln4zpbFcpdHs8NCXHmBqcpJeo0Ur7XPHnXfyzp9+Dzu2z5IuL/D+X34fTx75\nEq12AnmPPbNjrB35NEO1EIMQ2GLiq9SUCkGhgMjmdqOlvGgUF2OoGMiNJ5CA7qBEFWAwRii5HLVC\nTXMCZ0hESMVhxGCdYHFEvigZOhESLNZA2WW4uTnkwhythUWSXpejXzkCwxF//W9+kGpVIO2RJwaN\nyvzeZ77E048+wfCLbiPNHIkUJEkHqkmCxyskAw+PRUlxoEJXIfAeNyAZXVVChZI6EjUkuvmdunMD\noXOd9Euhhimel7z4Jh5bcTx9/Bw/+NJX8+Ffr1KZup33vP8/cN2hHSRJMXA2Sm7qN/rE/LpreIOi\nFENjvTiw7uVZv/c3lKF0nejIxhVkcJmrnVkbEtLGhLYuK8lz61n6df9+nuj1e3Q6Fs1TWq02rXab\n0ydOMDszzYtuu43e6ip79uyhWqsTiGHhyiLV+jT79x4gKpc5/OUn6aae86uLTG3fRpamuCxlciik\nlQg5Buc8ISG7Jndy/Z59tFYXOXv6IsuLZ7j+xoPs37adxuIp4uoo8wtNsjSgjOP773k7TuCzD/5P\nzi/NUYoiRB2dZpOJyb0E/Yx3/fO388DDn+P0wjxXul3q0W5cu89q8winu01aFx4hc4YLC+c3Fxgg\nL7UJu5APSeF1cz1aaz16HYN2hE7bM1wxhALxOPSXhSzzxJOC60NzVUnWhNFxi/WOaEzZMwat5T4B\nNYIAvu+N38XN+3Zy5cjnuGfmK3R7hpf/QjDIokHXHwGFRl0sht9zC3zguz1fOCJ84GHIjeAyxQaC\njYqNXDHXGeKq4eHPnODYV0/gM6jvVmoTQtaHyohBXZ/5C0WJajNYbvSxoWViuE5jrcFa4hAtcvb9\n/+4nafWWWL24wsnTRzEBSO64eHaRobFdjI7XUe+ZmR7hvi8f5UMfOUkkMDoBr3r5frKlU8zsKLPr\n4G6i0o10ejnlMKTZ7pM7Q6lSR9WyttIlzxKsFaxpkxNjwjLRxBj4jEgzjA7KLap4dQXBzjJ6LWXu\nfIM0ddw+sYubrhniSqvBycujzOy9lu98fZ+1pfP00pxLzbVN507Dnyc0gkqZSALEGrxxOJ+Thl28\nelKXkkQdMkmxJQblbYv30DUZqTEERghNQF96VMKim1I1oUsbCHD0UZ9QokrrcomLT85R0TbToWM3\nKSMlpR4HYC6x2hXyuuLmvkQ8U8cff4DqgZu5+JUHWTv+FBPNx4hNj4NTMZ3mY1xjhJtGyizLa/hP\n6ccQbHG8CErfWqIB4dks9LO/St7LIHPYwgmJrdSQ0Qnc8nlMFILrQtZB8j74tNhcagdsj/zMg0h9\nFjMyWzScFPI54dnfhfF9yMithfv/8p9Aax5ZnkNO3wfZrQTZDPVkBLfjhey89U5UU6xR1PWozN7E\njruGmOmuIZ0V3KmnaF12lIIu1YlZxu21aBV61Vny2RuIW4v0/SF2lk+zyxtcEpGe/AiapHjnqV9+\njDtueOW3jMO379JC8N4PFqqrRiRvA3DZxtSQiWI9eHWkSZdKUJiWcw+xMawvVz6X4vwGUWQxo/pv\n3kv91huonDnMcHWaxlKPHbKfSvoUd/ROUdIFXN9A5ggkoJ1G/PhPvZc9+/fgfFG2sGGIcx7jcjpL\nl3nve3+G0UqdnJD99UuMBMKzpw9TrliUMYw6ZHCOjQDqPcGgTX4zcB5Co4gfeFA8IDmxgAssikN9\nwcht7iEwGIFIHUaEjuRYNdQEYnE4wGHQZ4+RPv4E882MaP8+7N49lEdHOfs3H+XX3/cK9s5sI3bL\nOFnGTrcJ21Vcf4R3ve0QP3eszjCGNVEyUQyWUJQEQ+ZzUl036xbkxqghGpQm84Hc6ynITuA9iAV1\n2H9El5ZHCpl9YBos+g900PboSdOEGw7s5iNffZCfesub0YVH+cV3/jg/9jO/xD1veWlR5lonJwOf\n1KDIBXA1H4ukZF2RKWx/xYuLImrhq9L1NsivPzuCr7HufEPHwdc5djauu35b0WJjsBl02l1GhuvE\ncczhRx+lXq9x3cHrSJI+88sLnH1mjhuuu5njz55kdGKIExcucv0NNdoJXDx9iiiKOX72AhKEENap\n12tE5RF2H7ieo898BckTwiAm0irHj82xNt/gxbddQ1wVoixipdFAGp7hkSEik1MZqXDxyipZFvDY\nE0+wY+dOXv+y15IlHe47/BCnzx6hXptk2+ReXNLk7z77P7npuv1Mb9/Pxz99H2vZKuNTkwRhwMW5\n83SabcJanTzvbS4wwE/++M/xB3/4K2R4cmdIV6C1Av0e1CODNwpeKY8LUR36DQgiIXWQdZXpkZjV\nXsLSoufADTtRGoxUpkiSDtY7jEJ7dZ5LcxnnloXHEnjpQeXzP+/JnMd5yHIIopR2D/ptxalQK8Nq\nSzh6uTDkRgMTPIA6oVZXSuKQdoQIuFRoXRIm99U49JIhmskFVpcE74XhSfCpsHJlcxuJajWiXolZ\nbjTJ8pw4CsjyHPEwXp3mr/72E4yNjlMf2c5otcriwmXGZoQ0beDzDmF1hHOXE7bt2s2v/PIBQqsE\nhZGOXrKXIIpQl7HW7BEGAUmeEocxzgUkvRyvKQKEIiS9HCIoiaeTOvpKsZF1joXFFu1OwtpahzxT\nriyscnlxgYmJCX7ou1/ILdeN8+y5S8xfgbWeo9HOudg5zvY948TXTnHsqeOIb246dzRs4E1EbhzG\nRoVq7wXnUyTuI5ojLiW0fYjSYv0QC1JYLrw6VARXyME4gYwOUSBYGxVdcNojo4NzjsbCMJeeuEgl\nbzJcC5jMl6j4HqIWYyNSb6lWa0SVKkGcMzpmyHtrpKvzxPTYNbSK2xFCmmPCnKltJfJ+inVXKHV7\n3FJ+MYf9I8VxHUCGR2y2aXUHICwHmLxoQ7KiYAriTdonn38GasME7QSbtUEc5Bmow/uUHHDJeUzj\nMmZoGjXF2ikKnbUxaq2HkPSv8OFUUQ6wh0h6e8niO1jrX2GVKo36KCbtEjbO4ae3k3tHqTRCEITY\n9ln8ic/ByinKy18GzYlG9sP+UTRswcTLqFR2sNvVySUib1+BZ/6BZG2IhQaUrzuIVKagOkW+6y5K\ncf1bxuHbEp7CwlCE12/sdq8uBiKC8x4Rj0rhsdA+aBRgBge5iR0ciIcShErPOUzmsTffwvZ//W4e\nvnyJvNFjNOnTi0doh10OXFxiv/skViFnBLSNy1M+d2qCd12zhyRNiy4T1eL8EECd473/x/v4xJ//\nJTPX7OGVr3w1xw6fpzLZRU0ZG4D6DO8FEY8JBBS6/ZxKlDNe3dxuy+BR5wgG8VEMOTmKEqrijSGx\nhq5kGAwVb8hNcbaERyhjCQQ6vS6BtbjAUDlxlJX5LpXXvoFwpEyZgPTkMW546Au85c2HaLYTHn9w\nnvryFVzY44XvGOapBxdoXVzGTozx1lv6fOpYRsUaeq02XF6k2WozOjFG/opXYG2IVXBe6SKE4jdM\nycYrPVGq3mOBjgh1FKuK3WzNBnDODGr3hUHaK1jVovtAipJXmqXs3TnJ73/6M/zEW7+HbOkZ/vy3\nfoYHP/cm3vsf/ncqNUuWuecwDDYIy7oqs/6HrNObge9G18/W0avdGsUPpYNS1HPKWTIgR1/vxXmO\n0nP1sec8PCg3bhadbp8giOn129x5552g8KlPfZoDB/bTT/rs2DXF8uoK5y/Os33PXroXlnj06FGW\nl5apVYdYbnep1GsYG9BYa+C8o1QqY8wIWX+U3CVkeZ9LjUXoBSxcWSLzCbXqMJgyi8srVMsB/X6H\ndqVDmmd0M4X8MlkOmnXZN/0SJmcnecmP/y+cv/JyauM7+dCHP8vlxTUunD3JZ+99lG3bdzEyuYsL\nc6dhWWh1OywuL/Cy6/ehYUxv9Vu3h34rvOtdv0AvXeb9v/PfQQt1OK4VBLfZzhkZEWZ2lulnPYKq\nYOLiR8jXPNVgiLd83z381Yc/wsIK+GyON7/+XUyO7OCzn/57otgzNgUPPvEoyf6DaN/w0UuzGJ3n\n7tsEG0CWwF98Et70OthXVVotIXewsgwfeVC577xhxDrSQPA51OoKRhifhXxl0MZLUeLymULU4bGH\ne4zsFMQJnSYMjcHQtNJc29y4qoaWuYsNJCjOj7FisCYgcxmJgTe9+TvYs30/n3v4Ph4/fporl5e5\nMHeZkycvE8fwwlt2MDY+Sa06xNKVDqPjo9TCAFsqUSsb1Ct51qdmLTjIck+/Z1HfRYwnQummijoh\ncwlJlrG22qXR7NPrpnQ6fdqdBBNYAmMIAkOeeSYmh3jHm1/L/l01njkzz1/fe44sM1QrDfbvnEV8\nSu6EfupJMsetN+7n8Scf2XTuBHEHJMWbBCelossUwfkMTL84q8XnRGGCVYfgERMgmuM1x/lC0ROx\ng+khxImwEjmgwpgdp5QO0U1zut0KK4+fJEha1CsRU1mD2HcITEo3CcFnVOsxkzum8c5Buka94snj\nCr350/jFE/TP3EsUlchaDeKawfiY+vAklaFtlC9f4juWt0P1xZy0j+AG5zhZU3hlNwufJZCnxTEk\noSMMKUhs0iY98QBZaZTo0hU06GNNcYaaGgFicu3h9SJy4Wns+E5MfZJiRCp6w3dzZfUW8naPib03\nEZYr9HsN5o5+kQZXCCav5eLRwwSxoZakrD32WQjq2JldUE9Rn2KzJq57GRs64ijHKDgt4Vv7MXWH\n15shncStzWPCiFLSpil30vZN1moxQ/OfxI7tJm+W8GsJsrIIt739m+fItwuSDHbXbuBxKdaD9fY7\nC+QYLE6L8pXX4mwD53TdSYpb960gBFqYmd0VmPrv/56GUyZXV7iVES5sn2TbyiKV1UVmzTm6PRiO\nLC7PCHyKNyEPH23ybhGEAJ97vDrIE6Jyid/4tV/j3k99nG0HD1Ku1rlw4SwTew9x5fynyXpdnIXM\nThZSr0BAoWxUyyFpbjm3ukkTmMuxxiI+I1GD14Rs0IZv85xSEIJa+lZIrNBVj8+VTKBkhcALSo7k\nCaVulysf/TRr+w8wfst1GN8hfeIE2cQojcceozq5h6Tcpddz3GJDHm/02PfKCotPTSA7L3H+aJfZ\nkqF79jLLrReQ3fdFsnab0p23s+3mmzn35ccIKLoIKiokxpNp0aZatoZW5nACkVcyhViEsnq8Fv6k\nwG+uIwAgz4UgKBQdGRjWvS3M0tj1CnBxwvPUeI0P/sPf8zM//KMsnHoU8+Tf8MNvfJxf+p0PcNPt\ne0n6/QHRKNSWdZIySNKNsyHYOChQv0Z9MeulMZFBd8HX/daDt3+Nx0ef495Zf06fQ/af89wmxUHE\nhDxz/BSBVVrNFsNDNe6++9UsLy8ShhHepBx+6jCjozN86eHHMUEIeYczcxc4cOB6+s5RLlXZvXs3\nDz30UGHCDyMWFq4wPbWNCxfPoMaSOGF5eZkbD1zPkdMXuHjyq4yORezbPU6tZPC5I4yWmZicYHhs\nlNFymUq5ThhXWJw/y0gwjTYsUZrwDx/7KP8vZ+8dZdl11/l+9j755sqhq6pzUqtbUssKlmTLAmdj\nYeB5DMaMPfiRBh4LDO+9Ab9hmAFm8BrSY94aZswANsFmwLYMxsiSbQUrWJZaUrc651Q53Lr5xL33\n++PcasmMbSj2Wt19b1Wtvrf23Wef3/7+viGJBa5vURkdwhhwqgHPH32WydFJrly7ysLiAnfd9wAD\nw9tZbjZx/JHNTQzwo//qdl4+doKRWi4XdwOBFRi6DYFWBulCasWYTLC+TG47YKDbA0WLW/a9nuIH\nLd5497vZPrMDozL+34//MgduqnL27Coyy3eu4xcus3V4kkg5/MXxATxrnXtvk7gevG4//O6nBXfv\nEdx7i6Hdgi8e0zyyAL60kRasKUWaipz8auetVmU0ljRk/ZsTCjpxbjzaPg3FCpTKEEXQbYC0N7dw\nriw2GKzWiKKIJE7IRM6NsSwXVutsHRtj+exLPPzki9T8KXZN7uOeg6MMDAxQqQZIS2J7Aiw4e+4q\nFy5c4ZVjZ2h2W0xNjDIzNYjn2iQxNBqt3KTSEiRJTJoYoiQ3e0tTRZr25dKWg1+wcH2PSq1EmioS\nlZEkGUOVgAfu2MJYYPH8hTleOKUIAod2J+WOm8aYW82oN0IsCxwLAkfQ7GXoks9v/O/v3vTake46\nQrho4ZBiIUW+y2utQYYYlSCMxgHc/iFJa4kWCdLk4pqN1rUlXTIdg0lIzA4OyrsQC+s8+cyzFLbs\nxbeaDNJGyAwRhfR0QtExFF3BQEmBH9CLMmr+Os2FWZrrHeTCNfzhSaLZC6y88AXE+gpBTWI7kvXL\nhrGxJuMFRXnyIEk0yER3jfvdaeLSzczFJ/D0q4kF/3AL+8eGjpK8G2GynC7gCRxLkURtTP0qWXgG\nWmtQdHGKReTwXopbd5OszZNdPwNRSHLpWeTQJO72OxB+BYTBIuXJP/4VxgYmqQ8WcbceoLzlFqrD\nM3jdRdJjn2VpNsXeVcQKhrAdl87CNdIMBkZmkE4JWR7HsjXC9rlRWvgp69PvpLDzzRRqY6jOItbO\nQVae+jNIt7E0OY2ZgPLFhxBtl8axCyStHjW/jeHbC5C+M2m5j+bbQqL6kQMYgcL0iW4SIzS5YW3u\nqSIVaFfcMGMymcSxc65IbDRCCG79zY/RnhxjrdOhuLBKeXwGGk3qi8tsd9aQgUCkgFAEbgctJTgp\no9WILM1Ik7zYSOKIgi155mtP8V9/7TeobZ1BJxGtdpOCHGPHbZMM7P6XBF5AtVLha5/7HaLYRWYZ\niZAIR+I5Nq5vYbubW0CpUaSZIRYCmcbESiMsC0saYgmBMYjMwSs4iNTQsjcobrlsFcD3Ijh1lIsv\nXmHqPW/PJ7vdQbc7rP3pX1M5dICBm3YzumWK+vxLOKU5zq1rDu+sUhsroNMDDCYhhw6HrC/C9ORe\nBv76ZQa2jjDnTqOPnOC6sal86AeJ4xRHQSYt0LmTc4LIlTbkbQJJruxJtCDVikBInH47arNDQB9N\nM32lVH7DwhJ9f6G8IjGAUppyweL3PvVn/J8f/gnOPPclZqY1H/ngj/CBf/0LvP/Hv480TZFCYVsO\nWhsc1yZNDGkGhYJApRlK6292R+4XPUaYV3k4N3g9r+4Y5pvhov4v8CrHx/RNgG6gQRttrddWPZsY\nb3rgLTz++FdyZCYokaSaa1evc/nyZaZnJnCdCfYfOMjCcp1Gd51SqULWSdi5YxeN9SblUoUgCGi1\nWoyMjFAqlbAsSafXQKksb606DuVikajU49iJV7j91juZmd5Bq34ZIzpo5SKMjecF1Ot1EpEwWNxG\nnKZIRzM4PESpWuPy3BXm13tcnVsnEj6NdoNDt95KL01ZmLvKwYMzXLqwzmq9zo7de9i3fx+9ToIl\nfBy5ebfcNTnLzE2DLPTquKEg7uVKz/IQ2B2BV4RrJzWDBUFlGuxhQa8OsRLEqebzT3yM2sgSf/uV\nJh/8336O46e/RLt9BD/YgyUL2G6uRLLchKXwCo4aJsPnvz5bxKbHPbcLDuyHoZcMv/Jpwd/vh8fO\naJYKkxza43L06BUC16VoNJkwRD1BUDCkTRgdEyzHIVqAyUBbkKR5dwAhiHrgFaCzImjUNV6wuXUz\nOjSYKzyRSMtGZYo07tBpw7aRrfhjN+EM383v/Zv3IbKIJGyhohYq6qDSECFtpO1j2TY3v+4gzn13\nYxXKnJ5d5LGvPcsv/8rv0OrBz//4YQqFKr3Y0Gm3AQijGJ3me560JYXAzrmBcZJ3iqUmShXNTsjk\noM9dt08hLMGTRy+ywx/B84somnS6KeOjBcaHq8wur9PuhJQCm0Y7xQiFJQ2drsbbMrjptSO8bs45\ngb5ytl+/2HlcjZAbvr79IXNzWGk27rM2oFHGIEmxjEHKUe5J3kp89RRnrixQ9kFmbfYWBd1mg0Zi\nUQgcTKbo6Zh11aNclAz7BlkqcOX4S6wvZBjXxyzPY3k+jbMvMX9+jkrg4Nsab9DCH4E4lHTW6njV\nFcrVgOzcPGOp4B3BAZ4r1TiTPo1jzDftXf/Uobohdv9QqMlQrsk9i2KF7w0zcOhm7KmDJLWtCLvA\nijOI0C0q1CndZTDtM3RefJbo1FcQfhln6maEHdBaPseuXYcZL1e4+sgniJo9Dh94K36pwsKxT1Pr\nnWP3gbcy3+2AHMaqjCKGp0mNRsV1sB3wawiviOiu5GvJgNs8R+Ox3+Ern/5D7jy8F9lZwvQ6dBaX\nwC+x3E44t+5w181bcdoe4dxCvk/7NlbR/7bz8I/48PR7mTrJPXQwGCGRRuQ+vQZ0H8HRBjKdkiko\nSE3az99ykRiTS9sVAidVzN13F9nkAFsuX8HUaqz1eqS2y/Zaj/HVF9hmTmCLEaS1lt9ktCYTFY4s\nDrN0fYXayDDdboSlYsLU8Eu/+BGqU5PY2CiZYbmCuBfy9JPPcMcDb+KBg7dz5BtfJ6iUca0CUgq0\nFGilicK0f0rf3CLy0lx5pXRKIdPYRhNhKCjQgUfPaGKTYqe5LA9loWSO7LjS0HMU6omnaK2kTD34\nXWS9HlIKisogCwW2/Mj7OP3Jv8AcfZG7H3wr83PL3HawwM57xrGaTdaeT1n1ztNai7h5r4N7e5Xf\n/b1H6RmHju3g3L+d7Cd+lNLwAKYXYYQkkzkR1JAiTZ5UJYVmUCtmTY7guSZ3x3aNpGM0Ba2xrM2z\n5LTOUZw8a4y+bFbcKJIRos/Hyf8oJfBtwe//+R/z73/u3/DkZ/6Q6e1b+dtPfIwjzz7Hr/7WRykN\nVXj6ySN4QYnnvvRZXjl6EhvFvkP38+bvexsTkyMMT1TI4gwlNOKG8eXG6/TXdE7s4VWcKf/OPxz/\n0KdnA2HiNf+lRtywXf+njn/x/vfz2S/8DdOjQ8SJIgp7rK7WGR0bJ8sSlHZ55cQ5SpUqW6YmWV2r\nUy0PUKsNsG26RLvV4djJ41QqFUZHR1FK0el0KBdHqVTKKD3NxcunQCWUSj4Fr8jpU6colgrs3j6M\nbXzSXg8bl26vhzIpVatKohWVoocCrs0v0ut1STNFZtdox+BXfMZHpojamsxYqEiSxAn1+hp333M3\nUzNb6HZWqbeadNMM+c9YN1Qi0qbEdMFyQSbQXYckMphM0G0b4h7oQh7X0FmFyjAMFUF1JcZbYnQG\nRsRVwpPPMGw0TjBOHLtUqhXGuqCFz+pKzMKypsAqFX+YnnL5zUdS/h874Y5Dkh9+OzxwS97ierEL\nfsmmEXfwyqDCnOc2XAU/AETuoxWG0FzNDQ0zBE4ZxrbA8qymFwrKRYEnIFUGy4FCZXNTk2YGIVPC\nKCZJuszPZdzxulv5+Z/5BRjezer8LHFnjTTqoFUKKs1bE1qDTrEcH8v2+62cXFThWA7ThSI/+a43\n8tMfeC/PHDvPb/3H/8yxI4+z/eZJBgaLCMvC9Rz8okMBjeu69OKENM3IUpOvvzBmpObyzu/aycp6\nm0dfuorWLraUzIoO2weqWKGkMmhx295Rnjl+BUtUcLRNmqW4UhKqXP3k+y6feuQZ3vqvNzc/wt5A\nYflm2t0GgmvxzYcYDNLuG+YagyHDGJEfprMMY2CwdDPu+Yu0eutYImZyeoY33jTOlLrIQyu5J06n\n3SGw29iyTcuGwIHlhiS1UtbmMq6teoy6EaM6giyhvrhIN4TxSq4UdkKNHxhGhzVuySbttekmhl5n\njfKgwx67wWRwgL8LbC5ET5OpLLcv2cTIOinSAmHLnLzd0rmbepgw+qPvzaX3KsZlEXSG3RWsMkVk\nuVTqL1Fc+wxO7WbCCxdJTz+G8MvYo7vwihMUxiZI05jSbW/HG99Oc/kcvYaDt+MBGL2ZQjDEeH0Z\ntbrE8Pg47vQQwmiyNCbuNvBay9iZJlu7jkkBDbFxuN6d5O43v4PJ4grm8lmyzhpjM8NknRX80KC3\nb6ex1qDWWsOzDcb1cgAm+mc6LW80EKS0+7C97MNpur9AcuRGk9uJ60zn5FSRW2wnJrf1TxEYmcdH\ndBdhKugxc2KOxu5hLl05jzu+nSScxW2G1EyI7VTQZhKVPY0lLWQRRFNTrRT5hZ//ae59wxsRls/0\n1BauXj7D/OwlbK+AbZkb4Y5b908zNjzCgG/zhU/+F6zsKsXiGMjcmA+dS8kLBQv0t5Yff6fhG0WS\nKPzMEJncusvNz14QpSg7j56wlMGWEqE1tslbKjWtCS5eor4QMXPfYawwJqgVUFqg1js4jsQp++QZ\nJxaPv3KeXTrhkU8t81PfW+Jyo8N0tUh4+iyu73DiRUitLktrLSp+wHqYUCi8jumxYVpRTCjAw2BS\ng23nn0dmTP65GWj2Q2GFyCNDUr3x+ecmke3Nd7RQOl8t9Hk8eR5M7h66YUtwA1MROUKojUFmEf/+\n93+T/zI8wOUXvky7F1K/9gwffOMbeceHfpKvP/znJN01JqZKRKtdikXFV77wLC989d9RHRrklsPv\n5V/81E8zMDZBpjKM+gcsmxtGi//gE/8WERH/qzlhDhtt8JIMYLTYdEvrngfewCOPfon3vPudGFUi\n04KgNEgUK3bvmSZJFPVGm1YnBiy279pNFMXML63Q7V6hUChw33338vjjj1OtVmg1WxQKBTKVIaTm\nzMnzTE1NkmYRwnMRSKngdQAAIABJREFUysITDo12h+PHr3P7bQdR9hq9qI2KEyxPU29EFIIWg0PD\nuI5HkkC9k7Bn2xaefP44pUoNLygSxSFGJGRhRH21zuzsPN1eC9sxrDfmSbOYdqNFmiYEQWFzEwMs\nzsakkcEuSlqzOWG5NgrFAUG3kfcoaxOGaMXQ6xqcIiQdycx+gy4bGmvw4vNwYM8pnOS32Tv9QXTH\nwx6vYYkAGQQ065pWG4JAIDJY665SK1RphAX+25ESgVdn24xgfFTz9XMCVcyRocC1GRkcZGm+haMF\nqgu91ODYArsKvWYeOSOkQCuFVxKMlgQTe33mlhPKIwbZEzQ7kMaC5urmFk7usJxRbzQZqdX4/F/9\nEXfc+1bCqy+zfv0MRmfkkX4uUkqQNloIjMoQlo3lBji2Q16m5/5gwpJE7RWiXh1r9hiHKwM89Nf/\njadfOsOv/9ovgHKx3Vz84DoWSkOr3SPLNJ0oodWK2b+tyu27xmi0Qz771AUcy0FrB9uCoYEiQcGi\nEJTZMRFg+yGLyz26XUm1ZJNkmmrRoxG3SU2GaztoLXn6yxc2vXYcJ79ElXlNdxte49L+qkjhBkpi\ncuRnQ5gjcRDGYzq4naocoyQG6YkQowS6OMi77zvErbUrRF2Pdr1DImx8K8pzJGWRRpbRrWu2DWes\nhbkasx3btJZibop6uCLi+oUrTI2UaKTguQo3VswuCjo9xZYZD0vGtNe7tLsxlYGISiGjWs54j7yH\nR/xBznb/FqE2h57qOEXZ3HCijsMYKRTO4E50888guDsnzKku6AauKDLV/DMIFXrNJlvOCM/+PXpo\nJ8y/QjayC1kZxysOUAsc3MEhyjM7kSO7kNImy1KUX8GpDFDsXmXLgd0IuRXjuWCPo+OIjjWFbc0R\nlCbR4x/FPf6nJE/9AcYGk6UU0xVKJUly/gW4chwRjGB0D6s6ggMMTO4imDtBsLCM5VXRlo1KCujk\n28eSfGfSshCARhiJkBmg+nEAffWW0RiVgTZkQpLEWX7R+7rvZJkzyyUiD027nlD90z/h7LHTiJvv\nYam1SMEusNxqoSKPc8kwQ/E0tdYsQ8WnMUGASRJUF1KtObTT4RsXG3zxb79AmkUsr6yAgqHBGtvG\nHVwElhAUA8nEQJeks8Arp5cwwmLL6Hh+YxJgdB/a7HvevDaT6Z86MqVxMoPKFJZloQU42hALg6sV\ntjJExhA7NkKluNIiFhoPwZotEMcuEtx1C36U4Rc8mt0eKjNkWUIhk6RRwp4P/RClwKf72FeIZpcJ\ne12+cXmBqe17eaqTMnzrfVy8cBm1fJmCL5FJQiwkeA6rn/oK1kuXsd/2FsKxSYpK0bPyqIhQ6xtB\nrsoYEgSZMGRoXG36ztl5Blpk8vjQzQ6t8skWhtwyXcAGFGj6Vslig3Dcd+oWxuQ+T0nIz/67/5tf\n+7kWpeWzZNhMjCU8/tDHGSpJ/JKPJwqMDVqMTli8boeg4xRpxxmt7vP89kef48CB+7n7Le9iYsdB\nHM/Ks1voFzH5y95gH4uNtX7jKNhf/2ygQK+9Mb36TPc3zs2q9rMk45ZDt/LEE0/zkZ/9GU6fPE7R\ndShXKszONykGNtu2b6PRajM6Oo5j27SSNq7jMjo6yurqKp/4xCd55zvfyaWLFzl0yy19T5U1Go0m\n+/bvo15fpFwpkXUalAtDqJJN0G0T9gzPPn+Bg3t34tkB651L1EoFRoeHCdwhOi2NMS08z2Vxfpa1\nlSVavQzXC3BtiygCzw1orLdoNlq4jodTqxInHeIMXMfBcV1sx/kWJo//+OisCnptgbQhDA06gfq8\nQFoGxxOUa+B2IHMFB+4aZ60dMr/QwPLgzQdvZdJvInDBk7hDNb701MOceK7BvrsnmBwPWLwaEiX5\ntT+9TZI1LRYXMyLRZGx8hp5b4i/PV/n+6DIN2+F/nE45sGMcOy5RERDrFraTobIerhBkXUNmeUSJ\nxvZTLAdUokGAp2B1yVAKNHYm6K2AE4AbwHDJoNPNzU+jlbC4tMyPfeBD/Ntf+ijJ+hrrJx5Da4W0\n7L6wJMtVgypXkTlOhuVoRLxKoTBD3DiNEQFaZ0i7gjEthF1BygIaQa+9Rnz0S9y1ZTd/9VeP8tM/\n85O0WsuowKcb5n5YaZqSZBqtFd//pm1cma/zwqU6aWZwLIcs1Xi+Talg5cWO67DcaWBjs3e0zNHz\nq/iujSUkcZZgeQ6tKMRzHHzL4eVjs1zcvMAPzzH9+JwN1Cb/ujGv8uxumL+zUQDljxSmz/lLsPQo\nIQuUdRWtV4n8rTRq8JbtAzjNc/gjhri1QsXrMbeksXxBQdio0iAF2WFs1CVNG1RLNm4BGh3N6Q5I\nNyDNNK4LgZtSbxl2TtuEPU2SZVxbLSKCGiOlGRJfYHiGxfklpm4+hHAshlWbN3gHWNUniZKzm5ob\nk6YYLVEahMlwxncR7Hs92aWvIA98Hk0PGV3FNE/C+iyMv4/MWWT9xS/iXj9K9/hRro48yLi1SjFs\noZbOk219HTg2vRe+SHjxBcy2w9j3fwhnYBSkRdzrUGuepLjNQieLYFqYuIBwJunEMyRekUL9G+jk\nq9BZQKRTCHcMEy/hCdjzpu+maWzkvgdhfB+6tUK0uoxxB2DHOOO77iQoZljN55C1SYzl4/iC4Pbv\n+bbz8J0LHmMwIgOdx68LQGjQUuOXBpC2lds+FwqUysOs1xd564+XuHz6NNfPXMBzQFgWWhhMK8H9\nyM+xf3yMRqOIIGW85RAWfewsIUsbeF5KgyorYoZEncUjw0gDKUjL5oGbI5675JBKG+3aVIoDZDoF\n47FtOMNVGUoYjMlor10FaaG0wfbykDEpZE6iJnd6DqOE8QGLMJGEyeYq5ixRyEwjpUDofpQBOR8m\ncyxQCtuAZRRS57dJS1gkZBSSjDDLKDdbqNEBVqIQlRmMylGgOMtIJZh2mzOf+xxycZmG5+EJyVfP\nrDGqr1MlI5m9ws6b9/Lpl1t0Fq/TyTKqWIRxSiYtrpw4h33kJCO/8n/RGxrGpIpMbtzCNYnJ84CM\nzvO/BNAzCmNstNH4ZoOevnktkk5zObq28jgIITdg4w3bcXK0pN8qvcGrMQKlBVkv5Sc++h/49L/9\nMTwr5eJywq23Hebi2WPoTLGy3qZQ9Lh8vs2JbsrerSnDAylpy6dUtIlWniBefJEXL+9iZOaD7Dpw\nAKVTbkRFvLZFhcgX9gaCeeNv0a9/XhuTkf9eygAm9+/Qm9SlW5ZDqjRTM1v5k7/4FP/z03/Jb33s\nYwhhY6SgGFRYWFgkyTJmZxeYnJ6hVCrTarfoLvbYuXMH3/M9D3Lt6nXuuPP1NJtNLOkwNDSEZUsW\nF+fxA49aYSTP47HLONUJbpk+zOj4JF999CGunHqUHVNDOH4FIQVFx2Xfnr2UKxUWFhbIsgTHdUlx\niLMUO3DodNps27qdRqvN6VPnuHz5GsVimTvu3kuSdrAdH2kFFAo23W63H2y5udHrAZmhF0KvK1CJ\noTIEcUuQNgwiBa0gszVLSy1WOiE9BWOF2/nQD3+C9NIjJPWLICyuX5tj3wCs7TAEwsZxPMIIpAth\nU4BSpEg8T7LW1ExNeNRGtlMb9HlB7GNsyzbCr/wBWTugHcWsNyNW1tvU10NsoagU8mTrdickTsHx\noejljrNK5TfQoS2QxhmTg9DpgePmr48Um+bGnbywzH//T7/CBz74QbrzV0iTFCwHow1SCiwrw/fK\n6OYCOouxHJdw/iJZ6xxptMb1eQvLalC86UdQV57A8osUqhmy4COkhV3ZifRLGG+I3spVvF6LP/nj\nT/JjP/bDNNsdkixBG0GYJuyaKrNlyOMzX7/KgO8hpEZluWrRcWw8z8L1JFHYIwptfOmwf6bAi2fn\niRObMOzijldIlaEXxQQFhyxRZKnhsUfPMr3Jdh/QD6Kmz8nZINrRL4Jy2kKeGvDNhxgL+ganuZDF\nkR08tZtF2qBn6NR73DY1wsVLl7il0qG1pvA8i/nVNVpxkfEtk5hSmZHpCfbsGWE8OkfcgG59mV5B\noAdTzl+BzPFYWl6lZRy62jA6WCBOeyy1BRM7xpjcNoQztJ/y5G6GHJvLJ8/Raq5gUo3utWg0Esq+\n4qbyAY6bzRU8QoCWLsJkWLbP4Ht+CREkXDv1PNlvvpmB3btYfPk47pZpossdvC1dkh3j1FcdVjtb\naOz7MGNyleraVVKnDO0lVK+FKJdJpY/nDiDmzqG7bUxtGAs7b58qjWm/AIP7MMEbQJYxS5+kHDSo\nLH8aetcguIf0YpfemWMksU2t0C9Qr75EeYeDs/Yy4vJjqI6FSC0yAjSSNG2jrp/AbzUg6iCsDPEv\n/wNXJz7E9m8zD/8IadnCcjwsv4Bl20jHRdoBQuYOo0oZLNtgjI3l2BgrYPuBW9m+/yDd+jyPP/S3\n9NbrWI6kU4fvfeANLI/4TB8+yKlLzzHij7FmegSRwTQ0LR1ytaE5UFyiEdyMbh/FCyxcC3ptRRq7\n2E4Byyhild9MszQj8C2yLMv9BSBf6NJCaY3KNEVPoXUMlp8ncQsNRiBtm2uLMYWCYXJ4c/C7m8Yo\nI0kUOHa/1BECtEJbAqToE7mhb7mIURrHZKRCEhy6mfWXj7L4+AqOJVG+Q+C4pFmGrTWJ0YSz8xSi\nkFBAN4xYz1Le9LpJHn34MSaqRSYmJri+uEwS9QjbhmLRIjIGmaUIyyXKMizLRx1/BfmmBzAiT8i2\n+rEWcX9T8IVFYhSpyYP0BBotDD0DDiZXL2xyZFnO2REyRwQ3bHAQ9FugG06lpu/WudHeygvTVBnC\nRPLge99P68ojeMevc3G5iR8Uub7eZMv4KLbtstBcZX5hlfl6xMjwIAe3ZsyM23QGNX//yDIFZ457\nSzVsdqM2+Dv0i5h+i+qbacyvjcl9zaXQf/u5mRoYI1Gmb/u+Sf6X6L9mqlOCUpn3/8RPMnPodn7+\nwx9msuYSxxHVWhUSSbmco2tKKTzPY3FxkcWlRbZt3cnevfuJo5TZa3O0Oy3uvOsw7U6T5eVF9uw9\nkKuH0gwZjLLrwAPcd8/d1EbG6CRdrp39OstLTYwlKRZdjNbML15kb20fh28/yJHnX2bP3gM8f/Q4\nStiQZfi+R7FYZL3Z4cyZs3nUgIyAjCTtYrseWhmM0Hiet+nIDcg5O61WTiwWicEXEC4ZhDQoJVmd\nM9hFw46REodqw5yNZ+kIzYNv/F7W1pc5fXWZ1bPn6HRSRjzDwNgAu70hCpUhWj1JsQrdDpTKmqFS\nlZW0R1zMSJsgbIvKwADj4zV2bp0GLSl4cPTCedprCVkMrg0lActdgWNpCoFF4OREf50ZjJNv1sWC\nQCY5iFgI8rRxT8BA2abTzWi3wAo2Nze/9os/wwc+9KP05q5iTJ7ZpEzennIs6Jx8iujq4+jqvTBy\nE5W9h2icPMmVI+fIsNn97l+mc/U4dTFO1a1RX1tmfTHE0isE1SLSPUZxfA/FqduwStNkUQv7+kv8\n5m9/nMN33M2u3ZNMSof7RrcxIQpUWhX27xrjCwtnud41eDZ4jkPgS1xH0g0THCFx0HRFxHqniU4E\nb78FVr0DfP35q1R8lzjOd4RyyeP5Z69QB0r/DHQQ+keVDSSf/sM+eruRS3Uj8qD/E/pGi9uA8UnN\nJJncT5LuIWx6jJQv4Do25aBIu7WEkDXmG10SEVAZHKGy43b2TdXYNVVloKBw2w0un7iKiDv4wmXn\npM/1+RaZyhgdrWENDjKfKK6sddk6CIdukgxuGSAqbqU2Mkh1KKXXXOf2H/pxXvjM/0BpRdJaoxiM\ncPS5Jzn0tndzwRnY1LzYgY9wy7ieRGiDHJ4iOv5JvMvPc+n9f4Aj26ze/iYqQxPsf/9+Lj31MEl3\nhfmFOWrbZph79PMM3f92ms0BSvE6Wa+JCdvINIFSBXd8DB1FmEIZ2Td9dHyfpB3SPdXCd34V6/6P\nw8zbEVMfhe482ZUX6Xz1YUT0MMvug8Qjh2FCM7DwKZCQnX4GefUpKFWQoobRMVavBypEyITwyBdw\nTIhTBKMzxPhNOPWn2Dr/G/C26FvPw3eapMrELhAWpm8qqPsOmkrlC9QiIe1JMhWy2G2jVYbvFkl0\nLmfutOrYlovuJoz91L/i+dt20Dp/EddewOt4rDcWsJXCK3qoUkotajG0c4zSskXLG4G1KlVWcQJB\nOQiZiGKUDHGdIrKb4EhJikT6RYSIEX0FmURio0lU3kuP05D6WkZQUvi2hbSdPKXWkjiVApnSXF7c\n3OYcpBlaWAjbRihNKvLWEFLgacAYMtfCxZBohTRWTqTNEkyY4Vkpwe0HybTBihVpEqKiFFuCyBTR\n8ROES2u0kph9B/ezMH+dJEz58pe/gS0Miczw0x4qVJjlLhV/lEZ0nYGKRy9NsUyGznJEY/Ghv6d8\n9BSlD/wgemgYpXLzKokgNgIpMpQ2RCL33QnIQ1dtk/d8N6+1AbTGaOubTlk33JA3kBXz6mbERkcJ\ngZAaW8BAwdBwXKanRllZj5ldPotrZ0yOj6BUwlq9zdrqKlt33kaatEiiFpdWSiw0W2zbNcjgTg+v\nVKMVfoNTL/0eO2/9yDd77tyQqfeLnX6OlvhWRU//54zue++aPFhT69ynZXMj32hty8IYUFnGXa87\nzItHX+K+Ow8i8ZhfWKA2OESpVMSReV5NGEbs3buXVrNFt9ulWqkxNzfH/pv2Y1mSufmr1Go13vKW\nd3Lq9DkqpTLVwhaGxiYhXOUbX/8b3vq2H+DQ/l18sVil2+rQazYZG55mdWmWOGkwM1GlW60wvXcn\nS9euMLu4xvTW3bQ7bcbGqrTDJuv1OnEnwyn4GFsQ9lKCUgkLBaIHAnyvhONsUvpIzv0qlQVpYtAF\nQRaD0ZKxkTKzi02QhqoD991xF3dNDXKosY+19TrdI0/yd0efYFV5iK4mnb3Cjnv28vjpOm6xQtK8\nxPTEPgrSZz2JKBgfFXvYOsOyNDJWpL1cYJAmESpLKZWqWI5NEiYMSZBlQRoU8KIeaxGEiaHTy7Ad\nKAQiL/KxyLTGlYIggGjJgAuRkxOf49Ah7mV4HmzWl/Ejv/iLRMsLCGn1UWWJlHmsz/LTf0Msh7k+\nO4JYuETvSkzjr/8/guoIyEkmRwdpnfkcl06eRdhfQqWGQiDAKjE4cA+tlWsErKLCOdJWncrWW/CG\ndhOtXWFsy8381r/7Bc7+1eeZqY6zbes4A4Uqg9v2Mmy3eN+e+/ijv/kC//nhc+wZH8LolCwFWwjS\nzJCZlFt3D3H8yjpGSNZnl9l9/wFOHvcxKgUchIKlpQb7DuzGosajRzaHYABkNwxOBVq/yt/JL13R\nV1puKJXoY819Dmr/2k/ZjWveh+pN015dob2+yl27BinGK4Rao2xJz5R54VIDhrZz7+vuYseWCpOD\nNq6ruXTxPC++coEdtmS0OorMGjhezJ4xGBktM7F9F3fcsofmygLe6ir7bt/OxctzfPGS4mr9KFnl\nOt99+CA/8MZJrMymsvUw0CDTPl6yRmCF2OvL3Dby7ds232pI18UZnMAtOajZa4DAHt6NdAVnH/9z\nWhfPc/ewiwjKnPt8D1tnuNsOMrZjmqKruefgTlZVytmVlNd5EabXQoctpEoRnQai18OoFOnkHDGT\npug0Za2yjWyiyGB4mNKf/AFseZjVWEC7QeUNP4X80I+ynIa89MRDTLsxb0g/BYffgTn9MIQuQmXY\npRFEUEZdv4ZJwtwwMnCxlYWRZaCb7+txFxPuwJp+w7edh+9Y8Cil0LIfCKAUSsUkKgEtsbDJTIaQ\ngsBzcvtuA5mO0QkkUYpSYNsJyQqYH/hhrp66zMDYMIXVOq3nn6NbX2f3zt0sqAoeLv7gGM20x7ND\n7+N10UOI0s2MFFdQUuN7p9m3bZnvvm+KuXrCsRcVjgvDgcuQ3cIjyZewEBij0AJUljdyHdsmCAQq\ng44BoTI82+DaHgiDYwmsTdJUYqVQtsZKNTG5u7ISEq0N6AzLd5ECAq2xMEQ6JRFw2LL44JjPb11u\n04xjHKVQRuBKlyyQyGaT+a8+QdhoglHce//refEbz9OKEzyjGSrtppu1cRPN0lpCktmURY2esbHj\nFB1H9HpdLNclMhrbQCgswguXWfrV/8S29z6IfMN9dDODtASOMYQaemi0FkgDWuRp9hKdy+s3f1DP\nc1iMBN3nA/drH0vQl6ObG6ewG6rwjX+FQGlJ1MvwRYLwxrn11iJTNc2Xn7vOK6cug2XRajYolkvY\nepEgsNGOjVEWvqXorHTZPV4kmA6JQolSx/ICx+jXbILmNZSdPrW6//VXc7Ty75o+61EDRufeU8YI\nMi3I1P+KCH3n0T9p9osrWwBGY9kO7/3wz/LIX/whJjE01tuMjxXwXA/X97l69Spg8DyPKIyYnCgy\nNDTE4uIiK6srKBX3nawlYyPDNNbW2LZjF+uNOllzlZVWnYUrywhtsXV6lFeOngdj0+2kzGwdZbQc\nEM/NI2sTfPXRr/H6227iwfsOc/L0NbaUBjl88HZOnj9OY20ZEfUQvoNSBt/3cGzwPAGOQqsckZFy\n8wtHCLA86CUQNnKbC6E0aWpIk3yeE2144shXycx3cebIGW6fHqKTpgwffD0lt0w5miMYzzgy2+bC\nYkQWNxgdG2DHjA/aI21E4Fc5fzahnXZRRuIVAQmWJTD91rdlS8YmanSv1RkvQAdBW0osKXCFuUH5\n0opcdZKB5UOWQS9SZF0B6wJRgKVIMzPuEXmGLDWIusAvb3J+uiFGi5wDp/OLR4cr0FPUzSS92Uu4\ncp1uYoi76+y77RCeyGhEJa5fOcVINWR0xwHOXr5G2OhyaM/dRAunabVnCXsJJi4zXtmBPH8SkXwG\nmd6HM3ov6ewr3HfPW2h9/LfZdmgCqS7QbldJzi1SPHgvekXw4Xc/wK6hcX790aNU3IAkTlBakCSa\n8SGf5UYXg8C1JWeuF9i2tEqaxJQKHpnKrU1KxRrf/Z57ePBtCSfev/mCR+s8KHZjbACMegPq2biW\n+y3oPDSk304XoMQYMvt+mvM11q4/h05DbGMQ7EZHHaSJmdy+m+qgZL2xQnu9wdriHHvGbXqRZGR0\nKxcsnwvL63xtrcehAcE+qdkWZ+x+4yFGx7dglQc5eP+bWTr2GAud62TC5rm5aZ5YiFmw93H39BQf\n/aNH0Mk93HdoC2FjGW/3BCunTiEHq4zWDKXAZoqJTc2NkA6qt0q8vIrIgO4K9thBKm96HzujEpZM\nKF09TppYNNbaJGmX0k4XvbBIsm0rztgUMtbMlA1JMyHrNpHddYzKsNOQrNNGWR46ClHC4HgBYWOB\noHmFhZcf50RngJ1v+Vlm9uwlWl7FlRFCtFk68kWS8y+zerqLtWOa6O7vovjGj2Aq47T/+k+oDXpI\nx8JkKabdRsQJOAaRGhwRkWmvT0+Atr+N+aVB7BPXuOnAt56H71jw9Lr5EUTK3DlTaYHjBEg03V4D\nYVwsxyWJIpRROLKA7fkkWQdL2rn3wkKX0q//Kpx6EefAbSRikLWVF3AnK4gdW1mtjmOldRKnSN21\n8KMYENSTae4c6SCT5xFxj9TxcYI13rUj4eSUh1+WrNdjJp0E68oaGRZCaqTRaCmwpI0mRTo+pYKL\n9G2KMsiBBZG7iLaj8MZze5NGPMKyMJaLSGNsITHSJlEppSRD1cooAcVMEaHxDDjCopXEvCNd4/n5\niLeOjqGilE8vtigbTSo1thac++znsaTAVgmu4/Lc40/jSShlGUJKWuFVbHsIpRwCK8hzjoRAZ4to\nI2i2uuzesZX1ZpdkrU4mExzHxjGCrusz+5m/o3zyJNUf+QA4AWE/CM8yeTEitSYVeR/cRpCiwGye\ntGzSGGwbaUmsvvmFeA2a8mr8XD7/0myQh02fSA6Vik1aGsBaP8n64nkwEfftEvzQ972PqwsNPvfw\nSzx39DSq2aboOYzUylQCzWK9y+XZFkurHm9aHWTnXZKhcA4hHIyJ87wtuMHH2XB+FjdQH/qPX/sL\n5eoPra28paUFmcq9hTK1yZbWN8VX5NJ9aeWb9fe89wPs3jLG7/7HX8O1BWtrdQrlMkppdmzfwdFj\nR7npppsYGh6g22vxwpFvMDo6wvTUFo6feIW9e/dSKhVRmUPSabC2NMfI2BitqEdzaZVaMEypVKbs\nOVT8ErHrsbraoTk4wNhIAenB0tIczU6LY6dOMz46ysCQC67F+UtXWJpf5O0P3INvNA9/6cuUaiOE\nYRcjJX6xgGVLOmEX2wpwXWfT68aWhjgUdNfB1nkx2MUwe70BrkD6IFwYnhylUA3YP1piS8VhascA\nYXKVARFQ8VOOz2uOnFul2+pgqRi3ZpGaXJ0yUoV2c52BygjadHH8hFiRy3SVzmNmTB4nEDhVipUW\nFTRxYuh0UwpuHoks+kskVeBoQeAaMpXlPj9CYFsgHEOswZOQtRO6DRdTNFQKgmSTXZs07iGtnDcm\nbINqXKI1u8jKya+TJSGrS/NMHPxeLj79EKWiJFyrs5YVsMuS0ZEywt/B6uICywsNDuyusbpcp7HQ\n4ObDtxKEXZrtmJ5dozp8mGuXTjAjrzNUWSBrDzG+9zDBFlhuPoe0PJzAp1IscPHs03RaW9m27xD3\n37qDM5eWeOjCfM7b0yBtuGl7lZcvtvEcC2nZrKuA1fML+H7E4ECVRichTntUCxW2Dgck1YBnH//Y\npteOUrzazuoXORtHkRvP+09Mvw1txIZ4AlT2LnoXLdqN4xQDSbMbgjAMyC6tXpOdkx7Dwz6mHBAm\nMdovMzQ4yJlr67znB9/OL//iz/CpE8P8+v/xLr7xl4+wtn6Nka0F9kxWmb7z9QzecivNtsYtlqnO\nniHesoVj9XGccc354+cIpjuMDipuKrVojW9naWEe2ZunYFfp9K4zOFzCKwV49Ng5sjlTTy0t6IWk\na03sLKL5d7+Pv+sAJo4xjQ614TLt0xGmu0B5YIjWehdneY7Re99Fae/tIBKqZ55AzFlkiUKHPUy3\nicySPFEhU/SEJAiJAAAgAElEQVRMgj71WWzWsLIYu9PCUw5JOkBj7hrJX/4q1nTIlkIFtMRpXmGH\nv4XErNPZsQt7tMILz1zggfJ/Ry9epdFwCdsxU2IRHaek6zFEClMtYOIMkTVRsYMZyA/U1fknWev0\n8D/wu992Hr5zlpZtIYRifWmRoFrFtl1MpomUITYuUguSTpuy52NZFhCjMg+BIEsjuhe7+B/8IeSe\n7ezcGXLBrjDQOoesJUR6iO5KwuLLJ6nesYNi6zqqPogMBI51jOHiKtnKw6RWjVKhh+7l3Jg9OxfZ\nGS1xKNhJkkQ8e9nlutZYroVlF3BtF2FbONKmqyMCE+EFFgjTb3fli9+2BKXAw6AR2iZJN9mXUAmB\nEES2QAuJkgYvTBEqw+2FYDvEvS5CWmS1MqmOcdebLBQ1w2WfYthk/uIcpWKNIAvpuFWssE3JsomT\nKE9h7nYwSiGQZErjC4vMRGTxdaRXZC2RqEzTUV2kkGgybj10kOXVFfbt3kYYxdTX18kMJFbuhI3n\nIy4v0Pr4J/B/7AP4ToGkD/dGUuGR53slymD1k2+9TQar5kMjdD7jG+0iR/AtUulzHGVDFa5Nnpdl\nWxoiaNcvwOLL+MUhyhWHlc4q2cIpJoMBfvIH7+T+O7fyub97jvn1DgvtHt1Usme0gLAh1hkdE9Or\nD5N6irjbxSu4mD4nKX9fGyq9jS+Y12RviRvvUZN7CGlD3008z19LVJ+vtAmE57XqpRsu0QBGMzZQ\nZtuD7+FrX32CY19/lpnhKZIsY2CwQL1e55ZbbiHLMq5fv0YcR9x77+s58sIRtmyZ4M4776TVajE+\nPs75c7NYwqbdi/E7ITM7d1CpjmG7ASvLi6yu1vHcImmcIV2XlWaL6y/Nc2hrFSVmaRjB2RPnuXjp\nEaamtvLWt72D7tIcabeNJxO+5+33cdst+/jT//l52q2QSnX0BkKI0IRRF9/bfEurFwviFgQFMMpg\nI5Dd3My0WoFmy7B31wQ/8ObvZ/3cy8xMDTMyVcWUXMJeQoTHsy9f5tG/eZFGDAcOVHNFaSdDZIpK\n2Sa1YLxQZMuuQa5eUowO2iym64SdPoPLSJRRGKWRlo2wNZ6QuErjeg5CaoTMeTNOfh8hTaFakGgh\n6WYZRSc3JNQasHPzRCUNFcch1TZZpLA2OT25QjaPelk/8RWuP/bnsO2d9NQg7ZWzVAYqnPna5xgZ\nn0TYMNdIEMUAa3WZbldTli+RWePccuhmWqnEUz3Gbr6XuYvH8bhMbdt3kamUE09/iVvf/cN06v8/\nZ28eJdl113l+7n1r7JGR+1pZlbWoVJKqtC+2vBtsyY1t3HQDNk1j3MeHbgaY0wxwmOmZ6WYaOG7A\nNGObZjA0jE23zWbkRbZkydZi7btUmyprzco9M/blrffe+eNlSR7G1jh9z4mT8UecjIhf3Pfub/ku\nTzIsLFSaIoWhEYDTnMIvDlMrj9Ffa3Mpusxm5xVWe3W0GOKd8zm+eNxk/lM2HDkwxMnLLYbLRaJE\nYTsuXtUmb0ZR/Q0a4gzVyggDEfH1b1/k5jevMDZVpq4Me3a5d7QB9Ov2r5A90ca8ZrJqdm42gu8a\ndWEI4vcSnRlG9U/j+z71zU1ELMmPTbDdbDMxXKboK8qjk7SDbQpDZVha4/KJp/ix99zMf/rPn+KR\n1j6YsJgKV3nwi7/D7e/4MFv9AaXJa8jPHEDuvZlSfZvW+ja5yf0U6lt4lzu8cnqdw3M1Olsn+cs/\nf4af+Gfvpt3osnj2O6Qba5SHrke7Qygnj3FDamUYBLuzbdH9AFkoQW0M1arT+85X6T3yRWy/yOxP\n/w72vv2Ub3g3aX0d27IYFym600QtPcXm9kUsT+Md/zKxKuEkKWrQh04dmSTIzhI6uEB5apTC5n1Y\n9hDCzWP5Lp3mgNagzB13XEe1fo7i4DGsdDPDzMUWzPhIIxmenmB1c52JtEH6yN+iBJRsC8cvEW71\nQBi05yE8H2GDqNYQgxC62xj1+uHRjiy+8asf4Zf//sL3jMMbJjxxKon6TRrnzjF9w63EcUwoBI7U\n5Gwfg8LPVbGEDSSoJMEQY4yiW9/G/x//FaW7fpTKTIH12GGgYyxlI3PDJJstTNomf2QPnqVISzW6\nqk2+H1CSKQV7mWa/wuR4Sj0eZdTdotsbQvZOU5mC/eMvEsi38PLLF7GQWGTtaOnYSMfHIEhNmKlD\ni8wAFTRGitc3vbkyXNBYuyxGTaIQJsK1LAa2Raw0ji2xE0Pa6eMAkSWwRETahg+Pl7mYWJwZJPxI\nGlKyBJfDPvuHRjmtLXKuYe3L3yKOQtI0wRLgSkkqMtq/bSRCKWxhKLgOYTzAtWx6aZphqmyBVoJT\nJ08TdgecPXUBmfOQxuAFEdrXRColNik/ddfb2CgU+Zu//AIzH/sorlJoBL4WJCJFqaz3osla54Mf\nAsUjRTZevILbscTODUZmkvRXKOlc6awACLCE2El6BKEeoPqrpLJC1bVINlYolT1MEKB7PUItmHZc\nfudXf5R6I+azX3qO4xfXOWUMw2XJ4SGXtB3SSvsot4SrXgJu3RleZeBpcWUMypXPw+sl4ZVOjCYT\nTDTitWQnVWSPVBAnP8ToZuf9r7yrRmChqdgp/Vjzz/75T/H0g4/QbndpDDpEYUSz2WRrc5Ojx47h\nODZxFLG5sc6e+Tls2yYIAkZGRqjX6/j5Es1Gi36rx/nLazSimLxfwc0lOH4ObXsY4aB1H6VTHK/E\n1MRV5Kw+/cEAmVqsLHVwrYMcO3IzUSsg7gxYmJukubXGcG2IajnHDceu5cVXz5MkguHcMMaOaCXN\nHdDq9wYOvtFKNWBBbxWcSpYkjI4KggBUCvnUkFMhX/i/P8PcyD7qjZDiUokf/cB7CUyfsdkF5gYW\nb5tb5MlmxLXXHaBeH5BXGmFJaqrG6kqPQs3FWA6FUomiJalqi36SIowmCgMGvYBBboAvi7g6ByKg\nmPeYNIIkMJRsSLXJZDHygkZkSLTAaEWSAJYh7IOxwaoICq6mYBxcz6M9KBCFA0YquyuyrgjBxq3L\nLD54P72whlx8Cuw866ur+Dmb0b0LJCZPb2sJrzLJ1EjM9pagvlanMjqCwFC/+DzV8YOE0YDS4CJR\nZZq0lSDW7iOSh9l/90dpbZwnP30jRsVYuSobl1Z4/Ntw6DaXWj9modRiz8ICpeG3ceHc01yIezzy\nyNNcf2Sa66o5Xmh0mR3K0+8nRDEgQmzp4loJe2uCiX3v52Na88DLD/LAc/9A0vF58dWQH//IZ/n8\nf/lZrj68ew0nKUDvGN6/htMz2fX7mj7xFUwhWfLjCQjkncTnryburiJtcIyCVFCqTXDV3jmOHfAZ\nTi7R2NrAr+a5vNxkZrrC4OFzLFwzwV9+/q95ODnCdfZJlpehsT2gYqfU5qeZu2GK8ZveTv7QHaB9\nhDdCdX4YMzFJv9GmVn+ekYlZ1je22epa3P3TH2B0uMjcpS+TS236fWitrSDcHNLLM+R4BJ11gsDb\nVWx0vw+OjfBymHwB02ogUx+VJOiXHkaGF0iPP4qDZLB4HPwKic6TFErEty6QD1ZwREIUbmPiBvmZ\nG7D37KNcjnA+/G/ReOjkMk6aolaeQ595AdKUIdFjf2UM0biATJvEA4NfAlnwMW4eFXRQXpVCocL0\nXT/G/GO/i94GIwXlvEKWbZzhEazqUGaNISW600E1Grgqxq/ufMHUsN2F50OHH7n16u8bhzdMeBIl\neeSFUxStHPOunQGVjSY1V4hQLipJSFSMZVmkKUgN30lTbv3dzzBFxGwqONvu4yd19sbnUUGRSHWg\nVGV4zwHieAOrnRBffAU3X8WYRc5onyX3doy+meL5hJu8l7ipWGFibgknVyZc6mBJTUk+SWO7ghHO\naywpYQxohRQWYRTjOmTdDyOwjMyEuXhtBM4VCvJugSomCglTi8SyCR1N1bIIhUAXfJwoAZ0ipUFa\nFuOu5pvdLuX1DT5+ZA5bxTTOXWY671MbtHj39Uf597/9SQZrW3ixxlUJ2s7Em2b3zLG5tU0xL9GD\nkDAMCJUi79mQxJAqfCFIg5A0CrFFiTSN8T2H3qBH3rIwtg39FNu1IYz5/Fe+yc/91I9zsFzj0p/+\nBUP/8sPYBmJtSI2hL0EqQAosDf0fgpZuS4Ets79yJwHPJls71RVX8DM6S0D+UYNESI1tF9lq1Vh+\n7u8w3QG10SpLK21KPriFEkYKNgfQ6sX0ttt85C0zpG+bp98NeOjUJpd6im6qOftgwKnTOX5l4Wns\n8u0YozCv1YH/WKpdvMYYM5AZKposwdEaVCqIlCFNIYqzRxjtjmCsDHgpxI6k2e8ylishJKS2RT3p\nMJQrc+DYtUwePUB95Qxjfolo0MGkAfNzk6RRn+bWgOHhEXSiKBXz5BwP7Rva7TZJktJoNEjCkH3z\nC9SqIwwPjeAUC/T6fS5dusTtb343SRLw8P1/yztuOcyxQzOsrV+ktZHSCWK8ks8db7oDszOWUZZD\nN0k5u3iSa/ZPMT52kIvrLXqDkMZGg6Cf0Kz3uPW2mynO7CW1vGw0xL27is2QDZNXCVrjhvMvQaiz\nxEIIQbKhmVxwONtp8pMf+ggzE2Mcf+DL1NcTKoUJDD0GPcWhQ8fYd+PzbL60zeT4MLblE9X75F2H\nQm2CBbvC8HCRdqiwtMdDL17k3MUeR2/YSxDHFDwHx7NwvTzF/DAbXc2mVWTfiMNMrcj2Vp21Toe3\n3nGMkZEa93/jW1R8QW9g8F2LnKuo7DF4ozt7vwidULKZBORFnzAApGB9bVehwS35hBsrLH7pj2k0\ntgmay1Rm9tJeXiL25zB0yG+fRtVuYO+hcbYWT7He6lPcczMT5Xdg+xadtQ3KBw6w3gyppa8SL/wE\npeYFxq6q0TIfQkQBa8vLjDh9umf+Gu/NnyQtzfCuq+/GHyqzfHaZiWKBk+t5xr0t3nLgLKNj41xT\nqNI5dDXzwzF3zZ9jUBNsbPUJtcF2HIqOw5sOu/hiLzPJQc4++SjzpSpHkyIffOevY2yHh99Z5zf+\nrz/m93/v85xYUUSD39pVfK5YybCTzBidFShkUM5sFMjOdS0yinqYvBN9Zj/Eq0hRpFwepZzzUQG4\nVkxOal69948ZP3Ado9UytuhwcNZh/dFHudR3+cuvneHdR6v8TDnl1dUyb56x+MoLA77woV8iVyrx\n1o/+T6yU9uAEZfLNHhYpTi6HLI6x9+f+HWOvPEbh/vvI6yV8dxy1+g1yBclg9HraJx5js2bx0APH\n6WjFob1j2LkRTi1H9Lee4/vAVL73SmN0q4mwHYTrQLmGHsRYpTKD40/ibh3nzIsXCbuSkdECwgyQ\npOQOl+DCSarvuIny+z+G39vAtYs40WlSd4FEtRHlYZK0i601gRbIwije0Q+TnFrGefS/Mh3EDPa+\nE9WPiSfvgPpxnLiD7YRQFjjuJP7gO6z86V/REjA2LNEp5Etg2S1EuwmtC1kNasBS4JZATN2CGJ5F\nFnMM3ENUcnu4y6tRrF/8vmF4w4Tn9PIaKw1F2dcMIoHnu6QiRqbZ4RXFKSBxXAedRCitefr226kf\nu4HPNxvcMTpG1G9Q8x2kI5GDIVaaAqdXoOuApMeI0wMZ4lVsilZCYvZhRyuo0GNGGqo5l/vD2yh3\nTzHWOQt2jLQg2dZE1oC5SpGtrofteDiOh7BchMiApEprfEsijIUkfs09FyEwOx4qemd0IfXuDi1j\nFLES2AgskdJEIVPwXUnfKLQlscmcydvdPnnP5iQeL251uJ6QXqqJLcnV1yzwta/cw+alNSqWRbKj\nmGVUSqpS+kGfiWqZpfUNbJ0ySBNywiJIU3I71HstM/f2SqFAEsX4lkQajXFdwjRBJgm+5SBSQ2o0\n9Vab3/j0nzNWqVGcnSI6/yqD2QPkpSDJGmIMMOR0dtOQ/DAdjOxOo3XWTpbsjKuuAAm5gp/5R//b\nZE7Qtuvw7P0P8l//4OfJF/dxbP8oX7rvFNfvqfL8GpypbxKZrBsURPCWA1XeU8kRWDma9Q63XbuX\nR46vQKPJUMWlVArZ3nqZ6f2ZyCImcx3WO+1tswNU3vkI2fDTmIx6vqO3o5Qg0QaVQpxmPkmJEpx5\n8I/hrb/wA8dG6xSwUUA/AVXUGJ2glEfULdMrwsMPv8SrJ5eYGR2lHySUS4bp6Wl6vR7FYoGcX8T3\nc1y6dIn5+Xna3TZhGlAuVZid3UOqIiozE7iWTRgF9Ad9RssVUIbZqWk8z+fM4iq2TNnaWqNw/T5u\nue4gi6eWOLnWxMrl2Ko3mZqexPd9wiBgdW0NN24g0z7loQq2V8VCZGaWShMmfTApxWKVqb1XUSwV\nd71vPEcQhRKkj13pE28KVAC5IrgjAruUEsYQDhrUWyF+wWK4onClZmPtLKFwOJektJqaXrnIIy8t\n02oPOKov4JzcQJk5jt1yAwuHFnjupee49vA4q+0By5u91zS1cBziKEKpFCkziu3Kcou4Z2FtdBnx\nFEkI9z/xIrP7LcrXQHMd5CATdpMSyAt6NkQ9g92HKBSkMfTtjNhgUoF4Y1GQ/89qfPv3Kd/2caLm\nGXJCERbHCKIcY0dvoP/st7FHb2Q9quJ1OqysnUBO3MKQ38aevI7LLzzH2HW3Ut1eJFk9TtXdwyDN\nY05+lcge5/KZbQo8QTWXY+zYh9Brhvz0j5M7cCf/9n03sQxMdTvInKQhHTpaEAiL/UmO7aUt5mZG\n6XZWOLmR4qPZ2ghRAoQyXLOvwp23TaN6ecpbB0hb2wxXR0iTgFKlSqMXUK253DkxwROf+gP++W/+\nDrHZ2vXewXw30eCKJMjr4GUhswIGI0BnhteqPYlKBqTdAVbOwYq6XGo2iWLBtQf28fYjQ6w0Jxke\nqlKuuCSNbeIkoTS9h8bgEjU81jYD9uRDDszOM7T6Mm05ihqbY+/Bw7w6GGKmJElVQhxnXeNOewOb\nlPL4JO7cYeZv2KB/VjJ+9V5U7zaCpZfotjykshgrwWOXDWPDORqtFMfuEzcaeN7uzithFGYwyGLk\nF8CxcWf24e85TFJfRHXWsPMF/F4/SxCThDiMEGfOMPWBA0yOPIDefoqcewCTDoElcPULWeEaXcCJ\nVwALY01CotDNLTrdGjlrGBO24ZVvUc41cT2NrM0jhsYhUZh+G93oknQjXEtmU4VEv2bWnHVWMh03\nYQwMQN55N9a7P4rsn2dLz9Gz9oE3xhWVNz0/9X3j8IaX3OWtHls9j7WVbT74To1OJEZKhO2Tpj0c\nNLY29PsRl2p5Noam0dccQ585x+G8AW04mwQEm2vULJ+x4XGumtqkeRG8EvjOOkanmEEfqx7heW16\nushY2qQbJNyyNyAtTNJtzfBQ6Vpe9D/CmCdZyD3KvugfELGL5crM38u2sYyNFJm+eKzSzFlWmgyz\nITLdlAzQlh1uWhiElmgyT7DdrDRJcW0QUuIkWYJlJRoRS1LHRUsw0uBJQ6Il7V5IHsN9mx2OTleY\nHDNYro9tGaqFHDNDRcJGj8SojBEXK7SEPTNVFo+vIKIArSW+lMg0JdUJxraJEo1jQSIdkl4Xr1JA\nRIJUG3SSYiuDb0t6aQpCY9kuZaXwpCAKWgzOhHgHZrGn9tElo1nLK8MeIwgMGd14l0sblXn47AAE\ntRFIYzLdmiuHgsj0i4R4nS8hRNYVW1nu8Vu/+HNMjs2xMYj5zonLaCT3nm3jy8zpO29nv12lBK82\nY8586zxBP2C64jN2oc51V+X5zpah3xFMJJJcNSLuDbCLNhqzo7ydJWT/uMVkTAaezkZZoJVAKdCp\nIFEQx5kuzNKpZzj/yB8CP3jCg8q+rzKCQayIpSQxHk8/9SwXF89RLtV44N6/p+JGFAvDtLohhXye\nOM6MGuv1bVQqKRSKuK5Lo97g4FUHqbc3WF9fR2nFcG0M33HpdjocO7qAa9scPLTAM889S783AJ0Q\nxSFhqki0YKvVJgwVTqmGdkskQH5H/VkCYTBgfX2TPSMFYqXpDVoUrBy2kJRKJYy06PRbDHodgm6H\nSKVMTc3set8kscLzcniew2XZhzzggArBKRmGhj1mpmfYP1pi49JzvPmq/USDkGe+9UXuf/o0Y5Mz\n9HoRqdEMegPCKKUfKe6ch7jdZmmrybce/yof+8h72O4kNERIvd4mGoAwEqUStNZEsUUUpaSpZmJq\nlE6rS9RXOLbByltUp8CMCwYoooYkaho8JxupG6Pobhr0lsBYYOkd/EgiSFMwbmagu1vW/j2f+3M+\n4FSYfse/4dTX/htJq42xU84/fR9y7DYs1WRjpcNQrkfu4O10l89Q3ruP849/i1BJGg/+J7zaAdIg\nojq/n0LvHGlkqM4Oc3n1FLU7f5re6nlaX/kEV7/rwxz+6X/Nv/upt/LYCRgCtJBoLQnSBDcSrMUt\nvn0q4vDUGBsbAQf3zLF48hlEAZSRCEvz1qPzvP0de+g0+zTPlRgbdUhyeYTSkK8SJRHFkVHiKMB3\nHdz2gL/9X36Vu37rU7veO2rHU0IbdjA7vNbhybo7r6dDUmYgZ9MJSGKP1HLRSJrNLXxpuOm6G9g/\nVsYfXMDJl5C6i449wnrC6JtvRz75ImUuEWNIVcJGo8ncbIGLxUM0mwFvf9e72Hvt9cSWT7tZ5zsv\nLiLXN1lbXOLprYvYrsfP//xP8OZ/8n4qt7yH6sEjNE48R2niOrQZwuo9zca6Zv+UwEkN1aIiiDRh\nEkOkqI6O7So2IhxAojGpRiuN8HKk9UXCc2u410xibIta1WXQ7FEgpVfMkXgpZqRCdbKHrr4PE0eZ\nloJugQ5BtTA44C5ght+DbHwTufwU4vTzXD5zMxt9gzdzI7PJ/bgiQAculhNiz1+Ndcu/xLS2iL7x\nh+ilRYSSSMdCpwZSoJAxZjPz6Z0vEYH14Y/DwijbXcVy/GbKJkZWx8mlAU74CsOeANsHrv+ecfj/\nSXjajFZdTjag3okZqqRYWMioQ+Dk6A8V2GRAq9tmNRgwfctRLpx8mYlkwP5r3s5jG6tM2DAxPs2K\n7ZHomD2xRd9PiFJFIQyJVhdpnVrmVrvJ5FRCJ19l/7TPSriPlXyBtWXYjAXtuE04NkS/cB0v5w5j\nT/wEP3vik3ReeBop8xmt0NIIkYCRJCrdqeI1WhowFkJkFZzZAbIZMoCs1Jn41K5WqgFNJBWpzjLS\nkhEEOsVC4iqZMXu0IUcKUYQWDi3P4/mVNY4OVdBWyvmTZ+n0+7S2W+SkRRLECATWvjnmRcCv/Yfr\nue/zEZ/+bIO8B3nbRqSCPYf2sXRujR979w1sNXtYtsUjz76MbQSJ1sg0IY0ShBT0Y0UMGMfFN5ok\njFG2g200iW2Il5ZISfFSC09kpnkamz4a/4d0S9daobRGao3RMrOUAJQAx4jXyq5sxCUztgSZYmy9\nEfGvfvROhmoWo3mLV7cTyh5M1Fwa6yG1SpESikArtsMES1g4UuE4grgrGQjJ2lZILm+QOUNPpJTz\nmt66gzNtY383VgfDjvDCzn4QO8KCGTXemAyrk6jshhqpDKCaKkl9fZ3Oo7+Ftnfp7OxoOomhHRg+\n8+lP8fP/4hcII83f/M2f0mu+gCfHGGw3KFViIrdLaHrUmx7FkscgCigWygiRSfwfPXqUl19+JTPC\nHQTMz82yuHgOG5vcxCQjw0MsX77ArbfcxOkTL+IJxVpzkzBR+J7L2PAoqeVxcWmd/bPj4JZYrbfI\nlUdIkxhJSnOzQ6cfEDRazB6Zp15fYnW9Q3GQp1wdYs+BQ6xvLNMZJAzCgInxMbZXN2lu1He9b3Ie\nWCLH1PQUcWR4/qkWri2J+oa8C2kSMTkxQ9+a5HIzz3p9k34vYPH8OuEg5OLps0yNF7NRbtolUTAu\nNc0eLLgB2hg2Vuv0samNjbH3yI2cePwpzp+to7Vi0GohR8ewhCGOB5RLBYwReK5LHMbk8jZewaab\nJnQagqCXuaF7CXgy86ZzXYtOX6EUKCmwjMG2BSSgyJR9owSsXerwfHPRpfAnn+TW932QYz/3azz3\nd1+ic/4b5GrXUe/2seIWc5UOWtkU4gvYe68l7BxneOYGWmuLWOPvZfPkE1SqVeKVR9iK9zE/U2Ht\n9Fkie4RT//33KNfgjl/8LCPjNr9+95tYaudZTcB2JK4tQVqkxkaFCu3BxY0G3ShkIpG06yuEKUTN\nNpfiEDNooG46QqvbZenVAdtLgnwEJRVg5crQ71EcH6e5tsbsaJWo34NSEc9x+bNf+uiu946+wkk3\nO7Yvhtc0tTJM6w4dfedlaeKDKaJIMF4VIaBTX2dkZp4bD01xeFKw/kqXf7jn6/zyL/8s3e3L2KPT\nBOcv0i/to8tL7CuDEpLtdofEatCPJDPXvI22P02vvIcL2wmNe77GB9a2WXAkwrN5/9gNfKO4wf/+\ne5/jnltvpDA8gZm9mpHaCJv3/gWpFuioz+KaJBCCZ1ck+2ZiCnbMq6/WOXrr9UTfzb//QVaUZMWm\nMugoRloOg/UtBq9cZMQOyd20B3d1lWB+hP4Q6EqOkuzjmg6d4w1Gy88j974JcgvoMET4eYxdRRbG\nMdJDNM6S9K6l/qVPs3IGBrfeweliwJ3eOYwGy5PZzd+AOn0v9h0fR+x/B5b/52S1eKaFtNkWTI6B\ncCUkKiMZaZDGIKZGYLLNpvk1vPoXcCZuRtWPs9D4MygvgGdjVBvSJoK7vmcY3jjhWe0yNjRMs5Nw\n9sIyt918mDhVrEzUWCx26G5eRJsiHVlkbH6B5voZRi2f/B1v5vETJ4mqQ3ihQyPqMuttErsFctPT\nyMZlvI0u7bjNTLmAWyvyUDrL6LjLSHGU5YmD9HSCCtpECxJPKPJens6gR3Ntk/LQML4Q/MOBf8HU\nI4vIjkQisYS1MxIBlWRHmJDyCmotG/Ps0CWviONqdsCpu8x3nDQhQWEbgycMgRF4EmyTkgqJZ0Js\np4BCIozEVilBoYCx4BnlU9ncotPr0IlC7nv8ZVKTaVHYdgZfHXvPu2n88edZfKCBiBW1fA7Hzmi1\nBdthYjIt6GYAACAASURBVKjC7FumsPwiRvWIXYuhYpH2dh3XEiQGfNsiQiO1IQfEaUpqFEJa5IUg\nUGCR0Du9yHCSYEjpWYKitnBNjAY8LX4ot3RHgjI6G2kps8PCykZbqZTYIhspGpEdCgKd0eIth//y\nid/FTddI3TLLgWL/TJ6KNFQwFDyLZr9PatmEShMFKeVRCEJN3s+z/5oSUhiW1i36no3EYTSJufc7\nXY7dkuDbSVaV8HrCpXc0nwEUWWWodOaRpXQGBk9TQ5xmyU6SSNrtAVtf/h+od9rsVsSpnzqsN/uc\nO3+O1Yuv8Jef/SMOLdzEobkZwrFJUi159sLXWau3qBVy5HMOG1t1ltd7XHvNETAuUdDHdV1efvll\nSqUSzWYTS1hsb2wxPjrCnj0zDPoDRiamsKXgwrlzjNdmmF7Yy8bqZcJUkYYxjkrwHBfHctna2GCl\nt0K+mMezbYqjI4T9Di4WjWaTO+84hm+nOK5DP5I8/8RzjIxPMDI2QTdsM5U/xMTkAivLFzBxRDDY\nHZMEwC3CIAjodOuQZGBBrQzYhorl0L1oOCUu8sD5h3nv++6mJBL+/mv3E+QcFuZHGDQ0h+dH2GgO\n8CwbJxGUoz6r1UlWpq8n1zgPwKMPP8P+A3vIlVfotjYpFCFVKX6hgDYaIRN8L09tZJRzp8+RhmAL\nUKlGKU3ShlaSYRbBYMlsr0jb4Fk2SZgQx1krPgFkYjBpZiuhBwKRGuQuiRL3X9BcW/EQX/0Sk8/d\nx20f+QRrzXdw+t6/oiAFneYrFOZuoSAS4vl3s/3tz2JETHX4UXLePgaDAdWJafqpi2XlGJsaotvb\nIg41tckprvnIn7J3/xT3fe7TfPm/30thaj+5GmxfPMuwpdE60+qKkhjP9QgTTU4YmnGAvX2Rrq3Z\nrrcZ1Cb4sQ++jVrV45Xj59k7pznx1DaOiWgXPQb9Pk5JMztUJNpapmJZbC+uUMhJBisGd+Igldru\naNeQ6fC8htO5csu6orUldgo3IzL7odRCImkni2g9TtgNKeZ9irUx9o2PMF5K6V46gyf6/Mhd78ak\nASdOn+POkVFytVHOnnuS/cMOqdJs9DPSjLY7+CMHaHcH6G5CsRVz7oVXuXpli/baGVb2TlKbmmPs\nxgl+tnSMs2c/y7NPPsNb735/tjEsm9L+61h/9MtIKRjJK/5h0WWsnI3tg8I083Mt/FKJoN3eVWxM\nqjHavJb0GEeTeh7r9oDa5hZWMEFccCgUWgyJBiKwcZ0CnqN4brnMy7//ADM3nmN0epzSgSPEWuGr\nlK0L6+SLeR74wqdgDfZ94NdZv/0gy2ef4uD6k1xtv0zLdbE9gwmybo3ahuhPfw45fSuqvgE2iETg\nakMnMCgtELHKOj2ZkmR2Xpf3Yq0vEuk/wKtez9zyb1POtzHJWegeATQ4M5B2vm8c3lhp2QEhFG7R\nYW6sQhrH3HfkAJYOKLeWsUb2kotKuEMjRKM1csE2udmriTZXKF9cxBEOswfK9GXAJT1GX/rI7SUq\n3ijbfh8xPEscr+MvFCheSClYw0i3ihm00VGC8Cu4JsCyLMwgopAfxQ7bmBiSOKabq7J87EeQX30w\nO7jMjuYLECqFtUOKltrOuhZyh5klrng37bBkXgMw/+DLMopE22SuwwZPQ2wZaDfRkzn89gBd8dBS\nEBtw0xRjUpzU4vT6Ft7ls4z4Hue3tthY20ZiSGwbEypq73wL0dIKd/7YLHuui3nwsQ79VDFkW4hU\noYTg0SePc9WRa+lG6wxXKlw8tY5lO6Ayt3NhgUxS7J3uiWVLhOMgkhhDZhQqlMETFmG9l10IwpCL\nNb5RBMLgIkmNQprd93gcaZBkuiZIwxWTPmlEVuIid4TpNJY2SGmBBS++sskLX/sLqkMVkJIpH47W\nfB671GFDGw6OOSy2E1r9iPnhEsXpKomBrhcwOyOxS5rO2Zi3z7nkJMxd5eEZH3oWjz28wdF3NWkz\nzut0jR3cDrym/5PumOzpnWRH6Qyrk6SGNBG0OyGX7vlVijmJGEhcmewqNr3I4rFHnuLEiw9RciLO\nnXqAmg9x3OCaG34abbk8+/DLCLPKUM4l59usXlxndmaK0dEJBBarUYwUkjAMcRyHWq2GZfsopcjl\ncvieT87P3LEf/84j3H33Xfiux7e//SC33X4LfSV54vEnqAtwSQmDPq24RW54mgP793H69BK1oWE8\nz6XT7mAYMDxWJG63md0zzdrWBpWhIvXmNpZvoZHs2X+Effvv5KY3WTzy7c8Rx10WF3eHzK1vSVyr\nT7vdAwOelERdAzHM33Q1Dz34Em9959WoRpsbj95OtQj3fOZ+IhKuPjTFuRPbTE4M0RwkuJYkTm1s\nFaBdHys/RjF3kclheO6pk5x/6iRf/buvc/N1wzg2pEmCTkIGvS7NhmZj9UlefeVVdBdsJ7NJkdow\nCDQiBVdCrLJ6ySjQQqDILCa0AtKMpSU0mNjsUBV3nrPDKNrF0pbhm+cVhQM+kdKsf+IXmdx/kNs+\n9K/ZXGux8liftdVXKM4u0H7uHipzB2hu1+lFCTW3xeh4nuZWDjc/SxJ02F5ZZ25vhbf+wi8TD7o8\ncc/n+NQvfRWKkB+bIxo0Od3oEwDW2CzNtcu0o4RRoJAzGFtipKLqODS6LUxlil/99Me59bYFUh2R\nYvGT8k62N7f4q9/9Qxb25ekGAbLZpBr26MY9hotgYk2cKApjMwzadVaPP09pdoH9uwtPxqZU8P8u\nZjLm5+uvETsjLoXRhlLVwtZ5mknI/MwsRVdw/TULiM5l1peXqBVTLp8/w6G5G3jkxUvcdNMRTBpy\nYSPGLxRZW2qRc7Ou+/ioh7Q1tSGf3GCJUN3I2ovPcY2VMPn2d2I7CXHOpbUW4qUBdx09hi9ACYGl\nFCoCx/cpz+9n6enHCaTNk+uSD14Vorwxygjk2DiW5yCt3bHY0iSTWdip/TMAou9guRCbNjqoMz7q\n4vs17PwoxnIQtoUkZXyqwiPWLcyEW3zn3/01N70NXtkssWaN8sCT5/knb5vn2C89xqbV4cT5Vygt\nH4flJdq1Gfqdl7FkAo5N0pfoBDRDmCCi//dfxxqxEKPjiCTEarQZy8lMpTvMzmmVZONHIWBgKuS2\nR5g/pDE1jakdxXjXQbQNugtip4KQ359V/MY6PJ5EypBCr0neLxDYHiWtaW1uEvbyDI6/QP3wIYod\niSwXCMfnGSQh/soSQ8025VtvousGdLe7FOwmaaKhMsFaf4WJqUmceIArSgz0gPFrR9GDPgWriwn6\naLdCRIyWFkpprLzPoBswoVs03DxV38Hqr6COvQv1d19GF8qZd5UWGAlJojBGZzo4OzNduZPtqx3t\nE/FdipuvlwQ/2EqUwhIJWjqkYYSSEhkrhFb4sSJJYxKlM3dcKUiNySRYtUHX6zz5wiJV12K73Uco\nhcaQpgphwfgt13PxTz7L0h6Xb3854tyZPnk789PyBBgpKdiGU6fP8Bu/+DNcWG9y4tlnieIU25HE\nJktmIp1iSRuBRKXQNhGlK6OlOCDFAQm2ZSGUIqcNoRTYqSJvQWg0Agh/CJaW1BpjK5ROSVMBxs2A\ny7bBUmJn1r5jHioNFhpbuHzxj/4DJT/BGMmUSHniQp/lXsSt83mMK+mGkmKYUHBT3nl0jm4qiSyX\nJ04ts7GpuHXCYnPGMOgZKrWUbzy8TqIMKRZ33DLCsw/dx5G7f544ir7rtxfoHUNQbeROpWgy+nlq\nUArSGOJEECvB8a/+R4qDS3ScHP1BSKW8O6fD+77+JU6++Ay9+jKVUpH69grLl18h53tgErx8lfm9\ne7C9yxQKDlZouPa6Bc6cvkR9u4m0BGmaYjTkcnkGg4D19XXyBZc4jjhw4AArqyvs27tAu9Nh38IC\ng/4AFSoWFub52r1fZeHq67BdD+EWWVldISiXkL4DUZennz/F/sNX4foWcRzR7/cJowA/77O93CMW\nHvVmh1hZ+IUKG811pCd3TCPLTE7P8da3/zibm5d5/BvP7io23ZahNCzwpMDPGyZm4dKZ7Nrcu2eS\nh3iJEb9Af3SErVdfAs+hMAy64NGrBzRaA7ZX2lSHSoSdPlPVIu3uNk88cJJ8oUav2aEK1B2BW4Rm\n05A0e5QdQEiGh6toS2LZLjMzk2yubBKSdYw9X2BFGpMopBL4eU2QCogg9iDvGxwpSHV2sGB2En0J\n2ckGOs7wg+I1RtEullG8EMD4iqTqGqYqOSKWOP/JX+HQm9/L0Q99lGucAjrRBEpg0gFry03iQUTQ\n7ZPmHA4dGcIvj4BjkfMkg36fP/nFu9jchMQFXXYQ0metUce2UmoH38z/9pE3USpVKDqwdvkcj7/w\nAt984gkOFsDyLZbW+8zfdBWf+T/fSz4HrfYFLLuEsFxkw8fzHGI7IAl6rK2uMOY4RFFKqxEw6Cps\nVxJJQetCwHY7otPpUxgY7thleJIdeQhjQFzxMiS7tq8ovCMMlnAxWEgnh2Kd5pbN/L47UNLh2IEx\nrp0wnFuMEVaeyliFxsoXUfpGBh1AaTZWYgbkKBWLbNpNYmHQOsXPFXELZaYmRvnAT76f8/Ekz5Nj\nUCmznUpuODxPq9MnMZo4ihmaGmfp8iXaK5coDo8hLA9r/CDuyil6i0/wZ6ctSjlFow83VgpUR4bp\ndQNazTbVWmVXsVGxes1BXmkgjDGOg+1b4Eh8dYqkb2H8CkRWtne1gjRkf3KCODnA1+1bcH/l12jl\nJCvbHTZDzc989DCvnjlBr/MME+uvMHHxFdqRy8XAMKlt7jk/z4dmL6K9GJ1ahFsQp00cz4JZD+EI\njO5ijVaJkxQ/ijBSQQa7BbWDWwbOhzNMX4w59bkv8Nb/dQqTtkDWEe5Qdk0WxzHGR7dPwfeBOL0x\nLT0Y4EyPM/7xu/nqdTPUkx7VS6epLi7TPHIY75ojjKqUeNIl9nxYXuPw9AjxIOTiwkFGNtaIgzWC\n3ARyzxiTcUpr8RxqZpaw1yRE4EoLL+dQFAm6WER4VQaWQBXGUEmMUAk5x6ERh6TbG2wPTzCaNohD\nQb8wyUQBGrkStgEbDSYDpMZK48gsqbHQGJPhNPQVqwMyaqLZOczNLrsYghQ7gQDwjUCpjCmSxgpp\n0mzuH3bxpItBYts2dqdH3/YQlkO7N6AtMnPEXhpRzXkQaYq33czq5/8K3e2w+ILmxDPr5DwH7Uo8\nMoq01Bpp2+R1jz/4oz/H9T1KvkOQJOg4gVRhuQ42AldAalSmmaQtPMejF4VIY+HJBNtY9I0CZYjS\nBF8KEmNwtCFvLLpGY8ndd3gyHy8rS2h01n1TWFgIUpl1eqzv6rJYlsvjDz2LiRIGRvO+6RxFy6bn\nSK4/UGYQpRQFnFxrU8oXuHF+gpma5LnzXS61QgbtgEBKHn60T75gmKo59LXHzHyeVy4L2vUO7bUB\nf/uF/8ye696OO74HbfSO+WeW6BhzpbMDWklSZUhUlqdGCSSxYO3Ut/E7D7Ad5KDfplZ0mRjanQvk\n/V//JI52UIHAjMxy1bW3sXrhNJ3OBl/9yqcYGjtEp7XI8HiOfhqhjIW0NH7OZ2urTrHkY9kWnuPS\nbrep1WoopSgU8gz6fba2tuj1Ai5fXkIKw4GFfZw4cZyj1xwln88zPz9Ho1GnOjJG6uSJBnDqwioX\nL2/RbYVce/Q6dKpApDSbTXr9AUlsaNYDwlDQ6NUpV0YJE8no5DStlfPESZs4bLK1eZ7puWlGq0fQ\nUWHX+waZYaVSS5CEAi8PlbIgNoZLzz0BwNZTD5E6mhcevZe9scWBW45xfP0c4VqT9naL7XXJ8MQw\nrXYHPxpAYqiNgEbjFXLkixIZaRLLgpyivQNgV2lK0O/ilcaQQuH5LpVKibkZj8YgQvrZfjAWlF2H\nybxNKxhgjMC3s0TGtyWRMqiQHSormYz5TmfhNXy8uNKN+MHXu6YdXu3DQ62YH63ZrHUS2pHFodEc\nrz7+dc489nWEBNeDcmWW3PAk1fII0s1j8hAFAevPrNPavszG+iadDkQKQmDg2nQCTSWv2epGLDVj\nGJrhPVcdob18nsLUBMN75xg6dJBjhw/xwXe8nV/5j79N3FfccP0+PvHvb2LQXMSoIZTIEegujlNA\nJi7nnv4WVm6Ii5stwlJCiKCWd0iiHmkaYPsuQ9Ui9XqDnsjR7iR4jeaut06q5A4aj+zARCDlDoLH\nAlvapGi0UmAUwmiCKOXowQ+wtzbObdfuZfXlp/nyfU325LvsnR2iMFSlr6GYs5mdG0aV9/LSqSaJ\nMViOxnGhPUgpjk/S63UYqgy4tB3x7GU4vX2Zfzo/RdhvU6SHnSZ4hTxumhIRcGDPHrqDS3SXziKF\noTQ2i8lVcWeuZejw2zn5d/fzIxNZUrm9eYGVxVGmDk5Tv7ROdai0q9iE0sGL4telNqIUsCAnkWmG\n85U5BbqBzMKTdSMtMA70kwnmWmcplmP2rC2RHy5h/ALyQpM7zRbmgXuwChojIyraQczcglcr8PJX\nXuSfzoJQKXp0HEv3KRZKmUrz+CzStdBKYlSKW5M44QoyBzoSmZnjztIC3MYZLr3pZzly68f4+id+\nk/d87FbWnn+C0tw+ctUiW/f/Lcrk6SvJgaO/+T3j8IYJz/s+82sklSrDtsfamVfphYrc/Dj9/dfQ\nExbFl16CQhHHd8kNBuRrVfawzbempnAjnzGnyVZcxveL0OzS22jhTA4zgqYTppQrBSKqtANDznWw\n+ttM6TZ5kSc0gkE3pOC3kFYJR+SxzYBktMjKhR5zRw6hG+cJeinJNdfjnzyBcVyQKVrZxJHCszPF\nXKGzUYURIIWVTTH0DiX9CvVtlxieBfok0mJ4eJy17SY9DTktMs2aNEbGMcqRFPMuKonpJClxnDKS\nE1zsbOP5OcKgi+t7THqlTPbeVVTm5+m/coIoSTKLCUcAKcEALCnRJmuX636CtpyM552GhErjSJuB\nihgqFegHIY7M1IAVkJOCVGl6KsByHGwpSIxERTGWk8NWaQbmU4pEa7SReCT4GHq77b0DCRqpUrS0\ngRSVAHhZYikFwsrAy7YtkQ6cO34KN1ilVC5TtQUF1+ZcL+LY3homjpiuFdnY6jNcLdLXmm+eXOcr\nxzV5y2IsZ6FUAm6ByCtRq5bZCCOeeGyL+WmX6T15Lq9axNqhlnN48O//G3f/m/85a/MidthYWayy\n7k42vtIakhTiGNJUEieaS/f+H4Ra4fgpU+PjBEGHc53ddcAO7b+epC8ZLk+ytnWRXrvP+MQM5049\njwjWcFyXKA4pCE3R8+kMIuIw4sD+A1y8sIptQ606TH2zjjGaKIoYHq6RJCnrG+v/D2dvFqRJdp7n\nPeec3P+t/tqrunrv6enpaQyBGSwEQAIkQXAxSdGmFbbCXIIiL2jLti4syrZsy5YdDkY4QpLtkH1h\nK0xTCtMUSVMKmRAEEBSxkAAJYgYzmLWn9+ru2v99yfUsvsi/GwiLgFTMi96qIjr/rJOZ3/m+931e\nlrpdnBM8fPiIOPK5dOEc62vrvPH662yf2aLT6XB79212GkuM85yqsrxx6xjtBNs7V3CE9I7HTDOL\nrgx5kRP6be7dfcxau0luCyZzTZJ0OTocsLN5hePeA6w2zCbH7O++i8kMWXb6l5aUAps7isyhRe2M\nc0oSJ45RWusWRNqjvRGy2Q2ZDhp83yc+gv16xIM/fZNKQdBtEASG2XCA6iTkZQkFpMeHDCcj8rkl\n19BwlmeXFDuBZKqgKA2d7jKNxK9dnNZQFXM+fr6F89q89mDO+nuv4CUBoc558PgRDojCGk9hraMy\nFlNp0PUad/4iNNdbOIWeOEUX1vTTHGcCn/XA8raAf3Fi+N4lSVFpXi0UN84kNfhUSQLPZ5qOybVm\n1DugzFOcK8hLj8q1iFZvsHY2o0GDtjvgS3diHp28w3A6J582uNmb1+u0WfDml/+A1eVl4iCAPOdk\nNGVne5PNts//9T/893zmnS/zV39xiyw9oNlsMUlzGvESaE1lUjylWbt2jaVOj93emKIYcZLElP0x\nCSntOMDkBfnRCD+K6Gd9TGn/XJTuJ5lZjgWEVoIQsuZ8OYvRumZsSYEnHM5GfGTlB3lPZ4WDvZv8\nH6+8wr1BxY9fjVnrRjA/Il66wLnL72c2r3jphecJojYP9l8nTwsCArYu3eA9Vz/M2Wvfxfj//Y9I\nvJCd932QMGnyV1q7NN97lgeHDXqPHvAHX/sav/7WW/xvP/Oz5ALkoxGXr7QZDI/x1zZoYhEyJNo4\nx8rmMj+8Du1I1HqZZpvpdEyRdgmVZD4dnera/Fqu+WljCRfuUCtAY/CiACHnyBa4Bsh5CF6ICBuQ\n15tnmQ9ZomLl2mVevnlIr5ScmU6I8wOuPPsce/uG5mRM2+TIGEQWEwQz/P6QZDYGBQcf+V9Zf+ED\n+Pt/gK0s8u7nkde/D3np+xCHf4AbvEH4uMfR248JonoCIJ5oTcwizzEzeO//KcrIZ19oPvPf/c9c\n/Tf+Al2Rc/Snu8z3Z3QiCB89/rbX4TsWPPsVeG+8Qb68SWt1BX3lWWaNBul4ztl//k/g6hW8jS1O\nVjd4Lh3xYHWLV5qbyMM+w5N7HC+fZzM6ZOjHRNUj+rKinDXZCgacWe8wMMtMizmx7aOnY5SsGKgu\n0mviFX3Cbi3wy1SDwGrCF95Hb5aTK59xNaJhK2ZhjDmzg3n965jKB+FjTD1K8WSN3RR4aM89tYdq\nt7BFfwuG/JSudPbnAg/HaH6A9UNEnjPDILSlMJbQGDwN49yhgoBsXiEDj6NZwfLmWfb1azSEQhrL\n2DiErYGOa+YR13/oHIO9nOk0ZXQyYZpWuOEMK0OskmTW4oSrO1q2ILOSBIvTFaGTFHmBtY5iYVWQ\nEgqnFrEIGuV5WFsHz8ggwDhDlaYIz6MQ9ecSwjJ1C0ieO70tHeuohMNSU5ydkGhd4RGAhkBInAf9\n4xO+/unf5OUv/FMuPfscRw93WV5vc2zhynKDobVE+Ny+N6T0PC60HUkSMEgkU2MZTSyTqh6foS2j\nec7J0Zjnr7dpr/usbKzjNQ0vfcAw7ZQoCb/9O7/BT/wHv4x13lOBcl3wLCIjjMNoQWVcXfBUgrKy\nHN55haRd4MoIL2rUsRtBQpKmp1s7ewW/9Jf/K1rNFl97+Xf43O/9P5iwQ6u1hcln2Pmcfm+O8n1U\nUpEVc6T1ODo6YHm5Q7PRYNAfopRHnqf4vk+v12dre40PfOCDDAaDeu0Mh5w9u02/P8AXkkuXrvD2\nzTe4dOlizQIymqjZodIDwrjJ3t0HrKys0V5tYkzO9uYGDx8/QinJmY0dBkfvIDzwQ5/NzTPM84qd\n9TXu7B5CJUlnY4R8xNFBk7XuKmFw+nUjRf2wi6RgYx2GGXixQ1l451jQXBLcyhzrhyXbTcXNez3K\nz3+ao6M5y+sdzl9a4sp7rjPce4iuDH5gkFi8SHCy/5hJWSGQNBKHiurCwy5ehFprbr9xkzPnzzKb\nTvC8e0x7PRJtUcqnOE5pf7jJyuYqni3RUcD9/tdRs1qWJoRYOIEWhYxweEogZL22nowUhANcXRCd\nat2UgiUpuZQ4MmP4+sTxfBP8yvL6Y8P5rsWXOfHGOmGyzHg+IWgodCnR/jaNTspS9zqzvbdobt7A\n7n6JaulFPnS9y40Lq7x6+yHTSnN+LePVdx7hEKTasj8a03/rHbQuOLO2xf3de3jOsLO1yt/4T76H\n3uxNwjimsBCFCYUxrHWWScsS5BLt5jY3PrLHG1/6LHa9Q55nGCS2EZGVgmw2BaUIpWZ3DNmkYKNT\nnnrt2IU5QgmBkqC8J4pXuYBg1rl1gQJPJtzgJ1j2XuBTn/0Kr+7PaDVjrp9bQYgYG2/QTPrY2S5R\ns4FzluvPbvDqH3+N5tIWgbmN82OuXHuR+OJ3QdDmduuneWHF430f/16Wqznru2MsJRfWIuyky+ZS\nwic/8hHybI40CqtLlpINMjunSudIKUEY8GNKF7C5DJO5oxFBURS4uMlwUCGDmGp2Oovf/tAxiSQr\n0izYc/VatK7e5GEccu1n4BM/gLv5u4j8APnDv4I4/Ge4f/x36N15zNlWBzEZ0MxGXP6BDzMbpqSH\nJ7hRn9R16IocpRY5lrMhVTtmllJr2Ub7+J//b7GjCcwGCP8IcfgK+nP/ZV1cBQHCRbgMMHVB5i3O\nEbeQHgQgZwNKEuTaJXY6Gftv3kUay60Hh7y4uYKnJeLKd33b6/CdNTydmEQuo6cT7uRAJnD9I1YH\nx5jnLqNXtsg666TDKZPxhItbjoe9CdutNrPOOnL+kEedDfz1Lco7Q1o7DY4LjwlNRmWHxGasVcco\nNyeLVvHmPcKkiXUW63m4zDIL14hFSNTymGUZXm6IPMVh7rO2fpWV4UOOmh5VkRMmbXCCwhgqZ/Gl\nBKewCHzrsCiss/WH/pbOjnPiXyL9/qsOzxiEdfjOkJcVzhiMUnjZHG0tnrF4tiI1BQ0T0K5mzAgJ\nncXrdMimOYHvIT2J0jUksRv5vPuZN/j8wx5R7KMWwZYSCAO/tnlXBmT9WZAWg8ETiqEzSLdAbxtQ\nC35B7hzVQlXprEFahTQGp1TNyjGWwBPo4QCW1/CFQQqJEBBTW7JrmMXpDuscOIN19TkaUTsEKgTC\njzg6fMyXf/tXefj652kEGe0oZD68z3g+oeVZpPF5NMrwE0UpJbPM0SsyLjwbE0uQgWLaL8ALqYoa\nWNfoWrJcMhlEPHiQ8QPXEr6xO2VDJ7QvxezfMVy8JHnhY4qvfeGrvPDxj1GUFmPEQqBcU32NrgGD\nZQWVFlQajDXoW78DvsDmglarw2o75N1H+8RJdLq1I33CpElRQFmVaDukKhyr62fYv/MGnZbDl5I0\nFTRiRbvVYjbMqMqSVlNSFOXTJ0Gr3WY2m6GUYjKeEEURUkqm0xkgkErS7S5x++YtdrZ3uHL5InlZ\nMpunlGVBlmbkpuTs2VU+9rEPkuuM+XyEFB6mslw9d5l9ecju8QHtRszqepfjwYS8mON5HpPpMWGc\n2wNtaQAAIABJREFUoF1GEFccnrzDcNznuR/7eZw5nc4A6rl95EuWOzXhuRGCHkM5FWRjB5HjT0cA\ngpffTTl/Dub9B7y0s8GthyNWt5+l1Yg4LC1m0dVNpWJ1pcF+b8wgg/MNyb42tBVM+jAxmnIGum2o\ntCNsdlha6XDtmef54uc+x/69d7h4oYuMAx7ce8A33n2T5aTFfFRQjiFwiw2T1PhKYRfwF7kodKhJ\nGbUrsT71+tHzxB/9r3l8PRdcUZpNH1Y9Sek7vjiA97cdvtLcGUo2GgHZ/oCl4BjlxyRIrJa0Gpbh\nZJ04u49SSwwev03Sukrv8U3y8DzO83nuufPcf+cmH3/pRX7s4x+l1Wxx+9EhvcGUDz2zydryBqtL\nXdLZjDuPdvGbTbQ3WGTmxVib0WldRfgeo8k+kR/Taa2Q5iU/8hMf5+//ymfxijkqbiN9n+E0xVmD\n8D1sUdGf55SVIystWXn6YrkqF7oPubChW4OQIKVEIOtMNQl55bAm5a35nIff+Aqb2+fxendZbTdo\nhT5Oz7n2nudYWpPc+syv877LKxwf9VhuNwiSBof9Y6Tn8QM/+8vcniX0Dk/40Aef4cMff4nVQJP2\n+1zKezjh0NmEMPS5cX0JRMz+pIfwHJ6QCCuY3J3RfW+XodaIQEFlwZOEW88wL2BzSbC97NPqdEhH\nh7Q7bVxRotTpuu7eXKCD+p4wTxMJbL0ZNhJmBuGViHaEjdbgT/8JvvuB+plYRowmlrW05KNnNe2V\nHZZ+9Hu49Zk/wLx1QKMdMNZNrDnCcxDIimZkKfMp4xn0qw06b/1D8B4hczBVLebPx6B8UI05BBY8\nUzcgaorMN6cuVtQ/U1NihKTEJ/YtTse89icvs1u2+MR3X0EPxxRzTabst5PwfOeCJxARVbCKuCDp\nWg+ze5v1JERfPIsnIvZdRW8+5eLRfTZXt5kYgQqa9N99m521M5SRIOgsYR7tk5kIU7XxyxF61oT0\nZdSFC9AvyKUiGd9DehXZrKIZSpy/gVtSpAWsegaTZ/SyOgQtSnwoMuJxzrHXYe3sJY4rn2QBtisr\nuxCuyQVG/MnN474FPw6Ghcj59K5rgny2SNy1+M4tLKkSVRQI7Yi1xo7HdJZ8jM7QvR6l59NYWYUi\nx3qKKPAR1mKsRjpBnuU0JIShw5MGZSXaA0/UL+DcWhJPkhlBpCxWKaQWVEYTSIl0Am0MgSfrF7YQ\nNb1ywR4yyJpLZA3SWOIoItcVeB7u+IRweRmlHZUySFOvt8r9+aIlpKijPpSDcuGKsKJm3bz5xmv8\n/b/+SZ69tMPF7XU+9/oBge9xZm2Z7W7CjowpTcmzl5b54s0+mdEM+injsuTmI0GnoZgZD6ccUVJn\nfy2v1nbL1SsJjTZUM8tb90p+/LuWeXNPc/i1issfDgiqCMo+9x8MeP5jTxLP6xeTNvUYq9B1dISu\nav2OqaCq5rjibTyvzWbHY2WlwaPHe3TaTYaj0+22Aio+++nfoBVf4mTUI4wElXRUWIajkm63wA8M\nw8GESV6ysdKlLCqWOl1wsLa2xvHhCUZbxqMRzVaLKIrQ2rC/f8BsNsVZ+/Tf2q0W29tbZFmG5/ns\n3rqNs1DkJXEYMZkbLl+6QBD6UBniaJU4SlhNWoTCJ+106JU5oaeIfJ84CvGbDYwA7TRBIyC3DiNG\nbJ2F2fSERw9f5ey5q6deN1R1C2ReCY6OHV4gSMdwfGLBh61NOHgA4PCAj3z/+9nYXsOr5qT6MWVV\n1t2DJ1l5ri5kfaEJFVQeeNayYmDLF7SuWV78wPeyujvkK6/eZXnrAs7zWF1ZIkkSGnHEcFjx3HUP\nT5d8//svMwtiVle7vP76bd7Y6yGkgwKiuN44xX49xrWGJ3kqiMUoWiiBMzUKQ6jT7bLGRclNodCm\nookjEY4zIXxj4tgpBJfa9mmn+txSQD4r6BY9POVR6TmRf4Jo3kA1E7xpRDp+SKOzxPTxu6zunKc4\neJcXz3ZI7/9TNi9/N6pc5ruv7tCInmU27JHfv4fc3iGWPle7XaJWwjuv7HL+qkYLn1aygiXDlorl\nzllUAAfH94iiNp2tNf6d//iH+Ud/77OshiWzaYFwFuWBnuVo4ZCmHuk0goJRcUoMNdTmlIUzzuga\n+KikRQiLUvU6MAa0EejCMde/i+r8JfrTQ3Z2dijnJxjRRcUNpCnIR1AanyiomFaadrhMI9HY4yk/\n+Av/DWrjMtsnPW5cu0BlIM9KZOy4qnv4wx6V1UhlEdLhRS0G4yG2KrElyCjCOYuaG3Ij8XyFQOE8\nCSYhOn8VEDhtuduDj26FCFuRpYbtM8vodHqqaxMUJdYozCJEWsGiMK+1iW4O+mu/hffwy0i9h1wG\nM6QGRwmPzMacDErOoZFOk3/hM5w7+gIPHzmSHQ8aK+RFjZXwo5JmOWGej5h4Pp+7W/AzzSNMUVf9\nngdOO2SjtqM7DVJkCOeoqoVr9mkCrHvqrMOUCCFQRrPJYzY3mzzz/EXWr64j0hGlSEinj9CtlW97\nHb7jqjoUHhfPbTHqjWlORrC1Q6kEw15KGuVkec6ShdUXv5tjLZk3Y4o7d7iyEjMMNEG0TXNtk/7+\np5kHbdbSKWU+RHhHxEvLeLduksjHWFPidZ7HjyKM18LFMccoGpR07AkiXqeqDFb4lEsJvp0hpcKa\njCCPSJM2utmqk4SthyWtAzdFLVZ+EmKpXf38qYTFcwKnZM0n+JdBu//K4/BgH8/ziaKYbhBQWEMg\nIwpticsSqws0FqoMjEVUBXGWcyJDKDJUVWIiH8/VdlasozCWRuWQxuBJycyUBE6hgUoJfARGW0Lh\nsM4DUyKVR2YsylicMXgSppUhUAvx7YIyXVpDgCRQitxZKhymrFBCUlkN0xGiNFTCElqBVWC1JXQQ\nn1bgxDd3VzX7wWCcW0SQhvyjv/kJnn3mGU6mc17/6k2CMKAbe2xsr3P/5ruceIr3b8b0xiXzecly\n5NGzjmYUMq5S0qnHjcttErGEVT5e0/Fo7Gg0UmZHGdtbXUylccuCz93d49xmm4uBouynHPYUm9vL\n3Hz5t/nEX/q3sFoshMr1C6nQtW5H67pTpjUYXZGPB0hXcdzPuX5ujYe7d4jjLW49eERnqXu6tfPw\nXeKgRZZEjPOMNK3YuLBDMddMij7HYzjp93H4RK5NL0s5u7HG/cf7IH3O7pxHCMk8nbK83KUsS+I4\n4vj4mNXVFc6fu0iSNLDWsLf/iG+88Tqm1Oxsb2GtRxiGdLwE4epOUhwklGXF4fEBm9tn6PV6NMKC\nZEmwPxhhw4j+ZMrz7/8uBo/exQjoHx0jwwgvDsnFkEl+wrnzDfJ8RKcref2tf879+w9OvW4q48gr\n8AtB5NfPOa+OtQYFQVLvTpZjwU98F2S3X+PlyTNcee49iNacwAh8qTD4deafFEznFctJ3WEpDMwF\nLIegJWQ5JJ1lZnqM50mmkxOc0Jh0QJ6W7O8dUfjwcOyYGrj57us01jdIIkV/OK2dV2aB+nKi1lR9\nq/3qyR+FQKg67Vw8saef0qb1UgxfnRt2K8UZUREIiIVgxYOJhjeGjmdb0AgExxODQ1JWjtDXjOcV\nCE1n+FVWllvgdyiISXsHNJebVKPbOBkwmU/pbH+A1oUXMVZTFYbp8AhrBWevXcd3grYCLQOcs7z7\nG7cY/2SLa+81qNYKlS5Jkk3Kck46GxElKwzGAwaZ49/96e/hH/y9zzJP83qconxUYTCutuLIIKDK\nKlCq7mKe8rD6mxRl5KLLbOsi9AmFubaki1rPqTzSPOXKzmWGgz6BH9Pr9/nJT7zILC159w//kGsX\ntjBlRvFoTtxo02xpHuSbtHYuYSzcuHqBUaqx2rC10cLPpsz++CvkFy8SJE3mh0NC66jchPl8irU+\netFhaXg+Qht8E7J87VIddrpIVm6fu8xaUCF8wfFY8u79R3z0g+9DSV2H6J6StBwu1qlmkT5g66aA\nW0wEnKtRJnZvHxGA9T2UNbjKQaFwSUDUiUnv5/hJRcv2mcUvkrz3JmI4xC9bjESLtpoiLfgyI6pm\nHGaSVkLttjIO0XDIi89hvHOE3/+LOD3D/M4vwUmFkzk6X5zPE2/WQm+LBOUKIhxGCHbcIXa2wqXL\nm0TVHsW8QYrBOkdybvPbXofv2BdrqYL8wR7t/JikPMILHD0VkQcOpimyqrCdJscHI/Y1jKcz1HhC\na2mZ7tISst1hfnzAWqdNZ+MMru2TrV4k2tghthGN89c57Fxj2H2eTEBfdZl5PmmVI/IpaVrSaLZI\n05JZskk7FDSyCr1+HuE1MF6brBUQtjbwd87VYkFlyMtag4JQi/gCsYiXEBjnUIt0KGHrXvST2fpp\njlbYZC1u4VtXi+G0Ic0zcpOBrpjlBdYYZKER8xQlBF3tKKcpqQM/ifEFWGMJbG3jVkrSXGvwIz/1\nIpmTfPQjN9g6t87amRU86/BdjQUvNXjOQuEwxhAK8BfjuWrRmXGmditIKTBOI4wlc4aJNVRaY7TB\neZBhoSixx8cYa1Cm3gapsh53lQb0n2OkZUyd4m5MjQcw1lFqg5USmaxx72TMPC3oNBI2Y7VwPXiU\nuiJNKyQwn1c0A0F/lrGaKFQU8PyFhPde6bAUBhwOM76+XzEaV1w7vwbK4i87Th7O8doB62cd3c02\n0ya0z3vsval59eUDRAWu/8fMhkVd3BhHZet8LKPrgNCydJSlRVcVVTWD+TG3do9IGiFZlVPokHGa\ngXLs7X97kdyfuXaWGuRphXBNkkYbJ0LSvCSvKp67sUOUhMzzEYIcqwvGoxSlYsrCIFFIxCJTq4kQ\ngjRNybKM8+fPsba2xtJSh8PDx3z9la9ytH/IuXPnSaKY3vEJCDg+PqYRR6TTCdPRAGNgb+8RW5s7\nNXXaOdpJwsuvv85JPuPh0R6tpSVyLRkPp8wmOdNphjESa32Eytnc3CAJu8xGAfNxjKsMJ/u3T71u\nhACkwEmHwTGfOY4Hi5uzAqfhwtkmP/GxTcK4ycHLGnXYo9Xq0t0+g/Q9lHAYaxAWBB5KOpqNCBH5\n+NZR4NhqwVoDUgtf+dObFOmYLCu5dPkq6+tdrj5zjatXr3H12mW2Y9iitp47K3HK5+XXbvKHf3Ib\ncoGtwPcFQSAIA/9pl8UtEBD1W5gFuXDxQWsE1amOrUDwXt/yMDfsVpK5gUDWTs9pJZhreG0IB3PH\n/tRxMLXszxyDuaWyjt5csTt0vL075d79Xcb9XeZlRitOkOsfZWn7Bpde+ARmcI/+W1/k8PZDzGTI\naO8eZCmD/gGT0UN2D+8zvv0Npnvv8GDvdf6Lv/Zl3v76BBDEQQNnSoQMMEZSVh5GO8oyY640v/DL\nnySbaCprKYqauK+UQjvFPC0AR+Yc2ekfOfX4X9UasG/N0bKLl31lBHqBIvGkZHySoPyAVhLS3byA\n9AOSlU3eeuM+rsyIO22E9NAuYKkZk2lFZT1+7Od+keW1VYIkQXgBSeizs73MUiMiaMbsrm1z/JWv\nI+c5XhSgtaUnNLl12DAibLUQsib/B8rHG84J16/WMzkErpxTTOa0PbjVlyS+IQgC8jSl3Wqiq3yh\nwfzXP8J6T01lFxRq4AmNzmqBM+CrJqq5ibXrVJOE8r7E7AGVAF/hCQtrK7hGjOmsI4YHaJcgFQQi\np1IJrgKsIPQ0pvAoKtiJDdKv30VirQ0f/knkc9cxNsCVE6So6vNZaHesfrJ5qE/UubrzrmXEyEB/\nnnF0+eeRnuXyT32Uc3/rNwle+j6KTKObXbbXv73g/TsWPEn/mMZgF3cywKxcZZobqn6fMi8Y/ovf\nZyf0aMwqSjTFaEB8dMTGuQvcJ+aRC9lSORbHOz0fHhxAPqF0DpFVODUlzU9qnggBg6DD1ExoYPGl\nz1IjohtbgrxAeILZNMcIReU59FTTzHtorUm0gqKARgtZ4wMoSoMv6tiIOpyytiZKa2tRJNRqfVfb\no58yek5xGF0xKXKUs0yyCltqKq0piwpdFtgsZfqNNxk9uk8xHlMFMQ+mM3SaIuZTPOEoKo02ltJW\ntfXbOm7e2mNzOUFPCnRVIQScObNKXmlyY7FYfOFQyic3FWVlMYUmtxpZabAaz1oqK+qgTuswlUYh\naEhZvwyASDh0lhEuvn/6+ABlDFJbtNZ1MOxinqOq04H1AIyrIzdKa6h07Vwx1vHw9uusLzUwuqrz\nmpSlEfhk85wHjx+ineTFcy22OiHWEyjf59KZDqVSLAcen3015/WbM752c4j2SkQ65ExLQzbhwxcu\n8OK5c1zdWqWblCQ7Pttb0BkXPLg/J8srnCtpRooL2wmz/mHt0DKLsZWBohKUJWCgrAxlOaUsZ5hs\nHy8AV+bsH0/o+oKj42MCJei2T2cRNaqisbxM1FK0l1aABpW2lEXKcPKQnbNnuHjhBgifRhKRpTkP\n9vfQFq5efZatrS08z0MpRbvdpixLjo+PEaJ+SIzHE7RerJ2dHW7fuss8zbh9f5evvvwK3ZU1gtBH\nSIPEEYdNsnnOwd4ezWaLCxcvcHB0SInheDbB+ArtSg7279Hd6NLqdCnLiuGwx727t1iJG3jG42B3\nROKv0I7WaAWrxPaUKGGgtAIvAFvCaCqYFzULCRyyITAaLt7YoVx7hq8dOV4r4cFRzv27ezw6nrN1\n5TI4w3A8qot/CVjLZJrjOUclYL5w6pxpCFpKcvveQyJP4Hker3zpC7z5tVfZvb+H1hVGG7IKDqSj\nsaK4cP06z1x+hlKbmknhnuC1LMZAaS1aL15G36QuPIW5y7gmeAtHjdo/xTHWsBxInokcIyN4bCUT\nA4GC9dDVgEwLtybQK2BQQD+Fvanj0UgwK+oOZlpYRrliXHi0mh66yoncMVF7i6OHD0jWrzOdgMgm\nTE+O0cWcw6O7HO+/zeHJI3QluT1O+fKru3yjanMAvP5WBZnHcLpfs7fKgmawRpYXBFGTpUabRtzh\nkz/6fo4ArS3CGowTZFpQ6ZpFhnUUecXa2TOnXjvGCiwOt3jG2wVywi3iYr7ZUKvzFeNV2DjbAgex\nZ9laXmK5GXPx8ga3jzKuXH+O0XjM8eEx7cghygEDndCzTe4+7JM0Y8DR7TaJQ4/AV3hK8t5P/CBf\nHPfJJkOkKRk1KlJVgPTqcaYfEIYRVio8T6KyhekhjEE5nDaoIGBlLcRzBucEz1++QKBLiqqi1WpT\nFN8ervdnHcLWEo4nXS5nXX0uzqFNPf6VDYE69wz+S+8n/OEfxfuLPwcf+XFcMePWyQjlCezGFuPt\nj6CNxGUT/O/7RTxXEqsS6ydUFbVRph7QEQlYTTRGLqYNmcN+4Xcwv/+ruH/4byJ+9z8HVXer3BNo\n5CL/7OndsbhfhJG1VMI5ws33MBj1EbPH/LP/+m/yP37qTYrlVVpNj+Hut4+0+Y4jLT9qo8+cJT8Y\nUuaO2BXIdoiaKlr//s8ztyGT0YAyShC9E5S1HFYVnie5fG6bk94JxaBge/Mcwh/hBkPOj9+iWr5I\nOcmIuptsVhOKIEECTSeZxx2UqchcwaoL2DcGezig1e5Q5jNWVzZJ5Iyx36EqFHaqifxj9Pkd3J03\nKbVjnpc0AwVS4pz9Zvq1dXWulq3TxS3fMss6JYfHx+FwSG1IpCREIK2jcoJR7xibzkmaCeXRMVYq\n8uUl7HiCGw+g0aHViCnmGVIKUCHpPEc6+Ms/8yGuf+g6f/vXruIcDAYjfusffKVu8BlXC409SVrk\nhFKisGglQFekglrobJ/kYVnKBUlYOkteGTxPoaQgc6KOd7CgbY477BGUOX5VoYXCSYWHY+Yg/XNo\neOZZicFSOou2Ck1NnH78xhd52Juy1gwJw7p1PfLg4rk2fidmfDjiUlvQnxUUheHqasS8cmw2Qs6u\nxwRRAQJUYJmmiq3NhMNRRn4yIz+OkFFEd1NQ5JokiGAYsJx4NKm4+qGQH2mvcvekR2t9mSLNyIva\nhaUXsRGmclRVSVmN0eWMqswxpmB+dIft7gqDcc4kN1RKcG57mft7fZYbp2u/L28mvHHzUyTBl7HG\nJ8sLnOixvd2hFe6gVULSXWM7iBE643s++gLH4wGXr57n/r2b3L//kPW1FaqqIo5jkiRhZWWFLMu5\ne/cuQRCglGFre5PDo2NWVlbZWO3iB5I7d27TaXfY33+EtXD+7Dke7B2wub5NZSsG/T6tRszZnTNI\nP0AEIWk2Z6URoJxAq5jDoxOWV3bQYs6VZ3fo9QbsPR7jq4TtrS1eee0brKictZXTtd0BqtShpaC0\njmEP1Lc8oSyO+RySZkJvMGHkRezGc4rhlMt5weVnLtPtLmP693jn3X2SCgojOEklWVGxlkh+7EqE\nkiEdGXF46wi/GfKRF5+pO6HVA1546QbJ0go3rl2td5ZVQRzCB194njJ4jFGCwXjCg1sn9XxA1Xoc\nawWVsZAVzEsHQf0SqbutPG3LL5in9Qv4lLSH16aGUFoSC2eEIXNwRD1OOythLai3dpmBRylUVtAN\nHKGEYVEH+QZS4EkIPUegHPeGllgd04h6RLfeReBQGDrtkN4h2Lzg3Jk1Yp0yMh5ZOefxTPGVxwWv\n9eE//Q8/yP/yb28SBi2slxF4Cc7NieMO2gY0HLTbz2At5PmQ5kqX3//SX+O7P/Z3uBACUlJqS7BA\nhqSZ5sM/9MMsB6djWwE8AT0K4Z7m4Qm+JTlo8bt1YmFdn3Fo/k96Iubyyv/E+hmfuH+Lr7++y0df\nukLgCebDMS6fcxhdQl//OVzY4ObbhwwPT/jaH7+CLSvOXbnI1tYmH3jPJq1G3eH6yb/9t/n0b/1j\nPrBmsNOKvAzxvLrz4IoCayyNMEBFAVE7xmIgn2OzGTJKkKHh2o0PMJj+ETePHMd7Dzl78SJR4FGZ\nimKen+rSJFiUkxjk02Kn5o/VRg3nFJgpTL4Ew4WGpgSRxgxPFCdphXWCr37lDq9sWH7lcs5+EbP2\nzu+ioyX8skAKQTYH2XR4ccbSisfGquPcUsb9k1XuLH2Y7+/9LmI0RQYSVy3wVEEJur7XzaKOM4uv\nIUC6BV7G5WB0HVAUxxx88G9w+7d+jfv9Pu9d6vLWqz0++swmf/KpP+Lf++U/+zp8x4JHV44yMzgq\nxGTEcG2NYTZhaaWNrTwGoxOWqvrGa1KwtNxFB4qVhgMzpWhvIhgjhg9IXIUfwACPTTVh0myRzXOa\nUQdPgZERnaRFLxujwy5yvI9Y2WKl02Wu96niFkmrTeW36asNmrPHzLyClY6PttssnRvVLXajsZXB\ni2pLonySnYXBSIUyT9rlsh5lLZp75rSYd1cLk50DT1oKBEoKKiGIwoB0UudsGWPwfY8yTXHW0bAG\nladMpymhkhRVhSzBl/U4p5rPGT++T9yKiaKAs1sJKyseD286tG9RFoyp6nakq3fEgScXbfQ6sNOX\nCi1cHbqGwgmNFa7e8djaIZD4HiAodIVwDhH5WK0JjMFgqTyNQOBbiX/aJzM1eLBwitLV4wUNGBT5\n+JD3XVphMJzjkFSJRM9KSiTPPrfF4P4hn70z4b2bLTIL660ljqZTntlqsT+Y8qFLS0wGBUOTo6Vm\nlqastBWDyvJ4XhCkFfsjn9VWyCAwRL6HkYb+saN4MGHjUobXColKhyGkWhQ6dTPLUemSshxTlVOq\nKqcq5gsL9xTnC+71Uq5fvYDKJsyyKZ5UDPPT7dR1ukynlbH3+CFL7Q6jwQQ/OEueQpGVHBw+YKnd\nITMZzdgjL8asrsYU5QhtcrKZZufMNmmaMh6PybKM2WyG7yuSpEG73WY2GzGb5awsr9Jpd2i1Wqys\ndFhd6XL37j1WomXm85zdh3uMJzOKSZ+NrfXa/moqZuMhwovw/ZBIaxp+iMRnPJuRVxl7vRFxQ3Lu\n/CrJks+Ot8mwr1k/c5HG7n1MUfF4cHLqdeOpWuuiFGDBlPC0VbLgI3lRRDaZonXNlSgRDHsD3phO\nWF/tcm5F4fkSXVhwlmUPhKdYX41oJzWgbjbL8SKIOzGPTyYEQhD4AaPhDC0SZvOMleWEpNmkNNB/\nvMeDWweYbEIpPPb3a6G6sGJBK4dOICGQzLR+mgH0JNsJn6cjLLEgELrTLRt62tJQksw6POtoUP83\nuXXMEMSLyJymB2LRPRqWjkBCwxMoJKFyhIuXXKnAN1B4jlllSPy6m5YbSTStcELhtGC2e0LkSQ5m\nlpsjeHv4ze7Cj3zyPLPZnGDZYUSGspIsM0zmI5qNMxgTkM+PyHTdVUvTPltbS/xnf+X7+c1f/RJF\nadC2DldtKzCXr7DaadA/Pn3w7JNi5+nfBU8jbepSqHa9OqjNKhaUFAiR8cajX2J9429xrvMePK/P\nOHWcsZY0mzCbw/iDP0vsRUSe5Oz2Cid7PTqdZda2lllZW2LnzBoWSxT45HNH7gyf/It/gdmn/m/m\nqcVJgakMKgwwVUUAqEYIvoVRVodgGI2rNEQCFUUsX32eo0/9EfeGIddSQaPfp9NdBb9FEJ7OGerc\nk9DURaaYdU8dyk/NgtIDY7CRhKDWcHoy4zDvUti6HhikmqTZQrkJaVrSaXSoyhFShChbYeohA9JY\nlOfRCDNkq8NFr0dXf4WeO8u6fYRY3BhOG5yqcySdrZ2NLDStjifhT/XPyi8GKCdoxj6uEXHy+C1K\nmZDuPcTz4T1bHfJCsLz27TWV37HgmacjmhtNXGeFvOPjDg4pfEFj5SxDH4TyEc0GNi9oXL3BaD7j\njJwR+CHZeEYgSnRqKJc3IZsjixQ6OwydohGnmLxP7rq4UMJkRN9oGo2EfHqIajSosgLKgpnwCYyk\nqDIK6WPncyo/xIsbjHVGTEi6fpYojuvxFg5fiQVjp75kbvEEMmJB+bULsXA98l2oek5xGFOHAmLI\nC4gCSaYlgTPM8gxRaSIryaVAlxWRFzIPJf7E0hv3CEpNEgeEUlKKWiuCc/zeZ+7z6U/dwWaOQkFV\nWeJEonyx0G9ZpBP4UCegO0duFjRk40iFwBhbV87GImU9NtPW4kmv1tRIh6lq27o1Ft/3IdcOkMF7\nAAAgAElEQVTk8znC80AovNxiFAgnn3aJTnV5tEU7R2UMlat5PlIJ/DBmhiSTAdsNHzlzNAIfP5B0\nghghAnrzkqAR0MwchYRUV/RHFfPC4fVnrC9F3Gh0eTwqccLDlIrX+xM6SxrhW4q0IPQcvYlld2Yp\nK81SO+HqtYgJ0O9ZVp6t6cp1seOoKktV5uhqRlFMqcoUXeaLDo8mT1Omac5qq0GaTlmKPLou5vKa\n5e7B6YL85r0h8+mIle4y6WxCoxFRpAZftpmVkmw2ottUWDsGYqIwJstnZLkmCkMmgxytK4bDIZub\nm0gpn46yJpMxSimWl1eI4xijodFssdTp4HmKu3cfc/v2HS5cvMRsNscPfIo8p9FsIISg3+vRbMRs\nb21xdDKkyAqUUuw/eszS0ibD2ZSo5RFVFq0zpLL4IQyHOatrW6R5gef7GA2z9HS7UADlCeKkbpzw\n/78nFyMKoTyss2gjabXAV4I/+swrTIFrV9r8+A9dp51ElAvtHAoakY+gjnqpdMV8IezfPUnpnOsg\nPJDSESYhrWZII4mpypwizyly2N8/IM8nvHtnwvEM0DV358nRiBWVFERSYe2i47cYncFix7yg1z5t\nNZzykeOcYG4cBRC62tCWCGgiGFiohKCzcLa0lcMTglTX3bJR4fAkRFYQSYcvIaydwAQGPCmY5g5F\nHa45Lerw36ISpFawP4XjStJpJLzvyjql1SytJeTFMVEc44UNPAWF1QReQhyuI70OoZ2DMPgyIC+G\ntJpLHAwH/MwvfC+//r9/HutkbRt3joEX8NLV57iyuYRwp4QUUTOcFnplwH3z5+N4Wuw8ZZE8+bqo\nXbqh75hM/y739HXOb/xVls62Gez+IfPMsf/ev05T1NEpUimCUPH8S88RBJLxZCFrmFVsNkLKStNo\nNSjSFO0MfRniO41naxacKyukdhCF2CwD56P7KWJwgO97tZA9TxG6ZMkv2D4bs5tVNCOPvBJYWyFF\nQZmfruvu3ELeIb5FR+8Eknpj7LQArRFrTWQ3hMBHSAWR4Si1BFGCUIrCCTACRcE8A7l9FfPa7yFU\nbS6woi4ma5SJpalACINB0lF9lLMUOiTyCsT2BgQRPNytNwdagBW4RQX2JO0e6rFw5jUpJURSoTyP\n3je+Qb+a0swmrOaSxvYqJpesX9z6ttfhO9vSz5wndxLvYA+3tIpoKc6VE4JyRDGwLMU+vgvZbjYI\n0hndrA+rW4jlDWJ7k6OJJFgNyEYpSTnCd5Jg9gi9dgk7N4SFpqnvkQcvkEUWP1lCZSNEEFB5TYqq\nIog7+GpM2LvFbD5ntH6D5eEjsq1nqBx0JXi+RIWbpO1N3O5NPAGBUk+LmbrwqX9Vi+rRLpwfghoa\ndspxOoExWCWQlcaoAKMt0moKKYnmc5SSGKHZ7nTZG/SReUYSRszzPnoyoRP5T1x3KCHQon7ACF+R\nlgVxKAjqQB6s1VhqG6F2i4eSrM9b4ii0qXUCnodvKoyUT2nSbtHmNThyawg9hbX1TlpbWwcaOhDC\nIno9vM3N2uH05NpZ/efwaEGuSyrnURlDaSQGUFYShBH3DgcoIZj6DbaXIvYyw1g7/vCLX2Wj0+B4\nNGWmLUEYUNmaM9QbDNlYbTKZCYyu0Lai2/K58zhlMLXYomJ/F7Q1BKGC1LK65eFJQzOxrDQzUp0Q\nS0FrTTHpFTghKUuHrhxVlVEWU8pqRpXPKMsMUxToqqyF1JMDpHFo4eFZn9WlVUaTAcczWOssn27t\nKIMrM5xoE4crjCdTLl7YJlABUbhMnt5lOvHZ3tpkMhyR5yWVqwNWtdHEScz+/j6bm5vMZjM8z8P3\nfTxPcf78Bfr9PvO5T5blJEmTqix46803aLcTDg8OWO4ukRcFw9GIorSsra8SqroT2eq0qMqS46MT\nllfW2D8ZsbFzhpVOm+PjY9qdBvuHd+h2t8A1ePzwEO3nbG9eptfPOTzerzPdRIRUp2ep2EWfPa+e\nrLpvqQoslBmYylAWhlJrWp3/j7M3ibUsOe/8fhFxxju/+eXLOWtiVZEUSdGQ1Bagdk/ottvL9t5o\neGd45ZU3bfSivTLgjeGFvTbahgG30LYl290aLFkTKVEii8UimZVz5pvfnc8Qoxdx7suSZBb0eIBE\n5hvz3jhxIr74f/8BVC1J9xXLM0s5HLFYaJZ1JL5ba3FCMS5zbKgJCNJM4YPBevB5yt/++7/KaNDj\n+Zv/ie3DI7I0I1GSPB9Q5DnvfPwh27sJo8sfkxYpi7pilWw4ENHa4Hzu2JvI6COUKJrad7hC1E4E\n16lgbJellUByU1TZBwQBJwVrIWgJmAAlgTGCysM5sJdGLkuhIsJdO9A+ZgjOdWAtIVeCPAJSFJ4o\n3wakiIRfFwQz7Xm9FjRWkGWQZ1DVK96crVgt4O+/ex8VBEWR0rZzrBtGbhglCSVNOycEi/MpJgi2\nxxPqxrAzuoM1LXMLZWQaUBvH3nvv8Oj2Lmerhr3xzZ2WnQ8EG9EilERe28qCFBvQTXa88RDbJDIW\nPCoVOLugrv6Ix8/+iP/2+d/hV/tDMl+QDMb0ej08xMNjSGmNpyhSRoMCbyyJBNMa8sEAY2znh+ai\nUtYnkYRiDWnWQyQSgsd5h1u1NJ89ZSiS2GbrDUGlhOqKN8+ecOfOHexnP2W2lrjQovIU21SUw5vx\nBulGYnO4VwECHikl1sd8rbBUBF3By9Xbx87B82aHnUlBCGCcRTsN3lNVQLPsjHtjC9d0Jq5ee1SR\nUMoW7OoaCR2VU0yTx+fBSwhl5ML5L77I+HHYwHEicpASETMXrfdUV3OSe+9z/vv/il/eKSiznLaK\nar+y/7PnzpcWPK6qyfUMd3ubRErGYgvnR1TnFxzuHXJxsWKSt6hBSt4ahvuHlEnF5fkpK7fLeKQx\n6wWjdolJxigzRWzdJneBtF8wG75Dr5ox13PyNCF3FauspLWWUqaMls9o2wuy8SHp9hZitINfVNjx\nDltlyeX5S4zI0P0EW6+RHz5i8ckfkxe9OJGJVaxCELBd6neg6wJGFvj1wN5sW3fBkyNwiSIJDikE\nQQqU94jgY2utaTjXmu3tLaaXV6S9AmkDwlqKXhELEe+QAQoSKmFQ3uOsxqkEI6AQksobku6BFYLo\n74AiiFjGpVISgmPldIR1jQaVXLfrQgjIEKMpBIJECqyxxEhNBakieHDVEusPcDIulEmQaBl+LuNB\nZy3GaayXMX08BFyQqKJHIlImheTJ2RR5/x7bvQqxbPG9HrX1vHcw5sdv1pSl5bMzj0wCyzIn1HCw\nE6gJ1E5xfmz50StH1XrqNjAsFQ6B0R6XSNaVYG8gMV4iVMni0iLznFGWkLq4HTnrMKbFmApnW6yu\nsaZFNzXOaKw1WGNxZkkmJ7j2gmw759XJKXcORkznDYsb8r9c0MgAFxdT+oMdQHF19YbVqmAy3mJr\nvMX0coFEMeiVXJ5d4aUkSXPKYkAwmrLsMZ/PkVJSliXz+ZyDg/1rtdZqteKTT37Ae+9+wHK5osxz\nemUPJRWj0YCr+YrlcoFxksF4wmKxQErBaDwmlZLdnR3K/oBy2SKF5HIxBWHZ3hmzrocMexlSJPRH\nQ67W5yyWDd7Do/t3ePnkJ/ggKIqbL8pSxTatd38VBBEkoZPVAj7EwB+lQElIBymcOX7y/IStUU6l\noXAgpeJi6fmF+wovCgSaNFVcWjhIYN1afv3Xfwdn4cnjBX/8O3/IYGtMhmY4HPLTz5/RrGaM9g/J\nJiVff/cBp1efsro0HQcnvsJ+LkhVPFCp5AsIQifHha7o2aAON19yuq1bdNyLgJCwAqyHbRnoBTBB\ncNLAXg6KyNcZIKgFaN8Z8gVY2UAjIJWCvo+HLtGZlQqgsoHjCkzwBAWrmLWLklBpwZqAUNDr75Cl\nGmPBK00hcpK0ZL4+wwpBkYxYmxpjK4IfkKoWawfMl1csiZ0+ay29/QPu3bvHZNhnOncsws0J7869\nVdsKH6LZY5e0LbsBlCJ2kYwAgsDBtc9LEOBsfF9z91v8aw+uOuDfF566bUiTnDRPqRvLcJARg99a\nslSRqrj+t62N5F4Cfj6lNgaVpEgTEEmKx5MpCRbKrEBVV7jdHdTWbfzsJVKlMfQ677F/+4DL6Rzn\n4NXM8c0tj2k0Ki9YVz+HjG2DmvprXnBUtPmAaz2hhTAHnECmXSFYepyQ5InHOY/VJqrdnEJJy+rp\np6gutUDiCSiMcaRpHF+/gdxUbP8GB0q1ccxfHQPHEdXpCNUbdVbo1KKb+Rg8eCMAyao1ZKZisn/E\nyZNL5EcHqFTRLBuKsmB9+bOd7790pQ5CMMmXpNWcYbNGzM5pkz6hLDBpQd531MOUrL/N7d0MXy/R\nC+grRdqckzVLTL1GqoSdtMFkfVSekaFp5ACQvDGw1e+DFORpD6MDu1LjbUVdL0nbCh9yKrmD7N0i\nXPyYgbaI+YJxf8xIrPFlgSo8/usfcTVdUmQKCDEtV8RK1ncmhLEA2ignNlXEjRtaBBdJz9HNWbL0\nvpP0gtAt3kCZQC9RNHVsFSXe4whkqUQ3GmM9jbFgHNrZCDEmkkGS0BJPkHVo40RwPtqAE1BdPowN\nDo+ncZbgYzyFd/HUQgjIEDNIpBRIAcFZrNHUVhNEoFcOUFna9XAtfrnEeoN3nsQYjK0JpjMZueFV\nt5Zaa+pW07YNWrfotqYcH9Dfkly1joOtEe/uyGgjkGU8eXHGmsC6tczmDaU1TGdrXr5ZoBvL+bLl\nzZnh89ct//J3lvzup2tWTQvC89H7Q/o9hbegUomzjp4PfPJ8zeVC88mTFU2rGQ4Sqosat3IoleG8\nxtoaq9fodonVa9p6jdEtxkbFnTaaWjeMxgWjUcakn6FUAGHI+wV7N+RXvjo7xXjFbLqGEDC6JUk8\n6/UZxii2x3tkaY88GdGsA4KMJ09fsV4b1lVFWZaUZcl0OmVrawvfmQy+efMG7z1PnjylaVq++tWv\nYawlz3PSNMN7zwcfvM/+/h6TyQTnPZdXl3zyyQ8o8pzhcEi1XkUFnbX44Dk4OMAHz7Kag3Cs1lMm\n20NW62k0jTOW1dqitWY0HPLJJ3/K7u6INE3o9W5OPE1FR1wM4a88k93HTTxlhxAL2+U8uoFvZuj2\nuKDsZYyHWXR09VFt18s38zwWJmnofEiEoj8s6Q8LigLuPbzLux+8z/37j3j46Ct88P47NLpiaQV6\nFfjej15xddVxWELcFEQEFLpzQQcpf+EoGaSISKoUbJjMwYUbh4duOEHuWl0Ti5dKQNUhxT0RGCaB\ny0bQdIiQlFAqKKSg3BSIHXG6toGZCUxNYKFhrqO667QRBKkosgxjocj7fPOjjznYGrI12gGGfPBg\nElsZyQjPmiIbMm/d9f3piRTnDLpdIpSjMXO0y1muz9gZTfg7f+seV96jcsE7733A/d0JCEHen7BX\n5jedOvF++phpFkOBRQyL7lrq3gtsN3YbZk/oxtGFqPJyTtE0krURtBUUe6do4/E+0OvlKAV5lqCE\nhOCiuW3wSCURUpGkCqkS0rygefoMIRQiSLw1UWHkA6HV0FETRN4j3esRrEYmaTSztRprPI5+lKL3\nBZNc4Yzl6moWxTjuhiqtLxTXgbe5Y76T8DsT0T05CMj9MsYYNbENOAuKXEmC9zgE5+dXLNcNMu8j\n568RIolzsoNjrOvGvTWx8BFREe09MRTUi45qIiKzqnsufcerensP4+t0XXUWvMV5z06esHN0h2Yx\npxIgpKetLFfnc8gk4UsyW74U4dlez/G9lHS4S69MqeczQl0xzjWztiYbjNnupfTrSxoZT9JlnlGt\njxmYgAspPhuhRcuyrUjTjMxpgiwYm2My18OyxORbTMqMmeyj7CmhWeLzPfztr1KMR9hFhQmOVZAU\n/87fY9nUjNopBQMWOmXetsh1zu6tR8hbD8iE6QwH4/oS8/q6MEIRkN3CZEPHFRCdWdUNrhB8JFHi\nEEDehWXqoJBCkQVDlaRY41HBM0pTrpZrdNNgG41JYoGSCWi9v4bvtDZxIjiPCo5cRZdl72NEhPMB\nKTwuCAoERkDiAjJXhNYjg0M7j1JgO+ZSKlKSVNIYgw+BTCiklDitcSGAVAQkW6bFaoMIAkPkBOTo\na++Km1x7A8F0WTHXgcapaGcuHUk5QdjARw+PqFYV/TJhbjx7GXz7g0NckPQHOa22bPcUk1HB+aKh\nbi2DYUJQsLaeYW9JLmE4HGJ8IHMJg8zhi5YiSclKydPLmg/vZsxXFqkFIs95dVozGSUcbQn+4k/+\nF+anx3z97/5TjK4xpu6KswarNcZonDFYa3nvnbssXx/z/oPb6NogpWa6MgzlzSePdhlmvebw8IjF\nfM6wnxEai7JRgpqlPbyTOAfzxZK9vW2+ffBtnJW0jaGfDZnN5hweHrJer8nzHGstg8EwOiuPRvGU\nFODo6IjZbMbdwx0Gwz5Ga9arFUWRA4Esy5B1gxCCQX/A0dERSgQWsylnZ2fIfIAUnuFgSFEoarMm\nzxN6gz4uBFKVsL+1xd7uLk3juLU/olo3TEZH6Lq98bxp2mjMJ0N3GLmutUO3qEaiaQgBYxyzCmRq\ncd0cvTpf8Xn2mqbSJAIsHpFA1TRY61CZQCYK23psDk1V8fDhIybjPj/888+59+AeWzsT+v0+SsrY\nZraOwaAHCparCm06CuwXzF5Ct4THhDreUkXi2es6E0j4+Lngv/Dzf8PLd0TTuDnEabe5z5cicnW2\nVUTCRnmgthEtG6hYa5VJlK4nMsrXTYjriY6dANquCDIuKuKEsEwr+PYv/AJFWpDR0mQJ9/KUo6Oc\nkjnD4R5erEhUyWz9kiLfxnlNkfdQIpAkBY0LGJuyNT5gtrxge3wPKQO//Gsf8Vt/8ILR/h1u3b7L\nrd0R01XFsNeDvH/jueM3EjgXCeFJPOFew4R+g3lvhr27JyrSH7uxjL/Dh0hun557/u/f/4/56L1/\nzjc/+hplv0eeKYz1CCFonaaXSXqZpDaOxgUKAioRtK+fMhQZ3ptuw+4QNqNBFVhnyLcmyDIQbA1K\nIRJJqANCJYyPjsik4P1D0I5YcCUZKhWYen2zwQlc85V8gKSDtET3uXoZyJbRpFEkirBsUCrOLSNS\nFC7+2wfWVUtt4uYqVm8IpAglu4I8x7HCWkHio/EvXQGMDoRERP+7jkpx/Qi5L6QefOH+RB5P9JQT\nTiM7lfHaeH778SvePxxc3zdno5ed+ZID+pcWPE37ktHRL3E47HFZaerxIfLyAj86IiwrSh8YjHbw\ngx5+cYLpT1iGQG/nFvXrz7lwJVulZB5KemmJb6Zobbnwkv10m+G4z9WFoKcdKgFvz+mVOXXtKf2S\n1CQo6TgjMEh72Kqiqa7oaUfY3sGGlNXAkbeaIqmpbEbyzofIN3+OQCGCiDJ0BMJLkq6iD12OiRQB\nH2RMZb1hrg3Od7lcEhviCcAFSaGgDYHWBdJgSLuHZy0V/SKlFFFRkeQJKkQ4OnUxODRXEuUDXkiM\n19y9vcvjl6dknfumIlAoQWs8qRKY4NEmGqwlITbuUkAqiQ8dj6lzOq61iRCjENhgwCYIAVkST77B\nOdZVRVY3NFLgQoSbW6lIfw4Sz+PHP6CpFgyGExrTx6pthMxJhGaxtpRHgXv3R3gU+4OSdWtQNuVk\ntqRMBFu5ZLLVZ+4bPn53wr/97jFV2/LxhzkJgcVc8eBoQHlnD1lWnD1f0CiLSaP0crUwkSzdSD56\nOCTrOf7wByvmq4TkUYpt+6xefYfnV9/hvfk/xkuFaWtMW2NNQ6vbWOy0Lc56pqsrFlXDxLRY7bm4\nXHK0M+K0vuDh4d6NxkZmfdaXM4yeooAi6aGrmocP7zC3lrIcIYRktVqxf7jDer1EuYKT4yvKYkDS\nz1itVjx8+BCtNU+ePCFNUz7++GOurq6w1rJardBak6YZ9+/eZbFYgrCs12uapqE3nHQKwtg6aJo2\nemKFgHWG1XJFXTdgYW9vm8XynNWyIe8ljCZ9jJHs7OxzfDylFHB1fI7Wnt5oF+lhuY6/76aXNTFf\nx18r/b/wXCYCbMAj8CFgumykPM85ONzj+PQVIgiGowFea3QwSC/IpaRICyqzxvoQn4/Os8V6OH19\nwqefTGka+J3f/G3Gu3t89eOH7O/t8ezZKy7PVzz5336fZ6ebFxJx9s2pOVGxTYyUCMBYD7bD4jco\nTqeg2kRNsEGYb3BtNoPQtcu879a2LjtqQXSq3osiG4Zp3KCmJiI7PRU5O5L4d+haYEtiIeQ65CNR\nAhtg0cK3Pv4K7bri8eO/4MFWxn/+d28x6fdR/YLffFkhZI23DUIVOFfjvEA7wzDJOZ5dkaohW1sP\nmC1eI70ly4a09TFCZPwH/+Dr/Iv/6jc5PLpNL1P88KcvONopOZ2dY/TN4logmguKzmFZImJIZhdG\nu+llCKLc2fv4/XQFAL6baRvaiIg+Zs5JjP0R3/30n3D36I/Iq5bd3Ql5mmKsIU1SdIjzqJcrpEww\nPsaKOOsQ5HgT34tykb5ACEghY4EmQdw7QhU9wrqJBYdKUALKnTuoNCO4QCoVvf4A2xqqJlAvbpaW\nvqFwvK0n3s49FyJh2FbEeRqWKAHGCJI2UAcYdnvQqtL03plg9CWmUzp7LwguxpoYKyKq4yBco55d\ni5G4zwYRLVyCENeWDcHHiInrw0JnHSA8eBEP7s5bEg+N86zWNVevn/GN/THOW9rWIYUnhMD66me3\ntL5cpbX1Ac3pmqw/pmVJsXJMlwtEf4+1ddwqBbSOarmg/9VvYVcrxsspzdWaMjXRpO3Ck5UlxtZo\nEkTjueNfIG0fq1tEpun5ltpvM1keM+sNGSaBJQleVOS6T5kl6LWm1y8Ytg0+GxC0wazPmcjI4VBb\nB/TtkuX2CPUq3lIvwltOfmyno7pBt9dN3U7LdUMeBiLgrYnW3AESL7DSIYLCBchSResVddvSSwRF\nXRHyDJKUBE8iFc5aMglaCbSJeVpCRo6OcJ4fPX5NWaQd/OwRIfbdg+hsPTbvSYFzceZ4ZOzdJ2Ct\ni1w5KaOaiYAMjuBFLDBDoLEGhcB7R7OqSI0B34WPyijpvjH8BTz+7v+FknBnZ8Tx1Zpk+1vI/m0u\n3nyH+wfbzFeWvb7kB88vOL1cc29vxCAznGiNF5ILnfJm0ZCl8PmTGXujAqUk01PDy7ljuxCsA2zt\nBhYXGQfv7uDbNbNZi1vDYuYYEViJhKfHAWsaMgqUa/jWe0eYxZSJuMDPA83qDCcKnGvRukE3LUa3\nWN0hPEaD01yt1pSznCLLCFKSFX1qVXBW3cyH597dr/Dw1n3uHOzw+Y8/ZTE7Y2dnzJPHP6bYe0jI\nJYP+EKUsBE+R55ycTYGAd5aLiwveffdd5vM5SZKwt7eHUorpdIrWGu89Ozu7eO84Pj7m7u0j7t27\nx3o9R8qG27fvcDFboJQiCZHHdHh4yNbWhBA8xlomkzGoBEvKaDRmubwgLzLyMsVYz8XFjDTLKfOM\nSTbk2YunWCu5mp6yvT1kMa24vLy5tFhlIubsOPHXuWPdGt00LetVg7EGgaDVLVfzmv4QijLh3Xfu\n8KSteX3aYDrbBu1BG4+Ugbb1ZFlC3caT93hrTFACeMmHX3uPYjjmF77xdba398mV59n3/hgVMqCO\nvJ0ODdjskEkeN9U0ic++th0SIwIkAuGI7tEhdIereGq9aRudt2BFDMLseDfXVvwymipaL9gVgYGI\nDrt5Gr15LnUseooO8QmE6MCbRFKz6VAy7aFpAl979xEg+cHjzwD4T37tHvujEpkYxlspD140nJ6+\n5mB3GyMU3hcMyntIkVJry/7kEWlR0po1UuQ0bc2iOWNQjAguYXdnzD/6h+9j7ZhRkVFLR1aMqBVM\n5I1HhxA6J33RaXND1yLpDDk7K6QObeG66MV1hcC1X2S3a3TgrbWQp/Bvfv+/4G/94j9jZ2ccOVRC\nIhPAe3yQWOvwwSF9pB+0rScVjqAEwrgY6pwkCGsiRzKViLJEPLgHRKUaIhBkgggRj9rZu8Vo8FOG\nvR5vXpzxi7/0S9hmRW9002De2FLaFNuRu9q9zw7VsybKwqOuJzb9mgq0jBylpja0NuBd7AFu0jrY\n+B0pRWugl7wtIpWICM/GVEdYT1AiZt0l8V6w8anq2o/huurp7lX3tHhvEVKgvaSZXXL1/T8j+5V3\nWFYLyqFgvL+N1Q7vfzay/KUFTygnDPcmuLykFCVGn5DmBfPzlxzcuU89e84URbFzi4vpgnR1gSp7\npO0JzehDdtsFz+2MQq4gSfDLU3pOk7/zIedGUFy8ZGQluhhRBcda5mz39lG5ZTukLOoFM7FLpk+Q\nKsqHQ76Pl4b5smGQZzg1Rolz1lrTLhqKZoWXCik6QVsIICQybKpEiRQBgkQQeTeWgBI3lfnRKcAk\niXQ450lRGCEI1uClIsfjZaC1niLPaK0lSRISBFpbtLf0sjT2dgkYGzBpwDsf7fBTGeXyPp5UAoKc\neHg0IXKIBiiWTl87t2oRXaa1CSQhIKXsHEwdIkSfHqTqAgxlZOsriQoCqQ3S1kgrSGUSM8eEpJU3\nh3iCD8xWLaWSXLx4w+TymJPpCikDr1aOr96fcHZa83C7z/Yk5SAvuLhak/cVF+cVH78/wuUpVile\n12u++eEunz++4r3DEfcf9Ckaw2tn2R71UVXDYEtSrRKypOFCXPHRnSFUBqkdATg5L1BZzcFkwKzy\nNDPBSv+E/iLl8b/+52x/6z/CNpfYdA8hc5wTtG2DNRprNIv1goNhj/1BDyMk3/7KfT5/eQKJIBvd\nDH4PIfAnf/KHfJIptoZ9JsMBi/kVg0HBZ4+f8tUPv4aQkjzvdQoYjTVT8iwnzwomO9s455jP5zRN\nwze/+U0+//xznHN8+OGHPH36FCHoCqGE4+Nj2vWCW7f2uXXrFlprZtMpeV6A9Ny9e4erqytOT485\nODigqdesFlNWlWG8e8h8PqPIc6yLO4LVhqIsmM0W7O3c5eTkgvm8oq4NO/t3QKWoJF3xNygAACAA\nSURBVEX/HInXSsKoELQbsuL1V8R1hysmX8dtyfvoyF8kCYcHe3z88V2cbnh9usQEsDIgvSdPFVIq\nnGvQTrFqHP0EmhYGgx5ZEflGAZhsDTt1m8bbgJKSg4lAJyUvH0d7ZSEESkXfmtBxiKzRJCLF+qjI\nElIQ9Nvultj02Ltd56b5fZ5NobX5wdBxiMS1KlMg0AROhaB1gT0pUBL6BErVSdS1IFfRn0cJ6ClB\nJgPax697IagI5HnOd3/4KSA5KBV7I4XIJUk2jHzCYEgTy7xesLV3m9olrOtT0mRImW+hzQqZSBKZ\nsbM1wGpPIiApSuZVRe2W/KO/9wG/8W8Syjzw7Y8f8fmzE0iGZKObq7R8iCZ1LgAumsAmaXgbyeTj\nPIpf7tB+v+H5bJx64lobV1vRtcAkrQ405rf4gz/NkfKfsb01ZDCI7S0pRTx0+jivRrnAmwbfGFyS\n4DvhTOwlmlioEknnvHuIPNiL8kOVEIKLaIhzyGxEWhTsDPu8uVyTIlisFtw63EWEm6GnHb4Y646w\nIQnHr/jQEb67IGVxjV52LuchIEVUA/sgqFZLevsKKaIS0hCLRwnUTjIJb/dHpfx16zB0fLJgia1e\nYpF5jVyGtw998NEvL3SQXPySwwRPQkA7y/wcjHbkSOq1ZvswQYhwLWv//7u+tODxeQ89KFn0S8zs\nJRiJ0w3DrISrBWYwRuV9Uj3lVr3k6uIKtu4itm4hUsGVzTjYlpiVoloaeknBcDCmqiWD+XN6+0eY\n9RmSSIYa+xm9q3PM3V9hev6SSdFnPX9FyDOS4ElkwLuK2VojbcVKpCRuTtLfJVke02SKbLkCv5ms\nLt45768nshJghUBterXCk/hwYy9hHzwqKFpnoo9DiB4/ysX4Cm0tO0Ihk5SZsbRtyzDLWaaCRddD\nVUqREwvcRAiCiLEOXoRoxCQFIQik8NjgkcSHWQGEiFLZEI0GQ1dSp4BFkopA5R2JdSAjgdsSK/g4\nzz1KRG6NDw7lDJW1lGsDUtEIhxISqxRK3dwTw3tBP1MsVwv2D8Z8/b19BpNtnrw+5zf+7Z/y7Kzm\n6+/sMq9qZouGYkdhU8hIUHlG3RheX64ospRUwk+eTbm732dtDNMXU/r9HLmT8vSnL/Eqo2jHVOsa\nW7W8czih8g2Vtlit6PUC41s5+6MxQ23AaB6f1Az7isGBxLucN3/26wSnUfmQdd2y8/CbrEw/qtds\ny9awj3MeLQKmdTydn6GFYjjok6Q3MwHztiZTCgnR6K9fkqU9VvNzTBuoqpp79x5QVStmV5dICe+/\n9xXevDkmy3Lquub169dsbcW21E9+8hPu3r1L27a8fPmS8XiMEPD06VO8D9w+OuL24S5Nu+bP//wx\nk60JQkq8d2htGI+3yFVgd2cLBBRFQSK3QazwztHUNb1+ga0qXr16xc7BiPF4wsX5nDwvWayfUw76\nyMxhvGVn74Bnz4+ZL27mTwRQN4HFSuDtRtT9xYZ+h0CYBt12cScWyODgcJvTizmfffaCYRGfG4Cm\n8aytRxDIsxTnWqSEVCla51ESvvedH5HkcSmslzVPf/KC/ckO/V7Bk8dPYiDhIOe9wxHL1RXSaq4u\nWny35xQpZCnR5FTGtkroCiHRobYhxCKIDTfp57iCf8t5iC2ttx2biDyFa8ND6wJTIWgd7IZAqbo1\nR0AuA60X1DauiQMFaRduShA0HbImrzcOz3/47UcYl3QHqxZBn6//wi5FXpIPD2malkm5TZKWkPQw\npkUqSZKOqNsGY9YUaZ9GLjGmJZEZdav4xjfu8D//qx+j5X2evnyDFj2Gg5KkN7jx+HjfKeN8uLbt\n8E5ctwBDJ/rYkM29iwiQJ8TIoeup1u0NdLJ1ETdgGQLzxW/wv//OT/l3v/3f8LUPv4q3gTSTDHLZ\nHVRFXFerJUpKjI0RPVIEgrPI1iCzFBk6a5S9bUSeE+oqzpUg6OR8ZP0hQWYMx9vsdvw1bzSPP3vO\nVz5+8HNOIt5allzXzQJruOaUfWEaoa7lXOC1BxFYNeY6fV50v1QKiQsC41WHnkHwXXfA+oiedQWm\nII4nG/Wi3xRc8d/XCKqP3xJFOJvXIUlCQIiEK2C21NwqIUkkKMnj7/yAr/ztX/qZb//LVVoV+FfP\naJ6+wV5pwuqEXpbhRiPawZAwr8lXK0okq/4uw6GinQyotKB+fsyh1cjefdLFM8YDx9bWLfLhgNZo\nzsUOYVWh1Q6tzOg5QzrYYdW/y+r4MULmLGVKb7IDzRqy2HKprKPUM25v77K7vU/o7aHTAUaN2N2d\n0LoODRG+uxmx9+iDQHowIiqXPIGYTBUgJGwMjv7G8yYEjDMkXpALgUoUqiNOSRFIlWSJixJSFxUD\nrbdM+kX8WWIlWtvYSFbeowgo58hDPIcIHxB4TAcbJiHyfQwReVEduVAGEGoj/QbpHY3tkjCVJBBz\nahIRXZNVd5rRzkQZLRLnPa7WCKcR3c96a1HWEvTNs7QWiylBOGbLFVenM/7s+8/59JOfkPolX310\ni6+9c4s8F7y4MHgneLA9ZNIrubNTcjTOSLPA+dLSV4AUSCTffz7Fac9uKRgl0JN9krzH0U6BcTVB\nVZTjwPPzmlfPLUJllCPIc0WRa0KouZQ1Z0Fx8GCIJaWqwMscUoX2AussgcCLT36P5evv0CyeYetL\nMpWi0ox61bJerkmTEo9Aecvl+mZ8gxePP2XUK9nf26c/GKKSFERKoKQsBpyfX5DnJS9evMT7CCW/\neXPKnTv3scaRZilSSYoOlZBS8vr1a8qyZHd3F601vV6PJEk4OrpFmiSoJCGEwO3bRyyXK54+fUpZ\n9rhz5y7Pnr8ghMDl5RW3b99hd3cXH2B3bw/vA1fTS9q2pWla3n/vfapVjXfQNJFgb3zFsp5B4qma\nih9++kOss7zzzqMbzxtXx7ynxv1lfAfe0uy8CbjgrhcvKaFaLTk7fYNuZ5wvr2gbi/EAkbTiENGT\np1PuGBsokxj8+f5X7vKVj94FYDiZcPvebe7cvcPh0T3eff8d6pUlJAkyEczmS0aTKLffyNKTJL7S\nREnyJMGazmtnU41AXLl16KCg7s9NeYOBbkN8uzlt/r1Ja79ObvfRl2YZAq8DXFiB8bGVpSSUKtBP\noJCBuQ2sXSQyCwmpkkjgk8ePr//rd3czgrcE7yO6IAJF7wCpMpSUVO0JNizxfolwC4p8gPceZ9Y4\nM2PQ3yMtS5JkC0JAmylFMeHe7RF4zXq9Is3H9MsU5T2Xs5sXy3TcLttxQpyLa29wcSx8iAexKEXv\n2lodb2mDQMAX2iqCa2VsEAIXosZauMf83h/8Y168eEleZDTaUdeW4TAnVQrrHe3lGYlSsSC1DtYN\naIt3FpmkSClIDybREsS4qPAVHfKSFnghEUKwfXiL7VFJmSsaLRlvb/PovYesVqsbj83bLW4Trvq2\nrWrtRqUWuU2uGx9jY0vQeo/dJI86C0i0cTiSLm09FpTWq4imdb9DdATiaG5I7CxEqPJ63m7m7KY1\nG0IsfvzmHnRorrMWQUAHT9NaRkAWTLw3LrCeax596+usZj+b0P2lBU+mT9CJwK1e0SYtzWSMEYp+\nsUUQgUZr0tRSqZxMJKyHj1C6JbQ1olfgJ7vMjl/S232fVUiRMjCfr8nNKXfyJcnuGFsUuPE2WXMC\nNqHvVwRVUCQpvjVUraM36DMzfbJSkghD0t9lKvpor0nNFcX8KTJLubCKvFcQrv02Iwgu/KZnLZEh\n9i6vKTxBgvA3L3i8IxUChKP1jkZrtLGI4EmCJA0BHSICk0gLISCcp7GWtXekgMJTSIH1DhMCQcQT\nlJOCRAiMBBPi4mmJwYexHy0wLhJMCzohi3ekIkrRPXETSGUkk4kQT58qtohRUnbcSUHwrnvNHuE8\n7XIJ3qJc5Cgpa5H+5gXPhVmRJpLLuqI/KsjyhOPlitVcMMigaVd8//mCw52CWwcT1sbw6YsZKfCL\nj/ZwjecbjwZUMjBdteyMevyDb92l30u5c2vMk7Oa9WzF7mTItIqy6DLv0zYBqQQP7o/JSodMElqv\nUP2ccea5OvMIl5KUAw4PthgME9bVFQd3PsYKwWodDQcHhUDqOfX5M85ffMKsWVGojMZKjJL0Byll\n0Yda0yu+FCj9a5c1kQtyeXlJvzegbjTGOLzPGPaHCCH4/ve/z+HhLdbriqa1EBQXF1eMJzs457l1\neIvZbMbR0RE7O9t471ksFkynUxaLBev1mt3dXabTKZswoVtHh5HUt16RZTlXV1e8fPmSfr/PeGuL\nBw8fRh6QMRhnubqasq7WHBwexHZanuGcp2kMWVpwsH/IelWR97bpDUd88NFHqFRw6+iQLBc0zQ2V\nJAChW2wD8Fefyc0eLwJ0i/NmHV8sVvSKER+8d8C/98vvxdZMAO08tY5EfiVlTFEXAuMiwZcEyv6Q\nre1oHln2cw4OdlEqqkFAoE1cjPtFBgJOT66AqHACaGw8QKAkxjsy9XZz2WwoYmOk9cW3dUO303hy\n/oLd/mZj3nBVYlLmW+VLxzPSDi49nHjB2sXvF0QOeCaJbrhA42FpA62BPBFUJkJYh/2SQgkkDXiH\nC0lsc0hPOTiMFhf5mDTpddwWBUFz6/DboHpsTx5h9CmrxQWtrsjSPnl+QC/vo8UhHzw0mEbTHw+x\nIQUl6RU3l6VbG66JyNYHnO/I2H6D5kSOj9iM2YbPEjatq40IbsMZ+eLgB1yIEnbrJFLC//Hb/xmn\np1O0NjQetLZ470iSjPVffBcnO8Klizt36Ew1JSD7GbLfQ+zvQrsgWndLkGkMbM4yQgjotsKKHrPK\ncTKdMZ0uEViEuplPkejUfBspfvw7HiqCDzgLQQe8DddFoncReREiHhZqYwmIznYg2kPgNsVlQCAx\n3Rz0Ib4lKXkrXfSC4MR1sXk9b/8SsrNBMsO1RB3i6yA4RAjkSlKvl/SIOFzTOlSWohKFUIKTx29+\n5jh8acFj+wMG1Qw52UGWGXt+jUtymss5pqlovaLyI5LqHOkXDNwZuj9GKkVh/pyLNy84LFvS2Sm3\nM81g3EemGYd7d0l7O8hsAPM1aVVDMUGWGfNkhAue4AxlH4KecmEKdgeBRozwMkP0+qxDinYZohxB\nPmakPKPLS8zpBUJFFtQXbHZwSuDUprDpeugbaBPRicv/5pcKARM8MvhrQ8BUCqwIGO9ovSfDk0pP\nLhIK4SOhLpZdOBfQPkSiWwdNO6dxBMoQqIMj8zKqvJyLi5fzGO9Irw3EHEZEMpfz8UG2ISYBpyEu\ngCJ4XADjPV4EXIiEuiAkSedNpIgVvGwNermkcA60QXpPayzC3Fxts7XVZ7A1ZjgaUY4L9nbGfPvD\n+zw9v8S4QFt7+mkS7dpF4Kdna4wL/PGzKT8+X7JWGZ89mXN61vCNB/scL1fIJPDBnV3SPGNvmLEt\nHCeXC1QaIdXz04rhsGR7O8WGFuFTlBCUI4UNiqXrcXSvx6SsSCuDQFKqMRmS73/n/0GvGqzTeKNp\ndFwwq2qJ8DVJVvDs9TEZhlJJLi8XjNLAcJTFY9CNxiYWKMNBTDqXQiBEgpLRHfnO7dv0+z3SNKM/\nHLJ/eIixgRcv3uCcZ71eY4whyzLSNOX4+ITxeEyWZbx584a9vT0uLy85OTlBqYSnT59yeXnFbDrn\n6mrK9vYOQgqSJEVrTQiB5WLByckJ8/mcN69eIxCUZUmWZ1RVhbHR1PLs7Iz33/+I+XzFcr5mtWo4\nu6zY3b+N9QbrK8ZbfRbLC/Ls5q1QEBQ9KPK//jxuYpKcjVw0JYgSb2JBI6VkXVnOzi5wLp7kpYBa\nQ7VuIwIS4iaYphILEOD87Jyf/jiiGd//sx/yR//vH/PpD3/IT3/8I773Zz+gthCsZ1nH4Np6dY2v\nA1E0EH9V/JwTXFtXxWiJjq/RcW3CZme96fBsYtfD2w3CbzaJDe/hix8jusTwENVaNnASBHMflW5C\nxKInEdEzbKgEo0SQqbeGbwB3Ryk5hlSkBBHRCe1TdN12h0uBcxWJGkBQBJHTGM1q9RQlJMfnTwBB\nXm4zGuzhRcrO+IAs7bFYV/zqr73Di9efc3n8glFqGUoL5uaxJKJD8v31OMSPnYuFondcF8ohxLw2\nubknHWcn/nzofsfb8fSbIom4rjovceZ7/A//8p/w/OUxlY4q3fGgwM8v0MfHIFJca/DGRMWW93gX\nN22ZKsRWPx44jQUZ+Se+qXDWgTWEkDCcTBgNe9zaGTLs9RgPCqwLmPZm4xOuqwkiTzmI6+cpoifg\nrMBbgTPxY+9ill0gLnGNtkgRrVOs9R0aHqVVG++jxsdCMbA5uIS/NF9ho45726YKXyg4N2PuXWxN\netcVZA68tdgQuGgMy3nFLh0vKwSWqxZjNNY4plc/G/360qNp4j3G54hWspVqZnZCfwCrxYoi79G7\ntYtRBXl/SF6WmLIhLUao+Uuc+iYH5ozt0S7TcoQJGct2yXBvi6nrE5IlxeySNZJhaAhphhcJhXGE\nTCGq5zTmAVkpEPacut1HDQvsi2e0vW1yfUJSBozaQaQFa99g6jlCRGhQdMSsWMUHEicwb5tc8UZv\nHCLDzTk81gdSF718wJPIeOJTIpK4RIDaeYR38YDdOWAJ71He4kIaq1gV/S+scxRSIoInAoaSBI8J\ngSTEVlwiVXdi8VgRBWaOKJ/0xNgJRED6mLcj5WZdjYoO66OqRaKwIuZveSQtHukMzlnEfE07rPAq\nI3GR7xTtMm927e/s8+TJC5JccDFdMJ2uGfdyylLx/HzFN48m9K3h1lbO6bzmeFoxq1uGRUJbO5JE\ncLDVh+D45PkF+9s9/s/vHfNPf+U+bWu5XGuG/YzqfMlwNGDUB+Eyag2DsmSyX3J1taKxjr4IoBpW\nVYq3CaHf0tuXPP3hOf1kQKpSytygQ6CQQK7Is8DLsxW5SvESTO1RSYrwkajXGof3gdm8RY4mNxuc\nIBkM+oggMVbHqJDgKcs+ua84OTlhe3vCyckJbdsyny/wLnoOta0mSdNrSHvjwyOlpGka9vf3uby8\n5OHDh3z22WcMhyNGoxF1VTEe9bh79y6fP3lCVVVMZzNUWqDbloODPay1XF5e0Ov3KYsc52DVaLx3\n5HmfXtkn75V8/vlTCAn3HrzLyZszytEeqhiCqFFJxYuXf8H+3oj18mbqtc21nMWT2+a6JpN201DK\nzsAtAA6cgiRLUE1DkkrKXkJRKPQqtrAAlFII50hUJJF66/AyFkR7+/vkRQZ8l1/+lV8kG/R4/4MP\n6PWHCGf509/6PYILnE8Xncz8L7OLZLfWOBdi/qLvgIMNEfYLGwt+w8/bEEP/5lcgbtD+C1yL69/7\nxW8KXW0E3X+yKdAkKxNoJIxljOXpXzsRvw3WLKRgnMYD1MrCvd0BqYh8J0kNPie4isWq5dC2ONGQ\npznWLXFCEvSCNC2iylFMUUmDSHexdg4emnYBXpOkGUnS5xtfPeA//d3/kd/u3kKaZJT9Pv/1f/ff\n32yA/krLJvgQ+ZGyQ1e6FPVrAnm3+Ue+ShwwsSkkO8RMCGLKfddfiXWPwOtAokCEH/K7f/IvIP0v\n2R7dZ2w9zdkbkAptLNY0SO9j8WkM0ndjnSbw7n3CbAp5EV+60TF2wgucaUhGJeV4myIvsLZlMuqj\nreFoMubVy+XNhuZ6XohOGh4/GZAdh6bbv3j7PhHRkNDR+V75GGHkRDQ+1MZ27UKP97F7ou1b1MZL\n2GQUBU+Ml+h8kSK/jU6WHjqvgK643BCorxWR8X6IjhJ0t1fyFI8i2i6Y1iFHCbo29CaS9EsQ9y8t\neHrrM+ryiObNM7JEk+08ojKwfXCLZvaG0d4OS63Ig8PWM8p0yOvLU7baJV5r2H/I2gb85afMkwOG\nInCSGvb1U8JgCyclwzLnqrmghyRJI59C+RaR3yXd6mEXS3ppinULqmfPGe3u0+/t0M6fUScZ2eoV\nKg20q4TLVU1+8gK5ncfbJgPCC5yQCOkRQiJctDHyMi5UETETXe7K3/zy3uOk7Sy2PdoKcpVQW4d3\nsYVVINFCRDIbAi8FWV6wOx7RVhVeRrg8qMitiY6SnsgoimWyB7yU4Cw6xJwsIaKplutaWBYf1VYd\nqzH4SF7ezCUniCosEadyEJ3Qzwu8sCihwDqcrqnXK7i6oDcYYaO2IIbI3fAKszWLynC3GPHgg0N+\n/PQNL95MyYJgq18w6mWIVpBLQYlgMiwZ5orZuuGq1hRpyjBVHC9qdkY9fvndAy7nC/7X7x2zNUkY\n7uacLCu+/vF79KThZHFB0sspvGC6WjFfToEE7wK1ykiTFCE0MhF4l3JyUjM4SNHLNcsrBcHSTzLO\npxXWwdaoz96gpCgT6towq9bsDcdkicGTooIky1NWVvLu1s0IlnuDgs+fPuXug/dZLOakec7i/2vv\nzHojO9L0/EScNfPkSiaXYrFKqlYvknqZtsdoYMbw+M5XvjbgCwM27N/g/+P/YMAYGLAxV/bMYLrb\no+nukdRyV1FVxSKZZO55lth8EXGSJbWkGerOjfyABPdkZpw4EW983/u+33JJYx3j8ZA3V5cURU6a\npsxmMx49OsVowXw+Z7lccXQw8o7ZZUmSJAwGAzabDVprxuMx/X6Ply9fBmKroalr+v1jnHOs1yt+\n8pMfUynN9c0t8/mCo6Nj4jjhaDLhYDzi8vKSoltwPZ35TVtKNtsNuIg465JlQ7KsoCrhYHLOp3/9\nVxz0Epp0w3CQkWeGcrPmW4j7EFjynsQqQR3s83dMg/B8Sns331qxWzw38yUbBS9fw/XMZ3eMYWdW\n2M0SFputN4azvj1CL5OkqeTNq9cQ+6XwxYtXZN2Cw/6QOEm4eP4iNOJ1fPCDx0xfOy4/m38BZNh2\nQRYeQMXByK7tWN42EBUSL083b+3NDwhr7Y6M3cbbr8ODrPsvWpgT/Gw9XxG/t9zguBGQxvBhFggA\nAqSDLHIUsaCTCFhY7raGUkfYVGM0yKiBRvHq5ZLz7R0K7x1lqWn0lpPJjyhribYxvWLAajVlPb9E\npgVFUpDHBbWtkOKENLnl4Pwp7//xIX/zN7cgBNooVg/0mfFvLHi74Nc625ZJguuyCAo5KX0jzSjy\n6irbkmydw4T2B7uMjoNY+t+1kbdLUI3/P7USvsR18+f81//25/yPv/gZ//pP/zN/8pv/iYpPvels\nZdCNIk0S2GyJen3ivEPxr34Gg8ITmvMCR4yLHFJt/YtOO+CgriuGjx6RfJpjXI2QMVVZsV4/DPC0\nKcddB/Lw/ggHB2UsupK42P8wssLzjZ3fR5QxbJwiEpLbdYWuHUaJXSbIGondCGpjMbXvV2ZiSeQJ\ns35M9b0STuDvKas878zqe2BjQ2myTQpZws9TD0hrY6jr2gt+tEELsApEEqEqxejw61vafOPRPe8f\noutbRqMutijYWE0nzYm1oN8Zs2lq6uUlZV2ySUa8vLkhEnDST4jyDDZL4n7O9uAH9MYTkqRisHmN\njSTJZuHN95opeV5g4hG2XOHUkp5Z0tlcUV7WZElGJx1hbUw2PqKsGxarGerpH9E9fIf08F2Kx/+M\nlXWkzy+IdNjQBQgb0UoMnfAmSUL4yQr3bSdwNvgO/OPDOEOEJMLirO9RVRuDcZbIWBoEsfBGgB0B\nwjis0gg0jarZBGCirfUEYX8G8CjZGVI8f0cKhzY+lZjgdpb4MpxWnLOe74PbOTL7MlV7wmsvst0t\netpoL8v02J0meJBKpXB1g6gUdV2hdINSNU4//KR+tdxycDxgui35u19+Rg9LJ0t4MipIpORXr+bU\n1lJqxa01dPKUx5M+7530aZShwbGqGn5wdkCaR/zm+Q11bTkbJnTSlMrWqGFCnjneLFbcTRXrVYk2\nJZiGTpbhjGXQS3AoqtoiXII1lrrUJEKSIUnilO4kI8kkd8uK4bjLeNQnjqBsFM4aGqXIkoz5quS2\nitnohq60rDYVkXCU64cRLK+ubzk4PmV6e4OMIuaLBduy5HZ+x3o54/H5Y548fcJms+EH3/8enW6X\n1XpJt5uz3W4otxtWqxUnJyesVivu7u4YDoecnZ3R6XSoqhrn8CWx4J7s+bPeY+fnP/8FCMl3v/se\naZpSbksvc18uUUpxdnZGWVUe7AgP0qM4ptEKbRyLZc2Liyvmiy0ff/wptlrSywy9TszB6ACrJUnU\nwX6LzCCxoCmhaeDLsGBHKg0tVN7OZmijGfRT3nl6yPeeTUhjEe4P/zdVrXx5A7+4SzxvWDnH6PCA\n8/MzAB4/OeHZd59w/vSc73//+3zwow9AgFKGujFEX+EPI6Q3uNilbNqzU3sjuvDPQmpf7G7Oh0VL\nDG2f2oV6y9sfd49QYtidkEMJQbTZi/BaGiW4DI1aBQH0CEgkFBEcFPDzixmV9WPl8OtPJOCdZxPf\nMy0tEDjSOCGKB2y2C6rqjbcmqNd0iwOUi+jkh8hkRLd7gGqUL4+oDZta8B/+7Z/5yx9DFLeE8IeF\nbUnK4b215SjfNiJwVlyYMO0laHf/oCBqyd9+zNrxE8H3TJBEkKSixQoIB5taUpWwWfwVv/3sL7i7\nXVFr69VJWnsFnDakIiGOE782j3q4ukEmWWgsrsBpiDPQDc4arJAk2YB6tWB0cEiRpd5lXDpOzk8f\nODpvI2Cxe9sQiNnO0yxs4OQ46yXq1vi5U2tLGUr3xhhsYwjnc2zg4hjjMOFz28rNWy6bxXN4HKF3\n2f2jndct0dlaEZ4nlCfb57J+F1PW0R+PqPD2K+38j4QjSiRl9fWc02+cVrKqOYsaEpciZYezIsZZ\n2FJTqppyY+jUNUm3YL1awKBLX1o+55jZtmJtM6qoR8eVuGqNKc6px98jzVKa/ojldI7MDnHNFrd5\nRTeO6IguHDymGQ445reQjTAkGGcg7zPoRGRxiXrxGesmo5I5U9EhOn9C8vIlUSqIpG+siSAsRAHs\n4Lz7ZrgS3ptHhvT07y9k3zh9rKDS2iNja9DWYjE4q7FxTOEMS21QxmKcwUqHA/CXNwAAF01JREFU\njL0fQ7fo4gLj3MvbBWk4iUgs0jqMtcTCkTmfjUm1pRS+9BEJg3GGDEPSrgzCobBY50stlXWIsCkI\ny87VWToXGpE6L9u0jtT5k521Dls3Xr3VaP/QgV7/wOgVGR+cHPNnP3nCz376hLODLpNhzlY5DoYZ\nxSjnk+sFf/npDbPphk8urvnFiwXaSN477HJzNePp2YSnh30maYyIIxrjiLMIY+B0MuLpQZ9VaVmy\nwXYa7NpyNzOUVcyrlw2dLCVLU4SN0MrQWEelG6raIKRgvqgosoIOktnC0MkSauVII28DoIxjvWwY\n9HO61nJ8NGCYKBIT8Xqxpq4aJsWQN7cPO432iq6vizvLbLmg1goRR8g4Zr6c45zlk08+JYojZvM5\nn3z8MWW55ejoiG4n5/T0lKdPn7JarZhMJkwmE4QQzGYzjPEmgsvlkm63y3K55OrqCuccw+GATp7z\n6NEpm+2WV69eMRwMkFKgtGK73bLZblmv18GRuSSKJEkcUzeKplFcXHzOfD5HKcXnn79k0B9wcjLk\n4sVvWa9WXL25Jcu6ZFmPKHqYXB8A64hifMfrL0XLORBhw2rPKNZAlkqODvucHg/pprkvGznQRjDo\nQJxlJHHkGxj6EwNKOWK8WVyaehLoYlGxWZcYo1HKZ8esharSVJXPEIZXcf/CnPVlisgv6Dsucmsc\nhPPl8/D93Wb+YNDjwik4rGFhA9staV/1CETZ9rEDR6Fk46zjRsNLL7jbteMBv4F0k4iFapitGoyu\ncNrgJRTw4QeP2FbXpJ0h6+oWbRpf2rMb+r0nJElGmvYZ909Jkz7WRIjIr1HdbERd35GnPcqy4k//\n5D1/LQMxmweux3DfHsJaew/qwgZrrQj+d63XDrs+TdaEDt+tWii8fxfKkR6QhHGWHpSlqSNK8ApS\nAdtSUjfw8qP/wjqd+EN2U+9KZKKpEZH0gPLRANEfQN6FNA9gxyET3+7FYiDrIbCITGLqmkhmzOdT\n0rzLerUmiR8mlID78WjVgS2VzLlWxdaCHOn5PMZzeaQUNNpQKa9gVdYEYOnQbYYn9ClzeM6YM97r\nqE3RuNBfq+Xk2NBu4h6ch/5ZuF02yLTE85Bx88DH0I0kjXNcAJXxPmvaeEW0UeYbWz9+M47OMrb5\nIVtbQ5rT6BhnK0ztmK23JFGJfPSE9bahkwuy6WtWKmGSKuIsYfir/8705QWynMPtc1RTYW4/RnUO\nybFIM6OrNwxlQpF1uZ68R3TyiKipiTpnyMc/IZaKJtJ0IkdXr1DFGBf1ifOG5uqXOLUguvw7ZLlA\nfPS3yDgO5mT3zpotIc0fzjzEEYgAikKS74H5dx2o5knr9+AM2kCEJLEKZRyZsxir0U7gFMGJ0yHy\nFGxQjzm/uVbC10+Ns97PR0gavHIixqEiiJ0lwfdYkkJihHcwMoHcJcPpTr+lQnPWep8j50jxwMfg\nJ6J2jkx4xZa2xsvVjUKvSzabLUJVKNMg1EMZTrDaNPz9q1fYtEvvcEjv0RHX2y0fv55SrWu2i5L3\nj8f89J0jjg66DDop0tZcrEr6/YJ//sMnXN0uubhbczwc4Iyl1IZ1ZZgqxaXasOlKVLpk3E8Z93oM\nT3oU/YRUZvSHEVkR0aiGJI0xWuKU50kNBzmVgU4nYV2vmG629HoR0jl6UYRSGusiRkVGp5+xruDg\n5IhRV/K7qzW9YUqv6GFMxWev3+AtQx8yNiVSRoxGY5I0ZTgaMTmaMBqP+OGPf8LNzQ3dbpeyLLHW\nMhgMODg44OLzC548eQIIrq7e0O/3MYHQrrWm0+kwnU65vLzk2bNnpGnKeDymKLpo7VVXSeK5Y928\nw/njc9brNUmSkCYJVVVx9eaKpmlI05TzJ0+Y3t6yXq3oFQXGGJRSdPOc0XCI0Yrp9JrD0Yh+0aWp\nNcdHj3AkCBKq8uFAmchvTFr//pxrWzFYPPG+BUDW+ZYUN9M1z59fc/VmxrZSWLwEO0+9QVlZG+rW\no8u4wM2A2fSO3332HIDZ7R2zW6+GuXpzycXzC1+esjDs5XSy9lq7t18YUeRFAFK2i/h9lgFaLx53\nz1t46/T7j442c/F7ipYHPNUODLXZModWjqmG143YvSsBpAK6keU4hY9elyil0NZhRYwx0Fws+Omw\nj1E1g3xCvzshjQoEA6IIVFNyefdb3kw/Yzg4RkaCzeqOsl4ihCWNOwgZEaVDRodDHp8Xu8ydc1/z\n+r8h/MHVZw1aEi1Bet4CKc85EVgdms7a+5Kkcc6/v5ARE23pxbkddhT+/EySCJIMZGyD84lgsYau\n/B5J3gcLpqr9ZTYaYS1CCqQ1JB+eg0wQ1uK2a08ZSDLQCr26Q0RBoWYUUqREWZdOnnJ4MOL6zWuS\nJKJcf337hK+97i3Q4x4vg89gGdsCkhZ0hu+1pSXjUKYdZdDKeIsH7Xk+VrXKLj/Wxhcu/N4a7BJw\n/uzsCcneUf0LJHvnR9zY9nrxVtYHwsaOwZEEwY1SdrdeLG/XQRDw9XfDNwKe1RIWjcQmOXnqS0Wp\nqtDLVzw6PWQkDUmkqTeazc0bRLNk0yy5MiNGh8fof/lviPSUyCia0WMSXZFlxyT1DKUqiuIRlXA4\nl2CkYnj998yWKxoZ4fSWjethXELWbJBJjhIZ5uYFVa3Ixu8y6Q9JygV6+IS+jFlNb99SF0ThJnC7\nBqFGeE8OK+7vJmcFRkZ8xYHym+dPWHVUMCywQGI0whrqxiuilLGkIsJog5QWaUEojdSKRniQ4WgJ\nzSGrEJj8yvlMjxTWd1EPpC/rDKlzaGv9Zhdke5F1vg9Xe/c65yXx+NMLgHCKKvB9otCqonHO26Fj\nqa3ykv2mQTY+05NrS/MtMjzvnk5ItUM0Daqy3M42fPZqzrvnR9Sx5IffecQgi6m1Y6ktJ5MBnTwj\nEymzsuKj1zO09b1nImno5BlRJJluKw6OOgyGfZxTGKtJYp8BWq9KnHb0DmIm4x5X12vyIsMoRx4Z\n+r0EJy3aWOJIsakabASdSJDEGSJOiGLB2WHPuyuXinWl0ELx6//7nKZyjHodPn55BU4z7MR0M8ns\nbvmgsVlvt9R1jTGG8XhM0zS7r+uqYr1ec3FxQZokRHHMycmJr1mnKVVV8tFHf4vWGikldV0znU6J\noojpdMrBwcEOBBVFQZIkDIdDjo+PyPOcu7tbVqsVFy8uuLq6Js9zoiji+nqKlJL3338f5xxVVfLr\nX/2K7WYLQjC9nrLZbDk7O/eLy3KFtZY4TjDKkHd6NI1jW1peXlwzvZnT7XQfPG+Q/kQZxW/Tgn20\nJzdn7W6zB8IiabHWsNpUaGPv53zsZdZ5EnkehvJMFhsWVaNhcjTh8Mj3Q5scjXnyzhknp0c8fvKU\nD374vt+8jJ837ktLpiBkZoMlP8KXPkLroZ0x4L00NwAN+a0SPLu/+TJhud00dpkN+NJm8tYJ/63n\nA/+ajILLxvGyEeE9+B/FAg5z+Oy2RNmEWuvdTjRbLjka5DyWkHUeM51/jjEbNs0CZQQyGpDFHaRI\nKasVCOj1jsmSgk42CYZykkgmkJzwn/69N4yzPJzQfT8mLpDFv1i627kqh0y6MoSO9T6D0IJIyT3Q\naeX7u7EMH6X0/d6SxJGkgih2iNjPp4vF0vd/qhqk0gitiKoGaQIh2jryJ09wxvoniYOjtG5w1hIn\nHUTWwd8IEUmnQyQjlssVTW0ZDCdEaUqneNi9FTpT7TKBLryZnbLQBvO/kGmxxnkwEzI11nkzXS/X\nN9igGHY63EtGeCsA7rN0xvqeZl6O7ueYNQH0BPDjAvihLYNZ349rB7xMIDN7/0YE0E9ilputd3Z2\nnjTdaINWGt0odPP1rSW+EfA0zS3CKJKsC8kBiTW4NCcanVDeTKniPtYakuiWYjgifvpDCjbUs2uW\nsxny5hVJ2qc3OmWSGIRdk2YSmR6gSkWdFLj+MbIjiZIRUZKQLF9TV5o47xCbK4RaIodHzNcNkboj\n7Z0gTcn8bo0tJtSyIJ89Z11XYHypwi88Xpm0S2H64lbw5PGZHc9xEcSAeyDfQFmv1LLWoowhspba\nWZTxZoZS+ayK1ooURyMsifQy8SKWJJGnSStjvaIiSM1jAXkgmybCO11WeL6REr7JqBHQWiXW1uCc\nb2TqcGT4dhEa3xW+dgbhvGFYYwXC+lKKsobGF713J5gY3zZAa43SCls1GKsfLLsG+PziJYk1vLpa\n8OZ2QSJinp1MaLRj1EmoGsXFdMP/+vSSVMOPz444HqScDAS9JOX9R2PORx22wvGLyzucgNW64cPH\nE+KOY1NWVNstVWm4XWi2m9KT7IxjkKYYseVwnLNeaISV9AYd0lwSi5giM0gjiE1Bss2JhCCLYdzP\nKboxK6PppAmDIDk3lSaRMbNVRdFJ6cYZlRNUQSGwTh7G4RmOx6zXG2Z3d7x+9Yp/+k9+ugM819fX\nvP/++zx79ozVek2vKGiaBqUUJ6en1HXN0fExBwcHbLdblssl5+fnrFYrPvzwA+I4Jk1Tqqri9vaW\n29tbfvPrX7NarViv1x6gGMPh4SHWGvI8Z7lc0O/3GQ1HzOeLXWZoNBpzfHzEoN9HKUVR9FgtV6Fn\nV0Wv12W59D21yq0iirt0izHPnn2ItRGz+cOAIICrfU8f1XxFCiR8aYz58hEV8KWwo0nBwTglToPB\nnPOeKVkWE0cRnTzyfDUBSQTdXCIjQa9XhOfWaOXd040xqEZhnMA0sF5sPQnz7dcLO7U4+PvZ7jYV\nn2EQAohARMLf4C235qFjw1t/Ez75AjBoQdA/UBH6quyJLyvAVDnetJyeAHzyGF7eLJiuK3CSpvYu\nhcmgA2/uOO2u6Kopadqj2x0wHr6DlBHXN39LLz+kNzymrNcIp7i5fYWMM+arK6SzKF15MKkUf/an\nP969j2+T4XHtxQjmkjZksUHsJPvGBGJyKHHtMhvsLkuoB/jsoJB+094BBXyGR0YQhyxPnAgiKYhT\neGM+pa42sK08Z7NRQbjvM/ByXCC7XZDSS9IjCdbipDcGdUm8487hIMoKjG7Iuz2iLGO1WVFuaj79\n3eXDxgZxDwBFAICBP4PzdIdG+ayXCzJ+04JB5w1ZdbA4MMaGA0AgF2vhjQu19X8fgIoJjQ7alhW7\ncpkBAqD5ovGgB1u7Mpkm/L8gk9f+Is8rv0clQBkMM8vG4SJJnKUk0dfv5d+4y0dZxmh0SCKEV+yM\n+6zn10QuJnYaaSpUEzEYnYLsQLWiO3mPfi/FpRlaCAZmzcJY4sE7GDliLcdcz+d0jh6zXjcIo6hs\nQiRqmkYS9fs4GZPrDbdVgs4GXN3MGGQlpphgo4TkYELXXnOxMNhOF90dI1/+jjgsRg6QtlVggRCS\n1ne57UnlrPQ4PpS0nDAPmkAYiK1BWE0iHFo7YusglIYaIYixAWgYhNY4bb3DsrE7XwYRiXtvBus3\n7No6UoTn/+C8lM55eXoKKGvQRmNd+zMbpPjsnJcj60nJrYW5ILS+wEvjYyfIjcNa5XskhcyQMwqr\nFaLR0BiassJ9A2L+unjvfMLRZMBqveV2eocoS5R1bJRisTXMS4VII945HZLkMbPtAuck863i7LjL\n9WLFstZI3bBdKKbrkkeTgjhzZIXjcOwQkWa7LanKJWnHkaZwkEleT5f0OwdImVHWDVvVeCdZ5VAN\nvL4pmZwMGQ1jor5j0IuptGW1KdEOVsuKvJexKRs6RU4nSSi6KdeLLd0IJsMueRxzPTMcTCaMew+T\npc9Wa7S1HB4eUhQFv/zl/yFNUppGcXJyws31NZv1mneePuXi4oLLy0vOzh4xn888R0JrtDZIKdlu\nt0ynUx49esTPf/4LVqsVh4eHO6ffXq9Hnud+gUgS8jxnuy2RUUSv12c+n3NwcIiUEcvVim63w2x2\nh7WWqioZDPpstxuKfh+tDdc3U6I4ptPp0ulkSCkoej2KwYAoTZkcPeH15S1P3/0u50+ePXjeiMB1\nEV9RYr7HN/f8Cr8JCUScsFkr3+zV+pKw38T8aTGWvkWLDeAeC7GUlI1lNp1x+eoNAC9evObN5Q2f\nv7jg1cXv+M3f/T0m3Dte7fn7S6bAnzIRnswrQypGCHGvHrA+o+CM33FabseDxqb92IIl98XHrhz1\n5aFzX/zY8ja+HC5sYK9rx1z7TV/i/cUmPfh8DsY2ocSgqRX88i9ew4sLzqMKozdMb35DFhcIGXF4\n8C5J2mez1URSMFtNEZGiWl8QxwJjNTLqYlxJYyN+9KPvMxhnD0eCIezbm6e/HIGkfV9i3IEh6/ld\nHgC9Vepqf4+Wey7CtQzNMm1b2hLIyDdhTjKIEkecQBnB6+UtVmm/juqwtju/7qbvHsGgh6vK+wsh\nvB2JF9i0gNhAlCBw9CaPwJUUnQxTb4kjR6/7sP59/v24wHFqvXJCuakt6YXx0PaePKyNTxUY63Zm\njs4alPEVDBVAkgnZIGscRkuUEhgtvbeQdiGb016fwJfSYtewtR139xZY9aXHAJR2JbLQkqzR3h5E\n+/+5bQxREqMbS6//9dmvbwY81rJWMTLNiDZz6ukGVxyynd3RDA+pkx5qNaXjHOu6pApMb1Wc4IjZ\nxgVbCV29ZVvesXVrcnXLeHuJNTWdXoKsVhwOBySDU7qjPp0kZ1BYytG7DNUCIQyFW1H3zkE0ZKZB\n64x4/A6PxBs6AUVWzy/9JBTsTL4S6031HH4hEjtPChCRVyjFO07Pwzo7C+lRrtK+tBQbQ2O8EWFl\nNFZXviYbOucK4yitxhlLGse08EpoSxQmX2V9LyJrfXYojmWYYNabVjlLbY3PSOGzRVgPipKdosCv\nfs5ZZCBrW9f+jj/x+L/3XCGcv87CGu9Cay2xVVjrm5uiDOYr+BT/UDTzFetFSRFBZSwXt0tm8xl9\nIqKo4fnNmnEW87PzA1ZbRa0ch/0M4yRv7jak3YxBLklkxFEv5zsHBZNezitniANXKUpyRGIYHw7Q\nJseQMVeSUS9lvVnTNCV5LjBKoGtYrx1JLJiMOpimQokYjKQ/Lhh1MuJOTp5JlBXoqqEjJFWjUdaR\nRRnDbsyiNEgRkaF5epxwt1jC9uZBYxP3JmyVYLlc0DQ1h4eHGGMYDgdcXV35LuZxhNaaoig4ODjY\nEZLTNCVNUwaDAYvFgqOjI77zne/w/PlzHj9+zOnpaShJVQghMMbQ7/dQSrHZbCnLiuFwSNNU3N3d\nUZYVt7d3SCmw1nI3m3F0dIwxhtFoRL/Xo9yWFN2O77EVJzw+O6OT5yyXS/r9Hlknw2IZDEb85V/+\nb4piwCef/I5fffTrB8+bHZHyK6Zcu49bbfkCB8aBUZpuHrPZKO5m1b1LrvUGcxp/mFDaZ37bClMs\nYXI84d33ngLw7rNzzs7PePrsKe9857v80R//ePe/hHAkyVeQqYMCxquiwucmZE28B6r/PSk8kGuR\nxLcZmxa0fAnsfCH9Y7/ie299fBvsfLnM1Zq8vaoEdXjdsYR+Ap/cQa0tRjV+k1OKWvb47Scr7MUL\nRrdThoN32FRXWKsot3NeTT/F2ZJar+nmPdL8kP7wu6RphnE5cVwQRyO0rnAy4T/+u38RxvRbjI+n\n+WJCKcrulD6tIev9xqoDYvSbvQiKq/tBazuBO0DKIKR2bfnGP4+MBEniMztJ4itUaQJ/c/UpThlc\no3eZGhlFSCnJ33vsmzerCmc0tmlwUeKVsKqBOMNZDQjfMb2pUdWW4cExMYKT08ckWY/j479+6OC8\nzXN/q+Tpbzgb3JJNICybFmTssl9eSGOd75autb+XtAarJVqDtpbGuMCNEujA3m9LWTaUEN92ct5l\nctqsTgu+9H17C9dyekKhIZUSrRok0Gif4VHGUVcKrTTw4muHQbivgvr72Mc+9rGPfexjH39A8S37\n9u5jH/vYxz72sY99/P8Te8Czj33sYx/72Mc+/uBjD3j2sY997GMf+9jHH3zsAc8+9rGPfexjH/v4\ng4894NnHPvaxj33sYx9/8LEHPPvYxz72sY997OMPPv4fdIUDdUTUECYAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [] - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "olF4PpORpCTK", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} diff --git a/trax/models/reformer/machine_translation.ipynb b/trax/models/reformer/machine_translation.ipynb deleted file mode 100644 index 55192bf21..000000000 --- a/trax/models/reformer/machine_translation.ipynb +++ /dev/null @@ -1,382 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Reformer: Machine Translation", - "provenance": [], - "collapsed_sections": [ - "udDs_biH0n5U" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "udDs_biH0n5U", - "colab_type": "text" - }, - "source": [ - "#### Copyright 2020 Google LLC." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WPY-OyyM0pSs", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "psnUF-8c02o_", - "colab_type": "text" - }, - "source": [ - "# Reformer: Machine Translation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1lnRd_IoERdk", - "colab_type": "text" - }, - "source": [ - "This notebook was designed to run on TPU.\n", - "\n", - "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8PluCmWbZIpJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install JAX.\n", - "!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .\n", - "!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .\n", - "!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl\n", - "!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl\n", - "\n", - "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", - "import requests\n", - "import os\n", - "if 'TPU_DRIVER_MODE' not in globals():\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", - " resp = requests.post(url)\n", - " TPU_DRIVER_MODE = 1\n", - "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "from jax.config import config\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "print(config.FLAGS.jax_backend_target)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yiPdBenoZwH6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3\n", - "\n", - "from tensorflow.compat.v1.io.gfile import GFile\n", - "import gin\n", - "import os\n", - "import pickle\n", - "import jax\n", - "import trax\n", - "from trax.models.beam_search import Search\n", - "from trax.supervised import inputs\n", - "\n", - "from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder\n", - "\n", - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "from scipy.special import softmax" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "uCX88z9iXB7s", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install sacreBLEU\n", - "!pip install sacrebleu\n", - "import sacrebleu" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FQ89jHCYfhpg" - }, - "source": [ - "## Load WMT14 data" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8S3h28Q9b_9B", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Download the newstest2014 English-to-German translation pairs\n", - "!sacrebleu -t wmt14/full -l en-de --echo src > wmt14-en-de.src\n", - "!sacrebleu -t wmt14/full -l en-de --echo ref > wmt14-en-de.ref" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "CBv2SDnWZEI7", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Load the source text and reference translations into Python\n", - "refs = []\n", - "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1):\n", - " if line.endswith('\\n'):\n", - " line = line[:-1]\n", - " refs.append(line)\n", - "srcs = []\n", - "for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1):\n", - " if line.endswith('\\n'):\n", - " line = line[:-1]\n", - " srcs.append(line)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "CbYw4eMXZGKa", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Set up our sub-word tokenizer\n", - "tokenizer = SubwordTextEncoder(\n", - " 'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "2NbOslppZGZ0", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Encode source sentences using the tokenizer\n", - "input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64)\n", - "for i, x in enumerate(srcs):\n", - " x = tokenizer.encode(x)\n", - " assert len(x) <= 127\n", - " input_ids[i, :len(x)] = x\n", - " input_ids[i, len(x)] = 1" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YwzU64GmZTb2", - "colab_type": "text" - }, - "source": [ - "## Load the pre-trained model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "VXjtCPxl3I82", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# We'll be using a pre-trained reversible transformer-base model.\n", - "# First, load the config (which sets all needed hyperparameters).\n", - "!gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin\n", - "gin.parse_config_file('./config.gin')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "IediBe8MXyLf", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Now we load the pre-trained model weights.\n", - "with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f:\n", - " model_weights = pickle.load(f)['weights']" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zY3hpgnI5Rgn", - "colab_type": "text" - }, - "source": [ - "## Beam search decoding" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fc_VlhrBYW0u", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Set up beam search.\n", - "beam_decoder = Search(\n", - " trax.models.Reformer, model_weights,\n", - " beam_size=4,\n", - " alpha=0.6, # For length normalization, set to 0.6 following Vaswani et al.\n", - " eos_id=1, # The stop token has id 1 in the vocabulary we use.\n", - " max_decode_len=146,\n", - " )" - ], - "execution_count": 0, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "bynTpreMYXPs", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 71 - }, - "outputId": "cfd24e01-617b-4beb-a5f2-98a7ce2e1449" - }, - "source": [ - "pred_ids = []\n", - "preds = []\n", - "BATCH_SIZE = 1024\n", - "for start in range(0, input_ids.shape[0], BATCH_SIZE):\n", - " print(start, '/', input_ids.shape[0], flush=True)\n", - " batch = input_ids[start:start+BATCH_SIZE]\n", - " seqs, scores = beam_decoder.decode(batch, batch_size=BATCH_SIZE)\n", - " # Select highest scoring output.\n", - " batch_pred_ids = seqs[:, -1]\n", - " pred_ids.append(batch_pred_ids)\n", - " preds.extend([\n", - " tokenizer.decode(pred.tolist(), strip_extraneous=True)\n", - " for pred in batch_pred_ids\n", - " ])" - ], - "execution_count": 13, - "outputs": [ - { - "output_type": "stream", - "text": [ - "0 / 3003\n", - "1024 / 3003\n", - "2048 / 3003\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "c5Gq4qF_YY2i", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 35 - }, - "outputId": "37a5e24f-9264-4d7a-dd74-065758c9a7e4" - }, - "source": [ - "bleu = sacrebleu.corpus_bleu(preds, [refs], lowercase=True, tokenize='intl')\n", - "print(bleu)" - ], - "execution_count": 14, - "outputs": [ - { - "output_type": "stream", - "text": [ - "BLEU = 27.86 59.5/33.5/21.3/14.2 (BP = 1.000 ratio = 1.020 hyp_len = 65943 ref_len = 64676)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "olF4PpORpCTK", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": 0, - "outputs": [] - } - ] -} diff --git a/trax/models/reformer/reformer.py b/trax/models/reformer/reformer.py index 73edeae7a..0101f7a77 100644 --- a/trax/models/reformer/reformer.py +++ b/trax/models/reformer/reformer.py @@ -19,624 +19,776 @@ from trax.fastmath import numpy as jnp from trax.models.research import configurable_transformer as ct - # Layers are always CamelCase, but functions in general are snake_case # pylint: disable=invalid-name -def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, - n_heads, attention_type, dropout, ff_activation, - ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, - attention_chunk_size, n_attention_layers=1, - n_feedforward_layers=1, center_layernorm=True, - use_bfloat16=False, mode='train'): - """Reversible transformer decoder layer. - - Args: - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - d_attention_key: int: depth of key vector for each attention head - d_attention_value: int: depth of value vector for each attention head - n_heads: int: number of attention heads - attention_type: subclass of tl.BaseCausalAttention: attention class to use - dropout: float: dropout rate (how much to drop out) - ff_activation: the non-linearity in feed-forward layer - ff_dropout: the dropout rate in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - attention_chunk_size: int, if > 0 run attention chunked at this size - n_attention_layers: how many residual causal attention layers should we - have before the feed-forward block (default: 1, the standard block) - n_feedforward_layers: how many FFNN layers should we have (default 1). - center_layernorm: whether to use centering in LayerNorm (default) or if - to skip it, which is known as RMS normalization. - use_bfloat16: whether to use bfloat16 for weights (default: False). - mode: str: 'train' or 'eval' - - - Returns: - the layer. - """ - # pylint: disable=g-complex-comprehension - def _Attn(): - return ct.ApplyAttentionLayer( - attention_type, d_model, n_heads, d_attention_key, - d_attention_value, True, False, dropout, dropout, - attention_chunk_size, mode) - - def _FF(): - return ct.FeedForwardWithOptions( - d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, - ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, - mode, use_bfloat16) - - def _attention_half_residual(): - return [ - tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm), - attention_layer=_Attn(), - name='ReversibleHalfResidualDecoderAttn'), - tl.ReversibleSwap() +def DecoderBlock( + d_model, + d_ff, + d_attention_key, + d_attention_value, + n_heads, + attention_type, + dropout, + ff_activation, + ff_dropout, + ff_use_sru, + ff_chunk_size, + ff_sparsity, + attention_chunk_size, + n_attention_layers=1, + n_feedforward_layers=1, + center_layernorm=True, + use_bfloat16=False, + mode="train", +): + """Reversible transformer decoder layer. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_heads: int: number of attention heads + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + ff_activation: the non-linearity in feed-forward layer + ff_dropout: the dropout rate in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + attention_chunk_size: int, if > 0 run attention chunked at this size + n_attention_layers: how many residual causal attention layers should we + have before the feed-forward block (default: 1, the standard block) + n_feedforward_layers: how many FFNN layers should we have (default 1). + center_layernorm: whether to use centering in LayerNorm (default) or if + to skip it, which is known as RMS normalization. + use_bfloat16: whether to use bfloat16 for weights (default: False). + mode: str: 'train' or 'eval' + + + Returns: + the layer. + """ + + # pylint: disable=g-complex-comprehension + def _Attn(): + return ct.ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_attention_key, + d_attention_value, + True, + False, + dropout, + dropout, + attention_chunk_size, + mode, + ) + + def _FF(): + return ct.FeedForwardWithOptions( + d_model, + d_ff, + dropout, + [-2], + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + center_layernorm, + mode, + use_bfloat16, + ) + + def _attention_half_residual(): + return [ + tl.ReversibleHalfResidual( + tl.LayerNorm(center=center_layernorm), + attention_layer=_Attn(), + name="ReversibleHalfResidualDecoderAttn", + ), + tl.ReversibleSwap(), + ] + + def _feed_forward(): + return [ + tl.ReversibleHalfResidual(_FF(), name="ReversibleHalfResidualDecoderFF"), + tl.ReversibleSwap(), + ] + + return [_attention_half_residual() for _ in range(n_attention_layers)] + [ + _feed_forward() for _ in range(n_feedforward_layers) ] - def _feed_forward(): - return [ - tl.ReversibleHalfResidual(_FF(), - name='ReversibleHalfResidualDecoderFF'), - tl.ReversibleSwap() + +def ReformerLM( + vocab_size, + d_model=512, + d_ff=2048, + d_attention_key=64, + d_attention_value=64, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + attention_type=tl.SelfAttention, + pos_type=None, + pos_axial_shape=(), + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + ff_activation=tl.FastGelu, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + mode="train", +): + """Reversible transformer language model (only uses a decoder, no encoder). + + Args: + vocab_size: int: vocab size + d_model: int: depth of *each half* of the two-part features + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_layers: int: number of decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + attention_type: class: attention class to use, such as SelfAttention. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + ff_activation: the non-linearity in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + loss_sparsity_type: str, type of sparsity to used in loss layer. See + SparseDenseWithOptions for options. None if no sparsity should be used. + loss_sparsity: int, the sparsity for loss layer (if used) + loss_d_lowrank: int, the dimensions for intermediate layer (if used) + loss_sparsity_prob: float, the probability for sparse version of loss to be + used. If None, only sparse version is used. + attention_chunk_size: int, if > 0 run attention chunked at this size + mode: str: 'train', 'eval', or 'predict' + + Returns: + the layer. + """ + positional_encoding = ct.PositionalEncoder( + mode, + dropout, + max_len, + pos_type, + pos_axial_shape, + pos_d_axial_embs, + pos_start_from_zero_prob, + pos_max_offset_to_add, + ) + + positional_embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + positional_encoding, ] - return ([_attention_half_residual() for _ in range(n_attention_layers)] - + [_feed_forward() for _ in range(n_feedforward_layers)]) - - -def ReformerLM(vocab_size, - d_model=512, - d_ff=2048, - d_attention_key=64, - d_attention_value=64, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - attention_type=tl.SelfAttention, - pos_type=None, - pos_axial_shape=(), - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - ff_activation=tl.FastGelu, - ff_use_sru=0, - ff_chunk_size=0, - ff_sparsity=0, - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - mode='train'): - """Reversible transformer language model (only uses a decoder, no encoder). - - Args: - vocab_size: int: vocab size - d_model: int: depth of *each half* of the two-part features - d_ff: int: depth of feed-forward layer - d_attention_key: int: depth of key vector for each attention head - d_attention_value: int: depth of value vector for each attention head - n_layers: int: number of decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - attention_type: class: attention class to use, such as SelfAttention. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - ff_activation: the non-linearity in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - loss_sparsity_type: str, type of sparsity to used in loss layer. See - SparseDenseWithOptions for options. None if no sparsity should be used. - loss_sparsity: int, the sparsity for loss layer (if used) - loss_d_lowrank: int, the dimensions for intermediate layer (if used) - loss_sparsity_prob: float, the probability for sparse version of loss to be - used. If None, only sparse version is used. - attention_chunk_size: int, if > 0 run attention chunked at this size - mode: str: 'train', 'eval', or 'predict' - - Returns: - the layer. - """ - positional_encoding = ct.PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs, - pos_start_from_zero_prob, pos_max_offset_to_add) - - positional_embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - positional_encoding, - ] - - decoder_blocks = [] - - if isinstance(attention_type, (tuple, list)): - assert n_layers % len(attention_type) == 0 - else: - attention_type = [attention_type] - for layer_idx in range(n_layers): - layer_attention_type = attention_type[layer_idx % len(attention_type)] - decoder_block = DecoderBlock( - d_model, d_ff, d_attention_key, d_attention_value, n_heads, - attention_type=layer_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - mode=mode) - decoder_blocks.append(decoder_block) - - dense_loss_layer = tl.SparseDenseWithOptions( - vocab_size, - d_input=d_model, - sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, - d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, - mode=mode) - - return tl.Serial( - tl.ShiftRight(mode=mode), - positional_embedder, - tl.Dup(), - tl.ReversibleSerial(decoder_blocks), - tl.Concatenate(), - # TODO(kitaev): Test whether dropout should go before or after the - # LayerNorm, and whether dropout broadcasting is needed here. - tl.LayerNorm(), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - dense_loss_layer, - ) - - -def ReformerShortenLM(vocab_size, - shorten_factor=1, - d_embedding=256, - d_model=512, - d_ff=2048, - d_attention_key=64, - d_attention_value=64, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - attention_type=tl.SelfAttention, - pos_type=None, - pos_axial_shape=(), - pos_d_axial_embs=None, - ff_activation=tl.FastGelu, - ff_use_sru=0, - ff_chunk_size=0, - ff_sparsity=0, - attention_chunk_size=0, - mode='train'): - """Reversible transformer language model with shortening. - - When shorten_factor is F and processing an input of shape [batch, length], - we embed the (shifted-right) input and then group each F elements (on length) - into a single vector -- so that in the end we process a tensor of shape :: - - [batch, length // F, d_model] - - almost until the end -- at the end it's un-shortend and a SRU is applied. - This reduces the length processed inside the main model body, effectively - making the model faster but possibly slightly less accurate. - - Args: - vocab_size: int: vocab size - shorten_factor: by how much to shorten, see above - d_embedding: the depth of the embedding layer and final logits - d_model: int: depth of *each half* of the two-part features - d_ff: int: depth of feed-forward layer - d_attention_key: int: depth of key vector for each attention head - d_attention_value: int: depth of value vector for each attention head - n_layers: int: number of decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - attention_type: class: attention class to use, such as SelfAttention. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, values must sum to d_embedding. - ff_activation: the non-linearity in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - attention_chunk_size: int, if > 0 run attention chunked at this size - mode: str: 'train' or 'eval' - - Returns: - the layer. - """ - assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference - - positional_encoding = ct.PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs) - - positional_embedder = [ - tl.Embedding(vocab_size, d_embedding), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - positional_encoding, - ] - - decoder_blocks = [] - - if isinstance(attention_type, (tuple, list)): - assert n_layers % len(attention_type) == 0 - else: - attention_type = [attention_type] - for layer_idx in range(n_layers): - layer_attention_type = attention_type[layer_idx % len(attention_type)] - decoder_block = DecoderBlock( - d_model, d_ff, d_attention_key, d_attention_value, n_heads, - attention_type=layer_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - mode=mode) - decoder_blocks.append(decoder_block) - - # pylint: disable=g-long-lambda - return tl.Serial( - tl.ShiftRight(), - positional_embedder, - tl.Dup(), # Stack has (x, x), the first will be shortened - # Before shortening, we need to pad by shorten factor so as not to leak - # information into the future. To understand why, imagine shorten factor - # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we - # would have 0ABC, which gets grouped to [0A][BC] on input, which is - # predicting ABCD as targets. The problem is that [0A] has access to A - # and [BC] has access to C -- it will learn to copy it, peek into - # the future. Shifting twice to [00][AB] solves the problem as the first - # "big" symbol becomes all-0 and the rest is shifted enough. - tl.ShiftRight(n_positions=shorten_factor - 1), - tl.Fn('Shorten', lambda x: jnp.reshape( # Shorten -- move to depth. - x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), - tl.Dense(d_model), - tl.Dup(), # Stack has (short_x, short_x, x) - tl.ReversibleSerial(decoder_blocks), - tl.Select([0], n_in=2), - tl.LayerNorm(), - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter - tl.Dense(shorten_factor * d_embedding), - tl.Fn('ProlongBack', lambda x: jnp.reshape( # Prolong back. - x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), - tl.Concatenate(), # Concatenate with just the embeddings. - tl.CausalConv(d_embedding), - tl.Relu(), - tl.SRU(d_embedding), # One RNN layer for conditional dependence. - tl.Dense(vocab_size), - ) - # pylint: enable=g-long-lambda - - -def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, - ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, - attention_chunk_size=0, center_layernorm=True, - use_bfloat16=False, use_two_swaps_per_block=True, - mode='train'): - """Returns a list of layers that implements a Reformer encoder block. - - The input to the layer is a pair, (activations, mask), where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. - - Args: - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_heads: int: number of attention heads - attention_type: subclass of tl.BaseCausalAttention: attention class to use - dropout: float: dropout rate (how much to drop out) - ff_activation: the non-linearity in feed-forward layer - ff_dropout: the dropout rate in feed-forward layer - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - attention_chunk_size: int, if > 0 run attention chunked at this size - center_layernorm: whether to use centering in LayerNorm (default) or if - to skip it, which is known as RMS normalization. - use_bfloat16: whether to use bfloat16 for weights (default: False) - use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder - block, otherwise use only one swap. - mode: str: 'train' or 'eval' - - Returns: - A list of layers that maps (activations, mask) to (activations, mask). - """ - if mode == 'predict': - # Mode 'predict' means that the decoder should be run one token at a time. - # The encoder only ever runs over full sequences, which is why it's switched - # to 'eval' mode instead. - mode = 'eval' - - def _Attn(): - return ct.ApplyAttentionLayer( - attention_type=attention_type, d_model=d_model, n_heads=n_heads, - d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False, - attention_dropout=dropout, output_dropout=dropout, - attention_chunk_size=attention_chunk_size, mode=mode) - - def _FF(): - return ct.FeedForwardWithOptions( - d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, - ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, - mode, use_bfloat16) - - # TODO(lukaszkaiser): refactor efficient attention layers to unify the API - # If we're using standard attention, we need to pass reshaped mask and not - # return the mask to be compatible with the EfficientAttention API. - attention = _Attn() - if attention.n_out == 2: - attention = tl.Serial( - tl.Parallel([], _InsertAxes12()), - attention, - tl.Select([0], n_in=2) + decoder_blocks = [] + + if isinstance(attention_type, (tuple, list)): + assert n_layers % len(attention_type) == 0 + else: + attention_type = [attention_type] + for layer_idx in range(n_layers): + layer_attention_type = attention_type[layer_idx % len(attention_type)] + decoder_block = DecoderBlock( + d_model, + d_ff, + d_attention_key, + d_attention_value, + n_heads, + attention_type=layer_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + decoder_blocks.append(decoder_block) + + dense_loss_layer = tl.SparseDenseWithOptions( + vocab_size, + d_input=d_model, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + mode=mode, ) - def _attention_half_residual(): - return [ - tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm), - attention_layer=attention, - name='ReversibleHalfResidualEncoderAttn'), - tl.ReversibleSwap() + return tl.Serial( + tl.ShiftRight(mode=mode), + positional_embedder, + tl.Dup(), + tl.ReversibleSerial(decoder_blocks), + tl.Concatenate(), + # # TODO(kitaev): Test whether dropout should go before or after the + # LayerNorm, and whether dropout broadcasting is needed here. + tl.LayerNorm(), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + dense_loss_layer, + ) + + +def ReformerShortenLM( + vocab_size, + shorten_factor=1, + d_embedding=256, + d_model=512, + d_ff=2048, + d_attention_key=64, + d_attention_value=64, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + attention_type=tl.SelfAttention, + pos_type=None, + pos_axial_shape=(), + pos_d_axial_embs=None, + ff_activation=tl.FastGelu, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, + attention_chunk_size=0, + mode="train", +): + """Reversible transformer language model with shortening. + + When shorten_factor is F and processing an input of shape [batch, length], + we embed the (shifted-right) input and then group each F elements (on length) + into a single vector -- so that in the end we process a tensor of shape :: + + [batch, length // F, d_model] + + almost until the end -- at the end it's un-shortend and a SRU is applied. + This reduces the length processed inside the main model body, effectively + making the model faster but possibly slightly less accurate. + + Args: + vocab_size: int: vocab size + shorten_factor: by how much to shorten, see above + d_embedding: the depth of the embedding layer and final logits + d_model: int: depth of *each half* of the two-part features + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_layers: int: number of decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + attention_type: class: attention class to use, such as SelfAttention. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, values must sum to d_embedding. + ff_activation: the non-linearity in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + attention_chunk_size: int, if > 0 run attention chunked at this size + mode: str: 'train' or 'eval' + + Returns: + the layer. + """ + assert mode != "predict" # TODO(lukaszkaiser,kitaev): fast inference + + positional_encoding = ct.PositionalEncoder( + mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs + ) + + positional_embedder = [ + tl.Embedding(vocab_size, d_embedding), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + positional_encoding, ] - def _feed_forward(): - layers = [ - tl.ReversibleHalfResidual(_FF(), - name='ReversibleHalfResidualEncoderFF') + decoder_blocks = [] + + if isinstance(attention_type, (tuple, list)): + assert n_layers % len(attention_type) == 0 + else: + attention_type = [attention_type] + for layer_idx in range(n_layers): + layer_attention_type = attention_type[layer_idx % len(attention_type)] + decoder_block = DecoderBlock( + d_model, + d_ff, + d_attention_key, + d_attention_value, + n_heads, + attention_type=layer_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + decoder_blocks.append(decoder_block) + + # pylint: disable=g-long-lambda + return tl.Serial( + tl.ShiftRight(), + positional_embedder, + tl.Dup(), # Stack has (x, x), the first will be shortened + # Before shortening, we need to pad by shorten factor so as not to leak + # information into the future. To understand why, imagine shorten factor + # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we + # would have 0ABC, which gets grouped to [0A][BC] on input, which is + # predicting ABCD as targets. The problem is that [0A] has access to A + # and [BC] has access to C -- it will learn to copy it, peek into + # the future. Shifting twice to [00][AB] solves the problem as the first + # "big" symbol becomes all-0 and the rest is shifted enough. + tl.ShiftRight(n_positions=shorten_factor - 1), + tl.Fn( + "Shorten", + lambda x: jnp.reshape( # Shorten -- move to depth. + x, (x.shape[0], x.shape[1] // shorten_factor, -1) + ), + n_out=1, + ), + tl.Dense(d_model), + tl.Dup(), # Stack has (short_x, short_x, x) + tl.ReversibleSerial(decoder_blocks), + tl.Select([0], n_in=2), + tl.LayerNorm(), + tl.Dropout( + rate=dropout, shared_axes=[-2], mode=mode + ), # pylint: disable=no-value-for-parameter + tl.Dense(shorten_factor * d_embedding), + tl.Fn( + "ProlongBack", + lambda x: jnp.reshape( # Prolong back. + x, (x.shape[0], x.shape[1] * shorten_factor, -1) + ), + n_out=1, + ), + tl.Concatenate(), # Concatenate with just the embeddings. + tl.CausalConv(d_embedding), + tl.Relu(), + tl.SRU(d_embedding), # One RNN layer for conditional dependence. + tl.Dense(vocab_size), + ) + # pylint: enable=g-long-lambda + + +def EncoderBlock( + d_model, + d_ff, + n_heads, + attention_type, + dropout, + ff_activation, + ff_dropout, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, + attention_chunk_size=0, + center_layernorm=True, + use_bfloat16=False, + use_two_swaps_per_block=True, + mode="train", +): + """Returns a list of layers that implements a Reformer encoder block. + + The input to the layer is a pair, (activations, mask), where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + ff_activation: the non-linearity in feed-forward layer + ff_dropout: the dropout rate in feed-forward layer + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + attention_chunk_size: int, if > 0 run attention chunked at this size + center_layernorm: whether to use centering in LayerNorm (default) or if + to skip it, which is known as RMS normalization. + use_bfloat16: whether to use bfloat16 for weights (default: False) + use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder + block, otherwise use only one swap. + mode: str: 'train' or 'eval' + + Returns: + A list of layers that maps (activations, mask) to (activations, mask). + """ + if mode == "predict": + # Mode 'predict' means that the decoder should be run one token at a time. + # The encoder only ever runs over full sequences, which is why it's switched + # to 'eval' mode instead. + mode = "eval" + + def _Attn(): + return ct.ApplyAttentionLayer( + attention_type=attention_type, + d_model=d_model, + n_heads=n_heads, + d_qk=d_model // n_heads, + d_v=d_model // n_heads, + masked=True, + causal=False, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + + def _FF(): + return ct.FeedForwardWithOptions( + d_model, + d_ff, + dropout, + [-2], + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + center_layernorm, + mode, + use_bfloat16, + ) + + # TODO(lukaszkaiser): refactor efficient attention layers to unify the API + # If we're using standard attention, we need to pass reshaped mask and not + # return the mask to be compatible with the EfficientAttention API. + attention = _Attn() + if attention.n_out == 2: + attention = tl.Serial( + tl.Parallel([], _InsertAxes12()), attention, tl.Select([0], n_in=2) + ) + + def _attention_half_residual(): + return [ + tl.ReversibleHalfResidual( + tl.LayerNorm(center=center_layernorm), + attention_layer=attention, + name="ReversibleHalfResidualEncoderAttn", + ), + tl.ReversibleSwap(), + ] + + def _feed_forward(): + layers = [ + tl.ReversibleHalfResidual(_FF(), name="ReversibleHalfResidualEncoderFF") + ] + if use_two_swaps_per_block: + layers.append(tl.ReversibleSwap()) + return layers + + return _attention_half_residual() + _feed_forward() + + +def EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + ff_activation, + ff_dropout, + mode, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, +): + """Reversible transformer decoder layer. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + ff_activation: the non-linearity in feed-forward layer + ff_dropout: float: (optional) separate dropout rate for feed-forward layer + mode: str: 'train' or 'eval' + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + + Returns: + the layer. + """ + enc_dec_attention = tl.EncDecAttention( + n_heads=n_heads, + d_qk=d_model // n_heads, + d_v=d_model // n_heads, + attention_dropout=dropout, + output_dropout=dropout, + mode=mode, + ) + enc_dec_attention_half_residual = tl.ReversibleHalfResidual( + tl.LayerNorm(), + attention_layer=enc_dec_attention, + ) + + causal_attention = tl.SelfAttention( + n_heads=n_heads, + d_qk=d_model // n_heads, + d_v=d_model // n_heads, + causal=True, + attention_dropout=dropout, + output_dropout=dropout, + mode=mode, + ) + causal_attention_half_residual = tl.ReversibleHalfResidual( + tl.LayerNorm(), + attention_layer=causal_attention, + ) + + feed_forward = ct.FeedForwardWithOptions( + d_model, + d_ff, + dropout, + [-2], + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, + ) + + return [ # vec_d1 vec_d2 vec_e masks + causal_attention_half_residual, + tl.ReversibleSwap(), + enc_dec_attention_half_residual, + tl.ReversibleSwap(), + tl.ReversibleHalfResidual(feed_forward), + tl.ReversibleSwap(), + ] + + +def Reformer( + input_vocab_size, + output_vocab_size=None, + d_model=512, + d_ff=2048, + n_encoder_layers=6, + n_decoder_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + ff_activation=tl.Relu, + ff_dropout=None, + mode="train", + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + ff_use_sru=0, + ff_chunk_size=0, + ff_sparsity=0, +): + """Reversible transformer encoder-decoder model. + + This model expects an input pair: target, source. + + At the moment, this model supports dot-product attention only. For the + attention types in the Reformer paper, see ReformerLM. + + Args: + input_vocab_size: int: vocab size of the source. + output_vocab_size: int (optional): vocab size of the target. If None, the + source and target are assumed to have the same vocab. + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_encoder_layers: int: number of encoder layers + n_decoder_layers: int: number of decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + ff_activation: the non-linearity in feed-forward layer + ff_dropout: float: (optional) separate dropout rate at feed-forward + nonlinearity. This is called relu_dropout in T2T. + mode: str: 'train' or 'eval' + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + + Returns: + A Reformer model as a layer that maps from a target, source pair to + activations over a vocab set. + """ + in_encoder, out_encoder, output_vocab_size = ct.EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + mode, + dropout, + [-2], # dropout_shared_axes + max_len, + output_vocab_size=output_vocab_size, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + ) + + # pylint: disable=g-complex-comprehension + encoder_blocks = [ + EncoderBlock( + d_model, + d_ff, + n_heads, + tl.SelfAttention, + dropout, + ff_activation, + ff_dropout, + mode=mode, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + ) + for _ in range(n_encoder_layers) ] - if use_two_swaps_per_block: - layers.append(tl.ReversibleSwap()) - return layers - - return _attention_half_residual() + _feed_forward() - - -def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, - ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0, - ff_sparsity=0): - """Reversible transformer decoder layer. - - Args: - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - ff_activation: the non-linearity in feed-forward layer - ff_dropout: float: (optional) separate dropout rate for feed-forward layer - mode: str: 'train' or 'eval' - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - - Returns: - the layer. - """ - enc_dec_attention = tl.EncDecAttention( - n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, - attention_dropout=dropout, output_dropout=dropout, - mode=mode) - enc_dec_attention_half_residual = tl.ReversibleHalfResidual( - tl.LayerNorm(), - attention_layer=enc_dec_attention, - ) - - causal_attention = tl.SelfAttention( - n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, - causal=True, - attention_dropout=dropout, output_dropout=dropout, - mode=mode) - causal_attention_half_residual = tl.ReversibleHalfResidual( - tl.LayerNorm(), - attention_layer=causal_attention, - ) - - feed_forward = ct.FeedForwardWithOptions( - d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, - ff_chunk_size, ff_use_sru, ff_sparsity, True, mode) - - return [ # vec_d1 vec_d2 vec_e masks - causal_attention_half_residual, - tl.ReversibleSwap(), - enc_dec_attention_half_residual, - tl.ReversibleSwap(), - tl.ReversibleHalfResidual(feed_forward), - tl.ReversibleSwap(), - ] - - -def Reformer(input_vocab_size, - output_vocab_size=None, - d_model=512, - d_ff=2048, - n_encoder_layers=6, - n_decoder_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - ff_activation=tl.Relu, - ff_dropout=None, - mode='train', - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - ff_use_sru=0, - ff_chunk_size=0, - ff_sparsity=0): - """Reversible transformer encoder-decoder model. - - This model expects an input pair: target, source. - - At the moment, this model supports dot-product attention only. For the - attention types in the Reformer paper, see ReformerLM. - - Args: - input_vocab_size: int: vocab size of the source. - output_vocab_size: int (optional): vocab size of the target. If None, the - source and target are assumed to have the same vocab. - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_encoder_layers: int: number of encoder layers - n_decoder_layers: int: number of decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - ff_activation: the non-linearity in feed-forward layer - ff_dropout: float: (optional) separate dropout rate at feed-forward - nonlinearity. This is called relu_dropout in T2T. - mode: str: 'train' or 'eval' - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - - Returns: - A Reformer model as a layer that maps from a target, source pair to - activations over a vocab set. - """ - in_encoder, out_encoder, output_vocab_size = ( - ct.EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model, - mode, - dropout, - [-2], # dropout_shared_axes - max_len, - output_vocab_size=output_vocab_size, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs) - ) - - # pylint: disable=g-complex-comprehension - encoder_blocks = [ - EncoderBlock( - d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, - ff_dropout, mode=mode, ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) - for _ in range(n_encoder_layers)] - # pylint: enable=g-complex-comprehension - - encoder = tl.Serial([ - in_encoder, - tl.Dup(), - tl.ReversibleSerial(encoder_blocks), - _XYAvg(), - tl.LayerNorm(), - ]) - if mode == 'predict': - encoder = tl.Cache(encoder) - - # pylint: disable=g-complex-comprehension - encoder_decoder_blocks = [ - EncoderDecoderBlock( - d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, - ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity) - for _ in range(n_decoder_layers)] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( - # Input: encoder_side_tokens, decoder_side_tokens - # Copy decoder tokens for use in loss. - tl.Select([0, 1, 1]), # tok_e tok_d tok_d - tl.Branch([], [tl.PaddingMask(), - _RemoveAxes12()]), # tok_e mask tok_d ..... - - # Encode. - encoder, # vec_e mask tok_d ..... - - # Decode. - tl.Select([2, 0, 1]), # tok_d vec_e mask ..... - tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... - out_encoder, # vec_d vec_e mask ..... - tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... - tl.ReversibleSerial(encoder_decoder_blocks), - _XYAvg(), # vec_d vec_e mask ..... - tl.LayerNorm(), # vec_d vec_e mask ..... - - # Map to output vocab. - tl.Select([0], n_in=3), # vec_d ..... - tl.Dense(output_vocab_size), # vec_d ..... - ) + # pylint: enable=g-complex-comprehension + + encoder = tl.Serial( + [ + in_encoder, + tl.Dup(), + tl.ReversibleSerial(encoder_blocks), + _XYAvg(), + tl.LayerNorm(), + ] + ) + if mode == "predict": + encoder = tl.Cache(encoder) + + # pylint: disable=g-complex-comprehension + encoder_decoder_blocks = [ + EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + ff_activation, + ff_dropout, + mode, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + ) + for _ in range(n_decoder_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( + # Input: encoder_side_tokens, decoder_side_tokens + # Copy decoder tokens for use in loss. + tl.Select([0, 1, 1]), # tok_e tok_d tok_d + tl.Branch([], [tl.PaddingMask(), _RemoveAxes12()]), # tok_e mask tok_d ..... + # Encode. + encoder, # vec_e mask tok_d ..... + # Decode. + tl.Select([2, 0, 1]), # tok_d vec_e mask ..... + tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... + out_encoder, # vec_d vec_e mask ..... + tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... + tl.ReversibleSerial(encoder_decoder_blocks), + _XYAvg(), # vec_d vec_e mask ..... + tl.LayerNorm(), # vec_d vec_e mask ..... + # Map to output vocab. + tl.Select([0], n_in=3), # vec_d ..... + tl.Dense(output_vocab_size), # vec_d ..... + ) def _InsertAxes12(): - """Returns a layer that inserts two internal size-1 axes into an array.""" - return tl.Fn('InsertAxes12', - lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1]))) + """Returns a layer that inserts two internal size-1 axes into an array.""" + return tl.Fn( + "InsertAxes12", lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])) + ) def _RemoveAxes12(): - """Returns a layer that removes two internal size-1 axes from an array.""" - return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2))) + """Returns a layer that removes two internal size-1 axes from an array.""" + return tl.Fn("RemoveAxes12", lambda x: jnp.squeeze(x, (1, 2))) def _AsTokenIDs(): - """Returns a layer that makes mask values look like token ID ints.""" - return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32)) + """Returns a layer that makes mask values look like token ID ints.""" + return tl.Fn("AsTokenIDs", lambda x: x.astype(jnp.int32)) def _XYAvg(): - """Returns a layer that computes the element-wise average of two arrays.""" - return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0) + """Returns a layer that computes the element-wise average of two arrays.""" + return tl.Fn("XYAvg", lambda x, y: (x + y) / 2.0) def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True): - """ReversibleSerial but with a forgetting block every n_layers.""" - if not n_layers or len(layers) <= n_layers + 1: - return tl.ReversibleSerial(layers) - layers1, layers2 = layers[:n_layers], layers[n_layers:] - - if forget_dense: - forgetting_layer = tl.Serial( - _XYAvg(), - tl.Dense(d_model), - tl.Dup(), + """ReversibleSerial but with a forgetting block every n_layers.""" + if not n_layers or len(layers) <= n_layers + 1: + return tl.ReversibleSerial(layers) + layers1, layers2 = layers[:n_layers], layers[n_layers:] + + if forget_dense: + forgetting_layer = tl.Serial( + _XYAvg(), + tl.Dense(d_model), + tl.Dup(), + ) + else: + forgetting_layer = tl.Select([0, 1]) + + return tl.Serial( + tl.ReversibleSerial(layers1), + forgetting_layer, + _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense), ) - else: - forgetting_layer = tl.Select([0, 1]) - - return tl.Serial( - tl.ReversibleSerial(layers1), - forgetting_layer, - _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense) - ) def _ConvertToNaNsOnAnyZero(): - def _convert_to_nans(x, y): - # if all values in y are non-zeros, return x; otherwise return 0s - return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y - return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2) + def _convert_to_nans(x, y): + # if all values in y are non-zeros, return x; otherwise return 0s + return jnp.where(jnp.all(y, keepdims=False), x, x / 0.0), y + + return tl.Fn("ConvertToNaNsOnAnyZero", _convert_to_nans, n_out=2) diff --git a/trax/models/reformer/reformer_e2e_test.py b/trax/models/reformer/reformer_e2e_test.py deleted file mode 100644 index 57b180353..000000000 --- a/trax/models/reformer/reformer_e2e_test.py +++ /dev/null @@ -1,80 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End to end test for Reformer.""" - -import os - -from absl.testing import absltest -import gin - -from trax import test_utils -from trax.models.reformer import reformer # pylint: disable=unused-import -from trax.supervised import trainer_lib -from trax.tf_numpy import numpy as tf_np # pylint: disable=unused-import - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') -_CONFIG_DIR = os.path.join(pkg_dir, '../../supervised/configs/') - - -class ReformerE2ETest(absltest.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - gin.add_config_file_search_path(_CONFIG_DIR) - test_utils.ensure_flag('test_tmpdir') - - def test_reformer_wmt_ende(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('reformer_wmt_ende.gin') - - gin.bind_parameter('data_streams.data_dir', _TESTDATA) - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('Reformer.n_encoder_layers', n_layers) - gin.bind_parameter('Reformer.n_decoder_layers', n_layers) - gin.bind_parameter('Reformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - def test_reformer_copy(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - d_model = 32 - - gin.parse_config_file('reformer_copy.gin') - - gin.bind_parameter('data_streams.data_dir', _TESTDATA) - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ReformerLM.n_layers', n_layers) - gin.bind_parameter('ReformerLM.d_ff', d_ff) - gin.bind_parameter('ReformerLM.d_model', d_model) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/reformer/reformer_test.py b/trax/models/reformer/reformer_test.py deleted file mode 100644 index 5a1fce949..000000000 --- a/trax/models/reformer/reformer_test.py +++ /dev/null @@ -1,126 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Reformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.models.reformer import reformer - - -BACKENDS = [fastmath.Backend.JAX] - - -def short_name(b): - if b == fastmath.Backend.JAX: - return 'jax' - else: - return 'tf' - - -class ReformerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def _timebin_self_attention_fn(self, use_reference_code=False): - return functools.partial( - tl.SelfAttention, - attention_dropout=0.05, - chunk_len=64, - n_chunks_before=1, - n_parallel_heads=1, - use_reference_code=use_reference_code - ) - - def test_reformer_lm_forward_shape(self): - vocab_size = 16 - model = reformer.ReformerLM( - vocab_size, d_model=32, d_ff=64, d_attention_key=16, - d_attention_value=16, n_layers=1, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - - def test_reformer_lm_lsh(self): - lsh_self_attention = self._lsh_self_attention_fn() - timebin_self_attention = self._timebin_self_attention_fn() - - model = reformer.ReformerLM( - vocab_size=256, - d_model=256, - d_ff=512, - d_attention_key=64, - d_attention_value=64, - n_layers=2, - n_heads=2, - dropout=0.05, - max_len=65536, - attention_type=[timebin_self_attention, lsh_self_attention], - pos_axial_shape=(256, 256), - pos_d_axial_embs=(64, 192), - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=8192, - mode='train', - ) - x = np.ones((1, 65536)).astype(np.int32) - weights, state = model.init(shapes.signature(x)) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits, new_state = model.pure_fn(x, weights, state, rng) - loss = fastmath.numpy.mean(logits[..., 0]) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - weights, state, logits = mock_training_step( - x, weights, state, fastmath.random.get_prng(0)) - self.assertEqual(logits.shape, (1, 65536, 256)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 b/trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 deleted file mode 100644 index 271d5aeae..000000000 Binary files a/trax/models/reformer/testdata/translate_ende_wmt32k-dev-00000-of-00001 and /dev/null differ diff --git a/trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 b/trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 deleted file mode 100644 index ed977fc71..000000000 Binary files a/trax/models/reformer/testdata/translate_ende_wmt32k-train-00000-of-00001 and /dev/null differ diff --git a/trax/models/reformer/text_generation.ipynb b/trax/models/reformer/text_generation.ipynb deleted file mode 100644 index 5b67721b0..000000000 --- a/trax/models/reformer/text_generation.ipynb +++ /dev/null @@ -1,548 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Reformer: Text Generation", - "provenance": [], - "collapsed_sections": [ - "udDs_biH0n5U" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "TPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "udDs_biH0n5U", - "colab_type": "text" - }, - "source": [ - "#### Copyright 2020 Google LLC." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WPY-OyyM0pSs", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Licensed under the Apache License, Version 2.0 (the \"License\")\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "psnUF-8c02o_", - "colab_type": "text" - }, - "source": [ - "# Reformer: Text Generation [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1lnRd_IoERdk", - "colab_type": "text" - }, - "source": [ - "This notebook was designed to run on TPU.\n", - "\n", - "To use TPUs in Colab, click \"Runtime\" on the main menu bar and select Change runtime type. Set \"TPU\" as the hardware accelerator." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8PluCmWbZIpJ", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Install JAX.\n", - "!pip install --upgrade jax\n", - "!pip install --upgrade jaxlib\n", - "!pip install --upgrade trax\n", - "\n", - "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", - "import requests\n", - "import os\n", - "if 'TPU_DRIVER_MODE' not in globals():\n", - " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", - " resp = requests.post(url)\n", - " TPU_DRIVER_MODE = 1\n", - "\n", - "# The following is required to use TPU Driver as JAX's backend.\n", - "from jax.config import config\n", - "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", - "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", - "print(config.FLAGS.jax_backend_target)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yiPdBenoZwH6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "!pip install --upgrade -q sentencepiece\n", - "!pip install --upgrade -q gin \n", - "\n", - "from tensorflow.compat.v1.io.gfile import GFile\n", - "import gin\n", - "import os\n", - "import jax\n", - "import trax\n", - "from trax.data import inputs\n", - "\n", - "import numpy as np\n", - "import jax.numpy as jnp\n", - "\n", - "from scipy.special import softmax\n", - "\n", - "from sentencepiece import SentencePieceProcessor" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "FQ89jHCYfhpg" - }, - "source": [ - "## Setting up data and model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9_OCIqghSyfs", - "colab_type": "text" - }, - "source": [ - "In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of \"Crime and Punishment\" on *each* of the 8 TPU cores (over 500,000 tokens per 8GB of memory)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "tYSOVGR47LVL", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Import a copy of \"Crime and Punishment\", by Fyodor Dostoevsky\n", - "with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:\n", - " text = f.read()\n", - "\n", - "# The file read above includes metadata and licensing information.\n", - "# For training our language model, we will only use the actual novel text.\n", - "start = text.find('CRIME AND PUNISHMENT') # skip header\n", - "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip header\n", - "start = text.find('CRIME AND PUNISHMENT', start + 1) # skip translator preface\n", - "end = text.rfind('End of Project') # skip extra text at the end\n", - "text = text[start:end].strip()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "mMntV3H-6OR0", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 102 - }, - "outputId": "c8d4386c-cf5d-4dc4-92d9-24391fa2f30e" - }, - "source": [ - "# Load a BPE vocabulaary with 320 types. This mostly consists of single letters\n", - "# and pairs of letters, but it has some common words and word pieces, too.\n", - "!gsutil cp gs://trax-ml/reformer/cp.320.* .\n", - "\n", - "TOKENIZER = SentencePieceProcessor()\n", - "TOKENIZER.load('cp.320.model')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Copying gs://trax-ml/reformer/cp.320.model...\n", - "Copying gs://trax-ml/reformer/cp.320.vocab...\n", - "/ [2 files][239.0 KiB/239.0 KiB] \n", - "Operation completed over 2 objects/239.0 KiB. \n" - ], - "name": "stdout" - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "True" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 4 - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "HnJzxSi_77zP", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17 - }, - "outputId": "f8b2050b-0233-40e4-88f1-e546a1541b31" - }, - "source": [ - "# Tokenize\n", - "IDS = TOKENIZER.EncodeAsIds(text)\n", - "IDS = np.asarray(IDS, dtype=np.int32)\n", - "PAD_AMOUNT = 512 * 1024 - len(IDS)\n", - "print(\"Number of tokens:\", IDS.shape[0])" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Number of tokens: 513812\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bzQ7G9uGSga5", - "colab_type": "text" - }, - "source": [ - "As we see above, \"Crime and Punishment\" has just over half a million tokens with the BPE vocabulary we have selected.\n", - "\n", - "Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.\n", - "\n", - "We have 8 TPU cores, so we will separately randomize the amount of padding for each core." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "PdAwmpS220ub", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "outputId": "c0919b3d-4c63-4d2f-db44-3aeccaf4d966" - }, - "source": [ - "# Set up the data pipeline.\n", - "def my_inputs(n_devices):\n", - " while True:\n", - " inputs = []\n", - " mask = []\n", - " pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)\n", - " for i in range(n_devices):\n", - " inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", - " mode='constant'))\n", - " mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),\n", - " (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),\n", - " mode='constant'))\n", - " inputs = np.stack(inputs)\n", - " mask = np.stack(mask)\n", - " yield (inputs, inputs, mask)\n", - "\n", - "print(\"(device count, tokens per device) = \",\n", - " next(my_inputs(trax.fastmath.device_count()))[0].shape)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "(device count, tokens per device) = (8, 524288)\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ei90LdK024r_", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Configure hyperparameters.\n", - "gin.parse_config(\"\"\"\n", - "import trax.layers\n", - "import trax.models\n", - "import trax.optimizers\n", - "import trax.data.inputs\n", - "import trax.supervised.trainer_lib\n", - "\n", - "# Parameters that will vary between experiments:\n", - "# ==============================================================================\n", - "train.model = @trax.models.ReformerLM\n", - "# Our model will have 6 layers, alternating between the LSH attention proposed\n", - "# in the Reformer paper and local attention within a certain context window.\n", - "n_layers = 6\n", - "attn_type = [\n", - " @trax.layers.SelfAttention,\n", - " @LSHSelfAttention, \n", - " @trax.layers.SelfAttention,\n", - " @LSHSelfAttention,\n", - " @trax.layers.SelfAttention,\n", - " @LSHSelfAttention,\n", - " ]\n", - "share_qk = False # LSH attention ignores this flag and always shares q & k\n", - "n_heads = 2\n", - "attn_kv = 64\n", - "dropout = 0.05\n", - "n_tokens = 524288\n", - "\n", - "# Parameters for multifactor:\n", - "# ==============================================================================\n", - "multifactor.constant = 0.01\n", - "multifactor.factors = 'constant * linear_warmup * cosine_decay'\n", - "multifactor.warmup_steps = 100\n", - "multifactor.steps_per_cycle = 900\n", - "\n", - "# Parameters for Adam:\n", - "# ==============================================================================\n", - "Adam.weight_decay_rate=0.0\n", - "Adam.b1 = 0.86\n", - "Adam.b2 = 0.92\n", - "Adam.eps = 1e-9\n", - "\n", - "# Parameters for SelfAttention:\n", - "# ==============================================================================\n", - "trax.layers.SelfAttention.attention_dropout = 0.05\n", - "trax.layers.SelfAttention.chunk_len = 64\n", - "trax.layers.SelfAttention.n_chunks_before = 1\n", - "trax.layers.SelfAttention.n_parallel_heads = 1\n", - "\n", - "# Parameters for LSHSelfAttention:\n", - "# ==============================================================================\n", - "LSHSelfAttention.attention_dropout = 0.0\n", - "LSHSelfAttention.chunk_len = 64\n", - "LSHSelfAttention.n_buckets = [64, 128]\n", - "LSHSelfAttention.n_chunks_after = 0\n", - "LSHSelfAttention.n_chunks_before = 1\n", - "LSHSelfAttention.n_hashes = 1\n", - "LSHSelfAttention.n_parallel_heads = 1\n", - "LSHSelfAttention.predict_drop_len = 128\n", - "LSHSelfAttention.predict_mem_len = 1024\n", - "\n", - "# Parameters for ReformerLM:\n", - "# ==============================================================================\n", - "ReformerLM.attention_type = %attn_type\n", - "ReformerLM.d_attention_key = %attn_kv\n", - "ReformerLM.d_attention_value = %attn_kv\n", - "ReformerLM.d_model = 256\n", - "ReformerLM.d_ff = 512\n", - "ReformerLM.dropout = %dropout\n", - "ReformerLM.ff_activation = @trax.layers.Relu\n", - "ReformerLM.max_len = %n_tokens\n", - "ReformerLM.mode = 'train'\n", - "ReformerLM.n_heads = %n_heads\n", - "ReformerLM.n_layers = %n_layers\n", - "ReformerLM.vocab_size = 320\n", - "ReformerLM.axial_pos_shape = (512, 1024)\n", - "ReformerLM.d_axial_pos_embs= (64, 192)\n", - "\"\"\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "RGGt0WaT3a-h", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Set up a Trainer.\n", - "output_dir = os.path.expanduser('~/train_dir/')\n", - "!rm -f ~/train_dir/model.pkl.gz # Remove old model\n", - "\n", - "trainer = trax.supervised.Trainer(\n", - " model=trax.models.ReformerLM,\n", - " loss_fn=trax.layers.CrossEntropyLoss(),\n", - " optimizer=trax.optimizers.Adam,\n", - " lr_schedule=trax.lr.multifactor(),\n", - " inputs=trax.data.inputs.Inputs(my_inputs),\n", - " output_dir=output_dir)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "y6VQkmKO3a1L", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 255 - }, - "outputId": "3c933bab-b49d-4e18-caf6-3dfc3e220938" - }, - "source": [ - "# Run one training step, to make sure the model fits in memory.\n", - "# The first time trainer.train_epoch is called, it will JIT the entire network\n", - "# architecture, which takes around 2 minutes. The JIT-compiled model is saved\n", - "# so subsequent runs will be much faster than the first.\n", - "trainer.train_epoch(n_steps=1, n_eval_steps=1)" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - "Step 1: Ran 1 train steps in 155.17 secs\n", - "Step 1: Evaluation\n", - "Step 1: train accuracy | 0.00343633\n", - "Step 1: train loss | 6.36618853\n", - "Step 1: train neg_log_perplexity | -6.36618853\n", - "Step 1: train sequence_accuracy | 0.00000000\n", - "Step 1: train weights_per_batch_per_core | 513812.00000000\n", - "Step 1: eval accuracy | 0.00340154\n", - "Step 1: eval loss | 6.36649418\n", - "Step 1: eval neg_log_perplexity | -6.36649418\n", - "Step 1: eval sequence_accuracy | 0.00000000\n", - "Step 1: eval weights_per_batch_per_core | 513812.00000000\n", - "Step 1: Finished evaluation\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "EFnX4G6z3asD", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Train for 600 steps total\n", - "# The first ~20 steps are slow to run, but after that it reaches steady-state\n", - "# speed. This will take at least 30 minutes to run to completion, but can safely\n", - "# be interrupted by selecting \"Runtime > Interrupt Execution\" from the menu.\n", - "# The language model won't be exceptionally good when trained for just a few\n", - "# steps and with minimal regularization. However, we can still sample from it to\n", - "# see what it learns.\n", - "trainer.train_epoch(n_steps=9, n_eval_steps=1)\n", - "for _ in range(59):\n", - " trainer.train_epoch(n_steps=10, n_eval_steps=1)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zY3hpgnI5Rgn", - "colab_type": "text" - }, - "source": [ - "## Sample from the model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ffeLSbJk35pv", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# As we report in the Reformer paper, increasing the number of hashing rounds\n", - "# helps with quality. We can even increase the number of hashing rounds at\n", - "# evaluation time only.\n", - "\n", - "gin.parse_config(\"\"\"LSHSelfAttention.n_hashes = 4\"\"\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "-BwIjdl6_2tX", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# Load the trained Reformer in 'predict' mode\n", - "model = trax.models.ReformerLM(mode='predict')\n", - "model.init_from_file(os.path.join(output_dir,'model.pkl.gz'),\n", - " weights_only=True)\n", - "\n", - "# Sample from ReformerLM\n", - "output_token_ids = trax.supervised.decoding.autoregressive_sample(\n", - " model, temperature=0.0)\n", - "\n", - "# Decode token IDs\n", - "# Reformer outputed a batch with one item, we access it using [0]\n", - "# tolist() converts from int64 to int, the type SentencePiece expects\n", - "TOKENIZER.DecodeIds(output_token_ids[0].tolist()) \n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "s5f5QAmZBgPj", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/trax/models/research/configurable_transformer.py b/trax/models/research/configurable_transformer.py index b0d1e3232..25d5b92ef 100644 --- a/trax/models/research/configurable_transformer.py +++ b/trax/models/research/configurable_transformer.py @@ -22,1005 +22,1163 @@ from trax import layers as tl -def _FeedForward(d_model, d_ff, dropout, activation, act_dropout, - use_bfloat16, mode): - """Feed-forward block with layer normalization at start.""" - if act_dropout is None: - act_dropout = dropout - return [ - tl.Dense(d_ff, use_bfloat16=use_bfloat16), - tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode), - activation(), - tl.Dense(d_model, use_bfloat16=use_bfloat16), - ] - - -def FeedForwardWithOptions(d_model, - d_ff, - dropout, - dropout_shared_axes, - ff_activation, - ff_dropout, - ff_chunk_size, - ff_use_sru, - ff_sparsity, - center_layernorm, - mode, - use_bfloat16=False, - ff_sparsity_type='1inN'): - """Feed-Forward block with all the options. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block - with this sparsity - center_layernorm: whether to use centering in LayerNorm (default) or if - to skip it, which is known as RMS normalization. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - use_bfloat16: whether to use bfloat16 for weights (default: False). - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - use SwitchSparseFF if ff_sparsity_type=`'Switch'` - - Returns: - A list of layers which maps vectors to vectors. - """ - if ff_sparsity and ff_sparsity_type == '1inN': - temperature, quant_prob = 0.1, 0.3 - if isinstance(ff_sparsity, str): - # This is hacky but used to pass ff_sparsity in yaml sweep files. - ff_sparsity = [(float(x) if '.' in x else int(x)) - for x in ff_sparsity.split()] - if isinstance(ff_sparsity, (list, tuple)): - if len(ff_sparsity) == 2: - n_elements_in_block, d_lowrank = ff_sparsity - else: - n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity - else: - assert isinstance(ff_sparsity, int) - n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity - ff = tl.SparseFF( - d_ff, - n_elements_in_block=n_elements_in_block, - d_lowrank=d_lowrank, - temperature=temperature, - quant_prob=quant_prob, - use_bfloat16=use_bfloat16, - mode=mode, - dropout_rate=dropout, - dropout_shared_axes=dropout_shared_axes, - ff_chunk_size=ff_chunk_size) - elif ff_sparsity and ff_sparsity_type == 'Block': - ff = tl.BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) - elif ff_sparsity and ff_sparsity_type == 'Switch': - ff = tl.SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) - else: - ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, - use_bfloat16, mode) - res = [tl.LayerNorm(center=center_layernorm), ff] - if ff_sparsity_type != '1inN' or ff_sparsity == 0: - # SparseFF has Dropout and BatchLeadingAxes built-in. - res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, - mode=mode)) - if ff_chunk_size > 0: - res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) - if ff_use_sru: - if isinstance(ff_use_sru, (list, tuple)): - sru_n_layers, sru_n_units = ff_use_sru +def _FeedForward(d_model, d_ff, dropout, activation, act_dropout, use_bfloat16, mode): + """Feed-forward block with layer normalization at start.""" + if act_dropout is None: + act_dropout = dropout + return [ + tl.Dense(d_ff, use_bfloat16=use_bfloat16), + tl.Dropout(rate=act_dropout, shared_axes=[-2], mode=mode), + activation(), + tl.Dense(d_model, use_bfloat16=use_bfloat16), + ] + + +def FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + center_layernorm, + mode, + use_bfloat16=False, + ff_sparsity_type="1inN", +): + """Feed-Forward block with all the options. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block + with this sparsity + center_layernorm: whether to use centering in LayerNorm (default) or if + to skip it, which is known as RMS normalization. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + use_bfloat16: whether to use bfloat16 for weights (default: False). + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + use SwitchSparseFF if ff_sparsity_type=`'Switch'` + + Returns: + A list of layers which maps vectors to vectors. + """ + if ff_sparsity and ff_sparsity_type == "1inN": + temperature, quant_prob = 0.1, 0.3 + if isinstance(ff_sparsity, str): + # This is hacky but used to pass ff_sparsity in yaml sweep files. + ff_sparsity = [ + (float(x) if "." in x else int(x)) for x in ff_sparsity.split() + ] + if isinstance(ff_sparsity, (list, tuple)): + if len(ff_sparsity) == 2: + n_elements_in_block, d_lowrank = ff_sparsity + else: + n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity + else: + assert isinstance(ff_sparsity, int) + n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity + ff = tl.SparseFF( + d_ff, + n_elements_in_block=n_elements_in_block, + d_lowrank=d_lowrank, + temperature=temperature, + quant_prob=quant_prob, + use_bfloat16=use_bfloat16, + mode=mode, + dropout_rate=dropout, + dropout_shared_axes=dropout_shared_axes, + ff_chunk_size=ff_chunk_size, + ) + elif ff_sparsity and ff_sparsity_type == "Block": + ff = tl.BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) + elif ff_sparsity and ff_sparsity_type == "Switch": + ff = tl.SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) else: - sru_n_layers, sru_n_units = ff_use_sru, 32 - sru = [tl.SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)] - block = [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units) - ] + sru + [tl.Dense(d_model)] - res = tl.Residual(block, shortcut=res) - return [res] + ff = _FeedForward( + d_model, d_ff, dropout, ff_activation, ff_dropout, use_bfloat16, mode + ) + res = [tl.LayerNorm(center=center_layernorm), ff] + if ff_sparsity_type != "1inN" or ff_sparsity == 0: + # SparseFF has Dropout and BatchLeadingAxes built-in. + res.append(tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)) + if ff_chunk_size > 0: + res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) + if ff_use_sru: + if isinstance(ff_use_sru, (list, tuple)): + sru_n_layers, sru_n_units = ff_use_sru + else: + sru_n_layers, sru_n_units = ff_use_sru, 32 + sru = [tl.SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)] + block = ( + [tl.LayerNorm(center=center_layernorm), tl.Dense(sru_n_units)] + + sru + + [tl.Dense(d_model)] + ) + res = tl.Residual(block, shortcut=res) + return [res] # TODO(lukaszkaiser): unify attention layers API and remove this branch -def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal, - masked, attention_dropout, output_dropout, - attention_chunk_size, mode): - """Runs the supplied attention layer.""" - try: - attention = attention_type( - n_heads=n_heads, - d_qk=d_qk, - d_v=d_v, - causal=causal, - masked=masked, - output_dropout=output_dropout, - attention_dropout=attention_dropout, - mode=mode) - except TypeError: # No d_qk arguments in less advanced layers. - attention = attention_type( - d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode) - return tl.Chunk(attention, attention_chunk_size) - - -@tl.assert_shape('...d->...d') -def PositionalEncoder(mode, - dropout=None, - max_len=None, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - use_bfloat16=False): - """Returns the positional encoding layer depending on the arguments. - - Args: - mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder - block will include dropout; else, it will pass all values through - unaltered. - dropout: Stochastic rate (probability) for dropping an activation - value when applying dropout after the embedding block. - max_len: Maximum symbol length for positional encoding. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - - Returns: - A layer that will do the positional encoding. - """ - if not pos_type: - positional_encoding = tl.PositionalEncoding( - max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16, - start_from_zero_prob=pos_start_from_zero_prob, - max_offset_to_add=pos_max_offset_to_add, mode=mode) - elif pos_type == 'sin-cos': - positional_encoding = tl.SinCosPositionalEncoding(mode=mode) - elif pos_type == 'fixed-base': - positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) - elif pos_type == 'infinite': - positional_encoding = tl.InfinitePositionalEncoding(affine=False) - elif pos_type == 'infinite-affine': - positional_encoding = tl.InfinitePositionalEncoding() - elif pos_type == 'time-bin': - positional_encoding = tl.TimeBinPositionalEncoding() - elif pos_type == 'no': - positional_encoding = tl.Serial() # no positional encoding at all - else: # TODO(lukaszkaiser): name this type and check for the correct name - assert pos_d_axial_embs is not None - positional_encoding = tl.AxialPositionalEncoding( - shape=pos_axial_shape, d_embs=pos_d_axial_embs, - dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)), - dropout=dropout, mode=mode) - - return positional_encoding - - -def EmbeddingAndPositionalEncodings(input_vocab_size, - d_model, - mode, - embedding_dropout, - dropout_shared_axes, - max_len, - output_vocab_size=None, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - use_bfloat16=False): - """Returns the embedder and positional encoder. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder - block will include dropout; else, it will pass all values through - unaltered. - embedding_dropout: Stochastic rate (probability) for dropping an activation - value when applying dropout after the embedding block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - max_len: Maximum symbol length for positional encoding. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if None, then input and target integers (token IDs) are assumed to come - from the same vocabulary. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training +def ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_qk, + d_v, + causal, + masked, + attention_dropout, + output_dropout, + attention_chunk_size, + mode, +): + """Runs the supplied attention layer.""" + try: + attention = attention_type( + n_heads=n_heads, + d_qk=d_qk, + d_v=d_v, + causal=causal, + masked=masked, + output_dropout=output_dropout, + attention_dropout=attention_dropout, + mode=mode, + ) + except TypeError: # No d_qk arguments in less advanced layers. + attention = attention_type( + d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode + ) + return tl.Chunk(attention, attention_chunk_size) + + +@tl.assert_shape("...d->...d") +def PositionalEncoder( + mode, + dropout=None, + max_len=None, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + use_bfloat16=False, +): + """Returns the positional encoding layer depending on the arguments. + + Args: + mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder + block will include dropout; else, it will pass all values through + unaltered. + dropout: Stochastic rate (probability) for dropping an activation + value when applying dropout after the embedding block. + max_len: Maximum symbol length for positional encoding. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + + Returns: + A layer that will do the positional encoding. + """ + if not pos_type: + positional_encoding = tl.PositionalEncoding( + max_len=max_len, + dropout=dropout, + use_bfloat16=use_bfloat16, + start_from_zero_prob=pos_start_from_zero_prob, + max_offset_to_add=pos_max_offset_to_add, + mode=mode, + ) + elif pos_type == "sin-cos": + positional_encoding = tl.SinCosPositionalEncoding(mode=mode) + elif pos_type == "fixed-base": + positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) + elif pos_type == "infinite": + positional_encoding = tl.InfinitePositionalEncoding(affine=False) + elif pos_type == "infinite-affine": + positional_encoding = tl.InfinitePositionalEncoding() + elif pos_type == "time-bin": + positional_encoding = tl.TimeBinPositionalEncoding() + elif pos_type == "no": + positional_encoding = tl.Serial() # no positional encoding at all + else: # TODO(lukaszkaiser): name this type and check for the correct name + assert pos_d_axial_embs is not None + positional_encoding = tl.AxialPositionalEncoding( + shape=pos_axial_shape, + d_embs=pos_d_axial_embs, + dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)), + dropout=dropout, + mode=mode, + ) + + return positional_encoding + + +def EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + mode, + embedding_dropout, + dropout_shared_axes, + max_len, + output_vocab_size=None, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + use_bfloat16=False, +): + """Returns the embedder and positional encoder. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder + block will include dropout; else, it will pass all values through + unaltered. + embedding_dropout: Stochastic rate (probability) for dropping an activation + value when applying dropout after the embedding block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + max_len: Maximum symbol length for positional encoding. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if None, then input and target integers (token IDs) are assumed to come + from the same vocabulary. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training + when randomizing; this offset plus input length must still be less than + max_len for all training examples. + use_bfloat16: If `True`, use bfloat16 weights instead of the default + float32; this can save memory but may (rarely) lead to numerical issues. + + Returns: + A tuple of (input encoder, output encoder, output vocab size used). + """ + + # tokens --> vectors + def Embedder(vocab_size, embedding_mode): + if vocab_size is not None: + embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16) + else: + embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16) + return [ + embedding, + tl.Dropout( + rate=embedding_dropout, + shared_axes=dropout_shared_axes, + mode=embedding_mode, + ), + ] + + # NOTE: Positional encodings are not shared between encoder and decoder. + + # Since encoder doesn't run stepwise, we do not use predict mode there. + encoder_mode = "eval" if mode == "predict" else mode + in_embedder = Embedder(input_vocab_size, encoder_mode) + in_encoder = in_embedder + [ + PositionalEncoder( + encoder_mode, + dropout=embedding_dropout, + max_len=max_len, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + pos_start_from_zero_prob=pos_start_from_zero_prob, + pos_max_offset_to_add=pos_max_offset_to_add, + use_bfloat16=use_bfloat16, + ) + ] + + # If output_vocab_size is None, we reuse the same embedding matrix, otherwise + # we initialize one. + assert input_vocab_size or output_vocab_size + if output_vocab_size is None: + out_embedder = in_embedder + else: + out_embedder = Embedder(output_vocab_size, mode) + + out_encoder = out_embedder + [ + PositionalEncoder( + mode, + dropout=embedding_dropout, + max_len=max_len, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + pos_start_from_zero_prob=pos_start_from_zero_prob, + pos_max_offset_to_add=pos_max_offset_to_add, + use_bfloat16=use_bfloat16, + ) + ] + + # Set this to the value actually used. + if output_vocab_size is None: + output_vocab_size = input_vocab_size + + if input_vocab_size is None: + in_encoder = tl.AssertFunction("...a->...b", in_encoder) + else: + in_encoder = tl.AssertFunction("...->...d", in_encoder) + out_encoder = tl.AssertFunction("...->...d", out_encoder) + + return in_encoder, out_encoder, output_vocab_size + + +def ConfigurableTransformerEncoder( + vocab_size, + n_classes=10, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + max_len=2048, + dropout=0.1, + dropout_shared_axes=None, + mode="train", + ff_activation=tl.Relu, + ff_dropout=0.1, + ff_chunk_size=0, + ff_use_sru=0, + ff_sparsity=0, + ff_sparsity_type="1inN", + attention_chunk_size=0, + attention_type=tl.Attention, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, +): + """Returns a Transformer encoder merged with an N-way categorization head. + + This model performs text categorization: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 2 tensor representing a batch of log-probability + distributions over N categories; shape is (batch_size, `n_classes`). + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor should + be an integer in `range(vocab_size)`. These integers typically represent + token IDs from a vocabulary-based tokenizer. + n_classes: Final dimension of the output tensors, representing N-way + classification. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + block. + n_layers: Number of encoder blocks. Each block includes attention, dropout, + residual, feed-forward (`Dense`), and activation layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within an encoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each encoder block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder block; + must be an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use for the encoder part. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + + Returns: + A Transformer model that maps strings (conveyed via token IDs) to + probability-like activations over a range of output classes. + """ + positional_encoder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + PositionalEncoder( + mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs + ), + ] + + positional_encoder = tl.AssertFunction("...->...d", positional_encoder) + + # pylint: disable=g-complex-comprehension + encoder_blocks = [ + EncoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + ) + for i in range(n_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( # toks + # Encode. + tl.Branch(positional_encoder, tl.PaddingMask()), # vecs masks + encoder_blocks, # vecs masks + tl.Select([0], n_in=2), # vecs + tl.LayerNorm(), # vecs + # Map to output categories. + tl.Mean(axis=1), # vecs + tl.Dense(n_classes), # vecs + ) + + +def ConfigurableTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + max_len=2048, + dropout=0.1, + dropout_shared_axes=None, + mode="train", + ff_activation=tl.Relu, + ff_dropout=0.1, + ff_chunk_size=0, + ff_use_sru=0, + ff_sparsity=0, + ff_sparsity_type="1inN", + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + attention_type=tl.CausalAttention, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, +): + """Returns a Transformer language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + This model uses only the decoder part of the overall Transformer. + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor should + be an integer in `range(vocab_size)`. These integers typically represent + token IDs from a vocabulary-based tokenizer. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + block. + n_layers: Number of encoder blocks. Each block includes attention, dropout, + residual, feed-forward (`Dense`), and activation layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within an encoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'predict'`, use fast inference. If `'train'`, each encoder block + will include dropout; else, it will pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder block; + must be an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + loss_sparsity_type: string, type of sparsity to used in loss layer. See + SparseDenseWithOptions for options. None if no sparsity should be used. + loss_sparsity: int, the sparsity for loss layer (if used) + loss_d_lowrank: int, the dimensions for intermediate layer (if used) + loss_sparsity_prob: float, the probability for sparse version of loss to be + used. If None, only sparse version is used. + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use for the decoder part. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + pos_start_from_zero_prob: how often to start from 0 during training, + (if 1.0, we always start from position 0, if less, we randomize). + pos_max_offset_to_add: maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples. - use_bfloat16: If `True`, use bfloat16 weights instead of the default - float32; this can save memory but may (rarely) lead to numerical issues. - - Returns: - A tuple of (input encoder, output encoder, output vocab size used). - """ - # tokens --> vectors - def Embedder(vocab_size, embedding_mode): - if vocab_size is not None: - embedding = tl.Embedding(vocab_size, d_model, use_bfloat16=use_bfloat16) - else: - embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16) - return [ - embedding, - tl.Dropout(rate=embedding_dropout, - shared_axes=dropout_shared_axes, - mode=embedding_mode), + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + positional_encoder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + PositionalEncoder( + mode, + dropout, + max_len, + pos_type, + pos_axial_shape, + pos_d_axial_embs, + pos_start_from_zero_prob, + pos_max_offset_to_add, + ), + ] + + # pylint: disable=g-complex-comprehension + decoder_blocks = [ + DecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + ) + for i in range(n_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( # tokens (or chunked tuple of tokens) + tl.ShiftRight(mode=mode), # toks + positional_encoder, # vecs + decoder_blocks, # vecs + tl.LayerNorm(), # vecs + tl.SparseDenseWithOptions( # vecs + vocab_size, + d_input=d_model, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + mode=mode, + ), + ) + + +def ConfigurableTransformer( + input_vocab_size, + output_vocab_size=None, + d_model=512, + d_ff=2048, + n_encoder_layers=6, + n_decoder_layers=6, + n_heads=8, + max_len=2048, + dropout=0.1, + dropout_shared_axes=None, + mode="train", + ff_activation=tl.Relu, + ff_dropout=0.1, + ff_chunk_size=0, + ff_use_sru=0, + ff_sparsity=0, + ff_sparsity_type="1inN", + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + encoder_attention_type=tl.Attention, + encoder_decoder_attention_type=tl.CausalAttention, + pos_type=None, + pos_axial_shape=None, + pos_d_axial_embs=None, + enc_dec_attention_sparsity=0, +): + """Returns a full Transformer model. + + This model is an encoder-decoder that performs tokenized string-to-string + ("source"-to-"target") transduction: + + - inputs (2): + + - source: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(input_vocab_size)`, and `0` + values mark padding positions. + + - target: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(output_vocab_size)`, and `0` + values mark padding positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + An example use would be to translate (tokenized) sentences from English to + German. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if None, then input and target integers (token IDs) are assumed to come + from the same vocabulary. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + and decoder block. + n_encoder_layers: Number of encoder blocks. + n_decoder_layers: Number of decoder blocks. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within an encoder/decoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder + block will include dropout; else, it will pass all values through + unaltered. + ff_activation: Type of activation function at the end of each + encoder/decoder block; must be an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + loss_sparsity_type: str, type of sparsity to used in loss layer. See + SparseDenseWithOptions for options. None if no sparsity should be used. + loss_sparsity: int, the sparsity for loss layer (if used) + loss_d_lowrank: int, the dimensions for intermediate layer (if used) + loss_sparsity_prob: float, the probability for sparse version of loss to be + used. If None, only sparse version is used. + attention_chunk_size: int, if > 0 run attention chunked at this size + encoder_attention_type: The attention layer to use for the encoder part. + encoder_decoder_attention_type: The attention layer to use for the + encoder-decoder attention. + pos_type: string, the type of positional embeddings to use. + pos_axial_shape: tuple of ints: input shape to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. + Tuple length must match pos_axial_shape, and values must sum to d_model. + enc_dec_attention_sparsity: int, if > 0 use this sparsity in attention. + + Returns: + A Transformer model as a layer that maps from a source-target tokenized + text pair to activations over a vocab set. + """ + in_encoder, out_encoder, output_vocab_size = EmbeddingAndPositionalEncodings( + input_vocab_size, + d_model, + mode, + dropout, + dropout_shared_axes, + max_len, + output_vocab_size=output_vocab_size, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + ) + + # pylint: disable=g-complex-comprehension + encoder_blocks = [ + EncoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + encoder_attention_type, + ) + for i in range(n_encoder_layers) + ] + # pylint: enable=g-complex-comprehension + + encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) + if mode == "predict": + encoder = tl.Cache(encoder) + + # pylint: disable=g-complex-comprehension + encoder_decoder_blocks = [ + EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + encoder_decoder_attention_type, + enc_dec_attention_sparsity, + ) + for i in range(n_decoder_layers) + ] + # pylint: enable=g-complex-comprehension + + # Assemble and return the model. + return tl.Serial( + # Input: encoder_side_tokens, decoder_side_tokens + # Copy decoder tokens for use in loss. + tl.Select([0, 1, 1]), # tok_e tok_d tok_d + # Encode. + tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... + encoder, # vec_e ..... ..... ..... + # Decode. + tl.Select([2, 1, 0]), # tok_d masks vec_e ..... + tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... + out_encoder, # vec_d ..... ..... ..... + tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... + encoder_decoder_blocks, # vec_d masks ..... ..... + tl.LayerNorm(), # vec_d ..... ..... ..... + # Map to output vocab. + tl.Select([0], n_in=3), # vec_d tok_d + tl.SparseDenseWithOptions( # vec_d ..... + output_vocab_size, + d_input=d_model, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + mode=mode, + ), + ) + + +def EncoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + n_attention_layers=1, + n_feedforward_layers=1, +): + """Returns a list of layers that implements a Transformer encoder block. + + The input to the block is a pair, (activations, mask), where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use. + n_attention_layers: how many residual causal attention layers should we + have before the feed-forward block (default: 1, the standard block) + n_feedforward_layers: how many FFNN layers should we have (default 1). + + Returns: + A list of layers that maps (activations, mask) to (activations, mask). + """ + # `n_attention_layers` number of residuals of attention layer + dropout. + # pylint: disable=g-complex-comprehension + residual_attentions = [ + tl.Residual( + tl.LayerNorm(), + ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=False, + masked=True, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + ) + for _ in range(n_attention_layers) + ] + + feed_forwards = [ + tl.Residual( + FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, + False, + ff_sparsity_type, + ) + ) + for _ in range(n_feedforward_layers) + ] + # pylint: enable=g-complex-comprehension + + return residual_attentions + feed_forwards + + +def DecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + n_attention_layers=1, + n_feedforward_layers=1, +): + """Returns a list of layers that implements a Transformer decoder block. + + The input is an activation tensor. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use. + n_attention_layers: how many residual causal attention layers should we + have before the feed-forward block (default: 1, the standard block) + n_feedforward_layers: how many FFNN layers should we have (default 1). + + Returns: + A list of layers that maps an activation tensor to an activation tensor. + """ + # pylint: disable=g-complex-comprehension + causal_attentions = [ + ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=True, + masked=False, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + for _ in range(n_attention_layers) ] - # NOTE: Positional encodings are not shared between encoder and decoder. - - # Since encoder doesn't run stepwise, we do not use predict mode there. - encoder_mode = 'eval' if mode == 'predict' else mode - in_embedder = Embedder(input_vocab_size, encoder_mode) - in_encoder = in_embedder + [ - PositionalEncoder(encoder_mode, - dropout=embedding_dropout, - max_len=max_len, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs, - pos_start_from_zero_prob=pos_start_from_zero_prob, - pos_max_offset_to_add=pos_max_offset_to_add, - use_bfloat16=use_bfloat16) - ] - - # If output_vocab_size is None, we reuse the same embedding matrix, otherwise - # we initialize one. - assert input_vocab_size or output_vocab_size - if output_vocab_size is None: - out_embedder = in_embedder - else: - out_embedder = Embedder(output_vocab_size, mode) - - out_encoder = out_embedder + [ - PositionalEncoder(mode, - dropout=embedding_dropout, - max_len=max_len, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs, - pos_start_from_zero_prob=pos_start_from_zero_prob, - pos_max_offset_to_add=pos_max_offset_to_add, - use_bfloat16=use_bfloat16) - ] - - # Set this to the value actually used. - if output_vocab_size is None: - output_vocab_size = input_vocab_size - - if input_vocab_size is None: - in_encoder = tl.AssertFunction('...a->...b', in_encoder) - else: - in_encoder = tl.AssertFunction('...->...d', in_encoder) - out_encoder = tl.AssertFunction('...->...d', out_encoder) - - return in_encoder, out_encoder, output_vocab_size - - -def ConfigurableTransformerEncoder(vocab_size, - n_classes=10, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - ff_dropout=0.1, - ff_chunk_size=0, - ff_use_sru=0, - ff_sparsity=0, - ff_sparsity_type='1inN', - attention_chunk_size=0, - attention_type=tl.Attention, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None): - """Returns a Transformer encoder merged with an N-way categorization head. - - This model performs text categorization: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 2 tensor representing a batch of log-probability - distributions over N categories; shape is (batch_size, `n_classes`). - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor should - be an integer in `range(vocab_size)`. These integers typically represent - token IDs from a vocabulary-based tokenizer. - n_classes: Final dimension of the output tensors, representing N-way - classification. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - n_layers: Number of encoder blocks. Each block includes attention, dropout, - residual, feed-forward (`Dense`), and activation layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each encoder block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder block; - must be an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use for the encoder part. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - - Returns: - A Transformer model that maps strings (conveyed via token IDs) to - probability-like activations over a range of output classes. - """ - positional_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), - PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs) - ] - - positional_encoder = tl.AssertFunction('...->...d', positional_encoder) - - # pylint: disable=g-complex-comprehension - encoder_blocks = [ - EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, - ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, ff_sparsity_type, - attention_chunk_size, attention_type) - for i in range(n_layers) - ] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( # toks - # Encode. - tl.Branch( - positional_encoder, tl.PaddingMask()), # vecs masks - encoder_blocks, # vecs masks - tl.Select([0], n_in=2), # vecs - tl.LayerNorm(), # vecs - - # Map to output categories. - tl.Mean(axis=1), # vecs - tl.Dense(n_classes), # vecs - ) - - -def ConfigurableTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - ff_dropout=0.1, - ff_chunk_size=0, - ff_use_sru=0, - ff_sparsity=0, - ff_sparsity_type='1inN', - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - attention_type=tl.CausalAttention, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0): - """Returns a Transformer language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - This model uses only the decoder part of the overall Transformer. - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor should - be an integer in `range(vocab_size)`. These integers typically represent - token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - n_layers: Number of encoder blocks. Each block includes attention, dropout, - residual, feed-forward (`Dense`), and activation layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'predict'`, use fast inference. If `'train'`, each encoder block - will include dropout; else, it will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder block; - must be an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - loss_sparsity_type: string, type of sparsity to used in loss layer. See - SparseDenseWithOptions for options. None if no sparsity should be used. - loss_sparsity: int, the sparsity for loss layer (if used) - loss_d_lowrank: int, the dimensions for intermediate layer (if used) - loss_sparsity_prob: float, the probability for sparse version of loss to be - used. If None, only sparse version is used. - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use for the decoder part. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - pos_start_from_zero_prob: how often to start from 0 during training, - (if 1.0, we always start from position 0, if less, we randomize). - pos_max_offset_to_add: maximum offset to add to positions during training - when randomizing; this offset plus input length must still be less than - max_len for all training examples. - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - positional_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), - PositionalEncoder( - mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs, - pos_start_from_zero_prob, pos_max_offset_to_add) - ] - - # pylint: disable=g-complex-comprehension - decoder_blocks = [ - DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, - ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, ff_sparsity_type, - attention_chunk_size, attention_type) - for i in range(n_layers) - ] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( # tokens (or chunked tuple of tokens) - tl.ShiftRight(mode=mode), # toks - positional_encoder, # vecs - decoder_blocks, # vecs - tl.LayerNorm(), # vecs - tl.SparseDenseWithOptions( # vecs - vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, mode=mode), - ) - - -def ConfigurableTransformer(input_vocab_size, - output_vocab_size=None, - d_model=512, - d_ff=2048, - n_encoder_layers=6, - n_decoder_layers=6, - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - ff_dropout=0.1, - ff_chunk_size=0, - ff_use_sru=0, - ff_sparsity=0, - ff_sparsity_type='1inN', - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - pos_type=None, - pos_axial_shape=None, - pos_d_axial_embs=None, - enc_dec_attention_sparsity=0): - """Returns a full Transformer model. - - This model is an encoder-decoder that performs tokenized string-to-string - ("source"-to-"target") transduction: - - - inputs (2): - - - source: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(input_vocab_size)`, and `0` - values mark padding positions. - - - target: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(output_vocab_size)`, and `0` - values mark padding positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - An example use would be to translate (tokenized) sentences from English to - German. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if None, then input and target integers (token IDs) are assumed to come - from the same vocabulary. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - and decoder block. - n_encoder_layers: Number of encoder blocks. - n_decoder_layers: Number of decoder blocks. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within an encoder/decoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder - block will include dropout; else, it will pass all values through - unaltered. - ff_activation: Type of activation function at the end of each - encoder/decoder block; must be an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - loss_sparsity_type: str, type of sparsity to used in loss layer. See - SparseDenseWithOptions for options. None if no sparsity should be used. - loss_sparsity: int, the sparsity for loss layer (if used) - loss_d_lowrank: int, the dimensions for intermediate layer (if used) - loss_sparsity_prob: float, the probability for sparse version of loss to be - used. If None, only sparse version is used. - attention_chunk_size: int, if > 0 run attention chunked at this size - encoder_attention_type: The attention layer to use for the encoder part. - encoder_decoder_attention_type: The attention layer to use for the - encoder-decoder attention. - pos_type: string, the type of positional embeddings to use. - pos_axial_shape: tuple of ints: input shape to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. - Tuple length must match pos_axial_shape, and values must sum to d_model. - enc_dec_attention_sparsity: int, if > 0 use this sparsity in attention. - - Returns: - A Transformer model as a layer that maps from a source-target tokenized - text pair to activations over a vocab set. - """ - in_encoder, out_encoder, output_vocab_size = ( - EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model, - mode, - dropout, - dropout_shared_axes, - max_len, - output_vocab_size=output_vocab_size, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs) - ) - - # pylint: disable=g-complex-comprehension - encoder_blocks = [ - EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, - ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, ff_sparsity_type, - attention_chunk_size, encoder_attention_type) - for i in range(n_encoder_layers) - ] - # pylint: enable=g-complex-comprehension - - encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) - if mode == 'predict': - encoder = tl.Cache(encoder) - - # pylint: disable=g-complex-comprehension - encoder_decoder_blocks = [ - EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation, ff_dropout, ff_chunk_size, - ff_use_sru, ff_sparsity, ff_sparsity_type, - attention_chunk_size, encoder_decoder_attention_type, - enc_dec_attention_sparsity) - for i in range(n_decoder_layers) - ] - # pylint: enable=g-complex-comprehension - - # Assemble and return the model. - return tl.Serial( - # Input: encoder_side_tokens, decoder_side_tokens - # Copy decoder tokens for use in loss. - tl.Select([0, 1, 1]), # tok_e tok_d tok_d - - # Encode. - tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... - encoder, # vec_e ..... ..... ..... - - # Decode. - tl.Select([2, 1, 0]), # tok_d masks vec_e ..... - tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... - out_encoder, # vec_d ..... ..... ..... - tl.Branch( - [], tl.EncoderDecoderMask()), # vec_d masks ..... ..... - encoder_decoder_blocks, # vec_d masks ..... ..... - tl.LayerNorm(), # vec_d ..... ..... ..... - - # Map to output vocab. - tl.Select([0], n_in=3), # vec_d tok_d - tl.SparseDenseWithOptions( # vec_d ..... - output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, mode=mode), - ) - - -def EncoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation, - ff_dropout, - ff_chunk_size, - ff_use_sru, - ff_sparsity, - ff_sparsity_type, - attention_chunk_size, - attention_type, - n_attention_layers=1, - n_feedforward_layers=1): - """Returns a list of layers that implements a Transformer encoder block. - - The input to the block is a pair, (activations, mask), where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use. - n_attention_layers: how many residual causal attention layers should we - have before the feed-forward block (default: 1, the standard block) - n_feedforward_layers: how many FFNN layers should we have (default 1). - - Returns: - A list of layers that maps (activations, mask) to (activations, mask). - """ - # `n_attention_layers` number of residuals of attention layer + dropout. - # pylint: disable=g-complex-comprehension - residual_attentions = [ - tl.Residual(tl.LayerNorm(), - ApplyAttentionLayer(attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=False, - masked=True, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=attention_chunk_size, - mode=mode), - tl.Dropout(rate=dropout, - shared_axes=dropout_shared_axes, - mode=mode) - ) - for _ in range(n_attention_layers) - ] - - feed_forwards = [ - tl.Residual( - FeedForwardWithOptions(d_model, d_ff, dropout, - dropout_shared_axes, ff_activation, - ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, True, mode, False, - ff_sparsity_type) - ) - for _ in range(n_feedforward_layers) - ] - # pylint: enable=g-complex-comprehension - - return residual_attentions + feed_forwards - - -def DecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation, - ff_dropout, - ff_chunk_size, - ff_use_sru, - ff_sparsity, - ff_sparsity_type, - attention_chunk_size, - attention_type, - n_attention_layers=1, - n_feedforward_layers=1): - """Returns a list of layers that implements a Transformer decoder block. - - The input is an activation tensor. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use. - n_attention_layers: how many residual causal attention layers should we - have before the feed-forward block (default: 1, the standard block) - n_feedforward_layers: how many FFNN layers should we have (default 1). - - Returns: - A list of layers that maps an activation tensor to an activation tensor. - """ - # pylint: disable=g-complex-comprehension - causal_attentions = [ApplyAttentionLayer( - attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=True, - masked=False, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=attention_chunk_size, - mode=mode) for _ in range(n_attention_layers)] - - residual_attentions = [ - tl.Residual( - tl.LayerNorm(), - causal_attentions[i], - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - ) for i in range(n_attention_layers)] - - feed_forwards = [ - tl.Residual( - FeedForwardWithOptions(d_model, d_ff, dropout, - dropout_shared_axes, ff_activation, - ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, True, mode, False, - ff_sparsity_type) - ) - for _ in range(n_feedforward_layers) - ] - # pylint: enable=g-complex-comprehension - - return residual_attentions + feed_forwards - - -def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation, ff_dropout, ff_chunk_size, - ff_use_sru, ff_sparsity, ff_sparsity_type, - attention_chunk_size, attention_type, - enc_dec_attention_sparsity=0): - """Returns a list of layers implementing a Transformer encoder-decoder block. - - The input is a triple (decoder_activations, mask, encoder_activiations) where - the mask is created from the original input token IDs to prevent attending to - the padding part of the encoder. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - ff_dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout after the FF dense layer. - ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks - ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers - in addition to the feed-forward block (second int specifies sru size) - ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity - ff_sparsity_type: string, if ff_sparsity >0, - use SparseFF if ff_sparsity_type=`'1inN'` and - use BlockSparseFF if ff_sparsity_type=`'Block'` - attention_chunk_size: int, if > 0 run attention chunked at this size - attention_type: The attention layer to use. - enc_dec_attention_sparsity: Sparsity to use in encoder-decoder attention. - - Returns: - A list of layers which maps triples (decoder_activations, mask, - encoder_activations) to triples of the same sort. - """ - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - # TODO(afrozm): This layer isn't configurable because: We currently don't have - # any alternative for it (LSH cannot do it fundamentally, that's why we have - # NoEncDec models, and local attention doesn't make sense in the general - # setting where we don't know what in input is local to what in output; - # some variants of FAVOR can do it, so maybe in the future, - # but we don't have them yet). - if isinstance(enc_dec_attention_sparsity, tuple): - q_sparsity, result_sparsity = enc_dec_attention_sparsity - elif enc_dec_attention_sparsity > 0: - q_sparsity = enc_dec_attention_sparsity - result_sparsity = 'noop' # We simply skip Dense layer after attention. - else: - q_sparsity = None - result_sparsity = None - attention_qkv = tl.AttentionQKV( - d_model, n_heads=n_heads, dropout=dropout, mode=mode, - cache_KV_in_predict=True, - q_sparsity=q_sparsity, result_sparsity=result_sparsity) - - causal_attention = ApplyAttentionLayer( - attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=True, - masked=True, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=attention_chunk_size, - mode=mode) - - feed_forward = FeedForwardWithOptions(d_model, d_ff, dropout, - dropout_shared_axes, ff_activation, - ff_dropout, ff_chunk_size, ff_use_sru, - ff_sparsity, True, mode, False, - ff_sparsity_type) - - return [ # vec_d masks vec_e - tl.Residual( - tl.LayerNorm(), # vec_d ..... ..... - causal_attention, # vec_d ..... ..... - _Dropout(), # vec_d ..... ..... - ), - tl.Residual( - tl.LayerNorm(), # vec_d ..... ..... - tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e - attention_qkv, # vec_d masks vec_e - _Dropout(), # vec_d masks vec_e - ), - tl.Residual( - feed_forward # vec_d masks vec_e - ), - ] + residual_attentions = [ + tl.Residual( + tl.LayerNorm(), + causal_attentions[i], + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), + ) + for i in range(n_attention_layers) + ] + + feed_forwards = [ + tl.Residual( + FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, + False, + ff_sparsity_type, + ) + ) + for _ in range(n_feedforward_layers) + ] + # pylint: enable=g-complex-comprehension + + return residual_attentions + feed_forwards + + +def EncoderDecoderBlock( + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + ff_sparsity_type, + attention_chunk_size, + attention_type, + enc_dec_attention_sparsity=0, +): + """Returns a list of layers implementing a Transformer encoder-decoder block. + + The input is a triple (decoder_activations, mask, encoder_activiations) where + the mask is created from the original input token IDs to prevent attending to + the padding part of the encoder. + + Args: + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + ff_dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout after the FF dense layer. + ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks + ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers + in addition to the feed-forward block (second int specifies sru size) + ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity + ff_sparsity_type: string, if ff_sparsity >0, + use SparseFF if ff_sparsity_type=`'1inN'` and + use BlockSparseFF if ff_sparsity_type=`'Block'` + attention_chunk_size: int, if > 0 run attention chunked at this size + attention_type: The attention layer to use. + enc_dec_attention_sparsity: Sparsity to use in encoder-decoder attention. + + Returns: + A list of layers which maps triples (decoder_activations, mask, + encoder_activations) to triples of the same sort. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + # TODO(afrozm): This layer isn't configurable because: We currently don't have + # any alternative for it (LSH cannot do it fundamentally, that's why we have + # NoEncDec models, and local attention doesn't make sense in the general + # setting where we don't know what in input is local to what in output; + # some variants of FAVOR can do it, so maybe in the future, + # but we don't have them yet). + if isinstance(enc_dec_attention_sparsity, tuple): + q_sparsity, result_sparsity = enc_dec_attention_sparsity + elif enc_dec_attention_sparsity > 0: + q_sparsity = enc_dec_attention_sparsity + result_sparsity = "noop" # We simply skip Dense layer after attention. + else: + q_sparsity = None + result_sparsity = None + attention_qkv = tl.AttentionQKV( + d_model, + n_heads=n_heads, + dropout=dropout, + mode=mode, + cache_KV_in_predict=True, + q_sparsity=q_sparsity, + result_sparsity=result_sparsity, + ) + + causal_attention = ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=True, + masked=True, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=attention_chunk_size, + mode=mode, + ) + + feed_forward = FeedForwardWithOptions( + d_model, + d_ff, + dropout, + dropout_shared_axes, + ff_activation, + ff_dropout, + ff_chunk_size, + ff_use_sru, + ff_sparsity, + True, + mode, + False, + ff_sparsity_type, + ) + + return [ # vec_d masks vec_e + tl.Residual( + tl.LayerNorm(), # vec_d ..... ..... + causal_attention, # vec_d ..... ..... + _Dropout(), # vec_d ..... ..... + ), + tl.Residual( + tl.LayerNorm(), # vec_d ..... ..... + tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e + attention_qkv, # vec_d masks vec_e + _Dropout(), # vec_d masks vec_e + ), + tl.Residual(feed_forward), # vec_d masks vec_e + ] diff --git a/trax/models/research/configurable_transformer_test.py b/trax/models/research/configurable_transformer_test.py deleted file mode 100644 index 0c10f078f..000000000 --- a/trax/models/research/configurable_transformer_test.py +++ /dev/null @@ -1,188 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Transformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.layers import test_utils -from trax.models.research import configurable_transformer as ct - - -class ConfigurableTransformerTest(parameterized.TestCase): - - def test_transformer_lm_forward_shape(self): - vocab_size = 16 - model = ct.ConfigurableTransformerLM( - vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) - x = np.ones((3, 5)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, vocab_size)) - - def _test_transformer_forward_shape(self, input_vocab_size, - output_vocab_size): - model = ct.ConfigurableTransformer( - input_vocab_size, - output_vocab_size, - d_model=32, - d_ff=64, - n_encoder_layers=2, - n_decoder_layers=2, - n_heads=2) - xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - y, _ = model(xs) - - vocab_size = output_vocab_size or input_vocab_size - self.assertEqual(y.shape, (3, 5, vocab_size)) - - @parameterized.named_parameters(('same_vocab', 16, None), - ('same_size', 16, 16), - ('different_size', 16, 50)) - def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): - """Run the Transformer forward and check output shape.""" - self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) - - - def test_dot_product_causal_attention_fast_inference(self): - self._test_fast_inference(length=5) - - def _test_fast_inference(self, length): - with fastmath.use_backend(fastmath.Backend.JAX): - model_fn = functools.partial( - ct.ConfigurableTransformerLM, - vocab_size=16, - d_model=4, - d_ff=8, - n_layers=2, - n_heads=2, - ) - batch_size = 2 - inp = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_equals_predict(inp, model_fn) - - def test_sparse_configurable_transformer_fast_inference(self): - self._test_sparse_fast_inference(length=5) - - def _test_sparse_fast_inference(self, length): - with fastmath.use_backend(fastmath.Backend.JAX): - vocab_size = 16 - d_model = 4 - batch_size = 2 - - encoder_decoder_attention_type = functools.partial( - tl.MultiplicativeConvCausalAttention, - sparsity=2, - length_kernel_size=1, - ) - - model_fn = functools.partial( - ct.ConfigurableTransformer, - input_vocab_size=vocab_size, - d_model=d_model, - d_ff=8, - n_encoder_layers=2, - n_decoder_layers=2, - n_heads=2, - loss_sparsity=2, - ff_sparsity=2, - encoder_decoder_attention_type=encoder_decoder_attention_type, - ff_use_sru=(1, 4), - ) - - inp = np.random.randint(vocab_size, size=(batch_size, length)) - out = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_equals_predict((inp, out), model_fn, seq_tensor=1) - - @parameterized.named_parameters( - ('positional_encoding', None), - ('fixed_base_positional_encoding', 'fixed-base'), - ('infinite_positional_encoding', 'infinite'), - ('infinite_affine_positional_encoding', 'infinite-affine'), - ('axial_positional_encoding', (2, 16))) - def test_positional_encoder(self, pos_axial_shape): - # dim should divide FixedBasePositionalEncoding.n_digits - batch, length, dim = 2, 32, 8 - input_shape = (batch, length, dim) - vocab_size = 32 - x = np.random.randint(0, vocab_size - 1, input_shape) - # should sum to dim - pos_d_axial_embs = (4, 4) - - positional_encoding = ct.PositionalEncoder( - 'train', dropout=0.1, max_len=length, pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs) - _, _ = positional_encoding.init(shapes.signature(x)) - y = positional_encoding(x) - self.assertEqual(y.shape, input_shape) - - @parameterized.named_parameters( - ('input_vocab_size_only', 32, None), - ('output_vocab_size_only', None, 32), - ('same_input_output_vocab_size', 32, 32), - ('different_input_output_vocab_size', 32, 16), - ) - def test_embedding_and_positional_encodings(self, input_vocab_size, - output_vocab_size): - d_model = 16 - max_len = 32 - batch = 2 - input_shape = (batch, max_len) - output_vocab_size_expected = output_vocab_size or input_vocab_size - x_out = np.random.randint(0, output_vocab_size_expected - 1, input_shape) - if input_vocab_size is None: - x_in = np.random.uniform(size=list(input_shape) + [2]) - else: - x_in = np.random.randint(0, input_vocab_size - 1, input_shape) - - in_encoder, out_encoder, output_vocab_size_result = ( - ct.EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model, - 'train', - 0.1, - [-2], - max_len, - output_vocab_size=output_vocab_size, - pos_axial_shape=None, - pos_d_axial_embs=None)) - - self.assertEqual(output_vocab_size_result, output_vocab_size_expected) - - model_in = tl.Serial(in_encoder) - model_out = tl.Serial(out_encoder) - - model_in.init(shapes.signature(x_in)) - model_out.init(shapes.signature(x_out)) - - y = model_in(x_in) - self.assertEqual(y.shape, input_shape + (d_model,)) - - y = model_out(x_out) - self.assertEqual(y.shape, input_shape + (d_model,)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/hourglass.py b/trax/models/research/hourglass.py index 669cd7e39..b452a89ac 100644 --- a/trax/models/research/hourglass.py +++ b/trax/models/research/hourglass.py @@ -16,297 +16,341 @@ """Hourglass - a hierarchical Transformer language model.""" import trax.layers as tl -from trax.layers.research.rel_attention import get_rel_att_inputs -from trax.layers.research.rel_attention import RelativeAttentionWrapper -from trax.layers.research.resampling import AttentionResampling -from trax.layers.research.resampling import AveragePooling -from trax.layers.research.resampling import FeedForwardBlock -from trax.layers.research.resampling import LinearUpsampling + +from trax.layers.research.rel_attention import ( + RelativeAttentionWrapper, + get_rel_att_inputs, +) +from trax.layers.research.resampling import ( + AttentionResampling, + AveragePooling, + FeedForwardBlock, + LinearUpsampling, +) from trax.models.research.configurable_transformer import ApplyAttentionLayer -def _RelativeDecoderBlock(attention_type, d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation, - context_bias_layer, location_bias_layer, - total_pooling): - """Returns a list of layers. - - The layers implement a Transformer decoder block with relative attention - parametrization. - - The input to the block is a pair, (activations, mask), where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. - - Args: - attention_type: attention type. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each block; must be - an activation-type subclass of `Layer`. - context_bias_layer: context bias layer. - location_bias_layer: location bias layer. - total_pooling: The combined pool size of previously used funnel blocks. - - Returns: - A list of layers that maps (activations, att_vecs, mask) to - (activations, att_vecs, mask). - """ - if attention_type == RelativeAttentionWrapper: - attention = RelativeAttentionWrapper( - d_model, - n_heads, - dropout, - mode=mode, - context_bias_layer=context_bias_layer, - location_bias_layer=location_bias_layer, - total_pooling=total_pooling) - else: - attention = ApplyAttentionLayer( - attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=True, - masked=False, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=0, # Disables tl.Chunk in ApplyAttentionLayer. - mode=mode, +def _RelativeDecoderBlock( + attention_type, + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + context_bias_layer, + location_bias_layer, + total_pooling, +): + """Returns a list of layers. + + The layers implement a Transformer decoder block with relative attention + parametrization. + + The input to the block is a pair, (activations, mask), where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. + + Args: + attention_type: attention type. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each block; must be + an activation-type subclass of `Layer`. + context_bias_layer: context bias layer. + location_bias_layer: location bias layer. + total_pooling: The combined pool size of previously used funnel blocks. + + Returns: + A list of layers that maps (activations, att_vecs, mask) to + (activations, att_vecs, mask). + """ + if attention_type == RelativeAttentionWrapper: + attention = RelativeAttentionWrapper( + d_model, + n_heads, + dropout, + mode=mode, + context_bias_layer=context_bias_layer, + location_bias_layer=location_bias_layer, + total_pooling=total_pooling, + ) + else: + attention = ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=True, + masked=False, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=0, # Disables tl.Chunk in ApplyAttentionLayer. + mode=mode, + ) + + feed_forward = FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation ) - feed_forward = FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, - mode, ff_activation) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Residual( # vecs - tl.LayerNorm(), - attention, - _Dropout(), - ), # vecs - tl.Residual( - tl.LayerNorm(), - feed_forward, - _Dropout(), - ), # vecs - ] + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + tl.Residual( # vecs + tl.LayerNorm(), + attention, + _Dropout(), + ), # vecs + tl.Residual( + tl.LayerNorm(), + feed_forward, + _Dropout(), + ), # vecs + ] def _parse_hierarchy(hierarchy_str): # pylint: disable = invalid-name - """Parse hierarchy for Hourglass definition.""" - levels = hierarchy_str.split(' ') - if levels != levels[::-1]: - raise ValueError('Hierarchy is not a palindrome') - layer_level_pairs = [(x.split('@')) for x in levels[:1 + (len(levels) // 2)]] - hierarchy_n_layers = [int(x[0]) for x in layer_level_pairs] - total_sf_per_level = [int(x[1]) for x in layer_level_pairs] - - hierarchy_shorten_factors = [] - for current_sf, prev_sf in zip(total_sf_per_level, - [1] + total_sf_per_level[:-1]): - if current_sf % prev_sf != 0: - raise ValueError( - f'Hierarchy not divisible by previous level: {current_sf}, {prev_sf}') - hierarchy_shorten_factors.append(current_sf // prev_sf) - - return hierarchy_n_layers, hierarchy_shorten_factors - - -def HourglassLM(vocab_size, - d_model=512, - d_ff=2048, - vanilla_layers=(1, 1), - hierarchy='6@3', - n_heads=8, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.FastGelu, - vanilla_attn_type=RelativeAttentionWrapper, - middle_attn_type=RelativeAttentionWrapper, - downsampling_fn=AttentionResampling, - upsampling_fn=AttentionResampling, - attention_downsampling_fn=AveragePooling, - attention_upsampling_fn=LinearUpsampling): - """Returns a hierarchical Transformer language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - This model uses only the decoder part of the overall Transformer. - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor should - be an integer in `range(vocab_size)`. These integers typically represent - token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level - Transformer decoder layers before and after shortening. - hierarchy: string - shortening hierarchy, as described in the paper. - Hierarchy levels must form a palindrome, e.g. '1@2 2@6 1@2'. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value when - applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing - along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful - way to save memory and apply consistent masks to activation vectors at - different sequence positions. - mode: str: 'train' or 'eval'. - ff_activation: Type of activation function at the end of each encoder block; - must be an activation-type subclass of `Layer`. - vanilla_attn_type: class: attention class such as SelfAttention to use in - the layers before and after shortening (vanilla layers). - middle_attn_type: class: attention class to use in the middle layers (these - operating on the shortened sequence). - downsampling_fn: function that takes full token-level vectors of length `l` - and transforms them into `l` / `k` vectors, where `k` denotes - `shorten_factor` parameter. - upsampling_fn: function that takes shortened representations of a sequence, - consisting of `l` / `k` vectors and transforms them into full token-level - representations of length `l`. - attention_downsampling_fn: Downsampling function that transforms token-level - vectors into query vectors with reduced length. Necessary only when - AttentionResampling is used as `downsampling_fn`. - attention_upsampling_fn: Upsampling function for AttentionResampling. Valid - only when AttentionResampling is used as a `upsampling_fn`. - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - assert mode != 'predict' # For now, 'predict' mode is unsupported. - hierarchy_n_layers, hierarchy_shorten_factors = _parse_hierarchy(hierarchy) - - token_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - ] - - context_bias_layer, location_bias_layer = get_rel_att_inputs(d_model, n_heads) - - n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers - - def create_decoder_blocks(n_layers, total_pooling, # pylint: disable = invalid-name - attention_type): - decoder_blocks = [ - # pylint: disable=g-complex-comprehension - _RelativeDecoderBlock(attention_type, d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation, - context_bias_layer, location_bias_layer, - total_pooling) for _ in range(n_layers) + """Parse hierarchy for Hourglass definition.""" + levels = hierarchy_str.split(" ") + if levels != levels[::-1]: + raise ValueError("Hierarchy is not a palindrome") + layer_level_pairs = [(x.split("@")) for x in levels[: 1 + (len(levels) // 2)]] + hierarchy_n_layers = [int(x[0]) for x in layer_level_pairs] + total_sf_per_level = [int(x[1]) for x in layer_level_pairs] + + hierarchy_shorten_factors = [] + for current_sf, prev_sf in zip(total_sf_per_level, [1] + total_sf_per_level[:-1]): + if current_sf % prev_sf != 0: + raise ValueError( + f"Hierarchy not divisible by previous level: {current_sf}, {prev_sf}" + ) + hierarchy_shorten_factors.append(current_sf // prev_sf) + + return hierarchy_n_layers, hierarchy_shorten_factors + + +def HourglassLM( + vocab_size, + d_model=512, + d_ff=2048, + vanilla_layers=(1, 1), + hierarchy="6@3", + n_heads=8, + dropout=0.1, + dropout_shared_axes=None, + mode="train", + ff_activation=tl.FastGelu, + vanilla_attn_type=RelativeAttentionWrapper, + middle_attn_type=RelativeAttentionWrapper, + downsampling_fn=AttentionResampling, + upsampling_fn=AttentionResampling, + attention_downsampling_fn=AveragePooling, + attention_upsampling_fn=LinearUpsampling, +): + """Returns a hierarchical Transformer language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + This model uses only the decoder part of the overall Transformer. + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor should + be an integer in `range(vocab_size)`. These integers typically represent + token IDs from a vocabulary-based tokenizer. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + block. + vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level + Transformer decoder layers before and after shortening. + hierarchy: string - shortening hierarchy, as described in the paper. + Hierarchy levels must form a palindrome, e.g. '1@2 2@6 1@2'. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value when + applying dropout within an encoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing + along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful + way to save memory and apply consistent masks to activation vectors at + different sequence positions. + mode: str: 'train' or 'eval'. + ff_activation: Type of activation function at the end of each encoder block; + must be an activation-type subclass of `Layer`. + vanilla_attn_type: class: attention class such as SelfAttention to use in + the layers before and after shortening (vanilla layers). + middle_attn_type: class: attention class to use in the middle layers (these + operating on the shortened sequence). + downsampling_fn: function that takes full token-level vectors of length `l` + and transforms them into `l` / `k` vectors, where `k` denotes + `shorten_factor` parameter. + upsampling_fn: function that takes shortened representations of a sequence, + consisting of `l` / `k` vectors and transforms them into full token-level + representations of length `l`. + attention_downsampling_fn: Downsampling function that transforms token-level + vectors into query vectors with reduced length. Necessary only when + AttentionResampling is used as `downsampling_fn`. + attention_upsampling_fn: Upsampling function for AttentionResampling. Valid + only when AttentionResampling is used as a `upsampling_fn`. + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + assert mode != "predict" # For now, 'predict' mode is unsupported. + hierarchy_n_layers, hierarchy_shorten_factors = _parse_hierarchy(hierarchy) + + token_encoder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), ] - return decoder_blocks + [tl.LayerNorm()] - - def create_hourglass_valley(rest_shorten_factors, rest_n_funnel_blocks, # pylint: disable = invalid-name - current_total_pooling): - assert rest_shorten_factors - assert len(rest_shorten_factors) == len(rest_n_funnel_blocks) - - current_sf = rest_shorten_factors[0] - current_n_layers = rest_n_funnel_blocks[0] - - shortening_layer = downsampling_fn( - current_sf, - d_model, - is_upsampling=False, - d_ff=d_ff, - n_heads=n_heads, - dropout=dropout, - dropout_shared_axes=dropout_shared_axes, - mode=mode, - ff_activation=ff_activation, - context_bias_layer=context_bias_layer, - location_bias_layer=location_bias_layer, - total_pooling=current_total_pooling, - resampling_fn=attention_downsampling_fn) - - upsampling_layer = upsampling_fn( - current_sf, - d_model=d_model, - is_upsampling=True, - d_ff=d_ff, - n_heads=n_heads, - dropout=dropout, - dropout_shared_axes=dropout_shared_axes, - mode=mode, - ff_activation=ff_activation, - context_bias_layer=context_bias_layer, - location_bias_layer=location_bias_layer, - total_pooling=current_total_pooling, - resampling_fn=attention_upsampling_fn) - - if len(rest_shorten_factors) > 1: # we need to go deeper again - pre_stage_blocks = create_decoder_blocks( - current_n_layers, current_total_pooling * current_sf, - middle_attn_type) - - post_stage_blocks = create_decoder_blocks( - current_n_layers, current_total_pooling * current_sf, - middle_attn_type) - - return [ - tl.Dup(), - tl.ShiftRight(current_sf - 1, mode=mode), shortening_layer, - pre_stage_blocks, *create_hourglass_valley( - rest_shorten_factors[1:], rest_n_funnel_blocks[1:], - current_total_pooling * current_sf), post_stage_blocks, - upsampling_layer, - tl.LayerNorm(), - tl.Add() - ] - else: - blocks = create_decoder_blocks(current_n_layers, - current_total_pooling * current_sf, - middle_attn_type) - - return [ - tl.Dup(), - tl.ShiftRight(current_sf - 1), shortening_layer, blocks, - upsampling_layer, - tl.LayerNorm(), - tl.Add() - ] - - pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks, 1, - vanilla_attn_type) - - post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks, 1, - vanilla_attn_type) - - valley = create_hourglass_valley(hierarchy_shorten_factors, - hierarchy_n_layers, 1) - - # Assemble and return the model. - return tl.Serial( # tokens (or chunked tuple of tokens) - tl.ShiftRight(mode=mode), # toks - token_encoder, # vecs - pre_decoder_blocks, # vecs - valley, # shortened vecs - post_decoder_blocks, # vecs - tl.Dense(vocab_size), # vecs - ) + + context_bias_layer, location_bias_layer = get_rel_att_inputs(d_model, n_heads) + + n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers + + def create_decoder_blocks( + n_layers, + total_pooling, # pylint: disable = invalid-name + attention_type, + ): + decoder_blocks = [ + # pylint: disable=g-complex-comprehension + _RelativeDecoderBlock( + attention_type, + d_model, + d_ff, + n_heads, + dropout, + dropout_shared_axes, + mode, + ff_activation, + context_bias_layer, + location_bias_layer, + total_pooling, + ) + for _ in range(n_layers) + ] + return decoder_blocks + [tl.LayerNorm()] + + def create_hourglass_valley( + rest_shorten_factors, + rest_n_funnel_blocks, # pylint: disable = invalid-name + current_total_pooling, + ): + assert rest_shorten_factors + assert len(rest_shorten_factors) == len(rest_n_funnel_blocks) + + current_sf = rest_shorten_factors[0] + current_n_layers = rest_n_funnel_blocks[0] + + shortening_layer = downsampling_fn( + current_sf, + d_model, + is_upsampling=False, + d_ff=d_ff, + n_heads=n_heads, + dropout=dropout, + dropout_shared_axes=dropout_shared_axes, + mode=mode, + ff_activation=ff_activation, + context_bias_layer=context_bias_layer, + location_bias_layer=location_bias_layer, + total_pooling=current_total_pooling, + resampling_fn=attention_downsampling_fn, + ) + + upsampling_layer = upsampling_fn( + current_sf, + d_model=d_model, + is_upsampling=True, + d_ff=d_ff, + n_heads=n_heads, + dropout=dropout, + dropout_shared_axes=dropout_shared_axes, + mode=mode, + ff_activation=ff_activation, + context_bias_layer=context_bias_layer, + location_bias_layer=location_bias_layer, + total_pooling=current_total_pooling, + resampling_fn=attention_upsampling_fn, + ) + + if len(rest_shorten_factors) > 1: # we need to go deeper again + pre_stage_blocks = create_decoder_blocks( + current_n_layers, current_total_pooling * current_sf, middle_attn_type + ) + + post_stage_blocks = create_decoder_blocks( + current_n_layers, current_total_pooling * current_sf, middle_attn_type + ) + + return [ + tl.Dup(), + tl.ShiftRight(current_sf - 1, mode=mode), + shortening_layer, + pre_stage_blocks, + *create_hourglass_valley( + rest_shorten_factors[1:], + rest_n_funnel_blocks[1:], + current_total_pooling * current_sf, + ), + post_stage_blocks, + upsampling_layer, + tl.LayerNorm(), + tl.Add(), + ] + else: + blocks = create_decoder_blocks( + current_n_layers, current_total_pooling * current_sf, middle_attn_type + ) + + return [ + tl.Dup(), + tl.ShiftRight(current_sf - 1), + shortening_layer, + blocks, + upsampling_layer, + tl.LayerNorm(), + tl.Add(), + ] + + pre_decoder_blocks = create_decoder_blocks( + n_pre_decoder_blocks, 1, vanilla_attn_type + ) + + post_decoder_blocks = create_decoder_blocks( + n_post_decoder_blocks, 1, vanilla_attn_type + ) + + valley = create_hourglass_valley(hierarchy_shorten_factors, hierarchy_n_layers, 1) + + # Assemble and return the model. + return tl.Serial( # tokens (or chunked tuple of tokens) + tl.ShiftRight(mode=mode), # toks + token_encoder, # vecs + pre_decoder_blocks, # vecs + valley, # shortened vecs + post_decoder_blocks, # vecs + tl.Dense(vocab_size), # vecs + ) diff --git a/trax/models/research/hourglass_test.py b/trax/models/research/hourglass_test.py deleted file mode 100644 index 9329c109e..000000000 --- a/trax/models/research/hourglass_test.py +++ /dev/null @@ -1,145 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Hourglass model.""" - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import jax -import numpy as np -from trax import fastmath -from trax import layers as tl -from trax import shapes -import trax.layers.research.resampling as resampling -import trax.models.research.hourglass as hourglass - - -class HourglassTest(parameterized.TestCase): - - def _check_forward_shape(self, model, input_shape, output_vocab_size): - x = np.ones(input_shape).astype(np.int32) - model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (*input_shape, output_vocab_size)) - - def test_hourglass_lm_forward_shape(self): - d_model = 16 - vocab_size = 7 - model = hourglass.HourglassLM( - vocab_size, - hierarchy='2@3 2@6 2@3', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - batch_size, seq_len = 3, 24 - self._check_forward_shape(model, - input_shape=(batch_size, seq_len), - output_vocab_size=vocab_size) - - def test_lsh_attention_in_vanilla(self): - d_model = 16 - vocab_size = 7 - - gin.bind_parameter('PureLSHSelfAttentionWrapper.pure_lsh_implementation', - tl.PureLSHSelfAttention) - gin.bind_parameter('PureLSHSelfAttention.chunk_len', 2) - - model = hourglass.HourglassLM( - vocab_size, - hierarchy='2@3', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - vanilla_attn_type=tl.PureLSHSelfAttentionWrapper, - downsampling_fn=resampling.LinearPooling, - upsampling_fn=resampling.LinearUpsampling, - ) - - batch_size, seq_len = 3, 12 - self._check_forward_shape( - model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size) - - def _test_autoregressive_property(self, model, input_shape, - output_vocab_size): - rng_1 = jax.random.PRNGKey(0) - rng_2 = jax.random.PRNGKey(1) - - def _get_output_logits(unitialized_eval_model: tl.Layer, x): - input_signature = shapes.signature(x) - unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False) - - output_logits, *_ = unitialized_eval_model(x, rng=rng_1) - return output_logits - - def check_autoregressive_property(model): - with fastmath.use_backend(fastmath.Backend.JAX): - x_1 = jax.random.randint(rng_1, input_shape, 0, output_vocab_size) - y_1 = _get_output_logits(model, x_1) - - x_2 = jax.random.randint(rng_2, input_shape, 0, output_vocab_size) - - for i in range(input_shape[1]): - masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1) - - y_2 = _get_output_logits(model, masked_x_2) - self.assertEqual(y_2.shape[0], input_shape[1]) - np.testing.assert_array_almost_equal(y_1[:i + 1], y_2[:i + 1]) - - check_autoregressive_property(model) - - def test_hourglass_lm_autoregressive_property(self): - d_model = 8 - vocab_size = 26 - - model_single_stage = hourglass.HourglassLM( - vocab_size, - hierarchy='2@4', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - model_multi_stage = hourglass.HourglassLM( - vocab_size, - hierarchy='2@3 2@6 2@3', - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - input_shape = (1, 12) - self._test_autoregressive_property(model_single_stage, input_shape, - output_vocab_size=vocab_size) - self._test_autoregressive_property(model_multi_stage, input_shape, - output_vocab_size=vocab_size) - - def test_parse_hourglass_hierarchy(self): - self.assertEqual(hourglass._parse_hierarchy('6@3'), ([6], [3])) - self.assertEqual(hourglass._parse_hierarchy('3@2 2@6 5@24 2@6 3@2'), ( - [3, 2, 5], [2, 3, 4] - )) - self.assertRaises(ValueError, hourglass._parse_hierarchy, '1@2 2@3 1@2') - self.assertRaises(ValueError, hourglass._parse_hierarchy, '1@2 2@3') - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/layerdrop_transformer.py b/trax/models/research/layerdrop_transformer.py index 0709fad38..57364f7a6 100644 --- a/trax/models/research/layerdrop_transformer.py +++ b/trax/models/research/layerdrop_transformer.py @@ -24,278 +24,303 @@ def LargerThan(val): - """Checks if the input is larger than a certain value.""" - return tl.Fn('LargerThan', lambda x: x > val) - - -@assert_shape('...s->...sv') -def SkippingTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - mode='train', - ff_activation=tl.Relu, - skip_fraction=0.4): - """Returns a Skipping Transformer language model. - - The input to the model is a tensor of tokens. (This model uses only the - decoder part of the overall Transformer.) - - Args: - vocab_size: int: vocab size - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_layers: int: number of encoder/decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference - ff_activation: the non-linearity in feed-forward layer - skip_fraction: fraction of times to skip some layers - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.PositionalEncoding(max_len=max_len, mode=mode), - ] - - @assert_shape('...sd,->...sd,') - def ConditionedBlock(current_layer_num): - return tl.Serial( - # stack: embedding, n_layers_to_keep - tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep - tl.Cond( - # if n_layers_to_keep > current_layer_num - LargerThan(float(current_layer_num)), - # then: run block - tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access - d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), - # else: run noop - tl.Serial() + """Checks if the input is larger than a certain value.""" + return tl.Fn("LargerThan", lambda x: x > val) + + +@assert_shape("...s->...sv") +def SkippingTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode="train", + ff_activation=tl.Relu, + skip_fraction=0.4, +): + """Returns a Skipping Transformer language model. + + The input to the model is a tensor of tokens. (This model uses only the + decoder part of the overall Transformer.) + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference + ff_activation: the non-linearity in feed-forward layer + skip_fraction: fraction of times to skip some layers + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.PositionalEncoding(max_len=max_len, mode=mode), + ] + + @assert_shape("...sd,->...sd,") + def ConditionedBlock(current_layer_num): + return tl.Serial( + # stack: embedding, n_layers_to_keep + tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep + tl.Cond( + # if n_layers_to_keep > current_layer_num + LargerThan(float(current_layer_num)), + # then: run block + tl.Serial( + transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access + d_model, d_ff, n_heads, dropout, [], mode, ff_activation + ), + ), + # else: run noop + tl.Serial(), ) - # stack: embedding, n_layers_to_keep + # stack: embedding, n_layers_to_keep ) - if mode == 'train': - if skip_fraction == 0.0: - minimum_layers = float(n_layers) - maximum_layers = float(n_layers) + if mode == "train": + if skip_fraction == 0.0: + minimum_layers = float(n_layers) + maximum_layers = float(n_layers) + else: + minimum_layers = 0.0 + maximum_layers = float(n_layers) / skip_fraction else: - minimum_layers = 0.0 - maximum_layers = float(n_layers) / skip_fraction - else: - minimum_layers = maximum_layers = float(n_layers) - - return tl.Serial( - tl.ShiftRight(mode=mode), - embedder, - # stack: embedding - tl.RandomUniform(minimum_layers, maximum_layers, sync=True), - # stack: n_layers_to_keep, embedding - tl.Swap(), - # stack: embedding, n_layers_to_keep - [ConditionedBlock(i) for i in range(n_layers)], - # stack: embedding, n_layers_to_keep - tl.AssertShape('...sd,'), - tl.Select([0], n_in=2), # stack: embedding - tl.AssertShape('...sd'), - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -@assert_shape('...s->...sv') -def EveryOtherLayerDropTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - mode='train', - ff_activation=tl.Relu, - skip_mode='even', - skip_fraction=0.5, - eval_skip_fraction=0.0): - """Returns an "EveryOther" LayerDrop Transformer language model. - - During each training step it either runs all layers, or skips a subset of - layers. This subset is the same every time, and it is specified by - "skip_mode". - The input to the model is a tensor of tokens. (This model uses only the - decoder part of the overall Transformer.) - - Args: - vocab_size: int: vocab size - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_layers: int: number of encoder/decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference - ff_activation: the non-linearity in feed-forward layer - skip_mode: which layers to skip when skipping: even/odd/1half/2half. - skip_fraction: fraction of times to skip layers - eval_skip_fraction: fraction of times to skip layers during eval - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.PositionalEncoding(max_len=max_len, mode=mode), - ] - - if mode == 'train': - pass - else: - skip_fraction = eval_skip_fraction - - skip_mode_funs = { # which layers should be skipped? - 'even': (lambda num: num%2 == 0), # 0th layer is even - 'odd': (lambda num: num%2 == 1), - '1half': (lambda num: num < (n_layers/2)), - '2half': (lambda num: num >= (n_layers/2)), - } - - skip_mode_fun = skip_mode_funs[skip_mode] - - @assert_shape('...sd,->...sd,') - def ConditionedBlock(current_layer_num): + minimum_layers = maximum_layers = float(n_layers) + return tl.Serial( + tl.ShiftRight(mode=mode), + embedder, + # stack: embedding + tl.RandomUniform(minimum_layers, maximum_layers, sync=True), + # stack: n_layers_to_keep, embedding + tl.Swap(), # stack: embedding, n_layers_to_keep - tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep - tl.Cond( - # if random() > skip_fraction OR layer not in skip_mode ... - LargerThan(skip_fraction if skip_mode_fun(current_layer_num) - else 0.0), - # then: run block - tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access - d_model, d_ff, n_heads, dropout, [], mode, ff_activation)) - # else: noop (implicit) - ) + [ConditionedBlock(i) for i in range(n_layers)], # stack: embedding, n_layers_to_keep + tl.AssertShape("...sd,"), + tl.Select([0], n_in=2), # stack: embedding + tl.AssertShape("...sd"), + tl.LayerNorm(), + tl.Dense(vocab_size), + ) + + +@assert_shape("...s->...sv") +def EveryOtherLayerDropTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode="train", + ff_activation=tl.Relu, + skip_mode="even", + skip_fraction=0.5, + eval_skip_fraction=0.0, +): + """Returns an "EveryOther" LayerDrop Transformer language model. + + During each training step it either runs all layers, or skips a subset of + layers. This subset is the same every time, and it is specified by + "skip_mode". + The input to the model is a tensor of tokens. (This model uses only the + decoder part of the overall Transformer.) + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference + ff_activation: the non-linearity in feed-forward layer + skip_mode: which layers to skip when skipping: even/odd/1half/2half. + skip_fraction: fraction of times to skip layers + eval_skip_fraction: fraction of times to skip layers during eval + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.PositionalEncoding(max_len=max_len, mode=mode), + ] + + if mode == "train": + pass + else: + skip_fraction = eval_skip_fraction + + skip_mode_funs = { # which layers should be skipped? + "even": (lambda num: num % 2 == 0), # 0th layer is even + "odd": (lambda num: num % 2 == 1), + "1half": (lambda num: num < (n_layers / 2)), + "2half": (lambda num: num >= (n_layers / 2)), + } + + skip_mode_fun = skip_mode_funs[skip_mode] + + @assert_shape("...sd,->...sd,") + def ConditionedBlock(current_layer_num): + return tl.Serial( + # stack: embedding, n_layers_to_keep + tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep + tl.Cond( + # if random() > skip_fraction OR layer not in skip_mode ... + LargerThan(skip_fraction if skip_mode_fun(current_layer_num) else 0.0), + # then: run block + tl.Serial( + transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access + d_model, d_ff, n_heads, dropout, [], mode, ff_activation + ) + ) + # else: noop (implicit) + ) + # stack: embedding, n_layers_to_keep ) - return tl.Serial( - tl.ShiftRight(mode=mode), - embedder, - # stack: embedding - tl.RandomUniform(0., 1., sync=True), - # stack: n_layers_to_keep, embedding - tl.Swap(), - # stack: embedding, n_layers_to_keep - [ConditionedBlock(i) for i in range(n_layers)], - # stack: embedding, n_layers_to_keep - tl.Select([0], n_in=2), # stack: embedding - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -@assert_shape('...s->...sv') -def LayerDropTransformerLM(vocab_size, - d_model=512, - d_ff=2048, - n_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - mode='train', - ff_activation=tl.Relu, - skip_fraction=0.4, - eval_skip_fraction='every_other'): - """Returns a LayerDrop Transformer language model. - - Based on Fan, Grave, Joulin 2019, https://arxiv.org/abs/1909.11556 . - - The input to the model is a tensor of tokens. (This model uses only the - decoder part of the overall Transformer.) - - Args: - vocab_size: int: vocab size - d_model: int: depth of embedding - d_ff: int: depth of feed-forward layer - n_layers: int: number of encoder/decoder layers - n_heads: int: number of attention heads - dropout: float: dropout rate (how much to drop out) - max_len: int: maximum symbol length for positional encoding - mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference - ff_activation: the non-linearity in feed-forward layer - skip_fraction: probability of skipping a layer; it can be a single - probability or a list of probabilities different for each layer - eval_skip_fraction: probability of skipping a layer during eval; it can be a - single probability, or a list of probabilities different for each layer, - or a string "every other" implementing a strategy from original paper - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - embedder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.PositionalEncoding(max_len=max_len, mode=mode), - ] - - if not isinstance(skip_fraction, (list, tuple)): - # If we don't get a list of skip_fractions we use the same skip_fraction - # for each layer. - skip_fraction = [skip_fraction for i in range(n_layers)] - if len(skip_fraction) != n_layers: - raise ValueError('n_layers ({}) must be equal to len(skip_fraction) ({})' - .format(n_layers, len(skip_fraction))) - - if eval_skip_fraction == 'every_other': - # 100% skipping for even-numbered layers; 0% for odd-numbered layers. - eval_skip_fraction = [(1.0 if i % int(1./skip_fraction[i]) == 0 else 0.0) - if skip_fraction[i] != 0 else 0.0 - for i in range(n_layers)] - if eval_skip_fraction == 'same': - # Same skip_fraction as in training. - eval_skip_fraction = skip_fraction - if not isinstance(eval_skip_fraction, (list, tuple)): - # If we don't get a list of eval_skip_fractions we use the same - # eval_skip_fraction for each layer. - eval_skip_fraction = [eval_skip_fraction for i in range(n_layers)] - if len(eval_skip_fraction) != n_layers: - raise ValueError( - 'n_layers ({}) must be equal to len(eval_skip_fraction) ({})' - .format(n_layers, len(eval_skip_fraction))) - - @assert_shape('...sd->...sd') - def ConditionedBlock(current_layer_num): return tl.Serial( + tl.ShiftRight(mode=mode), + embedder, # stack: embedding - tl.RandomUniform(0., 1, sync=True), - # stack: random_uniform, embedding - tl.Cond( - # if random_uniform > skip_fraction - LargerThan(skip_fraction[current_layer_num] if mode == 'train' - else eval_skip_fraction[current_layer_num]), - # then: run block - tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access - d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), - # else: run noop - tl.Serial() + tl.RandomUniform(0.0, 1.0, sync=True), + # stack: n_layers_to_keep, embedding + tl.Swap(), + # stack: embedding, n_layers_to_keep + [ConditionedBlock(i) for i in range(n_layers)], + # stack: embedding, n_layers_to_keep + tl.Select([0], n_in=2), # stack: embedding + tl.LayerNorm(), + tl.Dense(vocab_size), + ) + + +@assert_shape("...s->...sv") +def LayerDropTransformerLM( + vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode="train", + ff_activation=tl.Relu, + skip_fraction=0.4, + eval_skip_fraction="every_other", +): + """Returns a LayerDrop Transformer language model. + + Based on Fan, Grave, Joulin 2019, https://arxiv.org/abs/1909.11556 . + + The input to the model is a tensor of tokens. (This model uses only the + decoder part of the overall Transformer.) + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference + ff_activation: the non-linearity in feed-forward layer + skip_fraction: probability of skipping a layer; it can be a single + probability or a list of probabilities different for each layer + eval_skip_fraction: probability of skipping a layer during eval; it can be a + single probability, or a list of probabilities different for each layer, + or a string "every other" implementing a strategy from original paper + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + embedder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.PositionalEncoding(max_len=max_len, mode=mode), + ] + + if not isinstance(skip_fraction, (list, tuple)): + # If we don't get a list of skip_fractions we use the same skip_fraction + # for each layer. + skip_fraction = [skip_fraction for i in range(n_layers)] + if len(skip_fraction) != n_layers: + raise ValueError( + "n_layers ({}) must be equal to len(skip_fraction) ({})".format( + n_layers, len(skip_fraction) + ) + ) + + if eval_skip_fraction == "every_other": + # 100% skipping for even-numbered layers; 0% for odd-numbered layers. + eval_skip_fraction = [ + (1.0 if i % int(1.0 / skip_fraction[i]) == 0 else 0.0) + if skip_fraction[i] != 0 + else 0.0 + for i in range(n_layers) + ] + if eval_skip_fraction == "same": + # Same skip_fraction as in training. + eval_skip_fraction = skip_fraction + if not isinstance(eval_skip_fraction, (list, tuple)): + # If we don't get a list of eval_skip_fractions we use the same + # eval_skip_fraction for each layer. + eval_skip_fraction = [eval_skip_fraction for i in range(n_layers)] + if len(eval_skip_fraction) != n_layers: + raise ValueError( + "n_layers ({}) must be equal to len(eval_skip_fraction) ({})".format( + n_layers, len(eval_skip_fraction) ) - # stack: embedding ) - return tl.Serial( - tl.ShiftRight(mode=mode), - embedder, - [ConditionedBlock(i) for i in range(n_layers)], - tl.LayerNorm(), - tl.Dense(vocab_size), - ) + @assert_shape("...sd->...sd") + def ConditionedBlock(current_layer_num): + return tl.Serial( + # stack: embedding + tl.RandomUniform(0.0, 1, sync=True), + # stack: random_uniform, embedding + tl.Cond( + # if random_uniform > skip_fraction + LargerThan( + skip_fraction[current_layer_num] + if mode == "train" + else eval_skip_fraction[current_layer_num] + ), + # then: run block + tl.Serial( + transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access + d_model, d_ff, n_heads, dropout, [], mode, ff_activation + ) + ), + # else: run noop + tl.Serial(), + ) + # stack: embedding + ) + + return tl.Serial( + tl.ShiftRight(mode=mode), + embedder, + [ConditionedBlock(i) for i in range(n_layers)], + tl.LayerNorm(), + tl.Dense(vocab_size), + ) diff --git a/trax/models/research/layerdrop_transformer_test.py b/trax/models/research/layerdrop_transformer_test.py deleted file mode 100644 index 2fe41fe07..000000000 --- a/trax/models/research/layerdrop_transformer_test.py +++ /dev/null @@ -1,81 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Reformer models.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models.research import layerdrop_transformer - - -class SkippingTransformerTest(absltest.TestCase): - - def test_skipping_transformer_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.SkippingTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -class LayerDropTransformerTest(absltest.TestCase): - - def test_layerdrop_transformer_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.LayerDropTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - def test_layerdrop_layerwise_skip_fraction(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.LayerDropTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16, - skip_fraction=[0.2, 0.8]) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -class EveryOtherLayerDropTransformerTest(absltest.TestCase): - - def test_everyother_layerdrop_transformer_forward(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = layerdrop_transformer.EveryOtherLayerDropTransformerLM( - vocab_size, d_model=16, d_ff=32, n_layers=2, n_heads=2, max_len=16, - skip_mode='1half') - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/predict_terraformer.py b/trax/models/research/predict_terraformer.py index 79df866c5..2fc8edc97 100644 --- a/trax/models/research/predict_terraformer.py +++ b/trax/models/research/predict_terraformer.py @@ -40,99 +40,28 @@ """ - -import sys -import time - -import os -import random -import time -import numpy as np - -import trax -from trax import layers as tl -from trax import fastmath -from trax.fastmath import numpy as jnp -from trax.supervised import training -from trax.layers.assert_shape import assert_shape - - -import copy import functools -import gc -import os -import time -from jax.config import config -import numpy as np -import psutil -from tensorflow.compat.v2 import test -from trax import fastmath -from trax import layers as tl -from trax import models -from trax import shapes -from trax.supervised import decoding -import gin - - -# from colabtools import adhoc_import -import json -import gc -import jax import numpy as np -import os -import time -import gin - import tensorflow_datasets as tfds +import trax -# from colabtools import adhoc_import -import functools - -from trax.data import tf_inputs -import tensorflow_datasets as tfds -from t5.data import preprocessors as t5_processors -import t5.data - -from trax import data from trax import layers as tl from trax import models -from trax import optimizers -from trax.data import inputs -from trax.supervised import lr_schedules -from trax.supervised import trainer_lib -from trax.rl import serialization_utils -from trax.rl import space_serializer -import math from trax.fastmath import numpy as numpy_math -import trax - - -import numpy as np - -from trax import fastmath -from trax.fastmath import numpy as jnp -from trax.layers import base -from trax.layers import combinators as cb -from trax.layers import core -from trax.layers import initializers as init -from trax.layers.assert_shape import assert_shape -from trax.layers.base import Fn -from trax.layers.research import sparsity +from trax.learning.supervised import decoding -import functools -from trax import layers as tl -from trax.fastmath import numpy as jnp -from trax.models.reformer import reformer -from trax.models.research import configurable_transformer as ct -from trax.models.research import transformer2 as t2 +# from colabtools import adhoc_import +# from colabtools import adhoc_import ##### og_PositionalEncoding = tl.PositionalEncoding -trax.layers.attention.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64) +trax.layers.attention.PositionalEncoding = functools.partial( + og_PositionalEncoding, d_feature=64 +) trax.layers.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64) tl.PositionalEncoding = functools.partial(og_PositionalEncoding, d_feature=64) @@ -141,22 +70,24 @@ import gin + gin.enter_interactive_mode() def model_configure(*args, **kwargs): - kwargs['module'] = 'trax.models' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.models" + return gin.external_configurable(*args, **kwargs) + #### -xm2a_main = '/tmp/Terraformer/model_200000.pkl.gz' -xm2a_weights = '/tmp/Terraformer/model_200000.weights.npy.gz' -xm2a_opt_slots = '/tmp/Terraformer/model_200000.opt_slots0.npy.gz' -xm2a_config = '/tmp/Terraformer/config.gin' +xm2a_main = "/tmp/Terraformer/model_200000.pkl.gz" +xm2a_weights = "/tmp/Terraformer/model_200000.weights.npy.gz" +xm2a_opt_slots = "/tmp/Terraformer/model_200000.opt_slots0.npy.gz" +xm2a_config = "/tmp/Terraformer/config.gin" -VOCAB_FILE = 'en_16k.subword' -VOCAB_DIR = '/tmp/Terraformer' +VOCAB_FILE = "en_16k.subword" +VOCAB_DIR = "/tmp/Terraformer" #### @@ -169,31 +100,35 @@ def model_configure(*args, **kwargs): # ) og_DotProductCausalAttention = trax.layers.attention.DotProductCausalAttention trax.layers.attention.DotProductCausalAttention = functools.partial( - og_DotProductCausalAttention, max_inference_length=16384, + og_DotProductCausalAttention, + max_inference_length=16384, ) # gin_config.append( # '\nMixedLSHSelfAttention.std_length=16384' # ) -gin_config = [l for l in gin_config if 'mira' not in l] -gin_config = [l for l in gin_config if 'okenize' not in l] # tokenize +gin_config = [l for l in gin_config if "mira" not in l] +gin_config = [l for l in gin_config if "okenize" not in l] # tokenize -gin_config = ''.join(gin_config) +gin_config = "".join(gin_config) gin.parse_config(gin_config) -gin.operative_config_str().split('\n') +gin.operative_config_str().split("\n") print(gin_config) #### + def model(mode): - return models.ConfigurableTerraformer(mode=mode) + return models.ConfigurableTerraformer(mode=mode) + # #### -padding_fun = trax.data.PadToLength(len_map={0: 15*1024, 1: 15*1024, 2: 15*1024}, - pad_value = {0: 0, 1: 0, 2:0}) +padding_fun = trax.data.PadToLength( + len_map={0: 15 * 1024, 1: 15 * 1024, 2: 15 * 1024}, pad_value={0: 0, 1: 0, 2: 0} +) # padding_fun = lambda x: x # padding_fun = trax.data.PadToLength(len_map={0: 128, 1: 128, 2:128}, pad_value={0: 0, 1: 0, 2: 0}, multiple=True) @@ -202,48 +137,67 @@ def model(mode): dataset = tfds.summarization.scientific_papers.ScientificPapers() -valid = tfds.load(name='scientific_papers/arxiv:1.1.1')['test'] +valid = tfds.load(name="scientific_papers/arxiv:1.1.1")["test"] index = 0 xarts = [] for x in valid: - xarts.append(x) - index += 1 - if index == 3: - break + xarts.append(x) + index += 1 + if index == 3: + break model_file = xm2a_main shape11 = trax.shapes.ShapeDtype((1, 1), dtype=numpy_math.int32) -shape1l = trax.shapes.ShapeDtype((1, 15*1024), dtype=numpy_math.int32) +shape1l = trax.shapes.ShapeDtype((1, 15 * 1024), dtype=numpy_math.int32) with trax.fastmath.use_backend(trax.fastmath.Backend.JAX): - model = model(mode='eval') - model.init_from_file(model_file, weights_only=True) - # in mode='predict' use input_signature=(shape1l, shape11) - old_state = model.state + model = model(mode="eval") + model.init_from_file(model_file, weights_only=True) + # in mode='predict' use input_signature=(shape1l, shape11) + old_state = model.state # Decode the first article -xart = xarts[2]['article'] +xart = xarts[2]["article"] question = xart.numpy().decode() # print(question[:512]) -tokenized = next(padding_fun(trax.data.tokenize([question,], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, n_reserved_ids=100))) +tokenized = next( + padding_fun( + trax.data.tokenize( + [ + question, + ], + vocab_file=VOCAB_FILE, + vocab_dir=VOCAB_DIR, + n_reserved_ids=100, + ) + ) +) + def detokenize(x): - return trax.data.detokenize(x, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, - n_reserved_ids=100) + return trax.data.detokenize( + x, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, n_reserved_ids=100 + ) + with trax.fastmath.use_backend(trax.fastmath.Backend.JAX): - model.state = old_state - counter, tokens, max_length = 0, [], 30 - for token in decoding.autoregressive_sample_stream( - model, tokenized[None, :15*1024], batch_size=1, temperature=0.0, - eval_mode=True, eval_min_length=1024): - print(f'Token {counter}: "{detokenize(token)}" {token}') - tokens.append(token[:, None]) - counter += 1 - if counter > max_length: - break - tokens = np.concatenate(tokens, axis=1) - print(tokens) - print(detokenize(tokens[0])) + model.state = old_state + counter, tokens, max_length = 0, [], 30 + for token in decoding.autoregressive_sample_stream( + model, + tokenized[None, : 15 * 1024], + batch_size=1, + temperature=0.0, + eval_mode=True, + eval_min_length=1024, + ): + print(f'Token {counter}: "{detokenize(token)}" {token}') + tokens.append(token[:, None]) + counter += 1 + if counter > max_length: + break + tokens = np.concatenate(tokens, axis=1) + print(tokens) + print(detokenize(tokens[0])) diff --git a/trax/models/research/rezero_test.py b/trax/models/research/rezero_test.py deleted file mode 100644 index d6be6d32e..000000000 --- a/trax/models/research/rezero_test.py +++ /dev/null @@ -1,67 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for ReZero models.""" - -from absl.testing import absltest -import numpy as np - -from trax import layers as tl -from trax import shapes -from trax.models.research import rezero - - -class ResidualZeroTest(absltest.TestCase): - - def test_residual_layer_forward(self): - """Tests that the forward pass runs and returns the expected shape.""" - model = rezero.ResidualZero(tl.Dense(5)) - x = [np.arange(5).astype(np.float32)] - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.tolist(), [0., 1., 2., 3., 4.]) - - -class ReZeroTransformerLMTest(absltest.TestCase): - - def test_rezero_lm_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = rezero.ReZeroTransformerLM( - vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -class ReZeroTransformerTest(absltest.TestCase): - - def test_rezero_forward_shape(self): - """Tests that the forward pass runs and returns the expected shape.""" - vocab_size = 16 - model = rezero.ReZeroTransformer( - vocab_size, d_model=32, d_ff=64, n_encoder_layers=2, n_decoder_layers=2, - n_heads=2, max_len=16) - xs = [np.ones((1, 8)).astype(np.int32), - np.ones((1, 8)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - ys = model(xs) - self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/rse.py b/trax/models/research/rse.py index b5c51c247..f97cb193c 100644 --- a/trax/models/research/rse.py +++ b/trax/models/research/rse.py @@ -19,6 +19,7 @@ """ import numpy as np + from trax import fastmath from trax import layers as tl from trax.fastmath import numpy as jnp @@ -27,31 +28,30 @@ # pylint: disable=invalid-name def _inverse_sigmoid(x): - return np.log(x / (1 - x)) + return np.log(x / (1 - x)) -@assert_shape('...->...') +@assert_shape("...->...") class _ClippedScaling(tl.Layer): - """Pointwise multiplies by sigmoid(S) with a learnable vector S.""" + """Pointwise multiplies by sigmoid(S) with a learnable vector S.""" - def __init__(self, - residual_weight): - super().__init__(n_in=1, n_out=1) - self._residual_weight = residual_weight + def __init__(self, residual_weight): + super().__init__(n_in=1, n_out=1) + self._residual_weight = residual_weight - def forward(self, x): - s = self.weights - return jnp.multiply(x, fastmath.expit(s)) + def forward(self, x): + s = self.weights + return jnp.multiply(x, fastmath.expit(s)) - def init_weights_and_state(self, input_signature): - self.weights = _inverse_sigmoid(self._residual_weight) * np.ones( - (input_signature.shape[-1])).astype('float32') + def init_weights_and_state(self, input_signature): + self.weights = _inverse_sigmoid(self._residual_weight) * np.ones( + (input_signature.shape[-1]) + ).astype("float32") -@assert_shape('bld->bld') -def ResidualSwitchUnit( - d_model, dropout=0.1, mode='train', residual_weight=0.9): - r"""RSU (Residual Switch Unit) layer as in https://arxiv.org/pdf/2004.04662.pdf. +@assert_shape("bld->bld") +def ResidualSwitchUnit(d_model, dropout=0.1, mode="train", residual_weight=0.9): + r"""RSU (Residual Switch Unit) layer as in https://arxiv.org/pdf/2004.04662.pdf. As defined in the paper: @@ -75,145 +75,152 @@ def ResidualSwitchUnit( Returns: The RSU layer. """ - return tl.Serial( - tl.Fn( - 'Reshape2Pairs', - lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] // 2, -1)), - n_out=1), - tl.Residual( - tl.Dense(4 * d_model, use_bias=False), - tl.LayerNorm(), - tl.Gelu(), - tl.Dense(2 * d_model), - tl.Fn('Scaling', + return tl.Serial( + tl.Fn( + "Reshape2Pairs", + lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] // 2, -1)), + n_out=1, + ), + tl.Residual( + tl.Dense(4 * d_model, use_bias=False), + tl.LayerNorm(), + tl.Gelu(), + tl.Dense(2 * d_model), + tl.Fn( + "Scaling", lambda x: x * np.sqrt(1 - residual_weight**2) * 0.25, - n_out=1), - shortcut=_ClippedScaling(residual_weight)), - tl.Fn( - 'UnPair', - lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)), - n_out=1), - tl.Dropout(rate=dropout, mode=mode) - ) + n_out=1, + ), + shortcut=_ClippedScaling(residual_weight), + ), + tl.Fn( + "UnPair", + lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)), + n_out=1, + ), + tl.Dropout(rate=dropout, mode=mode), + ) def _ror(x, n, p=1): - """Bitwise right rotation. + """Bitwise right rotation. - Args: - x: np.array - n: Bit count to represent each value of x - p: Bit positions to shift + Args: + x: np.array + n: Bit count to represent each value of x + p: Bit positions to shift - Returns: - np.array: x with all values shifted by p positions in n bits - """ - a = np.right_shift(x, p) - b = np.left_shift(1, p) - 1 - c = np.bitwise_and(x, b) - d = np.left_shift(c, n - p) + Returns: + np.array: x with all values shifted by p positions in n bits + """ + a = np.right_shift(x, p) + b = np.left_shift(1, p) - 1 + c = np.bitwise_and(x, b) + d = np.left_shift(c, n - p) - return a + d + return a + d def _rol(x, n, p=1): - """Bitwise left rotation. + """Bitwise left rotation. - Args: - x: np.array - n: Bit count to represent each value of x - p: Bit positions to shift + Args: + x: np.array + n: Bit count to represent each value of x + p: Bit positions to shift - Returns: - np.array: x with all values shifted by p positions in n bits - """ - a = np.left_shift(x, p) - b = np.left_shift(1, n) - 1 - c = np.bitwise_and(a, b) - d = np.right_shift(x, n - p) + Returns: + np.array: x with all values shifted by p positions in n bits + """ + a = np.left_shift(x, p) + b = np.left_shift(1, n) - 1 + c = np.bitwise_and(a, b) + d = np.right_shift(x, n - p) - return np.bitwise_or(c, d) + return np.bitwise_or(c, d) def _shuffle_layer(inputs, shuffle_fn): - """Shuffles the elements according to bitwise left or right rotation. + """Shuffles the elements according to bitwise left or right rotation. - Args: - inputs: Tensor input from previous layer - shuffle_fn: Shift function rol or ror + Args: + inputs: Tensor input from previous layer + shuffle_fn: Shift function rol or ror - Returns: - tf.Tensor: Inputs shifted according to shuffle_fn - """ - seq_length = inputs.shape[1] - n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 + Returns: + tf.Tensor: Inputs shifted according to shuffle_fn + """ + seq_length = inputs.shape[1] + n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 - indices = np.arange(0, seq_length).astype('int32') - rev_indices = shuffle_fn(indices, n_bits) - return jnp.take(inputs, rev_indices, axis=1, mode='clip') + indices = np.arange(0, seq_length).astype("int32") + rev_indices = shuffle_fn(indices, n_bits) + return jnp.take(inputs, rev_indices, axis=1, mode="clip") -@assert_shape('bld->bld') +@assert_shape("bld->bld") def ShuffleLayer(): - return tl.Fn( - 'ShuffleLayer', lambda x: _shuffle_layer(x, _rol), n_out=1) + return tl.Fn("ShuffleLayer", lambda x: _shuffle_layer(x, _rol), n_out=1) -@assert_shape('bld->bld') +@assert_shape("bld->bld") def ReverseShuffleLayer(): - return tl.Fn( - 'ReverseShuffleLayer', lambda x: _shuffle_layer(x, _ror), n_out=1) + return tl.Fn("ReverseShuffleLayer", lambda x: _shuffle_layer(x, _ror), n_out=1) -@assert_shape('...,bld->...,bld') +@assert_shape("...,bld->...,bld") def _ForwardStep(d_model, dropout, mode): - """Takes (n_layer, state) and returns (n_layer, shuffle_layer(rsu(state))).""" - return tl.Parallel([], tl.Serial( - ResidualSwitchUnit(d_model, dropout, mode), - ShuffleLayer(), - )) + """Takes (n_layer, state) and returns (n_layer, shuffle_layer(rsu(state))).""" + return tl.Parallel( + [], + tl.Serial( + ResidualSwitchUnit(d_model, dropout, mode), + ShuffleLayer(), + ), + ) -@assert_shape('...,bld->...,bld') +@assert_shape("...,bld->...,bld") def _BackwardStep(d_model, dropout, mode): - """Takes (n_layer, state) and returns (n_layer, reverse_shuffle_layer(rsu(state))).""" - return tl.Parallel([], tl.Serial( - ResidualSwitchUnit(d_model, dropout, mode), - ReverseShuffleLayer(), - )) + """Takes (n_layer, state) and returns (n_layer, reverse_shuffle_layer(rsu(state))).""" + return tl.Parallel( + [], + tl.Serial( + ResidualSwitchUnit(d_model, dropout, mode), + ReverseShuffleLayer(), + ), + ) -@assert_shape('bld->bld') +@assert_shape("bld->bld") def BenesBlock(d_model, dropout, mode): - def bit_sequence(inputs): - seq_length = inputs.shape[1] - n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 - return jnp.arange(0, n_bits) - return tl.Serial( - tl.Dup(), - tl.Fn('BitSeq', bit_sequence, n_out=1), - tl.Scan(_ForwardStep(d_model, dropout, mode)), - tl.Scan(_BackwardStep(d_model, dropout, mode)), - tl.Select([1]), - ) - - -@assert_shape('bl->blv') -def ResidualShuffleExchange(vocab_size, - d_model, - input_dropout, - dropout, - mode='train', - n_blocks=2): - """Returns a Residual Shuffle Exchange Network model.""" - benes_blocks = [BenesBlock(d_model, dropout, mode) for _ in range(n_blocks)] - return tl.Serial( - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=input_dropout, mode=mode), - # Apply Benes Block n_blocks times. - *benes_blocks, - ResidualSwitchUnit(d_model, dropout, mode), - # Produce probabilities. - tl.Dense(vocab_size), - tl.LogSoftmax(), - ) + def bit_sequence(inputs): + seq_length = inputs.shape[1] + n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 + return jnp.arange(0, n_bits) + + return tl.Serial( + tl.Dup(), + tl.Fn("BitSeq", bit_sequence, n_out=1), + tl.Scan(_ForwardStep(d_model, dropout, mode)), + tl.Scan(_BackwardStep(d_model, dropout, mode)), + tl.Select([1]), + ) + + +@assert_shape("bl->blv") +def ResidualShuffleExchange( + vocab_size, d_model, input_dropout, dropout, mode="train", n_blocks=2 +): + """Returns a Residual Shuffle Exchange Network model.""" + benes_blocks = [BenesBlock(d_model, dropout, mode) for _ in range(n_blocks)] + return tl.Serial( + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=input_dropout, mode=mode), + # Apply Benes Block n_blocks times. + *benes_blocks, + ResidualSwitchUnit(d_model, dropout, mode), + # Produce probabilities. + tl.Dense(vocab_size), + tl.LogSoftmax(), + ) diff --git a/trax/models/research/rse_test.py b/trax/models/research/rse_test.py deleted file mode 100644 index 36891dbe5..000000000 --- a/trax/models/research/rse_test.py +++ /dev/null @@ -1,110 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Residual Shuffle-Exchange Networks.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models.research import rse - - -class RSETest(absltest.TestCase): - - def test_rsu_forward_shape(self): - batch_size = 3 - seq_len = 32 - d_model = 17 - model = rse.ResidualSwitchUnit( - d_model=d_model, dropout=0.1, mode='train') - x = np.ones((batch_size, seq_len, d_model)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (batch_size, seq_len, d_model)) - - def test_shuffle_layer(self): - shuffle_layer = rse.ShuffleLayer() - x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) - print(x.shape) - _, _ = shuffle_layer.init(shapes.signature(x)) - y = shuffle_layer(x) - expected_output = np.array([[[0], [2], [4], [6], [1], [3], [5], [7]]]) - self._assert_equal_tensors(y, expected_output) - - def test_shuffle_layer_log_times_is_identity(self): - seq_len = 8 - d_model = 17 - shuffle_layer = rse.ShuffleLayer() - x = _input_with_indice_as_values(seq_len, d_model) - _, _ = shuffle_layer.init(shapes.signature(x)) - y = x - for _ in range(int(np.log2(seq_len))): - y = shuffle_layer(y) - self._assert_equal_tensors(x, y) - - def test_reverse_shuffle_layer(self): - reverse_shuffle_layer = rse.ReverseShuffleLayer() - x = np.array([[[0], [1], [2], [3], [4], [5], [6], [7]]]) - print(x.shape) - _, _ = reverse_shuffle_layer.init(shapes.signature(x)) - y = reverse_shuffle_layer(x) - expected_output = np.array([[[0], [4], [1], [5], [2], [6], [3], [7]]]) - self._assert_equal_tensors(y, expected_output) - - def test_reverse_shuffle_layer_log_times_is_identity(self): - seq_len = 8 - d_model = 17 - reverse_shuffle_layer = rse.ReverseShuffleLayer() - x = _input_with_indice_as_values(seq_len, d_model) - _, _ = reverse_shuffle_layer.init(shapes.signature(x)) - y = x - for _ in range(int(np.log2(seq_len))): - y = reverse_shuffle_layer(y) - self._assert_equal_tensors(x, y) - - def test_rse_forward_shape(self): - vocab_size = 12 - seq_len = 32 - model = rse.ResidualShuffleExchange( - vocab_size=vocab_size, d_model=17, dropout=0.1, input_dropout=0.05, - mode='train') - x = np.ones((3, seq_len)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, seq_len, vocab_size)) - - def _assert_equal_tensors(self, x, y): - self.assertEqual(y.shape, x.shape) - for i in range(x.shape[0]): - for j in range(x.shape[1]): - for k in range(x.shape[2]): - self.assertEqual( - x[i][j][k], y[i][j][k], - f'Tensors differ on index [{i}][{j}][{k}].') - - -def _input_with_indice_as_values(length, dim): - """Retuns np.array of size (1, length, dim) where x[0, a, b] = a.""" - positions = [] - for i in range(length): - positions.append([i] * dim) - positions_input = np.array(positions) - positions_input = np.expand_dims(positions_input, axis=0) - return positions_input - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/terraformer.py b/trax/models/research/terraformer.py index 892c5c5d9..797770977 100644 --- a/trax/models/research/terraformer.py +++ b/trax/models/research/terraformer.py @@ -19,6 +19,7 @@ """ import functools + from trax import layers as tl from trax.fastmath import numpy as jnp from trax.models.reformer import reformer @@ -29,462 +30,478 @@ # pylint: disable=invalid-name -def ConfigurableTerraformer(input_vocab_size, - output_vocab_size=None, - d_model=512, - d_ff=2048, - d_attention_key=None, - d_attention_value=None, - n_encoder_layers=6, - n_decoder_layers=6, - n_heads=8, - dropout=0.1, - max_len=2048, - encoder_attention_type=tl.SelfAttention, - encoder_decoder_attention_type=tl.SelfAttention, - pos_type='fixed-base', - pos_axial_shape=(), - pos_d_axial_embs=None, - pos_start_from_zero_prob=1.0, - pos_max_offset_to_add=0, - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=0, - ff_dropout=None, - ff_sparsity=0, - loss_sparsity_type='mult', - loss_sparsity=0, - loss_d_lowrank=0, - loss_sparsity_prob=None, - attention_chunk_size=0, - n_layers_forget=0, - forget_dense=True, - n_decoder_attention_layers=2, - use_bfloat16=False, - reversible_encoder=False, - use_two_swaps_per_encoder_block=True, - center_layernorm=True, - half_before_layer=None, - double_after_layer=None, - mode='train'): - """Returns a highly configurable Terraformer encoder-decoder model. - - This model maps paired text sequences (source and target) to float-valued - losses. If ``input_vocab_size`` is not ``None``, the layer takes - two input sequences: - - - inputs (2): - - - source: 2-D int array representing a batch of text strings via token - IDs plus padding markers; shape is `(batch_size, sequence_length)`, - where sequence_length <= ``max_len``. Array elements are in - ``range(input_vocab_size)``, and 0 values mark padding positions. - - - target: 2-D int array representing a batch of text strings via token - IDs plus padding markers; shape is `(batch_size, sequence_length)`, - where sequence_length <= ``max_len``. Array elements are in - ``range(output_vocab_size)``, and 0 values mark padding positions. - - - output: 1-D float array of losses; shape is `(batch_size)`. - - If ``input_vocab_size`` is ``None``, the layer takes three input sequences: - - - inputs (3): - - - source: 3-D float array representing a batch of already-embedded text - strings; shape is `(batch_size, sequence_length, d_model)`, where - sequence_length <= ``max_len``. - - - mask: 2-D int array representing active versus masked positions; 0 - values mark masked (padding) positions. - - - target: 2-D int array representing a batch of text strings via token - IDs plus padding markers; shape is `(batch_size, sequence_length)`, - where sequence_length <= ``max_len``. Array elements are in - ``range(output_vocab_size)``, and 0 values mark padding positions. - - - output: 1-D float array of losses; shape is `(batch_size)`. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if ``None``, then input and target integers (token IDs) are assumed to - come from the same vocabulary. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - d_attention_key: Depth of key vectors in each attention head. - d_attention_value: Depth of value vectors in each attention head. - n_encoder_layers: Number of encoder blocks. - n_decoder_layers: Number of decoder blocks. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - max_len: Maximum symbol length for positional encoding. - encoder_attention_type: Type of attention to use in the encoder; must be - an attention-type subclass of :py:class:`trax.layers.Layer`. - encoder_decoder_attention_type: Type of attention to use in the decoder; - must be an attention-type subclass of :py:class:`trax.layers.Layer`. - pos_type: String indicating the type of positional embeddings to use. - pos_axial_shape: Shape (tuple of ints) to use for the axial position - encoding. If unset, axial position encoding is disabled. - pos_d_axial_embs: Tuple of ints specifying the depth of position embedding - for each axis. Tuple length must match ``pos_axial_shape``, and values - must sum to ``d_model``. - pos_start_from_zero_prob: Stochastic rate (probability) for starting - positional encoding at position 0 during training. If 1.0, always start - from position 0; if < 1.0, the non-zero starts will be uniformly - distributed up to ``pos_max_offset_to_add``. - pos_max_offset_to_add: Maximum offset to add to positions during training - when randomizing. This offset plus input length must be less than - ``max_len`` for all training examples. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`trax.layers.Layer`. - ff_use_sru: If > 0, use this number of SRU layers in place of feedforward - layers. - ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this - size. - ff_dropout: Stochastic rate (probability) for dropping an activation value - at feedforward nonlinearities. - ff_sparsity: If > 0, use sparse feedforward blocks with this level of - sparsity. - loss_sparsity_type: String indicating the type of sparsity to used in loss - layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``, - use no sparsity. - loss_sparsity: If > 0, use this level of sparsity in the loss layer. - loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this - dimension, in the loss. - loss_sparsity_prob: Stochastic rate (probability) for using the sparse - version of the loss. If ``None``, use the sparse version exclusively. - attention_chunk_size: If > 0, compute attention using chunks of this size. - n_layers_forget: How often to have a forgetting block between layers. - forget_dense: If True, use :py:class:`Dense` instances as forget layers; - else use no-ops. - n_decoder_attention_layers: Number of attention layers in a decoder block. - use_bfloat16: If True, use bfloat16 for weights; else use float32. - reversible_encoder: If True, make the encoder be reversible. - use_two_swaps_per_encoder_block: If True, ensure that there is a an even - number of swaps across the encoder. - center_layernorm: If True, use centering in :py:class:`LayerNorm` (the - default); else omit centering (which is known as RMS normalization). - half_before_layer: If not None, specifies an n'th layer such that all - layers before the n'th use half the normal values for ``d_model`` and - ``d_ff``. - double_after_layer: If not None, specifies an n'th layer such that all - layers after the n'th use double the normal values for ``d_model`` and - ``d_ff``. - mode: If ``'train'``, include dropout in each encoder/decoder block; else - dropout layers have no effect. - - Returns: - A Terraformer encoder-decoder as a layer that maps from target and source - text sequences to a scalar loss. - """ - if mode == 'predict': - portal_mask = _PortalInput() - else: - portal_mask = None - - # Set default dimensions for attention head key and value sizes. - if (d_model / 2) % n_heads != 0: - raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})') - if d_attention_key is None: - d_attention_key = d_model // n_heads - if d_attention_value is None: - d_attention_value = d_model // n_heads - - # Set values of d_model, d_ff and d_qkv for the first stage. - d_model1, d_ff1 = d_model, d_ff - d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value - if half_before_layer: - d_model1, d_ff1 = d_model / 2, d_ff / 2 - d_attention_key1 = d_attention_key / 2 - d_attention_value1 = d_attention_value / 2 - - # Set values of d_model, d_ff and d_qkv for the final stage. - d_model2, d_ff2 = d_model, d_ff - d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value - if double_after_layer: - d_model2, d_ff2 = d_model * 2, d_ff * 2 - d_attention_key2 = d_attention_key * 2 - d_attention_value2 = d_attention_value * 2 - - # Vector embeddings. - in_encoder, out_encoder, output_vocab_size = ( - ct.EmbeddingAndPositionalEncodings( - input_vocab_size, - d_model1, - mode, - dropout, - [-2], # dropout_shared_axes - max_len, - output_vocab_size=output_vocab_size, - pos_type=pos_type, - pos_axial_shape=pos_axial_shape, - pos_d_axial_embs=pos_d_axial_embs, - pos_start_from_zero_prob=pos_start_from_zero_prob, - pos_max_offset_to_add=pos_max_offset_to_add, - use_bfloat16=use_bfloat16) - ) - - def _EncoderBlock(): - return reformer.EncoderBlock( +def ConfigurableTerraformer( + input_vocab_size, + output_vocab_size=None, + d_model=512, + d_ff=2048, + d_attention_key=None, + d_attention_value=None, + n_encoder_layers=6, + n_decoder_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + encoder_attention_type=tl.SelfAttention, + encoder_decoder_attention_type=tl.SelfAttention, + pos_type="fixed-base", + pos_axial_shape=(), + pos_d_axial_embs=None, + pos_start_from_zero_prob=1.0, + pos_max_offset_to_add=0, + ff_activation=tl.Relu, + ff_use_sru=0, + ff_chunk_size=0, + ff_dropout=None, + ff_sparsity=0, + loss_sparsity_type="mult", + loss_sparsity=0, + loss_d_lowrank=0, + loss_sparsity_prob=None, + attention_chunk_size=0, + n_layers_forget=0, + forget_dense=True, + n_decoder_attention_layers=2, + use_bfloat16=False, + reversible_encoder=False, + use_two_swaps_per_encoder_block=True, + center_layernorm=True, + half_before_layer=None, + double_after_layer=None, + mode="train", +): + """Returns a highly configurable Terraformer encoder-decoder model. + + This model maps paired text sequences (source and target) to float-valued + losses. If ``input_vocab_size`` is not ``None``, the layer takes + two input sequences: + + - inputs (2): + + - source: 2-D int array representing a batch of text strings via token + IDs plus padding markers; shape is `(batch_size, sequence_length)`, + where sequence_length <= ``max_len``. Array elements are in + ``range(input_vocab_size)``, and 0 values mark padding positions. + + - target: 2-D int array representing a batch of text strings via token + IDs plus padding markers; shape is `(batch_size, sequence_length)`, + where sequence_length <= ``max_len``. Array elements are in + ``range(output_vocab_size)``, and 0 values mark padding positions. + + - output: 1-D float array of losses; shape is `(batch_size)`. + + If ``input_vocab_size`` is ``None``, the layer takes three input sequences: + + - inputs (3): + + - source: 3-D float array representing a batch of already-embedded text + strings; shape is `(batch_size, sequence_length, d_model)`, where + sequence_length <= ``max_len``. + + - mask: 2-D int array representing active versus masked positions; 0 + values mark masked (padding) positions. + + - target: 2-D int array representing a batch of text strings via token + IDs plus padding markers; shape is `(batch_size, sequence_length)`, + where sequence_length <= ``max_len``. Array elements are in + ``range(output_vocab_size)``, and 0 values mark padding positions. + + - output: 1-D float array of losses; shape is `(batch_size)`. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if ``None``, then input and target integers (token IDs) are assumed to + come from the same vocabulary. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + d_attention_key: Depth of key vectors in each attention head. + d_attention_value: Depth of value vectors in each attention head. + n_encoder_layers: Number of encoder blocks. + n_decoder_layers: Number of decoder blocks. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder/decoder blocks. The same rate is + also used for attention dropout in encoder/decoder blocks. + max_len: Maximum symbol length for positional encoding. + encoder_attention_type: Type of attention to use in the encoder; must be + an attention-type subclass of :py:class:`trax.layers.Layer`. + encoder_decoder_attention_type: Type of attention to use in the decoder; + must be an attention-type subclass of :py:class:`trax.layers.Layer`. + pos_type: String indicating the type of positional embeddings to use. + pos_axial_shape: Shape (tuple of ints) to use for the axial position + encoding. If unset, axial position encoding is disabled. + pos_d_axial_embs: Tuple of ints specifying the depth of position embedding + for each axis. Tuple length must match ``pos_axial_shape``, and values + must sum to ``d_model``. + pos_start_from_zero_prob: Stochastic rate (probability) for starting + positional encoding at position 0 during training. If 1.0, always start + from position 0; if < 1.0, the non-zero starts will be uniformly + distributed up to ``pos_max_offset_to_add``. + pos_max_offset_to_add: Maximum offset to add to positions during training + when randomizing. This offset plus input length must be less than + ``max_len`` for all training examples. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`trax.layers.Layer`. + ff_use_sru: If > 0, use this number of SRU layers in place of feedforward + layers. + ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this + size. + ff_dropout: Stochastic rate (probability) for dropping an activation value + at feedforward nonlinearities. + ff_sparsity: If > 0, use sparse feedforward blocks with this level of + sparsity. + loss_sparsity_type: String indicating the type of sparsity to used in loss + layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``, + use no sparsity. + loss_sparsity: If > 0, use this level of sparsity in the loss layer. + loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this + dimension, in the loss. + loss_sparsity_prob: Stochastic rate (probability) for using the sparse + version of the loss. If ``None``, use the sparse version exclusively. + attention_chunk_size: If > 0, compute attention using chunks of this size. + n_layers_forget: How often to have a forgetting block between layers. + forget_dense: If True, use :py:class:`Dense` instances as forget layers; + else use no-ops. + n_decoder_attention_layers: Number of attention layers in a decoder block. + use_bfloat16: If True, use bfloat16 for weights; else use float32. + reversible_encoder: If True, make the encoder be reversible. + use_two_swaps_per_encoder_block: If True, ensure that there is a an even + number of swaps across the encoder. + center_layernorm: If True, use centering in :py:class:`LayerNorm` (the + default); else omit centering (which is known as RMS normalization). + half_before_layer: If not None, specifies an n'th layer such that all + layers before the n'th use half the normal values for ``d_model`` and + ``d_ff``. + double_after_layer: If not None, specifies an n'th layer such that all + layers after the n'th use double the normal values for ``d_model`` and + ``d_ff``. + mode: If ``'train'``, include dropout in each encoder/decoder block; else + dropout layers have no effect. + + Returns: + A Terraformer encoder-decoder as a layer that maps from target and source + text sequences to a scalar loss. + """ + if mode == "predict": + portal_mask = _PortalInput() + else: + portal_mask = None + + # Set default dimensions for attention head key and value sizes. + if (d_model / 2) % n_heads != 0: + raise ValueError(f"n_heads ({n_heads}) must divide d_model/2 ({d_model/2})") + if d_attention_key is None: + d_attention_key = d_model // n_heads + if d_attention_value is None: + d_attention_value = d_model // n_heads + + # Set values of d_model, d_ff and d_qkv for the first stage. + d_model1, d_ff1 = d_model, d_ff + d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value + if half_before_layer: + d_model1, d_ff1 = d_model / 2, d_ff / 2 + d_attention_key1 = d_attention_key / 2 + d_attention_value1 = d_attention_value / 2 + + # Set values of d_model, d_ff and d_qkv for the final stage. + d_model2, d_ff2 = d_model, d_ff + d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value + if double_after_layer: + d_model2, d_ff2 = d_model * 2, d_ff * 2 + d_attention_key2 = d_attention_key * 2 + d_attention_value2 = d_attention_value * 2 + + # Vector embeddings. + in_encoder, out_encoder, output_vocab_size = ct.EmbeddingAndPositionalEncodings( + input_vocab_size, d_model1, - d_ff1, - n_heads, - encoder_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=ff_dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - center_layernorm=center_layernorm, + mode, + dropout, + [-2], # dropout_shared_axes + max_len, + output_vocab_size=output_vocab_size, + pos_type=pos_type, + pos_axial_shape=pos_axial_shape, + pos_d_axial_embs=pos_d_axial_embs, + pos_start_from_zero_prob=pos_start_from_zero_prob, + pos_max_offset_to_add=pos_max_offset_to_add, use_bfloat16=use_bfloat16, - use_two_swaps_per_block=use_two_swaps_per_encoder_block, - mode=mode) + ) - def _Encoder(): # vec_e mask_e tok_e tok_d tok_d - layers = [ - tl.ReversibleSelect([0, 0]), - _ReversibleSerialForget( - [_EncoderBlock() for _ in range(n_encoder_layers)], + def _EncoderBlock(): + return reformer.EncoderBlock( d_model1, - n_layers_forget, - forget_dense) - ] - if not reversible_encoder: - layers += [ - _XYAvg(), - tl.Dense(d_model1, use_bfloat16=use_bfloat16), - tl.LayerNorm(), - ] - if mode == 'predict': - return tl.Cache(tl.Serial(layers)) - else: - return tl.Serial(layers) - - if mode == 'predict': - # TODO(jaszczur): Remove temporary fix of Terraformer padding in predict. - # In predict mode Terraformer needs masking for merged encoder-decoder - # sequence. This monkey patches the layer with a mask to neccessary places. - # This shouldn't be a permanent solution - mask should be passed through - # the stack and all the layers. - tl.attention.DotProductCausalAttention.monkey_patched_mask = ( - lambda x: portal_mask) - tl.research.sparsity._RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access - lambda x: portal_mask) - originalScanSRUCell = tl.rnn.ScanSRUCell - tl.rnn.ScanSRUCell = functools.partial(tl.rnn.ScanSRUCell, - monkey_patched_mask=portal_mask) - - decoder_blocks = [] - - if isinstance(encoder_decoder_attention_type, (tuple, list)): - assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 - else: - encoder_decoder_attention_type = [encoder_decoder_attention_type] - for layer_idx in range(n_decoder_layers): - layer_attention_type = encoder_decoder_attention_type[ - layer_idx % len(encoder_decoder_attention_type)] - # Grow d_model, d_ff, and d_qkv if requested. - d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 - if half_before_layer and layer_idx >= half_before_layer: - d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value - if double_after_layer and layer_idx > double_after_layer: - d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 - decoder_block = reformer.DecoderBlock( - d_m, d_f, d_k, d_v, n_heads, - attention_type=layer_attention_type, - dropout=dropout, - ff_activation=ff_activation, - ff_dropout=ff_dropout, - ff_use_sru=ff_use_sru, - ff_chunk_size=ff_chunk_size, - ff_sparsity=ff_sparsity, - attention_chunk_size=attention_chunk_size, - n_attention_layers=n_decoder_attention_layers, - center_layernorm=center_layernorm, - use_bfloat16=use_bfloat16, - mode=mode) - decoder_blocks.append(decoder_block) - if half_before_layer and layer_idx == half_before_layer - 1: - decoder_blocks.append(tl.ReversibleConcatenatePair()) - if double_after_layer and layer_idx == double_after_layer: - decoder_blocks.append(tl.ReversibleConcatenatePair()) - - if mode == 'predict': - # After initializing the decoder we can revert to original state of - # previously monkey-patched classes/functions. - tl.attention.DotProductCausalAttention.monkey_patched_mask = ( - lambda x: None) - tl.research.sparsity._RememberPad.monkey_patched_mask = (lambda x: None) # pylint: disable=protected-access - tl.rnn.ScanSRUCell = originalScanSRUCell - - def _Loss(): - return tl.SparseDenseWithOptions( - output_vocab_size, - d_input=d_model2, - sparsity_type=loss_sparsity_type, - sparsity=loss_sparsity, - d_lowrank=loss_d_lowrank, - prob_sparse=loss_sparsity_prob, - use_bfloat16=use_bfloat16, - mode=mode) - - def _enc_dec_concat(): - """Layers to merge encoder and decoder.""" - if reversible_encoder: - return [ - tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d - t2.ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d - ] - else: - return [ - tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d - t2.ConcatWithPadding(mode=mode), # v_ed tok_e tok_d - tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d - ] - - def _inp_layers(): - if input_vocab_size is not None: - return tl.AssertFunction( - 'bl,br->bld,bl,bl,br', # b: batch, l/r: enc/dec length, d: vec depth - tl.Serial( # tok_e tok_d - tl.Select([0, 0, 0, 1]), - tl.Parallel(in_encoder, [tl.PaddingMask(), - _RemoveAxes12()]) - )) # vec_e mask_e tok_e tok_d + d_ff1, + n_heads, + encoder_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=ff_dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + center_layernorm=center_layernorm, + use_bfloat16=use_bfloat16, + use_two_swaps_per_block=use_two_swaps_per_encoder_block, + mode=mode, + ) + + def _Encoder(): # vec_e mask_e tok_e tok_d tok_d + layers = [ + tl.ReversibleSelect([0, 0]), + _ReversibleSerialForget( + [_EncoderBlock() for _ in range(n_encoder_layers)], + d_model1, + n_layers_forget, + forget_dense, + ), + ] + if not reversible_encoder: + layers += [ + _XYAvg(), + tl.Dense(d_model1, use_bfloat16=use_bfloat16), + tl.LayerNorm(), + ] + if mode == "predict": + return tl.Cache(tl.Serial(layers)) + else: + return tl.Serial(layers) + + if mode == "predict": + # TODO(jaszczur): Remove temporary fix of Terraformer padding in predict. + # In predict mode Terraformer needs masking for merged encoder-decoder + # sequence. This monkey patches the layer with a mask to neccessary places. + # This shouldn't be a permanent solution - mask should be passed through + # the stack and all the layers. + tl.attention.DotProductCausalAttention.monkey_patched_mask = ( + lambda x: portal_mask + ) + tl.research.sparsity._RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access + lambda x: portal_mask + ) + originalScanSRUCell = tl.rnn.ScanSRUCell + tl.rnn.ScanSRUCell = functools.partial( + tl.rnn.ScanSRUCell, monkey_patched_mask=portal_mask + ) + + decoder_blocks = [] + + if isinstance(encoder_decoder_attention_type, (tuple, list)): + assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: - # Input in this case is vec_e, mask_e, tok_d. Where all downstream - # operations expect tok_e, we give it instead mask_e, expecting that - # downstream ops only are looking for padding/not padding. - return tl.AssertFunction( - 'blf,bl,br->bld,bl,bl,br', # f: in-feature depth, d: out-vector depth - tl.Serial( # vec_e mask_e tok_d - tl.Select([0, 1, 1, 2]), - tl.Parallel(in_encoder, [], _AsTokenIDs()) - )) # vec_e mask_e tok_e tok_d - - # Assemble and return the model. - return tl.Serial( - _inp_layers(), # vec_e mask_e tok_e tok_d - tl.Parallel([], portal_mask), - - tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss. - - # Embed in and out tokens; done together as weights may be shared. - tl.Parallel([], [], [], [tl.ShiftRight(mode=mode), - out_encoder]), # vec_e mask_e tok_e vec_d tok_d - - # Encode; then concat encoder and decoder, given encoder mask. - _Encoder(), # vec_e mask_e tok_e vec_d tok_d - _enc_dec_concat(), - - # Run decoder blocks. - _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget, - forget_dense), # vec_ed1 vec_ed2 tok_e tok_d - _XYAvg(), # vec_ed tok_e tok_d - tl.LayerNorm(), # vec_ed tok_e tok_d - - # Separate out the encoder part from the concatenated vector, - # then compute loss. - tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d - t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d - _Loss(), # vec_d tok_d - ) + encoder_decoder_attention_type = [encoder_decoder_attention_type] + for layer_idx in range(n_decoder_layers): + layer_attention_type = encoder_decoder_attention_type[ + layer_idx % len(encoder_decoder_attention_type) + ] + # Grow d_model, d_ff, and d_qkv if requested. + d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 + if half_before_layer and layer_idx >= half_before_layer: + d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value + if double_after_layer and layer_idx > double_after_layer: + d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 + decoder_block = reformer.DecoderBlock( + d_m, + d_f, + d_k, + d_v, + n_heads, + attention_type=layer_attention_type, + dropout=dropout, + ff_activation=ff_activation, + ff_dropout=ff_dropout, + ff_use_sru=ff_use_sru, + ff_chunk_size=ff_chunk_size, + ff_sparsity=ff_sparsity, + attention_chunk_size=attention_chunk_size, + n_attention_layers=n_decoder_attention_layers, + center_layernorm=center_layernorm, + use_bfloat16=use_bfloat16, + mode=mode, + ) + decoder_blocks.append(decoder_block) + if half_before_layer and layer_idx == half_before_layer - 1: + decoder_blocks.append(tl.ReversibleConcatenatePair()) + if double_after_layer and layer_idx == double_after_layer: + decoder_blocks.append(tl.ReversibleConcatenatePair()) + + if mode == "predict": + # After initializing the decoder we can revert to original state of + # previously monkey-patched classes/functions. + tl.attention.DotProductCausalAttention.monkey_patched_mask = lambda x: None + tl.research.sparsity._RememberPad.monkey_patched_mask = ( + lambda x: None + ) # pylint: disable=protected-access + tl.rnn.ScanSRUCell = originalScanSRUCell + + def _Loss(): + return tl.SparseDenseWithOptions( + output_vocab_size, + d_input=d_model2, + sparsity_type=loss_sparsity_type, + sparsity=loss_sparsity, + d_lowrank=loss_d_lowrank, + prob_sparse=loss_sparsity_prob, + use_bfloat16=use_bfloat16, + mode=mode, + ) + + def _enc_dec_concat(): + """Layers to merge encoder and decoder.""" + if reversible_encoder: + return [ + tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d + t2.ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d + ] + else: + return [ + tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d + t2.ConcatWithPadding(mode=mode), # v_ed tok_e tok_d + tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d + ] + + def _inp_layers(): + if input_vocab_size is not None: + return tl.AssertFunction( + "bl,br->bld,bl,bl,br", # b: batch, l/r: enc/dec length, d: vec depth + tl.Serial( # tok_e tok_d + tl.Select([0, 0, 0, 1]), + tl.Parallel(in_encoder, [tl.PaddingMask(), _RemoveAxes12()]), + ), + ) # vec_e mask_e tok_e tok_d + else: + # Input in this case is vec_e, mask_e, tok_d. Where all downstream + # operations expect tok_e, we give it instead mask_e, expecting that + # downstream ops only are looking for padding/not padding. + return tl.AssertFunction( + "blf,bl,br->bld,bl,bl,br", # f: in-feature depth, d: out-vector depth + tl.Serial( # vec_e mask_e tok_d + tl.Select([0, 1, 1, 2]), tl.Parallel(in_encoder, [], _AsTokenIDs()) + ), + ) # vec_e mask_e tok_e tok_d + + # Assemble and return the model. + return tl.Serial( + _inp_layers(), # vec_e mask_e tok_e tok_d + tl.Parallel(tl.Select([0]), portal_mask), + tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss. + # Embed in and out tokens; done together as weights may be shared. + tl.Parallel( + tl.Select([0]), + tl.Select([0]), + tl.Select([0]), + [tl.ShiftRight(mode=mode), out_encoder], + ), # vec_e mask_e tok_e vec_d tok_d + # Encode; then concat encoder and decoder, given encoder mask. + _Encoder(), # vec_e mask_e tok_e vec_d tok_d + _enc_dec_concat(), + # Run decoder blocks. + _ReversibleSerialForget( + decoder_blocks, d_model2, n_layers_forget, forget_dense + ), # vec_ed1 vec_ed2 tok_e tok_d + _XYAvg(), # vec_ed tok_e tok_d + tl.LayerNorm(), # vec_ed tok_e tok_d + # Separate out the encoder part from the concatenated vector, + # then compute loss. + tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d + t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d + _Loss(), # vec_d tok_d + ) def _InsertAxes12(): - """Returns a layer that inserts two internal size-1 axes into an array.""" - return tl.Fn('InsertAxes12', - lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1]))) + """Returns a layer that inserts two internal size-1 axes into an array.""" + return tl.Fn( + "InsertAxes12", lambda x: jnp.reshape(x, (x.shape[0], 1, 1, x.shape[1])) + ) def _RemoveAxes12(): - """Returns a layer that removes two internal size-1 axes from an array.""" - return tl.Fn('RemoveAxes12', lambda x: jnp.squeeze(x, (1, 2))) + """Returns a layer that removes two internal size-1 axes from an array.""" + return tl.Fn("RemoveAxes12", lambda x: jnp.squeeze(x, (1, 2))) def _AsTokenIDs(): - """Returns a layer that makes mask values look like token ID ints.""" - return tl.Fn('AsTokenIDs', lambda x: x.astype(jnp.int32)) + """Returns a layer that makes mask values look like token ID ints.""" + return tl.Fn("AsTokenIDs", lambda x: x.astype(jnp.int32)) def _XYAvg(): - """Returns a layer that computes the element-wise average of two arrays.""" - return tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0) + """Returns a layer that computes the element-wise average of two arrays.""" + return tl.Fn("XYAvg", lambda x, y: (x + y) / 2.0) def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True): - """ReversibleSerial but with a forgetting block every n_layers.""" - if not n_layers or len(layers) <= n_layers + 1: - return tl.ReversibleSerial(layers) - layers1, layers2 = layers[:n_layers], layers[n_layers:] - - if forget_dense: - forgetting_layer = tl.Serial( - _XYAvg(), - tl.Dense(d_model), - tl.Dup(), - ) - else: - forgetting_layer = tl.Select([0, 1]) + """ReversibleSerial but with a forgetting block every n_layers.""" + if not n_layers or len(layers) <= n_layers + 1: + return tl.ReversibleSerial(layers) + layers1, layers2 = layers[:n_layers], layers[n_layers:] + + if forget_dense: + forgetting_layer = tl.Serial( + _XYAvg(), + tl.Dense(d_model), + tl.Dup(), + ) + else: + forgetting_layer = tl.Select([0, 1]) - return tl.Serial( - tl.ReversibleSerial(layers1), - forgetting_layer, - _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense) - ) + return tl.Serial( + tl.ReversibleSerial(layers1), + forgetting_layer, + _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense), + ) def _ConvertToNaNsOnAnyZero(): - def _convert_to_nans(x, y): - # if all values in y are non-zeros, return x; otherwise return 0s - return jnp.where(jnp.all(y, keepdims=False), x, x/0.), y - return tl.Fn('ConvertToNaNsOnAnyZero', _convert_to_nans, n_out=2) + def _convert_to_nans(x, y): + # if all values in y are non-zeros, return x; otherwise return 0s + return jnp.where(jnp.all(y, keepdims=False), x, x / 0.0), y + + return tl.Fn("ConvertToNaNsOnAnyZero", _convert_to_nans, n_out=2) class _PortalInput(tl.Layer): - """Portal input for monkey-patching of mask in predict mode.""" + """Portal input for monkey-patching of mask in predict mode.""" - def __init__(self): - super().__init__(name='_PortalInput', n_out=1, n_in=1) - self._portal_output = _PortalOutput(self) + def __init__(self): + super().__init__(name="_PortalInput", n_out=1, n_in=1) + self._portal_output = _PortalOutput(self) - def forward(self, x): - if isinstance(x, (list, tuple)): - x = x[0] - self.state = (x,) - return x + def forward(self, x): + if isinstance(x, (list, tuple)): + x = x[0] + self.state = (x,) + return x - def init_weights_and_state(self, input_signature): - """Initializes this layer's weights.""" - if isinstance(input_signature, (list, tuple)): - input_signature = input_signature[0] - self.state = (jnp.zeros(input_signature.shape),) + def init_weights_and_state(self, input_signature): + """Initializes this layer's weights.""" + if isinstance(input_signature, (list, tuple)): + input_signature = input_signature[0] + self.state = (jnp.zeros(input_signature.shape),) - def get_value(self): - return self.state[0] + def get_value(self): + return self.state[0] - def get_layer(self): - return self._portal_output + def get_layer(self): + return self._portal_output class _PortalOutput(tl.Layer): - """Portal input for monkey-patching of mask in predict mode.""" + """Portal input for monkey-patching of mask in predict mode.""" - def __init__(self, portal_input): - super().__init__(name='_PortalOutput', n_out=1, n_in=0) - self._portal_input = portal_input + def __init__(self, portal_input): + super().__init__(name="_PortalOutput", n_out=1, n_in=0) + self._portal_input = portal_input - def forward(self, x): - return self._portal_input.get_value() + def forward(self, x): + return self._portal_input.get_value() - def get_value(self): - return self._portal_input.get_value() + def get_value(self): + return self._portal_input.get_value() diff --git a/trax/models/research/terraformer_e2e_test.py b/trax/models/research/terraformer_e2e_test.py deleted file mode 100644 index 2dfd36742..000000000 --- a/trax/models/research/terraformer_e2e_test.py +++ /dev/null @@ -1,99 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End to end test for Reformer.""" - -import os - -from absl.testing import absltest -import gin - -from trax import test_utils -from trax.models.research import terraformer # pylint: disable=unused-import -from trax.supervised import trainer_lib -from trax.tf_numpy import numpy as tf_np # pylint: disable=unused-import - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') -_CONFIG_DIR = os.path.join(pkg_dir, '../../supervised/configs/') - - -class TerraformerE2ETest(absltest.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - gin.add_config_file_search_path(_CONFIG_DIR) - test_utils.ensure_flag('test_tmpdir') - - def test_terraformer_wmt_ende(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('terraformer_wmt_ende.gin') - - gin.bind_parameter('data_streams.data_dir', _TESTDATA) - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('batcher.buckets', - ([512], [batch_size_per_device, batch_size_per_device])) - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - def test_terraformer_copy(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('terraformer_copy.gin') - - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('batcher.buckets', ([64], [1, 1])) # batch size 1. - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - def test_terraformer_purelsh_copy(self): - batch_size_per_device = 2 - steps = 1 - n_layers = 2 - d_ff = 32 - - gin.parse_config_file('terraformer_purelsh_copy.gin') - - gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('batcher.buckets', ([64], [1, 1])) # batch size 1. - gin.bind_parameter('train.steps', steps) - gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers', n_layers) - gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff) - - output_dir = self.create_tempdir().full_path - _ = trainer_lib.train(output_dir=output_dir) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/terraformer_oom_test.py b/trax/models/research/terraformer_oom_test.py deleted file mode 100644 index 2d68819fe..000000000 --- a/trax/models/research/terraformer_oom_test.py +++ /dev/null @@ -1,129 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for OOM for Terraformer .""" - -import functools -import operator - -from absl.testing import absltest -import gin -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.models.research import terraformer - - -class TerraformerOOMTest(absltest.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def test_terraformer_one_step(self): - d_model = 1024 - vocab_size = 14041 - max_len = 16384 - pos_axial = (128, 128) # should multiply to max_len - pos_d_axial_embs = (512, 512) # sum to d model - - assert operator.mul(*pos_axial) == max_len - assert sum(pos_d_axial_embs) == d_model - - d_ff = 4096 - n_heads = 8 - d_attn = d_model // n_heads - - n_buckets = 128 - encoder_chunk_len = (2 * max_len) // n_buckets # 256 - decoder_chunk_len = 2 * encoder_chunk_len # 512 - encoder_n_chunks_after = 1 # since its not causal. - - lsh_self_attention = functools.partial(self._lsh_self_attention_fn(), - n_buckets=n_buckets) - - encoder_lsh_self_attention = functools.partial( - lsh_self_attention, n_chunks_after=encoder_n_chunks_after, - chunk_len=encoder_chunk_len) - - decoder_lsh_self_attention = functools.partial( - lsh_self_attention, n_chunks_after=0, - chunk_len=decoder_chunk_len) - - model = terraformer.ConfigurableTerraformer( - vocab_size, - d_model=d_model, - d_ff=d_ff, - d_attention_key=d_attn, - d_attention_value=d_attn, - n_encoder_layers=1, - n_decoder_layers=1, - n_heads=n_heads, - dropout=0.05, - max_len=max_len, - encoder_attention_type=encoder_lsh_self_attention, - encoder_decoder_attention_type=decoder_lsh_self_attention, - pos_axial_shape=pos_axial, - pos_d_axial_embs=pos_d_axial_embs, - ff_activation=tl.Relu, - ff_use_sru=0, - mode='train', - ) - - def random_sentence(): - return np.random.randint(low=1, high=vocab_size - 1, size=(1, max_len), - dtype=np.int32) - - x = [random_sentence(), random_sentence()] - weights, state = model.init(shapes.signature(x)) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) - # This returns [logits, decoder tokens] - logits = logits_and_dec_toks[0] - loss = fastmath.numpy.mean(logits[..., 0]) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - weights, state, logits = mock_training_step( - x, weights, state, fastmath.random.get_prng(0)) - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/terraformer_test.py b/trax/models/research/terraformer_test.py deleted file mode 100644 index b5344a2f5..000000000 --- a/trax/models/research/terraformer_test.py +++ /dev/null @@ -1,273 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Terraformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax.layers import test_utils -from trax.models.research import terraformer - - -BACKENDS = [fastmath.Backend.JAX] - - -def short_name(b): - if b == fastmath.Backend.JAX: - return 'jax' - else: - return 'tf' - - -class TerraformerTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def _timebin_self_attention_fn(self, use_reference_code=False): - return functools.partial( - tl.SelfAttention, - attention_dropout=0.05, - chunk_len=64, - n_chunks_before=1, - n_parallel_heads=1, - use_reference_code=use_reference_code - ) - - @parameterized.named_parameters( - [('_%s_efficient' % short_name(backend), backend, tl.SelfAttention, False) - for backend in BACKENDS] + - [('_%s_causal' % short_name(backend), backend, tl.CausalAttention, False) - for backend in BACKENDS] + - # NOTE: tl.SelfAttention is not currently working for this case. - [('_%s_preembed' % short_name(backend), backend, tl.CausalAttention, True) - for backend in BACKENDS]) - def test_terraformer_quick(self, backend, encoder_attention_type, preembed): - with fastmath.use_backend(backend): - vocab_size = 2 - input_vocab_size = None if preembed else vocab_size - output_vocab_size = vocab_size if preembed else None - max_len = 2 - - model = terraformer.ConfigurableTerraformer( - input_vocab_size, - d_model=4, - d_ff=4, - n_encoder_layers=1, - n_decoder_layers=1, - n_heads=2, - dropout=0.05, - max_len=max_len, - pos_type=None, - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=2, - mode='train', - output_vocab_size=output_vocab_size, - encoder_attention_type=encoder_attention_type, - ) - - if preembed: - model_inputs = [np.ones((1, max_len, 3)).astype(np.float32), - np.ones((1, max_len)).astype(bool)] - else: - model_inputs = [np.ones((1, max_len)).astype(np.int32)] - x = model_inputs + [np.ones((1, max_len)).astype(np.int32)] - model.init(shapes.signature(x)) - - logits, dec_toks = model(x) - del dec_toks - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - def test_terraformer_deterministic_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - vocab_size = 16 - d_model = 4 - batch_size = 2 - length = 5 - - model_fn = functools.partial( - terraformer.ConfigurableTerraformer, - vocab_size, - d_model=d_model, - d_ff=16, - n_encoder_layers=0, - n_decoder_layers=1, - n_heads=2, - dropout=0.0, - max_len=length*2, - pos_type=None, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - ) - - inp = np.random.randint(vocab_size, size=(batch_size, length)) - out = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_is_deterministic((inp, out), model_fn) - - def test_terraformer_predict_equals_eval(self): - with fastmath.use_backend(fastmath.Backend.JAX): - vocab_size = 16 - d_model = 8 - batch_size = 1 - length = 5 - - model_fn = functools.partial( - terraformer.ConfigurableTerraformer, - vocab_size, - d_model=d_model, - d_ff=16, - n_encoder_layers=1, - n_decoder_layers=1, - n_heads=2, - ff_use_sru=(1, 8), # ? is SRU working? - dropout=0.0, - max_len=(length+7)*2, - pos_type=None, - reversible_encoder=True, - n_decoder_attention_layers=1, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - ) - - # Token id of 0 indicates padding; and predict mode doesn't support it. - inp = np.random.randint(1, vocab_size, size=(batch_size, length)) - inp[:, -2:] = 0 - out = np.zeros((batch_size, length), dtype=np.int32) - - test_utils.test_eval_equals_predict( - (inp, out), model_fn, seq_axis=1, seq_tensor=-1, init_tokens=1) - - def test_terraformer_doubling(self): - vocab_size = 2 - max_len = 2 - - model = terraformer.ConfigurableTerraformer( - vocab_size, - d_model=8, - d_ff=16, - n_encoder_layers=1, - n_decoder_layers=6, - n_heads=2, - dropout=0.05, - max_len=max_len, - pos_type=None, - half_before_layer=2, - double_after_layer=2, - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - mode='train', - ) - - x = [np.ones((1, max_len)).astype(np.int32), - np.ones((1, max_len)).astype(np.int32)] - model.init(shapes.signature(x)) - - logits, dec_toks = model(x) - del dec_toks - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - def test_terraformer_one_step(self): - vocab_size = 32 - max_len = 256 - pos_axial = 16 - assert pos_axial * pos_axial == max_len - - chunk_len = 32 - - # Since 2 * chunk_len * n_buckets should be max_len. - n_buckets = max_len // (2 * chunk_len) - - lsh_self_attention = functools.partial(self._lsh_self_attention_fn(), - chunk_len=chunk_len, - n_buckets=n_buckets) - - timebin_self_attention = self._timebin_self_attention_fn() - - model = terraformer.ConfigurableTerraformer( - vocab_size, - d_model=32, - d_ff=64, - d_attention_key=64, - d_attention_value=64, - n_encoder_layers=2, - n_decoder_layers=2, - n_heads=2, - dropout=0.05, - max_len=max_len, - encoder_attention_type=lsh_self_attention, - encoder_decoder_attention_type=[timebin_self_attention, - lsh_self_attention], - pos_axial_shape=(pos_axial, pos_axial), - pos_d_axial_embs=(64, 192), - ff_activation=tl.Relu, - ff_use_sru=0, - ff_chunk_size=64, - ff_sparsity=8, - mode='train', - ) - - x = [np.ones((1, max_len)).astype(np.int32), - np.ones((1, max_len)).astype(np.int32)] - weights, state = model.init(shapes.signature(x)) - - @fastmath.jit - def mock_training_step(x, weights, state, rng): - def compute_mock_loss(weights): - logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) - # This returns [logits, decoder tokens] - logits = logits_and_dec_toks[0] - loss = fastmath.numpy.mean(logits[..., 0]) - return loss, (new_state, logits) - gradients, (new_state, logits) = fastmath.grad( - compute_mock_loss, has_aux=True)(weights) - new_weights = fastmath.nested_map_multiarg( - lambda w, g: w - 1e-4 * g, weights, gradients) - return new_weights, new_state, logits - - weights, state, logits = mock_training_step( - x, weights, state, fastmath.random.get_prng(0)) - - self.assertEqual(logits.shape, (1, max_len, vocab_size)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 b/trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 deleted file mode 100644 index 271d5aeae..000000000 Binary files a/trax/models/research/testdata/translate_ende_wmt32k-dev-00000-of-00001 and /dev/null differ diff --git a/trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 b/trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 deleted file mode 100644 index ed977fc71..000000000 Binary files a/trax/models/research/testdata/translate_ende_wmt32k-train-00000-of-00001 and /dev/null differ diff --git a/trax/models/research/testdata/vocab.translate_ende_wmt32k.32768.subwords b/trax/models/research/testdata/vocab.translate_ende_wmt32k.32768.subwords deleted file mode 100644 index 3b1ae32bd..000000000 --- a/trax/models/research/testdata/vocab.translate_ende_wmt32k.32768.subwords +++ /dev/null @@ -1,33288 +0,0 @@ -'_' -'_' -', _' -'._' -'the_' -'_' -'in_' -'of_' -'and_' -'to_' -'die_' -'der_' -'und_' -'a_' -'s_' -'-_' -'is_' -'that_' -'zu_' -'for_' -'den_' -'von_' -'on_' -'n_' -'ist_' -'fÃŧr_' -'. _' -'an_' -'en_' -'The_' -'with_' -'are_' -'be_' -'es_' -'das_' -'e_' -'as_' -'auf_' -'des_' -'mit_' -'it_' -'eine_' -'dass_' -'have_' -'nicht_' -'by_' -'I_' -'im_' -'not_' -'this_' -' (_' -'sich_' -' – _' -'was_' -'ein_' -'from_' -'will_' -'Die_' -'werden_' -'’_' -'we_' -'dem_' -'t_' -'has_' -': _' -'which_' -'or_' -'Sie_' -'at_' -'als_' -'er_' -'In_' -'you_' -'wir_' -'auch_' -'sind_' -'r_' -'um_' -'wird_' -'so_' -') _' -'ing_' -''_' -'all_' -'sie_' -'can_' -'d_' -' - _' -'einer_' -'haben_' -'also_' -'hat_' -'European_' -'wie_' -'their_' -'its_' -'more_' -'oder_' -'would_' -'Ãŧber_' -'einen_' -'but_' -'?_' -'ed_' -'ich_' -'y_' -'our_' -'zur_' -'aus_' -'one_' -'been_' -'Das_' -'they_' -'bei_' -'einem_' -'; _' -'Union_' -'should_' -'It_' -'kÃļnnen_' -'nur_' -'durch_' -'This_' -'/_' -'diese_' -'EU_' -'nach_' -'do_' -'countries_' -'de_' -'zum_' -'am_' -'only_' -'We_' -' , _' -'1_' -'kann_' -'other_' -'there_' -'your_' -'o_' -'new_' -'war_' -'dieser_' -'than_' -'noch_' -'about_' -'ung_' -'Mr_' -'Der_' -'2_' -'like_' -'were_' -'Europe_' -'no_' -'President_' -'man_' -'wenn_' -'vor_' -'- _' -'must_' -'Ich_' -'us_' -'aber_' -'m_' -' "_' -'who_' -'these_' -'Es_' -'wurde_' -'up_' -'sein_' -'world_' -'very_' -'time_' -'if_' -'A_' -'ten_' -'US_' -'Hotel_' -'Commission_' -' “_' -'such_' -' ._' -'uns_' -'people_' -'his_' -'out_' -'mehr_' -'Wir_' -'into_' -'te_' -'But_' -'mÃŧssen_' -'5_' -'now_' -'when_' -'most_' -'ihre_' -'3_' -'sehr_' -'many_' -'China_' -'what_' -'some_' -'Kommission_' -'States_' -'!_' -'% _' -'he_' -'them_' -'ng_' -'ly_' -' „_' -'had_' -'l_' -'economic_' -'any_' -')._' -'4_' -'Herr_' -'need_' -'well_' -'years_' -'), _' -'eines_' -'even_' -',_' -'unter_' -'" _' -'diesem_' -'use_' -'muss_' -'alle_' -'over_' -'zwischen_' -'way_' -'make_' -'political_' -'first_' -'hotel_' -'because_' -'work_' -'i_' -'two_' -'Parliament_' -'those_' -'may_' -'bis_' -'keine_' -'policy_' -'my_' -'could_' -'government_' -'important_' -'between_' -'anderen_' -'system_' -'report_' -'Council_' -'gibt_' -'Präsident_' -'support_' -'gegen_' -'Europa_' -'Europäischen_' -'0_' -'vom_' -'just_' -'If_' -'year_' -'seine_' -'” _' -'made_' -'sowie_' -'being_' -'much_' -'ungen_' -'al_' -'able_' -'Member_' -'country_' -'take_' -'where_' -'public_' -'mÃļchte_' -'dieses_' -'“ _' -'after_' -'how_' -'Welt_' -'market_' -'' _' -'Wenn_' -'Jahr_' -' _' -'Menschen_' -'global_' -'both_' -'g_' -'k_' -'(_' -'wurden_' -'international_' -'USA_' -'right_' -'good_' -'Im_' -'As_' -'sondern_' -'da_' -'ihrer_' -'last_' -':_' -'own_' -'Jahren_' -'part_' -'through_' -'Zeit_' -'immer_' -'social_' -'diesen_' -'dies_' -'financial_' -'United_' -'wÃŧrde_' -'long_' -'For_' -'Diese_' -'Länder_' -'6_' -'S_' -'7_' -'same_' -'re_' -'hier_' -'then_' -'unsere_' -'h_' -'growth_' -'andere_' -'damit_' -'against_' -'information_' -'Bericht_' -'jedoch_' -'does_' -'le_' -'end_' -'10_' -'high_' -'ohne_' -'heute_' -'Und_' -'sollten_' -'There_' -'dann_' -'sollte_' -'order_' -'seiner_' -'free_' -'too_' -'point_' -'8_' -'far_' -'’ _' -'under_' -'neue_' -'national_' -'fact_' -'Mitgliedstaaten_' -'europäischen_' -'see_' -'Regierung_' -'within_' -'rights_' -'area_' -'possible_' -'place_' -'ion_' -'her_' -'example_' -'bereits_' -'And_' -'20_' -'course_' -'Parlament_' -'z_' -'kÃļnnte_' -'already_' -'want_' -'room_' -'me_' -'great_' -'since_' -'set_' -'development_' -'Europäische_' -'Entwicklung_' -'number_' -'future_' -'power_' -'denen_' -'That_' -'before_' -' [[_' -'So_' -'neuen_' -'während_' -'still_' -'less_' -'ation_' -'Land_' -'without_' -'c_' -'You_' -'Ein_' -'used_' -'la_' -'ihren_' -'today_' -' | _' -'zwei_' -'Frage_' -'weil_' -'selbst_' -'say_' -'se_' -'wieder_' -'ihr_' -'including_' -'000_' -'! _' -'human_' -'jetzt_' -'dazu_' -'change_' -'problem_' -'large_' -'here_' -'wäre_' -'viele_' -'cannot_' -'therefore_' -'9_' -'find_' -'few_' -'Unternehmen_' -'30_' -'become_' -'Aber_' -'..._' -'trade_' -'America_' -' '_' -'economy_' -'state_' -'believe_' -'15_' -'while_' -'einige_' -'real_' -'Lage_' -'machen_' -'down_' -'level_' -'ers_' -'three_' -'Maßnahmen_' -'C_' -'denn_' -'case_' -'process_' -'next_' -'geht_' -'darauf_' -'better_' -'available_' -'Eine_' -'help_' -'wo_' -'view_' -'per_' -'viel_' -'issue_' -'dafÃŧr_' -'liegt_' -'et_' -'different_' -'means_' -'2009_' -'go_' -'clear_' -'E_' -'Japan_' -'ve_' -'They_' -'unserer_' -'however_' -'Ihnen_' -'00_' -'   _' -'st_' -'rs_' -'best_' -'Ländern_' -'Jahre_' -'waren_' -'services_' -'business_' -'ne_' -'areas_' -'Mit_' -'u_' -'mich_' -'allen_' -'Frau_' -'doch_' -'back_' -'Iran_' -'interest_' -'energy_' -'during_' -'crisis_' -'B_' -'finden_' -'why_' -'question_' -'daß_' -'lich_' -'lassen_' -'habe_' -'day_' -'based_' -'Dies_' -'security_' -'nun_' -'city_' -'D_' -'Committee_' -'private_' -'gut_' -'get_' -'current_' -'small_' -'service_' -'know_' -'problems_' -'access_' -'These_' -'Doch_' -'To_' -'role_' -'every_' -'Politik_' -'wollen_' -'ts_' -'ganz_' -'el_' -'making_' -'com_' -'bar_' -'Commissioner_' -'50_' -'offers_' -'besteht_' -'provide_' -'etwas_' -'common_' -'Staaten_' -'whether_' -'result_' -'letzten_' -'ce_' -'allem_' -'x_' -'seit_' -'politischen_' -'each_' -'come_' -'When_' -'Israel_' -'worden_' -'particular_' -'health_' -'Zimmer_' -'major_' -'law_' -'action_' -'Rat_' -'ch_' -'bietet_' -'Ende_' -'stellen_' -'politische_' -'Wie_' -'tion_' -'think_' -'rate_' -'ob_' -'At_' -'term_' -'Ziel_' -'Probleme_' -'put_' -'ge_' -'World_' -'Ihre_' -'seinen_' -'debate_' -'Russia_' -'Internet_' -'stay_' -'New_' -'Bereich_' -'non_' -'might_' -'measures_' -'citizens_' -'old_' -'markets_' -'given_' -'drei_' -'again_' -'On_' -'zwar_' -'schon_' -'policies_' -'page_' -'million_' -'main_' -'einfach_' -'did_' -'governments_' -'full_' -'Auch_' -'steht_' -'around_' -'American_' -'2008_' -'|_' -'weniger_' -'needs_' -'money_' -'Art_' -'situation_' -'further_' -'UnterstÃŧtzung_' -'Euro_' -'12_' -'said_' -'que_' -'often_' -'little_' -'insbesondere_' -'erhalten_' -'Germany_' -'Fall_' -'rooms_' -'open_' -'life_' -'House_' -'Bank_' -'Arbeit_' -'2000_' -'11_' -'women_' -'ter_' -'rules_' -'ment_' -'institutions_' -'called_' -'With_' -'weiter_' -'taken_' -'give_' -'debt_' -'another_' -'Herrn_' -'Als_' -'sagen_' -'ns_' -'hatte_' -'beim_' -'System_' -'recent_' -'quality_' -'proposal_' -'issues_' -'Sicherheit_' -'systems_' -'rather_' -'mir_' -'T_' -'sten_' -'local_' -'increase_' -'ies_' -'deren_' -'data_' -'budget_' -'Namen_' -'members_' -'hope_' -'große_' -'du_' -'davon_' -'Problem_' -'All_' -'mÃļglich_' -'kommen_' -'gen_' -'sector_' -'second_' -'military_' -'group_' -'continue_' -'close_' -'tun_' -'p_' -'others_' -'ons_' -'ersten_' -'ern_' -'bin_' -'agreement_' -'Teil_' -'German_' -'BÃŧrger_' -'vote_' -'sehen_' -'geben_' -'eigenen_' -'children_' -'P_' -'He_' -'terms_' -'located_' -'What_' -'price_' -'großen_' -'gegenÃŧber_' -'especially_' -'Rahmen_' -'Nach_' -'Millionen_' -'2005_' -'rates_' -'past_' -'necessary_' -'fÃŧhren_' -']] _' -'Bei_' -'though_' -'stehen_' -'legal_' -'led_' -'least_' -'f_' -'away_' -'ab_' -'together_' -'short_' -'os_' -'low_' -'Informationen_' -' $_' -'once_' -'line_' -'investment_' -'basis_' -'Zusammenarbeit_' -'Recht_' -'yet_' -'weitere_' -'off_' -'einmal_' -'dort_' -'Bedeutung_' -'2006_' -'etwa_' -'companies_' -'Auf_' -'using_' -'soll_' -'ry_' -'itself_' -'ensure_' -'dabei_' -'conditions_' -'FÃŧr_' -'2001_' -'third_' -'leaders_' -'ieren_' -'central_' -'aller_' -'\u_' -'Community_' -'pro_' -'position_' -'particularly_' -'him_' -'century_' -'capital_' -'allerdings_' -'Weise_' -'Rolle_' -'One_' -'O_' -'Er_' -'2007_' -'100_' -' & _' -'risk_' -'greater_' -'certain_' -'banks_' -'2004_' -'region_' -'always_' -'Indeed_' -'Ãļffentlichen_' -'ins_' -'innerhalb_' -'create_' -'chen_' -'bringen_' -'meisten_' -'location_' -'form_' -'efforts_' -'single_' -'kÃļnnten_' -'erreichen_' -'ago_' -'Zukunft_' -'France_' -'Fragen_' -'staff_' -'modern_' -'ic_' -'foreign_' -'democracy_' -'Was_' -'However_' -'unseren_' -'states_' -'sogar_' -'society_' -'resources_' -'.”_' -'working_' -'present_' -'key_' -'ihnen_' -'fully_' -'cooperation_' -'age_' -'Seite_' -'start_' -'costs_' -'control_' -'aid_' -'kein_' -'home_' -'done_' -'ble_' -'b_' -'among_' -'alles_' -'UN_' -'East_' -'reform_' -'needed_' -'indem_' -'going_' -'decision_' -'approach_' -'water_' -'prices_' -'Region_' -'M_' -' ..._' -'vielen_' -'hin_' -'excellent_' -'europäische_' -'environment_' -'almost_' -'Vorschlag_' -'Chinese_' -'wÃŧrden_' -'unterstÃŧtzen_' -'nd_' -'left_' -'internationalen_' -'Wirtschaft_' -'Europas_' -'tax_' -'really_' -'protection_' -'meine_' -'education_' -'centre_' -'Weg_' -'Wachstum_' -'Stadt_' -'Leben_' -' ' -'whole_' -'fast_' -'account_' -' _' -'special_' -'period_' -'iert_' -'Deutschland_' -'wissen_' -'schen_' -'nt_' -'ja_' -'having_' -'Um_' -'Frauen_' -'25_' -'reason_' -'never_' -'ischen_' -'higher_' -'early_' -'until_' -'shall_' -'several_' -'nden_' -'name_' -'developing_' -'demand_' -'2003_' -'production_' -'mÃļchten_' -'longer_' -'food_' -'cost_' -'aufgrund_' -'Kinder_' -'G_' -'welche_' -'kommt_' -'interests_' -'general_' -'Seiten_' -'Mrs_' -'Grund_' -'Beispiel_' -'16_' -'standards_' -'sowohl_' -'hand_' -'enough_' -'daher_' -'September_' -'Iraq_' -'wirtschaftlichen_' -'views_' -'play_' -'offer_' -'Group_' -'2002_' -'wirklich_' -'products_' -'nehmen_' -'meiner_' -'look_' -'land_' -'ihm_' -'billion_' -'MÃļglichkeit_' -'International_' -' % _' -'value_' -'taking_' -'seinem_' -'results_' -'remain_' -'progress_' -'plan_' -'peace_' -'matter_' -'darin_' -'No_' -'Dollar_' -';_' -'workers_' -'something_' -'simply_' -'side_' -'respect_' -'resolution_' -'keit_' -'gentlemen_' -'darÃŧber_' -'bieten_' -'Zusammenhang_' -'18_' -'vielleicht_' -'ty_' -'sei_' -'regional_' -'management_' -'liche_' -'following_' -'em_' -'R_' -'Obama_' -'40_' -'un_' -'th_' -'stellt_' -'poor_' -'include_' -'ever_' -'climate_' -'U_' -'Landes_' -'India_' -'show_' -'sche_' -'proposals_' -'men_' -'income_' -'etc_' -'befindet_' -'State_' -'Parlaments_' -'24_' -'themselves_' -'ted_' -'research_' -'regard_' -'nuclear_' -'known_' -'entfernt_' -'economies_' -'difficult_' -'democratic_' -'N_' -'L_' -'Central_' -'site_' -'personal_' -'member_' -'breakfast_' -'big_' -'agree_' -'After_' -'Africa_' -'wish_' -'various_' -'things_' -'lead_' -'ity_' -'framework_' -'force_' -'beiden_' -'Service_' -'Richtlinie_' -'Regierungen_' -'November_' -'Moreover_' -'Bush_' -'wenig_' -'weiterhin_' -'thus_' -'makes_' -'list_' -'industry_' -'besonders_' -'ar_' -'across_' -'above_' -'Man_' -'LÃļsung_' -'Fraktion_' -'By_' -'wichtig_' -'weit_' -'traditional_' -'likely_' -'ihrem_' -'ie_' -'groups_' -'fall_' -'easy_' -'created_' -'Ziele_' -'VerfÃŧgung_' -'Kosten_' -'Investitionen_' -'French_' -'F_' -'programme_' -'oil_' -'found_' -'family_' -'dessen_' -'beispielsweise_' -'Macht_' -'serious_' -'rn_' -'related_' -'proposed_' -'months_' -'lichen_' -'hatten_' -'building_' -'Mittel_' -'transport_' -'nen_' -'macht_' -'keinen_' -'gehen_' -'enden_' -'direkt_' -'Situation_' -'Kommissar_' -'Geschichte_' -'words_' -'verschiedenen_' -'software_' -'schaffen_' -'provided_' -'projects_' -'population_' -'opportunity_' -'near_' -'ll_' -'hinaus_' -'forward_' -'directive_' -'besser_' -'Programm_' -'Milliarden_' -'Daten_' -'BevÃļlkerung_' -'welcome_' -'strategy_' -'nationalen_' -'ihn_' -'changes_' -'ask_' -'allow_' -'Thema_' -'top_' -'questions_' -'les_' -'feel_' -'experience_' -'attention_' -'arbeiten_' -'Some_' -'GDP_' -'Form_' -'reforms_' -'points_' -'online_' -'nichts_' -'nde_' -'ls_' -'halten_' -'gemacht_' -'ben_' -'Russland_' -'Of_' -'21_' -'14_' -'young_' -'user_' -'stand_' -'specific_' -'seen_' -'run_' -'rapporteur_' -'ous_' -'nächsten_' -'light_' -'ierung_' -'freedom_' -'fiscal_' -'Western_' -'Even_' -'Auswirkungen_' -'2010_' -'version_' -'total_' -'times_' -'safety_' -'ren_' -'program_' -'ladies_' -'house_' -'ebenfalls_' -'bedeutet_' -'Treaty_' -'South_' -'Restaurant_' -'Our_' -'Dieses_' -'An_' -'wichtige_' -'true_' -'she_' -'sch_' -'rise_' -'hätte_' -'history_' -'half_' -'eurozone_' -'che_' -'behalf_' -'authorities_' -'achieve_' -'Krieg_' -'Kollegen_' -'Grundlage_' -'. _' -'towards_' -'technology_' -'success_' -'sicher_' -'seems_' -'relations_' -'ische_' -'gute_' -'gilt_' -'ende_' -'employment_' -'either_' -'cultural_' -'adopted_' -'West_' -'Paris_' -'Krise_' -'Ansicht_' -'Alle_' -'19_' -'strong_' -'sollen_' -'schnell_' -'positive_' -'pages_' -'ner_' -'due_' -'deal_' -'bzw_' -'bleibt_' -'Ukraine_' -'sense_' -'remains_' -'nämlich_' -'let_' -'language_' -'km_' -'funds_' -'deutlich_' -'call_' -'amendments_' -'Minister_' -'Institutionen_' -'Ebene_' -'British_' -'April_' -',” _' -'Ê_' -'à_' -'town_' -'sure_' -'statt_' -'significant_' -'regions_' -'range_' -'parties_' -'internationale_' -'individual_' -'euro_' -'erste_' -'ens_' -'ebenso_' -'days_' -'concerns_' -'company_' -'bekannt_' -'Madam_' -'London_' -'Gemeinschaft_' -'Demokratie_' -'ss_' -'setzen_' -'required_' -'ra_' -'levels_' -'lediglich_' -'jobs_' -'jeder_' -'impact_' -'gab_' -'focus_' -'everything_' -'content_' -'community_' -'cases_' -'art_' -'X_' -'Hotels_' -'Geld_' -'EUR_' -'Da_' -'Asia_' -'zurÃŧck_' -'whose_' -'upon_' -'solution_' -'solche_' -'kind_' -'globalen_' -'face_' -'enjoy_' -'currently_' -'consider_' -'competition_' -'Meinung_' -'Indien_' -'Ihr_' -'ting_' -'sen_' -'pay_' -'los_' -'leading_' -'later_' -'inflation_' -'gesagt_' -'fÃŧnf_' -'fÃŧhrt_' -'field_' -'address_' -'Schutz_' -'Members_' -'Jahres_' -'Greece_' -'60_' -' . _' -'step_' -'share_' -'restaurant_' -'quite_' -'natÃŧrlich_' -'keep_' -'growing_' -'four_' -'former_' -'eren_' -'environmental_' -'concerned_' -'center_' -'balance_' -'actually_' -'Windows_' -'Vereinigten_' -'II_' -'Hilfe_' -'Dieser_' -'ways_' -'values_' -'unserem_' -'understand_' -'thank_' -'son_' -'sea_' -'return_' -'minutes_' -'majority_' -'legislation_' -'indeed_' -'improve_' -'ia_' -'held_' -'genau_' -'game_' -'erreicht_' -'ent_' -'date_' -'bleiben_' -'benefit_' -'accept_' -'War_' -'TV_' -'Haus_' -'Gesellschaft_' -'1999_' -'station_' -'rule_' -'recently_' -'outside_' -'na_' -'importance_' -'appropriate_' -'Zugang_' -'V_' -'Turkey_' -'Menschenrechte_' -'Beziehungen_' -'BIP_' -'text_' -'subject_' -'responsibility_' -'popular_' -'move_' -'meeting_' -'live_' -'fundamental_' -'Italy_' -'ste_' -'spending_' -'scheint_' -'potential_' -'natural_' -'media_' -'involved_' -'five_' -'einigen_' -'daran_' -'clearly_' -'bring_' -'air_' -'Strategie_' -'Rates_' -'Office_' -'Interesse_' -'Ihrer_' -'Erfolg_' -')_' -'"_' -'wide_' -'weltweit_' -'reduce_' -'project_' -'party_' -'obwohl_' -'model_' -'keyword_' -'ive_' -'internal_' -'hard_' -'handelt_' -'developed_' -'deshalb_' -'conflict_' -'car_' -'ca_' -'benefits_' -'Liste_' -'Frankreich_' -'Dienstleistungen_' -'Bezug_' -'Aus_' -'type_' -'thing_' -'sprechen_' -'size_' -'politik_' -'opinion_' -'lot_' -'jeden_' -'ions_' -'increased_' -'huge_' -'hours_' -'hen_' -'heit_' -'heart_' -'gerade_' -'forces_' -'facilities_' -'dÃŧrfen_' -'beautiful_' -'addition_' -'Russian_' -'Nähe_' -'Mitglieder_' -'M' -' : _' -'wirtschaftliche_' -'w_' -'tatsächlich_' -'principle_' -'nutzen_' -'ness_' -'ging_' -'exchange_' -'ert_' -'comes_' -'certainly_' -'anti_' -'ability_' -'UK_' -'Präsidenten_' -'Grenzen_' -'13_' -'&_' -'via_' -'verwendet_' -'verhindern_' -'verfÃŧgt_' -'pool_' -'person_' -'parts_' -'mean_' -'lange_' -'ig_' -'effective_' -'effect_' -'care_' -'bed_' -'application_' -'amerikanischen_' -'T' -'Kampf_' -'Banken_' -'17_' -'"._' -'tes_' -'response_' -'reach_' -'product_' -'poverty_' -'integration_' -'bereit_' -'bank_' -'administration_' -'Tatsache_' -'Staat_' -'Spain_' -'Party_' -'Middle_' -'La_' -'Berlin_' -'BemÃŧhungen_' -'Bedingungen_' -'B' -'vergangenen_' -'turn_' -'stark_' -'self_' -'rt_' -'provides_' -'ideal_' -'decisions_' -'civil_' -'add_' -'W_' -'Richtung_' -'Park_' -'Government_' -'Deshalb_' -'Bereichen_' -' ( _' -'weiß_' -'simple_' -'prevent_' -'living_' -'igen_' -'idea_' -'doing_' -'decades_' -'ck_' -'check_' -'Vor_' -'Von_' -'More_' -'Debatte_' -'Außerdem_' -' -_' -'wären_' -'v_' -'risks_' -'regulation_' -'oft_' -'klar_' -'erst_' -'directly_' -'direct_' -'computer_' -'besten_' -'ant_' -'along_' -'Verantwortung_' -'Terrorismus_' -'San_' -'MÃļglichkeiten_' -'Ergebnis_' -'Afghanistan_' -'standard_' -'stability_' -'require_' -'option_' -'meet_' -'job_' -'highly_' -'gemeinsamen_' -'file_' -'features_' -'cy_' -'currency_' -'complete_' -'co_' -'clean_' -'challenges_' -'body_' -'While_' -'Wert_' -'Umsetzung_' -'Tatsächlich_' -'Tag_' -'Gruppe_' -'Fortschritte_' -'December_' -'Aussprache_' -'23_' -'zeigen_' -'wahrscheinlich_' -'terrorism_' -'strategic_' -'space_' -'rest_' -'responsible_' -'nothing_' -'neu_' -'movement_' -'moment_' -'image_' -'gemeinsame_' -'gegeben_' -'domestic_' -'di_' -'develop_' -'dar_' -'contains_' -'capacity_' -'alten_' -'York_' -'Wahl_' -'Schritt_' -'Preis_' -'Hier_' -'“, _' -'tragen_' -'sub_' -'stop_' -'reviews_' -'regime_' -'post_' -'opportunities_' -'nature_' -'ke_' -'del_' -'consequences_' -'complex_' -'commitment_' -'border_' -'My_' -'How_' -'Booking_' -'Abstimmung_' -' ''_' -'   – _' -'zusammen_' -'ur_' -'unser_' -'table_' -'sozialen_' -'similar_' -'road_' -'reasons_' -'perhaps_' -'network_' -'lack_' -'kleinen_' -'heißt_' -'favour_' -'extremely_' -'customers_' -'comfortable_' -'brauchen_' -'Partei_' -'Mal_' -'K_' -'Irak_' -'Aufgabe_' -'500_' -'weiteren_' -'wegen_' -'unterstÃŧtzt_' -'shown_' -'red_' -'politics_' -'notwendig_' -'liegen_' -'initiative_' -'increasingly_' -'gas_' -'existing_' -'elections_' -'ding_' -'darf_' -'culture_' -'challenge_' -'built_' -'betrifft_' -'below_' -'beach_' -'ary_' -'activities_' -'Yet_' -'TÃŧrkei_' -'Punkt_' -'Prozess_' -'National_' -'January_' -'Interessen_' -'Druck_' -'August_' -'Abkommen_' -'80_' -'ziehen_' -'visit_' -'training_' -'ta_' -'protect_' -'performance_' -'night_' -'ma_' -'looking_' -'friendly_' -'effects_' -'came_' -'cal_' -'befinden_' -'became_' -'amount_' -'Zu_' -'Seit_' -'Prozent_' -'Platz_' -'O' -'Initiative_' -'Gewalt_' -'G' -'Eurozone_' -'Ergebnisse_' -'Chinas_' -'Ausschuss_' -'wichtigen_' -'walk_' -'verstehen_' -'try_' -'took_' -'takes_' -'pressure_' -'nie_' -'negotiations_' -'konnte_' -'je_' -'il_' -'grÃļßten_' -'gleichzeitig_' -'gefÃŧhrt_' -'essential_' -'einzelnen_' -'design_' -'derzeit_' -'cause_' -'asked_' -'anderer_' -'Software_' -'Presidency_' -'Ort_' -'Ihren_' -'", _' -'īŋŊ_' -'structural_' -'schließlich_' -'probably_' -'ors_' -'news_' -'lässt_' -'infrastructure_' -'included_' -'glaube_' -'ful_' -'events_' -'entwickelt_' -'election_' -'credit_' -'creating_' -'choice_' -'ally_' -'Website_' -'Verhandlungen_' -'Security_' -'Nationen_' -'H' -'Artikel_' -' ‘_' -'zeit_' -'wohl_' -'violence_' -'vier_' -'verschiedene_' -'unternehmen_' -'threat_' -'successful_' -'soziale_' -'soon_' -'receive_' -'largest_' -'isch_' -'hold_' -'gehÃļrt_' -'findet_' -'entwickeln_' -'despite_' -'deficit_' -'card_' -'answer_' -'Z' -'Vorschläge_' -'Version_' -'Verfahren_' -'Verbindung_' -'Regionen_' -'Personen_' -'Pakistan_' -'Markt_' -'Kunden_' -'Enterprise_' -'90_' -'27_' -'sustainable_' -'style_' -'politischer_' -'net_' -'ling_' -'leadership_' -'labor_' -'kaum_' -'gemeinsam_' -'fight_' -'eher_' -'effort_' -'don_' -'context_' -'committee_' -'chinesischen_' -'board_' -'behind_' -'ate_' -'although_' -'alone_' -'advanced_' -'act_' -'according_' -'Zum_' -'Während_' -'Verfassung_' -'V' -'Uhr_' -'North_' -'NATO_' -'May_' -'General_' -'Finally_' -'Entscheidung_' -'City_' -'Blick_' -'22_' -'. - (_' -'Öffentlichkeit_' -'zehn_' -'x' -'werde_' -'week_' -'video_' -'v' -'unique_' -'steps_' -'sound_' -'six_' -'purpose_' -'programs_' -'practice_' -'partner_' -'nor_' -'nice_' -'monetary_' -'mind_' -'massive_' -'lt_' -'lower_' -'limited_' -'leben_' -'internet_' -'industrial_' -'failed_' -'event_' -'distribution_' -'designed_' -'darum_' -'basic_' -'Since_' -'Hinblick_' -'Handel_' -'Hand_' -'Gebiet_' -'FÃļrderung_' -'Court_' -'Bar_' -'Armut_' -'Arab_' -'-, _' -'wichtigsten_' -'toward_' -'started_' -'spielen_' -'rund_' -'rich_' -'reality_' -'previous_' -'places_' -'music_' -'jedes_' -'instead_' -'ings_' -'includes_' -'guests_' -'goods_' -'funding_' -'front_' -'expect_' -'death_' -'consumers_' -'brought_' -'bisher_' -'bad_' -'Zahl_' -'Tage_' -'K' -'Forschung_' -'Bekämpfung_' -'Antwort_' -'Allerdings_' -'Abgeordneten_' -'..." _' -' / _' -'ze_' -'seines_' -'recht_' -'perfect_' -'nahe_' -'matters_' -'ken_' -'jede_' -'implementation_' -'ger_' -'failure_' -'established_' -'erster_' -'decided_' -'cut_' -'cs_' -'average_' -'aktuellen_' -'] _' -'Regeln_' -'Obwohl_' -'Now_' -'Großbritannien_' -'From_' -'BehÃļrden_' -'Anfang_' -'Am_' -'Ãļffentliche_' -'ändern_' -'weapons_' -'quickly_' -'official_' -'negative_' -'lives_' -'i' -'guest_' -'guarantee_' -'gebracht_' -'enthalten_' -'double_' -'damage_' -'consumer_' -'completely_' -'beginning_' -'Reformen_' -'March_' -'June_' -'Berichterstatter_' -'160_' -'zed_' -'weltweiten_' -'produce_' -'principles_' -'objective_' -'note_' -'mentioned_' -'latest_' -'häufig_' -'hätten_' -'helfen_' -'emissions_' -'beyond_' -'base_' -'annual_' -'alternative_' -'agreements_' -'achieved_' -'University_' -'St_' -'Ressourcen_' -'Gelegenheit_' -'FÃŧhrung_' -'Development_' -'DarÃŧber_' -'Damit_' -'D' -'Amerika_' -'1998_' -'1990_' -' " _' -'trotz_' -'thought_' -'tell_' -'request_' -'presented_' -'please_' -'month_' -'late_' -'ht_' -'generation_' -'final_' -'erforderlich_' -'effectively_' -'concern_' -'con_' -'cher_' -'ces_' -'categories_' -'bus_' -'avoid_' -'ations_' -'X' -'Wort_' -'Today_' -'Schlusselwortern_' -'Reihe_' -'Produkte_' -'Many_' -'Erklärung_' -'E' -'Dialog_' -'Datei_' -'Bildung_' -'Beginn_' -'Barcelona_' -'Amerikas_' -'200_' -'verbunden_' -'union_' -'supported_' -'study_' -'speak_' -'secure_' -'procedure_' -'opposition_' -'mus_' -'hoffe_' -'guten_' -'ground_' -'gar_' -'emerging_' -'einschließlich_' -'digital_' -'d' -'coming_' -'choose_' -'Zentrum_' -'Umwelt_' -'Qualität_' -'Politiker_' -'Nations_' -'Kontrolle_' -'IMF_' -'Here_' -'Gesundheit_' -'Gefahr_' -'Entscheidungen_' -'English_' -'70_' -'zeigt_' -'wobei_' -'voll_' -'unemployment_' -'throughout_' -'technical_' -'stärker_' -'solutions_' -'setting_' -'schaft_' -'rising_' -'programmes_' -'package_' -'p' -'mobile_' -'mal_' -'lost_' -'increasing_' -'hinsichtlich_' -'gleichen_' -'gesamten_' -'gehÃļren_' -'ft_' -'forms_' -'follow_' -'famous_' -'fair_' -'erwarten_' -'entire_' -'enthält_' -'dollar_' -'danken_' -'considered_' -'child_' -'book_' -'banking_' -'aware_' -'allowed_' -'allein_' -'Wann_' -'Schaffung_' -'Rights_' -'Osten_' -'Herren_' -'Funktion_' -'First_' -'Erweiterung_' -'Einsatz_' -'Debian_' -'Daher_' -'Ãŧbernachten_' -'works_' -'ut_' -'submitted_' -'später_' -'sometimes_' -'situated_' -'richtig_' -'restaurants_' -'reached_' -'mail_' -'len_' -'ker_' -'hinter_' -'ganzen_' -'function_' -'finance_' -'fest_' -'fen_' -'ermÃļglichen_' -'enable_' -'doubt_' -'closed_' -'cities_' -'cht_' -'charge_' -'caused_' -'campaign_' -'bekommen_' -'agricultural_' -'agenda_' -'actions_' -'Unsere_' -'Such_' -'Spanien_' -'Opfer_' -'Only_' -'Nutzung_' -'LÃļsungen_' -'Greek_' -'Einfluss_' -'Damen_' -'Dabei_' -'CD_' -'Ansatz_' -'Afrika_' -'äußerst_' -'wrong_' -'web_' -'vÃļllig_' -'verwenden_' -'status_' -'sorgen_' -'shows_' -'served_' -'productivity_' -'privaten_' -'politicians_' -'offered_' -'independent_' -'files_' -'est_' -'einzige_' -'began_' -'aspects_' -'anything_' -'ans_' -'accommodation_' -'Viele_' -'Spiel_' -'Rechte_' -'Putin_' -'Minuten_' -'Korea_' -'Jahrhunderts_' -'Economic_' -'Durch_' -'Dank_' -'DE_' -'Bild_' -'Ausschusses_' -'1791_' -'Ãŧberhaupt_' -'Änderungsantrag_' -'zudem_' -'voted_' -'treatment_' -'speed_' -'server_' -'send_' -'reduction_' -'promote_' -'plus_' -'original_' -'objectives_' -'nation_' -'n' -'lose_' -'ler_' -'larger_' -'j' -'instruments_' -'ihres_' -'globale_' -'gewährleisten_' -'genug_' -'fishing_' -'external_' -'ermÃļglicht_' -'enlargement_' -'eigene_' -'ds_' -'demands_' -'consumption_' -'confidence_' -'commercial_' -'build_' -'b' -'Y_' -'Wahlen_' -'W' -'Risiken_' -'NatÃŧrlich_' -'Mitglied_' -'Mehrheit_' -'Meer_' -'Integration_' -'Industrie_' -'F' -'Dezember_' -'Dazu_' -'45_' -', ' -' % _' -'worked_' -'verfÃŧgen_' -'verbessern_' -'ure_' -'tra' -'test_' -'search_' -'scale_' -'requires_' -'received_' -'prepared_' -'plans_' -'paid_' -'ks_' -'konnten_' -'jene_' -'helpful_' -'giving_' -'falls_' -'facing_' -'eben_' -'class_' -'bevor_' -'bestehen_' -'ated_' -'additional_' -'Text_' -'Services_' -'Server_' -'Schulden_' -'S' -'Republic_' -'Palestinian_' -'Märkte_' -'Kingdom_' -'Italian_' -'Gästebewertungen_' -'Furthermore_' -'Frieden_' -'Folgen_' -'Federal_' -'Fed_' -'Familie_' -'CO2_' -'Anwendung_' -'Angesichts_' -'Although_' -'1781_' -'zone_' -'zahlen_' -'won_' -'website_' -'warum_' -'trust_' -'story_' -'stable_' -'setzt_' -'requirements_' -'refugees_' -'read_' -'options_' -'nimmt_' -'minister_' -'lang_' -'item_' -'ine_' -'getting_' -'gestellt_' -'fragen_' -'fear_' -'extra_' -'expected_' -'dadurch_' -'critical_' -'bitte_' -'au_' -'apartment_' -'active_' -'Zudem_' -'Syria_' -'Kraft_' -'Jahrhundert_' -'Israeli_' -'I' -'His_' -'Fund_' -'Europeans_' -'Convention_' -'C' -'Britain_' -'Ausgaben_' -'28_' -'worldwide_' -'wealth_' -'tic_' -'thousands_' -'speech_' -'sign_' -'ses_' -'sectors_' -'sechs_' -'saying_' -'sagte_' -'players_' -'neben_' -'links_' -'legislative_' -'leader_' -'knowledge_' -'jedem_' -'influence_' -'großer_' -'geworden_' -'genommen_' -'efficient_' -'creation_' -'contact_' -'concerning_' -'carried_' -'c' -'braucht_' -'bereich_' -'ber_' -'attack_' -'airport_' -'added_' -'Verbraucher_' -'Programme_' -'Portugal_' -'Kosovo_' -'Information_' -'Dinge_' -'Center_' -'Aufmerksamkeit_' -'Affairs_' -'300_' -'ying_' -'versuchen_' -'team_' -'strengthen_' -'star_' -'raise_' -'quiet_' -'providing_' -'neuer_' -'mÃļglicherweise_' -'lines_' -'kleine_' -'kam_' -'ism_' -'ierte_' -'happen_' -'gt_' -'grÃļßte_' -'gleich_' -'gewesen_' -'frei_' -'folgt_' -'fallen_' -'exports_' -'evidence_' -'distance_' -'changed_' -'businesses_' -'borders_' -'angesichts_' -'Zeitpunkt_' -'Wasser_' -'WTO_' -'Verbesserung_' -'Rechts' -'Prime_' -'People_' -'Organisation_' -'Mediterranean_' -'Lisbon_' -'Linux_' -'Linie_' -'L' -'July_' -'Ihrem_' -'Hamas_' -'Fehler_' -'EinfÃŧhrung_' -'DVD_' -'Ausdruck_' -'African_' -'“._' -'ß_' -'zweiten_' -'ves_' -'versucht_' -'tions_' -'supply_' -'stage_' -'source_' -'solidarity_' -'sh_' -'seem_' -'school_' -'safe_' -'relationship_' -'reducing_' -'parking_' -'ngen_' -'nbsp_' -'nations_' -'morning_' -'meinen_' -'map_' -'lle_' -'leicht_' -'leave_' -'join_' -'investors_' -'inter' -'insurance_' -'ideas_' -'hohen_' -'hoch_' -'goal_' -'gesamte_' -'erklären_' -'dangerous_' -'ausgestattet_' -'applications_' -'allows_' -'[_' -'Verwendung_' -'Themen_' -'Tagen_' -'Room_' -'Raum_' -'Not_' -'John_' -'Infrastruktur_' -'Idee_' -'IT_' -'Herausforderung_' -'Given_' -'FrÃŧhstÃŧck_' -'Flughafen_' -'During_' -'Aktivitäten_' -' ) _' -'Änderungsanträge_' -'   ._' -'   . _' -'zweite_' -'wenige_' -'usually_' -'uses_' -'users_' -'structure_' -'stock_' -'spend_' -'solchen_' -'schÃŧtzen_' -'reports_' -'recovery_' -'provisions_' -'possibility_' -'organisation_' -'office_' -'nationale_' -'medical_' -'mainly_' -'m' -'justice_' -'ise_' -'innovation_' -'illegal_' -'ier_' -'getan_' -'geschaffen_' -'extent_' -'erklärt_' -'e' -'discussion_' -'direction_' -'connection_' -'concept_' -'communication_' -'bit_' -'anders_' -'Zeiten_' -'Werte_' -'Vertrauen_' -'Unter_' -'Systems_' -'Stunden_' -'Social_' -'So' -'Schließlich_' -'Präsidentin_' -'Please_' -'Oktober_' -'Notwendigkeit_' -'Nahen_' -'Muslim_' -'Latin_' -'Juni_' -'Its_' -'Haushalts' -'H_' -'Gaza_' -'Fällen_' -'Constitution_' -'Arbeitnehmer_' -'Anzahl_' -'29_' -'2011_' -' "..._' -'Änderungen_' -'written_' -'writing_' -'ties_' -'technologies_' -'target_' -'statement_' -'rten_' -'religious_' -'regards_' -'primary_' -'played_' -'mission_' -'material_' -'leisten_' -'joint_' -'immigration_' -'immediately_' -'hohe_' -'gs_' -'gewinnen_' -'fÃļrdern_' -'firms_' -'fe_' -'farmers_' -'f' -'everyone_' -'erneut_' -'enter_' -'elements_' -'efficiency_' -'difference_' -'circumstances_' -'buy_' -'bringt_' -'bare_' -'authority_' -'anti' -'activity_' -'Wettbewerb_' -'Web_' -'Vergangenheit_' -'Station_' -'Schritte_' -'Saudi_' -'Parteien_' -'P' -'Nur_' -'Microsoft_' -'Mai_' -'Lissabon_' -'Let_' -'Kultur_' -'J' -'George_' -'ECB_' -'Conference_' -'Club_' -'Chance_' -'Bad_' -'? _' -'4' -'35_' -'win_' -'wichtiger_' -'ver_' -'talk_' -'suggest_' -'students_' -'si_' -'sed_' -'save_' -'round_' -'regarding_' -'produced_' -'path_' -'pass_' -'partners_' -'park_' -'officials_' -'networks_' -'nach' -'mis' -'meinem_' -'lÃļsen_' -'ka_' -'implemented_' -'ien_' -'hotels_' -'historical_' -'genießen_' -'genannten_' -'equal_' -'encourage_' -'degree_' -'darstellt_' -'construction_' -'chance_' -'ch' -'cally_' -'britischen_' -'aufzu' -'attacks_' -'asset_' -'ance_' -'Waffen_' -'Vertrag_' -'Stelle_' -'Status_' -'Spanish_' -'Preise_' -'Person_' -'Mail_' -'Japanese_' -'Herzen_' -'Herausforderungen_' -'Great_' -'Erstens_' -'Design_' -'Centre_' -'Bestimmungen_' -'Article_' -'Anti' -'--_' -', “_' -'). _' -'“_' -'zumindest_' -'wanted_' -'variety_' -'treffen_' -'transfer_' -'train_' -'somit_' -'sieht_' -'shared_' -'section_' -'scientific_' -'regulations_' -'priority_' -'priorities_' -'powerful_' -'phone_' -'mit' -'minute_' -'message_' -'ments_' -'loss_' -'kurz_' -'ige_' -'hoch' -'head_' -'handeln_' -'h' -'grÃļßere_' -'gi' -'financing_' -'erte_' -'equipped_' -'equipment_' -'durchgefÃŧhrt_' -'dialogue_' -'denke_' -'committed_' -'außer_' -'apply_' -'aim_' -'agreed_' -'Westen_' -'Vorteile_' -'Technologien_' -'Tat_' -'Soviet_' -'She_' -'Ra' -'Position_' -'Nutzen_' -'März_' -'Just_' -'Juli_' -'Islamic_' -'Ireland_' -'Griechenland_' -'Grand_' -'Finanzierung_' -'Directive_' -'Clinton_' -'Bitte_' -'Auffassung_' -'Asien_' -'Anstieg_' -'According_' -': ' -'31_' -'zen_' -'word_' -'wants_' -'vor' -'vermeiden_' -'ue_' -'types_' -'treten_' -'travel_' -'transparent_' -'tor_' -'te' -'task_' -'ster_' -'ssen_' -'schwierig_' -'rural_' -'property_' -'professional_' -'precisely_' -'practical_' -'powers_' -'par' -'operations_' -'ning_' -'method_' -'meines_' -'mein_' -'mehrere_' -'maintain_' -'las_' -'improving_' -'gezeigt_' -'geschlossen_' -'feature_' -'express_' -'establish_' -'entspricht_' -'easily_' -'drive_' -'diejenigen_' -'developments_' -'core_' -'code_' -'bar' -'aufgenommen_' -'attempt_' -'angenommen_' -'analysis_' -'Zweitens_' -'Währung_' -'Washington_' -'Vielfalt_' -'Umgebung_' -'Street_' -'Star_' -'Site_' -'Selbst_' -'Schwierigkeiten_' -'Personal_' -'PC_' -'Management_' -'Ko' -'Italien_' -'Inflation_' -'Guest_' -'Geschäfts' -'Freiheit_' -'Energie_' -'EN_' -'Co' -'Auswahl_' -': "_' -'2015_' -'1997_' -'* _' -' [_' -'Ãŧbernehmen_' -'Änderung_' -'ÂŽ _' -'zahlreiche_' -'white_' -'weisen_' -'weder_' -'tabled_' -'res_' -'record_' -'procedures_' -'pleased_' -'participation_' -'ones_' -'nder_' -'modified_' -'moderne_' -'lichkeit_' -'ität_' -'ismus_' -'island_' -'instrument_' -'insgesamt_' -'impossible_' -'historischen_' -'hands_' -'freien_' -'expensive_' -'erzielt_' -'erwartet_' -'draft_' -'demokratischen_' -'decline_' -'comprehensive_' -'carry_' -'carbon_' -'bestimmten_' -'beide_' -'becoming_' -'ar' -'ang_' -'ad_' -'Veränderungen_' -'Vereinten_' -'Therefore_' -'Solidarität_' -'Sea_' -'Punkte_' -'Poland_' -'Plan_' -'Lebens' -'KÃŧche_' -'IWF_' -'Heute_' -'Haupt' -'Gruppen_' -'Forum_' -'Do_' -'Bereiche_' -'Augen_' -'Amsterdam_' -'Americans_' -'%._' -'zunächst_' -'zentrale_' -'z' -'worth_' -'votes_' -'vision_' -'ums_' -'turned_' -'trading_' -'starting_' -'sites_' -'reserves_' -'reflect_' -'reference_' -'ps_' -'press_' -'presidential_' -'pre_' -'pe_' -'ory_' -'opening_' -'numerous_' -'normal_' -'nce_' -'maßnahmen_' -'linked_' -'keiten_' -'kannst_' -'k' -'gesetzt_' -'export_' -'enen_' -'ell_' -'do' -'deutschen_' -'details_' -'deine_' -'corporate_' -'benÃļtigen_' -'begann_' -'appear_' -'amerikanische_' -'ale_' -'akzeptieren_' -'affected_' -'adopt_' -'Woche_' -'Verordnung_' -'Unabhängigkeit_' -'Technologie_' -'Te' -'Tagesordnung_' -'Secondly_' -'Rooms_' -'Revolution_' -'Reform_' -'October_' -'Nachfrage_' -'Most_' -'Markt' -'Le_' -'High_' -'Folge_' -'England_' -'El_' -'EZB_' -'Barack_' -'Also_' -'A' -'.)_' -'zusätzliche_' -'worse_' -'wer_' -'walking_' -'voting_' -'um' -'ton_' -'told_' -'tischen_' -'taxes_' -'suchen_' -'stets_' -'seek_' -'screen_' -'scher_' -'running_' -'review_' -'representatives_' -'relevant_' -'relativ_' -'quote_' -'president_' -'plant_' -'planet_' -'parliamentary_' -'medium_' -'managed_' -'licher_' -'ley_' -'leider_' -'introduced_' -'ice_' -'hält_' -'hour_' -'holiday_' -'heard_' -'got_' -'gives_' -'gefunden_' -'folgen_' -'erkennen_' -'easier_' -'durchaus_' -'described_' -'deep_' -'decade_' -'daily_' -'conference_' -'competitiveness_' -'colleagues_' -'beitragen_' -'begin_' -'becomes_' -'atmosphere_' -'aimed_' -'Wirtschafts' -'Steuer' -'Stabilität_' -'Sozial' -'Sektor_' -'Restaurants_' -'Nicht_' -'Neu' -'Museum_' -'Konferenz_' -'Human_' -'He' -'Haltung_' -'Global_' -'Free_' -'Forschungs' -'Diskussion_' -'Code_' -'Berichte_' -'Beitrag_' -'Because_' -'Airport_' -'., _' -'. ' -''' _' -'ys_' -'unten_' -'trying_' -'territory_' -'summer_' -'smaller_' -'sing_' -'sides_' -'shopping_' -'ship_' -'series_' -'saw_' -'reichen_' -'police_' -'paar_' -'offen_' -'legen_' -'kosten_' -'ju' -'jeweiligen_' -'investments_' -'increases_' -'improved_' -'il' -'green_' -'gekommen_' -'floor_' -'erfÃŧllen_' -'erfolgreich_' -'elsewhere_' -'eindeutig_' -'draw_' -'documents_' -'disease_' -'corruption_' -'continent_' -'consensus_' -'congratulate_' -'bathroom_' -'außerhalb_' -'associated_' -'assets_' -'appears_' -'advantage_' -'Why_' -'Vor' -'Vo' -'Unterkategorien_' -'Syrien_' -'Sicht_' -'Resort_' -'Reserve_' -'R' -'Pro' -'Policy_' -'Per' -'NEW_' -'N' -'Mexico_' -'Maße_' -'Kommissarin_' -'Google_' -'Einigung_' -'Du' -'Da' -'Austria_' -'Am' -'Al_' -'Ad' -'3' -'26_' -'2013_' -'2' -'zunehmend_' -'zuletzt_' -'zation_' -'wing_' -'wesentlich_' -'weise_' -'warm_' -'vorgeschlagen_' -'vital_' -'verfolgen_' -'vast_' -'useful_' -'tools_' -'terrorist_' -'t' -'sales_' -'rung_' -'ring_' -'regionale_' -'refer_' -'presence_' -'po' -'per' -'paper_' -'overall_' -'org_' -'ons' -'offering_' -'oben_' -'modernen_' -'letzte_' -'learn_' -'lag_' -'inzwischen_' -'initiatives_' -'ial_' -'historic_' -'haus_' -'greatest_' -'goes_' -'gelangen_' -'ganze_' -'gain_' -'g' -'friends_' -'fahren_' -'document_' -'district_' -'dagegen_' -'crucial_' -'controls_' -'contribution_' -'contrary_' -'compromise_' -'competitive_' -'boost_' -'bewusst_' -'assessment_' -'apartments_' -'agriculture_' -'addressed_' -'ably_' -'Weltwirtschaft_' -'Unser_' -'Trade_' -'Team_' -'Standards_' -'Spa_' -'RE' -'Natur_' -'Land' -'Kredit' -'Kolleginnen_' -'Kindern_' -'Innovation_' -'HÃļhe_' -'Hälfte_' -'Golf_' -'Egypt_' -'Den_' -'Dateien_' -'Costa_' -'Bra' -'Bewegung_' -'Besuch_' -'Beschäftigung_' -'Angebot_' -'A5_' -'9' -'.  _' -'Ãŧberzeugt_' -'ßen_' -'www_' -'tis' -'thinking_' -'targets_' -'stimmen_' -'sse_' -'specifically_' -'so' -'seeking_' -'reading_' -'rapidly_' -'pursue_' -'prime_' -'politisch_' -'otherwise_' -'nächste_' -'nis_' -'ni_' -'nachdem_' -'muß_' -'moral_' -'minimum_' -'mass_' -'log_' -'limits_' -'laws_' -'largely_' -'ja' -'ite_' -'innen_' -'ierten_' -'ied_' -'ical_' -'hÃļren_' -'hin' -'hear_' -'glauben_' -'games_' -'freie_' -'forced_' -'finally_' -'families_' -'eventually_' -'era_' -'entweder_' -'entsprechenden_' -'endlich_' -'eigentlich_' -'dÃŧrfte_' -'derartige_' -'darstellen_' -'connected_' -'conclusion_' -'calling_' -'box_' -'bezÃŧglich_' -'betrachtet_' -'bald_' -'außerdem_' -'applied_' -'amendment_' -']], _' -'Welt' -'Versuch_' -'Urlaub_' -'USS_' -'Sorge_' -'Sicherheits' -'RÃŧck' -'Public_' -'Projekt_' -'Pa' -'Lu' -'Le' -'Januar_' -'J_' -'Is_' -'Inter' -'Installation_' -'IS' -'Health_' -'GrÃŧnden_' -'Globalisierung_' -'Generation_' -'Gegensatz_' -'Einige_' -'Benutzer_' -'Beitritt_' -'Amendment_' -': „_' -'32_' -'1980_' -'и' -'zufolge_' -'za' -'wiederum_' -'weeks_' -'verbundenen_' -'ultimately_' -'tz_' -'tu' -'transition_' -'to' -'signed_' -'serve_' -'ser_' -'separate_' -'sent_' -'sanctions_' -'respond_' -'resolve_' -'released_' -'release_' -'relation_' -'reduced_' -'raum_' -'profitieren_' -'port_' -'planning_' -'phase_' -'payments_' -'owing_' -'monitoring_' -'models_' -'migration_' -'mag_' -'lies_' -'interesting_' -'immediate_' -'icht_' -'host_' -'holding_' -'hit_' -'helped_' -'globalization_' -'ger' -'genuine_' -'gen' -'fuel_' -'franzÃļsischen_' -'followed_' -'finanzielle_' -'film_' -'expression_' -'expressed_' -'establishing_' -'erung_' -'erhÃļhen_' -'einiger_' -'economics_' -'dra' -'determined_' -'cross_' -'clients_' -'chinesische_' -'buffet_' -'aktuelle_' -'Wochen_' -'Wirtschaftswachstum_' -'Warum_' -'Vorschriften_' -'Vielleicht_' -'Tu' -'Strand_' -'Sache_' -'Regierungs' -'Red_' -'Präsidentschaft_' -'Prodi_' -'Partner_' -'Nun_' -'Musik_' -'Leute_' -'Konzept_' -'Konflikt_' -'Justice_' -'Instead_' -'Ger' -'Financial_' -'Erfahrung_' -'Brazil_' -'Bedrohung_' -'Arbeitslosigkeit_' -'Angelegenheiten_' -'75_' -'400_' -'2012_' -'150_' -'... _' -'%, _' -'„_' -'Ä' -'went_' -'wa' -'ve' -'urban_' -'une_' -'unbedingt_' -'tro' -'tool_' -'thereby_' -'struggle_' -'societies_' -'science_' -'sa_' -'ready_' -'rapid_' -'perspective_' -'output_' -'operating_' -'no' -'nearly_' -'namely_' -'multi_' -'moving_' -'mittel_' -'menu_' -'mention_' -'lo_' -'jÃŧngsten_' -'interested_' -'innovative_' -'highest_' -'hi' -'gs' -'grÃļßeren_' -'grounds_' -'governance_' -'gelten_' -'fixed_' -'exclusive_' -'estimated_' -'entsprechende_' -'else_' -'ei_' -'dringend_' -'download_' -'discussed_' -'difficulties_' -'default_' -'criticism_' -'color_' -'ck' -'che' -'candidate_' -'camera_' -'broad_' -'beste_' -'bessere_' -'behavior_' -'begrÃŧßen_' -'barkeit_' -'assistance_' -'army_' -'approved_' -'announced_' -'ahead_' -'achten_' -'Zustimmung_' -'Wohn' -'Wohlstand_' -'Unfortunately_' -'Umwelt' -'Turkish_' -'Tri' -'Su' -'San' -'Saddam_' -'Royal_' -'Risiko_' -'Produktion_' -'No' -'Monaten_' -'Mo' -'Mer' -'Meine_' -'Irland_' -'Initiativen_' -'Informations' -'Grund' -'Gesamt' -'Gegenteil_' -'Friedens' -'Einkommen_' -'Despite_' -'Bis_' -'Behandlung_' -'Atom' -'Anteil_' -'Anstrengungen_' -'Annahme_' -'.._' -'.' -'+_' -'" (_' -'wählen_' -'wo' -'wider_' -'whom_' -'wert_' -'welfare_' -'weg_' -'verlassen_' -'verantwortlich_' -'tung_' -'ts' -'treated_' -'track_' -'television_' -'summit_' -'spirit_' -'speaking_' -'slow_' -'skills_' -'sitting_' -'saving_' -'raised_' -'post' -'ng' -'ming_' -'middle_' -'merely_' -'lernen_' -'kaufen_' -'ium_' -'isierung_' -'introduce_' -'intended_' -'integrated_' -'inequality_' -'industries_' -'industrie_' -'helping_' -'heiten_' -'halte_' -'großes_' -'generally_' -'gelegen_' -'gave_' -'fÃŧhrte_' -'fund_' -'flight_' -'fields_' -'extreme_' -'exist_' -'evening_' -'entwickelten_' -'einander_' -'cover_' -'continued_' -'comfort_' -'click_' -'claims_' -'break_' -'bilden_' -'biggest_' -'back' -'as' -'anderem_' -'accession_' -'Zunächst_' -'Win' -'Video_' -'Vereinbarung_' -'U' -'Transparenz_' -'Sterne_' -'See' -'Secretary_' -'Se' -'Robert_' -'Online_' -'Nice_' -'Nachdem_' -'Ist_' -'In' -'Home_' -'Gäste_' -'Erfahrungen_' -'Entschließung_' -'Ent' -'El' -'Eindruck_' -'Ebenso_' -'Do' -'Depression_' -'Christian_' -'Bundes' -'Bilder_' -'Bank' -'Arbeits' -'36_' -'1996_' -')' -' -, _' -'—_' -'Äą' -'Öl' -'zo' -'ys' -'victims_' -'verÃļffentlicht_' -'unbe' -'uf' -'trägt_' -'tried_' -'tomorrow_' -'thanks_' -'ter' -'suffering_' -'su' -'strategies_' -'staatliche_' -'sierung_' -'regionalen_' -'reception_' -'purchase_' -'prosperity_' -'programm_' -'primarily_' -'pre' -'permanent_' -'parents_' -'outcome_' -'obvious_' -'nts_' -'morgen_' -'met_' -'mark_' -'manage_' -'loans_' -'listed_' -'laid_' -'laden_' -'individuals_' -'independence_' -'ier' -'hÃļher_' -'hundreds_' -'himself_' -'granted_' -'gold_' -'gleiche_' -'gesehen_' -'fordern_' -'fellow_' -'extension_' -'europäischer_' -'eu' -'ess_' -'erlaubt_' -'ere_' -'entstehen_' -'enormous_' -'ence_' -'emphasis_' -'el' -'eingesetzt_' -'dollars_' -'divided_' -'display_' -'danach_' -'cutting_' -'criteria_' -'constitutional_' -'conflicts_' -'comments_' -'collapse_' -'closely_' -'benutzt_' -'begrÃŧße_' -'bear_' -'baren_' -'arms_' -'ard_' -'absolutely_' -'Zentralbank_' -'Wer_' -'Wasser' -'Volkswirtschaften_' -'Vertreter_' -'Trotz_' -'Tra' -'Thus_' -'Strategien_' -'Stimme_' -'Steuern_' -'Standard_' -'Square_' -'Sinne_' -'Produktions' -'NA' -'Monat_' -'Mitgliedschaft_' -'Methode_' -'Me' -'Krankheiten_' -'Ju' -'Islam_' -'Insel_' -'Industrie' -'Homepage_' -'Hinsicht_' -'Falle_' -'Erachtens_' -'Einrichtungen_' -'Einrichtung_' -'Ca' -'Business_' -'Breakfast_' -'Arbeitsplätze_' -'Another_' -'Air_' -'5' -'39_' -'195' -'. (_' -', "_' -'Ãŧberall_' -'är' -'­_' -'za_' -'wollte_' -'willing_' -'werk_' -'waste_' -'w' -'verloren_' -'us' -'understanding_' -'technological_' -'swimming_' -'suggests_' -'sufficient_' -'street_' -'staatlichen_' -'sm_' -'sho' -'shift_' -'ser' -'seien_' -'schwer_' -'schließen_' -'root_' -'richtige_' -'rer_' -'relatively_' -'relating_' -'relate_' -'rd_' -'qua' -'practices_' -'pour_' -'placed_' -'persÃļnlichen_' -'ourselves_' -'ort_' -'operation_' -'olitik_' -'mÃļglichen_' -'musste_' -'millions_' -'mid_' -'measure_' -'lo' -'lis' -'institutional_' -'ichen_' -'har' -'happy_' -'happened_' -'gern_' -'funktioniert_' -'fort' -'formal_' -'flexible_' -'fish_' -'fine_' -'ff' -'fand_' -'fail_' -'erhÃļht_' -'erfordert_' -'enormen_' -'elected_' -'earlier_' -'dly_' -'directives_' -'deutsche_' -'dennoch_' -'deliver_' -'contain_' -'conclude_' -'compared_' -'coffee_' -'co' -'ci' -'ches_' -'charged_' -'character_' -'changing_' -'cation_' -'bonds_' -'bly_' -'betrachten_' -'besondere_' -'balanced_' -'automatisch_' -'animals_' -'actual_' -'Zweifel_' -'Zentralbanken_' -'Vorteil_' -'Volk_' -'Then_' -'Stärkung_' -'Sorgen_' -'Sonder' -'Schul' -'Ro' -'Republik_' -'Privat' -'Pe' -'Or' -'Netherlands_' -'Nachbarn_' -'MySQL_' -'Monetary_' -'Mitte_' -'Medien_' -'Macht' -'Lo' -'Liberalisierung_' -'Landwirtschaft_' -'Ku' -'Klicken_' -'Irish_' -'Instrument_' -'IP_' -'ID_' -'Hoffnung_' -'Groß' -'Euro' -'Entwicklungsländer_' -'Eastern_' -'Dr_' -'Democrats_' -'Bu' -'Basis_' -'Barroso_' -'Aspekt_' -'Angriff_' -'1995_' -'1989_' -'01_' -'0' -'а' -'äre_' -'Über' -'wor' -'war' -'vorhanden_' -'van_' -'va_' -'unto_' -'unterschiedlichen_' -'unless_' -'tät_' -'transparency_' -'theory_' -'supports_' -'supporting_' -'stress_' -'stated_' -'significantly_' -'sicherzustellen_' -'schneller_' -'says_' -'revenue_' -'resulting_' -'published_' -'provision_' -'protected_' -'processes_' -'presidency_' -'player_' -'ping_' -'oral_' -'numbers_' -'neues_' -'named_' -'membership_' -'lung_' -'looks_' -'ll' -'liberal_' -'lesen_' -'leads_' -'laufen_' -'latter_' -'lagen_' -'kennen_' -'j_' -'io_' -'introduction_' -'identity_' -'hol' -'guidelines_' -'griechischen_' -'gleich' -'format_' -'folgenden_' -'flexibility_' -'fire_' -'finanziellen_' -'factors_' -'extended_' -'expectations_' -'examples_' -'exactly_' -'email_' -'ef' -'dynamic_' -'disc' -'der' -'denken_' -'delegation_' -'decide_' -'criminal_' -'crime_' -'courses_' -'collective_' -'bi' -'berÃŧcksichtigt_' -'ben' -'bars_' -'ba' -'article_' -'ank' -'anderes_' -'and' -'agencies_' -'absolute_' -'Your_' -'Währungs' -'Wettbewerbsfähigkeit_' -'Vergleich_' -'Unternehmens' -'Teilen_' -'Spieler_' -'Rezession_' -'Regel_' -'Realität_' -'RI' -'Post' -'Po' -'Open_' -'Nachfolgend_' -'Na' -'NE' -'Modell_' -'Mitteilung_' -'Militär' -'Michael_' -'Mi' -'Mat' -'Mal' -'Konflikte_' -'Ja' -'Her' -'Handels' -'Han' -'Gesundheits' -'Firmen_' -'Film_' -'Fe' -'Emissionen_' -'EADS_' -'Dar' -'Congress_' -'Che' -'CA' -'BÃŧrgern_' -'Bio' -'Best_' -'Ausland_' -'Ar' -'Apartments_' -'Abschluss_' -'AIDS_' -'8' -'13' -' –, _' -' = _' -' ... _' -'ҁ' -'Êe_' -'yourself_' -'wirtschaftlicher_' -'voters_' -'vo' -'victory_' -'vertreten_' -'verringern_' -'verpflichtet_' -'vehicles_' -'unge' -'unde' -'u' -'truly_' -'tive_' -'tionen_' -'tests_' -'tasks_' -'sun_' -'substantial_' -'starken_' -'standing_' -'st' -'ss' -'spread_' -'species_' -'southern_' -'sources_' -'someone_' -'ski_' -'signal_' -'shops_' -'seven_' -'seriously_' -'selection_' -'selbstverständlich_' -'russischen_' -'restrictions_' -'relative_' -'reaching_' -'rea' -'promoting_' -'profit_' -'physical_' -'partly_' -'occur_' -'ob' -'nte_' -'nis' -'nationaler_' -'militärische_' -'meaning_' -'materials_' -'maintained_' -'länger_' -'llen_' -'link_' -'limit_' -'leaving_' -'launched_' -'launch_' -'languages_' -'langfristige_' -'l' -'komplett_' -'ked_' -'kan' -'jo' -'jener_' -'implement_' -'geändert_' -'gerecht_' -'garden_' -'gar' -'exercise_' -'euch_' -'ets_' -'et' -'entscheiden_' -'ein' -'duty_' -'dir_' -'detailed_' -'database_' -'danger_' -'cuts_' -'contrast_' -'contract_' -'contemporary_' -'combination_' -'closer_' -'claim_' -'cha' -'causes_' -'calls_' -'burden_' -'born_' -'bodies_' -'bestimmte_' -'bestand_' -'benutzen_' -'behandelt_' -'aus' -'anyone_' -'angeht_' -'alt' -'allowing_' -'ages_' -'acht_' -'accounts_' -'accepted_' -'absence_' -'Wähler_' -'Worte_' -'Wein' -'Verhältnis_' -'Time_' -'Stadt' -'Sehr_' -'SehenswÃŧrdigkeiten_' -'Rotary_' -'RA' -'Projekte_' -'Prioritäten_' -'Polen_' -'Pi' -'Perhaps_' -'NI_' -'Leistung_' -'Last_' -'Kombination_' -'Kas' -'Jo' -'III_' -'HIV_' -'God_' -'Gebieten_' -'Foreign_' -'Forderung_' -'Ereignisse_' -'Entwicklungsländern_' -'Dennoch_' -'Beschreibung_' -'Bas' -'Aufgrund_' -'Aufbau_' -'Argentina_' -'Anfrage_' -'Amerikaner_' -'Alliance_' -'Agreement_' -'AT' -'AR' -'95_' -''._' -'Ņ‚' -'Ãŗ' -'äu' -'ät_' -'ären_' -'Ökonomen_' -'zurÃŧckge' -'zimmer_' -'zero_' -'zeit' -'wrote_' -'write_' -'wodurch_' -'westlichen_' -'weiterer_' -'weg' -'vollständig_' -'voice_' -'ving_' -'up' -'unlikely_' -'ungs' -'umfassende_' -'ual_' -'traffic_' -'ten' -'survive_' -'surrounding_' -'supposed_' -'super' -'sts_' -'spent_' -'sis_' -'signs_' -'riesigen_' -'resort_' -'represents_' -'represent_' -'raising_' -'putting_' -'prÃŧfen_' -'prospects_' -'promotion_' -'producers_' -'picture_' -'nes_' -'motion_' -'läuft_' -'lu' -'king_' -'ke' -'iti' -'ionen_' -'implementing_' -'ian_' -'housing_' -'historische_' -'hardly_' -'grundlegende_' -'gone_' -'goals_' -'gli' -'gestimmt_' -'geschehen_' -'gemäß_' -'fÃŧhrenden_' -'fähigkeit_' -'functions_' -'founded_' -'focused_' -'feststellen_' -'faces_' -'faced_' -'extensive_' -'except_' -'establishment_' -'erhält_' -'erfolgen_' -'entsprechend_' -'enterprises_' -'ensuring_' -'ene_' -'ehemaligen_' -'echte_' -'du' -'discussions_' -'differences_' -'di' -'desire_' -'deeply_' -'darunter_' -'da' -'conventional_' -'continues_' -'constitution_' -'completed_' -'communities_' -'cat' -'cast_' -'bottom_' -'ber' -'beginnen_' -'ban_' -'aufzunehmen_' -'auf' -'asking_' -'arrangements_' -'angezeigt_' -'ad' -'Y' -'Wohl' -'Wege_' -'Wahl' -'Voraussetzungen_' -'Verhalten_' -'Unter' -'UNO_' -'Transport_' -'THE_' -'Straße_' -'Ste' -'Standpunkt_' -'Si' -'Schulden' -'SA' -'Regime_' -'Re' -'Pre' -'Option_' -'Old_' -'Oder_' -'Neben_' -'Mit' -'MI' -'Kunst_' -'Kompromiss_' -'Kar' -'Kapital_' -'James_' -'Hu' -'Hong_' -'Gold_' -'Gebäude_' -'Fähigkeit_' -'Foundation_' -'Familien_' -'Every_' -'Erholung_' -'Ei' -'Denn_' -'Chi' -'Brasilien_' -'Beispiele_' -'Bau' -'Ba' -'Assad_' -'Aspekte_' -'Asian_' -'Arten_' -'33_' -'30' -'1' -',' -' /_' -'Đž' -'Ãŧber' -'Ãŗn_' -'ÃĄ_' -'ß' -'Über_' -'Öl_' -'  _' -'zusätzlich_' -'zurÃŧck' -'ya_' -'window_' -'widely_' -'wichtigste_' -'whatever_' -'verlieren_' -'vergessen_' -'unver' -'unable_' -'umfassenden_' -'ub' -'treaty_' -'tradition_' -'terrace_' -'stärken_' -'studies_' -'stone_' -'spacious_' -'spa' -'sovereign_' -'sort_' -'shower_' -'ships_' -'select_' -'seitens_' -'sa' -'ru' -'recognize_' -'rasch_' -'r' -'producing_' -'pp' -'policymakers_' -'platz_' -'peaceful_' -'participate_' -'ordnung_' -'opposed_' -'on' -'ok_' -'offensichtlich_' -'north_' -'nisse_' -'niemand_' -'niemals_' -'neither_' -'ne' -'mussten_' -'ms_' -'militärischen_' -'maximum_' -'mar' -'machine_' -'losses_' -'lokale_' -'letztlich_' -'lessons_' -'le' -'ld_' -'langen_' -'labour_' -'jährlich_' -'ji' -'ing' -'improvements_' -'improvement_' -'iger_' -'ierungs' -'hÃļhere_' -'hundred_' -'household_' -'he' -'gezwungen_' -'getroffen_' -'gesprochen_' -'gefallen_' -'gap_' -'flows_' -'fisheries_' -'firm_' -'figures_' -'fe' -'faster_' -'eye_' -'extend_' -'erten_' -'ersch' -'ern' -'equally_' -'ep' -'entschieden_' -'entry_' -'enk' -'element_' -'elegant_' -'eight_' -'discuss_' -'demokratische_' -'deaths_' -'dealing_' -'daten_' -'dank_' -'cuisine_' -'contracts_' -'considerable_' -'compatible_' -'coal_' -'charges_' -'cas' -'car' -'britische_' -'books_' -'bestimmt_' -'beschlossen_' -'beds_' -'außerordentlich_' -'automatically_' -'aten_' -'air' -'accordance_' -']] [[_' -'Zur_' -'Ziel' -'Verwaltungs' -'Verwaltung_' -'Verpflichtung_' -'Un' -'Umständen_' -'Tre' -'Tod_' -'Teile_' -'Systeme_' -'Staats' -'Sprache_' -'Sieg_' -'See_' -'Sarkozy_' -'Q_' -'Organisationen_' -'Name_' -'Männer_' -'Mar' -'Leistungen_' -'LE' -'Kyoto_' -'Kong_' -'Jetzt_' -'Indian_' -'Haushalt_' -'Gro' -'Grenze_' -'Gra' -'Frankfurt_' -'Flug' -'Fahr' -'Entwicklungs' -'Drittens_' -'De_' -'David_' -'Copenhagen_' -'Click_' -'Bus' -'Bre' -'Both_' -'Board_' -'Berichts_' -'Ban' -'Angst_' -'Anforderungen_' -'Agenda_' -'48_' -'25' -'2014_' -'', _' -' -- _' -'Ãļ' -'ätzen_' -'Übereinstimmung_' -'zahlreichen_' -'wirtschaft_' -'wenigen_' -'weak_' -'warming_' -'wall_' -'wages_' -'vieler_' -'verstärkt_' -'verfahren_' -'usw_' -'ur' -'und' -'täglich_' -'tä' -'tz' -'turning_' -'translation_' -'touch_' -'ti_' -'ti' -'the' -'th' -'talking_' -'stands_' -'smoking_' -'sion_' -'siehe_' -'schlagen_' -'scheme_' -'s' -'ringen_' -'reservation_' -'ren' -'punkt_' -'proved_' -'por_' -'piece_' -'photos_' -'partnership_' -'para_' -'opened_' -'omi' -'older_' -'nennen_' -'ned_' -'nachhaltige_' -'methods_' -'memory_' -'mehreren_' -'manchmal_' -'machte_' -'ln_' -'legt_' -'learning_' -'konzentrieren_' -'kle' -'ker' -'keinerlei_' -'itu' -'intervention_' -'internationaler_' -'installation_' -'ins' -'initial_' -'id_' -'hot_' -'hauptsächlich_' -'gerne_' -'garantiert_' -'ft' -'franzÃļsische_' -'fle' -'felt_' -'father_' -'est' -'ens' -'ended_' -'employees_' -'ea' -'drugs_' -'driving_' -'destruction_' -'defined_' -'defense_' -'debates_' -'damals_' -'counter_' -'communications_' -'commitments_' -'combating_' -'combat_' -'bt_' -'benÃļtigt_' -'bei' -'beachten_' -'bath_' -'auszu' -'ausdrÃŧcklich_' -'au' -'at' -'anzu' -'animal_' -'Zeit' -'Za' -'Vietnam_' -'Vielzahl_' -'Um' -'Tro' -'Times_' -'Tal' -'Ta' -'TA' -'Strom' -'Stand_' -'Sta' -'Sinn_' -'Shi' -'SchlÃŧssel' -'Santa_' -'Roman_' -'Rezeption_' -'Regierungschefs_' -'Rede_' -'Reaktion_' -'Private_' -'PT_' -'PHP_' -'Other_' -'Mont' -'Mexiko_' -'Mag' -'Located_' -'Live_' -'Las_' -'Kritik_' -'Kinder' -'Jahres' -'Inseln_' -'Hotel' -'Hoch' -'Hause_' -'Handels_' -'Grad_' -'Gleichzeitig_' -'Gi' -'Gesellschaften_' -'Gerichts' -'Finanz' -'February_' -'Falls_' -'Fach' -'Erde_' -'Du_' -'Dra' -'Deutsche_' -'Dann_' -'DI' -'Ce' -'Bild' -'Ber' -'Aufenthalt_' -'Anerkennung_' -'Andere_' -'Altstadt_' -'Alternative_' -'Abend_' -'64_' -'10' -'Ês_' -'è' -'ähnliche_' -'Überwachung_' -'wind_' -'widespread_' -'welches_' -'verlangt_' -'urgent_' -'ures_' -'unternommen_' -'unseres_' -'uner' -'umzusetzen_' -'umgesetzt_' -'truth_' -'trouble_' -'trends_' -'tourist_' -'tische_' -'threats_' -'ständig_' -'strategischen_' -'stimulus_' -'stations_' -'sta' -'sovereignty_' -'sought_' -'solcher_' -'sieben_' -'sicherlich_' -'season_' -'scope_' -'sal' -'rt' -'roughly_' -'ries_' -'rest' -'responsibilities_' -'relaxing_' -'relax_' -'regulatory_' -'referred_' -'ref' -'reasonable_' -'proper_' -'print_' -'pon' -'pictures_' -'pi' -'phrase_' -'photo_' -'persons_' -'over' -'ou' -'organizations_' -'organisations_' -'oren_' -'ongoing_' -'oa' -'nu' -'nt' -'nge_' -'mountain_' -'mode_' -'mi' -'meetings_' -'mechanism_' -'markt_' -'manner_' -'länder_' -'lp' -'len' -'ld' -'kÃŧrzlich_' -'kt' -'kitchen_' -'ka' -'je' -'ition_' -'ish_' -'ir_' -'ig' -'ide_' -'hy' -'humanitarian_' -'ht' -'hrt_' -'honourable_' -'heutigen_' -'her' -'happens_' -'guter_' -'go' -'gezogen_' -'geschÃŧtzt_' -'ges_' -'gains_' -'frÃŧher_' -'fla' -'fighting_' -'ff_' -'fer' -'experts_' -'estate_' -'erscheint_' -'erk' -'entered_' -'ele' -'einzigen_' -'einge' -'ei' -'distributed_' -'disa' -'dinner_' -'dia' -'declared_' -'cycle_' -'crises_' -'cing_' -'cho' -'cheap_' -'buildings_' -'brings_' -'bra' -'bound_' -'block_' -'beziehen_' -'betreffen_' -'bestehenden_' -'besseren_' -'berÃŧcksichtigen_' -'berg_' -'believed_' -'bedarf_' -'ausschließlich_' -'argue_' -'ana' -'alte_' -'alliance_' -'ale' -'aims_' -'achieving_' -'a' -'Zweck_' -'YORK_' -'Worten_' -'Without_' -'Wissen_' -'West' -'Wegen_' -'Unterschied_' -'Trump_' -'Treffen_' -'Those_' -'Suche_' -'Stimmen_' -'Sol' -'She' -'Seine_' -'Sand' -'SI' -'Rome_' -'Richtlinien_' -'Rest_' -'Report_' -'Regulation_' -'Rechtsvorschriften_' -'Prä' -'Phase_' -'Pan' -'Norden_' -'Miss' -'Merkel_' -'Mas' -'Mann_' -'MA' -'Leider_' -'Kultur' -'Kriegs' -'Klimawandel_' -'Klein' -'King_' -'Karte_' -'Inhalt_' -'Hand' -'HA' -'GrÃļße_' -'Gemeinsamen_' -'Fu' -'Formen_' -'Forderungen_' -'Familien' -'Faktoren_' -'Earth_' -'EL' -'Dadurch_' -'Computer_' -'Cho' -'Chancen_' -'Car' -'Blair_' -'Binnenmarkt_' -'Bahn_' -'BO' -'BMW_' -'BA' -'Auto' -'Ausbildung_' -'Atmosphäre_' -'Armee_' -'Anwendungen_' -'Anspruch_' -'An' -'Amt_' -'Alt' -'AG_' -'AC' -'194' -'15' -'11' -' â‚Ŧ_' -' — _' -'ÃĄ' -'   . – _' -'zusammenge' -'zurÃŧckzu' -'zugleich_' -'ya' -'wine_' -'wei' -'ven_' -'usual_' -'tte_' -'tt_' -'trotzdem_' -'teilen_' -'systeme_' -'sugar_' -'suffered_' -'subsidies_' -'stronger_' -'statements_' -'square_' -'south_' -'sky_' -'severe_' -'selected_' -'schools_' -'sch' -'revolution_' -'reported_' -'remove_' -'remote_' -'remember_' -'records_' -'recorded_' -'recommend_' -'recognise_' -'race_' -'quarter_' -'promise_' -'pro' -'pollution_' -'perfectly_' -'payment_' -'patients_' -'pa' -'origin_' -'onen_' -'ommen_' -'offiziellen_' -'obviously_' -'obligations_' -'ni' -'mutual_' -'moved_' -'mm_' -'mittlerweile_' -'miles_' -'marked_' -'manufacturing_' -'manche_' -'love_' -'leg' -'learned_' -'lau' -'land' -'kommenden_' -'ko' -'keits' -'kation_' -'kamen_' -'ization_' -'ir' -'instance_' -'inside_' -'ingly_' -'hÃļchsten_' -'hostels_' -'hoher_' -'hn_' -'hip_' -'hau' -'gung_' -'gla' -'ged_' -'fÃŧhrende_' -'fresh_' -'frage_' -'fordert_' -'finding_' -'figure_' -'fi' -'federal_' -'fa' -'expense_' -'exists_' -'erfolgreichen_' -'equality_' -'entirely_' -'enorme_' -'electoral_' -'einfachen_' -'ee_' -'door_' -'doesn_' -'discover_' -'device_' -'destroyed_' -'den' -'definition_' -'deficits_' -'defence_' -'currencies_' -'couple_' -'convinced_' -'consultation_' -'compete_' -'combined_' -'cohesion_' -'coalition_' -'chi' -'cars_' -'carefully_' -'cards_' -'capable_' -'button_' -'budgetary_' -'bu' -'bo' -'bilateral_' -'bewegen_' -'bemÃŧht_' -'beinahe_' -'background_' -'attractive_' -'ating_' -'arrival_' -'arbeitet_' -'agency_' -'agen_' -'advance_' -'administrative_' -'ade_' -'addressing_' -'ach_' -'[[_' -'Zi' -'Zeitraum_' -'Wissenschaft_' -'Weltbank_' -'Verpflichtungen_' -'Verkehrs' -'Verbindungen_' -'Ver' -'UN' -'Tour_' -'Thank_' -'Stu' -'Standard' -'Staates_' -'Sprachen_' -'Spiel' -'Spar' -'Sha' -'Second_' -'Safety_' -'SQL_' -'Produkt' -'Pri' -'Pin' -'Patienten_' -'Ohne_' -'Niveau_' -'Mu' -'Mor' -'Monate_' -'Man' -'Madrid_' -'Kra' -'Jeder_' -'Investoren_' -'Instrumente_' -'Identität_' -'Har' -'GrÃŧnde_' -'Good_' -'Gespräche_' -'Gesetz_' -'FlÃŧchtlinge_' -'Find_' -'Februar_' -'FR_' -'Europäer_' -'Erd' -'Energie' -'En' -'Each_' -'DurchfÃŧhrung_' -'Dienste_' -'Di' -'DE' -'DA' -'Can_' -'Boden_' -'Bildungs' -'Beweis_' -'Berichterstatterin_' -'Bel' -'Before_' -'Beach_' -'Bau_' -'Aussicht_' -'Au' -'Arabia_' -'AM' -'1994_' -'192' -'.” _' -', „_' -'  _' -'ä' -'zusätzlichen_' -'ze' -'xi' -'west_' -'welcher_' -'weight_' -'wars_' -'vulnerable_' -'vous_' -'verändert_' -'vere' -'universal_' -'unit_' -'ungefähr_' -'une' -'unacceptable_' -'umfasst_' -'trend_' -'tre' -'tor' -'tin' -'testing_' -'terror_' -'telephone_' -'technischen_' -'talks_' -'sz' -'swe' -'super_' -'succeed_' -'stimmt_' -'steigen_' -'ste' -'stand' -'ssi' -'spielt_' -'sozial' -'sofort_' -'sized_' -'situations_' -'sha' -'sets_' -'session_' -'serves_' -'seemed_' -'seats_' -'scheinen_' -'sbe' -'sar' -'sage_' -'rä' -'rs' -'rn' -'rental_' -'registered_' -'rechts' -'pt_' -'productive_' -'prison_' -'precise_' -'plants_' -'passed_' -'parliaments_' -'parliament_' -'pan' -'packages_' -'out' -'ou_' -'ol' -'newly_' -'neu' -'names_' -'mÃļglichkeiten_' -'myself_' -'ms' -'mostly_' -'mo_' -'miteinander_' -'min_' -'maintaining_' -'luxury_' -'lten_' -'liquidity_' -'leisure_' -'legitimate_' -'legitimacy_' -'laut_' -'kämpfen_' -'ku' -'kor' -'kon' -'ko_' -'ip' -'investieren_' -'intellectual_' -'informiert_' -'impressive_' -'impose_' -'imports_' -'images_' -'igkeit_' -'hÃļheren_' -'houses_' -'hinweisen_' -'golf_' -'gewisse_' -'genannt_' -'gel' -'ga' -'fällt_' -'fu' -'fit_' -'falsch_' -'explanation_' -'experienced_' -'expansion_' -'exit_' -'ex_' -'eten_' -'essen_' -'erinnern_' -'ergreifen_' -'ergeben_' -'erg' -'equivalent_' -'emergency_' -'ely_' -'eln_' -'ellen_' -'eingerichtet_' -'dÊ' -'dritten_' -'don' -'died_' -'derzeitigen_' -'depend_' -'delay_' -'define_' -'covered_' -'contributions_' -'conce' -'colleague_' -'club_' -'chten_' -'category_' -'candidates_' -'cancer_' -'bringing_' -'black_' -'bezahlt_' -'bestimmen_' -'besonderen_' -'bekämpfen_' -'bek' -'beaches_' -'battle_' -'bal' -'ausge' -'attempts_' -'appreciate_' -'ang' -'alt_' -'allgemeinen_' -'all' -'abroad_' -'abge' -'Zinsen_' -'Zentral' -'Women_' -'Wandel_' -'Wachstums_' -'Villa_' -'Very_' -'VI' -'Trek_' -'Tradition_' -'To' -'Their_' -'Texte_' -'Tele' -'Sweden_' -'Städte_' -'Studie_' -'Sollte_' -'Skype_' -'RÃŧckkehr_' -'Ri' -'Research_' -'Reise' -'Regelung_' -'Polizei_' -'Opposition_' -'OS_' -'OR_' -'Noch_' -'Mitteln_' -'Mitgliedern_' -'Mel' -'Media_' -'Ma' -'MO' -'Luft' -'Lebanon_' -'Kurs_' -'Krankheit_' -'Kontakt_' -'Kommunikations' -'Kommunikation_' -'Ki' -'Kenntnis_' -'Institute_' -'Insbesondere_' -'Ideen_' -'IT' -'Hussein_' -'Hostel_' -'Ho' -'Hel' -'Grundrechte_' -'Green_' -'Go' -'Gesetze_' -'Geschäftsordnung_' -'GefÃŧhl_' -'Gefahren_' -'Gast' -'Fähigkeiten_' -'Funktionen_' -'Front_' -'Fest' -'Fax_' -'Executive_' -'Ex' -'Download_' -'Diskussionen_' -'Day_' -'Daten' -'Czech_' -'Ci' -'Charta_' -'Bä' -'Buch_' -'BrÃŧssel_' -'Bo' -'Ans' -'Airbus_' -'Administration_' -'44_' -'21' -'191' -'190' -'188' -'186' -' (' -'â„ĸ _' -'”._' -'“' -'Đĩ' -'в' -'ÃŧberprÃŧfen_' -'ägen_' -'zuvor_' -'zufrieden_' -'zing_' -'yesterday_' -'wÃŧnschen_' -'weltweite_' -'weiteres_' -'wait_' -'vorschlagen_' -'vorgelegt_' -'village_' -'verfolgt_' -'verdient_' -'va' -'ura' -'unmÃļglich_' -'under' -'una' -'trillion_' -'tourism_' -'title_' -'tend_' -'substances_' -'sub' -'structures_' -'stre' -'strategische_' -'stellung_' -'stattfinden_' -'starke_' -'sports_' -'sport_' -'solange_' -'shape_' -'sensitive_' -'senior_' -'sektor_' -'schÃļnen_' -'sb' -'savings_' -'sam' -'sagt_' -'rÊ' -'runs_' -'rr' -'ro_' -'representative_' -'reject_' -'reicht_' -'referendum_' -'reden_' -'recognition_' -'rechts_' -'rechte_' -'rat' -'ral_' -'radical_' -'que' -'protecting_' -'prevention_' -'positions_' -'politisches_' -'pol' -'pick_' -'phenomenon_' -'peri' -'paying_' -'our' -'ot_' -'ori' -'ora' -'opposite_' -'of' -'o' -'notwendigen_' -'ney_' -'nen' -'nearby_' -'nan' -'nahme_' -'mp' -'mindestens_' -'metro_' -'merk' -'meist_' -'medi' -'mechanisms_' -'marketing_' -'lovely_' -'li' -'lands_' -'lage_' -'konkrete_' -'killed_' -'kept_' -'keiner_' -'kar' -'journalists_' -'jeweils_' -'ities_' -'is' -'intensive_' -'institution_' -'install_' -'informed_' -'imagine_' -'igung_' -'ia' -'hr' -'hoffen_' -'ho' -'head' -'ha_' -'guide_' -'guaranteed_' -'gruppen_' -'gruppe_' -'grow_' -'gro' -'greenhouse_' -'gets_' -'geleistet_' -'garantieren_' -'ga_' -'funktionieren_' -'formed_' -'forget_' -'fantastic_' -'ez_' -'expenditure_' -'exclusively_' -'ete_' -'erzielen_' -'erscheinen_' -'entlang_' -'ent' -'endorse_' -'emphasise_' -'einzusetzen_' -'einsetzen_' -'eingefÃŧhrt_' -'drug_' -'disp' -'discrimination_' -'discovered_' -'dienen_' -'devices_' -'determine_' -'detail_' -'depends_' -'demanding_' -'declaration_' -'cre' -'court_' -'coordination_' -'contributed_' -'contribute_' -'consequence_' -'confirmed_' -'comment_' -'cken_' -'centuries_' -'capitalism_' -'booking_' -'bitten_' -'bislang_' -'besuchen_' -'bestätigen_' -'beg' -'bedeuten_' -'bau_' -'ausreichend_' -'asylum_' -'aside_' -'arbeits' -'anstatt_' -'ani' -'ana_' -'abhängig_' -'abe_' -'ab' -'Wettbewerbs' -'Waren_' -'Vision_' -'Ur' -'Untersuchung_' -'Ungleichheit_' -'Under_' -'Tor' -'Tier' -'SÃŧden_' -'Swedish_' -'Summit_' -'Stil_' -'Start_' -'Sport_' -'Sport' -'Spitze_' -'Sommer_' -'Signal_' -'Serbia_' -'Sa' -'SE' -'Reise_' -'Rahmen' -'Rad' -'Prozesses_' -'Produkt_' -'Pro_' -'Preis' -'Pra' -'Pen' -'Par' -'Paket_' -'PS' -'Ordnung_' -'Nutzer_' -'Not' -'Nach' -'Musik' -'Mitarbeiter_' -'Men' -'Meiner_' -'Massen' -'Los_' -'Leit' -'KÃļnigreich_' -'Kräfte_' -'Kriterien_' -'Kontinent_' -'Kapital' -'Ji' -'Je' -'Hä' -'Hintergrund_' -'Ha' -'Gre' -'Gewinne_' -'Gericht_' -'Geldpolitik_' -'Foto_' -'Fortschritt_' -'Export' -'Erwartungen_' -'Erfolge_' -'Enjoy_' -'Engagement_' -'EC_' -'Definition_' -'Data_' -'Cor' -'Con' -'Charakter_' -'CO' -'Budget_' -'Book_' -'Bon' -'Bi' -'Beziehung_' -'Beihilfen_' -'Balkan_' -'Auge_' -'Auf' -'Atlantic_' -'@_' -'> _' -'98_' -'80' -'20' -'1991_' -'1970_' -'14' -'06_' -') ' -'Ãŧr' -'Ãļffentlich_' -'í' -'Ê' -'zentrum_' -'wesentlichen_' -'wer' -'wave_' -'wahren_' -'wachsen_' -'vir' -'verlangen_' -'verbessert_' -'ven' -'van' -'ust' -'upper_' -'uli' -'ul_' -'ud' -'tzen_' -'traditionellen_' -'teilweise_' -'teil_' -'tal_' -'tag' -'suite_' -'suffer_' -'stra' -'spa_' -'sonst_' -'sold_' -'smus_' -'sight_' -'shi' -'sharing_' -'sen' -'se' -'schreiben_' -'schlecht_' -'rta' -'roads_' -'ro' -'rischen_' -'richten_' -'rejected_' -'regular_' -'railway_' -'rag' -'radio_' -'quick_' -'qui' -'profits_' -'preis' -'praktisch_' -'playing_' -'planned_' -'ped_' -'paragraph_' -'outdoor_' -'or' -'nz_' -'nb' -'naturally_' -'mÃļglichst_' -'mÃļgen_' -'minority_' -'metres_' -'mer_' -'mandate_' -'lokalen_' -'lit' -'linie_' -'lin_' -'lesson_' -'leiden_' -'kurzem_' -'ktion_' -'krise_' -'kinds_' -'jungen_' -'junge_' -'ischer_' -'ire_' -'ingen_' -'imposed_' -'implications_' -'ill_' -'ik_' -'ien' -'ians_' -'hostel_' -'hing_' -'halt_' -'ha' -'guarantees_' -'greatly_' -'gre' -'gradually_' -'gewissen_' -'gestalten_' -'gegenwärtigen_' -'gegenwärtig_' -'ge' -'gained_' -'fÃŧhrung_' -'furniture_' -'fruit_' -'forum_' -'fo' -'finanzieren_' -'fewer_' -'festgestellt_' -'false_' -'fairly_' -'eyes_' -'existence_' -'erwähnt_' -'ernst_' -'erfahren_' -'entf' -'en' -'elo' -'ek' -'eingehen_' -'einfacher_' -'ed' -'driven_' -'dramatic_' -'diplomatic_' -'dining_' -'deserves_' -'des' -'defend_' -'ded_' -'deb' -'credibility_' -'cra' -'cooperate_' -'continuing_' -'consideration_' -'considerably_' -'confirm_' -'con' -'command_' -'classes_' -'cher' -'channels_' -'ce' -'cash_' -'briefly_' -'bon' -'bild_' -'bezahlen_' -'beruht_' -'bedroom_' -'bed' -'be' -'ban' -'aut_' -'aufge' -'ative_' -'association_' -'armed_' -'ara' -'apart_' -'angeboten_' -'an' -'al' -'aktiv_' -'affects_' -'af' -'ae' -'accompanied_' -'accessible_' -'acceptable_' -'abgeschlossen_' -']]_' -'Zweiten_' -'Zwar_' -'Z_' -'Wunsch_' -'Wo' -'Wirkung_' -'Wirklichkeit_' -'Willen_' -'Wall_' -'Verträge_' -'Verbraucher' -'Verb' -'VO' -'Umfeld_' -'Two_' -'Test_' -'TNG_' -'SÃŧd' -'Strukturen_' -'Strategy_' -'Stellen_' -'Staats_' -'Sprach' -'Speicher' -'Sonnen' -'Son' -'Sohn_' -'Sitzung_' -'Schulen_' -'Schi' -'Sanktionen_' -'Quelle_' -'Puerto_' -'Pu' -'Problems_' -'Premierminister_' -'Palace_' -'Over_' -'Op' -'Nor_' -'Non_' -'Ni' -'Moment_' -'Mission_' -'Ministers_' -'Menge_' -'Location_' -'Li' -'League_' -'Lassen_' -'Labour_' -'LI' -'Kur' -'Kredite_' -'Kor' -'Kontroll' -'Klimaanlage_' -'Kirk_' -'Ke' -'Ka' -'JavaScript_' -'Israels_' -'Island_' -'Investitions' -'IR' -'IN' -'Hinweis_' -'Hin' -'Hilfs' -'Hauptstadt_' -'Handeln_' -'Grunde_' -'GlÃŧck_' -'GlaubwÃŧrdigkeit_' -'Geist_' -'Firma_' -'Fi' -'Facilities_' -'FA' -'Express_' -'Erklärungen_' -'Environment_' -'Entwurf_' -'End' -'Dor' -'Dimension_' -'Deutschlands_' -'Del' -'Dan' -'Cuba_' -'Control_' -'Check_' -'Cameron_' -'CAMBRIDGE_' -'Bus_' -'Brussels_' -'Bill_' -'Beschäftigungs' -'Ben' -'Bedenken_' -'Bay_' -'Ausweitung_' -'Ausnahme_' -'Association_' -'Argentinien_' -'Antrag_' -'Allgemeinen_' -'Ale' -'Akteure_' -'Ag' -'Ac' -': “_' -'800_' -'3D_' -'22' -'1992_' -'05_' -'03_' -'...' -''' -' > _' -'– _' -'č_' -'Ãŧhrung_' -'Ãŧbrigen_' -'Ãŧbertragen_' -'Ãļffnen_' -'ía_' -'zieren_' -'zen' -'wollten_' -'wesentliche_' -'wan' -'wagen_' -'vors' -'vorgesehen_' -'virus_' -'vier' -'verwende_' -'verfÃŧgbar_' -'vehicle_' -'val' -'update_' -'unfortunately_' -'unabhängig_' -'una_' -'ums' -'typical_' -'tly_' -'ther_' -'tent' -'technische_' -'tat' -'tar' -'sur_' -'supplies_' -'stri' -'sto' -'sti' -'stayed_' -'stance_' -'spoke_' -'soft_' -'sion' -'sel' -'sation_' -'russische_' -'rti' -'rie' -'ri_' -'revenues_' -'returned_' -'remind_' -'remained_' -'refugee_' -'qualität_' -'purposes_' -'pt' -'propose_' -'propaganda_' -'profound_' -'processing_' -'poorest_' -'pointed_' -'plenty_' -'platform_' -'patterns_' -'pattern_' -'para' -'owned_' -'ordinary_' -'opinions_' -'ol_' -'og' -'offenen_' -'occasions_' -'nte' -'nta' -'none_' -'nf' -'nar' -'my' -'museums_' -'mus' -'multiple_' -'mu' -'movements_' -'motor' -'monitor_' -'momentan_' -'mo' -'ministers_' -'micro' -'menschliche_' -'mbi' -'ma' -'luxurious_' -'logical_' -'lly_' -'listen_' -'lig' -'ließ_' -'lier' -'li_' -'lending_' -'kÃŧnftige_' -'kulturellen_' -'kulturelle_' -'kostet_' -'kleiner_' -'klare_' -'keeping_' -'kat' -'joined_' -'it' -'ists_' -'investiert_' -'initially_' -'inevitable_' -'incentives_' -'immigrants_' -'ility_' -'ile_' -'ie' -'identified_' -'hs_' -'hon' -'hervor_' -'heavy_' -'happening_' -'hand' -'grundlegenden_' -'gewährleistet_' -'gewählt_' -'gestärkt_' -'geo' -'gender_' -'gegrÃŧndet_' -'gefordert_' -'frÃŧheren_' -'freue_' -'fel' -'favor_' -'falling_' -'ey_' -'explain_' -'ethnic_' -'erinnert_' -'eri' -'erforderlichen_' -'erfolgt_' -'erf' -'equity_' -'entitled_' -'engineering_' -'enables_' -'ena' -'emi' -'ement_' -'em' -'elle_' -'eliminate_' -'eli' -'ek_' -'effectiveness_' -'economists_' -'drink_' -'diverse_' -'dishes_' -'dis' -'dient_' -'didn_' -'daraus_' -'danke_' -'dal' -'dahin_' -'cy' -'customer_' -'coup_' -'cor' -'controlled_' -'connections_' -'conducted_' -'conclusions_' -'compensation_' -'client_' -'classic_' -'ckt_' -'chosen_' -'bor' -'blood_' -'bildung_' -'bildet_' -'belief_' -'availability_' -'ausgezeichneten_' -'aufnehmen_' -'assume_' -'art' -'arguments_' -'argued_' -'appointed_' -'ans' -'anbieten_' -'amounts_' -'allgemeine_' -'allgemein_' -'ai' -'ah' -'advantages_' -'adoption_' -'adjustment_' -'ade' -'Ziffer_' -'Zeichen_' -'Within_' -'Wahrheit_' -'WLAN_' -'WI' -'VÃļlker_' -'Vom_' -'Vienna_' -'Verteidigung_' -'Unternehmens_' -'Universität_' -'Teil' -'Taiwan_' -'Stein' -'Stadtzentrum_' -'Sowjetunion_' -'Selbst' -'Schäden_' -'Sat' -'RÃŧckgang_' -'Romania_' -'Roma_' -'Regelungen_' -'Referendum_' -'Rechts_' -'Raum' -'Q' -'Protokoll_' -'Priorität_' -'Port' -'Plaza_' -'Player_' -'Peter_' -'Paul_' -'Paper_' -'Palestinians_' -'Palestine_' -'Page_' -'Pact_' -'Or_' -'Ob' -'ON' -'Nie' -'Nicht' -'Nevertheless_' -'Nachrichten_' -'Muslims_' -'Munich_' -'Mindest' -'Mess' -'Mari' -'Malaysia_' -'Mac' -'La' -'Kunst' -'Korruption_' -'Klima' -'Kind_' -'Kil' -'Key' -'Katastrophe_' -'Java_' -'Ja_' -'Israelis_' -'Ins' -'Innovationen_' -'Inflations' -'Hy' -'Hill_' -'Gu' -'Grenz' -'GmbH_' -'Gleichgewicht_' -'Gen' -'Garten_' -'Form' -'Firstly_' -'Film' -'Far' -'Fa' -'Entwicklungen_' -'Entscheidungsträger_' -'Energien_' -'Einklang_' -'ER' -'Drittel_' -'Chile_' -'Cap' -'CIA_' -'CAN' -'BÃŧrger' -'Brexit_' -'Botschaft_' -'Bor' -'Bewertung_' -'Besides_' -'BedÃŧrfnisse_' -'Bar' -'Bade' -'Ava' -'Außen' -'Austrian_' -'Aussch' -'Aufgaben_' -'Apartment_' -'Angaben_' -'Analyse_' -'Ak' -'Agency_' -'Adresse_' -'Absicht_' -'Abgeordnete_' -'52_' -'49_' -'37_' -'.  _' -'."_' -'-/_' -'," _' -' â‚Ŧ _' -'Đ´' -'Ãŧn' -'Ãļl' -'zustimmen_' -'zugunsten_' -'zentralen_' -'yes_' -'ws_' -'worst_' -'worry_' -'wonach_' -'wireless_' -'wieder' -'wi' -'welt' -'weiter' -'weather_' -'weapon_' -'warten_' -'wage_' -'vorgesehenen_' -'vorge' -'volle_' -'verpflichten_' -'ver' -'underlying_' -'umfangreiche_' -'uh' -'typically_' -'turns_' -'tter' -'traditionelle_' -'totally_' -'topic_' -'tief_' -'tel_' -'tan_' -'tage_' -'tag_' -'surrounded_' -'sun' -'substance_' -'su_' -'stag' -'sst_' -'spricht_' -'sprach_' -'spot_' -'somewhat_' -'sol' -'sk' -'sinnvoll_' -'sichern_' -'si' -'seite_' -'schwierigen_' -'schriftlich_' -'schnelle_' -'rweise_' -'rum_' -'rou' -'rose_' -'richtigen_' -'resolved_' -'resistance_' -'residential_' -'reply_' -'removed_' -'regimes_' -'reflects_' -'recession_' -'reagieren_' -'rd' -'pushed_' -'prä' -'properly_' -'prop' -'promises_' -'prior_' -'preis_' -'pra' -'potentially_' -'pleasure_' -'personally_' -'permitted_' -'percent_' -'pension_' -'parks_' -'parallel_' -'osi' -'operate_' -'op_' -'ona_' -'olin' -'oe' -'occasion_' -'nÃļtig_' -'noted_' -'normalerweise_' -'nights_' -'nc' -'mÃŧssten_' -'mÃŧsste_' -'mÃļgliche_' -'mÃļ' -'multilateral_' -'mini' -'migrants_' -'mes_' -'mental_' -'mel' -'match_' -'lt' -'losing_' -'looked_' -'lived_' -'les' -'la' -'konzentriert_' -'ki_' -'keineswegs_' -'ized_' -'italienischen_' -'ita' -'israelischen_' -'islands_' -'interface_' -'interessiert_' -'ino_' -'ine' -'indicate_' -'index_' -'in' -'import_' -'ign' -'ight_' -'iel' -'hÃļchste_' -'hur' -'http_' -'hren_' -'homes_' -'hilft_' -'helps_' -'gutes_' -'grown_' -'gor' -'globalisation_' -'glo' -'gh_' -'genauso_' -'gelegt_' -'gele' -'gegenwärtige_' -'gan' -'gal' -'fun_' -'frei' -'fore' -'foot_' -'fonds_' -'flow_' -'flat_' -'fer_' -'feld_' -'fees_' -'fashion_' -'faith_' -'excessive_' -'ete' -'essentially_' -'esi' -'es' -'erstellt_' -'erm' -'erleben_' -'era' -'er' -'ending_' -'electronic_' -'einzu' -'educational_' -'east_' -'driver_' -'drawn_' -'diversity_' -'disputes_' -'diseases_' -'dimension_' -'derart_' -'dep' -'dem' -'delivered_' -'dar' -'ction_' -'cri' -'crew_' -'counter' -'corporations_' -'convenient_' -'constantly_' -'considering_' -'conservative_' -'concluded_' -'commodity_' -'cold_' -'church_' -'choices_' -'charming_' -'cal' -'ca' -'burg_' -'browser_' -'breite_' -'blocks_' -'blo' -'blind_' -'bio' -'bility_' -'betroffen_' -'besitzen_' -'besch' -'bequem_' -'bekannten_' -'begonnen_' -'beginnt_' -'bee' -'bedeutende_' -'badly_' -'ausgesetzt_' -'ausger' -'audio_' -'arbeit_' -'applies_' -'applicable_' -'appeal_' -'andererseits_' -'ancient_' -'amerikanischer_' -'allies_' -'ali' -'alen_' -'akzeptiert_' -'aircraft_' -'afrikanischen_' -'affect_' -'advertising_' -'abkommen_' -']]._' -'Zusammenbruch_' -'Zug_' -'Zu' -'Yu' -'Yes_' -'Who_' -'Web' -'War' -'Vorschlägen_' -'Vorbereitung_' -'Verständnis_' -'Verf' -'Verbesserungen_' -'Venezuela_' -'Var' -'Us' -'Umfang_' -'USB_' -'Tru' -'Ton' -'Taliban_' -'Städten_' -'Stellung_' -'Stability_' -'Schwerpunkt_' -'School_' -'Sau' -'Samsung_' -'Saint_' -'SS_' -'Rom' -'Road_' -'Regulierungs' -'Regimes_' -'Reform' -'Rechnung_' -'Real' -'Ratsvorsitz_' -'Problemen_' -'Prinzipien_' -'Plan' -'Pla' -'Ph' -'Perspektive_' -'Partnership_' -'Pacific_' -'PPE_' -'Optionen_' -'Open' -'ON_' -'Nr_' -'Note_' -'Nos_' -'Nord' -'Nigeria_' -'News_' -'Nein_' -'Nacht' -'Multi' -'Mol' -'Mao_' -'MEPs_' -'Lin' -'Liefer' -'Leistungs' -'Laufe_' -'LA' -'Kurz_' -'Kurz' -'Kopf_' -'Konvent_' -'Kan' -'Je_' -'Jahrzehnten_' -'Jacques_' -'Internationale_' -'IC' -'Hostelsclub_' -'Holz' -'Haushalte_' -'Haus' -'Hamburg_' -'Gipfel_' -'Gesch' -'Gerichtshof_' -'Gemeinschafts' -'Gegner_' -'Geb' -'FrÃŧhstÃŧcksbuffet_' -'Frankreichs_' -'Fonds_' -'Finanzkrise_' -'Fin' -'ErhÃļhung_' -'Einheit_' -'EA' -'Due_' -'Doha_' -'Disk' -'Datenbank_' -'Cold_' -'Class_' -'Charter_' -'Canada_' -'Bos' -'Blog_' -'Betriebs' -'Berg' -'Bedarf_' -'Ball' -'BI' -'Außenpolitik_' -'Audio_' -'Asyl' -'Arbeitsplätzen_' -'Arbeiter_' -'Anreize_' -'Al' -'Aktion_' -'Abschließend_' -'AI' -'AB' -'://_' -'85_' -'60' -'6' -'42_' -'41_' -'24' -'1967_' -'12' -'08_' -'03' -'/ _' -'/' -') - _' -'() _' -'!' -' )._' -'”, _' -'Đē' -'änge_' -'Ü' -'zuzu' -'zust' -'zunehmenden_' -'zuerst_' -'zu' -'zit' -'zi' -'zer_' -'wonderful_' -'wishes_' -'wisdom_' -'wirtschaftlich_' -'wirken_' -'winter_' -'willkommen_' -'wider' -'whi' -'weit' -'wachsenden_' -'vorstellen_' -'vorgeschlagenen_' -'vollen_' -'visitors_' -'visited_' -'visible_' -'verändern_' -'verurteilt_' -'ute' -'usly_' -'urs_' -'unmittelbar_' -'unlike_' -'unity_' -'understood_' -'unc' -'twenty_' -'trip_' -'treat_' -'tele' -'symbol_' -'sur' -'supporters_' -'submit_' -'store_' -'steigern_' -'sser' -'spi' -'speakers_' -'solve_' -'solid_' -'solely_' -'sin' -'siert_' -'side' -'sh' -'sf' -'sell_' -'sei' -'seeks_' -'seat_' -'sea' -'scho' -'schie' -'schi' -'saved_' -'sauna_' -'rÃŧ' -'ruling_' -'rufen_' -'rte_' -'route_' -'rit' -'ris' -'restructuring_' -'restore_' -'reporting_' -'remaining_' -'rely_' -'reflected_' -'ree' -'rece' -'rec' -'read' -'rb' -'raw_' -'raten_' -'rank' -'ra' -'pri' -'prefer_' -'ports_' -'plays_' -'pie' -'phi' -'pha' -'perform_' -'patent_' -'palästinensischen_' -'ow' -'oo' -'onto_' -'oni' -'om' -'ogen_' -'obtain_' -'northern_' -'nl' -'nie' -'neighbors_' -'necessarily_' -'nd' -'natÃŧrlichen_' -'nal_' -'nahmen_' -'mont' -'mon_' -'mini_' -'meant_' -'meaningfully_' -'meals_' -'manchen_' -'macroeconomic_' -'loo' -'liked_' -'lie_' -'liberalisation_' -'let' -'lem' -'leistung_' -'lea' -'layer_' -'kt_' -'kostenlos_' -'konfrontiert_' -'kel' -'kas' -'jemand_' -'japanische_' -'iva' -'items_' -'intention_' -'informieren_' -'inform_' -'ina_' -'ideological_' -'ics_' -'ian' -'hängt_' -'hor' -'hof_' -'hl_' -'hinzu' -'hingegen_' -'hall_' -'hab' -'grÃļßer_' -'groß_' -'grammatically_' -'geÃļffnet_' -'gez' -'gering_' -'genutzt_' -'generated_' -'genannte_' -'gemein' -'gang_' -'fÃŧhlen_' -'fying_' -'fy_' -'fs_' -'frÃŧhen_' -'freundlich_' -'freiheit_' -'flo' -'flights_' -'firstly_' -'fertig_' -'fehlt_' -'feeling_' -'feed_' -'farming_' -'facto_' -'eva' -'estimates_' -'ero' -'erheblich_' -'enti' -'enth' -'enta' -'eng_' -'end' -'ements_' -'eller_' -'electricity_' -'einzigartigen_' -'einst_' -'eg' -'eastern_' -'ear' -'durchzufÃŧhren_' -'drop_' -'division_' -'dit' -'discipline_' -'det_' -'derer_' -'deiner_' -'dat' -'dans_' -'dan' -'cted_' -'correct_' -'copy_' -'convergence_' -'contextually_' -'conditioning_' -'condition_' -'concrete_' -'components_' -'citizen_' -'cies_' -'cht' -'chen' -'cable_' -'broken_' -'breiten_' -'bond_' -'ble' -'birth_' -'bez' -'bes_' -'behalten_' -'bas' -'baby_' -'außen_' -'ausb' -'aufgezahlt_' -'atz_' -'att' -'ath' -'arti' -'armen_' -'arise_' -'arabischen_' -'album_' -'akt' -'ag_' -'afternoon_' -'afford_' -'ace' -'abgesch' -'\\_' -'Zo' -'Year_' -'Wieder' -'Weitere_' -'Weiter' -'Wa' -'Videos_' -'Verfassungs' -'Verein' -'Verbreitung_' -'Unterschiede_' -'Though_' -'Terroristen_' -'Teilnahme_' -'TI' -'TE' -'Straßen_' -'Sti' -'Stattdessen_' -'Spiele_' -'Sin' -'Sim' -'Schluss_' -'Schau' -'SP' -'Ru' -'Ratschlag_' -'PrÃŧfung_' -'Programms_' -'Pool_' -'Pf' -'Partnerschaft_' -'Palästinenser_' -'Ost' -'Os' -'Ol' -'Northern_' -'Nordkorea_' -'Neben' -'Nation_' -'Moscow_' -'Mittelmeer_' -'Mittel' -'Minderheiten_' -'Mikro' -'Mein_' -'Mc' -'Maßnahme_' -'Lord_' -'Like_' -'Law_' -'KÃļrper_' -'Krisen_' -'Krisen' -'Konsens_' -'Komfort_' -'Klima_' -'Kla' -'KDE_' -'Internationalen_' -'Hal' -'Haft' -'HE' -'Geschäfte_' -'Gas' -'Gang_' -'Frei' -'Fra' -'Fo' -'Fla' -'Finde_' -'FO' -'Exporte_' -'Empire_' -'Eisenbahn' -'Einwanderung_' -'Einstellungen_' -'Einstellung_' -'Einkommens' -'Ein' -'Dynamik_' -'Druck' -'Din' -'Dienstleistungs' -'Details_' -'Dem_' -'Datei' -'DO' -'DC_' -'Consider_' -'Company_' -'Com' -'Clo' -'Charles_' -'Char' -'Chairman_' -'Card' -'Cam' -'CI' -'CF' -'C5_' -'Bro' -'Block' -'Bes' -'Belarus_' -'Beim_' -'Begriff_' -'Bauern_' -'Bal' -'Back' -'BE' -'Auto_' -'As' -'Arbeitsmarkt_' -'Apple_' -'Any_' -'Ant' -'Angelegenheit_' -'Amendments_' -'Aktionen_' -'Aktien' -'Adobe_' -'Ab' -'AN' -'; ' -'65_' -'350_' -'34_' -'250_' -'2030_' -'1990er_' -'16' -'07_' -') (_' -'"' -' â€Ļ.. _' -'ҁ҂' -'ÃŧberprÃŧft_' -'Ãē' -'ès_' -'ätze_' -'ÃŖo_' -'Übersetzung_' -'Übergang_' -'Ägypten_' -'zählen_' -'zuständig_' -'zusammen' -'zuges' -'zt_' -'ziemlich_' -'ystem_' -'ye_' -'wonder_' -'wohnen_' -'wissenschaftlichen_' -'wirkt_' -'wirksam_' -'windows_' -'wiederholt_' -'whereby_' -'wesen_' -'welt_' -'welchem_' -'we' -'waiting_' -'wachsende_' -'vis' -'vielmehr_' -'via' -'verstärken_' -'verehrte_' -'variable_' -'unre' -'university_' -'ule' -'uk' -'ug_' -'uchen_' -'uch' -'ua' -'tÃŧrkischen_' -'tÃŧ' -'tut_' -'ture_' -'tten_' -'trifft_' -'treiben_' -'tors_' -'tom' -'tik_' -'tig' -'teil' -'tau' -'sämtliche_' -'surplus_' -'stu' -'stru' -'strongly_' -'strengthening_' -'strength_' -'streets_' -'storage_' -'stoff' -'stic_' -'sterben_' -'stellte_' -'spring_' -'spectrum_' -'soweit_' -'son' -'sofern_' -'sma' -'sla' -'sing' -'sieren_' -'showing_' -'she' -'shares_' -'settlement_' -'serviert_' -'selten_' -'schlägt_' -'satellite_' -'sale_' -'ry' -'rop' -'rma' -'rin' -'responses_' -'replaced_' -'replace_' -'renewable_' -'religiÃļsen_' -'religion_' -'reliable_' -'regularly_' -'regeln_' -'regardless_' -'reconstruction_' -'reb' -'realistic_' -'rated_' -'rare_' -'ran_' -'quer' -'push_' -'purchases_' -'pur' -'prove_' -'proposing_' -'promised_' -'previously_' -'preparation_' -'premi' -'ppi' -'posed_' -'pleasant_' -'persÃļnlich_' -'personnel_' -'periods_' -'percentage_' -'perceived_' -'passiert_' -'participants_' -'overcome_' -'ov' -'ose_' -'os' -'organised_' -'opt_' -'operational_' -'op' -'oma' -'ok' -'og_' -'offices_' -'offen' -'oc' -'obligation_' -'ns' -'nochmals_' -'nk_' -'nine_' -'nge' -'mächtig' -'mut' -'mobility_' -'ml' -'meter_' -'mea' -'markt' -'manufacturers_' -'managing_' -'maintenance_' -'locations_' -'lla' -'lists_' -'lf' -'legacy_' -'leaves_' -'lay_' -'lautet_' -'kurze_' -'kur' -'kontrollieren_' -'knows_' -'key' -'jährigen_' -'judicial_' -'jpg_' -'joining_' -'jenen_' -'itt_' -'ist' -'iranischen_' -'involve_' -'invest_' -'integriert_' -'installed_' -'ines_' -'indicated_' -'inden_' -'ina' -'identify_' -'ible_' -'hä' -'hn' -'hilfreich_' -'highlight_' -'hel' -'heißen_' -'heim' -'greifen_' -'grateful_' -'governing_' -'ght_' -'gge' -'geringe_' -'geraten_' -'gep' -'fuels_' -'frequently_' -'free' -'foundation_' -'for' -'finanziert_' -'fen' -'female_' -'fell_' -'fears_' -'fate_' -'failing_' -'explore_' -'exciting_' -'escape_' -'erlauben_' -'erklärte_' -'ering_' -'erfordern_' -'erfolg' -'entscheidender_' -'entscheidende_' -'ents_' -'enf' -'enabling_' -'els_' -'eiten_' -'einzelne_' -'einfache_' -'eil' -'ehen_' -'echten_' -'echn' -'dy_' -'dro' -'discussing_' -'disaster_' -'direkten_' -'ding' -'derartigen_' -'depending_' -'dependent_' -'del' -'definiert_' -'decisive_' -'ct_' -'crimes_' -'credible_' -'creative_' -'constitute_' -'cons' -'concentrate_' -'computers_' -'collection_' -'collaboration_' -'coast_' -'clubs_' -'claimed_' -'cer' -'causing_' -'catch_' -'bur' -'budgets_' -'bro' -'bridge_' -'brand_' -'brachte_' -'bi_' -'bezeichnet_' -'beschränkt_' -'beinhaltet_' -'beenden_' -'bau' -'bathrooms_' -'basiert_' -'ball_' -'author_' -'austerity_' -'ausländischen_' -'ausgehen_' -'ates_' -'argument_' -'are' -'approximately_' -'ant' -'annehmen_' -'angel' -'angegebenen_' -'ange' -'anf' -'ambitious_' -'alternatives_' -'alis' -'ain_' -'ain' -'adequate_' -'aci' -'aba' -'Zustand_' -'Zahlen_' -'Will_' -'Wien_' -'Where_' -'Webseite_' -'Wald' -'Wachstums' -'WE' -'Vorstellung_' -'Voraussetzung_' -'Vice_' -'Vi' -'Veränderung_' -'Verteidigungs' -'Verst' -'Verringerung_' -'Verbot_' -'Valencia_' -'Val' -'Uni' -'UR' -'Titel_' -'The' -'Thailand_' -'Syrian_' -'Switzerland_' -'Super' -'Stärke_' -'Straf' -'Stra' -'Steuerzahler_' -'Stellungnahme_' -'Star' -'Spa' -'Ski' -'Similarly_' -'Sierra_' -'Shanghai_' -'Sep' -'Schutz' -'Scheitern_' -'Sam' -'Russlands_' -'Runde_' -'Roll' -'Rei' -'Rechten_' -'Ratspräsidentschaft_' -'Quellen_' -'Prinzip_' -'Premier_' -'Preisen_' -'Plat' -'Place_' -'Pat' -'Para' -'Orten_' -'Organe_' -'Once_' -'OK_' -'Ne' -'Natur' -'Nacht_' -'Märkten_' -'Mon' -'Ministerpräsident_' -'Menschenrechts' -'Menschen' -'Medi' -'Martin_' -'MS_' -'ME' -'Luc' -'Loc' -'Licht_' -'Lebens_' -'Laut_' -'LONDON_' -'LL' -'Kulturen_' -'Kre' -'Konvention_' -'Konsequenzen_' -'Jordan_' -'Jedes_' -'Jede_' -'Jean_' -'Jahrzehnt_' -'Inn_' -'Immobilien' -'Ihres_' -'IP' -'Hi' -'Herstellung_' -'Herrschaft_' -'Having_' -'Hauptbahnhof_' -'Großteil_' -'Gew' -'Ges' -'Gerichte_' -'Georgia_' -'General' -'Ge' -'Gas_' -'GNU_' -'GA' -'Fuß_' -'Freuen_' -'Foto' -'Finanzierungs' -'Fernseh' -'Fer' -'Fenster_' -'Fas' -'Experten_' -'Energy_' -'Elemente_' -'ES_' -'Durch' -'Does_' -'Democratic_' -'Dass_' -'Cyprus_' -'College_' -'Chris_' -'Card_' -'Burma_' -'Burg' -'Bur' -'Bulgarien_' -'Buch' -'Blo' -'Beschluss_' -'Bereitschaft_' -'Beiträge_' -'Beide_' -'Be' -'Bars_' -'Armen_' -'Arm' -'Anpassung_' -'Anlage_' -'And' -'Ambiente_' -'Alles_' -'All' -'Air' -'Action_' -'Abst' -'Abe_' -'A3' -'?' -'97_' -'700_' -'50' -'47_' -'26' -'193' -'187' -'09_' -'02' -'.“_' -'): _' -'({{_' -'(' -'''_' -'' ' -'$_' -'”' -'’, _' -'‘_' -'Ņ€' -'ÅĄ' -'Ãŧsse_' -'Ãŧberw' -'Ãŧbernommen_' -'´_' -' – _' -'zweifellos_' -'zer' -'zeiten_' -'zeichen_' -'zar' -'yp' -'y' -'wood_' -'wirkliche_' -'wiederholen_' -'werte_' -'waters_' -'wake_' -'vorher_' -'vorbei_' -'volume_' -'vit' -'virtually_' -'vi' -'verursacht_' -'verteilt_' -'verstanden_' -'versi' -'vern' -'valid_' -'urge_' -'uren_' -'unterge' -'uncertainty_' -'ul' -'uc' -'tun' -'tri' -'transportation_' -'tons_' -'time' -'tie' -'tens_' -'tea_' -'tari' -'tackle_' -'systemen_' -'surf' -'sum_' -'suitable_' -'successfully_' -'succeeded_' -'stärkere_' -'strike_' -'steigende_' -'sso' -'sse' -'spect' -'specified_' -'soil_' -'sit_' -'sharply_' -'sexual_' -'selbst' -'secret_' -'sd' -'scientists_' -'sches_' -'ruhig_' -'rom' -'ride_' -'richtlinie_' -'rg' -'rfen_' -'rf' -'returns_' -'respected_' -'reserve_' -'res' -'requirement_' -'represented_' -'repa' -'rent' -'remarkable_' -'rel' -'reich_' -'rei' -'regionaler_' -'reflection_' -'recognized_' -'reaction_' -'quit' -'quantitative_' -'prospect_' -'proposes_' -'proceedings_' -'prec' -'pose_' -'populations_' -'plenary_' -'pit' -'photograph' -'phones_' -'passengers_' -'passenger_' -'outstanding_' -'ously_' -'ought_' -'orders_' -'ong_' -'offiziell_' -'offene_' -'occurred_' -'ny_' -'nti' -'notice_' -'not' -'normalen_' -'nord' -'noise_' -'nnen_' -'nken_' -'nke' -'niedrigen_' -'nicht' -'ngs_' -'ner' -'natur' -'na' -'mountains_' -'mmel' -'mittels_' -'mittel' -'messen_' -'mer' -'men' -'medicines_' -'me' -'mb' -'marquis_' -'mann_' -'manager_' -'magnificent_' -'ländern_' -'ls' -'like' -'ließen_' -'lie' -'length_' -'lam' -'kraft_' -'komp' -'kol' -'km' -'klein_' -'kes_' -'kernel_' -'karte_' -'jÃŧngste_' -'issued_' -'israelische_' -'iso' -'isi' -'iron_' -'involves_' -'intelligence_' -'insufficient_' -'insist_' -'input_' -'implies_' -'illi' -'idi' -'ide' -'holds_' -'hit' -'hinder' -'hierbei_' -'heraus' -'hed_' -'hearing_' -'harmful_' -'ham_' -'griechische_' -'gn' -'gewähren_' -'gewa' -'gespielt_' -'gese' -'ges' -'geplant_' -'gelÃļst_' -'gelungen_' -'gek' -'gardens_' -'fÃļrdert_' -'freuen_' -'formen_' -'finde_' -'films_' -'filled_' -'festgelegt_' -'fest' -'fare_' -'familiar_' -'failures_' -'facilitate_' -'excess_' -'exceptional_' -'este' -'esta' -'ess' -'esch' -'erÃļffnet_' -'erstmals_' -'erreichbar_' -'erfÃŧllt_' -'ered_' -'ere' -'erbe' -'entsch' -'enten_' -'enjoyed_' -'eni' -'enhanced_' -'engine_' -'enabled_' -'emp' -'emotional_' -'ell' -'elite_' -'einerseits_' -'eh' -'egen_' -'ege' -'eco' -'earth_' -'durch' -'dritte_' -'dispute_' -'disco' -'dic' -'destination_' -'demonstrate_' -'def' -'dd' -'critics_' -'copyright_' -'cope_' -'convert_' -'containing_' -'contained_' -'constant_' -'cken' -'chu' -'centres_' -'cells_' -'cc' -'castle_' -'capita_' -'buying_' -'bul' -'bri' -'bie' -'beträgt_' -'betroffenen_' -'beste' -'besitzt_' -'beschränken_' -'beschleunigen_' -'bemÃŧhen_' -'bekommt_' -'beige' -'behoben_' -'behaupten_' -'behandeln_' -'bedeutenden_' -'backed_' -'ausl' -'ausgesprochen_' -'aufmerksam_' -'attractions_' -'attitude_' -'ator_' -'ast' -'assure_' -'aspect_' -'app' -'anywhere_' -'ante_' -'angesprochen_' -'anderswo_' -'amenities_' -'aggressive_' -'adults_' -'adapt_' -'acts_' -'actors_' -']]' -'Zugleich_' -'Zug' -'Zimmern_' -'Zahlungs' -'Yo' -'WÃŧrde_' -'Währungen_' -'Wirtschafts_' -'Widerstand_' -'Wesentlichen_' -'Welche_' -'Wei' -'Vorsitzenden_' -'Volkes_' -'Verz' -'Versorgung_' -'Verlust_' -'Verkehr_' -'Verhandlungs' -'Verfolgung_' -'VerbÃŧndeten_' -'Vel' -'Valley_' -'Ums' -'USD_' -'Trotzdem_' -'Tiere_' -'Three_' -'Third_' -'Thanks_' -'Test' -'Terrasse_' -'Termin' -'Telefon' -'Tar' -'Tages_' -'TO' -'TH' -'Support_' -'Suite_' -'Sub' -'Stunde_' -'Statt_' -'Sony_' -'Soldaten_' -'Sind_' -'Set_' -'Sektoren_' -'Science_' -'Schule_' -'Schiff_' -'Sc' -'SM' -'SA_' -'River_' -'Rhetorik_' -'Regional_' -'RT' -'Projekt' -'Produkten_' -'Potenzial_' -'Pol' -'Poker_' -'Planeten_' -'Pas' -'Pala' -'PRO' -'PD' -'Off' -'Ob_' -'Nu' -'Nonetheless_' -'National' -'NGOs_' -'Mus' -'Mos' -'Morocco_' -'Module_' -'Mitgliedstaat_' -'Met' -'MenÃŧ_' -'Menschheit_' -'Mark_' -'Mac_' -'Len' -'Leitung_' -'Lebensmittel' -'Lang' -'KÃŧste_' -'Kritiker_' -'Koalition_' -'Kluft_' -'Klasse_' -'Kir' -'Kilometer_' -'Kern' -'Ken' -'Kat' -'Justiz_' -'Jerusalem_' -'Japans_' -'Iraqi_' -'Intervention_' -'Index_' -'IS_' -'IL' -'IG' -'Hindernisse_' -'Henry_' -'Hea' -'Grundsatz_' -'Großen_' -'Gelder_' -'Geld' -'Gegenstand_' -'Gegen' -'GE' -'FÃŧhrer_' -'Funds_' -'Fun' -'Freunde_' -'Fotos_' -'Fischerei_' -'Ferner_' -'Federation_' -'Eltern_' -'Ele' -'Education_' -'Ed' -'ET' -'Distribution_' -'Dis' -'Digital_' -'Desk_' -'Design' -'Des' -'Den' -'Declaration_' -'Constitutional_' -'Common_' -'Climate_' -'Cha' -'Captain_' -'Call' -'Bug' -'Blue_' -'Black_' -'Big_' -'Bewältigung_' -'BevÃļlkerungs' -'Beteiligung_' -'Bestandteil_' -'BerÃŧcksichtigung_' -'Bemerkung_' -'Beijing_' -'Auftrag_' -'Ap' -'Anst' -'Anschluss_' -'Anliegen_' -'Anleihen_' -'Anlass_' -'Angebote_' -'Anders_' -'Anbetracht_' -'Ai' -'Agrar' -'Agentur_' -'Acc' -'Ablehnung_' -'AS' -'AP' -':' -'79' -'237' -'23' -'1993_' -'1986_' -'00' -'-' -'); _' -') – _' -' !_' -'â€ĸ _' -'’._' -'҃' -'Ãŧberaus_' -'Ãŧ' -'Ãļr' -'älteren_' -' ' -'zählt_' -'zugänglich_' -'zieht_' -'zentren_' -'zahl_' -'zahl' -'youth_' -'yo' -'ym' -'ye' -'wächst_' -'work' -'win' -'whereas_' -'western_' -'weitgehend_' -'warning_' -'wa_' -'vorlegen_' -'voraus_' -'voluntary_' -'voller_' -'voll' -'vis_' -'violent_' -'vin' -'verwe' -'verw' -'verteidigen_' -'versions_' -'verleihen_' -'verh' -'valuable_' -'ursprÃŧnglichen_' -'urs' -'unwi' -'unterzeichnet_' -'unions_' -'ungsbe' -'unf' -'undoubtedly_' -'umwelt' -'ult' -'ug' -'tätig_' -'tum_' -'tti' -'tter_' -'tte' -'tst' -'trag' -'tours_' -'tobacco_' -'tier' -'tic' -'threatening_' -'terrible_' -'tern' -'tennis_' -'temporary_' -'teams_' -'taten_' -'taste_' -'tables_' -'sver' -'surprise_' -'surface_' -'suggested_' -'subsequent_' -'stylish_' -'steuer' -'stelle_' -'stein_' -'spezielle_' -'speichern_' -'spe' -'sole_' -'soldiers_' -'sme' -'slo' -'sive_' -'shadow_' -'seri' -'senken_' -'sending_' -'scenario_' -'sauber_' -'sai' -'rly_' -'rightly_' -'rig' -'ri' -'revolutionary_' -'reu' -'reputation_' -'repeated_' -'renowned_' -'relief_' -'rein' -'regret_' -'registration_' -'rede' -'recommended_' -'recommendations_' -'recht' -'reagiert_' -'react_' -'rating_' -'rat_' -'rasche' -'ras' -'ranging_' -'ps' -'präsentiert_' -'profile_' -'produziert_' -'pressures_' -'pres' -'prepare_' -'possibly_' -'possibilities_' -'pla' -'ph' -'pen' -'pays_' -'password_' -'passieren_' -'pass' -'pal' -'pace_' -'owners_' -'overlooking_' -'ordered_' -'oppose_' -'opi' -'omm' -'ome' -'ology_' -'oll' -'ole' -'oh' -'offizielle_' -'obtained_' -'nunmehr_' -'non' -'nin' -'neighbouring_' -'namen_' -'nachge' -'multi' -'motor_' -'mother_' -'minor_' -'minimal_' -'mid' -'mes' -'mba' -'marks_' -'machines_' -'ländlichen_' -'lug' -'loved_' -'lounge_' -'lobby_' -'lieber_' -'letzter_' -'ler' -'lb' -'lag' -'lad' -'kÃļnne_' -'kÃļ' -'kurz' -'kri' -'kre' -'knew_' -'klaren_' -'jemals_' -'jegliche_' -'ise' -'involvement_' -'invested_' -'introducing_' -'int' -'inl' -'inge' -'inde' -'incentive_' -'immense_' -'ila' -'ij' -'igu' -'iger' -'ified_' -'iet' -'households_' -'hielt_' -'heritage_' -'hergestellt_' -'heran' -'hardware_' -'gte_' -'grÃļßerer_' -'graphics_' -'granting_' -'gewonnen_' -'gesellschaft_' -'geschieht_' -'gerät_' -'gerichtet_' -'gent' -'generations_' -'gefÃļrdert_' -'gefährden_' -'geeignet_' -'geb' -'furnished_' -'fri' -'fossil_' -'financed_' -'fails_' -'factor_' -'fache' -'exploitation_' -'exclusion_' -'ewi' -'ever' -'evaluation_' -'eur_' -'etzt_' -'erstellen_' -'entscheidenden_' -'ente_' -'ensi' -'engagement_' -'ener' -'encouraging_' -'emerged_' -'emerge_' -'elli' -'ektor_' -'eite_' -'eit_' -'einzufÃŧhren_' -'einheitlichen_' -'eich' -'ehr' -'effektive_' -'edited_' -'edge_' -'eck' -'dung_' -'dream_' -'dre' -'direkte_' -'determination_' -'designs_' -'demokratischer_' -'demo' -'dead_' -'cut' -'cru' -'creates_' -'cosy_' -'corner_' -'conse' -'composed_' -'communist_' -'colours_' -'col' -'cks_' -'ciÃŗn_' -'chain_' -'cen' -'cameras_' -'boot_' -'boom_' -'blue_' -'blieb_' -'bla' -'bewa' -'betreffenden_' -'beteiligt_' -'besorgt_' -'beschäftigt_' -'bela' -'beigetragen_' -'behaviour_' -'behauptet_' -'begins_' -'bedrooms_' -'ba_' -'az' -'ausländische_' -'ausgezeichnete_' -'attempted_' -'ationen_' -'ati' -'ass' -'articles_' -'arian_' -'ari_' -'applying_' -'appe' -'anzuzeigen_' -'ante' -'ano' -'angen_' -'amp' -'amongst_' -'amm' -'ama' -'am' -'alo' -'alcohol_' -'ak' -'ah_' -'aging_' -'affairs_' -'adjustments_' -'acy_' -'acting_' -'ack' -'ach' -'abzu' -'absch' -'Zusammenhalt_' -'Zusammen' -'Zugriff_' -'Zuge_' -'Zivilgesellschaft_' -'Zeitung_' -'Xi_' -'Währungsunion_' -'Wr' -'Winter_' -'Wikitravel_' -'Wiki' -'Weltkrieg_' -'Wal' -'Vorbe' -'View_' -'Viel' -'VermÃļgens' -'Update_' -'Traditionen_' -'Tonnen_' -'Ticket_' -'Tibet_' -'Ti' -'Tests_' -'Tau' -'SÃŧdkorea_' -'Sus' -'Sudan_' -'Studio_' -'Strasbourg_' -'Steigerung_' -'Staff_' -'Staatsanleihen_' -'Spezialitäten_' -'Special_' -'Spannungen_' -'Sou' -'Sitz_' -'Sho' -'Serie_' -'Sen' -'Sel' -'Sektors_' -'Schätzungen_' -'Schlussfolgerungen_' -'Sche' -'Sach' -'Saa' -'SU' -'SO' -'Rumänien_' -'Ruf_' -'Regulierung_' -'Regierungskonferenz_' -'Rea' -'Rather_' -'Ram' -'RD' -'Que' -'Project_' -'Presse' -'Power_' -'Positionen_' -'Politik' -'Pläne_' -'Pay' -'Palästina_' -'Pal' -'Pakete_' -'PSE_' -'PL_' -'Ober' -'OECD_' -'Netzwerk_' -'Nero_' -'NI' -'Mädchen_' -'Mur' -'Morgen_' -'Moo' -'Mobile_' -'Mittelpunkt_' -'Mitgliedsstaaten_' -'Millennium_' -'Mensch_' -'Mehr_' -'Mehr' -'Meanwhile_' -'Mau' -'Maria_' -'Malaria_' -'Mah' -'Maastricht_' -'Luft_' -'Lie' -'LabVIEW_' -'KÃŧsten' -'Korean_' -'Koordinierung_' -'Kooperation_' -'Kongress_' -'Kohle' -'Kim_' -'Keine_' -'Karten_' -'KE' -'Jahrzehnte_' -'Iranian_' -'Interessen' -'Institution_' -'Innenstadt_' -'Indonesia_' -'IM' -'IF' -'IA' -'Hungary_' -'Hol' -'Hersteller_' -'Has' -'Haftung_' -'HTML_' -'GrÃŧndung_' -'Growth_' -'Governments_' -'Golden_' -'Glauben_' -'Gesetzgebung_' -'Gerechtigkeit_' -'Gef' -'Gedanken_' -'Gebiete_' -'Garden_' -'Gara' -'Ga' -'Fälle_' -'Fur' -'Freude_' -'Fort' -'Format_' -'Fle' -'Finnish_' -'Fil' -'Fat' -'Fal' -'Existenz_' -'Est' -'Einzel' -'Einwohner_' -'Effizienz_' -'EM' -'EF' -'Deutschen_' -'Des_' -'Department_' -'Denmark_' -'Dec' -'De' -'Davos_' -'Darfur_' -'Cy' -'Cro' -'Cra' -'Civil_' -'Chinesen_' -'Camp' -'Cab' -'CE_' -'CC' -'BÃļrse_' -'Budgets_' -'Bru' -'Box_' -'Bla' -'Besucher_' -'Besitz' -'BeschlÃŧsse_' -'Belgium_' -'Beitritts' -'Balkans_' -'Bahnhof_' -'Bag' -'B5_' -'Automobil' -'Aufnahme_' -'Atomwaffen_' -'At' -'Are_' -'Arbeiten_' -'Anzeichen_' -'Anti_' -'Ansichten_' -'Anlagen_' -'Anhänger_' -'Alter_' -'Abhängigkeit_' -'78_' -'72_' -'55_' -'43_' -'38_' -'31' -'27' -'1945_' -'180_' -'17' -'04_' -'.“ _' -'...._' -'''' -'" ' -' ), _' -' $ _' -'” – _' -'” (_' -'Đģ' -'Đĩ_' -'Åž' -'Ãŧssen_' -'Ãŧh' -'Ãŧberzeugen_' -'Ãŧberrascht_' -'Ãŗ_' -'Ãąa_' -'Êt' -'änder_' -'ßt_' -'ßer' -'Übereinkommen_' -'Österreich_' -'Öffnung_' -'{{_' -'zwei' -'zus' -'zulassen_' -'zin' -'yi' -'yer_' -'woman_' -'willingness_' -'wife_' -'welchen_' -'ware_' -'wahr' -'vorliegenden_' -'voran' -'vert' -'verschärft_' -'verschiedener_' -'verringert_' -'verhindert_' -'venture_' -'vel' -'ux_' -'ust_' -'usi' -'ursprÃŧnglich_' -'uri' -'urgently_' -'upgrade_' -'unterschiedliche_' -'untern' -'unstable_' -'unpa' -'universities_' -'undi' -'un' -'uld' -'uer_' -'uel' -'tru' -'troops_' -'trip' -'trial_' -'towns_' -'top' -'tm' -'tisch_' -'threshold_' -'tested_' -'territories_' -'teile_' -'teachers_' -'taxpayers_' -'tas' -'targeted_' -'tant' -'sustainability_' -'suites_' -'sucht_' -'ständigen_' -'stä' -'studio_' -'stocks_' -'steel_' -'statistics_' -'stadt_' -'stabile_' -'spezifischen_' -'spezifische_' -'spar' -'spaces_' -'sounds_' -'sor' -'some' -'sna' -'smooth_' -'smo' -'sli' -'sleep_' -'simultaneously_' -'sim' -'selben_' -'seeing_' -'sec' -'schÃļne_' -'schutz' -'sau' -'sand' -'san' -'sam_' -'run' -'ruled_' -'ruhigen_' -'rts_' -'roll_' -'rken_' -'rk' -'ress_' -'respective_' -'reso' -'rer' -'requiring_' -'requested_' -'representing_' -'repeatedly_' -'render' -'rend' -'remarks_' -'religiÃļse_' -'regarded_' -'reg' -'rechtzeitig_' -'rechtlichen_' -'recall_' -'realen_' -'reale_' -'rapporteurs_' -'ragen_' -'punkte_' -'providers_' -'prominent_' -'progressive_' -'professionals_' -'produktion_' -'prim' -'prevented_' -'preise_' -'predict_' -'praktische_' -'politically_' -'ple' -'planung_' -'pic' -'phon' -'pet' -'performed_' -'par_' -'pack' -'oy' -'ote' -'organization_' -'ora_' -'ona' -'ological_' -'od_' -'näher_' -'ny' -'nto' -'nten_' -'nst' -'nom' -'nm' -'nk' -'nik_' -'neun_' -'ndo_' -'ndern_' -'nationalist_' -'narrow_' -'moralische_' -'mod' -'mix_' -'mistake_' -'ministeri' -'mess' -'meal_' -'mati' -'mas' -'mali' -'mainstream_' -'losen_' -'lor' -'log' -'load_' -'llt_' -'ller_' -'llen' -'leu' -'lets_' -'leistungen_' -'legend' -'legally_' -'lat' -'lasting_' -'langfristigen_' -'landscape_' -'lac' -'kti' -'kosten' -'kop' -'kin' -'killing_' -'ki' -'jährlichen_' -'justified_' -'japanischen_' -'jahr_' -'ively_' -'iv_' -'isten_' -'internen_' -'intentions_' -'installiert_' -'indi' -'incomes_' -'inci' -'impe' -'imp' -'ik' -'igt_' -'ignore_' -'iel_' -'iche_' -'hydro' -'humans_' -'hu' -'hte' -'hre' -'holen_' -'hle' -'hire_' -'hi_' -'heutige_' -'herzlich_' -'herrscht_' -'hen' -'heavily_' -'has' -'handel_' -'grant_' -'gos' -'golden_' -'gle' -'gewo' -'gestern_' -'gesellschaften_' -'geschrieben_' -'gesch' -'genetic_' -'gel_' -'gehalten_' -'gebiet_' -'gangen_' -'functioning_' -'frische' -'fourth_' -'forest_' -'focusing_' -'flying_' -'flu' -'floors_' -'fitness_' -'finish_' -'ferner_' -'fat' -'extrem_' -'exposed_' -'expanded_' -'evi' -'etz' -'esti' -'erzeugt_' -'error_' -'erleichtern_' -'erlangen_' -'erb' -'eo_' -'enz' -'entsteht_' -'entschlossen_' -'enterprise_' -'entering_' -'ening_' -'enhance_' -'enge_' -'enemies_' -'endo' -'enco' -'enb' -'emb' -'electric_' -'einzig_' -'einverstanden_' -'eingegangen_' -'eigenes_' -'eigener_' -'egung_' -'eat_' -'dt_' -'drei' -'dominated_' -'doctrine_' -'doctors_' -'doctor_' -'distinguish_' -'disk_' -'disasters_' -'difficulty_' -'departure_' -'defizit_' -'definitely_' -'dea' -'de' -'dates_' -'cu' -'ct' -'cro' -'cop' -'contents_' -'conta' -'constructive_' -'consistent_' -'conduct_' -'comply_' -'compliance_' -'com' -'colour' -'cleaning_' -'checking_' -'channel_' -'cell_' -'cart' -'carrying_' -'cad' -'bs_' -'bs' -'bru' -'broader_' -'bot_' -'bot' -'black' -'betreiben_' -'betonen_' -'beseitigen_' -'beschließen_' -'berÃŧhmten_' -'bere' -'beraten_' -'beobachten_' -'bel_' -'bekam_' -'beeinflusst_' -'beauty_' -'beantworten_' -'bauen_' -'basieren_' -'barriers_' -'backing_' -'autumn_' -'auto' -'auswählen_' -'ausgaben_' -'auftreten_' -'auftr' -'aufgegeben_' -'ata' -'assumed_' -'asi' -'arten_' -'arrived_' -'arrested_' -'array_' -'ards_' -'architecture_' -'anten_' -'anschließend_' -'anhand_' -'angest' -'angesehen_' -'anerkannt_' -'alter_' -'alongside_' -'allo' -'agr' -'adopting_' -'adi' -'ada' -'acknowledge_' -'achen_' -'accused_' -'abuse_' -']' -'Zwei' -'Zuständigkeit_' -'Zusatz' -'Zeitalter_' -'Ze' -'Wolf' -'Wirtschaftss' -'William_' -'Wild' -'Wikicars_' -'Wider' -'Wi_' -'Wi' -'White_' -'Weiß' -'Webseiten_' -'Wahrnehmung_' -'Vorsch' -'Vorausschau_' -'Volkswirtschaft_' -'Viertel_' -'Vielen_' -'Versuche_' -'Verk' -'VerfÃŧgbarkeit_' -'Verd' -'Ve' -'Use_' -'Unterkunft_' -'Une' -'Umgang_' -'Truppen_' -'Treaties_' -'Tip_' -'Terror_' -'Ter' -'Tee' -'Technology_' -'Tat' -'Tages' -'TER' -'System' -'Swiss_' -'Sun' -'Strukturfonds_' -'Structural_' -'Straßen' -'Str' -'Start' -'Standort_' -'Sommer' -'Situationen_' -'Sid' -'Show_' -'Selbstverständlich_' -'Sein_' -'Sea' -'Schuld' -'Scho' -'Schnell' -'Schlag' -'Schei' -'Schaden_' -'Sauna_' -'SD' -'Ruhe_' -'Rou' -'Roman' -'Rock_' -'Risiko' -'Rettungs' -'Reagan_' -'Rand_' -'Qui' -'Praxis_' -'Portuguese_' -'Politikern_' -'Politiken_' -'Polish_' -'Point_' -'Otherwise_' -'Opera_' -'Nord_' -'Nobel_' -'Niederlage_' -'Nic' -'NT' -'NO' -'Muster_' -'Monopol' -'Mittelmeer' -'Minute_' -'Micro' -'Methoden_' -'Mei' -'May' -'Market_' -'Mark' -'Mani' -'Mangel_' -'MB_' -'Logik_' -'Likewise_' -'Lib' -'Lernen_' -'Lei' -'Lauf' -'LateRooms_' -'Lake_' -'KÃŧnstler_' -'Kunden' -'Konvents_' -'Konto_' -'Konsum_' -'Konsum' -'Konflikten_' -'Kon' -'Kompromiss' -'Kauf' -'Kampf' -'KO' -'Joseph_' -'Jones_' -'Ironi' -'Indonesien_' -'Import' -'ISO_' -'ISIS_' -'Hollande_' -'Hir' -'Helsinki_' -'Hei' -'Hauses_' -'HI' -'Guests_' -'Griechenlands_' -'Gr' -'Gir' -'Geschwindigkeit_' -'Georg' -'Gentoo_' -'GS' -'GI' -'Freizeit' -'Frank_' -'Force_' -'For' -'Finanzsystem_' -'Fi_' -'Festlegung_' -'Fernsehen_' -'Fehl' -'Faktor_' -'FT' -'Eu' -'Ess' -'Erwachsene_' -'Erk' -'Ents' -'Empfehlungen_' -'Emissions' -'Eg' -'Dubai_' -'Direkt' -'Delegation_' -'Darin_' -'DER_' -'Country_' -'Communist_' -'Colo' -'ChÃĄvez_' -'Chirac_' -'Cat' -'Cas' -'Cambridge_' -'Cal' -'CS' -'CON' -'CH' -'CE' -'CA_' -'Boot_' -'Blut' -'Besorgnis_' -'Bern' -'Bemerkungen_' -'Bedrohungen_' -'Band' -'BR' -'Ave' -'Authority_' -'Australia_' -'AusfÃŧhrungen_' -'Assembly_' -'Ari' -'Arch' -'Annehmlichkeiten_' -'Ann' -'Angriffe_' -'Ala' -'Airlines_' -'AU' -'AL' -'90' -'56_' -'2020_' -'200' -'189' -'182' -'01' -'. „_' -''' (_' -'&#_' -' ,,_' -'â€Ļ _' -'–_' -'Đŧ' -'С' -'Ãŧberh' -'Ãļt' -'Êa' -'äten_' -'ät' -'är_' -'ält' -'ähr' -'ähnlich_' -'äge_' -'Übrigen_' -'ÜberprÃŧfung_' -'Überg' -'  ' -'}} **{{_' -'zuf' -'zie' -'zes_' -'zerstÃļrt_' -'zahlt_' -'yr' -'yg' -'xa' -'wÃŧnsche_' -'wunder' -'wirk' -'whilst_' -'wenden_' -'well' -'welcoming_' -'web' -'watch_' -'walls_' -'wal' -'vorgenommen_' -'vollständige_' -'vollkommen_' -'vie' -'victim_' -'versorgung_' -'versichern_' -'versetzt_' -'verdienen_' -'vent' -'vely_' -'ut' -'unterstÃŧtzte_' -'unterstÃŧtze_' -'unterliegen_' -'unter' -'unt' -'uni' -'ungs_' -'unden_' -'uncertain_' -'unabhängigen_' -'ump' -'umgeben_' -'ular_' -'ufen_' -'ude' -'tÃļ' -'tze' -'ttel' -'träger_' -'tra_' -'tour' -'tot' -'ton' -'tiv' -'tions' -'til' -'ticket_' -'threatened_' -'tha' -'texts_' -'tes' -'terrorists_' -'tern_' -'teri' -'tera' -'teilnehmen_' -'tbar' -'taxi_' -'tar_' -'tal' -'sÊ' -'switch_' -'swi' -'surely_' -'summe' -'subs' -'string_' -'sten' -'stem_' -'staaten_' -'ssu' -'spoken_' -'sph' -'sp' -'sorgt_' -'sodass_' -'sobald_' -'sit' -'sicherheit_' -'shortage_' -'shop_' -'sheet_' -'serving_' -'sensible_' -'sens' -'sehe_' -'seg' -'see' -'schÃļn_' -'schätzen_' -'schweren_' -'schutz_' -'schlimmsten_' -'schlechte_' -'schl' -'schein' -'rre' -'row_' -'rke' -'rk_' -'rival' -'ria_' -'rhetoric_' -'revers' -'ress' -'respects_' -'resource_' -'reserved_' -'rescue_' -'repeat_' -'rent_' -'renoviert_' -'remo' -'rein_' -'reich' -'refusal_' -'red' -'recover_' -'recommendation_' -'rechten_' -'rder' -'rch' -'rational_' -'ratifiziert_' -'rati' -'radi' -'rach' -'qui_' -'qualified_' -'q_' -'pursuit_' -'pu' -'pti' -'prÃŧf' -'prozess_' -'programm' -'presents_' -'power' -'potenziellen_' -'porta' -'por' -'polls_' -'plÃļtzlich_' -'pli' -'persÃļnliche_' -'pens' -'pec' -'partnerships_' -'ost' -'ossen_' -'ort' -'orit' -'oriented_' -'organized_' -'organisationen_' -'ore_' -'optimal_' -'opti' -'ont' -'onal_' -'ock' -'occupation_' -'obstacles_' -'observers_' -'nz' -'ntie' -'nter' -'nted_' -'nomi' -'niveau_' -'nist' -'nish' -'niedrig_' -'nic' -'netz_' -'nde' -'nces_' -'nahm_' -'mä' -'mos' -'mor' -'mono' -'moni' -'moments_' -'modify_' -'mode' -'mixed_' -'milk_' -'meters_' -'messages_' -'mere_' -'menschlichen_' -'mein' -'medicine_' -'med_' -'mat' -'male_' -'lä' -'lunch_' -'lter_' -'lots_' -'llig' -'lle' -'lity_' -'lichkeiten_' -'liches_' -'lic' -'library_' -'lev' -'lenken_' -'lec' -'last' -'langsam_' -'lake_' -'ky_' -'kun' -'ktor' -'ktions' -'ks' -'kr' -'kov' -'korrekt_' -'konkreten_' -'knapp_' -'klicken_' -'kleinere_' -'kes' -'kers_' -'ken' -'judge_' -'journey_' -'ji_' -'jam' -'itte' -'irtschaft' -'involving_' -'investitionen_' -'investigation_' -'ini' -'informal_' -'inflows_' -'individuelle_' -'impression_' -'ika' -'ignored_' -'ies' -'iegen_' -'id' -'hung_' -'hot' -'hopes_' -'hood_' -'holidays_' -'hof' -'hnen_' -'historisch_' -'hingewiesen_' -'hill_' -'hilfe_' -'herauszu' -'heraus_' -'hee' -'handling_' -'handle_' -'gäbe_' -'gt' -'grÃļßeres_' -'girls_' -'gg' -'gestiegen_' -'gesetz' -'gesellschaftlichen_' -'geringere' -'gera' -'geprÃŧft_' -'genÃŧgend_' -'generate_' -'gener' -'geld' -'geboten_' -'geboren_' -'gat' -'gallery_' -'gaben_' -'foundations_' -'fortsetzen_' -'forthcoming_' -'fort_' -'ford_' -'football_' -'fleet_' -'fis' -'ffi' -'ffe' -'festzustellen_' -'festen_' -'falsche_' -'fach' -'ey' -'extremist' -'extra' -'explicitly_' -'expand_' -'existieren_' -'executive_' -'exa' -'evo' -'eut' -'erwiesen_' -'ersetzt_' -'erreichte_' -'err' -'ernsthaft_' -'erlebt_' -'erkannt_' -'ergriffen_' -'ep_' -'entw' -'enger_' -'energie_' -'elt_' -'eld_' -'ekt' -'eise_' -'eingesch' -'ehemalige_' -'effektiv_' -'een_' -'ee' -'edi' -'eau_' -'earned_' -'durchfÃŧhren_' -'dropped_' -'drivers_' -'dramatically_' -'dos_' -'div' -'disposal_' -'displayed_' -'diskutieren_' -'dictatorship_' -'dialog_' -'deswegen_' -'design' -'describes_' -'describe_' -'derzeitige_' -'department_' -'denselben_' -'demonstrations_' -'demonstrated_' -'delighted_' -'delicious_' -'dee' -'dedicated_' -'decorated_' -'declare_' -'debian_' -'deadline_' -'davor_' -'dark_' -'dam' -'cul' -'ctions_' -'count_' -'cou' -'constraints_' -'considerations_' -'comparison_' -'claiming_' -'cil' -'ci_' -'chw' -'chr' -'chief_' -'chel' -'chef_' -'charm' -'chaft_' -'careful_' -'camps_' -'calendar_' -'busy_' -'buses_' -'bug_' -'buchen_' -'bright_' -'brief_' -'bezieht_' -'beweisen_' -'bestätigt_' -'bestehende_' -'bericht_' -'benefited_' -'bell' -'bekämpfung_' -'begun_' -'begl' -'bat' -'bahn_' -'awareness_' -'award_' -'avoided_' -'aufh' -'aufgefÃŧhrt_' -'asiatischen_' -'ase_' -'artistic_' -'aro' -'arn' -'arkt' -'ark' -'arbitrary_' -'approval_' -'appropriations_' -'apparent_' -'ape' -'ants_' -'anniversary_' -'anhaltenden_' -'angemessene_' -'anc' -'amp_' -'ambition_' -'alter' -'aktuell_' -'ag' -'aftermath_' -'adult_' -'admit_' -'adjusted_' -'adds_' -'adaptation_' -'act' -'accepting_' -'acceptance_' -'acc' -'abs' -'aben_' -'Zoll' -'Währungsfonds_' -'Wä' -'Wit' -'Wissenschaftler_' -'Wird_' -'Winter' -'Werten_' -'Weltb' -'Welcome_' -'Weit' -'We' -'Ware_' -'Waffen' -'Vorsitz_' -'Vorschlags_' -'Vorfeld_' -'Voll' -'Versprechen_' -'Versch' -'Verle' -'Verlauf_' -'Vereinigte_' -'Vereinbarungen_' -'Verbrechen_' -'Van_' -'Using_' -'Up' -'Tätigkeit_' -'Tur' -'Traum_' -'Trans' -'Town_' -'Todesstrafe_' -'Ten' -'Technik_' -'Symbol_' -'Sunni_' -'Studenten_' -'Strukturreformen_' -'Struktur_' -'Straßburg_' -'Stor' -'Smith_' -'Sky' -'Sir' -'Sil' -'Sicherheitsrat_' -'Sicher' -'Should_' -'Schwäche_' -'Schweden_' -'Schon_' -'SchlÃŧssel_' -'Schle' -'Schie' -'Sar' -'STA' -'Rose' -'Rob' -'Richard_' -'Ret' -'Republicans_' -'Rein' -'Rechnungs' -'Radio_' -'Qualitäts' -'Qua' -'Produktivität_' -'Prim' -'Presse_' -'Press_' -'Pole' -'Playa_' -'PP' -'PA' -'Nummer_' -'Notes_' -'Nokia_' -'Niemand_' -'Nicolas_' -'Netzwerk' -'Netz_' -'Net' -'Nei' -'Nan' -'Nachricht_' -'NG' -'NCC_' -'MÃļ' -'Myanmar_' -'Mini' -'Min' -'Meter_' -'Metall' -'Messe' -'Menschenhandel_' -'Maß_' -'Mass' -'Marketing_' -'Manager_' -'Mallorca_' -'MP' -'Ly' -'Lounge_' -'Liberal' -'Libanon_' -'Let' -'Lateinamerika_' -'Ladies_' -'KÃļnig_' -'Kä' -'Kuba_' -'Krieges_' -'Kr' -'Komponenten_' -'Kit' -'Keynes_' -'Kauf_' -'Kara' -'Kandidaten_' -'KA' -'Islands_' -'Is' -'Invest' -'Internet' -'Image_' -'Hände_' -'Hum' -'Hostels_' -'Hongkong_' -'Herz' -'Have_' -'Haupts' -'Hall_' -'Halb' -'Hafen_' -'Haar' -'GÃŧ' -'Gästen_' -'Gri' -'Grab' -'Gott_' -'GlÃŧcklicherweise_' -'Glas' -'Gesicht_' -'Generationen_' -'Gegens' -'Gar' -'GB_' -'Friendly_' -'Framework_' -'Formular_' -'Ford_' -'Flash_' -'Fischer_' -'Finanzp' -'Festival_' -'FI' -'Europa' -'Eta' -'Erstellung_' -'Errichtung_' -'Entschließungsantrag_' -'Energies' -'Emi' -'Einhaltung_' -'Einer_' -'Einbeziehung_' -'EP' -'EIB_' -'EC' -'Drei_' -'Dreh' -'Doppel' -'Die' -'Dia' -'Deswegen_' -'Description_' -'Demokratien_' -'Demokraten_' -'Dek' -'DS9_' -'Cu' -'Cruz_' -'Col' -'Cit' -'Christ_' -'Cher' -'Canon_' -'Can' -'California_' -'BÃŧrgerinnen_' -'BÃļrsen' -'Bulgaria_' -'Budapest_' -'Bri' -'Br' -'Bol' -'Bewegungen_' -'Betr' -'Benjamin_' -'Banken' -'Ausstattung_' -'Ausnahme' -'Ausmaß_' -'Auslands' -'Ausl' -'Aus' -'Aufstieg_' -'Aspekten_' -'Argumente_' -'Argument_' -'Area_' -'Angela_' -'Ana' -'Allianz_' -'AV' -'96_' -'63_' -'59_' -'54_' -'40' -'196' -'02_' -'%' -'“ (_' -'Đŋ' -'Đŗ' -'Ãŧchter' -'Ãŧbertr' -'Ãŧbernimmt_' -'Êg' -'äts' -'ält_' -'älle_' -'Übergangs' -'Ökonomie_' -'Ängste_' -'zy' -'zunehmen_' -'zugrunde_' -'ysteme_' -'xt' -'wort_' -'wohl' -'with' -'wirksame_' -'wie' -'werfen_' -'weist_' -'waves_' -'wasser' -'wahlen_' -'vorzu' -'vol' -'vil' -'vieles_' -'vessels_' -'verteilung_' -'verschieden' -'versch' -'verr' -'vermutlich_' -'verhalten_' -'verb' -'varia' -'vari' -'var' -'uss' -'unusual_' -'untersucht_' -'untergraben_' -'unsch' -'united_' -'ungsp' -'underground_' -'undergo' -'unab' -'umso_' -'uen_' -'ue' -'täglichen_' -'tzung_' -'twin_' -'tur' -'tum' -'tub' -'ttle' -'troll' -'tritt_' -'trees_' -'trail' -'tracks_' -'tour_' -'tory_' -'topics_' -'tone_' -'tit' -'thousand_' -'theo' -'tells_' -'tel' -'tea' -'taxation_' -'tatsächlichen_' -'tali' -'tail' -'tabl' -'ta' -'sä' -'systematic_' -'swa' -'surviv' -'surprised_' -'supplementary_' -'superior_' -'successes_' -'subsequently_' -'stunning_' -'stream' -'stories_' -'stores_' -'stor' -'stops_' -'stischen_' -'stimulate_' -'stellten_' -'steigt_' -'staying_' -'starts_' -'stan' -'stammen_' -'stages_' -'ssion_' -'spin' -'spiel_' -'speziellen_' -'sozialer_' -'solar_' -'sm' -'slowdown_' -'sil' -'sicheren_' -'sichere_' -'sia' -'showed_' -'shock_' -'shed_' -'sent' -'selling_' -'seits_' -'seconds_' -'schrieb_' -'scha' -'scar' -'sake_' -'sah_' -'rÃļ' -'rän' -'rz' -'rum' -'rte' -'ros' -'roll' -'rm_' -'rle' -'rkt_' -'rium_' -'rier' -'rie_' -'rid' -'ria' -'rge' -'restricted_' -'restored_' -'resolutions_' -'residence_' -'representation_' -'renminbi_' -'rem' -'regulators_' -'regul' -'regelmäßig_' -'refused_' -'reduziert_' -'recognised_' -'realer_' -'real' -'rdin' -'rarely_' -'rant' -'rah' -'quo_' -'queries_' -'qualifi' -'puts_' -'pun' -'psycho' -'proud_' -'prosperous_' -'promising_' -'proc' -'privileged_' -'prinzip_' -'prev' -'preserve_' -'pr' -'pools_' -'poet' -'platz' -'pin' -'pher' -'personenbezogene' -'perception_' -'peak_' -'ownership_' -'oste' -'ost_' -'ore' -'oran' -'opp' -'operates_' -'oper' -'ont_' -'om_' -'ole_' -'oin' -'ogi' -'oftmals_' -'objects_' -'object_' -'nung_' -'novel_' -'noti' -'notably_' -'normally_' -'nommen_' -'nn_' -'newspaper_' -'neuro' -'net' -'nent' -'ndung_' -'ndes_' -'nate' -'nal' -'nachhaltigen_' -'mÃŧ' -'murder_' -'mt_' -'movie_' -'modest_' -'mn' -'mitten_' -'miss' -'mine_' -'minds_' -'mil' -'mie' -'mi_' -'methoden_' -'mediat' -'meat_' -'mankind_' -'manage' -'malaria_' -'luc' -'lter' -'logic_' -'lock' -'lling_' -'lli' -'liefern_' -'lid' -'leichter_' -'lebt_' -'lebens' -'lde_' -'lc' -'lar' -'langem_' -'kten_' -'kritisiert_' -'kont' -'komme_' -'kne' -'kli' -'kennt_' -'kel_' -'justify_' -'jederzeit_' -'jedenfalls_' -'ize_' -'ite' -'ira_' -'ious_' -'ions' -'invite_' -'intends_' -'intend_' -'inten' -'instantly_' -'innovations_' -'info_' -'inclusion_' -'inadequate_' -'importantly_' -'imperative_' -'ile' -'ild_' -'igi' -'iffe' -'ideale_' -'hyper' -'hr_' -'honest_' -'home' -'holes_' -'hne' -'hinzufÃŧgen_' -'hidden_' -'heu' -'hervorragende_' -'hervorheben_' -'hence_' -'hem' -'hei' -'heads_' -'hal' -'haften_' -'haft_' -'gy_' -'gy' -'gul' -'guided_' -'grÃŧnden_' -'grundsätzlich_' -'großem_' -'grand_' -'gra' -'gr' -'gne_' -'gleichermaßen_' -'glass_' -'gl' -'gewisse' -'gewinn' -'getragen_' -'genuinely_' -'geni' -'geltenden_' -'gegen' -'gefährlich_' -'gef' -'gaining_' -'fÃŧhrer_' -'fur' -'funded_' -'frequency_' -'frau' -'fores' -'film' -'festge' -'fern' -'fas' -'farm_' -'far' -'fangen_' -'fan_' -'fan' -'fahr' -'facts_' -'factory_' -'fac' -'extraordinary_' -'exploit' -'expert_' -'exemption_' -'exe' -'exception_' -'evident_' -'esp' -'ert' -'ers' -'erre' -'erp' -'erla' -'erhÃļhte' -'erho' -'erhielt_' -'erheben_' -'erba' -'entwurf_' -'entwickelte_' -'entstanden_' -'ention' -'enthalt_' -'enforcement_' -'ene' -'ende' -'encouraged_' -'empfangen_' -'elektronischen_' -'eintr' -'einrichtungen_' -'eingerichtete_' -'eingebracht_' -'eing' -'einf' -'eindeutige_' -'eht_' -'effekt' -'eer' -'ec' -'easing_' -'ea_' -'dÃŧrften_' -'duties_' -'dul' -'druck_' -'droh' -'dringende' -'dri' -'drawing_' -'down' -'doors_' -'dominant_' -'directory_' -'dig' -'desto_' -'desk_' -'description_' -'deposit_' -'denjenigen_' -'democracies_' -'delivery_' -'defeat_' -'declining_' -'decken_' -'dealt_' -'deals_' -'das' -'cz' -'cts_' -'cti' -'creditors_' -'covering_' -'countless_' -'corresponding_' -'correctly_' -'coordinated_' -'concentration_' -'coherent_' -'coe' -'codecision_' -'closing_' -'clause_' -'civilian_' -'chtet_' -'chte' -'chst' -'chlag' -'chie' -'chemical_' -'chas' -'chal' -'cent_' -'bä' -'brutale' -'breast_' -'bre_' -'bracht_' -'bought_' -'boot' -'bo_' -'billions_' -'bili' -'bewirken_' -'beweg' -'bert_' -'believes_' -'bel' -'behÃļrde_' -'begrÃŧndet_' -'beglÃŧckwÃŧnschen_' -'begegnen_' -'beeinflussen_' -'bedingungen_' -'bears_' -'band_' -'band' -'bai' -'ay_' -'awarded_' -'authors_' -'authoritarian_' -'auszus' -'auswirken_' -'ausw' -'aub' -'atte' -'atische' -'ational_' -'artists_' -'arte' -'ars_' -'arkt_' -'ari' -'arc' -'ara_' -'approaches_' -'appeared_' -'appearance_' -'apa' -'anstreben_' -'anl' -'angeh' -'andern' -'anda' -'alp' -'alb' -'aktive_' -'ak_' -'ahr' -'aged_' -'aft_' -'affecting_' -'advised_' -'acquired_' -'absolut_' -'abo' -'abb' -'Zwei_' -'Zusätzlich_' -'Zimmerservice_' -'ZerstÃļrung_' -'Zen' -'Zei' -'XI' -'Wor' -'Wirtschaftswachstums_' -'Wiederherstellung_' -'Werkzeuge_' -'Werk_' -'Weltk' -'Weil_' -'Water' -'Wand' -'WM' -'Vorlage_' -'Vorgehen_' -'Video' -'Verteidigungspolitik_' -'Verr' -'Verhaltens' -'Verfahrens' -'Vater_' -'VER' -'Ut' -'User_' -'Ursachen_' -'Untersuchungen_' -'Uns_' -'Ung' -'US' -'Tool_' -'Tom' -'Tim' -'Throughout_' -'Texas_' -'Temple_' -'Teilnehmer_' -'Tausende_' -'Taten_' -'Tak' -'TO_' -'Sur' -'Stä' -'Streitkräfte_' -'Stimm' -'Stan' -'Spe' -'Souveränität_' -'Sing' -'Sign' -'Siehe_' -'Service' -'Serbien_' -'Senkung_' -'Semi' -'Section_' -'Sec' -'Schwellenländer_' -'Schr' -'Schloss_' -'Schlag_' -'Schatten_' -'Schalt' -'SG' -'RÃŧcks' -'Rä' -'Rules_' -'Ris' -'Ric' -'Rest' -'Rent' -'Religion_' -'Ref' -'Ree' -'Rechtsstaatlichkeit_' -'Rec' -'Rats' -'Präsidentschaftswahl' -'Präsenz_' -'Protein' -'Produzenten_' -'Procedure_' -'Prob' -'Post_' -'Pos' -'Pont' -'Pis' -'Pes' -'Peace_' -'Past' -'Partnerschaften_' -'PO' -'PL' -'PI' -'PE' -'Original' -'Organization_' -'Offenheit_' -'Od' -'Obamas_' -'OS' -'OR' -'Normen_' -'None_' -'Nizza_' -'Nichtraucherzimmer_' -'Next_' -'Nat' -'NS' -'Männern_' -'Mächte_' -'Muslime_' -'Mou' -'Modern_' -'Missbrauch_' -'Minutes_' -'Ministerpräsidenten_' -'Milosevic_' -'Migration_' -'Mechanismus_' -'Material_' -'Maschinen_' -'Martin' -'Marktes_' -'Manhattan_' -'Manche_' -'Mana' -'Mala' -'Mad' -'MPEG_' -'MP3_' -'MM' -'Ltd_' -'Lis' -'Links_' -'Lim' -'Lehre_' -'Lehr' -'Legitimität_' -'Landwirte_' -'Lad' -'Lac' -'Lab' -'LIN' -'LE_' -'KÃļnnen_' -'Kooperations' -'Kontrollen_' -'Konkurrenz_' -'Konjunktur' -'Komp' -'Kirche_' -'Kha' -'Kernel_' -'Kau' -'Kategorie_' -'Kata' -'Kampagne_' -'KORE_' -'Jugend' -'Jen' -'Jackson_' -'Jack' -'Ir' -'Internal_' -'Instruments_' -'Insgesamt_' -'Innen' -'Inf' -'Inc' -'Il' -'ION_' -'ID' -'IBM_' -'IB' -'HÃŧ' -'Human' -'Hard' -'Hai' -'GÃŧter_' -'Gut' -'Gulf_' -'Gua' -'Gh' -'Gewicht_' -'Gemeinschaften_' -'Gel' -'Gehminuten_' -'Gegend_' -'Geber' -'Garantie_' -'Ganz_' -'Games_' -'Gab' -'GR' -'FÃļrder' -'Freunden_' -'Freiheiten_' -'Free' -'Fred' -'Food_' -'Flughäfen_' -'Flexibilität_' -'Fischerei' -'Fisch' -'Firefox_' -'Finland_' -'Ferienwohnung_' -'Fehler' -'Farben_' -'Family_' -'Exchange_' -'Eth' -'Erwärmung_' -'Erreichung_' -'Ern' -'Erbe_' -'Entscheidungs' -'Emp' -'Em' -'Element_' -'Einwanderer_' -'Einf' -'Eigenschaften_' -'Eigenheim' -'Edition_' -'ES' -'ER_' -'EN' -'EL_' -'EI' -'EG_' -'ED_' -'Dro' -'Dir' -'Dev' -'Deflation_' -'Defizite_' -'Darum_' -'DS' -'Creati' -'Cre' -'Cou' -'Clubs_' -'Client_' -'Chu' -'Chr' -'Cent' -'Cari' -'Cannes_' -'Brok_' -'Black' -'Bl' -'Bez' -'Bett' -'Betrieb_' -'Beobachter_' -'Bene' -'Ben_' -'Beispielsweise_' -'Beh' -'Begleit' -'BedÃŧrfnissen_' -'Bed' -'Beamten_' -'Based_' -'Base_' -'Avi' -'Außenminister_' -'Ausr' -'AusfÃŧhrung_' -'Auftr' -'Ast' -'Arabien_' -'Antworten_' -'Anonymous_' -'Amtszeit_' -'Alm' -'Allein_' -'Alle' -'Ali' -'Aktivität_' -'Aktionäre_' -'Abschnitt_' -'Absatz_' -'Abh' -'Abf' -'Abe' -'Ab_' -'ASEAN_' -'86_' -'82_' -'58_' -'53_' -'51_' -'35' -'28' -'1990s_' -'1980er_' -'185' -'183' -'1000_' -'04' -'. - _' -''', _' -'! ' -' ?_' -' ; _' -' +_' -' + _' -' &#_' -' "' -'â€Ļ"..._' -'â€Ļ' -'҇' -'Ãŧtte' -'Ãŧstung_' -'Ãŧge_' -'Ãŧck' -'Ãŧberwinden_' -'Ãŧberwiegend_' -'Ãŧberschreiten_' -'Ãŧberraschend_' -'Ãļsterreichischen_' -'Ãļkonomischen_' -'Ãļffentlicher_' -'Êv' -'Êri' -'Êc' -'è_' -'ändische' -'äl' -'äh' -'äg' -'Überzeugung_' -'Übernachtung_' -'Überdies_' -'Überblick_' -'Ö' -'¡ _' -'zz' -'zwingen_' -'zweiter_' -'zweit' -'zutiefst_' -'zuständigen_' -'zusch' -'zusammenarbeiten_' -'zukÃŧnftigen_' -'zugute_' -'zuge' -'zug' -'zle' -'zione_' -'zins' -'zers' -'zel' -'zahlungen_' -'yt' -'yn' -'yb' -'yan' -'xen' -'wÃŧ' -'wound' -'wn_' -'wl' -'wissenschaftliche_' -'wis' -'wirklichen_' -'winning_' -'wherever_' -'wettbewerbs' -'westliche_' -'welcomes_' -'wai' -'wahr_' -'vulnerability_' -'vr' -'voneinander_' -'void' -'virtual_' -'villages_' -'viertel_' -'viable_' -'vez_' -'verwi' -'verschaffen_' -'versa' -'vermittelt_' -'vermitteln_' -'vermi' -'vergi' -'verd' -'verboten_' -'verbl' -'verbindet_' -'verabschiedet_' -'ved_' -'uth' -'ute_' -'uss_' -'usch' -'urgency_' -'urg_' -'updated_' -'unz' -'unterscheiden_' -'uns' -'unilateral' -'ung' -'unentgeltlich_' -'undertake_' -'unan' -'umzu' -'umgekehrt_' -'umgehen_' -'umge' -'ultra' -'ultimate_' -'uis' -'uge' -'uft_' -'uf_' -'täts' -'tätigkeit_' -'tätig' -'typisch_' -'ty' -'twice_' -'twentieth_' -'tungen_' -'tude_' -'trou' -'travel' -'transformation_' -'transfers_' -'transaction_' -'transa' -'tranquil' -'tragic_' -'tough_' -'tigen_' -'tickets_' -'thumbnail_' -'thumb_' -'throw_' -'themes_' -'tet_' -'territorial_' -'terra' -'tely_' -'teilt_' -'techniques_' -'teaching_' -'tam' -'tall' -'sÃŧd' -'sy_' -'sv' -'supervision_' -'suited_' -'sufficiently_' -'student_' -'strukturellen_' -'strategie_' -'stored_' -'stimme_' -'stil' -'stieg_' -'ster' -'stars_' -'star' -'stake_' -'stagnation_' -'ssch' -'squa' -'spro' -'spre' -'spo' -'split_' -'spiel' -'spiegelt_' -'spanischen_' -'solches_' -'slu' -'slightly_' -'ske' -'ska' -'significance_' -'sig' -'sien_' -'sie' -'sicher' -'shortly_' -'sharp_' -'sey_' -'senden_' -'semi' -'screens_' -'score_' -'schwieriger_' -'schwerer_' -'schwach_' -'schuld' -'schrittweise_' -'schr' -'schm' -'schlicht_' -'schien_' -'schemes_' -'schau' -'satz_' -'sat' -'räum' -'rungen_' -'rst' -'rov' -'routes_' -'rounds_' -'rolle_' -'rock_' -'robust_' -'riti' -'risen_' -'rip' -'ring' -'right' -'riffen_' -'rica' -'ric' -'rh' -'revisi' -'reverse_' -'reveal_' -'retirement_' -'reta' -'renewed_' -'rek' -'regulated_' -'refers_' -'reduzieren_' -'rechtliche_' -'realized_' -'realize_' -'reactions_' -'rau' -'ratio_' -'rail_' -'rahmen_' -'rage_' -'quart' -'q' -'pushing_' -'province_' -'projekte_' -'produces_' -'proceed_' -'privacy_' -'pris' -'principal_' -'practically_' -'possess' -'poses_' -'play' -'planen_' -'pipeline_' -'pho' -'philosophy_' -'pharmaceutical_' -'pf' -'perfekt_' -'peoples_' -'pensions_' -'partition_' -'parti' -'parlamentarischen_' -'parameters_' -'pala' -'oth' -'ot' -'organiz' -'organisiert_' -'organis' -'ores_' -'ordnungsgemäß_' -'optional_' -'ops_' -'operated_' -'opens_' -'openness_' -'opa' -'oo_' -'olo' -'officially_' -'ode' -'oci' -'observations_' -'ober' -'nze' -'nue' -'nse' -'nowadays_' -'notwendige_' -'notes_' -'norm' -'nobody_' -'nni' -'nn' -'nit' -'ning' -'ngs' -'nett' -'negativen_' -'nali' -'nahe' -'nachhaltig_' -'nachdrÃŧcklich_' -'mÃŧsse_' -'mÊ' -'märkte_' -'mpo' -'mp_' -'mou' -'monet' -'mmen' -'mitge' -'meaningful_' -'mber' -'max_' -'masse' -'marine_' -'manu' -'manifest' -'mand' -'mal' -'makers_' -'lÃļsung_' -'länd' -'lte_' -'lowering_' -'lohn' -'lock_' -'llo' -'liti' -'literature_' -'listings_' -'lin' -'lieb' -'liberalization_' -'leute_' -'lawyers_' -'law' -'laufenden_' -'langfristig_' -'lager_' -'label' -'kÃŧnftig_' -'kul' -'koreanische' -'kontextuellen_' -'kontextuell_' -'kommen' -'klima' -'jer' -'jahr' -'ix_' -'ive' -'ising_' -'ios' -'ione' -'io' -'inva' -'instructions_' -'institutionelle_' -'inst' -'ink' -'infected_' -'incredible_' -'inc' -'inability_' -'immt_' -'ilit' -'ike' -'ignoriert_' -'if' -'ieß' -'ient' -'iehen_' -'ici' -'ichts' -'ica_' -'ial' -'hÃļr' -'hy_' -'hoping_' -'honour_' -'hne_' -'hinzufugen_' -'hier' -'hie' -'hes_' -'herunterladen_' -'here' -'heits' -'heat_' -'healthy_' -'headed_' -'hast_' -'hart_' -'harm_' -'harder_' -'gÃŧltige' -'guard' -'gu' -'gree' -'grand' -'grammatisch_' -'got' -'gon' -'gkeit_' -'gin' -'gewiss_' -'gestr' -'gestaltet_' -'geringer_' -'generous_' -'generell_' -'gene' -'gemeinschaft_' -'geme' -'gem' -'gelingen_' -'gefragt_' -'gebildet_' -'gebaut_' -'gear' -'fÃŧgen_' -'fun' -'ftig_' -'ften_' -'fristige' -'freundliche_' -'freedoms_' -'fraud_' -'fran' -'fr' -'folge' -'fly_' -'flexib' -'flash_' -'flag_' -'fit' -'firmly_' -'fier' -'festival_' -'ferr' -'feindliche' -'fee_' -'featuring_' -'fanden_' -'fairness_' -'fahrt_' -'ext' -'existiert_' -'exhibit' -'ewa' -'ew' -'evolution_' -'everywhere_' -'everybody_' -'etzen_' -'ett' -'ethi' -'eth' -'eta' -'este_' -'esse' -'esen_' -'ese_' -'ese' -'esa' -'erzielte_' -'erweisen_' -'erst' -'erse' -'ernsthafte_' -'erne' -'ermutigen_' -'erli' -'erkunden_' -'erkennt_' -'erhebliche_' -'erfreut_' -'erfolgreiche_' -'erarbeitet_' -'environmentally_' -'entwicklung_' -'entsprechen_' -'entgegen_' -'entdeckt_' -'entdeckst_' -'enst' -'enn' -'enl' -'enha' -'engaged_' -'electrical_' -'ela' -'einziges_' -'eintreten_' -'einstellen_' -'eins_' -'einheitliche_' -'eingestellt_' -'eingel' -'eignet_' -'educated_' -'ebene_' -'ease_' -'ean' -'dy' -'ds' -'drängen_' -'drittens_' -'drig' -'dot' -'dor' -'dokument' -'dli' -'distr' -'diskutiert_' -'disappear' -'disabled_' -'director_' -'digitale_' -'digit' -'devastating_' -'deutet_' -'destroying_' -'destroy_' -'dest' -'desktop_' -'desirable_' -'dera' -'depth_' -'denkt_' -'deli' -'definieren_' -'deck' -'dec' -'debts_' -'dating_' -'dad_' -'critique_' -'credits_' -'coverage_' -'courts_' -'counterparts_' -'corr' -'convey' -'conveniently_' -'convenience_' -'contradict' -'contextual_' -'confe' -'complicated_' -'compa' -'comp' -'committees_' -'commit_' -'collect_' -'coa' -'cm_' -'cle' -'cla' -'chtig' -'chein' -'char' -'cet' -'cess' -'career_' -'campaign' -'camp_' -'calm_' -'bÃŧ' -'bte' -'brä' -'bre' -'boosting_' -'bol' -'boards_' -'bil' -'bigger_' -'bezahl' -'bewegt_' -'bet' -'besonderer_' -'beschäftigen_' -'bene' -'belong' -'beliebte' -'beitr' -'beibehalten_' -'behe' -'begr' -'befassen_' -'bef' -'beendet_' -'beding' -'beda' -'beantragen_' -'ball' -'bail' -'aw' -'autonomy_' -'authentic' -'ausschuss_' -'ausgefÃŧhrt_' -'ausgeb' -'ausf' -'aufrechtzuerhalten_' -'aufrecht_' -'aufgefordert_' -'audience_' -'attract_' -'attended_' -'attempting_' -'ato_' -'ations' -'ate' -'assess_' -'assen_' -'artificial_' -'arit' -'arbeite' -'apo' -'anzunehmen_' -'anzeigen_' -'anwenden_' -'answers_' -'ano_' -'annually_' -'angewiesen_' -'angebracht_' -'ands_' -'amo' -'ami' -'alike_' -'ald' -'ala' -'ai_' -'agieren_' -'aga' -'affordable_' -'aff' -'advocate_' -'advances_' -'ado_' -'aden_' -'achievements_' -'accounting_' -'abzusch' -'abwe' -'abuses_' -'ables_' -'abgelehnt_' -'abandoned_' -'Zwischen_' -'Zwecke_' -'Zwangs' -'Zucker' -'Zone_' -'Zivil' -'Zimmer' -'Zimbabwe_' -'Work_' -'Whether_' -'Werke_' -'Werk' -'Wer' -'Wellness_' -'Well_' -'Wea' -'Watson_' -'Wander' -'Wahrscheinlichkeit_' -'WASHINGTON_' -'Vorz' -'Vors' -'Vork' -'Vorgänger_' -'Volks' -'Virus_' -'Vielmehr_' -'Vic' -'Vertretern_' -'Vertrags' -'Versionen_' -'Verletzung_' -'Verg' -'Vat' -'VI_' -'VE_' -'Up_' -'Until_' -'Unternehmer_' -'Unternehmer' -'UnterkÃŧnfte_' -'UnterdrÃŧckung_' -'Unt' -'Uns' -'Umstände_' -'UM' -'Typ_' -'Transport' -'Trag' -'Tr' -'Tor_' -'Tochter_' -'Thu' -'Text' -'Teams_' -'Tarifa_' -'Take_' -'TE_' -'TB_' -'Sunday_' -'Stre' -'Strand' -'Stoffe_' -'Stat' -'Stahl' -'Spielraum_' -'Spielen_' -'Spezial' -'Sp' -'Sowohl_' -'Southern_' -'Sound_' -'Sound' -'Sol_' -'Sm' -'Siemens_' -'Sicherheitsa' -'Shin' -'Scotland_' -'Schä' -'Schwer' -'Schweizer_' -'Schwe' -'Schreiben_' -'Schlusselwort_' -'Schlusselphrase_' -'Satz_' -'Salz' -'Sal' -'SS' -'SE_' -'Rund' -'Rubrik_' -'Rom_' -'Roh' -'Rit' -'Ren' -'Rem' -'Rela' -'Regional' -'Reg' -'Reaktionen_' -'RO' -'RM' -'Qaeda_' -'Punkten_' -'Prozess' -'Provence_' -'Protocol_' -'Programm' -'Prague_' -'Pr' -'Porto_' -'Poli' -'Plus_' -'Plenum_' -'Planung_' -'Philosophie_' -'Pha' -'Performance_' -'Passwort_' -'Park' -'Palma_' -'Pad' -'PRI' -'PC' -'Out_' -'Osteuropa_' -'Ok' -'OM' -'OF_' -'Nova' -'Nin' -'Nik' -'Niederlande_' -'Netto' -'Net_' -'Nazi_' -'Nachfolger_' -'NG_' -'My' -'Mul' -'Mubarak_' -'Monitor' -'Mitt' -'Mexican_' -'Metro_' -'Meeres' -'Medien' -'Mars' -'Marktwirtschaft_' -'Mario_' -'Marina_' -'Mar_' -'Mandat_' -'Mak' -'MT' -'Lohn' -'Listings_' -'Leu' -'Leitlinien_' -'Later_' -'Lastly_' -'Las' -'Laeken_' -'Kurdish_' -'Konf' -'Kommunismus_' -'Kol' -'Kohle_' -'Kle' -'Klar' -'Kie' -'Kenntnisse_' -'Kapitalismus_' -'Kapazitäten_' -'Kann_' -'Kad' -'Journalisten_' -'Jews_' -'Install' -'Ing' -'Infrastruktur' -'Info' -'Immobilien_' -'Imm' -'Ig' -'HÃļchst' -'Hour_' -'Hor' -'Hon' -'Hochschul' -'Hit' -'History_' -'Hezbollah_' -'Heimat_' -'Had' -'HT' -'HO' -'Guide_' -'GrÃŧnbuch_' -'GrÃŧn' -'Gottes_' -'Gold' -'Ghana_' -'Gewinn' -'Gewerkschaften_' -'Get_' -'Gesprächen_' -'Gene_' -'Geh' -'Ged' -'Gebrauch_' -'Gebi' -'Gaulle_' -'GP' -'FÃŧ' -'FreizÃŧgigkeit_' -'Freedom_' -'Forscher_' -'FlÃŧchtlings' -'Flugzeuge_' -'Flugzeug' -'Flo' -'Fischler_' -'Fern' -'Ferienwohnungen_' -'Fang' -'Fair' -'FS' -'Explorer_' -'Este' -'Equal' -'Entschlossenheit_' -'Elite_' -'Einnahmen_' -'Einkaufs' -'Eigentum_' -'Eigen' -'Egyptian_' -'Effekt_' -'EP_' -'Durchsetzung_' -'Durchschnitt_' -'Dublin_' -'Dritte_' -'Dritt' -'Dre' -'Dorf_' -'Dol' -'Dokument_' -'Director_' -'Diesen_' -'Denken_' -'Defence_' -'Dauer_' -'Datenschutz' -'Darstellung_' -'Daniel_' -'Dach' -'Croatia_' -'Count' -'Cooperation_' -'Content_' -'Consumer_' -'Conf' -'Commissioners_' -'Clearly_' -'Cla' -'Church_' -'Change_' -'Chamber_' -'Casino_' -'CafÊ_' -'CD' -'BÃŧrokratie_' -'BÃŧhne_' -'Bucht_' -'Brand' -'Branche_' -'Boy' -'Bot' -'Boh' -'Ble' -'Bibliothek_' -'Beweise_' -'Bew' -'Beu' -'Betten_' -'Bestrebungen_' -'Beste' -'Besitz_' -'Besch' -'Berichten_' -'Berg_' -'Beratung_' -'Belgien_' -'Beifall_' -'Beha' -'Bearbeitung_' -'Beachtung_' -'Be_' -'Baum' -'Band_' -'Baltic_' -'Bahn' -'Baby' -'BS' -'BB' -'Az' -'Ay' -'Aussichten_' -'Ausgangs' -'Aufsch' -'Aufs' -'Aufforderung_' -'Att' -'Architektur_' -'Arbeitgeber_' -'Anträge_' -'Angl' -'Alternativen_' -'Alta' -'Album_' -'Aktionsplan_' -'Aktions' -'Aktien_' -'Ah' -'Afrikas_' -'Af' -'Abteilung_' -'Absch' -'Abb' -'AKP_' -'AA' -'92_' -'900_' -'77_' -'73' -'68_' -'67_' -'61_' -'600_' -'2050_' -'19th_' -'1985_' -'1979_' -'1956_' -'1900_' -',/_' -', ‘_' -'+ _' -'%_' -'% ' -' ("_' -' %._' -' %' -'” ' -'’' -'Ņ_' -'Ņ' -'Đŧ_' -'и_' -'Îĩ' -'Ãŧtz' -'Ãŧbersch' -'Ãŧberlassen_' -'Ãŧberdenken_' -'Ãŧberd' -'Ãļn' -'ô' -'í_' -'ÃĢ' -'Êr' -'Êl' -'ç' -'äußern_' -'ände_' -'änd' -'ähnlicher_' -'äch' -'ÃĄn_' -'ÃĄn' -'Übers' -'Übernahme_' -'Äußerungen_' -'Ärzte_' -' –, _' -'}}) _' -'| _' -'{_' -'zwÃļlf_' -'zweitens_' -'zwangsläufig_' -'zunehmende_' -'zt' -'zo_' -'zis' -'zig_' -'zierte' -'zien' -'ziele_' -'ziel_' -'zeug' -'zentraler_' -'zentral_' -'zeigte_' -'younger_' -'yn_' -'yields_' -'yield_' -'yed_' -'yea_' -'ycl' -'yc' -'xe' -'wÃŧr' -'worrying_' -'worried_' -'wonder' -'withdrawal_' -'wir' -'willig' -'widmen_' -'weste' -'weshalb_' -'werten_' -'werte' -'werke_' -'wenigsten_' -'weiten_' -'wea' -'watch' -'ward' -'wann_' -'wahl' -'vorsieht_' -'viewed_' -'videos_' -'veränderte' -'verä' -'verursachen_' -'vertraut_' -'verte' -'verstärkte_' -'verschi' -'vermieden_' -'verlor_' -'verkaufen_' -'verk' -'verfÃŧgbaren_' -'verbreitet_' -'verbot' -'verbinden_' -'uti' -'ury_' -'urren' -'urm_' -'urf_' -'uo' -'unterschiedlich_' -'unters' -'unte' -'unp' -'ungsf' -'unfair_' -'unemployed_' -'uneingeschränkt_' -'underway_' -'understands_' -'undermining_' -'unbea' -'unange' -'umgehend_' -'umfassen_' -'umb' -'ula_' -'uit' -'uin' -'uelle' -'ud_' -'ually_' -'ture' -'tube_' -'tting_' -'tt' -'tsche' -'très_' -'tro_' -'trib' -'trenn' -'trea' -'travellers_' -'trat_' -'tras' -'transnational_' -'transitional_' -'transformed_' -'trains_' -'trained_' -'tragedy_' -'traditions_' -'trade' -'trace' -'tr' -'tori' -'toll_' -'tol' -'tliche_' -'tle' -'tl' -'tiz' -'tischer_' -'tica' -'thro' -'thre' -'thirds_' -'thir' -'thin' -'theme_' -'theater_' -'tg' -'tf' -'test' -'termin' -'temp' -'telling_' -'teils_' -'technologischen_' -'tech_' -'tec' -'tas_' -'sympathy_' -'suspended_' -'sus_' -'surveillance_' -'surroundings_' -'sup' -'suggestions_' -'stÃŧ' -'stuf' -'struktur_' -'strict_' -'strengthened_' -'strategi' -'stopped_' -'stopp' -'stoffe_' -'steigenden_' -'stehenden_' -'stea' -'stattfindet_' -'starters_' -'stabiliz' -'squ' -'sprachliche' -'sprachen_' -'spl' -'speziell_' -'spectacular_' -'specifications_' -'soul_' -'sorgfältig_' -'solch_' -'smit' -'slopes_' -'skin_' -'ski' -'signing_' -'shot_' -'shortcomings_' -'shore_' -'shoot' -'shipping_' -'ship' -'shing_' -'shel' -'shareholders_' -'shaping_' -'seu' -'settings_' -'servi' -'sel_' -'secondary_' -'scu' -'script_' -'sco' -'schwere_' -'schwank' -'schri' -'schn' -'schlechter_' -'schla' -'schen' -'schauen_' -'schaden_' -'scene_' -'safer_' -'sad' -'rÃŧck' -'rth' -'rste' -'roots_' -'ron_' -'roles_' -'roh' -'rog' -'rod' -'rob' -'roa' -'rmo' -'rli' -'rien_' -'riche' -'rich' -'rg_' -'rfe' -'reve' -'returning_' -'retain_' -'restoration_' -'responded_' -'rese' -'reproduction_' -'rena' -'relevante' -'regulate_' -'register_' -'regelung_' -'recruit' -'re' -'rc' -'rbe' -'ram' -'rad_' -'racing_' -'quote' -'quotas_' -'quin' -'quil' -'question' -'quel' -'quantities_' -'publicly_' -'publi' -'ption_' -'proximity_' -'proof_' -'promoted_' -'professionelle' -'profession' -'produzieren_' -'produkt' -'printed_' -'preparing_' -'predict' -'pped_' -'positiv_' -'posi' -'porte' -'politischem_' -'poli' -'pm_' -'ple_' -'plat' -'plane_' -'plan' -'piel_' -'piel' -'persuade_' -'persistent_' -'persist' -'permit_' -'pendi' -'pen_' -'pel_' -'peer_' -'patio_' -'pan_' -'pakete_' -'pak' -'pain_' -'oz' -'owe' -'ow_' -'overwhelming_' -'overn' -'ov_' -'outset_' -'outlined_' -'outcomes_' -'ott' -'orp' -'originally_' -'orc' -'openly_' -'oon_' -'onne' -'ome_' -'ologie_' -'ola' -'off' -'observed_' -'obli' -'nÃŧtzlich_' -'nä' -'ntr' -'nswerte' -'nsi' -'nsch' -'noticed_' -'nos_' -'normale_' -'noi' -'nkt_' -'niz' -'niert_' -'niedriger_' -'nha' -'ners_' -'neoliberal' -'neo' -'neighbourhood_' -'nei' -'negotiation_' -'neg' -'neck' -'nea' -'ndi' -'nder' -'nci' -'native_' -'nationalism_' -'nachweis' -'nac' -'mäßigen_' -'mächtigen_' -'m²_' -'mutige' -'museum_' -'muni' -'mpi' -'mot' -'mortgage_' -'moralischen_' -'mora' -'moon_' -'moderate_' -'mobilis' -'mittleren_' -'missing_' -'missbrauch' -'mirror' -'minorities_' -'min' -'mill' -'mig' -'meri' -'mente_' -'mend' -'meistens_' -'medieval_' -'mate' -'maritime_' -'mano' -'mani' -'mandatory_' -'managers_' -'mailing_' -'mage' -'machten_' -'m2_' -'lÊ' -'längst_' -'lying_' -'ly' -'ltung' -'loses_' -'los' -'lon' -'lokaler_' -'logie_' -'loc' -'loan_' -'llers_' -'lip' -'ling' -'line' -'liebe_' -'license_' -'liability_' -'leverage_' -'leid' -'leib' -'lang' -'lan_' -'lacking_' -'kÃŧrze' -'kÃŧnftigen_' -'käm' -'kritischen_' -'konzipiert_' -'konst' -'komplizierte' -'kk' -'kis_' -'kid' -'kg_' -'kenn' -'kehrt_' -'kata' -'kana' -'jährliche_' -'jä' -'justification_' -'jointly_' -'jo_' -'jet' -'jed' -'jan' -'iza' -'ius_' -'iten_' -'issu' -'isse' -'islamischen_' -'ised_' -'irre' -'irgendeine' -'ional_' -'interpretation_' -'interp' -'internationally_' -'intern' -'intensiven_' -'intense_' -'intelligent_' -'installieren_' -'inspir' -'inspection_' -'inner' -'inj' -'inhabitants_' -'infolge_' -'infi' -'infer' -'indoor_' -'individuellen_' -'individuali' -'indischen_' -'indicators_' -'inder' -'incon' -'inclusive_' -'incidents_' -'impos' -'imi' -'imbalances_' -'illegalen_' -'illegale_' -'ille_' -'ignor' -'ifi' -'ieh' -'ieg' -'ied' -'ider' -'ideology_' -'ida' -'ichten_' -'icht' -'ich' -'ically_' -'ic' -'ib' -'ias_' -'humanity_' -'hospitality_' -'hop' -'holders_' -'hohem_' -'hme' -'hl' -'hinter' -'hinaus' -'highlights_' -'heutzutage_' -'hervorragenden_' -'heil' -'hea' -'harmonisation_' -'handlung' -'han' -'haltung_' -'halte' -'hal_' -'gÃŧnstig_' -'gue_' -'grÃŧnen_' -'grÃļße' -'gross_' -'grenzen_' -'gratulieren_' -'grass' -'gramm' -'grade' -'governed_' -'gol' -'gma' -'globe_' -'globaler_' -'gkeiten_' -'gio' -'ghe' -'gger' -'gewÃŧnschten_' -'gewählte_' -'gewe' -'gesunde' -'gestattet_' -'gespeichert_' -'geräumige' -'geprägt_' -'genetische' -'gend' -'gelassen_' -'gelang_' -'gefährlichen_' -'gefährdet_' -'gefa' -'gedr' -'gedacht_' -'ged' -'gation_' -'gate' -'gam' -'gabe_' -'fte_' -'frÃŧhere_' -'freundliches_' -'freely_' -'fre' -'franc' -'fought_' -'fortune_' -'forests_' -'fordere_' -'foot' -'follows_' -'folder_' -'fol' -'flug' -'fli' -'flee' -'flag' -'fine' -'findings_' -'finances_' -'fill_' -'fied_' -'fi_' -'ffic' -'fertig' -'fei' -'feelings_' -'fans_' -'fahrt' -'experiment_' -'expanding_' -'execution_' -'excluded_' -'examine_' -'ewe' -'euros_' -'euren_' -'essa' -'esc' -'erweiterten_' -'erweitert_' -'erweitern_' -'ertet_' -'erstklassige' -'ersta' -'ersetzen_' -'erneuerbare_' -'erlassen_' -'erhoben_' -'erheblichen_' -'erh' -'erforderliche_' -'eon' -'entspr' -'ents' -'entre' -'entl' -'entity_' -'entertainment_' -'ente' -'ensures_' -'ensch' -'enm' -'enjoys_' -'endgÃŧltigen_' -'endet_' -'empower' -'employed_' -'empfängt_' -'empfehlen_' -'embrace_' -'eliminated_' -'elf_' -'elegante_' -'eint' -'einschl' -'einsch' -'eingehalten_' -'einb' -'eichen_' -'ego' -'ection' -'economically_' -'eber' -'eba' -'eate' -'durchsetzen_' -'dungen_' -'duct' -'dst' -'drives_' -'drama' -'dominance_' -'diversifi' -'dist' -'diplomacy_' -'dina' -'differently_' -'df' -'designer_' -'ders_' -'deregulation_' -'denk' -'demonstrates_' -'demanded_' -'declar' -'dare' -'damaged_' -'customs_' -'cur' -'cue' -'cua' -'cred' -'countryside_' -'convince_' -'consists_' -'conservation_' -'connect_' -'confront_' -'confi' -'component_' -'competences_' -'como_' -'command' -'combine_' -'colors_' -'colo' -'codes_' -'coc' -'clock_' -'clarity_' -'ckte' -'ckl' -'cket_' -'cke' -'circle_' -'cin' -'cia_' -'chung_' -'chtlich_' -'chne' -'chau' -'chart' -'chant' -'centr' -'ced_' -'cca' -'caught_' -'cas_' -'cari' -'capt' -'capabilities_' -'bus' -'bung_' -'bubble_' -'brutal_' -'brid' -'breath' -'break' -'brain_' -'bombe' -'bod' -'boat_' -'block' -'blind' -'blick_' -'blick' -'blank' -'bisherigen_' -'bir' -'binding_' -'bin' -'bewältigen_' -'bewiesen_' -'beweist_' -'bewegung_' -'bew' -'betrieben_' -'betr' -'betont_' -'bet_' -'bestrafen_' -'beschränk' -'beschreibt_' -'berichten_' -'beri' -'berg' -'berechnet_' -'bekräftigt_' -'beispiel' -'beings_' -'behi' -'begrÃŧßt_' -'begrenzt_' -'bedroht_' -'bedenk' -'basically_' -'bags_' -'backs_' -'awa' -'avi' -'avec_' -'aux_' -'aussch' -'ausreichende_' -'ausgewogene' -'ausges' -'ausgerichtet_' -'ausfÃŧhrliche' -'aufzubauen_' -'aufget' -'auch' -'att_' -'assist_' -'ash' -'ase' -'artige' -'arrest_' -'arme_' -'argumentieren_' -'ard' -'archive_' -'archa' -'appointment_' -'anzukurbeln_' -'antin' -'anst' -'anspr' -'ansehen_' -'anlagen_' -'anischen_' -'angemessen_' -'amtierende' -'alu' -'allocated_' -'alist' -'aktivieren_' -'airports_' -'airlines_' -'ahn' -'ahl' -'agt_' -'aggregate_' -'adhere' -'adequately_' -'adapt' -'actor_' -'achtet_' -'accurate_' -'academic_' -'aca' -'abstimmen_' -'abr' -'able' -'abgeben_' -'abe' -'`_' -']], [[_' -'] ' -'Zuständigkeiten_' -'ZurÃŧck' -'Zukunfts' -'Zinssätze_' -'Zins' -'Zielen_' -'Zer' -'Yen_' -'XP_' -'Wäscheservice_' -'Wärme' -'Wirtschaftspolitik_' -'Wireless_' -'Wind_' -'Wilhelm_' -'Wie' -'Whi' -'Wes' -'Werbung_' -'Wechselkurs_' -'Websites_' -'Watch_' -'Wall' -'Wai' -'Wahr' -'Wahlkampf' -'WC_' -'Vorw' -'Vorstellungen_' -'Vorr' -'Vorhaben_' -'Vier' -'Vie' -'Vertrauens' -'Vertr' -'Verteilung_' -'Vermi' -'Verkauf_' -'Vereinigung_' -'Veranstaltung_' -'Va' -'Ursache_' -'Urlaubs' -'Unters' -'Universitäten_' -'Ungarn_' -'Umweltschutz_' -'UNESCO_' -'Tsa' -'Trä' -'True_' -'Trend_' -'Tourismus_' -'Tourism_' -'Torre' -'Top' -'Tod' -'Thursday_' -'Thr' -'Thai_' -'Telefon_' -'Tam' -'Tagung_' -'Tab' -'TT' -'TR' -'SÃŧdafrika_' -'Syn' -'Sy' -'Suites_' -'Such' -'Struktur' -'Streit' -'Stream' -'Stiftung_' -'Steuereinnahmen_' -'Stand' -'Stabilitäts' -'Staatsf' -'Sports_' -'Spieler' -'Smoking_' -'Situated_' -'Sicherheitspolitik_' -'Shuttle_' -'Shop_' -'Sevilla_' -'Set' -'Sei' -'Screen' -'Schre' -'Schm' -'Schiff' -'Schicksal_' -'Satelliten' -'Salo' -'Sahara_' -'STO' -'RÃŧ' -'Rub' -'Ross' -'Ros' -'Ronald_' -'Rh' -'Review_' -'Respekt_' -'Resolution_' -'Reso' -'Republican_' -'Renaissance_' -'Reisen_' -'Regen' -'Red' -'Real_' -'Radio' -'RER_' -'RC' -'RA_' -'Question_' -'Queen_' -'Quant' -'Qatar_' -'Pul' -'Prot' -'Program_' -'Professor_' -'Pop' -'Poker' -'Pod' -'Piazza_' -'Pflanzen' -'Personal' -'Pers' -'Patten_' -'Pap' -'Panorama' -'Palm' -'PS_' -'Out' -'Others_' -'Ost_' -'Orient' -'Olympic_' -'Oh' -'Og' -'OT' -'Nor' -'Nielson_' -'Nichtraucherzonen_' -'Neg' -'Natura_' -'Nachhaltigkeit_' -'Nachbar' -'Monte' -'Mond' -'Mir_' -'Mio_' -'Medizin_' -'Mechanismen_' -'Maschine' -'Mart' -'Marg' -'Marco_' -'Marc' -'Manchmal_' -'Mail' -'Magi' -'LÃŧ' -'LÃļsungs' -'Lärm' -'Los' -'List_' -'Link_' -'Light' -'Life_' -'Libya_' -'Library_' -'Les_' -'Lehrer_' -'Lebensmittel_' -'Landes' -'LO' -'KÃļrper' -'Kru' -'Kriminalität_' -'Kreuz' -'Kremlin_' -'Kosten' -'Kos' -'Kopenhagen_' -'Kont' -'Kompetenzen_' -'Kolonial' -'Kollegin_' -'Klimawandels_' -'Kin' -'Kerry_' -'Kern_' -'Kei' -'Kart' -'Kap' -'Kam' -'Kalten_' -'Kai' -'KI' -'Jur' -'Jane' -'It' -'Islamist_' -'Innerhalb_' -'Infolgedessen_' -'Industry_' -'Industrieländern_' -'Indem_' -'Inc_' -'IV_' -'ING_' -'ICA' -'Häfen_' -'Hypotheken_' -'Hunger_' -'Holiday_' -'Hohe_' -'High' -'Herzlich' -'Herz_' -'Her_' -'Heil' -'Hau' -'Hardware_' -'Handlungs' -'HD_' -'HD' -'Guan' -'GrÃļ' -'Großbritanniens_' -'Gleich' -'Gibt_' -'Gewinn_' -'Get' -'Gestatten_' -'Germans_' -'Genu' -'Geno' -'Generalsekretär_' -'Gener' -'Genau_' -'Gemeinden_' -'Geheim' -'Gegen_' -'Gat' -'Gast_' -'Gan' -'Galileo_' -'Gal' -'GM' -'Funktions' -'Fundamental_' -'FrÃŧhling_' -'FrÃŧh' -'Frank' -'Francisco_' -'Franc' -'Fraktionen_' -'Four_' -'Fortunately_' -'Flug_' -'Florida_' -'Fisch_' -'Fire' -'Finanzmärkte_' -'Few_' -'Ferien' -'Feld_' -'Feld' -'Feier' -'Farbe_' -'Fallout_' -'Fahrzeuge_' -'Fahrrad' -'FC_' -'Extrem' -'Exporte' -'Excellent_' -'Evo' -'Europol_' -'Et' -'Erz' -'Erm' -'Entwicklungshilfe_' -'Engine_' -'Enc' -'Eingabe' -'Ehe' -'Effektivität_' -'Eco' -'EU' -'ED' -'Dutch_' -'Drive' -'Drittländern_' -'Draghi_' -'Dort_' -'Donald_' -'Dollars_' -'Dokumente_' -'District_' -'Dim' -'Differenz_' -'Dienst_' -'Det' -'Demo' -'Deck' -'Deb' -'Davi' -'Danish_' -'DR' -'DL' -'DA_' -'Custom' -'Cup_' -'Culture_' -'Cru' -'Could_' -'Come_' -'Code' -'Churchill_' -'Christmas_' -'Children_' -'Chat' -'Ch' -'Castle_' -'Carlo_' -'Cand' -'CU' -'CP' -'Butt' -'Bud' -'Branchen_' -'Bolkestein_' -'Binnen' -'Beurteilung_' -'Betrag_' -'Bet' -'Berufs' -'Bereitstellung_' -'Behauptung_' -'Bea' -'Basic_' -'BarÃŗn_' -'Ban_' -'Bai' -'Badezimmer_' -'Back_' -'BM' -'Avenue_' -'Autor_' -'Autonomie_' -'Aussage_' -'Ausge' -'Ausgangspunkt_' -'Ausgaben' -'AusflÃŧge_' -'Ausf' -'Aufzug_' -'Ass' -'Arte' -'Arbeitskräfte_' -'Ara' -'Aqua' -'Anpassungen_' -'Ano' -'Ander' -'Amts' -'Alters' -'Allgemein' -'Alex' -'Agrarpolitik_' -'Adria' -'Abu_' -'Absichten_' -'Abkommens_' -'Abgesehen_' -'Abendessen_' -'Abbas_' -'= _' -'99_' -'75' -'62_' -'33' -'32' -'29' -'236' -'1960_' -'181' -'120_' -'. &#_' -'+' -'))._' -'( _' -'", "_' -'") _' -' {{_' -' - ' -' ,_' -'“ – _' -'א' -'Ņ€Đž' -'Đž_' -'ĐŊ' -'Đą' -'ł' -'Ãŧrlich' -'Ãŧgen_' -'Ãŧge' -'Ãŧf' -'Ãŧbrigens_' -'Ãŧbrig_' -'Ãŧberwachen_' -'Ãŧbers' -'Ãŧberr' -'Ãŧbermäßige' -'Ãļse_' -'Ãļse' -'Ãļren_' -'Ãļpf' -'Ãļkonomische_' -'Ãļhl' -'Ãļge' -'Ãļf' -'Ãą' -'ín' -'ça' -'äß' -'äume_' -'äuft_' -'äufe_' -'ässe' -'ärmsten_' -'änk' -'änden_' -'än' -'ährung_' -'ße_' -'Übertragung_' -'Übersch' -'Änderungs' -'Âģ _' -'°_' -'zÃŧgig' -'zÃļ' -'zurÃŧckzufÃŧhren_' -'zukÃŧnftige_' -'zukommen_' -'zil' -'zierungs' -'zierung_' -'zier' -'ziel' -'zial' -'zerstÃļren_' -'zerstÃļr' -'zent' -'zehnte' -'yen_' -'yard_' -'yar' -'wÃŧnscht_' -'wort' -'worker_' -'wood' -'wohn' -'wohin_' -'wn' -'witness_' -'winner_' -'wick' -'wichtiges_' -'whenever_' -'wettbewerbsfähige' -'wes' -'wert' -'werken_' -'wen_' -'weiterzu' -'weite_' -'weis' -'wed_' -'wechseln_' -'wechsel_' -'wechsel' -'wealthy_' -'weaker_' -'water' -'was' -'wart' -'warrant' -'warned_' -'wander' -'wand' -'vy_' -'vorzunehmen_' -'vorstell' -'vorliegt_' -'vorgestellt_' -'vorg' -'vore' -'vorbe' -'vorange' -'von' -'volunteers_' -'voices_' -'virtuelle' -'vin_' -'villa_' -'vid' -'verzeichnen_' -'verwandeln_' -'vertreter_' -'vertrag' -'verständlich_' -'verspricht_' -'verspr' -'verse' -'verschwende' -'verlust' -'verliert_' -'verkehr_' -'veri' -'verhältnis_' -'verhelfen_' -'verbringen_' -'verbre' -'verbrauch' -'veran' -'veil' -'vag' -'uver' -'uts_' -'utr' -'ussi' -'usa' -'uring_' -'urf' -'ural' -'upe' -'unzählige_' -'unwe' -'unwahrscheinlich_' -'untr' -'untersuchen_' -'unterste' -'unterscheidet_' -'unn' -'unmittelbare_' -'universe_' -'unin' -'ungeachtet_' -'undermine_' -'unden' -'unabhängige_' -'umgeb' -'ume_' -'ulate' -'ugh_' -'ufe' -'ues_' -'ubi' -'ub_' -'ua_' -'turno' -'tsch' -'trie' -'treasur' -'trategie_' -'transparente' -'transferred_' -'transc' -'transatlantischen_' -'trafficking_' -'tradition' -'tradable_' -'torture_' -'tolerance_' -'tle_' -'tipp' -'tight' -'tiefe_' -'tiefe' -'tied_' -'thr' -'thou' -'tho' -'thi' -'thermal_' -'territoriale' -'temperature_' -'tem' -'tei' -'technik_' -'teach_' -'tation_' -'taste' -'tape' -'tangible_' -'tan' -'sÃŧdlich_' -'systemic_' -'syn' -'sustained_' -'surre' -'surpluses_' -'surge' -'supplied_' -'suddenly_' -'substantially_' -'stände_' -'stuff_' -'studied_' -'struck_' -'strikes_' -'strictly_' -'stream_' -'straight_' -'strai' -'storm' -'stl' -'still' -'stern' -'steri' -'steam_' -'stamm' -'stal' -'staat_' -'ssystem' -'sste' -'ssl' -'ssige' -'ssel' -'sre' -'späten_' -'spät_' -'speeches_' -'specialities_' -'speaker_' -'sparen_' -'sound' -'sou' -'sophisticated_' -'songs_' -'sofa_' -'smart_' -'slight_' -'sko' -'sixt' -'situ' -'sions_' -'sierte' -'sicherstellen_' -'shocks_' -'sequence' -'senk' -'seitig' -'seemingly_' -'seas' -'script' -'schÃļnsten_' -'schnelles_' -'schnellen_' -'schne' -'schiff_' -'schafts' -'sberei' -'satisfied_' -'sat_' -'sarge' -'sani' -'sand_' -'sab' -'rÃŧckg' -'rulers_' -'ruins_' -'ruch' -'rta_' -'ror' -'rooted_' -'ron' -'romantic_' -'robe' -'rnen_' -'river_' -'rist' -'ris_' -'ript' -'rif' -'rian_' -'rhe' -'rf_' -'retro' -'rete' -'ret' -'resulted_' -'restrict_' -'restoring_' -'ression' -'resign' -'researchers_' -'repro' -'reo' -'reminded_' -'reiten_' -'reisen_' -'reine' -'reif' -'reichsten_' -'reiben_' -'reh' -'regionen_' -'refuse_' -'referring_' -'referen' -'reco' -'rechnen_' -'rechen_' -'reasonably_' -'realise_' -'readily_' -'rba' -'ration' -'ratification_' -'raph' -'rand' -'ranc' -'ran' -'raises_' -'racism_' -'rab' -'qualitative_' -'qualifications_' -'qu' -'pull_' -'pue' -'pto' -'provider_' -'prote' -'proportion_' -'prophe' -'promptly_' -'proliferation_' -'projekt_' -'programming_' -'profitiert_' -'professor_' -'prof' -'produktive_' -'produkte_' -'producer_' -'proceed' -'problemati' -'pricing_' -'presum' -'presse' -'presenting_' -'preferences_' -'praktischen_' -'ppen_' -'ppe_' -'pou' -'posts_' -'postponed_' -'popularity_' -'popula' -'politician_' -'politi' -'plu' -'pledged_' -'pipe' -'pig' -'pert' -'persona' -'pe' -'pay' -'path' -'patents_' -'patent' -'passes_' -'partic' -'painful_' -'overview_' -'overl' -'outs_' -'outl' -'organ' -'orfen_' -'optimale_' -'opponents_' -'ook_' -'ong' -'ommen' -'olle' -'oldest_' -'oid' -'oi' -'ohnehin_' -'ohn' -'offset_' -'officers_' -'odie' -'occupied_' -'occup' -'occ' -'oca' -'observation_' -'obliged_' -'nÃļtigen_' -'nächster_' -'nzi' -'num' -'ntion_' -'np' -'nous_' -'notable_' -'nos' -'nonetheless_' -'nom_' -'nne' -'nlage' -'nister' -'nism' -'nische_' -'nineteenth_' -'niedrige' -'nhei' -'nger' -'nenn' -'negotiating_' -'negotiate_' -'nche' -'nberg_' -'nati' -'nas' -'nai' -'nachfrage_' -'mÃŧh' -'mÃĄ' -'multinational' -'multimedia_' -'mul' -'mt' -'mouse_' -'mounting_' -'mount_' -'mortgages_' -'moralisch_' -'moderner_' -'moderate' -'modell_' -'mittelalterliche' -'minder' -'milita' -'method' -'mental' -'melt' -'mehrs' -'mehr' -'mbl' -'match' -'master' -'massage_' -'mas_' -'marke' -'mar_' -'manipulation_' -'mangel' -'man' -'magic_' -'lÃļst_' -'lÃļ' -'lys' -'lution' -'lungen_' -'lon_' -'loi' -'logy_' -'logi' -'ln' -'llte_' -'liv' -'liquid' -'linguistic_' -'lings_' -'lighting_' -'light' -'lich' -'lia' -'letter_' -'lern_' -'leis' -'leidet_' -'leichen_' -'lebenslange' -'lb_' -'lays_' -'laying_' -'langfristige' -'lan' -'labelling_' -'kurzfristigen_' -'kurzfristig_' -'kurzer_' -'kurzen_' -'ktionen_' -'ksi' -'kräfte_' -'kriege' -'kostenlose_' -'koordinierte' -'konstruktive_' -'konferenz_' -'komplexer_' -'komplexen_' -'komplexe_' -'kni' -'klar' -'kit' -'kirche_' -'kilometers_' -'kill_' -'kern_' -'junger_' -'jun' -'jahre' -'izi' -'itz' -'itor' -'ito' -'itische' -'issen_' -'isen_' -'isches_' -'isc' -'iro_' -'irk' -'iri' -'irakischen_' -'investing_' -'interessen_' -'interessante_' -'institutionellen_' -'innenpolitische' -'initiated_' -'ining_' -'infringement_' -'inen_' -'industrialis' -'inder_' -'ind' -'incite' -'inan' -'immun' -'immer' -'illegal' -'ildung_' -'iko' -'ii' -'iew' -'ierenden_' -'iere' -'ieden' -'ieben_' -'idat' -'ichtete' -'icherung' -'icate' -'ica' -'hunting_' -'hundert' -'human' -'hum' -'hro' -'hoped_' -'honor_' -'hole_' -'hol_' -'hochwertige' -'hmen_' -'hm' -'hle_' -'historically_' -'hervorgehoben_' -'herstell' -'herausstellen_' -'hend_' -'heart' -'heading_' -'hb' -'hazardous_' -'haz' -'han_' -'halb' -'hair_' -'hai' -'had' -'gur' -'grenzt_' -'gran' -'governmental_' -'globale' -'git' -'gingen_' -'gierung_' -'geäußert_' -'gezielte' -'gewählten_' -'gewinnt_' -'getÃļtet_' -'gestatten_' -'gestartet_' -'gest' -'geschw' -'geräte_' -'gers_' -'geringen_' -'geregelt_' -'gens' -'geniessen_' -'genes_' -'gemeinschaftlichen_' -'gemeinsamer_' -'gelingt_' -'gehabt_' -'geh' -'gegangen_' -'gede' -'geblieben_' -'gebieten_' -'gebi' -'gd' -'gay_' -'fÃŧhr' -'fÃŧhl' -'fällig_' -'fähig' -'fundamental' -'frag' -'fra' -'forth_' -'formi' -'formats_' -'format' -'forge_' -'forcing_' -'folgende_' -'focuses_' -'fn' -'fizier' -'fix_' -'fiskal' -'finished_' -'finanzierung_' -'fin_' -'fiel_' -'fic' -'ffer' -'ffen' -'fern_' -'feet_' -'fassen_' -'fascinating_' -'farms_' -'facilitated_' -'extending_' -'expressing_' -'exporte' -'explains_' -'explained_' -'expertise_' -'experiences_' -'expenditures_' -'existed_' -'exi' -'ex' -'evaluat' -'eute' -'europ' -'eto' -'ethnischen_' -'esto' -'eso' -'erwähnen_' -'erwerben_' -'erlangt_' -'erklär' -'erin' -'erhältlich_' -'erei' -'erarbeiten_' -'episode_' -'entscheidend_' -'entdecken_' -'enne' -'engagierte' -'engage_' -'eng' -'enforced_' -'enemy_' -'ends_' -'endgÃŧltig_' -'ender_' -'endanger' -'encourages_' -'ences_' -'employ' -'ella_' -'elites_' -'eits' -'eit' -'eise' -'eis' -'einladende' -'einh' -'einfl' -'einbezogen_' -'eigentlichen_' -'eibung' -'ehrgeizige' -'egt_' -'ega' -'effiziente_' -'effizient_' -'effi' -'edly_' -'edia' -'echt' -'dw' -'durchschnittlich_' -'durchg' -'dun' -'dry_' -'dru' -'drinking_' -'downtown_' -'downs' -'doubl' -'domain_' -'dog_' -'dn' -'dle' -'dition' -'distinction_' -'disorder_' -'disi' -'discretion_' -'discret' -'discovery_' -'discourse_' -'dischen_' -'disagree_' -'directed_' -'diplomats_' -'dine' -'dil' -'dignity_' -'digitalen_' -'dieselbe_' -'dge' -'detaillierte' -'desi' -'dere' -'derartiger_' -'depression_' -'depressed_' -'deposits_' -'deployment_' -'denied_' -'demnächst_' -'demi' -'delicate_' -'dei_' -'defining_' -'decent_' -'deca' -'debating_' -'deadl' -'daf' -'dac' -'crack' -'covers_' -'courage_' -'council_' -'cot' -'cost' -'coordinat' -'conventions_' -'contributing_' -'contacts_' -'consent_' -'confusion_' -'confirmation_' -'configuration_' -'condemn_' -'comput' -'compre' -'compar' -'communism_' -'commande' -'collaborat' -'codi' -'coastal_' -'closest_' -'clon' -'clo' -'cler' -'class' -'clarify_' -'citizenship_' -'cit' -'ciones_' -'cier' -'cial_' -'chtigen_' -'chose_' -'choosing_' -'chle' -'cherung_' -'chee' -'checks_' -'cheaper_' -'charit' -'characters_' -'characteristic_' -'chances_' -'cere' -'cepti' -'centrally_' -'centrali' -'cent' -'catalog' -'cali' -'bull' -'buch' -'boy_' -'boundaries_' -'borrowing_' -'bon_' -'boasts_' -'blow' -'blog_' -'blase_' -'blame_' -'birds_' -'biodiversity_' -'bilateralen_' -'bha' -'beunruhigend' -'beträchtliche' -'betriebe_' -'betrieb' -'beteiligen_' -'bern_' -'berichtet_' -'bereiche_' -'berechtigt_' -'bem' -'bele' -'behÃļrden_' -'behindert_' -'begleitet_' -'befÃŧrchten_' -'befindliche' -'bedÃŧrf' -'bedienen_' -'beantwortet_' -'baut_' -'bath' -'basket' -'base' -'ax_' -'außen' -'auszusch' -'auszur' -'ausste' -'ausreichen_' -'ausr' -'auslÃļsen_' -'ausgelÃļst_' -'ausgeh' -'aufweist_' -'aufst' -'aufrechterhalten_' -'aufn' -'auffordern_' -'auff' -'aufen_' -'auber' -'attending_' -'attached_' -'atr' -'ato' -'assuming_' -'assigned_' -'assen' -'aspirations_' -'asch' -'arises_' -'arge' -'aren_' -'arche' -'appl' -'ao' -'antiqu' -'anno' -'annex' -'anne' -'anlage_' -'anim' -'anh' -'angry_' -'anges' -'angef' -'angeblich_' -'angeb' -'anerkennen_' -'anci' -'ame_' -'ame' -'amb' -'amazing_' -'alls' -'alle' -'alität_' -'alg' -'alarming_' -'alarm' -'agte' -'agi' -'ager' -'agent_' -'age' -'advice_' -'adapted_' -'activists_' -'ace_' -'accommodate_' -'accidents_' -'acce' -'abst' -'abschl' -'abl' -'abi' -'abgeb' -'abandoning_' -'aan' -'aa' -'Zwi' -'Zw' -'Zunahme_' -'Zugeständnisse_' -'Zeitplan_' -'Zeichen' -'Zauber' -'Yugoslavia_' -'Youth_' -'Ye' -'Xi' -'WÃŧr' -'Wälder_' -'Wirtschaftsp' -'Wette' -'Wettbewerbs_' -'Wert' -'Werbe' -'Weltraum' -'Wein_' -'Warcraft_' -'WA' -'Vorgehensweise_' -'Vorbild_' -'Volksgesundheit_' -'Vladimir_' -'Vista_' -'Viertens_' -'VerÃļffentlichung_' -'Versicherungs' -'Verordnungen_' -'VermÃļgen_' -'Verm' -'Verluste_' -'Verkehrsmittel_' -'Verh' -'Veranstaltungen_' -'VP_' -'Ursprung_' -'Unterhaltungs' -'Unrecht_' -'Unm' -'Unf' -'Unbe' -'Una' -'Umge' -'Ukrainian_' -'UT' -'UE' -'UB' -'TÃŧr_' -'Tun' -'Tour' -'Tot' -'Top_' -'Til' -'Through_' -'Thomas_' -'Theater_' -'Terror' -'Tempo_' -'Tel_' -'Tax' -'Tan' -'TU' -'TPP_' -'SÃŧd_' -'SÃŧ' -'Systemen_' -'Swa' -'Supreme_' -'Sul' -'Strom_' -'Sto' -'Stimmung_' -'Still_' -'Steuer_' -'Ster' -'Statistik_' -'Starfleet_' -'Spur' -'Spri' -'Speisen_' -'Spare' -'Spani' -'Sozialisten_' -'Sor' -'Solutions_' -'Solar' -'Sogar_' -'Soft' -'Socialist_' -'Sk' -'Sit' -'Show' -'Shia_' -'Several_' -'Ser' -'Select_' -'Sco' -'SchÃļ' -'Schwellenländern_' -'Schweiz_' -'Schwarz' -'Schulz_' -'Schu' -'Schlafzimmer_' -'Sai' -'Saharan_' -'Sachen_' -'ST' -'SPA' -'SH' -'SER' -'Russen_' -'Russ' -'Rule_' -'Rot' -'Rhein' -'Revision_' -'Ressourcen' -'Reservierung_' -'Rental_' -'Regi' -'Rede' -'Rechtss' -'Reali' -'Ratspräsident_' -'Rap' -'Rande_' -'Railway_' -'RG' -'Quadra' -'Prozesse_' -'Provinz_' -'Propaganda_' -'Projekts_' -'Prognosen_' -'Produktivitäts' -'Privatisierung_' -'Pres' -'Praktiken_' -'Pot' -'Posten_' -'Port_' -'Political_' -'Play' -'Pilot' -'Pflicht_' -'Pfa' -'Petr' -'Petersburg_' -'Peking_' -'Patt' -'Patent' -'Parti' -'Partei' -'Parameter_' -'Pakt_' -'Pack_' -'PR' -'PM' -'Over' -'Orl' -'Opti' -'Oppositions' -'Operation_' -'On' -'OD' -'OC' -'Null' -'Normal' -'Nixon_' -'Nikon_' -'New' -'Never_' -'Netanyahu_' -'Nelson_' -'Nehmen_' -'Near_' -'Nationalen_' -'Nar' -'Nahrungsmittel' -'NICHT_' -'MÃŧnchen_' -'Mängel_' -'Mä' -'Mut_' -'Mut' -'Multi_' -'Mozilla_' -'Mot' -'Modus_' -'Mode_' -'Mitarbeitern_' -'Mir' -'Mini_' -'Minderheit_' -'Mil' -'Mid' -'Metro' -'Messen' -'Merkmale_' -'Mengen_' -'Maßstab_' -'Max' -'Marx_' -'Manage' -'Main' -'Mach' -'MH' -'LÃļhne_' -'Länge_' -'Luxus_' -'Luca' -'Long_' -'Log' -'Leone_' -'Leicht' -'Legal_' -'Lee_' -'Leave_' -'Laun' -'Lands' -'Lam' -'Lah' -'Labor_' -'KÃļ' -'Ky' -'Kuwait_' -'Kräften_' -'Kreditkarte_' -'Kop' -'Konzept' -'Kontin' -'Kontext_' -'Kontakt' -'Konsumenten_' -'Konferenz' -'Kommissionspräsident_' -'Kommissions' -'Kohlen' -'Klassen' -'Key_' -'Kenya_' -'Kenn' -'Kel' -'Katastrophen_' -'Katalog_' -'Karriere_' -'Karls' -'Kanada_' -'Kamera_' -'Kali' -'Kaffee' -'Kab' -'KON' -'Jugendlichen_' -'Jud' -'Juan_' -'Jewish_' -'Jan_' -'Isa' -'Irak' -'Interventionen_' -'Internetzugang_' -'Intera' -'Installer_' -'Ink' -'Inhalte_' -'Imp' -'Immo' -'Images_' -'Il_' -'Ideal' -'Ibiza_' -'IDE' -'Hu_' -'Hot' -'Host' -'Holland_' -'Holl' -'Him' -'Herrsch' -'Herb' -'Heizung_' -'Hauptver' -'Hat' -'Hass_' -'Hart' -'Harmonisierung_' -'Harbour_' -'Halbinsel_' -'Haben_' -'HS' -'GÃŧter' -'Gäste' -'Grundlagen_' -'Griff_' -'Graf' -'Grad' -'Governance_' -'Gla' -'Gipfeltreffen_' -'Gewährleistung_' -'Gesundheitswesen_' -'Gestaltung_' -'Gesetzes' -'Geschäft_' -'Geräte' -'Gepäckraum_' -'Gem' -'Gegenstände_' -'Gefängnis_' -'Gay_' -'Gau' -'Garantien_' -'G8_' -'Fä' -'Friday_' -'Fri' -'Freund' -'Fox_' -'Fol' -'Fluss_' -'Florence_' -'Five_' -'Finnland_' -'Finanzdienstleistungen_' -'Filme_' -'File_' -'Festplatte_' -'Fertig' -'Feinde_' -'Fatah_' -'Farb' -'Fam' -'Fall' -'Expert' -'Exp' -'Everything_' -'Etwa' -'Esta' -'Essen_' -'Esc' -'Ero' -'Erg' -'Erfolgs' -'Erf' -'Erdoğan_' -'Entwicklungsp' -'Enth' -'End_' -'Elevator_' -'Einzig' -'Einigkeit_' -'Ef' -'ENT_' -'Dy' -'Dun' -'Dru' -'Drogen' -'Dos' -'Dona' -'Dom' -'Display_' -'Disp' -'Dingen_' -'Differenzen_' -'Deutsch_' -'Desktop_' -'Der' -'Dep' -'Demokratischen_' -'Dem' -'Deep_' -'Debatten_' -'Dea' -'Danach_' -'DAT' -'Cur' -'Cri' -'Crespo_' -'Create_' -'Court' -'Cost' -'Corp' -'Consequently_' -'Computer' -'Compa' -'Chief_' -'Chaos_' -'Chancengleichheit_' -'Chancellor_' -'Cau' -'Carr' -'Cardassian' -'Capital_' -'Cala_' -'Cai' -'COM_' -'CM' -'CEO_' -'Bushs_' -'Brun' -'Boris_' -'Borg_' -'Bord_' -'Blue' -'Bis' -'Bezeichnung_' -'Betrug_' -'Betroffenen_' -'Betreiber' -'Beteiligten_' -'Besteuerung_' -'Besonders_' -'Besondere' -'Besatzungs' -'Bernanke_' -'Berlusconi_' -'Berliner_' -'Bele' -'Bek' -'Bei' -'BefÃŧrchtungen_' -'BefÃļrderung_' -'Befugnisse_' -'Bee' -'Bedauerlicherweise_' -'Bay' -'Bara' -'Balkon_' -'BU' -'Außen_' -'Ausweg_' -'Ausarbeitung_' -'August' -'Augen' -'Aufenthalts' -'Asiens_' -'Asi' -'Arti' -'Arc' -'Arbeitszeit' -'Arbeitsmarkt' -'Anne' -'Anleger_' -'Ank' -'Angel' -'Angebots' -'Ange' -'Anfragen_' -'Andre' -'Andererseits_' -'Among_' -'Ami' -'Alpen_' -'Allow_' -'Alexander_' -'Against_' -'Act' -'Abschaffung_' -'Abre' -'Abr' -'Able' -'Abfall' -'AVI_' -'AM_' -'AF' -'ACP_' -'== _' -'="_' -': '_' -'88_' -'78' -'69_' -'65' -'57_' -'53' -'51' -'46_' -'44' -'41' -'360_' -'2016_' -'1st_' -'1930_' -'176' -'150' -'0er_' -'.&_' -'. “_' -'. )' -'- (_' -', '_' -'!!_' -' Âģ_' -' -' -' * _' -' &_' -'“-_' -'Ņ‹' -'ĐŊĐž' -'ĐŊи' -'Đš_' -'Îą_' -'Îą' -'ÅĄe' -'č' -'Ãŧltig' -'Ãŧl' -'Ãŧcke' -'Ãŧberzeugend_' -'Ãŧbernahm_' -'Ãŧberg' -'Ãŧberf' -'ø' -'Ãļrtlichen_' -'Ãļrt_' -'Ãļkonomi' -'Ãļhn' -'Êta' -'ça_' -'ÃĨ_' -'ÃĨ' -'äußerte_' -'äst' -'ängen_' -'änderung_' -'ältig' -'ähig' -'ägyptischen_' -'ÃĄr' -'ßte' -'Überwachungs' -'Überraschung_' -'Überf' -'Ära_' -'Änderungsanträgen_' -'¡' -'ÂŽ' -'ÂĢ _' -'}} ' -'zÃŧ' -'zy_' -'zweifel' -'zwanzig_' -'zuwe' -'zuversichtlich_' -'zutr' -'zuse' -'zusammenzuarbeiten_' -'zurÃŧckz' -'zum' -'zon_' -'zogen_' -'zn' -'zip' -'zige' -'zeitige' -'zeig' -'zeichne' -'zauber' -'yu' -'ystems_' -'yellow_' -'yacht' -'xy' -'xe_' -'wÃŧrdigen_' -'wÃŧnschenswert_' -'wre' -'wr' -'word' -'wm' -'wit' -'wise_' -'wines_' -'windi' -'wild_' -'wi_' -'westlich_' -'wenigstens_' -'weltweite' -'wellness_' -'welcomed_' -'wel' -'weiß' -'weiterge' -'weise' -'warme' -'ware' -'walk' -'wahre_' -'wag' -'wachstum_' -'wach' -'vs' -'vorzulegen_' -'vornehmen_' -'vorne' -'vorhandenen_' -'volks' -'visitor_' -'visa_' -'violations_' -'view' -'vic' -'vi_' -'verwiesen_' -'verweisen_' -'verwaltung_' -'verwa' -'veru' -'verträg' -'vertrag_' -'versus_' -'versteht_' -'vernÃŧnftige' -'verlorene' -'verifi' -'verheerende_' -'verhandeln_' -'vergeben_' -'variations_' -'vacation_' -'utiliz' -'user' -'use' -'usage_' -'ursprÃŧngliche_' -'urlaub_' -'urge' -'uranium_' -'ups_' -'upl' -'upcoming_' -'unÃŧber' -'unweit_' -'unterwegs_' -'unterstr' -'unterg' -'untereinander_' -'unterbreitet_' -'unsu' -'unres' -'uno' -'unkon' -'units_' -'unilateral_' -'unification_' -'unhe' -'ungsz' -'ungsm' -'ungerecht' -'unfa' -'undurch' -'understandable_' -'undermined_' -'unclear_' -'unb' -'unanimously_' -'ulier' -'ukt' -'ui' -'ugu' -'uern_' -'uct' -'uchs' -'tÃļdliche' -'täte' -'twi' -'turmoil_' -'turbulent' -'tungs' -'ttu' -'tto' -'tsunami_' -'truck_' -'tric' -'trend' -'treat' -'transmission_' -'trans' -'traf_' -'tow' -'tou' -'tos_' -'tos' -'tolle' -'todo' -'tn' -'tlichen_' -'tionally_' -'tiny_' -'timely_' -'till_' -'tili' -'tia' -'ths_' -'thinks_' -'therapy_' -'ther' -'theoretical_' -'theoreti' -'thal' -'teuer' -'tete_' -'terminal_' -'tens' -'tene' -'tempt' -'tem_' -'tellen_' -'telecommunications_' -'tee' -'technische' -'teacher_' -'tasty_' -'tain' -'tab' -'sätze_' -'szen' -'symbolic_' -'sym' -'switch' -'survey_' -'sure' -'supranational_' -'suppress' -'suppliers_' -'suit_' -'suggestion_' -'substitute' -'subsid' -'subscribe' -'stÃŧtzen_' -'stärksten_' -'ständ' -'strukturen_' -'struggling_' -'strongest_' -'strip' -'string' -'strecke' -'streben_' -'strange_' -'straightforward_' -'strafrechtlich' -'stood_' -'stoffen_' -'stitut' -'sting_' -'stiegen_' -'sters_' -'stell' -'stays_' -'stat' -'startet_' -'starten_' -'stammt_' -'stabil' -'sst' -'ssen' -'spielte_' -'sphere_' -'sper' -'specify_' -'speaks_' -'spart' -'sow' -'sorte' -'sonn' -'sogenannten_' -'socio' -'slave' -'skri' -'sinken_' -'sin_' -'sichtbar_' -'sic' -'shifts_' -'shelter' -'shar' -'shame' -'sgr' -'sges' -'sge' -'sex_' -'sex' -'sevent' -'setzte_' -'settlements_' -'settle' -'sett' -'servers_' -'separate' -'semi_' -'seitdem_' -'sehe' -'seh' -'sees_' -'securities_' -'secondly_' -'schwer' -'schwedischen_' -'schung_' -'schritt' -'schrei' -'schnellere' -'schli' -'schicken_' -'scheitern_' -'sche' -'schalte' -'scen' -'scan' -'satisf' -'san_' -'salt_' -'sala' -'sagten_' -'safely_' -'safe' -'räume_' -'rät' -'räge_' -'räfte' -'rze' -'rve' -'ruin' -'rui' -'ruhe_' -'rseits_' -'roof_' -'roc' -'rne' -'rische_' -'rine' -'rim' -'rige_' -'richtet_' -'richte' -'revealed_' -'retten_' -'retr' -'retire' -'reti' -'resse' -'responding_' -'residents_' -'residen' -'requests_' -'reproduc' -'rep' -'renovated_' -'rem_' -'rela' -'reicher_' -'rei_' -'regierung_' -'regier' -'regelungen_' -'refuge_' -'redu' -'rech' -'realised_' -'reader_' -'raus' -'ration_' -'ratified_' -'rap' -'ranks_' -'raid' -'raft_' -'radical' -'radiation_' -'ract' -'quest_' -'quellen_' -'qualit' -'qualify_' -'py_' -'purely_' -'purchasing_' -'pse' -'präsent' -'proven_' -'proto' -'protest_' -'promot' -'prom' -'prohibited_' -'programmen_' -'profit' -'privatization_' -'privatis' -'prisons_' -'presentation_' -'premium_' -'prejudice_' -'preferred_' -'preference_' -'prefer' -'predictable_' -'precious_' -'precarious_' -'prach' -'ppt' -'ppo' -'pper_' -'ppe' -'possession_' -'positiven_' -'poorer_' -'polnischen_' -'pole' -'plä' -'plung' -'plug_' -'plo' -'platforms_' -'pilot_' -'pil' -'pielen_' -'pfe' -'pf_' -'pet_' -'periphery_' -'penalty_' -'pei' -'ped' -'pause_' -'patient_' -'paths_' -'past' -'pas_' -'partitions_' -'participa' -'partei' -'panels_' -'panel_' -'pane' -'palace_' -'overw' -'overt' -'overseas_' -'outs' -'outline_' -'outlets_' -'oto' -'oti' -'osten_' -'osc' -'osa' -'orts_' -'orte' -'orm_' -'orientierte' -'ordnete' -'opposing_' -'operator_' -'ood_' -'onis' -'one' -'onds_' -'olt' -'oi_' -'offene' -'offenbar_' -'ody_' -'odu' -'od' -'och' -'obsess' -'objekt' -'obi' -'obe' -'oba' -'nÃļrdlichen_' -'nut' -'nsb' -'nou' -'nominal_' -'nnt_' -'nna_' -'nj' -'nity_' -'nissen_' -'nischen_' -'nir' -'nightlife_' -'nier' -'niederländischen_' -'nh' -'newe' -'new' -'netz' -'nent_' -'nell' -'negotiated_' -'nee' -'necessity_' -'nds_' -'ndr' -'ndelte' -'nav' -'nau' -'natÃŧrliche_' -'nativ' -'national' -'nada_' -'mÃļge_' -'mäßig_' -'myth' -'muslimischen_' -'multip' -'multilateralen_' -'mption' -'mpli' -'mpe' -'movies_' -'mov' -'mouth_' -'mood_' -'momentum_' -'modules_' -'modes_' -'modernisier' -'mobili' -'mmer_' -'mitzu' -'mittlere' -'mitteln_' -'mist' -'missile_' -'mir' -'ming' -'militärischer_' -'militia' -'meti' -'messe' -'mern_' -'merge' -'ment' -'menschen' -'meister' -'medizinische_' -'md' -'mature_' -'materiali' -'master_' -'massa' -'maschine_' -'mart' -'mals_' -'male' -'magazine_' -'mac' -'lär' -'lv' -'lusion' -'lus' -'lossen' -'losigkeit_' -'logo_' -'logisti' -'locker' -'locked_' -'lla_' -'lk' -'lively_' -'link' -'liner_' -'lift_' -'liederung_' -'lieben_' -'lge' -'lg' -'lex' -'lette_' -'lengthy_' -'lend_' -'lend' -'leiten_' -'leit' -'legale' -'lear' -'laute' -'lauf' -'late' -'las' -'lant' -'langsamer' -'landing_' -'kÃŧmmern_' -'kä' -'kund' -'kultur_' -'kte_' -'kte' -'kritische_' -'kritisch' -'kreis' -'kratisch' -'kraten_' -'kopieren_' -'konzept' -'kontrolliert_' -'kontrollen_' -'konk' -'konf' -'komplette_' -'komplette' -'kombiniert_' -'kollektive' -'kno' -'klassischen_' -'klassische' -'klarer_' -'kka' -'kit_' -'keys_' -'keen_' -'kamp' -'kale' -'jÃŧdischen_' -'junta_' -'jug' -'judges_' -'iz' -'ives_' -'itt' -'its' -'istischen_' -'istische_' -'istan' -'iss' -'isolated_' -'ism' -'isieren_' -'ish' -'isation_' -'isa' -'irr' -'ironi' -'irischen_' -'irgend' -'ires_' -'iranische_' -'ip_' -'iona' -'invi' -'intervene_' -'internationales_' -'interfere' -'interessanten_' -'interd' -'inter_' -'integrati' -'integral' -'inta' -'insur' -'institute_' -'installer_' -'installations_' -'inspired_' -'ino' -'inner_' -'inis' -'ini_' -'inhalt' -'industriellen_' -'industri' -'induce' -'indicati' -'ind_' -'incur' -'incredibly_' -'inb' -'inad' -'imen' -'imate' -'imagin' -'ille' -'illa' -'ilen_' -'igungs' -'igli' -'igen' -'iga' -'ift' -'ifft_' -'ient_' -'ieg_' -'iefer' -'ieb_' -'idig' -'ider_' -'iden_' -'icklung_' -'ick' -'ichtung' -'ibe' -'iPod_' -'hÃļchst_' -'hôtel_' -'häl' -'hypothe' -'hybrid_' -'hunger_' -'hundert_' -'hub' -'hs' -'hren' -'hos' -'horn' -'hop_' -'hom' -'hohes_' -'hochwertige_' -'hma' -'hir' -'hinzugefÃŧgt_' -'hinweg_' -'hinnehmen_' -'hinge' -'hike' -'high' -'hersteller_' -'herstellen_' -'hers_' -'heroi' -'herausge' -'hera' -'heal' -'hause' -'harsh_' -'harm' -'hard' -'hara' -'handled_' -'halben_' -'halb_' -'haf' -'gw' -'gua' -'grÃŧnde' -'grundlegende' -'gru' -'großartig' -'grouping' -'grosse_' -'graph' -'graf' -'graduate' -'gradual_' -'grab' -'god' -'gnen_' -'gk' -'gien' -'gewährt_' -'gewä' -'gewidmet_' -'gewi' -'gewalt' -'gewachsen_' -'getr' -'geteilten_' -'get' -'gesunken_' -'gesto' -'geschr' -'geschlossene' -'geschicht' -'geru' -'gerechten_' -'gepflegt' -'generier' -'gende' -'gema' -'geltend' -'geist' -'gehei' -'geg' -'gefÃŧhrten_' -'geeigneten_' -'geeignete_' -'gather_' -'gastro' -'garage_' -'ganze' -'gangs' -'fÃŧrchte' -'fÃŧnf' -'fÃŧ' -'fäl' -'fut' -'funktion' -'fungier' -'fueled_' -'frÃŧh_' -'fru' -'friend_' -'fried' -'freilich_' -'freige' -'frame_' -'fragte' -'founder_' -'foster_' -'fortgesetzte' -'fortge' -'formuliert_' -'forme' -'forecast_' -'forder' -'fon' -'fläche_' -'flow' -'flaw' -'flags_' -'fits_' -'finanziell_' -'fin' -'fifth_' -'fie' -'festzulegen_' -'festlegen_' -'festgelegten_' -'featured_' -'favourite_' -'favour' -'fault_' -'fashion' -'fal' -'fait' -'faire_' -'eßen_' -'extremen_' -'exposure_' -'expo' -'experimental_' -'expenses_' -'expect' -'expan' -'exhibition_' -'executed_' -'execut' -'exceptions_' -'exceeds_' -'exceed' -'everyday_' -'europäisches_' -'eure_' -'eur' -'etzung_' -'etung' -'etablieren_' -'essi' -'erz' -'erv' -'erta' -'errors_' -'errichten_' -'ernsthaften_' -'erneuer' -'ermÃļglichte_' -'erische' -'eric' -'erhielten_' -'ergänzen_' -'erfu' -'erfreu' -'erfasst_' -'erbaut_' -'enve' -'entzieh' -'entra' -'ently_' -'entitlement' -'entgegenge' -'entertain' -'entdeckte' -'entails_' -'enjoying_' -'engi' -'enge' -'endi' -'encompass' -'empty_' -'employers_' -'emphasis' -'emm' -'emergence_' -'ement' -'embedded_' -'embark' -'eliminating_' -'eligible_' -'eins' -'eingeschränkt_' -'eingeb' -'eind' -'eien_' -'eichne' -'eho' -'ehe' -'egi' -'egal_' -'editorial_' -'ede' -'economist_' -'ecken_' -'ebr' -'eben' -'durchgefÃŧhrten_' -'duration_' -'dunkle' -'droht_' -'dringlich' -'drew_' -'dramatisch_' -'drafted_' -'doubts_' -'dorthin_' -'dore' -'donor_' -'donat' -'dom' -'dle_' -'diversi' -'diversen_' -'districts_' -'dish' -'discharge_' -'direkt' -'dioxide_' -'dimensional' -'dim' -'dig_' -'differ_' -'dienst_' -'dich_' -'diagnose' -'diag' -'dha' -'dez' -'devoted_' -'devi' -'deutsch' -'deuten_' -'deterren' -'det' -'destabilizing_' -'desired_' -'deserve_' -'ders' -'derjenigen_' -'derived_' -'deprived_' -'dependen' -'departments_' -'depart' -'dense' -'demokratisch_' -'demographic_' -'delle_' -'degradation_' -'deco' -'dauern_' -'dasselbe_' -'dangers_' -'dag' -'cyclical_' -'cyber_' -'curs' -'cula' -'creativity_' -'crat' -'count' -'coun' -'correspond_' -'correction' -'cord' -'coole' -'cook' -'converted_' -'convert' -'conversation_' -'conv' -'controversial_' -'constructed_' -'construct_' -'constrain' -'consolidation_' -'cono' -'confront' -'conform' -'conf_' -'compulsory_' -'comprises_' -'complexity_' -'complement' -'competitors_' -'competing_' -'compensate' -'compatibility_' -'communicate_' -'commission' -'collected_' -'coherence_' -'clothing_' -'clip' -'clinical_' -'clicking_' -'classifi' -'classical_' -'clarification_' -'ckung_' -'civilians_' -'cit_' -'circum' -'cia' -'chts' -'chsel' -'chs' -'chme' -'ching_' -'child' -'chi_' -'ched_' -'checke' -'champions' -'challenging_' -'chairman_' -'chair_' -'certainty_' -'cease' -'casino_' -'capacities_' -'cap' -'cano' -'cancelled_' -'cam' -'bust' -'bury_' -'brother_' -'brit' -'breaking_' -'brand' -'branche_' -'booked_' -'boo' -'bomb' -'bold_' -'boa' -'bloß_' -'bli' -'bles_' -'bl' -'bit' -'bind' -'bien_' -'bia' -'bewert' -'bewer' -'besuchte_' -'bestände_' -'beschloss' -'beschleunigt_' -'bes' -'beru' -'berge' -'bereichen_' -'bera' -'benutzer' -'benachrichtigt_' -'benachbarten_' -'bemerkt_' -'bemerk' -'beliebige' -'belegt_' -'belaufen_' -'belast' -'bekenn' -'beiträgt_' -'beit' -'beinhalten_' -'beginne' -'begeben_' -'begab_' -'befriedigen_' -'bedeutsame' -'bedanken_' -'beat' -'bearing_' -'beanspruch' -'bb' -'battery_' -'batt' -'barr' -'barely_' -'balances_' -'bailout_' -'bab' -'aß' -'azi' -'az_' -'avoid' -'ava' -'autonom' -'automatic_' -'authorise' -'aut' -'auswärtige_' -'ausreicht_' -'ausgenutzt_' -'ausgegeben_' -'ausgedrÃŧckt_' -'ausgedehnt' -'auseinander_' -'aum_' -'aufzus' -'aufgew' -'aufgestellt_' -'aufges' -'aufgeb' -'audit' -'atz' -'attribute' -'attraction_' -'attend_' -'atl' -'atis' -'ation' -'aste' -'assurance_' -'asse_' -'arts_' -'arsenal' -'arrive_' -'aro_' -'argu' -'arena_' -'archi' -'appro' -'appreciation_' -'appreciated_' -'appoint' -'apologi' -'ap_' -'ap' -'anzugehen_' -'anw' -'antrag' -'anticipated_' -'anspruchsvolle' -'anschließen_' -'anschl' -'anis' -'anhaltende_' -'angs_' -'angewandt_' -'angetrieben_' -'anger_' -'angepasst_' -'angeme' -'angelegt_' -'angefangen_' -'ane_' -'ander' -'anbelangt_' -'analyze' -'analyst' -'amended_' -'ambitions_' -'allzu_' -'allgemeinem_' -'allge' -'alleg' -'alla' -'alized_' -'aliz' -'ality_' -'alisierung_' -'aligned_' -'align' -'albeit_' -'ala_' -'aket' -'airline_' -'ahmen_' -'ahm' -'ahl_' -'agn' -'agend' -'afraid_' -'advent' -'ads_' -'administrati' -'adidas_' -'addresses_' -'actively_' -'activ' -'aco' -'accus' -'accomplish' -'accident_' -'accelerate_' -'academi' -'ac' -'abzulehnen_' -'abhängen_' -'abgestimmt_' -'abgehalten_' -'abg' -'abandon_' -']] | _' -'Zweig' -'Zustellbetten_' -'Zust' -'Zuge' -'Zuerst_' -'Zimmerbeschreibung_' -'Zh' -'Zertifikat' -'Zahlungen_' -'Zah' -'Young_' -'You' -'Yan' -'Ya' -'Xin' -'XML_' -'Would_' -'Wort' -'Wind' -'Wiederbelebung_' -'Widerspruch_' -'Wesen_' -'Werden_' -'Wende' -'Wellness' -'Weiteren_' -'Wars' -'Warnung_' -'WAR' -'Voyager_' -'Vorstands' -'Volksabstimmung' -'Volcker_' -'Vogel' -'Vir' -'Vin' -'Vil' -'Vid' -'Vet' -'Vertretung_' -'Vertrages_' -'Vert' -'Versi' -'Versammlung_' -'Versa' -'VergnÃŧgen_' -'Verfassungsvertrag' -'Vereinfachung_' -'Verbr' -'Ven' -'Van' -'Urteil_' -'Unterzeichnung_' -'Unterscheidung_' -'Unsere' -'Unruhen_' -'Unless_' -'Umfragen_' -'Umb' -'Tyr' -'Turb' -'Tuni' -'Tsi' -'Treasury_' -'Transit' -'Transfer' -'TragÃļdie_' -'Tower_' -'Touristen_' -'Tools_' -'Tisch_' -'Tickets_' -'Tibet' -'Thro' -'Think' -'Ther' -'Theorie_' -'Textil' -'Terroran' -'Tempora' -'Temp' -'Tap' -'Tao' -'Tag' -'Tabelle_' -'Syri' -'Swoboda_' -'Swi' -'Summen_' -'Suche' -'Subject_' -'Stärken_' -'Studium_' -'Studien_' -'Storage_' -'Stockholm_' -'Steuerung_' -'Sternen' -'Stau' -'Station' -'Standpunkte' -'Stammes' -'Stamm' -'Staatsschulden_' -'Spre' -'Spitzen' -'Spezie' -'Speci' -'Space_' -'Sonne_' -'Solarium_' -'Sof' -'Socialists_' -'Slo' -'Ska' -'Six_' -'Sisko_' -'Single_' -'Simbabwe_' -'Siege' -'Sicherheitss' -'Sicherheitsbe' -'Shops_' -'Sharon_' -'Server' -'Series_' -'Sekunde_' -'Scott_' -'SchÃļne' -'Schuld_' -'SchrÃļder_' -'Schlussfolgerung_' -'Schluss' -'Schlu' -'Schli' -'Schlaf' -'Schiffs' -'Schengen_' -'Scar' -'Saudis_' -'Saturday_' -'SanDisk_' -'Sammlung_' -'Salzburg_' -'SR' -'SCH' -'Run' -'Rock' -'Rid' -'Richter_' -'Rica_' -'Repr' -'Repo' -'Renn' -'Rena' -'Rel' -'Reisende' -'Regions_' -'Reduzierung_' -'Reduc' -'Redner_' -'Rebellen_' -'Ratifizierung_' -'Rat' -'Rand' -'Rahmens_' -'Raf' -'RU' -'RS_' -'ROM_' -'Quick' -'Quest' -'Pä' -'Putins_' -'Pun' -'Proteste_' -'Protection_' -'Pros' -'Prophet' -'Programmen_' -'Profit' -'Profil' -'Prize_' -'Power' -'Pou' -'Por' -'Polizei' -'Plu' -'Platz' -'Pir' -'Pie' -'Picard_' -'Philippines_' -'Philip' -'Phil' -'Pharma' -'Pent' -'Pel' -'Peer_' -'Patri' -'Patch' -'Partnern_' -'Part' -'Parliamentary_' -'Parlaments' -'Papier' -'Papa' -'Panel_' -'Palacio_' -'PXI_' -'PDF_' -'Ozean' -'Orts' -'Orte_' -'Original_' -'Organis' -'Organ' -'Ordner_' -'Opfern_' -'Ombudsman_' -'Olympus_' -'Offen' -'Oberha' -'OU' -'OPEC_' -'OP' -'ODER_' -'Nutz' -'Nue' -'Nuclear_' -'Nob' -'Niedergang_' -'Neue_' -'Nazis' -'Namibia_' -'Nahrungsmittel_' -'Nag' -'NT_' -'ND' -'Mär' -'Muss' -'Music_' -'Musharraf_' -'Mund' -'Mugabe_' -'Much_' -'Motto_' -'Modul_' -'Modernisierung_' -'Modelle_' -'Mittela' -'Mittag' -'Mitgliedsländer_' -'Mitentscheidung' -'Minister' -'Militär_' -'Migranten_' -'Miet' -'Michel_' -'Michel' -'Meth' -'Messe_' -'Menschenrechtsverletzungen_' -'Menschenrechten_' -'Mem' -'McKinsey_' -'Mazedonien_' -'Maus' -'Materialien_' -'Marke_' -'Map' -'Mand' -'ME_' -'MED' -'MDGs_' -'MD' -'MAR' -'MAN' -'Luxembourg_' -'Lula_' -'Love' -'Louis_' -'Lot' -'Lobby_' -'Liu_' -'Lite' -'Linien_' -'Libyen_' -'Lesung_' -'Leo' -'Leiter_' -'Leid_' -'Lehman_' -'Legi' -'Lea' -'Large_' -'Language_' -'Landschaft_' -'Lama_' -'Lager' -'LA_' -'KÃŧn' -'Kraftstoff' -'Komplex' -'Kollege_' -'Kohlendioxid' -'Know_' -'Kirchen' -'Ki_' -'Kho' -'Kein_' -'Kampa' -'Kambodscha_' -'Kabel_' -'JÃļrg_' -'Jung' -'Jos' -'Jon' -'Jesus_' -'Jas' -'Jam' -'JA' -'Irr' -'Ion' -'Intergovernmental_' -'Intel_' -'Inte' -'Installations' -'Inlands' -'Infektion' -'Individual' -'Indi' -'Importe' -'IN_' -'ICEcat_' -'IA_' -'HÃļhepunkt_' -'Händen_' -'Hyatt_' -'Horizont' -'Hohen_' -'Hinzu' -'Hinweise_' -'Hinter' -'Hierzu_' -'Hier' -'Herzego' -'Heads_' -'Hannover_' -'Ham' -'Haiti_' -'GÃļ' -'Gul_' -'Gui' -'GrÃŧnen_' -'GrÃļßen' -'Grundsätze_' -'Große_' -'Greens_' -'Gran' -'Gouverneur' -'Golf' -'Gleichstellung_' -'Gle' -'Gil' -'Gewässer' -'Gewinner_' -'Gesundheitsp' -'Gesamte' -'Gerät_' -'Gerichten_' -'Geneva_' -'Genehmigung_' -'Gemä' -'Gemeinschaftsp' -'Gele' -'Geis' -'Gegenwärtig_' -'Gegenden_' -'Gefangenen_' -'Geburt' -'Gazastreifen_' -'Gaz' -'Gard' -'Gall' -'GEN' -'FÃŧrs' -'FÃŧnf' -'FÃŧhrungsp' -'Fußball' -'Fukushima_' -'Fuji' -'Freihandels' -'Fore' -'Flughafen' -'Flor' -'Fleisch' -'Fl' -'Fische' -'Finanzk' -'Finanzin' -'Feed' -'Fee' -'Fast_' -'Fantas' -'Fahrer_' -'FL' -'FD' -'FARC_' -'Extra' -'Exe' -'Evans_' -'Europarat' -'Ethiopia_' -'Estonia_' -'Esp' -'Erwerb_' -'Erst_' -'Ersch' -'Erl' -'Erinnerung_' -'Er' -'Entwicklungsa' -'Entw' -'Entsteh' -'Entschuld' -'Entlastung_' -'Ens' -'Enk' -'Englisch_' -'Energieeffizienz_' -'Energiea' -'Empfänger_' -'Element' -'Elektro' -'Electronic' -'Einzelnen_' -'Einwohnern_' -'Einsp' -'Einschränkung_' -'Einsatz' -'Eing' -'Einfluss' -'Einb' -'Eigent' -'Eb' -'EX' -'EWG_' -'EV' -'ECH' -'DÊ' -'Dusche_' -'Durchbruch_' -'Dur' -'Drogen_' -'Dritten_' -'Dresden_' -'Double_' -'Donnerstag_' -'Domin' -'Dokument' -'Disc' -'Disabled_' -'Diplomatie_' -'Digi' -'Dienste' -'Dick' -'Di_' -'Deut' -'Dest' -'Deposit_' -'Def' -'Deba' -'Deal_' -'Dau' -'DD' -'Cross_' -'Crew_' -'Countries_' -'Corb' -'Const' -'Congo_' -'Computers' -'Commo' -'Colon' -'Clau' -'Citizens_' -'Cin' -'Christine_' -'Christ' -'Child_' -'Chicago_' -'Chechnya_' -'Charl' -'Chan' -'Cast' -'Cash' -'Carl_' -'Car_' -'CL' -'CDs_' -'CC_' -'CB' -'BÃŧro_' -'BÃŧro' -'Bundestag_' -'Bull' -'Buchung_' -'Browser_' -'Briten_' -'Boutique_' -'Bour' -'Bonn_' -'Bomben' -'Bin' -'Bill' -'Bilanz_' -'Bier' -'Bie' -'Bewohner_' -'Betriebssystem_' -'Betracht_' -'Bestand' -'Beschwerden_' -'Bergen_' -'Bereits_' -'Berechtigung' -'Beratung' -'Benach' -'BemÃŧhen_' -'Bekanntlich_' -'Bein' -'Begr' -'Beg' -'Befr' -'Bedien' -'Bear' -'Balk' -'Bach' -'BL' -'Autorität_' -'Automati' -'Ausw' -'Australien_' -'Ausrichtung_' -'Ausbildungs' -'Ausbeutung_' -'Augenmerk_' -'Augenblick_' -'Aufg' -'Asse' -'Arr' -'Armenia' -'Arme' -'Argument' -'Architekt' -'Archer_' -'Arbeitss' -'Arbeitslosen' -'Arbeitsgruppe' -'Arafat_' -'Arabi' -'App' -'Apart_' -'Anwe' -'Anschläge_' -'Anschl' -'Anreiz_' -'Anpassungs' -'Annan_' -'Anlage' -'Anl' -'Anges' -'Ang' -'Anbindung_' -'Amm' -'Alpe' -'Alli' -'Allah_' -'Airways_' -'Afrikanische' -'Aff' -'Add_' -'Acht_' -'Access' -'Abw' -'About_' -'Ablauf_' -'Abbau_' -'AD' -'AB_' -'A6_' -'>' -'==' -'93_' -'89_' -'84_' -'76_' -'76' -'74_' -'74' -'66_' -'47' -'215' -'1980s_' -'1962_' -'1950_' -'1948_' -'1930er_' -'184' -'0s_' -'.) _' -'.&#_' -'. "_' -') ._' -'#_' -' – _' -' – ' -' Âģ _' -' = [[_' -' --> _' -'â‚Ŧ' -'– ' -'Ņ…_' -'Ņ€Đ°' -'ĐŋĐž' -'вО' -'вĐĩ' -'Κ' -'ła' -'Ãŧß' -'Ãŧnstige' -'Ãŧllen_' -'Ãŧhrt_' -'Ãŧbt_' -'Ãŧberwältigende' -'Ãŧbern' -'Ãŧbergeben_' -'Ãŧberge' -'Ãŧberein' -'Ãŧberarbeitet' -'Ãŧb' -'Ãļst' -'Ãļs' -'Ãļrder' -'Ãŗr' -'Ãąa' -'ÃĒt' -'Ên' -'ère_' -'ätzliche' -'ärkte_' -'ärk' -'ändler_' -'ändige' -'ändert_' -'äler' -'äd' -'ä_' -'ßig' -'Übersetzungs' -'Überr' -'Überleben_' -'Überl' -'Öko' -'Öffentliche_' -'Ähnlich' -'Äg' -'ÂĢ' -'}} _' -'{' -'zÃŧge' -'zzi' -'zuw' -'zuverlässige' -'zuv' -'zutage_' -'zunichte_' -'zulässig' -'zugeh' -'zug_' -'zog_' -'zl' -'ziplin' -'zine' -'ziert_' -'zeita' -'zeile' -'zan' -'zahlung_' -'zahlreicher_' -'yy' -'yw' -'yth' -'yl_' -'yl' -'yan_' -'xo' -'xie' -'xes_' -'wäsche_' -'wählt' -'wunderschÃļnen_' -'wunderbare' -'worsen' -'worauf_' -'wol' -'wle' -'witnessing_' -'wished_' -'wing' -'willen' -'wiederholte' -'width_' -'wid' -'wheel_' -'whe' -'wertvollen_' -'wertung_' -'wertige' -'werks' -'wendet_' -'wen' -'weißen_' -'weig' -'weiche' -'wee' -'websites_' -'weakness_' -'weakening_' -'warnen_' -'ward_' -'wandel' -'vorzus' -'vorsitzende' -'vorschlag_' -'vorher' -'vorgeschlagene_' -'vorauss' -'vora' -'volunteer_' -'volumen_' -'vo_' -'visiting_' -'virtue_' -'ville_' -'vierte' -'vielf' -'vice_' -'vet_' -'verÃļffentlichte_' -'verzeichnet' -'verz' -'verwendete_' -'verursachten_' -'vertrauen_' -'vertr' -'versuchten_' -'versuchte_' -'verst' -'versp' -'versicherung' -'versetzen_' -'verschärfen_' -'verschmutz' -'verschlimmer' -'verschle' -'verschl' -'verschieben_' -'versamm' -'verpa' -'vermute' -'vermei' -'vermehrt' -'verletzt_' -'verla' -'verkehrs_' -'verkauft_' -'verheerend' -'verha' -'vergrÃļßert' -'vergleich' -'verge' -'verga' -'vereint_' -'vereinig' -'vereinfacht' -'verein' -'verdräng' -'verdo' -'verda' -'verbrauche' -'verbr' -'verbleibenden_' -'verbindlich_' -'verbesserten_' -'verantwortungs' -'verabschieden_' -'ventur' -'vention' -'veness_' -'vene' -'vegetables_' -'vast' -'valley_' -'validat' -'vai' -'uß' -'uu' -'uts' -'utm' -'ution_' -'uten_' -'uste' -'usses_' -'usschuss_' -'usal' -'urop' -'urh' -'ura_' -'upt_' -'ups' -'upp' -'updates_' -'unzureichend' -'unverantwortlich' -'unterw' -'unterschätz' -'unterliegt_' -'unterbrochen' -'unta' -'unsustainable_' -'unse' -'unsa' -'unprecedented_' -'unnÃļtige' -'unmi' -'unm' -'unl' -'uniti' -'unh' -'ungss' -'ungsl' -'ungsg' -'ungleich' -'ungewÃļhnliche' -'unges' -'undertaking_' -'underscore' -'underpinned_' -'underpin' -'uncon' -'unbest' -'umr' -'umk' -'umi' -'umfassend_' -'umf' -'ume' -'ult_' -'uls' -'ular' -'ula' -'ugs' -'ugh' -'ugen_' -'uga' -'uert_' -'ucht_' -'uch_' -'uali' -'tÊ' -'tätigen_' -'tär' -'typischen_' -'typische_' -'tv' -'turn' -'turb' -'tse' -'tschetschenische' -'trug' -'tril' -'triggered_' -'trigger_' -'trick' -'tremendous_' -'treatments_' -'treaties_' -'trap_' -'transporti' -'tracking_' -'totalit' -'torium_' -'topp' -'tok' -'tlin' -'tlich_' -'tisa' -'tioni' -'tiona' -'tien' -'tiefer_' -'tiate' -'thereafter_' -'theme' -'teurer_' -'testen_' -'tert' -'tensions_' -'tension_' -'teilung_' -'techno' -'techni' -'tch_' -'tariffs_' -'tare' -'tanz' -'tankers_' -'tang' -'tab_' -'sÃŧdlichen_' -'switched_' -'svors' -'sverfahren_' -'suspension_' -'suspect' -'surprising_' -'supplier_' -'supervisor' -'superb_' -'sung' -'summ' -'suche_' -'such' -'successor_' -'subway_' -'substanzielle' -'substan' -'subjects_' -'stÃŧtzt_' -'stärkste_' -'stärkeren_' -'stupid_' -'stung_' -'stunde_' -'stun' -'stum' -'strukturelle_' -'strom_' -'stro' -'strive_' -'stric' -'strand' -'stran' -'strahl' -'ston' -'stli' -'stimmung' -'stick_' -'stet_' -'stereo' -'stel' -'steer' -'steckt_' -'stattdessen_' -'standardis' -'stair' -'stabiler_' -'ssung_' -'ssp' -'ssment_' -'ssed_' -'ssan' -'spu' -'sprozess_' -'spra' -'spor' -'sponsor' -'spiele_' -'spekul' -'spark' -'sozialistischen_' -'souveräne' -'sorti' -'sonstige_' -'som' -'solide_' -'soft' -'smokers_' -'smen_' -'slow' -'slip' -'sleeping_' -'sle' -'sku' -'siz' -'sive' -'situationen_' -'simpli' -'simis' -'signali' -'sichtig' -'shutt' -'shut_' -'shrinking_' -'shou' -'shortages_' -'sheets_' -'sexuellen_' -'sexu' -'sew' -'settled_' -'sessions_' -'servant' -'serv' -'sep' -'sentence' -'sends_' -'sely_' -'sele' -'sektors_' -'secretar' -'secret' -'seas_' -'seal' -'scrap' -'scr' -'schÃŧtzt_' -'schwierige_' -'schwierig' -'schwarzen_' -'schwa' -'schulden' -'schu' -'schmerzhaft' -'schlimm' -'schließe_' -'schlechtes' -'schle' -'schenken_' -'schaft' -'schafft_' -'sber' -'sba' -'saubere' -'satisfaction_' -'santa_' -'sant' -'sample_' -'same' -'sac' -'rÃŧckt' -'rÃŧch' -'räg' -'rzt_' -'rzi' -'rust' -'rungs' -'rund' -'ruf' -'rub' -'rtet' -'rter_' -'rted_' -'rsi' -'rse' -'rsch' -'rp' -'rota' -'ros_' -'rome_' -'roi' -'robuste' -'robot' -'rnt_' -'rl' -'risky_' -'riskieren_' -'rio' -'rige' -'ries' -'rieren_' -'ried' -'richtete_' -'rib' -'rianis' -'rgen_' -'rez' -'reward_' -'revol' -'revised_' -'reunification_' -'reten_' -'resur' -'resume' -'restructure' -'respektieren_' -'respect' -'resp' -'resolving_' -'resid' -'reserven_' -'reservations_' -'repräsent' -'repression_' -'repair_' -'remuneration_' -'remit' -'reme' -'relie' -'reliability_' -'relevance_' -'relaxation_' -'relationships_' -'rejection_' -'reiterate_' -'reist' -'reise_' -'reis_' -'reinforce_' -'reine_' -'reiche' -'regierung' -'regi' -'reduces_' -'recover' -'recip' -'rechtliche' -'rechtlich_' -'rechtfertigen_' -'receiving_' -'ream' -'realities_' -'rdl' -'rchi' -'rce' -'rben_' -'rbeit_' -'rauch' -'rans' -'rang' -'ral' -'raft' -'raf' -'radar_' -'rac' -'ques_' -'quent' -'pus' -'pursued_' -'pure_' -'punish' -'pull' -'publici' -'präzise_' -'pruden' -'protestier' -'protektionistische' -'protection' -'prostitut' -'pron' -'prolonged_' -'projected_' -'profi' -'produz' -'prize_' -'privilege_' -'prisoners_' -'prioriti' -'printing_' -'printer_' -'pria' -'preventing_' -'pretend' -'prestigious_' -'pressing_' -'presidents_' -'prediction' -'precede' -'precaution' -'practi' -'pping_' -'potenzielle_' -'pot_' -'port' -'pornograph' -'populist_' -'pop' -'poorly_' -'poly' -'polo' -'polari' -'poker_' -'pocket' -'plÃļtzliche' -'plätze_' -'pläne_' -'plate_' -'pl' -'pity_' -'pis' -'piracy_' -'pine' -'pillar_' -'pilgrim' -'pier' -'pieces_' -'physi' -'photograph_' -'pflanz' -'pes_' -'perspectives_' -'persone' -'pers' -'permission_' -'permanently_' -'peg' -'peacefully_' -'pat' -'passport_' -'passende' -'passage' -'participating_' -'participated_' -'partei_' -'parl' -'park' -'paren' -'papers_' -'panoramic_' -'palästinensische_' -'paket_' -'paint_' -'pain' -'pai' -'pag' -'pad_' -'packen_' -'oß_' -'overr' -'overlook_' -'outright_' -'ours_' -'oso' -'oro' -'orn_' -'origins_' -'orientiert_' -'ori_' -'orge_' -'org' -'ordn' -'ordern_' -'optimis' -'optical_' -'oph' -'operators_' -'opera_' -'ope' -'onym' -'oni_' -'onent' -'oly' -'ologischen_' -'olo_' -'oli' -'oks_' -'ois' -'ohl_' -'ogen' -'oga' -'offs_' -'offensichtliche' -'odd' -'ochen_' -'oche' -'occurring_' -'obscure' -'nÃŧtzliche_' -'näch' -'nwe' -'nutz' -'nuklearen_' -'nuestr' -'nswert_' -'nste' -'nsta' -'nsa' -'novel' -'notion_' -'notification_' -'nomina' -'nkung' -'nko' -'nka' -'nig' -'nien_' -'niedrigere' -'niederge' -'nid' -'ngte' -'ngt_' -'nger_' -'nga' -'nfalls_' -'newsletter_' -'neutral_' -'neuesten_' -'netze_' -'nesian' -'nese_' -'nes' -'nern_' -'neo_' -'nem' -'nel_' -'nel' -'neighborhood' -'nehmer_' -'nehme_' -'neben' -'ndlung' -'nding_' -'nden' -'ncie' -'nationalistische' -'nas_' -'narr' -'nant' -'namens_' -'nam' -'nahezu_' -'nah_' -'nage' -'nad' -'nachzu' -'nacht' -'nachdenken_' -'nab' -'mÃŧtig' -'mÃŧndlichen_' -'mysql' -'musik_' -'music' -'multinational_' -'multilateral' -'mpel' -'motivier' -'moreover_' -'mord' -'moo' -'monument' -'monopoly_' -'mond_' -'mon' -'modification_' -'modi' -'modernization_' -'mmt_' -'mml' -'mmig' -'mme_' -'mme' -'mm' -'mix' -'mittl' -'mitteilen_' -'mitglieder_' -'mistakes_' -'miss_' -'miracle' -'minde' -'mile' -'milde' -'mier' -'mie_' -'micro_' -'mher' -'metri' -'meta' -'met' -'menge' -'meltdown_' -'meets_' -'measurement_' -'maßen_' -'matic_' -'marriage_' -'march' -'marble_' -'mapp' -'manufacturer_' -'manufacture' -'mangelnde_' -'mane' -'mamm' -'malt' -'mall' -'makro' -'mais_' -'mail' -'mai' -'mag' -'macro_' -'mach' -'lÃŧge' -'lÃļsungen_' -'läßt_' -'länger' -'läge_' -'lw' -'luxuriÃļse_' -'lux' -'lud' -'luck' -'lth' -'low' -'loser_' -'loop' -'logis' -'loge' -'loaded_' -'lli_' -'lled_' -'lj' -'lität' -'literally_' -'limiting_' -'lik' -'lieren_' -'lieferte' -'liebe' -'lichsten_' -'licherweise_' -'libysche' -'liberalen_' -'lian' -'liabilities_' -'letzte' -'less' -'lernt' -'lera' -'lent' -'leitet_' -'leistungsstarke' -'lei' -'legung' -'leer' -'lds_' -'lding' -'lave' -'laundry_' -'lati' -'laste' -'lass' -'lar_' -'landschaft_' -'lakes_' -'label_' -'kutiv' -'kus' -'kurzfristige_' -'kurse_' -'kum' -'kula' -'ktivi' -'ktive' -'kräftige' -'kräften_' -'kreuz' -'kraft' -'kra' -'kot' -'korrigieren_' -'konzentrierte' -'kontrollierte' -'kontra' -'kontin' -'konsum_' -'konsolidier' -'konkreter_' -'konflikt' -'kommunistische_' -'kommerziellen_' -'komfortablen_' -'komfortable_' -'kohlenstoff' -'knapp' -'klingt_' -'kleines_' -'kleineren_' -'kl' -'kins_' -'kh' -'kette_' -'kett' -'kenne_' -'keeps_' -'kd' -'kay' -'kate' -'karten_' -'kapitalistische' -'kapital_' -'kab' -'jÃŧngst_' -'juristische' -'jung' -'junct' -'judgment_' -'judg' -'journal' -'jor' -'jetzigen_' -'jegliche' -'ject' -'jan_' -'iß' -'ize' -'ix' -'itäten_' -'itä' -'itude_' -'itis_' -'iting_' -'ith' -'iter' -'istischer_' -'istic_' -'ister' -'isn_' -'irregular' -'irgendwann_' -'iou' -'ior_' -'ion' -'inz' -'investigate_' -'inve' -'invasion_' -'interinstitutional_' -'interessant_' -'intell' -'integrierten_' -'integrierte_' -'integrieren_' -'inte' -'insure' -'instructi' -'instances_' -'inspiration_' -'inspectors_' -'inos_' -'innocent_' -'inne_' -'inne' -'inmitten_' -'init' -'inig' -'inhab' -'ingt_' -'informati' -'influenced_' -'inflict' -'infection_' -'iner_' -'inen' -'inefficient_' -'indo' -'individually_' -'indirect_' -'indication_' -'independently_' -'incorporated_' -'incorporate_' -'incl' -'inch_' -'importi' -'imported_' -'implied_' -'impede' -'imo' -'immune_' -'ime' -'ima' -'ilung_' -'illustrate' -'illustr' -'illo' -'ilde' -'ild' -'iha' -'ignorieren_' -'iges_' -'ife' -'ießen_' -'ierbare' -'iede' -'ido' -'ideologische_' -'identisch' -'identifizier' -'identical_' -'idealer_' -'ideal' -'idad_' -'icon_' -'ico' -'icherte' -'icher' -'ibt_' -'ianische' -'hÃļh' -'hängig' -'hw' -'hus' -'hurt_' -'hungr' -'hts' -'hrte' -'hou' -'hospitals_' -'homo' -'hofft' -'hly_' -'hlen_' -'his' -'hinzu_' -'hinweg' -'hink' -'hind' -'hil' -'hiking_' -'hierzu_' -'hiera' -'hielten_' -'hic' -'hervorrufen_' -'hervorge' -'herself_' -'herrschende' -'herrliche' -'herbeifÃŧhren_' -'herab' -'hens' -'hell' -'height_' -'hec' -'hearts_' -'headquarters_' -'hd' -'hast' -'harte_' -'harmonis' -'hap' -'handels' -'hamper' -'hafte_' -'hafte' -'gÃŧnstige_' -'gungs' -'gun' -'guilt' -'gue' -'grÃŧne_' -'grÃŧndlich_' -'grÃŧndet' -'grÃļßtenteils_' -'graphic_' -'grants_' -'gramm_' -'grafische' -'grad' -'good' -'gnis' -'gne' -'gnant_' -'gm' -'glich_' -'gleicher_' -'glad_' -'gift_' -'gier' -'gia' -'ghts_' -'ght' -'ggle' -'gewÃļhnlich' -'gewer' -'gewaltige' -'getrennt' -'gespa' -'gesp' -'gesichert_' -'gesetze' -'gesellschaftliche_' -'geschm' -'geschickt_' -'geschenkt_' -'gericht' -'geor' -'geopolitischen_' -'gentl' -'genocide_' -'generals_' -'genc' -'genauer_' -'gemÃŧtliche_' -'gemäßigte' -'gemessen_' -'gemeint' -'geli' -'gelegentlich_' -'gelegenen_' -'gekämpft_' -'geko' -'gekennzeichnet_' -'geka' -'geholfen_' -'gegenÃŧbersteh' -'gegenseitige_' -'gefährliche_' -'gefe' -'gebunden_' -'gasse_' -'gases_' -'gas' -'fÃŧhrten_' -'fÃŧhrend_' -'fÃļrder' -'fällen_' -'fähr' -'fä' -'furt' -'funktionen_' -'funktion_' -'fundamentalis' -'fund' -'functional_' -'fting_' -'fter_' -'frÃŧhe' -'frÃŧh' -'freundlicher_' -'freund' -'frequen' -'freiwillig' -'frameworks_' -'fragment' -'founding_' -'fossilen_' -'fortunate' -'fortschrittliche' -'forschung' -'formulier' -'formelle' -'formation' -'forgotten_' -'forever_' -'folgte_' -'flÃŧ' -'flourish' -'flou' -'flavo' -'fix' -'fische' -'finger' -'finds_' -'fina' -'figur' -'fielen_' -'fet' -'fests' -'feste' -'fens' -'fehlen' -'fed_' -'fax_' -'favourable_' -'fat_' -'fantastische' -'familie' -'fall' -'fahrzeuge_' -'fab' -'ez' -'extremism_' -'externen_' -'exter' -'express' -'explicit_' -'expla' -'exercise' -'exc' -'exact_' -'exacerbate' -'eventuelle' -'eve' -'evaluate_' -'europa' -'euro' -'etten_' -'ette_' -'ette' -'etr' -'eth_' -'eter_' -'esta_' -'ession_' -'esca' -'erÃļrtern_' -'erÃļffne' -'erwach' -'erto' -'erschÃŧtter' -'erri' -'erra' -'erod' -'erneute_' -'ernannt_' -'erleichtert_' -'erledig' -'erle' -'erinner' -'erhalt' -'ergänzt_' -'ergeb' -'erge' -'erforsch' -'erfolgte_' -'erenz' -'ereich' -'erect' -'erd' -'eradicat' -'equip' -'epis' -'epidemic_' -'epidemi' -'entspannen_' -'entschied_' -'enthaltenen_' -'entgegenzu' -'entgegen' -'entfernten_' -'entfernen_' -'entfallen_' -'ense' -'enrich' -'enp' -'enorm' -'eno' -'enntnis' -'enh' -'enfor' -'endeavour' -'ency_' -'empfohlen' -'emotion' -'emitt' -'emission' -'embr' -'emble' -'ema' -'elte' -'elli_' -'electorate_' -'eitung_' -'eisung' -'einzustellen_' -'einzusch' -'einzuf' -'einzig' -'einräum' -'einr' -'einnimmt_' -'einm' -'einla' -'einkommen_' -'eingest' -'eingeräumt_' -'eingereicht_' -'eingehend' -'eingebe' -'eine' -'eigne' -'eigentliche_' -'eid' -'ehrlich_' -'ehn' -'effektiven_' -'eff' -'efe' -'ees_' -'educa' -'edition_' -'eder' -'ect_' -'echnologie' -'ech' -'ebo' -'eating_' -'dynamism_' -'dynamische' -'dyn' -'dus' -'durchschnittliche_' -'durchbr' -'dura' -'dur' -'duc' -'dubious_' -'dte' -'drÃŧcken_' -'draw' -'drag_' -'dow' -'dor_' -'domin' -'dol' -'dne' -'dm' -'divisions_' -'dite' -'disziplin_' -'distress_' -'distinct_' -'diss' -'disrupti' -'displaced_' -'diskriminier' -'disk' -'dise' -'disagreement_' -'dire_' -'dik' -'dige' -'diffus' -'differen' -'differ' -'diesbezÃŧglichen_' -'dienste_' -'dicht_' -'dice' -'dhi' -'deux_' -'deutscher_' -'deutliche_' -'detriment_' -'dete' -'desp' -'deport' -'deploy' -'deplor' -'dependence_' -'denti' -'dens' -'denominat' -'deno' -'demonstration_' -'demonstrating_' -'dell_' -'delegate' -'deflation_' -'defines_' -'deepen' -'deemed_' -'decrease_' -'declined_' -'deckt_' -'deciding_' -'dba' -'db' -'dauerhaften_' -'dauerhafte_' -'dauer' -'dateien_' -'databases_' -'dark' -'dan_' -'cultivat' -'culo' -'ctive_' -'crown' -'crip' -'creature_' -'craft_' -'cosmetic' -'convicted_' -'conver' -'controller_' -'contro' -'continuous_' -'conte' -'consult_' -'constitutes_' -'consist' -'consi' -'consciousness_' -'connect' -'confused_' -'condemned_' -'condemn' -'concludes_' -'concert_' -'concentrated_' -'conc' -'composition_' -'compli' -'complain' -'compile' -'compact_' -'committing_' -'comme' -'combining_' -'comb' -'collapsed_' -'collaps' -'coat' -'clu' -'cle_' -'clara_' -'cking_' -'cker' -'cited_' -'cis' -'cinema' -'cien' -'cid' -'cian' -'chä' -'chter' -'chte_' -'chronic' -'chriften_' -'chor' -'chil' -'chemische_' -'chem_' -'cheat' -'chaue' -'characteristics_' -'cham' -'centers_' -'cease_' -'caution_' -'cati' -'casinos_' -'care' -'captured_' -'capture_' -'capital' -'can' -'camp' -'buck' -'brÃŧcke' -'broke_' -'brochure' -'brin' -'breit_' -'brauch' -'branche' -'bow' -'bour' -'bound' -'bombs_' -'bombard' -'bolster' -'blu' -'bloße_' -'blood' -'blockieren_' -'blat' -'biology_' -'bing_' -'billig_' -'bilanz' -'bid_' -'bic' -'bias_' -'bezug_' -'bezieh' -'bezeichnen_' -'bewährte' -'bewusste' -'bewahren_' -'bewaffneten_' -'betragen_' -'bete' -'bestä' -'bestr' -'bestimmtes_' -'bestimmter_' -'besticht_' -'beside' -'beschä' -'beschrieben_' -'beschei' -'besagt_' -'beruhigen_' -'berufliche_' -'berichte' -'bereiten_' -'berechtig' -'beobachtet' -'beneficial_' -'benannt' -'beliebtes' -'beizutragen_' -'behold_' -'behav' -'beha' -'beh' -'befehl' -'befasst_' -'beeindruckend' -'bedeut' -'beau' -'bean' -'bbi' -'bbe' -'battle' -'bathing_' -'basi' -'barri' -'balcony_' -'bahn' -'aÃąo' -'ays_' -'avo' -'ave' -'autoritäre' -'auto_' -'auszufÃŧhren_' -'ausstatt' -'ausschusses_' -'ausschließ' -'ausreichende' -'ausmachen_' -'ausgewählt_' -'ausgew' -'ausgeschlossen_' -'ausgeprägt' -'ausgebildete' -'ausfÃŧhr' -'ausdrÃŧckliche' -'aufwa' -'auft' -'aufkommen_' -'aufgez' -'aufgef' -'aufb' -'ature' -'atur' -'attraktive_' -'attra' -'attain' -'attacking_' -'attacked_' -'atra' -'ativen_' -'atic_' -'athlet' -'ater_' -'atemberaubende' -'asymmetri' -'asy' -'astr' -'ast_' -'assumption_' -'asso' -'assi' -'assess' -'asserti' -'assa' -'asc' -'ars' -'arri' -'arme' -'arising_' -'aria' -'architect' -'arabische_' -'appeal' -'appalling_' -'anzuwenden_' -'anzuh' -'anzug' -'anze' -'anz_' -'anz' -'anyway_' -'any' -'anwesend_' -'antworten_' -'antr' -'antic' -'anteil' -'anstrebt' -'anscheinend_' -'anschau' -'anpass' -'announce_' -'annimmt_' -'ankurbeln_' -'anisch' -'angenehme' -'angenehm_' -'angemessenen_' -'angekÃŧndigt_' -'angehen_' -'angebot' -'anga' -'anfangen_' -'anes' -'ande' -'ancy_' -'ances_' -'analysier' -'analyse_' -'amend_' -'ambigu' -'ambi' -'alts' -'altet' -'alleged_' -'aler_' -'akzeptabel_' -'aktiviert_' -'aktiven_' -'aj' -'aine' -'ahrt_' -'ahlen_' -'agog' -'aggre' -'affirm' -'affi' -'affen_' -'advisor' -'advertisement' -'advancing_' -'adre' -'ador' -'adding_' -'actress_' -'achievement_' -'accord_' -'accompanying_' -'accommodati' -'abzielen_' -'abw' -'absurd_' -'abstra' -'abse' -'abgew' -'abgesehen_' -'abf' -']]''' -'] ._' -'Zwischen' -'Zusagen_' -'Zus' -'Zucker_' -'Zinss' -'Zeitungen_' -'Zahlung_' -'Zahlreiche_' -'Yeltsin_' -'XL' -'WÃŧnsche_' -'Wähle' -'Wurzel' -'Wur' -'Wun' -'Wu' -'Works' -'Wo_' -'Wirtschaftsf' -'Wirksamkeit_' -'Wille' -'Wil' -'Wies' -'Wiederaufbau_' -'WidersprÃŧche_' -'Wichtigkeit_' -'Wichtig_' -'WiFi_' -'Who' -'Whenever_' -'Wetter_' -'Wettbewerbe' -'Weste' -'Werkzeug_' -'Weltm' -'Wellen_' -'Wel' -'Weiterver' -'Weiterentwicklung_' -'Weigerung_' -'Week' -'Wed' -'Wechselkurs' -'Wechsel_' -'Wassers' -'Warn' -'Wag' -'Wachstumss' -'Wac' -'WO' -'Vu' -'Votum_' -'Vorsorge' -'Vorsitzender_' -'Vorgabe' -'Vorf' -'Vorbereitungen_' -'Vita' -'Visual_' -'Village_' -'Vier_' -'VerzÃļgerungen_' -'Verzeichnis_' -'Verwirklichung_' -'Vertrags_' -'Versand' -'Vermittlungs' -'Verlängerung_' -'Verla' -'Veri' -'Verhinderung_' -'Verheugen_' -'Verha' -'Verge' -'Verfechter_' -'Verfasser' -'Verbot' -'Verbind' -'Verantwortungs' -'Verantwortlichkeit' -'Veranstaltungsräume_' -'Verabschiedung_' -'Vene' -'Vegas_' -'Vale' -'VD' -'VAT_' -'VA' -'Users_' -'Urban' -'Unterricht' -'Unterhalt' -'Universum' -'Un_' -'Umweltz' -'Umweltfragen_' -'Umsatz_' -'UNG_' -'Tätigkeiten_' -'Typ' -'Twitter_' -'Twin' -'Tw' -'Tunnel' -'Tuesday_' -'Tschetschenien_' -'Trinkwasser' -'Tribu' -'Tren' -'Treiber_' -'Treib' -'Travel_' -'Translat' -'Transa' -'Trans_' -'Training_' -'Tradi' -'Track' -'Tourismus' -'Total_' -'Together_' -'Todesopfer_' -'Todes' -'Tob' -'Titan' -'Tir' -'Tho' -'Thir' -'Thi' -'Theor' -'Theo' -'Theater' -'Terr' -'Terminal_' -'Temperaturen_' -'Tell' -'Teilnehmer' -'Taxi' -'Task_' -'TS' -'TION' -'TEN' -'TEC' -'TC' -'Sä' -'Synchron' -'Swe' -'Sunniten_' -'Summe_' -'Sum' -'Suiten_' -'Substanz' -'Studie' -'Stud' -'Strände_' -'Strei' -'Streben_' -'Store_' -'Stock_' -'Still' -'Stev' -'Step' -'Stal' -'Stadium_' -'Spring_' -'Spr' -'Spende' -'Spen' -'Spekulation' -'Source_' -'Sonnens' -'Sometimes_' -'Sollten_' -'Society_' -'Soci' -'Sn' -'Sky_' -'Sitz' -'SilverFast_' -'Sigma' -'Sieb' -'Sie' -'Sicherheits_' -'Shu' -'Shar' -'Sh' -'Settings_' -'Senegal_' -'Senate_' -'Sem' -'Seitens' -'Segel' -'Seas' -'Script' -'Scien' -'Schwarze' -'Schwach' -'Schuh' -'Schrift' -'Schmerz' -'Schließung_' -'Schla' -'Schl' -'Schiiten_' -'Santiago_' -'Sah' -'Sag' -'Safe_' -'Sad' -'Sachs_' -'SY' -'SV_' -'STE' -'SLR_' -'SL' -'SC_' -'RÃļ' -'Russians_' -'Rum' -'Robert' -'River' -'Rival' -'Risikobe' -'Right_' -'Rif' -'Richtlinien' -'Richt' -'Rezept' -'Revol' -'Rettung_' -'Respons' -'Reserven_' -'Res' -'Republikaner_' -'Renten' -'Renditen_' -'Remo' -'Relax_' -'Reit' -'Reinigung_' -'Reihenfolge_' -'Reich_' -'Reh' -'Regula' -'Registr' -'Regierungsp' -'Regel' -'Recovery_' -'Record' -'Rechtsetzung' -'Recht' -'Rechn' -'Recently_' -'Recent_' -'Ratsvorsitzes_' -'Randlage_' -'Ramblas_' -'Raketen' -'Rahmenprogramm' -'RSS_' -'RS' -'RGB_' -'Quer' -'Quarte' -'Quark' -'Qa' -'Prävention_' -'Prov' -'Protokolls_' -'Protest' -'Promo' -'Projekten_' -'Prof_' -'Process_' -'Priv' -'Prev' -'Presidents_' -'Premier' -'Portal' -'Populis' -'Popularität_' -'Poor_' -'Poll' -'Politi' -'Polic' -'Pl' -'Pil' -'Pierre_' -'Phänomen_' -'Photos' -'Photo' -'Pho' -'Pfund_' -'Pflege_' -'Pflege' -'Personen' -'Perl' -'Pensions' -'Peninsula_' -'Pazifik' -'Passag' -'Parks_' -'Parkplätze_' -'Parallele' -'Parag' -'Pak' -'Paar_' -'PV' -'PA_' -'Oz' -'Oste' -'Oslo_' -'Ora' -'Olive' -'Official_' -'Offensi' -'Oberfläche_' -'OLAF_' -'OL' -'OG' -'Nuevo_' -'Nothing_' -'Norway_' -'Nordic_' -'Nomin' -'Niger' -'Nichts_' -'Nepal' -'Navy_' -'Natural_' -'Nah' -'Nachrichten' -'Nachh' -'Nachfrage' -'Nachbarschaft_' -'NRO_' -'NOT' -'NN_' -'NH' -'NF' -'MÃŧtter_' -'Mutter_' -'Mumbai_' -'Mull' -'Movement_' -'Moto' -'Motion_' -'Mord' -'Monte_' -'Mons' -'Monday_' -'Monats_' -'Mona' -'Momentan_' -'Moldova_' -'Modell' -'Mode' -'Mobilität_' -'Mobili' -'Mitbe' -'Mischung_' -'Ministeri' -'Million_' -'Milliarde_' -'Mich' -'Meister' -'Meinungs' -'Meines_' -'Meer' -'Medica' -'Mec' -'Matte' -'Materi' -'Master_' -'Marken_' -'Mant' -'Malta_' -'Mali_' -'Make_' -'Mai' -'MB' -'MADRID_' -'M5' -'LÃŧcke_' -'LÃļsch' -'LÃļ' -'Lui' -'Luggage_' -'Luftraum' -'Low_' -'Lou' -'Look_' -'Lon' -'Logo_' -'Lockerung_' -'Lizenz_' -'Lizenz' -'Lit' -'Liquidität_' -'Linken_' -'Line_' -'Liikanen_' -'Lif' -'Lieferung' -'Lich' -'Liberia_' -'Liberal_' -'Lewi' -'Less' -'Leser' -'Lern' -'Leonardo_' -'Leistungsbilanz' -'Leica_' -'Lehren_' -'Leh' -'Led' -'Lebensqualität_' -'Lebe' -'Learn_' -'Lauf_' -'Lau' -'Lar' -'Landwirtschafts' -'Lama' -'Lag' -'Laden_' -'LU' -'LT' -'LS_' -'LDP_' -'LC' -'KÃŧrze_' -'Kurs' -'Kun' -'Kriege_' -'Kreis_' -'Kraftwerk' -'Korr' -'Koran_' -'Konzerne_' -'Konzern' -'Konzepte_' -'Konver' -'Kontinents_' -'Kons' -'Konjunktur_' -'Kongress' -'Kongo_' -'Konflikts_' -'Komplettpreise_' -'Kommando_' -'Komm' -'Kom' -'Kohäsions' -'Kofi_' -'Knoten_' -'Klä' -'Klon' -'Klinik' -'Kleid' -'Kissinger_' -'Kinders' -'Ker' -'Kathedrale_' -'Kath' -'Kapitel_' -'Kanal' -'Junge' -'Jugend_' -'JosÊ_' -'Jordanien_' -'Jh' -'Jere' -'Jakob' -'Jak' -'Jahr' -'Jag' -'Jacuzzi_' -'Jac' -'JE' -'Iv' -'Italien' -'Iss' -'Interi' -'Inter_' -'Innovations' -'Innovat' -'Innen_' -'Inn' -'Inha' -'Ingenieur' -'Industrieländer_' -'Indiens_' -'Index' -'Impulse_' -'Improv' -'Implementi' -'Impfstoff' -'Immunität_' -'Imag' -'Ideologie_' -'Identifi' -'Ideal_' -'Id' -'Ib' -'INI_' -'IE' -'ICC_' -'Häuser_' -'Häufig' -'Hungers' -'How' -'Hop' -'Histori' -'Hisbollah_' -'Hindu' -'Hindernis_' -'Hilton_' -'Hilfen_' -'Hil' -'Herkunft_' -'Heran' -'Hen' -'Help_' -'Helm' -'Haushaltsdefizit_' -'Hauptb' -'Hat_' -'Hast' -'Hall' -'Hag' -'HP_' -'GÃŧltigkeit' -'GÃļteborg_' -'Gy' -'Gute_' -'Guinea_' -'Gue' -'GrÃŧnder_' -'GrÃŧnd' -'Grunds' -'Grill' -'Gras' -'Grafik' -'Gore_' -'Gordon_' -'Gorba' -'Gor' -'God' -'Gläubiger_' -'Gleichheit_' -'Gipfels_' -'Gesetzgeb' -'Geschäfts_' -'Gerade_' -'Georgien_' -'Geo' -'Gender_' -'Gemeinde_' -'Geltung' -'Gelegenheiten_' -'Gelde' -'Gei' -'Gehä' -'Geheimdienst' -'Gegenzug_' -'Gegenw' -'GefÃŧhle_' -'Gefangene' -'GebÃŧhren_' -'GebÃŧhr_' -'Gebäuden_' -'GF' -'FÃŧhrungs' -'FÃļderation_' -'Future_' -'Further_' -'FrÃŧchte_' -'Frucht' -'Fronte' -'Fro' -'Friedensprozess' -'Friedens_' -'Freundschaft_' -'Fremd' -'Freilassung_' -'Frauen' -'Franzosen_' -'Franklin_' -'Fourth_' -'Fotokopiereinrichtungen_' -'Fos' -'Forst' -'Former_' -'Foot' -'Folglich_' -'Fokus_' -'Flexi' -'Fitness_' -'Fischereipolitik_' -'Fis' -'Find' -'Finanzsektor' -'Finanzs' -'Finanzr' -'Finanzminister_' -'Finanz_' -'Filter' -'Filip' -'Fift' -'Feuer_' -'Fernseher_' -'Fem' -'Feli' -'Features_' -'Fau' -'Fab' -'FR' -'FF_' -'FE' -'Extremisten_' -'Export_' -'Experience_' -'Exil' -'Excel_' -'Events_' -'Event_' -'Europäisches_' -'Erweiterungs' -'Ertrag' -'Erst' -'Ersparnisse_' -'Erneuerung_' -'Erkenntnis_' -'Eri' -'Ergebnissen_' -'Erfo' -'Ep' -'Entschädigung' -'Energiequellen_' -'Elli' -'Elend_' -'Elektronik' -'Eisen' -'Einwanderungspolitik_' -'Eintritt' -'Einschätzung_' -'Einschränkungen_' -'Einsch' -'Eins' -'Einnahme' -'Einladung_' -'Einh' -'Eingriff' -'Einfuhr' -'Eines_' -'Eigenkapital' -'Ehren' -'Ehre_' -'Ehr' -'Effe' -'Ecuador_' -'Economy_' -'Ecke_' -'Ec' -'Ebola_' -'Easy_' -'ESS' -'EE' -'Dänemark_' -'Drittstaaten_' -'Dream' -'Drama' -'Down' -'Don_' -'Don' -'Dominion_' -'Domini' -'Disziplin_' -'Diskriminierung_' -'Differen' -'Diejenigen_' -'Dha' -'Dez' -'Dere' -'Denk' -'Defense_' -'Decision_' -'Days_' -'Dara' -'Damals_' -'DT' -'DPJ_' -'DNA_' -'DEN' -'DB' -'Crystal_' -'Crown' -'Critic' -'Corporate_' -'Corn' -'Cop' -'Converter_' -'Continental_' -'Conte' -'Cont' -'Conservative' -'Consensus_' -'Cond' -'Conc' -'Comple' -'Competiti' -'Comp' -'Commander_' -'Columbia_' -'Coffee_' -'Cle' -'Chip' -'Chem' -'Chef_' -'Chef' -'Check' -'Chat_' -'Champs_' -'Chal' -'Centr' -'Celsius_' -'Casa_' -'CT_' -'BÃŧrgerkrieg_' -'By' -'Burk' -'Bulgarian_' -'Bui' -'BrÃŧder_' -'Brotherhood_' -'Brig' -'Bridge_' -'Bou' -'Boote' -'Bog' -'Blut_' -'Blume' -'Bisher_' -'Bir' -'Binnenmarktes_' -'Bibliotheken_' -'Bib' -'Bewusstsein_' -'BevÃļlkerungen_' -'Beträge_' -'Bestellung_' -'Besetzung_' -'Beseitigung_' -'Beschäftigten_' -'Beschr' -'Besatzung_' -'Bericht' -'Benutzern' -'Bell_' -'Belastung_' -'Beit' -'BefÃŧrworter_' -'Beendigung_' -'Beck' -'Beau' -'Beam' -'Battle' -'Batt' -'Bath_' -'Basis' -'Bashir_' -'Basel_' -'Bart' -'Barr' -'Baj' -'BSE_' -'BN' -'BER' -'BC_' -'BBC_' -'B6_' -'Award' -'Australian_' -'Austin_' -'Ausbr' -'Ausblick_' -'Ausb' -'Aufr' -'AuflÃļsung_' -'Aufhebung_' -'Atlantik' -'Ate' -'Ash' -'Arzneimittel' -'Armee' -'Are' -'Arbeitsbedingungen_' -'Arabischen_' -'Applikation' -'Appl' -'Appart' -'Any' -'Anwendungs' -'Antid' -'Ansätze_' -'AnsprÃŧche_' -'Anruf' -'Anr' -'Annäherung_' -'Annahmen_' -'Anlei' -'Angabe_' -'Anbieter_' -'Analysen_' -'Analys' -'Alter' -'Alon' -'Alco' -'Alc' -'Albu' -'Albe' -'Akzeptanz_' -'Akti' -'Akteur_' -'Age_' -'Africans_' -'Admiral_' -'Admi' -'Addi' -'Add' -'Adam_' -'Act_' -'Acqu' -'Achtung_' -'Account' -'Abwe' -'Abstand_' -'Abschw' -'AT_' -'ASE' -'AND_' -'ALDE_' -'AKVIS_' -'AG' -'AC_' -'? ' -'87_' -'7' -'68' -'45' -'38' -'20th_' -'201' -'1982_' -'1971_' -'1968_' -'175' -'1701_' -'160' -'13th_' -'10th_' -'010' -'002' -'001' -'/-_' -'. ({{_' -'*_' -'**' -')]' -''']]' -' īŋŊ_' -' “ _' -'  ' -' ==' -' ($_' -' "._' -' ", _' -'īŋŊīŋŊ' -'â‚Ŧ _' -'”_' -'“) _' -'’ – _' -'×ĸ' -'Ҍ_' -'҆' -'ĐŊŅ‹' -'Đēи' -'в_' -'Îŋ_' -'Îŧ' -'Åž_' -'ÅĄi' -'Č' -'Ãŧtig' -'Ãŧste' -'Ãŧssel' -'Ãŧnsch' -'Ãŧndig' -'Ãŧndet' -'Ãŧllt_' -'Ãŧle' -'Ãŧhrungs' -'Ãŧhren_' -'Ãŧhl' -'Ãŧchte' -'Ãŧche_' -'Ãŧblichen_' -'Ãŧblich_' -'Ãŧberschuss_' -'Ãŧberschr' -'Ãŧberleb' -'Ãŧberholt_' -'ør' -'Ãļsung' -'Ãļstlichen_' -'Ãļrter' -'Ãļpfe_' -'Ãļne' -'Ãļko' -'Ãļffnet_' -'Ãļffentliches_' -'Ãļd' -'Ãąo_' -'ÃĒ' -'Êo' -'Êes_' -'äußerster_' -'äuse' -'äumt_' -'ätter' -'ätt' -'äse' -'ärm' -'ärf' -'ärer_' -'änz' -'ängig_' -'änen_' -'ändig_' -'älteste_' -'älter' -'ällt_' -'äf' -'äck' -'ächte' -'ÃĄl' -'Überw' -'Überlegungen_' -'Übera' -'Übel' -'Üb' -'Ökosystem' -'Ägyptens_' -'Âģ | _' -'° _' -'}: _' -'|''_' -'zweitgrÃļßte_' -'zweifelhaft' -'zweier' -'zweckmäßig' -'zweck' -'zwan' -'zuzuge' -'zustellen_' -'zustande_' -'zusammenbr' -'zurÃŧckkehren_' -'zurÃŧckf' -'zung_' -'zun' -'zuh' -'zugestimmt_' -'zugelassen_' -'zufällig' -'zub' -'zoom_' -'zones_' -'zna' -'zlich' -'zio' -'ziehung' -'zi_' -'zeuge' -'zeug_' -'zet' -'zerr' -'zeichnet_' -'zb' -'yto' -'ypti' -'ypte' -'ypt' -'yo_' -'yne' -'yma' -'yf' -'yd_' -'xin' -'wÃŧrdige' -'wär' -'wun' -'wuchs_' -'wu' -'writ' -'wra' -'worthy_' -'workplace_' -'workforce_' -'wooden_' -'wolle' -'woch' -'witnessed_' -'withdraw_' -'wisse' -'wirtschaftliches_' -'wirkungsvoll' -'wirkungen_' -'wirksamer_' -'winzige' -'willk' -'will' -'wil' -'widerspricht_' -'widersprechen' -'widerspr' -'widen' -'wheel' -'werdende' -'weigh_' -'wehr' -'wede' -'wd' -'watching_' -'watched_' -'wandeln_' -'waltung_' -'wald_' -'wage' -'vul' -'voy' -'vou' -'vorÃŧbergehend_' -'vorzubereiten_' -'vorz' -'vort' -'vorsichtig_' -'vorschläge_' -'vorsch' -'vorr' -'vorn_' -'vorliegende_' -'vorliegen_' -'vorkommen_' -'vorigen_' -'vorhers' -'vorherige_' -'vorgeb' -'vorb' -'vollzieh' -'volatile_' -'vocational_' -'visual_' -'violat' -'vierzig_' -'vierten_' -'vici' -'veterinar' -'vet' -'vested_' -'vessel_' -'verÃļffentlichten_' -'verzichten_' -'verwirklichen_' -'verwalte' -'verv' -'verursachte_' -'vertrete' -'vertreib' -'verteilen_' -'versäum' -'versuch' -'versta' -'versichert' -'versehen' -'verschwinden_' -'verma' -'verlängert_' -'verletz' -'verlei' -'verl' -'verknÃŧpft_' -'verkehrs' -'verhängt' -'verglichen_' -'vergeb' -'verg' -'verfÃŧgbare' -'verfassungs' -'verf' -'vereinbarte' -'verdi' -'verdeutlich' -'verbu' -'verbreiteten_' -'verbreitete_' -'verbindliche' -'verbind' -'verba' -'verantwort' -'veranstalte' -'verans' -'verabschiedeten_' -'verab' -'vera' -'venue' -'vehement' -'vegeta' -'vede' -'vas' -'variables_' -'vall' -'val_' -'vaccination_' -'uß_' -'uzi' -'utzte' -'utze' -'utu' -'utter' -'uto' -'utier' -'uter' -'usiasm' -'urz_' -'urteil' -'url' -'urc' -'urbel' -'urbanization_' -'upt' -'upload_' -'upheaval' -'unzulä' -'unverzÃŧglich_' -'unverh' -'unterzeichne' -'unterschr' -'unterrichte' -'unt_' -'unschuldige' -'unpro' -'unpredictable_' -'unmittelbaren_' -'unkt' -'universelle' -'univers' -'unite' -'union' -'uniform_' -'ungswe' -'ungssysteme_' -'ungsr' -'ungslos_' -'ungskr' -'ungsan' -'unglaub' -'unfortunate_' -'unerwÃŧnschte' -'unent' -'undin' -'undertaken_' -'undert' -'underli' -'underdevelop' -'uncom' -'unce' -'unbekannte' -'unbegr' -'umweltfreundlich' -'umsetzen_' -'umfassendere' -'umfangreichen_' -'umfang_' -'ument' -'ulierungen_' -'uer' -'uen' -'uc_' -'ubt' -'uble_' -'uar' -'tägliche_' -'tzes_' -'typen_' -'tunnel' -'tual' -'ttent' -'tta' -'träum' -'troubles_' -'tropical_' -'triumph_' -'trin' -'tribute_' -'trials_' -'treibe' -'trei' -'treffe' -'traurig' -'traum' -'transnationale' -'translated_' -'transform' -'transferring_' -'traf' -'traditionell_' -'tract' -'toxic_' -'toxi' -'tower_' -'towels_' -'toute' -'tout_' -'toma' -'toi' -'tock' -'tnis' -'tna' -'tk' -'tire' -'tir' -'tip' -'timing_' -'timetable_' -'tim' -'tile' -'tigt' -'tighter_' -'tige_' -'tieren_' -'tiere' -'tid' -'tial_' -'thy_' -'thread_' -'thoughts_' -'thon_' -'therm' -'therapie' -'therapeuti' -'theories_' -'theoretische' -'thee_' -'thanking_' -'tform' -'textile_' -'teria' -'tered_' -'tenu' -'tendency_' -'tended_' -'temporarily_' -'tempor' -'tempo_' -'teiln' -'teen' -'technologisch' -'taxpayer_' -'tausch' -'tatt' -'tatsächliche_' -'tap' -'tand_' -'tand' -'talent_' -'talent' -'tages' -'tadt_' -'tac' -'sÃŧ' -'säu' -'sämtliche' -'synthesi' -'synchroniz' -'symbol' -'sy' -'swor' -'swert' -'sweet' -'sustain' -'suspicion_' -'surveys_' -'surprisingly_' -'surgi' -'supermarkets_' -'suns' -'summari' -'sul' -'sudden_' -'succeed' -'subst' -'subsidiaries_' -'subsidiar' -'subjecti' -'stÃļ' -'stärkt_' -'ständige_' -'studying_' -'stub' -'sts' -'strÃļme' -'strukturierte' -'strikt' -'striking_' -'stress' -'strenge' -'streit' -'straf' -'stoppen_' -'stolz_' -'stock' -'stim' -'stif' -'stie' -'stic' -'steuern_' -'stetig' -'stes_' -'steril' -'stepp' -'stecken_' -'steady_' -'stattgefunden_' -'statistical_' -'stationen_' -'standardiz' -'stall' -'stakes_' -'stak' -'stabilisier' -'stabilen_' -'stab' -'staate' -'ssta' -'ssene_' -'sreich' -'sr' -'sq' -'spä' -'spyware_' -'spreche' -'sprache_' -'spoil' -'spite_' -'spezialisierte' -'spends_' -'spell_' -'spell' -'spekt' -'spei' -'speedi' -'speculati' -'specific' -'spann' -'sozio' -'sop' -'sooner_' -'sont_' -'solved_' -'sollt' -'solide' -'sogenannte_' -'socialist_' -'soar' -'sn' -'smuggl' -'smi' -'sly_' -'slin' -'slaughter' -'skiing_' -'skat' -'sitze' -'sitz' -'sische' -'sir' -'sinn' -'singles_' -'simulation_' -'simul' -'sili' -'silence_' -'signifikante' -'signature_' -'signat' -'signals_' -'sights_' -'sighted_' -'sige' -'sichtlich_' -'sichtige' -'sicht' -'sian_' -'shores_' -'shire_' -'shifting_' -'shifted_' -'sher' -'shell_' -'shak' -'sgebiet' -'setzten_' -'ses' -'servici' -'sequen' -'sentence_' -'sensitivity_' -'sene' -'sende' -'sena' -'seltene_' -'self' -'selecti' -'seiz' -'seiner' -'sein' -'sehbar' -'securi' -'secular_' -'sections_' -'sect' -'searche' -'screening_' -'score' -'schätzt_' -'schä' -'schwächere' -'schwächen_' -'schwei' -'schwe' -'schwache_' -'schuldig_' -'schul' -'scht_' -'schreckliche' -'schockier' -'schläge_' -'schloss_' -'schlimmer_' -'schließt_' -'schlicht' -'schlag_' -'schickt' -'schic' -'scheu' -'scherweise_' -'schem' -'schei' -'schal' -'schaftlich' -'schaften_' -'scenes_' -'scann' -'scan_' -'sav' -'sang' -'samples_' -'sammeln_' -'salaries_' -'sailing_' -'sahen_' -'sag' -'sacrifice_' -'sache' -'sach' -'saa' -'rÃŧcken_' -'rvat' -'rut' -'ruppe' -'runde_' -'ruhiger_' -'ruhige_' -'ruhig' -'rue' -'rri' -'rpr' -'rozess' -'rox' -'rot' -'rogramm' -'rmu' -'rme' -'rmaßen_' -'rlich_' -'rkl' -'rke_' -'rivers_' -'rite' -'risi' -'rior' -'rio_' -'rin_' -'rimi' -'rily_' -'rik_' -'rigoro' -'rigen_' -'rift_' -'riert_' -'riege' -'rider' -'rid_' -'richtig' -'rice' -'rga' -'rew' -'revival_' -'revision_' -'ret_' -'restrictive_' -'restricti' -'ressourcen_' -'respecting_' -'resiste' -'resistant_' -'reserv' -'repu' -'replacement_' -'renew' -'rence' -'removing_' -'rema' -'relying_' -'relocat' -'relativ' -'reitung' -'reinv' -'reifen_' -'reicht' -'reichere' -'regelmäßige' -'regel' -'reformier' -'reforme' -'referendums_' -'redefin' -'recti' -'recourse_' -'recording' -'reconcile' -'recon' -'recom' -'recipe_' -'recession' -'receiver' -'rebuild_' -'rebels_' -'rebel' -'reassure' -'realisti' -'realisier' -'realis' -'reaches_' -'rde' -'rci' -'rbeitung_' -'rav' -'rats_' -'ras_' -'rar' -'rapide_' -'ramp' -'ramm' -'rali' -'rale' -'rainf' -'rain_' -'rain' -'rail' -'rai' -'radio' -'rade' -'rad' -'race' -'quip' -'quicker_' -'quet' -'query_' -'quelle_' -'quasi_' -'qualities_' -'qualifizierten_' -'put' -'pursuing_' -'pup' -'punkt' -'punc' -'pump' -'puls' -'ptic' -'prÃŧnglich' -'prÃŧfung_' -'prÊ' -'pru' -'proze' -'prosper' -'proportional' -'properties_' -'prol' -'projekt' -'prog' -'profession_' -'probleme' -'probe_' -'private' -'prinzipie' -'prin' -'pride_' -'preventive_' -'pretty_' -'preside' -'preserved_' -'prerequisite_' -'premises_' -'preliminary_' -'preferable_' -'predicted_' -'preche' -'preach' -'posted_' -'poss' -'positioned_' -'portugiesische' -'portray' -'portrait' -'portion_' -'portfolio' -'populär' -'populistischen_' -'poll' -'plural' -'plum' -'plentiful_' -'pledge_' -'plausible_' -'plant' -'planes_' -'plane' -'plain_' -'plain' -'placing_' -'piscin' -'pioneer' -'pilot' -'pick' -'phy' -'phori' -'phen' -'phe' -'phan' -'pfl' -'pfen_' -'pfelt' -'pez' -'peu' -'perpetua' -'permi' -'permanente' -'periodi' -'perfekten_' -'penetrat' -'penalties_' -'pel' -'pedi' -'pedestrian_' -'peacekeeping_' -'patron' -'pati' -'patch_' -'passiv' -'passion_' -'passing_' -'passen_' -'parent_' -'parameter_' -'paradox' -'pale' -'pal_' -'pakistanische' -'painting_' -'pact_' -'packaging_' -'pa_' -'oya' -'oxid' -'ox' -'owsk' -'owner_' -'owi' -'owe_' -'oversight_' -'overs_' -'outrage' -'ous' -'ound' -'otr' -'otis' -'oten_' -'otel_' -'ote_' -'oss_' -'orr' -'orisch' -'ories_' -'organisieren_' -'organic_' -'ordert_' -'ordentlich' -'orbit' -'oq' -'optimier' -'opo' -'operativen_' -'ony_' -'ono' -'onf' -'ond' -'omp' -'ommenen_' -'oll_' -'oliti' -'olic' -'oler_' -'old' -'olar' -'oj' -'ohner' -'ogg' -'ofi' -'officer_' -'offe_' -'odo' -'odi' -'ode_' -'ocr' -'ocean_' -'occurs_' -'obtaining_' -'obst' -'observer_' -'obs' -'oberha' -'obere' -'oasis_' -'oard_' -'nÃŧtzliche' -'nÃŧpf' -'nächstes_' -'nutzbar_' -'nus' -'nungen_' -'null_' -'nukleare' -'ntwort' -'ntu' -'nsk' -'nschw' -'nr' -'nowhere_' -'nosti' -'noisy_' -'noc' -'nnen' -'nne_' -'nitts' -'nitt' -'nisch_' -'nio' -'nik' -'nian' -'nia_' -'nho' -'ngly_' -'nglich' -'ngi' -'neutrali' -'neuer' -'neuem_' -'nets_' -'nest' -'neighbor' -'nego' -'neglect_' -'negativ' -'nec' -'ndt' -'ndo' -'ndl' -'ndert' -'ndel' -'nch_' -'nba' -'navigate_' -'nationalist' -'nationali' -'nant_' -'nano' -'nah' -'nacht_' -'nable_' -'märkten_' -'muy_' -'musical_' -'mur' -'mund' -'mpa' -'mounted_' -'mos_' -'moralis' -'mons' -'mona' -'molecul' -'modul' -'moderniz' -'modernem_' -'modell' -'mob' -'mmi' -'mla' -'mitte' -'mithilfe_' -'mistr' -'mistaken_' -'mission' -'missed_' -'misma' -'mins' -'ministry_' -'mild' -'mic_' -'metr' -'mete' -'merika_' -'memori' -'mem' -'meld' -'meister_' -'mehrfach_' -'medizinischen_' -'medizinische' -'mechanismus_' -'mechanisch' -'mbo' -'mble_' -'mba_' -'maybe_' -'maximi' -'matt' -'materielle' -'massiven_' -'mask_' -'mask' -'marsch' -'market' -'marken' -'marginali' -'mant' -'mangelt_' -'manc' -'maler' -'makroÃļkonomischen_' -'maker_' -'majest' -'main' -'magni' -'maga' -'mad' -'lÃŧsse' -'lze' -'lush_' -'ltig_' -'lti' -'lte' -'lowest_' -'loud' -'lop' -'lohnt_' -'local' -'lobb' -'llungs' -'llung_' -'lles_' -'litik_' -'lite' -'litari' -'listening_' -'listened_' -'lining_' -'linear_' -'lim' -'lift' -'lifestyle' -'liefert_' -'lieder' -'lichung_' -'lichem_' -'lic_' -'liberali' -'liation_' -'letters_' -'lett' -'lenders_' -'leichte' -'leich_' -'lehnt' -'legte_' -'legitime' -'legali' -'lebend' -'leas' -'leak' -'ldet' -'lder_' -'lde' -'lation_' -'lateinamerikanische_' -'laser_' -'laptop_' -'lap' -'langer_' -'landes_' -'lai' -'lager' -'lage' -'lacks_' -'kÃŧnstlerische' -'kÃŧnft' -'kÃŧhne' -'kÃŧhl' -'kÃļnnt_' -'käme_' -'ky' -'kurs_' -'ktiv_' -'kritisieren_' -'kreative_' -'kraftwerke' -'kow' -'kov_' -'kou' -'kostenfreien_' -'kostenfreiem_' -'kosm' -'korrigiert_' -'koo' -'konze' -'konventionelle' -'kontinuierliche' -'kontaktieren_' -'konsum' -'konstruktiv' -'konstitutionelle' -'konkret_' -'kompr' -'kompakt' -'kommerzielle_' -'kommende_' -'komfortabel_' -'kombin' -'kom' -'kolle' -'kod' -'klären_' -'kleinste' -'klei' -'klare' -'klapp' -'kki' -'kirch' -'kinder' -'kilometre' -'kill' -'kie' -'kerin' -'keinem_' -'kehrte_' -'kehrs' -'kauf' -'katastrophen_' -'kampf_' -'kamera' -'kal' -'jÃļ' -'jähriger_' -'jährige_' -'judiciary_' -'jk' -'jeu' -'jeti' -'jeopardi' -'jemande' -'jekt' -'jeglicher_' -'jar' -'jail' -'jahrelange' -'jad' -'izu' -'izer' -'iums_' -'itz_' -'ited_' -'issa' -'isolat' -'islamistische' -'ision' -'ische' -'irrational_' -'irm' -'irgendeiner_' -'ipe' -'inviting_' -'invited_' -'investor_' -'investigated_' -'inventi' -'inu' -'intra' -'intimate_' -'interventions_' -'interven' -'interrupt' -'interpreted_' -'interpretati' -'interne_' -'interg' -'intensiv_' -'integrity_' -'integrate_' -'instruct' -'institute' -'install' -'instability_' -'instabil' -'inspirierende' -'inspections_' -'inspe' -'insight_' -'insert_' -'insel' -'inse' -'innovativen_' -'inm' -'inkl' -'ingung' -'ingli' -'infra' -'infl' -'infecti' -'infant' -'inexp' -'ineffective_' -'industrielle_' -'indices_' -'indicates_' -'incorrect' -'inco' -'inar' -'impro' -'impressed_' -'importiert' -'imply_' -'implant' -'imperialis' -'immunity_' -'immi' -'immens' -'imme' -'imaging_' -'imagination_' -'ima_' -'im' -'ilie' -'igungen_' -'igten_' -'igte_' -'ifying_' -'iff_' -'ierungen_' -'ierende' -'ielen_' -'ieb' -'identification_' -'ideally_' -'ichkeit_' -'iche' -'iale' -'iad_' -'iPhone_' -'hÃŧbsch' -'häuser_' -'häufig' -'hängen_' -'humanitäre_' -'humane' -'hrung_' -'house' -'hosts_' -'hostilit' -'hostile_' -'homeland_' -'hochge' -'hm_' -'hinterl' -'hinein_' -'hinein' -'hinausge' -'hina' -'hierfÃŧr_' -'hierarchi' -'het_' -'hes' -'herzliche' -'hervorragende' -'hervorr' -'hervorbr' -'herunter_' -'herum_' -'herge' -'herein' -'herausragende' -'hende_' -'hemm' -'hell_' -'heit' -'heir' -'heilig' -'heftige' -'hed' -'heben_' -'heb' -'heav' -'healing_' -'header' -'haushalt' -'hausg' -'hate_' -'hat' -'harvest' -'harmonise_' -'harmonische' -'harmoni' -'happiness_' -'hani' -'handed_' -'ham' -'halt' -'halle' -'hack' -'habt_' -'gÃŧnstig' -'gym' -'gungen_' -'guitar' -'guidance_' -'gui' -'guest' -'guardian_' -'guard_' -'gua_' -'grÃŧnd' -'grÃļßt' -'grÃļsste' -'grund' -'groß' -'ground' -'grip' -'grim' -'grid' -'gri' -'grenzÃŧberschreitenden_' -'grenze_' -'greift_' -'gravi' -'gratis_' -'gration_' -'graph_' -'grain' -'gog' -'goa' -'gma_' -'globalis' -'gliche' -'glei' -'gische_' -'gin_' -'gifts_' -'geäußerten_' -'gewicht_' -'geting_' -'gesundheitliche' -'gestÃŧ' -'gestellten_' -'geste' -'gesorgt_' -'geschÃŧtzte' -'geschätzten_' -'geschwächt' -'geschi' -'gescheiterte' -'geschafft_' -'gerÃŧ' -'gerichtete_' -'gerichte' -'gerei' -'gerechtfertigt_' -'gerechte_' -'geopolitical_' -'genÃŧgt_' -'genießt_' -'gene_' -'gelä' -'geln_' -'gelie' -'gelei' -'gelegene_' -'gelangt_' -'geiz' -'gehe_' -'gegl' -'gegenseitig_' -'gegebenenfalls_' -'gefÃŧhl' -'gefährlicher_' -'gefr' -'gefasst_' -'gedrängt_' -'gebÃŧhren_' -'gebu' -'gebraucht_' -'gebra' -'gebiete_' -'geber_' -'geber' -'gaming_' -'gall' -'fÃŧnfzehn' -'fÃŧnfte' -'fÃŧgt_' -'fälle_' -'futures_' -'furthermore_' -'funktions' -'funktionierende' -'functionality_' -'func' -'fulfill' -'fts' -'ftig' -'fte' -'frustr' -'front' -'frisch_' -'friedlichen_' -'friedliche_' -'freez' -'franzÃļsische' -'fragile_' -'fortgesetzt_' -'formulierte_' -'formally_' -'forgi' -'forg' -'foreigners_' -'forci' -'force' -'foods_' -'fond' -'folg' -'fold_' -'fluc' -'fliegen_' -'flicht' -'fizi' -'fires_' -'fired_' -'fire' -'find' -'finanzierten_' -'file' -'fig' -'ffekt' -'fetch' -'festste' -'festiv' -'feren' -'feier' -'fehlenden_' -'feels_' -'federa' -'fear' -'favorable_' -'fatal' -'fasci' -'fare' -'fantas' -'fang' -'fanatic' -'famously_' -'familiengefÃŧhrte' -'falt' -'faktor' -'fakt' -'fairer' -'fairen_' -'factories_' -'facility_' -'facilitating_' -'fabri' -'eß' -'exzellenten_' -'extraordinarily_' -'extr' -'externe' -'exquisite_' -'expulsion_' -'explo' -'explanations_' -'experiments_' -'experiencing_' -'expe' -'exp' -'exo' -'existen' -'exhaust' -'excuse_' -'excu' -'excess' -'eventuell_' -'evan' -'europaweit_' -'etic_' -'eti' -'ethnische_' -'ethischen_' -'ethical_' -'eternal_' -'etch' -'estimate_' -'ession' -'essen' -'esk' -'erzielten_' -'erzeugen_' -'erzeug' -'ery_' -'erworben' -'erwi' -'erweis' -'erwarte' -'ervi' -'erupted_' -'erungs' -'eru' -'ertrag' -'ertig' -'ertes_' -'erteilt' -'ersuchen_' -'erste' -'erstaunliche' -'ersi' -'erschw' -'errichtete' -'erreich' -'eros' -'ernähr' -'ernste' -'ermutigt_' -'erläutert' -'erläutern_' -'erlern' -'erklärung' -'erie' -'erholsame' -'erhebt_' -'erheb' -'ergibt_' -'ergab' -'eres_' -'erende_' -'erend' -'equa' -'epr' -'eo' -'entwickel' -'entrance_' -'entities_' -'enthusiast' -'enthielt' -'entgegens' -'enten' -'enswerte_' -'enswert' -'ensu' -'enlarged_' -'enjoyable_' -'engaging_' -'engag' -'endorsed_' -'endes_' -'encryption_' -'encouragement_' -'ench' -'enburg_' -'employee_' -'empfinden_' -'empfiehlt_' -'empf' -'emission_' -'embraced_' -'embargo_' -'ema_' -'eln' -'elles_' -'elle' -'ella' -'elf' -'elegan' -'electron' -'elect_' -'ektor' -'eite' -'einzuleiten_' -'einzugehen_' -'einseitige' -'einsa' -'einig_' -'einig' -'einheit' -'einhe' -'eingerichteten_' -'eingeleitet' -'eingeh' -'eindeutige' -'eill' -'eilig' -'eih' -'eigens' -'eigen' -'eichnungen_' -'ehrt_' -'ehr_' -'ehnen_' -'ehmen' -'egu' -'egte' -'effektiver_' -'edl' -'eda' -'ecu' -'ecological_' -'eckte' -'ecke' -'echen_' -'eche' -'ebe' -'earn_' -'eager_' -'eag' -'dÃŧr' -'dys' -'dust_' -'dus_' -'durchzusetzen_' -'dumm' -'dt' -'dräng' -'dry' -'drop' -'dritt' -'dres' -'dreh' -'dreams_' -'drau' -'drastic_' -'dramatische_' -'downward_' -'downturn_' -'dou' -'dose_' -'dorf_' -'dogs_' -'document' -'doct' -'dl' -'dividing_' -'divide_' -'dity_' -'distortion_' -'distorti' -'distorted_' -'distin' -'distant_' -'disso' -'disregard' -'displays_' -'disg' -'disen' -'discourage' -'disappointing_' -'diplomatischen_' -'dio' -'dinar' -'dimensions_' -'dikti' -'diesmal_' -'diesjährige' -'diesbezÃŧgliche_' -'dien' -'diction' -'dges' -'dg' -'devote_' -'devo' -'developer' -'detr' -'determining_' -'deteriorati' -'detection_' -'despair_' -'deser' -'desde_' -'descent_' -'desa' -'deny_' -'denten_' -'dent_' -'density_' -'demokrati' -'demografische' -'democratization_' -'dema' -'delivering_' -'delightful_' -'delete_' -'dele' -'deko' -'degrees_' -'deg' -'definitions_' -'definite_' -'definierte' -'defend' -'defect' -'deeper_' -'decree' -'decor' -'deci' -'decentrali' -'debated_' -'daz' -'daw' -'dauert_' -'dau' -'dargelegt' -'dance_' -'dacht' -'cycles_' -'custom_' -'cus_' -'curi' -'cultures_' -'ction' -'crush' -'creditor_' -'cozy_' -'court' -'counterfeit' -'couldn_' -'costly_' -'cos' -'corrupt_' -'cooperative_' -'cool_' -'cool' -'conversion_' -'convention_' -'contra' -'contin' -'contest_' -'contend' -'contaminated_' -'consistently_' -'considers_' -'connecting_' -'confronted_' -'confrontation_' -'confo' -'configure_' -'config' -'conciliation_' -'comprise' -'compr' -'compound' -'compliment_' -'completion_' -'complementary_' -'comple' -'complaints_' -'complaint_' -'competence_' -'compell' -'compe' -'commerce_' -'commend' -'collecting_' -'collect' -'cluster_' -'cloth' -'clinic' -'cli' -'clearer_' -'cleane' -'clau' -'cks' -'cker_' -'chtung_' -'chtigkeit' -'chslung' -'chs_' -'chron' -'chli' -'chl' -'chere' -'chemis' -'chemicals_' -'change' -'chan' -'chaft' -'chaff' -'cents_' -'centri' -'censorship_' -'ceme' -'cell' -'ceiling_' -'ced' -'cci' -'catch' -'catastrophe_' -'carries_' -'carrier_' -'came' -'calculated_' -'calamit' -'cafe' -'cac' -'cabin' -'bÃŧro' -'bÃļ' -'by' -'buyer' -'butt' -'burning_' -'burn_' -'bungen_' -'bund' -'build' -'buffer' -'buchs' -'brÃŧc' -'browsing_' -'breaks_' -'brea' -'bout' -'borrow_' -'bookings_' -'boil' -'bog' -'bode' -'blue' -'blogs_' -'blocki' -'blocked_' -'bling_' -'blich_' -'bleibende' -'blas' -'bittere' -'bische' -'birthday_' -'biome' -'biologischen_' -'billige' -'bill_' -'bilaterale_' -'bike_' -'bg' -'bezogen' -'bewirkt_' -'bevorzugte' -'bevorstehende' -'beverages_' -'bett_' -'bets_' -'betrug_' -'betrieb_' -'betre' -'betrachte_' -'betra' -'besucht_' -'bestimmungen_' -'besorgniserregend' -'besitzer_' -'besetzten_' -'beschriebene' -'beschreiben_' -'bereitzustellen_' -'bereitgestellt_' -'benchmark' -'bemerken_' -'belt' -'belo' -'beklagen_' -'bekannteste' -'bekannte_' -'beizu' -'beherrschen_' -'begrenzen_' -'begannen_' -'befÃŧrwortet_' -'befÃŧrworten_' -'befÃļrder' -'befreien_' -'befr' -'befand' -'befa' -'beeinträchtigt_' -'beein' -'bedingten_' -'bedeutendsten_' -'bedeutend' -'bedding_' -'beck_' -'beam' -'beabsichtigt_' -'bay_' -'bares_' -'barem_' -'banned_' -'bankr' -'bankers_' -'banken_' -'banke' -'bands_' -'bade' -'aya_' -'ay' -'awful_' -'awa_' -'aw_' -'aviation_' -'avert' -'avers' -'aver' -'autori' -'automatische' -'automati' -'authorized_' -'authorita' -'auth' -'ausÃŧb' -'auszuge' -'ausser' -'aussehen_' -'auss' -'ausn' -'ausgeÃŧbt' -'ausgezeichnet_' -'ausgenommen_' -'ausgel' -'ausgehenden_' -'ausgeben_' -'auseinander' -'ause_' -'ausd' -'auschen_' -'ault_' -'augment' -'aufwe' -'aufsicht_' -'aufs_' -'aufr' -'aufhÃļren_' -'aufhalten_' -'aufgerufen_' -'aufgehoben_' -'aufgeben_' -'aufga' -'aufbauen_' -'auer_' -'attracted_' -'attent' -'attend' -'attach' -'ats_' -'atory_' -'atomaren_' -'atives_' -'ative' -'ath_' -'asto' -'aster' -'assur' -'assung' -'assignment' -'assembl' -'asks_' -'asiatische_' -'asia' -'arro' -'arriving_' -'arrives_' -'arranged_' -'armo' -'arm' -'aris' -'ario' -'arin' -'arf' -'ardo' -'archives_' -'architektonisch' -'architectural_' -'aqua' -'apt' -'approaching_' -'applaud' -'apparently_' -'anzupassen_' -'anzuerkennen_' -'anu' -'antibiotic' -'anta' -'ansprechen_' -'ansetzen_' -'anon' -'annte' -'announcement' -'angesprochene' -'angese' -'angen' -'angekÃŧndigte' -'angehÃļre' -'angegeben_' -'angebliche' -'ane' -'anderweitig' -'andel' -'analy' -'amt' -'ament' -'amen_' -'ambiance_' -'aman' -'altogether_' -'altig' -'alte' -'als' -'allmählich' -'alliances_' -'alk' -'alismus_' -'alism_' -'alige' -'algo' -'aktualisiert' -'aktion' -'aki' -'ais_' -'ains_' -'ail_' -'ail' -'ahren_' -'agung' -'agne_' -'aggressive' -'agg' -'afts' -'afrikanische_' -'afft_' -'afford' -'afflict' -'aero' -'advocate' -'adventure_' -'advan' -'adv' -'adu' -'adore' -'ado' -'admitted' -'adjust_' -'adj' -'ademi' -'add' -'ada_' -'acute_' -'acted_' -'acr' -'acquire_' -'acion' -'acht' -'achst' -'ache_' -'accu' -'accesse' -'accelerating_' -'acced' -'ac_' -'abzuw' -'abste' -'abstain' -'absorb' -'abortion_' -'abolish' -'aboard_' -'ablehnen_' -'abilis' -'abger' -'aber' -'abd' -'abba' -'aban' -'aat' -'Zypern_' -'Zuwa' -'Zusammenfassung_' -'Zun' -'Zul' -'Zuh' -'Zivilisten_' -'Zion' -'Zers' -'Zell' -'Zeilen' -'Zealand_' -'Zan' -'Xe' -'Xa' -'XLS' -'WÃŧsten' -'Währungsk' -'Wy' -'Worse_' -'Worf_' -'Word_' -'Wollen_' -'Wohnsi' -'Wochen' -'Wissenschaft' -'Wirtschaftsre' -'Wirtschaftsr' -'Wirtschaftskrise_' -'Wirtschaftlich' -'Wirbelst' -'Winds' -'Winde' -'Williams' -'Willi' -'Will' -'Wiener_' -'Wiederg' -'Widerstand' -'Wide' -'Wichtig' -'Werde_' -'Wend' -'Weltkriegs_' -'Welthandels' -'Wellnessbereich_' -'Weiße_' -'Weiterhin_' -'Weisheit_' -'Weis' -'Weihnachten_' -'Wehr' -'Weg' -'Way' -'Watt' -'Wasch' -'Waren' -'Wan' -'Wahlbe' -'WWII_' -'WS' -'WP' -'WHO_' -'VÃļlkern_' -'VÃļlker' -'Vorredner' -'Vorjahr' -'Vorherrschaft_' -'Vorha' -'Voraus_' -'Voraus' -'Vollbeschäftigung_' -'Volkswagen_' -'Vol' -'Vitorino_' -'Visit' -'Vision' -'Vik' -'Viertel' -'Victoria_' -'Via_' -'Veto' -'Verweis' -'Verwand' -'Verv' -'Verteidigungsminister' -'Versuch' -'Versorgungs' -'Versicherungen_' -'Verkehrsnetz' -'Verkehrsa' -'Verhaftung_' -'Vergleich' -'Vergessen' -'Verbraucherschutz' -'Verbrauch_' -'Verantwortlichen_' -'Veran' -'Vera' -'Vent' -'Venez' -'Vas' -'Vari' -'Vall' -'Vac' -'Ursprungs' -'Urheber' -'Updates_' -'Unver' -'Unterh' -'Untere' -'Unterdessen_' -'Unst' -'Unsicherheit_' -'Unlike_' -'Univers' -'UnionsbÃŧrger' -'Ungleichgewichte' -'Ungleich' -'Unfälle' -'Unabhängig' -'Umweltverschmutzung_' -'Umwelta' -'Umstrukturierung_' -'Umstellung_' -'Ult' -'Ul' -'Ug' -'Ufer_' -'Uf' -'USE_' -'UP' -'UKIP_' -'TÃļ' -'Tät' -'Ty' -'Turm' -'Tunisia_' -'Truppen' -'Troika_' -'Trin' -'Trial' -'Trennlinie' -'Trends_' -'Trail' -'Touch' -'Tos' -'Tom_' -'Tol' -'Tochter' -'Tit' -'Tief' -'Thor' -'Thirdly_' -'Thema' -'Theatre_' -'Th' -'Teufel_' -'Territorium' -'Terrace_' -'Termine_' -'Tenn' -'Ten_' -'Temperatur_' -'Techni' -'Tech_' -'Tanz' -'Tank' -'Taba' -'TS_' -'TRA' -'TP' -'TOS_' -'TAT' -'SÃŧnden_' -'Säule' -'Säu' -'Sz' -'Symp' -'Sun_' -'Sud' -'Subvention' -'Subsidiarität_' -'StÃŧck_' -'Stuttgart_' -'Student' -'Stri' -'Strecke_' -'Strahlung_' -'Straftat' -'Stop' -'Stock' -'Stimme' -'Stellvertreter_' -'Stellungnahmen_' -'Stellenwert_' -'Steigen' -'Stea' -'Statute_' -'Statut' -'Statisti' -'Starts' -'Standort' -'Stabilitätspakt_' -'StaatsbÃŧrger' -'Staat' -'Sri_' -'Sr' -'Später_' -'Sprech' -'Sponsor' -'Spo' -'Spl' -'Spit' -'Spezialis' -'Speed_' -'Spaziergang_' -'Spannungs' -'Spanischen_' -'Sozialpartner' -'Sozialdemokrati' -'Sozialdemokraten_' -'Souvenir' -'Sonic_' -'Solo' -'Software' -'Small_' -'Slowakei_' -'Slovenia_' -'Slovakia_' -'Slideshows_' -'Sla' -'Sitze_' -'Sis' -'Sina' -'Simply_' -'Simpl' -'Silver_' -'Signale_' -'Sichtweise_' -'Sicht' -'Sicherung_' -'Shopping_' -'Shinzo_' -'Sex' -'Seven' -'Seuche' -'Session_' -'Seri' -'Serb' -'Sensor' -'Sensibili' -'Selbstvert' -'Sekunden_' -'Sekund' -'Seit' -'SchÃļnheit_' -'Schätz' -'Schwan' -'Schw' -'Schutze' -'Schum' -'Schulter' -'Schreib' -'Schottland_' -'Schne' -'Schn' -'Schmit' -'Schiffen_' -'Schiffe_' -'Schauspieler_' -'Scha' -'Scandinavia' -'Save_' -'Sat_' -'Same' -'Sali' -'Sala' -'ST_' -'SOL' -'SN' -'SDR' -'SAR' -'RÃŧcken_' -'RÃļm' -'Räume' -'Ruh' -'Row' -'Route' -'Roth' -'Rote' -'Rota' -'Rosa' -'Rollen' -'Rohstoffpreise' -'Rohstoff' -'Rod' -'Risikos' -'Ring' -'Rie' -'Rich' -'Ria' -'Ressource_' -'Residenz_' -'Residence_' -'Renzi_' -'Reli' -'Rekord' -'Reichen_' -'Regulator' -'Regionalp' -'Referenz' -'Referend' -'Refer' -'Reden_' -'Redebeitr' -'Reco' -'Rechtsgrundlage_' -'Rechnungshof' -'Rechner_' -'Rebell' -'Read' -'Raumfahrt' -'Ratspräsident' -'Ras' -'Range' -'Rai' -'Rahmenbedingungen_' -'Radisson_' -'Rab' -'REACH_' -'RB' -'Quartal_' -'Quart' -'Qualifi' -'Qaddafi_' -'QE_' -'Put' -'Puffer' -'Psych' -'Prozent' -'Profi' -'Produktp' -'Print' -'Prestige_' -'Prese' -'Prag_' -'Pon' -'Poly' -'Politisch' -'Polar' -'Poettering_' -'Poe' -'Plätze_' -'Ple' -'PlayStation_' -'Pier' -'Pic' -'Phone' -'Philadelphia_' -'Pferde' -'Pet' -'PersÃļnlichkeiten_' -'Persian_' -'Persi' -'Pere' -'Pensionen_' -'Pen_' -'Pec' -'Partners' -'Partition' -'Part_' -'Parma' -'Parlamentswahlen_' -'Parlamentsa' -'Parking_' -'Paris' -'Pare' -'Pana' -'Palästinensische' -'Palette_' -'Palast_' -'Pakistani_' -'Pakistan' -'Paketen_' -'Pach' -'Paa' -'PT' -'PNR_' -'PH' -'PCs_' -'PARIS_' -'PAR' -'Otto' -'Ot' -'Orte' -'Oro' -'Ori' -'Options' -'Operationen_' -'Oper' -'Olympia' -'Oli' -'Offizier' -'Office' -'Ocean_' -'Obr' -'Objektiv' -'Oberflächen' -'OM_' -'NÃļ' -'Nähr' -'Nous_' -'Noten' -'Norwege' -'Norda' -'Nieder' -'Neue' -'Netz' -'Netanjahu_' -'Ner' -'Nenn' -'Nea' -'Nationalismus_' -'Nationale_' -'Napoleon_' -'Nahost' -'Nachk' -'NS_' -'NP' -'NL_' -'NET_' -'MÃŧnzen_' -'MÃŧll' -'MÃļnch' -'MÃļglicherweise_' -'MÃļglich' -'Mutter' -'Musk' -'Moza' -'Movi' -'Moro' -'Morgan_' -'Moral_' -'Mont_' -'Mohamm' -'Modi_' -'Mod' -'Mobil' -'Mist' -'Mission' -'Mine' -'Milli' -'Militära' -'Mike_' -'Migu' -'Migrations' -'Mes' -'Mena' -'Meldung' -'Meeting_' -'Medikamenten_' -'Med' -'McCain_' -'Mayo' -'Maur' -'Mauer_' -'Mathematik_' -'Massenvernichtungswaffen_' -'Mary' -'Marktk' -'Market' -'Marken' -'Maritime_' -'Marine_' -'Marin' -'Mare_' -'Marco' -'Mara' -'Mao' -'Mang' -'Mandrake' -'Mandela_' -'Magazin_' -'Macedonia_' -'Maca' -'MG' -'Lyon_' -'Luxemburg_' -'Lunch' -'Luftverkehr_' -'Ludwig' -'Lub' -'Lor' -'London' -'Lobby' -'Llo' -'Liv' -'Little_' -'Lithuania' -'Literatur_' -'Lita' -'Listen' -'Linz_' -'Lieblings' -'Liebe_' -'Libert' -'Liberalen_' -'Level_' -'Leuten_' -'Letzte' -'Les' -'Lektion_' -'Legislativ' -'Lebenser' -'Lay' -'Late' -'Lasten_' -'Lap' -'Lanzarote_' -'Lande_' -'Lai' -'Lager_' -'Labora' -'LL_' -'LICH' -'LG' -'LD' -'LCD_' -'LAN_' -'KÃŧ' -'KÃļnig' -'Kämpfe_' -'Kurse' -'Krugman_' -'Kroatien_' -'Kreis' -'Kreditvergabe_' -'Kreditkarten' -'Krebs' -'Konzentration_' -'Konvergenz_' -'Konsultationen_' -'Konst' -'Konso' -'Konkurs' -'Konfiguration' -'Kompromisse_' -'Kompo' -'Kommunistische' -'Kommunisten_' -'Kommun' -'Kommentare_' -'Komit' -'Koch' -'Knopf' -'Kni' -'Klimaschutz' -'Klick' -'Klang' -'Klagen' -'Kernkraft' -'Kennzeichnung' -'Kaz' -'Kapitals' -'Kapitalm' -'Kant' -'Kalifornien_' -'Kal' -'Kaffee_' -'KT_' -'KT' -'KOM_' -'KG' -'Julia' -'Jugoslawien_' -'Jugendliche_' -'Juden_' -'Jubil' -'Journalist' -'Jose_' -'Jor' -'Johnson' -'Joh' -'Joe' -'Jinping_' -'Jia' -'Jer' -'Jav' -'Jardin_' -'Jar' -'Jahrtausend' -'Jahrestag_' -'JP' -'Isol' -'Isle' -'Islamische' -'Investment_' -'Investitionsb' -'Interv' -'Internetseite' -'Interinstitution' -'Institution' -'Instanz' -'Instabilität_' -'Inspektoren_' -'Insofern_' -'Insel' -'Innenhof_' -'Ingenieure' -'Infra' -'Informationsschalter_' -'Ine' -'Industriep' -'Individu' -'Indikatoren_' -'Indian' -'Inde' -'Increase' -'Impl' -'Imperial' -'Immerhin_' -'Immer_' -'Ih' -'If' -'Ic' -'ION' -'INC' -'HÃŧgel_' -'Händler_' -'Hyde_' -'Hurri' -'Hung' -'Hunde' -'Humanit' -'Hub' -'Hou' -'Hotelsafe_' -'Horde_' -'Homo' -'Holy_' -'Holocaust_' -'Hollywood_' -'Hoffnungen_' -'Hof' -'Hitler_' -'Hinterl' -'Higher_' -'Hierbei_' -'Hes' -'Herbst_' -'Hence_' -'Heinrich_' -'Heilig' -'Heat' -'Haut' -'Haushaltsp' -'Haushaltsl' -'Haushaltskon' -'Haushaltsk' -'Haushalts_' -'Hauptziel' -'Harry_' -'Hap' -'Hans_' -'Handelsver' -'Halle_' -'Halle' -'Hafen' -'Habr' -'Hab' -'HN' -'HI_' -'HAVEN_' -'Gut_' -'Grundsatz' -'Grundrecht' -'Großunternehmen_' -'Gross' -'Grie' -'Gree' -'Graz' -'Gou' -'Got' -'Go_' -'Gn' -'Glaubens' -'Gift_' -'Gib' -'Gewerkschaft' -'Gesundheitsversorgung_' -'Gespräch' -'Gesetz' -'Geschlechter' -'Gesan' -'Geräte_' -'Gerhard_' -'Gep' -'Genau' -'Gemeinschaftsm' -'Geiste' -'GefÃŧhl' -'Gefähr' -'Gedächtnis_' -'Gebäude' -'Garten' -'Garni' -'Gare_' -'Garc' -'Game' -'GNOME_' -'GN' -'GMO_' -'GH' -'G20_' -'FÃŧhrungsrolle_' -'FÃŧhrungskräfte' -'FÃŧhrungsk' -'Fuß' -'Fusion' -'Furcht_' -'Fuerte' -'FrÃŧhjahr_' -'FrÃŧher' -'Frist_' -'Fris' -'Friedrich_' -'Fried' -'Frequen' -'Freitag_' -'Freiheits' -'Freie_' -'Freib' -'Fre' -'Frageb' -'Fracht' -'Format' -'Forest_' -'Fon' -'Folter_' -'Folge' -'FlÃŧchtlingen_' -'Fluggäste_' -'Fluggesellschaften_' -'Flugg' -'Flotte' -'Fleisch_' -'Fir' -'Fine_' -'Finanzw' -'Finanzsystems_' -'Finanzmärkten_' -'Finanzmittel_' -'Finanzmi' -'Finanzinstrument' -'Finanzau' -'Finance_' -'Files_' -'Festplatten_' -'Ferien_' -'Fels' -'Felder' -'Fehlen' -'Faschis' -'Fan' -'Fakten_' -'Fahrzeug_' -'Facility_' -'Fachw' -'Fabrik' -'FPGA_' -'FOR_' -'FM' -'FI_' -'External_' -'Expo' -'Explosi' -'Exce' -'Ex_' -'Evolution_' -'Ever' -'Eurosta' -'Euros_' -'Eti' -'Etage_' -'Erzeugnisse_' -'Erwerbs' -'Erwachsenen' -'Ernennung_' -'Ermordung_' -'Ermittlung' -'Erla' -'Eric' -'ErfÃŧllung_' -'Ereignissen_' -'Erdg' -'Erderwärmung_' -'Entwicklungsb' -'Entschließungen_' -'Enron_' -'Enl' -'Enhance' -'Englisch' -'Engl' -'Energiepr' -'Energiepolitik_' -'Energiem' -'Endes_' -'Employment_' -'Empf' -'Elf' -'Election' -'Eisb' -'Eis_' -'Einwanderungs' -'Einw' -'Einste' -'Einig' -'Einheits' -'Einheiten_' -'Eingreif' -'Einfach' -'Eine' -'Eindr' -'Eigenschaft_' -'Eiffel_' -'Eif' -'Edward' -'Edit' -'Edin' -'Eden_' -'Ebenen_' -'Eben' -'ENT' -'Durban_' -'Download' -'Dow_' -'Dou' -'Dorf' -'Dominica' -'Doll' -'Document' -'Doc' -'Division_' -'Diktatur_' -'Diktat' -'Diesel' -'Dies' -'Diamant' -'Devisenwechsel_' -'Deli' -'Delhi_' -'Deg' -'Deco' -'Datum_' -'Dasein' -'Darauf' -'Danke_' -'Dama' -'Dam' -'DU' -'DR_' -'Cubase_' -'Criminal_' -'Corr' -'Corporation_' -'Cori' -'Copyright_' -'Conver' -'Convenient' -'Controller' -'Contact_' -'Connect' -'Compo' -'Color_' -'Colombia_' -'Collect' -'Coll' -'Cocktail' -'Coc' -'Cluster' -'Client' -'Clear' -'Clar' -'Christopher_' -'Christians_' -'Choose_' -'Chin' -'Chen_' -'Charme_' -'Channel_' -'Champ' -'Cham' -'Chai' -'Chad_' -'Cell' -'Catholic_' -'Cathedral_' -'Cath' -'Castro_' -'Castel' -'Carolyn_' -'Carolina_' -'Carn' -'Cana' -'CafÊs_' -'CR' -'COD_' -'CNS_' -'CN' -'CK' -'CAS' -'BÃŧndnis_' -'BÃļse' -'Bäume_' -'Business' -'Bundesrepublik_' -'Bun' -'Budget' -'Buck' -'Buchungs' -'Brutto' -'Brunnen_' -'Bruc' -'Brown_' -'Brow' -'Broadway_' -'Brit' -'Bring' -'Brid' -'Brennstoff' -'Brandenburg' -'Botschaften_' -'Bosnia_' -'Borr' -'Boom' -'Bonus_' -'Bond' -'Bomb' -'Bolivi' -'Bod' -'Blizzard_' -'Blitz' -'Blind' -'Biokraftstoffe' -'Bind' -'Bin_' -'Bezug' -'Bewerb' -'Beweg' -'BetrÃŧger' -'Bestimmung' -'Beschu' -'Beschränkungen_' -'Berei' -'Berechnung_' -'Berater_' -'Benutzer' -'Benutz' -'Beng' -'Belo' -'Beli' -'Belastungen_' -'Beitrittsverhandlungen_' -'BehÃļrde_' -'Behinderung' -'Behinderte' -'BegrÃŧndung_' -'Begleiter' -'Beginn' -'Bege' -'Befreiung' -'Beda' -'Bed_' -'Bavaria_' -'Bauste' -'Baust' -'Barcode_' -'Barcelon' -'Barbara_' -'Bahrain_' -'Bahnh' -'Baby_' -'BES' -'Außenhandel' -'Autos_' -'Autorit' -'Autonom' -'Auth' -'Ausz' -'Auswahl' -'Austausch_' -'Ausstellung_' -'Aussagen_' -'Ausgabe_' -'Ausbruch_' -'Ausbau' -'Auftrags' -'Aufständ' -'Aufstand' -'Aufpreis_' -'Aufla' -'Aufklärung_' -'Aufb' -'Audio' -'Atlant' -'Athen' -'Astro' -'Astrium_' -'Assist' -'Asp' -'Arts_' -'Artikels_' -'Arth' -'Ars' -'Archiv' -'Arbeitsweise_' -'Arbeitsrecht' -'Arbeitskräften_' -'Arbeitsk' -'Arbeits_' -'Anwender_' -'Antw' -'Antonio_' -'Anton' -'Antar' -'Ansehen_' -'Anreisedatum_' -'Anrei' -'Annu' -'Anmeld' -'Anleihen' -'Ankunft' -'Anhang_' -'Angriffen_' -'Angeli' -'Angeles_' -'AngehÃļrige' -'AnfÃŧhrer_' -'Andro' -'Andrew_' -'Andreas' -'Andr' -'Andernfalls_' -'Analog' -'America' -'Alt_' -'Already_' -'Alpha_' -'Almos' -'Alltag_' -'Alkohol' -'Ali_' -'Algeria_' -'Alge' -'Albert_' -'Aktu' -'Akt_' -'Akku' -'Aix_' -'Agriculture_' -'Agent' -'Afghan_' -'Afghan' -'Aero' -'Administrat' -'Adi' -'Acht' -'Academy_' -'Abänderung' -'Abzug_' -'Abschwung_' -'Abk' -'Abgabe' -'Abenteuer_' -'Abde' -'AU_' -'AS_' -'ASPs_' -'API' -'AL_' -'AGE' -'A4' -'=' -';&_' -';' -'94' -'92' -'91_' -'89' -'86' -'84' -'83_' -'77' -'70' -'66' -'64' -'61' -'57' -'55' -'52' -'500' -'36' -'34' -'2nd_' -'1988_' -'1983_' -'1978_' -'1973_' -'197' -'1960s_' -'1957_' -'1939_' -'1931_' -'1920_' -'18th_' -'168' -'152' -'127' -'111_' -'108' -'103_' -'006' -'. .' -'*' -'), "_' -'))_' -') , _' -'''' ' -''''' -'$' -'">' -'")' -' − _' -' â€Ļ' -' ­' -' : ' -' // _' -' ...' -' *' -' (* _' -' ''[[_' -' #' -'īŋŊ' -'â€Ļ_' -'”: _' -'“ ' -'‘' -'— _' -'– ' -'Ņ' -'҈' -'ĐŊа' -'ĐŧĐĩ' -'ĐģҌ' -'Đģи' -'Đēа' -'Đ¸Ņ‚' -'иĐŊ' -'Đļ' -'Ī„' -'ÅĄa' -'Å ' -'ş' -'če' -'Ãŧtzen_' -'Ãŧt' -'Ãŧs' -'Ãŧnder' -'Ãŧmmer' -'Ãŧmer' -'Ãŧllung' -'Ãŧhrten_' -'Ãŧgt' -'Ãŧglich' -'Ãŧckt_' -'Ãŧche' -'Ãŧblicherweise_' -'Ãŧberwie' -'Ãŧberwacht_' -'Ãŧbertri' -'Ãŧbersehen_' -'Ãŧberraschen' -'Ãŧberlegen_' -'Ãŧberleben_' -'Ãŧberla' -'Ãŧbergreifende' -'Ãŧbergehen' -'Ãŧbereinstimmen_' -'Ãŧberein_' -'Ãēn_' -'Ãēl_' -'Ãļstlich' -'Ãļs_' -'Ãļrung_' -'Ãļkologische_' -'Ãļh' -'Ãļffnete' -'ï' -'Ês' -'èr' -'äußert' -'äußere' -'äus' -'äum' -'ätzt' -'ätz' -'ätigkeit' -'äter' -'äte_' -'ärmere' -'ärme' -'änglich' -'ändigen_' -'ändig' -'änderte' -'än_' -'ämpfer_' -'ältere_' -'ällig' -'äische' -'ähn' -'äglich' -'äfte' -'ächung_' -'Ãĸt' -'Ãĸ' -'ßnahmen_' -'Übertr' -'Übersetzungen_' -'Überleben' -'Überla' -'Übereinkunft_' -'Österreich' -'Ölpr' -'Ökonomien_' -'Äußerung_' -'Är' -'Š' -'    ' -'}}) {{_' -'}}' -'}' -'zÃļgern_' -'zählte' -'zähl' -'zze' -'zyklische' -'zwischenstaatlichen_' -'zwi' -'zweimal_' -'zweige' -'zuweis' -'zuteil' -'zuständige_' -'zusammenf' -'zusammenarbeit' -'zurÃŧckzukehren_' -'zurÃŧckgeh' -'zurÃŧckgeg' -'zurÃŧckb' -'zurzeit_' -'zunehmender_' -'zula' -'zukÃŧnftig' -'zuhalten_' -'zugeg' -'zugefÃŧ' -'zufrieden' -'zivile_' -'zivile' -'zim' -'zig' -'ziehungen_' -'zic' -'zeugen_' -'zession' -'zens_' -'zellen_' -'zeits' -'zeitlich_' -'zei' -'zehn' -'zbe' -'zad' -'yri' -'yos' -'ymen' -'yle' -'yh' -'yd' -'ych' -'xle' -'xit' -'xist' -'xamp' -'wÃŧrdigkeit_' -'wÃŧrdig_' -'wÃŧnschte' -'wärts_' -'wäl' -'währung_' -'währende' -'wäch' -'wungene' -'wunderschÃļne_' -'worri' -'woody_' -'womÃļglich_' -'womit_' -'wohlhabend' -'woh' -'wofÃŧr_' -'withdraw' -'wissenschaft' -'wishing_' -'wirtschafts' -'wirksamen_' -'winkel_' -'winds' -'wind' -'wiese' -'wiederge' -'widmet' -'widerstehen_' -'widerspiegeln_' -'wick_' -'wich' -'whol' -'westlicher_' -'west' -'wesentlicher_' -'wende' -'wel_' -'weitesten_' -'weitergehen_' -'weitem_' -'weigh' -'weifel' -'wear_' -'weaken' -'wc' -'wasted_' -'wash' -'wartung' -'wartet_' -'warnings_' -'wandte' -'wake' -'wahrs' -'wahrnehmen_' -'wahl_' -'wachsender_' -'wachsen' -'vÃļ' -'vÊ' -'vä' -'vre_' -'vorzugehen_' -'vorwe' -'vorschriften_' -'vorschlägt_' -'vorschl' -'vorrangige' -'vorle' -'vorlag' -'vorherzusagen_' -'vorha' -'vorgezogen' -'vorgeg' -'vorgebracht' -'vorbereite' -'voranzutreiben_' -'voransch' -'voran_' -'voor' -'voluntarily_' -'voltage_' -'vollständigen_' -'vollst' -'vollem_' -'volcan' -'voi' -'viv' -'vita' -'visuali' -'visits_' -'vision' -'vised_' -'violation_' -'violate_' -'vio' -'vigilan' -'vig' -'vielerlei_' -'vide' -'victor' -'vey' -'veto_' -'vest' -'verzÃļgert' -'verzÃļgern_' -'verzweifelt' -'verzerr' -'verwurzelt' -'verwendeten_' -'verweiger' -'verwandelt_' -'verurteilen_' -'verträge_' -'vertrau' -'vertiefen_' -'verteil' -'verteidigt_' -'verstÃŧmmel' -'verstärkten_' -'verständliche' -'verständig' -'verstre' -'verstecken_' -'verste' -'versprochen_' -'verschoben_' -'versatil' -'versagen_' -'verpflichte' -'vernÃŧnftige_' -'vernichtet' -'vernichten' -'vernehm' -'verne' -'verme' -'verläss' -'verliehen_' -'verlie' -'verleg' -'verlangsam' -'verlager' -'verkÃļrpert' -'verkäufe' -'verkl' -'verke' -'verkauf' -'verhandlungen_' -'verhandelt_' -'vergrÃļßern_' -'vergleichbar' -'verfallen_' -'verfa' -'verehrte' -'verdamm' -'verbreitet' -'verbreiten_' -'verbraucht_' -'verbracht' -'verbesserte_' -'verbesser' -'verarbeitet' -'verarbeiten' -'verar' -'verantwortungsvolle' -'verantwortlichen_' -'veranlassen_' -'veraltet' -'vegetari' -'vate' -'varies_' -'varied_' -'variant' -'validity_' -'vague_' -'vad' -'vaccine' -'uve' -'uung' -'ution' -'utin' -'usr' -'usive' -'uropäischen_' -'uropa_' -'uro' -'urn' -'urgische' -'ure' -'urd' -'urchschnittlich' -'uphold_' -'upgrading_' -'unzählige' -'unwind_' -'unwillingness_' -'unwilli' -'unwa' -'unvorher' -'unvo' -'unveränder' -'unvermeidlich' -'unvergessliche' -'unus' -'unum' -'unterziehen_' -'unterz' -'untersuch' -'unterstÃŧtzten_' -'unterminier' -'unterl' -'unterhalt' -'unteren' -'unterdrÃŧck' -'unterbreite' -'unterbrech' -'unsp' -'unsicheren_' -'unsich' -'unrest' -'unregul' -'unreal' -'unpopul' -'unnÃļtig_' -'unma' -'unko' -'unknown_' -'unklar_' -'unke' -'unjust' -'universali' -'uniquely_' -'unified_' -'unif' -'ungsw' -'ungsst' -'ungsre' -'ungspr' -'ungsmi' -'ungser' -'ungew' -'ungenÃŧgend' -'ungen' -'unfähig_' -'unfo' -'unfi' -'unersch' -'unerlässlich_' -'unein' -'undu' -'undervalu' -'undertakings_' -'unders' -'undenen_' -'unaus' -'unauf' -'unanimity_' -'unangenehme' -'unabhängiges_' -'umzuge' -'umm' -'umgebende' -'umbrella_' -'umben' -'ultur' -'ultra_' -'ultima' -'ulenzen_' -'uldet' -'ulation_' -'ulati' -'uke' -'ui_' -'udi' -'uchte' -'uche_' -'uber' -'ualisier' -'ual' -'tÃŧrkische_' -'tÃŧc' -'tÃļten_' -'täuschen' -'tände' -'tän' -'tzt_' -'tzig' -'tyr' -'typische' -'typ' -'two' -'twist' -'tutor' -'turm_' -'turi' -'tuf' -'tuati' -'tts' -'tti_' -'trä' -'truths_' -'trust' -'trop' -'trocken' -'tries_' -'tricks_' -'trick_' -'treu' -'tres_' -'tres' -'tree' -'traße_' -'travelling_' -'traveller_' -'trav' -'traumati' -'trauen_' -'traten_' -'transportieren_' -'transporte' -'transmitted_' -'translate_' -'transfer' -'train' -'tragische' -'traditionally_' -'trademark_' -'traded_' -'trac' -'tourists_' -'touristi' -'tough' -'touching_' -'totale' -'tot_' -'toria' -'tooth' -'tonn' -'tolerate_' -'tod' -'tner' -'tles_' -'tiven' -'titu' -'tite' -'tip_' -'tionali' -'tine_' -'tiker' -'tighten' -'tifi' -'tie_' -'tick' -'tically_' -'tib' -'thw' -'thrown_' -'thrill' -'threatens_' -'thirty_' -'thic' -'thesis_' -'theatre_' -'theatr' -'textiles_' -'texte' -'text' -'teuer_' -'terroristische' -'terrasse_' -'teresse' -'tender' -'temples_' -'tement' -'tels_' -'tellte' -'tein' -'teilnehmer_' -'teiligung_' -'teilgenommen_' -'tehende' -'teen_' -'tchin' -'tched_' -'tche' -'tc' -'tax' -'tausende_' -'taucht' -'tatten_' -'tations_' -'tat_' -'tane' -'tana' -'talism' -'taktische' -'tain_' -'tailored_' -'tactic' -'sza' -'systematically_' -'sys_' -'synt' -'synchronisier' -'symptoms_' -'symptom' -'sympath' -'symbols_' -'symbolis' -'swift' -'sweise_' -'sway' -'sustain_' -'suspend' -'susp' -'sus' -'surgery_' -'suprem' -'supranationale' -'supportive_' -'supplement' -'suppl' -'sund' -'sums_' -'sula' -'suits_' -'sui' -'succumb' -'subve' -'subt' -'submi' -'stÃŧcke' -'stätte' -'stärkere' -'ständige' -'styles_' -'struggled_' -'structured_' -'stroll_' -'strikter' -'streu' -'stressed_' -'strengen_' -'straße_' -'strategischer_' -'strat' -'strain_' -'strafrechtliche_' -'stoßen_' -'stones_' -'stolen_' -'stol' -'stitch' -'stisch' -'stimulat' -'stimmte' -'stilvolle' -'stigung' -'stigt' -'stick' -'stellvertretende' -'stellend_' -'steile' -'steigerung_' -'steigert_' -'steh' -'steep' -'sted_' -'stau' -'statute_' -'statu' -'station' -'start' -'starring_' -'starkes_' -'standen_' -'stalte' -'stagnier' -'stad' -'stabilize_' -'stabilis' -'staats' -'staatlicher_' -'staatliche' -'staatlich_' -'sstr' -'sste_' -'ssiv' -'ssions_' -'ssal' -'ssa' -'srat' -'spÃŧr' -'spur' -'sprung' -'spreads_' -'sprachig' -'spora' -'sponsors_' -'spolitik_' -'splend' -'spiritual_' -'spiritu' -'spill' -'sperre' -'specialist_' -'spec' -'spatial' -'spare_' -'spannend' -'span' -'soziales_' -'souve' -'sorry_' -'sorgte' -'sorgfältige' -'sor_' -'song_' -'sona' -'solven' -'solidi' -'sola' -'soi' -'sofortige' -'soeben_' -'socially_' -'social' -'sob' -'snowb' -'sno' -'snack' -'smarte' -'smann' -'slowly_' -'slower_' -'slot' -'slee' -'skilled_' -'skeptics_' -'skan' -'sitzung_' -'siste' -'sincere_' -'simplify_' -'simplifi' -'similar' -'sik' -'sightseeing_' -'sierungs' -'siebz' -'sid_' -'sicht_' -'sicherheits' -'sicherere' -'show' -'shoulder_' -'shorte' -'sheer_' -'shadows_' -'shade' -'sg' -'sfrei' -'sfe' -'sexuelle_' -'severely_' -'setzung' -'setup_' -'seren' -'seperate' -'separation_' -'separated_' -'sentiment_' -'sensors_' -'sensor_' -'sensib' -'sender_' -'seminar' -'seme' -'sem' -'selt' -'seller' -'seize_' -'segments_' -'segment' -'sechs' -'seating_' -'scra' -'sco_' -'scienti' -'sci' -'schÃŧr' -'schÃļner_' -'schwäch' -'schwinden' -'schwerwiegende' -'schweigen_' -'schwedische' -'schwarze' -'schwachen_' -'schuss_' -'schuh' -'schte_' -'schmu' -'schlechten_' -'schlecht' -'scheduled_' -'schedule_' -'schaff' -'sceptic' -'scape' -'scandal_' -'sc' -'savers_' -'save' -'saubere_' -'sanf' -'salon' -'sall' -'safeguards_' -'safeguard_' -'sade' -'sacrific' -'sabotage' -'rÃļmische' -'russ' -'rup' -'rule' -'rufe' -'rude' -'ruck_' -'ruck' -'rtung_' -'rtun' -'rts' -'rtig' -'rtei' -'rschu' -'rrung' -'rren_' -'royal' -'rows_' -'routine' -'route' -'rott' -'romantisch' -'rogramm_' -'roe' -'roblem' -'rns_' -'rner' -'rm' -'rlo' -'rlichen_' -'rkte' -'rks' -'rism' -'risiko' -'rische' -'rir_' -'riot' -'rina' -'riff' -'rient' -'rieg_' -'rieben' -'riding_' -'ridge' -'richtung' -'richter' -'rho' -'rha' -'rger' -'revolutionäre_' -'revolution' -'revive' -'revi' -'rev' -'retreat' -'reth' -'restaur' -'respond' -'resorts_' -'reservier' -'resemble' -'resa' -'repudiat' -'republikanische' -'republic' -'repri' -'repressive' -'replacing_' -'repla' -'repl' -'repe' -'repair' -'renovierte_' -'renewal_' -'rene' -'renamed_' -'removal_' -'reminder' -'remedy_' -'remark_' -'relinquish' -'religiÃļser_' -'relies_' -'relentless' -'releases_' -'relaxed_' -'rejects_' -'reitet_' -'reit' -'reise' -'reinsta' -'reiheit_' -'reihe' -'reichende' -'reiche_' -'reibungslose' -'reibung' -'reguliert' -'regulator' -'regrettabl' -'registrierte' -'registriert_' -'regain' -'refuses_' -'refund' -'reformen_' -'reflektier' -'refi' -'reem' -'reductions_' -'redistribution_' -'recycle' -'recyc' -'recreational_' -'reconciliation_' -'recognizing_' -'recognis' -'recipient_' -'rechtmäßig' -'recepti' -'reca' -'rebuil' -'reat' -'reason' -'realm_' -'realistische_' -'realistische' -'reaffirm_' -'reacti' -'rdnung' -'rbeiten_' -'ray_' -'raums_' -'rato' -'rationale_' -'rassi' -'rascher' -'rar_' -'rangige' -'random_' -'rage' -'radikale_' -'radikale' -'raci' -'rable_' -'quot_' -'quot' -'quier' -'quet_' -'quenz' -'quar' -'quanti' -'py' -'punishment' -'pulat' -'puede' -'publishe' -'publish_' -'publications_' -'pub_' -'pter' -'psychologisch' -'prozesses_' -'provoke_' -'provok' -'provis' -'prototype' -'protocol_' -'prospect' -'prose' -'propriet' -'proportion' -'propell' -'propag' -'prompt_' -'prompt' -'projection' -'prohibiti' -'profile' -'problemlos_' -'prob' -'privileges_' -'privater_' -'print' -'pric' -'preview_' -'pretext_' -'preserving_' -'preserv' -'preparatory_' -'preparations_' -'preisen_' -'predominan' -'prede' -'pred' -'praktizier' -'pragmatism_' -'pragmatic_' -'pparat' -'pp_' -'powered_' -'pour' -'potenziell_' -'potentiell' -'pote' -'poster' -'positiv' -'pos' -'populists_' -'populism_' -'policing_' -'poison' -'pod' -'poch' -'po_' -'plic' -'ples_' -'pleas' -'playground_' -'platte' -'plague' -'pixels_' -'pivot' -'pit_' -'pin_' -'pile' -'piele' -'piegel' -'pia' -'physische' -'physics_' -'physician' -'physic' -'phr' -'philosophische' -'phasing_' -'ph_' -'pflichte' -'pflege' -'pfer_' -'pfer' -'pfei' -'petro' -'pest' -'pes' -'perver' -'pert_' -'personnes_' -'personen_' -'person' -'pers_' -'perm' -'periode_' -'perform' -'perfekte_' -'perfe' -'perceptions_' -'pent' -'penin' -'pend' -'penal' -'pekt' -'peers_' -'peasant' -'pea' -'patri' -'patien' -'pate' -'paste' -'passt_' -'passports_' -'passive_' -'partnerschaftliche' -'partner' -'paraly' -'parag' -'paradi' -'pap' -'panisch' -'panic_' -'palt' -'palm' -'pair_' -'painter' -'owed_' -'overp' -'overlooks_' -'overflow_' -'overco' -'ove' -'outsider' -'outlook_' -'outermost_' -'outbreaks_' -'ousness_' -'ounce' -'oun' -'oul' -'oud_' -'otten_' -'ota' -'oster' -'ossi' -'osk' -'osit' -'osh' -'oses_' -'orum_' -'orth' -'orm' -'orientation' -'orie' -'organisms_' -'organisierte_' -'organisi' -'organische' -'oren' -'ordnungsp' -'ordnet_' -'ordin' -'order' -'ord' -'orche' -'orati' -'optionen_' -'optimize' -'oppressi' -'opponent_' -'opoulo' -'opol' -'opfer' -'ool' -'onom' -'onia_' -'ones' -'oned_' -'ondo_' -'onar' -'onali' -'onale' -'omo' -'omme' -'omiss' -'oming_' -'omen_' -'ologi' -'olis' -'olf' -'olen_' -'oke' -'oka' -'oire' -'oil' -'ohlen' -'ograph' -'ogat' -'offre' -'offensive_' -'offenkundig_' -'offenbar' -'odes_' -'oct' -'ock_' -'och_' -'ocean' -'occupy_' -'occasional' -'observe_' -'obligator' -'objection' -'oberste' -'oberfläch' -'nÃŧtzige' -'nÃŧtz' -'nÃļrdlich' -'nÃļ' -'nähere' -'nzu' -'nza' -'nw' -'nver' -'nutzte_' -'nuts_' -'nursing_' -'nungs' -'nto_' -'nth' -'nsp' -'nsic' -'nsh' -'nschl' -'npro' -'npo' -'notwendige' -'notebook' -'note' -'nostalgi' -'normale' -'norm_' -'nor' -'nomination' -'noble' -'nnung_' -'nner' -'nli' -'niÃąos_' -'nium' -'nisses_' -'nings_' -'nigt_' -'nige' -'niederl' -'nieder' -'nick' -'nichts' -'nges' -'ngebot' -'newspapers_' -'nevertheless_' -'neueste_' -'nette_' -'nerv' -'nerat' -'neighbours_' -'neiden' -'nehm' -'ndige' -'ndete_' -'ndet_' -'ndent' -'ndel_' -'ndan' -'ndab' -'nda_' -'nda' -'ncia' -'nch' -'nced_' -'naval_' -'nato' -'nationals_' -'nationalistische_' -'nan_' -'nami' -'nament' -'nahme' -'nag' -'nachzudenken_' -'nachlassen_' -'nachhaltiges_' -'nachhaltige' -'mÃŧndliche_' -'mÃŧnd' -'mÃļglichkeit_' -'mäßig' -'männ' -'mw' -'muta' -'muster' -'musik' -'murdered_' -'multipli' -'multilaterale_' -'mten_' -'mst' -'mpl' -'mpf' -'motivated_' -'motivat' -'motiv' -'mortality_' -'morally_' -'monate' -'mol' -'module_' -'modular' -'modernis' -'moderat' -'modal' -'mobilize' -'mmung' -'mmer' -'mliche' -'mk' -'mite' -'mitb' -'misleading_' -'mish' -'mis_' -'ministr' -'mining_' -'mines_' -'mineral_' -'miner' -'minate' -'mina' -'millenni' -'militant_' -'miete' -'mh' -'mg' -'mf' -'mexikanischen_' -'metropolis' -'metal_' -'mers_' -'merit' -'mercury_' -'mera' -'mention' -'mens' -'meldungen_' -'melden_' -'meist' -'mei' -'mehrmals_' -'mehrerer_' -'mee' -'medic' -'media' -'mec' -'measurements_' -'meanwhile_' -'meantime_' -'maßgebliche' -'maßgeb' -'maximiz' -'maximale' -'max' -'mater' -'matches_' -'massive' -'marvel' -'maro' -'markiert_' -'marina_' -'mari' -'marginal_' -'manual_' -'mans' -'manipulier' -'mandat' -'managements_' -'mala' -'magnet' -'macht' -'machine' -'mache' -'läufige' -'läufig_' -'lässig' -'längere_' -'ländliche_' -'läh' -'lädt_' -'läche' -'lz_' -'lz' -'luss_' -'lungs' -'lumi' -'luggage_' -'lten' -'lst_' -'lovers_' -'lous' -'loose_' -'lokal' -'logistics_' -'lod' -'lls_' -'llin' -'llect' -'literar' -'liste' -'lische' -'linken_' -'linke' -'linguist' -'lineare' -'limit' -'limi' -'lifted_' -'lifetime_' -'lifel' -'ließ' -'lien_' -'liegenden_' -'liegende_' -'lieferung' -'lichten_' -'licht_' -'licht' -'licensing_' -'licenses_' -'licence_' -'liberty_' -'liberation_' -'liberate' -'lib' -'liate' -'liaison_' -'lgen_' -'leur' -'letzt' -'letting_' -'lens_' -'leistet_' -'leichte_' -'leich' -'lehr' -'legitimi' -'legislate' -'lecture' -'lected_' -'lebte' -'lebenden_' -'leb' -'lder' -'lber' -'layo' -'lay' -'lav' -'lauten' -'laureate' -'laugh_' -'laub' -'latt' -'lateralis' -'lase' -'lare' -'landwirtschaftlichen_' -'lande' -'lament' -'lak' -'laime' -'lah' -'ladi' -'labor' -'lab' -'kÃŧnstliche' -'kÃŧ' -'kÃļstliche' -'kÃļrper' -'kÃļnig' -'käufe' -'kwa' -'kw' -'kung_' -'kunden' -'kub' -'ktur' -'ktivitäten_' -'ktion' -'ktes_' -'kst' -'kris' -'krie' -'kredite_' -'kreativ' -'krank' -'kostenloses_' -'korrupt' -'koordinieren_' -'kooperieren' -'konvertier' -'kontinentale' -'konsul' -'konsequente_' -'konsequent_' -'konse' -'konfigur' -'komplizierten_' -'komplex' -'kommunizier' -'kommunistischen_' -'kommuni' -'komfort' -'kolonial' -'kolo' -'kok' -'knowing_' -'kluge' -'klini' -'klause' -'klassi' -'klam' -'kis' -'kids_' -'kf' -'kets' -'keln' -'keinesfalls_' -'kehr_' -'keh' -'keeper' -'kba' -'kauft' -'kation' -'kategori' -'kapazität' -'kane' -'kandid' -'kammer' -'kamer_' -'kame' -'kalte' -'kali' -'kal_' -'kabel' -'jÃŧngere' -'jÃŧng' -'justifi' -'just' -'jurisdiction_' -'jumped_' -'judged_' -'joy_' -'joy' -'journe' -'journalist_' -'journalism_' -'jour_' -'jou' -'joga' -'jk_' -'jin' -'jes' -'jenige_' -'jenem_' -'jack_' -'jac' -'iÊ' -'izier' -'ivo' -'ivit' -'ivi' -'iven_' -'itÊ' -'itive' -'ites_' -'itat' -'italienische_' -'itali' -'isten' -'isoliert' -'isolation_' -'islamische' -'isla' -'isierte' -'ises_' -'irresponsib' -'irrep' -'irrelevant_' -'irku' -'iration_' -'ique_' -'ipl' -'ioni' -'ionell' -'ional' -'invoke' -'invites_' -'inventor' -'invented_' -'invade' -'intru' -'into' -'interview_' -'interview' -'interoperab' -'international' -'internali' -'interiors_' -'interior_' -'interinstitutionelle' -'interessier' -'interessante' -'interconnected' -'interco' -'intensivier' -'intensiver_' -'integr' -'insurgent' -'institutione' -'instan' -'inspiring_' -'inspiriert' -'insolvent_' -'insolvency_' -'insof' -'insisting_' -'insid' -'insect' -'insc' -'inputs_' -'innovators_' -'innovativ' -'inno' -'inneren_' -'innere_' -'innehat' -'inn_' -'inn' -'inland_' -'inklusive_' -'injustice' -'inium' -'initiate_' -'inheritance_' -'inher' -'inhal' -'inhaftiert' -'ingr' -'inger_' -'ingen' -'informi' -'informelle' -'informationen_' -'inflation' -'infections_' -'infe' -'infa' -'inf' -'ineff' -'ined_' -'individuell_' -'indispensable_' -'indische' -'indirekt_' -'indifferen' -'indic' -'indexe' -'indebted_' -'incorporat' -'incompeten' -'incomp' -'incar' -'inanzierung_' -'importe' -'impo' -'impetus_' -'imperial' -'imperfect_' -'impair' -'impa' -'immo' -'immigrant_' -'imm' -'imal' -'image' -'illusion' -'illnesses_' -'iller' -'ilität_' -'ilis' -'iling' -'ilia' -'ili_' -'iles_' -'ika_' -'ihre' -'ih' -'ight' -'igenen_' -'igend_' -'igan' -'ifor' -'ification' -'ific' -'iesen_' -'ierungsa' -'ierter_' -'ielt_' -'iegelt' -'iebene' -'idor' -'ido_' -'idl' -'ideologisch' -'ideologies_' -'identit' -'ident' -'ict' -'icon' -'ichtig' -'iches_' -'ibility_' -'ibi' -'iar' -'iant' -'hÃŧt' -'hÃļrigkeit_' -'hÃļrde' -'hÃļheres_' -'hÃļherer_' -'hÃļf' -'hÃļchste' -'hÃļ' -'häufiger_' -'händler_' -'hypoc' -'husband_' -'humiliati' -'htl' -'hti' -'hrop' -'hov' -'hotele' -'hosted_' -'hospital_' -'horses_' -'horse_' -'hors' -'horror' -'hore_' -'honest' -'homeowners_' -'holt' -'holdings_' -'hob' -'ho_' -'hnung' -'hnt_' -'hlte' -'hli' -'hinterlassen_' -'hilfsbereit_' -'highway_' -'highlighted_' -'hieß' -'hide_' -'hide' -'hid' -'heti' -'hetero' -'hesita' -'hervorzu' -'herum' -'herrschen_' -'herkÃļmmlichen_' -'hering_' -'herd' -'herbe' -'herb' -'herausf' -'hene_' -'henden_' -'helme' -'heiz' -'heimischen_' -'heim_' -'hebt_' -'heated_' -'hav' -'hauses_' -'haul_' -'hatred_' -'hath_' -'hase' -'harten_' -'hars' -'harmonisiert' -'harg' -'harbour_' -'hanging_' -'hang_' -'handic' -'handful_' -'hande' -'hamme' -'hall' -'halbe' -'hag' -'haftes_' -'hacke' -'hace' -'habits_' -'habit' -'haber_' -'haa' -'gÃŧter_' -'gÃŧnstigen_' -'gän' -'gutem_' -'guiding_' -'guide' -'guaranteeing_' -'gster' -'grundsätzliche' -'grundlage' -'grub' -'großzÃŧgig' -'grow' -'groundwater_' -'grossen_' -'grobe' -'griffen_' -'grie' -'grew_' -'greed_' -'gravierende' -'grau' -'grati' -'grasp_' -'graduate_' -'gotten_' -'gnet' -'glÃŧcklich_' -'gly_' -'glu' -'globally_' -'glim' -'glichkeiten_' -'glichen_' -'gle_' -'glaubwÃŧrdig_' -'glaubte' -'gische' -'gis' -'gion' -'gig' -'giant_' -'ghter_' -'ghte' -'gh' -'ggi' -'gewor' -'gewann_' -'gewalttätige' -'getroffene' -'getrieben' -'geti' -'getestet_' -'getau' -'gesunden_' -'gestÃŧtzt_' -'gesti' -'gesteckt_' -'gestaltung_' -'gesetzlichen_' -'gesetzlich_' -'gesetzgebung_' -'geschäft_' -'gescho' -'geschla' -'gesche' -'gescha' -'gesamt' -'gesa' -'gerufen_' -'geringste' -'geringfÃŧgig' -'gericht_' -'gerettet_' -'gerecht' -'geplanten_' -'geplante' -'geographical_' -'geographic' -'geografisch' -'gent_' -'generating_' -'genehmigt_' -'genehmig' -'genauen_' -'genaue_' -'gemÃŧtliche' -'gemi' -'gemeinsame' -'geladen' -'gela' -'geklärt_' -'geistigen_' -'geistig' -'gehÃļre' -'gehÃļr' -'geho' -'geheime' -'gegenseitige' -'gegebene' -'gefÃŧg' -'gefällt_' -'gefäl' -'gefolgt_' -'gefl' -'gefangen_' -'gefahr' -'geeinigt' -'geehrte_' -'gee' -'gedenkt_' -'gebr' -'gebot' -'gebildete' -'gebeten_' -'geben' -'gau_' -'gathered_' -'gate_' -'gare' -'gant' -'gans_' -'gan_' -'galax' -'fÃŧrs' -'fÃŧnfzig_' -'fÃŧllen_' -'fÃŧll' -'fÃŧhrender_' -'fÃļrm' -'fähig_' -'fus' -'furnishings_' -'fung' -'ful' -'fuer' -'ftet_' -'fst' -'fruitful_' -'froh' -'friend' -'friedlich' -'freundschaft' -'frequent_' -'fremden' -'freiheit' -'freight_' -'freies_' -'frankly_' -'foun' -'fotografi' -'fossile' -'fos' -'forums_' -'fortzusetzen_' -'forts' -'fors' -'formula_' -'formula' -'formation_' -'forma' -'foreseeable_' -'forecasts_' -'folgten_' -'folgender' -'fluss_' -'flus' -'flung_' -'flugzeuge' -'flotte' -'flop' -'flood' -'float' -'fließen_' -'fließ' -'fliehen_' -'fleet' -'fledged_' -'flat' -'fl' -'fixes_' -'fishe' -'fisch' -'firme' -'fing_' -'finanzier' -'financially_' -'filter_' -'filt' -'fighter' -'fifteen_' -'fg' -'ffnung' -'ffl' -'feu' -'feststellte' -'festhalten_' -'feste_' -'fes' -'fertigt_' -'fenster_' -'fence' -'felder_' -'feine' -'feind' -'feierlich' -'fehlge' -'fehler' -'fehlende_' -'feedback_' -'feared_' -'favorit' -'faszinierend' -'fastest_' -'fasst_' -'farme' -'farb' -'fam' -'falsche' -'fai' -'fahrer_' -'fade' -'facet' -'ezi' -'exze' -'expose_' -'exporting_' -'exporters_' -'export' -'exploring_' -'explore' -'exploit_' -'experiment' -'expects_' -'expansive' -'exklusive' -'exert_' -'exemplary_' -'exe_' -'exclude' -'exclu' -'exci' -'excellence_' -'excel' -'exceeded_' -'examined_' -'evÃļlkerung_' -'evolving_' -'evolved_' -'evolution' -'evade' -'ev' -'eue' -'eu_' -'ett_' -'etie' -'etablierte' -'eta_' -'esu' -'estr' -'esten_' -'establish' -'essor' -'esh' -'esen' -'erÃļrtert_' -'erzählt' -'erzeugung_' -'erwäh' -'erwirtschafte' -'erweiterte_' -'erw' -'erungen_' -'erti' -'erte' -'ersÃļnlich' -'erstes_' -'erstens_' -'eror' -'erneuerbaren_' -'ernen' -'ermä' -'ermordet' -'ermittelt' -'erlä' -'erly_' -'erleb' -'erkennbar_' -'erkannte' -'eriu' -'erin_' -'eries_' -'erier' -'eria' -'erge_' -'erg_' -'erfassen_' -'erfass' -'erfahrene' -'erbringen_' -'erarbeitete' -'equit' -'eping_' -'epic' -'epi' -'eous_' -'enzt' -'envi' -'entz' -'enttäuscht' -'entste' -'entstandene' -'entstand_' -'entsprechender_' -'entspannten_' -'entschä' -'entschuldigen_' -'entschlossene' -'entru' -'entries_' -'entrepreneurs_' -'ento' -'entn' -'entlich' -'entlassen_' -'entirety_' -'ential_' -'enthusiastic_' -'entfalten_' -'ensured_' -'ensh' -'enorm_' -'enien' -'enic' -'enheit_' -'enhancement' -'engines_' -'engineer' -'engere' -'engen_' -'engagieren_' -'energisch' -'energies' -'endless' -'endete' -'endemi' -'encr' -'encountered_' -'enc' -'enacted_' -'emul' -'empt' -'empire_' -'empfunden' -'empfehle' -'empe' -'emon' -'emergen' -'emer' -'embryon' -'embodies_' -'ember' -'embar' -'emails_' -'emach' -'elz' -'elu' -'elten_' -'ellte' -'ellt_' -'elit' -'elin' -'elig' -'elevat' -'elet' -'elemente' -'elektronische' -'elektro' -'eleganten_' -'electronically_' -'elder' -'eld' -'eke' -'ej_' -'ej' -'eitige' -'eiten' -'eistung' -'eist_' -'einzurichten_' -'einzuh' -'einzelner_' -'einstellungen_' -'einse' -'einnehmen_' -'einmalige_' -'eink' -'einiges_' -'einhergehen_' -'einheitlich' -'eingeschlagen' -'eingenommen' -'eingeg' -'eingef' -'einfÃŧhren_' -'einflussreiche' -'eindrucksvoll' -'eighth_' -'eigent' -'eigenständige' -'eigenem_' -'eiche' -'eiben_' -'ehung' -'ehrung' -'ehl' -'ehen' -'egia' -'egene' -'efo' -'effizientere' -'effizienter_' -'effen' -'effektivere' -'edoni' -'edo' -'editing_' -'ede_' -'ectual' -'ecti' -'ecosystem_' -'echsel' -'echo' -'echende' -'echa' -'ebung_' -'ebt_' -'ebi' -'ebenen_' -'eb_' -'eb' -'dÃŧstere' -'dänische' -'dysfunction' -'dynast' -'dynamisch' -'durchz' -'durchsch' -'durchs' -'durchl' -'durchd' -'durable_' -'dumping_' -'dual_' -'dscha' -'dsch' -'drÃŧckt' -'drucke' -'dron' -'drink' -'dress_' -'dreißig' -'drehen_' -'draws_' -'drastisch_' -'drain_' -'dozen' -'dox' -'downloading_' -'downloade' -'dorm' -'doppelte' -'doppel' -'door' -'donors_' -'dominat' -'dogmati' -'documentation_' -'diving_' -'divert' -'diver' -'disturb' -'distinctive_' -'distanzier' -'distan' -'dissident' -'disse' -'disruption_' -'disposi' -'dispens' -'disparities_' -'dispar' -'disorder' -'dismissed_' -'dismiss' -'dismantle' -'discriminat' -'discount_' -'disappointed_' -'disappear_' -'disagreements_' -'disadvantage' -'direct' -'dir' -'diplomatische_' -'diplomati' -'dingungen_' -'dif' -'dieselben_' -'diesbezÃŧglich_' -'diente' -'dienste' -'dictator' -'dicta' -'dichte' -'dica' -'diagnosis_' -'dia_' -'devastat' -'deutlichen_' -'deut' -'detected_' -'detect_' -'detai' -'destructive_' -'destinations_' -'destabilisier' -'desse' -'desperat' -'designierte' -'designation_' -'designated_' -'dert' -'derselben_' -'derogations_' -'dero' -'dern_' -'dern' -'deriv' -'derartiges_' -'depressing_' -'depreciation_' -'deprec' -'denz' -'dental' -'denkbar' -'deni' -'demonstrators_' -'demons' -'demograph' -'democratically_' -'democrati' -'delte' -'delic' -'delays_' -'dek' -'deinen_' -'deinem_' -'definierte_' -'defending_' -'defeated_' -'defeat' -'defaults_' -'deepening_' -'decreas' -'debtors_' -'debe' -'debattier' -'daughter_' -'dauerhaft_' -'daten' -'darstellte_' -'dankbar_' -'dank' -'dane' -'dancer' -'damaligen_' -'damages_' -'dale_' -'dad' -'dable_' -'cynical' -'customize' -'cussi' -'curtail' -'cura' -'cum' -'cult' -'culminate' -'cui' -'ctua' -'ctr' -'cting_' -'cta' -'crystal' -'crowd_' -'criticise' -'critici' -'criterion_' -'criminals_' -'creat' -'cover' -'cotton_' -'cosm' -'correlation' -'correct' -'cornerstone_' -'coral_' -'cooking_' -'cooked_' -'convi' -'contribut' -'continual' -'continents_' -'contest' -'contemp' -'contamina' -'container' -'consultati' -'consultan' -'consul' -'constitut' -'constituencies_' -'consolidate' -'consol' -'consist_' -'consequently_' -'conscious_' -'consci' -'conque' -'conjunction_' -'congratulations_' -'cong' -'confort_' -'conflicting_' -'confirm' -'confined_' -'configured_' -'confident_' -'conductor' -'condo' -'conditione' -'concessions_' -'concepts_' -'conception_' -'concentrati' -'conceivabl' -'comprehensi' -'composer' -'compo' -'complain_' -'complacen' -'competit' -'competent_' -'compare_' -'comparative' -'comparable_' -'commun' -'commonly_' -'commenta' -'comment' -'come' -'combines_' -'colorful' -'colonial_' -'college_' -'collections_' -'coin_' -'cogniti' -'cock' -'cob' -'coach_' -'climb_' -'cliente' -'classi' -'cki' -'ckelt' -'civilisation' -'circulation_' -'cip' -'ciona' -'cion' -'cio_' -'cina' -'cigarette' -'cian_' -'chur' -'chuld' -'chts_' -'chter_' -'christ' -'choi' -'chnung' -'chnitt' -'chnet' -'chliche' -'chkeit' -'chk' -'chis' -'chirm' -'chir' -'chinesisch' -'ches' -'chat_' -'chat' -'chase' -'charakter_' -'characterized_' -'chapter_' -'chaos_' -'chairs_' -'chains_' -'chaften_' -'ces' -'certification_' -'certificates_' -'cerca' -'cer_' -'cement_' -'celebration' -'celebrated_' -'cautious_' -'cater' -'caste' -'carp' -'cardi' -'card' -'capitalist' -'capability_' -'cann' -'cancell' -'cance' -'cana' -'call' -'cabl' -'bÃŧrokratische' -'burst_' -'burne' -'burge' -'bureaucrats_' -'bureaucrac' -'bunte' -'bundes' -'builds_' -'bud' -'bubbles_' -'bst' -'brutal' -'bruch' -'brot' -'broaden' -'britische' -'brilliant_' -'bridge' -'brick' -'brew' -'breites_' -'breiter_' -'breakthrough_' -'breakdown_' -'breach_' -'brauchte' -'brands_' -'branch_' -'brake' -'brachten_' -'brac' -'boys_' -'boxes_' -'boutique_' -'bottle_' -'boss' -'borrowe' -'borne_' -'boosted_' -'bonus_' -'bone_' -'bombings_' -'boli' -'bold' -'boden_' -'bn' -'blut' -'blicken' -'blic' -'blend' -'blem' -'blaue' -'biti' -'bish' -'biot' -'bindung_' -'bills_' -'billig' -'bill' -'bilie' -'bilde' -'bicycle' -'bic_' -'bezeichnete' -'bewertet_' -'bevÃļlkerung' -'beunruhigt_' -'beu' -'beträchtliche_' -'beträ' -'betriebs' -'betreffend_' -'betray' -'beton' -'beti' -'besuchten_' -'bestÃŧ' -'bestreb' -'besto' -'bester_' -'besseres_' -'besonderem_' -'beson' -'besiege' -'besie' -'besichtig' -'besi' -'besetz' -'beseitig' -'beschwer' -'beschw' -'beschuldigt' -'beschrei' -'beschleunigte' -'bescha' -'besa' -'berÃŧhmteste' -'berÃŧhmt' -'berÃŧc' -'beruh' -'berufen_' -'bers_' -'berichtete_' -'bereitstellen_' -'bereich' -'beratung' -'berater' -'berate' -'benÃļtigten_' -'beneficiar' -'bend_' -'bemÃŧh' -'bemerkenswert_' -'belä' -'beln_' -'beliefs_' -'beliebige_' -'belie' -'belgischen_' -'belegen_' -'bekämpft_' -'bekanntlich_' -'bekannter' -'beitreten_' -'beid' -'behindertenfreundliche_' -'behinderte' -'beharr' -'behandlung' -'begrÃŧß' -'begrÃŧnde' -'begrenzten_' -'begleiten_' -'begin' -'begehen_' -'begeh' -'begegnet_' -'begangen_' -'befÃŧrchte' -'befri' -'befrei' -'befo' -'beer_' -'bedeutete_' -'bede' -'bedauerlich_' -'beautifully_' -'beat_' -'bearbeiten_' -'beam_' -'beachtet_' -'beach' -'beabsichtigte' -'bd' -'baute' -'basierte_' -'basierenden_' -'basierend' -'bargaining_' -'barbari' -'bankruptcy_' -'bank' -'bana' -'ballot_' -'bak' -'baggage_' -'bad' -'bac' -'aym' -'axi' -'await_' -'await' -'avoiding_' -'avant' -'außergewÃļhnlichen_' -'außenpolitischen_' -'automobile_' -'automatisier' -'automatischen_' -'automat' -'autob' -'authori' -'aute' -'auszuÃŧben_' -'auszuweiten_' -'auszulÃļsen_' -'auszud' -'auszubauen_' -'ausz' -'aust' -'aussprechen_' -'aussi' -'aussetzen_' -'ausschlaggebend' -'ausländische' -'ausha' -'ausgi' -'ausgez' -'ausgewählten_' -'ausgewählte_' -'ausgeht_' -'ausgef' -'ausfÃŧhrlich_' -'ausfallen_' -'ausbreite' -'ausbauen_' -'aur' -'aufz' -'aufweisen_' -'aufw' -'aufrichtig' -'aufregende' -'aufnahme_' -'auflÃļs' -'aufgreif' -'aufeinander' -'auern_' -'audiovisual_' -'auction' -'attraktive' -'attracti' -'attitudes_' -'attachment' -'atori' -'atoren_' -'aton' -'astronom' -'assumes_' -'associate' -'assist' -'assessments_' -'assert_' -'asser' -'assembly_' -'assault_' -'aspire' -'arungen_' -'artner' -'articulat' -'arrangement_' -'arrang' -'arma' -'arische' -'aris_' -'aries_' -'argumentiert_' -'argues_' -'argentinische' -'ardin' -'arde' -'arati' -'arat' -'arabisch' -'apu' -'approx' -'approve_' -'appr' -'apie_' -'apie' -'anzubieten_' -'anybody_' -'anxious_' -'antly_' -'anstr' -'anstehende' -'ansi' -'anschlie' -'ansatz_' -'anrichten_' -'anol_' -'anna' -'ann_' -'anme' -'anleg' -'ank_' -'animation' -'angig' -'anget' -'angestrebte' -'angeschlossen_' -'angenommene' -'angefÃŧhrt' -'angebotenen_' -'angebot_' -'anfällig' -'anforderungen_' -'anfall' -'anerkannte' -'anen_' -'anen' -'andi' -'andes_' -'anden_' -'andelt_' -'andel_' -'andard' -'anbieter' -'anbet' -'anb' -'analyses_' -'amt_' -'ams' -'amplifie' -'amerikanis' -'ambassador' -'altung_' -'alth' -'altern' -'alm' -'allergi' -'allerg' -'aller' -'alit' -'alienat' -'alia_' -'algorithms_' -'alance' -'akzeptier' -'aku' -'aktivitäten_' -'aktiv' -'aktie' -'aks_' -'aki_' -'aken' -'ake' -'ais' -'agrees_' -'agree' -'agoni' -'ago' -'afrika_' -'afi' -'affluen' -'affair_' -'adviser' -'additi' -'addict' -'adapting_' -'acu' -'activate_' -'activate' -'acqu' -'acknowledged_' -'achte_' -'achse' -'achiev' -'accru' -'accord' -'accomplish_' -'acci' -'accessories_' -'access' -'accelerate' -'abzuwe' -'abzule' -'abzielt_' -'abys' -'abwer' -'abundant' -'absolut' -'absicht' -'abschließend_' -'abra' -'abnehmen' -'ablauf' -'abk' -'abilit' -'abide_' -'abhängt_' -'abha' -'abh' -'abgez' -'abgestimmte' -'abgeschnitten_' -'abgele' -'abgegeben' -'abbau_' -']]), ' -'ZÃŧrich_' -'ZÃŧge' -'Zyp' -'Zyklus_' -'Zwischenzeit_' -'Zuwe' -'Zuwanderer' -'Zusätzlich' -'Zuständ' -'Zusch' -'Zusammenst' -'Zusammensetzung_' -'Zusammenschluss_' -'Zusammens' -'Zusammenh' -'Zusa' -'Zuri' -'Zulassung_' -'Zugangs' -'Zuf' -'Zube' -'Zoom' -'Zivilisation_' -'Zin' -'Zerfall' -'Zentrala' -'Zeitschrift_' -'Zeitraums_' -'Zeita' -'Zar' -'YouTube_' -'Yemen' -'Yell' -'Yar' -'Yam' -'Yal' -'Yacht' -'WÃŧnsch' -'Wählerschaft_' -'Wut_' -'Wunder' -'Working_' -'Word' -'Wonder' -'Wohnung_' -'Wochenende_' -'Wissenschaftler' -'Wissens_' -'Wis' -'Wirtschaftsm' -'Wirtschaftsa' -'Wirk' -'Wir' -'Willkommen_' -'Wild_' -'Wiederaufnahme_' -'Wiederauf' -'Widerstands_' -'Which_' -'Wettbewerbsvorteil' -'Wett' -'Westens_' -'Werkst' -'Wenig' -'Wen_' -'Weltmarkt' -'Weißbuch_' -'Wednesday_' -'Wechsel' -'Weber_' -'Wave' -'Wasserk' -'Wash' -'Want_' -'Wahrscheinlich' -'Waffenstillstand_' -'Waf' -'Wachstumsraten_' -'Wachstumspakt' -'VÃļlkermord' -'Vulkan' -'Vs_' -'Vorwand_' -'Vorteilen_' -'Vorsicht_' -'Vormittag_' -'Vorla' -'Vorhersage' -'Vorgang_' -'Volle' -'Volksk' -'Voice_' -'Vize' -'Visu' -'Visa_' -'Vis' -'Violence_' -'Vinc' -'Villa' -'Vila' -'Vig' -'Verwundbarkeit' -'Verwend' -'Verwe' -'Vertriebs' -'Vertrieb' -'Vertiefung_' -'Versäumnis' -'Versuchung_' -'Versuchen_' -'VerstÃļße' -'Verstä' -'Verschuldung_' -'Verschlechterung' -'Verschl' -'Vers' -'Verp' -'Vernichtung_' -'VermÃļgenswerte_' -'Vermächtnis' -'Vermittlung_' -'Verme' -'Verlang' -'Verlagerung_' -'Verkehrsanbindung_' -'Verkaufs' -'Veridian_' -'Verhältnis' -'Verhandlungstisch_' -'VergÃŧtung' -'Vend' -'Variable' -'Value' -'Valent' -'VIP_' -'VII' -'VE' -'UnterstÃŧtz' -'Unterb' -'Untera' -'Universal_' -'UnglÃŧck' -'Under' -'Und' -'Umweltpro' -'Umweltausschuss_' -'Umk' -'Ultimate' -'Ukrain' -'Uhr' -'Ub' -'UNMI' -'UNIC' -'UNHCR_' -'UND_' -'UL' -'UA' -'TÃŧrk' -'TÃŧren_' -'TÃļtung_' -'Typi' -'Type_' -'Turni' -'Tul' -'Tud' -'Tsu' -'Ts' -'Trumps_' -'Truc' -'Trock' -'Treibhausgase' -'Trav' -'Traum' -'Trau' -'Transform' -'Transatlanti' -'Tram' -'Train_' -'Trade' -'Toyota_' -'Tow' -'Touristen' -'Toulouse_' -'Toro' -'Tony_' -'Ton_' -'Toleranz_' -'Tisch' -'Tip' -'Tiger' -'Tia' -'Thyssen' -'Thom' -'Therma' -'Thatcher_' -'Than' -'Thal' -'Tha' -'Textilien_' -'Textes_' -'Terrorismus' -'Terra' -'Tennis_' -'Tendenzen_' -'Templates_' -'Telekommunikations' -'Technologies_' -'Techno' -'Techniken_' -'Tea_' -'Tausch' -'Tauche' -'Tatsachen_' -'Tastatur' -'Tas' -'Tarif' -'Table' -'Tabellen_' -'TY_' -'TG' -'TFT' -'SÃŧdwest' -'SÃĄ' -'Systema' -'Symbol' -'Sydney_' -'Sustainable_' -'Surv' -'Surf' -'Suppo' -'Supp' -'Superior_' -'Super_' -'Suc' -'Subs' -'Subjekt' -'Sub_' -'StÃŧtz' -'StÃŧrme' -'StÃŧcke_' -'StÃļrungen_' -'Sty' -'Stufe_' -'Studierende' -'Student_' -'Stress_' -'Streit_' -'Strassen_' -'Stran' -'Stornierung' -'Stimul' -'Stil' -'Stick' -'Steuerh' -'Stern_' -'Stereo' -'Steinberg_' -'Stefan' -'Stay' -'Stag' -'Stadtteil_' -'Stable_' -'Stabilis' -'Staatsp' -'Staatsa' -'Spyware_' -'Spy' -'Sporta' -'Split' -'Spirit' -'Spin' -'Spezifikation' -'Speise' -'Spee' -'Sparpolitik_' -'Spaltung_' -'Sozialversicherungs' -'Sozialleistungen_' -'Sowjet' -'Southeast_' -'Sonntag_' -'Songs' -'Somit_' -'Some' -'Solana_' -'Sobald_' -'Sno' -'Slow' -'Sli' -'Skript' -'Ski_' -'Skan' -'Sitzungen_' -'Singapur_' -'Singapore_' -'Similar_' -'Silicon' -'Silber' -'Signal' -'Sig' -'Siedlungen_' -'Siedlung' -'Side' -'Short_' -'Sheraton_' -'Shell' -'Sharia_' -'Shari' -'Sham' -'Shaf' -'Seve' -'Seoul_' -'Sender_' -'Sende' -'Senator_' -'Select' -'Selbstver' -'Selbstmord_' -'Selbstbe' -'Sekt' -'Segen_' -'Seen_' -'Securi' -'Seba' -'SchÃŧler_' -'SchÃŧ' -'Schwi' -'Schwester_' -'Schwerpunkt' -'Schweige' -'Schwed' -'Schulungs' -'Schuldner' -'Schuldenerlass' -'Schriften_' -'Schra' -'Schock' -'Schnittstelle_' -'Schmie' -'Schlechte' -'Schlacht_' -'Schlacht' -'Schiefe' -'Schichten_' -'Schauspieler' -'Schatten' -'Schar' -'Schall' -'Schaff' -'Schaf' -'Schaden' -'Schach' -'Scal' -'Savi' -'Satellit' -'Sant_' -'Samstag_' -'Sample' -'Sale' -'Sadly_' -'Sab' -'Saatgut_' -'Saal' -'SZR_' -'SW' -'SSL' -'SP_' -'SOE' -'SMEs_' -'SI_' -'SITE_' -'SIM' -'SF_' -'SB' -'SAP_' -'Ry' -'Rus' -'Rural_' -'Ruhe' -'Rud' -'Rov' -'Route_' -'Roosevelt' -'Romulan' -'Romani' -'Rol' -'Rohstoffe_' -'Roc' -'Road' -'Risk' -'Rio' -'Rindfleisch' -'Richtig' -'Revi' -'Reu' -'Restr' -'Restaura' -'Reserv' -'Rese' -'Repräsentanten' -'Reparatur_' -'Rep' -'Renten_' -'Renov' -'Religionen_' -'Release' -'Reisez' -'Reichweite_' -'Reichtum_' -'RegulierungsbehÃļrden_' -'Registrierung_' -'Regent' -'Regens' -'Rega' -'Redner' -'Rechtsvorschrift_' -'Rechtsprechung_' -'Rechtsakt' -'Rechen' -'Realisierung_' -'Read_' -'Rav' -'Rauchen' -'Ratio' -'Rate' -'Rassen' -'Rass' -'Rang' -'Ramada' -'Raci' -'RP' -'RL' -'RI_' -'Quote' -'Quin' -'Qu' -'Qi' -'QU' -'Pyr' -'Py' -'Pus' -'Purvis_' -'Pur' -'Pum' -'Publikum_' -'Publikation' -'Präventiv' -'Präsidenten' -'Prämien' -'Präf' -'Prozessor_' -'Prozentpunkte' -'Provinz' -'Prom' -'Progress' -'Prognos' -'Products_' -'Product_' -'Produ' -'Proc' -'Problemati' -'Privileg_' -'Privatsphäre_' -'Princes' -'Price_' -'Price' -'Prem' -'Preiss' -'Portug' -'Polizeia' -'Poi' -'Plugin' -'Plug_' -'Plenar' -'Plattform_' -'Platte' -'Platform_' -'Plasma' -'Planet_' -'Pix' -'Pirate' -'Pipe' -'Pinochet_' -'Picture' -'Pick' -'Photo_' -'Phasen_' -'Pfei' -'Petitions_' -'Petitions' -'Petition' -'Pete' -'Peru_' -'Pentax_' -'Pennsylvania_' -'Penis' -'Pav' -'Pauls' -'Pau' -'Patent_' -'Passagiere_' -'Pass_' -'Partition_' -'Partie' -'Parameter' -'Paradox' -'Panzer' -'Panasonic_' -'Palais_' -'Paket' -'Pack' -'PU' -'POS' -'POL' -'PN' -'PM_' -'PLA' -'PE_' -'Oxford' -'Outs' -'Outdoor_' -'Ostens_' -'Oscar_' -'Osa' -'Optimismus_' -'Optimierung' -'Opp' -'Operati' -'Online' -'Omniture_' -'Olympischen_' -'Okt' -'Oil_' -'Ohren_' -'Oft' -'Offiziell' -'Oe' -'Occ' -'Objekte_' -'Objekt_' -'Objekt' -'Obgleich_' -'Oberste' -'Obasanjo_' -'ONE_' -'OE' -'NÃŧ' -'Nuklear' -'Now' -'Novo' -'Nov' -'Notw' -'Norw' -'Nordk' -'Night_' -'Nieders' -'Niederlanden_' -'Niederla' -'Nichtsdestotrotz_' -'News' -'Neuseeland_' -'Neus' -'Neug' -'Neuf' -'Neube' -'Neuau' -'Netzwerke_' -'Network_' -'Nephi_' -'Neb' -'Navigation_' -'Nav' -'Nau' -'Naturwissenschaft' -'Native_' -'Nationalstaaten_' -'Nas' -'Namens' -'Nahverkehr' -'Nachweis_' -'Nachteil_' -'Nachfolge' -'Nachdruck_' -'Nachbarländern_' -'NY' -'NU' -'NR' -'NGO_' -'NC' -'NB' -'MÃŧtter' -'MÃŧhl' -'MÊ' -'Mächt' -'MySpace_' -'Musiker_' -'Musi' -'Museen_' -'Muhammad_' -'Motor_' -'Motive' -'Mosk' -'Moschee' -'Mosa' -'Moral' -'Montp' -'Monterrey_' -'Montage' -'Mone' -'Monaco_' -'Moment' -'Moderni' -'Mitgliedstaat' -'Mitglieds' -'Misserfolg' -'Ministry_' -'Minimum_' -'Mille' -'Mili' -'Miles_' -'Mic' -'Mexikos_' -'Meta' -'Messung_' -'Mercur' -'Mercosur_' -'Memor' -'Meinungsumfragen_' -'Meinungsf' -'Meinungen_' -'Meilen' -'Mehrwert_' -'Mehrheits' -'Mehrere_' -'Megapixel' -'Meeres_' -'Medit' -'Medikamente_' -'Medicine' -'Media' -'McC' -'Maß' -'Maz' -'Maximum_' -'Massagen_' -'Massage' -'Marra' -'Marokko_' -'Marine' -'Marbella_' -'Manuel_' -'Manu' -'Mano' -'Manne' -'Mandel' -'Malware_' -'Malay' -'Mainstream' -'Main_' -'Mahm' -'Magazin' -'Made' -'Maci' -'Machthaber_' -'MID' -'Län' -'Lä' -'Lun' -'Luftverschmutzung_' -'Luftfahrt' -'Lore' -'Lodge_' -'Liquiditäts' -'Lip' -'Linu' -'Line' -'Lind' -'Limit_' -'Lig' -'Lieferanten_' -'Liebe' -'Lichte_' -'Licht' -'License' -'Letztere' -'Letter' -'Lese' -'Leistungsfähigkeit_' -'Leipzig_' -'Leid' -'Leib' -'Legen' -'Lef' -'Lebensstandard_' -'Lebensmittelsicherheit_' -'Lebensbedingungen_' -'Lebanese_' -'Learning_' -'Lead' -'Lava' -'Laufwe' -'Lastwagen' -'Laser_' -'Laos_' -'Lanka_' -'Lange_' -'Lange' -'Lane_' -'Landschaft' -'Lan' -'Lamanites_' -'Lago' -'Labor' -'Label' -'LP' -'LOS' -'KÃŧrzungen_' -'KÃŧrz' -'KÃŧnstler' -'KÃļrperschaft' -'KÃļpfe' -'KÃļnigs' -'KÃļln' -'Kurden' -'Kurd' -'Kura' -'Kunde_' -'Kumari_' -'Kulisse' -'Kuba' -'Kro' -'Krit' -'Kri' -'Kreislauf' -'Kreditw' -'Krediten_' -'Kreativität_' -'Krankenh' -'Kraftfahrzeug' -'Korrektur_' -'Koordinat' -'Konzert' -'Konvertier' -'Kontra' -'Kontakte_' -'Konsultation_' -'Konstruktion' -'Konsolidierung_' -'Konse' -'Konk' -'Konjunkturp' -'Konflikt' -'Konferenzen_' -'Konditionen_' -'Kompetenz' -'Kommenta' -'Komme' -'Kommand' -'Kolla' -'Koll' -'Koizumi_' -'Know' -'Kno' -'Kne' -'Klu' -'Klin' -'Kleinb' -'Klausel' -'Klassifi' -'Klarheit_' -'Kl' -'Kindes' -'Khamenei_' -'Keynesian_' -'Kette' -'Kernwaffen_' -'Kenne' -'Keep_' -'Kay' -'Katzen_' -'Kasten' -'Kasse' -'Karten' -'Karriere' -'Karl_' -'Karibik_' -'KapitalflÃŧsse_' -'Kapell' -'Kapazität_' -'Kanten_' -'Kandidatenländer' -'Kandidat' -'Kamera' -'Kalk' -'KMU_' -'KB_' -'JÃŧng' -'Jus' -'Junta_' -'Juni' -'Juncker_' -'Jugendherberge_' -'Journal' -'Johannes_' -'Joa' -'Jintao_' -'Jim_' -'Jeff' -'Jedoch_' -'Jede' -'Jazz' -'Jarzembowski_' -'Japaner_' -'Jame' -'Jahrzehnts_' -'JPEG_' -'Islamists_' -'Islamis' -'Islami' -'Iri' -'Investmentbank' -'Invasion_' -'Interpret' -'International' -'Intern' -'Interess' -'Intelligen' -'Integrität_' -'Integrationsp' -'Integrat' -'Integ' -'Int' -'Institut_' -'Institut' -'Inside' -'Innere' -'Inkrafttreten_' -'Inhaber_' -'Informations_' -'Industriestaaten_' -'Industrien_' -'Industrial_' -'Independen' -'Impfung' -'Imperi' -'Imagine_' -'Ima' -'Identitäten_' -'IX_' -'ISS' -'INS' -'INA' -'ILA' -'IK' -'IC_' -'ICT_' -'HÃļlle_' -'Hyp' -'Hunde_' -'Hubschrauber_' -'Hot_' -'Hospi' -'Hosni_' -'Horn' -'Honor' -'Home' -'Hom' -'Hold' -'Hoheit' -'Hohe' -'Hoff' -'His' -'Hintergr' -'Highlight' -'Heri' -'Herberge_' -'Heraus' -'Hell' -'Hektar_' -'Heid' -'Hege' -'Heer' -'Heating_' -'Head_' -'Haz' -'Haushaltsdefizit' -'Haushaltsaus' -'Haushalten_' -'Hauptgr' -'Haupta' -'Harbor_' -'Handt' -'Handelss' -'Handelspartner' -'Handelsk' -'Handelsb' -'Handelsabkommen_' -'Haag' -'Gän' -'Gutes_' -'Gur' -'Gun' -'Guatemala_' -'Guar' -'Gruppen' -'Grundwerte_' -'Grundw' -'Gregori' -'Greater_' -'Grant_' -'Grand' -'Granada_' -'Gran_' -'Gramm' -'Governor_' -'Good' -'Gon' -'Goldst' -'Gol' -'Goe' -'Gob' -'Goals_' -'GlÃŧckw' -'Gläubiger' -'Globale' -'Glob' -'Glo' -'Gleiche' -'Glacier_' -'Gitarren' -'Git' -'Gio' -'Gill' -'Gewicht' -'Getreide' -'Gesundheitsschutz' -'Gestern_' -'Geste' -'Gesellschafts' -'Geschäftsreise' -'GeschäftsfÃŧhr' -'Geschäft' -'Geschicht' -'Geschenk_' -'Gesamtnachfrage_' -'Gesamtb' -'Gesamta' -'Gere' -'Georgian_' -'Gent' -'Genf_' -'Geneti' -'Gene' -'GemÃŧse' -'Gemeinw' -'Gemeinschaftsrecht' -'Gemeinsam_' -'Gemeinsam' -'Gelände_' -'Gegebenheiten_' -'Gefängniss' -'Gefä' -'Gedicht' -'Gedanke' -'GebietskÃļrperschaften_' -'Gates_' -'Gastgeber' -'Ganzen_' -'Ganze_' -'Gange_' -'Gaming_' -'Galic' -'Galax' -'Gai' -'Gad' -'Gabriel_' -'GUI' -'GT' -'GO' -'GM_' -'GL' -'GIMP_' -'GBP_' -'GAP_' -'FÃŧrsten' -'FÃŧnftel_' -'FÃŧhrungen_' -'FÃļdera' -'Fäh' -'Fuss_' -'Funktionsweise_' -'Funktionieren_' -'Funktionalität_' -'Funktion' -'Fundament' -'Full_' -'Friedman_' -'Friedh' -'Frie' -'Freund_' -'Fremdenf' -'Fremde' -'Freig' -'FreeBSD_' -'Frattini_' -'Frassoni_' -'Franco' -'Fragment' -'Fou' -'Fotograf' -'Fortun' -'Fortschritts_' -'Fortschritt' -'Formulierung_' -'Forge' -'Forex_' -'Fond' -'Folgendes_' -'Folgende' -'FlÃŧge_' -'Flächen_' -'Fläche' -'Flut' -'Flucht_' -'Florenz_' -'Flie' -'Flick' -'Flasche_' -'Flam' -'Flach' -'Fitnessraum_' -'Fitnesscenter_' -'Fitness' -'Fiskal' -'Fisher' -'Finanzmarkt' -'Finanzinstitutionen_' -'Finanziellen_' -'Finanzielle_' -'Finanzen_' -'Fig' -'Field' -'Fett' -'Fests' -'Festl' -'Ferrer' -'Ferr' -'Fen' -'Feinden_' -'Feind_' -'Fein' -'Fea' -'Fax' -'Fav' -'Farmer' -'Fanati' -'Fahrt_' -'Face' -'FU' -'FRE' -'FIFA_' -'FF' -'Exze' -'Exten' -'Experimente' -'Expe' -'Expansion_' -'Exo' -'Existenz' -'Euros' -'Europ' -'Eure' -'Ethi' -'Establish' -'Eskalation_' -'Erwähnung_' -'Erwägung' -'Erwe' -'Erwartung_' -'Erste_' -'Erste' -'Erscheinungs' -'Erscheinung_' -'Ersatz' -'Ernährungs' -'Ernst' -'Erlebnis_' -'Erle' -'Erika_' -'Ericsson_' -'Erhalt_' -'Erfordernissen_' -'Erfindung' -'Ereignis' -'Erbe' -'Entwicklungsf' -'Entwicklungsbank_' -'Entspannen_' -'Entlassung' -'EntfÃŧhrung' -'Entdecke' -'Enr' -'Englischen_' -'Engel' -'Eng' -'Energieversorgung_' -'Energiever' -'Energieträger' -'Energietechnologie' -'En_' -'Empfang_' -'Emotion' -'Emma' -'Embryo' -'Emb' -'Eliten_' -'Elde' -'Elb' -'Einzelh' -'Eintrag_' -'Eint' -'Einsti' -'Einst' -'Einse' -'Einmischung_' -'Einlagen' -'Einl' -'Eingang_' -'Einblick_' -'Eigentumsrechte' -'Economic' -'Eck' -'East' -'Earl' -'ESM' -'EQ' -'EOS_' -'ENE' -'EME' -'EEC_' -'EB' -'DÃŧ' -'DÃļrfer_' -'Dut' -'DurchfÃŧhr' -'Duke' -'Duc' -'Dua' -'Dry_' -'Drohungen_' -'Drogenh' -'Dri' -'Drac' -'Dorn_' -'Dop' -'Domain' -'Dokumenten_' -'Dod' -'Doctor_' -'Divi' -'Dive' -'Distri' -'Disney_' -'Diskurs' -'Disco' -'Direktinvestitionen_' -'Diploma' -'Dienstleist' -'Dienstes_' -'Dienst' -'Device' -'Derivat' -'Demonstrationen_' -'Demokratis' -'Demokratie' -'Democracy_' -'Demand_' -'Definitionen_' -'Defen' -'Deep' -'Deborah_' -'Datens' -'Datenbank' -'Dat' -'Darstell' -'Danube_' -'Dai' -'DafÃŧr_' -'DVDs_' -'DSLR_' -'DP' -'DI_' -'DF' -'Cyp' -'Cyber' -'Curr' -'Cres' -'Credit' -'Cove' -'Cort' -'Cookies_' -'Cookie_' -'Contrary_' -'Constant' -'Configur' -'Concer' -'Conce' -'Compli' -'Communities_' -'Communi' -'Commen' -'Colom' -'Collection_' -'Colla' -'Cob' -'Coal' -'Cli' -'Clever' -'Clean_' -'Classic' -'Citi' -'Christi' -'Chian' -'Chemie' -'Chelsea_' -'Chechen_' -'Charakter' -'Champions' -'Chair_' -'Certainly_' -'Cate' -'Catalunya_' -'Casa' -'Cars' -'Carol_' -'Carme' -'Cap_' -'Cance' -'Campi' -'Camp_' -'Cambodia' -'Cale' -'Cala' -'Cadiz_' -'CV' -'CS_' -'CF_' -'CER' -'CEOs_' -'C6_' -'BÃŧr' -'BÃŧgelservice_' -'BÃŧcher_' -'BÃŧ' -'Byrne' -'Button_' -'Burning_' -'Bureau_' -'Bundeskanzler' -'Bund' -'Built_' -'Buen' -'Bucha' -'BrÃŧcke_' -'Bruttoinlandsprodukt' -'Brothers_' -'Bronze' -'Broad' -'Bristol_' -'Brief_' -'Brief' -'Brian' -'Brennstoffen_' -'Breit' -'Break' -'Brazilian_' -'Brand_' -'Boutique' -'Boston_' -'Born' -'Borde' -'Bombard' -'Bolivia_' -'Boden' -'Blin' -'Blick' -'Blei' -'Blau' -'Blasen_' -'Bit' -'Bisc' -'Biog' -'Billionen_' -'Billig' -'Bildschirm' -'Bil' -'Bezirk' -'Bey' -'Bewä' -'Bewu' -'Bewertungen_' -'Bevor_' -'Between_' -'Better_' -'Betriebe' -'Betra' -'Bestände_' -'Bestr' -'Bestec' -'Besichtigung' -'Beschleunigung_' -'Beschl' -'Beschaff' -'Berichterstatters_' -'Beobacht' -'Benalmadena_' -'Belohnung_' -'Belle' -'Bell' -'Belgian_' -'Beleg' -'Belange' -'Bekannt' -'Beitrittskandidaten_' -'Beis' -'Beir' -'Beihilfe' -'Behauptungen_' -'Behandlungs' -'Befe' -'Bef' -'Bedienung' -'Bavarian_' -'Bat' -'Bass' -'Barrier' -'Barre' -'Baron' -'Bari' -'Barg' -'Barc' -'Barbe' -'Barba' -'Banker' -'Bankensystem' -'Bankensektor' -'Banglades' -'Bang' -'Baker' -'BERLIN_' -'Azer' -'Autovermietung_' -'Autoren_' -'Autor' -'Ausstellung' -'Ausst' -'Ausspr' -'Ausser' -'AusschÃŧsse_' -'Ausschluss' -'Auss' -'Ausländer' -'Ausg' -'Ausfuhr' -'Auseinandersetzung_' -'Auschecken_' -'Ausbreitung_' -'Auktion' -'Aufwand_' -'Auftritt' -'Aufsichtsrat' -'Aufschwung_' -'Aufruf_' -'Aufnahme' -'Auflagen_' -'Aufl' -'Aufgaben' -'Aud' -'Attrakti' -'Assozi' -'Asiat' -'Asc' -'Arra' -'Army_' -'Armstrong_' -'Argent' -'Arbeitsplatz_' -'Arbeitsplatz' -'Arbeitsp' -'Arbeitsmärkte' -'Arbeitnehmern_' -'Arbeitnehmer' -'Arb' -'Arabs_' -'Application' -'Applause_' -'Appell' -'Appartement_' -'Appar' -'Apollo_' -'Apo' -'Aparthotel' -'Apartamentos_' -'Anyone_' -'Anwesenheit_' -'Antrieb' -'Anteile_' -'Anstatt_' -'Anspr' -'Anschluss' -'Anschein_' -'Anna_' -'Anklage_' -'Animat' -'Ani' -'Anh' -'Angriffs' -'Angreifer' -'Angola_' -'Angeb' -'Anderson_' -'Ande' -'Anal' -'Amerikaner' -'Ambition' -'Amb' -'Amat' -'Alternative' -'Alpi' -'Alpha' -'Alltags' -'Alkohol_' -'Alexanderplatz_' -'Alb' -'Aktivi' -'Akteuren_' -'Aid' -'Agri' -'Agenturen_' -'Again_' -'Afrikaner' -'Affi' -'Advent' -'Adv' -'Adress' -'Ada' -'Activ' -'Accord' -'Abz' -'Abwesenheit_' -'Abstände' -'Abstimmungs' -'Abso' -'Absicherung_' -'Abschreckung_' -'Abschnitt' -'Absa' -'Abraham' -'Above_' -'Abn' -'Abgr' -'Abge' -'Abend' -'Abdullah_' -'Abbildung' -'AX' -'AO' -'ANY_' -'AK' -'AH' -'ADE' -'ABAP_' -'A350_' -'A1_' -'=_' -'; â€ĸ _' -'93' -'83' -'825' -'81_' -'71_' -'450_' -'39' -'238' -'226' -'220_' -'1976_' -'1975_' -'1965_' -'1940_' -'1936_' -'1933_' -'1918_' -'190_' -'1840' -'18' -'178' -'177' -'169' -'163' -'158' -'154' -'128_' -'125_' -'116' -'110_' -'104' -'100' -'011' -'// _' -'.“' -'...] _' -'...) _' -'.. _' -'.-_' -'.  ' -'. ) _' -'. " _' -', ..._' -')|_' -'):_' -'): ÂĢ' -'), ' -'() , _' -'%) _' -'$ _' -'"-_' -'")._' -'" - _' -'!!!!' -' „ _' -' –&' -' ÂĢ _' -' ÂŖ_' -' `' -' [' -' ..' -'â„ĸ-_' -'â‚Ŧ_' -'„' -'”) _' -'ا' -'י' -'Ņ‚Đž' -'ĐŋŅ€Đž' -'ĐžŅ‚' -'ĐžŅ€' -'ĐŊĐĩ' -'Đģа' -'иĐĩ_' -'Đĩҁ' -'ĐĩŅ€' -'ĐĩĐŊ' -'да' -'΁' -'Îŋ' -'ÅŊ' -'Åŧ' -'ÅĄka_' -'ğ' -'ć_' -'ÃŊ_' -'Ãŧtlich' -'Ãŧte' -'Ãŧstung' -'Ãŧstet_' -'Ãŧrzung' -'Ãŧrze' -'Ãŧrger' -'Ãŧrg' -'Ãŧrfe_' -'Ãŧrf' -'Ãŧre' -'Ãŧrdig' -'Ãŧnstig' -'Ãŧnfte' -'Ãŧnde' -'Ãŧhren' -'Ãŧhre' -'Ãŧdisch' -'Ãŧcht' -'Ãŧbl' -'Ãŧberzu' -'Ãŧberzogen' -'Ãŧberzeugende' -'Ãŧbertrieben_' -'Ãŧbertrag' -'Ãŧbersteigt_' -'Ãŧberste' -'Ãŧberst' -'Ãŧberse' -'ÃŧberschÃŧ' -'Ãŧbernahme' -'Ãŧbermäßig_' -'Ãŧbermittelt_' -'Ãŧbermitteln_' -'Ãŧberlegt_' -'Ãŧberl' -'ÃŧberflÃŧssig' -'Ãŧberbe' -'Ãŧben' -'Ãŧbe' -'Ãļß' -'Ãļtigt' -'Ãļsterreichische_' -'Ãļsterreich' -'Ãļsser' -'Ãļsen_' -'Ãļsch' -'Ãļrtlich' -'Ãļrper' -'Ãļrig' -'Ãļre' -'Ãļr_' -'Ãļpfung_' -'Ãļni' -'Ãļlle' -'Ãļkologischen_' -'Ãļkologisch' -'Ãļhnlich_' -'Ãļhe' -'Ãļglichkeiten_' -'Ãļffentlich' -'Ãļfe_' -'ôte_' -'Ãŗn' -'Ã˛' -'ÃŽ' -'ío' -'Ên_' -'Êm' -'ège_' -'äßig_' -'äuter' -'äumen_' -'äufer_' -'ätig' -'ästinens' -'ässig' -'äss' -'ärtige' -'ängt_' -'ängst' -'ängig' -'änger' -'äne' -'änderungen_' -'ämter' -'ämpfte' -'ämpf' -'ämme' -'äm' -'ältesten_' -'äle_' -'äisch_' -'äi' -'ähnlichen_' -'ähnliche' -'ähneln' -'ählt_' -'ägige' -'äden_' -'ächtige' -'ächtig' -'ächt' -'ÃĄs' -'ßlich' -'ßges' -'ßende' -'ße' -'Überwindung_' -'Übersetze' -'Überschw' -'Überschuss' -'ÜberprÃŧfung' -'Übernacht' -'Überlegenheit_' -'Übereinkommens_' -'Ölv' -'Ölpreise_' -'Ökolog' -'Öff' -'Äthiopien_' -'Ã_' -'Âģ' -'ÂŽ_' -'ÂŽ, _' -'       ' -' %' -'}}) ==' -'}})' -'}{_' -'}, _' -'|' -'zzo_' -'zza' -'zyklus_' -'zwischenstaatliche_' -'zwingend_' -'zweite' -'zuzusch' -'zuvorkommende' -'zusätzliches_' -'zusammenzuf' -'zusammensetz' -'zurÃŧcksch' -'zurÃŧckl' -'zurecht' -'zuni' -'zung' -'zulä' -'zule' -'zukunfts' -'zukommt_' -'zugängliche' -'zugewiesen' -'zugeschnitten_' -'zugenommen_' -'zugeben_' -'zuck' -'zubereitet_' -'zte' -'zs' -'zp' -'zoo_' -'zonen' -'zona_' -'zk' -'zitiere_' -'zion' -'zinier' -'zini' -'zigste' -'ziger_' -'zifische' -'zielt' -'ziell' -'zia' -'zhou_' -'zeugt_' -'zeugnis' -'zero' -'zeri' -'zep' -'zentrums_' -'zentra' -'zem' -'zell' -'zeitweilig' -'zeitung_' -'zeitliche' -'zeitgenÃļssische' -'zeitg' -'zeichnung' -'zei_' -'zee' -'zak' -'yti' -'ystemen_' -'yps' -'yours' -'youn' -'yot' -'yla' -'yk' -'yer' -'year' -'yam' -'xpo' -'xon' -'xn' -'xion_' -'xic' -'wÃļhn' -'wÃļchentlich' -'wÃļ' -'wäh' -'wussten_' -'wusste_' -'wusst' -'wurs' -'wur' -'wron' -'writer_' -'wozu_' -'wouldn_' -'worsening_' -'workshop_' -'works' -'wora' -'wohnt' -'wohlhabender' -'withstand_' -'withdrawn_' -'wissenschaftlich' -'wirtschaftspolitische' -'wirtschaft' -'wirksamere' -'wirft_' -'wire' -'wins_' -'winners_' -'wings_' -'willkÃŧrliche' -'wik' -'wies_' -'wiegend' -'wiederzu' -'wiederhole' -'wiederherzustellen_' -'wiederauf' -'widerspiegelt_' -'widersp' -'widersetz' -'widerleg' -'widening_' -'wichtigste' -'wicht' -'wholesale_' -'whatsoever_' -'whale' -'wetter' -'wettbewerbsfähig_' -'wettbewerb' -'wett' -'wesens_' -'wertvolle' -'wertsteuer' -'werkzeug' -'welle_' -'welding_' -'welders_' -'weiße_' -'weitreichend' -'weithin_' -'weitergegeben_' -'weiterf' -'weitaus_' -'weigert_' -'weibliche' -'wegwe' -'weekend' -'week' -'wed' -'wecken_' -'wechselt_' -'wear' -'weaknesses_' -'weakened_' -'waters' -'waterfalls_' -'wat' -'wasn_' -'wary_' -'warnt' -'warn' -'warfare_' -'wan_' -'wall' -'wahrha' -'wahre' -'waffen' -'wachsam' -'vÃļllige' -'votre_' -'vorzube' -'vorteilhaft' -'vorläufige' -'vorkomm' -'vorhin_' -'vorherigen_' -'vorherge' -'vorhandene' -'vorgetragen_' -'vorgesehene_' -'vorgeschrieben' -'vorgelegte' -'vorgehen' -'vorgegebene' -'vorgefertigte' -'vorde' -'vorbereiten_' -'vorbehalt' -'voraussetz' -'vorausge' -'vorangeh' -'vons' -'volum' -'vollzog' -'vollendet' -'volle' -'voli' -'vole_' -'volatil' -'voice' -'vl' -'vivid' -'vität' -'vital' -'vista' -'visit' -'visib' -'viol' -'ving' -'vine' -'vina_' -'villa' -'vill' -'vigorous' -'viewer_' -'viet' -'vielversprechend' -'viels' -'vielfältigen_' -'vielfältige' -'vielfalt' -'vicious_' -'vic_' -'vibrant_' -'verÃļffentlich' -'veränder' -'verzi' -'verzeichnis_' -'verzauber' -'verwÃļhn' -'verwirr' -'verweist_' -'verweigert' -'verwei' -'verwalten_' -'vertritt_' -'vertrieben' -'vertretene' -'vertraue' -'vertief' -'verteidigte' -'verstÃļß' -'verstärk' -'verstoßen_' -'verstorben' -'verstehe_' -'versprechen_' -'versorgen_' -'versorg' -'verschwund' -'verschw' -'verschr' -'verschli' -'verschlechtert_' -'verschafft' -'vers_' -'verringerte_' -'verringer' -'verpflichtung' -'verordn' -'vernÃŧnftig_' -'vernetzt' -'vernachlässigen_' -'vernachlässig' -'vermÃļgen' -'vermis' -'verlängern_' -'verließ_' -'verletzen' -'verleih_' -'verle' -'verlaufen_' -'verlagerung' -'verkÃŧndet_' -'verkÃŧ' -'verin' -'verify' -'verhäng' -'verhältnisse_' -'verhältnis' -'verhe' -'verhaftet_' -'vergÃŧ' -'vergleichen' -'vergew' -'vergangene' -'verfÃŧgte' -'verfÃŧ' -'verfolg' -'verfe' -'verfasst' -'vereinte' -'vereinen_' -'vereinbarung' -'vereinbart_' -'vereinbar' -'verdoppeln_' -'verbundene_' -'verbleibende_' -'verbindungen_' -'verbieten_' -'verbiete' -'verbe' -'verbann' -'verband_' -'verarbeitung_' -'veranstaltungen_' -'veranlasst_' -'verankert_' -'verabschiedete_' -'ventionell' -'vent_' -'vendor' -'vec' -'vation_' -'vary_' -'var_' -'valued_' -'valent' -'vald' -'vak' -'vacu' -'vaca' -'vac' -'ußen' -'uzz' -'uum_' -'utz' -'utterly_' -'utility_' -'utilis' -'utili' -'utenant' -'utation' -'uta' -'usse' -'ussch' -'usher' -'ush' -'usgaben_' -'usen' -'useless_' -'urteile_' -'urt_' -'urse' -'urne' -'urm' -'urie' -'urgi' -'urges_' -'urban' -'upward_' -'uption' -'upte' -'uphold' -'upgraded_' -'unzureichende_' -'unverzichtbar_' -'unvereinbar_' -'unthink' -'unterteilt_' -'unterstÃŧtz' -'unterstreicht_' -'unterschiedlichste' -'unterschiedlicher_' -'unterschiede' -'unterr' -'unternehmer' -'unterm' -'unterhält_' -'unterhalten' -'unterhalb_' -'unterd' -'unterb' -'unst' -'unsicher_' -'unsc' -'unqu' -'unos_' -'unnecessary_' -'unmittelbarer_' -'unlängst_' -'unli' -'unle' -'unk_' -'unk' -'universally_' -'unim' -'unic' -'uni_' -'ungsvor' -'ungsver' -'ungssystem' -'ungsprogramm' -'ungso' -'ungsmaßnahmen_' -'ungsge' -'ungsaus' -'ungsanlagen_' -'ungsa' -'unglÃŧcklich' -'ungleichen_' -'unglaubliche' -'ungene' -'unga' -'unforgettable_' -'unfair' -'unexpected_' -'unexp' -'uneven' -'unerw' -'unequ' -'unen' -'underw' -'understa' -'underp' -'underline_' -'unconditional' -'uncha' -'unberÃŧhrt' -'unambiguous' -'umstritten_' -'umstritten' -'umst' -'umpft' -'umgew' -'umgest' -'umgeh' -'umfeld' -'umfassendes_' -'umfassend' -'umfangreicher' -'umfangreich' -'umen_' -'uma_' -'ulta' -'ulous_' -'ulos' -'ull' -'ulent_' -'ulden' -'ularit' -'ukrainischen_' -'uil' -'uier' -'ugn' -'ufung_' -'ufte' -'ues' -'uerung' -'uens' -'udia' -'uck' -'ubwÃŧrdig' -'uben' -'uan' -'u0027s_' -'tÃŧme' -'tÃļtet_' -'tÊ_' -'tzte_' -'tzte' -'tzlich' -'tzen' -'tyran' -'typischerweise_' -'typ_' -'twelve_' -'twe' -'twar' -'tw' -'tutt' -'tus_' -'tures_' -'tuous' -'tune_' -'tue' -'tu_' -'ttung_' -'ttlich' -'ttes' -'tteri' -'tschow_' -'trÃŧ' -'trÃļ' -'träger' -'träge' -'truste' -'truppen_' -'trupp' -'trumpe' -'trum' -'truktur' -'troubled_' -'tropi' -'tron' -'trivial' -'triumph' -'tripl' -'triebe' -'tribunal' -'tribe' -'trial' -'tria' -'treu_' -'treib' -'tre_' -'travels_' -'travelers_' -'travail_' -'traumhafte' -'trate' -'transposi' -'transported_' -'transpo' -'transp' -'transmitter' -'transmi' -'translator' -'translations_' -'translat' -'transitions_' -'transit_' -'transforming_' -'transformer_' -'transformati' -'transatlantic_' -'trans_' -'tran' -'tram_' -'trage' -'trafen_' -'traditioneller' -'traders_' -'tp' -'tournament_' -'tourismus' -'touched_' -'totalitären_' -'toren_' -'topbonus_' -'toni' -'tone' -'tolera' -'toilet_' -'toilet' -'toffe' -'tliches_' -'tland' -'tlan' -'titles_' -'tisier' -'tisches_' -'tion' -'tings_' -'timo' -'timi' -'till' -'tightening_' -'tiger_' -'tiert_' -'tiefgreifende' -'tian_' -'thun' -'thu' -'thron' -'thriving_' -'threaten_' -'thoroughly_' -'thorough' -'thick_' -'theorie' -'thek' -'theirs_' -'thei' -'teure_' -'teure' -'testif' -'tess' -'terw' -'ters_' -'terri' -'terminology_' -'terminal' -'termed' -'tener' -'tends_' -'tendenzi' -'tende' -'tend' -'tena' -'temptation_' -'template' -'temperatures_' -'temperatur_' -'temper' -'temal' -'teln' -'tellung' -'teles' -'telephones_' -'telefoni' -'tekt' -'teilzunehmen_' -'teilnehmenden_' -'teilnahm' -'tehen_' -'technologische_' -'technologie_' -'technologie' -'technisch_' -'technically_' -'technic' -'tear' -'tb' -'tav' -'tausende' -'taught_' -'tate_' -'tart' -'tarn' -'tariff_' -'tank_' -'tande' -'tance' -'taltung' -'talk' -'take' -'tak' -'tains_' -'taine' -'tai_' -'tah' -'tags_' -'tagen_' -'taf' -'tado' -'tackling_' -'tabe' -'taat' -'sÃŧdlich' -'säkulare' -'sächsische' -'szei' -'systemische' -'systematische_' -'systematische' -'systemati' -'sys' -'syrische_' -'synd' -'sync' -'swap' -'sw' -'survival_' -'surpass' -'suppose' -'superpower_' -'superi' -'sunny_' -'sunnitische' -'suicide' -'suffic' -'suff' -'successive_' -'success' -'suc' -'subventionier' -'subtr' -'subti' -'subsidize' -'subscri' -'subordinate' -'subje' -'stÃŧrz' -'stÃŧck' -'stÃļr' -'städtischen_' -'städte' -'styl' -'sty' -'stunden' -'stufen' -'studios_' -'studierte_' -'studiere' -'studie' -'strä' -'struktur' -'strophen' -'strong' -'stroke_' -'stricte' -'stretche' -'strengths_' -'strengere_' -'streng_' -'streng' -'streite' -'streiche' -'strebt' -'straßen_' -'strains_' -'straff' -'str' -'stom' -'stische_' -'stirbt_' -'stipulated_' -'stip' -'stina' -'stia' -'sthe' -'steuerung_' -'steuerl' -'steten_' -'stete' -'stery' -'sterie' -'sterbl' -'stems_' -'stellungen_' -'stein' -'steiger' -'steigender_' -'stehende_' -'stec' -'steadily_' -'stav' -'stattfindenden_' -'stattf' -'statistic_' -'statisti' -'statische' -'stationier' -'stating_' -'stati' -'statesm' -'starv' -'starship' -'starr' -'starker_' -'starb_' -'stande' -'stamp' -'stakeholders_' -'stadium' -'sta_' -'sspe' -'ssou' -'ssion' -'ssing_' -'sser_' -'ssene' -'ssari' -'ssad' -'sri' -'srechte' -'srech' -'squeeze' -'spÃŧre' -'späteren_' -'spätere' -'spy' -'spur_' -'sprÃŧ' -'spring' -'spri' -'spreading_' -'spots_' -'spotl' -'sporting_' -'spontaneous' -'sponsored_' -'spon' -'spokes' -'spitz' -'spiral_' -'spiegeln_' -'spiegel' -'spfl' -'spezifisch' -'spezielle' -'spezialisiert_' -'spende' -'spektakulär' -'speedy_' -'speed' -'spectac' -'specif' -'specially_' -'specialize' -'specialist' -'specialis' -'speci' -'spanische_' -'spalte' -'spac' -'sowjetischen_' -'south' -'souls_' -'sorganis' -'sorg' -'sons_' -'sons' -'sond' -'solo_' -'soli' -'sole' -'socie' -'socialism_' -'soap_' -'soa' -'snow_' -'sni' -'sneak' -'sne' -'smÃļglichkeit' -'smu' -'smoothly_' -'smoke_' -'smell_' -'smart' -'sman' -'smallest_' -'slowenisch' -'slowe' -'slogan_' -'slides_' -'slide_' -'slide' -'slich_' -'slic' -'slaver' -'slas' -'slan' -'sl_' -'sl' -'sky' -'skur' -'skra' -'skon' -'skin' -'skill' -'skie' -'sket' -'skeptisch' -'skandinavische' -'skandidat' -'skampagne' -'sitzt_' -'sitzen_' -'sity_' -'sits_' -'sition' -'sis' -'sir_' -'sip' -'sint' -'sinnvolle_' -'sinnlos_' -'sinni' -'sinkt_' -'sinkenden_' -'sink' -'singe' -'sine' -'simple' -'silver_' -'silen' -'signif' -'sierung' -'sierende' -'sier' -'sieg' -'sid' -'sichert_' -'sichergestellt_' -'sichere' -'sica_' -'sible_' -'sibi' -'shu' -'shri' -'shr' -'showers_' -'showcase' -'shortc' -'shoes_' -'shocked_' -'shock' -'shaped_' -'shape' -'shal' -'sfähig' -'sfun' -'setze' -'setz' -'setback' -'sers_' -'seria' -'separatist' -'sentiment' -'sensible' -'sensi' -'senen_' -'selten' -'sell' -'selig' -'selecting_' -'selber_' -'sek' -'seiten_' -'seeke' -'seed' -'sed' -'secure' -'sectarian_' -'secrecy_' -'seat' -'seasons_' -'sdauer_' -'sdate' -'scrutiny_' -'scre' -'scourge' -'science' -'schÃŧren_' -'schÃŧ' -'schÃļnes_' -'schÃļ' -'schätzung' -'schädliche_' -'schädliche' -'schädlich_' -'schäd' -'schwächer_' -'schwimme' -'schwi' -'schwerwiegenden_' -'schweig' -'schwarz_' -'schrä' -'schrumpfen_' -'schritten' -'schriftliche_' -'schriftliche' -'schottischen_' -'schon' -'scholarship' -'scholars_' -'schola' -'schnellstmÃļglich' -'schnellstens_' -'schneiden' -'schmerz' -'schme' -'schlug' -'schlu' -'schlimmste_' -'schlage' -'schlaf' -'schiitischen_' -'schiff' -'schere' -'scheitert' -'scheinbar_' -'scharfe' -'schar' -'scenery_' -'scatter' -'scare' -'scarce' -'scale' -'scal' -'sca' -'sbehÃļrde' -'sauberer' -'satisfactory_' -'sans_' -'sanita' -'sandy_' -'sanction' -'sanc' -'samml' -'samen_' -'sali' -'sais' -'saf' -'sadd' -'rÃŧste' -'rÃŧcks' -'rÃŧcke' -'rÃļm' -'rÊe' -'ränken_' -'räger' -'rÃĄ' -'rwe' -'rw' -'rvi' -'rush' -'rundlage' -'runde' -'rumänische' -'ruktur' -'ruk' -'ruf_' -'ruction' -'ruc' -'rubb' -'rtungen_' -'rting_' -'rteilung' -'rstat' -'rsion_' -'rse_' -'rsche' -'rschaft' -'rrs' -'rounde' -'rote_' -'rot_' -'rose' -'ror_' -'ropor' -'rooftop_' -'ront' -'roman_' -'rom_' -'rolle' -'rol' -'rodukti' -'rodukte' -'rocks_' -'rocke' -'rochen_' -'road' -'rno' -'rni' -'rmin' -'rmen_' -'rman_' -'rland' -'rla' -'rkung' -'rklärung' -'rix_' -'riv' -'ritte' -'ritt_' -'riskante' -'risk' -'risiko_' -'rise' -'risch_' -'riots_' -'ringung_' -'ringt' -'ril' -'rike' -'rika' -'rigid_' -'rigi' -'riffe' -'riff_' -'riesiger_' -'riesige_' -'rieb' -'richtungen_' -'richt_' -'richest_' -'ribut' -'ribe' -'riat' -'rian' -'riage' -'rgebnis' -'rfs' -'rfern_' -'revolutionäre' -'review' -'revidier' -'rever' -'revelation' -'reveals_' -'reundliche' -'reter' -'retains_' -'retained_' -'retailer' -'retail_' -'resultierende_' -'result' -'restriktive' -'restraint_' -'resses_' -'responsi' -'respekt' -'resisting_' -'resist' -'resent' -'reproductive' -'repre' -'reposi' -'reporte' -'replie' -'repetiti' -'reper' -'repeal' -'reparat' -'repan' -'renz' -'renovier' -'renovation' -'renounce' -'rengung' -'rende' -'reminds_' -'remin' -'remember' -'remark' -'reluctant' -'religions_' -'relian' -'relativen_' -'rejecting_' -'rej' -'reite' -'reinforcing_' -'reinforced_' -'reign' -'reife' -'reichlich_' -'rehabilitat' -'regulierung_' -'regulieren' -'regulati' -'regr' -'registrier' -'registr' -'regieren_' -'regener' -'refusing_' -'refurbish' -'refresh' -'refra' -'reforming_' -'reformers_' -'reform' -'reflecti' -'refine' -'reduzierte' -'redress_' -'redo' -'redit' -'redis' -'recurr' -'reconstruct' -'recons' -'recommends_' -'reckung' -'reckoning_' -'reckless' -'recken' -'reck' -'recit' -'rechnung' -'rechne' -'recher_' -'rechend' -'receiv' -'recapitaliz' -'rebellion_' -'rebalancing_' -'reasse' -'reappear' -'reap' -'realm' -'realization_' -'realiz' -'reagierte_' -'readin' -'reactor' -'reacted_' -'rds' -'rchy_' -'rbo' -'rbi' -'rberg' -'rber' -'raw' -'rauen_' -'rationali' -'ratings_' -'ratify_' -'ratifizieren_' -'ratic_' -'rasanten_' -'rapi' -'rape' -'ranking_' -'ranges_' -'rami' -'rama' -'rally' -'rall' -'rak' -'raine' -'raffin' -'raeli' -'rado' -'radioactiv' -'racial_' -'racht_' -'quisit' -'quest' -'quem' -'quee' -'quarters_' -'quantity_' -'quantit' -'qualifizierte_' -'qualifiziert' -'qi' -'puzzl' -'pursu' -'purchased_' -'punkte' -'publication_' -'pub' -'psychology_' -'psychische' -'psychiatr' -'pson_' -'präventi' -'präsidenten_' -'präsident' -'präsentieren_' -'präge' -'prudential' -'prozessor' -'proyecto' -'proxy_' -'proximitÊ_' -'provocation' -'provo' -'proving_' -'provincia' -'provinces_' -'proverb' -'prove' -'proud' -'protests_' -'proteste' -'protects_' -'protectionism_' -'protec' -'protagonist' -'prosecutor' -'prosecuti' -'pros' -'propos' -'proofed_' -'prone_' -'prominent' -'promenade_' -'prohibits_' -'progressive' -'progress' -'prognostiz' -'profliga' -'profitable_' -'profitability_' -'professionellen_' -'produktiven_' -'produktiv_' -'produktiv' -'produ' -'procurement_' -'procl' -'prochen' -'processor_' -'processed_' -'problematisch_' -'probieren_' -'probability_' -'proa' -'prize' -'privilegi' -'privati' -'priva' -'prit' -'prises_' -'principal' -'primären_' -'primäre_' -'prevents_' -'prevalen' -'prevails_' -'prevailing_' -'prevail_' -'prestig' -'presenta' -'present' -'prescription' -'prescribe' -'prer' -'prep' -'preo' -'premier_' -'prematurely_' -'premature_' -'preisg' -'preises_' -'predecessors_' -'predat' -'precondition' -'precis' -'prech' -'precedent_' -'praktischer_' -'prakti' -'praise_' -'pragmati' -'practise' -'prachigen_' -'pph' -'ppen' -'ppelt_' -'potenzi' -'potent_' -'potato' -'postpone' -'possessed_' -'positives_' -'positively_' -'positioni' -'portugiesische_' -'portab' -'popul' -'pop_' -'pond' -'pollut' -'polizeilichen_' -'politiciz' -'polio_' -'policymaking_' -'polar' -'pois' -'pointing_' -'pointer' -'pluralism_' -'plug' -'plot' -'plitter' -'pling_' -'plaus' -'plates_' -'plastic' -'plas' -'planetar' -'placement' -'plac' -'pix' -'pitz' -'pitali' -'piso' -'pir' -'pillo' -'pillars_' -'pill' -'pian' -'phob' -'philosophical_' -'philosopher' -'phenomena' -'phases_' -'pg' -'pfu' -'pflicht_' -'pflicht' -'pett' -'petition_' -'petiti' -'petit' -'pesticide' -'pessimisti' -'persÃļnliche' -'persua' -'personalize' -'personality_' -'personali' -'persecut' -'perpetrator' -'permits_' -'performing_' -'performa' -'pere' -'perceive_' -'perat' -'pensioner' -'pear' -'peace' -'pd' -'pc' -'paßt' -'pazifis' -'paz' -'paya' -'pave' -'patron_' -'patro' -'patrioti' -'paternalis' -'patch' -'passionate' -'pas' -'partition' -'partial_' -'parteien_' -'parte' -'parque' -'parlamentarische_' -'parity_' -'paris' -'paramount' -'param' -'paradise' -'pannen' -'pand' -'paintings_' -'pacif' -'pac' -'ox_' -'ovs' -'oversee' -'overs' -'overlooked_' -'overb' -'ova_' -'ova' -'outweigh_' -'outr' -'outp' -'outlin' -'outgoing_' -'outer_' -'outdated_' -'outbreak_' -'ouse' -'otte' -'otic_' -'otic' -'other' -'oth_' -'osteuropäischen_' -'osta' -'osse' -'ositione' -'osen_' -'ose' -'orthodoxy_' -'orthodox' -'ors' -'ormit' -'orma' -'orkommen' -'ork' -'orio' -'oring_' -'origine' -'originated_' -'originat' -'original' -'orientierten_' -'orientier' -'orgung_' -'orgen_' -'orge' -'organize_' -'organisierten_' -'organise_' -'orf' -'ored' -'ordnungsgemäße' -'ordneten_' -'ordnen_' -'ordinatio' -'orderly_' -'orde' -'ord_' -'orb_' -'orate' -'oral' -'optimistic_' -'optimism_' -'optimiert_' -'optimalen_' -'optim' -'opte' -'opro' -'oppos' -'opia' -'opho' -'operative' -'operat' -'ope_' -'oon' -'onward' -'ontrol' -'onste' -'onn' -'onist' -'onie' -'oner' -'onder' -'onde' -'ond_' -'onb' -'onat' -'omin' -'olved_' -'ols' -'ologische' -'ologis' -'ologe' -'ollst' -'olive_' -'olitischen_' -'olig' -'oli_' -'olb' -'okation' -'oint_' -'ographi' -'ogra' -'offs' -'offizieller_' -'offi' -'offer' -'oduktion' -'odell_' -'oda' -'ocken' -'occupies_' -'occupational_' -'obsolet' -'observ' -'oblige' -'objektive' -'objection_' -'obesity_' -'oberste_' -'oade' -'nÃŧg' -'nÃŧber' -'nÃļtige_' -'nÃļten_' -'näher' -'nve' -'nutzung_' -'nutrition_' -'nurtur' -'nummer_' -'nua' -'ntwicklung' -'ntsch' -'ntrate' -'ntly_' -'ntlich' -'nting' -'nth_' -'nteil' -'ntan' -'nsu' -'nstli' -'nso' -'nruhig' -'nre' -'nov' -'notorious_' -'notified_' -'north' -'normalisier' -'normal' -'noo' -'nomm' -'nomine' -'nomen_' -'nobelpreis' -'nnte' -'nkt' -'nks_' -'nkheit' -'niveaus_' -'nitt_' -'night' -'niemande' -'nico' -'nic_' -'ngst' -'ngle' -'ngl' -'ngene' -'ngel' -'nfo' -'nfe' -'neutrale' -'neun' -'neueren_' -'neuartig' -'nesi' -'nerl' -'nerg' -'nein' -'neighbour' -'neigen' -'nehmung' -'nehmbar_' -'negotiators_' -'negligen' -'ndst' -'ndliche_' -'ndliche' -'ndli' -'ndigen_' -'ndig' -'nderung_' -'ndere' -'ndene' -'nded_' -'nbe' -'navigation' -'nauf' -'natÃŧrlicher_' -'nationen_' -'nationalit' -'nationales_' -'nata' -'nat' -'nard' -'napp' -'nannten_' -'nande' -'naive_' -'nahm' -'naheleg' -'nachteilig' -'nachrichten_' -'nachläss' -'nachkommen_' -'nachhaltigeren_' -'nachgewiesen_' -'nachgeb' -'nachfolgenden_' -'nachfolge' -'mÃŧd' -'mÃļglich' -'mÃĒme_' -'männliche' -'männer' -'mächte' -'mysti' -'mysterious' -'myr' -'mußt' -'mutz' -'mutmaß' -'muslimische_' -'municipality_' -'mung_' -'multinationale_' -'multil' -'multifa' -'multic' -'muddl' -'mud' -'mpo_' -'mpho' -'mped' -'mp3' -'moves_' -'mour' -'mount' -'mott' -'motorways_' -'motiviert_' -'motive_' -'motion' -'mosa' -'morph' -'mori' -'monsters_' -'monopolies_' -'monopoli' -'monk' -'monitored_' -'mone' -'monatliche' -'monat' -'momenta' -'mold' -'mois' -'modifica' -'modernste' -'modernit' -'modernes_' -'mocht' -'moc' -'mobilisiert' -'mobilen_' -'mmy_' -'mmin' -'mma' -'mlo' -'mle' -'mixture_' -'mitzuteilen_' -'mittelfristig_' -'mitigate_' -'mitiga' -'mitgliede' -'mitglied' -'misunderstand' -'missverstanden' -'misst' -'missions_' -'miserable_' -'mise' -'miscon' -'mischen_' -'misch' -'minus' -'minu' -'mint' -'minist' -'minimiz' -'minimi' -'minimale' -'mine' -'mination' -'militärisch' -'mik' -'migrator' -'migran' -'miento_' -'mics_' -'microb' -'mica' -'mexi' -'metropolitan' -'meter' -'metaphor' -'metal' -'merits_' -'merger_' -'merchan' -'merc' -'menu' -'mento' -'menti' -'mentation_' -'menta' -'menschliche' -'mensch' -'mela' -'meint_' -'meinschaft' -'mehrheitlich' -'mega_' -'mega' -'mediterran' -'medien_' -'meda' -'med' -'mechanismen_' -'measured_' -'meas' -'maß_' -'maz' -'mayor_' -'may' -'mau' -'maturit' -'mathematics_' -'mathematical_' -'mathemati' -'mate_' -'massiv' -'maschinen' -'married_' -'marktor' -'markier' -'mark' -'marit' -'margin_' -'maps_' -'mans_' -'mann' -'manipulate' -'manifest_' -'mangelhaft' -'mang' -'manche' -'maln' -'malis' -'malign' -'maintains_' -'mainst' -'mails_' -'mah' -'magnetis' -'magic' -'lÃŧcken_' -'lÃŧc' -'lÃļste' -'lÃļschen_' -'läufe' -'läs' -'längerfristige_' -'läng' -'lvi' -'luxuriÃļs' -'lute' -'lust' -'lun' -'luft' -'ltu' -'lton' -'lpin' -'loyalt' -'loyali' -'loyal_' -'lower' -'love' -'lou' -'lord' -'long' -'lom' -'logische_' -'logische' -'login' -'logge' -'loft' -'locally_' -'lobe' -'loat' -'loads_' -'loading_' -'load' -'lnde' -'lmen' -'lm' -'llusion' -'llst' -'llier' -'ller' -'llel' -'llba' -'lke' -'lizi' -'lizenz' -'liza' -'livel' -'listing_' -'liste_' -'lism_' -'lisiert' -'lion' -'linux' -'linien' -'lingen_' -'linge_' -'liness_' -'linen_' -'lindern_' -'lind' -'lina' -'limitations_' -'lime' -'lil' -'lights_' -'liferation_' -'life' -'liest_' -'liefe' -'licke' -'licens' -'lica' -'libr' -'liberties_' -'liberalisierung' -'liberalisi' -'liberale' -'liberal' -'libe' -'lia_' -'lh' -'levy' -'level' -'leut' -'letztere' -'letztendlich_' -'lete' -'lest_' -'lessly_' -'lers' -'lender_' -'lement_' -'lek' -'leitung_' -'leitete_' -'leiter' -'leistungsfähig' -'leistungs' -'leiste' -'leih' -'leidenschaftlich' -'leidende' -'leichtes' -'legislature' -'legislativen_' -'legislat' -'left' -'leere_' -'led' -'lebig' -'lebendige' -'lean' -'lche' -'lbo' -'lbe' -'lawyer_' -'lavi' -'lava_' -'laundering_' -'launching_' -'laun' -'laufe' -'latz_' -'latein' -'lassen' -'lasse' -'lanung' -'lani' -'langw' -'langjährige' -'landwirtschaftliche' -'landscape' -'landmark_' -'landm' -'landes' -'landed' -'lance' -'lana' -'lamp' -'laim' -'lagging_' -'ladimir' -'lace' -'kÃŧste_' -'kÃŧrzungen_' -'kÃŧr' -'kÃŧndigt' -'kÃŧmmert_' -'kämpft' -'kunst_' -'kundig' -'kulturell' -'kult' -'ku_' -'ktisch' -'kse' -'kräft' -'kriteri' -'kriminell' -'krieg_' -'kreise_' -'krebs' -'krat' -'krank_' -'kostspielig' -'kostenlose' -'kostengÃŧnstig' -'kostenfreie' -'kos_' -'kos' -'korrekte_' -'korrekt' -'konzipier' -'konzert' -'kontrolle_' -'kontro' -'konservativen_' -'konservative_' -'konkurrieren' -'kong' -'kompet' -'kompatib' -'kommiss' -'kommand' -'komm' -'kohärent' -'knÃŧpf' -'know' -'klÃŧg' -'klärt' -'klä' -'klusiv' -'klu' -'klo' -'klinischen_' -'kling' -'klick_' -'klasse_' -'klargestellt_' -'klan' -'klag' -'kingdom' -'kind' -'kilo' -'kidnapp' -'kic' -'kha' -'keyboards_' -'keyboard_' -'ketten_' -'kennzeichne' -'kenntnis' -'kenne' -'kehrte' -'kehren_' -'katastrophale' -'kart' -'kapit' -'kap' -'kannt' -'kanische_' -'kanisch' -'kanen_' -'kanal' -'kade' -'kW_' -'justizielle' -'juris' -'junior_' -'jump_' -'juic' -'jours_' -'journals_' -'jos' -'jordanische' -'jihad' -'jeweilige_' -'jen' -'jecti' -'jas_' -'jara' -'jam_' -'jak' -'jag' -'jacuzzi_' -'izität_' -'izing_' -'ivste' -'itzer' -'itten_' -'itione' -'itio' -'itet' -'itert_' -'isung' -'istung' -'istisch_' -'istin' -'iste' -'ista_' -'israelisch_' -'ison_' -'isit' -'isiert_' -'isier' -'isi_' -'isende' -'isen' -'isch' -'irt' -'irri' -'irreversibl' -'irrespective_' -'ironic_' -'iro' -'iris' -'irgendwo_' -'irgendwie_' -'irgendwe' -'ire' -'irat' -'irakische_' -'iplina' -'ioned_' -'ionary_' -'inward' -'invol' -'invoice' -'invitation' -'investigations_' -'invest' -'intuitive_' -'introduces_' -'intro' -'intrinsic' -'intr' -'intim' -'interpreting_' -'interpretier' -'interpret_' -'interna' -'intermi' -'intermediati' -'intermediate' -'interm' -'interdependence_' -'interaktive' -'interactive_' -'interaction' -'interact' -'intensive' -'intensity_' -'intensify' -'intensifi' -'intelligente' -'intellektuellen_' -'intellektuell' -'int_' -'insult' -'insulat' -'instrumente' -'institutionalis' -'instituti' -'institu' -'insti' -'inste' -'installier' -'insists_' -'insiste' -'insecur' -'inquir' -'innu' -'inni' -'inländische_' -'inkom' -'injured_' -'injection_' -'initiativen_' -'initia' -'iniste' -'inist' -'inien_' -'inie' -'inhuman' -'inhibit' -'inherent_' -'inhaltliche' -'ington' -'ings' -'ingredients_' -'infring' -'infrastructur' -'inform' -'influx_' -'influential_' -'influenc' -'influ' -'inflationary_' -'inevitab' -'ines' -'iner' -'inent_' -'ineffiziente' -'indung' -'indulg' -'induc' -'indle' -'indis' -'indirectly_' -'indigenous_' -'indefinite' -'increment' -'inconvenien' -'incompatible_' -'incline' -'incidentally_' -'incapable_' -'inca' -'ination_' -'inat' -'inappropriate_' -'inal' -'inakzeptabel_' -'inacti' -'inac' -'imstande_' -'impulse_' -'impu' -'improves_' -'imprisonment_' -'impri' -'impressi' -'imposing_' -'implo' -'implizier' -'implicit' -'impli' -'implement' -'imperat' -'impediment' -'imon' -'immuni' -'immte' -'imma' -'imier' -'imb' -'imagined_' -'ily_' -'ils_' -'ilm' -'illusion_' -'illness_' -'illiquid' -'illig' -'illes' -'iligung' -'ilig' -'ilfe_' -'ilet' -'iler' -'ileg' -'ildet' -'ilation_' -'iki_' -'ikationen' -'ihood' -'ihe' -'igsten_' -'igran' -'igra' -'igo' -'igkeits' -'igende' -'ifizier' -'iff' -'ifen_' -'iest_' -'iertes_' -'ierbar' -'ientiert' -'ienstl' -'ienst' -'ielle' -'iell_' -'iegende' -'iegel' -'idylli' -'idung' -'ids' -'idio' -'identifie' -'identi' -'ident_' -'icul' -'ickt' -'icke' -'icio' -'ichtigt' -'icherheit_' -'icati' -'ican' -'ibut' -'ibu' -'ibl' -'ibel_' -'iba' -'iati' -'iate_' -'iam' -'iabl' -'hÃŧ' -'hÃļrt' -'hÃļher' -'hÃļhe_' -'hÃļchstwahrscheinlich_' -'hÃļchster_' -'härte' -'här' -'hän' -'hydrat' -'hv' -'hut' -'hurricane_' -'hurdles_' -'hunt' -'hunderte_' -'hund' -'hun' -'humb' -'humanitären_' -'humanit' -'hul' -'hugely_' -'hu_' -'htm' -'htig' -'hse' -'hrung' -'hrow_' -'hotl' -'hosti' -'hostage_' -'hosp' -'horrif' -'horr' -'hormon' -'horizont' -'hopeless' -'hoot' -'hoo' -'honour' -'hone_' -'homosexual' -'holz' -'holis' -'holder_' -'hoffentlich_' -'hockey_' -'hochrangige' -'hoc_' -'hnte' -'hns' -'hmt' -'hmi' -'hman' -'hlung_' -'hls' -'hkeit' -'historis' -'historian' -'histori' -'hinzuzu' -'hinten_' -'hinen' -'hine' -'hindernis_' -'hill' -'hilfen_' -'hilfe' -'hilf' -'hih' -'hibit' -'hf' -'hest' -'hesitati' -'herzustellen_' -'hervorgeh' -'hervorgebracht_' -'herunterge' -'herunter' -'herrlichen_' -'herr' -'hero' -'heri' -'herbeige' -'heranzu' -'heme' -'helpful' -'hellen' -'helle_' -'helicopters_' -'helicopter_' -'heis' -'heiraten_' -'heikle' -'heighten' -'hegemony_' -'hege' -'heels_' -'hedge_' -'heating_' -'heat' -'headlines_' -'hda' -'haven_' -'haupts' -'harte' -'hardline' -'hardest_' -'happi' -'hane' -'handle' -'handelte_' -'handelbar' -'haltung' -'haltig' -'haltestelle' -'halber_' -'hake_' -'hairdryer' -'hair' -'haftig' -'habitacion' -'habita' -'haber' -'gÃŧltig' -'gäste' -'gär' -'gängig' -'gänge_' -'gänge' -'gz' -'guy' -'gut' -'gust' -'gun_' -'guise' -'guess' -'grÃŧner' -'grÃļßter_' -'grä' -'grundlegender_' -'grundlegend_' -'großzÃŧgige_' -'großartigen_' -'grosse' -'gros' -'griff' -'grenzÃŧberschreitende_' -'grenzÃŧbergreifende' -'grenz' -'greifende' -'greifbare' -'greener' -'gray_' -'grave_' -'grausam' -'gratulier' -'gras' -'graphi' -'grandiose' -'grande' -'gram' -'gradi' -'grac' -'governor' -'goldene_' -'goldene' -'gno' -'gniz' -'gnan' -'glÃŧck' -'globalisierten_' -'glob' -'glieder' -'gleichgÃŧltig_' -'gleichbe' -'glaubwÃŧrdige_' -'glaubt_' -'glas_' -'glaring' -'glan' -'give' -'gist' -'girl_' -'gipfel_' -'ging' -'gill' -'gigantische' -'gif' -'giebig' -'gie_' -'gi_' -'ghan' -'gha' -'ggin' -'geänderte' -'gezielt_' -'gezahlt_' -'gewÃŧnschte_' -'gewÃŧ' -'gewÃļ' -'gewohnt_' -'gewohnheit' -'gewisser_' -'gewerbliche' -'gewei' -'gewartet' -'gewandt' -'gewaltsam' -'gewaltigen_' -'gewahr' -'getä' -'getre' -'geteilt_' -'geta' -'gesund' -'gesucht_' -'gesture' -'gestrichen_' -'gestop' -'gestellte_' -'gesteigert_' -'gestaltete' -'gespräche_' -'gespe' -'geson' -'gesetzliche' -'gesenkt_' -'geschÃŧ' -'geschätzt_' -'geschäftliche' -'geschäfte' -'geschäft' -'geschwindigkeit_' -'geschmackvoll_' -'geschlecht' -'geschah_' -'gesamte' -'gerÃŧstet_' -'gerät' -'gering' -'gerie' -'gerichteten_' -'gere' -'geraten' -'geprägte' -'gepriesen' -'gepa' -'geopoliti' -'geometri' -'geolo' -'geographische' -'geogr' -'genz' -'genwärtig' -'gens_' -'genoss' -'genom' -'genia' -'genetis' -'genetically_' -'genero' -'generator' -'generat' -'geneous_' -'genen_' -'gends' -'genau' -'gemischt' -'gemeldet_' -'gemeinschaftlich' -'gelÃļscht_' -'gelt' -'gelobt' -'gelmäßig' -'gelitten_' -'gelin' -'gelenk' -'geleitet_' -'geldpolitischen_' -'gelder_' -'gekÃŧr' -'gekauft_' -'gehend' -'geheimnis' -'gehandelt_' -'gehalt' -'geha' -'gegrÃŧndete' -'gegner' -'gegens' -'gefÃŧhrte_' -'gefÃŧ' -'gefährdete' -'gefor' -'geeignet' -'geehrter_' -'geehrt' -'gedient' -'geda' -'gebÃŧhren' -'gebÃŧ' -'gebrochen' -'gebilligt_' -'gebe_' -'geba' -'gea' -'gc' -'gb' -'gathering_' -'gastronomy_' -'garten_' -'garr' -'garis' -'garde_' -'garantierte_' -'gara' -'gap' -'gangene' -'gamb' -'galt_' -'galleries_' -'gad' -'fÃŧrsorg' -'fÃŧhlt_' -'fÃļrderung_' -'fälsch' -'fähigen_' -'fähige_' -'fusion_' -'fung_' -'fulness_' -'fulfilled_' -'fulfil_' -'fueling_' -'fsicht' -'fs' -'frustration_' -'frustrated_' -'fruits_' -'fruchtbaren_' -'frozen_' -'frontier' -'frist' -'fright' -'friction' -'freut_' -'freundlichen_' -'freundliche' -'freu' -'fren' -'fremd' -'freigegeben_' -'freier_' -'franch' -'frame' -'fragt_' -'fragmentation_' -'fraglich_' -'fragil' -'fraction_' -'frac' -'four' -'foto' -'fortgeschrittenen_' -'formulati' -'formulated_' -'formali' -'forged_' -'foremost_' -'foreg' -'forderte_' -'forbidd' -'fool' -'foodstuffs_' -'font' -'followers_' -'fold' -'foi' -'foe' -'focal_' -'flÃŧge' -'flächen_' -'fläch' -'flowing_' -'flori' -'flora_' -'flood_' -'fliege' -'flic' -'flexibler' -'flexible' -'flawed_' -'flanz' -'fizierten_' -'fixi' -'fitte' -'fishermen_' -'fischer' -'firma_' -'firm' -'firewall_' -'finnische_' -'finnisch' -'finishes_' -'fini' -'finest' -'finanzpolitische' -'finanzielle' -'finanz' -'finali' -'fill' -'fil' -'fift' -'fid' -'fiction_' -'ficia' -'fici' -'fication_' -'fica' -'fft_' -'ffs_' -'ffenen_' -'ffel' -'feti' -'festi' -'fertilize' -'fertili' -'fertigung' -'fert' -'feminis' -'fem' -'feldern_' -'feld' -'feindlich' -'fehlerhafte' -'feed' -'feasib' -'fea' -'favored_' -'favor' -'faun' -'faul' -'fathers_' -'fast' -'fassung_' -'fass' -'farbe' -'famil' -'fame' -'fallenden_' -'falle' -'faktor_' -'fahren' -'factu' -'factions_' -'facilita' -'faci' -'fabric_' -'fa_' -'eze' -'exzellente_' -'extracti' -'extract' -'extends_' -'exten' -'expressi' -'exposi' -'exportiert' -'exponential' -'explosion' -'exploration_' -'explodi' -'expli' -'expire' -'expell' -'expedi' -'expectation_' -'expands_' -'exotic_' -'exklusive_' -'existierende' -'exile' -'exert' -'exempt' -'executives_' -'execute_' -'exchanges_' -'exchange' -'exceedingly_' -'examining_' -'examination_' -'ew_' -'evolve' -'evil_' -'event' -'even' -'evacuat' -'europaweite' -'eun' -'etting_' -'etri' -'etliche' -'etischen_' -'etisch' -'etin' -'ethni' -'etes_' -'etende' -'etaria' -'etan' -'ester_' -'estat' -'estan' -'esser' -'essenti' -'espe' -'espac' -'esische' -'eseit' -'escal' -'erzählen' -'erzwingen_' -'erzie' -'erzeugte' -'erwähnte_' -'erwä' -'erwies' -'erweckt_' -'erwecken_' -'erwarteten_' -'erwa' -'erupt' -'erteilen_' -'ersucht_' -'ersuch' -'erstreckt_' -'erstreben' -'erstr' -'ersto' -'erstmalig' -'erstaunlich' -'erstatte' -'ersp' -'erson' -'erschweren' -'erschlie' -'erschienen_' -'erschien_' -'erregend_' -'erreg' -'erpro' -'ero_' -'ernste_' -'ernst' -'erneuerbarer_' -'ermÃļg' -'ermutigend' -'ermitt' -'ermaßen_' -'erman' -'erlich_' -'erlei' -'erlebten_' -'erl' -'erkrankung' -'erke' -'erien_' -'erholen_' -'erheit' -'erhaltene' -'ergänzend' -'ergreif' -'ergeh' -'ergebnisse' -'erfÃŧll' -'erfor' -'erfolgreicher_' -'erfen_' -'erfahr' -'ereit' -'ereigne' -'ereign' -'erde' -'erbracht' -'erbitte' -'erate' -'erat' -'eras' -'eran' -'equilibr' -'equal' -'eq' -'eption' -'epla' -'epen' -'eous' -'eordn' -'eolog' -'enza_' -'enwe' -'envisage' -'environments_' -'enu' -'entworfen_' -'entwickl' -'entsp' -'entsetzliche' -'entscheidungen_' -'entscheidung' -'entscheidende' -'entri' -'entrepreneurial_' -'entrepreneur' -'entren' -'entkommen' -'entis' -'entie' -'enthalte' -'entgeh' -'entg' -'entfern' -'enter' -'entb' -'ental_' -'ental' -'enso' -'ensiv' -'enro' -'enrichment_' -'enre' -'enra' -'enqu' -'enor' -'ennung' -'enlighten' -'enke' -'enix_' -'enische' -'enhancing_' -'engsten_' -'englischen_' -'englische_' -'englisch' -'engl' -'enforce_' -'energie' -'ened_' -'endur' -'endorse' -'endg' -'enda' -'enclos' -'enca' -'enberg_' -'enba' -'enan' -'enact' -'ena_' -'empl' -'empiri' -'emphasize_' -'emphasize' -'emphasises_' -'empha' -'empfind' -'empfan' -'emigrat' -'emen_' -'eme' -'embracing_' -'embod' -'embassy_' -'emat' -'emar' -'emancipati' -'elves_' -'elm' -'ellung_' -'eline' -'elimination_' -'eleven' -'elevated_' -'eler' -'elementar' -'elektrische' -'electr' -'electo' -'elbe' -'elan' -'ela_' -'ektion_' -'eitig' -'eister_' -'eister' -'eir' -'einzuräumen_' -'einzud' -'einzubez' -'einzub' -'einzigartige_' -'einzigartig_' -'einzelstaatliche_' -'einzelstaatliche' -'einzeln' -'einw' -'eintritt' -'einstimmig_' -'einsti' -'einste' -'einsetzte' -'einsetzt_' -'einschränk' -'einrichtung_' -'einleiten_' -'einle' -'einka' -'einhergeh' -'einheimischen_' -'einheimische_' -'einhalt' -'eingreif' -'eingetreten' -'eingesp' -'eingeschl' -'eingereichten_' -'einger' -'eingegangene' -'eingebr' -'einfÃŧh' -'einfache' -'eindämm' -'eindring' -'einbring' -'einbezieh' -'eimi' -'eilung_' -'eilte' -'eile_' -'eigte' -'eighteen' -'eigen_' -'eige' -'eift' -'eidig' -'eidet_' -'eiden' -'eichnete' -'ehrte_' -'ehren_' -'eholder' -'ehmer' -'ehlung_' -'ehemaliger_' -'ehe_' -'egr' -'egg_' -'egg' -'egel' -'eg_' -'efully_' -'eful_' -'effizienten_' -'efficiently_' -'een' -'educating_' -'editor_' -'editi' -'edit_' -'eding' -'eden_' -'edarf_' -'ected' -'ect' -'economi' -'ecommerce_' -'ecl' -'echter_' -'ece' -'ebun' -'ebu' -'ebnen_' -'ebn' -'easiest_' -'eas' -'earthquake_' -'ears' -'earning_' -'earn' -'eab' -'eB' -'e2' -'dÃŧnne' -'dÃŧnn' -'dÃŗ' -'dí' -'dè' -'dämm' -'däm' -'dynamik' -'dynamics_' -'dwe' -'duz' -'dustrie_' -'durchsetz' -'durchges' -'durchgefÃŧhrte_' -'duplicate_' -'dungs' -'drängt_' -'drun' -'druck' -'drove_' -'drought_' -'drohende' -'drin' -'dreie' -'dream' -'drasti' -'dragging_' -'drafts' -'dr' -'downgrade' -'douche' -'dort' -'doppelte_' -'doping_' -'donations_' -'dominiert_' -'dominieren' -'domini' -'dominate_' -'dom_' -'dock' -'doc' -'dividend' -'divid' -'divergence' -'disturbing' -'distributor' -'distributions_' -'distributi' -'distract' -'distinguishe' -'disr' -'disput' -'disproportionate_' -'dispose' -'displac' -'dismantling_' -'dismal' -'disl' -'diskreditier' -'discriminate_' -'discredit' -'discovering_' -'discont' -'disconnect' -'disclos' -'discipline' -'dische' -'disastrous_' -'disadvantage_' -'disabl' -'dirty_' -'direkter' -'directors_' -'direction' -'diplom' -'dip_' -'din_' -'din' -'diminishe' -'diligent' -'dilemma_' -'digung_' -'digt' -'differentiated_' -'differential' -'dieu' -'diet' -'diesel_' -'diert' -'diere' -'dier' -'dienstleist' -'dictators_' -'diamet' -'dial_' -'dial' -'diagnostic' -'dge_' -'dezentral' -'dez_' -'devise' -'develop' -'devaluation_' -'dev_' -'dev' -'deutschs' -'deutlicher_' -'deute' -'determina' -'determin' -'deteriorate' -'detect' -'destr' -'destiny_' -'destabiliz' -'desperate_' -'designe' -'desertification_' -'descript' -'describing_' -'derten_' -'derl' -'derives_' -'derive_' -'dering_' -'deri' -'derartig_' -'depri' -'depress' -'deposit' -'depo' -'deployed_' -'depe' -'dently_' -'denounce' -'dend' -'dence_' -'demonstrieren' -'demonstra' -'demolish' -'democrat' -'dementsprechend_' -'deln' -'della_' -'delivers_' -'delight' -'deliberately_' -'delete' -'delegati' -'deleg' -'dela' -'dein' -'dei' -'defor' -'deflat' -'defizit' -'definitive' -'defini' -'deficienc' -'defer' -'defensiv' -'defense' -'defender' -'deduct' -'dedication_' -'ded' -'decoration_' -'decom' -'declin' -'decides_' -'decepti' -'debatte' -'dealers' -'deaktiviert' -'deaf_' -'dde' -'dauer_' -'datum' -'datei_' -'darstelle' -'darle' -'dari' -'dargestellt_' -'damp' -'damaging_' -'dairy_' -'dahinge' -'dahin' -'dae' -'dachte_' -'dach' -'cyclic' -'cycli' -'cushion' -'cus' -'curr' -'curb' -'cup_' -'ctur' -'ctor' -'cry_' -'cry' -'crude_' -'crowd' -'crossing' -'crosse' -'crop_' -'criticized_' -'criticise_' -'criminali' -'creatures_' -'crashe' -'crash_' -'crafts' -'crafted_' -'cove' -'couv' -'courte' -'courageous' -'coupon' -'couple' -'counts_' -'counting_' -'counterpart_' -'corresponds_' -'correspond' -'corps' -'corporation_' -'corp' -'corners' -'copies_' -'coordinate_' -'cooperate' -'cookies_' -'convincing' -'conviction_' -'controversi' -'controlling_' -'contrast' -'contradiction_' -'contractor' -'contracti' -'contr' -'continuo' -'continuity_' -'continu' -'contentio' -'contempt_' -'contemplate' -'cont' -'consuming_' -'consume' -'consulting_' -'constru' -'constituency_' -'conservati' -'connectivity_' -'congress_' -'congratulat' -'confus' -'conducti' -'condu' -'conditional' -'condi' -'concluding_' -'conci' -'concerted_' -'conceptual_' -'compromises_' -'composit' -'complimentary_' -'complied_' -'complement_' -'comparisons_' -'communicating_' -'communica' -'communal' -'common' -'commodities_' -'commercial' -'commemorat' -'comm' -'comitology_' -'comi' -'comfort' -'comer' -'combina' -'column_' -'column' -'colony_' -'colonia' -'colon' -'collectively_' -'colla' -'coll' -'colat' -'cola' -'coins_' -'coincide' -'cof' -'codec' -'cod_' -'cock_' -'clus' -'clou' -'closure' -'clim' -'cliff' -'clear' -'cleans' -'clarifie' -'clar' -'clam' -'ckel' -'cked_' -'cka' -'civiliz' -'civi' -'cita' -'cist' -'circula' -'circuit_' -'circu' -'circl' -'cipa' -'cious_' -'cially_' -'chwi' -'chul' -'chtli' -'chten' -'chsene' -'chse' -'chronische' -'chritt' -'christlichen_' -'christliche_' -'chreiben' -'chni' -'chneide' -'chmi' -'chlu' -'chlo' -'chisch' -'chips_' -'chip_' -'chiefly_' -'chicken_' -'chicht' -'chev' -'cheris' -'chens_' -'charisma' -'charakteristischen_' -'charakteristi' -'characteris' -'chara' -'chancen_' -'champion' -'chambres_' -'chamber_' -'challenge' -'chair' -'cet_' -'certificate_' -'cert' -'ceremony_' -'ceremonies_' -'cera' -'censor' -'celebrate_' -'celebrat' -'cele' -'cel' -'cea' -'ccup' -'cave' -'causa' -'cattle_' -'cations_' -'categori' -'catastroph' -'catalyst_' -'casualties_' -'casting_' -'carte_' -'carriers_' -'carbo' -'cape' -'capa' -'cap_' -'cant' -'candi' -'cand' -'calme_' -'calibration' -'calculations_' -'calculati' -'calculat' -'caf' -'cade' -'caci' -'cache_' -'cabinet_' -'bÃŧrgerlichen_' -'bÃŧrgerlich' -'bÃŧrger' -'bÃŧn' -'bÃļse' -'bÃļrse' -'byte' -'byl' -'bust_' -'businessmen_' -'busines' -'burgh' -'bureaucratic_' -'bureau' -'burdens' -'bung' -'bundle' -'bug' -'buf' -'buen' -'bt' -'brÃŧ' -'browse' -'broker_' -'broker' -'broadly_' -'broadcasting_' -'broadcast_' -'broadcast' -'broadband_' -'bringen' -'brill' -'bright' -'bricht' -'brett' -'brethren_' -'brenn' -'breitere' -'breit' -'breeding_' -'bree' -'brechen' -'breathing' -'bread_' -'brave' -'brauch_' -'brasilian' -'bou' -'bos' -'borrowers_' -'boro' -'bord' -'borate' -'bora' -'booth_' -'booms_' -'boom' -'bond' -'boats_' -'boast' -'boarding_' -'bnis' -'blÃŧ' -'blutigen_' -'blun' -'blow_' -'bloods' -'blog' -'blis' -'blin' -'bliert_' -'blier' -'bliebe' -'blichkeit' -'bless' -'blea' -'blamed_' -'bkommen' -'bitter_' -'bitte' -'bits_' -'bisweilen_' -'bist_' -'biss' -'bird_' -'biologische' -'biological' -'biogra' -'biofuels_' -'binde' -'binary_' -'bina' -'billigen_' -'bilities_' -'bilds' -'bilder' -'bike' -'biete' -'bid' -'biblio' -'bian_' -'bezweifel' -'beziehungsweise_' -'bewältigt_' -'bewusstsein' -'bewusste_' -'bewu' -'bewohner_' -'bewirt' -'bewaffnete' -'bevorzugen_' -'bevo' -'beurteil' -'betreu' -'betreiber' -'betreffende' -'beteiligten_' -'besuch' -'bestätig' -'beständig_' -'beständig' -'bestreite' -'bestmÃļgliche' -'besti' -'bestens_' -'bestellt_' -'bestellen_' -'bestehender_' -'besta' -'best' -'besp' -'besetzt_' -'bese' -'beschädigt_' -'beschri' -'beschluss_' -'berÃŧhrt_' -'berÃŧhren' -'berÃŧhmte_' -'beruhend' -'beruhen_' -'berufliche' -'bers' -'bereitet_' -'bereite' -'berei' -'bequeme' -'beobacht' -'benign' -'benenn' -'benen_' -'benef' -'benachteiligt' -'bemerkenswerten_' -'belonging_' -'believing_' -'beliebt_' -'beleben' -'beleb' -'belasten_' -'bekunde' -'beiten_' -'beite' -'beispiellose' -'beispielhaft_' -'beiseite_' -'beider' -'behält' -'beherrsch' -'behavior' -'behauptete' -'begÃŧnstigt' -'begriffen_' -'begrenzte_' -'begreifen_' -'begeistert_' -'begeister' -'befÃŧrworte_' -'befÃŧrworte' -'befolgt_' -'befe' -'beeindruckt' -'beeindruckenden_' -'bedrohlich' -'bedrohen_' -'bedingte_' -'bedient_' -'bedeutender' -'bedenken_' -'bedauern' -'bedauere_' -'bedauer' -'bedank' -'becken' -'beauftragte' -'beauftragt_' -'bearbeitet' -'bearbeit' -'bear' -'beamte' -'beachte' -'bble_' -'baue' -'baths_' -'basierten_' -'bases_' -'barometer' -'bargain_' -'barer_' -'barbe' -'banner' -'banken' -'bang_' -'ballisti' -'baj' -'bahnb' -'bacteria' -'backward_' -'backup_' -'backpack' -'aÅž_' -'aya' -'awi' -'awards_' -'awakened_' -'außerordentliche' -'außergewÃļhnliche_' -'außerge' -'automo' -'autoc' -'authorisation_' -'auta' -'auszut' -'auszuh' -'auszug' -'ausweit' -'auswei' -'australische' -'austausch_' -'ausst' -'aussp' -'aussichten_' -'aussenden_' -'ausschl' -'ausrichte' -'auslÃļst' -'ausgewogenes_' -'ausgewogen_' -'ausgeweitet_' -'ausgetragen_' -'ausgestr' -'ausgestellt_' -'ausgestattete_' -'ausgestattete' -'ausgesch' -'ausgelegt_' -'ausgehend_' -'ausgeglichene' -'ausgegli' -'ausgefÃŧ' -'ausgearbeitet_' -'ausga' -'ausfÃŧhre' -'ausfällen_' -'ausa' -'aul' -'augen' -'aufzuh' -'aufwert' -'auftragten_' -'auftrag' -'aufstrebende_' -'aufstelle' -'aufsch' -'aufrufen_' -'aufkommen' -'aufgegriffen_' -'aufgebaut_' -'auferlegt' -'auferlegen_' -'aufbr' -'aufbe' -'auen' -'audit_' -'audi' -'attack' -'atroci' -'ators_' -'atom' -'atmospher' -'atme' -'ativ' -'ationsab' -'atie' -'ateur' -'atera' -'ater' -'aten' -'asu' -'assured_' -'assung_' -'assu' -'associations_' -'assisted_' -'assign' -'asser_' -'assassin' -'aspir' -'asonabl' -'asie' -'ascend' -'artist_' -'artikel_' -'artet' -'arted_' -'arrogant' -'arrest' -'arre' -'arose_' -'arom' -'armies_' -'arm_' -'arkeit_' -'arist' -'aring' -'arians_' -'arguing_' -'archäologische' -'archiv' -'archipelago' -'arch_' -'arbi' -'arbeitung_' -'arbeiter_' -'arbe' -'aqu' -'aps_' -'apr' -'approving' -'appropriate' -'apprecia' -'applicants_' -'applicant_' -'appelliere_' -'appear' -'appeals_' -'apparat' -'apli' -'api' -'anzustreben' -'anzusprechen_' -'anzuse' -'anzul' -'anzi' -'anzeige' -'anxiety_' -'anwendungen_' -'anwend' -'antwortete_' -'antworte' -'antre' -'antivirus_' -'antische' -'antiken_' -'anticipate_' -'answered_' -'anstelle_' -'ansteigen' -'anste' -'ansonsten_' -'ansieht_' -'anschließenden_' -'ansa' -'anre' -'anonymous_' -'annähernd_' -'annu' -'annten_' -'annt' -'ann' -'anm' -'anläss' -'ankÃŧndigte_' -'ankommt_' -'anke' -'anische_' -'anhe' -'anhalten_' -'angst' -'angre' -'angle_' -'angezogen_' -'angez' -'angewendet_' -'angespannt' -'angeschl' -'angesammelt' -'angere' -'angep' -'angenehme_' -'angeleg' -'angek' -'angegriffen_' -'anfÃŧhr' -'anfängliche' -'anerkenn' -'anerkannte_' -'aner_' -'anemic_' -'anei' -'ands' -'ando_' -'andauernde' -'ancho' -'ancest' -'anba' -'anat' -'analyti' -'analysing_' -'analogy_' -'analog_' -'ampli' -'amou' -'ammer' -'amkeit_' -'amin' -'ami_' -'amentari' -'ament_' -'amending_' -'amed_' -'ambient' -'ambience_' -'amat' -'alz' -'aluminum_' -'alum' -'alternative' -'alternat' -'alten' -'alpha' -'alor' -'allu' -'allocation_' -'allocate' -'allied_' -'allgemeiner' -'alleviat' -'allesamt_' -'allen' -'alleine' -'alive_' -'aliti' -'aling' -'alien_' -'ales' -'alc' -'alas_' -'alarmier' -'akzeptable' -'akut' -'aktueller' -'aktuell' -'aktivi' -'aktiver' -'aktion_' -'akademi' -'aka' -'ait' -'airspace_' -'ainer' -'aiming_' -'ailed_' -'aide' -'ahme_' -'ahme' -'ahlte' -'ahlt_' -'ahe_' -'ags_' -'agier' -'aggression_' -'agentur' -'agents_' -'agen' -'agein' -'afterward' -'afte' -'afrik' -'afor' -'afghanische_' -'affiliates_' -'aer' -'advocating_' -'advise' -'advertise' -'adverse_' -'advers' -'admissi' -'admire_' -'admir' -'administrator' -'administer' -'admin_' -'adjust' -'ades' -'activis' -'acquisition_' -'acquis_' -'acquir' -'ack_' -'achung_' -'achtung' -'achts' -'acher_' -'ached_' -'ache' -'acha' -'aces' -'accurately_' -'accuracy_' -'accumulati' -'accueil_' -'accounted_' -'accountability_' -'accounta' -'accordingly_' -'accompany_' -'accomodati' -'accom' -'accessibility_' -'abzustimmen_' -'abzus' -'abzumildern_' -'abzuge' -'abzeichne' -'abweichen' -'abtr' -'abstain_' -'absorbed_' -'absolv' -'absoluten_' -'absent_' -'absehbare' -'abschw' -'abschrecken' -'abschließende' -'abrupt' -'abolition_' -'abolished_' -'ablehn' -'abkommens_' -'abilities_' -'abhä' -'abgest' -'abges' -'abgen' -'abgeleg' -'abgeh' -'abends_' -'abend' -'abducti' -'`' -']]. _' -']] ' -'].' -'\' -'[ _' -'Zä' -'Zweite' -'Zuverlässigkeit' -'Zutat' -'Zut' -'Zuschauer_' -'Zuschauer' -'Zusammen_' -'Zurzeit_' -'Zugriffs' -'Zugr' -'Zuges' -'Zufriedenheit_' -'Zufall' -'Zoo_' -'ZivilbevÃļlkerung' -'Zitat' -'Zir' -'Zimmerman' -'Zigar' -'Zielvor' -'Zielsetzung' -'Ziels_' -'Zieh' -'Zeug' -'Zentren_' -'Zentralbank' -'Zentr' -'Zensur_' -'Zem' -'Zellen_' -'Zel' -'Zeitung' -'Zeits' -'Zeitlinie_' -'Zehn' -'Zedong' -'Zec' -'Zak' -'Zagreb' -'ZA' -'Yuko' -'Yugoslav_' -'Yi' -'Yas' -'Yah' -'YO' -'XWB_' -'XVI' -'XT' -'XIV_' -'XF' -'X1' -'WÃŧste_' -'WÃļrter_' -'WÃļrter' -'Wäld' -'Wählern_' -'Wunder_' -'Writ' -'Wortes_' -'World' -'Works_' -'Worker' -'Woods_' -'Woll' -'Wolfs' -'Wol' -'Wohnungs' -'Wohlstands' -'Wiss' -'Wirtschaftswissenschaftl' -'Wirtschaftsmi' -'Wirtschaftsent' -'Wirt' -'Wirkung' -'Wire' -'Wilde' -'Wiedervereinigung' -'Wiederhol' -'Widget' -'Whirlpool_' -'Whereas_' -'Whe' -'Whatever_' -'Wettbewerbsver' -'Wettbewerbspo' -'Westjordanland_' -'Wesentliche_' -'WertschÃļpfung' -'Wertpapiere' -'Wertpapier' -'Werte' -'Wen' -'Weltw' -'Weltre' -'Weltku' -'Weltge' -'Welle_' -'Well' -'Weine_' -'Weich' -'Wegb' -'Wechselkurse' -'Webserver' -'Water_' -'Wasserst' -'Wartung' -'Warte' -'Wanderung' -'Walter_' -'WallstrÃļm_' -'Wales_' -'Wald_' -'Wake' -'Wahlsieg_' -'Wahlk' -'Wahhab' -'Wagen_' -'Wachstumsrate_' -'Wachstumspoten' -'Wachstumsmodell_' -'Wachs' -'WT' -'WEI' -'WAV_' -'VÃļlkerrecht_' -'VÊ' -'Vulcan_' -'Vr' -'Vos' -'Vorwa' -'Vortrag' -'Vort' -'Vorste' -'Vorstand_' -'Vorsitzende_' -'Vorsi' -'Vorreiter' -'Vorrang_' -'Vorhanden' -'Vorgänge' -'Vorfall_' -'Vordergrund_' -'Vorde' -'Vorb' -'Volvo_' -'Volumen_' -'Volum' -'Vollvers' -'Vollst' -'Volkspartei_' -'Volkse' -'Volatilität_' -'Voic' -'Vizepräsident_' -'Vitt' -'Visa' -'Virginia_' -'Viren' -'Ville' -'Viewer_' -'View' -'Victor' -'Via' -'VerzÃļgerung_' -'Verze' -'Verwi' -'Verwer' -'Verw' -'Verurteilung_' -'Vertreter' -'Verteil' -'Verstärkung_' -'Versp' -'Verschwendung_' -'Verschw' -'Verschiedene_' -'Verschiebung' -'Versage' -'Verpackung' -'Vero' -'Verne' -'Verlust' -'Verlie' -'Verletzungen_' -'Verlau' -'VerknÃŧpfung' -'Verkehrst' -'Verhältnisse_' -'Vergleichss' -'Vergewaltig' -'Verga' -'Verfassungsentwurf_' -'Verfall' -'Vereinigungen_' -'Verdien' -'Verdacht_' -'Verda' -'VerbÃŧndete' -'Verbände' -'Verbreche' -'Verarbeitung' -'Veranstaltungs' -'Verachtung_' -'Venus_' -'Vec' -'Various_' -'Varia' -'Vali' -'Vakuum' -'VS' -'VC' -'VA_' -'Uzbekistan_' -'Uti' -'Uruguay_' -'Urteile' -'UrsprÃŧnge_' -'Urs' -'Urlaube' -'Uribe' -'Urheberrechts' -'Uran' -'Upp' -'Upgrade_' -'Unzufriedenheit_' -'Unw' -'Unterwa' -'Untert' -'Untersuch' -'Unterschr' -'Unterschied' -'Unternehmenss' -'Unterla' -'Untergang_' -'Unterg' -'Unterbrechung' -'Unsich' -'Unre' -'Unn' -'Unix_' -'Universitäts' -'Unit_' -'Unit' -'Unif' -'Unhei' -'UnglÃŧcklicherweise_' -'Ungere' -'Ungeachtet_' -'Underground_' -'Umweltschutz' -'Umweltsch' -'Umweltbe' -'Umstand_' -'Umst' -'Umfrage_' -'Umfeld' -'Uma' -'Ultra_' -'Uganda_' -'Ud' -'UV_' -'URL' -'UNA' -'TÃŧr' -'Täuschung' -'Tyran' -'Twenty_' -'Turnier_' -'Turne' -'Turin' -'Tunes_' -'Tune' -'Tuber' -'Tschechische' -'Tschechi' -'Tsche' -'Träger_' -'Truste' -'Trop' -'Trojan' -'Trip' -'Trick' -'Tric' -'Treibs' -'Treibhausgasemissionen_' -'Tree' -'Transporte' -'Transparen' -'Transaktionen_' -'Tran' -'Trainings' -'Trainer_' -'Tours_' -'Tourist' -'Toten_' -'Tool' -'Too_' -'Tomorrow_' -'Tole' -'Tokyo_' -'Toki' -'Tok' -'Todesf' -'Titel' -'Tin' -'Timothy_' -'Timor_' -'Tierschutz' -'Tiera' -'Tic' -'Tib' -'Things_' -'Therme' -'Theorie' -'Teufels' -'Teste' -'Terroris' -'Terms_' -'Temperatur' -'Telekom' -'Teilweise_' -'Teilung' -'Tehran_' -'Technologie' -'Technical_' -'Team' -'Taylor' -'Tausend' -'Tatatabot_' -'Taste_' -'Taste' -'Tasche' -'Target_' -'Tang' -'Tah' -'Tagung' -'Tageslicht_' -'TZ' -'TW' -'TU_' -'TRO' -'TL' -'TB' -'T2' -'SÃŧdtirol_' -'SÃŧds' -'SÃŧdostasi' -'SÃŧdost' -'SÃŧde' -'SÃŧdamerika_' -'SÃļ' -'SÊ' -'Sän' -'Szene' -'Szenario_' -'Synthe' -'Synt' -'Syndrom_' -'Symptome' -'Swim' -'Surely_' -'Supr' -'Superma' -'Summer_' -'Summer' -'Sultan' -'Sui' -'Suda' -'Suchmaschine' -'Subversi' -'Substantiv' -'Subsidiarit' -'Subscri' -'Subm' -'StÃŧ' -'StÃļrung_' -'StÃļr' -'Städt' -'Sturz_' -'Sturm_' -'StrÃļm' -'Strip' -'Strich' -'Stres' -'Streitigkeiten_' -'Streif' -'Straßenverkehr' -'Strau' -'Strategie' -'Strategi' -'Strasse' -'Strafv' -'Strafrecht' -'Strafgerichtshof' -'Strafen_' -'Stone' -'Stolz_' -'Stoffen_' -'Stier' -'Stich' -'Steve_' -'Steuerzahl' -'Steuersenkungen_' -'Steuers' -'Steuerer' -'Stern' -'Sterbe' -'Steph' -'Stellen' -'Stehende_' -'Steel' -'Statistiken_' -'Stahl_' -'Stadtk' -'Stadi' -'Staatsver' -'Staatss' -'Staatsanw' -'Spot' -'Spitzenpolitiker' -'Spion' -'Spiels_' -'Spiegel' -'Spie' -'Spenden_' -'Spektrum_' -'Speicherka' -'Specifically_' -'Speak' -'Spazier' -'Spaniens_' -'Spam' -'Spalte' -'Sozialpolitik_' -'Sozialm' -'Sozialist' -'Sozialh' -'Soziale' -'Soweit_' -'Sov' -'Sort' -'Song' -'Sonderbe' -'Soll' -'Solidarit' -'Solid' -'Solche_' -'Solan' -'Sofia' -'Sofern_' -'Snowd' -'Snowboard' -'Smart' -'Slu' -'Slov' -'Sle' -'Skeptiker' -'Skepsis_' -'Ske' -'Skandal_' -'Sixt' -'Sitzungsperiode_' -'Simon_' -'Silve' -'Siena_' -'Siege_' -'Sicherheitsrates_' -'Sicherheitsfrage' -'Sicher_' -'Sich_' -'Shop' -'Shir' -'Shadow' -'Sex_' -'Sev' -'Setzen_' -'Sett' -'Serg' -'Serbian_' -'Serbi' -'Serben_' -'Senk' -'Senior_' -'Sendung_' -'Senat_' -'Seminar_' -'Self' -'Sehen' -'Segment_' -'Seg' -'Seeverkehr' -'Seele_' -'Sechs' -'Seat' -'Scottish_' -'SchÃŧt' -'SchÃļnheit' -'Schwäche' -'Schwäch' -'Schwung' -'Schwimmbad_' -'Schwimm' -'Schwierigkeit_' -'Schwellen' -'Schwein' -'Schwarzen_' -'Schutzma' -'Schutzge' -'Schuss' -'Schuldens' -'Schuldendienst' -'Schuhputzmaschine_' -'Schritten_' -'Schriftsteller' -'Schott' -'School' -'Schnitt' -'Schmu' -'Schmidt_' -'Schloss' -'Schlo' -'Schlimmer' -'Schir' -'Schiene_' -'Schieds' -'Schick' -'Scher' -'Schand' -'Schad' -'Scanne' -'SaÅĄ' -'Say' -'Sax' -'Savo' -'Sav' -'Sauberkeit_' -'Sard' -'Saraj' -'Sarah_' -'Santo' -'Sanierung_' -'Sandstr' -'Sande' -'Samu' -'Sammlung' -'Sammel' -'Sak' -'Saison' -'Saint' -'Saf' -'Sacr' -'Sack' -'Sachverhalt' -'SWIFT_' -'SV' -'SSI' -'SSE' -'SPE' -'SL_' -'SIS_' -'SING' -'SD_' -'SDL_' -'SCO_' -'SC' -'SAL' -'RÃŧstungs' -'RÃŧckzug_' -'RÃŧcktritt_' -'RÃŧcksch' -'RÃŧckfÃŧhrung' -'Ryan_' -'Rwanda_' -'Russische' -'Rundf' -'Ruine' -'Rue' -'Rua' -'Roy' -'Routin' -'Rousse' -'Round_' -'Rotterdam_' -'Rotarier' -'Rotarian' -'Root' -'Roo' -'Ron_' -'Ron' -'Romans' -'Rohstoffe' -'Roger' -'Robot' -'Robinson_' -'Risikoauf' -'Rin' -'Rim' -'Rig' -'Riesen' -'Rico' -'Richtern_' -'Rhythm' -'Rhin' -'Revolutionary_' -'Rev' -'Resultat_' -'Resultat' -'Restoration_' -'Resources_' -'Residen' -'Resi' -'Reservierung' -'Reservation' -'Repub' -'Repressi' -'Representative' -'Rente_' -'Rennen_' -'Renmin' -'Religions' -'Relativ' -'Relations_' -'Reka' -'Reisetipp_' -'Reis' -'Reife' -'Reichs' -'Register' -'Regierungsf' -'Regierungschef' -'Regie' -'Regarding_' -'Refle' -'Reduktion' -'Recon' -'Rechtssysteme_' -'Rechtssystem_' -'Rechtssicherheit_' -'Rechtsg' -'Rechtschreib' -'Rechtsausschuss' -'Recherche' -'Rechenschaftspflicht_' -'Recep' -'Realm' -'Reading' -'Raus' -'Raums_' -'Rauch' -'Ratschl' -'Rating_' -'Ratifi' -'Rathaus' -'Rapid_' -'Raketenabwehr' -'Rail_' -'Rag' -'Radikalis' -'Radikal' -'Raba' -'RT_' -'ROM' -'RN' -'RK' -'RIC' -'RANT' -'RAM_' -'Quoten_' -'Quo_' -'Quellcode_' -'Quec' -'Quebec' -'Quar' -'Quality_' -'Qualifikationen_' -'QUI' -'Pv' -'Putsch' -'Pull' -'Publish' -'Ps' -'Präzisi' -'Präsidentschaft' -'Präsent' -'Prämi' -'Provinc' -'Provider_' -'Provid' -'Prototyp' -'Protektionismus_' -'Prostitution_' -'Prost' -'Prope' -'Promenade_' -'Programmier' -'Profite_' -'Professional_' -'Produkts' -'Produktivitätswachstum_' -'Produktionsst' -'Produktionsp' -'Production' -'Privatu' -'Privati' -'Privathaus' -'Prinz' -'Prince_' -'Primat' -'Priester' -'Prepa' -'Preises_' -'Prat' -'Pragmatismus_' -'Prade' -'PowerP' -'Potter' -'Potsdam' -'Postgre' -'Posse' -'Position' -'Portugies' -'Portr' -'Portale' -'Porta' -'Population_' -'Popul' -'Pope' -'Polizist' -'Points_' -'Plo' -'Play_' -'Pizza' -'Pitt' -'Picc' -'Physik_' -'Physi' -'Phrase' -'Photocopying_' -'Philippi' -'Pfla' -'Pferd' -'Pfeiler_' -'Pfad_' -'Pestizid' -'PersÃļnlichkeit' -'Perspective_' -'Peripherie' -'Periode' -'Perf' -'Pension_' -'Pedro' -'Pear' -'Peak_' -'Pay_' -'Pause' -'Paulo_' -'Patient_' -'Patente' -'Passw' -'Passport_' -'Passi' -'Passei' -'Passe' -'Parlamente_' -'Parlamente' -'Parlamentarier_' -'Parl' -'Parkplatz_' -'Parke' -'Parc' -'Parallel_' -'Paradigm' -'Paradies_' -'Parad' -'Papier_' -'Pandemi' -'Panam' -'Pan_' -'Palästinensern_' -'Palästina' -'Pale' -'Pakets_' -'Paint' -'Pai' -'Packag' -'PROV' -'PLAYER' -'PK' -'PIC' -'PAS' -'Outlook_' -'Osta' -'Osborne_' -'Ortho' -'Organen_' -'Ordn' -'Orden_' -'Orde' -'Orc' -'Orb' -'Orange_' -'Optim' -'Opportuni' -'Omni' -'Omar' -'Om' -'Olympi' -'Offiziere_' -'Oca' -'Obst' -'Obs' -'Obergrenze_' -'OW' -'OSZE_' -'OSCE_' -'Nächte' -'Nächste' -'Nä' -'Ny' -'Nutzungsbedingungen_' -'Num' -'Nuklearwaffen_' -'Nove' -'Nots' -'Noti' -'Notf' -'Notenbanken_' -'Notenbank_' -'Nost' -'Northwest' -'Norman' -'Normale' -'Norm_' -'Nordi' -'Nomad_' -'Noi' -'Nobel' -'Nil' -'Nikotin' -'Nikola' -'Nig' -'Nick' -'Nichts' -'Nicholas_' -'Newsletter_' -'Neva' -'Neut' -'Neuschwanstein_' -'Neui' -'Neues' -'Neuen' -'Neub' -'Neuan' -'Netzes_' -'Nes' -'Nep' -'Neo' -'Nem' -'Neigung_' -'Neighbo' -'Need_' -'Need' -'Neapel_' -'Navigat' -'Naturpark' -'Nature' -'Natural' -'Nationalpark_' -'Nationalis' -'Nass' -'Nase' -'Napol' -'Nap' -'Nahe_' -'Nad' -'Nachmittag_' -'Nachbarschaftspolitik_' -'Nachbarschafts' -'NTI' -'NO_' -'NL' -'NEC' -'ND_' -'NDE' -'NC_' -'NAFTA_' -'MÃŧn' -'MÃŧ' -'MÃļchten_' -'Mytho' -'Muslim' -'Musical' -'Museums_' -'Muni' -'Ms' -'Movielearn_' -'Movie_' -'Mountain_' -'Mountain' -'Mount_' -'Motors' -'Motore' -'Mother' -'Moskau_' -'Moses_' -'Morsi_' -'Morg' -'Mord_' -'Mora' -'Monti_' -'Montenegro_' -'Montai' -'Montag' -'Mong' -'Monarch' -'Moh' -'Moderne_' -'Moder' -'Modells_' -'Model_' -'Moda' -'Mobiltelefone' -'Mobiltelefon' -'Mitwirkung' -'Mitter' -'Mittelschicht_' -'Mittelmeerraum_' -'Mitteil' -'Mitgliedsländern_' -'Mitgliedsland' -'MitgefÃŧhl_' -'Mitg' -'Mitb' -'Mitarbeiter' -'Missverständnis' -'Missi' -'Missb' -'Minsk_' -'Ministerrat_' -'Mindeste' -'Minderheit' -'Mind' -'Min_' -'Milton_' -'Milo' -'Millions_' -'Millia' -'Millen' -'Mill_' -'Mill' -'Miliz' -'Militar' -'Milchprodukt' -'Milch' -'Milan_' -'Mikrof' -'Mig' -'Metropoli' -'Metal' -'Merkmal_' -'Merk' -'Mercedes_' -'MenÃŧ' -'Ment' -'MenschenwÃŧrde_' -'Menschenverstand' -'Menschenrechtsko' -'Memory_' -'Meinungsäußerung_' -'Meet' -'Meeress' -'Meeresf' -'Meere' -'Medina' -'Medic' -'Meda' -'Measur' -'McCa' -'Maßstäbe' -'Maya' -'Maurit' -'Matth' -'Mathe' -'Massen_' -'Massaker' -'Massachusetts_' -'Massa' -'Marí' -'Marxis' -'Marta_' -'Mars_' -'Marqu' -'Marktzugang' -'Marktp' -'Marktanteil' -'Margaret' -'Mannschaft_' -'Mangel' -'Mandats' -'Manchester_' -'Male_' -'Makro' -'Maje' -'Maj' -'Mais' -'Mailand_' -'Mahlzeiten_' -'Magst_' -'Magne' -'Maestr' -'Mae' -'Madr' -'Machine' -'Machbar' -'MX' -'MS' -'MP4_' -'MOV_' -'MOV' -'MIT_' -'MIN' -'MIDI_' -'MENT' -'MC' -'MAT' -'M4' -'LÃŧg' -'LÃļh' -'Luz' -'Luxus' -'Lul' -'Luka' -'Lufthansa_' -'Lud' -'Louvre' -'Lop' -'Looking_' -'Loo' -'Lohns' -'Logi' -'Locke' -'Local_' -'Lob' -'Loa' -'Lizenze' -'Liverpool_' -'Lio' -'Linke_' -'Link' -'Limited_' -'Liese' -'Lied' -'Libyan' -'Leva' -'Lesen_' -'Lerne' -'Leno' -'Leidens' -'Leich' -'Legitimation_' -'Legalität' -'Lega' -'Lee' -'Lebensst' -'Lebensmittelpr' -'Lebensmitteln_' -'Lebensg' -'Leader_' -'Laz' -'Law' -'Laur' -'Latvia' -'Latin' -'Lateinamerika' -'Lat' -'Laser' -'Lannoye_' -'Langzeit' -'Langs' -'Landwirten_' -'Lanc' -'Lampe' -'Lamaniten_' -'Lagen' -'Lagard' -'Lack' -'Lab_' -'LS' -'LR' -'LIF' -'KÃŧsten_' -'KÃŧhlschra' -'Käufe' -'Käm' -'Kyr' -'Kuro' -'Kup' -'Kuchen' -'Kuc' -'Kron' -'Kritik' -'Krist' -'Kriminelle' -'Krimi' -'Kriegsve' -'Kriege' -'Kreml_' -'Kreditv' -'Kreditk' -'Kreditgeber_' -'Krebs_' -'Kreaturen' -'Krat' -'Krankenvers' -'Kranke' -'Kraft' -'Korrekt' -'Korre' -'Kori' -'Kore' -'Korallen' -'Kopiere' -'Kopie_' -'Kopfs' -'Kopf' -'Konz' -'Konv' -'Kontrast_' -'KontextmenÃŧ' -'Konten_' -'Konservativen_' -'Konservat' -'Konjunkture' -'Konfrontation_' -'Konferenzr' -'Komponente_' -'Kompl' -'Kommissionsvorschlag_' -'Kommissionsmitglied' -'Kommissars_' -'Kommissare_' -'Kommandeur' -'Kolumbien' -'Kolonie' -'Kollekti' -'Kok' -'Kohä' -'Kohlenstoffemissionen_' -'Kofinanzierung' -'Kode' -'Kob' -'Koalition' -'Knowledge_' -'Kna' -'Klingon' -'Kli' -'Kleine' -'Klaus' -'Kiss' -'Kirch' -'Kinnock_' -'Kinderbetreuung_' -'Kinderarbeit_' -'Kind' -'Kerns' -'Kernel' -'Kennedy_' -'Kategori' -'Katastrophenschutz' -'Katalo' -'Karzai_' -'Kartell' -'Kari' -'Kare' -'Kapitalb' -'Kapit' -'Kanäle' -'Kanzler' -'Kana' -'Kamme' -'Kamin' -'Kama' -'Kalte_' -'Kalt' -'Kale' -'Kaiser_' -'Kaida_' -'Kabinett' -'KW' -'KLM_' -'KL' -'KEI' -'KB' -'Jury' -'Juris' -'Juli' -'Jugendlich' -'Jugendherberg' -'Jonas' -'Joint_' -'Johannesburg_' -'Jog' -'Jobs_' -'Job_' -'Jiang_' -'Jes' -'Jersey_' -'Jem_' -'Jem' -'Jel' -'Jaka' -'Jahrhunderte' -'Jahresz' -'Jagd_' -'Jacob_' -'JRE_' -'Italiener_' -'Italian' -'Istanbul_' -'Ist' -'Isolation_' -'Iron_' -'Iraqis_' -'Iraner_' -'Irakkrieg' -'Ira' -'Investition_' -'Inve' -'Inva' -'Intr' -'Interview_' -'Internetverbindung_' -'Interneta' -'Intensität' -'Intellektuelle' -'Instrument' -'Instan' -'Inst' -'Inspekt' -'Insp' -'Insolvenz' -'Innenpo' -'Innenminister' -'Inland_' -'Initiati' -'Ini' -'Inhalts' -'Inhaftierung_' -'Infrastrukturen_' -'Infos_' -'Informationst' -'Informati' -'Infineon' -'Indus' -'Indoor_' -'Indonesi' -'Indis' -'Indikat' -'Indic' -'Includ' -'Impuls' -'Impre' -'Imple' -'Impf' -'Immigration' -'Illusion_' -'Illus' -'Illi' -'Ideologien_' -'Identi' -'Ideally_' -'Ice' -'Ibn_' -'Iber' -'IR_' -'IO' -'IND' -'INCLUD' -'IKT_' -'II' -'IES_' -'IEN' -'ICS' -'IAE' -'HÃŧrden_' -'HÃļr' -'HÃļhenfl' -'HÃļhen' -'HÃļf' -'Hôtel_' -'Häusern_' -'Hätte' -'Händler' -'Hypotheken' -'Hypothe' -'Hygien' -'Hydra' -'Hv' -'Hus' -'Hungarian_' -'Hunderttausende' -'Hunderte_' -'Hund' -'Hul' -'Hugo_' -'Hua' -'Hotelzimmer_' -'Hoste' -'Homosexuelle' -'Hochzeit' -'Hochwasser' -'Hochschulen_' -'Historic' -'Hinwe' -'Himmel_' -'Hillary_' -'Hilfe' -'Heut' -'Heu' -'Het_' -'HerzstÃŧck_' -'Herunter' -'Herstell' -'Herman' -'Herkunftsl' -'Herausg' -'Herangehensweise_' -'Hera' -'Held' -'Heimatland' -'Heimat' -'Hebr' -'Header' -'Haw' -'Have' -'Hava' -'HaustÃŧr' -'Haushaltsplan' -'Haushaltsmittel' -'Haushaltsausschusses_' -'Hauptproblem' -'Hauptau' -'Hasse' -'Hass' -'Harvard_' -'Harris' -'Harmon' -'Hariri_' -'Handy_' -'Handlungen_' -'Handelsdefizit' -'Handelsbilanz' -'Hande' -'Handb' -'Hamp' -'Hamm' -'Halte' -'Halbjahr' -'Haider_' -'Had_' -'HOT' -'HM' -'HC' -'GÃļttin' -'Gä' -'Guth' -'Guns' -'Guid' -'Gues' -'Guardi' -'GuantÃĄnamo_' -'Gräuel' -'Gruppierung' -'GrundzÃŧge' -'Grundv' -'Grundsätzlich' -'Grundst' -'Grundge' -'Großv' -'Großm' -'Großkapital' -'Grey' -'Gremium_' -'Gremien_' -'Grego' -'Greenspan_' -'Green' -'Grap' -'Graci' -'Gourmet_' -'Gothic_' -'Goth' -'Gos' -'Golfplätze' -'Goldma' -'GlÃŧck' -'Glä' -'Gloucester' -'Globalis' -'Gleichw' -'Gleichg' -'Gleichbehandlung_' -'Glaube_' -'Gl' -'Giu' -'Gig' -'Gian' -'Gewäh' -'Gewissen' -'Gewiss' -'Gewi' -'Gewebe' -'Gewaltt' -'Gewalt' -'Getränke' -'Getränk_' -'Gesundheitssystem' -'Gesundheits_' -'Gestapo_' -'Gestalt_' -'Gesichtspunkt' -'Gesichts' -'Geschäftsbereich' -'Geschäftsbe' -'Geschmack' -'Geschlecht' -'Geschirr' -'Gescheh' -'Gesamtz' -'Gesamtko' -'German' -'Gepäck_' -'Genital' -'Genießen_' -'Generalversammlung_' -'Generalsekretär' -'Genauigkeit' -'Gemeinschaftsin' -'Gemein' -'Gelds' -'Geldm' -'Gehirn' -'Geha' -'GegenÃŧber' -'Gegengewicht' -'Geg' -'Gefolg' -'Gefangen' -'Gefahren' -'Geburts' -'Gebot' -'Gate_' -'Garden' -'Garde' -'Ganzes_' -'Gamm' -'Gamb' -'Gallery_' -'Galerie_' -'Gale' -'Gala' -'GUE' -'GP_' -'GNP_' -'GMOs_' -'GH_' -'GE_' -'GC_' -'GCC_' -'FÃŧlle' -'FÃŧhrungss' -'FÃŧhrer' -'Fälschung' -'Fußgänger' -'Futur' -'Futtermittel' -'FrÃŧhstÃŧcks' -'Frustration_' -'Fru' -'Frontier' -'Front' -'Fristen_' -'Friedensabkommen_' -'Freud' -'Fren' -'Fremdenverkehr_' -'Freizeita' -'Freil' -'Freedoms_' -'François_' -'FranzÃļsisch' -'Franz_' -'Francesc' -'Fran' -'Fragestunde_' -'Frage' -'Fr' -'Four' -'Founde' -'Fotografie_' -'Forums_' -'Forts' -'Fort_' -'Forschungsergebnisse_' -'Formular' -'Formel' -'Football_' -'Fonta' -'Following_' -'FlÃŧsse' -'FlÃŧchtlingskrise_' -'FlÃŧ' -'Flusse' -'Flus' -'Flugzeug_' -'Flugver' -'Fluch' -'Flu' -'Flagg' -'Fixed_' -'Fiskalpolitik_' -'Fischf' -'Fischereis' -'Fischereiabkommen_' -'Firmware_' -'Firmen' -'Firewall_' -'Finn' -'Finger_' -'Finanzver' -'Finanzstabilität_' -'Finanzministerium_' -'Finanzhilfe' -'Finanzd' -'Finanzb' -'Final' -'Figu' -'Fie' -'Fide' -'Feuer' -'Fettleibig' -'Feststellung' -'Festiv' -'Fernando_' -'Ferienhäuser_' -'Feinds' -'Faz' -'Fassung' -'Fassade' -'Farm_' -'Farbr' -'Far_' -'Fans_' -'Familienzimmer_' -'Fak' -'Fair_' -'Fahrs' -'Fahrlässigkeit_' -'Fahren_' -'Fahnen' -'Fact' -'Faced_' -'FP' -'FN' -'FIN' -'FB' -'FAQ' -'Extremis' -'Extras_' -'Experten' -'Exhi' -'Except' -'Exam' -'Ew' -'Evi' -'Everest_' -'Ever_' -'Even' -'Evangeli' -'Eva' -'Europäern_' -'Europe' -'Eurojust_' -'Eurog' -'Eurocopter_' -'Euroc' -'EuroM' -'Eurasi' -'Eur' -'Euch_' -'Eto' -'Essential' -'Espe' -'ErÃļ' -'Erzählung_' -'Erziehung_' -'Erzeugung_' -'Erzeuger_' -'Erzeug' -'Erweiterungen_' -'Erwa' -'Ert' -'Erstelle' -'Erstau' -'Ersparniss' -'ErschÃŧtterungen_' -'Ersatz_' -'Errungenschaft' -'Err' -'Erpr' -'Ernte' -'Erneu' -'Erleichterung' -'Erkrankungen_' -'Erkenntnisse_' -'Erinnerung' -'Erhol' -'Erhalt' -'Erh' -'Ergänzung_' -'Erfordernis' -'Erfolgsgeschichte' -'Erfa' -'ErdÃļl_' -'Erdgas_' -'Erdbeben' -'Erbr' -'Erb' -'Eras' -'Era' -'Equip' -'Episode' -'Epidemi' -'Epi' -'Entwicklungszusammenarbeit_' -'Entwicklungsziel' -'Entwicklungsstrategie' -'Entwick' -'Entspann' -'Entschl' -'Entscheidungsprozess' -'Entscheide' -'Entr' -'Entfernung_' -'Enter' -'Entdeckung' -'Eno' -'Engp' -'Engag' -'Energiee' -'Energieb' -'Endp' -'Endl' -'Ende' -'Employ' -'Empfänger' -'Empfind' -'Empfang' -'Emirate' -'Emerg' -'Eman' -'Email_' -'ElysÊes_' -'Elysee' -'Elizabeth_' -'Elis' -'Elementen_' -'Elektr' -'Eleganz_' -'Elefanten' -'Electric' -'Eisenbahnver' -'Eis' -'Einzelperson' -'Einzelheiten_' -'Einzelb' -'Einwi' -'Einwa' -'Einverständnis_' -'Einschnitte' -'Einrei' -'Einma' -'Einkaufsz' -'Einkaufssystem' -'Eink' -'Einhalt_' -'Einfl' -'Einen_' -'Eindämmung_' -'Einbe' -'EigentÃŧmer_' -'Eigenin' -'Eig' -'Ehe_' -'Effort' -'Effekte_' -'Editor_' -'Edit_' -'Economist' -'Ecol' -'Echtzeit_' -'Echt' -'Early_' -'EX_' -'EUR' -'EST' -'ENP' -'ENI' -'EMAS_' -'ECO' -'DÃŧsseldorf_' -'DÃŧn' -'DÃļ' -'Dän' -'Dä' -'Dynamik' -'Dyna' -'Dutzend_' -'Dus' -'Durchsch' -'Duomo_' -'Dum' -'Duff_' -'Dsch' -'DrÃŧcke' -'Drum' -'Drucks_' -'Drucker' -'Drogenk' -'Drittl' -'Dringlichkeits' -'Dringlichkeit_' -'Drei' -'Drago' -'Doyle_' -'Downloads' -'Dornik_' -'Doppelzimmer_' -'Doo' -'Dolomit' -'Dolmetsch' -'Dokumentation_' -'Doktrin' -'Diät_' -'Divers' -'DivX_' -'Distributoren_' -'Dissi' -'Diskriminierung' -'Discussi' -'Discover' -'Disa' -'Direktor' -'Direktive' -'Direkt_' -'Directory_' -'Director' -'Directi' -'Dip' -'Dilemma' -'Dienststelle' -'Dienstag_' -'Did_' -'Dictionary_' -'Dicht' -'Dich' -'Dialogs_' -'Dialog' -'Diagnose_' -'Diagnos' -'Diag' -'Diab' -'Devisen' -'Develope' -'Deutsche' -'Detail_' -'Dess' -'Designer_' -'Desc' -'Desa' -'Derartige_' -'Depo' -'Depar' -'Denkm' -'Denkens_' -'Demonstration_' -'Demon' -'Demokratisierung_' -'Democrat_' -'Dema' -'Delu' -'Delta_' -'Delors_' -'Delegation' -'Deine' -'Defizit_' -'Defizit' -'Defi' -'Deckmantel_' -'Death' -'Daw' -'Davon_' -'Datenbl' -'Date' -'Darwin_' -'Darleh' -'Dark_' -'Dane' -'Dampfb' -'Dame' -'Dalma' -'Dali' -'Dalai_' -'Daily' -'DafÃŧrhalten_' -'DS_' -'DM' -'DK' -'DJ_' -'DIC' -'DC' -'DAX_' -'Cut' -'Cus' -'Curt' -'Currently_' -'Currenc' -'Cultur' -'Cub' -'Cs' -'Crow' -'Cristi' -'Crisis_' -'Cr' -'Cover' -'Covenant' -'Course' -'Coup' -'Cott' -'Cord' -'Copy' -'Coop' -'Cool' -'Coo' -'Conv' -'Contro' -'Continu' -'Conta' -'Consum' -'Consulting_' -'Consult' -'Consi' -'Conse' -'Congr' -'Confu' -'Confedera' -'Conditions_' -'Conci' -'Computern_' -'Compr' -'Compani' -'Communication' -'Commons_' -'Commit' -'Commerce_' -'Commerc' -'Comm' -'Comfort' -'Come' -'Combi' -'Colombian' -'Cohe' -'Coelho_' -'Cochrane_' -'Coast_' -'Clu' -'Close_' -'Clip' -'Clif' -'Cleaning_' -'Claudi' -'Clas' -'Clark' -'Circle_' -'Cind' -'Chrom' -'Christus_' -'Christo' -'Christie_' -'Christen_' -'Christdemokraten_' -'Chic' -'Chev' -'Chest' -'Chemikalien' -'Charlotte' -'Chapel_' -'Chap' -'Chang' -'Challenge' -'Ces' -'Cert' -'Cer' -'Caucas' -'Catherine_' -'Catalan_' -'Catal' -'Castil' -'Cassi' -'Casio' -'Casino' -'Case_' -'Cart' -'Carrie_' -'Carp' -'Carolin' -'Carlo' -'Cara' -'Capi' -'Canc' -'Canari' -'Canal' -'Canadian_' -'Campingpl' -'Camera' -'Camer' -'Calendar_' -'Cairo' -'CTB' -'CT' -'CSS_' -'CO_' -'CONT' -'COM' -'CIA' -'CHI' -'CHF_' -'CHE' -'CGI_' -'CET_' -'CDU_' -'CCS_' -'BÃŧros' -'BÃŧrgerkrieg' -'BÃŧrgerbe' -'BÃŧnd' -'BÃŧch' -'BÃļge_' -'Busse' -'Busc' -'Burn' -'Burm' -'Bundesstaaten_' -'Bundesre' -'Bulgari' -'Building_' -'Buffet' -'Buddhist_' -'Buddhis' -'Buda' -'Buchst' -'Buchf' -'Bucher_' -'Buchen' -'BrÃŧder' -'BrÃŧcken_' -'BrÃŧ' -'Brustkrebs' -'Bruch_' -'Brot' -'Brooklyn_' -'Broc' -'Brew' -'Brett' -'Bres' -'Brenner' -'Bremen_' -'Brem' -'Breitband_' -'Braun' -'Branc' -'Bow' -'Boul' -'Botschafter_' -'Bosnien' -'Boots' -'Boot' -'Boo' -'Boeing' -'Bob' -'Board' -'Boa' -'Bo_' -'Blingee_' -'BlaÅž_' -'Blase' -'Blanch' -'Blan' -'Birma' -'Biot' -'Binnenmarkts_' -'Binnenm' -'Bildungsm' -'Bildes_' -'Bildern_' -'Bilanzen_' -'Bilanz' -'Bic' -'Bibl' -'Bhu' -'Bezirks' -'Bezahlung_' -'Bewertungs' -'Beweis' -'Bewahrung_' -'Betreuung' -'Betre' -'Betrachtung_' -'Betonung_' -'Besuchern_' -'Besuchen_' -'Besuche' -'Bestätigung_' -'Bestimm' -'Bestell' -'Besteh' -'Bestechungs' -'Besser' -'Besonderheiten_' -'Beschwerde' -'Beschlussfassung' -'Beschlusse' -'Berufung' -'Berufsbildung' -'Berufe' -'Beruf_' -'Berl' -'Berichterstatt' -'Berge' -'Berech' -'Berater' -'Beobachtungsstelle' -'Benzin' -'Benz_' -'Beni' -'Beliebt' -'Belgrad' -'Beleidigung' -'Belar' -'Beitrittsländer_' -'Beitritts_' -'Being_' -'Beil' -'Behe' -'Begriffe' -'Begrenzung' -'BefÃŧr' -'BefÃļrderungs' -'Befugnis' -'Befehle_' -'Befehl_' -'Bedingung_' -'Bec' -'Beauf' -'Beatri' -'Bearbeitung' -'Beamte_' -'Beachten_' -'Bayer' -'Bauwe' -'Baut' -'Baum_' -'Bauern' -'Bath' -'Basketball_' -'Basically_' -'Bashar_' -'Barry_' -'Barri' -'Baro' -'Barnier_' -'Banks_' -'Bankenunion_' -'Bande' -'Ballo' -'Bali_' -'Bald' -'Balance' -'Bajor_' -'Bahnstation_' -'Baden_' -'Bachelor_' -'BRCA' -'BF' -'BEAC' -'BA_' -'Azu' -'Ayatollah' -'Aviv_' -'Außer' -'Außens' -'Außenbe' -'AutonomiebehÃļrde_' -'Autok' -'Autobahn_' -'AusÃŧbung' -'Auszahlung' -'Auswe' -'Auswander' -'Austritt_' -'Austra' -'Austr' -'Austausch' -'Ausstoß' -'Ausstieg' -'Ausse' -'Ausschuß_' -'Ausschrei' -'AusrÃŧstung_' -'Ausmaßes_' -'Ausmaße_' -'AuslÃļser_' -'Auslegung' -'Auskunft' -'Ausgrenzung_' -'Ausgehend_' -'Ausgang_' -'Ausgabe' -'Ausflu' -'Auseinandersetzungen_' -'Ausdrucks' -'Ausd' -'Aur' -'Aufwertung' -'Aufträge_' -'Aufteil' -'Aufstockung_' -'Aufstellung_' -'AufsichtsbehÃļrde' -'Aufrechterhaltung_' -'Aufnahmen_' -'Aufk' -'AuffÃŧ' -'Auffassung' -'Auditors_' -'Au_' -'Attraktionen_' -'Atten' -'Attacke' -'Atomwaffe' -'Atl' -'Assa' -'Arzt_' -'Arzneimitteln_' -'Articles_' -'Arsenal_' -'Arou' -'Arn' -'Arma' -'Ark' -'Argentinie' -'Arena_' -'Archive' -'Arc_' -'Arbeitsw' -'Arbeitsprogramm' -'Arbeitsl' -'Arbeitern' -'Arbeite' -'Araber_' -'Aqu' -'Apr' -'Appe' -'Apost' -'Apache_' -'Anzeigen' -'Anzeige' -'Anwendungsbereich_' -'Anwender' -'Anweisungen_' -'Anweisung_' -'Anwalt' -'Anwa' -'Antrags' -'Antisemitismus_' -'Antibiotika_' -'Anteils' -'Anste' -'AnsprÃŧchen_' -'Anschuldigungen_' -'Anschrift' -'Anschließend_' -'Anschlag_' -'Ansatz' -'Anreise' -'Anregungen_' -'Annex_' -'Anmerkungen_' -'Anmerkung_' -'AnkÃŧndigung_' -'Ankl' -'Ankara_' -'AnhÃļrung' -'Anhä' -'Anhebung_' -'Angestellten' -'Anfäng' -'Anfa' -'Andria_' -'Andorra' -'Andalusien' -'Andalusia' -'Anc' -'Anbau' -'Anat' -'Analysten_' -'Amerikanis' -'Amba' -'Amazon_' -'Alzheimer_' -'Aly' -'Alv' -'Alum' -'Alten' -'Alta_' -'Alps_' -'Alp' -'Allie' -'Allgemeiner_' -'Aller' -'Algerien_' -'Algar' -'Alg' -'Alf' -'Alegr' -'Albanien_' -'Albania_' -'Alban' -'Alas' -'Alar' -'Alan_' -'Aktualisierung_' -'Aktivisten_' -'Aktionsprogramm_' -'Aktienmärkte_' -'Aktienm' -'Akte' -'Airconditioning_' -'Ahn' -'Ahmed' -'Ahmadinedschad_' -'Ahmadi' -'Agrars' -'Aggressi' -'Agent_' -'Advi' -'Advanced_' -'Addis' -'Active_' -'Across_' -'Acid' -'Achsen' -'Achse' -'Abwä' -'Abweich' -'Abwehr' -'Abtreibung' -'Abtei' -'Absp' -'Absolvent' -'Abse' -'Abschluss' -'Abschl' -'Abs' -'Abl' -'Abhängigkeit' -'Abfälle_' -'Aben' -'Abd' -'Abbe' -'Aa' -'AW' -'AP_' -'APE' -'AN_' -'ALL_' -'ALE' -'AE' -'AD_' -'ACCE' -'A2' -';% _' -'85' -'72' -'70er_' -'681' -'63' -'5th_' -'58' -'520_' -'4th_' -'42' -'400' -'3G_' -'3G' -'370' -'37' -'320' -'3000_' -'240_' -'225_' -'210' -'203' -'2025_' -'2017_' -'199' -'1987_' -'1984_' -'1981_' -'1972_' -'1970er_' -'1970' -'1969_' -'1961_' -'1947_' -'1929_' -'1920' -'179' -'171' -'170_' -'16th_' -'156' -'14th_' -'145' -'142' -'140_' -'124' -'121_' -'110' -'105_' -'102_' -'101' -'0D_' -'08' -'07' -'05' -'020' -'007' -'004' -'/+_' -'/ ' -'.  _' -'. ' -'......' -'..' -'.)._' -'.'_' -'. –' -'. &_' -'->' -'---' -'--' -',..._' -',- ' -', [_' -', (_' -'++' -'* ' -').' -')) (' -'))' -')"' -'() ._' -'')' -'''.' -'%\\' -'%-_' -'%), _' -'">- _' -'": _' -'"...' -'". _' -'", ,,_' -'" ._' -'!! _' -'!! !' -' ÂĢ_' -' = {_' -' = ' -' ;' -' -> _' -' ,' -' ***' -' ). _' -' („_' -' (.' -' ('' -' '''_' -' !!' -'؊' -'ט' -'ג' -'ŅŽ' -'҉' -'Ņ„' -'΃' -'ō' -'ě' -'ę' -'ā' -'Ãĩ' -'É' -'Â' -'Âŋ' -'²' -'~' -'[' -'â„ĸ' -'†' -'Ų†' -'Ų„' -'؁' -'Øą' -'ר' -'× ' -'Ҍ' -'Ņ…' -'Ī…' -'ÎŊ' -'Îģ' -'ś' -'œ' -'Ú' -'ÃŦ' -'Ñ' -'§' -'–' -'&' -'ãƒŧ' -'‚' -'Ų…' -'×§' -'ד' -'Д' -'Ī€' -'Îē' -'θ' -'β' -'ÅĢ' -'Ś' -'ń' -'ć' -'ÃĻ' -'Ê' -'Ã' -'Á' -'Âŧ' -'Âē' -'Âļ' -'´' -'@' -'#' -'년' -'čĒž' -'įŽ€' -'æœŦ' -'æ—Ĩ' -'文' -'åš´' -'中' -'−' -'â€ĸ' -'‑' -'ášŗ' -'⏍' -'ā¸ž' -'ā¤ž' -'⤰' -'Ų‰' -'Ų‡' -'Øĩ' -'ØĒ' -'ب' -'פ' -'ץ' -'ן' -'ו' -'Öŋ' -'Đš' -'ĐĨ' -'О' -'К' -'И' -'В' -'А' -'Ή' -'·' -'δ' -'Ί' -'ˤ' -'ư' -'ů' -'ř' -'Äž' -'ė' -'ĕ' -'ą' -'Ãģ' -'À' -'ÂŊ' -'š' -'­' -'ÂĨ' -'¤' -'ÂĄ' -'’' -'īŧš' -'īģŋ' -'īŦ' -'éģƒ' -'蒸' -'致' -'įžŽ' -'įŊ‘' -'į´™' -'ᆍ' -'斗' -'åŋœ' -'åĨŗ' -'呺' -'友' -'äŋĄ' -'äģ‹' -'丨' -'一' -'ãƒŖ' -'バ' -'チ' -'ジ' -'ã‚Ģ' -'ん' -'ら' -'め' -'â–ŧ' -'→' -'â€ģ' -'
' -'áģ›' -'áģ' -'áģ‹' -'áēŊ' -'áēģ' -'áēĨ' -'āĨ€' -'ā¤ŋ' -'⤝' -'ā¤Ŧ' -'⤤' -'⤛' -'⤆' -'ؐ' -'؃' -'Øē' -'Øš' -'د' -'ØŦ' -'ØĨ' -'،' -'×Ļ' -'ל' -'ה' -'Ņ–' -'Ҋ' -'Đ¯' -'Đ­' -'Ш' -'ĐĻ' -'ĐĄ' -'Đ ' -'П' -'М' -'Л' -'Г' -'Б' -'Ά' -'Îļ' -'Îŗ' -'Χ' -'Τ' -'Ι' -'Ε' -'˝' -'ˆ' -'ː' -'ˈ' -'Éž' -'ɛ' -'ɐ' -'ț' -'Åŋ' -'Åą' -'Å­' -'ő' -'Ő' -'ŏ' -'ň' -'İ' -'ÄĢ' -'ē' -'đ' -'Đ' -'ă' -'ÃŊ' -'ÃŖ' -'à' -'Ô' -'Ó' -'È' -'Å' -'ž' -'Âĩ' -'Âŗ' -'°' -'ÂŦ' -'Âĸ' -'™' -'—' -'“' -'' -'^' -'—' -'–' -'ÂŖ' -'<' diff --git a/trax/models/research/transformer2_test.py b/trax/models/research/transformer2_test.py deleted file mode 100644 index 18a10c5d3..000000000 --- a/trax/models/research/transformer2_test.py +++ /dev/null @@ -1,377 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Transformer models.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models.research import transformer2 - - -class Transformer2Test(absltest.TestCase): - - def test_concat_with_padding(self): - vec_e = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - # vec_e[:,:,0] != 0 - mask_e = np.array([[True, True, False, False, False, False], - [True, True, True, False, False, False]]) - - vec_d = np.array( - [[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - layer = transformer2.ConcatWithPadding(mode='train') - inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d - layer.init(shapes.signature(inp)) - y, _, _ = layer(inp) - - np.testing.assert_equal( - y, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - - def test_concat_with_padding_predict(self): - vec_e = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - # vec_e[:,:,0] != 0 - mask_e = np.array([[True, True, False, False, False, False], - [True, True, True, False, False, False]]) - - vec_d = np.array( - [[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - layer = transformer2.ConcatWithPadding(mode='predict') - inp = (vec_e, vec_d, mask_e, vec_e, vec_d) # tok_e = vec_e, tok_d = vec_d - _, _ = layer.init(shapes.signature(inp)) - y, _, _ = layer(inp) - - np.testing.assert_equal( - y, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - - # On subsequent runs however, we should get vec_d only. - for _ in range(2): - y, _, _ = layer(inp) - np.testing.assert_equal(y, vec_d) - - def test_concat_with_padding2(self): - vec_e = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - # vec_e[:,:,0] != 0 - mask_e = np.array([[True, True, False, False, False, False], - [True, True, True, False, False, False]]) - - vec_d = np.array( - [[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - layer = transformer2.ConcatWithPadding2(mode='train') - inp = (vec_e, vec_e, vec_d, mask_e, vec_e, vec_d) - layer.init(shapes.signature(inp)) - y1, y2, _, _ = layer(inp) - - np.testing.assert_equal( - y1, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - np.testing.assert_equal( - y2, - np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - ) - - def test_strip_from_concatenate_with_padding(self): - enc_dec = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) - tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) - - layer = transformer2.StripFromConcatenateWithPadding( - mode='train') - inp = (enc_dec, tok_e, tok_d) - _, _ = layer.init(shapes.signature(inp)) - y = layer(inp) - - np.testing.assert_equal( - y, - np.array([[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0]]])) - - def test_strip_from_concatenate_with_padding_predict(self): - enc_dec = np.array( - [[[7, 5, 2, 8, 8, 8, 6, 7], - [8, 2, 6, 2, 1, 1, 4, 2], - [4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - - [[4, 3, 1, 7, 5, 6, 2, 1], - [6, 9, 9, 4, 1, 3, 2, 1], - [3, 8, 2, 4, 7, 9, 4, 1], - [3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]]] - ) - - tok_e = np.array([[7, 8, 0, 0, 0, 0], [4, 6, 3, 0, 0, 0]]) - tok_d = np.array([[4, 6, 0, 0], [3, 4, 1, 0]]) - - layer = transformer2.StripFromConcatenateWithPadding( - mode='predict') - inp = (enc_dec, tok_e, tok_d) - _, _ = layer.init(shapes.signature(inp)) - y = layer(inp) - - np.testing.assert_equal( - y, - np.array([[[4, 7, 7, 4, 8, 9, 9, 9], - [6, 8, 2, 9, 3, 6, 6, 8], - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0]], - [[3, 7, 5, 6, 2, 9, 3, 1], - [4, 7, 3, 2, 1, 1, 1, 6], - [4, 7, 3, 2, 1, 1, 1, 6], - [0, 0, 0, 0, 0, 0, 0, 0]]])) - - # On subsequent runs however, we should get enc_dec only. - for _ in range(2): - y = layer(inp) - np.testing.assert_equal(y, enc_dec) - - def test_transformer_noencdec_forward_shape(self): - input_vocab_size = 16 - output_vocab_size = 16 - - model = transformer2.Transformer2( - input_vocab_size, output_vocab_size, d_model=32, d_ff=64, - n_encoder_layers=2, n_decoder_layers=2, n_heads=2) - - enc_toks = np.array( - [[6, 2, 0, 0, 0, 0], - [6, 3, 7, 0, 0, 0]]) - dec_toks = np.array( - [[4, 2, 0, 0], - [8, 5, 0, 0]]) - - xs = [enc_toks, dec_toks] - _, _ = model.init(shapes.signature(xs)) - - # decoder output, decoder mask - ys = model(xs) - - # (B, L2, H) - self.assertEqual(ys[0].shape, - (dec_toks.shape[0], dec_toks.shape[1], output_vocab_size)) - - self.assertEqual(ys[1].shape, dec_toks.shape) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/rl_test.py b/trax/models/rl_test.py deleted file mode 100644 index ac0e8b4ce..000000000 --- a/trax/models/rl_test.py +++ /dev/null @@ -1,55 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RL.""" - -from unittest import mock -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.models import rl - - -class RLTest(absltest.TestCase): - - def test_policy_forward_shape(self): - mock_dist = mock.MagicMock() - mock_dist.n_inputs = 4 - model = rl.Policy(policy_distribution=mock_dist) - x = np.ones((2, 3)) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (2, 4)) - - def test_value_forward_shape(self): - model = rl.Value() - x = np.ones((2, 3)) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (2, 1)) - - def test_policy_and_value_forward_shape(self): - mock_dist = mock.MagicMock() - mock_dist.n_inputs = 4 - model = rl.PolicyAndValue(policy_distribution=mock_dist) - x = np.ones((2, 3)) - _, _ = model.init(shapes.signature(x)) - ys = model(x) - self.assertEqual([y.shape for y in ys], [(2, 4), (2, 1)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/rnn.py b/trax/models/rnn.py index f29414a8e..e3f841bc8 100644 --- a/trax/models/rnn.py +++ b/trax/models/rnn.py @@ -19,208 +19,217 @@ from trax.fastmath import numpy as jnp -def RNNLM(vocab_size, - d_model=512, - n_layers=2, - rnn_cell=tl.LSTMCell, - rnn_cell_d_state_multiplier=2, - dropout=0.1, - mode='train'): - """Returns an RNN language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Embedding depth throughout the model. - n_layers: Number of RNN layers. - rnn_cell: Type of RNN cell; must be a subclass of `Layer`. - rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell - state. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout. - mode: If `'predict'`, use fast inference; if `'train'` apply dropout. - - Returns: - An RNN language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - - if n_layers != 2: # TODO(jonni): Remove n_layers arg, if it can't vary? - raise ValueError(f'Number of layers must be set to 2; instead got' - f' {n_layers}.') - - def MultiRNNCell(): - """Multi-layer RNN cell.""" +def RNNLM( + vocab_size, + d_model=512, + n_layers=2, + rnn_cell=tl.LSTMCell, + rnn_cell_d_state_multiplier=2, + dropout=0.1, + mode="train", +): + """Returns an RNN language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Embedding depth throughout the model. + n_layers: Number of RNN layers. + rnn_cell: Type of RNN cell; must be a subclass of `Layer`. + rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell + state. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout. + mode: If `'predict'`, use fast inference; if `'train'` apply dropout. + + Returns: + An RNN language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + + if n_layers != 2: # TODO(jonni): Remove n_layers arg, if it can't vary? + raise ValueError( + f"Number of layers must be set to 2; instead got" f" {n_layers}." + ) + + def MultiRNNCell(): + """Multi-layer RNN cell.""" + return tl.Serial( + tl.Parallel(tl.Select([0]), tl.Split(n_items=n_layers)), + tl.SerialWithSideOutputs( + [rnn_cell(n_units=d_model) for _ in range(n_layers)] + ), + tl.Parallel(tl.Select([0]), tl.Concatenate(n_items=n_layers)), + ) + + zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter + depth_multiplier=n_layers * rnn_cell_d_state_multiplier + ) + return tl.Serial( - tl.Parallel([], tl.Split(n_items=n_layers)), - tl.SerialWithSideOutputs( - [rnn_cell(n_units=d_model) for _ in range(n_layers)]), - tl.Parallel([], tl.Concatenate(n_items=n_layers)) + tl.ShiftRight(mode=mode), + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, mode=mode), + tl.Branch(tl.Select([0]), zero_state), + tl.Scan(MultiRNNCell(), axis=1, mode=mode), + tl.Select([0], n_in=2), # Drop RNN state. + tl.Dense(vocab_size), ) - zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter - depth_multiplier=n_layers * rnn_cell_d_state_multiplier - ) - - return tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, mode=mode), - tl.Branch([], zero_state), - tl.Scan(MultiRNNCell(), axis=1, mode=mode), - tl.Select([0], n_in=2), # Drop RNN state. - tl.Dense(vocab_size), - ) - - -def GRULM(vocab_size=256, - d_model=512, - n_layers=1, - mode='train'): - """Returns a GRU (gated recurrent unit) language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Embedding depth throughout the model. - n_layers: Number of GRU layers. - mode: If `'predict'`, use fast inference (and omit the right shift). - - Returns: - A GRU language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - return tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(vocab_size, d_model), - [tl.GRU(d_model, mode=mode) for _ in range(n_layers)], - tl.Dense(vocab_size), - ) + +def GRULM(vocab_size=256, d_model=512, n_layers=1, mode="train"): + """Returns a GRU (gated recurrent unit) language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Embedding depth throughout the model. + n_layers: Number of GRU layers. + mode: If `'predict'`, use fast inference (and omit the right shift). + + Returns: + A GRU language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + return tl.Serial( + tl.ShiftRight(mode=mode), + tl.Embedding(vocab_size, d_model), + [tl.GRU(d_model, mode=mode) for _ in range(n_layers)], + tl.Dense(vocab_size), + ) # TODO(jonni): Decide names (here and Transformer): input/source, output/target # TODO(jonni): Align with Transfomer: (attention-)dropout, n-(attention-)heads -def LSTMSeq2SeqAttn(input_vocab_size=256, - target_vocab_size=256, - d_model=512, - n_encoder_layers=2, - n_decoder_layers=2, - n_attention_heads=1, - attention_dropout=0.0, - mode='train'): - """Returns an LSTM sequence-to-sequence model with attention. - - This model is an encoder-decoder that performs tokenized string-to-string - ("source"-to-"target") transduction: - - - inputs (2): - - - source: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(input_vocab_size)`, and `0` - values mark padding positions. - - - target: rank 2 tensor representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length). The - tensor elements are integers in `range(output_vocab_size)`, and `0` - values mark padding positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - An example use would be to translate (tokenized) sentences from English to - German. - - The model works as follows: - - * Input encoder runs on the input tokens and creates activations that - are used as both keys and values in attention. - * Pre-attention decoder runs on the targets and creates - activations that are used as queries in attention. - * Attention runs on the queries, keys and values masking out input padding. - * Decoder runs on the result, followed by a cross-entropy loss. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - target_vocab_size: Target vocabulary size. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - n_encoder_layers: Number of LSTM layers in the encoder. - n_decoder_layers: Number of LSTM layers in the decoder after attention. - n_attention_heads: Number of attention heads. - attention_dropout: Stochastic rate (probability) for dropping an activation - value when applying dropout within an attention block. - mode: If `'predict'`, use fast inference. If `'train'`, each attention block - will include dropout; else, it will pass all values through unaltered. - - Returns: - An LSTM sequence-to-sequence model as a layer that maps from a - source-target tokenized text pair to activations over a vocab set. - """ - input_encoder = tl.Serial( - tl.Embedding(input_vocab_size, d_model), - [tl.LSTM(d_model) for _ in range(n_encoder_layers)], - ) - - pre_attention_decoder = tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(target_vocab_size, d_model), - tl.LSTM(d_model, mode=mode), - ) - - def PrepareAttentionInputs(): - """Layer that prepares queries, keys, values and mask for attention.""" - def F(encoder_activations, decoder_activations, input_tokens): - keys = values = encoder_activations - queries = decoder_activations - # Mask is 1 where inputs are not padding (0) and 0 where they are padding. - mask = (input_tokens != 0) - # We need to add axes to the mask for attention heads and decoder length. - mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) - # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. - mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) - mask = mask.astype(jnp.float32) - return queries, keys, values, mask - return tl.Fn('PrepareAttentionInputs', F, n_out=4) - - return tl.Serial( # in-toks, target-toks - tl.Select([0, 1, 0, 1]), # in-toks, target-toks, in-toks, target-toks - tl.Parallel(input_encoder, pre_attention_decoder), - PrepareAttentionInputs(), # q, k, v, mask, target-toks - tl.Residual( - tl.AttentionQKV(d_model, n_heads=n_attention_heads, - dropout=attention_dropout, mode=mode, - cache_KV_in_predict=True) - ), # decoder-vecs, mask, target-toks - tl.Select([0, 2]), # decoder-vecs, target-toks - [tl.LSTM(d_model, mode=mode) for _ in range(n_decoder_layers)], - tl.Dense(target_vocab_size), - tl.LogSoftmax() - ) +def LSTMSeq2SeqAttn( + input_vocab_size=256, + target_vocab_size=256, + d_model=512, + n_encoder_layers=2, + n_decoder_layers=2, + n_attention_heads=1, + attention_dropout=0.0, + mode="train", +): + """Returns an LSTM sequence-to-sequence model with attention. + + This model is an encoder-decoder that performs tokenized string-to-string + ("source"-to-"target") transduction: + + - inputs (2): + + - source: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(input_vocab_size)`, and `0` + values mark padding positions. + + - target: rank 2 tensor representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length). The + tensor elements are integers in `range(output_vocab_size)`, and `0` + values mark padding positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + An example use would be to translate (tokenized) sentences from English to + German. + + The model works as follows: + + * Input encoder runs on the input tokens and creates activations that + are used as both keys and values in attention. + * Pre-attention decoder runs on the targets and creates + activations that are used as queries in attention. + * Attention runs on the queries, keys and values masking out input padding. + * Decoder runs on the result, followed by a cross-entropy loss. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + target_vocab_size: Target vocabulary size. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + n_encoder_layers: Number of LSTM layers in the encoder. + n_decoder_layers: Number of LSTM layers in the decoder after attention. + n_attention_heads: Number of attention heads. + attention_dropout: Stochastic rate (probability) for dropping an activation + value when applying dropout within an attention block. + mode: If `'predict'`, use fast inference. If `'train'`, each attention block + will include dropout; else, it will pass all values through unaltered. + + Returns: + An LSTM sequence-to-sequence model as a layer that maps from a + source-target tokenized text pair to activations over a vocab set. + """ + input_encoder = tl.Serial( + tl.Embedding(input_vocab_size, d_model), + [tl.LSTM(d_model) for _ in range(n_encoder_layers)], + ) + + pre_attention_decoder = tl.Serial( + tl.ShiftRight(mode=mode), + tl.Embedding(target_vocab_size, d_model), + tl.LSTM(d_model, mode=mode), + ) + + def PrepareAttentionInputs(): + """Layer that prepares queries, keys, values and mask for attention.""" + + def F(encoder_activations, decoder_activations, input_tokens): + keys = values = encoder_activations + queries = decoder_activations + # Mask is 1 where inputs are not padding (0) and 0 where they are padding. + mask = input_tokens != 0 + # We need to add axes to the mask for attention heads and decoder length. + mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) + # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. + mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) + mask = mask.astype(jnp.float32) + return queries, keys, values, mask + + return tl.Fn("PrepareAttentionInputs", F, n_out=4) + + return tl.Serial( # in-toks, target-toks + tl.Select([0, 1, 0, 1]), # in-toks, target-toks, in-toks, target-toks + tl.Parallel(input_encoder, pre_attention_decoder), + PrepareAttentionInputs(), # q, k, v, mask, target-toks + tl.Residual( + tl.AttentionQKV( + d_model, + n_heads=n_attention_heads, + dropout=attention_dropout, + mode=mode, + cache_KV_in_predict=True, + ) + ), # decoder-vecs, mask, target-toks + tl.Select([0, 2]), # decoder-vecs, target-toks + [tl.LSTM(d_model, mode=mode) for _ in range(n_decoder_layers)], + tl.Dense(target_vocab_size), + tl.LogSoftmax(), + ) diff --git a/trax/models/rnn_test.py b/trax/models/rnn_test.py deleted file mode 100644 index 6de04bea2..000000000 --- a/trax/models/rnn_test.py +++ /dev/null @@ -1,60 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RNNs.""" - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -from trax.models import rnn - -BACKENDS = [fastmath.Backend.JAX] - - -@parameterized.named_parameters( - ('_' + b.value, b) for b in BACKENDS) -class RNNTest(parameterized.TestCase): - - def test_rnnlm_forward_shape(self, backend): - with fastmath.use_backend(backend): - model = rnn.RNNLM(vocab_size=20, d_model=16) - x = np.ones((3, 28)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 28, 20)) - - def test_grulm_forward_shape(self, backend): - with fastmath.use_backend(backend): - model = rnn.GRULM(vocab_size=20, d_model=16) - x = np.ones((3, 28)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 28, 20)) - - def test_lstmseq2seqattn_forward_shape(self, backend): - with fastmath.use_backend(backend): - model = rnn.LSTMSeq2SeqAttn( - input_vocab_size=20, target_vocab_size=20, d_model=16) - x = np.ones((3, 28)).astype(np.int32) - _, _ = model.init([shapes.signature(x), shapes.signature(x)]) - ys = model([x, x]) - self.assertEqual([y.shape for y in ys], [(3, 28, 20), (3, 28)]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/transformer.py b/trax/models/transformer.py index ed9917baa..be3149aee 100644 --- a/trax/models/transformer.py +++ b/trax/models/transformer.py @@ -21,596 +21,607 @@ from trax import layers as tl - # Defaults used across Transformer variants. -MODE = 'train' +MODE = "train" D_MODEL = 512 D_FF = 2048 N_LAYERS = 6 N_HEADS = 8 MAX_SEQUENCE_LENGTH = 2048 -DROPOUT_RATE = .1 +DROPOUT_RATE = 0.1 DROPOUT_SHARED_AXES = None FF_ACTIVATION_TYPE = tl.Relu -def TransformerEncoder(vocab_size, - n_classes=10, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer encoder suitable for N-way classification. - - This model maps tokenized text to N-way (``n_classes``) activations: - - - input: Array representing a batch of text strings via token IDs plus - padding markers; shape is (batch_size, sequence_length), where - sequence_length <= ``max_len``. Array elements are integers in - ``range(vocab_size)``, and 0 values mark padding positions. - - - output: Array representing a batch of raw (non-normalized) activations - over ``n_classes`` categories; shape is (batch_size, ``n_classes``). - - Args: - vocab_size: Input vocabulary size -- each element of the input array - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - n_classes: Last/innermost dimension of output arrays, suitable for N-way - classification. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of encoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder blocks. The same rate is also - used for attention dropout in encoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each encoder block will include dropout; else, it - will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. - - Returns: - A Transformer model that maps strings (conveyed by token IDs) to - raw (non-normalized) activations over a range of output classes. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _EncBlock(): - return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - tl.Branch([], tl.PaddingMask()), # Creates masks from copy of the tokens. - tl.Embedding(vocab_size, d_model), - _Dropout(), - tl.PositionalEncoding(max_len=max_len), - [_EncBlock() for _ in range(n_layers)], - tl.Select([0], n_in=2), # Drops the masks. - tl.LayerNorm(), - tl.Mean(axis=1), - tl.Dense(n_classes), - ) - - -def TransformerDecoder(vocab_size=None, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer decoder. - - This model maps sequential inputs to sequential outputs: - - - input if ``vocab_size`` is specified: array representing a batch - of text strings via token IDs plus padding markers; shape is - (batch_size, sequence_length). The tensor elements are integers in - ``range(vocab_size)``, and 0 values mark padding positions. - - - input if ``vocab_size`` is ``None``: 3-D array representing a batch of - sequences of activation vectors; shape is (batch_size, sequence_length, - ``d_model``). - - - output: 3-D array with shape (batch_size, sequence_length, ``d_model``). - - The model uses causal attention and does *not* shift the input to the right. - Thus, the output for position `t` is based on inputs up to and including - position `t`. - - Args: - vocab_size: If specified, gives the input vocabulary size -- each element - of the input tensor should be an integer in ``range(vocab_size)``. - If ``None``, indicates that the model expects as input sequences of - floating point vectors, each with ``d_model`` components. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of decoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also - used for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each encoder block will include dropout; else, it - will pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. - - Returns: - If ``vocab_size`` is defined: a Transformer model that maps strings - (conveyed by token IDs) to sequences of activation vectors. - - If ``vocab_size`` is ``None``: a Transformer model that maps sequences of - activation vectors to sequences of activation vectors. - """ - def _EmbeddingOrDense(): - return (tl.Embedding(vocab_size, d_model) if vocab_size is not None - else tl.Dense(d_model)) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _DecBlock(): - return _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - _EmbeddingOrDense(), - _Dropout(), - tl.PositionalEncoding(max_len=max_len), - [_DecBlock() for _ in range(n_layers)], - tl.LayerNorm(), - ) - - -def TransformerLM(vocab_size, - d_model=D_MODEL, - d_ff=D_FF, - n_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a Transformer language model. - - This model performs autoregressive language modeling: - - - input: Array representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). Array - elements are integers in ``range(vocab_size)``, and 0 values mark padding - positions. - - - output: 3-D array of raw activations with last/innermost dimension of - ``vocab_size``, suitable for decoding into a batch of token strings; - shape is (batch_size, sequence_length, ``vocab_size``). - - This model uses only the decoder part of the overall Transformer. - - Args: - vocab_size: Input vocabulary size -- each element of the input array - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_layers: Number of decoder blocks. Each block includes attention, dropout, - residual, layer-norm, feedforward (:py:class:`Dense`), and activation - layers. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also - used for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'predict'``, use fast inference. If ``'train'``, each decoder - block will include dropout; else, it will pass all values through - unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of :py:class:`Layer`. - - Returns: - A Transformer language model that maps strings (represented as token ID - sequences) to sequences of raw (non-normalized) activation vectors; each - vector in the sequence can be mapped (e.g., by `argmax`) to a token ID. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _DecBlock(): - return _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - return tl.Serial( - tl.ShiftRight(mode=mode), - tl.Embedding(vocab_size, d_model), - _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=mode), - [_DecBlock() for _ in range(n_layers)], - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -def Transformer(input_vocab_size, - output_vocab_size=None, - d_model=D_MODEL, - d_ff=D_FF, - n_encoder_layers=N_LAYERS, - n_decoder_layers=N_LAYERS, - n_heads=N_HEADS, - max_len=MAX_SEQUENCE_LENGTH, - dropout=DROPOUT_RATE, - dropout_shared_axes=DROPOUT_SHARED_AXES, - mode=MODE, - ff_activation=FF_ACTIVATION_TYPE): - """Returns a full Transformer model. - - This model is an encoder-decoder that performs tokenized string-to-string - ("source"-to-"target") transduction: - - - inputs (2): - - - source: Array representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length), - where sequence_length <= ``max_len``. Array elements are integers in - ``range(input_vocab_size)``, and 0 values mark padding positions. - - - target: Array representing a batch of text strings via token - IDs plus padding markers; shape is (batch_size, sequence_length), - where sequence_length <= ``max_len``. Array elements are integers in - ``range(output_vocab_size)``, and 0 values mark padding positions. - - - output: 3-D array of raw activations with last/innermost dimension of - ``output_vocab_size``, suitable for decoding into a batch of token - strings; shape is (batch_size, sequence_length, ``vocab_size``). - - An example use would be to translate (tokenized) sentences from English to - German. - - Args: - input_vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in ``range(vocab_size)``. These integers typically - represent token IDs from a vocabulary-based tokenizer. - output_vocab_size: If specified, gives the vocabulary size for the targets; - if ``None``, then input and target integers (token IDs) are assumed to - come from the same vocabulary. - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each encoder block. - n_encoder_layers: Number of encoder blocks. - n_decoder_layers: Number of decoder blocks. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'predict'``, use fast inference. If ``'train'``, each - encoder/decoder block will include dropout; else, it will pass all - values through unaltered. - ff_activation: Type of activation function at the end of each - encoder/decoder block; must be an activation-type subclass of - :py:class:`Layer`. - - Returns: - A Transformer model as a layer that maps from a source-target tokenized - text pair to activations over a vocab set. - """ - # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise. - encoder_mode = 'eval' if mode == 'predict' else mode - - # Share embedding weights if no separate output vocab size. - in_embedder = tl.Embedding(input_vocab_size, d_model) - if output_vocab_size is None: - out_embedder = in_embedder - output_vocab_size = input_vocab_size - else: - out_embedder = tl.Embedding(output_vocab_size, d_model) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _EncBlock(): - return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation) - - def _Encoder(): - encoder = tl.Serial( - in_embedder, +def TransformerEncoder( + vocab_size, + n_classes=10, + d_model=D_MODEL, + d_ff=D_FF, + n_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a Transformer encoder suitable for N-way classification. + + This model maps tokenized text to N-way (``n_classes``) activations: + + - input: Array representing a batch of text strings via token IDs plus + padding markers; shape is (batch_size, sequence_length), where + sequence_length <= ``max_len``. Array elements are integers in + ``range(vocab_size)``, and 0 values mark padding positions. + + - output: Array representing a batch of raw (non-normalized) activations + over ``n_classes`` categories; shape is (batch_size, ``n_classes``). + + Args: + vocab_size: Input vocabulary size -- each element of the input array + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + n_classes: Last/innermost dimension of output arrays, suitable for N-way + classification. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_layers: Number of encoder blocks. Each block includes attention, dropout, + residual, layer-norm, feedforward (:py:class:`Dense`), and activation + layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder blocks. The same rate is also + used for attention dropout in encoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each encoder block will include dropout; else, it + will pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder + block; must be an activation-type subclass of :py:class:`Layer`. + + Returns: + A Transformer model that maps strings (conveyed by token IDs) to + raw (non-normalized) activations over a range of output classes. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _EncBlock(): + return _EncoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + return tl.Serial( + tl.Branch([], tl.PaddingMask()), # Creates masks from copy of the tokens. + tl.Embedding(vocab_size, d_model), + _Dropout(), + tl.PositionalEncoding(max_len=max_len), + [_EncBlock() for _ in range(n_layers)], + tl.Select([0], n_in=2), # Drops the masks. + tl.LayerNorm(), + tl.Mean(axis=1), + tl.Dense(n_classes), + ) + + +def TransformerDecoder( + vocab_size=None, + d_model=D_MODEL, + d_ff=D_FF, + n_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a Transformer decoder. + + This model maps sequential inputs to sequential outputs: + + - input if ``vocab_size`` is specified: array representing a batch + of text strings via token IDs plus padding markers; shape is + (batch_size, sequence_length). The tensor elements are integers in + ``range(vocab_size)``, and 0 values mark padding positions. + + - input if ``vocab_size`` is ``None``: 3-D array representing a batch of + sequences of activation vectors; shape is (batch_size, sequence_length, + ``d_model``). + + - output: 3-D array with shape (batch_size, sequence_length, ``d_model``). + + The model uses causal attention and does *not* shift the input to the right. + Thus, the output for position `t` is based on inputs up to and including + position `t`. + + Args: + vocab_size: If specified, gives the input vocabulary size -- each element + of the input tensor should be an integer in ``range(vocab_size)``. + If ``None``, indicates that the model expects as input sequences of + floating point vectors, each with ``d_model`` components. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_layers: Number of decoder blocks. Each block includes attention, dropout, + residual, layer-norm, feedforward (:py:class:`Dense`), and activation + layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within decoder blocks. The same rate is also + used for attention dropout in decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each encoder block will include dropout; else, it + will pass all values through unaltered. + ff_activation: Type of activation function at the end of each encoder + block; must be an activation-type subclass of :py:class:`Layer`. + + Returns: + If ``vocab_size`` is defined: a Transformer model that maps strings + (conveyed by token IDs) to sequences of activation vectors. + + If ``vocab_size`` is ``None``: a Transformer model that maps sequences of + activation vectors to sequences of activation vectors. + """ + + def _EmbeddingOrDense(): + return ( + tl.Embedding(vocab_size, d_model) + if vocab_size is not None + else tl.Dense(d_model) + ) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _DecBlock(): + return _DecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + return tl.Serial( + _EmbeddingOrDense(), _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), - [_EncBlock() for _ in range(n_encoder_layers)], + tl.PositionalEncoding(max_len=max_len), + [_DecBlock() for _ in range(n_layers)], tl.LayerNorm(), ) - return tl.Cache(encoder) if mode == 'predict' else encoder - - def _EncDecBlock(): - return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation) - - # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e - # Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d - return tl.Serial( - tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss. - - # Encode. - tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d - _Encoder(), - - # Decode. - tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e ..... - tl.ShiftRight(mode=mode), - out_embedder, - _Dropout(), - tl.PositionalEncoding(max_len=max_len, mode=mode), - tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... - [_EncDecBlock() for _ in range(n_decoder_layers)], - tl.LayerNorm(), - tl.Select([0], n_in=3), # Drops masks and encoding vectors. - - # Map vectors to match output vocab size. - tl.Dense(output_vocab_size), - ) - - -def _EncoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers that implements a Transformer encoder block. - - The input to the block is a pair (activations, mask) where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. The block's outputs are the same type/shape as its inputs, - so that multiple blocks can be chained together. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder blocks. The same rate is also used - for attention dropout in encoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that act in series as a (repeatable) encoder block. - """ - def _Attention(): - return tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - return [ - tl.Residual( - tl.LayerNorm(), - _Attention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _DecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers that implements a Transformer decoder block. - - The input to the block is a pair (activations, mask) where the mask encodes - causal connections, preventing attention to future positions in the sequence. - The block's outputs are the same type/shape as its inputs, so that multiple - blocks can be chained together. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within decoder blocks. The same rate is also used - for attention dropout in decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that act in series as a (repeatable) decoder block. - """ - def _CausalAttention(): - return tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, - mode=mode), - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Residual( - tl.LayerNorm(), - _CausalAttention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _EncoderDecoderBlock(d_model, - d_ff, - n_heads, - dropout, - dropout_shared_axes, - mode, - ff_activation): - """Returns a list of layers implementing a Transformer encoder-decoder block. - - The block input is a triple (decoder_activations, mask, encoder_activations) - where the mask was created from the original input token IDs to prevent - attending to padding positions for that input. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within encoder/decoder blocks. The same rate is - also used for attention dropout in encoder/decoder blocks. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that act in series as a (repeatable) encoder-decoder - block. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - def _AttentionQKV(): - return tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, - mode=mode, cache_KV_in_predict=True) - - def _CausalAttention(): - return tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) - - def _FFBlock(): - return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, - ff_activation) - - return [ # vec_d masks vec_e - tl.Residual( - tl.LayerNorm(), - _CausalAttention(), - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e - _AttentionQKV(), # vec_d masks vec_e - _Dropout(), - ), - tl.Residual( - tl.LayerNorm(), - _FFBlock(), - _Dropout(), - ), - ] - - -def _FeedForwardBlock(d_model, - d_ff, - dropout, - dropout_shared_axes, - mode, - activation): - """Returns a list of layers that implements a feedforward block. - - Args: - d_model: Last/innermost dimension of activation arrays at most points in - the model, including the initial embedding output. - d_ff: Last/innermost dimension of special (typically wider) - :py:class:`Dense` layer in the feedforward part of each block. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) - is a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If ``'train'``, each block will include dropout; else, it will - pass all values through unaltered. - activation: Type of activation function at the end of each block; must - be an activation-type subclass of :py:class:`Layer`. - - Returns: - A list of layers that maps vectors to vectors. - """ - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Dense(d_ff), - activation(), - _Dropout(), - tl.Dense(d_model), - ] + + +def TransformerLM( + vocab_size, + d_model=D_MODEL, + d_ff=D_FF, + n_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a Transformer language model. + + This model performs autoregressive language modeling: + + - input: Array representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). Array + elements are integers in ``range(vocab_size)``, and 0 values mark padding + positions. + + - output: 3-D array of raw activations with last/innermost dimension of + ``vocab_size``, suitable for decoding into a batch of token strings; + shape is (batch_size, sequence_length, ``vocab_size``). + + This model uses only the decoder part of the overall Transformer. + + Args: + vocab_size: Input vocabulary size -- each element of the input array + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_layers: Number of decoder blocks. Each block includes attention, dropout, + residual, layer-norm, feedforward (:py:class:`Dense`), and activation + layers. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within decoder blocks. The same rate is also + used for attention dropout in decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'predict'``, use fast inference. If ``'train'``, each decoder + block will include dropout; else, it will pass all values through + unaltered. + ff_activation: Type of activation function at the end of each encoder + block; must be an activation-type subclass of :py:class:`Layer`. + + Returns: + A Transformer language model that maps strings (represented as token ID + sequences) to sequences of raw (non-normalized) activation vectors; each + vector in the sequence can be mapped (e.g., by `argmax`) to a token ID. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _DecBlock(): + return _DecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + return tl.Serial( + tl.ShiftRight(mode=mode), + tl.Embedding(vocab_size, d_model), + _Dropout(), + tl.PositionalEncoding(max_len=max_len, mode=mode), + [_DecBlock() for _ in range(n_layers)], + tl.LayerNorm(), + tl.Dense(vocab_size), + ) + + +def Transformer( + input_vocab_size, + output_vocab_size=None, + d_model=D_MODEL, + d_ff=D_FF, + n_encoder_layers=N_LAYERS, + n_decoder_layers=N_LAYERS, + n_heads=N_HEADS, + max_len=MAX_SEQUENCE_LENGTH, + dropout=DROPOUT_RATE, + dropout_shared_axes=DROPOUT_SHARED_AXES, + mode=MODE, + ff_activation=FF_ACTIVATION_TYPE, +): + """Returns a full Transformer model. + + This model is an encoder-decoder that performs tokenized string-to-string + ("source"-to-"target") transduction: + + - inputs (2): + + - source: Array representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length), + where sequence_length <= ``max_len``. Array elements are integers in + ``range(input_vocab_size)``, and 0 values mark padding positions. + + - target: Array representing a batch of text strings via token + IDs plus padding markers; shape is (batch_size, sequence_length), + where sequence_length <= ``max_len``. Array elements are integers in + ``range(output_vocab_size)``, and 0 values mark padding positions. + + - output: 3-D array of raw activations with last/innermost dimension of + ``output_vocab_size``, suitable for decoding into a batch of token + strings; shape is (batch_size, sequence_length, ``vocab_size``). + + An example use would be to translate (tokenized) sentences from English to + German. + + Args: + input_vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in ``range(vocab_size)``. These integers typically + represent token IDs from a vocabulary-based tokenizer. + output_vocab_size: If specified, gives the vocabulary size for the targets; + if ``None``, then input and target integers (token IDs) are assumed to + come from the same vocabulary. + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each encoder block. + n_encoder_layers: Number of encoder blocks. + n_decoder_layers: Number of decoder blocks. + n_heads: Number of attention heads. + max_len: Maximum symbol length for positional encoding. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder/decoder blocks. The same rate is + also used for attention dropout in encoder/decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'predict'``, use fast inference. If ``'train'``, each + encoder/decoder block will include dropout; else, it will pass all + values through unaltered. + ff_activation: Type of activation function at the end of each + encoder/decoder block; must be an activation-type subclass of + :py:class:`Layer`. + + Returns: + A Transformer model as a layer that maps from a source-target tokenized + text pair to activations over a vocab set. + """ + # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise. + encoder_mode = "eval" if mode == "predict" else mode + + # Share embedding weights if no separate output vocab size. + in_embedder = tl.Embedding(input_vocab_size, d_model) + if output_vocab_size is None: + out_embedder = in_embedder + output_vocab_size = input_vocab_size + else: + out_embedder = tl.Embedding(output_vocab_size, d_model) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _EncBlock(): + return _EncoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + def _Encoder(): + encoder = tl.Serial( + in_embedder, + _Dropout(), + tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), + [_EncBlock() for _ in range(n_encoder_layers)], + tl.LayerNorm(), + ) + return tl.Cache(encoder) if mode == "predict" else encoder + + def _EncDecBlock(): + return _EncoderDecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation + ) + + # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e + # Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d + return tl.Serial( + tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss. + # Encode. + tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d + _Encoder(), + # Decode. + tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e ..... + tl.ShiftRight(mode=mode), + out_embedder, + _Dropout(), + tl.PositionalEncoding(max_len=max_len, mode=mode), + tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... + [_EncDecBlock() for _ in range(n_decoder_layers)], + tl.LayerNorm(), + tl.Select([0], n_in=3), # Drops masks and encoding vectors. + # Map vectors to match output vocab size. + tl.Dense(output_vocab_size), + ) + + +def _EncoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation +): + """Returns a list of layers that implements a Transformer encoder block. + + The input to the block is a pair (activations, mask) where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. The block's outputs are the same type/shape as its inputs, + so that multiple blocks can be chained together. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder blocks. The same rate is also used + for attention dropout in encoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that act in series as a (repeatable) encoder block. + """ + + def _Attention(): + return tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _FFBlock(): + return _FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + return [ + tl.Residual( + tl.LayerNorm(), + _Attention(), + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + _FFBlock(), + _Dropout(), + ), + ] + + +def _DecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation +): + """Returns a list of layers that implements a Transformer decoder block. + + The input to the block is a pair (activations, mask) where the mask encodes + causal connections, preventing attention to future positions in the sequence. + The block's outputs are the same type/shape as its inputs, so that multiple + blocks can be chained together. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within decoder blocks. The same rate is also used + for attention dropout in decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that act in series as a (repeatable) decoder block. + """ + + def _CausalAttention(): + return ( + tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), + ) + + def _FFBlock(): + return _FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + tl.Residual( + tl.LayerNorm(), + _CausalAttention(), + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + _FFBlock(), + _Dropout(), + ), + ] + + +def _EncoderDecoderBlock( + d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation +): + """Returns a list of layers implementing a Transformer encoder-decoder block. + + The block input is a triple (decoder_activations, mask, encoder_activations) + where the mask was created from the original input token IDs to prevent + attending to padding positions for that input. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within encoder/decoder blocks. The same rate is + also used for attention dropout in encoder/decoder blocks. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that act in series as a (repeatable) encoder-decoder + block. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + def _AttentionQKV(): + return tl.AttentionQKV( + d_model, + n_heads=n_heads, + dropout=dropout, + mode=mode, + cache_KV_in_predict=True, + ) + + def _CausalAttention(): + return tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) + + def _FFBlock(): + return _FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation + ) + + return [ # vec_d masks vec_e + tl.Residual( + tl.LayerNorm(), + _CausalAttention(), + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e + _AttentionQKV(), # vec_d masks vec_e + _Dropout(), + ), + tl.Residual( + tl.LayerNorm(), + _FFBlock(), + _Dropout(), + ), + ] + + +def _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, activation): + """Returns a list of layers that implements a feedforward block. + + Args: + d_model: Last/innermost dimension of activation arrays at most points in + the model, including the initial embedding output. + d_ff: Last/innermost dimension of special (typically wider) + :py:class:`Dense` layer in the feedforward part of each block. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) + is a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If ``'train'``, each block will include dropout; else, it will + pass all values through unaltered. + activation: Type of activation function at the end of each block; must + be an activation-type subclass of :py:class:`Layer`. + + Returns: + A list of layers that maps vectors to vectors. + """ + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + tl.Dense(d_ff), + activation(), + _Dropout(), + tl.Dense(d_model), + ] diff --git a/trax/models/transformer_test.py b/trax/models/transformer_test.py deleted file mode 100644 index 017b1d4e0..000000000 --- a/trax/models/transformer_test.py +++ /dev/null @@ -1,70 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Transformer models.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax import fastmath -from trax import shapes -from trax.layers import test_utils -from trax.models import transformer - - -class TransformerTest(parameterized.TestCase): - - def test_transformer_lm_forward_shape(self): - vocab_size = 16 - model = transformer.TransformerLM( - vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) - x = np.ones((3, 5)).astype(np.int32) - _, _ = model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (3, 5, vocab_size)) - - def _test_transformer_forward_shape(self, input_vocab_size, - output_vocab_size): - model = transformer.Transformer( - input_vocab_size, output_vocab_size, d_model=32, d_ff=64, - n_encoder_layers=2, n_decoder_layers=2, n_heads=2) - xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] - _, _ = model.init(shapes.signature(xs)) - y, _ = model(xs) - - vocab_size = output_vocab_size or input_vocab_size - self.assertEqual(y.shape, (3, 5, vocab_size)) - - @parameterized.named_parameters( - ('same_vocab', 16, None), - ('same_size', 16, 16), - ('different_size', 16, 50)) - def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): - """Run the Transformer forward and check output shape.""" - self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) - - - def test_dot_product_causal_attention_fast_inference(self): - model_fn = functools.partial( - transformer.TransformerLM, d_model=4, d_ff=8, n_layers=2, n_heads=2 - ) - test_utils.test_eval_equals_predict_discrete(model_fn) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/optimizers/__init__.py b/trax/optimizers/__init__.py index 1ec623abe..ecefb40b5 100644 --- a/trax/optimizers/__init__.py +++ b/trax/optimizers/__init__.py @@ -17,24 +17,17 @@ import gin -from trax.optimizers import adafactor -from trax.optimizers import adam -from trax.optimizers import base -from trax.optimizers import momentum -from trax.optimizers import rms_prop -from trax.optimizers import sm3 -from trax.optimizers import trainer -from trax.optimizers.trainer import ReversibleSerialTrainer -from trax.optimizers.trainer import Trainer +from trax.optimizers import adafactor, adam, momentum, rms_prop, sgd, sm3 def opt_configure(*args, **kwargs): - kwargs['module'] = 'trax.optimizers' - return gin.external_configurable(*args, **kwargs) + kwargs["module"] = "trax.optimizers" + return gin.external_configurable(*args, **kwargs) + # Optimizers (using upper-case names). # pylint: disable=invalid-name -SGD = opt_configure(base.SGD) +SGD = opt_configure(sgd.SGD) Momentum = opt_configure(momentum.Momentum) RMSProp = opt_configure(rms_prop.RMSProp) Adam = opt_configure(adam.Adam) diff --git a/trax/optimizers/adafactor.py b/trax/optimizers/adafactor.py index 501290246..a1fbe5d6a 100644 --- a/trax/optimizers/adafactor.py +++ b/trax/optimizers/adafactor.py @@ -20,142 +20,148 @@ class Adafactor(opt_base.Optimizer): - """Adafactor optimizer, as described in https://arxiv.org/abs/1804.04235.""" + """Adafactor optimizer, as described in https://arxiv.org/abs/1804.04235.""" - def __init__(self, - learning_rate=0.05, - factored=True, - multiply_by_parameter_scale=True, - do_clipping=True, - do_momentum=False, - momentum_in_bfloat16=False, - beta1=0.0, - decay_rate=0.8, - clipping_threshold=1.0, - weight_decay_rate=1e-5, - weight_decay_n_steps=0, - epsilon1=1e-16, - epsilon2=1e-3): - """Create the Adafactor optimizer. + def __init__( + self, + learning_rate=0.05, + factored=True, + multiply_by_parameter_scale=True, + do_clipping=True, + do_momentum=False, + momentum_in_bfloat16=False, + beta1=0.0, + decay_rate=0.8, + clipping_threshold=1.0, + weight_decay_rate=1e-5, + weight_decay_n_steps=0, + epsilon1=1e-16, + epsilon2=1e-3, + ): + """Create the Adafactor optimizer. - Adafactor is described in https://arxiv.org/abs/1804.04235. + Adafactor is described in https://arxiv.org/abs/1804.04235. - Args: - learning_rate: float: trax-provided learning rate. - factored: boolean: whether to use factored second-moment estimator for 2d - variables. - multiply_by_parameter_scale: boolean: if True, then scale provided - learning_rate by parameter norm. if False, provided learning_rate is - absolute step size. - do_clipping: whether to clip gradients; if True, set clipping_theshold. - do_momentum: whether to use momentum; if True, set beta1. - momentum_in_bfloat16: if True, store momentum in bfloat16 to save memory. - beta1: a float value between 0 and 1, enables momentum and uses extra - memory if nonzero! Off by default. - decay_rate: float: controls second-moment exponential decay schedule. - clipping_threshold: an optional float >= 1, if None no update clipping. - weight_decay_rate: rate at which to decay weights. - weight_decay_n_steps: for how many steps to decay weights (always if None) - epsilon1: Regularization constant for squared gradient. - epsilon2: Regularization constant for parameter scale. - """ - # These 4 parameters are not configurable once the class is created. - self._factored = factored - self._multiply_by_parameter_scale = multiply_by_parameter_scale - self._do_clipping = do_clipping - self._do_momentum = do_momentum - self._momentum_in_bfloat16 = momentum_in_bfloat16 - # Dynamically configurable parameters will be passed to the update function. - super().__init__( - learning_rate=learning_rate, - beta1=beta1, - decay_rate=decay_rate, - clipping_threshold=clipping_threshold, - weight_decay_rate=weight_decay_rate, - weight_decay_n_steps=weight_decay_n_steps, - epsilon1=epsilon1, - epsilon2=epsilon2, - ) + Args: + learning_rate: float: trax-provided learning rate. + factored: boolean: whether to use factored second-moment estimator for 2d + variables. + multiply_by_parameter_scale: boolean: if True, then scale provided + learning_rate by parameter norm. if False, provided learning_rate is + absolute step size. + do_clipping: whether to clip gradients; if True, set clipping_theshold. + do_momentum: whether to use momentum; if True, set beta1. + momentum_in_bfloat16: if True, store momentum in bfloat16 to save memory. + beta1: a float value between 0 and 1, enables momentum and uses extra + memory if nonzero! Off by default. + decay_rate: float: controls second-moment exponential decay schedule. + clipping_threshold: an optional float >= 1, if None no update clipping. + weight_decay_rate: rate at which to decay weights. + weight_decay_n_steps: for how many steps to decay weights (always if None) + epsilon1: Regularization constant for squared gradient. + epsilon2: Regularization constant for parameter scale. + """ + # These 4 parameters are not configurable once the class is created. + self._factored = factored + self._multiply_by_parameter_scale = multiply_by_parameter_scale + self._do_clipping = do_clipping + self._do_momentum = do_momentum + self._momentum_in_bfloat16 = momentum_in_bfloat16 + # Dynamically configurable parameters will be passed to the update function. + super().__init__( + learning_rate=learning_rate, + beta1=beta1, + decay_rate=decay_rate, + clipping_threshold=clipping_threshold, + weight_decay_rate=weight_decay_rate, + weight_decay_n_steps=weight_decay_n_steps, + epsilon1=epsilon1, + epsilon2=epsilon2, + ) - @staticmethod - def _decay_rate_pow(i, exponent=0.8): - """Default Adafactor second-moment decay schedule.""" - t = jnp.array(i, jnp.float32) + 1.0 - return 1.0 - t**(-exponent) + @staticmethod + def _decay_rate_pow(i, exponent=0.8): + """Default Adafactor second-moment decay schedule.""" + t = jnp.array(i, jnp.float32) + 1.0 + return 1.0 - t ** (-exponent) - def init(self, weights): - shape = weights.shape - slots = [] - if self._factored and len(shape) >= 2: - v_row = jnp.zeros(shape[:-1], dtype=jnp.float32) - v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32) - slots.extend([v_row, v_col]) - else: - v = jnp.zeros_like(weights) - slots.append(v) - if self._do_momentum: - m = jnp.zeros_like(weights) - if self._momentum_in_bfloat16: - m = m.astype(jnp.bfloat16) - slots.append(m) - return slots + def init(self, weights): + shape = weights.shape + slots = [] + if self._factored and len(shape) >= 2: + v_row = jnp.zeros(shape[:-1], dtype=jnp.float32) + v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32) + slots.extend([v_row, v_col]) + else: + v = jnp.zeros_like(weights) + slots.append(v) + if self._do_momentum: + m = jnp.zeros_like(weights) + if self._momentum_in_bfloat16: + m = m.astype(jnp.bfloat16) + slots.append(m) + return slots - def update(self, step, grads, weights, slots, opt_params): - updates = [] - learning_rate = opt_params['learning_rate'] - beta1 = opt_params['beta1'] - decay_rate = opt_params['decay_rate'] - clipping_threshold = opt_params['clipping_threshold'] - weight_decay_rate = opt_params['weight_decay_rate'] - weight_decay_n_steps = opt_params['weight_decay_n_steps'] - weight_decay_rate = jnp.where( - weight_decay_n_steps < 1, # if weight_decay_n_steps == 0, ignore it - weight_decay_rate, - (weight_decay_rate * jnp.maximum(weight_decay_n_steps - step, 0.0) / - jnp.maximum(weight_decay_n_steps, 0.0))) - epsilon1 = opt_params['epsilon1'] - epsilon2 = opt_params['epsilon2'] - decay_rate = self._decay_rate_pow(step, exponent=decay_rate) - update_scale = learning_rate - if self._multiply_by_parameter_scale: - update_scale *= jnp.maximum( - jnp.sqrt(jnp.mean(weights * weights)), epsilon2) - mixing_rate = 1.0 - decay_rate + def update(self, step, grads, weights, slots, opt_params): + updates = [] + learning_rate = opt_params["learning_rate"] + beta1 = opt_params["beta1"] + decay_rate = opt_params["decay_rate"] + clipping_threshold = opt_params["clipping_threshold"] + weight_decay_rate = opt_params["weight_decay_rate"] + weight_decay_n_steps = opt_params["weight_decay_n_steps"] + weight_decay_rate = jnp.where( + weight_decay_n_steps < 1, # if weight_decay_n_steps == 0, ignore it + weight_decay_rate, + ( + weight_decay_rate + * jnp.maximum(weight_decay_n_steps - step, 0.0) + / jnp.maximum(weight_decay_n_steps, 0.0) + ), + ) + epsilon1 = opt_params["epsilon1"] + epsilon2 = opt_params["epsilon2"] + decay_rate = self._decay_rate_pow(step, exponent=decay_rate) + update_scale = learning_rate + if self._multiply_by_parameter_scale: + update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)), epsilon2) + mixing_rate = 1.0 - decay_rate - grads_sqr = grads * grads - if self._factored and len(weights.shape) >= 2: - v_row = slots[0] # In this case, the slots are (v_row, v_col, ...). - v_col = slots[1] - new_v_row = ( - decay_rate * v_row + mixing_rate * jnp.mean(grads_sqr, axis=-1)) - new_v_col = ( - decay_rate * v_col + mixing_rate * jnp.mean(grads_sqr, axis=-2)) - updates.extend([new_v_row, new_v_col]) - row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True) - row_factor = (row_mean / (new_v_row + epsilon1))**0.5 - col_factor = (new_v_col + epsilon1)**-0.5 - y = ( - grads * jnp.expand_dims(row_factor, axis=-1) * - jnp.expand_dims(col_factor, axis=-2)) - else: - v = slots[0] # In this case, the slots are (v, ...) - new_v = decay_rate * v + mixing_rate * grads_sqr - updates.append(new_v) - y = grads * (new_v + epsilon1)**-0.5 + grads_sqr = grads * grads + if self._factored and len(weights.shape) >= 2: + v_row = slots[0] # In this case, the slots are (v_row, v_col, ...). + v_col = slots[1] + new_v_row = decay_rate * v_row + mixing_rate * jnp.mean(grads_sqr, axis=-1) + new_v_col = decay_rate * v_col + mixing_rate * jnp.mean(grads_sqr, axis=-2) + updates.extend([new_v_row, new_v_col]) + row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True) + row_factor = (row_mean / (new_v_row + epsilon1)) ** 0.5 + col_factor = (new_v_col + epsilon1) ** -0.5 + y = ( + grads + * jnp.expand_dims(row_factor, axis=-1) + * jnp.expand_dims(col_factor, axis=-2) + ) + else: + v = slots[0] # In this case, the slots are (v, ...) + new_v = decay_rate * v + mixing_rate * grads_sqr + updates.append(new_v) + y = grads * (new_v + epsilon1) ** -0.5 - if self._do_clipping: - clipping_denom = ( - jnp.maximum(1.0, jnp.sqrt(jnp.mean(y * y)) / clipping_threshold)) - y /= clipping_denom + if self._do_clipping: + clipping_denom = jnp.maximum( + 1.0, jnp.sqrt(jnp.mean(y * y)) / clipping_threshold + ) + y /= clipping_denom - subtrahend = update_scale * y - if self._do_momentum: - m = slots[-1] # Momentum is always the last slot (if used). - m = m.astype(subtrahend.dtype) # Accumulate in subtrahend dtype. - new_m = beta1 * m + (1.0 - beta1) * subtrahend - subtrahend = new_m - updates.append(new_m.astype(slots[-1].dtype)) # Back to bfloat if needed. + subtrahend = update_scale * y + if self._do_momentum: + m = slots[-1] # Momentum is always the last slot (if used). + m = m.astype(subtrahend.dtype) # Accumulate in subtrahend dtype. + new_m = beta1 * m + (1.0 - beta1) * subtrahend + subtrahend = new_m + updates.append(new_m.astype(slots[-1].dtype)) # Back to bfloat if needed. - new_weights = (1 - weight_decay_rate) * weights - subtrahend - # TODO(lukaszkaiser): why is the astype needed here? Check and correct. - return new_weights.astype(weights.dtype), updates + new_weights = (1 - weight_decay_rate) * weights - subtrahend + # TODO(lukaszkaiser): why is the astype needed here? Check and correct. + return new_weights.astype(weights.dtype), updates diff --git a/trax/optimizers/adam.py b/trax/optimizers/adam.py index e950eab9f..d25a6155f 100644 --- a/trax/optimizers/adam.py +++ b/trax/optimizers/adam.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + """Adam optimizer class.""" from trax.fastmath import numpy as jnp @@ -21,62 +22,72 @@ # pylint: disable=line-too-long class Adam(opt_base.Optimizer): - r"""Adam optimizer; described in https://arxiv.org/abs/1412.6980. - - The update rule for time step :math:`t`, given gradients :math:`g_t` and - "Stepsize" :math:`\alpha`, is: - - .. math:: + r"""Adam optimizer; described in https://arxiv.org/abs/1412.6980. + The update rule for time step :math:`t`, given gradients :math:`g_t` and "Stepsize" :math:`\alpha`, is: + .. math:: \hat{m}_t &\leftarrow \big(\beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t\big)\ /\ (1 - \beta_1^t) \\ \hat{v}_t &\leftarrow \big(\beta_2 \cdot m_{t-1} + (1 - \beta_2) \cdot g_t^2\big)\ /\ (1 - \beta_2^t) \\ \theta_t &\leftarrow \theta_{t-1} -\ \alpha \cdot \hat{m}_t / \big(\sqrt{\hat{v}_t} + \epsilon\big) """ - # pylint: enable=line-too-long - def __init__(self, learning_rate=0.0001, weight_decay_rate=1e-5, # pylint: disable=useless-super-delegation - b1=0.9, b2=0.999, eps=1e-5, clip_grad_norm=None): - r"""Creates an Adam optimizer. + # pylint: enable=line-too-long + def __init__( + self, + learning_rate=0.0001, + weight_decay_rate=1e-5, # pylint: disable=useless-super-delegation + b1=0.9, + b2=0.999, + eps=1e-5, + clip_grad_norm=None, + ): + r"""Creates an Adam optimizer. + + Args: + learning_rate: Initial (unadapted) learning rate :math:`\alpha`; original + paper calls this `Stepsize` and suggests .001 as a generally good + value. + weight_decay_rate: Fraction of prior weight values to subtract on each + step; equivalent to multiplying each weight element by + `1 - weight_decay_rate`. (This is not part of the core Adam + algorithm.) + b1: Exponential decay rate :math:`\beta_1` for first moment estimates. + b2: Exponential decay rate :math:`\beta_2` for second moment estimates. + eps: Small positive constant :math:`\epsilon` for numerical stability. + clip_grad_norm: Threshold value above which gradient clipping occurs. + (This is not part of the core Adam algorithm.) + """ + super().__init__( + learning_rate=learning_rate, + weight_decay_rate=weight_decay_rate, + b1=b1, + b2=b2, + eps=eps, + clip_grad_norm=clip_grad_norm, + ) + + def init(self, weights): + m = jnp.zeros_like(weights) + v = jnp.zeros_like(weights) + return m, v + + def update(self, step, grads, weights, slots, opt_params): + m, v = slots + + learning_rate = opt_params["learning_rate"] + weight_decay_rate = opt_params["weight_decay_rate"] + b1 = opt_params["b1"] + b2 = opt_params["b2"] + eps = opt_params["eps"] - Args: - learning_rate: Initial (unadapted) learning rate :math:`\alpha`; original - paper calls this `Stepsize` and suggests .001 as a generally good - value. - weight_decay_rate: Fraction of prior weight values to subtract on each - step; equivalent to multiplying each weight element by - `1 - weight_decay_rate`. (This is not part of the core Adam - algorithm.) - b1: Exponential decay rate :math:`\beta_1` for first moment estimates. - b2: Exponential decay rate :math:`\beta_2` for second moment estimates. - eps: Small positive constant :math:`\epsilon` for numerical stability. - clip_grad_norm: Threshold value above which gradient clipping occurs. - (This is not part of the core Adam algorithm.) - """ - super().__init__( - learning_rate=learning_rate, - weight_decay_rate=weight_decay_rate, - b1=b1, - b2=b2, - eps=eps, - clip_grad_norm=clip_grad_norm - ) + m = (1 - b1) * grads + b1 * m # First moment estimate. + v = (1 - b2) * (grads**2) + b2 * v # Second moment estimate. + mhat = m / (1 - b1 ** (step + 1)) # Bias correction. + vhat = v / (1 - b2 ** (step + 1)) - def init(self, weights): - m = jnp.zeros_like(weights) - v = jnp.zeros_like(weights) - return m, v + new_weights = ( + (1 - weight_decay_rate) * weights + - (learning_rate * mhat / (jnp.sqrt(vhat) + eps)) + ).astype(weights.dtype) - def update(self, step, grads, weights, slots, opt_params): - m, v = slots - learning_rate = opt_params['learning_rate'] - weight_decay_rate = opt_params['weight_decay_rate'] - b1 = opt_params['b1'] - b2 = opt_params['b2'] - eps = opt_params['eps'] - m = (1 - b1) * grads + b1 * m # First moment estimate. - v = (1 - b2) * (grads ** 2) + b2 * v # Second moment estimate. - mhat = m / (1 - b1 ** (step + 1)) # Bias correction. - vhat = v / (1 - b2 ** (step + 1)) - new_weights = ((1 - weight_decay_rate) * weights - ( - learning_rate * mhat / (jnp.sqrt(vhat) + eps))).astype(weights.dtype) - return new_weights, (m, v) + return new_weights, (m, v) diff --git a/trax/optimizers/base.py b/trax/optimizers/base.py index 269bc0a73..52100d263 100644 --- a/trax/optimizers/base.py +++ b/trax/optimizers/base.py @@ -20,234 +20,234 @@ class Optimizer: - """Base class for optimizers that work hand in hand with Trax layers. + """Base class for optimizers that work hand in hand with Trax layers. - To define an optimizer subclass, specify its behavior with respect to a - single node in the network (e.g., a single dense layer): + To define an optimizer subclass, specify its behavior with respect to a + single node in the network (e.g., a single dense layer): - - `init`: how to create/initialize optimizer-internal parameters ("slots"), - as a function of the node's weights. - - `update`: how to use gradient information to update node weights and - optimizer slots. + - `init`: how to create/initialize optimizer-internal parameters ("slots"), + as a function of the node's weights. + - `update`: how to use gradient information to update node weights and + optimizer slots. - The Trax runtime combines these node-local computations into layer weight - updates and optimizer slot updates for the whole tree of layers in the model. - """ - - def __init__(self, learning_rate=0.01, clip_grad_norm=None, - **init_opt_params): - """Sets initial hyperparameter values for this optimizer. - - Takes optimizer hyperparameters as keyword arguments. These values can - change over time (training steps), e.g., for learning rate schedules. - - To expose subclass hyperparameters for gin configuration, override this - constructor and use explicitly named keyword arguments. See - `momentum.Momentum.__init__` for one such example. - - Args: - learning_rate: Learning rate for the optimizer. This can change during - training by means of a training rate schedule. - clip_grad_norm: If specified, this scalar value is used to limit gradient - size -- all gradient elements in a training step are treated as if - they belonged to a single vector and then scaled back if needed so - that such a vector's L2 norm does not exceed `clip_grad_norm`. If - None, no clipping happens. - **init_opt_params: Initial values of any additional optimizer parameters. - """ - init_opt_params['learning_rate'] = learning_rate - self._init_opt_params = { - name: jnp.array(value) for (name, value) in init_opt_params.items() - } - self._slots = None - # Gradient clipping happens with respect to the norm of the whole gradient - # tree, so it is not passed to single-slot updates, but done in this class - # for the whole gradient tree. - self._clip_grad_norm = clip_grad_norm - - def init(self, weights): - """Creates optimizer slots that fit the given weights. - - Args: - weights: Trainable weights for one layer. Optimizer slots typically match - the data shape and type of the given layer weights. - """ - raise NotImplementedError - - def update(self, step, grads, weights, slots, opt_params): - """Computes updated layer weights and optimizer slots for one training step. - - Args: - step: Training step number. - grads: Gradient values for this node (from back-propagation during a - training step). - weights: Current weight values for this node (i.e., layer weights). - slots: Current slot values for this node. - opt_params: Optimizer hyperparameters (e.g. learning rate, momentum), - same across all nodes in the model. - - Returns: - Tuple of (new_weights, new_slots), which the Trax runtime will use to - update the model and optimizer within each training step. + The Trax runtime combines these node-local computations into layer weight + updates and optimizer slot updates for the whole tree of layers in the model. """ - raise NotImplementedError - - @property - def slots(self): - return self._slots - @slots.setter - def slots(self, slots): - self._slots = slots + def __init__(self, learning_rate=0.01, clip_grad_norm=None, **init_opt_params): + """Sets initial hyperparameter values for this optimizer. + + Takes optimizer hyperparameters as keyword arguments. These values can + change over time (training steps), e.g., for learning rate schedules. + + To expose subclass hyperparameters for gin configuration, override this + constructor and use explicitly named keyword arguments. See + `momentum.Momentum.__init__` for one such example. + + Args: + learning_rate: Learning rate for the optimizer. This can change during + training by means of a training rate schedule. + clip_grad_norm: If specified, this scalar value is used to limit gradient + size -- all gradient elements in a training step are treated as if + they belonged to a single vector and then scaled back if needed so + that such a vector's L2 norm does not exceed `clip_grad_norm`. If + None, no clipping happens. + **init_opt_params: Initial values of any additional optimizer parameters. + """ + init_opt_params["learning_rate"] = learning_rate + self._init_opt_params = { + name: jnp.array(value) for (name, value) in init_opt_params.items() + } + self._slots = None + # Gradient clipping happens with respect to the norm of the whole gradient + # tree, so it is not passed to single-slot updates, but done in this class + # for the whole gradient tree. + self._clip_grad_norm = clip_grad_norm + + def init(self, weights): + """Creates optimizer slots that fit the given weights. + + Args: + weights: Trainable weights for one layer. Optimizer slots typically match + the data shape and type of the given layer weights. + """ + raise NotImplementedError + + def update(self, step, grads, weights, slots, opt_params): + """Computes updated layer weights and optimizer slots for one training step. + + Args: + step: Training step number. + grads: Gradient values for this node (from back-propagation during a + training step). + weights: Current weight values for this node (i.e., layer weights). + slots: Current slot values for this node. + opt_params: Optimizer hyperparameters (e.g. learning rate, momentum), + same across all nodes in the model. + + Returns: + Tuple of (new_weights, new_slots), which the Trax runtime will use to + update the model and optimizer within each training step. + """ + raise NotImplementedError + + @property + def slots(self): + return self._slots + + @slots.setter + def slots(self, slots): + self._slots = slots + + @property + def opt_params(self): + return self._init_opt_params + + @opt_params.setter + def opt_params(self, opt_params): + self._init_opt_params = opt_params + + def tree_init(self, weight_tree): + """Assembles node-local initializations into full-tree initialization. + + Args: + weight_tree: Weights for an entire model, in a tree that matches the + model's layer structure. + + Returns: + Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer + slot values and `opt_params` are optimizer hyperparameters (e.g., + learning rate, momentum). + """ + self._slots = tuple( + self.init(weight) for weight in fastmath.tree_flatten(weight_tree) + ) + return (self._slots, self._init_opt_params) + + def tree_update( + self, step, grad_tree, weight_tree, slots, opt_params, store_slots=True + ): + """Assembles node-local weight and slot updates for the full layer tree. + + Args: + step: Current step number in the training process. + grad_tree: Gradients for the entire model, in a tree that matches the + model's layer structure. + weight_tree: Current weights for the entire model, in a tree that matches + the model's layer structure. + slots: Optimizer slots. + opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). + store_slots: Boolean; if True, stores resulting slots in this object; + when set to False, this becomes a pure function. + + Returns: + Tuple `(weights, slots)`, where `weights` are the optimizer-updated + weights for the whole model (in a tree matching the model's layer + structure) and `slots` are the updated optimizer slot values. + """ + grads_flat = fastmath.tree_flatten(grad_tree) + grads_norm = self._l2_norm(grads_flat) + if self._clip_grad_norm is not None: + max_norm = self._clip_grad_norm + grads_flat = [ + jnp.where( + grads_norm < max_norm, # pylint: disable=g-complex-comprehension + g, + g * (max_norm / grads_norm), + ) + for g in grads_flat + ] + weights_flat = fastmath.tree_flatten(weight_tree) + weights_norm = self._l2_norm(weights_flat) + updated_pairs = [ + self._update_and_check(step, grad, weight, slot, opt_params) + for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) + ] + new_weights_flat, slots = zip(*updated_pairs) + new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree) + metrics = {"gradients_l2": grads_norm, "weights_l2": weights_norm} + slots = tuple(slots) + if store_slots: + self.slots = slots + return new_weights, slots, metrics + + def _l2_norm(self, flat_list): + """Returns an L2-like norm of all elements of all tensors in `flat_list`. + + Args: + flat_list: Collection of tensors as a flat list (rather than, e.g., a + tree). + + Returns: + A scalar value computed as if all the tensors in `flat_list` were joined + and flattened into a single vector, and then the L2 norm of that vector + was calculated. + """ + if fastmath.is_backend(fastmath.Backend.JAX): + norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) + else: + norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list)) + + return norm + + def _update_and_check(self, step, grads, weights, slots, opt_params): + """Updates a single weight array and checks types.""" + new_weights, new_slots = self.update(step, grads, weights, slots, opt_params) + if isinstance(weights, jnp.ndarray): + if not isinstance(new_weights, jnp.ndarray): + raise ValueError( + f"New weight values should be of type jnp.ndarray or a subclass; " + f"instead got {type(new_weights)}." + ) + if new_weights.dtype != weights.dtype: + raise ValueError( + f"New weight values dtype ({new_weights.dtype}) does not match " + f"the old one ({weights.dtype})." + ) + return new_weights, new_slots - @property - def opt_params(self): - return self._init_opt_params - @opt_params.setter - def opt_params(self, opt_params): - self._init_opt_params = opt_params - - def tree_init(self, weight_tree): - """Assembles node-local initializations into full-tree initialization. - - Args: - weight_tree: Weights for an entire model, in a tree that matches the - model's layer structure. +# Utilities. - Returns: - Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer - slot values and `opt_params` are optimizer hyperparameters (e.g., - learning rate, momentum). - """ - self._slots = tuple(self.init(weight) - for weight in fastmath.tree_flatten(weight_tree)) - return (self._slots, self._init_opt_params) - def tree_update(self, step, grad_tree, weight_tree, slots, opt_params, - store_slots=True): - """Assembles node-local weight and slot updates for the full layer tree. +def l2_norm(tree): + """Returns an L2 norm computed over all elements of all tensors in `tree`. Args: - step: Current step number in the training process. - grad_tree: Gradients for the entire model, in a tree that matches the - model's layer structure. - weight_tree: Current weights for the entire model, in a tree that matches + tree: Tree-structured collection of tensors, e.g., model weights matching the model's layer structure. - slots: Optimizer slots. - opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). - store_slots: Boolean; if True, stores resulting slots in this object; - when set to False, this becomes a pure function. Returns: - Tuple `(weights, slots)`, where `weights` are the optimizer-updated - weights for the whole model (in a tree matching the model's layer - structure) and `slots` are the updated optimizer slot values. - """ - grads_flat = fastmath.tree_flatten(grad_tree) - grads_norm = self._l2_norm(grads_flat) - if self._clip_grad_norm is not None: - max_norm = self._clip_grad_norm - grads_flat = [jnp.where(grads_norm < max_norm, # pylint: disable=g-complex-comprehension - g, - g * (max_norm / grads_norm)) - for g in grads_flat] - weights_flat = fastmath.tree_flatten(weight_tree) - weights_norm = self._l2_norm(weights_flat) - updated_pairs = [ - self._update_and_check(step, grad, weight, slot, opt_params) - for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) - ] - new_weights_flat, slots = zip(*updated_pairs) - new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree) - metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} - slots = tuple(slots) - if store_slots: - self.slots = slots - return new_weights, slots, metrics - - def _l2_norm(self, flat_list): - """Returns an L2-like norm of all elements of all tensors in `flat_list`. - - Args: - flat_list: Collection of tensors as a flat list (rather than, e.g., a - tree). - - Returns: - A scalar value computed as if all the tensors in `flat_list` were joined + A scalar value computed as if all the tensors in `tree` were combined and flattened into a single vector, and then the L2 norm of that vector was calculated. """ - if fastmath.is_backend(fastmath.Backend.JAX): - norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) - else: # TODO(lukaszkaiser): add vdot to TF-numpy - norm = jnp.sqrt(sum(jnp.sum(x*x) for x in flat_list)) - return norm - - def _update_and_check(self, step, grads, weights, slots, opt_params): - """Updates a single weight array and checks types.""" - new_weights, new_slots = self.update( - step, grads, weights, slots, opt_params) - if isinstance(weights, jnp.ndarray): - if not isinstance(new_weights, jnp.ndarray): - raise ValueError( - f'New weight values should be of type jnp.ndarray or a subclass; ' - f'instead got {type(new_weights)}.') - if new_weights.dtype != weights.dtype: - raise ValueError( - f'New weight values dtype ({new_weights.dtype}) does not match ' - f'the old one ({weights.dtype}).') - return new_weights, new_slots - - -class SGD(Optimizer): - """Stochastic gradient descent (SGD) optimizer.""" - - def init(self, weights): - return None - - def update(self, step, grads, weights, slots, opt_params): - del step, slots - lr = opt_params['learning_rate'] - new_weights = weights - (lr * grads).astype(weights.dtype) - return new_weights, None - - -# Utilities. + leaves = fastmath.tree_flatten(tree) + if fastmath.is_backend(fastmath.Backend.JAX): + norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves)) + else: + norm = jnp.sqrt(sum(jnp.tensordot(x, x)) for x in leaves) -def l2_norm(tree): - """Returns an L2 norm computed over all elements of all tensors in `tree`. + return norm - Args: - tree: Tree-structured collection of tensors, e.g., model weights matching - the model's layer structure. - Returns: - A scalar value computed as if all the tensors in `tree` were combined - and flattened into a single vector, and then the L2 norm of that vector - was calculated. - """ - leaves = fastmath.tree_flatten(tree) - return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves)) +def clip_grads(grad_tree, max_norm): + """Proportionally reduces each gradient value to respect an aggregate limit. + Args: + grad_tree: Gradient values structured as a tree of tensors matching the + model's layer structure. + max_norm: The aggregate limit on gradient values. All gradient elements in + `grad_tree` are treated as if they belonged to a single vector and + that vector is shortened if needed so that its L2 norm does not exceed + `clip_grad_norm`. -def clip_grads(grad_tree, max_norm): - """Proportionally reduces each gradient value to respect an aggregate limit. - - Args: - grad_tree: Gradient values structured as a tree of tensors matching the - model's layer structure. - max_norm: The aggregate limit on gradient values. All gradient elements in - `grad_tree` are treated as if they belonged to a single vector and - that vector is shortened if needed so that its L2 norm does not exceed - `clip_grad_norm`. - - Returns: - A new tree of tensors matching the structure of `grad_tree`, but with - element values proportionally rescaled as needed to respect the `max_norm` - limit. - """ - norm = l2_norm(grad_tree) - normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm)) - return fastmath.nested_map(grad_tree, normalize) + Returns: + A new tree of tensors matching the structure of `grad_tree`, but with + element values proportionally rescaled as needed to respect the `max_norm` + limit. + """ + norm = l2_norm(grad_tree) + normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm)) + return fastmath.nested_map(grad_tree, normalize) diff --git a/trax/optimizers/momentum.py b/trax/optimizers/momentum.py index 625e7a2b4..5318507a6 100644 --- a/trax/optimizers/momentum.py +++ b/trax/optimizers/momentum.py @@ -21,7 +21,7 @@ # TODO(jonni): Consider renaming this class to NesterovMomentum. class Momentum(base.Optimizer): - r"""A momentum optimizer. + r"""A momentum optimizer. This class implements two variants of momentum stochastic gradient descent (SGD): with and without the Nesterov correction. The implementation of the @@ -41,32 +41,32 @@ class Momentum(base.Optimizer): (:math:`\alpha`) on the parameters, independent of the Nesterov momentum. """ - def __init__( - self, learning_rate=0.01, mass=0.9, weight_decay_rate=1e-5, nesterov=True - ): # pylint: disable=useless-super-delegation - super().__init__( - learning_rate=learning_rate, - mass=mass, - weight_decay_rate=weight_decay_rate, - ) - self._nesterov = nesterov + def __init__( + self, learning_rate=0.01, mass=0.9, weight_decay_rate=1e-5, nesterov=True + ): # pylint: disable=useless-super-delegation + super().__init__( + learning_rate=learning_rate, + mass=mass, + weight_decay_rate=weight_decay_rate, + ) + self._nesterov = nesterov - def init(self, weights): - return jnp.zeros_like(weights) + def init(self, weights): + return jnp.zeros_like(weights) - def update(self, step, grads, weights, velocity, opt_params): - del step - v = velocity - mu = opt_params['mass'] - alpha = opt_params['weight_decay_rate'] - epsilon = opt_params['learning_rate'] + def update(self, step, grads, weights, velocity, opt_params): + del step + v = velocity + mu = opt_params["mass"] + alpha = opt_params["weight_decay_rate"] + epsilon = opt_params["learning_rate"] - new_v = mu * v + grads - if self._nesterov: - weight_update = mu * new_v + grads - else: - weight_update = new_v - new_weights = (1 - alpha) * weights - epsilon * weight_update + new_v = mu * v + grads + if self._nesterov: + weight_update = mu * new_v + grads + else: + weight_update = new_v + new_weights = (1 - alpha) * weights - epsilon * weight_update - new_weights = new_weights.astype(weights.dtype) - return (new_weights, new_v) + new_weights = new_weights.astype(weights.dtype) + return (new_weights, new_v) diff --git a/trax/optimizers/rms_prop.py b/trax/optimizers/rms_prop.py index 351d05425..40786cb5b 100644 --- a/trax/optimizers/rms_prop.py +++ b/trax/optimizers/rms_prop.py @@ -20,30 +20,32 @@ class RMSProp(opt_base.Optimizer): - """RMSProp optimizer. - - Uses optimizer weights ("slots") to maintain a root-mean-square exponentially - decaying average of gradients from prior training batches. - """ - - def __init__(self, learning_rate=0.001, gamma=0.9, - eps=1e-8, clip_grad_norm=None): # pylint: disable=useless-super-delegation - super().__init__( - learning_rate=learning_rate, - gamma=gamma, - eps=eps, - clip_grad_norm=clip_grad_norm - ) - - def init(self, weights): - return jnp.ones_like(weights) - - def update(self, step, grads, weights, avg_sq_grad, opt_params): - del step - lr = opt_params['learning_rate'] - gamma = opt_params['gamma'] - eps = opt_params['eps'] - avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma) - weights = weights - (lr * grads / - (jnp.sqrt(avg_sq_grad) + eps)).astype(weights.dtype) - return weights, avg_sq_grad + """RMSProp optimizer. + + Uses optimizer weights ("slots") to maintain a root-mean-square exponentially + decaying average of gradients from prior training batches. + """ + + def __init__( + self, learning_rate=0.001, gamma=0.9, eps=1e-8, clip_grad_norm=None + ): # pylint: disable=useless-super-delegation + super().__init__( + learning_rate=learning_rate, + gamma=gamma, + eps=eps, + clip_grad_norm=clip_grad_norm, + ) + + def init(self, weights): + return jnp.ones_like(weights) + + def update(self, step, grads, weights, avg_sq_grad, opt_params): + del step + lr = opt_params["learning_rate"] + gamma = opt_params["gamma"] + eps = opt_params["eps"] + avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1.0 - gamma) + weights = weights - (lr * grads / (jnp.sqrt(avg_sq_grad) + eps)).astype( + weights.dtype + ) + return weights, avg_sq_grad diff --git a/trax/optimizers/sgd.py b/trax/optimizers/sgd.py new file mode 100644 index 000000000..f3b84bcb7 --- /dev/null +++ b/trax/optimizers/sgd.py @@ -0,0 +1,18 @@ +from trax.fastmath import numpy as jnp +from trax.optimizers import base as opt_base + + +class SGD(opt_base.Optimizer): + """Stochastic gradient descent (SGD) optimizer.""" + + def init(self, weights): + return None + + def update(self, step, grads, weights, slots, opt_params): + del step, slots + lr = opt_params["learning_rate"] + new_weights = jnp.subtract( + weights, jnp.multiply(lr, grads).astype(weights.dtype) + ) + + return new_weights, None diff --git a/trax/optimizers/sm3.py b/trax/optimizers/sm3.py index e716bd32c..2a101ffd9 100644 --- a/trax/optimizers/sm3.py +++ b/trax/optimizers/sm3.py @@ -22,173 +22,181 @@ class MomentumType(enum.IntEnum): - EMA = 1 - HEAVY_BALL = 2 - NESTEROV = 3 + EMA = 1 + HEAVY_BALL = 2 + NESTEROV = 3 class SM3(opt_base.Optimizer): - """SM3 optimizer, as described in https://arxiv.org/abs/1901.11150.""" - - def __init__(self, - learning_rate=0.01, - momentum=0.9, - second_moment_averaging=1.0, - weight_decay=0.0, - momentum_type=MomentumType.EMA): # pylint: disable=useless-super-delegation - """Create the SM3 optimizer. - - Memory-Efficient Adaptive Optimization. - https://arxiv.org/abs/1901.11150 - - Args: - learning_rate: a postitive scalar value for the initial learning rate. - momentum: optional, a positive scalar value for momentum - second_moment_averaging: averaging of second moments (if 1.0, adds from - begining of time like AdaGrad). - weight_decay: Weight decay for regularizing the model. - momentum_type: Nestrov, Heavy-Ball or EMA (Default). - - """ - self._has_momentum = momentum > 0.0 - self._momentum_type = momentum_type - self._graft = second_moment_averaging != 1.0 - super().__init__( - learning_rate=learning_rate, - momentum=momentum, - second_moment_averaging=second_moment_averaging, - weight_decay=weight_decay, - ) - - def init(self, w): - momentum = [] - if self._has_momentum: - momentum = jnp.zeros_like(w) - v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] - v2s = [] - if self._graft: - v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] - return (momentum, v1s, v2s) - - def _momentum_update(self, g, m, beta1): - """Handle various types of momentum.""" - if self._momentum_type == MomentumType.EMA: - m = (1 - beta1) * g + beta1 * m - update = m - elif self._momentum_type == MomentumType.HEAVY_BALL: - m = g + beta1 * m - update = m - elif self._momentum_type == MomentumType.NESTEROV: - m = g + beta1 * m - nesterov_m = g + beta1 * m - update = nesterov_m - else: - assert False, 'Unknown momentum_type.' - return m, update - - def _update_diagonal(self, g, w, m, v1, v2, opt_params): - learning_rate = opt_params['learning_rate'] - beta2 = opt_params['second_moment_averaging'] - weight_decay = opt_params['weight_decay'] - - is_beta2_1 = (beta2 == 1).astype(g.dtype) - one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) - v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g - - preconditioner = jnp.where(v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16), - jnp.zeros_like(v1[0])) - - pg = preconditioner * g - if self._graft: - v2[0] += g * g - preconditioner_graft = jnp.where( - v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0])) - pg_graft = preconditioner_graft * g - pg_norm = jnp.linalg.norm(pg) - pg_graft_norm = jnp.linalg.norm(pg_graft) - pg = pg * (pg_graft_norm/(pg_norm + 1e-16)) - - pg = pg + w * weight_decay - - if self._has_momentum: - m, update = self._momentum_update(pg, m, opt_params['momentum']) - else: - update = pg - - w = w - (update * learning_rate).astype(w.dtype) - return w, (m, v1, v2) - - def _expanded_shape(self, shape, axis): - # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. - # For eg: i = 1 returns [1, N, 1]. - rank = len(shape) - return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) - - def _minimum(self, tensor_list): - minimum = tensor_list[0] - for i in range(1, len(tensor_list)): - minimum = jnp.minimum(minimum, tensor_list[i]) - return minimum - - def _update_sketched(self, g, w, m, v1, v2, opt_params): - """Update for higher-rank parameters.""" - learning_rate = opt_params['learning_rate'] - momentum = opt_params['momentum'] - beta2 = opt_params['second_moment_averaging'] - weight_decay = opt_params['weight_decay'] - - shape = w.shape - rank = len(shape) - reshaped_accumulators = [jnp.reshape(v1[i], self._expanded_shape(shape, i)) - for i in range(rank)] - acc = self._minimum(reshaped_accumulators) - - is_beta2_1 = (beta2 == 1).astype(g.dtype) - one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) - acc = beta2 * acc + one_minus_beta2_except1 * g * g - - preconditioner = jnp.where(acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16), - jnp.zeros_like(acc)) - pg = g * preconditioner - if self._graft: - v2_acc = self._minimum([ - jnp.reshape(v2[i], self._expanded_shape(shape, i)) - for i in range(rank) - ]) - v2_acc = v2_acc + g * g - preconditioner_graft = jnp.where(v2_acc > 0.0, - 1.0 / (jnp.sqrt(v2_acc) + 1e-16), - jnp.zeros_like(v2_acc)) - pg_graft = preconditioner_graft * g - pg_norm = jnp.linalg.norm(pg) - pg_graft_norm = jnp.linalg.norm(pg_graft) - pg = pg * (pg_graft_norm/(pg_norm + 1e-16)) - - pg = pg + w * weight_decay - - if self._has_momentum: - m, update = self._momentum_update(pg, m, momentum) - else: - update = pg - - w = w - (learning_rate * update).astype(w.dtype) - for i in range(len(v1)): - axes = list(range(int(i))) + list(range(int(i) + 1, rank)) - dim_accumulator = jnp.amax(acc, axis=axes) - v1[i] = dim_accumulator - - if self._graft: - for i in range(len(v2)): - axes = list(range(int(i))) + list(range(int(i) + 1, rank)) - dim_accumulator = jnp.amax(v2_acc, axis=axes) - v2[i] = dim_accumulator - return w, (m, v1, v2) - - def update(self, step, g, w, slots, opt_params): - del step - m, v1, v2 = slots - rank = len(w.shape) - if rank > 1: - return self._update_sketched(g, w, m, v1, v2, opt_params) - else: - return self._update_diagonal(g, w, m, v1, v2, opt_params) + """SM3 optimizer, as described in https://arxiv.org/abs/1901.11150.""" + + def __init__( + self, + learning_rate=0.01, + momentum=0.9, + second_moment_averaging=1.0, + weight_decay=0.0, + momentum_type=MomentumType.EMA, + ): # pylint: disable=useless-super-delegation + """Create the SM3 optimizer. + + Memory-Efficient Adaptive Optimization. + https://arxiv.org/abs/1901.11150 + + Args: + learning_rate: a postitive scalar value for the initial learning rate. + momentum: optional, a positive scalar value for momentum + second_moment_averaging: averaging of second moments (if 1.0, adds from + begining of time like AdaGrad). + weight_decay: Weight decay for regularizing the model. + momentum_type: Nestrov, Heavy-Ball or EMA (Default). + + """ + self._has_momentum = momentum > 0.0 + self._momentum_type = momentum_type + self._graft = second_moment_averaging != 1.0 + super().__init__( + learning_rate=learning_rate, + momentum=momentum, + second_moment_averaging=second_moment_averaging, + weight_decay=weight_decay, + ) + + def init(self, w): + momentum = [] + if self._has_momentum: + momentum = jnp.zeros_like(w) + v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] + v2s = [] + if self._graft: + v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] + return (momentum, v1s, v2s) + + def _momentum_update(self, g, m, beta1): + """Handle various types of momentum.""" + if self._momentum_type == MomentumType.EMA: + m = (1 - beta1) * g + beta1 * m + update = m + elif self._momentum_type == MomentumType.HEAVY_BALL: + m = g + beta1 * m + update = m + elif self._momentum_type == MomentumType.NESTEROV: + m = g + beta1 * m + nesterov_m = g + beta1 * m + update = nesterov_m + else: + assert False, "Unknown momentum_type." + return m, update + + def _update_diagonal(self, g, w, m, v1, v2, opt_params): + learning_rate = opt_params["learning_rate"] + beta2 = opt_params["second_moment_averaging"] + weight_decay = opt_params["weight_decay"] + + is_beta2_1 = (beta2 == 1).astype(g.dtype) + one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) + v1[0] = beta2 * v1[0] + one_minus_beta2_except1 * g * g + + preconditioner = jnp.where( + v1[0] > 0, 1.0 / (jnp.sqrt(v1[0]) + 1e-16), jnp.zeros_like(v1[0]) + ) + + pg = preconditioner * g + if self._graft: + v2[0] += g * g + preconditioner_graft = jnp.where( + v2[0] > 0, 1.0 / (jnp.sqrt(v2[0]) + 1e-16), jnp.zeros_like(v2[0]) + ) + pg_graft = preconditioner_graft * g + pg_norm = jnp.linalg.norm(pg) + pg_graft_norm = jnp.linalg.norm(pg_graft) + pg = pg * (pg_graft_norm / (pg_norm + 1e-16)) + + pg = pg + w * weight_decay + + if self._has_momentum: + m, update = self._momentum_update(pg, m, opt_params["momentum"]) + else: + update = pg + + w = w - (update * learning_rate).astype(w.dtype) + return w, (m, v1, v2) + + def _expanded_shape(self, shape, axis): + # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. + # For eg: i = 1 returns [1, N, 1]. + rank = len(shape) + return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) + + def _minimum(self, tensor_list): + minimum = tensor_list[0] + for i in range(1, len(tensor_list)): + minimum = jnp.minimum(minimum, tensor_list[i]) + return minimum + + def _update_sketched(self, g, w, m, v1, v2, opt_params): + """Update for higher-rank parameters.""" + learning_rate = opt_params["learning_rate"] + momentum = opt_params["momentum"] + beta2 = opt_params["second_moment_averaging"] + weight_decay = opt_params["weight_decay"] + + shape = w.shape + rank = len(shape) + reshaped_accumulators = [ + jnp.reshape(v1[i], self._expanded_shape(shape, i)) for i in range(rank) + ] + acc = self._minimum(reshaped_accumulators) + + is_beta2_1 = (beta2 == 1).astype(g.dtype) + one_minus_beta2_except1 = is_beta2_1 + (1.0 - beta2) * (1.0 - is_beta2_1) + acc = beta2 * acc + one_minus_beta2_except1 * g * g + + preconditioner = jnp.where( + acc > 0.0, 1.0 / (jnp.sqrt(acc) + 1e-16), jnp.zeros_like(acc) + ) + pg = g * preconditioner + if self._graft: + v2_acc = self._minimum( + [ + jnp.reshape(v2[i], self._expanded_shape(shape, i)) + for i in range(rank) + ] + ) + v2_acc = v2_acc + g * g + preconditioner_graft = jnp.where( + v2_acc > 0.0, 1.0 / (jnp.sqrt(v2_acc) + 1e-16), jnp.zeros_like(v2_acc) + ) + pg_graft = preconditioner_graft * g + pg_norm = jnp.linalg.norm(pg) + pg_graft_norm = jnp.linalg.norm(pg_graft) + pg = pg * (pg_graft_norm / (pg_norm + 1e-16)) + + pg = pg + w * weight_decay + + if self._has_momentum: + m, update = self._momentum_update(pg, m, momentum) + else: + update = pg + + w = w - (learning_rate * update).astype(w.dtype) + for i in range(len(v1)): + axes = list(range(int(i))) + list(range(int(i) + 1, rank)) + dim_accumulator = jnp.amax(acc, axis=axes) + v1[i] = dim_accumulator + + if self._graft: + for i in range(len(v2)): + axes = list(range(int(i))) + list(range(int(i) + 1, rank)) + dim_accumulator = jnp.amax(v2_acc, axis=axes) + v2[i] = dim_accumulator + return w, (m, v1, v2) + + def update(self, step, g, w, slots, opt_params): + del step + m, v1, v2 = slots + rank = len(w.shape) + if rank > 1: + return self._update_sketched(g, w, m, v1, v2, opt_params) + else: + return self._update_diagonal(g, w, m, v1, v2, opt_params) diff --git a/trax/optimizers/trainer.py b/trax/optimizers/trainer.py deleted file mode 100644 index c633bface..000000000 --- a/trax/optimizers/trainer.py +++ /dev/null @@ -1,905 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multi-device accelerated optimization.""" - -from concurrent import futures -import functools -import os -import time - -from absl import logging -import jax -import numpy as np -import psutil - -from trax import fastmath -from trax import layers as tl -from trax.fastmath import numpy as jnp -from trax.layers import base -from trax.layers import combinators as cb - - -class Trainer: - """Multi-device accelerated trainer. - - Given an optimizer and a composite layer containing model+loss, this class - creates a multi-device accelerated function with which it can compute one step - of updates to the model's weights/state and the optimizer slots. By default - it uses all available accelerators, via JIT compilation and parallel mapping. - - The optimizer and model must be initialized prior to use by this class. - - The key `one_step` function runs one forward-backward pass through the model, - and returns the resulting loss value and updated optimizer statistics. As a - side effect, the function also modifies the model weights and optimizer slots. - """ - - def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): - self._model_with_loss = model_with_loss - self._optimizer = optimizer - self._n_devices = n_devices or fastmath.local_device_count() - self._adasum = adasum - - # optimizer slots and opt_params may need to be replicated - self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices( - (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)) - - # accelerated version of model+loss to replicate weights and state - self._accelerated_model_with_loss = tl.Accelerate( - model_with_loss, n_devices=n_devices) - - # Signature: - # (batch, weights, state, rng) -> ((loss, state), gradients) - self._forward_and_backward_fn = ( - fastmath.value_and_grad( - model_with_loss.pure_fn, - argnums=1, # arg1 of pure_fn: weights - has_aux=True)) # return (loss, state), gradients - - # Signature: - # (weights, slots), step, opt_params, batch, state, rng -> - # (weights, slots), state, stats - self._accelerated_update_fn = ( - _accelerate_update_fn( - self._forward_and_backward_fn, - self._optimizer, - n_devices=self._n_devices, - accelerate=True, - adasum=self._adasum - ) - ) - - @property - def model_with_loss(self): - """Returns the composite model+loss for this instance.""" - return self._model_with_loss - - @property - def accelerated_model_with_loss(self): - """Returns the accelerated composite model+loss for this instance.""" - return self._accelerated_model_with_loss - - @property - def optimizer(self): - """Returns the optimizer for this instance.""" - return self._optimizer - - @property - def slots(self): - """Returns the slots of the optimizers.""" - return self._optimizer.slots - - @slots.setter - def slots(self, slots): - """Sets the slots of the optimizers and this class (replicated).""" - self._optimizer.slots = slots - self._slots = tl.on_cpu(tl.for_n_devices(slots, self._n_devices)) - - def one_step(self, batch, rng, step=0, learning_rate=None): - """Runs one training step, to update model and optimizer parameters. - - Args: - batch: Batch of labeled training data. - rng: Single-use random number generator (JAX PRNG key). - step: Training step number. - learning_rate: Learning rate for the optimizer; if None, use optimizer's - default learning rate. - - Returns: - Tuple of (loss, optimizer_stats), with the newly computed loss and - updated stats as reported by the optimizer. - """ - if learning_rate is not None: - self._opt_params['learning_rate'] = tl.for_n_devices( - learning_rate, self._n_devices) - - # Split the batch across devices (batch_dim --> batch_dim // n_devices) - # and create new rng's 1-1 with devices. - if self._n_devices > 1: - batch = tl.reshape_by_device(batch, self._n_devices) - rng = jnp.stack(fastmath.random.split(rng, self._n_devices)) - - weights = self._accelerated_model_with_loss.weights - state = self._accelerated_model_with_loss.state - if logging.vlog_is_on(1) and ((step & step - 1) == 0): - # Prints every power of two, if debugging is enabled. - logging.info('step[%d]', step) - logging.info('opt_params[%s]', self._opt_params) - logging.info('slots[%s]', self._slots) - logging.info('weights[%s]', weights) - logging.info('state[%s]', state) - - # NOTE: stats is a replicated dictionary of key to jnp arrays. - (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( - (weights, self._slots), step, self._opt_params, batch, state, rng) - - if logging.vlog_is_on(1) and ((step & step - 1) == 0): - logging.info('updated weights[%s]', new_weights) - logging.info('stats[%s]', stats) - - self._accelerated_model_with_loss.weights = new_weights - self._accelerated_model_with_loss.state = new_state - self._slots = new_slots - self._optimizer.slots = self._unreplicate(self._slots) - return stats['loss'], stats - - def _unreplicate(self, x): - if self._n_devices == 1: - return x - return fastmath.nested_map(lambda x: x[0], x) - - -def _adasum_merge(g1, g2): - """Adasum gradient composition, see https://arxiv.org/pdf/2006.02924.pdf.""" - frac1 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g1, g1) + 1e-30) - frac2 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g2, g2) + 1e-30) - return (1 - frac1) * g1 + (1 - frac2) * g2 - - -def _average_multidevice_gradients(gradients, adasum=False): - """Averages gradients over all the devices across different hosts.""" - n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS - if adasum: - # This implements a version of the Adasum algorithm from the following - # paper: https://arxiv.org/pdf/2006.02924.pdf - lg = max([i for i in range(20) if 2**i <= n]) - for lg_i in range(lg): - shift = 2**lg_i - perm = [] - for i in range(n): - block_i = i % (2*shift) # we do blocks of 2*shift size - if block_i < shift: - perm.append((i, i+shift)) - else: - perm.append((i, i-shift)) - perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name='batch') - gradients = fastmath.nested_map_multiarg( - _adasum_merge, gradients, perm_grad) - if base.N_WEIGHTS_SHARDS > 1: # only sum gradients from matching shards - groups = [[base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))] - for d in range(base.N_WEIGHTS_SHARDS)] - gradients_psum = fastmath.psum(gradients, 'batch', - axis_index_groups=groups) - else: - gradients_psum = fastmath.psum(gradients, 'batch') # sum all gradients - n = jnp.array(n, dtype=jnp.float32) - return fastmath.nested_map(lambda g: g / n, gradients_psum) - - -# Returns a function with the following signature: -# (weights, slots), step, opt_params, batch, state, rng -> -# (weights, slots), state, stats -def _accelerate_update_fn(forward_and_backward_fn, - optimizer, - n_devices, - accelerate=True, - adasum=False): - """Accelerates the given forward_and_backward_fn function.""" - if n_devices == 1: - def single_device_update_fn( - weights_and_slots, step, opt_params, batch, state, rng): - step = jnp.array(step, dtype=jnp.int32) # Needed in TFNP backend. - weights, slots = weights_and_slots - (loss, state), gradients = forward_and_backward_fn( - batch, weights, state, rng) - weights, slots, stats = optimizer.tree_update( - step, gradients, weights, slots, opt_params, store_slots=False) - stats['loss'] = loss - return (weights, slots), state, stats - if accelerate: - # TODO(afrozm): Find out the status of buffer donation on GPUs, then do - # donate_argnums=(0,). - single_device_update_fn = fastmath.jit(single_device_update_fn) - return single_device_update_fn - - # More than one device (core), i.e. all of TPU configurations etc. - assert n_devices > 1, f'{n_devices} should be greater than 1.' - - @functools.partial(fastmath.pmap, axis_name='batch', donate_argnums=(0,)) - def _multi_device_update_fn( - weights_and_slots, step, opt_params, batch, state, rng): - # All tensors should have the first dimension = n_devices. - weights, slots = weights_and_slots - (loss, state), gradients = ( - forward_and_backward_fn(batch, weights, state, rng)) - gradients = _average_multidevice_gradients(gradients, adasum=adasum) - weights, slots, stats = optimizer.tree_update( - step, gradients, weights, slots, opt_params, store_slots=False) - stats['loss'] = loss - return (weights, slots), state, stats - - def multi_device_update_fn( - weights_and_slots, step, opt_params, batch, state, rng): - # Need to replicate step to n_devices leading dimension. - return _multi_device_update_fn(weights_and_slots, - jnp.repeat(step, n_devices), opt_params, - batch, state, rng) - - return multi_device_update_fn - - -class ReversibleSerialTrainer: - """Runs an optimizer on a series of layers, reversible and not. - - We provide layers to this trainer in blocks, each block consisting of - a list of standard layers and a list of reversible layers. They all run - in turn (like one huge Serial block) but in a more memory-efficient way. - - The main motivation for this class is to save memory: it allows to train - models that have more weights than the memory available on accelerators. - This happens by caching the weights in CPU memory and transferring only - the weights of one layer at a time. The reversible layers are used to make - the backward pass without using additional memory for storing activations. - - Note: we do not allow sharing weights between blocks for now. - """ - - def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None, - memoize_jit=True, free_accelerators_on_step=False, adasum=False): - """Creates a ReversibleSerialTrainer and the needed optimizers. - - This trainer performs updates equivalent to using the default Trainer on:: - - tl.Serial(blocks + [loss_layer]). - - It is more memory-efficient though since weights are stored on CPU and only - sent to accelerator layer-by-layer. Blocks are pairs consisting of a list - of standard (arbitrary) layers and a list of reversible layers which help - save memory thanks to being reversible. - - Args: - blocks: A list of pairs of lists of standard and reversible layers. - loss_layer: The final layer of the model; it can have trainable weights - but should end with a loss: it is required to produce a scalar output. - optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. - n_devices: An optional integer, number of accelerator devices to use; - by default, all available accelerators will be used. - memoize_jit: Whether to memoize JITed functions; this significantly speeds - up XLA compilation of larger models, but it uses `repr(layer)` as keys - to memoize so it could fail if two layers with different functionality - had the same string representaion. We have not encountered such case - yet so this is turned on by default, but consider turning it off or - reviewing your model if you use custom layers and encounter a problem. - free_accelerators_on_step: If true, frees memory on accelerators when - starting a step. All layers and arguments must be on host for that, - otherwise it can lead to failures. Can prevent memory fragmentation. - adasum: if True, use adaptive summation to gather multi-device gradients. - """ - self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] - self._loss_layer = loss_layer - self._optimizer_fn = optimizer_fn - self._n_devices = n_devices or fastmath.local_device_count() - self._adasum = adasum - self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) - self._n_steps_per_log = 100 # Log layers and stats every 100 steps. - self._n_async_layers = 1 # How many layers to run asynchronously. - self._jit_memory = {} if memoize_jit else None - self._do_free = free_accelerators_on_step - self._jit_per_device_rngs = fastmath.jit( - self._per_device_rngs, backend='cpu') - - # Create accelerated versions of layers as pmaped/jited pure_fn. - self._accelerated_layer_fns = fastmath.nested_map( - lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'), - self._blocks) - - # Create per-layer optimizers and replicate opt_params. - def _make_optimizer(layer): - opt = optimizer_fn() - opt.tree_init(layer.weights) - opt.slots = tl.on_cpu(opt.slots) - return opt - - self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) - self._replicated_opt_params = fastmath.nested_map( - lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers) - - self._loss_opt = _make_optimizer(loss_layer) - self._replicated_loss_opt_params = self._replicate_cpu( - self._loss_opt.opt_params) - - # Forward + backward + optimizer-update functions for all layers. - # We call them in short FBO for "Forward + Backward + Optimizer update". - # Reversible layers define a reverse_and_fbo function that also reverses. - - self._fbos = [] - for i, (std_layer, rev_layers) in enumerate(self._blocks): - (std_opt, rev_opts) = self._optimizers[i] - std_fbo = _fbo_with_layer_and_opt( - std_layer, std_opt, self._n_devices, adasum=self._adasum) - rev_and_fbos = [] - for layer, opt in zip(rev_layers, rev_opts): - rev_and_fbo = _reverse_and_fbo_with_layer_and_opt( - layer, opt, self._n_devices, self._adasum) - # The donated args are (outputs, weights, grads) and we can donate - # them because weights and grads are immediately replaced and in - # case of reversible layers, the outputs are never used again. - rev_and_fbos.append(self._pjit( - rev_and_fbo, f'rev+bwd {repr(layer)}', donate_argnums=(0, 1, 2))) - # In standard layers, the inputs cannot be donated as they may be used - # as outputs for the reversible block below, but weights and grads can. - jit_std_fbo = self._pjit( - std_fbo, f'bwd {repr(std_layer)}', donate_argnums=(1, 2)) - self._fbos.append((jit_std_fbo, rev_and_fbos)) - - loss_fbo = _fbo_with_layer_and_opt( - self._loss_layer, self._loss_opt, self._n_devices, 'loss', self._adasum) - self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2)) - - @property - def loss_layer(self): - """Returns the loss layer used to initialize this class.""" - return self._loss_layer - - @property - def all_layers(self): - """Returns all layers that compose the model and loss in this class.""" - layers = [] - for (std_layer, rev_layers) in self._blocks: - layers.append(std_layer) - layers.extend(rev_layers) - layers.append(self._loss_layer) - return layers - - @property - def optimizer_fn(self): - """Returns the optimizer function used to initialize this class.""" - return self._optimizer_fn - - @property - def slots(self): - """Returns the slots of all optimizers.""" - optimizers = list(self._optimizers) + [self._loss_opt] - return fastmath.nested_map(lambda opt: opt.slots, optimizers) - - @slots.setter - def slots(self, slots): - """Sets the slots of all optimizers.""" - for ((s_opt, r_opts), (s_slots, r_slots)) in zip( - self._optimizers, slots[:-1]): - for (opt, slot) in zip([s_opt] + r_opts, [s_slots] + r_slots): - opt.slots = slot - self._loss_opt.slots = slots[-1] - - def _pjit(self, f, memory_key=None, donate_argnums=()): - """JIT f if 1 device is available and pmap if more are available.""" - should_memoize = self._jit_memory is not None and memory_key is not None - if (should_memoize and memory_key in self._jit_memory): - logging.info('Found JITed function in memory for: %s', memory_key) - return self._jit_memory[memory_key] - if self._n_devices == 1: - res = fastmath.jit(f, donate_argnums=donate_argnums) - else: - res = fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums) - if should_memoize: - self._jit_memory[memory_key] = res - return res - - def _replicate(self, x): - if self._n_devices > 1: - return tl.for_n_devices(x, self._n_devices) - return tl.on_accelerator(x) - - def _replicate_cpu(self, x): - # TODO(lukaszkaiser): move it to layers/acceleration to be together with - # tl.for_n_devices and other functions like that, possibly refactor them. - def f(x): - if self._n_devices > 1: - return np.broadcast_to(x, (self._n_devices,) + np.asarray(x).shape) - else: - return x - return tl.on_cpu(fastmath.nested_map(f, x)) - - def _unreplicate(self, x): - if self._n_devices == 1: - return tl.on_cpu(x) - return tl.on_cpu(fastmath.nested_map(lambda x: x[0], x)) - - def _lazy_unreplicate(self, x): - def unreplicate_and_start_async_copy(y): - unreplicated = y if self._n_devices == 1 else y[0] - unreplicated.copy_to_host_async() - return unreplicated - return fastmath.nested_map(unreplicate_and_start_async_copy, x) - - def _collect_weights(self, layer): - layer.weights = fastmath.nested_map(np.asarray, layer.weights) - - def _free_accelerators(self, exceptions=(), keep_constants=True): - """Deletes all live buffers from accelerator with no safety guarantees.""" - backend = jax.lib.xla_bridge.get_backend() - live_buffers = backend.live_buffers() - logging.info('Deleting %d live buffers.', len(live_buffers)) - exceptions_buffers = [] - for x in fastmath.tree_flatten(exceptions): - if hasattr(x, 'device_buffer'): # DeviceArray - exceptions_buffers.append(x.device_buffer) - if hasattr(x, 'device_buffers'): # ShardedDeviceArray - exceptions_buffers.extend(x.device_buffers) - for b in live_buffers: - should_delete = True - for e in exceptions_buffers: - if b is e: - should_delete = False - if keep_constants and not b.shape: - should_delete = False - if should_delete: - b.delete() - - def _per_device_rngs(self, rng): - """Create per-device RNGs from a given rng.""" - # Splitting by device first to be identical with default trainer. - per_device_rng = fastmath.random.split(rng, self._n_devices) - per_device_rngs = [ - fastmath.random.split(r, self._n_layers) for r in per_device_rng] - rngs = [jnp.stack([r[i] for r in per_device_rngs]) - for i in range(self._n_layers)] - return rngs - - def one_step(self, batch, rng, step=0, learning_rate=None): - """Updates layers weights/state and optimizers slots by running one step. - - Args: - batch: Batch of data to use for optimization. - rng: Random number generator to use for running this step. - step: Which step of the training are we running. - learning_rate: Learning rate to use instead of the default one. - - Returns: - Tuple (loss, stats) with new values from one step - of training, where stats are all optimizer statistics. - """ - # Update the learning rate if needed. - if learning_rate is not None: - self._replicated_loss_opt_params['learning_rate'] = self._replicate_cpu( - learning_rate) - for (std_op, rev_ops) in self._replicated_opt_params: - std_op['learning_rate'] = self._replicate_cpu(learning_rate) - for op in rev_ops: - op['learning_rate'] = self._replicate_cpu(learning_rate) - - # Batch needs to be split across the local devices -- the difference - # between _for_n_devices and _reshape_by_device is that the latter splits - # the batch dim to batch // n_devices, vs _for_n_devices - # broadcasts/replicates to n_devices dimension. - step_int = step - if self._n_devices > 1: - batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True) - step = np.repeat(step, self._n_devices) - - # Create separate rng for each device and layer. - if self._n_devices == 1: - rngs = fastmath.random.split(rng, self._n_layers) - else: - # JIT the function and run it on CPU to avoid memory fragmentation. - rngs = self._jit_per_device_rngs(tl.on_cpu(rng)) - # Group rngs by layer blocks. - rng_blocks, rng_i = [], 0 - for _, rev_layers in self._blocks: - l = len(rev_layers) - rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1])) - rng_i += l + 1 - - # Run the layers forward upto the loss layer. - if self._do_free: - self._free_accelerators() - process = psutil.Process(os.getpid()) - if isinstance(batch, (list, tuple)): - batch_shapes = [x.shape for x in batch] - else: - batch_shapes = batch.shape - logging.info('running step %d on shapes %s', step_int, str(batch_shapes)) - if step_int % self._n_steps_per_log == 1: - logging.info('run fwd: cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - - stack = batch - block_inputs_states = [] - for i, (std_layer, rev_layers) in enumerate(self._blocks): - acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] - std_rng, rev_rngs = rng_blocks[i] - # Run the standard layer. - stack, std_inputs, std_state = self._run_forward_standard( - stack, std_layer, acc_std_layer_fn, std_rng, step_int) - - # Run the reversible layers and collect old and new states. - stack, rev_old_states, rev_new_states = self._run_forward_reversible( - stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int) - block_inputs_states.append(tl.on_cpu( - ((std_inputs, std_state), (rev_old_states, rev_new_states)))) - - # Run the loss layer forward and backward with optimizer update. - if step_int % self._n_steps_per_log == 1: - logging.info('run loss: cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - loss_state = self._replicate(self._loss_layer.state) - loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) - loss_stats, grad_stack = self._run_backward_standard( - None, step, self._loss_layer, loss_inputs, - loss_state, self._loss_fbo, rngs[-1], self._loss_opt, - self._replicated_loss_opt_params) - self._collect_weights(self._loss_layer) - stats = [tl.on_cpu(loss_stats)] - - # De-fragment memory. - if self._do_free: - stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack) - self._free_accelerators() - - # Run the layers backward and run optimizer updates. - if step_int % self._n_steps_per_log == 1: - logging.info('run bwd: cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - for i in range(len(self._blocks) - 1, -1, -1): - std_layer, rev_layers = self._blocks[i] - (std_inputs, std_state), (rev_old_states, - rev_new_states) = block_inputs_states[i] - std_fbo, rev_fbos = self._fbos[i] - std_opt, rev_opts = self._optimizers[i] - std_rng, rev_rngs = rng_blocks[i] - repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] - - # Run reversible layers backward with optimizer update. - stack, grad_stack, new_stats = self._run_backward_reversible( - stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states, - rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params) - stats.extend(tl.on_cpu(new_stats)) - - # Run the standard layer forward-and-backward pass and optimizer update. - std_layer_stats, grad_stack = self._run_backward_standard( - grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng, - std_opt, repl_std_opt_params) - stack = cb.outputs_onto_stack( # Put layer inputs on the stack. - std_inputs, stack, std_layer.n_out) - stats.append(tl.on_cpu(std_layer_stats)) - - # Collect lazily unreplicated layer weights. - for rev_layer_id in range(self._n_async_layers): - self._collect_weights(rev_layers[rev_layer_id]) - self._collect_weights(std_layer) - - # Join stats from different optimizers into one. - joint_stats = {} - for i, stat in enumerate(reversed(stats)): - for k, v in stat.items(): - joint_stats[f'layer{i}/' + k] = v - return stats[0]['loss'], joint_stats - - def _run_forward_standard(self, stack, layer, accelerated_fn, rng, step): - """Run standard layer forward.""" - if step % self._n_steps_per_log == 1: - logging.info('running forward standard layer %s', str(layer)) - layer_inputs = cb.inputs_from_stack(stack, layer.n_in) - layer_weights = self._replicate(layer.weights) - layer_state = self._replicate(layer.state) - outputs, layer_new_state = accelerated_fn( - layer_inputs, layer_weights, layer_state, rng) - stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) - return stack, layer_inputs, layer_new_state - - def _run_forward_reversible(self, stack, rev_layers, accelerated_fns, - rngs, step): - """Run reversible layers forward, collect states for backwards pass.""" - old_states, new_states = [], [] - for i, layer in enumerate(rev_layers): - if step % self._n_steps_per_log == 1: - logging.info('running forward reversible layer %s', str(layer)) - weights = self._replicate(layer.weights) # also copies cpu -> accelerator - state = self._replicate(layer.state) - old_states.append(state) - inputs = cb.inputs_from_stack(stack, layer.n_in) - outputs, new_state = accelerated_fns[i]( - inputs, weights, state, rngs[i]) - stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) - new_states.append(new_state) - return stack, old_states, new_states - - def _run_backward_standard(self, grad_stack, step, layer, inp, state, - fbo_fn, rng, optimizer, replicated_opt_params): - """Run reversible layers backwards.""" - step_int = int(step) if self._n_devices < 2 else int(step[0]) - if step_int % self._n_steps_per_log == 1: - logging.info('running backward standard layer %s', str(layer)) - if grad_stack is not None: - grads = cb.inputs_from_stack(grad_stack, layer.n_out) - else: - grads = None - slots = self._replicate(optimizer.slots) - weights = self._replicate(layer.weights) - # Ensure all arguments are on accelerator. - state = tl.on_accelerator(state) - replicated_opt_params = tl.on_accelerator(replicated_opt_params) - rng = tl.on_accelerator(rng) - grads = tl.on_accelerator(grads) - inp = tl.on_accelerator(inp) - new_weights, new_state, new_slots, new_grads, stats = fbo_fn( - inp, weights, grads, state, slots, replicated_opt_params, rng, step) - layer.weights = self._lazy_unreplicate(new_weights) - layer.state = self._unreplicate(new_state) - optimizer.slots = self._unreplicate(new_slots) - if grad_stack is not None: - grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out) - else: - grad_stack = new_grads - return stats, grad_stack - - def _run_backward_reversible(self, stack, grad_stack, step, - rev_layers, rev_and_fbos, - old_states, new_states, rngs, - optimizers, replicated_opt_params): - """Run reversible layers backwards.""" - counter = 0 - stats = [] - step_int = int(step) if self._n_devices < 2 else int(step[0]) - for layer, reverse_and_fbo, old_state, new_state, rng in reversed(list(zip( - rev_layers, rev_and_fbos, - old_states, new_states, rngs))): - if step_int % self._n_steps_per_log == 1: - logging.info('running backward reversible layer %s', str(layer)) - counter -= 1 - stack, grad_stack, layer_stats = self._run_backward_one_reversible( - layer, stack, grad_stack, step, rng, optimizers[counter], - replicated_opt_params[counter], reverse_and_fbo, old_state, new_state) - stats.append(layer_stats) - if counter + self._n_async_layers < 0: - self._collect_weights(rev_layers[counter + self._n_async_layers]) - return stack, grad_stack, stats - - def _run_backward_one_reversible(self, layer, stack, grad_stack, step, rng, - optimizer, opt_params, reverse_and_fbo, - old_state, new_state): - """Run one reversible layer backwards.""" - # We are running backwards and reversing, so we get *outputs* from stack. - outputs = cb.inputs_from_stack(stack, layer.n_out) - grads = cb.inputs_from_stack(grad_stack, layer.n_out) - slots = self._replicate(optimizer.slots) - weights = self._replicate(layer.weights) # cpu -> accelerator - # Ensure all arguments are on accelerator. - outputs = tl.on_accelerator(outputs) - grads = tl.on_accelerator(grads) - old_state = tl.on_accelerator(old_state) - new_state = tl.on_accelerator(new_state) - opt_params = tl.on_accelerator(opt_params) - rng = tl.on_accelerator(rng) - new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( - outputs, weights, grads, old_state, new_state, - slots, opt_params, rng, step) - layer.weights = self._lazy_unreplicate(new_weights) # accelerator -> cpu - layer.state = self._unreplicate(new_state) - optimizer.slots = self._unreplicate(new_slots) - stack = cb.outputs_onto_stack(inputs, stack, layer.n_out) - grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out) - return stack, grad_stack, layer_stats - - -# Forward + backward + optimizer-update functions for all layers. -# We call them in short FBO for "Forward + Backward + Optimizer update". - - -def _fbo_with_layer_and_opt(layer, optimizer, n_devices, - stats_name=None, adasum=False): - """Create the fbo function for a given layer and optimizer.""" - def fbo(inputs, weights, grads, state, slots, opt_params, rng, step): - """FBO of the layer.""" - # We need a layer pure_fn but only for inputs and weights. - def pure_fn_without_state_and_rng(x, w): - return layer.pure_fn(x, w, state, rng) - - # Calculate the vector-Jacobian product of the reduced pure fn. - activations, vjp_fn, new_state = fastmath.vjp( - pure_fn_without_state_and_rng, inputs, weights, has_aux=True) - - # In the loss layer, set gradients to 1 with the dtype of activations=loss. - if grads is None and stats_name is not None: - grads = jnp.ones((), dtype=activations.dtype) - - # The vjp function returns gradients with respect to inputs and weights. - grads_inputs, grads_weights = vjp_fn(grads) - - # For non-trainable layers, return the calculated arguments. - if _is_empty_tuple(weights): - stats = {} - if stats_name is not None: - stats[stats_name] = activations - return weights, new_state, slots, grads_inputs, stats - - # In multi-device setting, average gradients from multiple devices. - if n_devices > 1: - grads_weights = _average_multidevice_gradients( - grads_weights, adasum=adasum) - - # Run the optimizer. - new_weights, new_slots, stats = optimizer.tree_update( - step, grads_weights, weights, slots, opt_params, store_slots=False) - if stats_name is not None: - stats[stats_name] = activations - return new_weights, new_state, new_slots, grads_inputs, stats - - return fbo - - -# Reversible layers define a reverse_and_fbo function that both reverses -# and runs the forward-backward pass and applied the optimizer. -# This function uses the `reverse_and_grad` method of reversible layers. - - -def _reverse_and_fbo_with_layer_and_opt(layer, optimizer, n_devices, adasum): - """Create the reverse_and_fbo function for a given layer and optimizer.""" - def reverse_and_fbo(output, weights, grads, state, new_state, - slots, opt_params, rng, step): - """Reverse and FBO of the layer.""" - # Call the reverse_and_grad method of the layer. - inputs, (grads_inputs, grads_weights) = layer.reverse_and_grad( - output, grads, weights, state, new_state, rng=rng) - - # For non-trainable layers, return the calculated arguments. - if _is_empty_tuple(weights): - return weights, slots, inputs, grads_inputs, {} - - # In multi-device setting, average gradients from multiple devices. - if n_devices > 1: - grads_weights = _average_multidevice_gradients( - grads_weights, adasum=adasum) - - # Run the optimizer. - new_weights, new_slots, stats = optimizer.tree_update( - step, grads_weights, weights, slots, opt_params, store_slots=False) - - return new_weights, new_slots, inputs, grads_inputs, stats - - return reverse_and_fbo - - -def _is_empty_tuple(x): - """Check if x is either empty or a tuple of (tuples of) empty things.""" - if not isinstance(x, (list, tuple)): - return False - for y in x: - if not _is_empty_tuple(y): - return False - return True - - -def extract_reversible_blocks(layers, loss_chunk_size=0): - """Extracts blocks and loss layer for use with ReversibleSerialTrainer. - - Args: - layers: a list of layers of a single layer to extract blocks from; - should end with a loss, e.g., [model, loss] or tl.Serial(model, loss). - loss_chunk_size: int, if > 0 creates a chunked loss layer to save memory - in models with larger vocabulary; requires the last sublayers of loss - are [Dense, LogSoftmax, _CrossEntropy, _WeightedMean] in that order. - - Returns: - a pair (blocks, loss_layer) to use with ReversibleSerialTrainer. - """ - def _flatten(l): - """Flatten all Serial layers and sub(sub-...) layers into a list.""" - if isinstance(l, (list, tuple)): - return [x for layer in l for x in _flatten(layer)] # pylint: disable=g-complex-comprehension - elif isinstance(l, tl.Serial): - return _flatten(l.sublayers) - else: - return [l] - - # Extract standard and reversible layer blocks. - blocks, std_layers, rev_layers = [], [], [] - for layer in _flatten(layers): - if isinstance(layer, tl.ReversibleLayer): - rev_layers.append(layer) - elif not rev_layers: - std_layers.append(layer) - else: - blocks.append((std_layers, rev_layers)) - std_layers, rev_layers = [], [] - std_layers.append(layer) - if rev_layers: - raise ValueError('The final layer must be a standard loss, not reversible.') - if loss_chunk_size > 0: - # For now we only do chunking of [Dense, LogSoftmax, CrossEntopy, Mean] - # Let's check that these are the last 4 layers. - border_layers = ['StripFromConcatenateWithPadding', 'Select'] - - loss_start = None - for index, layer in enumerate(std_layers): - if layer.name in border_layers: - loss_start = index + 1 - if loss_start is None: - raise ValueError('Loss layer should be preceeded by one of {}; got {}' - .format(border_layers, [l.name for l in std_layers])) - if len(std_layers) - loss_start < 4: - raise ValueError('Too short loss layer for chunking') - last_3_names = ' '.join([l.name for l in std_layers[-3:]]) - if last_3_names != 'LogSoftmax _CrossEntropy _WeightedMean': - raise ValueError('Loss chunking only works with last layers being "' - 'LogSoftmax, _CrossEntropy, _WeightedMean" but got: ' + - last_3_names) - - # Create chunked dense+logsoftmax+cross-entropy-loss. - chunked_xent = tl.Chunk(tl.Serial(std_layers[loss_start:-1]), - loss_chunk_size) - # The chunked loss should operate on a merged batch dimension, e.g., - # including both length and batch size. Need to merge and un-merge later. - def _reshape_to_batch_and_copy_targets(preds, targets): - batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]]) - batched_targets = jnp.reshape(targets, [-1]) - return batched_preds, batched_targets, targets - def _reshape_xent_back(xent, targets): - return jnp.reshape(xent, targets.shape) - batched_xent = tl.Serial( - tl.Fn('pre_xent_rebatch', _reshape_to_batch_and_copy_targets, n_out=3), - chunked_xent, - tl.Fn('after_xent_rebatch', _reshape_xent_back) - ) - loss_layer = tl.Serial(std_layers[:loss_start] + [batched_xent], - std_layers[-1]) - else: - loss_layer = tl.Serial(std_layers) - return blocks, loss_layer - - -def init_reversible_blocks(blocks, loss_layer, input_signature, rng): - """Initialize reversible blocks and the loss layer and place weights on CPU. - - Args: - blocks: List of reversible blocks (pairs of layer lists). - loss_layer: The final loss layer to initialize. - input_signature: The signature of the input to the blocks. - rng: Random key used to initialize the layers. - """ - sig_stack = input_signature - process = psutil.Process(os.getpid()) - mem_use = process.memory_info().rss - for (std_layers, rev_layers) in blocks: - rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) - rng = rngs[0] - for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): - sig = cb.inputs_from_stack(sig_stack, layer.n_in) - layer.init(sig, rng=layer_rng) - layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory - layer.state = tl.on_cpu(layer.state) # store weights in cpu memory - logging.info('init: layer %s\nadded cpu memory (MB): %.2f', str(layer), - (process.memory_info().rss - mem_use) / float(1024 * 1024)) - mem_use = process.memory_info().rss - logging.info('init: cpu memory use (MB): %.2f', - mem_use / float(1024 * 1024)) - out_sig = layer.output_signature(sig) - sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) - loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) - loss_layer.weights = tl.on_cpu(loss_layer.weights) - loss_layer.state = tl.on_cpu(loss_layer.state) - - diff --git a/trax/optimizers/trainer_test.py b/trax/optimizers/trainer_test.py deleted file mode 100644 index bda6d191f..000000000 --- a/trax/optimizers/trainer_test.py +++ /dev/null @@ -1,344 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for accelerated optimization of loss layers.""" - -import time -from absl.testing import absltest - -from jax.config import config -import numpy as np - -from trax import fastmath -from trax import layers as tl -from trax import optimizers -from trax import shapes -from trax.layers import base -from trax.models.research import terraformer - - -class TrainerTest(absltest.TestCase): - - def _assert_all_equal(self, t1, t2, tol=1e-5): - def eq(x1, x2): - diff = np.maximum(np.abs(x1 - x2) - tol, 0.0) - self.assertLessEqual(np.sum(diff), 0.0, - msg=f'\n{x1}\n !=\n{x2}\n diff:\n{x1-x2}') - fastmath.nested_map_multiarg(eq, t1, t2) - - def test_run_simple_task(self): - """Runs an accelerated optimizer on a simple task.""" - inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch - targets_batch = np.pi * np.ones_like(inputs_batch) - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss()) - loss_layer.init(labeled_batch) - optimizer = optimizers.SGD(.01) - optimizer.tree_init(loss_layer.weights) - trainer = optimizers.Trainer(loss_layer, optimizer) - rng = fastmath.random.get_prng(0) - trainer.one_step(labeled_batch, rng) - - - def test_run_sharded_terraformer(self): - """Runs Terraformer with sharded weights (only on 2+-device systems).""" - if fastmath.local_device_count() == 1: - return - base.N_WEIGHTS_SHARDS = fastmath.local_device_count() - inputs_batch = np.arange(8).reshape((2, 4)) + 1 - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) - input_sig = (int_sig, int_sig, int_sig) - # We want to test rng propagation too, so adding some dropout layers. - model = terraformer.ConfigurableTerraformer( - 20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, - n_encoder_layers=2, n_decoder_layers=2, - ff_sparsity=(4, 8, 0.0, 1.0), - encoder_attention_type=tl.Attention, - encoder_decoder_attention_type=tl.CausalAttention, - pos_type=None, reversible_encoder=True) - loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) - model_with_loss = tl.Serial(model, loss) - rng_init = fastmath.random.get_prng(12) - model_with_loss.init(input_sig, rng=rng_init) - - # Make a step with the trainer. - optimizer = optimizers.Adafactor(0.01) - split_w = fastmath.nested_map( - lambda x: x[0], - tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS)) - optimizer.tree_init(split_w) - trainer = optimizers.Trainer(model_with_loss, optimizer) - rng_step1 = fastmath.random.get_prng(7) - trainer.one_step(labeled_batch, rng_step1) - # Reset shards back to default. - base.N_WEIGHTS_SHARDS = 1 - - def test_run_reversible_slots(self): - """Tests that slots can be read and assigned in reversible trainer.""" - layers = [tl.Dense(4), tl.Dup()] - rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)), - tl.ReversibleSwap()] - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(4), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - trainer = optimizers.ReversibleSerialTrainer( - [(layers, rev_layers)], loss_layer, optimizers.Adam) - slots = trainer.slots - trainer.slots = slots - self.assertEqual(slots, trainer.slots) - - def test_run_reversible_same_as_default_basic(self): - """Runs the reversible trainer, check results are the same as default.""" - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - # We want to test rng propagation too, so adding some dropout layers. - first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) - rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), - tl.ReversibleSwap(), - tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), - tl.ReversibleSwap()] - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - model = tl.Serial([first_layer] + rev_layers + [loss_layer]) - rng_init = fastmath.random.get_prng(12) - model.init(labeled_batch, rng=rng_init) - optimizer_fn = optimizers.Adam # to test slots - - # Make 2 steps with the original trainer. - optimizer = optimizer_fn() - optimizer.tree_init(model.weights) - trainer = optimizers.Trainer(model, optimizer) - rng_step1 = fastmath.random.get_prng(7) - rng_step2 = fastmath.random.get_prng(8) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - first_layer_weights1 = first_layer.weights - rev_layer0_weights1 = rev_layers[0].weights - rev_layer2_weights1 = rev_layers[2].weights - loss_layer_weights1 = loss_layer.weights - - # Now make 2 steps with reversible trainer. - model.init(labeled_batch, rng=rng_init) - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer.sublayers, rev_layers)], loss_layer, optimizer_fn) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - - # Check that weights end up the same. - self._assert_all_equal(loss_layer_weights1, loss_layer.weights) - self._assert_all_equal(rev_layer2_weights1, rev_layers[2].weights) - self._assert_all_equal(rev_layer0_weights1, rev_layers[0].weights) - self._assert_all_equal(first_layer_weights1, first_layer.weights) - - def test_run_reversible_same_as_default_extended(self): - """Runs the reversible trainer, check results are the same as default.""" - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - # We want to test rng propagation too, so adding some dropout layers. - first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) - rev_layers1 = [tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), - tl.ReversibleSwap(), - tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), - tl.ReversibleSwap()] - mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) - rev_layers2 = [tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), - tl.ReversibleSwap()] - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] + - rev_layers2 + [loss_layer]) - rng_init = fastmath.random.get_prng(12) - model.init(labeled_batch, rng=rng_init) - optimizer_fn = optimizers.Adam # to test slots - - # Make 3 steps with the original trainer. - optimizer = optimizer_fn() - optimizer.tree_init(model.weights) - trainer = optimizers.Trainer(model, optimizer) - rng_step1 = fastmath.random.get_prng(7) - rng_step2 = fastmath.random.get_prng(8) - rng_step3 = fastmath.random.get_prng(9) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - first_layer_weights1 = first_layer.weights - rev_layer12_weights1 = rev_layers1[2].weights - mid_layer_weights1 = mid_layer.weights - rev_layer20_weights1 = rev_layers2[0].weights - loss_layer_weights1 = loss_layer.weights - - # Now make 3 steps with reversible trainer. - model.init(labeled_batch, rng=rng_init) - # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why? - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer.sublayers, rev_layers1), - (mid_layer.sublayers, rev_layers2)], - loss_layer, optimizer_fn, memoize_jit=False) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - - # Check that weights end up the same. - self._assert_all_equal(loss_layer_weights1, loss_layer.weights) - self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) - self._assert_all_equal(mid_layer_weights1, mid_layer.weights) - self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) - self._assert_all_equal(first_layer_weights1, first_layer.weights) - - def test_run_reversible_same_as_default_terraformer(self): - """Runs the reversible trainer, check results are the same as default.""" - inputs_batch = np.arange(8).reshape((2, 4)) + 1 - targets_batch = 2 * inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) - input_sig = (int_sig, int_sig, int_sig) - # We want to test rng propagation too, so adding some dropout layers. - model = terraformer.ConfigurableTerraformer( - 20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, n_encoder_layers=2, - n_decoder_layers=2, ff_sparsity=(4, 8, 0.0, 1.0), pos_type=None, - reversible_encoder=True) - loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) - optimizer_fn = optimizers.Adafactor - blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( - [model, loss], loss_chunk_size=4) - blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks] - model_with_loss = tl.Serial(model, loss) - rng_init = fastmath.random.get_prng(12) - model_with_loss.init(input_sig, rng=rng_init) - - # Make 3 steps with the original trainer. - optimizer = optimizer_fn() - optimizer.tree_init(model_with_loss.weights) - trainer = optimizers.Trainer(model_with_loss, optimizer) - rng_step1 = fastmath.random.get_prng(7) - rng_step2 = fastmath.random.get_prng(8) - rng_step3 = fastmath.random.get_prng(9) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - first_weights = blocks_serial[0][0].weights - first_rev_weights = blocks[0][1][0].weights - loss_weights = loss_layer.weights - - # Now make 3 steps with reversible trainer. - model_with_loss.init(input_sig, rng=rng_init) - trainer = optimizers.ReversibleSerialTrainer( - blocks, loss_layer, optimizer_fn) - trainer.one_step(labeled_batch, rng_step1) - trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) - trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) - - # Check that weights end up the same. - self._assert_all_equal(loss_weights, loss_layer.weights) - self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights) - self._assert_all_equal(first_weights, blocks_serial[0][0].weights) - - def test_run_reversible_large_weights(self): - """Runs the reversible trainer with a lot of weights to test memory use.""" - # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU - # and CPU when you run it locally, but it's too big for unit-testing. - ram_limited = True # Set to False to run this test locally. - if fastmath.global_device_count() == 1 and ram_limited: - return - - # Create inputs and rngs. - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - first_layer = tl.Serial(tl.Embedding(9, 16*1024), tl.Dup()) - rng_init = fastmath.random.get_prng(12) - rng_step = fastmath.random.get_prng(13) - - # Initialize layers. - first_layer.init(labeled_batch, rng=rng_init) - n_layers = 18 # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram - rev_layers = [] - int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) - shape = shapes.ShapeDtype((2, 4, 16*1024)) - sig = (shape, shape) - for _ in range(n_layers): - layer = tl.ReversibleHalfResidual(tl.Dense(16*1024)) - layer.init(sig, rng=rng_init) - layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory - rev_layers.append(layer) - rev_layers.append(tl.ReversibleSwap()) - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - loss_layer.init((shape, shape, int_shape, int_shape)) - optimizer_fn = optimizers.Adafactor - - # Make a step with reversible trainer. - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer, rev_layers)], loss_layer, optimizer_fn) - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - # Set to true to run again, e.g., for profiling. - run_twice = False - if run_twice: - t = time.time() - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss)) - - def test_run_reversible_weights_trainsfer_xprof(self): - """Runs the reversible trainer and profiles weight transfer stats.""" - run_this_test = False # We only run this test manually. - if not run_this_test or fastmath.global_device_count() == 1: # TPU only - return - - # Create inputs and rngs. - inputs_batch = np.ones((1024, 128), dtype=np.int32) - targets_batch = inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) - rng_init = fastmath.random.get_prng(12) - rng_step = fastmath.random.get_prng(13) - - # Initialize layers. - first_layer.init(labeled_batch, rng=rng_init) - n_layers = 6 - rev_layers = [] - int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) - shape = shapes.ShapeDtype((1024, 128, 1024)) - sig = (shape, shape) - for _ in range(n_layers): - layer = tl.ReversibleHalfResidual(tl.Dense(1024)) - layer.init(sig, rng=rng_init) - layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory - rev_layers.append(layer) - rev_layers.append(tl.ReversibleSwap()) - loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), - tl.LogSoftmax(), tl.CrossEntropyLoss()) - loss_layer.init((shape, shape, int_shape, int_shape)) - optimizer_fn = optimizers.SGD - - # Make a step with reversible trainer. - trainer = optimizers.ReversibleSerialTrainer( - [(first_layer, rev_layers)], loss_layer, optimizer_fn) - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - # We profile here. - t = time.time() - loss, _ = trainer.one_step(labeled_batch, rng_step) - self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. - print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss)) - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() diff --git a/trax/predict_drop.py b/trax/predict_drop.py deleted file mode 100644 index 0b00a12e9..000000000 --- a/trax/predict_drop.py +++ /dev/null @@ -1,327 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Prediction binary for the Drop task. - -Binary that loads a checkpoint and runs inference on selected problems -from the Drop dataset. For more details about Drop see -https://arxiv.org/pdf/1903.00161.pdf. -""" - -import json -import os -import re -import time - -from absl import app as absl_app -from absl import flags -import gin -import jax -import numpy as np -from seqio import vocabularies as t5_spc_vocab -from t5 import data -import tensorflow as tf -from trax import data as trax_data -from trax import layers as tl -from trax import shapes -from trax.supervised import decoding - - -FLAGS = flags.FLAGS - -flags.DEFINE_string('checkpoint_dir', '', - 'Path to model checkpoint.') -flags.DEFINE_integer('max_answer_len', 1024, - 'Maximum length of answers to produce.') -flags.DEFINE_integer('batch_size', 1, 'Batch size for eval.') -flags.DEFINE_integer('num_examples', 1, 'Number of examples to infer.') -flags.DEFINE_integer('n_hashes', None, - 'n_hashes parameter to override in attentions.') -flags.DEFINE_integer('example_repetitions', 1, - 'How many times to infer an example.') -flags.DEFINE_bool('use_eval_mode', False, - 'If True, use the slower but easier to debug eval mode.') -flags.DEFINE_bool('use_eval_set', False, - 'If True, use eval set for evaluation.') -flags.DEFINE_bool( - 'use_beam_search', False, - 'If True, use beam search, otherwise use autoregresive sampling.') -flags.DEFINE_float('autoregressive_sample_temp', 1, - 'The temperature for autoregressive sampling.') -flags.DEFINE_integer('n_beams', 4, 'How many beams to use in beam search.') -flags.DEFINE_string( - 'output_dir', '', 'Path to the output directory where articles, abstracts, ' - 'and predictions would be stored.') -flags.DEFINE_integer('starting_example', 0, - 'Example index for starting decoding.') -flags.DEFINE_integer('reload_after', 1000, - 'Reload checkpoint after reload_after examples.') -flags.DEFINE_multi_string('config_file', None, - 'Configuration file with parameters (.gin).') - - -def _check_exists(file_path): - if not tf.io.gfile.exists(file_path): - print('No such file: %s' % file_path, flush=True) - exit(1) - - -def multiply_examples(example): - for i in range(FLAGS.example_repetitions): - yield i, example - - -def prepare_model(model_file, batch_size=1): - """Prepare the model.""" - mode = 'eval' if FLAGS.use_eval_mode else 'predict' - print('Initializing the model in %s mode.' % mode, flush=True) - - # Read the model name from the gin file - model_reference = gin.query_parameter( - 'trax.supervised.trainer_lib.train.model') - model = model_reference.scoped_configurable_fn(mode=mode) - - dec_len = 32 if FLAGS.use_eval_mode else 1 - batch_size_pd = max(1, batch_size // jax.local_device_count()) - shape11 = shapes.ShapeDtype((batch_size_pd, dec_len), dtype=np.int32) - # shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model.init_from_file( - model_file, weights_only=True, input_signature=(shape11, shape11)) - model = tl.Accelerate(model) - - initial_state = model.state - vocab = t5_spc_vocab.SentencePieceVocabulary(data.DEFAULT_SPM_PATH) - - return vocab, model, initial_state - - -def is_number(s): - try: - float(s) - return True - except ValueError: - return False - - -def main(argv): - if len(argv) > 1: - raise absl_app.UsageError('Too many command-line arguments.') - if not FLAGS.output_dir: - raise absl_app.UsageError('--output_dir needs to be provided.') - - tf.compat.v1.enable_eager_execution() - - # Check that checkpoint_dir is correct: should contain model.pkl.gz file. - model_file = os.path.join(FLAGS.checkpoint_dir, 'model.pkl.gz') - _check_exists(model_file) - - gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, 'config.gin')) - # Batching on our own because of possible repetitions of examples. - gin.bind_parameter('data.Batch.batch_size', 1) - if FLAGS.n_hashes is not None: - gin.bind_parameter('LSHSelfAttention.n_hashes', FLAGS.n_hashes) - gin.bind_parameter('ref2_encoder/LSHSelfAttention.n_hashes', FLAGS.n_hashes) - - vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) - - host_id, host_count = jax.host_id(), jax.host_count() - print('Running on host %d out of %d.' % (host_id, host_count)) - - example_count = 0 - start_time = time.time() - - # Creates all intermediate directories if they do not exist - tf.io.gfile.makedirs(FLAGS.output_dir) - - json_to_write = os.path.join(FLAGS.output_dir, 'output%d.json' % host_id) - all_jsons = [] - - # In a case of a reset we have to check how much work was already done. - # We can check whether the processing of an example was finished, but - # currently we are only checking whether it was started. - done = FLAGS.starting_example - reload_count = 0 - all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir) - for filename in all_existing_files: - if 'processing' in filename: - # The definition of digits looks for a number after the infix "processing" - # in the file name. Example: tom_processing_532 will lead to - # digits = "processing_532" and number equal to "532". - digits = filename[filename.find('processing'):] - number = ''.join(d for d in digits if d.isdigit()) - if is_number( - number) and int(number) < FLAGS.num_examples + FLAGS.starting_example: - done = max(done, int(number)) - print('The done number is {}'.format(done)) - - if FLAGS.use_eval_set: - drop_gen = trax_data.CreateDropInputs(train=False)() - else: - drop_gen = trax_data.CreateDropInputs(train=True)() - padding_fun = trax_data.PadToLength() - - # TODO(henrykm): improve managment of the counters. - # example_count_total - all numeric examples - # example_count - all numeric examples above starting_example - # reload_count - if we processed FLAGS.reload_after examples, - # then the checkpoint should be reloaded. - # idx - total number of exaples - example_count_total = 0 - reload_count += 1 - for idx, e in enumerate(drop_gen): - if reload_count >= FLAGS.reload_after: - vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) - reload_count = 0 - if example_count >= FLAGS.num_examples: - print('Reached the example_count {} - breaking'.format(example_count)) - break - if not is_number(e[1]): - continue - target_answer = float(e[1]) - - # We count numeric starting examples - example_count_total += 1 - if example_count_total <= FLAGS.starting_example: - print('Skipping example_count_total {} because it is below {}'.format( - example_count_total, FLAGS.starting_example)) - continue - - if example_count % 10 == 0: - elapsed_time = time.time() - start_time - start_time = time.time() - print('Starting inference on example %d, %.2fs since last log' % - (example_count, elapsed_time), flush=True) - - example_count += 1 - if example_count <= done - FLAGS.starting_example + 1: - print('Skipping example_count {} because it is below {}'.format( - example_count, done - FLAGS.starting_example)) - # We are increasing the example_count because the example - # was processed before - continue - - if example_count % host_count != host_id: - continue - - # At this point we are committed to the processing of an example with - # index example_count - processing_file = os.path.join(FLAGS.output_dir, 'processing_') - data_id = str(example_count + FLAGS.starting_example) - with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: - w.write('Procesing started.') - for repetition_id, example in multiply_examples(e): - question = example[0] - question_text = question[question.find(':') + 2:] - question_text = question_text.replace('-', ' - ') - question = 'infer full calculation: ' + question_text - - list_num = [ - float(num.replace(',', '').rstrip('.')) for num in re.findall( - r'[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?', question) - ] - for i in range(len(list_num)): - question += ' n{} = {}'.format(i, list_num[i]) - - # print('Question {}'.format(question)) - tokenized_question = next( - padding_fun( - trax_data.tokenize([ - question, - ], - vocab_file=gin.query_parameter( - 'trax.data.Tokenize.vocab_file')))) - state = model.state - if FLAGS.use_beam_search: - answer_beams = decoding.beam_search( - model, - tokenized_question[None, :], - n_beams=FLAGS.n_beams, - max_length=FLAGS.max_answer_len, - accelerate=False) - model.state = state - else: - answer_beams = [] - # We recycle the n_beams flag to control the number - # of autoregressive samples. - for i in range(FLAGS.n_beams): - answer = decoding.autoregressive_sample( - model, - tokenized_question[None, :], - temperature=FLAGS.autoregressive_sample_temp, - max_length=FLAGS.max_answer_len, - accelerate=False) - model.state = state - answer_beams.append(answer) - - correct_example_index = -1 - - for i in range(len(answer_beams)): - if FLAGS.use_beam_search: - answer = trax_data.detokenize( - answer_beams[i][0][0], - vocab_file=gin.query_parameter('trax.data.Tokenize.vocab_file')) - else: - answer = trax_data.detokenize( - answer_beams[i][0], - vocab_file=gin.query_parameter('trax.data.Tokenize.vocab_file')) - print('Proposed computation {}'.format(answer)) - list_op = answer.split('|') - if not list_op[-1]: - list_op = list_op[:-1] - - try: - result = trax_data.tf_inputs.compute_result(list_op, list_num) - if target_answer in result: - correct_example_index = result.index(target_answer) - break - # This is a temporary hack with "broad" exceptions - the computations - # must fail sometime, because we evaluate arbitrary sequences; I am in - # the process of checking what are possible failure modes. - except Exception as e: # pylint: disable=broad-except - print(e) - try: - result = trax_data.tf_inputs.compute_result(list_op[:-1], list_num) - if target_answer in result: - correct_example_index = result.index(target_answer) - break - except Exception as e: # pylint: disable=broad-except - print(e) - print('Infered incorrect computation.') - - if correct_example_index == -1: - continue - - json_record = { - 'question': question_text, - 'input': question, - 'calculation': '|'.join(list_op[:correct_example_index + 1]), - 'target_answer': target_answer - } - all_jsons.append(json.dumps(json_record) + '\n') - # Outputting the inferred data in JSONL format. - data_id = str(example_count + FLAGS.starting_example) - with tf.io.gfile.GFile(json_to_write + data_id, 'w') as w: - w.write(json.dumps(json_record) + '\n') - with tf.io.gfile.GFile(processing_file + data_id, 'w') as w: - w.write('Procesing finished.') - - with tf.io.gfile.GFile(json_to_write + '_' + str(FLAGS.starting_example), - 'w') as w: - for record in all_jsons: - w.write(record) - - -if __name__ == '__main__': - absl_app.run(main) diff --git a/trax/rl/actor_critic.py b/trax/rl/actor_critic.py deleted file mode 100644 index f23988359..000000000 --- a/trax/rl/actor_critic.py +++ /dev/null @@ -1,1209 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for RL training in Trax.""" - -import functools -import os - -import gin -import gym -import numpy as np -import tensorflow as tf - -from trax import data -from trax import fastmath -from trax import layers as tl -from trax import shapes -from trax import supervised -from trax.fastmath import numpy as jnp -from trax.optimizers import adam -from trax.rl import advantages as rl_advantages -from trax.rl import distributions -from trax.rl import policy_tasks -from trax.rl import training as rl_training -from trax.rl import value_tasks -from trax.supervised import lr_schedules as lr - - -class ActorCriticAgent(rl_training.PolicyAgent): - """Trains policy and value models using actor-critic methods. - - Attrs: - on_policy (bool): Whether the algorithm is on-policy. Used in the data - generators. Should be set in derived classes. - """ - - on_policy = None - - def __init__(self, task, - value_model=None, - value_optimizer=None, - value_lr_schedule=lr.multifactor, - value_batch_size=64, - value_train_steps_per_epoch=500, - value_evals_per_epoch=1, - value_eval_steps=1, - n_shared_layers=0, - added_policy_slice_length=0, - n_replay_epochs=1, - scale_value_targets=False, - q_value=False, - q_value_aggregate='logsumexp', - q_value_temperature=1.0, - q_value_n_samples=1, - q_value_normalization=False, - offline=False, - **kwargs): # Arguments of PolicyAgent come here. - """Configures the actor-critic trainer. - - Args: - task: `RLTask` instance to use. - value_model: Model to use for the value function. - value_optimizer: Optimizer to train the value model. - value_lr_schedule: lr schedule for value model training. - value_batch_size: Batch size for value model training. - value_train_steps_per_epoch: Number of steps are we using to train the - value model in each epoch. - value_evals_per_epoch: Number of value trainer evaluations per RL epoch. - Every evaluation, we also synchronize the weights of the target - network. - value_eval_steps: Number of value trainer steps per evaluation; only - affects metric reporting. - n_shared_layers: Number of layers to share between value and policy - models. - added_policy_slice_length: How much longer should slices of - trajectories be for policy than for value training; this - is useful for TD calculations and only affect the length - of elements produced for policy batches; value batches - have maximum length set by `max_slice_length` in `**kwargs`. - n_replay_epochs: Number of last epochs to take into the replay buffer; - only makes sense for off-policy algorithms. - scale_value_targets: If `True`, scale value function targets by - `1 / (1 - gamma)`. - q_value: If `True`, use Q-values as baselines. - q_value_aggregate: How to aggregate Q-values. Options: 'mean', 'max', - 'softmax', 'logsumexp'. - q_value_temperature: Temperature parameter for the 'softmax' and - 'logsumexp' aggregation methods. - q_value_n_samples: Number of samples to average over when calculating - baselines based on Q-values. - q_value_normalization: How to normalize Q-values before aggregation. - Allowed values: 'std', 'abs', `None`. If `None`, don't normalize. - offline: Whether to train in offline mode. This matters for some - algorithms, e.g. QWR. - **kwargs: Arguments for `PolicyAgent` superclass. - """ - self._n_shared_layers = n_shared_layers - self._value_batch_size = value_batch_size - self._value_train_steps_per_epoch = value_train_steps_per_epoch - self._value_evals_per_epoch = value_evals_per_epoch - self._value_eval_steps = value_eval_steps - - # The 2 below will be initalized in super.__init__ anyway, but are needed - # to construct value batches which are needed before PolicyAgent init - # since policy input creation calls the value model -- hence this code. - self._task = task - self._max_slice_length = kwargs.get('max_slice_length', 1) - self._added_policy_slice_length = added_policy_slice_length - self._n_replay_epochs = n_replay_epochs - task.set_n_replay_epochs(n_replay_epochs) - - if scale_value_targets: - self._value_network_scale = 1 / (1 - self._task.gamma) - else: - self._value_network_scale = 1 - - self._q_value = q_value - self._q_value_aggregate = q_value_aggregate - self._q_value_temperature = q_value_temperature - self._q_value_n_samples = q_value_n_samples - self._q_value_normalization = q_value_normalization - - is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete) - self._is_discrete = is_discrete - self._vocab_size = None - self._sample_all_discrete_actions = False - if q_value and is_discrete: - self._vocab_size = self.task.action_space.n - # TODO(lukaszkaiser): the code below is specific to AWR, move it. - # If n_samples = n_actions, we'll take them all in actor and reweight. - if self._q_value_n_samples == self._vocab_size: - # TODO(lukaszkaiser): set this explicitly once it's in AWR Trainer. - self._sample_all_discrete_actions = True - if offline and is_discrete: - raise NotImplementedError( - 'Offline training is only supported for continuous action spaces for ' - 'now.' - ) - self._offline = offline - - if q_value: - value_model = functools.partial(value_model, - inject_actions=True, - is_discrete=is_discrete, - vocab_size=self._vocab_size) - self._value_eval_model = value_model(mode='eval') - self._value_eval_model.init(self._value_model_signature) - self._value_eval_jit = tl.jit_forward( - self._value_eval_model.pure_fn, fastmath.local_device_count(), - do_mean=False) - - # Initialize policy training. - super().__init__(task, **kwargs) - - # Initialize training of the value function. - value_output_dir = kwargs.get('output_dir', None) - if value_output_dir is not None: - value_output_dir = os.path.join(value_output_dir, 'value') - # If needed, create value_output_dir and missing parent directories. - if not tf.io.gfile.isdir(value_output_dir): - tf.io.gfile.makedirs(value_output_dir) - self._value_inputs = data.inputs.Inputs( - train_stream=lambda _: self.value_batches_stream()) - self._value_trainer = supervised.Trainer( - model=value_model, - optimizer=value_optimizer, - lr_schedule=value_lr_schedule(), - loss_fn=tl.L2Loss(), - inputs=self._value_inputs, - output_dir=value_output_dir, - metrics={'value_loss': tl.L2Loss(), - 'value_mean': self.value_mean}) - - @property - def value_mean(self): - """The mean value of the value function.""" - # TODO(henrykm): A better solution would take into account the masks - def f(values): - return jnp.mean(values) - return tl.Fn('ValueMean', f) - - @property - def _value_model_signature(self): - obs_sig = shapes.signature(self._task.observation_space) - target_sig = mask_sig = shapes.ShapeDtype( - shape=(1, 1, 1), - ) - inputs_sig = (obs_sig.replace(shape=(1, 1) + obs_sig.shape),) - if self._q_value: - act_sig = shapes.signature(self._task.action_space) - inputs_sig += (act_sig.replace(shape=(1, 1) + act_sig.shape),) - return (*inputs_sig, target_sig, mask_sig) - - @property - def _replay_epochs(self): - if self.on_policy: - assert self._n_replay_epochs == 1, ( - 'Non-unit replay buffer size only makes sense for off-policy ' - 'algorithms.' - ) - return [-(ep + 1) for ep in range(self._n_replay_epochs)] - - def _run_value_model(self, observations, dist_inputs): - if dist_inputs is None: - dist_inputs = jnp.zeros( - observations.shape[:2] + (self._policy_dist.n_inputs,) - ) - - actions = None - if self._q_value: - if self._sample_all_discrete_actions: - # Since we want to sample all actions, start by creating their list. - act = np.arange(self._vocab_size) - # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it. - # Add extra dimenstions so it's the same dimensionality as dist_inputs. - act = jnp.reshape(act, [-1] + [1] * (len(dist_inputs.shape) - 1)) - # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs. - dist_inputs = jnp.broadcast_to( - dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape) - if self._sample_all_discrete_actions: - actions = act + jnp.zeros(dist_inputs.shape[:-1], dtype=jnp.int32) - actions = jnp.swapaxes(actions, 0, 1) - # Swapping the n_samples and batch_size axes, so the input is split - # between accelerators along the batch_size axis. - dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) - if not self._sample_all_discrete_actions: - actions = self._policy_dist.sample(dist_inputs) - log_probs = self._policy_dist.log_prob(dist_inputs, actions) - obs = observations - obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) - inputs = (obs, actions) - else: - log_probs = None - inputs = (observations,) - - n_devices = fastmath.local_device_count() - weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) - state = tl.for_n_devices(self._value_eval_model.state, n_devices) - rng = self._value_eval_model.rng - values, _ = self._value_eval_jit(inputs, weights, state, rng) - values *= self._value_network_scale - values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. - return (values, actions, log_probs) - - def _aggregate_values(self, values, aggregate, act_log_probs): - # Normalize the Q-values before aggragetion, so it can adapt to the scale - # of the returns. This does not affect mean and max aggregation. - scale = 1 - epsilon = 1e-5 - if self._q_value_normalization == 'std': - scale = jnp.std(values) + epsilon - elif self._q_value_normalization == 'abs': - scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon - values /= scale - - temp = self._q_value_temperature - if self._q_value: - assert values.shape[:2] == ( - self._value_batch_size, self._q_value_n_samples - ) - if aggregate == 'max': - # max_a Q(s, a) - values = jnp.max(values, axis=1) - elif aggregate == 'softmax': - # sum_a (Q(s, a) * w(s, a)) - # where w(s, .) = softmax (Q(s, .) / T) - weights = tl.Softmax(axis=1)(values / temp) - values = jnp.sum(values * weights, axis=1) - elif aggregate == 'logsumexp': - # log(mean_a exp(Q(s, a) / T)) * T - n = values.shape[1] - values = (fastmath.logsumexp(values / temp, axis=1) - jnp.log(n)) * temp - else: - assert aggregate == 'mean' - # mean_a Q(s, a) - if self._sample_all_discrete_actions: - values = jnp.sum(values * jnp.exp(act_log_probs), axis=1) - else: - values = jnp.mean(values, axis=1) - - # Re-scale the Q-values after aggregation. - values *= scale - return np.array(values) # Move the values to CPU. - - def _get_dist_inputs(self, trajectory): - if not self._offline: - return trajectory.dist_inputs - else: - return trajectory.action - - def value_batches_stream(self): - """Use the RLTask self._task to create inputs to the value model.""" - max_slice_length = self._max_slice_length + self._added_policy_slice_length - for np_trajectory in self._task.trajectory_batch_stream( - self._value_batch_size, - max_slice_length=max_slice_length, - min_slice_length=(1 + self._added_policy_slice_length), - margin=self._added_policy_slice_length, - epochs=self._replay_epochs, - ): - dist_inputs = self._get_dist_inputs(np_trajectory) - (values, _, act_log_probs) = self._run_value_model( - np_trajectory.observation, dist_inputs - ) - values = self._aggregate_values( - values, self._q_value_aggregate, act_log_probs) - - # TODO(pkozakowski): Add some shape assertions and docs. - # Calculate targets based on the advantages over the target network - this - # allows TD learning for value networks. - advantages = self._advantage_estimator( - rewards=np_trajectory.reward, - returns=np_trajectory.return_, - values=values, - dones=np_trajectory.done, - discount_mask=np_trajectory.env_info.discount_mask, - ) - length = advantages.shape[1] - values = values[:, :length] - target_returns = values + advantages - - inputs = (np_trajectory.observation[:, :length],) - if self._q_value: - inputs += (np_trajectory.action[:, :length],) - - # Insert an extra depth dimension, so the target shape is consistent with - # the network output shape. - yield ( - # Inputs: observations and maybe actions. - *inputs, - # Targets: computed returns. - target_returns[:, :, None] / self._value_network_scale, - # Mask to zero-out padding. - np_trajectory.mask[:, :length, None], - ) - - def policy_inputs(self, trajectory, values): - """Create inputs to policy model from a TimeStepBatch and values. - - Args: - trajectory: a TimeStepBatch, the trajectory to create inputs from - values: a numpy array: value function computed on trajectory - - Returns: - a tuple of numpy arrays of the form (inputs, x1, x2, ...) that will be - passed to the policy model; policy model will compute outputs from - inputs and (outputs, x1, x2, ...) will be passed to self.policy_loss - which should be overridden accordingly. - """ - return NotImplementedError - - def policy_batches_stream(self): - """Use the RLTask self._task to create inputs to the policy model.""" - # Maximum slice length for policy is max_slice_len + the added policy len. - max_slice_length = self._max_slice_length + self._added_policy_slice_length - for np_trajectory in self._task.trajectory_batch_stream( - self._policy_batch_size, - epochs=self._replay_epochs, - max_slice_length=max_slice_length, - margin=self._added_policy_slice_length, - ): - dist_inputs = self._get_dist_inputs(np_trajectory) - (values, _, act_log_probs) = self._run_value_model( - np_trajectory.observation, dist_inputs) - values = self._aggregate_values(values, 'mean', act_log_probs) - if len(values.shape) != 2: - raise ValueError('Values are expected to have shape ' + - '[batch_size, length], got: %s' % str(values.shape)) - if values.shape[0] != self._policy_batch_size: - raise ValueError('Values first dimension should = policy batch size, ' + - '%d != %d' %(values.shape[0], self._policy_batch_size)) - yield self.policy_inputs(np_trajectory, values) - - def train_epoch(self): - """Trains RL for one epoch.""" - # Copy policy state accumulated during data collection to the trainer. - self._policy_trainer.model_state = self._policy_collect_model.state - - # Copy policy weights and state to value trainer. - if self._n_shared_layers > 0: - _copy_model_weights_and_state( - 0, self._n_shared_layers, self._policy_trainer, self._value_trainer - ) - - # Update the target value network. - self._value_eval_model.weights = self._value_trainer.model_weights - self._value_eval_model.state = self._value_trainer.model_state - - n_value_evals = rl_training.remaining_evals( - self._value_trainer.step, - self._epoch, - self._value_train_steps_per_epoch, - self._value_evals_per_epoch) - for _ in range(n_value_evals): - self._value_trainer.train_epoch( - self._value_train_steps_per_epoch // self._value_evals_per_epoch, - self._value_eval_steps, - ) - # Update the target value network. - self._value_eval_model.weights = self._value_trainer.model_weights - self._value_eval_model.state = self._value_trainer.model_state - - # Copy value weights and state to policy trainer. - if self._n_shared_layers > 0: - _copy_model_weights_and_state( - 0, self._n_shared_layers, self._value_trainer, self._policy_trainer - ) - n_policy_evals = rl_training.remaining_evals( - self._policy_trainer.step, - self._epoch, - self._policy_train_steps_per_epoch, - self._policy_evals_per_epoch) - # Check if there was a restart after value training finishes and policy not. - stopped_after_value = (n_value_evals == 0 and - n_policy_evals < self._policy_evals_per_epoch) - should_copy_weights = self._n_shared_layers > 0 and not stopped_after_value - if should_copy_weights: - _copy_model_weights_and_state( - 0, self._n_shared_layers, self._value_trainer, self._policy_trainer - ) - - # Update the target value network. - self._value_eval_model.weights = self._value_trainer.model_weights - self._value_eval_model.state = self._value_trainer.model_state - - for _ in range(n_policy_evals): - self._policy_trainer.train_epoch( - self._policy_train_steps_per_epoch // self._policy_evals_per_epoch, - self._policy_eval_steps, - ) - - def close(self): - self._value_trainer.close() - super().close() - - -def _copy_model_weights_and_state( # pylint: disable=invalid-name - start, end, from_trainer, to_trainer, copy_optimizer_slots=False -): - """Copy model weights[start:end] from from_trainer to to_trainer.""" - from_weights = from_trainer.model_weights - to_weights = list(to_trainer.model_weights) - shared_weights = from_weights[start:end] - to_weights[start:end] = shared_weights - to_trainer.model_weights = to_weights - - from_state = from_trainer.model_state - to_state = list(to_trainer.model_state) - shared_state = from_state[start:end] - to_state[start:end] = shared_state - to_trainer.model_state = to_state - - if copy_optimizer_slots: - # TODO(lukaszkaiser): make a nicer API in Trainer to support this. - # Currently we use the hack below. Note [0] since that's the model w/o loss. - # pylint: disable=protected-access - from_slots = from_trainer._opt_state.slots[0][start:end] - to_slots = to_trainer._opt_state.slots[0] - # The lines below do to_slots[start:end] = from_slots, but on tuples. - new_slots = to_slots[:start] + from_slots[start:end] + to_slots[end:] - new_slots = tuple([new_slots] + list(to_trainer._opt_state.slots[1:])) - to_trainer._opt_state = to_trainer._opt_state._replace(slots=new_slots) - # pylint: enable=protected-access - - -class AdvantageBasedActorCriticAgent(ActorCriticAgent): - """Base class for advantage-based actor-critic algorithms.""" - - def __init__( - self, - task, - advantage_estimator=rl_advantages.td_lambda, - advantage_normalization=True, - advantage_normalization_epsilon=1e-5, - advantage_normalization_factor=1.0, - added_policy_slice_length=0, - **kwargs - ): - self._advantage_estimator = advantage_estimator( - gamma=task.gamma, margin=added_policy_slice_length - ) - self._advantage_normalization = advantage_normalization - self._advantage_normalization_epsilon = advantage_normalization_epsilon - self._advantage_normalization_factor = advantage_normalization_factor - super().__init__( - task, added_policy_slice_length=added_policy_slice_length, **kwargs - ) - - def policy_inputs(self, trajectory, values): - """Create inputs to policy model from a TimeStepBatch and values.""" - # How much TD to use is determined by the added policy slice length, - # as the policy batches need to be this much longer to calculate TD. - advantages = self._advantage_estimator( - rewards=trajectory.reward, - returns=trajectory.return_, - values=values, - dones=trajectory.done, - discount_mask=trajectory.env_info.discount_mask, - ) - # Observations should be the same length as advantages - so if we are - # using n_extra_steps, we need to trim the length to match. - obs = trajectory.observation[:, :advantages.shape[1]] - act = trajectory.action[:, :advantages.shape[1]] - mask = trajectory.mask[:, :advantages.shape[1]] # Mask to zero-out padding. - if trajectory.dist_inputs is not None: - dist_inputs = self._get_dist_inputs(trajectory) - dist_inputs = dist_inputs[:, :advantages.shape[1]] - else: - dist_inputs = jnp.zeros(advantages.shape + (self._policy_dist.n_inputs,)) - # Shape checks to help debugging. - if len(advantages.shape) != 2: - raise ValueError('Advantages are expected to have shape ' + - '[batch_size, length], got: %s' % str(advantages.shape)) - if act.shape[0:2] != advantages.shape: - raise ValueError('First 2 dimensions of actions should be the same as in ' - 'advantages, %s != %s' % (act.shape[0:2], - advantages.shape)) - if obs.shape[0:2] != advantages.shape: - raise ValueError('First 2 dimensions of observations should be the same ' - 'as in advantages, %s != %s' % (obs.shape[0:2], - advantages.shape)) - if dist_inputs.shape[:2] != advantages.shape: - raise ValueError('First 2 dimensions of dist_inputs should be the same ' - 'as in advantages, %s != %s' % (dist_inputs.shape[:2], - advantages.shape)) - if mask.shape != advantages.shape: - raise ValueError('Mask and advantages shapes should be the same' - ', %s != %s' % (mask.shape, advantages.shape)) - return (obs, act, advantages, dist_inputs, mask) - - @property - def policy_loss_given_log_probs(self): - """Policy loss given action log-probabilities.""" - raise NotImplementedError - - def _preprocess_advantages(self, advantages): - if self._advantage_normalization: - advantages = self._advantage_normalization_factor * ( - (advantages - jnp.mean(advantages)) / - (jnp.std(advantages) + self._advantage_normalization_epsilon) - ) - return advantages - - @property - def policy_loss(self, **unused_kwargs): - """Policy loss.""" - def LossInput(dist_inputs, actions, advantages, old_dist_inputs): # pylint: disable=invalid-name - """Calculates action log probabilities and normalizes advantages.""" - advantages = self._preprocess_advantages(advantages) - log_probs = self._policy_dist.log_prob(dist_inputs, actions) - old_log_probs = self._policy_dist.log_prob(old_dist_inputs, actions) - return (log_probs, advantages, old_log_probs) - - return tl.Serial( - tl.Fn('LossInput', LossInput, n_out=3), - # Policy loss is expected to consume - # (log_probs, advantages, old_log_probs, mask). - self.policy_loss_given_log_probs, - ) - - @property - def policy_metrics(self): - metrics = super().policy_metrics - metrics.update({ - 'advantage_mean': self.advantage_mean, - 'advantage_std': self.advantage_std, - }) - return metrics - - @property - def advantage_mean(self): - return tl.Serial([ - # (dist_inputs, advantages, old_dist_inputs, mask) - tl.Select([1]), # Select just the advantages. - tl.Fn('AdvantageMean', lambda x: jnp.mean(x)), # pylint: disable=unnecessary-lambda - ]) - - @property - def advantage_std(self): - return tl.Serial([ - # (dist_inputs, advantages, old_dist_inputs, mask) - tl.Select([1]), # Select just the advantages. - tl.Fn('AdvantageStd', lambda x: jnp.std(x)), # pylint: disable=unnecessary-lambda - ]) - - -# TODO(pkozakowski): Move to a better place. -@gin.configurable(module='trax.rl') -def every(n_steps): - """Returns True every n_steps, for use as *_at functions in various places.""" - return lambda step: step % n_steps == 0 - - -# TODO(pkozakowski): Rewrite all interleaved actor-critic algos to subclass -# this, then rename to ActorCriticAgent and remove the other base classes. -class LoopActorCriticAgent(rl_training.Agent): - """Base class for actor-critic algorithms based on `Loop`.""" - - on_policy = None - - def __init__( - self, task, model_fn, - optimizer=adam.Adam, - policy_lr_schedule=lr.multifactor, - policy_n_steps_per_epoch=1000, - policy_weight_fn=(lambda x: x), - value_lr_schedule=lr.multifactor, - value_n_steps_per_epoch=1000, - value_sync_at=(lambda x: x % 100 == 0), - advantage_estimator=rl_advantages.monte_carlo, - batch_size=64, - network_eval_at=None, - n_eval_batches=1, - max_slice_length=1, - margin=0, - n_replay_epochs=1, - **kwargs - ): - """Initializes LoopActorCriticAgent. - - Args: - task: `RLTask` instance to use. - model_fn: Function mode -> Trax model, building a joint policy and value - network. - optimizer: Optimizer for the policy and value networks. - policy_lr_schedule: Learning rate schedule for the policy network. - policy_n_steps_per_epoch: Number of steps to train the policy network for - in each epoch. - policy_weight_fn: Function advantages -> weights for calculating the - log probability weights in policy training. - value_lr_schedule: Learning rate schedule for the value network. - value_n_steps_per_epoch: Number of steps to train the value network for - in each epoch. - value_sync_at: Function step -> bool indicating when to synchronize the - target network with the trained network in value training. - advantage_estimator: Advantage estimator to use in policy and value - training. - batch_size: Batch size for training the networks. - network_eval_at: Function step -> bool indicating in when to evaluate the - networks. - n_eval_batches: Number of batches to compute the network evaluation - metrics on. - max_slice_length: Maximum length of a trajectory slice to train on. - margin: Number of timesteps to add at the end of each trajectory slice for - better advantage estimation. - n_replay_epochs: Number of epochs of trajectories to store in the replay - buffer. - **kwargs: Keyword arguments forwarded to Agent. - """ - super().__init__(task, **kwargs) - - self._policy_dist = distributions.create_distribution( - self.task.action_space - ) - model_fn = functools.partial( - model_fn, - policy_distribution=self._policy_dist, - ) - train_model = model_fn(mode='train') - eval_model = model_fn(mode='eval') - - trajectory_batch_stream = self._init_trajectory_batch_stream( - batch_size, max_slice_length, margin, n_replay_epochs - ) - advantage_estimator = advantage_estimator(task.gamma, margin=margin) - (value_train_task, value_eval_task) = self._init_value_tasks( - trajectory_batch_stream, - optimizer=optimizer(), - lr_schedule=value_lr_schedule(), - advantage_estimator=advantage_estimator, - train_model=train_model, - eval_model=eval_model, - sync_at=value_sync_at, - n_steps_per_epoch=value_n_steps_per_epoch, - n_eval_batches=n_eval_batches, - ) - (policy_train_task, policy_eval_task) = self._init_policy_tasks( - trajectory_batch_stream, - optimizer=optimizer(), - lr_schedule=policy_lr_schedule(), - advantage_estimator=advantage_estimator, - value_train_task=value_train_task, - weight_fn=policy_weight_fn, - n_eval_batches=n_eval_batches, - ) - self._init_loop( - train_model=train_model, - eval_model=eval_model, - policy_train_and_eval_task=(policy_train_task, policy_eval_task), - value_train_and_eval_task=(value_train_task, value_eval_task), - eval_at=network_eval_at, - policy_n_steps_per_epoch=policy_n_steps_per_epoch, - value_n_steps_per_epoch=value_n_steps_per_epoch, - ) - self._init_collection(model_fn, policy_train_task.sample_batch) - - def _init_trajectory_batch_stream( - self, batch_size, max_slice_length, margin, n_replay_epochs - ): - assert self.on_policy is not None, 'Attribute "on_policy" not set.' - if self.on_policy: - assert n_replay_epochs == 1, ( - 'Non-unit replay buffer size only makes sense for off-policy ' - 'algorithms.' - ) - self._task.set_n_replay_epochs(n_replay_epochs) - self._max_slice_length = max_slice_length - return self._task.trajectory_batch_stream( - batch_size, - epochs=[-(ep + 1) for ep in range(n_replay_epochs)], - min_slice_length=(1 + margin), - max_slice_length=(self._max_slice_length + margin), - margin=margin, - ) - - def _init_value_tasks( - self, - trajectory_batch_stream, - optimizer, - lr_schedule, - advantage_estimator, - train_model, - eval_model, - sync_at, - n_steps_per_epoch, - n_eval_batches, - ): - def sync_also_at_epoch_boundaries(step): - return sync_at(step) or ( - # 0 - end of the epoch, 1 - beginning of the next. - step % n_steps_per_epoch in (0, 1) - ) - - head_selector = tl.Select([1]) - value_train_task = value_tasks.ValueTrainTask( - trajectory_batch_stream, - optimizer, - lr_schedule, - advantage_estimator=advantage_estimator, - model=train_model, - target_model=eval_model, - target_scale=(1 - self.task.gamma), - sync_at=sync_also_at_epoch_boundaries, - head_selector=head_selector, - ) - value_eval_task = value_tasks.ValueEvalTask( - value_train_task, n_eval_batches, head_selector - ) - return (value_train_task, value_eval_task) - - def _init_policy_tasks( - self, - trajectory_batch_stream, - optimizer, - lr_schedule, - advantage_estimator, - value_train_task, - weight_fn, - n_eval_batches, - ): - head_selector = tl.Select([0], n_in=2) - policy_train_task = policy_tasks.PolicyTrainTask( - trajectory_batch_stream, - optimizer, - lr_schedule, - self._policy_dist, - advantage_estimator=advantage_estimator, - value_fn=value_train_task.value, - weight_fn=weight_fn, - head_selector=head_selector, - ) - policy_eval_task = policy_tasks.PolicyEvalTask( - policy_train_task, n_eval_batches, head_selector - ) - return (policy_train_task, policy_eval_task) - - def _init_loop( - self, - train_model, - eval_model, - policy_train_and_eval_task, - value_train_and_eval_task, - eval_at, - policy_n_steps_per_epoch, - value_n_steps_per_epoch, - ): - (policy_train_task, policy_eval_task) = policy_train_and_eval_task - (value_train_task, value_eval_task) = value_train_and_eval_task - - if self._output_dir is not None: - model_output_dir = os.path.join(self._output_dir, 'model') - else: - model_output_dir = None - - self._n_train_steps_per_epoch = ( - policy_n_steps_per_epoch + value_n_steps_per_epoch - ) - - checkpoint_at = lambda step: step % self._n_train_steps_per_epoch == 0 - - def which_task(step): - if step % self._n_train_steps_per_epoch < value_n_steps_per_epoch: - return 1 - else: - return 0 - - self._loop = supervised.training.Loop( - model=train_model, - tasks=(policy_train_task, value_train_task), - eval_model=eval_model, - eval_tasks=(policy_eval_task, value_eval_task), - output_dir=model_output_dir, - eval_at=eval_at, - checkpoint_at=checkpoint_at, - which_task=which_task, - ) - - # Validate the restored checkpoints. - # TODO(pkozakowski): Move this to the base class once all Agents use Loop. - if self._loop.step != self._epoch * self._n_train_steps_per_epoch: - raise ValueError( - 'The number of Loop steps must equal the number of Agent epochs ' - 'times the number of steps per epoch, got {}, {} and {}.'.format( - self._loop.step, self._epoch, self._n_train_steps_per_epoch - ) - ) - - def _init_collection(self, model_fn, sample_batch): - self._collect_model = model_fn(mode='collect') - self._collect_model.init(shapes.signature(sample_batch)) - - @property - def loop(self): - """Loop exposed for testing.""" - return self._loop - - def policy(self, trajectory, temperature=1.0): - """Policy function that allows to play using this agent.""" - tr_slice = trajectory.suffix(self._max_slice_length) - trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) - return rl_training.network_policy( - collect_model=self._collect_model, - policy_distribution=self._policy_dist, - loop=self.loop, - trajectory_np=trajectory_np, - head_index=0, - temperature=temperature, - ) - - def train_epoch(self): - """Trains RL for one epoch.""" - # Copy policy state accumulated during data collection to the trainer. - self._loop.update_weights_and_state(state=self._collect_model.state) - # Perform one gradient step per training epoch to ensure we stay on policy. - self._loop.run(n_steps=self._n_train_steps_per_epoch) - - -### Implementations of common actor-critic algorithms. - - -class A2C(AdvantageBasedActorCriticAgent): - """Trains policy and value models using the A2C algorithm.""" - - on_policy = True - - def __init__(self, task, entropy_coeff=0.01, **kwargs): - """Configures the A2C Trainer.""" - self._entropy_coeff = entropy_coeff - super().__init__(task, **kwargs) - - @property - def policy_loss_given_log_probs(self): - """Definition of the Advantage Actor Critic (A2C) loss.""" - # A2C is one of the most basic actor-critic RL algorithms. - # TODO(henrykm) re-factor f into rl_layers and finally share code between - # actor_critic.py and actor_critic_joint.py - requires change of inputs - # in actor_critic_joint.py from dist_inputs to log_probs. - def f(log_probs, advantages, old_log_probs, mask): - del old_log_probs # Not used in A2C. - # log_probs of the shape float32[128,1] - # advantages of the shape int32[128,1] - # mask of the shape int32[128,1] - if log_probs.shape != advantages.shape: - raise ValueError('New log-probs and advantages shapes ' - 'should be the same, %s != %s' % (log_probs.shape, - advantages.shape)) - if log_probs.shape != mask.shape: - raise ValueError('New log-probs and mask shapes should be the same' - ', %s != %s' % (log_probs.shape, mask.shape)) - - a2c_objective = -jnp.sum(log_probs * advantages * mask) / jnp.sum(mask) - - entropy_vec = self._policy_dist.entropy(log_probs) * self._entropy_coeff - entropy_loss = jnp.mean(entropy_vec) - - combined_loss = a2c_objective - entropy_loss - - return combined_loss - - return tl.Fn('A2CLoss', f) - - -class PPO(AdvantageBasedActorCriticAgent): - """The Proximal Policy Optimization Algorithm aka PPO. - - Trains policy and value models using the PPO algorithm. - """ - - on_policy = True - - def __init__(self, task, epsilon=0.2, entropy_coeff=0.01, **kwargs): - """Configures the PPO Trainer.""" - self._entropy_coeff = entropy_coeff - self._epsilon = epsilon - super().__init__(task, **kwargs) - - @property - def policy_loss_given_log_probs(self): - """Definition of the Proximal Policy Optimization loss.""" - def f(new_log_probs, advantages, old_log_probs, mask): - # new_log_probs of the shape float32[128,1] - # advantages of the shape int32[128,1] - # old_log_probs of the shape int32[128,1] - # mask of the shape int32[128,1] - if new_log_probs.shape != advantages.shape: - raise ValueError('New log-probs and advantages shapes ' - 'should be the same, %s != %s' % (new_log_probs.shape, - advantages.shape)) - if new_log_probs.shape != old_log_probs.shape: - raise ValueError('New log-probs and old log-probs shapes ' - 'should be the same, %s != %s' % (new_log_probs.shape, - old_log_probs.shape)) - if new_log_probs.shape != mask.shape: - raise ValueError('New log-probs and mask shapes should be the same' - ', %s != %s' % (new_log_probs.shape, mask.shape)) - - # The ratio between new_probs and old_probs expressed - # using log_probs and exponentiation - probs_ratio = jnp.exp(new_log_probs - old_log_probs) - if advantages.shape != probs_ratio.shape: - raise ValueError('New log-probs and old log probs shapes ' - 'should be the same, %s != %s' % (advantages.shape, - probs_ratio.shape)) - unclipped_objective = probs_ratio * advantages - clipped_objective = jnp.clip(probs_ratio, - 1 - self._epsilon, - 1 + self._epsilon) * advantages - - if unclipped_objective.shape != probs_ratio.shape: - raise ValueError('unclipped_objective and clipped_objective shapes ' - 'should be the same, %s != %s' % ( - unclipped_objective.shape, - clipped_objective.shape)) - - ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) - - if ppo_objective.shape != mask.shape: - raise ValueError('ppo_objective and mask shapes ' - 'should be the same, %s != %s' % ( - ppo_objective.shape, - mask.shape)) - - ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask) - entropy_vec = self._policy_dist.entropy( - new_log_probs) * self._entropy_coeff - entropy_loss = jnp.mean(entropy_vec) - combined_loss = ppo_loss - entropy_loss - - return combined_loss - return tl.Fn('PPOLoss', f) - - -def _weighted_percentiles(x, thresholds): - """Calculate weights for x by percentile-and-weights given in thresholds. - - Thresholds contain a list of (p, weight, minumum). For each threshold, - all elements of x that are above the p-th percentile *and* above minimum - get the weight weight, and all other get the weight 0. - The result is the sum over all thresholds. - - Args: - x: tensor to calculate the weights for - thresholds: list of triples (percentile, weight, minimum) used to - calculate the weights (see above how) - - Returns: - weights, a tensor of the same shape as x - """ - res = [] - for (percentile, weight, minimum) in thresholds: - threshold = jnp.percentile(x, percentile) - if minimum is not None: - threshold = jnp.maximum(minimum, threshold) - zero_ones = jnp.where(x < threshold, jnp.zeros_like(x), jnp.ones_like(x)) - res.append(weight * zero_ones) - return sum(res) - - -# AWR is an off-policy actor-critic RL algorithm. -def awr_weights(advantages, beta, thresholds): - if thresholds: - return _weighted_percentiles(advantages, thresholds) - return jnp.exp(advantages / beta) - - -# Helper functions for computing AWR metrics. -def awr_metrics(beta, thresholds, preprocess_layer=None): - return { # pylint: disable=g-complex-comprehension - 'awr_weight_' + name: awr_weight_stat(name, fn, beta, thresholds, - preprocess_layer) - for (name, fn) in [ - ('mean', jnp.mean), - ('std', jnp.std), - ('min', jnp.min), - ('max', jnp.max), - ] - } - - -def awr_weight_stat(stat_name, stat_fn, beta, thresholds, preprocess_layer): - # Select just the advantages if preprocess layer is not given. - preprocess = tl.Select([1]) if preprocess_layer is None else preprocess_layer - return tl.Serial([ - preprocess, - tl.Fn( - 'AWRWeight' + stat_name.capitalize(), - lambda x: stat_fn(awr_weights(x, beta, thresholds)), - ), - ]) - - -def AWRLoss(beta, w_max, thresholds): # pylint: disable=invalid-name - """Definition of the Advantage Weighted Regression (AWR) loss.""" - def f(log_probs, advantages, old_log_probs, mask): - del old_log_probs # Not used in AWR. - weights = jnp.minimum(awr_weights(advantages, beta, thresholds), w_max) - return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask) - return tl.Fn('AWRLoss', f) - - -class AWR(AdvantageBasedActorCriticAgent): - """Trains policy and value models using AWR.""" - - on_policy = False - - def __init__(self, task, beta=1.0, w_max=20.0, thresholds=None, **kwargs): - """Configures the AWR Trainer.""" - self._beta = beta - self._w_max = w_max - self._thresholds = thresholds - super().__init__(task, **kwargs) - - @property - def policy_loss_given_log_probs(self): - """Policy loss.""" - return AWRLoss(beta=self._beta, w_max=self._w_max, - thresholds=self._thresholds) # pylint: disable=no-value-for-parameter - - -class LoopAWR(LoopActorCriticAgent): - """Advantage Weighted Regression.""" - - on_policy = False - - def __init__(self, task, model_fn, beta=1.0, w_max=20, **kwargs): - def policy_weight_fn(advantages): - return jnp.minimum(jnp.exp(advantages / beta), w_max) - super().__init__( - task, model_fn, policy_weight_fn=policy_weight_fn, **kwargs - ) - - -def SamplingAWRLoss(beta, w_max, thresholds, # pylint: disable=invalid-name - reweight=False, sampled_all_discrete=False): - """Definition of the Advantage Weighted Regression (AWR) loss.""" - def f(log_probs, advantages, old_log_probs, mask): - if reweight: # Use new policy weights for sampled actions instead. - mask *= jnp.exp(fastmath.stop_gradient(log_probs) - old_log_probs) - if sampled_all_discrete: # Actions were sampled uniformly; weight them. - mask *= jnp.exp(old_log_probs) - weights = jnp.minimum(awr_weights(advantages, beta, thresholds), w_max) - return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask) - return tl.Fn('SamplingAWRLoss', f) - - -class SamplingAWR(AdvantageBasedActorCriticAgent): - """Trains policy and value models using Sampling AWR.""" - - on_policy = False - - def __init__(self, task, beta=1.0, w_max=20.0, thresholds=None, - reweight=False, **kwargs): - """Configures the AWR Trainer.""" - self._beta = beta - self._w_max = w_max - self._thresholds = thresholds - self._reweight = reweight - super().__init__(task, q_value=True, **kwargs) - - def _policy_inputs_to_advantages(self, preprocess): - """A layer that computes advantages from policy inputs.""" - def fn(dist_inputs, actions, q_values, act_log_probs, mask): - del dist_inputs, actions, mask - q_values = jnp.swapaxes(q_values, 0, 1) - act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) - if self._sample_all_discrete_actions: - values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) - else: - values = jnp.mean(q_values, axis=0) - advantages = q_values - values # Broadcasting values over n_samples - if preprocess: - advantages = self._preprocess_advantages(advantages) - return advantages - return tl.Fn('PolicyInputsToAdvantages', fn) - - @property - def policy_metrics(self): - metrics = { - 'policy_loss': self.policy_loss, - 'advantage_mean': tl.Serial( - self._policy_inputs_to_advantages(False), - tl.Fn('Mean', lambda x: jnp.mean(x)) # pylint: disable=unnecessary-lambda - ), - 'advantage_std': tl.Serial( - self._policy_inputs_to_advantages(False), - tl.Fn('Std', lambda x: jnp.std(x)) # pylint: disable=unnecessary-lambda - ) - } - metrics.update(awr_metrics( - self._beta, self._thresholds, - preprocess_layer=self._policy_inputs_to_advantages(True))) - return metrics - - @property - def policy_loss(self, **unused_kwargs): - """Policy loss.""" - def LossInput(dist_inputs, actions, q_values, act_log_probs, mask): # pylint: disable=invalid-name - """Calculates action log probabilities and normalizes advantages.""" - # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...) - q_values = jnp.swapaxes(q_values, 0, 1) - mask = jnp.swapaxes(mask, 0, 1) - actions = jnp.swapaxes(actions, 0, 1) - act_log_probs = jnp.swapaxes(act_log_probs, 0, 1) - - # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting? - if self._sample_all_discrete_actions: - values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0) - else: - values = jnp.mean(q_values, axis=0) - advantages = q_values - values # Broadcasting values over n_samples - advantages = self._preprocess_advantages(advantages) - - # Broadcast inputs and calculate log-probs - dist_inputs = jnp.broadcast_to( - dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape) - log_probs = self._policy_dist.log_prob(dist_inputs, actions) - return (log_probs, advantages, act_log_probs, mask) - - return tl.Serial( - tl.Fn('LossInput', LossInput, n_out=4), - # Policy loss is expected to consume - # (log_probs, advantages, old_log_probs, mask). - SamplingAWRLoss( - beta=self._beta, w_max=self._w_max, thresholds=self._thresholds, - reweight=self._reweight, - sampled_all_discrete=self._sample_all_discrete_actions) - ) - - def policy_batches_stream(self): - """Use the RLTask self._task to create inputs to the policy model.""" - # For now TD-0 estimation of the value. TODO(pkozakowski): Support others? - for np_trajectory in self._task.trajectory_batch_stream( - self._policy_batch_size, - epochs=self._replay_epochs, - max_slice_length=self._max_slice_length, - ): - dist_inputs = self._get_dist_inputs(np_trajectory) - (q_values, actions, act_log_probs) = self._run_value_model( - np_trajectory.observation, dist_inputs) - shapes.assert_same_shape(q_values, act_log_probs) - - # q_values shape: (batch_size, n_samples, length) - if len(q_values.shape) != 3: - raise ValueError('Q-values are expected to have shape [batch_size, ' + - 'n_samples, length], got: %s' % str(q_values.shape)) - if q_values.shape[1] != self._q_value_n_samples: - raise ValueError('Q-values dimension 1 should = n_samples, %d != %d' - % (q_values.shape[1], self._q_value_n_samples)) - if q_values.shape[0] != self._policy_batch_size: - raise ValueError('Q-values dimension 0 should = policy batch size, ' + - '%d!=%d' %(q_values.shape[1], self._policy_batch_size)) - - mask = np_trajectory.mask - mask = np.reshape(mask, [mask.shape[0], 1] + list(mask.shape[1:])) - mask = jnp.broadcast_to(mask, q_values.shape) - shapes.assert_same_shape(mask, q_values) - yield (np_trajectory.observation, actions, q_values, act_log_probs, mask) diff --git a/trax/rl/actor_critic_joint.py b/trax/rl/actor_critic_joint.py deleted file mode 100644 index 70eeabff2..000000000 --- a/trax/rl/actor_critic_joint.py +++ /dev/null @@ -1,658 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for RL training in Trax.""" - -import functools - -from trax import data -from trax import layers as tl -from trax import supervised -from trax.fastmath import numpy as jnp -from trax.fastmath import stop_gradient -from trax.rl import actor_critic -from trax.rl import distributions -from trax.rl import rl_layers -from trax.rl import training as rl_training -from trax.supervised import lr_schedules as lr - - -# pylint: disable=g-long-lambda -class ActorCriticJointAgent(rl_training.Agent): - """Trains a joint policy-and-value model using actor-critic methods.""" - - def __init__(self, task, - joint_model=None, - optimizer=None, - lr_schedule=lr.multifactor, - batch_size=64, - train_steps_per_epoch=500, - supervised_evals_per_epoch=1, - supervised_eval_steps=1, - n_trajectories_per_epoch=50, - max_slice_length=1, - normalize_advantages=True, - output_dir=None, - n_replay_epochs=1): - """Configures the joint trainer. - - Args: - task: RLTask instance, which defines the environment to train on. - joint_model: Trax layer, representing the joint policy and value model. - optimizer: the optimizer to use to train the joint model. - lr_schedule: learning rate schedule to use to train the joint model/. - batch_size: batch size used to train the joint model. - train_steps_per_epoch: how long to train the joint model in each RL epoch. - supervised_evals_per_epoch: number of value trainer evaluations per RL - epoch - only affects metric reporting. - supervised_eval_steps: number of value trainer steps per evaluation - - only affects metric reporting. - n_trajectories_per_epoch: how many trajectories to collect per epoch. - max_slice_length: the maximum length of trajectory slices to use. - normalize_advantages: if True, then normalize advantages - currently - implemented only in PPO. - output_dir: Path telling where to save outputs (evals and checkpoints). - n_replay_epochs: how many last epochs to take into the replay buffer; - > 1 only makes sense for off-policy algorithms. - """ - super().__init__( - task, - n_trajectories_per_epoch=n_trajectories_per_epoch, - output_dir=output_dir, - ) - self._batch_size = batch_size - self._train_steps_per_epoch = train_steps_per_epoch - self._supervised_evals_per_epoch = supervised_evals_per_epoch - self._supervised_eval_steps = supervised_eval_steps - self._n_trajectories_per_epoch = n_trajectories_per_epoch - self._max_slice_length = max_slice_length - self._policy_dist = distributions.create_distribution(task.action_space) - self._lr_schedule = lr_schedule() - self._optimizer = optimizer - self._normalize_advantages = normalize_advantages - self._n_replay_epochs = n_replay_epochs - self._task.set_n_replay_epochs(n_replay_epochs) - - # Inputs to the joint model are produced by self.batches_stream. - self._inputs = data.inputs.Inputs( - train_stream=lambda _: self.batches_stream()) - - self._joint_model = functools.partial( - joint_model, - policy_distribution=self._policy_dist, - ) - - # This is the joint Trainer that will be used to train the policy model. - # * inputs to the trainer come from self.batches_stream - # * outputs are passed to self._joint_loss - self._trainer = supervised.Trainer( - model=self._joint_model, - optimizer=self._optimizer, - lr_schedule=self._lr_schedule, - loss_fn=self.joint_loss, - inputs=self._inputs, - output_dir=output_dir, - metrics={'joint_loss': self.joint_loss, - 'advantage_mean': self.advantage_mean, - 'advantage_norm': self.advantage_norm, - 'value_loss': self.value_loss, - 'explained_variance': self.explained_variance, - 'log_probs_mean': self.log_probs_mean, - 'preferred_move': self.preferred_move}) - self._eval_model = tl.Accelerate( - self._joint_model(mode='eval'), n_devices=1) - example_batch = next(self.batches_stream()) - self._eval_model.init(example_batch) - - def close(self): - self._trainer.close() - super().close() - - def batches_stream(self): - """Use self.task to create inputs to the policy model.""" - return NotImplementedError - - @property - def joint_loss(self): - """Joint policy and value loss layer.""" - return NotImplementedError - - @property - def advantage_mean(self): - """Mean of advantages.""" - def f(dist_inputs, values, returns): - del dist_inputs - return jnp.mean(returns - values) - return tl.Fn('AdvantageMean', f) - - @property - def advantage_norm(self): - """Norm of advantages.""" - def f(dist_inputs, values, returns): - del dist_inputs - return jnp.linalg.norm(returns - values) - return tl.Fn('AdvantageNorm', f) - - @property - def value_loss(self): - """Value loss - so far generic for all A2C.""" - def f(dist_inputs, values, returns): - del dist_inputs - return rl_layers.ValueLoss(values, returns, self._value_loss_coeff) - return tl.Fn('ValueLoss', f) - - @property - def explained_variance(self): - """Explained variance metric.""" - def f(dist_inputs, values, returns): - del dist_inputs - return rl_layers.ExplainedVariance(values, returns) - return tl.Fn('ExplainedVariance', f) - - @property - def log_probs_mean(self): - """Mean of log_probs aka dist_inputs.""" - def f(dist_inputs, values): - del values - return jnp.mean(dist_inputs) - return tl.Fn('LogProbsMean', f) - - @property - def preferred_move(self): - """Preferred move - the mean of selected moves.""" - def f(dist_inputs, values): - del values - return rl_layers.PreferredMove(dist_inputs, self._policy_dist.sample) - return tl.Fn('PreferredMove', f) - - def policy(self, trajectory, temperature=1.0): - """Chooses an action to play after a trajectory.""" - model = self._eval_model - model.replicate_weights(self._trainer.model_weights) - # The two lines below along with the copying - # before return make the TPU happy - tr_slice = trajectory.suffix(self._max_slice_length) - trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) - # Add batch dimension to trajectory_np and run the model. - pred = model(trajectory_np.observation[None, ...])[0] - # Pick element 0 from the batch (the only one), last (current) timestep. - pred = pred[0, -1, :] - sample = self._policy_dist.sample(pred, temperature=temperature) - return (sample.copy(), pred.copy()) - - def train_epoch(self): - """Trains RL for one epoch.""" - n_evals = rl_training.remaining_evals( - self._trainer.step, - self._epoch, - self._train_steps_per_epoch, - self._supervised_evals_per_epoch) - for _ in range(n_evals): - self._trainer.train_epoch( - self._train_steps_per_epoch // self._supervised_evals_per_epoch, - self._supervised_eval_steps) - - -class PPOJoint(ActorCriticJointAgent): - """The Proximal Policy Optimization Algorithm aka PPO. - - Trains policy and value models using the PPO algortithm. - """ - - # TODO(henrykm): make on_policy more generic - # (currently epochs are passed manually) - on_policy = True - - def __init__(self, task, epsilon=0.2, value_loss_coeff=0.1, - entropy_coeff=0.01, **kwargs): - """Configures the PPO Trainer.""" - self._epsilon = epsilon - self._value_loss_coeff = value_loss_coeff - self._entropy_coeff = entropy_coeff - super().__init__(task, **kwargs) - self._trainer = supervised.Trainer( - model=self._joint_model, - optimizer=self._optimizer, - lr_schedule=self._lr_schedule, - loss_fn=self.joint_loss, - inputs=self._inputs, - output_dir=self._output_dir, - metrics={'joint_loss': self.joint_loss, - 'advantage_mean': self.advantage_mean, - 'advantage_norm': self.advantage_norm, - 'value_loss': self.value_loss, - 'explained_variance': self.explained_variance, - 'log_probs_mean': self.log_probs_mean, - 'entropy_loss': self.entropy_loss, - 'probs_ratio_mean': self.probs_ratio_mean, - 'unclipped_objective_mean': self.unclipped_objective_mean, - 'clipped_objective_mean': self.clipped_objective_mean, - 'ppo_objective_mean': self.ppo_objective_mean, - 'clip_fraction': self.clip_fraction, - 'preferred_move': self.preferred_move, - 'approximate_kl_divergence': self.approximate_kl_divergence}) - - def batches_stream(self): - """Use the RLTask self._task to create inputs to the value model.""" - for np_trajectory in self._task.trajectory_batch_stream( - self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1]): - if np_trajectory.dist_inputs is not None: - old_dist_inputs = np_trajectory.dist_inputs - else: - old_dist_inputs = jnp.zeros( - np_trajectory.reward.shape + (self._policy_dist.n_inputs,) - ) - old_log_probs = self._policy_dist.log_prob( - old_dist_inputs, np_trajectory.action - ) - # Insert an extra depth dimension, so the target shape is consistent with - # the network output shape. - yield (np_trajectory.observation, # Inputs to the value model. - np_trajectory.return_[:, :, None], - np_trajectory.done[:, :, None], - np_trajectory.reward[:, :, None], - np_trajectory.action, - old_log_probs, - np_trajectory.mask) - - @property - def joint_loss(self): - """Joint policy and value loss.""" - def f(dist_inputs, values, returns, dones, rewards, - actions, old_log_probs, mask): - """Definition of the Proximal Policy Optimization loss.""" - del mask # TODO(lukaszkaiser): make PPO work with Transformer - # We have dist_inputs of the shape float32[128,1,18] - assert len(dist_inputs.shape) == 3, ( - f'dist_inputs.shape was {dist_inputs.shape}' - f'but expected length of the tensor shape is 3') - # values of the shape float32[128,1,1] - # returns of the shape float32[128,1,1] - # dones of the shape int32[128,1,1] - # rewards of the shape float32[128,1,1] - # and old_log_probs of the shape float32[128,1] - assert values.shape == returns.shape, ( - f'values.shape was {values.shape}' - f'returns.shape was {returns.shape}') - assert values.shape == dones.shape, ( - f'values.shape was {values.shape}' - f'returns.shape was {dones.shape}') - assert rewards.shape == dones.shape, ( - f'values.shape was {values.shape}' - f'returns.shape was {dones.shape}') - assert returns.shape[0:2] == old_log_probs.shape, ( - f'returns.shape was {returns.shape}' - f'old_log_probs.shape was {old_log_probs.shape}') - - # actions is a tensor of the shape int32[128,1] in the case - # of discrete actions and float32[128,1,6] in the case of - # half-cheetah and other continuous actions - # actions agree with returns/values on the first two coordinates - # meaning batch and time - assert actions.shape[0:2] == returns.shape[0:2], ( - f'actions.shape was {actions.shape} and ' - f'returns.shape was {returns.shape}') - - ppo_objective = rl_layers.PPOObjective( - dist_inputs, stop_gradient(values), returns, dones, rewards, - actions, old_log_probs, - log_prob_fun=self._policy_dist.log_prob, - epsilon=self._epsilon, - normalize_advantages=self._normalize_advantages) - - # we insist that ppo_objective is a vector of shape [128,1] - assert len(ppo_objective.shape) == 2, ( - f'ppo_objective was {ppo_objective}') - # which agrees with returns/values/actions on the first two coordinates - assert ppo_objective.shape[0:2] == values.shape[0:2], ( - f'ppo_objective.shape was {ppo_objective.shape} and ' - f'values.shape was {values.shape}') - - entropy_loss = rl_layers.EntropyLoss( - dist_inputs, - distribution=self._policy_dist, - coeff=self._entropy_coeff, - ) - - assert jnp.ndim(entropy_loss) == 0, f'entropy_loss was {entropy_loss}' - - l2_value_loss = rl_layers.ValueLoss( - values, returns, value_loss_coeff=self._value_loss_coeff) - - assert jnp.ndim(l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' - - return -ppo_objective.mean() + l2_value_loss - entropy_loss - - return tl.Fn('PPOJointLoss', f) - - # pylint: disable=invalid-name - @property - def probs_ratio_mean(self): - """Joint policy and value loss layer.""" - def ProbsRatioMean(dist_inputs, actions, old_log_probs): - """Probability Ratio Mean from the PPO algorithm.""" - probs_ratio = rl_layers.ProbsRatio( - dist_inputs, actions, old_log_probs, - log_prob_fun=self._policy_dist.log_prob) - return jnp.mean(probs_ratio) - - def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): - del values, returns, dones, rewards - return ProbsRatioMean(dist_inputs, actions, old_log_probs) - return tl.Fn('ProbsRatioMean', f) - - @property - def clip_fraction(self): - """Joint policy and value loss layer.""" - def ClipFraction(dist_inputs, actions, old_log_probs): - """Probability Ratio Mean from the PPO algorithm.""" - probs_ratio = rl_layers.ProbsRatio( - dist_inputs, actions, old_log_probs, - log_prob_fun=self._policy_dist.log_prob) - return jnp.mean(jnp.abs(probs_ratio - 1) > self._epsilon) - - def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): - del values, returns, dones, rewards - return ClipFraction(dist_inputs, actions, old_log_probs) - return tl.Fn('ClipFraction', f) - # pylint: enable=invalid-name - - @property - def entropy_loss(self): - """Entropy layer.""" - def f(dist_inputs, values, returns, dones, rewards, actions): - del values, returns, dones, rewards, actions - return rl_layers.EntropyLoss( - dist_inputs, - distribution=self._policy_dist, - coeff=self._entropy_coeff, - ) - return tl.Fn('EntropyLoss', f) - - @property - def approximate_kl_divergence(self): - """Approximate KL divergence.""" - def f(dist_inputs, values, returns, dones, rewards, - actions, old_log_probs): - del values, returns, dones, rewards - return rl_layers.ApproximateKLDivergence( - dist_inputs, - actions, - old_log_probs, - log_prob_fun=self._policy_dist.log_prob) - return tl.Fn('ApproximateKLDivergence', f) - - @property - def unclipped_objective_mean(self): - def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): - """Unclipped objective Mean from the PPO algorithm.""" - del dones, rewards - advantages = returns - values - probs_ratio = rl_layers.ProbsRatio( - dist_inputs, actions, old_log_probs, - log_prob_fun=self._policy_dist.log_prob) - # advantages are of the shape [128,1,1] - # and probs_ratio are of the shape [128,1] - advantages = advantages.squeeze(axis=2) - unclipped_objective = rl_layers.UnclippedObjective( - probs_ratio, advantages) - return jnp.mean(unclipped_objective) - - return tl.Fn('UnclippedObjectiveMean', f) - - @property - def clipped_objective_mean(self): - def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): - """Clipped objective from the PPO algorithm.""" - del dones, rewards - advantages = returns - values - probs_ratio = rl_layers.ProbsRatio( - dist_inputs, actions, old_log_probs, - log_prob_fun=self._policy_dist.log_prob) - # advantages are of the shape [128,1,1] - # and probs_ratio are of the shape [128,1] - advantages = advantages.squeeze(axis=2) - clipped_objective = rl_layers.ClippedObjective( - probs_ratio, advantages, epsilon=self._epsilon) - return jnp.mean(clipped_objective) - - return tl.Fn('ClippedObjectiveMean', f) - - @property - def ppo_objective(self): - """PPO objective with local parameters.""" - def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): - return rl_layers.PPOObjective( - dist_inputs, values, returns, dones, rewards, actions, old_log_probs, - log_prob_fun=self._policy_dist.log_prob, - epsilon=self._epsilon, - normalize_advantages=self._normalize_advantages) - return tl.Fn('PPOObjective', f) - - @property - def ppo_objective_mean(self): - """PPO objective mean.""" - def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): - """Clipped objective from the PPO algorithm.""" - ppo_objective = rl_layers.PPOObjective( - dist_inputs, values, returns, dones, rewards, actions, old_log_probs, - log_prob_fun=self._policy_dist.log_prob, - epsilon=self._epsilon, - normalize_advantages=self._normalize_advantages) - return jnp.mean(ppo_objective) - return tl.Fn('PPOObjectiveMean', f) - - -class A2CJoint(ActorCriticJointAgent): - """The A2C algorithm. - - Trains policy and value models using the A2C algortithm. - """ - - on_policy = True - - def __init__(self, task, value_loss_coeff=0.1, - entropy_coeff=0.01, **kwargs): - """Configures the A2C Trainer.""" - self._value_loss_coeff = value_loss_coeff - self._entropy_coeff = entropy_coeff - super().__init__(task, **kwargs) - self._trainer = supervised.Trainer( - model=self._joint_model, - optimizer=self._optimizer, - lr_schedule=self._lr_schedule, - loss_fn=self.joint_loss, - inputs=self._inputs, - output_dir=self._output_dir, - metrics={'joint_loss': self.joint_loss, - 'advantage_mean': self.advantage_mean, - 'advantage_norm': self.advantage_norm, - 'value_loss': self.value_loss, - 'explained_variance': self.explained_variance, - 'log_probs_mean': self.log_probs_mean, - 'entropy_loss': self.entropy_loss, - 'a2c_objective_mean': self.a2c_objective_mean, - 'approximate_kl_divergence': self.approximate_kl_divergence, - 'preferred_move': self.preferred_move}) - - def batches_stream(self): - """Use the RLTask self._task to create inputs to the value model.""" - for np_trajectory in self._task.trajectory_batch_stream( - self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1]): - # Insert an extra depth dimension, so the target shape is consistent with - # the network output shape. - yield (np_trajectory.observation, # Inputs to the value model. - np_trajectory.return_[:, :, None], - np_trajectory.done[:, :, None], - np_trajectory.reward[:, :, None], - np_trajectory.action, - jnp.zeros_like(np_trajectory.mask), - np_trajectory.mask) - - @property - def joint_loss(self): - """Joint policy and value loss.""" - def f(dist_inputs, values, returns, dones, rewards, - actions, old_log_probs, mask): - """Definition of the A2C loss.""" - del old_log_probs - - # Typically we have dist_inputs of the shape float32[128,1,18] - assert len(dist_inputs.shape) == 3, ( - f'dist_inputs.shape was {dist_inputs.shape} ' - f'but expected length of the tensor shape is 3') - # values of the shape float32[128,1,1] - # returns of the shape float32[128,1,1] - assert values.shape == returns.shape, ( - f'values.shape was {values.shape}' - f'returns.shape was (returns.shape)') - # actions of the shape int32[128,1] in the case of discrete actions - # and float32[128,1,6] in the case of of half-cheetah - # actions agree with returns/values on the first two coordinates - assert actions.shape[0:2] == returns.shape[0:2], ( - f'actions.shape was {actions.shape}' - f'returns.shape was (returns.shape)') - # and mask of the shape float32[128,1] - assert len(mask.shape) == 2, f'mask.shape was {mask.shape}' - # which agrees with returns/values/actions on the first two coordinates - assert mask.shape[0:2] == returns.shape[0:2], ( - f'mask.shape was {mask.shape}' - f'returns.shape was (returns.shape)') - - a2c_objective = rl_layers.A2CObjective( - dist_inputs, - stop_gradient(values), - returns, dones, rewards, actions, mask, - log_prob_fun=self._policy_dist.log_prob, - normalize_advantages=self._normalize_advantages) - - # we insist that a2c_objective is a scalar - assert jnp.ndim(a2c_objective) == 0, f'a2c_objective was {a2c_objective}' - - entropy_loss = rl_layers.EntropyLoss( - dist_inputs, - distribution=self._policy_dist, - coeff=self._entropy_coeff, - ) - - assert jnp.ndim(entropy_loss) == 0, f'entropy_loss was {entropy_loss}' - - l2_value_loss = rl_layers.ValueLoss( - values, returns, value_loss_coeff=self._value_loss_coeff) - - assert jnp.ndim(l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' - - combined_loss = a2c_objective + l2_value_loss - entropy_loss - - return combined_loss - - return tl.Fn('A2CJointLoss', f, n_out=1) - - @property - def entropy_loss(self): - """Entropy layer.""" - def f(dist_inputs, values, returns, dones, rewards, actions): - del values, returns, dones, rewards, actions - return rl_layers.EntropyLoss( - dist_inputs, - distribution=self._policy_dist, - coeff=self._entropy_coeff, - ) - return tl.Fn('EntropyLoss', f) - - @property - def approximate_kl_divergence(self): - """Approximate KL divergence.""" - def f(dist_inputs, values, returns, dones, rewards, - actions, old_log_probs): - del values, returns, dones, rewards - return rl_layers.ApproximateKLDivergence( - dist_inputs, - actions, - old_log_probs, - log_prob_fun=self._policy_dist.log_prob) - return tl.Fn('ApproximateKLDivergence', f) - - @property - def a2c_objective(self): - """A2C objective with local parameters.""" - return tl.Fn( - 'A2CObjective', - lambda dist_inputs, values, returns, dones, rewards, actions, \ - old_log_probs, mask: rl_layers.A2CObjective( - dist_inputs, - values, - returns, - dones, - rewards, - actions, - mask, - log_prob_fun=self._policy_dist.log_prob, - normalize_advantages=self._normalize_advantages), - n_out=1) - - @property - def a2c_objective_mean(self): - """A2C objective mean.""" - def f(dist_inputs, values, returns, dones, rewards, - actions, old_log_probs, mask): - """A2C objective mean.""" - # TODO(henrykm): include dones, rewards - del old_log_probs - a2c_objective = rl_layers.A2CObjective( - dist_inputs, values, returns, dones, rewards, actions, mask, - log_prob_fun=self._policy_dist.log_prob, - normalize_advantages=self._normalize_advantages) - return jnp.mean(a2c_objective) - return tl.Fn('A2CObjectiveMean', f, n_out=1) - - -class AWRJoint(ActorCriticJointAgent): - """Trains a joint policy-and-value model using AWR.""" - - # TODO(henrykm): value_loss_coeff looks like a common parameter - def __init__(self, task, value_loss_coeff=0.1, beta=1.0, w_max=20.0, - thresholds=None, **kwargs): - """Configures the joint AWR Trainer.""" - self._beta = beta - self._w_max = w_max - self._thresholds = thresholds - self._value_loss_coeff = value_loss_coeff - super().__init__(task, **kwargs) - - def batches_stream(self): - """Use the RLTask self._task to create inputs to the value model.""" - for np_trajectory in self._task.trajectory_batch_stream( - self._batch_size, max_slice_length=self._max_slice_length): - # Insert an extra depth dimension, so the target shape is consistent with - # the network output shape. - yield (np_trajectory.observation, # Inputs to the value model. - np_trajectory.return_[:, :, None], # Targets: regress to returns. - np_trajectory.action, # Policy targets: actions. - np_trajectory.mask) # Padding mask. - - @property - def joint_loss(self): - """Joint policy and value loss.""" - - def f(preds, values, returns, actions, mask): - advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) - logps = self._policy_dist.log_prob(preds, actions) - awr_loss = actor_critic.AWRLoss( - beta=self._beta, w_max=self._w_max, thresholds=self._thresholds)( - (logps, advantages, jnp.zeros_like(logps), mask)) - l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff - return awr_loss + l2_value_loss - return tl.Fn('AWRJointLoss', f) diff --git a/trax/rl/actor_critic_joint_test.py b/trax/rl/actor_critic_joint_test.py deleted file mode 100644 index c4eddcc84..000000000 --- a/trax/rl/actor_critic_joint_test.py +++ /dev/null @@ -1,176 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RL training.""" - -import functools - -from absl.testing import absltest - -from trax import layers as tl -from trax import models -from trax import optimizers as opt -from trax import test_utils -from trax.rl import actor_critic_joint -from trax.rl import task as rl_task -from trax.supervised import lr_schedules - - - -class ActorCriticJointTest(absltest.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - - def test_awrjoint_save_restore(self): - """Check save and restore of joint AWR trainer.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=2, - max_steps=2) - joint_model = functools.partial( - models.PolicyAndValue, - body=lambda mode: tl.Serial(tl.Dense(4), tl.Relu()), - ) - tmp_dir = self.create_tempdir().full_path - trainer1 = actor_critic_joint.AWRJoint( - task, - joint_model=joint_model, - optimizer=opt.Adam, - batch_size=4, - train_steps_per_epoch=1, - n_trajectories_per_epoch=2, - output_dir=tmp_dir) - trainer1.run(2) - self.assertEqual(trainer1.current_epoch, 2) - self.assertEqual(trainer1._trainer.step, 2) - # Agent 2 starts where agent 1 stopped. - trainer2 = actor_critic_joint.AWRJoint( - task, - joint_model=joint_model, - optimizer=opt.Adam, - batch_size=4, - train_steps_per_epoch=1, - n_trajectories_per_epoch=2, - output_dir=tmp_dir) - trainer2.run(1) - self.assertEqual(trainer2.current_epoch, 3) - self.assertEqual(trainer2._trainer.step, 3) - trainer1.close() - trainer2.close() - - - def test_jointppotrainer_cartpole(self): - """Test-runs joint PPO on CartPole.""" - - task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, - max_steps=2) - joint_model = functools.partial( - models.PolicyAndValue, - body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()), - ) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - - trainer = actor_critic_joint.PPOJoint( - task, - joint_model=joint_model, - optimizer=opt.Adam, - lr_schedule=lr, - batch_size=4, - train_steps_per_epoch=2, - n_trajectories_per_epoch=5) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - def test_jointawrtrainer_cartpole(self): - """Test-runs joint AWR on cartpole.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, - max_steps=2) - joint_model = functools.partial( - models.PolicyAndValue, - body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), - ) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - trainer = actor_critic_joint.AWRJoint( - task, - joint_model=joint_model, - optimizer=opt.Adam, - lr_schedule=lr, - batch_size=4, - train_steps_per_epoch=2, - n_trajectories_per_epoch=5) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - def test_jointa2ctrainer_cartpole(self): - """Test-runs joint A2C on cartpole.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, - max_steps=2) - joint_model = functools.partial( - models.PolicyAndValue, - body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), - ) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - trainer = actor_critic_joint.A2CJoint( - task, - joint_model=joint_model, - optimizer=opt.RMSProp, - lr_schedule=lr, - batch_size=2, - train_steps_per_epoch=1, - n_trajectories_per_epoch=1) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - def test_jointawrtrainer_cartpole_transformer(self): - """Test-runs joint AWR on cartpole with Transformer.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, - max_steps=2) - body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda - d_model=4, d_ff=4, n_layers=1, n_heads=1, mode=mode) - joint_model = functools.partial(models.PolicyAndValue, body=body) - trainer = actor_critic_joint.AWRJoint( - task, - joint_model=joint_model, - optimizer=opt.Adam, - batch_size=4, - train_steps_per_epoch=2, - n_trajectories_per_epoch=2, - max_slice_length=2) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - def test_jointa2ctrainer_cartpole_transformer(self): - """Test-runs joint A2C on cartpole with Transformer.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, - max_steps=2) - body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda - d_model=4, d_ff=4, n_layers=1, n_heads=1, mode=mode) - joint_model = functools.partial(models.PolicyAndValue, body=body) - trainer = actor_critic_joint.A2CJoint( - task, - joint_model=joint_model, - optimizer=opt.RMSProp, - batch_size=4, - train_steps_per_epoch=2, - n_trajectories_per_epoch=2) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/actor_critic_test.py b/trax/rl/actor_critic_test.py deleted file mode 100644 index eaf3dea10..000000000 --- a/trax/rl/actor_critic_test.py +++ /dev/null @@ -1,293 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RL training.""" - -import functools - - -from absl.testing import absltest -from absl.testing import parameterized - -from trax import layers as tl -from trax import models -from trax import optimizers as opt -from trax import test_utils -from trax.rl import actor_critic -from trax.rl import advantages -from trax.rl import task as rl_task -from trax.supervised import lr_schedules - - -class ActorCriticTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - - def test_a2ctrainer_save_restore(self): - """Check save and restore of A2C trainer.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, - max_steps=20) - body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) - policy_model = functools.partial(models.Policy, body=body) - value_model = functools.partial(models.Value, body=body) - tmp_dir = self.create_tempdir().full_path - trainer1 = actor_critic.A2C( - task, - value_model=value_model, - value_optimizer=opt.Adam, - value_batch_size=2, - value_train_steps_per_epoch=1, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_batch_size=2, - policy_train_steps_per_epoch=2, - n_trajectories_per_epoch=2, - n_shared_layers=1, - output_dir=tmp_dir) - trainer1.run(2) - self.assertEqual(trainer1.current_epoch, 2) - self.assertEqual(trainer1._value_trainer.step, 2) - self.assertEqual(trainer1._policy_trainer.step, 4) - # Trainer 2 starts where trainer 1 stopped. - trainer2 = actor_critic.A2C( - task, - value_model=value_model, - value_optimizer=opt.Adam, - value_batch_size=2, - value_train_steps_per_epoch=1, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_batch_size=2, - policy_train_steps_per_epoch=2, - n_trajectories_per_epoch=2, - n_shared_layers=1, - output_dir=tmp_dir) - trainer2.run(1) - self.assertEqual(trainer2.current_epoch, 3) - self.assertEqual(trainer2._value_trainer.step, 3) - self.assertEqual(trainer2._policy_trainer.step, 6) - trainer1.close() - trainer2.close() - - def test_sanity_a2ctrainer_cartpole(self): - """Test-runs a2c on cartpole.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, - max_steps=2) - body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) - policy_model = functools.partial(models.Policy, body=body) - value_model = functools.partial(models.Value, body=body) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-4, warmup_steps=100, factors='constant * linear_warmup') - trainer = actor_critic.A2C( - task, - n_shared_layers=1, - value_model=value_model, - value_optimizer=opt.Adam, - value_lr_schedule=lr, - value_batch_size=2, - value_train_steps_per_epoch=2, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_lr_schedule=lr, - policy_batch_size=2, - policy_train_steps_per_epoch=2, - n_trajectories_per_epoch=2) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - def test_sanity_ppo_cartpole(self): - """Run PPO and check whether it correctly runs for 2 epochs.s.""" - task = rl_task.RLTask( - 'CartPole-v1', initial_trajectories=0, max_steps=200) - - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-3, - warmup_steps=100, - factors='constant * linear_warmup') - - body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) - policy_model = functools.partial(models.Policy, body=body) - value_model = functools.partial(models.Value, body=body) - trainer = actor_critic.PPO( - task, - n_shared_layers=1, - value_model=value_model, - value_optimizer=opt.Adam, - value_lr_schedule=lr, - value_batch_size=128, - value_train_steps_per_epoch=10, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_lr_schedule=lr, - policy_batch_size=128, - policy_train_steps_per_epoch=10, - n_trajectories_per_epoch=10) - - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - def test_sanity_loopawr(self): - """Test-runs LoopAWR.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=2) - body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) - model_fn = functools.partial(models.PolicyAndValue, body=body) - trainer = actor_critic.LoopAWR( - task, - model_fn, - batch_size=2, - network_eval_at=(lambda _: True), - policy_n_steps_per_epoch=2, - value_n_steps_per_epoch=2, - n_trajectories_per_epoch=1, - n_eval_episodes=1, - ) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - @parameterized.named_parameters(('default', None), - ('thresholds', ((70, 1.0, 0), (90, 4.0, 0)))) - def test_sanity_awrtrainer_transformer_cartpole(self, thresholds): - """Test-runs AWR on cartpole with Transformer.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=2, - max_steps=2) - body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda - d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode) - policy_model = functools.partial(models.Policy, body=body) - value_model = functools.partial(models.Value, body=body) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - trainer = actor_critic.AWR( - task, - thresholds=thresholds, - n_shared_layers=0, - max_slice_length=2, - added_policy_slice_length=1, - value_model=value_model, - value_optimizer=opt.Adam, - value_lr_schedule=lr, - value_batch_size=2, - value_train_steps_per_epoch=2, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_lr_schedule=lr, - policy_batch_size=2, - policy_train_steps_per_epoch=2, - n_trajectories_per_epoch=1, - n_eval_episodes=1) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - def test_sampling_awrtrainer_cartpole(self): - """Test-runs AWR on cartpole with Transformer.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, - max_steps=20) - body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) - policy_model = functools.partial(models.Policy, body=body) - value_model = functools.partial(models.Value, body=body) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - trainer = actor_critic.SamplingAWR( - task, - n_shared_layers=0, - added_policy_slice_length=1, - value_model=value_model, - value_optimizer=opt.Adam, - value_lr_schedule=lr, - value_batch_size=2, - value_train_steps_per_epoch=2, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_lr_schedule=lr, - policy_batch_size=2, - policy_train_steps_per_epoch=2, - n_trajectories_per_epoch=2, - advantage_estimator=advantages.monte_carlo, - advantage_normalization=False, - q_value_n_samples=3, - q_value_aggregate='max', - reweight=False, - ) - trainer.run(1) - self.assertEqual(1, trainer.current_epoch) - - def test_sampling_awrtrainer_cartpole_sample_all_discrete(self): - """Test-runs AWR on cartpole with Transformer, n_actions = n_samples.""" - task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, - max_steps=20) - body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) - policy_model = functools.partial(models.Policy, body=body) - value_model = functools.partial(models.Value, body=body) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - trainer = actor_critic.SamplingAWR( - task, - n_shared_layers=0, - added_policy_slice_length=1, - value_model=value_model, - value_optimizer=opt.Adam, - value_lr_schedule=lr, - value_batch_size=2, - value_train_steps_per_epoch=2, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_lr_schedule=lr, - policy_batch_size=2, - policy_train_steps_per_epoch=2, - n_trajectories_per_epoch=2, - advantage_estimator=advantages.monte_carlo, - advantage_normalization=False, - q_value_n_samples=2, - q_value_aggregate='max', - reweight=False, - ) - trainer.run(1) - self.assertEqual(1, trainer.current_epoch) - - def test_sampling_awrtrainer_mountain_acr(self): - """Test-runs Sampling AWR on MountainCarContinuous.""" - task = rl_task.RLTask('MountainCarContinuous-v0', initial_trajectories=0, - max_steps=2) - body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) - policy_model = functools.partial(models.Policy, body=body) - value_model = functools.partial(models.Value, body=body) - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - trainer = actor_critic.SamplingAWR( - task, - n_shared_layers=0, - added_policy_slice_length=1, - value_model=value_model, - value_optimizer=opt.Adam, - value_lr_schedule=lr, - value_batch_size=2, - value_train_steps_per_epoch=2, - policy_model=policy_model, - policy_optimizer=opt.Adam, - policy_lr_schedule=lr, - policy_batch_size=2, - policy_train_steps_per_epoch=2, - n_trajectories_per_epoch=2, - advantage_estimator=advantages.monte_carlo, - advantage_normalization=False, - q_value_n_samples=3, - ) - trainer.run(1) - self.assertEqual(1, trainer.current_epoch) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/advantages.py b/trax/rl/advantages.py deleted file mode 100644 index 3825a856a..000000000 --- a/trax/rl/advantages.py +++ /dev/null @@ -1,176 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""RL advantage estimators.""" - -import gin -import numpy as np - -from trax import fastmath - -common_args = ['gamma', 'margin'] - - -def mask_discount(discount, discount_mask): - """Computes a discount to apply at a given timestep, based on the mask.""" - return fastmath.numpy.where(discount_mask, discount, 1.0) - - -def discounted_returns(rewards, gammas): - """Computes discounted returns for a trajectory or a batch of them.""" - returns = np.zeros_like(rewards) - ret = 0.0 - for i in reversed(range(rewards.shape[-1])): - ret = rewards[..., i] + gammas[..., i] * ret - returns[..., i] = ret - return returns - - -@gin.configurable(denylist=common_args) -def monte_carlo(gamma, margin): - """Calculate Monte Carlo advantage. - - We assume the values are a tensor of shape [batch_size, length] and this - is the same shape as rewards and returns. - - Args: - gamma: float, gamma parameter for TD from the underlying task - margin: number of extra steps in the sequence - - Returns: - Function (rewards, returns, values, dones) -> advantages, where advantages - advantages is an array of shape [batch_size, length - margin]. - """ - del gamma - def estimator(rewards, returns, values, dones, discount_mask): - del discount_mask - (_, length) = returns.shape - # Make sure that the future returns and values at "done" states are zero. - returns[dones] = rewards[dones] - values[dones] = 0 - return (returns - values)[:, :(length - margin)] - return estimator - - -@gin.configurable(denylist=common_args) -def td_k(gamma, margin): - """Calculate TD-k advantage. - - The k parameter is assumed to be the same as margin. - - We calculate advantage(s_i) as: - - gamma^n_steps * value(s_{i + n_steps}) - value(s_i) + discounted_rewards - - where discounted_rewards is the sum of rewards in these steps with - discounting by powers of gamma. - - Args: - gamma: float, gamma parameter for TD from the underlying task - margin: number of extra steps in the sequence - - Returns: - Function (rewards, returns, values, dones) -> advantages, where advantages - advantages is an array of shape [batch_size, length - margin]. - """ - def estimator(rewards, returns, values, dones, discount_mask): - del returns - gammas = mask_discount(gamma, discount_mask) - # Here we calculate advantage with TD-k, where k=margin. - k = margin - assert k > 0 - advantages = np.zeros_like(values[:, k:]) - discount = 1.0 - for i in range(margin): - advantages += discount * rewards[:, i:-(margin - i)] - discount *= gammas[:, i:-(margin - i)] - advantages += discount * values[:, k:] - # Zero out the future returns at "done" states. - dones = dones[:, :-k] - # TPU friendly version of the formula - # advantages[dones] = rewards[:, :-k][dones] - advantages = fastmath.index_update(advantages, - dones, - rewards[:, :-k][dones]) - # Subtract the baseline (value). - advantages -= values[:, :-k] - return advantages - return estimator - - -@gin.configurable(denylist=common_args) -def td_lambda(gamma, margin, lambda_=0.95): - """Calculate TD-lambda advantage. - - The estimated return is an exponentially-weighted average of different TD-k - returns. - - Args: - gamma: float, gamma parameter for TD from the underlying task - margin: number of extra steps in the sequence - lambda_: float, the lambda parameter of TD-lambda - - Returns: - Function (rewards, returns, values, dones) -> advantages, where advantages - advantages is an array of shape [batch_size, length - margin]. - """ - def estimator(rewards, returns, values, dones, discount_mask): - gammas = mask_discount(gamma, discount_mask) - lambdas = mask_discount(lambda_, discount_mask) - td_returns = np.zeros_like(returns) - (_, length) = returns.shape - td_returns[:, -1] = values[:, -1] - for i in reversed(range(length - 1)): - lambda_i = lambdas[:, i] - td_returns[:, i] = rewards[:, i] + (1 - dones[:, i]) * gammas[:, i] * ( - (1 - lambda_i) * values[:, i + 1] + lambda_i * td_returns[:, i + 1] - ) - return (td_returns - values)[:, :(returns.shape[1] - margin)] - return estimator - - -@gin.configurable(denylist=common_args) -def gae(gamma, margin, lambda_=0.95): - """Calculate Generalized Advantage Estimation. - - Calculate state values bootstrapping off the following state values - - Generalized Advantage Estimation https://arxiv.org/abs/1506.02438 - - Args: - gamma: float, gamma parameter for TD from the underlying task - margin: number of extra steps in the sequence - lambda_: float, the lambda parameter of GAE - - Returns: - Function (rewards, returns, values, dones) -> advantages, where advantages - advantages is an array of shape [batch_size, length - margin]. - """ - def estimator(rewards, returns, values, dones, discount_mask): - del returns - gammas = mask_discount(gamma, discount_mask) - lambdas = mask_discount(lambda_, discount_mask) - advantages = np.zeros_like(rewards) - (_, length) = rewards.shape - - for i in reversed(range(length - 1)): - bellman_delta = rewards[:, i] - values[:, i] + (1 - dones[:, i]) * ( - gammas[:, i] * values[:, i + 1] - ) - advantages[:, i] = bellman_delta + (1 - dones[:, i]) * ( - gammas[:, i] * lambdas[:, i] * advantages[:, i + 1] - ) - - return advantages[:, :(rewards.shape[1] - margin)] - return estimator diff --git a/trax/rl/advantages_test.py b/trax/rl/advantages_test.py deleted file mode 100644 index ba0c25be7..000000000 --- a/trax/rl/advantages_test.py +++ /dev/null @@ -1,245 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.rl.advantages.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from trax.rl import advantages - - -def calc_bias_and_variance(x, true_mean): - sample_mean = np.mean(x) - bias = np.mean(np.abs(sample_mean - true_mean)) - variance = np.mean((x - sample_mean) ** 2) - return (bias, variance) - - -def estimate_advantage_bias_and_variance( - advantage_fn, - mean_reward=1.23, - reward_noise=0.45, - discount_mask=None, - discount_true_return=True, - true_value=False, - n_samples=10000, - length=5, - gamma=0.9, - margin=1, - **advantage_kwargs -): - advantage_fn = advantage_fn(gamma, margin, **advantage_kwargs) - rewards = np.random.normal( - loc=mean_reward, scale=reward_noise, size=(n_samples, length) - ) - if discount_mask is None: - discount_mask = np.ones_like(rewards) - gammas = advantages.mask_discount(gamma, discount_mask) - returns = advantages.discounted_returns(rewards, gammas) - - true_returns = advantages.discounted_returns( - np.full(returns.shape, fill_value=mean_reward), gammas=gammas - ) - if true_value: - values = true_returns - else: - values = np.zeros_like(returns) - - dones = np.zeros_like(returns, dtype=bool) - adv = advantage_fn(rewards, returns, values, dones, discount_mask) - if discount_true_return: - mean_return = true_returns[0, 0] - else: - mean_return = mean_reward * length - return calc_bias_and_variance(adv[:, 0], mean_return - values[:, 0]) - - -class AdvantagesTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('monte_carlo', advantages.monte_carlo), - ('td_k', advantages.td_k), - ('td_lambda', advantages.td_lambda), - ('gae', advantages.gae), - ) - def test_shapes(self, advantage_fn): - rewards = np.array([[1, 1, 1]], dtype=np.float32) - returns = np.array([[3, 2, 1]], dtype=np.float32) - values = np.array([[2, 2, 2]], dtype=np.float32) - dones = np.array([[False, False, True]]) - discount_mask = np.ones_like(rewards) - adv1 = advantage_fn(gamma=1, margin=1)( - rewards, returns, values, dones, discount_mask - ) - self.assertEqual(adv1.shape, (1, 2)) - adv2 = advantage_fn(gamma=1, margin=2)( - rewards, returns, values, dones, discount_mask - ) - self.assertEqual(adv2.shape, (1, 1)) - - def test_monte_carlo_bias_is_zero(self): - (bias, _) = estimate_advantage_bias_and_variance( - advantages.monte_carlo, margin=3 - ) - np.testing.assert_allclose(bias, 0, atol=0.1) - - def test_td_k_variance_lower_than_monte_carlo(self): - (_, var_td_3) = estimate_advantage_bias_and_variance( - advantages.td_k, margin=3 - ) - (_, var_mc) = estimate_advantage_bias_and_variance(advantages.monte_carlo) - self.assertLess(var_td_3, var_mc) - - @parameterized.named_parameters(('1_2', 1, 2), ('2_3', 2, 3)) - def test_td_k_bias_decreases_with_k(self, k1, k2): - (bias1, _) = estimate_advantage_bias_and_variance( - advantages.td_k, margin=k1 - ) - (bias2, _) = estimate_advantage_bias_and_variance( - advantages.td_k, margin=k2 - ) - self.assertGreater(bias1, bias2) - - @parameterized.named_parameters(('1_2', 1, 2), ('2_3', 2, 3)) - def test_td_k_variance_increases_with_k(self, k1, k2): - (_, var1) = estimate_advantage_bias_and_variance( - advantages.td_k, margin=k1 - ) - (_, var2) = estimate_advantage_bias_and_variance( - advantages.td_k, margin=k2 - ) - self.assertLess(var1, var2) - - def test_td_lambda_variance_lower_than_monte_carlo(self): - (_, var_td_095) = estimate_advantage_bias_and_variance( - advantages.td_lambda, lambda_=0.95 - ) - (_, var_mc) = estimate_advantage_bias_and_variance(advantages.monte_carlo) - self.assertLess(var_td_095, var_mc) - - @parameterized.named_parameters( - ('td_lambda_0.5_0.7', advantages.td_lambda, 0.5, 0.7), - ('td_lambda_0.7_0.9', advantages.td_lambda, 0.7, 0.9), - ('gae_0.5_0.7', advantages.gae, 0.5, 0.7), - ('gae_0.7_0.9', advantages.gae, 0.7, 0.9), - ) - def test_bias_decreases_with_lambda(self, advantage_fn, lambda1, lambda2): - (bias1, _) = estimate_advantage_bias_and_variance( - advantage_fn, lambda_=lambda1 - ) - (bias2, _) = estimate_advantage_bias_and_variance( - advantage_fn, lambda_=lambda2 - ) - self.assertGreater(bias1, bias2) - - @parameterized.named_parameters(('0.5_0.7', 0.5, 0.7), ('0.7_0.9', 0.7, 0.9)) - def test_variance_increases_with_lambda(self, lambda1, lambda2): - (_, var1) = estimate_advantage_bias_and_variance( - advantages.td_lambda, lambda_=lambda1 - ) - (_, var2) = estimate_advantage_bias_and_variance( - advantages.td_lambda, lambda_=lambda2 - ) - self.assertLess(var1, var2) - - @parameterized.named_parameters( - ('monte_carlo', advantages.monte_carlo), - ('td_k', advantages.td_k), - ('td_lambda', advantages.td_lambda), - ('gae', advantages.gae), - ) - def test_advantage_future_return_is_zero_at_done(self, advantage_fn): - rewards = np.array([[1, 1, 1]], dtype=np.float32) - returns = np.array([[3, 2, 1]], dtype=np.float32) - values = np.array([[2, 2, 2]], dtype=np.float32) - dones = np.array([[False, True, False]]) - discount_mask = np.ones_like(rewards) - adv = advantage_fn(gamma=0.9, margin=1)( - rewards, returns, values, dones, discount_mask - ) - target_returns = values[:, :-1] + adv - # Assert that in the "done" state the future return in the advantage is - # zero, i.e. the return is equal to the reward. - np.testing.assert_almost_equal(target_returns[0, 1], rewards[0, 1]) - - @parameterized.named_parameters( - ('monte_carlo', advantages.monte_carlo), - # Disabled for TD-k because the differences are too small. - # ('td_k', advantages.td_k), - ('td_lambda', advantages.td_lambda), - ('gae', advantages.gae), - ) - def test_bias_and_variance_with_non_const_discount_mask(self, advantage_fn): - non_const_discount_mask = np.array([[1, 0, 1, 0, 1]]) - const_discount_mask = np.ones_like(non_const_discount_mask) - est_bias_and_variance = functools.partial( - estimate_advantage_bias_and_variance, - advantage_fn, - length=const_discount_mask.shape[1], - # Set gamma to a small value to accentuate the differences. - gamma=0.5, - # We want to measure error due to the discount, so compare with the - # undiscounted return. - discount_true_return=False, - # Use true values to remove the value estimation error. - true_value=True, - ) - (bias_non_const, var_non_const) = est_bias_and_variance( - discount_mask=non_const_discount_mask - ) - (bias_const, var_const) = est_bias_and_variance( - discount_mask=const_discount_mask - ) - self.assertLess(bias_non_const, bias_const) - self.assertGreater(var_non_const, var_const) - - @parameterized.named_parameters( - ('monte_carlo', advantages.monte_carlo), - ('td_k', advantages.td_k), - ('td_lambda', advantages.td_lambda), - ('gae', advantages.gae), - ) - def test_future_return_is_zero_iff_discount_mask_is_on(self, advantage_fn): - # (... when gamma=0) - rewards = np.array([[1, 2, 3, 4]], dtype=np.float32) - values = np.array([[5, 6, 7, 8]], dtype=np.float32) - dones = np.zeros_like(rewards, dtype=bool) - discount_mask = np.array([[1, 0, 1, 0]], dtype=bool) - gammas = advantages.mask_discount(0.0, discount_mask) - returns = advantages.discounted_returns(rewards, gammas) - adv = advantage_fn(gamma=0.0, margin=1)( - rewards, returns, values, dones, discount_mask - ) - target_returns = values[:, :-1] + adv - # Assert that in the states with discount_mask on the future return in the - # advantage is zero, i.e. the return is equal to the reward. - rewards = rewards[:, :-1] - discount_mask = discount_mask[:, :-1] - np.testing.assert_almost_equal( - target_returns[discount_mask], rewards[discount_mask] - ) - # Assert the converse. - with np.testing.assert_raises(AssertionError): - np.testing.assert_almost_equal( - target_returns[~discount_mask], rewards[~discount_mask] - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/distributions.py b/trax/rl/distributions.py deleted file mode 100644 index 294dea379..000000000 --- a/trax/rl/distributions.py +++ /dev/null @@ -1,224 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Probability distributions for RL training in Trax.""" - -import gin -import gym -import numpy as np - -from trax import layers as tl -from trax.fastmath import numpy as jnp - - -class Distribution: - """Abstract class for parametrized probability distributions.""" - - @property - def n_inputs(self): - """Returns the number of inputs to the distribution (i.e. parameters).""" - raise NotImplementedError - - def sample(self, inputs, temperature=1.0): - """Samples a point from the distribution. - - Args: - inputs (jnp.ndarray): Distribution inputs. Shape is subclass-specific. - Broadcasts along the first dimensions. For example, in the categorical - distribution parameter shape is (C,), where C is the number of - categories. If (B, C) is passed, the object will represent a batch of B - categorical distributions with different parameters. - temperature: sampling temperature; 1.0 is default, at 0.0 chooses - the most probable (preferred) action. - - Returns: - Sampled point of shape dependent on the subclass and on the shape of - inputs. - """ - raise NotImplementedError - - def log_prob(self, inputs, point): - """Retrieves log probability (or log probability density) of a point. - - Args: - inputs (jnp.ndarray): Distribution parameters. - point (jnp.ndarray): Point from the distribution. Shape should be - consistent with inputs. - - Returns: - Array of log probabilities of points in the distribution. - """ - raise NotImplementedError - - def LogProb(self): # pylint: disable=invalid-name - """Builds a log probability layer for this distribution.""" - return tl.Fn('LogProb', - lambda inputs, point: self.log_prob(inputs, point)) # pylint: disable=unnecessary-lambda - - -@gin.configurable(denylist=['n_categories', 'shape']) -class Categorical(Distribution): - """Categorical distribution parametrized by logits.""" - - def __init__(self, n_categories, shape=()): - """Initializes Categorical distribution. - - Args: - n_categories (int): Number of categories. - shape (tuple): Shape of the sample. - """ - self._n_categories = n_categories - self._shape = shape - - @property - def n_inputs(self): - return np.prod(self._shape, dtype=jnp.int32) * self._n_categories - - def _unflatten_inputs(self, inputs): - return jnp.reshape( - inputs, inputs.shape[:-1] + self._shape + (self._n_categories,) - ) - - def sample(self, inputs, temperature=1.0): - # No need for LogSoftmax with sampling - softmax normalization is - # subtracting a constant from every logit, and sampling is taking - # a max over logits plus noise, so invariant to adding a constant. - if temperature == 0.0: - return jnp.argmax(self._unflatten_inputs(inputs), axis=-1) - return tl.logsoftmax_sample(self._unflatten_inputs(inputs), temperature) - - def log_prob(self, inputs, point): - inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs)) - return jnp.sum( - # Select the logits specified by point. - inputs * tl.one_hot(point, self._n_categories), - # Sum over the parameter dimensions. - axis=[-a for a in range(1, len(self._shape) + 2)], - ) - - def entropy(self, inputs): - log_probs = tl.LogSoftmax()(inputs) - probs = jnp.exp(log_probs) - return -jnp.sum(probs * log_probs, axis=-1) - - -@gin.configurable(denylist=['shape']) -class Gaussian(Distribution): - """Independent multivariate Gaussian distribution parametrized by mean.""" - - def __init__(self, shape=(), std=1.0, learn_std=None): - """Initializes Gaussian distribution. - - Args: - shape (tuple): Shape of the sample. - std (float): Standard deviation, shared across the whole sample. - learn_std (str or None): How to learn the standard deviation - 'shared' - to have a single, shared std parameter, or 'separate' to have separate - parameters for each dimension. - """ - self._shape = shape - self._std = std - self._learn_std = learn_std - - @property - def _n_dims(self): - return np.prod(self._shape, dtype=jnp.int32) - - def _params(self, inputs): - """Extracts the mean and std parameters from the inputs.""" - if inputs.shape[-1] != self.n_inputs: - raise ValueError( - 'Invalid distribution parametrization - expected {} parameters, ' - 'got {}. Input shape: {}.'.format( - self.n_inputs, inputs.shape[-1], inputs.shape - ) - ) - n_dims = self._n_dims - # Split the distribution inputs into two parts: mean and std. - mean = inputs[..., :n_dims] - if self._learn_std is not None: - std = inputs[..., n_dims:] - # Std is non-negative, so let's softplus it. - std = tl.Softplus()(std + self._std) - else: - std = self._std - # In case of constant or shared std, upsample it to the same dimensionality - # as the means. - std = jnp.broadcast_to(std, mean.shape) - return (mean, std) - - @property - def n_inputs(self): - n_dims = self._n_dims - return { - None: n_dims, - 'shared': n_dims + 1, - 'separate': n_dims * 2, - }[self._learn_std] - - def sample(self, inputs, temperature=1.0): - (mean, std) = self._params(inputs) - mean = jnp.reshape(mean, mean.shape[:-1] + self._shape) - std = jnp.reshape(std, std.shape[:-1] + self._shape) - if temperature == 0: - # this seemingly strange if solves the problem - # of calling np/jnp.random in the metric PreferredMove - return mean - else: - return np.random.normal(loc=mean, scale=(std * temperature)) - - def log_prob(self, inputs, point): - point = point.reshape(inputs.shape[:-1] + (-1,)) - (mean, std) = self._params(inputs) - return -jnp.sum( - # Scaled distance. - (point - mean) ** 2 / (2 * std ** 2) + - # Normalizing constant. - (jnp.log(std) + jnp.log(jnp.sqrt(2 * jnp.pi))), - axis=-1, - ) - - def entropy(self, inputs): - (_, std) = self._params(inputs) - return jnp.sum(jnp.exp(std) + .5 * jnp.log(2.0 * jnp.pi * jnp.e), axis=-1) - - -# TODO(pkozakowski): Implement GaussianMixture. - - -def create_distribution(space): - """Creates a Distribution for the given Gym space.""" - if isinstance(space, gym.spaces.Discrete): - return Categorical(shape=(), n_categories=space.n) - elif isinstance(space, gym.spaces.MultiDiscrete): - assert space.nvec.size - assert min(space.nvec) == max(space.nvec), ( - 'Every dimension must have the same number of categories, got ' - '{}.'.format(space.nvec) - ) - return Categorical(shape=(len(space.nvec),), n_categories=space.nvec[0]) - elif isinstance(space, gym.spaces.Box): - return Gaussian(shape=space.shape) - else: - raise TypeError('Space {} unavailable as a distribution support.') - - -def LogLoss(distribution, **unused_kwargs): # pylint: disable=invalid-name - """Builds a log loss layer for a Distribution.""" - return tl.Serial( - distribution.LogProb(), - tl.Negate(), - tl.WeightedSum() - ) diff --git a/trax/rl/distributions_test.py b/trax/rl/distributions_test.py deleted file mode 100644 index df6f3d7f1..000000000 --- a/trax/rl/distributions_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.rl.distributions.""" - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import gym -import numpy as np - -from trax.rl import distributions - - -class DistributionsTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - gin.clear_config() - - @parameterized.named_parameters( - ('discrete', gym.spaces.Discrete(n=4), ''), - ('multi_discrete', gym.spaces.MultiDiscrete(nvec=[5, 5]), ''), - ( - 'gaussian_const_std', - gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)), - 'Gaussian.learn_std = None', - ), ( - 'gaussian_shared_std', - gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)), - 'Gaussian.learn_std = "shared"', - ), ( - 'gaussian_separate_std', - gym.spaces.Box(low=-np.inf, high=+np.inf, shape=(4, 5)), - 'Gaussian.learn_std = "separate"', - ), - ) - def test_shapes(self, space, gin_config): - gin.parse_config(gin_config) - - batch_shape = (2, 3) - distribution = distributions.create_distribution(space) - inputs = np.random.random(batch_shape + (distribution.n_inputs,)) - point = distribution.sample(inputs) - self.assertEqual(point.shape, batch_shape + space.shape) - # Check if the datatypes are compatible, i.e. either both floating or both - # integral. - self.assertEqual( - isinstance(point.dtype, float), isinstance(space.dtype, float) - ) - log_prob = distribution.log_prob(inputs, point) - self.assertEqual(log_prob.shape, batch_shape) - - @parameterized.named_parameters(('1d', 1), ('2d', 2)) - def test_gaussian_probability_sums_to_one(self, n_dims): - std = 1.0 - n_samples = 10000 - - distribution = distributions.Gaussian(shape=(n_dims,), std=std) - means = np.random.random((3, n_dims)) - # Monte carlo integration over [mean - 3 * std, mean + 3 * std] across - # all dimensions. - means = np.broadcast_to(means, (n_samples,) + means.shape) - probs = (6 * std) ** n_dims * np.mean( - np.exp(distribution.log_prob( - means, np.random.uniform(means - 3 * std, means + 3 * std) - )), - axis=0, - ) - # Should sum to one. High tolerance because of variance and cutting off the - # tails. - np.testing.assert_allclose(probs, np.ones_like(probs), atol=0.05) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/envs/data_envs.py b/trax/rl/envs/data_envs.py deleted file mode 100644 index 3bc67ce0c..000000000 --- a/trax/rl/envs/data_envs.py +++ /dev/null @@ -1,165 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""RL environments created from supervised data-sets.""" - -import gym -import numpy as np - - -class SequenceDataEnv(object): - """RL environment created from a generator of sequential data. - - This class allows to create RL environments from supervised sequential data, - such as tokenized natural languague processing tasks. The data comes as: - (input1, output1, input2, output2, ...) - where inputs and outputs are all sequences of integers. - - For example, with input (2, 3) and output (4, 5), so data = [(2, 3), (4, 5)], - the sequence of (observations, rewards, actions) will look like: - 2 = env.reset() # first observation - 3, 0.0, _, _ = env.step(ignored_action) - eos, 0.0, _, _ = env.step(ignored_action) - act1, 0.0, _, _ = env.step(act1) # observation = action - act2, 0.0, _, _ = env.step(act2) # observation = action - eos, score, _, _ = env.step(eos) - - where score = metric((4, 5), (act1, act2)) is the reward gotten from - comparing the two actions to the actual output from the data. - - The environment first presents the input as observations, doing this - sequentially, token-by-token, and ignoring all actions taken by the policy. - Then, the policy is asked to generate the response, again, token-by-token, - until it generates EOS. Generated tokens are repeated as observations. - When EOS is encountered, a metric is computed between the generated - output and the output from data, and this metric is returned as reward. - """ - - def __init__(self, data_stream, vocab_size, metric=None, - eos_id=1, max_length=1000): - """The constructor. - - Args: - data_stream: A python generator creating lists or tuples of - sequences (list, tuples or numpy arrays) of integers. - vocab_size: Integer, the size of the vocabulary. All integers in the - data stream must be positive and smaller than this value. - metric: A function taking two lists of integers and returning a float. - If None, we use per-token accuracy as the default metric. - eos_id: Integer, the id of the EOS symbol. - max_length: Integer, maximum length of the policy reply to avoid - infinite episodes if policy never produces EOS. - - Returns: - A new environment which presents the data and compares the policy - response with the expected data, returning metric as reward. - """ - self._data = data_stream - self._vocab_size = vocab_size - self._eos = eos_id - self._max_length = max_length - self._metric = _accuracy if metric is None else metric - self.reset() - - @property - def _on_input(self): - """Return True if we're currently processing input, False if output.""" - cur_sequence_id, _ = self._cur_position - return cur_sequence_id % 2 == 0 - - @property - def observation(self): - cur_sequence_id, cur_token_id = self._cur_position - if cur_sequence_id >= len(self._cur_sequence): - obs = self._eos - elif self._on_input: - obs = self._cur_sequence[cur_sequence_id][cur_token_id] - else: - obs = self._response[-1] if self._response else self._eos - return np.array(int(obs), dtype=np.int32) - - @property - def action_space(self): - return gym.spaces.Discrete(self._vocab_size) - - @property - def observation_space(self): - return gym.spaces.Discrete(self._vocab_size) - - def reset(self): - """Reset this environment.""" - self._cur_sequence = next(self._data) - # Position contains 2 indices: which sequnece are we in? (input1, output1, - # input2, output2 and so on) and which token in the sequence are we in? - self._cur_position = (0, 0) - self._response = [] - return self.observation - - def step(self, action): - """Single step of the environment when policy took `action`.""" - cur_sequence_id, cur_token_id = self._cur_position - if cur_sequence_id >= len(self._cur_sequence): - return np.array(self._eos, dtype=np.int32), 0.0, True, None - - # Emit the control mask on the output. - control_mask = int(not self._on_input) - - if self._on_input: - self._response = [] - if cur_token_id + 1 < len(self._cur_sequence[cur_sequence_id]): - self._cur_position = (cur_sequence_id, cur_token_id + 1) - done = False - else: - self._cur_position = (cur_sequence_id + 1, 0) - done = cur_sequence_id + 1 >= len(self._cur_sequence) - reward = 0.0 - discount_mask = 0 - - else: - self._response.append(action) - if action == self._eos or len(self._response) > self._max_length: - self._cur_position = (cur_sequence_id + 1, 0) - reward = self._metric( - self._response[:-1], self._cur_sequence[cur_sequence_id]) - done = cur_sequence_id + 1 >= len(self._cur_sequence) - # Emit the discount mask on the last token of each action. - discount_mask = 1 - else: - reward = 0.0 - done = False - discount_mask = 0 - - info = {'control_mask': control_mask, 'discount_mask': discount_mask} - return self.observation, reward, done, info - - -def copy_stream(length, low=2, high=15, n=1): - """Generate `n` random sequences of length `length` and yield with copies.""" - while True: - res = [] - for _ in range(n): - seq = np.random.randint(low, high, size=(length,), dtype=np.int32) - res.extend([seq, seq]) - yield res - - -def _accuracy(seq1, seq2): - """Token-level accuracy.""" - seq1, seq2 = np.array(seq1), np.array(seq2) - max_length = max(seq1.shape[-1], seq2.shape[-1]) - min_length = min(seq1.shape[-1], seq2.shape[-1]) - seq1s, seq2s = seq1[..., :min_length], seq2[..., :min_length] - return np.sum(np.equal(seq1s, seq2s)) / max_length - diff --git a/trax/rl/envs/data_envs_test.py b/trax/rl/envs/data_envs_test.py deleted file mode 100644 index 205e44f5e..000000000 --- a/trax/rl/envs/data_envs_test.py +++ /dev/null @@ -1,170 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RL environments created from supervised data-sets.""" - -from absl.testing import absltest -import numpy as np -from trax.rl.envs import data_envs - - -class SequenceDataEnvTest(absltest.TestCase): - - def _assert_masks(self, info, control, discount): - self.assertEqual(info, {'control_mask': control, 'discount_mask': discount}) - - def test_copy_task_short_sequence_correct_actions(self): - """Test sequence data env on the copying task, correct replies. - - With input (x1, x2) this tests for the following sequence of - (observations, rewards, dones, actions): - x1 = env.reset() - x2, 0.0, F, _ = env.step(ignored_action) - eos, 0.0, F, _ = env.step(ignored_action) - x1, 0.0, F, _ = env.step(x1) - x2, 0.0, F, _ = env.step(x2) - eos, 1.0, T, _ = env.step(eos) - """ - env = data_envs.SequenceDataEnv(data_envs.copy_stream(2, n=1), 16) - x1 = env.reset() - x2, r0, d0, i0 = env.step(0) - self.assertEqual(r0, 0.0) - self.assertEqual(d0, False) - self._assert_masks(i0, control=0, discount=0) - eos, r1, d1, i1 = env.step(0) - self.assertEqual(eos, 1) - self.assertEqual(r1, 0.0) - self.assertEqual(d1, False) - self._assert_masks(i1, control=0, discount=0) - y1, r2, d2, i2 = env.step(x1) - self.assertEqual(y1, x1) - self.assertEqual(r2, 0.0) - self.assertEqual(d2, False) - self._assert_masks(i2, control=1, discount=0) - y2, r3, d3, i3 = env.step(x2) - self.assertEqual(y2, x2) - self.assertEqual(r3, 0.0) - self.assertEqual(d3, False) - self._assert_masks(i3, control=1, discount=0) - eos2, r4, d4, i4 = env.step(1) - self.assertEqual(eos2, 1) - self.assertEqual(r4, 1.0) - self.assertEqual(d4, True) - self._assert_masks(i4, control=1, discount=1) - - def test_copy_task_longer_sequnece_mixed_actions(self): - """Test sequence data env on the copying task, mixed replies. - - With input (x1, x2) and (y1, y2) this tests for the following sequence of - (observations, rewards, dones, actions): - x1 = env.reset() - x2, 0.0, F, _ = env.step(ignored_action) - eos, 0.0, F, _ = env.step(ignored_action) - x1, 0.0, F, _ = env.step(x1) - x2+1, 0.0, F, _ = env.step(x2+1) - y1, 0,5, F, _ = env.step(eos) - y2, 0.0, F, _ = env.step(ignored_action) - eos, 0.0, F, _ = env.step(ignored_action) - y1+1 0.0, F, _ = env.step(y1+1) - y2+1, 0.0, F, _ = env.step(y2+1) - eos, 0.0, T, _ = env.step(eos) - """ - env = data_envs.SequenceDataEnv(data_envs.copy_stream(2, n=2), 16) - x1 = env.reset() - x2, _, _, _ = env.step(0) - eos, _, _, _ = env.step(0) - _, _, _, _ = env.step(x1) - _, _, _, _ = env.step(x2 + 1) # incorrect - y1, r1, d1, _ = env.step(1) - self.assertEqual(r1, 0.5) - self.assertEqual(d1, False) - y2, _, _, _ = env.step(0) - eos, _, _, _ = env.step(0) - _, _, _, _ = env.step(y1 + 1) # incorrect - _, _, _, _ = env.step(y2 + 1) # incorrect - eos, r2, d2, _ = env.step(1) - self.assertEqual(eos, 1) - self.assertEqual(r2, 0.0) - self.assertEqual(d2, True) - - def test_copy_task_action_observation_space(self): - """Test that sequence data env returns correct action/observation space.""" - env = data_envs.SequenceDataEnv(data_envs.copy_stream(2, n=1), 16) - self.assertEqual(env.action_space.n, 16) - self.assertEqual(env.observation_space.n, 16) - - def test_copy_task_max_length(self): - """Test that sequence data env respects max_length.""" - env = data_envs.SequenceDataEnv(data_envs.copy_stream(10, n=1), 16, - max_length=2) - obs = env.reset() - for _ in range(10): - obs, reward, done, _ = env.step(0) - self.assertEqual(reward, 0.0) - self.assertEqual(done, False) - self.assertEqual(obs, 1) # produces EOS - obs, reward, done, _ = env.step(7) - self.assertEqual(obs, 7) # repeats action - self.assertEqual(reward, 0.0) - self.assertEqual(done, False) - obs, reward, done, _ = env.step(8) - self.assertEqual(obs, 8) # repeats action - self.assertEqual(reward, 0.0) - self.assertEqual(done, False) - obs, reward, done, _ = env.step(9) - self.assertEqual(done, True) # exceeded max_length, stop - self.assertEqual(obs, 1) # produce EOS on done - obs, reward, done, _ = env.step(10) - self.assertEqual(done, True) # continue producing done = True - self.assertEqual(obs, 1) # continue producing EOS - - def test_number_of_active_masks(self): - """Test that we have the correct number of control and discount masks.""" - n_input_seqs = 3 - n_output_seqs = 2 - input_len = 4 - output_len = 5 - - def data_stream(): - i = 2 * np.ones(input_len) - o = np.zeros(output_len) - while True: - yield (i, o, i, o, i) # 3 input, 2 output sequences. - - env = data_envs.SequenceDataEnv(data_stream(), 16, max_length=output_len) - env.reset() - - n_discount = 0 - n_control = 0 - n_steps = 0 - done = False - while not done: - (_, _, done, info) = env.step(action=0) - n_discount += info['discount_mask'] - n_control += info['control_mask'] - n_steps += 1 - - # One discount_mask=1 per output sequence. - self.assertEqual(n_discount, n_output_seqs) - # One control_mask=1 per output token, including EOS, because it's also - # controlled by the agent. - self.assertEqual(n_control, (output_len + 1) * n_output_seqs) - # One control_mask=0 per input token, excluding EOS, because when the env - # emits it, control transfers to the agent immediately. - self.assertEqual(n_steps - n_control, input_len * n_input_seqs) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/normalization.py b/trax/rl/normalization.py deleted file mode 100644 index 141eb719b..000000000 --- a/trax/rl/normalization.py +++ /dev/null @@ -1,124 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Normalization helpers.""" - -import gin -import numpy as np - -from trax import fastmath -from trax import layers as tl - - -def running_mean_init(shape, fill_value=0): - return (np.full(shape, fill_value), np.array(0)) - - -def running_mean_update(x, state): - (mean, n) = state - mean = n.astype(np.float32) / (n + 1) * mean + x / (n + 1) - return (mean, n + 1) - - -def running_mean_get_mean(state): - (mean, _) = state - return mean - - -def running_mean_get_count(state): - (_, count) = state - return count - - -def running_mean_and_variance_init(shape): - mean_state = running_mean_init(shape, fill_value=0.0) - var_state = running_mean_init(shape, fill_value=1.0) - return (mean_state, var_state) - - -def running_mean_and_variance_update(x, state): - (mean_state, var_state) = state - old_mean = running_mean_get_mean(mean_state) - mean_state = running_mean_update(x, mean_state) - new_mean = running_mean_get_mean(mean_state) - - var_state = running_mean_update((x - new_mean) * (x - old_mean), var_state) - - return (mean_state, var_state) - - -def running_mean_and_variance_get_mean(state): - (mean_state, _) = state - return running_mean_get_mean(mean_state) - - -def running_mean_and_variance_get_count(state): - (mean_state, _) = state - return running_mean_get_count(mean_state) - - -def running_mean_and_variance_get_variance(state): - (_, var_state) = state - return running_mean_get_mean(var_state) - - -@gin.configurable(denylist=['mode']) -class Normalize(tl.Layer): - """Numerically stable normalization layer.""" - - def __init__(self, sample_limit=float('+inf'), epsilon=1e-5, mode='train'): - super().__init__() - self._sample_limit = sample_limit - self._epsilon = epsilon - self._mode = mode - - def init_weights_and_state(self, input_signature): - self.state = running_mean_and_variance_init(input_signature.shape[2:]) - - def forward(self, inputs): - state = self.state - observations = inputs - if self._mode == 'collect': - # Accumulate statistics only in the collect mode, i.e. when collecting - # data using the agent. - for observation in observations[:, -1]: # (batch_size, time, ...) - # Update statistics for each observation separately for simplicity. - # Currently during data collection the batch size is 1 anyway. - count = running_mean_and_variance_get_count(state) - state = fastmath.cond( - count < self._sample_limit, - true_operand=(observation, state), - true_fun=lambda args: running_mean_and_variance_update(*args), - false_operand=None, - false_fun=lambda _: state, - ) - - mean = running_mean_and_variance_get_mean(state) - var = running_mean_and_variance_get_variance(state) - norm_observations = (observations - mean) / (var ** 0.5 + self._epsilon) - self.state = state - return norm_observations - - -@gin.configurable(denylist=['mode']) -def LayerNormSquash(mode, width=128): # pylint: disable=invalid-name - """Dense-LayerNorm-Tanh normalizer inspired by ACME.""" - # https://github.com/deepmind/acme/blob/master/acme/jax/networks/continuous.py#L34 - del mode - return tl.Serial([ - tl.Dense(width), - tl.LayerNorm(), - tl.Tanh(), - ]) diff --git a/trax/rl/normalization_test.py b/trax/rl/normalization_test.py deleted file mode 100644 index d24ab68a0..000000000 --- a/trax/rl/normalization_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.rl.normalization.""" - -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.rl import normalization - - -class NormalizationTest(absltest.TestCase): - - def test_running_mean(self): - x = np.random.uniform(size=10) - state = normalization.running_mean_init(shape=()) - for i in range(len(x)): - state = normalization.running_mean_update(x[i], state) - np.testing.assert_almost_equal( - normalization.running_mean_get_mean(state), np.mean(x[:i + 1]) - ) - - def test_running_variance(self): - x = np.random.uniform(size=10) - state = normalization.running_mean_and_variance_init(shape=()) - for i in range(len(x)): - state = normalization.running_mean_and_variance_update(x[i], state) - np.testing.assert_almost_equal( - normalization.running_mean_and_variance_get_variance(state), - np.var(x[:i + 1]), - ) - - def test_normalize_collect(self): - x = np.random.uniform(size=(2, 3, 4, 5)) - normalize = normalization.Normalize(mode='collect') - normalize.init(shapes.signature(x)) - old_state = normalize.state - y = normalize(x) - with self.assertRaises(AssertionError): - np.testing.assert_equal(normalize.state, old_state) - with self.assertRaises(AssertionError): - np.testing.assert_almost_equal(x, y) - - def test_normalize_train(self): - x = np.random.uniform(size=(2, 3, 4, 5)) - normalize = normalization.Normalize(mode='train', epsilon=0.0) - normalize.init(shapes.signature(x)) - old_state = normalize.state - y = normalize(x) - np.testing.assert_equal(normalize.state, old_state) - np.testing.assert_almost_equal(x, y) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/policy_tasks.py b/trax/rl/policy_tasks.py deleted file mode 100644 index 4bf1a9ba6..000000000 --- a/trax/rl/policy_tasks.py +++ /dev/null @@ -1,261 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Policy network training tasks. - -Policy tasks encapsulate the training process of a policy network into a simple, -replaceable component. To implement a policy-based Agent using policy tasks: - - 1. Subclass the base Agent class. - 2. In __init__(), initialize the policy training and evaluation tasks, and - a trax.supervised.training.Loop instance using them. - 3. In train_epoch(), call the Loop to train the network. - 4. In policy(), call network_policy() defined in this module. -""" - -import numpy as np - -from trax import layers as tl -from trax.fastmath import numpy as jnp -from trax.rl import distributions -from trax.supervised import training - - -class PolicyTrainTask(training.TrainTask): - """Task for policy training. - - Trains the policy based on action advantages. - """ - - def __init__( - self, - trajectory_batch_stream, - optimizer, - lr_schedule, - policy_distribution, - advantage_estimator, - value_fn, - weight_fn=(lambda x: x), - advantage_normalization=True, - advantage_normalization_epsilon=1e-5, - head_selector=(), - ): - """Initializes PolicyTrainTask. - - Args: - trajectory_batch_stream: Generator of trax.rl.task.TimeStepBatch. - optimizer: Optimizer for network training. - lr_schedule: Learning rate schedule for network training. - policy_distribution: Distribution over actions. - advantage_estimator: Function - (rewards, returns, values, dones) -> advantages, created by one of the - functions from trax.rl.advantages. - value_fn: Function TimeStepBatch -> array (batch_size, seq_len) - calculating the baseline for advantage calculation. Can be used to - implement actor-critic algorithms, by substituting a call to the value - network as value_fn. - weight_fn: Function float -> float to apply to advantages. Examples: - - A2C: weight_fn = id - - AWR: weight_fn = exp - - behavioral cloning: weight_fn(_) = 1 - advantage_normalization: Whether to normalize advantages. - advantage_normalization_epsilon: Epsilon to use then normalizing - advantages. - head_selector: Layer to apply to the network output to select the value - head. Only needed in multitask training. By default, use a no-op layer, - signified by an empty sequence of layers, (). - """ - self.trajectory_batch_stream = trajectory_batch_stream - self._value_fn = value_fn - self._advantage_estimator = advantage_estimator - self._weight_fn = weight_fn - self._advantage_normalization = advantage_normalization - self._advantage_normalization_epsilon = advantage_normalization_epsilon - self.policy_distribution = policy_distribution - - labeled_data = map(self.policy_batch, trajectory_batch_stream) - sample_batch = self.policy_batch( - next(trajectory_batch_stream), shape_only=True - ) - loss_layer = distributions.LogLoss(distribution=policy_distribution) - loss_layer = tl.Serial(head_selector, loss_layer) - super().__init__( - labeled_data, loss_layer, optimizer, - sample_batch=sample_batch, - lr_schedule=lr_schedule, - loss_name='policy_loss', - ) - - def calculate_advantages(self, trajectory_batch, shape_only=False): - (batch_size, seq_len) = trajectory_batch.observation.shape[:2] - assert trajectory_batch.action.shape[:2] == (batch_size, seq_len) - assert trajectory_batch.mask.shape == (batch_size, seq_len) - if shape_only: - values = np.zeros((batch_size, seq_len)) - else: - # Compute the value, i.e. baseline in advantage computation. - values = np.array(self._value_fn(trajectory_batch)) - assert values.shape == (batch_size, seq_len) - # Compute the advantages using the chosen advantage estimator. - return self._advantage_estimator( - rewards=trajectory_batch.reward, - returns=trajectory_batch.return_, - dones=trajectory_batch.done, - values=values, - discount_mask=trajectory_batch.env_info.discount_mask, - ) - - def calculate_weights(self, advantages): - """Calculates advantage-based weights for log loss in policy training.""" - if self._advantage_normalization: - # Normalize advantages. - advantages -= jnp.mean(advantages) - advantage_std = jnp.std(advantages) - advantages /= advantage_std + self._advantage_normalization_epsilon - weights = self._weight_fn(advantages) - assert weights.shape == advantages.shape - return weights - - def trim_and_mask_batch(self, trajectory_batch, advantages): - (batch_size, seq_len) = trajectory_batch.observation.shape[:2] - adv_seq_len = advantages.shape[1] - # The advantage sequence should be shorter by the margin. Margin is the - # number of timesteps added to the trajectory slice, to make the advantage - # estimation more accurate. adv_seq_len determines the length of the target - # sequence, and is later used to trim the inputs and targets in the training - # batch. Example for margin 2: - # observations.shape == (4, 5, 6) - # rewards.shape == values.shape == (4, 5) - # advantages.shape == (4, 3) - assert adv_seq_len <= seq_len - assert advantages.shape == (batch_size, adv_seq_len) - # Trim observations, actions and mask to match the target length. - observations = trajectory_batch.observation[:, :adv_seq_len] - actions = trajectory_batch.action[:, :adv_seq_len] - mask = trajectory_batch.mask[:, :adv_seq_len] - # Apply the control mask, so we only compute policy loss for controllable - # timesteps. - mask *= trajectory_batch.env_info.control_mask[:, :adv_seq_len] - return (observations, actions, mask) - - def policy_batch(self, trajectory_batch, shape_only=False): - """Computes a policy training batch based on a trajectory batch. - - Args: - trajectory_batch: trax.rl.task.TimeStepBatch with a batch of trajectory - slices. Elements should have shape (batch_size, seq_len, ...). - shape_only: Whether to return dummy zero arrays of correct shape. Useful - for initializing models. - - Returns: - Triple (observations, actions, weights), where weights are the - advantage-based weights for the policy loss. Shapes: - - observations: (batch_size, seq_len) + observation_shape - - actions: (batch_size, seq_len) + action_shape - - weights: (batch_size, seq_len) - """ - advantages = self.calculate_advantages( - trajectory_batch, shape_only=shape_only - ) - (observations, actions, mask) = self.trim_and_mask_batch( - trajectory_batch, advantages - ) - weights = self.calculate_weights(advantages) * mask / jnp.sum(mask) - return (observations, actions, weights) - - -class PolicyEvalTask(training.EvalTask): - """Task for policy evaluation.""" - - def __init__(self, train_task, n_eval_batches=1, head_selector=()): - """Initializes PolicyEvalTask. - - Args: - train_task: PolicyTrainTask used to train the policy network. - n_eval_batches: Number of batches per evaluation. - head_selector: Layer to apply to the network output to select the value - head. Only needed in multitask training. - """ - self._train_task = train_task - self._policy_dist = train_task.policy_distribution - labeled_data = map(self._eval_batch, train_task.trajectory_batch_stream) - sample_batch = self._eval_batch( - next(train_task.trajectory_batch_stream), shape_only=True - ) - # TODO(pkozakowski): Implement more metrics. - metrics = { - 'policy_entropy': self.entropy_metric, - } - metrics.update(self.advantage_metrics) - metrics.update(self.weight_metrics) - metrics = { - name: tl.Serial(head_selector, metric) - for (name, metric) in metrics.items() - } - (metric_names, metric_layers) = zip(*metrics.items()) - # Select the appropriate head for evaluation. - super().__init__( - labeled_data, metric_layers, - sample_batch=sample_batch, - metric_names=metric_names, - n_eval_batches=n_eval_batches, - ) - - def _eval_batch(self, trajectory_batch, shape_only=False): - advantages = self._train_task.calculate_advantages( - trajectory_batch, shape_only=shape_only - ) - (observations, actions, mask) = self._train_task.trim_and_mask_batch( - trajectory_batch, advantages - ) - return (observations, actions, advantages, mask) - - @property - def entropy_metric(self): - def Entropy(policy_inputs, actions, advantages, mask): - del actions, advantages, mask - return jnp.mean(self._policy_dist.entropy(policy_inputs)) - return tl.Fn('Entropy', Entropy) - - @property - def advantage_metrics(self): - def make_metric(aggregate_fn): # pylint: disable=invalid-name - def AdvantageMetric(policy_inputs, actions, advantages, mask): - del policy_inputs, actions, mask - return aggregate_fn(advantages) - return tl.Fn('AdvantageMetric', AdvantageMetric) - return { - 'advantage_' + name: make_metric(fn) for (name, fn) in [ - ('mean', jnp.mean), - ('std', jnp.std), - ] - } - - @property - def weight_metrics(self): - def make_metric(aggregate_fn): # pylint: disable=invalid-name - def WeightMetric(policy_inputs, actions, advantages, mask): - del policy_inputs, actions, mask - weights = self._train_task.calculate_weights(advantages) - return aggregate_fn(weights) - return tl.Fn('WeightMetric', WeightMetric) - return { # pylint: disable=g-complex-comprehension - 'weight_' + name: make_metric(fn) for (name, fn) in [ - ('mean', jnp.mean), - ('std', jnp.std), - ('min', jnp.min), - ('max', jnp.max), - ] - } diff --git a/trax/rl/rl_layers.py b/trax/rl/rl_layers.py deleted file mode 100644 index e4c320280..000000000 --- a/trax/rl/rl_layers.py +++ /dev/null @@ -1,220 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A number of RL functions intended to be later wrapped as Trax layers. - - Wrapping happens with help of the function tl.Fn. -""" - -from trax.fastmath import numpy as jnp - - -def ValueLoss(values, returns, value_loss_coeff): - """Definition of the loss of the value function.""" - advantages = returns - values - l2_value_loss = jnp.mean(advantages**2) * value_loss_coeff - return l2_value_loss - - -def ExplainedVariance(values, returns): - """Definition of explained variance - an approach from OpenAI baselines.""" - assert returns.shape == values.shape, ( - f'returns.shape was {returns.shape} and values.shape was {values.shape}') - # TODO(henrykm): it would be good to explain the relation with the time dim. - returns_variance = jnp.var(returns) - explained_variance = 1 - jnp.var(returns-values)/returns_variance - return explained_variance - - -def PreferredMove(dist_inputs, sample): - """Definition of the preferred move.""" - preferred_moves = sample(dist_inputs, temperature=0.0) - return jnp.mean(preferred_moves) - - -def NewLogProbs(dist_inputs, actions, log_prob_fun): - """Given distribution and actions calculate log probs.""" - new_log_probs = log_prob_fun(dist_inputs, - actions) - return new_log_probs - - -# TODO(henrykm): Clarify how jnp.mean is applied. -def EntropyLoss(dist_inputs, distribution, coeff): - """Definition of the Entropy Layer.""" - entropy_loss = distribution.entropy(dist_inputs) * coeff - return jnp.mean(entropy_loss) - - -def ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun): - """Probability Ratio from the PPO algorithm.""" - # dist_inputs of the shape float32[128,1,18] - # actions of the shape int32[128,1] - # and old_log_probs of the shape float32[128,1] - new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) - assert new_log_probs.shape == old_log_probs.shape, ( - f'new_log_probs.shape was {new_log_probs.shape} and' - f'old_log_probs.shape was {old_log_probs.shape}') - # The ratio between new_probs and old_probs expressed - # using log_probs and exponentiation - probs_ratio = jnp.exp(new_log_probs - old_log_probs) - return probs_ratio - - -def ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun): - """Probability Ratio from the PPO algorithm.""" - new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) - assert new_log_probs.shape == old_log_probs.shape, ( - f'new_log_probs.shape was {new_log_probs.shape} and' - f'old_log_probs.shape was {old_log_probs.shape}') - approximate_kl_divergence = 0.5 * \ - jnp.mean(new_log_probs - old_log_probs) ** 2 - return approximate_kl_divergence - - -def UnclippedObjective(probs_ratio, advantages): - """Unclipped Objective from the PPO algorithm.""" - assert probs_ratio.shape == advantages.shape, ( - f'probs_ratio.shape was {probs_ratio.shape} and' - f'advantages.shape was {advantages.shape}') - unclipped_objective = probs_ratio * advantages - return unclipped_objective - - -def ClippedObjective(probs_ratio, advantages, epsilon): - """Clipped Objective from the PPO algorithm.""" - assert probs_ratio.shape == advantages.shape, ( - f'probs_ratio.shape was {probs_ratio.shape} and' - f'advantages.shape was {advantages.shape}') - clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, - 1 + epsilon) * advantages - assert probs_ratio.shape == clipped_objective.shape, ( - f'probs_ratio.shape was {probs_ratio.shape} and' - f'clipped_objective.shape was {clipped_objective.shape}') - return clipped_objective - - -def PPOObjective(dist_inputs, values, returns, dones, rewards, - actions, old_log_probs, log_prob_fun, epsilon, - normalize_advantages): - """PPO Objective.""" - # dist_inputs of the shape float32[128,1,18] - # values of the shape float32[128,1,1] - # returns of the shape float32[128,1,1] - # dones of the shape float32[128,1,1] - # rewards of the shape int32[128,1,1] - # actions of the shape int32[128,1] - # and old_log_probs of the shape float32[128,1] - returns = returns.squeeze(axis=2) - values = values.squeeze(axis=2) - dones = dones.squeeze(axis=2) - rewards = rewards.squeeze(axis=2) - assert rewards.shape == dones.shape, ( - f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') - assert dones.shape == values.shape, ( - f'dones.shape was {dones.shape} and values.shape was {values.shape}') - assert returns.shape == values.shape, ( - f'returns.shape was {returns.shape} and values.shape was {values.shape}') - assert returns.shape == old_log_probs.shape, ( - f'returns.shape was {returns.shape} and' - f'old_log_probs.shape was {old_log_probs.shape}') - - probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun) - assert probs_ratio.shape == old_log_probs.shape, ( - f'probs_ratio.shape was {probs_ratio.shape} and' - f'old_log_probs.shape was {old_log_probs.shape}') - - # jaxified versions of - # returns[dones] = rewards[dones] - # values[dones] = 0 - returns = jnp.where(dones, rewards, returns) - values = jnp.where(dones, jnp.zeros_like(values), values) - advantages = returns - values - if normalize_advantages: - advantages = advantages - jnp.mean(advantages) - advantages /= jnp.std(advantages) + 1e-8 - assert old_log_probs.shape == advantages.shape, ( - f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was ' - f'{advantages.shape}') - - unclipped_objective = UnclippedObjective(probs_ratio, advantages) - assert unclipped_objective.shape == advantages.shape, ( - f'old_log_probs.shape was {old_log_probs.shape} and' - f'unclipped_objective.shape was {unclipped_objective.shape}') - - clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon) - assert clipped_objective.shape == advantages.shape, ( - f'clipped_objective.shape was {clipped_objective.shape} and' - f'advantages.shape was {advantages.shape}') - - ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) - assert ppo_objective.shape == advantages.shape, ( - f'ppo_objective.shape was {ppo_objective.shape} and' - f'advantages.shape was {advantages.shape}') - - return ppo_objective - - -def A2CObjective(dist_inputs, values, returns, dones, rewards, - actions, mask, log_prob_fun, normalize_advantages): - """Definition of the Advantage Actor Critic (A2C) loss.""" - # dist_inputs of the shape float32[128,1,18] - # values of the shape float32[128,1,1] - # returns of the shape float32[128,1,1] - # dones of the shape int32[128,1,1] - # actions of the shape int32[128,1] - # and mask of the shape float32[128,1] - # We have to squeeze values and returns, because we - # are planning to compute (return - values) * new_log_probs * mask - # and all of them should be of the same dimension - values = values.squeeze(axis=2) - returns = returns.squeeze(axis=2) - dones = dones.squeeze(axis=2) - rewards = rewards.squeeze(axis=2) - assert rewards.shape == dones.shape, ( - f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}') - assert dones.shape == values.shape, ( - f'dones.shape was {dones.shape} and values.shape was {values.shape}') - assert returns.shape == values.shape, ( - f'returns.shape was {returns.shape} and values.shape was {values.shape}') - assert values.shape == mask.shape, ( - f'values.shape was {values.shape} and mask.shape was {mask.shape}') - assert returns.shape[0] == dist_inputs.shape[0], ( - f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was ' - f'{dist_inputs.shape[0]}') - - new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) - assert new_log_probs.shape == mask.shape, ( - f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was ' - f'{mask.shape}') - - # jaxified versions of - # returns[dones] = rewards[dones] - # values[dones] = 0 - returns = jnp.where(dones, rewards, returns) - values = jnp.where(dones, jnp.zeros_like(values), values) - advantages = returns - values - if normalize_advantages: - advantages = advantages - jnp.mean(advantages) - advantages /= jnp.std(advantages) + 1e-8 - assert new_log_probs.shape == advantages.shape, ( - f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was ' - f'{advantages.shape}') - - # One of the motivation to the squeezes and assertions is to - # avoid [128,1] * [128,1,1] * [128] multiplications in the definition - # of the a2c objective - we insist on the same shapes - a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask) - return a2c_objective diff --git a/trax/rl/serialization_utils.py b/trax/rl/serialization_utils.py deleted file mode 100644 index 778c95872..000000000 --- a/trax/rl/serialization_utils.py +++ /dev/null @@ -1,437 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for serializing trajectories into discrete sequences.""" - -import functools - -import gym -import numpy as np - -from trax import layers as tl -from trax.fastmath import numpy as jnp -from trax.rl import space_serializer - - -# pylint: disable=invalid-name -# TODO(pkozakowski): Move the layers to trax.layers and remove this module. -def Serialize(serializer): - """Layer that serializes a given array.""" - def serialize(x): - (batch_size, length) = x.shape[:2] - shape_suffix = x.shape[2:] - x = jnp.reshape(x, (batch_size * length,) + shape_suffix) - x = serializer.serialize(x) - return jnp.reshape(x, (batch_size, -1, serializer.representation_length,)) - return tl.Fn('Serialize', serialize) - - -def Interleave(): - """Layer that interleaves and flattens two serialized sequences. - - The first sequence can be longer by 1 than the second one. This is so we can - interleave sequences of observations and actions, when there's 1 extra - observation at the end. - - For serialized sequences [[x_1_1, ..., x_1_R1], ..., [x_L1_1, ..., x_L1_R1]] - and [[y_1_1, ..., y_1_R2], ..., [y_L2_1, ..., y_L2_R2]], where L1 = L2 + 1, - the result is [x_1_1, ..., x_1_R1, y_1_1, ..., y_1_R2, ..., x_L2_1, ..., - x_L2_R1, y_L2_1, ..., y_L2_R2, x_L1_1, ..., x_L1_R1] (batch dimension omitted - for clarity). - - The layer inputs are a sequence pair of shapes (B, L1, R1) and (B, L2, R2), - where B is batch size, L* is the length of the sequence and R* is the - representation length of each element in the sequence. - - Returns: - Layer that interleaves sequence of shape (B, L1 * R1 + L2 * R2). - """ - def interleave(x, y): - (batch_size, _, _) = x.shape - (_, length, _) = y.shape - assert x.shape[1] in (length, length + 1) - - reprs = jnp.concatenate((x[:, :length], y), axis=2) - reprs = jnp.reshape(reprs, (batch_size, -1)) - remainder = jnp.reshape(x[:, length:], (batch_size, -1)) - return jnp.concatenate((reprs, remainder), axis=1) - return tl.Fn('Interleave', interleave) - - -def Deinterleave(x_size, y_size): - """Layer that does the inverse of Interleave.""" - def deinterleave(inputs): - reprs = inputs - (batch_size, length) = reprs.shape[:2] - shape_suffix = reprs.shape[2:] - remainder_length = length % (x_size + y_size) - if remainder_length > 0: - remainder = reprs[:, None, -remainder_length:] - reprs = reprs[:, :-remainder_length] - reprs = jnp.reshape(reprs, (batch_size, -1, x_size + y_size) + shape_suffix) - x_reprs = reprs[:, :, :x_size] - y_reprs = reprs[:, :, x_size:] - if remainder_length > 0: - x_reprs = jnp.concatenate((x_reprs, remainder), axis=1) - return (x_reprs, y_reprs) - return tl.Fn('Deinterleave', deinterleave, n_out=2) - - -def RepresentationMask(serializer): - """Upsamples a mask to cover the serialized representation.""" - # Trax enforces the mask to be of the same size as the target. Get rid of the - # extra dimensions. - def representation_mask(mask): - # mask shape (batch_size,4) - mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim))) - # mask shape (batch_size,4) - mask = jnp.repeat( - mask[..., jnp.newaxis], - repeats=serializer.representation_length, - axis=2) - # mask shape (batch_size,4,representation_length) - return mask - return tl.Fn('RepresentationMask', representation_mask) - - -def SignificanceWeights(serializer, decay): - """Multiplies a binary mask with a symbol significance mask.""" - def significance_weights(mask): - # (repr,) -> (batch, length, repr) - # significance = [0, 1, 2] - significance = serializer.significance_map - assert significance.shape[0] == mask.shape[2] - # significance = batch_size * [0, 1, 2] - significance = jnp.repeat( - significance[np.newaxis, ...], repeats=mask.shape[0], axis=0) - # significance = batch_size * [0, 1, 2] * mask.shape[1] - significance = jnp.repeat( - significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2) - # significance = batch_size * mask.shape[1] * [0, 1, 2] - significance = jnp.swapaxes(significance, 1, 2) - assert significance.shape == mask.shape - sig_weights = mask * decay ** significance - return sig_weights - return tl.Fn('SignificanceWeights', significance_weights) - - -class SerializedModel(tl.Serial): - """Wraps a world model in serialization machinery for training. - - The resulting model takes as input the observation and action sequences, - serializes them and interleaves into one sequence, which is fed into a given - autoregressive model. The resulting logit sequence is deinterleaved into - observations and actions, and the observation logits are returned together - with computed symbol significance weights. - - The model has a signature - (obs, act, obs, mask) -> (obs_logits, obs_repr, weights), where obs are - observations (the second occurrence is the target), act are actions, mask is - the observation mask, obs_logits are logits of the output observation - representation, obs_repr is the target observation representation and weights - are the target weights. - """ - - def __init__( - self, - seq_model, - observation_serializer, - action_serializer, - significance_decay, - mode='train', - ): - """Initializes SerializedModel. - - Args: - seq_model: Trax autoregressive model taking as input a sequence of symbols - and outputting a sequence of symbol logits. - observation_serializer: Serializer to use for observations. - action_serializer: Serializer to use for actions. - significance_decay: Float from (0, 1) for exponential weighting of symbols - in the representation. - mode: 'train' or 'eval'. - """ - assert mode in ('train', 'eval') - weigh_by_significance = [ - # (mask,) - RepresentationMask(serializer=observation_serializer), - # (repr_mask) - SignificanceWeights(serializer=observation_serializer, - decay=significance_decay), - # (mask, sig_weights) - ] - super().__init__( - # (obs, act, obs, mask) - tl.Parallel(Serialize(serializer=observation_serializer), - Serialize(serializer=action_serializer), - Serialize(serializer=observation_serializer)), - # (obs_repr, act_repr, obs_repr, mask) - Interleave(), - # (obs_act_repr, obs_repr, mask) - seq_model(mode=mode), - # (obs_act_logits, obs_repr, mask) - Deinterleave(x_size=observation_serializer.representation_length, - y_size=action_serializer.representation_length), - # (obs_logits, act_logits, obs_repr, mask) - tl.Parallel(None, tl.Drop(), None, weigh_by_significance), - # (obs_logits, obs_repr, weights) - ) - - self._seq_model = seq_model - self._observation_serializer = observation_serializer - self._action_serializer = action_serializer - - @property - def observation_serializer(self): - return self._observation_serializer - - @property - def action_serializer(self): - return self._action_serializer - - def make_predict_model(self): - """Returns a predict-mode model of the same architecture.""" - return self._seq_model(mode='predict') - - @property - def seq_model_weights(self): - """Extracts the weights of the underlying sequence model.""" - return self.weights[2] - - @property - def seq_model_state(self): - """Extracts the state of the underlying sequence model.""" - return self.state[2] - - -def TimeSeriesModel( - seq_model, - low=0.0, - high=1.0, - precision=2, - vocab_size=64, - significance_decay=0.7, - mode='train', -): - """Simplified constructor for SerializedModel, for time series prediction.""" - # Model scalar time series. - obs_srl = space_serializer.BoxSpaceSerializer( - space=gym.spaces.Box(shape=(), low=low, high=high), - vocab_size=vocab_size, - precision=precision, - ) - # Artifact of the fact that we must provide some actions. - # TODO(pkozakowski): Remove this requirement. - act_srl = space_serializer.DiscreteSpaceSerializer( - space=gym.spaces.Discrete(n=1), vocab_size=1 - ) - seq_model = functools.partial(seq_model, vocab_size=vocab_size) - return SerializedModel(seq_model, obs_srl, act_srl, significance_decay, mode) - - -def RawPolicy(seq_model, n_controls, n_actions): - """Wraps a sequence model in a policy interface. - - The resulting model takes as input observation anc action sequences, but only - uses the observations. Adds output heads for action logits and value - predictions. - - Args: - seq_model: Trax sequence model taking as input and outputting a sequence of - continuous vectors. - n_controls: Number of controls. - n_actions: Number of action categories in each control. - - Returns: - A model of signature (obs, act) -> (act_logits, values), with shapes: - obs: (batch_size, length + 1, obs_depth) - act: (batch_size, length, n_controls) - act_logits: (batch_size, length, n_controls, n_actions) - values: (batch_size, length) - """ - - def SplitControls(): # pylint: disable=invalid-name - """Splits logits for actions in different controls.""" - def f(x): - return jnp.reshape(x, x.shape[:2] + (n_controls, n_actions)) - return tl.Fn('SplitControls', f) - - action_head = [ - # Predict all action logits at the same time. - tl.Dense(n_controls * n_actions), - # Then group them into separate controls, adding a new dimension. - SplitControls(), - tl.LogSoftmax(), - ] - return tl.Serial( # (obs, act) - tl.Select([0], n_in=2), # (obs,) - seq_model, # (obs_hidden,) - tl.Dup(), # (obs_hidden, obs_hidden) - tl.Parallel(action_head, [tl.Dense(1), - tl.Flatten()]) # (act_logits, values) - ) - - -def substitute_inner_policy_raw(raw_policy, inner_policy): # pylint: disable=invalid-name - """Substitutes the weights/state of the inner model in a RawPolicy.""" - return raw_policy[:1] + [inner_policy] + raw_policy[2:] - - -def SerializedPolicy( - seq_model, n_controls, n_actions, observation_serializer, action_serializer -): - """Wraps a policy in serialization machinery for training. - - The resulting model takes as input observation and action sequences, and - serializes them into one sequence similar to SerializedModel, before passing - to the given sequence model. Adds output heads for action logits and value - predictions. - - Args: - seq_model: Trax sequence model taking as input a sequence of symbols and - outputting a sequence of continuous vectors. - n_controls: Number of controls. - n_actions: Number of action categories in each control. - observation_serializer: Serializer to use for observations. - action_serializer: Serializer to use for actions. - - Returns: - A model of signature (obs, act) -> (act_logits, values), same as in - RawPolicy. - """ - if action_serializer.representation_length != n_controls: - raise ValueError( - 'Action symbols should correspond 1-1 to controls, but got {} ' - 'controls and {} symbols.'.format( - n_controls, action_serializer.representation_length - ) - ) - - def FirstSymbol(): - return tl.Fn('FirstSymbol', lambda x: x[:, :, 0]) - - def PadRight(n_to_pad): - def pad_right(x): - pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2) - return jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) - return tl.Fn(f'PadRight({n_to_pad})', pad_right) - - action_head = [ - tl.Dense(n_actions), - tl.LogSoftmax(), - ] - value_head = [ - # Take just the vectors corresponding to the first action symbol. - FirstSymbol(), - # Predict values. - tl.Dense(1), - # Get rid of the singleton dimension. - tl.Flatten(), - ] - return tl.Serial( - # (obs, act) - tl.Parallel(Serialize(observation_serializer), - Serialize(action_serializer)), - # (obs_repr, act_repr) - Interleave(), - # (obs_act_repr,) - - # Add one dummy action to the right - we'll use the output at its first - # symbol to predict the value for the last observation. - PadRight(action_serializer.representation_length), - - # Shift one symbol to the right, so we predict the n-th action symbol - # based on action symbols 1..n-1 instead of 1..n. - tl.ShiftRight(), - seq_model, - # (obs_act_hidden,) - Deinterleave(observation_serializer.representation_length, - action_serializer.representation_length), - # (obs_hidden, act_hidden) - tl.Select([1, 1]), - # (act_hidden, act_hidden) - tl.Parallel(action_head, value_head), - # (act_logits, values) - ) - - -def substitute_inner_policy_serialized(serialized_policy, inner_policy): # pylint: disable=invalid-name - """Substitutes the weights/state of the inner model in a SerializedPolicy.""" - return serialized_policy[:4] + [inner_policy] + serialized_policy[5:] - - -def analyze_action_space(action_space): # pylint: disable=invalid-name - """Returns the number of controls and actions for an action space.""" - assert isinstance( - action_space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete) - ), 'Action space expected to be Discrete of MultiDiscrete, got {}.'.format( - type(action_space) - ) - if isinstance(action_space, gym.spaces.Discrete): - n_actions = action_space.n - n_controls = 1 - else: - (n_controls,) = action_space.nvec.shape - assert n_controls > 0 - assert np.min(action_space.nvec) == np.max(action_space.nvec), ( - 'Every control must have the same number of actions.' - ) - n_actions = action_space.nvec[0] - return (n_controls, n_actions) - - -def wrap_policy(seq_model, observation_space, action_space, vocab_size): # pylint: disable=invalid-name - """Wraps a sequence model in either RawPolicy or SerializedPolicy. - - Args: - seq_model: Trax sequence model. - observation_space: Gym observation space. - action_space: Gym action space. - vocab_size: Either the number of symbols for a serialized policy, or None. - - Returns: - RawPolicy if vocab_size is None, else SerializedPolicy. - """ - (n_controls, n_actions) = analyze_action_space(action_space) - if vocab_size is None: - policy_wrapper = RawPolicy - else: - obs_serializer = space_serializer.create(observation_space, vocab_size) - act_serializer = space_serializer.create(action_space, vocab_size) - policy_wrapper = functools.partial(SerializedPolicy, - observation_serializer=obs_serializer, - action_serializer=act_serializer) - return policy_wrapper(seq_model, n_controls, n_actions) - - -def substitute_inner_policy(wrapped_policy, inner_policy, vocab_size): # pylint: disable=invalid-name - """Substitutes the inner weights/state in a {Raw,Serialized}Policy. - - Args: - wrapped_policy (pytree): Weights or state of a wrapped policy. - inner_policy (pytree): Weights or state of an inner policy. - vocab_size (int or None): Vocabulary size of a serialized policy, or None - in case of a raw policy. - - Returns: - New weights or state of wrapped_policy, with the inner weights/state - copied from inner_policy. - """ - if vocab_size is None: - substitute_fn = substitute_inner_policy_raw - else: - substitute_fn = substitute_inner_policy_serialized - return substitute_fn(wrapped_policy, inner_policy) diff --git a/trax/rl/serialization_utils_test.py b/trax/rl/serialization_utils_test.py deleted file mode 100644 index e7fd79c8d..000000000 --- a/trax/rl/serialization_utils_test.py +++ /dev/null @@ -1,294 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.rl.serialization_utils.""" -import functools - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import gym -from jax import numpy as jnp -import numpy as np -from trax import models as trax_models -from trax import shapes -from trax import test_utils -from trax.data import inputs as trax_input -from trax.layers import base as layers_base -from trax.models import transformer -from trax.rl import serialization_utils -from trax.rl import space_serializer -from trax.supervised import trainer_lib - - -# pylint: disable=invalid-name -def TestModel(extra_dim, mode='train'): - """Dummy sequence model for testing.""" - del mode - def f(inputs): - # Cast the input to float32 - this is for simulating discrete-input models. - inputs = inputs.astype(np.float32) - # Add an extra dimension if requested, e.g. the logit dimension for output - # symbols. - if extra_dim is not None: - return jnp.broadcast_to(inputs[:, :, None], inputs.shape + (extra_dim,)) - else: - return inputs - return layers_base.Fn('TestModel', f) - # pylint: enable=invalid-name - - -def signal_inputs(seq_len, batch_size, depth=1): - def stream_fn(num_devices): - del num_devices - while True: - x = np.random.uniform(size=(batch_size, seq_len, depth)) - y = np.random.uniform(size=(batch_size, seq_len, depth)) - mask = np.ones_like(x).astype(np.float32) - yield (x, y, x, mask) - - return trax_input.Inputs( - train_stream=stream_fn, - eval_stream=stream_fn, - ) - - -class SerializationTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self._serializer = space_serializer.create( - gym.spaces.Discrete(2), vocab_size=2 - ) - self._repr_length = 100 - self._serialization_utils_kwargs = { - 'observation_serializer': self._serializer, - 'action_serializer': self._serializer, - 'representation_length': self._repr_length, - } - test_utils.ensure_flag('test_tmpdir') - - def test_serialized_model_discrete(self): - vocab_size = 3 - obs = np.array([[[0, 1], [1, 1], [1, 0], [0, 0]]]) - act = np.array([[1, 0, 0]]) - mask = np.array([[1, 1, 1, 0]]) - - test_model_inputs = [] - - # pylint: disable=invalid-name - def TestModelSavingInputs(mode): - del mode - def f(inputs): - # Save the inputs for a later check. - test_model_inputs.append(inputs) - # Change type to np.float32 and add the logit dimension. - return jnp.broadcast_to( - inputs.astype(np.float32)[:, :, None], inputs.shape + (vocab_size,) - ) - return layers_base.Fn('TestModelSavingInputs', f) - # pylint: enable=invalid-name - - obs_serializer = space_serializer.create( - gym.spaces.MultiDiscrete([2, 2]), vocab_size=vocab_size - ) - act_serializer = space_serializer.create( - gym.spaces.Discrete(2), vocab_size=vocab_size - ) - serialized_model = serialization_utils.SerializedModel( - TestModelSavingInputs, # pylint: disable=no-value-for-parameter - observation_serializer=obs_serializer, - action_serializer=act_serializer, - significance_decay=0.9, - ) - - example = (obs, act, obs, mask) - serialized_model.init(shapes.signature(example)) - - (obs_logits, obs_repr, weights) = serialized_model(example) - # Check that the model has been called with the correct input. - np.testing.assert_array_equal( - # The model is called multiple times for determining shapes etc. - # Check the last saved input - that should be the actual concrete array - # calculated during the forward pass. - test_model_inputs[-1], - # Should be serialized observations and actions interleaved. - [[0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]], - ) - # Check the output shape. - self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size,)) - # Check that obs_logits are the same as obs_repr, just broadcasted over the - # logit dimension. - np.testing.assert_array_equal(np.min(obs_logits, axis=-1), obs_repr) - np.testing.assert_array_equal(np.max(obs_logits, axis=-1), obs_repr) - # Check that the observations are correct. - np.testing.assert_array_equal(obs_repr, obs) - # Check weights. - np.testing.assert_array_equal( - weights, - [[[1., 1.], [1., 1.], [1., 1.], [0., 0.]]], - ) - - def test_train_model_with_serialization(self): - # Serializer handles discretization of the data. - precision = 2 - number_of_time_series = 2 - vocab_size = 16 - srl = space_serializer.BoxSpaceSerializer( - space=gym.spaces.Box(shape=(number_of_time_series,), - low=0.0, high=16.0), - vocab_size=vocab_size, - precision=precision, - ) - - def model(mode): - del mode - return serialization_utils.SerializedModel( - functools.partial( - trax_models.TransformerLM, - vocab_size=vocab_size, - d_model=16, - d_ff=8, - n_layers=1, - n_heads=1, - ), - observation_serializer=srl, - action_serializer=srl, - significance_decay=0.9, - ) - - output_dir = self.create_tempdir().full_path - state = trainer_lib.train( - output_dir=output_dir, - model=model, - inputs=functools.partial(signal_inputs, seq_len=5, - batch_size=64, depth=number_of_time_series), - steps=2) - self.assertEqual(2, state.step) - - def test_serialized_model_continuous(self): - precision = 3 - gin.bind_parameter('BoxSpaceSerializer.precision', precision) - - vocab_size = 32 - obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]]) - act = np.array([[0, 1, 0]]) - mask = np.array([[1, 1, 1, 0]]) - - obs_serializer = space_serializer.create( - gym.spaces.Box(shape=(2,), low=-2, high=2), vocab_size=vocab_size - ) - act_serializer = space_serializer.create( - gym.spaces.Discrete(2), vocab_size=vocab_size - ) - serialized_model = serialization_utils.SerializedModel( - functools.partial(TestModel, extra_dim=vocab_size), - observation_serializer=obs_serializer, - action_serializer=act_serializer, - significance_decay=0.9, - ) - - example = (obs, act, obs, mask) - serialized_model.init(shapes.signature(example)) - - (obs_logits, obs_repr, weights) = serialized_model(example) - self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size,)) - self.assertEqual( - obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision) - ) - self.assertEqual(obs_repr.shape, weights.shape) - - def test_serialized_model_extracts_seq_model_weights_and_state(self): - vocab_size = 3 - - seq_model_fn = functools.partial( - transformer.TransformerLM, - vocab_size=vocab_size, - d_model=2, - d_ff=2, - n_layers=0, - ) - seq_model = seq_model_fn(mode='eval') - obs_serializer = space_serializer.create( - gym.spaces.Discrete(2), vocab_size=vocab_size - ) - act_serializer = space_serializer.create( - gym.spaces.Discrete(2), vocab_size=vocab_size - ) - serialized_model = serialization_utils.SerializedModel( - seq_model_fn, - observation_serializer=obs_serializer, - action_serializer=act_serializer, - significance_decay=0.9, - ) - - obs_sig = shapes.ShapeDtype((1, 2)) - act_sig = shapes.ShapeDtype((1, 1)) - serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig)) - seq_model.weights = serialized_model.seq_model_weights - seq_model.state = serialized_model.seq_model_state - # Run the model to check if the weights and state have correct structure. - seq_model(jnp.array([[0]])) - - @parameterized.named_parameters(('raw', None), ('serialized', 32)) - def test_wrapped_policy_continuous(self, vocab_size): - precision = 3 - n_controls = 2 - n_actions = 4 - gin.bind_parameter('BoxSpaceSerializer.precision', precision) - - obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]]) - act = np.array([[[0, 1], [2, 0], [1, 3]]]) - - wrapped_policy = serialization_utils.wrap_policy( - TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter - observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2), - action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls), - vocab_size=vocab_size, - ) - - example = (obs, act) - wrapped_policy.init(shapes.signature(example)) - (act_logits, values) = wrapped_policy(example) - self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions)) - self.assertEqual(values.shape, obs.shape[:2]) - - def test_analyzes_discrete_action_space(self): - space = gym.spaces.Discrete(n=5) - (n_controls, n_actions) = serialization_utils.analyze_action_space(space) - self.assertEqual(n_controls, 1) - self.assertEqual(n_actions, 5) - - def test_analyzes_multi_discrete_action_space_with_equal_categories(self): - space = gym.spaces.MultiDiscrete(nvec=(3, 3)) - (n_controls, n_actions) = serialization_utils.analyze_action_space(space) - self.assertEqual(n_controls, 2) - self.assertEqual(n_actions, 3) - - def test_doesnt_analyze_multi_disccrete_action_space_with_inequal_categories( - self - ): - space = gym.spaces.MultiDiscrete(nvec=(2, 3)) - with self.assertRaises(AssertionError): - serialization_utils.analyze_action_space(space) - - def test_doesnt_analyze_box_action_space(self): - space = gym.spaces.Box(shape=(2, 3), low=0, high=1) - with self.assertRaises(AssertionError): - serialization_utils.analyze_action_space(space) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/space_serializer.py b/trax/rl/space_serializer.py deleted file mode 100644 index 316060de8..000000000 --- a/trax/rl/space_serializer.py +++ /dev/null @@ -1,216 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Serialization of elements of Gym spaces into discrete sequences.""" -import copy - -from absl import logging -import gin -import gym -from jax import numpy as np - - -class SpaceSerializer: - """Base class for Gym space serializers. - - Attrs: - space_type: (type) Gym space class that this SpaceSerializer corresponds - to. Should be defined in subclasses. - representation_length: (int) Number of symbols in the representation of - every element of the space. - significance_map: (np.ndarray) Integer array of the same size as the - discrete representation, where elements describe the significance of - symbols, e.g. in fixed-precision encoding. 0 is the most significant - symbol, 1 the second most significant etc. - """ - - space_type = None - representation_length = None - significance_map = None - - def __init__(self, space, vocab_size): - """Creates a SpaceSerializer. - - Subclasses should retain the signature. - - Args: - space: (gym.Space) Gym space of type self.space_type. - vocab_size: (int) Number of symbols in the vocabulary. - """ - assert isinstance(space, self.space_type) - self._space = space - self._vocab_size = vocab_size - - @property - def vocab_size(self): - return self._vocab_size - - def serialize(self, data): - """Serializes a batch of space elements into discrete sequences. - - Should be defined in subclasses. - - Args: - data: A batch of batch_size elements of the Gym space to be serialized. - - Returns: - int32 array of shape (batch_size, self.representation_length). - """ - raise NotImplementedError - - def deserialize(self, representation): - """Deserializes a batch of discrete sequences into space elements. - - Should be defined in subclasses. - - Args: - representation: int32 Numpy array of shape - (batch_size, self.representation_length) to be deserialized. - - Returns: - A batch of batch_size deserialized elements of the Gym space. - """ - raise NotImplementedError - - -def create(space, vocab_size): - """Creates a SpaceSerializer for the given Gym space.""" - return { - gym.spaces.Box: BoxSpaceSerializer, - gym.spaces.Discrete: DiscreteSpaceSerializer, - gym.spaces.MultiDiscrete: MultiDiscreteSpaceSerializer, - }[type(space)](space, vocab_size) - - -@gin.configurable(denylist=['space', 'vocab_size']) -class BoxSpaceSerializer(SpaceSerializer): - """Serializer for gym.spaces.Box. - - Assumes that the space is bounded. Internally rescales it to the [0, 1] - interval and uses a fixed-precision encoding. - """ - - space_type = gym.spaces.Box - - def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)): - self._precision = precision - - # Some gym envs (e.g. CartPole) have unreasonably high bounds for - # observations. We clip so we can represent them. - bounded_space = copy.copy(space) - (min_low, max_high) = max_range - bounded_space.low = np.maximum(space.low, min_low) - bounded_space.high = np.minimum(space.high, max_high) - if (not np.allclose(bounded_space.low, space.low) or - not np.allclose(bounded_space.high, space.high)): - logging.warning( - 'Space limits %s, %s out of bounds %s. Clipping to %s, %s.', - str(space.low), str(space.high), str(max_range), - str(bounded_space.low), str(bounded_space.high) - ) - - super().__init__(bounded_space, vocab_size) - - def serialize(self, data): - array = data - batch_size = array.shape[0] - array = (array - self._space.low) / (self._space.high - self._space.low) - array = np.clip(array, 0, 1) - digits = [] - for digit_index in range(-1, -self._precision - 1, -1): - threshold = self._vocab_size ** digit_index - digit = np.array(array / threshold).astype(np.int32) - # For the corner case of x == high. - digit = np.where(digit == self._vocab_size, digit - 1, digit) - digits.append(digit) - array -= digit * threshold - digits = np.stack(digits, axis=-1) - return np.reshape(digits, (batch_size, -1)) - - def deserialize(self, representation): - digits = representation - batch_size = digits.shape[0] - digits = np.reshape(digits, (batch_size, -1, self._precision)) - array = np.zeros(digits.shape[:-1]) - for digit_index_in_seq in range(self._precision): - digit_index = -digit_index_in_seq - 1 - array += self._vocab_size ** digit_index * digits[..., digit_index_in_seq] - array = np.reshape(array, (batch_size,) + self._space.shape) - return array * (self._space.high - self._space.low) + self._space.low - - @property - def representation_length(self): - return self._precision * self._space.low.size - - @property - def significance_map(self): - return np.reshape(np.broadcast_to( - np.arange(self._precision), self._space.shape + (self._precision,)), -1) - - -class DiscreteSpaceSerializer(SpaceSerializer): - """Serializer for gym.spaces.Discrete. - - Assumes that the size of the space fits in the number of symbols. - """ - - space_type = gym.spaces.Discrete - representation_length = 1 - - def __init__(self, space, vocab_size): - super().__init__(space, vocab_size) - assert space.n <= vocab_size, ( - 'Discrete space size should fit in the number of symbols.') - - def serialize(self, data): - return np.reshape(data, (-1, 1)).astype(np.int32) - - def deserialize(self, representation): - return np.reshape(representation, -1) - - @property - def significance_map(self): - return np.zeros(1, dtype=np.int32) - - -class MultiDiscreteSpaceSerializer(SpaceSerializer): - """Serializer for gym.spaces.MultiDiscrete. - - Assumes that the number of categories in each dimension fits in the number of - symbols. - """ - - space_type = gym.spaces.MultiDiscrete - - def __init__(self, space, vocab_size): - super().__init__(space, vocab_size) - assert np.max(space.nvec) <= vocab_size, ( - 'MultiDiscrete maximum number of categories should fit in the number ' - 'of symbols.' - ) - - def serialize(self, data): - return data.astype(np.int32) - - def deserialize(self, representation): - return representation - - @property - def representation_length(self): - return len(self._space.nvec) - - @property - def significance_map(self): - return np.zeros(self.representation_length, dtype=np.int32) diff --git a/trax/rl/space_serializer_test.py b/trax/rl/space_serializer_test.py deleted file mode 100644 index 6e46f2ba1..000000000 --- a/trax/rl/space_serializer_test.py +++ /dev/null @@ -1,157 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.rl.space_serializer.""" -import gin -import gym -import numpy as np -from tensorflow import test -from trax.rl import space_serializer - - -class BoxSpaceSerializerTest(test.TestCase): - - def _make_space_and_serializer( - self, low=-10, high=10, shape=(2,), - # Weird vocab_size to test that it doesn't only work with powers of 2. - vocab_size=257, - # Enough precision to represent float32s accurately. - precision=4, - ): - gin.bind_parameter('BoxSpaceSerializer.precision', precision) - space = gym.spaces.Box(low=low, high=high, shape=shape) - serializer = space_serializer.create(space, vocab_size=vocab_size) - return (space, serializer) - - def _sample_batch(self, space): - return np.reshape(space.sample(), (1,) + space.shape) - - def test_representation_length(self): - (space, serializer) = self._make_space_and_serializer() - input_array = self._sample_batch(space) - representation = serializer.serialize(input_array) - self.assertEqual( - representation.shape, (1, serializer.representation_length)) - - def test_commutes(self): - (space, serializer) = self._make_space_and_serializer() - input_array = self._sample_batch(space) - representation = serializer.serialize(input_array) - output_array = serializer.deserialize(representation) - # Testing till 5 decimals to reduce flakyness. - np.testing.assert_array_almost_equal(input_array, output_array, decimal=5) - - def test_representation_changes(self): - (space, serializer) = self._make_space_and_serializer() - array1 = self._sample_batch(space) - array2 = -array1 - (repr1, repr2) = tuple(map(serializer.serialize, (array1, array2))) - self.assertFalse(np.array_equal(repr1, repr2)) - - def test_bounds_space(self): - gin.bind_parameter('BoxSpaceSerializer.max_range', (-10.0, 10.0)) - (_, serializer) = self._make_space_and_serializer( - # Too wide range to represent, need to clip. - low=-1e18, high=1e18, - shape=(1,)) - input_array = np.array([[1.2345]]) - representation = serializer.serialize(input_array) - output_array = serializer.deserialize(representation) - np.testing.assert_array_almost_equal(input_array, output_array) - - def test_significance_map(self): - (_, serializer) = self._make_space_and_serializer(shape=(2,)) - np.testing.assert_array_equal( - serializer.significance_map, [0, 1, 2, 3, 0, 1, 2, 3]) - - def test_serializes_boundaries(self): - vocab_size = 256 - precision = 4 - (_, serializer) = self._make_space_and_serializer( - low=-1, high=1, shape=(1,), vocab_size=vocab_size, precision=precision, - ) - input_array = np.array([[-1, 1]]) - representation = serializer.serialize(input_array) - np.testing.assert_array_equal( - representation, [[0] * precision + [vocab_size - 1] * precision] - ) - - -class DiscreteSpaceSerializerTest(test.TestCase): - - def setUp(self): - super().setUp() - self._space = gym.spaces.Discrete(n=2) - self._serializer = space_serializer.create(self._space, vocab_size=2) - - def _sample_batch(self): - return np.reshape(self._space.sample(), (1,) + self._space.shape) - - def test_representation_length(self): - input_array = self._sample_batch() - representation = self._serializer.serialize(input_array) - self.assertEqual( - representation.shape, (1, self._serializer.representation_length)) - - def test_commutes(self): - input_array = self._sample_batch() - representation = self._serializer.serialize(input_array) - output_array = self._serializer.deserialize(representation) - np.testing.assert_array_almost_equal(input_array, output_array) - - def test_representation_changes(self): - array1 = self._sample_batch() - array2 = 1 - array1 - (repr1, repr2) = tuple(map(self._serializer.serialize, (array1, array2))) - self.assertFalse(np.array_equal(repr1, repr2)) - - def test_significance_map(self): - np.testing.assert_array_equal(self._serializer.significance_map, [0]) - - -class MultiDiscreteSpaceSerializerTest(test.TestCase): - - def setUp(self): - super().setUp() - self._space = gym.spaces.MultiDiscrete(nvec=[2, 2]) - self._serializer = space_serializer.create(self._space, vocab_size=2) - - def _sample_batch(self): - return np.reshape(self._space.sample(), (1,) + self._space.shape) - - def test_representation_length(self): - input_array = self._sample_batch() - representation = self._serializer.serialize(input_array) - self.assertEqual( - representation.shape, (1, self._serializer.representation_length)) - - def test_commutes(self): - input_array = self._sample_batch() - representation = self._serializer.serialize(input_array) - output_array = self._serializer.deserialize(representation) - np.testing.assert_array_almost_equal(input_array, output_array) - - def test_representation_changes(self): - array1 = self._sample_batch() - array2 = 1 - array1 - (repr1, repr2) = tuple(map(self._serializer.serialize, (array1, array2))) - self.assertFalse(np.array_equal(repr1, repr2)) - - def test_significance_map(self): - np.testing.assert_array_equal(self._serializer.significance_map, [0, 0]) - - -if __name__ == '__main__': - test.main() diff --git a/trax/rl/task.py b/trax/rl/task.py deleted file mode 100644 index 5a9b2342f..000000000 --- a/trax/rl/task.py +++ /dev/null @@ -1,874 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for defining RL tasks in Trax.""" - -import collections -import os - -import gin -import gym -import numpy as np - -from trax import fastmath -from trax.rl import advantages -from trax.supervised import training - - - -# TimeStepBatch stores a single step in the trajectory of an RL run, or -# a sequence of timesteps (trajectory slice), or a batch of such sequences. -# Fields: -# * `observation` at the beginning of the step -# * `action` that was taken -# * `reward` gotten when the action was taken (or None if action wasn't taken) -# * `done` - whether the trajectory has finished in this step -# * `mask` - padding mask -# * `return_` - discounted return from this state (includes the current reward); -# `None` if it hasn't been computed yet -# * `dist_inputs` - parameters of the policy distribution, stored by some -# RL algortihms -# TODO(pkozakowski): Generalize `dist_inputs` to `agent_info` - a namedtuple -# storing agent-specific data. -TimeStepBatch = collections.namedtuple('TimeStepBatch', [ - 'observation', - 'action', - 'reward', - 'done', - 'mask', - 'dist_inputs', - 'env_info', - 'return_', -]) - - -# EnvInfo stores additional information returned by -# `trax.rl.envs.SequenceDataEnv`. In those environments, one timestep -# corresponds to one token in the sequence. While the environment is emitting -# observation tokens, actions taken by the agent don't matter. Actions can also -# span multiple tokens, but the discount should only be applied once. -# Fields: -# * `control_mask` - mask determining whether the last interaction was -# controlled, so whether the action performed by the agent mattered; -# can be used to mask policy and value loss; negation can be used to mask -# world model observation loss; defaults to 1 - all actions matter -# * `discount_mask` - mask determining whether the discount should be applied to -# the current reward; defaults to 1 - all rewards are discounted -EnvInfo = collections.namedtuple('EnvInfo', ['control_mask', 'discount_mask']) -EnvInfo.__new__.__defaults__ = (1, 1) - - -# `env_info` and `return_` can be omitted in `TimeStepBatch`. -TimeStepBatch.__new__.__defaults__ = (EnvInfo(), None,) - - -class Trajectory: - """A trajectory of interactions with a RL environment. - - Trajectories are created when interacting with an RL environment. They can - be prolonged and sliced and when completed, allow to re-calculate returns. - """ - - def __init__(self, observation): - # TODO(lukaszkaiser): add support for saving and loading trajectories, - # reuse code from base_trainer.dump_trajectories and related functions. - self._last_observation = observation - self._timesteps = [] - self._timestep_batch = None - self._cached_to_np_args = None - - def __len__(self): - """Returns the number of observations in the trajectory.""" - # We always have 1 more of observations than of everything else. - return len(self._timesteps) + 1 - - def __repr__(self): - return repr({ - 'timesteps': self._timesteps, - 'last_observation': self._last_observation, - }) - - def suffix(self, length): - """Returns a `Trajectory` with the last `length` observations.""" - assert length >= 1 - t = Trajectory(self._last_observation) - t._timesteps = self._timesteps[-(length - 1):] # pylint: disable=protected-access - return t - - @property - def timesteps(self): - return self._timesteps - - @property - def total_return(self): - """Sum of all rewards in this trajectory.""" - return sum([t.reward or 0.0 for t in self._timesteps]) - - @property - def last_observation(self): - """Return the last observation in this trajectory.""" - return self._last_observation - - @property - def done(self): - """Returns whether the trajectory is finished.""" - if not self._timesteps: - return False - return self._timesteps[-1].done - - @done.setter - def done(self, done): - """Sets the `done` flag in the last timestep.""" - if not self._timesteps: - raise ValueError('No interactions yet in the trajectory.') - self._timesteps[-1] = self._timesteps[-1]._replace(done=done) - - def extend(self, new_observation, mask=1, **kwargs): - """Take action in the last state, getting reward and going to new state.""" - self._timesteps.append(TimeStepBatch( - observation=self._last_observation, mask=mask, **kwargs - )) - self._last_observation = new_observation - - def calculate_returns(self, gamma): - """Calculate discounted returns.""" - rewards = np.array([ts.reward for ts in self._timesteps]) - discount_mask = np.array([ - ts.env_info.discount_mask for ts in self._timesteps - ]) - gammas = advantages.mask_discount(gamma, discount_mask) - returns = advantages.discounted_returns(rewards, gammas) - for (i, return_) in enumerate(returns): - self._timesteps[i] = self._timesteps[i]._replace(return_=return_) - - def _default_timestep_to_np(self, ts): - """Default way to convert timestep to numpy.""" - return fastmath.nested_map(np.array, ts) - - def to_np(self, margin=1, timestep_to_np=None): - """Create a tuple of numpy arrays from a given trajectory. - - Args: - margin (int): Number of dummy timesteps past the trajectory end to - include. By default we include 1, which contains the last - observation. - timestep_to_np (callable or None): Optional function - TimeStepBatch[Any] -> TimeStepBatch[np.array], converting the - timestep data into numpy arrays. - - Returns: - TimeStepBatch, where all fields have shape - (len(self) + margin - 1, ...). - """ - timestep_to_np = timestep_to_np or self._default_timestep_to_np - args = (margin, timestep_to_np) - - # Return the cached result if the arguments agree and the trajectory has not - # grown. - if self._timestep_batch: - result_length = len(self) + margin - 1 - length_ok = self._timestep_batch.observation.shape[0] == result_length - if args == self._cached_to_np_args and length_ok: - return self._timestep_batch - - # observation, action, reward, etc. - fields = TimeStepBatch._fields - # List of timestep data for each field. - data_lists = TimeStepBatch(**{field: [] for field in fields}) - for timestep in self._timesteps: - timestep_np = timestep_to_np(timestep) - # Append each field of timestep_np to the appropriate list. - for field in fields: - getattr(data_lists, field).append(getattr(timestep_np, field)) - # Append the last observation. - data_lists.observation.append(self._last_observation) - - # TODO(pkozakowski): The case len(obs) == 1 is for handling - # "dummy trajectories", that are only there to determine data shapes. Check - # if they're still required. - if len(data_lists.observation) > 1: - # Extend the trajectory with a given margin - this is to make sure that - # the networks always "see" the "done" states in the training data, even - # when a suffix is added to the trajectory slice for better estimation of - # returns. - # We set `mask` to 0, so the added timesteps don't influence the loss. We - # set `done` to True for easier implementation of advantage estimators. - # The rest of the fields don't matter, so we set them to 0 for easy - # debugging (unless they're None). The list of observations is longer, so - # we pad it with margin - 1. - data_lists.mask.extend([0] * margin) - data_lists.done.extend([True] * margin) - data_lists.observation.extend( - [np.zeros_like(data_lists.observation[-1])] * (margin - 1) - ) - for field in set(fields) - {'mask', 'done', 'observation'}: - l = getattr(data_lists, field) - filler = None if l[-1] is None else np.zeros_like(l[-1]) - l.extend([filler] * margin) - - # Trim the observations to have the same length as the rest of the fields. - # This is not be the case when margin=0. - if margin == 0: - data_lists.observation.pop() - - def stack(x): - if not x: - return None - return fastmath.nested_stack(x) - - # Stack the data_lists into numpy arrays. - timestep_batch = TimeStepBatch(*map(stack, data_lists)) - - self._timestep_batch = timestep_batch - self._cached_to_np_args = args - - return timestep_batch - - -def play(env, policy, dm_suite=False, max_steps=None, last_observation=None): - """Play an episode in env taking actions according to the given policy. - - Environment is first reset and an from then on, a game proceeds. At each - step, the policy is asked to choose an action and the environment moves - forward. A Trajectory is created in that way and returns when the episode - finished, which is either when env returns `done` or max_steps is reached. - - Args: - env: the environment to play in, conforming to gym.Env or - DeepMind suite interfaces. - policy: a function taking a Trajectory and returning a pair consisting - of an action (int or float) and the confidence in that action (float, - defined as the log of the probability of taking that action). - dm_suite: whether we are using the DeepMind suite or the gym interface - max_steps: for how many steps to play. - last_observation: last observation from a previous trajectory slice, used to - begin a new one. Controls whether we reset the environment at the - beginning - if `None`, resets the env and starts the slice from the - observation got from reset(). - - Returns: - a completed trajectory slice that was just played. - """ - done = False - cur_step = 0 - if last_observation is None: - # TODO(pkozakowski): Make a Gym wrapper over DM envs to get rid of branches - # like that. - last_observation = env.reset().observation if dm_suite else env.reset() - cur_trajectory = Trajectory(last_observation) - while not done and (max_steps is None or cur_step < max_steps): - action, dist_inputs = policy(cur_trajectory) - action = np.asarray(action) - step = env.step(action) - if dm_suite: - (observation, reward, done) = ( - step.observation, step.reward, step.step_type.last() - ) - info = {} - else: - (observation, reward, done, info) = step - - # Make an EnvInfo out of the supported keys in the info dict. - env_info = EnvInfo(**{ - key: value for (key, value) in info.items() - if key in EnvInfo._fields - }) - cur_trajectory.extend( - action=action, - dist_inputs=dist_inputs, - reward=reward, - done=done, - new_observation=observation, - env_info=env_info, - ) - cur_step += 1 - return cur_trajectory - - -def _zero_pad(x, pad, axis): - """Helper for np.pad with 0s for single-axis case.""" - pad_widths = [(0, 0)] * len(x.shape) - pad_widths[axis] = pad # Padding on axis. - return np.pad(x, pad_widths, mode='constant', - constant_values=x.dtype.type(0)) - - -def _random_policy(action_space): - return lambda _: (action_space.sample(), None) - - -def _sample_proportionally(inputs, weights): - """Sample an element from the inputs list proportionally to weights. - - Args: - inputs: a list, we will return one element of this list. - weights: a sequence of numbers of the same length as inputs; we will sample - the k-th input with probability weights[k] / sum(weights). - - Returns: - an element from inputs. - """ - l = len(inputs) - weights = np.array(weights) - if l != len(weights): - raise ValueError(f'Inputs and weights must have the same length, but do not' - f': {l} != {len(weights)}') - norm_weights = weights / np.sum(weights) - # TODO(pkozakowski): Currently this is O(n). It can be sped up to O(log n) by - # storing CDF and binsearching on it. - idx = np.random.choice(l, p=norm_weights) - return inputs[int(idx)] - - -def _n_slices(trajectory, max_slice_length, margin): - """How many slices of length upto max_slice_length in a trajectory.""" - # TODO(lukaszkaiser): add option to sample from n last trajectories. - if not max_slice_length: - return 1 - # A trajectory [a, b, c, end_state] will have 2 slices of length 2: - # the slice [a, b] and the one [b, c], with margin=0; 3 with margin=1. - return max(1, len(trajectory) + margin - max_slice_length) - - -@gin.configurable -class RLTask: - """A RL task: environment and a collection of trajectories.""" - - def __init__(self, env=gin.REQUIRED, - initial_trajectories=1, - gamma=0.99, - dm_suite=False, - random_starts=True, - max_steps=None, - time_limit=None, - timestep_to_np=None, - num_stacked_frames=1, - n_replay_epochs=1): - r"""Configures a RL task. - - Args: - env: Environment confirming to the gym.Env interface or a string, - in which case `gym.make` will be called on this string to create an env. - initial_trajectories: either a dict or list of Trajectories to use - at start or an int, in which case that many trajectories are - collected using a random policy to play in env. It can be also a string - and then it should direct to the location where previously recorded - trajectories are stored. - gamma: float: discount factor for calculating returns. - dm_suite: whether we are using the DeepMind suite or the gym interface - random_starts: use random starts for training of Atari agents. - max_steps: optional int: cut all trajectory slices at that many steps. - The trajectory will be continued in the next epochs, up to `time_limit`. - time_limit: optional int: stop all trajectories after that many steps (or - after getting "done"). If `None`, use the same value as `max_steps`. - timestep_to_np: a function that turns a timestep into a numpy array - (ie., a tensor); if None, we just use the state of the timestep to - represent it, but other representations (such as embeddings that include - actions or serialized representations) can be passed here. - num_stacked_frames: the number of stacked frames for Atari. - n_replay_epochs: the size of the replay buffer expressed in epochs. - """ - if isinstance(env, str): - self._env_name = env - if dm_suite: - eval_env = None - env = None - else: - env = gym.make(self._env_name) - eval_env = gym.make(self._env_name) - else: - self._env_name = type(env).__name__ - eval_env = env - self._env = env - self._eval_env = eval_env - self._dm_suite = dm_suite - self._max_steps = max_steps - if time_limit is None: - time_limit = max_steps - self._time_limit = time_limit - self._gamma = gamma - self._initial_trajectories = initial_trajectories - self._last_observation = None - self._n_steps_left = time_limit - # Example trajectory for determining input/output shapes of the networks. - self._example_trajectory = self.play( - _random_policy(self.action_space), only_eval=True - ) - # TODO(lukaszkaiser): find a better way to pass initial trajectories, - # whether they are an explicit list, a file, or a number of random ones. - if isinstance(initial_trajectories, int): - initial_trajectories = [ - self.play(_random_policy(self.action_space)) - for _ in range(initial_trajectories) - ] - if isinstance(initial_trajectories, str): - initial_trajectories = self.load_initial_trajectories_from_path( - initial_trajectories_path=initial_trajectories) - if isinstance(initial_trajectories, list): - if initial_trajectories: - initial_trajectories = {0: initial_trajectories} - else: - initial_trajectories = {} - self._timestep_to_np = timestep_to_np - # Stored trajectories are indexed by epoch and within each epoch they - # are stored in the order of generation so we can implement replay buffers. - # TODO(lukaszkaiser): use dump_trajectories from BaseTrainer to allow - # saving and reading trajectories from disk. - self._trajectories = collections.defaultdict(list) - self._trajectories.update(initial_trajectories) - # When we repeatedly save, trajectories for many epochs do not change, so - # we don't need to save them again. This keeps track which are unchanged. - self._saved_epochs_unchanged = [] - self._n_replay_epochs = n_replay_epochs - self._n_trajectories = 0 - self._n_interactions = 0 - - @property - def env(self): - return self._env - - @property - def env_name(self): - return self._env_name - - @property - def max_steps(self): - return self._max_steps - - @property - def gamma(self): - return self._gamma - - @property - def action_space(self): - if self._dm_suite: - return gym.spaces.Discrete(self._env.action_spec().num_values) - else: - return self._env.action_space - - @property - def observation_space(self): - """Returns the env's observation space in a Gym interface.""" - if self._dm_suite: - return gym.spaces.Box( - shape=self._env.observation_spec().shape, - dtype=self._env.observation_spec().dtype, - low=float('-inf'), - high=float('+inf'), - ) - else: - return self._env.observation_space - - @property - def trajectories(self): - return self._trajectories - - @property - def timestep_to_np(self): - return self._timestep_to_np - - @timestep_to_np.setter - def timestep_to_np(self, ts): - self._timestep_to_np = ts - - def _epoch_filename(self, base_filename, epoch): - """Helper function: file name for saving the given epoch.""" - # If base is /foo/task.pkl, we save epoch 1 under /foo/task_epoch1.pkl. - filename, ext = os.path.splitext(base_filename) - return filename + '_epoch' + str(epoch) + ext - - def set_n_replay_epochs(self, n_replay_epochs): - self._n_replay_epochs = n_replay_epochs - - def load_initial_trajectories_from_path(self, - initial_trajectories_path, - dictionary_file='trajectories.pkl', - start_epoch_to_load=0): - """Initialize trajectories task from file.""" - # We assume that this is a dump generated by Trax - dictionary_file = os.path.join(initial_trajectories_path, dictionary_file) - dictionary = training.unpickle_from_file(dictionary_file, gzip=False) - # TODO(henrykm): as currently implemented this accesses only - # at most the last n_replay_epochs - this should be more flexible - epochs_to_load = dictionary['all_epochs'][start_epoch_to_load:] - - all_trajectories = [] - for epoch in epochs_to_load: - trajectories = training.unpickle_from_file( - self._epoch_filename(dictionary_file, epoch), gzip=True) - all_trajectories += trajectories - return all_trajectories - - def init_from_file(self, file_name): - """Initialize this task from file.""" - dictionary = training.unpickle_from_file(file_name, gzip=False) - self._n_trajectories = dictionary['n_trajectories'] - self._n_interactions = dictionary['n_interactions'] - self._max_steps = dictionary['max_steps'] - self._gamma = dictionary['gamma'] - epochs_to_load = dictionary['all_epochs'][-self._n_replay_epochs:] - - for epoch in epochs_to_load: - trajectories = training.unpickle_from_file( - self._epoch_filename(file_name, epoch), gzip=True) - self._trajectories[epoch] = trajectories - self._saved_epochs_unchanged = epochs_to_load - - def save_to_file(self, file_name): - """Save this task to file.""" - # Save trajectories from new epochs first. - epochs_to_save = [e for e in self._trajectories - if e not in self._saved_epochs_unchanged] - for epoch in epochs_to_save: - training.pickle_to_file(self._trajectories[epoch], - self._epoch_filename(file_name, epoch), - gzip=True) - # Now save the list of epochs (so the trajectories are already there, - # even in case of preemption). - dictionary = {'n_interactions': self._n_interactions, - 'n_trajectories': self._n_trajectories, - 'max_steps': self._max_steps, - 'gamma': self._gamma, - 'all_epochs': list(self._trajectories.keys())} - training.pickle_to_file(dictionary, file_name, gzip=False) - - def play(self, policy, max_steps=None, only_eval=False): - """Play an episode in env taking actions according to the given policy.""" - if max_steps is None: - max_steps = self._max_steps - if only_eval: - cur_trajectory = play( - self._eval_env, policy, self._dm_suite, - # Only step up to the time limit. - max_steps=min(max_steps, self._time_limit), - # Always reset at the beginning of an eval episode. - last_observation=None, - ) - else: - cur_trajectory = play( - self._env, policy, self._dm_suite, - # Only step up to the time limit, used up by all slices of the same - # trajectory. - max_steps=min(max_steps, self._n_steps_left), - # Pass the environmnent state between play() calls, so one episode can - # span multiple training epochs. - # NOTE: Cutting episodes between epochs may hurt e.g. with - # Transformers. - # TODO(pkozakowski): Join slices together if this becomes a problem. - last_observation=self._last_observation, - ) - # Update the number of steps left to reach the time limit. - # NOTE: This should really be done as an env wrapper. - # TODO(pkozakowski): Do that once we wrap the DM envs in Gym interface. - # The initial reset doesn't count. - self._n_steps_left -= len(cur_trajectory) - 1 - assert self._n_steps_left >= 0 - if self._n_steps_left == 0: - cur_trajectory.done = True - # Pass the last observation between trajectory slices. - if cur_trajectory.done: - self._last_observation = None - # Reset the time limit. - self._n_steps_left = self._time_limit - else: - self._last_observation = cur_trajectory.last_observation - - cur_trajectory.calculate_returns(self._gamma) - return cur_trajectory - - def collect_trajectories( - self, policy, - n_trajectories=None, - n_interactions=None, - only_eval=False, - max_steps=None, - epoch_id=1, - ): - """Collect experience in env playing the given policy.""" - max_steps = max_steps or self.max_steps - if n_trajectories is not None: - new_trajectories = [self.play(policy, max_steps=max_steps, - only_eval=only_eval) - for _ in range(n_trajectories)] - elif n_interactions is not None: - new_trajectories = [] - while n_interactions > 0: - traj = self.play(policy, max_steps=min(n_interactions, max_steps)) - new_trajectories.append(traj) - n_interactions -= len(traj) - 1 # The initial reset does not count. - else: - raise ValueError( - 'Either n_trajectories or n_interactions must be defined.' - ) - - # Calculate returns. - returns = [t.total_return for t in new_trajectories] - if returns: - mean_returns = sum(returns) / float(len(returns)) - else: - mean_returns = 0 - - # If we're only evaluating, we're done, return the average. - if only_eval: - return mean_returns - # Store new trajectories. - if new_trajectories: - self._trajectories[epoch_id].extend(new_trajectories) - - # Mark that epoch epoch_id has changed. - if epoch_id in self._saved_epochs_unchanged: - self._saved_epochs_unchanged = [e for e in self._saved_epochs_unchanged - if e != epoch_id] - - # Remove epochs not intended to be in the buffer - current_trajectories = { - key: value for key, value in self._trajectories.items() - if key >= epoch_id - self._n_replay_epochs} - self._trajectories = collections.defaultdict(list) - self._trajectories.update(current_trajectories) - - # Update statistics. - self._n_trajectories += len(new_trajectories) - self._n_interactions += sum([len(traj) for traj in new_trajectories]) - - return mean_returns - - def n_trajectories(self, epochs=None): - # TODO(henrykm) support selection of epochs if really necessary (will - # require a dump of a list of lengths in save_to_file - del epochs - return self._n_trajectories - - def n_interactions(self, epochs=None): - # TODO(henrykm) support selection of epochs if really necessary (will - # require a dump of a list of lengths in save_to_file - del epochs - return self._n_interactions - - def _random_slice(self, trajectory, max_slice_length, margin): - """Returns a random TimeStepBatch slice from a given trajectory.""" - # Sample a slice from the trajectory. - slice_start = np.random.randint( - _n_slices(trajectory, max_slice_length, margin) - ) - - # Convert the whole trajectory to Numpy while adding the margin. The - # result is cached, so we don't have to repeat this for every sample. - timestep_batch = trajectory.to_np(margin, self._timestep_to_np) - - # Slice and yield the result. - slice_end = slice_start + ( - max_slice_length or timestep_batch.observation.shape[0] - ) - return fastmath.nested_map( - lambda x: x[slice_start:slice_end], timestep_batch - ) - - def _trajectory_stream(self, epochs=None, max_slice_length=None, - sample_trajectories_uniformly=False, margin=0): - """Return a stream of random trajectory slices from the specified epochs. - - Args: - epochs: a list of epochs to use; we use all epochs if None - max_slice_length: maximum length of the slices of trajectories to return - sample_trajectories_uniformly: whether to sample trajectories uniformly, - or proportionally to the number of slices in each trajectory (default) - margin: number of extra steps after "done" that should be included in - slices, so that networks see the terminal states in the training data - - Yields: - random trajectory slices sampled uniformly from all slices of length - up to max_slice_length in all specified epochs - """ - # {int: array[int]}; - # epoch_to_ns_slices[epoch][i] = n_slices(self._trajectories[epoch][i]) - # It stores arrays for faster sampling. - epoch_to_ns_slices = {} - # {int: int}; - # epoch_to_total_n_slices[epoch] = sum(epoch_to_ns_slices[epoch]) - epoch_to_total_n_slices = {} - # [int]: list of epoch indices to sample from. - epoch_indices = [] - # epoch_to_total_n_slices filtered using epoch_indices. It's an array for - # faster sampling. - sampling_epoch_weights = None - - def new_epoch(epoch_id): - """Updates the lists defined above to include the new epoch.""" - all_epochs = list(self._trajectories.keys()) - max_epoch = max(all_epochs) + 1 - - # Calculate the numbers of slices for the new epoch. - epoch_to_ns_slices[epoch_id] = np.array([ - _n_slices(trajectory, max_slice_length, margin) - for trajectory in self._trajectories[epoch_id] - ]) - epoch_to_total_n_slices[epoch_id] = np.sum( - epoch_to_ns_slices[epoch_id] - ) - - # Update the indices of epochs to sample from. - new_epoch_indices = epochs or all_epochs - new_epoch_indices = [ - # So -1 means "last". - ep % max_epoch for ep in new_epoch_indices - ] - # Remove duplicates and consider only epochs where some trajectories - # were recorded and that we have processed in new_epoch. - new_epoch_indices = [ - epoch_id for epoch_id in set(new_epoch_indices) - if self._trajectories[epoch_id] and epoch_id in epoch_to_ns_slices - ] - epoch_indices[:] = new_epoch_indices - - nonlocal sampling_epoch_weights - sampling_epoch_weights = np.array(list( - epoch_to_total_n_slices[ep] for ep in epoch_indices - )) - - while True: - # If we haven't collected any trajectories yet, yield an example - # trajectory. It's needed to determine the input/output shapes of - # networks. - if not self._trajectories: - yield self._example_trajectory - continue - - # Catch up if we have a new epoch or we've restarted the experiment. - for epoch_id in self._trajectories.keys() - epoch_to_ns_slices.keys(): # pylint:disable=g-builtin-op - new_epoch(epoch_id) - - # Sample an epoch proportionally to number of slices in each epoch. - epoch_id = _sample_proportionally(epoch_indices, sampling_epoch_weights) - epoch = self._trajectories[epoch_id] - - # Sample a trajectory proportionally to number of slices in each one. - if sample_trajectories_uniformly: - slices_per_trajectory = np.ones((len(epoch),)) - else: - slices_per_trajectory = epoch_to_ns_slices[epoch_id] - trajectory = _sample_proportionally(epoch, slices_per_trajectory) - - yield trajectory - - def trajectory_slice_stream( - self, - epochs=None, - max_slice_length=None, - sample_trajectories_uniformly=False, - margin=0, - trajectory_stream_preprocessing_fn=None, - ): - """Return a stream of random trajectory slices from the specified epochs. - - Args: - epochs: a list of epochs to use; we use all epochs if None - max_slice_length: maximum length of the slices of trajectories to return - sample_trajectories_uniformly: whether to sample trajectories uniformly, - or proportionally to the number of slices in each trajectory (default) - margin: number of extra steps after "done" that should be included in - slices, so that networks see the terminal states in the training data - trajectory_stream_preprocessing_fn: function to apply to the trajectory - stream before batching; can be used e.g. to filter trajectories - - Yields: - random trajectory slices sampled uniformly from all slices of length - up to max_slice_length in all specified epochs - """ - trajectory_stream = self._trajectory_stream( - epochs=epochs, - max_slice_length=max_slice_length, - sample_trajectories_uniformly=sample_trajectories_uniformly, - margin=margin, - ) - - if trajectory_stream_preprocessing_fn is not None: - trajectory_stream = trajectory_stream_preprocessing_fn(trajectory_stream) - - for trajectory in trajectory_stream: - yield self._random_slice(trajectory, max_slice_length, margin) - - def trajectory_batch_stream( - self, - batch_size, - epochs=None, - max_slice_length=None, - min_slice_length=None, - margin=0, - sample_trajectories_uniformly=False, - trajectory_stream_preprocessing_fn=None, - ): - """Return a stream of trajectory batches from the specified epochs. - - This function returns a stream of tuples of numpy arrays (tensors). - If tensors have different lengths, they will be padded by 0. - - Args: - batch_size: the size of the batches to return - epochs: a list of epochs to use; we use all epochs if None - max_slice_length: maximum length of the slices of trajectories to return - min_slice_length: minimum length of the slices of trajectories to return - margin: number of extra steps after "done" that should be included in - slices, so that networks see the terminal states in the training data - sample_trajectories_uniformly: whether to sample trajectories uniformly, - or proportionally to the number of slices in each trajectory (default) - trajectory_stream_preprocessing_fn: function to apply to the trajectory - stream before batching; can be used e.g. to filter trajectories - - Yields: - batches of trajectory slices sampled uniformly from all slices of length - at least min_slice_length and up to max_slice_length in all specified - epochs - """ - def pad(tensor_list): - # Replace Nones with valid tensors. - not_none_tensors = [t for t in tensor_list if t is not None] - assert not_none_tensors, 'All tensors to pad are None.' - prototype = np.zeros_like(not_none_tensors[0]) - tensor_list = [t if t is not None else prototype for t in tensor_list] - - max_len = max([t.shape[0] for t in tensor_list]) - if min_slice_length is not None: - max_len = max(max_len, min_slice_length) - min_len = min([t.shape[0] for t in tensor_list]) - if max_len == min_len: # No padding needed. - return np.array(tensor_list) - - pad_len = 2**int(np.ceil(np.log2(max_len))) - return np.array([_zero_pad(t, (0, pad_len - t.shape[0]), axis=0) - for t in tensor_list]) - - trajectory_slice_stream = self.trajectory_slice_stream( - epochs=epochs, - max_slice_length=max_slice_length, - sample_trajectories_uniformly=sample_trajectories_uniformly, - margin=margin, - trajectory_stream_preprocessing_fn=trajectory_stream_preprocessing_fn, - ) - - cur_batch = [] - for t in trajectory_slice_stream: - cur_batch.append(t) - if len(cur_batch) == batch_size: - # Make a nested TimeStepBatch of lists out of a list of TimeStepBatches. - timestep_batch = fastmath.nested_zip(cur_batch) - # Actions, rewards and returns in the trajectory slice have shape - # [batch_size, trajectory_length], which we denote as [B, L]. - # Observations are more complex: [B, L] + S, where S is the shape of the - # observation space (self.observation_space.shape). - # We stop the recursion at level 1, so we pass lists of arrays into - # pad(). - yield fastmath.nested_map(pad, timestep_batch, level=1) - cur_batch = [] diff --git a/trax/rl/task_test.py b/trax/rl/task_test.py deleted file mode 100644 index 90630fbb4..000000000 --- a/trax/rl/task_test.py +++ /dev/null @@ -1,363 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RL training.""" - -import os -from absl.testing import absltest -import gym -import numpy as np -from trax import test_utils -from trax.rl import task as rl_task - - -class DummyEnv: - """Dummy Env class for testing.""" - - observation_space = gym.spaces.Box(-2, 2, shape=(2,)) - action_space = gym.spaces.Discrete(2) - - def reset(self): - self._step = 0 - return np.ones((2,)) - - def step(self, action): - del action - info = { - 'control_mask': self._step % 2 == 0, - 'discount_mask': self._step % 3 == 0, - } - self._step += 1 - return np.ones((2,)), 0.0, False, info - - -class TaskTest(absltest.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - - def _extend( - self, trajectory, action=0, dist_inputs=0, reward=0, done=False, - new_observation=0, - ): - trajectory.extend( - action=action, dist_inputs=dist_inputs, reward=reward, done=done, - new_observation=new_observation, - ) - - def test_trajectory_len(self): - """Test that trajectory length is equal to the number of observations.""" - tr = rl_task.Trajectory(observation=0) - for _ in range(5): - self._extend(tr) - self.assertLen(tr, 6) - - def test_empty_trajectory_last_observation(self): - """Test that last_observation is the one passed in __init__.""" - tr = rl_task.Trajectory(observation=123) - self.assertEqual(tr.last_observation, 123) - - def test_nonempty_trajectory_last_observation(self): - """Test that last_observation is the one passed in the last extend().""" - tr = rl_task.Trajectory(observation=123) - for _ in range(5): - self._extend(tr) - self._extend(tr, new_observation=321) - self.assertEqual(tr.last_observation, 321) - - def test_trajectory_done_get_and_set(self): - """Test that we can get and set the `done` flag of a trajectory.""" - tr = rl_task.Trajectory(observation=123) - self._extend(tr) - self.assertFalse(tr.done) - tr.done = True - self.assertTrue(tr.done) - - def test_trajectory_suffix_len(self): - """Test that a trajectory suffix has the correct length.""" - tr = rl_task.Trajectory(observation=0) - for _ in range(5): - self._extend(tr) - tr_suffix = tr.suffix(length=3) - self.assertLen(tr_suffix, 3) - - def test_trajectory_suffix_observations(self): - """Test that a trajectory suffix has the correct observations.""" - tr = rl_task.Trajectory(observation=0) - for obs in range(1, 6): - self._extend(tr, new_observation=obs) - tr_suffix = tr.suffix(length=4) - self.assertEqual([ts.observation for ts in tr_suffix.timesteps], [2, 3, 4]) - self.assertEqual(tr_suffix.last_observation, 5) - - def test_trajectory_to_np_shape(self): - """Test that the shape of a to_np result matches the trajectory length.""" - tr = rl_task.Trajectory(observation=np.zeros((2, 3))) - for _ in range(5): - self._extend(tr, new_observation=np.zeros((2, 3))) - tr_np = tr.to_np() - self.assertEqual(tr_np.observation.shape, (len(tr), 2, 3)) - self.assertEqual(tr_np.action.shape, (len(tr),)) - - def test_trajectory_to_np_shape_after_extend(self): - """Test that the shape of a to_np result grows after calling extend().""" - tr = rl_task.Trajectory(observation=0) - for _ in range(5): - self._extend(tr) - len_before = tr.to_np().observation.shape[0] - self._extend(tr) - len_after = tr.to_np().observation.shape[0] - self.assertEqual(len_after, len_before + 1) - - def test_trajectory_to_np_observations(self): - """Test that to_np returns correct observations.""" - tr = rl_task.Trajectory(observation=0) - for obs in range(1, 3): - self._extend(tr, new_observation=obs) - tr_np = tr.to_np() - np.testing.assert_array_equal(tr_np.observation, [0, 1, 2]) - - def test_trajectory_to_np_adds_margin(self): - """Test that to_np adds a specified margin.""" - tr = rl_task.Trajectory(observation=2) - for _ in range(2): - self._extend(tr, new_observation=2) - tr_np = tr.to_np(margin=2) - np.testing.assert_array_equal(tr_np.observation, [2, 2, 2, 0]) - np.testing.assert_array_equal(tr_np.mask, [1, 1, 0, 0]) - - def test_trajectory_to_np_without_margin_cuts_last_observation(self): - """Test that to_np with margin=0 cuts the last observation.""" - tr = rl_task.Trajectory(observation=0) - for obs in range(1, 4): - self._extend(tr, new_observation=obs) - tr_np = tr.to_np(margin=0) - np.testing.assert_array_equal(tr_np.observation, [0, 1, 2]) - - def test_task_random_initial_trajectories_and_max_steps(self): - """Test generating initial random trajectories, stop at max steps.""" - task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=9) - stream = task.trajectory_slice_stream(max_slice_length=1) - next_slice = next(stream) - self.assertEqual(next_slice.observation.shape, (1, 2)) - - def test_time_limit_terminates_epsiodes(self): - """Test that episodes are terminated upon reaching `time_limit` steps.""" - task = rl_task.RLTask( - DummyEnv(), initial_trajectories=3, max_steps=10, time_limit=10 - ) - trajectories = task.trajectories[0] # Get trajectories from epoch 0. - self.assertLen(trajectories, 3) - for trajectory in trajectories: - self.assertTrue(trajectory.done) - # max_steps + 1 (the initial observation doesn't count). - self.assertLen(trajectory, 11) - - def test_max_steps_doesnt_terminate_epsiodes(self): - """Test that episodes are not terminated upon reaching `max_steps` steps.""" - task = rl_task.RLTask( - DummyEnv(), initial_trajectories=2, max_steps=5, time_limit=10 - ) - trajectories = task.trajectories[0] # Get trajectories from epoch 0. - self.assertLen(trajectories, 2) - # The trajectory should be cut in half. The first half should not be "done". - self.assertFalse(trajectories[0].done) - self.assertLen(trajectories[0], 6) # max_steps + 1 - # The second half should be "done". - self.assertTrue(trajectories[1].done) - self.assertLen(trajectories[1], 6) # max_steps + 1 - - def test_collects_specified_number_of_interactions(self): - """Test that the specified number of interactions are collected.""" - task = rl_task.RLTask( - DummyEnv(), initial_trajectories=0, max_steps=3, time_limit=20 - ) - task.collect_trajectories(policy=(lambda _: (0, 0)), n_interactions=10) - trajectories = task.trajectories[1] # Get trajectories from epoch 1. - n_interactions = 0 - for trajectory in trajectories: - n_interactions += len(trajectory) - 1 - self.assertEqual(n_interactions, 10) - - def test_collects_specified_number_of_trajectories(self): - """Test that the specified number of interactions are collected.""" - task = rl_task.RLTask( - DummyEnv(), initial_trajectories=0, max_steps=3, time_limit=20 - ) - task.collect_trajectories(policy=(lambda _: (0, 0)), n_trajectories=3) - trajectories = task.trajectories[1] # Get trajectories from epoch 1. - self.assertLen(trajectories, 3) - - def test_task_save_init(self): - """Test saving and re-initialization.""" - task1 = rl_task.RLTask(DummyEnv(), initial_trajectories=13, - max_steps=9, gamma=0.9) - self.assertLen(task1.trajectories[0], 13) - self.assertEqual(task1.max_steps, 9) - self.assertEqual(task1.gamma, 0.9) - temp_file = os.path.join(self.create_tempdir().full_path, 'task.pkl') - task1.save_to_file(temp_file) - task2 = rl_task.RLTask(DummyEnv(), initial_trajectories=3, - max_steps=19, gamma=1.0) - self.assertLen(task2.trajectories[0], 3) - self.assertEqual(task2.max_steps, 19) - self.assertEqual(task2.gamma, 1.0) - task2.init_from_file(temp_file) - self.assertLen(task2.trajectories[0], 13) - self.assertEqual(task2.max_steps, 9) - self.assertEqual(task2.gamma, 0.9) - - def test_task_epochs_index_minusone(self): - """Test that the epoch index -1 means last epoch and updates to it.""" - obs = np.zeros((2,)) - tr1 = rl_task.Trajectory(obs) - self._extend(tr1, new_observation=obs, done=True) - task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) - stream = task.trajectory_slice_stream(epochs=[-1], max_slice_length=1) - next_slice = next(stream) - np.testing.assert_equal(next_slice.observation, np.zeros((1, 2))) - task.collect_trajectories(policy=(lambda _: (0, 0)), n_trajectories=1) - next_slice = next(stream) - np.testing.assert_equal(next_slice.observation, np.ones((1, 2))) - - def test_trajectory_slice_stream_shape(self): - """Test the shape yielded by trajectory stream.""" - obs = np.zeros((12, 13)) - tr1 = rl_task.Trajectory(obs) - self._extend(tr1, new_observation=obs, done=True) - task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) - stream = task.trajectory_slice_stream(max_slice_length=1) - next_slice = next(stream) - self.assertEqual(next_slice.observation.shape, (1, 12, 13)) - - def test_trajectory_slice_stream_long_slice(self): - """Test trajectory stream with slices of longer length.""" - obs = np.zeros((12, 13)) - tr1 = rl_task.Trajectory(obs) - self._extend(tr1, new_observation=obs) - self._extend(tr1, new_observation=obs, done=True) - task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) - stream = task.trajectory_slice_stream(max_slice_length=2) - next_slice = next(stream) - self.assertEqual(next_slice.observation.shape, (2, 12, 13)) - - def test_trajectory_slice_stream_sampling_uniform(self): - """Test if the trajectory stream samples uniformly.""" - # Long trajectory of 0s. - tr1 = rl_task.Trajectory(0) - for _ in range(100): - self._extend(tr1) - self._extend(tr1, new_observation=200, done=True) - # Short trajectory of 101. - tr2 = rl_task.Trajectory(101) - self._extend(tr2, new_observation=200, done=True) - task = rl_task.RLTask( - DummyEnv(), initial_trajectories=[tr1, tr2], max_steps=9) - - # Stream of both. Check that we're sampling by slice, not by trajectory. - stream = task.trajectory_slice_stream(max_slice_length=1) - slices = [] - for _ in range(10): - next_slice = next(stream) - assert next_slice.observation.shape[0] == 1 - slices.append(next_slice.observation[-1]) - mean_obs = sum(slices) / float(len(slices)) - # Average should be around 1 sampling from 0x100, 101 uniformly. - self.assertLess(mean_obs, 31) # Sampling 101 even 3 times is unlikely. - self.assertLen(slices, 10) - - def test_trajectory_slice_stream_sampling_by_trajectory(self): - """Test if the trajectory stream samples by trajectory.""" - # Long trajectory of 0s. - tr1 = rl_task.Trajectory(0) - for _ in range(100): - self._extend(tr1) - self._extend(tr1, new_observation=200, done=True) - # Short trajectory of 101. - tr2 = rl_task.Trajectory(101) - self._extend(tr2, new_observation=200, done=True) - task = rl_task.RLTask( - DummyEnv(), initial_trajectories=[tr1, tr2], max_steps=9) - - # Stream of both. Check that we're sampling by trajectory. - stream = task.trajectory_slice_stream( - max_slice_length=1, sample_trajectories_uniformly=True) - slices = [] - for _ in range(10): - next_slice = next(stream) - assert next_slice.observation.shape[0] == 1 - slices.append(next_slice.observation[-1]) - mean_obs = sum(slices) / float(len(slices)) - # Average should be around 50, sampling from {0, 101} uniformly. - # Sampling 101 < 2 times has low probability (but it possible, flaky test). - self.assertGreater(mean_obs, 20) - self.assertLen(slices, 10) - - def test_trajectory_slice_stream_margin(self): - """Test trajectory stream with an added margin.""" - tr1 = rl_task.Trajectory(0) - self._extend(tr1, new_observation=1) - self._extend(tr1, new_observation=1) - self._extend( - tr1, new_observation=1, action=1, dist_inputs=2, reward=3, done=True - ) - task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) - - # Stream of slices without the final state. - stream1 = task.trajectory_slice_stream(max_slice_length=4, margin=3) - got_done = False - for _ in range(20): - next_slice = next(stream1) - self.assertEqual(next_slice.observation.shape, (4,)) - if next_slice.done[0]: - # In the slice, first we have the last timestep in the actual - # trajectory, so observation = 1. - # Then comes the first timestep in the margin, which has the final - # observation from the trajectory: observation = 1. - # The remaining timesteps have 0 observations. - np.testing.assert_array_equal(next_slice.observation, [1, 1, 0, 0]) - # In the margin, done = True and mask = 0. - for i in range(1, next_slice.observation.shape[0]): - self.assertTrue(next_slice.done[i]) - self.assertFalse(next_slice.mask[i]) - got_done = True - # Assert that we got a done somewhere, otherwise the test is not triggered. - # Not getting done has low probability (1/2^20) but is possible, flaky test. - self.assertTrue(got_done) - - def test_trajectory_batch_stream_propagates_env_info(self): - task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=4) - stream = task.trajectory_batch_stream(batch_size=1, max_slice_length=4) - tr_slice = next(stream) - # control_mask = step % 2 == 0, discount_mask = step % 3 == 0. - np.testing.assert_array_equal( - tr_slice.env_info.control_mask, [[1, 0, 1, 0]] - ) - np.testing.assert_array_equal( - tr_slice.env_info.discount_mask, [[1, 0, 0, 1]] - ) - - def test_trajectory_batch_stream_shape(self): - task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=10) - batch_stream = task.trajectory_batch_stream( - batch_size=3, min_slice_length=4, max_slice_length=4 - ) - batch = next(batch_stream) - self.assertEqual(batch.observation.shape, (3, 4, 2)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/training.py b/trax/rl/training.py deleted file mode 100644 index d5cead516..000000000 --- a/trax/rl/training.py +++ /dev/null @@ -1,1158 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for RL training in Trax.""" - -import contextlib -import functools -import os -import pickle -import time - -import gin -import numpy as np -import tensorflow as tf - -from trax import data -from trax import fastmath -from trax import jaxboard -from trax import layers as tl -from trax import models -from trax import shapes -from trax import supervised -from trax.fastmath import numpy as jnp -from trax.optimizers import adam -from trax.rl import advantages -from trax.rl import distributions -from trax.rl import normalization # So gin files see it. # pylint: disable=unused-import -from trax.rl import policy_tasks -from trax.rl import task as rl_task -from trax.supervised import lr_schedules as lr - - -class Agent: - """Abstract class for RL agents, presenting the required API.""" - - def __init__(self, - task: rl_task.RLTask, - n_trajectories_per_epoch=None, - n_interactions_per_epoch=None, - n_eval_episodes=0, - eval_steps=None, - eval_temperatures=(0.0,), - only_eval=False, - output_dir=None, - timestep_to_np=None): - """Configures the Agent. - - Note that subclasses can have many more arguments, which will be configured - using defaults and gin. But task and output_dir are passed explicitly. - - Args: - task: RLTask instance, which defines the environment to train on. - n_trajectories_per_epoch: How many new trajectories to collect in each - epoch. - n_interactions_per_epoch: How many interactions to collect in each epoch. - n_eval_episodes: Number of episodes to play with policy at - temperature 0 in each epoch -- used for evaluation only. - eval_steps: an optional list of max_steps to use for evaluation - (defaults to task.max_steps). - eval_temperatures: we always train with temperature 1 and evaluate with - temperature specified in the eval_temperatures list - (defaults to [0.0, 0.5]) - only_eval: If set to True, then trajectories are collected only for - for evaluation purposes, but they are not recorded. - output_dir: Path telling where to save outputs such as checkpoints. - timestep_to_np: Timestep-to-numpy function to override in the task. - """ - if n_trajectories_per_epoch is None == n_interactions_per_epoch is None: - raise ValueError( - 'Exactly one of n_trajectories_per_epoch or ' - 'n_interactions_per_epoch should be specified.' - ) - self._epoch = 0 - self._task = task - self._eval_steps = eval_steps or [task.max_steps] - if timestep_to_np is not None: - self._task.timestep_to_np = timestep_to_np - self._n_trajectories_per_epoch = n_trajectories_per_epoch - self._n_interactions_per_epoch = n_interactions_per_epoch - self._only_eval = only_eval - self._output_dir = output_dir - self._avg_returns = [] - self._n_eval_episodes = n_eval_episodes - self._eval_temperatures = eval_temperatures - self._avg_returns_temperatures = { - eval_t: {step: [] for step in self._eval_steps - } for eval_t in eval_temperatures - } - if self._output_dir is not None: - self.init_from_file() - - @property - def current_epoch(self): - """Returns current step number in this training session.""" - return self._epoch - - @property - def task(self): - """Returns the task.""" - return self._task - - @property - def avg_returns(self): - return self._avg_returns - - def save_gin(self, summary_writer=None): - assert self._output_dir is not None - config_path = os.path.join(self._output_dir, 'config.gin') - config_str = gin.operative_config_str() - with tf.io.gfile.GFile(config_path, 'w') as f: - f.write(config_str) - if summary_writer is not None: - summary_writer.text( - 'gin_config', jaxboard.markdownify_operative_config_str(config_str) - ) - - def save_to_file(self, file_name='rl.pkl', - task_file_name='trajectories.pkl'): - """Save current epoch number and average returns to file.""" - assert self._output_dir is not None - task_path = os.path.join(self._output_dir, task_file_name) - self._task.save_to_file(task_path) - file_path = os.path.join(self._output_dir, file_name) - dictionary = {'epoch': self._epoch, 'avg_returns': self._avg_returns} - for eval_t in self._eval_temperatures: - dictionary['avg_returns_temperature_{}'.format( - eval_t)] = self._avg_returns_temperatures[eval_t] - with tf.io.gfile.GFile(file_path, 'wb') as f: - pickle.dump(dictionary, f) - - def init_from_file(self, file_name='rl.pkl', - task_file_name='trajectories.pkl'): - """Initialize epoch number and average returns from file.""" - assert self._output_dir is not None - task_path = os.path.join(self._output_dir, task_file_name) - if tf.io.gfile.exists(task_path): - self._task.init_from_file(task_path) - file_path = os.path.join(self._output_dir, file_name) - if not tf.io.gfile.exists(file_path): - return - with tf.io.gfile.GFile(file_path, 'rb') as f: - dictionary = pickle.load(f) - self._epoch = dictionary['epoch'] - self._avg_returns = dictionary['avg_returns'] - for eval_t in self._eval_temperatures: - self._avg_returns_temperatures[eval_t] = dictionary[ - 'avg_returns_temperature_{}'.format(eval_t)] - - def _collect_trajectories(self): - return self.task.collect_trajectories( - self.policy, - n_trajectories=self._n_trajectories_per_epoch, - n_interactions=self._n_interactions_per_epoch, - only_eval=self._only_eval, - epoch_id=self._epoch - ) - - def policy(self, trajectory, temperature=1.0): - """Policy function that allows to play using this trainer. - - Args: - trajectory: an instance of trax.rl.task.Trajectory - temperature: temperature used to sample from the policy (default=1.0) - - Returns: - a pair (action, dist_inputs) where action is the action taken and - dist_inputs is the parameters of the policy distribution, that will later - be used for training. - """ - raise NotImplementedError - - def train_epoch(self): - """Trains this Agent for one epoch -- main RL logic goes here.""" - raise NotImplementedError - - @contextlib.contextmanager - def _open_summary_writer(self): - """Opens the Jaxboard summary writer wrapped by a context manager. - - Yields: - A Jaxboard summary writer wrapped in a GeneratorContextManager object. - Elements of the lists correspond to the training and evaluation task - directories created during initialization. If there is no output_dir - provided, yields None. - """ - if self._output_dir is not None: - writer = jaxboard.SummaryWriter(os.path.join(self._output_dir, 'rl')) - try: - yield writer - finally: - writer.close() - else: - yield None - - def run(self, n_epochs=1, n_epochs_is_total_epochs=False): - """Runs this loop for n epochs. - - Args: - n_epochs: Stop training after completing n steps. - n_epochs_is_total_epochs: if True, consider n_epochs as the total - number of epochs to train, including previously trained ones - """ - with self._open_summary_writer() as sw: - n_epochs_to_run = n_epochs - if n_epochs_is_total_epochs: - n_epochs_to_run -= self._epoch - cur_n_interactions = 0 - for _ in range(n_epochs_to_run): - self._epoch += 1 - cur_time = time.time() - avg_return = self._collect_trajectories() - self._avg_returns.append(avg_return) - if self._n_trajectories_per_epoch: - supervised.trainer_lib.log( - 'Collecting %d episodes took %.2f seconds.' - % (self._n_trajectories_per_epoch, time.time() - cur_time)) - else: - supervised.trainer_lib.log( - 'Collecting %d interactions took %.2f seconds.' - % (self._n_interactions_per_epoch, time.time() - cur_time)) - supervised.trainer_lib.log( - 'Average return in epoch %d was %.2f.' % (self._epoch, avg_return)) - if self._n_eval_episodes > 0: - for steps in self._eval_steps: - for eval_t in self._eval_temperatures: - avg_return_temperature = self.task.collect_trajectories( - functools.partial(self.policy, temperature=eval_t), - n_trajectories=self._n_eval_episodes, - max_steps=steps, - only_eval=True) - supervised.trainer_lib.log( - 'Eval return in epoch %d with temperature %.2f was %.2f.' - % (self._epoch, eval_t, avg_return_temperature)) - self._avg_returns_temperatures[eval_t][steps].append( - avg_return_temperature) - - if sw is not None: - sw.scalar('timing/collect', time.time() - cur_time, - step=self._epoch) - sw.scalar('rl/avg_return', avg_return, step=self._epoch) - if self._n_eval_episodes > 0: - for steps in self._eval_steps: - for eval_t in self._eval_temperatures: - sw.scalar( - 'rl/avg_return_temperature%.2f_steps%d' % (eval_t, steps), - self._avg_returns_temperatures[eval_t][steps][-1], - step=self._epoch) - sw.scalar('rl/n_interactions', self.task.n_interactions(), - step=self._epoch) - sw.scalar('rl/n_interactions_per_second', - (self.task.n_interactions() - cur_n_interactions)/ \ - (time.time() - cur_time), - step=self._epoch) - cur_n_interactions = self.task.n_interactions() - sw.scalar('rl/n_trajectories', self.task.n_trajectories(), - step=self._epoch) - sw.flush() - - cur_time = time.time() - self.train_epoch() - supervised.trainer_lib.log( - 'RL training took %.2f seconds.' % (time.time() - cur_time)) - - if self._output_dir is not None and self._epoch == 1: - self.save_gin(sw) - if self._output_dir is not None: - self.save_to_file() - - def close(self): - pass - - -class PolicyAgent(Agent): - """Agent that uses a deep learning model for policy. - - Many deep RL methods, such as policy gradient (REINFORCE) or actor-critic fall - into this category, so a lot of classes will be subclasses of this one. But - some methods only have a value or Q function, these are different. - """ - - def __init__(self, task, policy_model=None, policy_optimizer=None, - policy_lr_schedule=lr.multifactor, policy_batch_size=64, - policy_train_steps_per_epoch=500, policy_evals_per_epoch=1, - policy_eval_steps=1, n_eval_episodes=0, - only_eval=False, max_slice_length=1, output_dir=None, **kwargs): - """Configures the policy trainer. - - Args: - task: RLTask instance, which defines the environment to train on. - policy_model: Trax layer, representing the policy model. - functions and eval functions (a.k.a. metrics) are considered to be - outside the core model, taking core model output and data labels as - their two inputs. - policy_optimizer: the optimizer to use to train the policy model. - policy_lr_schedule: learning rate schedule to use to train the policy. - policy_batch_size: batch size used to train the policy model. - policy_train_steps_per_epoch: how long to train policy in each RL epoch. - policy_evals_per_epoch: number of policy trainer evaluations per RL epoch - - only affects metric reporting. - policy_eval_steps: number of policy trainer steps per evaluation - only - affects metric reporting. - n_eval_episodes: number of episodes to play with policy at - temperature 0 in each epoch -- used for evaluation only - only_eval: If set to True, then trajectories are collected only for - for evaluation purposes, but they are not recorded. - max_slice_length: the maximum length of trajectory slices to use. - output_dir: Path telling where to save outputs (evals and checkpoints). - **kwargs: arguments for the superclass Agent. - """ - super().__init__( - task, - n_eval_episodes=n_eval_episodes, - output_dir=output_dir, - **kwargs - ) - self._policy_batch_size = policy_batch_size - self._policy_train_steps_per_epoch = policy_train_steps_per_epoch - self._policy_evals_per_epoch = policy_evals_per_epoch - self._policy_eval_steps = policy_eval_steps - self._only_eval = only_eval - self._max_slice_length = max_slice_length - self._policy_dist = distributions.create_distribution(task.action_space) - - # Inputs to the policy model are produced by self._policy_batches_stream. - self._policy_inputs = data.inputs.Inputs( - train_stream=lambda _: self.policy_batches_stream()) - - policy_model = functools.partial( - policy_model, - policy_distribution=self._policy_dist, - ) - - # This is the policy Trainer that will be used to train the policy model. - # * inputs to the trainer come from self.policy_batches_stream - # * outputs, targets and weights are passed to self.policy_loss - self._policy_trainer = supervised.Trainer( - model=policy_model, - optimizer=policy_optimizer, - lr_schedule=policy_lr_schedule(), - loss_fn=self.policy_loss, - inputs=self._policy_inputs, - output_dir=output_dir, - metrics=self.policy_metrics, - ) - self._policy_collect_model = tl.Accelerate( - policy_model(mode='collect'), n_devices=1) - policy_batch = next(self.policy_batches_stream()) - self._policy_collect_model.init(shapes.signature(policy_batch)) - self._policy_eval_model = tl.Accelerate( - policy_model(mode='eval'), n_devices=1) # Not collecting stats - self._policy_eval_model.init(shapes.signature(policy_batch)) - - @property - def policy_loss(self): - """Policy loss.""" - return NotImplementedError - - @property - def policy_metrics(self): - return {'policy_loss': self.policy_loss} - - def policy_batches_stream(self): - """Use self.task to create inputs to the policy model.""" - return NotImplementedError - - def policy(self, trajectory, temperature=1.0): - """Chooses an action to play after a trajectory.""" - model = self._policy_collect_model - if temperature != 1.0: # When evaluating (t != 1.0), don't collect stats - model = self._policy_eval_model - model.state = self._policy_collect_model.state - model.replicate_weights(self._policy_trainer.model_weights) - tr_slice = trajectory.suffix(self._max_slice_length) - trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) - # Add batch dimension to trajectory_np and run the model. - pred = model(trajectory_np.observation[None, ...]) - # Pick element 0 from the batch (the only one), last (current) timestep. - pred = pred[0, -1, :] - sample = self._policy_dist.sample(pred, temperature=temperature) - result = (sample, pred) - if fastmath.is_backend(fastmath.Backend.JAX): - result = fastmath.nested_map(lambda x: x.copy(), result) - return result - - def train_epoch(self): - """Trains RL for one epoch.""" - # When restoring, calculate how many evals are remaining. - n_evals = remaining_evals( - self._policy_trainer.step, - self._epoch, - self._policy_train_steps_per_epoch, - self._policy_evals_per_epoch) - for _ in range(n_evals): - self._policy_trainer.train_epoch( - self._policy_train_steps_per_epoch // self._policy_evals_per_epoch, - self._policy_eval_steps) - - def close(self): - self._policy_trainer.close() - super().close() - - -def remaining_evals(cur_step, epoch, train_steps_per_epoch, evals_per_epoch): - """Helper function to calculate remaining evaluations for a trainer. - - Args: - cur_step: current step of the supervised trainer - epoch: current epoch of the RL trainer - train_steps_per_epoch: supervised trainer steps per RL epoch - evals_per_epoch: supervised trainer evals per RL epoch - - Returns: - number of remaining evals to do this epoch - - Raises: - ValueError if the provided numbers indicate a step mismatch - """ - if epoch < 1: - raise ValueError('Epoch must be at least 1, got %d' % epoch) - prev_steps = (epoch - 1) * train_steps_per_epoch - done_steps_this_epoch = cur_step - prev_steps - if done_steps_this_epoch < 0: - raise ValueError('Current step (%d) < previously done steps (%d).' - % (cur_step, prev_steps)) - train_steps_per_eval = train_steps_per_epoch // evals_per_epoch - if done_steps_this_epoch % train_steps_per_eval != 0: - raise ValueError('Done steps (%d) must divide train steps per eval (%d).' - % (done_steps_this_epoch, train_steps_per_eval)) - return evals_per_epoch - (done_steps_this_epoch // train_steps_per_eval) - - -class LoopPolicyAgent(Agent): - """Base class for policy-only Agents based on Loop.""" - - def __init__( - self, - task, - model_fn, - value_fn, - weight_fn, - n_replay_epochs, - n_train_steps_per_epoch, - advantage_normalization, - optimizer=adam.Adam, - lr_schedule=lr.multifactor, - batch_size=64, - network_eval_at=None, - n_eval_batches=1, - max_slice_length=1, - trajectory_stream_preprocessing_fn=None, - **kwargs - ): - """Initializes LoopPolicyAgent. - - Args: - task: Instance of trax.rl.task.RLTask. - model_fn: Function (policy_distribution, mode) -> policy_model. - value_fn: Function TimeStepBatch -> array (batch_size, seq_len) - calculating the baseline for advantage calculation. - weight_fn: Function float -> float to apply to advantages when calculating - policy loss. - n_replay_epochs: Number of last epochs to take into the replay buffer; - only makes sense for off-policy algorithms. - n_train_steps_per_epoch: Number of steps to train the policy network for - in each epoch. - advantage_normalization: Whether to normalize the advantages before - passing them to weight_fn. - optimizer: Optimizer for network training. - lr_schedule: Learning rate schedule for network training. - batch_size: Batch size for network training. - network_eval_at: Function step -> bool indicating the training steps, when - network evaluation should be performed. - n_eval_batches: Number of batches to run during network evaluation. - max_slice_length: The length of trajectory slices to run the network on. - trajectory_stream_preprocessing_fn: Function to apply to the trajectory - stream before batching. Can be used e.g. to filter trajectories. - **kwargs: Keyword arguments passed to the superclass. - """ - self._n_train_steps_per_epoch = n_train_steps_per_epoch - super().__init__(task, **kwargs) - - task.set_n_replay_epochs(n_replay_epochs) - self._max_slice_length = max_slice_length - trajectory_batch_stream = task.trajectory_batch_stream( - batch_size, - epochs=[-(ep + 1) for ep in range(n_replay_epochs)], - max_slice_length=self._max_slice_length, - sample_trajectories_uniformly=True, - trajectory_stream_preprocessing_fn=trajectory_stream_preprocessing_fn, - ) - self._policy_dist = distributions.create_distribution(task.action_space) - train_task = policy_tasks.PolicyTrainTask( - trajectory_batch_stream, - optimizer(), - lr_schedule(), - self._policy_dist, - # Without a value network it doesn't make a lot of sense to use - # a better advantage estimator than MC. - advantage_estimator=advantages.monte_carlo(task.gamma, margin=0), - advantage_normalization=advantage_normalization, - value_fn=value_fn, - weight_fn=weight_fn, - ) - eval_task = policy_tasks.PolicyEvalTask(train_task, n_eval_batches) - model_fn = functools.partial( - model_fn, - policy_distribution=self._policy_dist, - ) - - if self._output_dir is not None: - policy_output_dir = os.path.join(self._output_dir, 'policy') - else: - policy_output_dir = None - # Checkpoint every epoch. - checkpoint_at = lambda step: step % n_train_steps_per_epoch == 0 - self._loop = supervised.training.Loop( - model=model_fn(mode='train'), - tasks=[train_task], - eval_model=model_fn(mode='eval'), - eval_tasks=[eval_task], - output_dir=policy_output_dir, - eval_at=network_eval_at, - checkpoint_at=checkpoint_at, - ) - self._collect_model = model_fn(mode='collect') - self._collect_model.init(shapes.signature(train_task.sample_batch)) - - # Validate the restored checkpoints. - # TODO(pkozakowski): Move this to the base class once all Agents use Loop. - if self._loop.step != self._epoch * self._n_train_steps_per_epoch: - raise ValueError( - 'The number of Loop steps must equal the number of Agent epochs ' - 'times the number of steps per epoch, got {}, {} and {}.'.format( - self._loop.step, self._epoch, self._n_train_steps_per_epoch - ) - ) - - @property - def loop(self): - """Loop exposed for testing.""" - return self._loop - - def train_epoch(self): - """Trains RL for one epoch.""" - # Copy policy state accumulated during data collection to the trainer. - self._loop.update_weights_and_state(state=self._collect_model.state) - # Train for the specified number of steps. - self._loop.run(n_steps=self._n_train_steps_per_epoch) - - -class PolicyGradient(LoopPolicyAgent): - """Trains a policy model using policy gradient on the given RLTask.""" - - def __init__(self, task, model_fn, **kwargs): - """Initializes PolicyGradient. - - Args: - task: Instance of trax.rl.task.RLTask. - model_fn: Function (policy_distribution, mode) -> policy_model. - **kwargs: Keyword arguments passed to the superclass. - """ - super().__init__( - task, model_fn, - # We're on-policy, so we can only use data from the last epoch. - n_replay_epochs=1, - # Each gradient computation needs a new data sample, so we do 1 step - # per epoch. - n_train_steps_per_epoch=1, - # Very simple baseline: mean return across trajectories. - value_fn=self._value_fn, - # Weights are just advantages. - weight_fn=(lambda x: x), - # Normalize advantages, because this makes optimization nicer. - advantage_normalization=True, - **kwargs - ) - - def policy(self, trajectory, temperature=1.0): - """Policy function that samples from the trained network.""" - tr_slice = trajectory.suffix(self._max_slice_length) - trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) - return network_policy( - collect_model=self._collect_model, - policy_distribution=self._policy_dist, - loop=self.loop, - trajectory_np=trajectory_np, - temperature=temperature, - ) - - @staticmethod - def _value_fn(trajectory_batch): - # Estimate the value of every state as the mean return across trajectories - # and timesteps in a batch. - value = np.mean(trajectory_batch.return_) - return np.broadcast_to(value, trajectory_batch.return_.shape) - - -@gin.configurable -def sharpened_network_policy( - temperature, - temperature_multiplier=1.0, - **kwargs -): - """Expert function that runs a policy network with lower temperature. - - Args: - temperature: Temperature passed from the Agent. - temperature_multiplier: Multiplier to apply to the temperature to "sharpen" - the policy distribution. Should be <= 1, but this is not a requirement. - **kwargs: Keyword arguments passed to network_policy. - - Returns: - Pair (action, dist_inputs) where action is the action taken and dist_inputs - is the parameters of the policy distribution, that will later be used for - training. - """ - return network_policy( - temperature=(temperature_multiplier * temperature), - **kwargs - ) - - -class ExpertIteration(LoopPolicyAgent): - """Trains a policy model using expert iteration with a given expert.""" - - def __init__( - self, task, model_fn, - expert_policy_fn=sharpened_network_policy, - quantile=0.9, - n_replay_epochs=10, - n_train_steps_per_epoch=1000, - filter_buffer_size=256, - **kwargs - ): - """Initializes ExpertIteration. - - Args: - task: Instance of trax.rl.task.RLTask. - model_fn: Function (policy_distribution, mode) -> policy_model. - expert_policy_fn: Function of the same signature as `network_policy`, to - be used as an expert. The policy will be trained to mimic the expert on - the "solved" trajectories. - quantile: Quantile of best trajectories to be marked as "solved". They - will be used to train the policy. - n_replay_epochs: Number of last epochs to include in the replay buffer. - n_train_steps_per_epoch: Number of policy training steps to run in each - epoch. - filter_buffer_size: Number of trajectories in the trajectory filter - buffer, used to select the best trajectories based on the quantile. - **kwargs: Keyword arguments passed to the superclass. - """ - self._expert_policy_fn = expert_policy_fn - self._quantile = quantile - self._filter_buffer_size = filter_buffer_size - super().__init__( - task, model_fn, - # Don't use a baseline - it's not useful in our weights. - value_fn=(lambda batch: jnp.zeros_like(batch.return_)), - # Don't weight trajectories - the training signal is provided by - # filtering trajectories. - weight_fn=jnp.ones_like, - # Filter trajectories based on the quantile. - trajectory_stream_preprocessing_fn=self._filter_trajectories, - # Advantage normalization is a no-op here. - advantage_normalization=False, - n_replay_epochs=n_replay_epochs, - n_train_steps_per_epoch=n_train_steps_per_epoch, - **kwargs - ) - - def policy(self, trajectory, temperature=1.0): - """Policy function that runs the expert.""" - tr_slice = trajectory.suffix(self._max_slice_length) - trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) - return self._expert_policy_fn( - collect_model=self._collect_model, - policy_distribution=self._policy_dist, - loop=self.loop, - trajectory_np=trajectory_np, - temperature=temperature, - ) - - def _filter_trajectories(self, trajectory_stream): - """Filter trajectories based on the quantile.""" - def trajectory_return(trajectory): - return trajectory.timesteps[0].return_ - - trajectory_buffer = [] - for trajectory in trajectory_stream: - trajectory_buffer.append(trajectory) - if len(trajectory_buffer) == self._filter_buffer_size: - n_best = int((1 - self._quantile) * self._filter_buffer_size) or 1 - trajectory_buffer.sort(key=trajectory_return, reverse=True) - yield from trajectory_buffer[:n_best] - trajectory_buffer.clear() - - -def network_policy( - collect_model, - policy_distribution, - loop, - trajectory_np, - head_index=0, - temperature=1.0, -): - """Policy function powered by a neural network. - - Used to implement Agent.policy() in policy-based agents. - - Args: - collect_model: the model used for collecting trajectories - policy_distribution: an instance of trax.rl.distributions.Distribution - loop: trax.supervised.training.Loop used to train the policy network - trajectory_np: an instance of trax.rl.task.TimeStepBatch - head_index: index of the policy head a multihead model. - temperature: temperature used to sample from the policy (default=1.0) - - Returns: - a pair (action, dist_inputs) where action is the action taken and - dist_inputs is the parameters of the policy distribution, that will later - be used for training. - """ - if temperature == 1.0: - model = collect_model - else: - # When evaluating (t != 1.0), use the evaluation model instead of the - # collection model - some models accumulate normalization statistics - # during data collection, and we don't want to do it in eval to avoid data - # leakage. - model = loop.eval_model - model.state = collect_model.state - # Copying weights from loop.model should work, because the raw model's - # weights should be updated automatically during training, but it doesn't. - # TODO(pkozakowski): Debug. - acc = loop._trainer_per_task[0].accelerated_model_with_loss # pylint: disable=protected-access - model.weights = acc._unreplicate(acc.weights[0]) # pylint: disable=protected-access - # Add batch dimension to trajectory_np and run the model. - pred = model(trajectory_np.observation[None, ...]) - if isinstance(pred, (tuple, list)): - # For multihead models, extract the policy head output. - pred = pred[head_index] - assert pred.shape == ( - 1, trajectory_np.observation.shape[0], policy_distribution.n_inputs - ) - # Pick element 0 from the batch (the only one), last (current) timestep. - pred = pred[0, -1, :] - sample = policy_distribution.sample(pred, temperature=temperature) - result = (sample, pred) - if fastmath.is_backend(fastmath.Backend.JAX): - # The result is composed of mutable numpy arrays. We copy them to avoid - # accidental modification. - result = fastmath.nested_map(lambda x: x.copy(), result) - return result - - -class ValueAgent(Agent): - """Trainer that uses a deep learning model for value function. - - Compute the loss using variants of the Bellman equation. - """ - - def __init__(self, task, - value_body=None, - value_optimizer=None, - value_lr_schedule=lr.multifactor, - value_batch_size=64, - value_train_steps_per_epoch=500, - value_evals_per_epoch=1, - value_eval_steps=1, - exploration_rate=functools.partial( - lr.multifactor, - factors='constant * decay_every', - constant=1., # pylint: disable=redefined-outer-name - decay_factor=0.99, - steps_per_decay=1, - minimum=0.1), - n_eval_episodes=0, - only_eval=False, - n_replay_epochs=1, - max_slice_length=1, - sync_freq=1000, - scale_value_targets=True, - output_dir=None, - **kwargs): - """Configures the value trainer. - - Args: - task: RLTask instance, which defines the environment to train on. - value_body: Trax layer, representing the body of the value model. - functions and eval functions (a.k.a. metrics) are considered to be - outside the core model, taking core model output and data labels as - their two inputs. - value_optimizer: the optimizer to use to train the policy model. - value_lr_schedule: learning rate schedule to use to train the policy. - value_batch_size: batch size used to train the policy model. - value_train_steps_per_epoch: how long to train policy in each RL epoch. - value_evals_per_epoch: number of policy trainer evaluations per RL epoch - - only affects metric reporting. - value_eval_steps: number of policy trainer steps per evaluation - only - affects metric reporting. - exploration_rate: exploration rate schedule - used in the policy method. - n_eval_episodes: number of episodes to play with policy at - temperature 0 in each epoch -- used for evaluation only - only_eval: If set to True, then trajectories are collected only for - for evaluation purposes, but they are not recorded. - n_replay_epochs: Number of last epochs to take into the replay buffer; - only makes sense for off-policy algorithms. - max_slice_length: the maximum length of trajectory slices to use; it is - the second dimenions of the value network output: - (batch, max_slice_length, number of actions) - Higher max_slice_length implies that the network has to predict more - values into the future. - sync_freq: frequency when to synchronize the target - network with the trained network. This is necessary for training the - network on bootstrapped targets, e.g. using n-step returns. - scale_value_targets: If `True`, scale value function targets by - `1 / (1 - gamma)`. We are trying to fix the problem with very large - returns in some games in a way which does not introduce an additional - hyperparameters. - output_dir: Path telling where to save outputs (evals and checkpoints). - **kwargs: arguments for the superclass RLTrainer. - """ - super(ValueAgent, self).__init__( - task, - n_eval_episodes=n_eval_episodes, - output_dir=output_dir, - **kwargs - ) - self._value_batch_size = value_batch_size - self._value_train_steps_per_epoch = value_train_steps_per_epoch - self._value_evals_per_epoch = value_evals_per_epoch - self._value_eval_steps = value_eval_steps - self._only_eval = only_eval - self._max_slice_length = max_slice_length - self._policy_dist = distributions.create_distribution(task.action_space) - self._n_replay_epochs = n_replay_epochs - - self._exploration_rate = exploration_rate() - self._sync_at = (lambda step: step % sync_freq == 0) - - if scale_value_targets: - self._value_network_scale = 1 / (1 - self._task.gamma) - else: - self._value_network_scale = 1 - - value_model = functools.partial( - models.Quality, - body=value_body, - n_actions=self.task.action_space.n) - - self._value_eval_model = value_model(mode='eval') - self._value_eval_model.init(self._value_model_signature) - self._value_eval_jit = tl.jit_forward( - self._value_eval_model.pure_fn, fastmath.local_device_count(), - do_mean=False) - - # Inputs to the value model are produced by self._values_batches_stream. - self._inputs = data.inputs.Inputs( - train_stream=lambda _: self.value_batches_stream()) - - # This is the value Trainer that will be used to train the value model. - # * inputs to the trainer come from self.value_batches_stream - # * outputs, targets and weights are passed to self.value_loss - self._value_trainer = supervised.Trainer( - model=value_model, - optimizer=value_optimizer, - lr_schedule=value_lr_schedule(), - loss_fn=self.value_loss, - inputs=self._inputs, - output_dir=output_dir, - metrics={'value_loss': self.value_loss, - 'value_mean': self.value_mean, - 'returns_mean': self.returns_mean} - ) - value_batch = next(self.value_batches_stream()) - self._eval_model = tl.Accelerate( - value_model(mode='collect'), n_devices=1) - self._eval_model.init(shapes.signature(value_batch)) - - @property - def _value_model_signature(self): - obs_sig = shapes.signature(self._task.observation_space) - target_sig = mask_sig = shapes.ShapeDtype( - shape=(1, 1, self._task.action_space), - ) - inputs_sig = obs_sig.replace(shape=(1, 1) + obs_sig.shape) - return (inputs_sig, target_sig, mask_sig) - - def value_batches_stream(self): - """Use self.task to create inputs to the policy model.""" - raise NotImplementedError - - def policy(self, trajectory, temperature=1): - """Chooses an action to play after a trajectory.""" - raise NotImplementedError - - def train_epoch(self): - """Trains RL for one epoch.""" - # Update the target value network. - self._value_eval_model.weights = self._value_trainer.model_weights - self._value_eval_model.state = self._value_trainer.model_state - - # When restoring, calculate how many evals are remaining. - n_evals = remaining_evals( - self._value_trainer.step, - self._epoch, - self._value_train_steps_per_epoch, - self._value_evals_per_epoch) - for _ in range(n_evals): - self._value_trainer.train_epoch( - self._value_train_steps_per_epoch // self._value_evals_per_epoch, - self._value_eval_steps) - value_metrics = dict( - {'exploration_rate': self._exploration_rate(self._epoch)}) - self._value_trainer.log_metrics(value_metrics, - self._value_trainer._train_sw, 'dqn') # pylint: disable=protected-access - # Update the target value network. - # TODO(henrykm) a bit tricky if sync_at does not coincide with epochs - if self._sync_at(self._value_trainer.step): - self._value_eval_model.weights = self._value_trainer.model_weights - self._value_eval_model.state = self._value_trainer.model_state - - def close(self): - self._value_trainer.close() - super().close() - - @property - def value_mean(self): - """The mean value of actions selected by the behavioral policy.""" - raise NotImplementedError - - @property - def returns_mean(self): - """The mean value of actions selected by the behavioral policy.""" - def f(values, index_max, returns, mask): - del values, index_max - return jnp.sum(returns) / jnp.sum(mask) - return tl.Fn('ReturnsMean', f) - - -class DQN(ValueAgent): - r"""Trains a value model using DQN on the given RLTask. - - Notice that the algorithm and the parameters signficantly diverge from - the original DQN paper. In particular we have separated learning and data - collection. - - The Bellman loss is computed in the value_loss method. The formula takes - the state-action values tensors Q and n-step returns R: - - .. math:: - L(s,a) = Q(s,a) - R(s,a) - - where R is computed in value_batches_stream. In the simplest case of the - 1-step returns we are getting - - .. math:: - L(s,a) = Q(s,a) - r(s,a) - gamma * \max_{a'} Q'(s',a') - - where s' is the state reached after taking action a in state s, Q' is - the target network, gamma is the discount factor and the maximum is taken - with respect to all actions avaliable in the state s'. The tensor Q' is - updated using the sync_freq parameter. - - In code the maximum is visible in the policy method where we take - sample = jnp.argmax(values). The epsilon-greedy policy is taking a random - move with probability epsilon and oterhwise in state s it takes the - action argmax_a Q(s,a). - """ - - def __init__(self, - task, - advantage_estimator=advantages.monte_carlo, - max_slice_length=1, - smoothl1loss=True, - double_dqn=False, - **kwargs): - - self._max_slice_length = max_slice_length - self._margin = max_slice_length-1 - # Our default choice of learning targets for DQN are n-step targets - # implemented in the method td_k. We set the slice used for computation - # of td_k to max_slice_length and we set the "margin" in td_k - # to self._max_slice_length-1; in turn it implies that the shape of the - # returned tensor of n-step targets is - # values[:, :-(self.margin)] = values[:, :1] - self._advantage_estimator = advantage_estimator( - gamma=task.gamma, margin=self._margin) - self._smoothl1loss = smoothl1loss - self._double_dqn = double_dqn - super(DQN, self).__init__(task=task, - max_slice_length=max_slice_length, - **kwargs) - - @property - def value_loss(self): - """Value loss computed using smooth L1 loss or L2 loss.""" - def f(values, actions, returns, mask): - ind_0, ind_1 = np.indices(actions.shape) - # We calculate length using the shape of returns - # and adequatly remove a superflous slice of values. - # An analogous operation is done in value_batches_stream. - length = returns.shape[1] - values = values[:, :length, :] - selected_values = values[ind_0, ind_1, actions] - shapes.assert_same_shape(selected_values, returns) - shapes.assert_same_shape(selected_values, mask) - if self._smoothl1loss: - return tl.SmoothL1Loss().forward((selected_values, returns, mask)) - else: - return tl.L2Loss().forward((selected_values, returns, mask)) - return tl.Fn('ValueLoss', f) - - @property - def _replay_epochs(self): - return [-(ep + 1) for ep in range(self._n_replay_epochs)] - - def value_batches_stream(self): - """Use the RLTask self._task to create inputs to the value model.""" - max_slice_length = self._max_slice_length - min_slice_length = 1 - for np_trajectory in self._task.trajectory_batch_stream( - self._value_batch_size, - max_slice_length=max_slice_length, - min_slice_length=min_slice_length, - margin=0, - epochs=self._replay_epochs, - ): - values_target = self._run_value_model( - np_trajectory.observation, use_eval_model=True) - if self._double_dqn: - values = self._run_value_model( - np_trajectory.observation, use_eval_model=False - ) - index_max = np.argmax(values, axis=-1) - ind_0, ind_1 = np.indices(index_max.shape) - values_max = values_target[ind_0, ind_1, index_max] - else: - values_max = np.array(jnp.max(values_target, axis=-1)) - - # The advantage_estimator returns - # gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards - # - values_max(s_i) - # hence we have to add values_max(s_i) in order to get n-step returns: - # gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards - # Notice, that in DQN the tensor values_max[:, :-self._margin] - # is the same as values_max[:, :-1]. - n_step_returns = values_max[:, :-self._margin] + \ - self._advantage_estimator( - rewards=np_trajectory.reward, - returns=np_trajectory.return_, - values=values_max, - dones=np_trajectory.done, - discount_mask=np_trajectory.env_info.discount_mask, - ) - - length = n_step_returns.shape[1] - target_returns = n_step_returns[:, :length] - inputs = np_trajectory.observation[:, :length, :] - - yield ( - # Inputs are observations - # (batch, length, obs) - inputs, - # the max indices will be needed to compute the loss - np_trajectory.action[:, :length], # index_max, - # Targets: computed returns. - # target_returns, we expect here shapes such as - # (batch, length, num_actions) - target_returns / self._value_network_scale, - # TODO(henrykm): mask has the shape (batch, max_slice_length) - # that is it misses the action dimension; the preferred format - # would be np_trajectory.mask[:, :length, :] but for now we pass: - np.ones(shape=target_returns.shape), - ) - - def policy(self, trajectory, temperature=1): - """Chooses an action to play after a trajectory.""" - tr_slice = trajectory.suffix(self._max_slice_length) - trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) - # Add batch dimension to trajectory_np and run the model. - obs = trajectory_np.observation[None, ...] - values = self._run_value_model(obs, use_eval_model=False) - # We insisit that values and observations have the shape - # (batch, length, ...), where the length is the number of subsequent - # observations on a given trajectory - assert values.shape[:1] == obs.shape[:1] - # We select the last element in the batch and the value - # related to the last (current) observation - values = values[0, -1, :] - # temperature == 0 is used in another place in order to trigger eval - if np.random.random_sample() < self._exploration_rate(self._epoch) and \ - temperature == 1: - sample = np.array(self.task.action_space.sample()) - else: - # this is our way of doing the argmax - sample = jnp.argmax(values) - result = (sample, values) - if fastmath.backend_name() == 'jax': - result = fastmath.nested_map(lambda x: x.copy(), result) - return result - - def _run_value_model(self, obs, use_eval_model=True): - """Runs value model.""" - n_devices = fastmath.local_device_count() - if use_eval_model: - weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) - state = tl.for_n_devices(self._value_eval_model.state, n_devices) - rng = self._value_eval_model.rng - else: - # TODO(henrykm): this strangely looking solution address the problem that - # value_batches_stream calls _run_value_model _once_ before - # the trainer is initialized. - try: - weights = tl.for_n_devices(self._value_trainer.model_weights, n_devices) - state = tl.for_n_devices(self._value_trainer.model_state, n_devices) - rng = self._value_trainer._rng # pylint: disable=protected-access - except AttributeError: - weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) - state = tl.for_n_devices(self._value_eval_model.state, n_devices) - rng = self._value_eval_model.rng - # TODO(henrykm): the line below fails on TPU with the error - # ValueError: Number of devices (8) does not evenly divide batch size (1). - obs_batch = obs.shape[0] - if n_devices > obs_batch: - obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0) - values, _ = self._value_eval_jit(obs, weights, state, rng) - values = values[:obs_batch] - values *= self._value_network_scale - return values - - @property - def value_mean(self): - """The mean value of actions selected by the behavioral policy.""" - def f(values, actions, returns, mask): - ind_0, ind_1 = np.indices(actions.shape) - # We calculate length using the shape of returns - # and adequatly remove a superflous slice of values. - # An analogous operation is done in value_batches_stream. - length = returns.shape[1] - values = values[:, :length, :] - selected_values = values[ind_0, ind_1, actions] - shapes.assert_same_shape(selected_values, returns) - shapes.assert_same_shape(selected_values, mask) - return jnp.sum(selected_values) / jnp.sum(mask) - return tl.Fn('ValueMean', f) diff --git a/trax/rl/training_test.py b/trax/rl/training_test.py deleted file mode 100644 index 8f646eed2..000000000 --- a/trax/rl/training_test.py +++ /dev/null @@ -1,189 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RL training.""" - -import functools -import math -import os -import pickle - -from absl.testing import absltest -import tensorflow as tf - -from trax import layers as tl -from trax import models -from trax import optimizers as opt -from trax import test_utils -from trax.rl import task as rl_task -from trax.rl import training -from trax.supervised import lr_schedules - - -class TrainingTest(absltest.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - self._model_fn = functools.partial( - models.Policy, - body=lambda mode: tl.Serial( # pylint: disable=g-long-lambda - tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu() - ), - ) - - def test_policy_gradient_smoke(self): - """Smoke test of PolicyGradient.""" - task = rl_task.RLTask('CartPole-v0', max_steps=2) - tmp_dir = self.create_tempdir().full_path - agent = training.PolicyGradient( - task, - model_fn=self._model_fn, - optimizer=opt.Adam, - batch_size=2, - n_trajectories_per_epoch=2, - n_eval_episodes=1, - output_dir=tmp_dir) - agent.run(1) - self.assertEqual(agent.current_epoch, 1) - - def test_expert_iteration_smoke(self): - """Smoke test of ExpertIteration.""" - task = rl_task.RLTask('CartPole-v0', max_steps=2) - tmp_dir = self.create_tempdir().full_path - agent = training.ExpertIteration( - task, - model_fn=self._model_fn, - optimizer=opt.Adam, - batch_size=2, - n_trajectories_per_epoch=2, - n_train_steps_per_epoch=2, - n_eval_episodes=1, - output_dir=tmp_dir, - ) - agent.run(1) - self.assertEqual(agent.current_epoch, 1) - - def test_policy_gradient_save_restore(self): - """Check save and restore of policy agent.""" - task = rl_task.RLTask('CartPole-v0', max_steps=2) - tmp_dir = self.create_tempdir().full_path - agent1 = training.PolicyGradient( - task, - model_fn=self._model_fn, - optimizer=opt.Adam, - batch_size=2, - n_trajectories_per_epoch=2, - n_eval_episodes=1, - output_dir=tmp_dir) - agent1.run(1) - agent1.run(1) - self.assertEqual(agent1.current_epoch, 2) - self.assertEqual(agent1.loop.step, 2) - # Trainer 2 starts where agent 1 stopped. - agent2 = training.PolicyGradient( - task, - model_fn=self._model_fn, - optimizer=opt.Adam, - batch_size=2, - n_trajectories_per_epoch=2, - n_eval_episodes=1, - output_dir=tmp_dir) - agent2.run(1) - self.assertEqual(agent2.current_epoch, 3) - self.assertEqual(agent2.loop.step, 3) - # Manually set saved epoch to 1. - dictionary = { - 'epoch': 1, - 'avg_returns': [0.0], - 'avg_returns_temperature_0.0': { - 200: [0.0] - } - } - with tf.io.gfile.GFile(os.path.join(tmp_dir, 'rl.pkl'), 'wb') as f: - pickle.dump(dictionary, f) - # Trainer 3 restores from a checkpoint with Agent/Loop step mistmatch, - # should fail. - def agent3_fn(): - return training.PolicyGradient( - task, - model_fn=self._model_fn, - optimizer=opt.Adam, - batch_size=2, - n_trajectories_per_epoch=2, - n_eval_episodes=1, - output_dir=tmp_dir, - ) - self.assertRaises(ValueError, agent3_fn) - agent1.close() - agent2.close() - - def test_policy_gradient_cartpole(self): - """Trains a policy on cartpole.""" - task = rl_task.RLTask('CartPole-v0', max_steps=200) - lr = lambda: lr_schedules.multifactor(constant=1e-2, factors='constant') - max_avg_returns = -math.inf - for _ in range(2): - agent = training.PolicyGradient( - task, - model_fn=self._model_fn, - optimizer=opt.Adam, - lr_schedule=lr, - batch_size=128, - eval_temperatures=[0.0, 0.5], - n_eval_episodes=1, - n_trajectories_per_epoch=2, - ) - # Assert that we get to 200 at some point and then exit so the test is as - # fast as possible. - for ep in range(200): - agent.run(1) - self.assertEqual(agent.current_epoch, ep + 1) - if agent.avg_returns[-1] == 200.0: - for eval_t in agent._eval_temperatures: - self.assertEqual( - len(agent._avg_returns_temperatures[eval_t][200]), - len(agent.avg_returns)) - return - max_avg_returns = max(max_avg_returns, agent.avg_returns[-1]) - self.fail( - 'The expected score of 200 has not been reached. ' - 'Maximum at end was {}.'.format(max_avg_returns) - ) - - def test_dqntrainer_cartpole(self): - """Test-runs joint PPO on CartPole.""" - - task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, - max_steps=2) - value_body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) - - lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda - constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') - - trainer = training.DQN( - task, - value_body=value_body, - value_optimizer=opt.Adam, - value_lr_schedule=lr, - value_batch_size=4, - value_train_steps_per_epoch=2, - n_trajectories_per_epoch=5) - trainer.run(2) - self.assertEqual(2, trainer.current_epoch) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl/value_tasks.py b/trax/rl/value_tasks.py deleted file mode 100644 index fe5671bd8..000000000 --- a/trax/rl/value_tasks.py +++ /dev/null @@ -1,219 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Value network training tasks.""" - -import copy - -import numpy as np - -from trax import layers as tl -from trax.fastmath import numpy as jnp -from trax.supervised import training - - -class ValueTrainTask(training.TrainTask): - """Task for value training.""" - - def __init__( - self, - trajectory_batch_stream, - optimizer, - lr_schedule, - advantage_estimator, - model, - target_model=None, - target_scale=1.0, - sync_at=(lambda step: step % 100 == 0), - loss_layer=None, - head_selector=(), - ): - """Initializes ValueTrainTask. - - Args: - trajectory_batch_stream: Generator of trax.rl.task.TimeStepBatch. - optimizer: Optimizer for network training. - lr_schedule: Learning rate schedule for network training. - advantage_estimator: Function - (rewards, returns, values, dones) -> advantages, created by one of the - functions from trax.rl.advantages. - model: Model being trained, used to synchronize weights of the target - model. - target_model: Model for calculating TD targets. If `None`, use `model`. - target_scale: Multiplier for the targets. Useful for rescaling the targets - to a more reasonable range for model training. - sync_at: Function step -> bool, indicating when to synchronize the target - network with the trained network. This is necessary for training the - network on bootstrapped targets, e.g. using TD-k. - loss_layer: The value loss layer. The default is L2 loss. - head_selector: Layer to apply to the network output to select the value - head. Only needed in multitask training. - """ - self._trajectory_batch_stream = trajectory_batch_stream - self._advantage_estimator = advantage_estimator - self._target_scale = target_scale - - self._synced = False - def sync_also_on_initial_batch(step): - return sync_at(step) or not self._synced - self._sync_at = sync_also_on_initial_batch - - self._head_selector = head_selector - - def attach_head(model): - return tl.Serial(model, self._head_selector) - self._train_model = attach_head(model) - if target_model is None: - target_model = model - # TODO(pkozakowski): Use target_model.clone() once it's implemented. - self._target_model = attach_head(copy.deepcopy(target_model)) - - # Count the steps, so we know when to synchronize the target network. - self._step = 0 - def labeled_data(): - for trajectory_batch in self._trajectory_batch_stream: - self._step += 1 - yield self.value_batch(trajectory_batch) - sample_batch = self.value_batch( - next(trajectory_batch_stream), shape_only=True - ) - if loss_layer is None: - loss_layer = tl.L2Loss() - loss_layer = tl.Serial(head_selector, loss_layer) - super().__init__( - labeled_data(), loss_layer, optimizer, - sample_batch=sample_batch, - lr_schedule=lr_schedule, - loss_name='value_loss', - ) - - @property - def trajectory_batch_stream(self): - return self._trajectory_batch_stream - - def _sync_target_model(self): - self._target_model.weights = self._train_model.weights - self._target_model.state = self._train_model.state - self._synced = True - - def value_batch(self, trajectory_batch, shape_only=False): - """Computes a value training batch based on a trajectory batch. - - Args: - trajectory_batch: trax.rl.task.TimeStepBatch with a batch of trajectory - slices. Elements should have shape (batch_size, seq_len, ...). - shape_only: Whether to return dummy zero arrays of correct shape. Useful - for initializing models. - - Returns: - Triple (observations, targets, weights), where targets are the target - values for network training and weights are used for masking in loss - computation. Shapes: - - observations: (batch_size, seq_len) + observation_shape - - targets: (batch_size, seq_len) - - weights: (batch_size, seq_len) - """ - if self._sync_at(self._step) and not shape_only: - self._sync_target_model() - - (batch_size, seq_len) = trajectory_batch.observation.shape[:2] - assert trajectory_batch.action.shape[:2] == (batch_size, seq_len) - assert trajectory_batch.mask.shape == (batch_size, seq_len) - # Compute the value from the target network. - values = np.array(self.value(trajectory_batch, shape_only=shape_only)) - assert values.shape == (batch_size, seq_len) - # Compute the advantages - the TD errors of the target network. - advantages = self._advantage_estimator( - rewards=trajectory_batch.reward, - returns=trajectory_batch.return_, - dones=trajectory_batch.done, - values=values, - discount_mask=trajectory_batch.env_info.discount_mask, - ) - adv_seq_len = advantages.shape[1] - # The advantage sequence should be shorter by the margin. For more details, - # see the comment in policy_tasks.PolicyTrainTask.policy_batch. - assert adv_seq_len <= seq_len - assert advantages.shape == (batch_size, adv_seq_len) - # Compute the targets based on the target values and their TD errors. The - # network gives perfect predictions when targets == values, so the - # advantages are zero. - targets = (values[:, :adv_seq_len] + advantages) * self._target_scale - # Trim observations and the mask to match the target length. - observations = trajectory_batch.observation[:, :adv_seq_len] - mask = trajectory_batch.mask[:, :adv_seq_len] - # Add a singleton depth dimension to the targets and the mask. - targets = targets[:, :, None] - mask = mask[:, :, None] - return (observations, targets, mask) - - def value(self, trajectory_batch, shape_only=False): - """Computes values of states in a given batch of trajectory slices. - - Can be passed as value_fn to PolicyTrainTask to implement a critic baseline - for advantage calculation. - - Args: - trajectory_batch: Batch of trajectory slices to compute values for. - shape_only: Whether to return dummy zero arrays of correct shape. Useful - for initializing models. - - Returns: - Array of values of all states in `trajectory_batch`. - """ - if shape_only: - # The target model hasn't been initialized yet, and we are asked for the - # initial, sample batch. Only shape matters here, so just return zeros. - return np.zeros(trajectory_batch.observation.shape[:2]) - - if not self._synced: - self._sync_target_model() - - values = self._target_model(trajectory_batch.observation) - # Squeeze the singleton depth axis. - return np.squeeze(values, axis=-1) / self._target_scale - - -class ValueEvalTask(training.EvalTask): - """Task for value evaluation.""" - - def __init__(self, train_task, n_eval_batches=1, head_selector=()): - """Initializes ValueEvalTask. - - Args: - train_task: ValueTrainTask used to train the policy network. - n_eval_batches: Number of batches per evaluation. - head_selector: Layer to apply to the network output to select the value - head. Only needed in multitask training. - """ - labeled_data = map( - train_task.value_batch, train_task.trajectory_batch_stream - ) - metrics = [tl.L2Loss(), self.l1_loss] - metric_names = ['value_l2', 'value_l1'] - # Select the appropriate head for evaluation. - metrics = [tl.Serial(head_selector, metric) for metric in metrics] - super().__init__( - labeled_data, metrics, - sample_batch=train_task.sample_batch, - metric_names=metric_names, - n_eval_batches=n_eval_batches, - ) - - @property - def l1_loss(self): - def loss(values, targets, weights): - return jnp.sum(jnp.abs(values - targets) * weights) / jnp.sum(weights) - return tl.Fn('L1Loss', loss) diff --git a/trax/rl/value_tasks_test.py b/trax/rl/value_tasks_test.py deleted file mode 100644 index 1d4274476..000000000 --- a/trax/rl/value_tasks_test.py +++ /dev/null @@ -1,200 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.rl.value_tasks.""" - -from absl.testing import absltest -import numpy as np - -from trax import layers as tl -from trax import models -from trax import optimizers as opt -from trax.rl import advantages -from trax.rl import distributions -from trax.rl import policy_tasks -from trax.rl import task as rl_task -from trax.rl import value_tasks -from trax.supervised import lr_schedules -from trax.supervised import training - - -class ValueTasksTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self._model_fn = lambda mode: tl.Serial( # pylint: disable=g-long-lambda - tl.Dense(64), tl.Relu(), tl.Dense(1) - ) - self._task = rl_task.RLTask( - 'CartPole-v0', gamma=0.5, max_steps=10, initial_trajectories=100 - ) - self._trajectory_batch_stream = self._task.trajectory_batch_stream( - batch_size=256, epochs=[-1], max_slice_length=2 - ) - - def _value_error(self, value_fn): - errors = [] - for _ in range(10): - batch = next(self._trajectory_batch_stream) - values = value_fn(batch) - errors.append(np.mean((values - batch.return_) ** 2)) - return np.mean(errors) - - def test_value_tasks_smoke(self): - # Smoke test for train + eval. - train_model = self._model_fn(mode='train') - eval_model = self._model_fn(mode='eval') - train_task = value_tasks.ValueTrainTask( - self._trajectory_batch_stream, - optimizer=opt.Adam(), - lr_schedule=lr_schedules.constant(1e-3), - advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), - model=train_model, - target_model=eval_model, - ) - eval_task = value_tasks.ValueEvalTask(train_task) - loop = training.Loop( - model=train_model, - eval_model=eval_model, - tasks=[train_task], - eval_tasks=[eval_task], - eval_at=(lambda _: True), - ) - loop.run(n_steps=1) - - def test_value_error_high_without_syncs(self): - train_model = self._model_fn(mode='train') - eval_model = self._model_fn(mode='eval') - train_task = value_tasks.ValueTrainTask( - self._trajectory_batch_stream, - optimizer=opt.Adam(), - lr_schedule=lr_schedules.constant(1e-3), - advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), - model=train_model, - target_model=eval_model, - # Synchronize just once, at the end of training. - sync_at=(lambda step: step == 100), - ) - loop = training.Loop( - model=train_model, - eval_model=eval_model, - tasks=[train_task], - ) - - # Assert that before training, the error is high. - error_before = self._value_error(train_task.value) - self.assertGreater(error_before, 2.0) - - loop.run(n_steps=100) - - # Assert that after training, the error is smaller, but still high. - error_after = self._value_error(train_task.value) - - self.assertLess(error_after, 2.0) - self.assertGreater(error_after, 0.8) - - def test_value_error_low_with_syncs(self): - min_error = np.inf - for _ in range(5): - train_model = self._model_fn(mode='train') - eval_model = self._model_fn(mode='eval') - train_task = value_tasks.ValueTrainTask( - self._trajectory_batch_stream, - optimizer=opt.Adam(), - lr_schedule=lr_schedules.constant(1e-3), - advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), - model=train_model, - target_model=eval_model, - # Synchronize often throughout training. - sync_at=(lambda step: step % 10 == 0), - ) - loop = training.Loop( - model=train_model, - eval_model=eval_model, - tasks=[train_task], - ) - - # Assert that before training, the error is high. - error_before = self._value_error(train_task.value) - self.assertGreater(error_before, 2.0) - - loop.run(n_steps=100) - - # Assert that after training, the error is small. - error_after = self._value_error(train_task.value) - - if error_after < 0.8: - return - - min_error = min(min_error, error_after) - - self.fail(f'Even after 5 trials, min error_after({min_error}) is not < 0.8') - - def test_integration_with_policy_tasks(self): - # Integration test for policy + value training and eval. - optimizer = opt.Adam() - lr_schedule = lr_schedules.constant(1e-3) - advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1) - policy_dist = distributions.create_distribution(self._task.action_space) - body = lambda mode: tl.Dense(64) - train_model = models.PolicyAndValue(policy_dist, body=body) - eval_model = models.PolicyAndValue(policy_dist, body=body) - - head_selector = tl.Select([1]) - value_train_task = value_tasks.ValueTrainTask( - self._trajectory_batch_stream, - optimizer, - lr_schedule, - advantage_estimator, - model=train_model, - target_model=eval_model, - head_selector=head_selector, - ) - value_eval_task = value_tasks.ValueEvalTask( - value_train_task, head_selector=head_selector - ) - - # Drop the value head - just tl.Select([0]) would pass it, and it would - # override the targets. - head_selector = tl.Select([0], n_in=2) - policy_train_task = policy_tasks.PolicyTrainTask( - self._trajectory_batch_stream, - optimizer, - lr_schedule, - policy_dist, - advantage_estimator, - # Plug a trained critic as our value estimate. - value_fn=value_train_task.value, - head_selector=head_selector, - ) - policy_eval_task = policy_tasks.PolicyEvalTask( - policy_train_task, head_selector=head_selector - ) - - loop = training.Loop( - model=train_model, - eval_model=eval_model, - tasks=[policy_train_task, value_train_task], - eval_tasks=[policy_eval_task, value_eval_task], - eval_at=(lambda _: True), - # Switch the task every step. - which_task=(lambda step: step % 2), - ) - # Run for a couple of steps to make sure there are a few task switches. - loop.run(n_steps=10) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/rl_trainer.py b/trax/rl_trainer.py deleted file mode 100644 index 7487a0c46..000000000 --- a/trax/rl_trainer.py +++ /dev/null @@ -1,127 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Trainer for RL environments. - -For now we only support PPO as RL algorithm. - -Sample invocation: - -.. code-block:: bash - - TRAIN_BATCH_SIZE=32 - python trax/rl_trainer.py \ - --config_file=trax/rl/configs/ppo_acrobot.gin \ - --train_batch_size=${TRAIN_BATCH_SIZE} \ - --output_dir=${HOME}/ppo_acrobot \ - --alsologtostderr -""" - -import faulthandler - -from absl import app -from absl import flags -from absl import logging -import gin -import jax -from jax.config import config -from trax import fastmath -from trax import rl # pylint: disable=unused-import -from trax import trainer_flags # pylint: disable=unused-import -from trax.rl import task as rl_task -from trax.rl import training as light_trainers -from trax.tf_numpy import numpy as tf_np - - -FLAGS = flags.FLAGS - - -# Not just 'train' to avoid a conflict with trax.train in GIN files. -@gin.configurable(denylist=['output_dir'], module='trax') -def train_rl( - output_dir, - n_epochs=10000, - light_rl=True, - light_rl_trainer=light_trainers.PolicyGradient): - """Train the RL agent. - - Args: - output_dir: Output directory. - n_epochs: Number epochs to run the training for. - light_rl: deprecated, always True, left out for old gin configs. - light_rl_trainer: which light RL trainer to use (experimental). - """ - del light_rl - tf_np.set_allow_float64(FLAGS.tf_allow_float64) - task = rl_task.RLTask() - env_name = task.env_name - - - if FLAGS.jax_debug_nans: - config.update('jax_debug_nans', True) - - if FLAGS.use_tpu: - config.update('jax_platform_name', 'tpu') - else: - config.update('jax_platform_name', '') - - - trainer = light_rl_trainer(task=task, output_dir=output_dir) - def light_training_loop(): - """Run the trainer for n_epochs and call close on it.""" - try: - logging.info('Starting RL training for %d epochs.', n_epochs) - trainer.run(n_epochs, n_epochs_is_total_epochs=True) - logging.info('Completed RL training for %d epochs.', n_epochs) - trainer.close() - logging.info('Trainer is now closed.') - except Exception as e: - raise e - finally: - logging.info('Encountered an exception, still calling trainer.close()') - trainer.close() - logging.info('Trainer is now closed.') - - if FLAGS.jax_debug_nans or FLAGS.disable_jit: - fastmath.disable_jit() - with jax.disable_jit(): - light_training_loop() - else: - light_training_loop() - - -def main(argv): - del argv - logging.info('Starting RL training.') - - gin_configs = FLAGS.config if FLAGS.config is not None else [] - gin.enter_interactive_mode() - gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) - gin.exit_interactive_mode() - - logging.info('Gin config:') - logging.info(gin_configs) - - train_rl(output_dir=FLAGS.output_dir) - - # TODO(afrozm): This is for debugging. - logging.info('Dumping stack traces of all stacks.') - faulthandler.dump_traceback(all_threads=True) - - logging.info('Training is done, should exit.') - - -if __name__ == '__main__': - app.run(main) diff --git a/trax/shapes.py b/trax/shapes.py deleted file mode 100644 index ee58a7e7c..000000000 --- a/trax/shapes.py +++ /dev/null @@ -1,140 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Core class and functions for handling data abstractly as shapes/dtypes.""" - -import numpy as np -import tensorflow as tf - - -class ShapeDtype: - """A NumPy ndarray-like object abstracted as shape and dtype. - - Main use is for representing input and output signatures. - """ - __slots__ = ['shape', 'dtype'] - - def __init__(self, shape, dtype=np.float32): - """Creates a `ShapeDtype` instance, with canonicalized `shape` and `dtype`. - - Args: - shape: A tuple or list, each element of which is an int or, less often, - `None`. - dtype: A `dtype` object, either from NumPy or TensorFlow. - - Returns: - A `ShapeDtype` instance whose `shape` is a tuple and `dtype` is a NumPy - `dtype` object. - """ - # Canonicalize shape and dtype. - if isinstance(shape, tf.TensorShape): - shape = shape.as_list() - if isinstance(shape, list): - shape = tuple(shape) - if not isinstance(shape, tuple): - raise TypeError('shape must be tuple or list; got: {}'.format(shape)) - if isinstance(dtype, tf.DType): - dtype = dtype.as_numpy_dtype - - self.shape = shape - self.dtype = dtype - - def __eq__(self, other): - return (isinstance(other, self.__class__) - and self.shape == other.shape - and self.dtype == other.dtype) - - def __ne__(self, other): - return not self == other - - def __repr__(self): - return 'ShapeDtype{{shape:{}, dtype:{}}}'.format(self.shape, self.dtype) - - def __len__(self): - """Returns length of 1; relevant to input and output signatures.""" - return 1 - - def as_tuple(self): - return self.shape, self.dtype - - def replace(self, **kwargs): - """Creates a copy of the object with some parameters replaced.""" - return type(self)( - shape=kwargs.pop('shape', self.shape), - dtype=kwargs.pop('dtype', self.dtype), - ) - - -def signature(obj): - """Returns a `ShapeDtype` signature for the given `obj`. - - A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype` - instances. Note that this function is permissive with respect to its inputs - (accepts lists or tuples or dicts, and underlying objects can be any type - as long as they have shape and dtype attributes) and returns the corresponding - nested structure of `ShapeDtype`. - - Args: - obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict - of such objects. - - Returns: - A corresponding nested structure of `ShapeDtype` instances. - """ - if isinstance(obj, (list, tuple)): - output = tuple(signature(x) for x in obj) - return output if isinstance(obj, tuple) else list(output) - elif isinstance(obj, dict): - return {k: signature(v) for (k, v) in obj.items()} - else: - return ShapeDtype(obj.shape, obj.dtype) - - -def splice_signatures(*sigs): - """Creates a new signature by splicing together any number of signatures. - - The splicing effectively flattens the top level input signatures. For - instance, it would perform the following mapping: - - - `*sigs: sd1, (sd2, sd3, sd4), (), sd5` - - return: `(sd1, sd2, sd3, sd4, sd5)` - - Args: - *sigs: Any number of signatures. A signature is either a `ShapeDtype` - instance or a tuple of `ShapeDtype` instances. - - Returns: - A single `ShapeDtype` instance if the spliced signature has one element, - else a tuple of `ShapeDtype` instances. - """ - result_sigs = [] - for sig in sigs: - if isinstance(sig, (list, tuple)): - result_sigs.extend(sig) - else: - result_sigs.append(sig) - return result_sigs[0] if len(result_sigs) == 1 else tuple(result_sigs) - - -def assert_shape_equals(array, shape): - """Asserts that an array has the given shape.""" - assert array.shape == shape, ( - 'Invalid shape {}; expected {}.'.format(array.shape, shape) - ) - - -def assert_same_shape(array1, array2): - """Asserts that two arrays have the same shapes.""" - assert_shape_equals(array1, array2.shape) diff --git a/trax/shapes_test.py b/trax/shapes_test.py deleted file mode 100644 index 4266195e5..000000000 --- a/trax/shapes_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.shapes.""" -from absl.testing import absltest -import numpy as np - -from trax import shapes -from trax.shapes import ShapeDtype - - -class ShapesTest(absltest.TestCase): - - def test_constructor_and_read_properties(self): - sd = ShapeDtype((2, 3), np.int32) - self.assertEqual(sd.shape, (2, 3)) - self.assertEqual(sd.dtype, np.int32) - - def test_default_dtype_is_float32(self): - sd = ShapeDtype((2, 3)) - self.assertEqual(sd.shape, (2, 3)) - self.assertEqual(sd.dtype, np.float32) - - def test_signature_on_ndarray(self): - array = np.array([[2, 3, 5, 7], - [11, 13, 17, 19]], - dtype=np.int16) - sd = shapes.signature(array) - self.assertEqual(sd.shape, (2, 4)) - self.assertEqual(sd.dtype, np.int16) - - def test_shape_dtype_repr(self): - sd = ShapeDtype((2, 3)) - repr_string = '{}'.format(sd) - self.assertEqual(repr_string, - "ShapeDtype{shape:(2, 3), dtype:}") - - def test_splice_signatures(self): - sd1 = ShapeDtype((1,)) - sd2 = ShapeDtype((2,)) - sd3 = ShapeDtype((3,)) - sd4 = ShapeDtype((4,)) - sd5 = ShapeDtype((5,)) - - # Signatures can be ShapeDtype instances, tuples of 2+ ShapeDtype instances, - # or empty tuples. - sig1 = sd1 - sig2 = (sd2, sd3, sd4) - sig3 = () - sig4 = sd5 - spliced = shapes.splice_signatures(sig1, sig2, sig3, sig4) - self.assertEqual(spliced, (sd1, sd2, sd3, sd4, sd5)) - - def test_len_signature(self): - """Signatures of all sizes should give correct length when asked.""" - x1 = np.array([1, 2, 3]) - x2 = np.array([10, 20, 30]) - inputs0 = () - inputs1 = x1 # NOT in a tuple - inputs2 = (x1, x2) - - sig0 = shapes.signature(inputs0) - sig1 = shapes.signature(inputs1) - sig2 = shapes.signature(inputs2) - - # pylint: disable=g-generic-assert - self.assertEqual(len(sig0), 0) - self.assertEqual(len(sig1), 1) - self.assertEqual(len(sig2), 2) - # pylint: enable=g-generic-assert - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/__init__.py b/trax/supervised/__init__.py deleted file mode 100644 index cc4ea931c..000000000 --- a/trax/supervised/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Supervised learning imports in Trax.""" - -from trax.supervised import callbacks -from trax.supervised import decoding -from trax.supervised import lr_schedules -from trax.supervised import trainer_lib -from trax.supervised import training -from trax.supervised.trainer_lib import train -from trax.supervised.trainer_lib import Trainer -from trax.supervised.training import EvalTask -from trax.supervised.training import TrainTask diff --git a/trax/supervised/callbacks.py b/trax/supervised/callbacks.py deleted file mode 100644 index 9c9b826b9..000000000 --- a/trax/supervised/callbacks.py +++ /dev/null @@ -1,248 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Loop callbacks. - -Callbacks can be used to customize the behavior of `supervised.training.Loop` -to accomodate a variety of use-cases. - -Examples include: - - custom evaluation schemes - - logging metrics to external servers - - sending model checkpoints to external servers - - updating the target network in RL algorithms and other non-stationary - problems -""" - -import collections -import os - -import gin -import numpy as np - -from trax import jaxboard -from trax import layers as tl -from trax import shapes -from trax.rl import serialization_utils -from trax.supervised import decoding - - -class TrainingStepCallback: - """Callback triggered before and after a training step.""" - - def __init__(self, loop): - """Initializes the callback with a `supervised.training.Loop` instance.""" - self._loop = loop - - def call_at(self, step): - """Returns whether the callback should be called at a given step.""" - raise NotImplementedError - - def on_step_begin(self, step): - """Called by Loop before training steps, when call_at returned True.""" - raise NotImplementedError - - def on_step_end(self, step): - """Called by Loop after training steps, when call_at returned True.""" - raise NotImplementedError - - -@gin.configurable -class SerializedModelEvaluation(TrainingStepCallback): - """Evaluates serialized sequence prediction models. - - Example: time series prediction. We can serialize a time series into - a sequence of discrete tokens and model this sequence using an autoregressive - sequence model, such as Transformer - see - `trax.rl.serialization_utils.SerializedModel`. Then we can use this callback - to evaluate long-horizon predictions of such a model. - """ - - def __init__( - self, - loop, - model=None, - eval_at=1000, - eval_task=None, - context_lengths=(1,), - horizon_lengths=(1,), - n_steps=1, - accelerate_model=True, - ): - """Initializes SerializedModelEvaluation. - - Args: - loop: Instance of `trax.supervised.training.Loop` or `None`. Can be set to - `None` for testing - in such a case, `model` and `eval_task` must be - provided. - model: Instance of `trax.rl.serialization_utils.SerializedModel`. Not - required if `loop` is provided. - eval_at: When to evaluate. Either int (every how many steps to evaluate), - or a list of ints (step numbers), or a function int -> bool (step - predicate). - eval_task: Instance of `trax.supervised.training.EvalTask` with the - evaluation data, or None. If not provided, the task will be taken from - `loop`. - context_lengths: List of lengths of the context sequence fed into the - model before starting prediction. - horizon_lengths: List of lengths of the predicted sequence. - n_steps: Number of batches to run evaluation for. - accelerate_model (bool): Whether to wrap the model in `tl.Accelerate`. - """ - super().__init__(loop) - - if model is None: - model = loop.model - - observation_serializer = model.observation_serializer - action_serializer = model.action_serializer - - predict_model = model.make_predict_model() - if accelerate_model: - predict_model = tl.Accelerate(predict_model) - self._predict_model = predict_model - self._obs_serializer = observation_serializer - self._act_serializer = action_serializer - - if isinstance(eval_at, int): - self._eval_at = lambda step: step % eval_at == 1 - elif hasattr(eval_at, '__in__'): - self._eval_at = lambda step: step in eval_at - elif callable(eval_at): - self._eval_at = eval_at - else: - raise TypeError(f'Unsupported type for eval_at: {type(eval_at)}.') - - if eval_task is None: - if len(loop.eval_tasks) != 1: - raise ValueError( - 'If eval_task is not provided, the number of eval_tasks registered ' - 'in Loop must be exactly 1.' - ) - eval_task = loop.eval_tasks[0] - self._eval_task = eval_task - - self._context_lengths = list(sorted(context_lengths)) - self._horizon_lengths = list(sorted(horizon_lengths)) - self._n_steps = n_steps - - self._batch_size = eval_task.sample_batch[0].shape[0] - (_, self._init_state) = predict_model.init( - shapes.ShapeDtype((self._batch_size, 1), dtype=np.int32) - ) - - @property - def predict_model(self): - return self._predict_model - - def call_at(self, step): - return self._eval_at(step) - - def on_step_begin(self, step): - pass - - def on_step_end(self, step): - summary_writer = jaxboard.SummaryWriter( - os.path.join(self._loop.output_dir, 'srl_eval') - ) - try: - weights = self._loop.eval_model.seq_model_weights - metrics = self.evaluate(weights) - self._loop.log_summary(metrics, summary_writer, '', 'srl_eval') - finally: - summary_writer.close() - - def evaluate(self, weights): - """Evaluates the model and returns the metrics.""" - self._predict_model.weights = weights - - metrics = collections.defaultdict(list) - for _ in range(self._n_steps): - batch = self._eval_task.next_batch() - step_metrics = self._evaluate_batch(batch) - for (key, value) in step_metrics.items(): - metrics[key].append(value) - - metrics = {k: np.array(v) for (k, v) in metrics.items()} - - def metric_name(context, horizon): - return f'pred_error/context_{context}/horizon_{horizon}' - - return { - metric_name(context, horizon): - np.sum(errors) / (np.sum(errors != 0) + 1e-6) - for ((context, horizon), errors) in metrics.items() - } - - def _evaluate_batch(self, batch): - """Performs evaluation on a single batch.""" - (obs, act, _, mask) = batch - obs_repr = serialization_utils.Serialize(self._obs_serializer)(obs) - act_repr = serialization_utils.Serialize(self._act_serializer)(act) - - errors = {} - last_context = 0 - last_state = self._init_state - last_start_id = 0 - for context in self._context_lengths: - self._predict_model.state = last_state - start_id = last_start_id - - if context > last_context: - context_seq = serialization_utils.Interleave()(( - obs_repr[:, last_context:context], act_repr[:, last_context:context] - )) - consume_sequence(self._predict_model, start_id, context_seq[:, :-1]) - last_start_id = start_id = context_seq[:, -1:] - last_state = self._predict_model.state - last_context = context - - for timestep in range(max(self._horizon_lengths)): - pred_repr = decoding.autoregressive_sample( - self._predict_model, - start_id=start_id, - eos_id=-1, - batch_size=self._batch_size, - max_length=self._obs_serializer.representation_length, - accelerate=False, - ) - horizon = timestep + 1 - if horizon in self._horizon_lengths: - pred = self._obs_serializer.deserialize(pred_repr) - error = self._calculate_error(pred, obs[:, context + timestep]) - errors[context, horizon] = error * mask[:, context + timestep] - - start_id = pred_repr[:, -1:] - consume_sequence( - self._predict_model, start_id, act_repr[:, context + timestep, :-1] - ) - start_id = act_repr[:, context + timestep, -1:] - - return errors - - def _calculate_error(self, prediction, ground_truth): - return (prediction - ground_truth) ** 2 - - -def consume_sequence(model, start_id, sequence): - decoding.autoregressive_sample( - model, - start_id=start_id, - eos_id=-1, - inputs=sequence, - batch_size=sequence.shape[0], - max_length=1, - accelerate=False, - ) diff --git a/trax/supervised/callbacks_test.py b/trax/supervised/callbacks_test.py deleted file mode 100644 index 3eaf328f8..000000000 --- a/trax/supervised/callbacks_test.py +++ /dev/null @@ -1,226 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.supervised.callbacks.""" - -import functools -import io -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import gym -import numpy as np - -from trax import models -from trax import test_utils -from trax.data import inputs -from trax.layers import test_utils as tl_test_utils -from trax.rl import serialization_utils -from trax.rl import space_serializer -from trax.supervised import callbacks -from trax.supervised import lr_schedules -from trax.supervised import trainer_lib -from trax.supervised import training - - -def random_inputs(seq_len, batch_size): - def stream_fn(num_devices): - del num_devices - while True: - x = np.random.uniform(size=(batch_size, seq_len)) - y = np.random.uniform(size=(batch_size, seq_len)) - mask = np.ones_like(x).astype(np.float32) - yield (x, y, x, mask) - - return inputs.Inputs( - train_stream=stream_fn, - eval_stream=stream_fn, - ) - - -def make_multibonacci_modulo(history_length, limit): - """Creates a function that generates the Multibonacci sequence modulo n.""" - def sequence_fn(seq): - return np.sum(seq[-history_length:]) % limit - return sequence_fn - - -def generate_trajectory(sequence_fn, space, n_steps): - """Generates random actions and observations that follow sequence_fn.""" - act = [space.sample() for _ in range(n_steps)] - obs = [space.sample()] - - for (o, a) in zip( - obs, - act[:-1], # Don't generate the last observation. - ): - context = list(np.array([o, a]).flatten()) - symbols = [] - for _ in range(np.array(o).size): - symbol = sequence_fn(context + symbols) - symbols.append(symbol) - obs.append(np.reshape(symbols, space.shape)) - - obs = np.array([obs]) - act = np.array([act]) - return (obs, act) - - -def make_singleton_eval_task(observations, actions): - """Creates an EvalTask with just one example.""" - mask = np.ones(observations.shape[:2]) - def data(): - while True: - yield (observations, actions, observations, mask) - - return training.EvalTask( - labeled_data=data(), - metrics=[], - ) - - -def make_serialized_model(seq_model, space, vocab_size): - srl = space_serializer.create(space, vocab_size) - return serialization_utils.SerializedModel( - functools.partial(seq_model, vocab_size=vocab_size), - observation_serializer=srl, - action_serializer=srl, - significance_decay=0.7, - ) - - -class CallbacksTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - - @mock.patch('sys.stdout', new_callable=io.StringIO) - def test_serialized_model_evaluation(self, mock_stdout): - precision = 1 - vocab_size = 2 - srl = space_serializer.BoxSpaceSerializer( - space=gym.spaces.Box(shape=(), low=0.0, high=1.0), - vocab_size=vocab_size, - precision=precision, - ) - - def inner_model(mode): - return models.TransformerLM( - mode=mode, - vocab_size=vocab_size, - d_model=2, - d_ff=4, - n_layers=1, - n_heads=1, - ) - - serialized_model_fn = functools.partial( - serialization_utils.SerializedModel, - inner_model, - observation_serializer=srl, - action_serializer=srl, - significance_decay=0.7, - ) - eval_callback = functools.partial( - callbacks.SerializedModelEvaluation, eval_at=5 - ) - - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir=output_dir, - model=serialized_model_fn, - inputs=functools.partial(random_inputs, seq_len=4, batch_size=64), - lr_schedule_fn=functools.partial(lr_schedules.constant, 0.01), - callbacks=[eval_callback], - steps=10, - ) - self.assertTrue(_has_metric('pred_error', mock_stdout)) - - @parameterized.product( - context_lengths=((2,), (1, 3)), - horizon_lengths=((1,), (1, 2)), - ) - def test_srl_eval_feeds_correct_sequence( - self, context_lengths, horizon_lengths - ): - vocab_size = 10 - n_steps = 5 - - multibonacci_modulo = make_multibonacci_modulo(2, vocab_size) - space = gym.spaces.Discrete(n=vocab_size) - (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) - eval_task = make_singleton_eval_task(obs, act) - seq_model = functools.partial( - tl_test_utils.MockTransformerLM, - sequence_fn=multibonacci_modulo, - ) - serialized_model = make_serialized_model(seq_model, space, vocab_size) - callback = callbacks.SerializedModelEvaluation( - loop=None, - eval_task=eval_task, - model=serialized_model, - context_lengths=context_lengths, - horizon_lengths=horizon_lengths, - accelerate_model=False, - ) - callback.evaluate(weights=None) - - expected_seq = np.zeros(2 * n_steps + 1) - expected_seq[1::2] = obs - expected_seq[2::2] = act - seen_len = (context_lengths[-1] + horizon_lengths[-1]) * 2 - callback.predict_model.assert_prediction_buffers_equal( - [expected_seq[:seen_len]] - ) - - @parameterized.named_parameters(('one_symbol', 1), ('two_symbols', 2)) - def test_srl_eval_reports_zero_error_for_perfect_model(self, precision): - vocab_size = 100 - n_steps = 5 - - multibonacci_modulo = make_multibonacci_modulo(2 * precision, vocab_size) - space = gym.spaces.MultiDiscrete(nvec=([vocab_size] * precision)) - (obs, act) = generate_trajectory(multibonacci_modulo, space, n_steps) - eval_task = make_singleton_eval_task(obs, act) - seq_model = functools.partial( - tl_test_utils.MockTransformerLM, - sequence_fn=multibonacci_modulo, - ) - serialized_model = make_serialized_model(seq_model, space, vocab_size) - callback = callbacks.SerializedModelEvaluation( - loop=None, - eval_task=eval_task, - model=serialized_model, - context_lengths=(1,), - horizon_lengths=(4,), - accelerate_model=False, - ) - metrics = callback.evaluate(weights=None) - error = next( - value for (name, value) in metrics.items() if 'pred_error' in name - ) - assert error == 0 - - -def _has_metric(metric_name, stdout): - log = stdout.getvalue() - metric_logs = [line for line in log.split('\n') if metric_name in line] - return bool(metric_logs) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/decoding.py b/trax/supervised/decoding.py deleted file mode 100644 index d8902c1bc..000000000 --- a/trax/supervised/decoding.py +++ /dev/null @@ -1,263 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Decoding with Trax models.""" - -import numpy as np -from trax import fastmath -from trax import layers as tl - - -def autoregressive_sample_stream(model, inputs=None, - batch_size=1, temperature=1.0, - start_id=0, accelerate=True, - eval_mode=False, eval_min_length=1): - """Yields samples from `model`, in autoregressive language model fashion. - - This function uses `model` to generate outputs one position at a time, with - access to inputs for the current position and all preceding positions. The - new output becomes the next position's input, and further calls to - `autoregressive_sample_stream` repeat the process for successive positions - indefinitely. - - Inputs and outputs always come in batches, even if size 1. If `inputs` is - present, it must have shape (`batch_size`, inputs_sequence_length), and each - output in the stream has shape (`batch_size`, 1). - - Args: - model: A layer object (subclass of `trax.layers.Layer`) created in - `'predict'` mode and initialized from trained weights. The model - must have a structure that allows it to run as an autoregressive - one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), - except if `eval_mode` is set -- any model can be sampled then, - but the sampling process may be much slower. - inputs: Sequence of symbols the model sees as input the first time it - generates an output. If None, the model generates the first output - based on just the start symbol. - batch_size: Number of sequences to generate in parallel as a batch. - temperature: Parameter that controls the sharpness of the softmax that - feeds the sampling process. Values range from 0.0 (all probability mass - goes to one candidate; like an argmax) to positive infinity (all - candidates have equal probability). - start_id: Integer representing the start symbol for the autoregressive - process, or array of shape (`batch_size`, 1) of such integers. - accelerate: If True, create an accelerated version of `model` and use it - for generating outputs. - eval_mode: If True, assume the model is created in `eval` mode and sample - by collecting all previous outputs and passing the whole tensor. - eval_min_length: If set, the minimum length to pad to in eval mode. - - Yields: - Tensor of integers with shape (`batch_size`, 1), representing the batch of - outputs for the next position in the stream. - """ - if inputs is not None and inputs.shape[0] != batch_size: - raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match ' - f'batch_size arg ({batch_size}.') - - fast_model = tl.Accelerate(model) if accelerate else model - if np.isscalar(start_id): - start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) - else: - start_symbol = start_id - if model.n_in == 1 and inputs is not None: - current_symbols = np.concatenate([start_symbol, inputs], axis=1) - else: - current_symbols = start_symbol - - if eval_mode: - # no start symbol needed in eval mode - current_symbols = current_symbols[:, 1:] - - while True: - # Pad inputs to power-of-2 length if needed. - if eval_mode: - # one extra symbol as an initial one will be added - l = max(eval_min_length, current_symbols.shape[1] + 1) - pad_len = int(2**np.ceil(np.log2(l))) - current_symbols.shape[1] - unpadded_symbols = current_symbols - current_symbols = np.pad( - current_symbols, [[0, 0], [0, pad_len]], mode='constant') - last_index = -pad_len # no -1 as the starting one will be added - else: - last_index = -1 - # Run the model. - if model.n_in > 1 and inputs is not None: - logits = fast_model((inputs, current_symbols))[0] - else: - logits = fast_model(current_symbols) - logits = tl.log_softmax(logits[:, last_index, :]) - sample = tl.logsoftmax_sample(logits, temperature=temperature) - yield sample - if eval_mode: - current_symbols = np.concatenate( - [unpadded_symbols, sample[:, None]], axis=1) - else: - # NOTE: Because the model is autoregressive and in 'predict' mode, its - # history is cached in the model state and the next input is the single - # symbol just sampled. - current_symbols = sample[:, None] - - -def autoregressive_sample(model, inputs=None, - batch_size=1, temperature=1.0, - start_id=0, eos_id=1, max_length=100, - accelerate=True, eval_mode=False, eval_min_length=1): - """Returns a batch of sequences created by autoregressive sampling. - - This function uses `model` to generate outputs one position at a time, with - access to inputs for the current position and all preceding positions. The - new output becomes the next position's input, and this loop repeats until - either the model outputs the `eos_id` value or the output sequence reaches - `max_length` items. - - Args: - model: A layer object (subclass of `trax.layers.Layer`) created in - `'predict'` mode and initialized from trained weights. The model - must have a structure that allows it to run as autoregressive - one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`), - except if `eval_mode` is set -- any model can be sampled then, - but the sampling process may be much slower. - inputs: Sequence of symbols the model sees as input the first time it - generates an output. If None, the model must generate the first output - with no input to guide it. - batch_size: Number of sequences to generate in parallel as a batch. - temperature: Parameter that controls the sharpness of the softmax that - feeds the sampling process. Values range from 0.0 (all probability mass - goes to one candidate; like an argmax) to positive infinity (all - candidates have equal probability). - start_id: The start symbol (ID/integer) for the autoregressive process, - or array of shape (`batch_size`, 1) of such integers. - eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive - process. - max_length: Maximum length for generated sequences. - accelerate: If True, create an accelerated version of `model` and use it - for generating outputs. - eval_mode: If True, assume the model is created in `eval` mode and sample - by collecting all previous outputs and passing the whole tensor. - eval_min_length: If set, the minimum length to pad to in eval mode. - - Returns: - Tensor of integers with shape (`batch_size`, output_length) representing - a batch of output sequences. output_length is the maximum length of the - output sequences, where each sequence can be no longer than `max_length`. - """ - result = [] - eos_seen = [] - counter = 0 - for sample in autoregressive_sample_stream( - model, inputs, batch_size=batch_size, temperature=temperature, - start_id=start_id, accelerate=accelerate, eval_mode=eval_mode, - eval_min_length=eval_min_length): - sample = sample[:, None] - result.append(sample) - counter += 1 - if counter >= max_length: - return np.concatenate(result, axis=1) - # Check at which batch positions have we already encountered EOS. - for j in range(batch_size): - if int(sample[j, 0]) == eos_id: - eos_seen.append(j) - # If EOS has been seen on all positions, stop. - if all([j in eos_seen for j in range(batch_size)]): - return np.concatenate(result, axis=1) - return np.concatenate(result, axis=1) - - -def beam_search(model, inputs=None, batch_size=1, n_beams=2, start_id=0, - eos_id=1, max_length=100, length_penalty=1.0, accelerate=True): - """Returns a batch of n_beams-sequences created by beam search. - - This function uses `model` to generate outputs one position at a time, with - access to inputs for the current position and all preceding positions. The - new output becomes the next position's input, and this loop repeats until - either the model outputs the `eos_id` value or the output sequence reaches - `max_length` items -- but keeping n_beams top beams. - - Args: - model: A layer object (subclass of `trax.layers.Layer`) created in - `'predict'` mode and initialized from trained weights. The model - must have a structure that allows it to run as autoregressive - one-sample-at-a-time predictor (e.g., `trax.models.TransformerLM`). - inputs: Sequence of symbols the model sees as input the first time it - generates an output. If None, the model must generate the first output - with no input to guide it. - batch_size: Number of sequences to generate in parallel as a batch. - n_beams: How many beams to consider at the same time. - start_id: The start symbol (ID/integer) for the autoregressive process, - or array of shape (`batch_size`, 1) of such integers. - eos_id: The end-of-sequence symbol (ID/integer) for the autoregressive - process. - max_length: Maximum length for generated sequences. - length_penalty: Factor alpha in calculating the length penalty for beams. - accelerate: If True, create an accelerated version of `model` and use it - for generating outputs. - - Returns: - Tensor of integers with shape (`batch_size`, n_beams, output_length) with - a batch of output sequences. output_length is the maximum length of the - output sequences, where each sequence can be no longer than `max_length`. - """ - del eos_id, length_penalty # TODO(lukaszkaiser): add length penalty, eos - assert batch_size == 1, 'Batch size > 1 not supported yet' - if inputs is not None and inputs.shape[0] != batch_size: - raise ValueError(f'Inputs batch size ({inputs.shape[0]}) does not match ' - f'batch_size arg ({batch_size}.') - - fast_model = tl.Accelerate(model) if accelerate else model - if np.isscalar(start_id): - start_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) - else: - start_symbol = start_id - if model.n_in == 1 and inputs is not None: - current_symbols = np.concatenate([start_symbol, inputs], axis=1) - else: - current_symbols = start_symbol - - beams = [current_symbols for _ in range(n_beams)] - results = [([], 0.0) for _ in range(n_beams)] - states = [fast_model.state for _ in range(n_beams)] - top_k = [None] * n_beams - counter = 0 - while counter < max_length: - counter += 1 - # Run the model on all beams, collect states and top_k for each beam. - for beam_id in range(n_beams if counter > 1 else 1): - fast_model.state = states[beam_id] - if model.n_in > 1 and inputs is not None: - logits = fast_model((inputs, beams[beam_id]))[0] - else: - logits = fast_model(beams[beam_id]) - logits = tl.log_softmax(logits[:, -1, :]) - states[beam_id] = fast_model.state - top_k[beam_id] = fastmath.top_k(logits, k=n_beams) - - # Select new beams. - cur_values = [] # will hold triples (sum-of-logprobs, beam-id, symbol) - for beam_id in range(n_beams if counter > 1 else 1): - for k in range(n_beams): - values, symbols = top_k[beam_id] - value, symbol = values[:, k], symbols[:, k] - cur_values.append((results[beam_id][1] + value, beam_id, symbol)) - cur_values.sort(key=lambda x: -x[0][0]) # x[0][0] as batch_size=1 - # Collect top beams to the new states and results. - new_results, new_states, new_beams = [], [], [] - for (value, beam_id, symbol) in cur_values[:n_beams]: - new_results.append((results[beam_id][0] + [symbol], value)) - new_states.append(states[beam_id]) # copy? - new_beams.append(symbol[:, None]) - results, states, beams = new_results, new_states, new_beams - - return [(np.stack(r, axis=-1), v) for (r, v) in results] diff --git a/trax/supervised/decoding_test.py b/trax/supervised/decoding_test.py deleted file mode 100644 index afaad725c..000000000 --- a/trax/supervised/decoding_test.py +++ /dev/null @@ -1,453 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for decoding.""" - -import functools -import os - -import gin -from jax.config import config -import numpy as np -from tensorflow.compat.v2 import test - -from trax import fastmath -from trax import layers as tl -from trax import models -from trax import shapes -from trax.supervised import decoding - - -pkg_dir, _ = os.path.split(__file__) -_TESTDATA = os.path.join(pkg_dir, 'testdata') -_CONFIG_DIR = os.path.join(pkg_dir, 'configs/') - - -class DecodingTest(test.TestCase): - - def test_autoregressive_sample_transformerlm(self): - model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, - n_heads=2, mode='predict') - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - s1 = decoding.autoregressive_sample( - model, batch_size=1, eos_id=-1, max_length=10) - self.assertEqual(s1.shape[0], 1) - self.assertEqual(s1.shape[1], 10) - batch_per_device = 2 // fastmath.local_device_count() - model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) - s2 = decoding.autoregressive_sample( - model, batch_size=2, max_length=10) - self.assertEqual(s2.shape[0], 2) - self.assertLess(s2.shape[1], 11) - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - prefix = np.array([[1, 2, 3]]) - s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10, - batch_size=1) - self.assertEqual(s3.shape[0], 1) - self.assertEqual(s3.shape[1], 10) - - def test_autoregressive_sample_transformerlm_tfnp(self): - with fastmath.use_backend(fastmath.Backend.TFNP): - model = models.TransformerLM(10, d_model=32, d_ff=64, n_layers=1, - n_heads=2, mode='predict') - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - s1 = decoding.autoregressive_sample( - model, batch_size=1, eos_id=-1, max_length=10) - self.assertEqual(s1.shape[0], 1) - self.assertEqual(s1.shape[1], 10) - batch_per_device = 2 // fastmath.local_device_count() - model.init(shapes.ShapeDtype((batch_per_device, 1), dtype=np.int32)) - s2 = decoding.autoregressive_sample( - model, batch_size=2, max_length=10) - self.assertEqual(s2.shape[0], 2) - self.assertLess(s2.shape[1], 11) - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - prefix = np.array([[1, 2, 3]]) - s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, - max_length=10, batch_size=1) - self.assertEqual(s3.shape[0], 1) - self.assertEqual(s3.shape[1], 10) - - def _lsh_self_attention_fn(self): - return functools.partial( - tl.LSHSelfAttention, - attention_dropout=0.0, - chunk_len=64, - n_buckets=[32, 32], - n_chunks_after=0, - n_chunks_before=1, - n_hashes=1, - n_parallel_heads=1, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def _pure_lsh_self_attention_fn(self, n_chunks_after=0): - return functools.partial( - tl.PureLSHSelfAttentionWrapper, - attention_dropout=0.0, - chunk_len=16, - n_buckets=[32, 32], - n_chunks_after=n_chunks_after, - n_chunks_before=1, - n_hashes=2, - n_parallel_heads=1, - max_length_for_buckets=1024, - predict_drop_len=128, - predict_mem_len=1024, - num_weights=2, - bias=False, - pure_lsh_implementation=tl.PureLSHSelfAttention, - ) - - def _timebin_self_attention_fn(self, use_reference_code=False, chunk_len=64): - return functools.partial( - tl.SelfAttention, - attention_dropout=0.05, - chunk_len=chunk_len, - n_chunks_before=1, - n_parallel_heads=1, - use_reference_code=use_reference_code, - predict_drop_len=128, - predict_mem_len=1024, - ) - - def test_autoregressive_sample_reformerlm(self): - lsh_self_attention = self._lsh_self_attention_fn() - timebin_self_attention = self._timebin_self_attention_fn() - - model = models.ReformerLM(vocab_size=256, - d_model=256, - d_ff=512, - d_attention_key=128, - d_attention_value=128, - n_layers=2, - n_heads=2, - dropout=0.05, - max_len=65536, - attention_type=[timebin_self_attention, - lsh_self_attention], - pos_axial_shape=(256, 256), - pos_d_axial_embs=(128, 128), - ff_activation=tl.Relu, - ff_use_sru=0, - mode='predict', - ) - model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) - s1 = decoding.autoregressive_sample( - model, batch_size=1, eos_id=-1, max_length=10) - self.assertEqual(s1.shape[0], 1) - self.assertEqual(s1.shape[1], 10) - - def test_autoregressive_sample_transformer(self): - model = models.Transformer(10, d_model=32, d_ff=64, n_encoder_layers=1, - n_decoder_layers=1, n_heads=2, mode='predict') - inputs = np.ones((1, 3), dtype=np.int32) - model.init((shapes.signature(inputs), - shapes.ShapeDtype((1, 1), dtype=np.int32))) - s = decoding.autoregressive_sample(model, inputs=inputs, - eos_id=-1, max_length=10) - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_transformerlm_quality(self): - pred_model = models.TransformerLM( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_layers=2, vocab_size=13, mode='predict') - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape11, shape11)) - inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.autoregressive_sample(pred_model, inputs, - max_length=6, temperature=0.0) - self.assertEqual(str(s[0]), '[3 7 5 3 2 4]') - - def test_autoregressive_sample_transformerlm_quality_eval(self): - eval_model = models.TransformerLM( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_layers=2, vocab_size=13, mode='eval') - model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') - eval_model.init_from_file(model_path) - inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.autoregressive_sample(eval_model, inputs, eval_mode=True, - max_length=6, temperature=0.0) - self.assertEqual(str(s[0]), '[3 7 5 3 2 4]') - - def test_autoregressive_sample_transformerlm_quality_beam(self): - pred_model = models.TransformerLM( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_layers=2, vocab_size=13, mode='predict') - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model_path = os.path.join(_TESTDATA, 'transformerlm_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape11, shape11)) - inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.beam_search(pred_model, inputs, n_beams=3, max_length=6) - self.assertEqual(len(s), 3) # 3 beams - self.assertEqual(str(s[0][0][0]), '[3 7 5 3 2 4]') - self.assertEqual(str(s[1][0][0]), '[3 7 5 3 2 2]') # different from above - self.assertEqual(str(s[2][0][0]), '[3 7 5 3 3 2]') # different from above - - def test_autoregressive_sample_transformer_quality(self): - pred_model = models.Transformer( - d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, - n_encoder_layers=2, n_decoder_layers=2, input_vocab_size=13, - mode='predict') - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - model_path = os.path.join(_TESTDATA, 'transformer_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape11, shape11)) - inputs = np.array([[3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32) - s = decoding.autoregressive_sample(pred_model, inputs=inputs, - eos_id=1, max_length=10, temperature=0.0) - self.assertEqual(str(s[0]), '[3 7 5 3 2 4 1]') - - def test_autoregressive_sample_terraformer_lsh(self): - max_len = 128 - - pred_model = models.ConfigurableTerraformer( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=max_len, - n_heads=4, - n_encoder_layers=1, - n_decoder_layers=1, - ff_use_sru=1, - d_attention_key=64, - d_attention_value=64, - encoder_attention_type=self._lsh_self_attention_fn(), - encoder_decoder_attention_type=self._lsh_self_attention_fn(), - input_vocab_size=256, - pos_axial_shape=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - pred_model.init(input_signature=(shape1l, shape11)) - - # 0w0w - inputs = np.array( - [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], - dtype=np.int32) - inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], - mode='constant', constant_values=0) - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) - - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_terraformer_lsh_attn_quality(self): - gin.add_config_file_search_path(_CONFIG_DIR) - max_len = 32 # 32 is the max length we trained the checkpoint for. - test_lengths = [8, 16, 32] - vocab_size = 13 - # The checkpoint is correct on ~90% sequences, set random seed to deflake. - np.random.seed(0) - for test_len in test_lengths: - gin.clear_config() - gin.parse_config_file('terraformer_copy.gin') - gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len) - gin.bind_parameter('LSHSelfAttention.predict_drop_len', 2 * max_len) - - pred_model = models.ConfigurableTerraformer(mode='predict') - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - - model_path = os.path.join(_TESTDATA, 'terraformer_copy_lsh_attn.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape1l, shape11)) - initial_state = pred_model.state - - for _ in range(2): # Set low to make the test run reasonably fast. - # Pick a length in [1, test_len] at random. - inp_len = np.random.randint(low=1, high=test_len + 1) - inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, max_len)) - # TODO(jaszczur): properly fix padding in terraformer predict mode, - # and add a test here. - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, - temperature=0.0) - np.testing.assert_equal(s[0], inputs[0, :inp_len]) - pred_model.state = initial_state - gin.clear_config() # Make sure to not affect other tests. - - def test_autoregressive_sample_reformerlm_lsh(self): - max_len = 32 - - pred_model = models.ReformerLM( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=2 * max_len, - n_heads=4, - n_layers=3, - ff_use_sru=0, - d_attention_key=64, - d_attention_value=64, - attention_type=functools.partial(tl.LSHSelfAttention, - chunk_len=16, - n_hashes=2, - n_buckets=[32, 32], - predict_drop_len=max_len, - predict_mem_len=max_len, - max_length_for_buckets=1024), - vocab_size=13, - pos_type='fixed-base', - pos_d_axial_embs=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - pred_model.init(shape11) - - # 0w0 - inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) - inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], - mode='constant', constant_values=0) - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) - - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_reformerlm_lsh_quality(self): - max_len = 32 - - pred_model = models.ReformerLM( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=2 * max_len, - n_heads=4, - n_layers=3, - ff_use_sru=0, - d_attention_key=64, - d_attention_value=64, - attention_type=functools.partial(tl.LSHSelfAttention, - chunk_len=16, - n_hashes=2, - n_buckets=[32, 32], - predict_drop_len=max_len, - predict_mem_len=max_len, - max_length_for_buckets=1024), - vocab_size=13, - pos_type='fixed-base', - pos_d_axial_embs=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - - model_path = os.path.join( - _TESTDATA, 'reformerlm_copy_lsh_attn.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=shape11) - - # 0w0 - inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) - inp_len = inputs.shape[1] - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=inp_len-2, - temperature=0.0) - - np.testing.assert_equal(s[0], inputs[0, 1:inp_len-1]) - # pylint: enable=unreachable - - def test_autoregressive_sample_terraformer_pure_lsh(self): - max_len = 128 - - pred_model = models.ConfigurableTerraformer( - mode='predict', - d_model=256, - d_ff=512, - dropout=0.05, - max_len=max_len, - n_heads=4, - n_encoder_layers=1, - n_decoder_layers=1, - ff_use_sru=1, - d_attention_key=64, - d_attention_value=64, - encoder_attention_type=self._pure_lsh_self_attention_fn( - n_chunks_after=1), - encoder_decoder_attention_type=self._pure_lsh_self_attention_fn(), - input_vocab_size=256, - pos_axial_shape=None, - ) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - pred_model.init(input_signature=(shape1l, shape11)) - - # 0w0w - inputs = np.array( - [[0, 3, 7, 5, 3, 2, 4, 1, 8, 0, 3, 7, 5, 3, 2, 4, 1, 8]], - dtype=np.int32) - inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], - mode='constant', constant_values=0) - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) - - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 10) - - def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self): - gin.add_config_file_search_path(_CONFIG_DIR) - max_len = 32 # 32 is the max length we trained the checkpoint for. - test_lengths = [8, 16, 32] - vocab_size = 13 - # The checkpoint is correct on ~90% sequences, set random seed to deflake. - np.random.seed(0) - for test_len in test_lengths: - gin.clear_config() - gin.parse_config_file('terraformer_purelsh_copy.gin') - gin.bind_parameter('PureLSHSelfAttention.predict_mem_len', 2 * max_len) - gin.bind_parameter('PureLSHSelfAttention.predict_drop_len', 2 * max_len) - gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False) - gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2) - - pred_model = models.ConfigurableTerraformer(mode='predict') - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - - model_path = os.path.join(_TESTDATA, 'terraformer_purelsh_copy.pkl.gz') - pred_model.init_from_file(model_path, weights_only=True, - input_signature=(shape1l, shape11)) - initial_state = pred_model.state - - for _ in range(2): # Set low to make the test run reasonably fast. - # Pick a length in [1, test_len] at random. - inp_len = np.random.randint(low=1, high=test_len + 1) - inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, max_len)) - # TODO(jaszczur): properly fix padding in terraformer predict mode, - # and add a test here. - s = decoding.autoregressive_sample( - pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, - temperature=0.0) - - np.testing.assert_equal(s[0], inputs[0, :inp_len]) - pred_model.state = initial_state - gin.clear_config() # Make sure to not affect other tests. - - -if __name__ == '__main__': - config.config_with_absl() - test.main() diff --git a/trax/supervised/decoding_timing_test.py b/trax/supervised/decoding_timing_test.py deleted file mode 100644 index 48faf156e..000000000 --- a/trax/supervised/decoding_timing_test.py +++ /dev/null @@ -1,439 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Timing tests for decoding.""" - -import copy -import functools -import gc -import os -import time -from jax.config import config -import numpy as np -import psutil -from tensorflow.compat.v2 import test - -from trax import fastmath -from trax import layers as tl -from trax import models -from trax import shapes -from trax.supervised import decoding - - -def _size_of_model(model): - def _size(x): - try: - return x.size - except Exception: # pylint: disable=broad-except - return 0 - sizes = fastmath.nested_map(_size, model.weights) - total_size = sum(fastmath.tree_flatten(sizes)) - return total_size - - -def _recurrent_delete(w): - if 'delete' in dir(w): - # Object has a 'delete' method, so it is a DeviceArray or something similar, - # so we want to delete it. - w.delete() - elif isinstance(w, (list, tuple)): - for x in w: - _recurrent_delete(x) - elif isinstance(w, dict): - for x in w.values(): - _recurrent_delete(x) - else: - raise ValueError('Unknown type encountered in weights: {}'.format(type(w))) - - -def _memory_usage(): - gc.collect() - return psutil.Process(os.getpid()).memory_info().rss - - -class DecodingTimingTest(test.TestCase): - - def _terraformer_decoding_time(self, settings): - # Garbage collection influences the timing, so we turn it off. - gc.disable() - max_len = 16 - - def _self_attention_fn(): - return functools.partial( - tl.SelfAttention, - predict_drop_len=2 * max_len, - predict_mem_len=2 * max_len) - - def _causal_attention_fn(): - attn_layer, attn_kwargs = settings['attn'] - return functools.partial( - attn_layer, - max_inference_length=2 * max_len, **attn_kwargs) - - if settings['model'] == 'terraformer': - pred_model = models.ConfigurableTerraformer( - mode='predict', - d_model=settings['d_model'], - d_ff=settings['d_ff'], - dropout=0.1, - max_len=max_len, - n_heads=settings['n_heads'], - n_encoder_layers=settings['encoder_layers'], - n_decoder_layers=settings['decoder_layers'], - encoder_attention_type=_self_attention_fn(), - encoder_decoder_attention_type=_causal_attention_fn(), - input_vocab_size=settings['vocab'], - ff_sparsity=settings['ff_sparsity'], - ff_use_sru=settings['ff_use_sru'], - ff_dropout=0.1, - # ff_chunk_size=1024, - # attention_chunk_size=1, - n_decoder_attention_layers=settings['attention_layers'], - loss_sparsity=settings['loss_sparsity'], - pos_axial_shape=None, - use_bfloat16=True, - ) - elif settings['model'] == 'transformer': - pred_model = models.ConfigurableTransformer( - mode='predict', - d_model=settings['d_model'], - d_ff=settings['d_ff'], - dropout=0.1, - max_len=max_len, - n_heads=settings['n_heads'], - n_encoder_layers=settings['encoder_layers'], - n_decoder_layers=settings['decoder_layers'], - # encoder_attention_type=_self_attention_fn(), - encoder_decoder_attention_type=_causal_attention_fn(), - input_vocab_size=settings['vocab'], - ff_sparsity=settings['ff_sparsity'], - ff_use_sru=settings['ff_use_sru'], - # ff_dropout=0.1, - # ff_chunk_size=1024, - # attention_chunk_size=1, - # n_decoder_attention_layers=settings['attention_layers'], - loss_sparsity=settings['loss_sparsity'], - pos_axial_shape=None, - # enc_dec_attention_sparsity=settings['enc_dec_sparsity'], - # use_bfloat16=True, - ) - else: - assert False - # We put acceleration outside of autoregressive_sample_stream, because - # we want to have a separate run (separate input) for model compilation. - pred_model = tl.Accelerate(pred_model) - - shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) - shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) - pred_model.init(input_signature=(shape1l, shape11)) - original_state = copy.deepcopy(pred_model.state) - - inputs_warmup = np.zeros((1, max_len), dtype=np.int32) - inputs = np.arange(max_len, dtype=np.int32).reshape(1, max_len) - - # This is a warm-up run, for compilation. - result, current_time = [], time.time() - elapsed_warmup_times = [] - for index, sample in zip(range(0, 4), decoding.autoregressive_sample_stream( - pred_model, inputs_warmup, temperature=0.0, accelerate=False)): - del index # unused - result.append(sample[:, None]) # to be sure that the result is computed - - current_time, start_time = time.time(), current_time - elapsed_warmup_times.append(current_time - start_time) - - # This is a real decoding timing run that we measure. - pred_model.state = original_state - result, current_time = [], time.time() - elapsed_times = [] - for index, sample in zip(range(12), decoding.autoregressive_sample_stream( - pred_model, inputs, temperature=0.0, accelerate=False)): - del index # unused - result.append(sample[:, None]) # to be sure that the result is computed - - current_time, start_time = time.time(), current_time - elapsed_times.append(current_time - start_time) - peak_memory = _memory_usage() - - if min(elapsed_times[2:]) * 2 < max(elapsed_times[2:]): - print('WARNING! High variance found in elapsed times! Settings: {} ; ' - 'elapsed times: {} ; Probably more warm-up steps should be used, ' - 'or model size should be increased.'.format(settings, - elapsed_times)) - # Check resulting shapes. - s = np.concatenate(result, axis=1) - self.assertEqual(s.shape[0], 1) - self.assertEqual(s.shape[1], 12) - model_size = int(_size_of_model(pred_model)) - - # We delete the model weights, because in some situations they won't be - # deleted automatically. - _recurrent_delete(pred_model.weights) - gc.enable() - return model_size, elapsed_times, peak_memory - - def test_autoregressive_sample_terraformer_timing(self): - template_to_use = 'medium_transformer' - - settings_templates = { - # full model - # # 54B params - # 'full_model': { - # 'encoder_layers': 6, 'decoder_layers': 36, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': 64*1024, 'd_model': 96*96, 'n_heads': 96, - # 'ff_use_sru': (1, 64), 'ff_sparsity': (256, 32), - # 'loss_sparsity': 8, - # 'attn': (tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 64})}, - - # 1/18 of model (1/6 of encoder, 1/18 of decoder, full vocab) - # 4B params - # 'big_terraformer': { - # 'model': 'terraformer', - # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, - # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - # 'attn': (tl.CausalAttention, {})}, - - # 'big_transformer': { - # 'model': 'transformer', - # 'encoder_layers': 1, 'decoder_layers': 2, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': int(5/8 * 64*1024), 'd_model': 96*96, 'n_heads': 96, - # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - # 'attn': (tl.CausalAttention, {})}, - - # medium model - # 275M params (only decoder) - 'medium_transformer': { - 'model': 'transformer', - 'encoder_layers': 2, 'decoder_layers': 24, 'vocab': 32000, - 'attention_layers': 2, - 'd_ff': 4*1024, 'd_model': 1024, 'n_heads': 16, - 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - 'attn': (tl.CausalAttention, {})}, - # 'medium_terraformer': { - # 'model': 'terraformer', - # 'encoder_layers': 2, 'decoder_layers': 24, 'vocab': 32000, - # 'attention_layers': 2, - # 'd_ff': 4*1024, 'd_model': 1024, 'n_heads': 16, - # 'ff_use_sru': 0, 'ff_sparsity': 0, 'loss_sparsity': 0, - # 'attn': (tl.CausalAttention, {})}, - - } - - sweep_settings = { - # 'big_transformer': [ # for big - # dict(), # baseline - # {'ff_sparsity': (256, 32)}, # + Sparse FF - # {'attn': ( # + Sparse QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 64}), - # 'd_ff': 64*1024, - # }, - # {'ff_sparsity': (256, 32), - # 'attn': ( # + Sparse FF+QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 64}), - # 'd_ff': 64*1024, - # }, - # ], - - 'medium_transformer': [ # for medium - dict(), # baseline - - {'ff_sparsity': 64, - 'attn': ( # Sparse FF+QKV - tl.MultiplicativeConvCausalAttention, - {'length_kernel_size': 3, 'sparsity': 16}), - 'd_ff': 6*1024, - }, - - # {'ff_sparsity': 64, # Sparse FF+QKV + Loss - # 'attn': ( - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # 'loss_sparsity': 4, - # }, - - # {'attn': ( # Sparse QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # }, - # {'loss_sparsity': 4}, # Sparse Loss - # {'ff_sparsity': 64}, # Sparse FF - - # {'ff_sparsity': 128}, # + Sparse FF 128 - - # APPENDIX below - - # different loss layers - # {'loss_sparsity': 8}, - # {'loss_sparsity': 2}, - # {'loss_sparsity': 0}, - ], - - # 'big_terraformer': [ # for big terraformer - # dict(), # baseline - # {'ff_sparsity': 64}, # + Sparse FF / Sparse FF 64 - # {'ff_sparsity': 64, - # 'attn': ( # + Sparse FF+QKV - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # }, - # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss - # 'attn': ( - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # 'loss_sparsity': 4, - # }, - - # ], - - # 'medium_terraformer': [ # for medium terraformer - # {'ff_sparsity': 64, # + Sparse FF+QKV+Loss - # 'attn': ( - # tl.MultiplicativeConvCausalAttention, - # {'length_kernel_size': 3, 'sparsity': 16}), - # 'd_ff': 6*1024, - # 'loss_sparsity': 4, - # }, - # ], - } - - encoding_times = [] - decoding_times = [] - sizes = [] - memories = [] - messages = [] - for override_settings in sweep_settings[template_to_use]: - settings = copy.deepcopy(settings_templates[template_to_use]) - settings.update(override_settings) - - init_memory = _memory_usage() - size, elapsed_times, peak_memory = ( - self._terraformer_decoding_time(settings)) - - # TODO(jaszczur): Why is elapsed_times[0] always small? - encoding_time = elapsed_times[1] - decoding_time_10 = sum(elapsed_times[2:]) - - after_memory = _memory_usage() - model_memory_gigabytes = (peak_memory-init_memory)/1024**3 - decoding_time_diff = (max(elapsed_times[2:]) - min(elapsed_times[2:])) / 2 - decoding_time_diff_percent = int( - decoding_time_diff / np.mean(elapsed_times) * 100) - message = ( - '\n\n' - 'Params: {}\n' - 'Settings: {}\n' - 'Override: {}\n' - 'Init memory: {:.1f} GiB\n' - 'Peak memory: {:.1f} GiB\n' - 'After memory: {:.1f} GiB\n' - 'Estimated model memory: {:.1f} GiB\n' - 'Times for each step: {}\n' - 'Time for encoding: {:.4f} s\n' - 'Time for decoding 10 tokens: {:.4f} s +/- {} %\n' - '\n\n' - .format(size, settings, override_settings, - init_memory/1024**3, peak_memory/1024**3, - after_memory/1024**3, model_memory_gigabytes, - elapsed_times, encoding_time, - decoding_time_10, decoding_time_diff_percent)) - print(message) - messages.append(message) - encoding_times.append(encoding_time) - decoding_times.append(decoding_time_10) - sizes.append(size) - memories.append(model_memory_gigabytes) - - print('Final results (recap):') - for message in messages: - print(message) - - # This is useful for copying results into a spreadsheet etc. - # for i in range(len(sweep_settings)): - # print('{}\t{}\t{}\t{:.1f}'.format( - # sizes[i], encoding_times[i], decoding_times[i], memories[i])) - - def test_loss_layer_timing(self): - all_settings = [ - # The first run is sometimes slower, less reliable. - {'output': 32000, 'input': 2048, 'prob': None, - 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, - - {'output': 32000, 'input': 2048, 'prob': None, - 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': False}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': False}, - - {'output': 32000, 'input': 2048, 'prob': None, - 'type': None, 'sparsity': 0, 'lowrank': 0, 'use_bias': True}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'einsum', 'sparsity': 0, 'lowrank': 0, 'use_bias': True}, - {'output': 32000, 'input': 2048, 'prob': None, - 'type': 'mult', 'sparsity': 2, 'lowrank': 0, 'use_bias': True}, - ] - - messages = [] - for settings in all_settings: - pred_model = tl.SparseDenseWithOptions( - n_units=settings['output'], - d_input=settings['input'], - sparsity_type=settings['type'], - sparsity=settings['sparsity'], - d_lowrank=settings['lowrank'], - prob_sparse=settings['prob'], - use_bias=settings['use_bias'], - mode='predict', - ) - pred_model = tl.Accelerate(pred_model) - - shape1l = shapes.ShapeDtype((1, settings['input'])) - pred_model.init(input_signature=shape1l) - inputs = np.ones((1, settings['input'])) - - total_time = 0.0 - for counter in range(-50, 100): - start_time = time.time() - y = pred_model(inputs) - self.assertEqual(y.shape, (1, settings['output'])) - elapsed_time = time.time() - start_time - if counter >= 0: - total_time += elapsed_time - - message = ( - '\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n' - % (_size_of_model(pred_model), settings, total_time)) - messages.append(message) - print(message) - - print('Final results (recap):') - for message in messages: - print(message) - - -if __name__ == '__main__': - config.config_with_absl() - test.main() diff --git a/trax/supervised/history.py b/trax/supervised/history.py deleted file mode 100644 index 910b2ce33..000000000 --- a/trax/supervised/history.py +++ /dev/null @@ -1,88 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Trax history.""" -import collections -import copy - -from absl import logging -import six - - -class History: - """History of metrics. - - History contains the metrics recorded during training and evaluation. - Save data with history.append and get a sequence of data by calling - history.get. - - For example: - history.append('train', 'metrics/accuracy', 1, 0.04) - history.append('train', 'metrics/accuracy', 1000, 0.31) - history.get('train', 'metrics/accuracy') - # returns [(1, 0.04), (1000, 0.31)] - """ - - def __init__(self): - # Structure is - # values = { - # 'mode1': { - # 'metric1': [val1, val2], - # ... - # }, - # 'mode2': ... - # } - self._values = {} - - @classmethod - def from_dict(cls, json_object): - """Constructs a `History` from a Python dictionary of parameters.""" - history = History() - for (key, value) in six.iteritems(json_object): - history.__dict__[key] = value - return history - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def append(self, mode, metric, step, value): - """Append (step, value) pair to history for the given mode and metric.""" - if mode not in self._values: - self._values[mode] = collections.defaultdict(list) - self._values[mode][metric].append((step, value)) - - def get(self, mode, metric): - """Get the history for the given metric and mode.""" - if mode not in self._values: - logging.info('Metric %s not found for mode %s', metric, mode) - return [] - return list(self._values[mode][metric]) - - @property - def modes(self): - """Current tracked modes.""" - return sorted(list(self._values.keys())) - - def metrics_for_mode(self, mode): - """Metrics available for a given mode.""" - if mode not in self._values: - logging.info('Mode %s not found', mode) - return [] - return sorted(list(self._values[mode].keys())) - - def __str__(self): - return str(self._values) diff --git a/trax/supervised/history_test.py b/trax/supervised/history_test.py deleted file mode 100644 index 3aee06a64..000000000 --- a/trax/supervised/history_test.py +++ /dev/null @@ -1,56 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.supervised.history.""" - -from absl.testing import absltest - -from trax.supervised import history as trax_history - - -class HistoryTest(absltest.TestCase): - - def test_unknown_mode(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - self.assertEqual(history.get('unknown_mode', 'metric1'), []) - - def test_unknown_metric(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - self.assertEqual(history.get('train', 'unknown_metric'), []) - - def test_serializer_and_deserializer(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - json_object = history.to_dict() - history2 = trax_history.History.from_dict(json_object) - self.assertEqual(history2.get('train', 'metric1'), [(1, 0.1)]) - - def test_modes(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - history.append('test', 'metric2', 2, 0.2) - self.assertEqual(history.modes, ['test', 'train']) - - def test_metrics_for_mode(self): - history = trax_history.History() - history.append('train', 'metric1', 1, 0.1) - history.append('train', 'metric2', 2, 0.2) - self.assertEqual(history.metrics_for_mode('train'), ['metric1', 'metric2']) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/lr_schedules.py b/trax/supervised/lr_schedules.py deleted file mode 100644 index c58f53d26..000000000 --- a/trax/supervised/lr_schedules.py +++ /dev/null @@ -1,229 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Learning rate (LR) schedules. - -In Trax a learning rate schedule is a function: -:math:`\text{step} \mapsto \text{learning_rate}`. -This module provides helpers for constructing such functions. For example:: - - constant(0.001) - -returns a function that always returns `0.001`. -""" - -import math -import gin -from trax.fastmath import numpy as jnp - - -@gin.configurable -def constant(value): - """Returns an LR schedule that is constant from time (step) 1 to infinity.""" - return _BodyAndTail(value, body_start=1) - - -@gin.configurable -def warmup(n_warmup_steps, max_value): - """Returns an LR schedule with linear warm-up followed by constant value. - - Args: - n_warmup_steps: Number of steps during which the learning rate rises on - a line connecting (0, 0) and (n_warmup_steps, max_value). - max_value: Value for learning rate after warm-up has finished. - """ - return _BodyAndTail(max_value, body_start=n_warmup_steps + 1) - - -@gin.configurable -def warmup_and_rsqrt_decay(n_warmup_steps, max_value): - """Returns an LR schedule with warm-up + reciprocal square root decay.""" - return _BodyAndTail(max_value, tail_start=n_warmup_steps + 1, tail_fn=_rsqrt) - - -@gin.configurable -def multifactor(factors='constant * linear_warmup * rsqrt_decay', - constant=0.1, # pylint: disable=redefined-outer-name - warmup_steps=400, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000, - second_constant=0.01, - second_constant_step=10000, - minimum=0): - """Factor-based learning rate schedule. - - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_deay: Cyclic cosine decay, uses steps_per_cycle parameter. - * two_constants: constant until second_constant_step, then switch to - second_constant. - - Args: - factors: a string with factors separated by '*' that defines the schedule. - constant: float, the starting constant for the learning rate schedule. - warmup_steps: how many steps to warm up for in the warmup schedule. - decay_factor: The amount to decay the learning rate by. - steps_per_decay: How often to decay the learning rate. - steps_per_cycle: Steps per cycle when using cosine decay. - second_constant: float, the second constant for the learning rate schedule. - second_constant_step: the step when the second_constant is triggered. - minimum: if the computed rate is below the minimum, then return the minimum. - - Returns: - a function learning_rate(step): float -> {'learning_rate': float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def learning_rate(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= constant - elif name == 'two_constants': - if step < second_constant_step: - ret *= constant - else: - ret *= second_constant - elif name == 'linear_warmup': - ret *= jnp.minimum(1.0, step / warmup_steps) - elif name == 'rsqrt_decay': - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= jnp.sqrt(warmup_steps) - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= (decay_factor ** (step//steps_per_decay)) - elif name == 'cosine_decay': - progress = jnp.maximum( - 0.0, (step - warmup_steps) / float(steps_per_cycle)) - ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) - else: - raise ValueError('Unknown factor %s.' % name) - # TODO(henrykm): return float(jnp.max(minimum, ret)) would be - # better but causes TypeError: 'numpy.float64' object cannot - # be interpreted as an integer - if ret <= minimum: - return minimum - return ret - - return learning_rate - - -class _BodyAndTail: - """Defines a curve over time as a linear ramp + constant body + curvy tail. - - The body is a span of constant learning rate, and can be the entire curve. - The warm-up, if present, is based on the line connecting points (0, 0) and - (body_start, body_value). The tail, if defined, is a function from time to - learning rate that is used for all training steps from tail_start on. - """ - - def __init__( - self, body_value, body_start=None, tail_start=None, tail_fn=None): - """Specifies a body-and-tail time curve. - - Args: - body_value: Constant learning rate for the body of the curve (after - warm-up and before tail). Also is the reference (maximum) value for - calculating warm-up values and tail values. - body_start: Training step number at which the body starts. If None, takes - its value from tail_start, which amounts to there being no body. All - steps from 1 to body_start - 1 are computed using a linear warm-up. - tail_start: Training step number at which the tail starts. If None, the - body value remains until the end of training. - tail_fn: Function returning a floating point learning rate, given inputs: - - step_number (absolute step number from the start of training) - - tail_start (step number at which the tail starts) - - body_value (value relative to which the tail should be computed) - """ - if body_start is None and tail_start is None: - raise ValueError('Both body start and tail start are None.') - if tail_start is not None and tail_fn is None: - raise ValueError( - f'Tail start has value ({tail_start}) but tail_fn is None.') - if body_start is None: - body_start = tail_start if tail_start is not None else 1 - - self._body_value = body_value - self._body_start = body_start - self._tail_start = tail_start - self._tail_fn = tail_fn - - def __call__(self, step_number): - """Returns the learning rate for the given step number.""" - if step_number < self._body_start: - return (step_number / self._body_start) * self._body_value - elif self._tail_start is not None and step_number >= self._tail_start: - return self._tail_fn(step_number, self._tail_start, self._body_value) - else: - return self._body_value - - -def _rsqrt(step_number, tail_start, body_value): - """Computes a tail using a scaled reciprocal square root of step number. - - Args: - step_number: Absolute step number from the start of training. - tail_start: Step number at which the tail of the curve starts. - body_value: Value relative to which the tail should be computed. - - Returns: - A learning rate value that falls as the reciprocal square root of the step - number, scaled so that it joins smoothly with the body of a BodyAndTail - instance. - """ - return body_value * (math.sqrt(tail_start) / math.sqrt(step_number)) - - -class _CosineSawtoothTail: - """Cosine-sawtooth-shaped tail that simulates warm restarts. - - Creates a cyclic learning rate curve; each cycle is half of a cosine, falling - from maximum value to minimum value. For motivation and further details, see - Loshchilov & Hutter (2017) [https://arxiv.org/abs/1608.03983]. - """ - - def __init__(self, steps_per_cycle, min_value=1e-5): - """Configures the periodic behavior of this learning rate function. - - Args: - steps_per_cycle: Number of training steps per sawtooth cycle. The - learning rate will be highest at the start of each cycle, and lowest - at the end. - min_value: Minimum value, reached at the end of each cycle. - """ - self._steps_per_cycle = steps_per_cycle - self._min_value = min_value - - def __call__(self, step_number, tail_start, body_value): - """Returns the learning rate for the given step number, when in the tail. - - Args: - step_number: Absolute step number from the start of training. - tail_start: Step number at which the tail of the curve starts. - body_value: Value relative to which the tail should be computed. - """ - max_value = body_value - min_value = self._min_value - position_in_cycle = ( - ((step_number - tail_start) / self._steps_per_cycle) % 1.0) - theta = math.pi * position_in_cycle - return min_value + (max_value - min_value) * .5 * (1 + math.cos(theta)) diff --git a/trax/supervised/lr_schedules_test.py b/trax/supervised/lr_schedules_test.py deleted file mode 100644 index 1973686bf..000000000 --- a/trax/supervised/lr_schedules_test.py +++ /dev/null @@ -1,95 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests of learning rate schedules.""" - -import math - -from absl.testing import absltest - -from trax.supervised import lr_schedules - - -class LRFunctionsTest(absltest.TestCase): - - def test_warmup(self): - lr_fn = lr_schedules.warmup(9, .01) - - # Linear warm-up. - self.assertAlmostEqual(.001, lr_fn(1)) - self.assertAlmostEqual(.002, lr_fn(2)) - self.assertAlmostEqual(.005, lr_fn(5)) - self.assertAlmostEqual(.009, lr_fn(9)) - - # Constant thereafter. - self.assertAlmostEqual(.01, lr_fn(10)) - self.assertAlmostEqual(.01, lr_fn(11)) - self.assertAlmostEqual(.01, lr_fn(20)) - self.assertAlmostEqual(.01, lr_fn(300)) - self.assertAlmostEqual(.01, lr_fn(4000)) - - def test_constant(self): - lr_fn = lr_schedules.constant(.02) - self.assertEqual(.02, lr_fn(1)) - self.assertEqual(.02, lr_fn(20)) - self.assertEqual(.02, lr_fn(300)) - self.assertEqual(.02, lr_fn(4000)) - self.assertEqual(.02, lr_fn(50000)) - self.assertEqual(.02, lr_fn(600000)) - self.assertEqual(.02, lr_fn(7000000)) - self.assertEqual(.02, lr_fn(80000000)) - self.assertEqual(.02, lr_fn(900000000)) - - def test_warmup_and_rsqrt_decay(self): - lr_fn = lr_schedules.warmup_and_rsqrt_decay(24, .25) - - # Warm-up. - self.assertAlmostEqual(.01, lr_fn(1)) - self.assertAlmostEqual(.02, lr_fn(2)) - self.assertAlmostEqual(.23, lr_fn(23)) - self.assertAlmostEqual(.24, lr_fn(24)) - - # Reciprocal square-root decay. - self.assertAlmostEqual(.25 * (5 / math.sqrt(25)), lr_fn(25)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(26)), lr_fn(26)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(27)), lr_fn(27)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(300)), lr_fn(300)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(4000)), lr_fn(4000)) - self.assertAlmostEqual(.25 * (5 / math.sqrt(50000)), lr_fn(50000)) - - def test_cosine_sawtooth(self): - tail_fn = lr_schedules._CosineSawtoothTail(180, min_value=.1) - lr_fn = lr_schedules._BodyAndTail(.3, tail_start=0, tail_fn=tail_fn) - - # First cycle - self.assertAlmostEqual(.29998477, lr_fn(1)) - self.assertAlmostEqual(.28660254, lr_fn(30)) - self.assertAlmostEqual(.25, lr_fn(60)) - self.assertAlmostEqual(.20, lr_fn(90)) - self.assertAlmostEqual(.15, lr_fn(120)) - self.assertAlmostEqual(.10001523, lr_fn(179)) - - # Second cycle - self.assertEqual(.3, lr_fn(180)) - self.assertAlmostEqual(.29998477, lr_fn(181)) - self.assertAlmostEqual(.28660254, lr_fn(210)) - self.assertAlmostEqual(.25, lr_fn(240)) - self.assertAlmostEqual(.20, lr_fn(270)) - self.assertAlmostEqual(.15, lr_fn(300)) - self.assertAlmostEqual(.10001523, lr_fn(359)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/mnist_test.py b/trax/supervised/mnist_test.py deleted file mode 100644 index 41fed682f..000000000 --- a/trax/supervised/mnist_test.py +++ /dev/null @@ -1,164 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test training an MNIST model 100 steps (saves time vs. 2000 steps).""" - -import io -import itertools -from unittest import mock - -from absl.testing import absltest - -from trax import layers as tl -from trax.data import inputs -from trax.data import tf_inputs -from trax.optimizers import adam -from trax.supervised import training - - -class MnistTest(absltest.TestCase): - - @mock.patch('sys.stdout', new_callable=io.StringIO) - def test_train_mnist_single_task(self, mock_stdout): - """Train MNIST model a bit, to compare to other implementations.""" - mnist_model = _build_model(two_heads=False) - (task, eval_task) = _mnist_tasks() - training_session = training.Loop( - mnist_model, - tasks=[task], - eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 20 == 0, - ) - - training_session.run(n_steps=100) - self.assertEqual(training_session.step, 100) - - # Assert that we reach at least 80% eval accuracy. - self.assertGreater(_read_metric('WeightedCategoryAccuracy', mock_stdout), - 0.8) - - @mock.patch('sys.stdout', new_callable=io.StringIO) - def test_train_mnist_multitask(self, mock_stdout): - """Train two-head MNIST model a bit, to compare to other implementations.""" - mnist_model = _build_model(two_heads=True) - # MNIST classification task. - (cls_task, cls_eval_task) = _mnist_tasks(head=tl.Select([0], n_in=2)) - # Auxiliary brightness prediction task. - reg_task = training.TrainTask( - itertools.cycle(_mnist_brightness_dataset().train_stream(1)), - tl.Serial(tl.Select([1]), tl.L2Loss()), - adam.Adam(0.001), - ) - reg_eval_task = training.EvalTask( - itertools.cycle(_mnist_brightness_dataset().eval_stream(1)), - [tl.Serial(tl.Select([1]), tl.L2Loss())], - n_eval_batches=1, - metric_names=['L2'], - ) - training_session = training.Loop( - mnist_model, - tasks=[cls_task, reg_task], - eval_tasks=[cls_eval_task, reg_eval_task], - eval_at=lambda step_n: step_n % 20 == 0, - which_task=lambda step_n: step_n % 2, - ) - - training_session.run(n_steps=100) - self.assertEqual(training_session.step, 100) - - # Assert that we reach at least 80% eval accuracy on MNIST. - self.assertGreater(_read_metric('WeightedCategoryAccuracy', mock_stdout), - 0.8) - # Assert that we get below 0.03 brightness prediction error. - self.assertLess(_read_metric('L2', mock_stdout), 0.03) - - -def _build_model(two_heads): - cls_head = tl.Dense(10) - if two_heads: - reg_head = tl.Dense(1) - heads = tl.Branch(cls_head, reg_head) - else: - heads = cls_head - return tl.Serial( - tl.Fn('ScaleInput', lambda x: x / 255), - tl.Flatten(), - tl.Dense(512), - tl.Relu(), - tl.Dense(512), - tl.Relu(), - heads, - ) - - -def _mnist_dataset(): - """Loads (and caches) the standard MNIST data set.""" - streams = tf_inputs.data_streams('mnist') - return inputs.batcher(streams, variable_shapes=False, - batch_size_per_device=256, - eval_batch_size=256) - - -def _mnist_brightness_dataset(): - """Loads (and caches) a MNIST mean brightness data set.""" - def preprocess_stream(stream): - def new_stream(): - for (image, _) in stream(): - yield (image, (image / 255).mean()[None]) - return new_stream - - streams = tuple(map(preprocess_stream, tf_inputs.data_streams('mnist'))) - return inputs.batcher(streams, variable_shapes=False, - batch_size_per_device=256, - eval_batch_size=256) - - -def _mnist_tasks(head=None): - """Creates MNIST training and evaluation tasks. - - Args: - head: Adaptor layer to put before loss and accuracy layers in the tasks. - - Returns: - A pair (train_task, eval_task) consisting of the MNIST training task and the - MNIST evaluation task using cross-entropy as loss and accuracy as metric. - """ - loss = tl.WeightedCategoryCrossEntropy() - accuracy = tl.WeightedCategoryAccuracy() - if head is not None: - loss = tl.Serial(head, loss) - accuracy = tl.Serial(head, accuracy) - task = training.TrainTask( - itertools.cycle(_mnist_dataset().train_stream(1)), - loss, - adam.Adam(0.001), - ) - eval_task = training.EvalTask( - itertools.cycle(_mnist_dataset().eval_stream(1)), - [loss, accuracy], - n_eval_batches=10, - metric_names=['CrossEntropy', 'WeightedCategoryAccuracy'], - ) - return (task, eval_task) - - -def _read_metric(metric_name, stdout): - log = stdout.getvalue() - metric_log = [line for line in log.split('\n') if metric_name in line][-1] - return float(metric_log.strip().split(' ')[-1]) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/supervised/pretrain_finetune.py b/trax/supervised/pretrain_finetune.py deleted file mode 100644 index 3bd7b784e..000000000 --- a/trax/supervised/pretrain_finetune.py +++ /dev/null @@ -1,193 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""data processing for BERT. - -For now, this file only supports fine-tuning bert-base-uncased on GLUE. - -TODO(afrozm): Move this into data/ -""" -import functools - -import gin -import numpy as onp - -import tensorflow_datasets as tfds -from trax.data.inputs import Inputs - - -def _tfds_stream(n_devices, - dataset_name, - split, - batch_size, - data_dir, - shuffle_files, - shuffle_buffer_size, - batch_shuffle_size, - preprocess_fun, - repeat=True): - """Streams batches of examples from tfds, with pure-python preprocessing.""" - # TODO(piotrekp1): delete if switched to data_streams - if batch_size % n_devices != 0: - raise ValueError(f'Batch size ({batch_size}) not divisible' - ' by number of devices ({n_devices})') - ds = tfds.load( - name=dataset_name, - split=split, - data_dir=data_dir, - shuffle_files=shuffle_files) - if repeat: - ds = ds.repeat() - if shuffle_buffer_size is not None: - ds = ds.shuffle(shuffle_buffer_size) - ds = ds.batch(batch_size) - if batch_shuffle_size is not None: - ds = ds.shuffle(batch_shuffle_size) - - for batch in tfds.as_numpy(ds): - if preprocess_fun is not None: - yield preprocess_fun(batch) - else: - yield batch - - -@gin.configurable -def tfds_inputs( - dataset_name, - preprocess_fun, - batch_size, - eval_batch_size=None, - data_dir=None, - train_split=tfds.Split.TRAIN, - eval_split=tfds.Split.VALIDATION, - shuffle_buffer_size=1024, - batch_shuffle_size=128, -): - """Tensorflow Datasets input pipeline, with pure-python preprocessing.""" - if eval_batch_size is None: - eval_batch_size = batch_size - return Inputs( - train_stream=functools.partial( - _tfds_stream, - dataset_name=dataset_name, - split=train_split, - batch_size=batch_size, - data_dir=data_dir, - shuffle_files=True, - shuffle_buffer_size=shuffle_buffer_size, - batch_shuffle_size=batch_shuffle_size, - preprocess_fun=preprocess_fun, - ), - eval_stream=functools.partial( - _tfds_stream, - dataset_name=dataset_name, - split=eval_split, - batch_size=eval_batch_size, - data_dir=data_dir, - shuffle_files=False, - shuffle_buffer_size=None, - batch_shuffle_size=None, - preprocess_fun=preprocess_fun, - ), - ) - - -@gin.configurable -def bert_tokenizer(vocab_path=None): - """Constructs a BERT tokenizer.""" - # This import is from https://github.com/google-research/bert which is not - # listed as a dependency in trax. - # TODO(piotrekp1): using SubwordTextEncoder instead after fixing the - # differences - from bert.tokenization.bert_tokenization import FullTokenizer # pylint: disable=g-import-not-at-top - if vocab_path is None: - raise ValueError('vocab_path is required to construct the BERT tokenizer.') - tokenizer = FullTokenizer(vocab_path, do_lower_case=True) - return tokenizer - - -def bert_preprocess(batch, tokenizer, key_a, key_b=None, max_len=128): - """Tokenize and convert text to model inputs in a BERT format.""" - batch_size = batch['idx'].shape[0] - input_ids = onp.zeros((batch_size, max_len), dtype=onp.int32) - type_ids = onp.zeros((batch_size, max_len), dtype=onp.int32) - for i in range(batch_size): - sentence_a = batch[key_a][i] - tokens_a = [101] + tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(sentence_a)) + [102] - - if key_b is not None: - sentence_b = batch[key_b][i] - tokens_b = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(sentence_b)) + [102] - else: - tokens_b = [] - - ex_input_ids = (tokens_a + tokens_b)[:max_len] - ex_type_ids = ([0] * len(tokens_a) + [1] * len(tokens_b))[:max_len] - - input_ids[i, :len(ex_input_ids)] = ex_input_ids - type_ids[i, :len(ex_type_ids)] = ex_type_ids - return input_ids, type_ids, input_ids > 0, batch['label'], onp.ones( - batch_size) - - -@gin.configurable -def glue_inputs(dataset_name=gin.REQUIRED, - batch_size=16, - eval_batch_size=None, - data_dir=None, - max_len=128, - tokenizer=bert_tokenizer): - """Input pipeline for fine-tuning BERT on GLUE tasks.""" - if callable(tokenizer): # If we pass a function, e.g., through gin, call it. - tokenizer = bert_tokenizer() - - eval_split = tfds.Split.VALIDATION - if dataset_name == 'glue/mnli': - eval_split = 'validation_matched' - # TODO(kitaev): Support diagnostic dataset (AX) - - keys_lookup = { - 'glue/cola': ('sentence', None), - 'glue/sst2': ('sentence', None), - 'glue/mrpc': ('sentence1', 'sentence2'), - 'glue/qqp': ('question1', 'question2'), - 'glue/stsb': ('sentence1', 'sentence2'), - 'glue/mnli': ('premise', 'hypothesis'), # TODO(kitaev): swap the two? - 'glue/qnli': ('question', 'sentence'), # TODO(kitaev) swap the two? - 'glue/rte': ('sentence1', 'sentence2'), - 'glue/wnli': ('sentence1', 'sentence2'), - } - - key_a, key_b = keys_lookup[dataset_name] - - preprocess_fn = functools.partial( - bert_preprocess, - tokenizer=tokenizer, - key_a=key_a, - key_b=key_b, - max_len=max_len) - return tfds_inputs( # TODO(piotrekp1): use data_streams instead - dataset_name=dataset_name, - preprocess_fun=preprocess_fn, - batch_size=batch_size, - eval_batch_size=eval_batch_size, - data_dir=data_dir, - train_split=tfds.Split.TRAIN, - eval_split=eval_split) - - -# TODO(piotrekp1): add glue evaluation diff --git a/trax/supervised/trainer_lib.py b/trax/supervised/trainer_lib.py deleted file mode 100644 index 4ffef14f3..000000000 --- a/trax/supervised/trainer_lib.py +++ /dev/null @@ -1,956 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Original API for supervised learning/training in Trax. - -Trax authors expect that the `supervised.training` module (under development) -will replace `trainer_lib`. -""" - -import collections -import functools -import itertools -import os -import sys -import time - -from absl import logging - -import gin - -import jax -import tensorflow.compat.v2 as tf -from trax import fastmath -from trax import jaxboard -from trax import layers as tl -from trax import optimizers as trax_opt -from trax.data import inputs as trax_inputs -from trax.fastmath import numpy as np -from trax.fastmath import random as jax_random -from trax.layers import base -from trax.shapes import ShapeDtype -from trax.supervised import history as trax_history -from trax.supervised import lr_schedules as lr -from trax.supervised import training - - -# TODO(afrozm): Maybe flatten everything from OptState into TrainerState. -TrainerState = collections.namedtuple('_TrainerState', [ - 'step', # Current training step number. - 'opt_state', # OptState. - 'history', # trax.history.History. - 'model_state', # Auxilliary state of the model. -]) - - -OptState = collections.namedtuple('_OptState', [ - 'weights', # Model weights. - 'slots', # Per-parameter optimizer state, e.g. gradient moments. - 'opt_params', # Optimizer (hyper)parameters, e.g. learning rate, momentum. -]) - - -_DEFAULT_METRICS = { - 'loss': tl.WeightedCategoryCrossEntropy(), - 'accuracy': tl.WeightedCategoryAccuracy(), - 'sequence_accuracy': tl.MaskedSequenceAccuracy(), - 'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(), - tl.Negate()), - 'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()), -} - - -NamedStream = collections.namedtuple( - 'NamedStream', ['name', 'stream'] -) - - -@gin.configurable -def named_stream(name=gin.REQUIRED, stream=gin.REQUIRED): - return NamedStream(name=name, stream=stream) - - -class Trainer: - """Trax trainer. - - A trainer allows to make training steps, train for full epochs, - save the training state and access evaluation data. - """ - - def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, - output_dir=None, random_seed=None, n_devices=None, - checkpoints_at=None, should_save_checkpoints=True, - should_write_summaries=True, - metrics=None, checkpoint_highest=None, - checkpoint_lowest=None, - init_checkpoint=None): - - self._is_chief, _, self._n_devices, rng = ( - training.init_host_and_devices(n_devices, random_seed)) - self._should_save_checkpoints = should_save_checkpoints and self._is_chief - self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] - self._should_write_summaries = should_write_summaries - if not output_dir: - self._should_save_checkpoints = False - self._should_write_summaries = False - self._checkpoint_highest = checkpoint_highest - self._checkpoint_lowest = checkpoint_lowest - self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS - # Inputs is either an Inputs instance or a function that returns it. - self._inputs = inputs - if callable(inputs): # If we pass a function, e.g., through gin, call it. - self._inputs = inputs() - # Initialize the learning rate to a dummy value. It will be set in reset(). - opt = optimizer(learning_rate=0.0) - - # Setup the model. - model_train = model(mode='train') - model_predict_eval = model(mode='eval') - # Should work for fine-tuning of T5. - if init_checkpoint: - model_train.init_from_file(init_checkpoint, weights_only=True) - model_predict_eval.init_from_file(init_checkpoint, weights_only=True) - self._model_with_loss = tl.Serial(model_train, loss_fn) - - # Setup state. - rng, init_rng = jax_random.split(rng) - self._rngs = np.stack(jax_random.split(rng, self._n_devices)) - shapes, dtypes = self._inputs.example_shape_dtype - input_signature = tuple(ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) - - def new_opt_state_and_model_state(rng): - """Returns optimizer and model states suitable for training a model.""" - weights, state = self._model_with_loss.init(input_signature, rng=rng) - (slots, opt_params) = opt.tree_init(weights) - return (OptState(weights, slots, opt_params), state) - - if fastmath.is_backend(fastmath.Backend.JAX): - # JIT parameter initialization to avoid memory fragmentation - new_opt_state_and_model_state = ( - fastmath.jit(new_opt_state_and_model_state)) - self._new_opt_state_and_model_state = ( - lambda: new_opt_state_and_model_state(init_rng)) - - # Arrange and initialize metrics layers. - self._metrics = list(sorted(self._metrics_dict.keys())) - metrics_layers = [self._metrics_dict[m] for m in self._metrics] - metrics_in_parallel = tl.Branch(*metrics_layers) - metrics_in_parallel.rng = init_rng - example_signature = tuple( - ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype) - ) - model_predict_eval.init(example_signature) - self._input_signature = example_signature - output_signature = model_predict_eval.output_signature(example_signature) - m_weights, m_state = metrics_in_parallel.init(output_signature) - self._metrics_weights = self._for_n_devices(m_weights) - self._metrics_state = self._for_n_devices(m_state) - - # Jit model_predict and update so they're fast. - self._jit_eval = _jit_predict_fn( - model_predict_eval, metrics_in_parallel, self._n_devices) - self._jit_update_fn = _jit_update_fn( - model_train, loss_fn, opt, self._n_devices) - - self._model_train = model_train - self._model_predict_eval = model_predict_eval - self._loss_fn = loss_fn - self._lr_schedule = lr_schedule - - # Those fields will be set in reset(). - self._output_dir = None - self._train_sw = None - self._eval_sw = None - self._history = None - self._opt_state = None - self._step = None - self._model_state = None - self.reset(output_dir) - - @property - def n_devices(self): - return self._n_devices - - @property - def step(self): - return self._step - - @property - def model_weights(self): - # Currently we need to pick [0] as we ignore loss weights (empty). - weights = self._opt_state.weights[0] - if self.n_devices > 1: - unreplicate = lambda x: x[0] - weights = fastmath.nested_map(unreplicate, weights) - return weights - - @model_weights.setter - def model_weights(self, weights): - new_model_weights = self._for_n_devices(weights) - if isinstance(self._opt_state.weights, list): - self._opt_state.weights[0] = new_model_weights - else: # weights are a tuple, need to re-create - new_weights = [new_model_weights] + list(self._opt_state.weights[1:]) - self._opt_state = self._opt_state._replace(weights=new_weights) - - @property - def model_state(self): - # Currently we need to pick [0] as we ignore loss state (empty). - state = self._model_state[0] - if self.n_devices > 1: - unreplicate = lambda x: x[0] - state = fastmath.nested_map(unreplicate, state) - return state - - @model_state.setter - def model_state(self, state): - new_model_state = self._for_n_devices(state) - if isinstance(self._model_state, list): - self._model_state[0] = new_model_state - else: # weights are a tuple, need to re-create - self._model_state = [new_model_state] + list(self._model_state[1:]) - - @property - def state(self): - return TrainerState( - opt_state=self._opt_state, step=self._step, history=self._history, - model_state=self._model_state) - - @property - def learning_rate(self): - with fastmath.use_backend(fastmath.Backend.NUMPY): - return self._lr_schedule(self._step) - - def reset(self, output_dir, init_checkpoint=None): - """Reset the model parameters. - - Restores the parameters from the given output_dir if a checkpoint exists, - otherwise randomly initializes them. - - Does not re-jit the model. - - Args: - output_dir: Output directory. - init_checkpoint: Initial checkpoint (default $output_dir/model.pkl.gz) - """ - self.close() - self._output_dir = output_dir - if output_dir is not None: - tf.io.gfile.makedirs(output_dir) - else: - assert not self._should_save_checkpoints - assert not self._should_write_summaries - - # Create summary writers and history. - if self._should_write_summaries: - self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'), - enable=self._is_chief) - self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'), - enable=self._is_chief) - - # Reset the train and eval streams. - self._train_stream = _repeat_stream(self._inputs.train_stream, - self._n_devices) - # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval - # set by adding a padding and stopping the stream when too large. - self._eval_stream = _repeat_stream( - self._inputs.eval_stream, self._n_devices) - self._train_eval_stream = _repeat_stream( - self._inputs.train_eval_stream, self._n_devices) - - # Restore the training state. - if output_dir is not None: - state = load_trainer_state(output_dir, self._model_with_loss, - init_checkpoint) - else: - state = TrainerState(step=None, opt_state=None, - history=trax_history.History(), model_state=None) - self._step = state.step or 0 - history = state.history - self._history = history - if state.opt_state: - opt_state = state.opt_state - model_state = state.model_state - else: - opt_state, model_state = self._new_opt_state_and_model_state() - model_state = self._for_n_devices(model_state) - self._opt_state = OptState(*self._for_n_devices(opt_state)) - self._model_state = model_state - if not state.opt_state and self._should_save_checkpoints: - self.save_state(keep=False) - - def train_epoch(self, n_steps, n_eval_steps): - """Runs `n_steps` of training, with periodic logging, saving, and evals.""" - # TODO(jonni): Clarify how this method relates to the stricter notion of - # epoch (training for as many steps as needed for a full pass through the - # training data). - print() # Add visual separator in logs for start of training epoch. - start_time = time.time() - - for _ in range(n_steps): - batch = next(self._train_stream) - if self.n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible. - batch = _reshape_by_device(batch, self.n_devices) - self.train_step(batch) - if self._should_save_now(): - self.save_state(keep=True) - if self._should_log_now(): - self._train_sw.scalar('training/learning_rate', self.learning_rate) - - # At end of n_steps, do bookkeeping, run evals, and save state. - elapsed_time = time.time() - start_time - self.log_step('Ran %d train steps in %0.2f secs' % (n_steps, elapsed_time)) - if self._train_sw and n_steps > 1: - self._train_sw.scalar('training/steps per second', - n_steps / elapsed_time, step=self._step) - self._train_sw.flush() - self.evaluate(n_eval_steps) - if self._eval_sw: - self._eval_sw.flush() - if self._should_save_checkpoints: - self.save_state(keep=False) - if self._should_save_checkpoints and self._current_step_is_best(high=True): - self.save_state(keep=False, prefix='highest_' + self._checkpoint_highest) - if self._should_save_checkpoints and self._current_step_is_best(high=False): - self.save_state(keep=False, prefix='lowest_' + self._checkpoint_lowest) - - def train_step(self, batch): - """Run one training step and update self._opt_state.""" - # Calculate the current optimizer parameters. - opt_param_updates = self._for_n_devices( - {'learning_rate': np.array(self.learning_rate)}) - opt_state = self._opt_state - opt_state.opt_params.update(opt_param_updates) - - # Run the update. - weights, slots, opt_params = opt_state - (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( - (weights, slots), self._step, opt_params, batch, - self._model_state, self._rngs) - self._opt_state = opt_state._replace(weights=weights, slots=slots) - if self._should_log_now(): - for name, value in stat.items(): - # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here - # with a device put array error complaining that it should be an array. - # On multiple devices, take the mean. - scalar_value = np.mean(np.array(value)) - self._train_sw.scalar('training/' + name, scalar_value, step=self._step) - self._step += 1 - - def evaluate(self, n_eval_steps): - """Evaluate the model and log metrics.""" - _, rng = jax_random.split(self._rngs[0]) - # TODO(lukaszkaiser): both model state and parameters by default include - # the loss layer. Currently, we access the pure-model parameters by just - # indexing, [0] here. But we should make it more explicit in a better API. - weights = (self._opt_state.weights[0], self._metrics_weights) - state = (self._model_state[0], self._metrics_state) - self.log_step('Evaluation') - train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps) - train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state, - rng) - self.log_metrics(train_metrics, self._train_sw, 'train') - eval_slice = itertools.islice(self._eval_stream, n_eval_steps) - eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng) - self.log_metrics(eval_metrics, self._eval_sw, 'eval') - self.log_step('Finished evaluation') - - # Save the learning rate in history. - self._history.append('train', 'training/learning_rate', - self._step, self.learning_rate) - - def evaluation_round(self, inputs_stream, weights, state, rng): - """Evaluate. - - Args: - inputs_stream: Iterable of inputs to evaluate on. - weights: Weights for each f in eval_fns. - state: State for each f in eval_fns. - rng: Single-use random number generator (JAX PRNG key). - - Returns: - Tuple of `(metrics, state)`. `metrics` is a dict from metric name to - metric value averaged over the number of inputs, and `state` is the end - state returned by this trainer's `predict_fn`. - """ - metrics = collections.defaultdict(float) - count = 0 - for inp in inputs_stream: - count += 1 - rng, subrng = jax_random.split(rng) - metric_values, _ = self._jit_eval(inp, weights, state, subrng) - try: - metric_values = list(metric_values) - except (TypeError, IndexError): - metric_values = [float(metric_values)] - for m, v in zip(self._metrics, metric_values): - metrics[m] += v - return {m: v / count for (m, v) in metrics.items()}, state - - def save_gin(self): - """"Saves the operative gin config, only if it is the chief.""" - if not self._is_chief: - return - assert self._output_dir is not None - config_path = os.path.join(self._output_dir, 'config.gin') - config_str = gin.operative_config_str() - with tf.io.gfile.GFile(config_path, 'w') as f: - f.write(config_str) - sw = self._train_sw - if sw: - sw.text('gin_config', - jaxboard.markdownify_operative_config_str(config_str)) - - def _save_state_dict(self, trainer_state_dict, weights_file): - training.pickle_to_file(trainer_state_dict, weights_file, gzip=True) - log('Model saved to %s' % weights_file, stdout=False) - - def save_state(self, keep, prefix='model'): - """Save trainer state given a possibly replicated opt_state.""" - opt_state = self._opt_state - if self.n_devices > 1: - first_replica = lambda x: x[0] - opt_state = OptState(*fastmath.nested_map(first_replica, opt_state)) - # This line, while optional, allows JAX to transfer arrays from the device - # to the host in parallel, which is particularly important for cloud TPU. - if fastmath.is_backend(fastmath.Backend.JAX): - opt_state = jax.device_get(opt_state) - step, history, model_state = self._step, self._history, self._model_state - output_dir = self._output_dir - - weights_file = os.path.join(output_dir, prefix + '.pkl.gz') - - # This dict will be stored as the model. - trainer_state_dict = make_trainer_state_dict( - step, opt_state, history, model_state, self._input_signature) - self._save_state_dict(trainer_state_dict, weights_file) - - if keep: - weights_file = os.path.join(output_dir, - '{}_{}.pkl.gz'.format(prefix, step)) - self._save_state_dict(trainer_state_dict, weights_file) - - def save_computation_graphs(self): - """Dump computation graphs to files.""" - if self.n_devices != 1: - return # TODO(lukaszkaiser): make this work with more devices. - batch = next(self._train_stream) - output_dir = self._output_dir - if self.n_devices > 1: - batch = _reshape_by_device(batch, self.n_devices) - weights = self._opt_state.weights[0] - forward_computation = jax.jit(self._model_predict_eval).lower( - batch, weights=weights, state=self._model_state[0], - rng=self._rngs[0]).compiler_ir(dialect='hlo') - with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f: - f.write(forward_computation.as_hlo_text()) - with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f: - f.write(forward_computation.as_hlo_dot_graph()) - - def log_step(self, step_message): - log('Step % 6d: %s' % (self.step, step_message)) - - def log_metrics(self, metrics, summ_writer, log_prefix): - """Log metrics to summary writer and history.""" - history = self._history - rjust_len = max([0] + [len(name) for name in metrics]) - for name, value in metrics.items(): - self.log_step('%s %s | % .8f' % ( - log_prefix.ljust(5), name.rjust(rjust_len), value)) - full_name = 'metrics/' + name - if history: - history.append(log_prefix, full_name, self.step, value) - if summ_writer: - summ_writer.scalar(full_name, value, self.step) - - def print_n_weights(self): - """Prints the total count of trainable weights.""" - opt_state = self._opt_state - sizes = _sizes(opt_state.weights) - if self.n_devices > 1: - unreplicate = lambda x: x[0] - single_weights = fastmath.nested_map(unreplicate, opt_state.weights) - sizes = _sizes(single_weights) - total_size = _nested_reduce(sum, sizes) - self.log_step('Total number of trainable weights: %d' % total_size) - - def _should_save_now(self): - return self._should_save_checkpoints and self._step in self._checkpoints_at - - def _current_step_is_best(self, high): - """Is the current step the best (highest if high, else lowest).""" - metric = self._checkpoint_highest if high else self._checkpoint_lowest - if metric is None: - return False - # History is a list of pairs (step, value). - history = self._history.get('eval', 'metrics/' + metric) - sequence = [float(i[1]) for i in history] # Just the values. - best = max(sequence) if high else min(sequence) # Best value. - last_is_best = float(history[-1][1]) == best # Is last the best? - cur_step = history[-1][0] == self._step # Is last the current step? - return cur_step and last_is_best - - def _should_log_now(self): - return (self._train_sw is not None - and (self._step == 1 or self._step % 10 == 0)) - - def _for_n_devices(self, x): - """Replicates/broadcasts `x` for n devices if `self.n_devices > 1`.""" - return tl.for_n_devices(x, self.n_devices) # pylint: disable=protected-access - - def close(self): - if self._train_sw is not None: - self._train_sw.close() - self._train_sw = None - if self._eval_sw is not None: - self._eval_sw.close() - self._eval_sw = None - - -@gin.configurable(denylist=['output_dir']) -def train(output_dir, - model=gin.REQUIRED, - loss_fn=tl.WeightedCategoryCrossEntropy(), - inputs=trax_inputs.batcher, - optimizer=trax_opt.Adafactor, - lr_schedule_fn=lr.multifactor, - trainer_class=Trainer, - steps=1000, - checkpoints_at=None, - permanent_checkpoints_at=None, - eval_steps=10, - eval_frequency=100, - permanent_checkpoint_frequency=None, - random_seed=None, - save_graphs=True, - metrics=None, - checkpoint_highest=None, - checkpoint_lowest=None, - use_loop=True, - loss_chunk_size=0, - use_memory_efficient_trainer=False, - adasum=False, - init_checkpoint=None, - callbacks=None, - n_weights_shards=1, - additional_train_tasks=None, - additional_eval_tasks=None, - additional_eval_streams=None): - """Train the model on the inputs. - - Args: - output_dir: Directory where to put the logs and checkpoints. - model: The model to train as a callable returning 2 callables, an init_fn - and apply_fn. - loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, - rng -> loss. - inputs: callable returning trax.inputs.Inputs. - optimizer: The optimizer (see optimizers/base.py for signature). - lr_schedule_fn: A learning rate schedule function, that when called returns - a function from step to learning rate (a float). - trainer_class: The trainer class to use. - steps: int, total number of training steps. - checkpoints_at: list of integers. Save a checkpoint for each training step - in the list. - permanent_checkpoints_at: list of integers. Save a permanent checkpoint for - each training step in the list. - eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. - eval_frequency: int, how often to run evaluation (every eval_frequency - steps). If None or 0, eval disabled. - permanent_checkpoint_frequency: int, how often to save permanent checkpoints - (every permanent_checkpoint_frequency steps). - random_seed: the random seed to use; time/os dependent if None (default). - save_graphs: bool, if True, save computation graph to file. - metrics: optionally override the default metrics dictionary. - checkpoint_highest: save the checkpoint highest at this metric. - checkpoint_lowest: save the checkpoint lowest at this metric. - use_loop: whether to use training.Loop instead of Trainer. - loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. - use_memory_efficient_trainer: whether to use memory-efficient trainer. - adasum: if True, use adaptive summation for multi-device gradients. - init_checkpoint: a checkpoint for fine tuning. - callbacks: a list of callbacks to call during training. - n_weights_shards: shard weights into this many devices. - additional_train_tasks: additional tasks which should be performed during - training. - additional_eval_tasks: additional tasks which should be performed during - evaluation. - additional_eval_streams: List[NamedStream], additional data streams that - should be used during evaluation. Can be provided independently of - additional_eval_tasks. - - Returns: - trax.TrainerState or training.Loop if use_loop is True - """ - base.N_WEIGHTS_SHARDS = n_weights_shards - if (permanent_checkpoint_frequency is not None - and permanent_checkpoints_at is not None): - raise ValueError('Only one of ["permanent_checkpoint_frequency", ' - '"permanent_checkpoints_at"] should be set.') - if use_loop: - n_devices = num_devices() or fastmath.local_device_count() - - # Prepare the training task. - # Inputs is either an Inputs instance or a function that returns it. - if callable(inputs): # If we pass a function, e.g., through gin, call it. - inputs = inputs() - opt = optimizer if use_memory_efficient_trainer else optimizer() - train_task = training.TrainTask( - inputs.train_stream(n_devices), - loss_layer=loss_fn, - optimizer=opt, - lr_schedule=lr_schedule_fn(), - n_steps_per_checkpoint=eval_frequency, - n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency) - - if additional_train_tasks is None: - additional_train_tasks = [] - - # Prepare the evaluation. - metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS - names, metrics = zip(*metrics_dict.items()) - eval_task = training.EvalTask(inputs.eval_stream(n_devices), - metrics, - metric_names=names, - n_eval_batches=eval_steps) - - if additional_eval_tasks is None: - additional_eval_tasks = [] - - additional_eval_tasks_from_streams = [] - if additional_eval_streams is not None: - for stream in additional_eval_streams: - additional_eval_tasks_from_streams.append( - training.EvalTask(stream.stream, - metrics, - metric_names=names, - n_eval_batches=eval_steps, - export_prefix=stream.name)) - - # Prepare the training loop. - checkpoint_at = None - if checkpoints_at is not None: - checkpoint_at = lambda step: step in checkpoints_at - permanent_checkpoint_at = None - if permanent_checkpoints_at is not None: - permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at) - - # Setup the model. - model_train = model(mode='train') - model_predict_eval = model(mode='eval') - if init_checkpoint: - model_train.init_from_file(init_checkpoint, weights_only=True) - model_predict_eval.init_from_file(init_checkpoint, weights_only=True) - loop = training.Loop( - model_train, [train_task] + additional_train_tasks, - eval_model=model_predict_eval, - eval_tasks=[eval_task] + - additional_eval_tasks + additional_eval_tasks_from_streams, - output_dir=output_dir, - checkpoint_at=checkpoint_at, - checkpoint_low_metric=checkpoint_lowest, - checkpoint_high_metric=checkpoint_highest, - permanent_checkpoint_at=permanent_checkpoint_at, - n_devices=n_devices, - loss_chunk_size=loss_chunk_size, - use_memory_efficient_trainer=use_memory_efficient_trainer, - adasum=adasum, - random_seed=random_seed, - callbacks=callbacks, - ) - - steps_to_go = steps - loop.step - if steps_to_go <= 0: - log('Stop training, already reached the total training steps %d' % steps) - return loop - - # Train and return the loop. - loop.run(steps_to_go) - return loop - - n_devices = num_devices() - trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, - output_dir, - random_seed=random_seed, - n_devices=n_devices, - checkpoints_at=checkpoints_at, - metrics=metrics, - checkpoint_lowest=checkpoint_lowest, - checkpoint_highest=checkpoint_highest, - init_checkpoint=init_checkpoint) - - epoch_steps = [steps] # Only training if eval_frequency is 0 or None - if eval_frequency and eval_steps > 0: - epoch_steps = itertools.chain([1, # first epoch only 1 step - eval_frequency - 1], - itertools.repeat(eval_frequency)) - trainer.log_step('Starting training using %d devices' % trainer.n_devices) - trainer.print_n_weights() - - try: - for epoch_steps in epochs(steps, trainer.step, epoch_steps): - trainer.train_epoch(epoch_steps, eval_steps) - - # Bookkeeping we do at the first step - if trainer.step == 1: - # Save computation graph (single-device only for now) - if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): - trainer.save_computation_graphs() - - # Save Gin config - trainer.save_gin() - - trainer.log_step('Training done') - except Exception as e: - raise e - finally: - trainer.close() - return trainer.state - - -@gin.configurable -def num_devices(value=None): - """Returns how many devices to use (if None, default, use all available).""" - return value - - -@gin.configurable -def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): - """Returns a (JIT-compiled) function that computes updates for one step.""" - model_and_loss = tl.Serial(predict_fn, loss_fn) - # Gradients are always wrt. the first argument, so putting weights first. - def model_and_loss_call(weights, batch, state, rng): - res = model_and_loss(batch, weights=weights, state=state, rng=rng) - return res, model_and_loss.state - if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. - def single_update(weights_and_slots, i, opt_params, batch, state, rng): - weights, slots = weights_and_slots - rng, subrng = jax_random.split(rng[0]) - grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) - grads, state = grad_fn(weights, batch, state, rng) - new_weights, new_slots, stats = optimizer.tree_update( - i, grads, weights, slots, opt_params) - return (new_weights, new_slots), stats, state, [subrng] - if jit: - # TODO(lukaszkaiser): donate_argnums=(0,) when XLA supports it on GPU - return fastmath.jit(single_update) - else: - return single_update - - # Else, for n_devices > 1: - @functools.partial(fastmath.pmap, axis_name='batch') # donate_argnums=(0,)) - def mapped_update(weights_and_slots, i, opt_params, batch, state, rng): - """This is a multi-device version of the update function above.""" - # We assume all tensors have the first dimension = n_devices. - weights, slots = weights_and_slots - rng, subrng = jax_random.split(rng) - grad_fn = fastmath.grad(model_and_loss_call, has_aux=True) - grads, state = grad_fn(weights, batch, state, rng) - # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just - # the number of devices on this host machine, however psum goes over all - # devices of all hosts (ex: a TPU pod) and we need to be averaging over all - # of them. - # - # Collect all gradients. - grads = fastmath.psum(grads, 'batch') - n_devices_total = fastmath.psum(np.array(1.0), 'batch') - # Average across hosts. - grads = jax.tree_util.tree_map(lambda g: g / n_devices_total, grads) - - new_weights, new_slots, stats = optimizer.tree_update( - i, grads, weights, slots, opt_params) - return (new_weights, new_slots), stats, state, subrng - - def update(weights_and_slots, i, opt_params, batch, state, rng): - return mapped_update(weights_and_slots, np.repeat(i, n_devices), - opt_params, batch, state, rng) - - return update - - -@gin.configurable -def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): - """Returns a JIT-compiled predict function (unless jit=False).""" - model = tl.Serial(model_predict, metric_fn) - if not jit: - return model.pure_fn - - return tl.jit_forward(model.pure_fn, n_devices) - - -@gin.configurable -def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True): - """Returns a (JIT-compiled) function that computes the loss for one step.""" - if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. - def single_compute_loss(opt_state, batch, state, rng): - rng, subrng = jax_random.split(rng[0]) - loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) - return loss_val, state, [subrng] - return fastmath.jit(single_compute_loss) if jit else single_compute_loss - - # Else, for n_devices > 1: - @functools.partial(fastmath.pmap, axis_name='batch') - def mapped_compute_loss(opt_state, batch, state, rng): - """This is a multi-device version of the update function above.""" - # We assume all tensors have the first dimension = n_devices. - rng, subrng = jax_random.split(rng) - loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) - return loss_val, state, subrng - - def compute_loss(opt_state, batch, state, rng): - return mapped_compute_loss( - opt_state, _reshape_by_device(batch, n_devices), state, rng) - - return compute_loss - - -def log(s, stdout=True): - logging.info(s) - if stdout: - print(s) - sys.stdout.flush() - - -def epochs(total_steps, steps_to_skip, epoch_steps): - """Generates the number of steps in each epoch before reaching total_steps. - - Args: - total_steps: int, total number of steps. - steps_to_skip: int, number of steps to skip because of a restart. - epoch_steps: iterable of int, numbers of steps in each epoch. - - Yields: - epoch_steps: int, number of steps in this epoch - """ - steps_to_go = total_steps - steps_to_skip - epoch_steps = iter(epoch_steps) - - # Remove the desired number of steps from the stream. - for steps_this_epoch in epoch_steps: - if steps_this_epoch > steps_to_skip: - # Put back the number of steps left in the unfinished epoch. - epoch_steps = itertools.chain( - [steps_this_epoch - steps_to_skip], epoch_steps) - if steps_this_epoch >= steps_to_skip: - break - steps_to_skip -= steps_this_epoch - - # Yield the remaining steps per epoch up to total_steps. - for steps_this_epoch in epoch_steps: - steps_this_epoch = min(steps_this_epoch, steps_to_go) - yield steps_this_epoch - steps_to_go -= steps_this_epoch - if steps_to_go == 0: - break - - -def make_trainer_state_dict(step, - opt_state, - history, - model_state, - input_signature): - """Creates a trainer state dictionary to save to disk. - - Args: - step: int, a step number - opt_state: OptState namedtuple - history: `trax.history.History`, the history object. - model_state: A nested structure of the model state. - input_signature: signature of model inputs. - - Returns: - A dictionary with the fields of TrainerState and OptState flattened. - """ - flat_weights, flat_state = tl.flatten_weights_and_state( - opt_state.weights, model_state) - return { - 'step': step, - 'flat_weights': flat_weights, - 'slots': opt_state.slots, - 'opt_params': opt_state.opt_params, - 'history': history, - 'flat_state': flat_state, - 'input_signature': input_signature, - 'version_timestamp': 'Jun-18-2020' # To update in the future if needed. - } - - -def trainer_state_from_dict(trainer_state_dict, model): - """Given the trainer state dictionary, returns `TrainerState`.""" - # TODO(afrozm): This becomes simpler if OptState is flattened into - # TrainerState. - step = trainer_state_dict['step'] - history = trainer_state_dict['history'] - input_signature = trainer_state_dict['input_signature'] - weights_and_state_sig = model.weights_and_state_signature(input_signature) - weights, model_state = tl.unflatten_weights_and_state( - trainer_state_dict['flat_weights'], trainer_state_dict['flat_state'], - weights_and_state_sig) - opt_state = OptState( - weights=weights, - slots=trainer_state_dict['slots'], - opt_params=trainer_state_dict['opt_params']) - return TrainerState(step=step, opt_state=OptState(*opt_state), - history=history, model_state=model_state) - - -def load_trainer_state(output_dir, model, weights_file=None): - """Returns a TrainerState instance loaded from the given `output_dir`.""" - if weights_file is None: - weights_file = os.path.join(output_dir, 'model.pkl.gz') - if not tf.io.gfile.exists(weights_file): - return TrainerState(step=None, opt_state=None, - history=trax_history.History(), model_state=None) - elif not tf.io.gfile.exists(weights_file): - raise ValueError('File not found: %s' % weights_file) - - trainer_state_dict = training.unpickle_from_file(weights_file, gzip=True) - trainer_state = trainer_state_from_dict(trainer_state_dict, model) - log('Model loaded from %s at step %d' % (weights_file, trainer_state.step)) - logging.debug('From loaded model : history = %s', trainer_state.history) - return trainer_state - - -def _reshape_by_device(x, n_devices): - """Reshapes possibly nested x into a shape (n_devices, ...).""" - return tl.reshape_by_device(x, n_devices) # pylint: disable=protected-access - - -def _nested_reduce(f, x): - """Fold the function f to the nested structure x (dicts, tuples, lists).""" - if isinstance(x, list): - return f([_nested_reduce(f, y) for y in x]) - if isinstance(x, tuple): - return f([_nested_reduce(f, y) for y in x]) - if isinstance(x, dict): - return f([_nested_reduce(f, v) for (_, v) in x.items()]) - return x - - -def _sizes(x): - """Get a structure of sizes for a structure of nested arrays.""" - def size(x): - try: - return x.size - except Exception: # pylint: disable=broad-except - return 0 - return fastmath.nested_map(size, x) - - -def _repeat_stream(stream, n_devices): - """Repeat a stream indefinitely.""" - while True: - for example in stream(n_devices): - yield example diff --git a/trax/supervised/trainer_lib_test.py b/trax/supervised/trainer_lib_test.py deleted file mode 100644 index 6464cdf2c..000000000 --- a/trax/supervised/trainer_lib_test.py +++ /dev/null @@ -1,555 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax.supervised.trainer_lib.""" - -import functools -import os - -from absl.testing import absltest -from absl.testing import parameterized -import jax -from jax.config import config -import tensorflow.compat.v2 as tf -from trax import fastmath -from trax import layers as tl -from trax import models -from trax import optimizers as trax_opt -from trax import shapes as trax_shapes -from trax import test_utils -from trax.data import inputs as inputs_lib -from trax.fastmath import numpy as jnp -from trax.supervised import lr_schedules as lr -from trax.supervised import trainer_lib -from trax.tf_numpy import extensions as npe -from trax.tf_numpy import numpy as tf_np - - - -def _test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)): - """Make trainer_lib.inputs.Inputs.""" - batch_size = 2 * jax.device_count() - - def input_stream(n_devices): - del n_devices - key = fastmath.random.get_prng(0) - while True: - keys = fastmath.random.split(key, 4) - key = keys[0] - inputs = fastmath.random.uniform( - keys[1], [batch_size] + list(input_shape)) - targets = fastmath.random.randint( - keys[2], [batch_size], dtype=jnp.int32, minval=0, maxval=n_classes) - weights = fastmath.random.uniform(keys[3], [batch_size]) - if with_weights: - yield inputs, targets, weights - else: - yield inputs, targets - - def input_stream_masked(n_devices): - return inputs_lib.add_loss_weights(input_stream(n_devices)) - - return inputs_lib.Inputs(input_stream_masked) - - -def _test_inputs_lm(vocab_size, seq_len, per_device_batch_size=2): - """Make trainer_lib.inputs.Inputs for language model.""" - batch_size = per_device_batch_size * jax.device_count() - - def input_stream(_): - def make_batch(key): - return fastmath.random.randint( - key, [batch_size, seq_len], dtype=jnp.int32, minval=0, - maxval=vocab_size) - key = fastmath.random.get_prng(0) - while True: - keys = fastmath.random.split(key, 3) - key = keys[0] - inputs = make_batch(keys[1]) - targets = make_batch(keys[2]) - yield inputs, targets - - def input_stream_masked(n_devices): - return inputs_lib.add_loss_weights(input_stream(n_devices)) - - return inputs_lib.Inputs(input_stream_masked) - - - -BACKENDS = [fastmath.Backend.JAX] -BACKENDS_AND_CONFIGS = [(fastmath.Backend.JAX, [('Simple', None)])] - - -def short_name(b): - if b == fastmath.Backend.JAX: - return 'jax' - else: - return 'tf' - - -def opt_name(opt): - if opt is None: - return 'None' - return opt.__name__ - - -def _pure_lsh_self_attention_fn(n_chunks_after=0): - return functools.partial( - tl.PureLSHSelfAttentionWrapper, - attention_dropout=0.1, - chunk_len=16, - n_buckets=[32, 32], - n_chunks_after=n_chunks_after, - n_chunks_before=1, - n_hashes=2, - n_parallel_heads=1, - max_length_for_buckets=1024, - predict_drop_len=128, - predict_mem_len=1024, - num_weights=2, - bias=False, - pure_lsh_implementation=tl.PureLSHSelfAttention, - ) - - -def _mixed_lsh_self_attention_fn(n_chunks_after=0): - return functools.partial( - tl.PureLSHSelfAttentionWrapper, - attention_dropout=0.1, - chunk_len=16, - n_buckets=[32, 32], - n_chunks_after=n_chunks_after, - n_chunks_before=1, - n_hashes=2, - n_parallel_heads=1, - max_length_for_buckets=1024, - predict_drop_len=128, - predict_mem_len=1024, - num_weights=2, - bias=False, - pure_lsh_implementation=tl.MixedLSHSelfAttention, - ) - - -class TraxTest(parameterized.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super().__init__(methodName) - if npe.tpu_devices(): - # Initialize TPU for TF - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') - tf.tpu.experimental.initialize_tpu_system(resolver) - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - self._old_is_allow_float64 = tf_np.is_allow_float64() - tf_np.set_allow_float64(False) - - def tearDown(self): - tf_np.set_allow_float64(self._old_is_allow_float64) - super().tearDown() - - def _test_train_eval_predict(self, backend, model_name='Simple', - optimizer=None): - with fastmath.use_backend(backend): - # Prepare model and inputs - steps = 2 - eval_steps = 2 - - if model_name == 'Simple': - n_classes = 4 - # Adds Dropout and BatchNorm to test state handling. - def model_fn(mode='train'): - return tl.Serial( - tl.Dropout(mode=mode, rate=0.1), - tl.BatchNorm(mode=mode), - models.MLP(layer_widths=(16, 16, n_classes), mode=mode)) - inputs = _test_inputs(n_classes) - n_in = 1 - elif model_name == 'Resnet50': - n_classes = 4 - model_fn = models.Resnet50 - inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) - n_in = 1 - elif model_name == 'Transformer': - vocab_size = 32 - seq_len = 16 - inputs = _test_inputs_lm(vocab_size, seq_len) - model_fn = functools.partial( - models.Transformer, - input_vocab_size=vocab_size) - n_in = 2 - else: - raise ValueError('Unrecognized model name: ' + model_name) - - kwargs = {} - if optimizer is not None: - kwargs['optimizer'] = optimizer - - # Train and evaluate - output_dir = self.create_tempdir().full_path - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1, # eval at every step. - **kwargs) - - # Assert total train steps - self.assertEqual(steps, loop.step) - - inputs = inputs.train_stream(1) - - # Predict with final weights - model = model_fn() - weights = loop.model.weights - state = loop.model.state - model(next(inputs)[:n_in], weights=weights, state=state) - - # Predict with weights loaded from file. - model = model_fn() - model.init_from_file(os.path.join(output_dir, 'model.pkl.gz')) - model(next(inputs)[:n_in]) - - @parameterized.named_parameters( - ('_%s_%s_%s' % (short_name(backend), model_name, opt_name(opt)), # pylint: disable=g-complex-comprehension - backend, model_name, opt) - for backend, configs in BACKENDS_AND_CONFIGS - for model_name, opt in configs) - def test_train_eval_predict(self, backend, model_name, opt): - self._test_train_eval_predict(backend, model_name, opt) - - @parameterized.parameters(BACKENDS) - def test_train_eval_predict_sm3(self, backend): - self._test_train_eval_predict(backend, 'Simple', trax_opt.SM3) - - @parameterized.parameters(BACKENDS) - def test_train_restart(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Restart training - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=(2 * steps), - eval_steps=eval_steps, - eval_frequency=1) - - # Assert total train steps - self.assertEqual(loop.step, 2 * steps) - - @parameterized.parameters(BACKENDS) - def test_train_permanent_checkpoints(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 5 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - - # Steps 1 -> 5 - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1, - permanent_checkpoint_frequency=2) - - # Steps 6 -> 10 - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=(2 * steps), - eval_steps=eval_steps, - eval_frequency=1, - permanent_checkpoints_at=[7, 8, 10]) - - path = os.path.join(output_dir, 'model.pkl.gz') - self.assertTrue(tf.io.gfile.exists(path)) - - for step in range(11): - filename = 'model_{}.pkl.gz'.format(step) - path = os.path.join(output_dir, filename) - if step in [1, 2, 4, 7, 8, 10]: - self.assertTrue(tf.io.gfile.exists(path), - msg='No model for step: {} in dir {}.'.format( - step, tf.io.gfile.listdir(output_dir))) - else: - self.assertFalse(tf.io.gfile.exists(path), - msg='Model for step: {} in dir {}.'.format( - step, tf.io.gfile.listdir(output_dir))) - - # Assert total train steps - self.assertEqual(loop.step, 10) - - @parameterized.parameters(BACKENDS) - def test_train_restart_with_same_steps(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Restart training - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Assert total train steps - self.assertEqual(loop.step, steps) - - def test_train_with_pure_lsh_attention(self, backend=fastmath.Backend.JAX): - with fastmath.use_backend(backend): - # Prepare model and inputs - def model(mode='train'): - return models.ConfigurableTerraformer( - mode=mode, - d_model=16, - d_ff=16, - n_heads=2, - dropout=0.05, - n_decoder_layers=1, - n_encoder_layers=1, - input_vocab_size=256, - encoder_attention_type=_pure_lsh_self_attention_fn(), - encoder_decoder_attention_type=_pure_lsh_self_attention_fn(), - ) - - max_len = 128 - inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) - - steps = 1 - eval_steps = 1 - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Read checkpoint - model_file = os.path.join(output_dir, 'model.pkl.gz') - - shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) - shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) - - model_predict = model(mode='predict') - model_predict.init_from_file( - model_file, weights_only=True, input_signature=(shape1l, shape11)) - - def test_train_with_mixed_lsh_attention(self, backend=fastmath.Backend.JAX): - with fastmath.use_backend(backend): - # Prepare model and inputs - - def model(mode='train'): - return models.ConfigurableTerraformer( - mode=mode, - d_model=16, - d_ff=16, - n_heads=2, - dropout=0.05, - n_decoder_layers=1, - n_encoder_layers=1, - input_vocab_size=256, - encoder_attention_type=_mixed_lsh_self_attention_fn(), - encoder_decoder_attention_type=_mixed_lsh_self_attention_fn(), - ) - - max_len = 128 - inputs = _test_inputs_lm(vocab_size=256, seq_len=max_len) - - steps = 1 - eval_steps = 1 - - # Train and evaluate - output_dir = self.create_tempdir().full_path - trainer_lib.train( - output_dir, - model=model, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1) - - # Read checkpoint - model_file = os.path.join(output_dir, 'model.pkl.gz') - - shape11 = trax_shapes.ShapeDtype((1, 1), dtype=jnp.int32) - shape1l = trax_shapes.ShapeDtype((1, max_len), dtype=jnp.int32) - - model_predict = model(mode='predict') - model_predict.init_from_file(model_file, weights_only=True, - input_signature=(shape1l, shape11)) - - @parameterized.parameters(BACKENDS) - def test_train_fills_in_missing_eval_metrics(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - additional_eval_stream = trainer_lib.NamedStream( - # deliberately duplicating eval data - stream=inputs.eval_stream(1), - name='additional_eval_task') - - # Train and evaluate - output_dir = self.create_tempdir().full_path - loop = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps, - eval_frequency=1, - additional_eval_streams=[additional_eval_stream]) - - self.assertLen(loop.eval_tasks, 2) - eval_task_1, eval_task_2 = loop.eval_tasks - self.assertCountEqual(eval_task_1.metrics, eval_task_2.metrics) - self.assertCountEqual(eval_task_1.metric_names, eval_task_2.metric_names) - - @parameterized.named_parameters( - ('_%s' % short_name(backend), backend) - for backend in BACKENDS) - def test_train_with_weights(self, backend): - with fastmath.use_backend(backend): - # Prepare model and inputs - n_classes = 4 - steps = 2 - eval_steps = 2 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes, with_weights=True) - - # Train and evaluate - output_dir = self.create_tempdir().full_path - state = trainer_lib.train( - output_dir, - model=model_fn, - inputs=inputs, - steps=steps, - eval_steps=eval_steps) - - # Assert total train steps - self.assertEqual(state.step, steps) - - @parameterized.parameters(BACKENDS) - def test_reset_twice(self, backend): - with fastmath.use_backend(backend): - n_classes = 4 - model_fn = functools.partial(models.MLP, - layer_widths=(16, 16, n_classes)) - inputs = _test_inputs(n_classes) - - trainer = trainer_lib.Trainer( - model=model_fn, - loss_fn=tl.WeightedCategoryCrossEntropy(), - optimizer=trax_opt.SM3, - lr_schedule=lr.multifactor(), - inputs=inputs, - ) - - output_dir1 = self.create_tempdir(name='output_dir1').full_path - trainer.reset(output_dir1) - trainer.evaluate(1) - output_dir2 = self.create_tempdir(name='output_dir2').full_path - trainer.reset(output_dir2) - trainer.evaluate(1) - - def test_tf_xla_forced_compile(self): - # TODO(wangpeng): re-enable this test - self.skipTest('Needs --config=cuda to pass this test') - old_flag = fastmath.tf.tf_xla_forced_compile_enabled() - fastmath.tf.set_tf_xla_forced_compile(True) - self._test_train_eval_predict('tf') - fastmath.tf.set_tf_xla_forced_compile(old_flag) - - - -class EpochsTest(absltest.TestCase): - - def test_cuts_epoch_when_total_steps_reached(self): - epoch_steps = trainer_lib.epochs( - total_steps=5, steps_to_skip=0, epoch_steps=[1, 2, 3]) - self.assertEqual(list(epoch_steps), [1, 2, 2]) - - def test_skips_full_epoch(self): - epoch_steps = trainer_lib.epochs( - total_steps=4, steps_to_skip=2, epoch_steps=[2, 2]) - self.assertEqual(list(epoch_steps), [2]) - - def test_skips_part_of_epoch(self): - epoch_steps = trainer_lib.epochs( - total_steps=4, steps_to_skip=1, epoch_steps=[2, 2]) - self.assertEqual(list(epoch_steps), [1, 2]) - - -if __name__ == '__main__': - config.config_with_absl() - tf.compat.v1.enable_eager_execution() - absltest.main() diff --git a/trax/supervised/training.py b/trax/supervised/training.py deleted file mode 100644 index e65709ae1..000000000 --- a/trax/supervised/training.py +++ /dev/null @@ -1,1387 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simplified API (under development) for supervised learning/training in Trax. - -This module will eventually replace :py:class:`trainer_lib.Trainer`. - -Key classes: - - - :py:class:`Loop`: Core training loop for an n-step training session, - starting from random initialization. - - - :py:class:`TrainTask`: Labeled data + feedback mechanism (loss function w/ - optimizer) for modifying a model's weights. - - - :py:class:`Optimizer`: How to compute model weight updates using - loss-derived gradients. May contain state ("slots", 1-1 with model weights) - that accumulates across training steps. (This class is defined in the - :py:class:`trax.optimizers`.) - - - :py:class:`EvalTask`: How and when to measure model performance as a - function of training step number. -""" -import collections -import contextlib -import functools -import gzip as gzip_lib -import os -import pickle -import random -import sys -import time - -from absl import logging -import gin -import jax -import numpy as np -import psutil -import tensorflow as tf - -from trax import fastmath -from trax import jaxboard -from trax import layers as tl -from trax import optimizers -from trax import shapes -from trax.data import inputs -from trax.fastmath import numpy as jnp -from trax.fastmath import random as jax_random -from trax.layers import base -from trax.supervised import history as trax_history - - -_Evaluator = collections.namedtuple( - '_Evaluator', ['weights', 'state', 'metrics_fn'] -) - - -class Loop: - """Loop that can run for a given number of steps to train a supervised model. - - Can train the model on multiple tasks by interleaving updates according to the - ``which_task`` argument. - - The typical supervised training process randomly initializes a model and - updates its weights via feedback (loss-derived gradients) from a training - task, by looping through batches of labeled data. A training loop can also - be configured to run periodic evals and save intermediate checkpoints. - - For speed, the implementation takes advantage of JAX's composable function - transformations (specifically, ``jit`` and ``grad``). It creates JIT-compiled - pure functions derived from variants of the core model; schematically: - - - training variant: `jit(grad(pure_function(model+loss)))` - - evals variant: `jit(pure_function(model+evals))` - - In training or during evals, these variants are called with explicit - arguments for all relevant input data, model weights/state, optimizer slots, - and random number seeds: - - - batch: labeled data - - model weights/state: trainable weights and input-related state (e.g., as - used by batch norm) - - optimizer slots: weights in the optimizer that evolve during the training - process - - random number seeds: JAX PRNG keys that enable high-quality, distributed, - repeatable generation of pseudo-random numbers - """ - - def __init__( - self, - model, - tasks, - eval_model=None, - eval_tasks=None, - output_dir=None, - checkpoint_at=None, - checkpoint_low_metric=None, - checkpoint_high_metric=None, - permanent_checkpoint_at=None, - eval_at=None, - which_task=None, - n_devices=None, - random_seed=None, - loss_chunk_size=0, - use_memory_efficient_trainer=False, - adasum=False, - callbacks=None, - ): - """Configures a training ``Loop``, including a random initialization. - - Args: - model: Trax layer, representing the core model to be trained. Loss - functions and eval functions (a.k.a. metrics) are considered to be - outside the core model, taking core model output and data labels as - their two inputs. - tasks: List of :py:class:`TrainTask` instances, which define the training - data, loss function, and optimizer to be used in respective tasks in - this training loop. It can also be a single :py:class:`TrainTask` - instance which is treated in the same way as a singleton list. - eval_model: Optional Trax layer, representing model used for evaluation, - e.g., with dropout turned off. If ``None``, the training model (model) - will be used. - eval_tasks: List of :py:class:`EvalTask` instances which define how to - evaluate the model: which validation data to use and which metrics to - report. Evaluation on each of the tasks and will run and be reported - separately which allows to score a model on different subtasks. This - argument can also be ``None``, in which case no evals will be run, or - a single :py:class:`EvalTask`, which wil be treated in the same way - as a singleton list. - output_dir: Path telling where to save outputs (evals and checkpoints). - Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are - ``None``. - checkpoint_at: Function (integer --> boolean) telling, for step n, whether - that step should have its checkpoint saved. If ``None``, the default - is periodic checkpointing at ``task.n_steps_per_checkpoint``. - checkpoint_low_metric: Name of metric, or None. The metric name must - be one of the metric names from the evals in ``eval_tasks``. At - checkpoint times determined by ``checkpoint_at``, a separate - specially named checkpoint will be saved (overwriting any previous - version) if the designated metric reaches a value less than or equal - to any previous recorded low value. No such checkpoint is saved if - arg value is `None`. - checkpoint_high_metric: Name of metric, or None. The metric name must - be one of the metric names from the evals in ``eval_tasks``. At - checkpoint times determined by ``checkpoint_at``, a separate - specially named checkpoint will be saved (overwriting any previous - version) if the designated metric reaches a value greater than or - equal to any previous recorded high value. No such checkpoint is - saved if arg value is `None`. - permanent_checkpoint_at: Function (integer --> boolean) telling, - for step n, whether that step should have its checkpoint saved - permanently. If ``None``, the default is periodic checkpointing at - ``task.n_steps_per_permanent_checkpoint``. - eval_at: Function (integer --> boolean) that says, for training step n, - whether that step should run evals. If ``None``, run evals on the - first step and on every N'th step, as determined by the first - training task. - which_task: Function (integer --> integer) indicating which task should be - used at which training step. Can be set to ``None`` in single-task - training. - n_devices: integer or ``None``, the number of devices for this - computation. - random_seed: the random seed to use; time/os dependent if ``None`` - (default). - loss_chunk_size: int, if > 0 use chunks of this size to make loss - computation more more memory-efficient. - use_memory_efficient_trainer: whether to use a special memory-efficient - trainer; if set to 2, the memory efficiency if very aggressive - adasum: if True, use adaptive summation for multi-device gradients - callbacks: List of subclasses of StepCallback to call on training - steps. - """ - self._is_chief, self._n_hosts, self._n_devices, self._rng = ( - init_host_and_devices(n_devices, random_seed)) - if use_memory_efficient_trainer: - self._rng = tl.on_cpu(self._rng) - - # Handle single task case without lists too. - if not isinstance(tasks, (list, tuple)): - tasks = [tasks] - - if not tasks: - raise ValueError('Must provide at least one training task.') - if eval_tasks is None: - eval_tasks = [] - eval_at = _never - else: - if not isinstance(eval_tasks, (list, tuple)): - eval_tasks = [eval_tasks] - - self._tasks = tasks - self._model = model - self._eval_model = eval_model or model - - self._use_memory_efficient_trainer = use_memory_efficient_trainer - self._loss_chunk_size = loss_chunk_size - self._adasum = adasum - # TODO(lukaszkaiser): can we have different eval models and save memory? - if use_memory_efficient_trainer: - assert len(tasks) == 1, 'only single task supported for now' - self._eval_model = model - - default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint) - permanent_default_at = _at_step_1_and_every_nth_step( - tasks[0].n_steps_per_permanent_checkpoint) - if output_dir is not None: - self._output_dir = os.path.expanduser(output_dir) - tf.io.gfile.makedirs(self._output_dir) - inputs.load_data_counters(self._output_dir) - else: - self._output_dir = None - - # Prepare training components. - self._step = 0 - self._history = trax_history.History() - self._checkpoint_at = checkpoint_at or default_at - self._checkpoint_low_metric = checkpoint_low_metric - self._checkpoint_high_metric = checkpoint_high_metric - self._permanent_checkpoint_at = ( - permanent_checkpoint_at or permanent_default_at) - if which_task is None: - # If which task is not passed, then we permute tasks one by one. - # If len(tasks) = 1, then which_task is a constant function equal to 0. - which_task = lambda n: n % len(tasks) - self._which_task = which_task - - # Initialize using the given random seed. - # NOTE: If random_seed is None then self._rng will be different on - # different hosts, leading to different weights on the different hosts. - self._batch_signature = shapes.signature(tasks[0].sample_batch) - self._model.rng = self.new_rng() - # In the memory-efficient case, we initialize in init_trainer. - if not use_memory_efficient_trainer: - if _is_uninitialized(self._model): - self._model.init(self._batch_signature) - self._eval_model.rng = self.new_rng() - if _is_uninitialized(self._eval_model): - self._eval_model.init(self._batch_signature) - - # To handle the above case (i.e. random_seed = None), we psum the weights - # and state and average them. - # NOTE: This adds time (how much?) so we prefer not to do it if it is - # unnecessary, i.e. random_seed was set. - # NOTE: Averaging the weights across devices can screw up the initial weight - # statistics. - # TODO(pkozakowski): Broadcast from one of the devices instead? - if (random_seed is None and self._n_hosts > 1 and - not use_memory_efficient_trainer): - logging.info('Syncing weights/state across %d hosts.', self._n_hosts) - # Do self._sync_weights_and_state_across_hosts() but layer-by-layer - # to save memory. - blocks, last_layer = optimizers.trainer.extract_reversible_blocks( - [self._model]) - all_layers = [] - for (std_layer, rev_layers) in blocks: - all_layers.append(tl.Serial(std_layer)) - all_layers.extend(rev_layers) - all_layers.append(last_layer) - for layer in all_layers: - weights_and_state = (layer.weights, layer.state) - if not _is_empty(weights_and_state): - layer.weights, layer.state = tl.on_cpu(self._unreplicate( - _make_weights_and_state_same_across_hosts( - self._for_n_devices(weights_and_state)))) - - # Create the optimizer for the training loss function. - self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) - - # Sync layers weights/state in memory effcient trainer layers. - if (random_seed is None and self._n_hosts > 1 and - use_memory_efficient_trainer): - logging.info('Syncing layers across %d hosts.', self._n_hosts) - for layer in self._trainer_per_task[0].all_layers: - weights_and_state = (layer.weights, layer.state) - if not _is_empty(weights_and_state): - layer.weights, layer.state = tl.on_cpu(self._unreplicate( - _make_weights_and_state_same_across_hosts( - self._for_n_devices(weights_and_state)))) - - # Load checkpoint if it exists. - self.load_checkpoint() - - # Prepare eval components. - self._eval_at = eval_at or default_at - self._eval_tasks = eval_tasks - loss_names = [task.loss_name for task in self._tasks] - metric_names = [ - name # pylint: disable=g-complex-comprehension - for eval_task in self._eval_tasks - for name in eval_task.metric_names - ] - self._rjust_len = max(map(len, loss_names + metric_names)) - self._evaluator_per_task = tuple( - self._init_evaluator(eval_task) for eval_task in self._eval_tasks) - - if self._output_dir is None: - _log('Will not write evaluation metrics, because output_dir is None.') - - def task_output_dir(task_index, task_list): - if self._output_dir is not None: - if len(task_list) < 2: - output_dir = self._output_dir - else: - output_dir = os.path.join( - self._output_dir, - task_list[task_index].export_prefix or str(task_index)) - tf.io.gfile.makedirs(output_dir) - return output_dir - else: - return None - self._output_dir_per_eval_task = [ - task_output_dir(i, eval_tasks) for i in range(len(eval_tasks))] - self._output_dir_per_train_task = [ - task_output_dir(i, tasks) for i in range(len(tasks))] - - callbacks = callbacks or [] - self._callbacks = [ - callback_class(self) for callback_class in callbacks - ] - - def _init_trainer(self, task): - """Initializes the per-task trainer.""" - # Build the per-task model, sharing weights with other tasks. - if not self._use_memory_efficient_trainer: - model_in_training = _model_with_ends( - self._model, - [task.loss_layer], - shapes.signature(task.sample_batch) - ) - if base.N_WEIGHTS_SHARDS > 1: - sharded_weights = fastmath.nested_map( - lambda x: x[0], tl.shard(model_in_training.weights)) - task.optimizer.tree_init(sharded_weights) - else: - task.optimizer.tree_init(model_in_training.weights) - return optimizers.Trainer( - model_in_training, task.optimizer, adasum=self._adasum) - # In the memory-efficient path, we initialize the model here. - blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( - [self._model, task.loss_layer], loss_chunk_size=self._loss_chunk_size) - rng = self._model.rng - sig = shapes.signature(task.sample_batch) - optimizers.trainer.init_reversible_blocks(blocks, loss_layer, sig, rng) - # TODO(lukaszkaiser): here optimizer is a function, revisit this. - return optimizers.ReversibleSerialTrainer( - blocks, loss_layer, task.optimizer, - free_accelerators_on_step=(self._use_memory_efficient_trainer == 2), - adasum=self._adasum) - - def _init_evaluator(self, eval_task): - """Initializes the per-task evaluator.""" - model_with_metrics = _model_with_metrics( - self._eval_model, eval_task) - if self._use_memory_efficient_trainer: - return _Evaluator( - weights=tl.on_cpu(model_with_metrics.weights[1]), - state=tl.on_cpu(model_with_metrics.state[1]), - metrics_fn=_accelerate_model_with_metrics(model_with_metrics, 0) - ) - else: - return _Evaluator( - # Replicate the eval part of weights and state. - weights=self._for_n_devices(model_with_metrics.weights[1]), - state=self._for_n_devices(model_with_metrics.state[1]), - metrics_fn=_accelerate_model_with_metrics( - model_with_metrics, self.n_devices) - ) - - def _sync_weights_and_state_across_hosts(self): - """Sync weights and state across all the hosts in the computation.""" - - if logging.vlog_is_on(1): - logging.debug( - 'Input training weights shape: %s', - fastmath.nested_map(lambda x: x.shape, - self._model.weights)) - logging.debug('Input training weights: %s', self._model.weights) - logging.debug('Input training state: %s', self._model.state) - logging.debug('Input eval weights: %s', self._eval_model.weights) - logging.debug('Input eval state: %s', self._eval_model.state) - - (self._model.weights, self._model.state, - self._eval_model.weights, self._eval_model.state) = self._unreplicate( - _make_weights_and_state_same_across_hosts( - self._for_n_devices( - (self._model.weights, self._model.state, - self._eval_model.weights, - self._eval_model.state)))) - - if logging.vlog_is_on(1): - logging.debug( - 'Output training weights shape: %s', - fastmath.nested_map(lambda x: x.shape, self._model.weights)) - logging.debug('Output training weights: %s', self._model.weights) - logging.debug('Output training state: %s', self._model.state) - logging.debug('Output eval weights: %s', self._eval_model.weights) - logging.debug('Output eval state: %s', self._eval_model.state) - - def run(self, n_steps=1): - """Runs this training loop for n steps. - - Optionally runs evals and saves checkpoints at specified points. - - Args: - n_steps: Stop training after completing n steps. - """ - with self._open_summary_writers() as ( - train_summary_writers, eval_summary_writers): - process = psutil.Process(os.getpid()) - loss_acc, step_acc = 0.0, 0 - start_time = time.time() - optimizer_metrics_acc = collections.defaultdict(float) - for i in range(n_steps): - prev_task_index = self._which_task(self._step) - self._step += 1 - task_index = self._which_task(self._step) - task_changed = task_index != prev_task_index - - if task_changed: - loss_acc, step_acc = 0.0, 0 - - loss, optimizer_metrics = self._run_one_step(task_index, task_changed) - - # optimizer_metrics and loss are replicated on self.n_devices, a few - # metrics are replicated (ex: gradients_l2, weights_l2) - i.e. they are - # the same across devices, whereas some (ex: loss) aren't because they - # are different on different devices (due to different data). - # Taking the average does the correct thing in both the cases. - # - # NOTE: Only the weights and gradients are synced across the hosts. This - # implies the loss here is averaged from this hosts' devices and not - # across all hosts. - optimizer_metrics, loss = fastmath.nested_map( - functools.partial(tl.mean_or_pmean, self._n_devices), - (optimizer_metrics, loss)) - - loss_acc += loss - # Log loss every 50 steps, every step in memory-efficient trainer. - if self._step % 50 == 0 or self._use_memory_efficient_trainer: - self._log_step('Loss: %.4f' % loss, stdout=False) - step_acc += 1 - for metric_name, value in optimizer_metrics.items(): - optimizer_metrics_acc[metric_name] += value - - # TODO(yuwenyan): Finds a way to log the last round eval step in - # history. - # - # Right now, the last round eval log is missing in history since the - # checkpoint is saved before it. However sometimes the eval step will - # fail for some reasons, and it's not acceptable to loose the whole - # checkpoint in this case. Stays with the old way for now, and fixes it - # when the checkpoint format is changed to storing weights separately - # from a small file with history and other data. - if self._checkpoint_at(self.step): - self.save_checkpoint('model') - if self._permanent_checkpoint_at(self.step): - self.save_checkpoint(f'model_{self.step}') - if self._eval_at(self.step): - logging.info('cpu memory use (MB): %.2f', - process.memory_info().rss / float(1024 * 1024)) - elapsed_time = time.time() - start_time - self._log_training_progress( - task=self._tasks[task_index], - total_loss=loss_acc, - n_steps=step_acc, - elapsed_time=elapsed_time, - optimizer_metrics=optimizer_metrics_acc, - summary_writer=train_summary_writers[task_index], - ) - self.run_evals(eval_summary_writers) - loss_acc, step_acc = 0.0, 0 - start_time = time.time() - optimizer_metrics_acc = collections.defaultdict(float) - - # For the current step, after all evals are run and recorded in the - # event history, check if we need to save special checkpoints because - # of a new low metric value or a new high metric value. - if self._checkpoint_at(self.step): - if self._checkpoint_low_metric is not None and self._at_lowest(): - self.save_checkpoint(f'lowest_{self._checkpoint_low_metric}') - if self._checkpoint_high_metric is not None and self._at_highest(): - self.save_checkpoint(f'highest_{self._checkpoint_high_metric}') - - # Store the final values back into their respective objects, for testing - # or other inspection/use. - # - # We keep the standard model weights/state unreplicated and - # tl.Accelerate(model) will carry the replicated weights/state. - # TODO(afrozm): Try to use tl.Accelerate(model) everywhere in the Loop. - self._eval_model.weights = self._model.weights - - def _at_lowest(self): - low_items = self.history.get('eval', - f'metrics/{self._checkpoint_low_metric}') - vals = [float(obj[1]) for obj in low_items] - return vals[-1] == min(vals) - - def _at_highest(self): - high_items = self.history.get('eval', - f'metrics/{self._checkpoint_high_metric}') - vals = [float(obj[1]) for obj in high_items] - return vals[-1] == max(vals) - - @property - def step(self): - """Returns current step number in this training session.""" - return self._step - - @property - def history(self): - """Returns history in this training session.""" - return self._history - - @property - def n_devices(self): - """Returns the number of devices to be used in this computation.""" - return self._n_devices - - @property - def is_chief(self): - """Returns true if this Loop is the chief.""" - return self._is_chief - - @property - def model(self): - """Returns the model that is training.""" - return self._model - - @property - def tasks(self): - """Returns the training tasks.""" - return self._tasks - - @property - def eval_model(self): - """Returns the model used for evaluation.""" - return self._eval_model - - @property - def eval_tasks(self): - """Returns the evaluation tasks.""" - return self._eval_tasks - - @property - def output_dir(self): - """Returns the output directory.""" - return self._output_dir - - def new_rng(self): - """Returns a new single-use random number generator (JAX PRNG key).""" - self._rng, rng = fastmath.random.split(self._rng) - if self._use_memory_efficient_trainer: - self._rng = tl.on_cpu(self._rng) - rng = tl.on_cpu(rng) - return rng - - def _for_n_devices(self, x): - """Replicates/broadcasts ``x`` for n devices if ``self.n_devicess > 1``.""" - return tl.for_n_devices(x, self.n_devices) - - def _unreplicate(self, x): - if self.n_devices == 1: - return x - - unreplicate_fn = lambda x: x[0] - return fastmath.nested_map(unreplicate_fn, x) - - def _reshape_by_device(self, x): - if self.n_devices == 1: - return x - return tl.reshape_by_device(x, self.n_devices) - - def update_weights_and_state(self, weights=None, state=None): - """Updates the weights and state of the trained model. - - Sends this data both to the singleton model accessible via Loop.model - and to the replicated model on the accelerator. - - Useful when the weights or state are modified outside of training, e.g. - during data collection in RL agents. - - Args: - weights: Model weights or ``None``. If ``None``, don't set. - state: Model state or ``None``. If ``None``, don't set. - """ - for trainer in self._trainer_per_task: - acc_model_with_loss = trainer.accelerated_model_with_loss - if weights is not None: - self._model.weights = weights - acc_model_with_loss.replicate_weights(trainer.model_with_loss.weights) - if state is not None: - self._model.state = state - acc_model_with_loss.replicate_state(trainer.model_with_loss.state) - - def _run_one_step(self, task_index, task_changed): - """Updates model weights/state and optimizer slots by running one step. - - Args: - task_index (int): Index of the task to train on. - task_changed (bool): Whether the state has changed since the last step. - - Returns: - Tuple (loss, stats) with loss value from one step - of training and stats, the current optimizer statistics. - """ - step = self.step - for callback in self._callbacks: - if callback.call_at(step): - callback.on_step_begin(step) - - learning_rate = self._tasks[task_index].learning_rate(step) - batch = self._tasks[task_index].next_batch() - rng = self.new_rng() - trainer = self._trainer_per_task[task_index] - if task_changed: - # Re-replicate weights and state to synchronize them between tasks. - self.update_weights_and_state(self._model.weights, self._model.state) - - (loss, stats) = trainer.one_step( - batch, rng, step=step, learning_rate=learning_rate - ) - - for callback in self._callbacks: - if callback.call_at(step): - callback.on_step_end(step) - - return (loss, stats) - - def _log_training_progress(self, task, total_loss, n_steps, elapsed_time, - optimizer_metrics, summary_writer): - """Logs training related metrics. - - Logs: - * current learning rate, - * steps per second, - * average training loss, - * average metrics returned from the optimizer - to the provided summary writer. Training loss is also logged to stdout. - - Args: - task: Current training task. - total_loss: Total training loss accumulated over n_steps training steps. - n_steps: Number of steps over which the metrics were accumulated. - elapsed_time: Time of execution of n_steps training steps. - optimizer_metrics: Dict from optimizer metric name to metric values. - summary_writer: Jaxboard summary writer for saving provided metrics. - """ - # only here do avoid potential divide-by-0 - n_steps = max(1, n_steps) - _log('') # Separator for visibility on terminals. - if self.step == 1: - self._log_n_weights() - self._log_step('Ran %d train steps in %0.2f secs' % (n_steps, elapsed_time)) - self.log_summary( - {task.loss_name: total_loss / float(n_steps)}, - summary_writer, 'metrics/', 'train') - if self.step == 1: - self._save_gin(summary_writer) - train_parameters = { - 'learning_rate': task.learning_rate(self.step), - 'steps per second': n_steps / elapsed_time, - } - # Average optimizer_metrics over n_steps. - optimizer_metrics = {k: v / n_steps for k, v in optimizer_metrics.items()} - train_parameters.update(optimizer_metrics) - self.log_summary( - train_parameters, summary_writer, 'training/', 'train', stdout=False) - - def _save_gin(self, summary_writer): - """"Saves the operative gin config.""" - if not self.is_chief or self._output_dir is None: - return - config_path = os.path.join(self._output_dir, 'config.gin') - config_str = gin.operative_config_str() - with tf.io.gfile.GFile(config_path, 'w') as f: - f.write(config_str) - if summary_writer is not None: - summary_writer.text( - 'gin_config', jaxboard.markdownify_operative_config_str(config_str) - ) - - def _log_n_weights(self): - """"Logs the number of weights in the training model.""" - def _size(x): - try: - return x.size - except Exception: # pylint: disable=broad-except - return 0 - sizes = fastmath.nested_map(_size, self._model.weights) - total_size = sum(fastmath.tree_flatten(sizes)) - total_size *= base.N_WEIGHTS_SHARDS - self._log_step('Total number of trainable weights: %d' % total_size) - - # TODO(afrozm): Fix multi-host evals, right now the reported numbers in the - # summary writer are only from the chief and not averaged across hosts. - def run_evals(self, summary_writers=None): - """Runs and records evals for this training session. - - Args: - summary_writers: List of per-task Jaxboard summary writers to log metrics. - """ - if summary_writers is None: - summary_writers = (None,) * len(self._eval_tasks) - - self._eval_model.weights = self._model.weights - self._eval_model.state = self._model.state - - def recursively_look_for_printable_states(state): - if isinstance(state, (tuple, list)): - for substate in state: - for item in recursively_look_for_printable_states(substate): - yield item - if isinstance(state, dict): - for key, value in state.items(): - if isinstance(key, str) and key.startswith('summary_'): - for device_id, device_value in enumerate(value): - yield ('device{}/{}'.format(device_id, key[len('summary_'):]), - device_value) - - # The most recently trained weights are in this trainer, use those for eval. - cur_train_task_index = self._which_task(self._step) - trainer = self._trainer_per_task[cur_train_task_index] - - for eval_task_index in range(len(self._eval_tasks)): - eval_task = self._eval_tasks[eval_task_index] - evaluator = self._evaluator_per_task[eval_task_index] - if eval_task is None: - continue - - # Extract the actual model weights and state, excluding the loss layer. - if self._use_memory_efficient_trainer: - model_weights, model_state = self._model.weights, self._model.state - else: - model_weights = trainer.accelerated_model_with_loss.weights[0] - model_state = trainer.accelerated_model_with_loss.state[0] - - # evaluator.{weights,state} are already replicated. - metrics_weights = (model_weights, evaluator.weights) - metrics_state = (model_state, evaluator.state) - - n_batches = eval_task.n_eval_batches - n_metrics = len(eval_task.metrics) - sums = np.zeros((n_metrics,)) - for _ in range(n_batches): - rng = self.new_rng() - batch = eval_task.next_batch() - metric_values, _ = evaluator.metrics_fn( - batch, metrics_weights, metrics_state, rng) - sums += metric_values - averages = sums / n_batches - all_metrics = dict(zip(eval_task.metric_names, averages)) - summary_writer = summary_writers[eval_task_index] - self.log_summary(all_metrics, summary_writer, 'metrics/', 'eval') - summary_metrics = dict(recursively_look_for_printable_states( - model_state)) - self.log_summary(summary_metrics, summary_writer, 'summary_', 'eval') - - def log_summary(self, values, summary_writer, value_prefix, log_prefix, - stdout=True): - """Logs and saves provided metrics. - - Args: - values: Dict from metric name to metric value. - summary_writer: Jaxboard summary writer. - value_prefix: String appended in front of summary_writer entries. - log_prefix: String appended in front of logs. - stdout: Boolean saying if logs should be logged to stdout as well. - """ - history = self._history - should_write_summaries = self.is_chief and summary_writer is not None - for name, value in values.items(): - full_name = value_prefix + name - s = tuple(jnp.shape(value)) - if not s: - self._log_step( - '%s %s | % .8f' % - (log_prefix.ljust(5), name.rjust(self._rjust_len), value), - stdout=stdout) - if should_write_summaries: - summary_writer.scalar(full_name, value, self.step) - else: - if should_write_summaries: - summary_writer.image(full_name, value, self.step) - if history: - history.append(log_prefix, full_name, self.step, value) - if should_write_summaries: - summary_writer.flush() - - def _log_step(self, msg, stdout=True): - """Logs message, labeled with the current training step number.""" - _log('Step % 6d: %s' % (self.step, msg), stdout=stdout) - - def save_checkpoint(self, basename): - """Saves checkpoint (multiple files) to disk for the current training step. - - Saving a checkpoint will overwrite any previous checkpoint saved with the - same ``basename``. Use differing ``basename`` values to save multiple - checkpoints or multiple copies of the same checkpoint. - - Args: - basename: Basename for saving a checkpoint. Full file paths for the saved - checkpoint will combine the output dir, basename, and relevant file - extensions (e.g., `.weights.npy.gz`). - """ - if self._output_dir is None: - _log('Did not save checkpoint as output_dir is None') - return - - inputs.save_data_counters(self._output_dir) - if not self.is_chief: - _log('Did not save checkpoint as we are not chief.') - return - - dir_and_basename = os.path.join(self._output_dir, basename) - pkl_file = dir_and_basename + '.pkl.gz' - - _log('Saving checkpoint to %s' % pkl_file, stdout=False) - weights = self._model.weights - if base.N_WEIGHTS_SHARDS > 1: - weights = self._trainer_per_task[0].accelerated_model_with_loss.weights - weights = tl.unshard(weights) - state = self._model.state - compresslevel = 0 if self._use_memory_efficient_trainer else 2 - # Serialize optimizer slots. - for i, trainer in enumerate(self._trainer_per_task): - flat_slots = _flatten_and_remove_empty(trainer.slots) - tl.np_to_file(self._to_bits(flat_slots), - f'{dir_and_basename}.opt_slots{i}.npy.gz', - compresslevel=compresslevel) - # We only need the input signature for the body, not for the loss layers. - # That part is the same across tasks - take it from the first one. - input_signature = self._batch_signature[:self._model.n_in] - flat_weights, flat_state = tl.flatten_weights_and_state(weights, state) - _, flat_eval_state = tl.flatten_weights_and_state( - weights, self._eval_model.state) - tl.np_to_file(self._to_bits(flat_weights), - f'{dir_and_basename}.weights.npy.gz', - compresslevel=compresslevel) - d = { - 'step': self.step, - 'flat_weights': compresslevel, # for compatibility with older format - 'flat_state': flat_state, - 'flat_eval_state': flat_eval_state, - 'history': self._history.to_dict(), - 'slots_per_task': compresslevel, # for compatibility with older format - 'input_signature': input_signature, - 'version_timestamp': 'Mar-10-2021' # To update in the future if needed. - } - pickle_to_file(d, pkl_file, gzip=True) - _log('Checkpoint saved in %s' % pkl_file, stdout=False) - - def _to_bits(self, weights): - """Converts a list of weights to bit-cast weights and their types.""" - # This is currently needed to pickle bfloat16 arrays from JAX. - # TODO(lukaszkaiser): remove once it is not needed (the following unit test - # checks it: training_test/test_restores_step_bfloat16). - if not fastmath.is_backend(fastmath.Backend.JAX): - return weights - bits = [] - for w in weights: - if w.dtype == jnp.bfloat16: - converted = jax.lax.bitcast_convert_type(w, np.uint16) - bits.append(np.asarray(converted.astype(np.uint16))) - else: # for non-bfloat16 weights, be compatible with earlier checkpoints - bits.append(np.asarray(w)) - return bits - - def _from_bits(self, bits): - """Converts a list of bit-cast weights back to weights.""" - # This is the reverse of _to_bits, see above for explanation. - if not fastmath.is_backend(fastmath.Backend.JAX): - return bits - weights = [] - for b in bits: - if b.dtype == np.uint16: # currently all uint16 are bfloat16s - w = jax.lax.bitcast_convert_type(b, jnp.bfloat16) - weights.append(np.asarray(w)) - else: - weights.append(b) - return weights - - def load_checkpoint(self, directory=None, filename=None): - """Loads model weights and step from a checkpoint on disk. - - Args: - directory: Directory with the checkpoint (self._output_dir by default). - filename: Checkpoint file name (model.pkl.gz by default). - """ - directory = directory or self._output_dir - if directory is None: - _log('Not loading as both directory and output_dir are None.', - stdout=False) - return - filename = filename or 'model' - path = os.path.join(directory, filename) - pkl_path = path + '.pkl.gz' - if not tf.io.gfile.exists(pkl_path): - _log(f'Not loading as checkpoint file does not exist: {pkl_path}', - stdout=False) - return - _log('Loading checkpoint from %s' % pkl_path, stdout=False) - d = unpickle_from_file(pkl_path, gzip=True) - # Weights are stored in a separate non-pickled file in the new checkpoint - # format. We support loading old checkpoints with this hack. - # TODO(lukaszkaiser): remove the hack when not needed any more. - if isinstance(d['flat_weights'], int): - weights = tl.np_from_file(path + '.weights.npy.gz', - compresslevel=d['flat_weights']) - d['flat_weights'] = weights - else: - d['flat_weights'] = d['flat_weights'] - # The same holds for optimizer slots. - if 'slots' in d: # Old checkpoints had just 'slots' for one task. - if len(self._tasks) != 1: - raise ValueError( - 'Can\'t load a single-task checkpoint into a multitask Loop.' - ) - d['slots_per_task'] = [d['slots']] - # Read from separate files if optimizer slots are in them. - if 'slots_per_task' in d and isinstance(d['slots_per_task'], int): - compresslevel = d['slots_per_task'] - d['slots_per_task'] = [] - for i in range(len(self._trainer_per_task)): - slots = tl.np_from_file(path + f'.opt_slots{i}.npy.gz', - compresslevel=compresslevel) - d['slots_per_task'].append(slots) - for (trainer, slots) in zip(self._trainer_per_task, d['slots_per_task']): - matched_flat_slots = _match_by_shape( - self._to_bits(_flatten_and_remove_empty(trainer.slots)), - _flatten_and_remove_empty(slots)) - matched_slots, _ = fastmath.tree_unflatten( - self._from_bits(matched_flat_slots), - trainer.slots, copy_from_tree=[None, ()]) - trainer.slots = matched_slots - self._step = d['step'] - self._history = trax_history.History.from_dict(d['history']) - # This is self._model.init_from_file but optimized to not re-read. - input_signature = d['input_signature'] - weights_and_state_sig = self._model.weights_and_state_signature( - input_signature) - flat_init_weights, flat_init_state = tl.flatten_weights_and_state( - self._model.weights, self._model.state) - if len(d['flat_weights']) < len(flat_init_weights): - _log('Checkpoint has less weights than the model, loading first ones.') - matched_weights = _match_by_shape(self._to_bits(flat_init_weights), - d['flat_weights']) - matched_weights = self._from_bits(matched_weights) - try: - restored_state = True - matched_state = _match_by_shape(self._to_bits(flat_init_state), - d['flat_state']) - matched_state = self._from_bits(matched_state) - weights, state = tl.unflatten_weights_and_state( - matched_weights, matched_state, weights_and_state_sig) - self._model.state = state - except IndexError: - _log('Failed loading model state from checkpoint, loading weights only.') - restored_state = False - weights, _ = tl.unflatten_weights_and_state( - matched_weights, (), weights_and_state_sig, weights_only=True) - self._model.weights = weights - self._eval_model.weights = self._model.weights - # Restore eval model state; note: it's not always the same as train state. - if restored_state: - if 'flat_eval_state' in d: - flat_eval_state = d['flat_eval_state'] - else: # It wasn't saved in old checkpoints; remove this branch once done. - flat_eval_state = d['flat_state'] - _, eval_state = tl.unflatten_weights_and_state( - matched_weights, flat_eval_state, weights_and_state_sig) - self._eval_model.state = eval_state - _log('Checkpoint loaded from %s' % pkl_path, stdout=False) - - @contextlib.contextmanager - def _open_summary_writers(self): - """Opens the Jaxboard summary writers wrapped by context manager. - - Yields: - A pair (train_summary_writers, eval_summary_writers) of lists of - Jaxboard summary writers wrapped in a GeneratorContextManager object. - Elements of the lists correspond to the training and evaluation task - directories created during initialization. If there was no output_dir - provided, yields lists of Nones with the appropriate length. - """ - if self._output_dir is not None: - _log(f'Metrics will be written in {self._output_dir}.', stdout=False) - train_writers = [jaxboard.SummaryWriter(os.path.join(output_dir, 'train')) - for output_dir in self._output_dir_per_train_task] - eval_writers = [jaxboard.SummaryWriter(os.path.join(output_dir, 'eval')) - for output_dir in self._output_dir_per_eval_task] - try: - yield (train_writers, eval_writers) - finally: - for writer in train_writers + eval_writers: - writer.close() - _log(f'Metrics were written in {self._output_dir}', stdout=False) - else: - yield ([None] * len(self._tasks), [None] * len(self._eval_tasks)) - - -def _model_with_ends(model, end_layers, batch_signature): - """Returns a model+ends layer built on an already initialized model. - - Ends can be loss or metric layers. - - Args: - model: Layer with initialized weights and state. - end_layers: List of end layers. - batch_signature: Signature of the model input batch. - - Returns: - An initialized, combined model+ends layer, preserving the initialization - of ``model``. - """ - # TODO(jonni): Redo this function as part of an initialization refactor? - metrics_layer = tl.Branch(*end_layers) - metrics_input_signature = model.output_signature(batch_signature) - _, _ = metrics_layer.init(metrics_input_signature) - - model_with_metrics = tl.Serial(model, metrics_layer) - return model_with_metrics - - -def _model_with_metrics(model, eval_task): - """Returns a model+metrics layer built on an already initialized model. - - Args: - model: Layer with initialized weights and state. - eval_task: :py:class:`EvalTask` instance. - - Returns: - An initialized, combined model+metrics layer, preserving the initialization - of ``model``. - """ - return _model_with_ends( - model, eval_task.metrics, shapes.signature(eval_task.sample_batch) - ) - - -@gin.configurable -class TrainTask: - """A supervised task (labeled data + feedback mechanism) for training.""" - - def __init__(self, labeled_data, loss_layer, optimizer, - lr_schedule=None, n_steps_per_checkpoint=100, - n_steps_per_permanent_checkpoint=None, loss_name=None, - sample_batch=None, export_prefix=None): - r"""Configures a training task. - - Args: - labeled_data: Iterator of batches of labeled data tuples. Each tuple has - 1+ data (input value) tensors followed by 1 label (target value) - tensor. All tensors are NumPy ndarrays or their JAX counterparts. - loss_layer: Layer that computes a scalar value (the "loss") by comparing - model output :math:`\hat{y}=f(x)` to the target :math:`y`. - optimizer: Optimizer object that computes model weight updates from - loss-function gradients. - lr_schedule: Learning rate schedule, a function step -> learning_rate. - n_steps_per_checkpoint: How many steps to run between checkpoints. - n_steps_per_permanent_checkpoint: How many steps to run between permanent - checkpoints. - loss_name: Name for the loss metric. - sample_batch: Optional sample batch for model initialization. If not - provided, it will be taken from ``labeled_data``. - export_prefix: Optional task name to be used as prefix for exporting - metrics during training in Loop. - """ - self._export_prefix = export_prefix - self._labeled_data = labeled_data - self._loss_layer = loss_layer - self._optimizer = optimizer - self._lr_schedule = lr_schedule - self._sample_batch = sample_batch or next(labeled_data) - self._n_steps_per_checkpoint = n_steps_per_checkpoint - self._n_steps_per_permanent_checkpoint = n_steps_per_permanent_checkpoint - self._loss_name = loss_name or self._loss_layer.name - - @property - def labeled_data(self): - return self._labeled_data - - @property - def sample_batch(self): - return self._sample_batch - - def next_batch(self): - """Returns one batch of labeled data: a tuple of input(s) plus label.""" - return next(self._labeled_data) - - @property - def export_prefix(self): - return self._export_prefix - - @property - def loss_layer(self): - return self._loss_layer - - @property - def loss_name(self): - return self._loss_name - - @property - def n_steps_per_checkpoint(self): - return self._n_steps_per_checkpoint - - @property - def n_steps_per_permanent_checkpoint(self): - return self._n_steps_per_permanent_checkpoint - - @property - def optimizer(self): - return self._optimizer - - def learning_rate(self, step): - """Return the learning rate for the given step.""" - if self._lr_schedule is not None: - with fastmath.use_backend(fastmath.Backend.NUMPY): - return self._lr_schedule(step) - opt = self._optimizer - if callable(opt): # when optimizer is a function, like Adam, not Adam() - opt = opt() - params = opt._init_opt_params # pylint: disable=protected-access - return params['learning_rate'] - - -@gin.configurable -class EvalTask: - """Labeled data plus scalar functions for (periodically) measuring a model. - - An eval task specifies how (``labeled_data`` + ``metrics``) and with what - precision (``n_eval_batches``) to measure a model as it is training. - The variance of each scalar output is reduced by measuring over multiple - (``n_eval_batches``) batches and reporting the average from those - measurements. - """ - - def __init__(self, labeled_data, metrics, - metric_names=None, n_eval_batches=1, sample_batch=None, - export_prefix=None): - r"""Configures an eval task: named metrics run with a given data source. - - Args: - labeled_data: Iterator of batches of labeled data tuples. Each tuple has - 1+ data tensors (NumPy ndarrays) followed by 1 label (target value) - tensor. - metrics: List of layers; each computes a scalar value per batch by - comparing model output :math:`\hat{y}=f(x)` to the target :math:`y`. - metric_names: List of names, one for each item in ``metrics``, in matching - order, to be used when recording/reporting eval output. If ``None``, - generate default names using layer names from metrics. - n_eval_batches: Integer N that specifies how many eval batches to run; - the output is then the average of the outputs from the N batches. - sample_batch: Optional sample batch for model initialization. If not - provided, it will be taken from ``labeled_data``. - export_prefix: Optional task name to be used as prefix for exporting - metrics during evaluation in Loop. - """ - self._export_prefix = export_prefix - self._labeled_data = labeled_data - self._metrics = metrics - self._metric_names = metric_names or self._default_names() - self._n_eval_batches = n_eval_batches # pylint: disable=invalid-name - - self._sample_batch = sample_batch or next(labeled_data) - self._check_init_values() - - @property - def labeled_data(self): - return self._labeled_data - - @property - def sample_batch(self): - return self._sample_batch - - def next_batch(self): - """Returns one batch of labeled data: a tuple of input(s) plus label.""" - return next(self._labeled_data) - - @property - def export_prefix(self): - return self._export_prefix - - @property - def metrics(self): - return self._metrics - - @property - def metric_names(self): - return self._metric_names - - @property - def n_eval_batches(self): - return self._n_eval_batches - - def _default_names(self): - return [m.name for m in self._metrics] - - def _check_init_values(self): - if len(self._metrics) != len(self._metric_names): - raise ValueError( - f'Number of metrics ({len(self._metrics)}) does not equal ' - f'number of metric names ({len(self._metric_names)}).') - - -def _never(*args): - """Returns False for all step numbers.""" - del args - return False - - -def _at_step_1_and_every_nth_step(period): - """A function that's true at 1 and n when n % period == 0.""" - if period is None: - return lambda step_n: False - - def _at_1_and_periodically(step_n): - return (step_n == 1) or (step_n > 0 and (step_n % period == 0)) - return _at_1_and_periodically - - -def _log(s, stdout=True): - logging.info(s) - if stdout: - print(s) - sys.stdout.flush() - - -def pickle_to_file(obj, file_path, gzip=False): - """Pickle obj to file_path with gzipping and failure protection.""" - # Pickle to tmp file and overwrite to prevent writing partial files. - tmp_file_path = file_path + '._tmp_' - with tf.io.gfile.GFile(tmp_file_path, 'wb') as f: - if not gzip: - pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) - else: - with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: - pickle.dump(obj, gzipf, protocol=pickle.HIGHEST_PROTOCOL) - # Moving a file is much less error-prone than pickling large files. - tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True) - - -def unpickle_from_file(file_path, gzip=False): - """Unpickle obj from file_path with gzipping.""" - with tf.io.gfile.GFile(file_path, 'rb') as f: - if not gzip: - obj = pickle.load(f) - else: - with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf: - obj = pickle.load(gzipf) - return obj - - -def _init_random_number_generators(seed=None): - """Initializes random generators for Python, NumPy, TensorFlow, and JAX.""" - # Seed Python random (None as seed is okay), then use it to seed the others. - random.seed(seed) - if seed is None: - seed = random.randint(0, 2**31 - 1) - logging.info('using seed %d', seed) - np.random.seed(seed) - tf.random.set_seed(seed) - return jax_random.get_prng(seed) - - -def init_host_and_devices(n_devices=None, random_seed=None): - """Initializes host and device attributes for this trainer. - - Args: - n_devices: Number of devices this trainer will use. If ``None``, get the - number from the backend. - random_seed: Random seed as the starting point for all random numbers used - by the trainer. If ``None``, calculate one from system time and host id. - - Returns: - is_chief: True if this trainer has special chief responsibilities. - host_count: Number of hosts in this computation. - n_devices: The passed in value of n_devices or a computed default (for this - host). - random_seed: The passed in value of random_seed or a computed default. - """ - if fastmath.is_backend(fastmath.Backend.JAX): - host_id = jax.process_index() - host_count = jax.host_count() - else: - host_id = 0 - host_count = 1 - is_chief = (host_id == 0) - - logging.info('Initializing hosts and devices: host_id %d, host_count %d, ' - 'is_chief %d', host_id, host_count, is_chief) - - device_count = fastmath.local_device_count() - n_devices = n_devices or device_count - # TODO(lukaszkaiser): remove this restriction when possible. - if n_devices != device_count and fastmath.is_backend(fastmath.Backend.JAX): - raise ValueError('JAX cannot work yet with n_devices != all devices: ' - '%d != %d' % (n_devices, device_count)) - - if random_seed is None and host_count > 1: - random_seed = int(1e6 * (host_id + time.time())) % 2**31 - return (is_chief, host_count, n_devices, - _init_random_number_generators(random_seed)) - - -def _accelerate_model_with_metrics(model_with_metrics, n_devices, - accelerate=True, do_mean=True): - if not accelerate: - return model_with_metrics.pure_fn - - return tl.jit_forward(model_with_metrics.pure_fn, n_devices, do_mean=do_mean) - - -@functools.partial(fastmath.pmap, axis_name='devices', donate_argnums=(0,)) -def _make_weights_and_state_same_across_hosts(weights_and_state): - """Makes train and eval model's weights and state the same across hosts.""" - - # We assume that weights_and_state have been already replicated, i.e the - # leading axis is self._n_devices - - # This is the total number of devices across all hosts. - n_devices_total = fastmath.psum(jnp.array(1.0), 'devices').astype(jnp.int32) - - # We average the weights and state across all devices. - # We also make sure we don't change the type of the weights and state. - return fastmath.nested_map( - lambda x: (fastmath.psum(x, 'devices') / n_devices_total).astype(x.dtype), - weights_and_state) - - -def _is_empty(x): - if isinstance(x, (list, tuple)): - return all(_is_empty(y) for y in x) - else: - return x is None - - -def _is_uninitialized(model): - """Checks whether no weights in the model have been initialized.""" - if not _is_empty(model.weights): - return False - return all(_is_uninitialized(l) for l in model.sublayers) - - -def _match_by_shape(full, partial): - """Puts partial into full matching by shape.""" - partial_idx = 0 - res = [] - for w in full: - if partial_idx >= len(partial): - res.append(w) # read everything from parial list, just fill - elif w is None and partial[partial_idx] is None: # both Nones, move on - res.append(None) - partial_idx += 1 - elif w is None or partial[partial_idx] is None: # one None but not both - res.append(w) - elif w.shape == partial[partial_idx].shape: - res.append(partial[partial_idx]) - partial_idx += 1 - else: - res.append(w) - if partial_idx < len(partial): - _log('Did not manage to match shapes in model for all checkpoint weights.') - for w in partial[:partial_idx]: - _log(' Inserted tensor of shape %s' % str(w.shape)) - for i, w in enumerate(partial[partial_idx:]): - _log(' Not inserted tensor of shape %s' % str(w.shape)) - model_weight_shape = str(full[i + partial_idx].shape) - _log(' Tensor in that place has shape: %s' % model_weight_shape) - raise IndexError - return res - - -def _flatten_and_remove_empty(x): - flat = fastmath.tree_flatten(x) - return [f for f in flat if f is not None and f is not ()] # pylint: disable=literal-comparison diff --git a/trax/supervised/training_test.py b/trax/supervised/training_test.py deleted file mode 100644 index 7d6d6bf5a..000000000 --- a/trax/supervised/training_test.py +++ /dev/null @@ -1,674 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for supervised training: core classes and flows.""" - -import collections -import os -import time - -from absl.testing import absltest -from jax.config import config -import numpy as np - -from trax import data -from trax import fastmath -from trax import layers as tl -from trax import optimizers -from trax import shapes -from trax import test_utils -from trax.layers import base -from trax.models import transformer -from trax.supervised import callbacks -from trax.supervised import training - - -class TrainingTest(absltest.TestCase): - - def setUp(self): - super().setUp() - test_utils.ensure_flag('test_tmpdir') - - def test_loop_no_eval_task(self): - """Runs a training loop with no eval task(s).""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - training_session = training.Loop(model, [task]) - # Loop should initialize and run successfully, even with no eval task. - training_session.run(n_steps=5) - - - def test_loop_checkpoint_low_metric(self): - """Runs a training loop that saves checkpoints for low metric values.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask(_very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01)) - eval_metric = tl.L2Loss() - eval_task = training.EvalTask(_very_simple_data(), - [eval_metric], - metric_names=['l2_loss']) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, - [task], - eval_tasks=[eval_task], - output_dir=tmp_dir, - eval_at=lambda step_n: step_n % 2 == 0, - checkpoint_at=lambda step_n: step_n % 2 == 0, - checkpoint_low_metric='l2_loss') - call_counter = collections.Counter() - loop.save_checkpoint = lambda name: call_counter.update([name]) - loop.run(n_steps=10) - - # Eval metric steadily descends, so low checkpoint triggered all 5 times. - # High checkpoint not defined, so never triggered. - self.assertEqual(call_counter['model'], 5) - self.assertEqual(call_counter['lowest_l2_loss'], 5) - self.assertEqual(call_counter['highest_l2_loss'], 0) - - def test_loop_checkpoint_high_metric(self): - """Runs a training loop that saves checkpoints for high metric values.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask(_very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01)) - eval_metric = tl.L2Loss() - eval_task = training.EvalTask(_very_simple_data(), - [eval_metric], - metric_names=['l2_loss']) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, - [task], - eval_tasks=[eval_task], - output_dir=tmp_dir, - eval_at=lambda step_n: step_n % 2 == 0, - checkpoint_at=lambda step_n: step_n % 2 == 0, - checkpoint_high_metric='l2_loss') - call_counter = collections.Counter() - loop.save_checkpoint = lambda name: call_counter.update([name]) - loop.run(n_steps=10) - - # Eval metric steadily descends, so high checkpoint triggered only once. - # Low checkpoint not defined, so never triggered. - self.assertEqual(call_counter['model'], 5) - self.assertEqual(call_counter['lowest_l2_loss'], 0) - self.assertEqual(call_counter['highest_l2_loss'], 1) - - def test_train_dense_layer(self): - """Trains a very simple network on a very simple task.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=15) - self.assertEqual(15, training_session.step) - training_session.run(n_steps=5) - self.assertEqual(20, training_session.step) - - def test_loop_with_initialized_model(self): - """Check that loop does not re-initialize an already initialized model.""" - model = tl.Serial(tl.Dense(1)) - example_data = next(_very_simple_data()) - model.init(example_data) - w = model.weights[0][0] - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - loop = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0) - self.assertEqual(0, loop.step) - self.assertEqual(loop.model.weights[0][0], w) - - def test_train_save_restore_dense(self): - """Saves and restores a checkpoint to check for equivalence.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - train_data = data.Serial(lambda _: _very_simple_data(), - data.CountAndSkip('simple_data')) - task = training.TrainTask( - train_data(), tl.L2Loss(), optimizers.Adam(.0001)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - - def _make_model_and_session(): - m = tl.Serial(tl.Dense(1)) - ts = training.Loop(m, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - return m, ts - - model, training_session = _make_model_and_session() - self.assertEqual(0, training_session.step) - training_session.run(n_steps=1) - training_session.save_checkpoint('model') - self.assertEqual(data.inputs.data_counters['simple_data'], 2) - data.inputs.data_counters['simple_data'] = 0 # reset manually - self.assertEqual(data.inputs.data_counters['simple_data'], 0) # check - model2, training_session2 = _make_model_and_session() - self.assertEqual(data.inputs.data_counters['simple_data'], 2) # restored - - x = np.ones((8, 1)) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertEqual(str(y1), str(y2)) - - training_session2.run(n_steps=1) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertNotEqual(str(y1), str(y2)) - - slots1 = training_session._trainer_per_task[0].slots - slots2 = training_session2._trainer_per_task[0].slots - np.testing.assert_array_equal(slots1, slots2) - - def test_train_save_restore_sharded(self): - """Saves and restores a sharded checkpoint to check for equivalence.""" - if fastmath.local_device_count() < 2: - return # multi-accelerator only - base.N_WEIGHTS_SHARDS = fastmath.local_device_count() - train_data = data.Serial(lambda _: _very_simple_data(2, 2), - data.CountAndSkip('simple_data')) - task = training.TrainTask( - train_data(), tl.L2Loss(), optimizers.Adam(.0001)) - eval_task = training.EvalTask( - _very_simple_data(2, 2), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - - def _make_model_and_session(): - m = tl.Serial(tl.Dense(2)) - ts = training.Loop(m, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - return m, ts - - _, training_session = _make_model_and_session() - self.assertEqual(0, training_session.step) - training_session.run(n_steps=1) - training_session.save_checkpoint('model') - _, training_session2 = _make_model_and_session() - training_session2.run(n_steps=1) - base.N_WEIGHTS_SHARDS = 1 - - def test_train_save_restore_transformer(self): - """Saves and restores a checkpoint to check for equivalence.""" - vocab_size = 8 - task = training.TrainTask( - _very_simple_transformer_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_transformer_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - - def _make_model_and_session(): - m = transformer.TransformerLM( - vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.) - ts = training.Loop(m, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - return m, ts - - model, training_session = _make_model_and_session() - self.assertEqual(0, training_session.step) - training_session.run(n_steps=1) - training_session.save_checkpoint('model') - model2, training_session2 = _make_model_and_session() - - x = np.ones((2, 2)).astype(np.int32) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertEqual(str(y1), str(y2)) - - training_session2.run(n_steps=1) - y1 = model(x, rng=fastmath.random.get_prng(0)) - y2 = model2(x, rng=fastmath.random.get_prng(0)) - self.assertNotEqual(str(y1), str(y2)) - - def test_train_dense_layer_with_momentum(self): - """Trains with an optimizer that has slots / requires initialization.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Momentum(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['Momentum.L2Loss']) - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=20) - self.assertEqual(20, training_session.step) - - def test_train_dense_layer_evals(self): - """Trains a very simple network on a very simple task, 2 epochs.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()]) - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: False) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=10) - self.assertEqual(10, training_session.step) - training_session.run_evals() - self.assertEqual(10, training_session.step) # Unchanged - - def test_summaries_are_written(self): - """Training writes down metrics when writing is turned on.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - metric_names=['SGD.L2Loss']) - tmp_dir = self.create_tempdir().full_path - training_session = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - expected_train_metric_dir = os.path.join(tmp_dir, 'train') - expected_eval_metric_dir = os.path.join(tmp_dir, 'eval') - for directory in [expected_train_metric_dir, expected_eval_metric_dir]: - self.assertFalse( - os.path.isdir(directory), 'Failed for directory %s.' % directory) - training_session.run(n_steps=15) - time.sleep(1) # wait for the files to be closed - for directory in [expected_train_metric_dir, expected_eval_metric_dir]: - self.assertTrue( - os.path.isdir(directory), 'Failed for directory %s.' % directory) - self.assertEqual( - 1, _count_files(directory), 'Failed for directory %s.' % directory) - training_session.run(n_steps=5) - time.sleep(1) # wait for the files to be closed - for directory in [expected_train_metric_dir, expected_eval_metric_dir]: - self.assertEqual( - 2, _count_files(directory), 'Failed for directory %s.' % directory) - - def test_restores_step(self): - """Training restores step from directory where it saved it.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - loop2 = training.Loop(model, [task], output_dir=tmp_dir) - self.assertEqual(4, loop2.step) - - def test_restores_memory_efficient_from_standard(self): - """Training restores step from directory where it saved it.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - model = tl.Serial(tl.Dense(4), tl.Dense(1)) - task_std = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task_std], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - task_memeff = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Adam) - loop2 = training.Loop(model, [task_memeff], output_dir=tmp_dir, - use_memory_efficient_trainer=True) - loop2.run(2) - self.assertEqual(6, loop2.step) - - def test_restores_from_smaller_model(self): - """Training restores from a checkpoint created with smaller model.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - model1 = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model1, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(2) - model2 = tl.Serial(tl.Dense(1), tl.Dense(1)) - loop2 = training.Loop(model2, [task], output_dir=tmp_dir) - self.assertEqual(2, loop2.step) - - def test_restore_fails_different_model(self): - """Training restores from a checkpoint created with smaller model.""" - model1 = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model1, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(2) - model2 = tl.Serial(tl.Dense(2)) - with self.assertRaises(IndexError): - training.Loop(model2, [task], output_dir=tmp_dir) - - def test_restores_step_bfloat16(self): - """Training restores step from directory where it saved it, w/ bfloat16.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') - model = tl.Serial(tl.Dense(1, use_bfloat16=True)) - # We'll also use Adafactor with bfloat16 to check restoring bfloat slots. - opt = optimizers.Adafactor(.01, do_momentum=True, momentum_in_bfloat16=True) - task = training.TrainTask(_very_simple_data(), tl.L2Loss(), opt) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - loop2 = training.Loop(model, [task], output_dir=tmp_dir) - self.assertEqual(4, loop2.step) - loop2.run(2) # check that continued training works - self.assertEqual(6, loop2.step) - - def test_restores_step_sharded(self): - """Training restores step from directory where it saved it, sharded.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir, use_memory_efficient_trainer=True) - loop.run(4) - loop2 = training.Loop(model, [task], - output_dir=tmp_dir, use_memory_efficient_trainer=True) - self.assertEqual(4, loop2.step) - - def test_restores_step_sharded_bfloat16(self): - """Training restores step from where it saved it, sharded and bfloat16.""" - model = tl.Serial(tl.Dense(1, use_bfloat16=True)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop(model, [task], - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir, use_memory_efficient_trainer=True) - loop.run(4) - loop2 = training.Loop(model, [task], - output_dir=tmp_dir, use_memory_efficient_trainer=True) - self.assertEqual(4, loop2.step) - loop2.run(2) # check that continued training works - self.assertEqual(6, loop2.step) - - def test_restores_history(self): - """Training restores history from directory where it saved it.""" - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask(_very_simple_data(), tl.L2Loss(), - optimizers.SGD(.01)) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()]) - tmp_dir = self.create_tempdir().full_path - loop = training.Loop( - model, [task], - eval_tasks=[eval_task], - eval_at=lambda step_n: step_n % 2 == 0, - checkpoint_at=lambda step_n: step_n % 2 == 0, - output_dir=tmp_dir) - loop.run(4) - loop2 = training.Loop(model, [task], output_dir=tmp_dir) - self.assertLen(loop2.history.modes, 2) - self.assertLen(loop2.history.metrics_for_mode('train'), 6) - self.assertLen(loop2.history.metrics_for_mode('eval'), 1) - for mode, metric in [ - ('train', 'metrics/L2Loss'), - ('train', 'training/learning_rate'), - ('train', 'training/steps per second'), - ('train', 'training/gradients_l2'), - ('train', 'training/loss'), - ('train', 'training/weights_l2'), - ('eval', 'metrics/L2Loss'), - ]: - self.assertLen(loop2.history.get(mode, metric), 1) - self.assertEqual(2, loop2.history.get(mode, metric)[0][0]) - - def test_trains_on_two_tasks(self): - """Trains a very simple network on two very simple tasks.""" - model = tl.Serial(tl.Dense(3), tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01) - ) - eval_task = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - ) - training_session = training.Loop( - model, - tasks=(task, task), - eval_tasks=(eval_task, eval_task), - which_task=lambda step_n: step_n % 2, - ) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=15) - self.assertEqual(15, training_session.step) - training_session.run(n_steps=5) - self.assertEqual(20, training_session.step) - - def test_train_one_task_eval_two_tasks(self): - """Trains a very simple network on one task and evaluates on two tasks.""" - model = tl.Serial(tl.Dense(3), tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), - tl.L2Loss(), - optimizers.SGD(.01) - ) - export_prefix_1 = 'eval_1' - eval_task_1 = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - export_prefix=export_prefix_1, - ) - export_prefix_2 = 'eval_2' - eval_task_2 = training.EvalTask( - _very_simple_data(), # deliberately re-using training data - [tl.L2Loss()], - export_prefix=export_prefix_2, - ) - training_session = training.Loop( - model, - tasks=(task,), - eval_tasks=(eval_task_1, eval_task_2), - ) - self.assertEqual(0, training_session.step) - training_session.run(n_steps=5) - self.assertEqual(5, training_session.step) - export_prefixes = [task.export_prefix - for task in training_session.eval_tasks] - self.assertCountEqual([export_prefix_1, export_prefix_2], - export_prefixes) - - def test_can_predict_with_trained_model(self): - model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) - train_tasks, eval_tasks = [], [] - for output_dim in [1, 2]: - # The head we select from the model: 0 for output_dim 1 and 1 for 2. - head_index = output_dim - 1 - train_tasks.append(training.TrainTask( - _very_simple_data(output_dim), - tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()), - optimizers.SGD(.01) - )) - eval_tasks.append(training.EvalTask( - _very_simple_data(output_dim), # deliberately re-use training data - [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())] - )) - tmp_dir = self.create_tempdir().full_path - training_session = training.Loop( - model, - tasks=train_tasks, - eval_tasks=eval_tasks, - checkpoint_at=lambda step_n: step_n == 1, - output_dir=tmp_dir, - which_task=lambda step_n: step_n % 2, - ) - training_session.run(n_steps=2) - - trained_model = training_session.eval_model - inp = next(_very_simple_data())[0] - out = trained_model(inp) - self.assertEqual( - shapes.signature(out), - (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), - ) - - def test_train_memory_efficient(self): - """Trains a large network in a memory-efficient way.""" - # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU - # and CPU when you run it locally, but it's too big for unit-testing. - ram_limited = True # Set to False to run this test locally. - if fastmath.global_device_count() == 1 and ram_limited: - return - - # Create the model. - n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram - model = tl.Serial( - tl.Embedding(9, 16*1024), - tl.Dup(), - [[tl.ReversibleHalfResidual(tl.Dense(16*1024)), tl.ReversibleSwap()] - for _ in range(n_layers)], - tl.Concatenate(), - tl.Dense(9), - ) - - # Create inputs. - inputs_batch = np.arange(8).reshape((2, 4)) - targets_batch = inputs_batch - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - def _data_gen(): - while True: - yield labeled_batch - - # Run training. - loss_layer = tl.WeightedCategoryCrossEntropy() - task = training.TrainTask(_data_gen(), loss_layer, optimizers.Adafactor) - eval_task = training.EvalTask(_data_gen(), - [tl.WeightedCategoryCrossEntropy()]) - loop = training.Loop(model, [task], eval_tasks=[eval_task], - eval_at=lambda step_n: step_n == 2, - use_memory_efficient_trainer=True) - self.assertEqual(0, loop.step) - loop.run(n_steps=2) - self.assertEqual(2, loop.step) - - def test_initializes_step_callbacks_with_loop_instance(self): - """Runs a training loop, asserting that callbacks are initialized.""" - - class ActualLoop: - # Wrapper object to make the Loop reference mutable. - loop = None - - class TestCallback(callbacks.TrainingStepCallback): - - def __init__(self, loop): - super().__init__(loop) - ActualLoop.loop = loop - - def call_at(self, step): - return False - - def on_step_begin(self, step): - del step - - def on_step_end(self, step): - del step - - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01) - ) - expected_loop = training.Loop( - model, [task], callbacks=[TestCallback] - ) - self.assertIs(ActualLoop.loop, expected_loop) - - def test_calls_step_callbacks(self): - """Runs a training loop, asserting that callbacks are called.""" - call_at_steps = [1, 3, 4] - begin_steps = [] - end_steps = [] - test_case = self - - class TestCallback(callbacks.TrainingStepCallback): - - def call_at(self, step): - return step in call_at_steps - - def on_step_begin(self, step): - begin_steps.append(step) - - def on_step_end(self, step): - # Assert that on_step_begin() was called before. - test_case.assertIn(step, begin_steps) - end_steps.append(step) - - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01) - ) - loop = training.Loop(model, [task], callbacks=[TestCallback]) - loop.run(n_steps=5) - - # Assert that the callback has been called at the appropriate steps. - self.assertEqual(begin_steps, call_at_steps) - self.assertEqual(end_steps, call_at_steps) - - -def _very_simple_data(output_dim=1, input_dim=1): - """"Returns stream of labeled data that maps small integers to constant pi.""" - inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch - inputs_batch = np.concatenate([inputs_batch] * input_dim, axis=1) - targets_batch = np.pi * np.ones((8, output_dim)) - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - while True: - yield labeled_batch - - -def _very_simple_transformer_data(): - """"Returns stream of labeled data that maps small integers to constant pi.""" - inputs_batch = np.ones((2, 2)).astype(np.int32) - targets_batch = np.ones((2, 2, 8)).astype(np.int32) - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - while True: - yield labeled_batch - - -def _count_files(path): - """Returns number of files in a given directory.""" - return len([filename for filename in os.listdir(path) - if os.path.isfile(os.path.join(path, filename))]) - - -if __name__ == '__main__': - config.config_with_absl() - absltest.main() diff --git a/trax/tf_numpy/__init__.py b/trax/tf/__init__.py similarity index 99% rename from trax/tf_numpy/__init__.py rename to trax/tf/__init__.py index a4ee92161..7b248b4dc 100644 --- a/trax/tf_numpy/__init__.py +++ b/trax/tf/__init__.py @@ -12,4 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/trax/tf_numpy/extensions/__init__.py b/trax/tf/extensions/__init__.py similarity index 93% rename from trax/tf_numpy/extensions/__init__.py rename to trax/tf/extensions/__init__.py index 76d81ada7..9affff137 100644 --- a/trax/tf_numpy/extensions/__init__.py +++ b/trax/tf/extensions/__init__.py @@ -16,5 +16,6 @@ """JAX-like function transformations and extensions for TF-numpy.""" # pylint: disable=wildcard-import -from trax.tf_numpy.extensions.extensions import * +from trax.tf.extensions.extensions import * + # pylint: enable=wildcard-import diff --git a/trax/tf/extensions/extensions.py b/trax/tf/extensions/extensions.py new file mode 100644 index 000000000..f971974d5 --- /dev/null +++ b/trax/tf/extensions/extensions.py @@ -0,0 +1,2155 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensions such as `jit`, `grad`, `logsumexp`, etc.""" +import bisect +import contextlib +import copy +import functools +import string +import sys +import threading + +import numpy as np +import six +import tensorflow as tf + +import trax.tf.numpy as tf_np + +_int_dtype_lower_bounds = [ + -(2**63), + -(2**31), + -(2**15), + -(2**7), + 0, + 2**7, + 2**15, + 2**31, + 2**64, +] +_int_dtypes = [ + tf.int64, + tf.int32, + tf.int16, + tf.int8, + tf.uint8, + tf.uint16, + tf.uint32, + tf.uint64, +] +_tf_nn_APIs = { + 1: [tf.nn.conv1d, tf.nn.conv1d_transpose], + 2: [tf.nn.conv2d, tf.nn.conv2d_transpose], + 3: [tf.nn.conv3d, tf.nn.conv3d_transpose], +} + + +remat = tf.recompute_grad + + +def most_precise_int_dtype(x): + if not isinstance(x, six.integer_types) or isinstance(x, bool): + return None + i = bisect.bisect_right(_int_dtype_lower_bounds, x) + if i in (0, len(_int_dtype_lower_bounds)): + raise ValueError(f"Integer {x} is out of bounds") + assert len(_int_dtype_lower_bounds) == len(_int_dtypes) + 1 + return _int_dtypes[i - 1] + + +def _canonicalize_jit_arg(x): + if isinstance(x, tf_np.ndarray): + return x + try: + # We need to convert `int` to the most precise dtype, otherwise the dtype + # of the result may be different from numpy's. For example, when a binary + # op takes in a Python integer 5 and an array of uint32, numpy will pick + # uint32 as 5's dtype, while tf.convert_to_tensor will choose int32 which + # will cause the two arguments to be promoted to int64. We pick uint8 + # here, which will be promoted to uint32 by the binary op. + # Note that we prefer unsigned int to signed int when both are equally + # precise. For example, for 5, we pick uint8 instead of int8. There is no + # reason to prefer one to the other, because for each there is a case + # where the behavior diverges from numpy. If we prefer signed int, + # consider the case where the first operand is 5 and the second is + # 2**64-1. Numpy picks uint64 as the result dtype, but because we choose a + # signed type for 5 such as int8, the result type will be float64. On the + # other hand, if we prefer unsigned int, consider the case where the first + # operand is 2**31-1 and the second is -1. Numpy will pick int32, but + # because we choose uint32 for 2*32-1, the result will be int64. The root + # of the problem is that `jit` converts `int` to tensors (hence committing + # to a dtype) too early, when we don't have enough information about the + # jitted function (e.g. which subset of the arguments should be promoted + # together using np.result_type). tf.function doesn't have this problem + # because it doesn't convert `int` to tensors. jax.jit doesn't have this + # problem because it converts `int` to "int tracer" which doesn't commit + # to a dtype. + # TODO(wangpeng): Revisit this design and see whether we can improve `jit` + # and tf.function. + dtype = most_precise_int_dtype(x) + if dtype is None and isinstance(x, float): + dtype = tf_np.float32 + return tf.convert_to_tensor(value=x, dtype=dtype) + except (TypeError, ValueError): + return x + + +def _canonicalize_jit_arguments(inp): + """Canonicalize arguments to be used for jit. + + Args: + inp: a nested structure of arguments to be canonicalized (i.e. to be + converted to Tensors). Only tf_np.ndarray and things accepted by + `tf.convert_to_tensor` will be converted. + + Returns: + The canonicalized version. + """ + return tf.nest.map_structure(_canonicalize_jit_arg, inp) + + +def _tf_to_np(inp): + def f(x): + if type(x).__name__ == "ndarray": + data = x._data + + if isinstance(data, tf.IndexedSlices): + data = tf.convert_to_tensor(data) + return tf_np.asarray(data) + + if isinstance(x, tf.IndexedSlices): + return tf_np.asarray(x) + else: + return x + + return tf.nest.map_structure(f, inp) + + +def stop_gradient(x): + def static_stop_gradient(x): + # `tf.stop_gradient` is a no-op for non-Tensor. Returning the original type + # allows it to be used in the conditional without Autograph, if static. For + # example: + # `if fastmath.stop_gradient(5) > 4:` + return tf.stop_gradient(x) if tf.is_tensor(x) else x + + return _tf_to_np(tf.nest.map_structure(static_stop_gradient, x)) + + +def custom_grad(f_vjp, f_original=None): + """Decorator to define a function with a custom gradient. + + This function is very similar to `tf.custom_gradient`. See the documentation + of `tf.custom_gradient` for detailed usage. + + The differences with `tf.custom_gradient` are: + + - All arguments and results are tf_np.ndarrays instead of tensors. + + - The `grad_fn` returned by `f_vjp` accepts and returns nested structures, + unlike that in `tf.custom_gradient` which only accepts and returns lists. + + Args: + f_vjp: the same as the `f` argument of `tf.custom_gradient`. Note that all + inputs and outputs of `f_vjp` and of the `grad_fn` function it returns can + be nested structures. + f_original: (optional) not used. + + Returns: + The same as `tf.custom_gradient`. + """ + del f_original + + @tf.custom_gradient + def tf_f(*tf_args, **tf_kwargs): + np_args = _tf_to_np(tf_args) + np_kwargs = _tf_to_np(tf_kwargs) + np_y, np_vjp = f_vjp(*np_args, **np_kwargs) + tf_y = np_y + + def tf_vjp(*flat_tf_dy): + tf_dy = tf.nest.pack_sequence_as(tf_y, flat_tf_dy) + np_dy = _tf_to_np(tf_dy) + np_dx = np_vjp(np_dy) + return tf.nest.flatten(np_dx) + + return tf_y, tf_vjp + + def np_f(*args, **kwargs): + return _tf_to_np(tf_f(*args), **kwargs) + + return np_f + + +def vjp(f, *primals, has_aux=False): + """Returns the result and the VJP function of `f`. + + This function returns the result and the vector-Jacobian-product (VJP) + function of `f`. + + Args: + f: a function from (nested structures of) tf_np.ndarrays to a (nested + structure of) tf_np.ndarray. If `has_aux` is True, it should return an + extra output. + *primals: the inputs to be fed to `f`. + has_aux: if True, the second output of `f` will be regarded as an auxiliary, + non-differentiable output that will be ignored by the VJP function. + + Returns: + A pair `(y, vjpfun)` if `has_aux` is False; a tuple `(y, vjpfun, aux)` + otherwise. `y` and `aux` are the outputs of `f`, i.e. `y, aux = + f(*primals)`. `vjpfun` is a function `dx = vjpfun(dy)`, where `dy` is the + cotengents of `y`, having the same structures, shapes and dtypes as + `y`. `dx` is the cotengents of `x`, having the same structures, shapes and + dtypes as `x`. + """ + with tf.GradientTape(persistent=True) as tape: + tape.watch(tf.nest.flatten(primals)) + outputs = f(*primals) + if has_aux: + np_out, aux = outputs + else: + np_out = outputs + + def _vjp(dy): + tf_dx = tape.gradient(np_out, primals, output_gradients=dy) + return _tf_to_np(tf_dx) + + if has_aux: + ret = (np_out, _vjp, aux) + else: + ret = (np_out, _vjp) + return ret + + +# TODO(wangpeng): match JAX's handling of kwargs and non-ndarray args +def grad(f, has_aux=False): + """Returns a function that computes gradient of f. + + Gradients can only be computed through numpy and tensorflow operations and not + through python float operations and values. + + Args: + f: a function of type (params, *args) -> scalar. 'params' can be a nested + structure (made of lists and tuples) of ndarrays and the gradient is + evaluated against it. `scalar` is a scalar ndarray. + has_aux: bool, indicates whether fun returns a pair where the first element + is considered the output of the mathematical function to be differentiated + and the second element is auxiliary data. + + Returns: + A gradient function of type (params, *args) -> gradients, where the result + 'gradients' has the same structure and shapes as 'params'. + """ + + def check_loss_shape(np_loss): + if not isinstance(np_loss, tf_np.ndarray): + raise ValueError( + "The result of the function to take gradient must be an ndarray." + ) + # TensorFlow 1.x has + # TensorFlow 2.x does not contain such method as is_compatible_with. + # We can change it to compare the shape with () + if not np_loss.shape == (): + raise ValueError( + "The result of the function to take gradient must be a scalar." + ) + + def _f(params, *args): + """The gradient function to be returned.""" + with tf.GradientTape() as g: + g.watch(tf.nest.flatten(params)) + outputs = f(params, *args) + + if has_aux: + np_loss, aux = outputs + else: + np_loss = outputs + + check_loss_shape(np_loss) + + tf_grads = g.gradient(np_loss, params) + tf_grads = _tf_to_np(tf_grads) + + if has_aux: + res = (tf_grads, aux) + else: + res = tf_grads + return _tf_to_np(res) + + return _f + + +def _record_result_type(recorder, f): + """A decorator that records some information about the function. + + Args: + recorder: a function of signature `(args, kwargs, res) -> res`. + f: the original function. + + Returns: + A transformed function that calls the original function and then the + recorder afterwards. + """ + + def wrapper(*args, **kwargs): + res = f(*args, **kwargs) + res = recorder(args, kwargs, res) + return res + + return wrapper + + +def jit( + f, + static_argnums=(), + xla_forced_compile=False, + input_signature=None, + autograph=False, + experimental_compile=False, +): + """Returns a function that runs a trace-compiled version of `f`. + + A trace-compiled version of a function `f` has the same behavior as `f` (when + called with the same "static arguments", see below), but runs faster because + the whole computation is compiled into a computation graph once which is + reused for subsequent executions. + + The trace compilation happens lazily, when the returned function is called for + the first time. The compiled function may not be cached implicitly and + multiple calls to `jit` may not share the compiled function (see below for + "static" vs "dynamic" arguments). + + Args: + f: a function that takes any positional arguments `args` and any keyword + arguments `kwargs`. `ndarray`s and things accepted by + `tf.convert_to_tensor` in `args` and `kwargs` will be treated as 'dynamic + arguments' in the sense that calling the function with different values + for these arguments will not cause retracing. In contrast, arguments of + other types in `args` and `kwargs` are treated as 'static arguments' and + calling the function with different values of them will cause + re-compiling. Positional arguments whose positions are in `static_argnums` + are always treated as static arguments. + static_argnums: a tuple of positions of arguments that will be treated as + static arguments. Note that as aforementioned, any arguments that were not + convertible to tensor will also be static. + xla_forced_compile: if true, it will use XLA to force-compile the graph. + This requires that the function only contain ops that are XLA + compatible. It will compile the entire function into a single XLA op. + input_signature: a list of `tf.TensorSpec`, as the input signature to + control tracing behavior. See the + [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of + `tf.function` for details. + autograph: whether to use autograph to convert Python constructs such as + `if` and `while` to their TensorFlow counterparts. See the + [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of + `tf.function` for details. + experimental_compile: the `experimental_compile` flag for `tf.function`. See + the [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of + `tf.function` for details. This is the recommended way to turn on XLA for + tf.function, but unlike xla_forced_compile, it doesn't force-compile the + entire function into a single XLA op. + + Returns: + A trace-compiled version of f. + """ + + @tf.function( + input_signature=input_signature, + autograph=autograph, + experimental_compile=experimental_compile, + ) + def _tf_f(*args, **kwargs): + """Accelerated function with tensor inputs/outputs.""" + np_args = _tf_to_np(args) + kwargs = {k: _tf_to_np(v) for k, v in kwargs.items()} + if xla_forced_compile: + # Use list for mutability + output_is_list = [False] + output_is_empty = [False] + output_structure = [None] + + def recorder(args, kwargs, res): + del args, kwargs + # Workaround b/121383831 + output_is_list[0] = isinstance(res, list) + # If outputs are empty, xla.compile returns an `Operation`, which we + # don't want. + if tf.nest.flatten(res): + output_is_empty[0] = False + output_structure[0] = None + else: + output_is_empty[0] = True + # Without deepcopy, xla.compile will change output_structure[0] to a + # list of `Operation`. + output_structure[0] = copy.deepcopy(res) + return res + + f_ = _record_result_type(recorder, f) + np_out = tf.xla.experimental.compile(lambda: f_(*np_args, **kwargs)) + # Workaround b/121383831 + if output_is_empty[0]: + np_out = output_structure[0] + elif ( + isinstance(np_out, list) and len(np_out) == 1 and not output_is_list[0] + ): + np_out = np_out[0] + else: + np_out = f(*np_args, **kwargs) + return np_out + + def _f(*args, **kwargs): + args = [ + _canonicalize_jit_arguments(arg) if i not in static_argnums else arg + for i, arg in enumerate(args) + ] + kwargs = {k: _canonicalize_jit_arguments(v) for k, v in kwargs.items()} + tf_out = _tf_f(*args, **kwargs) + return _tf_to_np(tf_out) + + _f.tf_function = _tf_f + + return _f + + +def eval_on_shapes(f, static_argnums=(), allow_static_outputs=False): + """Returns a function that evaluates `f` given input shapes and dtypes. + + It transforms function `f` to a function that performs the same computation as + `f` but only on shapes and dtypes (a.k.a. shape inference). + + Args: + f: the function to be transformed. + static_argnums: see documentation of `jit`. + allow_static_outputs: whether to allow non-array outputs. If True, non-array + outputs (e.g. Python integers) will be returned as-is; otherwise, they + will be converted to ndarrays, and then specs of those ndarrays will be + returned. + + Returns: + A function whose input arguments can be either the same as `f`'s or only + their shapes/dtypes represented by `tf.TensorSpec`, and whose return values + are `tf.TensorSpec`s with the same nested structure as `f`'s return + values. If `allow_static_outputs` is True, when `f` returns some non-array + outputs (e.g. Python integers), the converted function will return them + as-is instead of returning `tf.TensorSpec`s for them. + """ + + def abstractify(args): + def _abstractify(x): + x = _canonicalize_jit_arg(x) + if isinstance(x, (tf.Tensor, tf_np.ndarray)): + return tf.TensorSpec(x.shape, x.dtype) + else: + return x + + new_args = [] + for i, arg in enumerate(args): + if i in static_argnums: + new_args.append(arg) + else: + new_args.append(tf.nest.map_structure(_abstractify, arg)) + return new_args + + if allow_static_outputs: + # When `tf_f` below is called (via get_concrete_function) with the same + # arugments (after abstraction), the Python function `f` won't be run, so we + # need this python_outputs_map to retrieve the Python outputs we've seen + # before that correspond the arguments. + python_outputs_map = {} + + def recorder(args, kwargs, res): + # Since the get_concrete_function below only uses positional args, we also + # only positional args here. + del args, kwargs + + def is_tensor_like(x): + if hasattr(x, "_type_spec"): + return True # x is a CompositeTensor + return isinstance(x, (tf_np.ndarray, tf.Tensor)) + + py_values = tf.nest.map_structure( + lambda x: None if is_tensor_like(x) else x, res + ) + key = id(tf.compat.v1.get_default_graph()) + python_outputs_map[key] = py_values + # Set non-tensor outputs to None to avoid tf.function calling + # tf.convert_to_tensor on them. + res = tf.nest.map_structure( + lambda x: None if not is_tensor_like(x) else x, res + ) + return res + + f = _record_result_type(recorder, f) + + # TODO(wangpeng): tf.function could add a knob to turn off materializing the + # graph, so that we don't waste computation and memory when we just want + # shape inference. + tf_f = jit(f, static_argnums=static_argnums).tf_function + + # pylint: disable=missing-docstring + def f_return(*args): + def to_tensor_spec(x): + if isinstance(x, tf.Tensor): + return tf.TensorSpec(x.shape, x.dtype) + else: + return x + + new_args = abstractify(args) + cfun = tf_f.get_concrete_function(*new_args) + res = cfun.structured_outputs + res = tf.nest.map_structure(to_tensor_spec, res) + + if allow_static_outputs: + key = id(cfun.graph) + py_values = python_outputs_map[key] + # We can also call tf.get_static_value on structured_outputs to retrieve + # the Python values, but since we'll need to use python_outputs_map to + # record "which outputs are static?" anyway, we choose to directly store + # the Python values in python_outputs_map. + res = tf.nest.map_structure( + lambda x, python_value: x if python_value is None else python_value, + res, + py_values, + ) + + return res + + # Provides access to `tf_f` for testing purpose. + f_return._tf_function = tf_f # pylint: disable=protected-access + return f_return + + +def _index_update_helper(updater, x, idx, y): + x = tf_np.asarray(x) + y = tf_np.asarray(y) + # TODO(b/164251540): Remove this expensive manual broadcasting once + # tf.raw_ops.tensor_strided_slice_update and tf.tensor_scatter_nd_update + # support broadcasting. + y = tf.broadcast_to(y, tf.shape(x[idx])) + return updater(x, idx, y) + + +# pylint: disable=protected-access +def index_update(x, idx, y): + """Pure equivalent of `x[idx] = y`. + + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] = y`. Because it's a pure function, `x` itself won't be + changed. + + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_update, x, idx, y) + + +def index_add(x, idx, y): + """Pure equivalent of `x[idx] += y`. + + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] += y`. Because it's a pure function, `x` itself won't be + changed. + + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_add, x, idx, y) + + +def index_min(x, idx, y): + """Pure equivalent of `x[idx] = minimum(x[idx], y)`. + + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] = minimum(x[idx], y)`. Because it's a pure function, `x` + itself won't be changed. + + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_min, x, idx, y) + + +def index_max(x, idx, y): + """Pure equivalent of `x[idx] = maximum(x[idx], y)`. + + Returns the value of x that would result from the NumPy-style indexed + assignment `x[idx] = maximum(x[idx], y)`. Because it's a pure function, `x` + itself won't be changed. + + Args: + x: an array with the values to be updated. + idx: a Numpy-style index, consisting of `None`, integers, slice objects, + ellipses, ndarrays with integer dtypes, or a tuple of the above. + y: the array of updates. `y` must be broadcastable to the shape of the array + that would be returned by `x[idx]`. + + Returns: + The updated version of `x`. + """ + return _index_update_helper(tf_np.ndarray._with_index_max, x, idx, y) + + +# pylint: enable=protected-access + + +def logsumexp(x, axis=None, keepdims=None): + """Computes log(sum(exp(elements across dimensions of a tensor))). + + Reduces `x` along the dimensions given in `axis`. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + entry in `axis`. If `keepdims` is true, the reduced dimensions + are retained with length 1. + If `axis` has no entries, all dimensions are reduced, and a + tensor with a single element is returned. + This function is more numerically stable than log(sum(exp(input))). It avoids + overflows caused by taking the exp of large inputs and underflows caused by + taking the log of small inputs. + + Args: + x: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(x), rank(x))`. + keepdims: If true, retains reduced dimensions with length 1. + + Returns: + The reduced tensor. + """ + return tf_np.asarray( + tf.math.reduce_logsumexp(input_tensor=x, axis=axis, keepdims=keepdims) + ) + + +def expit(x): + """Compute 1 / (1 + exp(-x)).""" + return tf_np.asarray(tf.math.sigmoid(x)) + + +def erf(x): + """Computes the Gauss error function of x element-wise.""" + return tf_np.asarray(tf.math.erf(x)) + + +def _minus(a, b): + return [x for x in a if x not in b] + + +def _compose_output_rep( + lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, lhs_batch, rhs_batch +): + """Compose the output string representation. + + e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik + aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik + + Args: + lhs_rep: A string representation for the left-hand side input array + rhs_rep: A string representation for the right-hand side input array + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + lhs_batch: Sequence[int] (the batch dimensions of lhs) + rhs_batch: Sequence[int] (the batch dimensions of rhs) + + Returns: + A string representation of the result array. + """ + output_rep = [] + for dim in lhs_batch: + output_rep.append(lhs_rep[dim]) + + for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction): + output_rep.append(lhs_rep[i]) + for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction): + output_rep.append(rhs_rep[i]) + return "".join(output_rep) + + +def _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction): + """Compute the non-batched matrix multiplication. + + If it is the general non-batched/single-batched matrix multiplication, + use the highly optimized kernel `tf.tensordot` to handle it. + + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + + Returns: + An array that contains the result. + """ + return tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction))) + + +def tf_dot_general(lhs, rhs, dimension_numbers): + """The general dot operation for TensorFlow. + + An equivalent general dot operation as that in JAX - + + Although there is an implementation in TF XLA, avoid directly using XLA when + possible. + + e.g., non-batched: ij,jk->ik + batched: ijk,ikl->ijl + + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + dimension_numbers: (Tuple[Tuple[Sequence[int], Sequence[int]], + Tuple[Sequence[int], Sequence[int]]]) – a tuple of tuples of the form + ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims)) + + Returns: + An array that contains the result. + """ + char_list = list(string.ascii_lowercase) + char_list = char_list[8:] + char_list[:8] + lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape) + lhs_rep = char_list[:lhs_rank] + rhs_rep = char_list[lhs_rank : lhs_rank + rhs_rank] + contraction, batch = dimension_numbers + lhs_contraction, rhs_contraction = contraction + if len(lhs_contraction) != len(rhs_contraction): + raise ValueError( + "The input matrices are required to have the same number " + "of contraction dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_contraction), len(rhs_contraction) + ) + ) + lhs_batch, rhs_batch = batch + if len(lhs_batch) != len(rhs_batch): + raise ValueError( + "The input matrices are required to have the same number " + "of batch dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_batch), len(rhs_batch) + ) + ) + + if not lhs_batch and not rhs_batch: + return _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) + + if ( + lhs_rank == rhs_rank == 3 + and lhs_batch == (0,) + and rhs_batch == (0,) + and lhs_contraction == (2,) + and rhs_contraction == (1,) + ): + return tf.linalg.matmul(lhs, rhs) + + for i in range(len(lhs_contraction)): + rhs_rep[rhs_contraction[i]] = lhs_rep[lhs_contraction[i]] + for i in range(len(lhs_batch)): + rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]] + + output_rep = _compose_output_rep( + lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, lhs_batch, rhs_batch + ) + equation = "".join(lhs_rep) + "," + "".join(rhs_rep) + "->" + output_rep + return tf.einsum(equation, lhs, rhs) + + +def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation, dim): + """Convert strides, lhs_dilation, rhs_dilation to match TF convention. + + For example, + in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] + if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] + + Args: + window_strides: window_strides to be converted + lhs_dilation: lhs_dilation to be converted + rhs_dilation: rhs_dilation to be converted + dim: dim to be converted + + Returns: + The updated window_strides, lhs_dilation and rhs_dilation + """ + + def _as_list_of_size(item, size): + if item is None: + return None + return [item] * size if isinstance(item, int) else list(item) + + return ( + _as_list_of_size(window_strides, dim), + _as_list_of_size(lhs_dilation, dim), + _as_list_of_size(rhs_dilation, dim), + ) + + +# pylint: disable=g-bad-todo +# TODO(DarrenZhang01): Expand the test cases of general convolution and revise +# the according bugs. +# TODO(DarrenZhang01): Support feature_group_count, batch_group_count and +# precision, and allow lhs_dilation and rhs_dilation to happen at the same time. +# pylint: enable=g-bad-todo +def tf_conv_general_dilated( + lhs, + rhs, + window_strides, + padding, + output_shape, + lhs_dilation=None, + rhs_dilation=None, + dimension_numbers=None, + feature_group_count=1, + batch_group_count=1, + precision=None, +): + """A general conv API for TensorFlow. + + According JAX version: + https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html + + Args: + lhs: a rank n+2 dimensional input array. + rhs: a rank n+2 dimensional array of kernel weights. + window_strides: a sequence of n integers, representing the inter-window + strides. + padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n + (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. + output_shape: the output shape of the convolution (only required for + transpose convolution). + lhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of lhs. LHS dilation is + also known as transposed convolution. + rhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of rhs. RHS dilation is + also known as atrous convolution. + dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple + (lhs_spec, rhs_spec, out_spec), where each element is a + string of length n+2. + feature_group_count: integer, default 1. Changing this is currently not + supported. + batch_group_count: integer, default 1. Changing this is currently not + supported. + precision: Optional. Either None, which means the default precision for the + backend, or a Precision enum value. + + Returns: + A TF NumPy array that contains the convolution result. + """ + dim = None + lhs_spec, rhs_spec, out_spec = dimension_numbers + if lhs_spec != out_spec: + raise ValueError( + "Current implementation requires the `data_format` of the " + "inputs and outputs to be the same." + ) + if len(lhs_spec) >= 6: + raise ValueError( + "Current implmentation does not support 4 or higher" + "dimensional convolution, but got: ", + len(lhs_spec) - 2, + ) + dim = len(lhs_spec) - 2 + if lhs_dilation and rhs_dilation: + if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: + lhs_dilation, rhs_dilation = None, None + else: + raise ValueError( + "Current implementation does not support that " + "deconvolution and dilation to be performed at the same " + "time, but got lhs_dilation: {}, rhs_dilation: {}".format( + lhs_dilation, rhs_dilation + ) + ) + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Current implementation requires the padding parameter" + "to be either 'VALID' or 'SAME', but got: ", + padding, + ) + if batch_group_count != 1 or feature_group_count != 1: + raise NotImplementedError( + "batch_group_count and feature_group_count " + "other than 1 is currently not supported, but" + " got feature_group_count: {}, batch_group_count" + ": {}".format(feature_group_count, batch_group_count) + ) + if precision is not None: + raise NotImplementedError( + "precision other than `None` is currently not " + "supported, but got: {}".format(precision) + ) + # Convert params from int/Sequence[int] to list of ints. + strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( + window_strides, lhs_dilation, rhs_dilation, dim + ) + # Preprocess the shapes + dim_maps = {} + if isinstance(lhs_spec, str): + dim_maps["I"] = list(rhs_spec).index("I") + dim_maps["O"] = list(rhs_spec).index("O") + dim_maps["N"] = list(lhs_spec).index("N") + dim_maps["C"] = list(lhs_spec).index("C") + else: + dim_maps["I"] = rhs_spec[1] + dim_maps["O"] = rhs_spec[0] + dim_maps["N"] = lhs_spec[0] + dim_maps["C"] = lhs_spec[1] + + lhs = tf_np.moveaxis(lhs, (dim_maps["N"], dim_maps["C"]), (0, dim + 1)) + # Adjust the filters, put the dimension 'I' and 'O' at last. + rhs = tf_np.moveaxis(rhs, (dim_maps["O"], dim_maps["I"]), (dim + 1, dim)) + spatial_dim_maps = {1: "W", 2: "HW", 3: "DHW"} + data_format = "N" + spatial_dim_maps[dim] + "C" + + if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): + output = _tf_nn_APIs[dim][0]( + lhs, rhs, strides, padding, data_format, rhs_dilation + ) + else: + output = _tf_nn_APIs[dim][1]( + lhs, + rhs, + tf.constant(output_shape), + strides, + padding, + data_format, + lhs_dilation, + ) + output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps["N"], dim_maps["C"])) + return output + + +def conv(inp, fltr, window_strides, padding, dimension_numbers, filter_dilation=None): + """Convolution over an N-D array. + + See https://www.tensorflow.org/api_docs/python/tf/nn/convolution and + https://www.tensorflow.org/xla/operation_semantics#conv_convolution for + reference. + + Args: + inp: an (N+2)-D array. The input of the convolution. + fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. + window_strides: a sequence of N ints, the strides for moving the convolution + window. + padding: a string, either "VALID" or "SAME". The padding algorithm. + dimension_numbers: a tuple of three strings encoding the data format of + input, filter and output. "I" means input; "O" means output; "C" means + channel; other characters such as "W", "H" and "D" means spatial + dimensions. + filter_dilation: the dilation rates for the filter. Dilating the filter + means adding "holes" to the filter. + + Returns: + An (N+2)-D array. The convolution result. + """ + input_spec, filter_spec, output_spec = dimension_numbers + if input_spec != output_spec: + raise ValueError( + "Input and output data formats must be the same; got %s " + "and %s" % (input_spec, output_spec) + ) + supported_filter_spec = ["WIO", "HWIO", "DHWIO"] + if filter_spec not in supported_filter_spec: + raise ValueError( + "The supported data format for the filter are %s; got %s" + % (supported_filter_spec, filter_spec) + ) + if input_spec[1:-1] != filter_spec[:-2]: + raise ValueError( + "Input data format (%s) is not compatible with filter " + "data format (%s)" % (input_spec, filter_spec) + ) + # No type promotion in order to prevent accidentally doing more expensive + # computation. + dtype = tf_np.result_type(inp, fltr) + inp = tf_np.asarray(inp, dtype) + fltr = tf_np.asarray(fltr, dtype) + return tf_np.asarray( + tf.nn.convolution( + input=inp, + filters=fltr, + padding=padding, + strides=window_strides, + dilations=filter_dilation, + data_format=input_spec, + ) + ) + + +def avg_pool(x, pool_size, strides, padding): + """Performs an N-D average pooling. + + Args: + x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + + [num_channels]`. Pooling happens over the spatial dimensions only. + pool_size: sequence of N ints. + strides: sequence of N ints. + padding: a string, the padding algorithm. Must be "SAME" or "VALID". + + Returns: + An (N+2)-D array, of shape + [batch_size] + output_spatial_shape + [num_channels], + where `output_spatial_shape` depends on the value of padding: + If padding = "SAME": + output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) + If padding = "VALID": + output_spatial_shape[i] = + ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). + """ + x = tf_np.asarray(x) + return tf_np.asarray( + tf.nn.pool( + input=x, + window_shape=pool_size, + pooling_type="AVG", + strides=strides, + padding=padding, + ) + ) + + +def max_pool(x, pool_size, strides, padding): + """Performs an N-D max pooling. + + Args: + x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + + [num_channels]`. Pooling happens over the spatial dimensions only. + pool_size: sequence of N ints. + strides: sequence of N ints. + padding: a string, the padding algorithm. Must be "SAME" or "VALID". + + Returns: + An (N+2)-D array, of shape + [batch_size] + output_spatial_shape + [num_channels], + where `output_spatial_shape` depends on the value of padding: + If padding = "SAME": + output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) + If padding = "VALID": + output_spatial_shape[i] = + ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). + """ + x = tf_np.asarray(x) + return tf_np.asarray( + tf.nn.pool( + input=x, + window_shape=pool_size, + pooling_type="MAX", + strides=strides, + padding=padding, + ) + ) + + +def sort_key_val(keys, values, dimension=-1): + """Sorts keys along a dimension and applies same permutation to values. + + Args: + keys: an array. The dtype must be comparable numbers (integers and reals). + values: an array, with the same shape of `keys`. + dimension: an `int`. The dimension along which to sort. + + Returns: + Permuted keys and values. + """ + keys = tf_np.asarray(keys) + values = tf_np.asarray(values) + rank = keys.shape.ndims + if rank is None: + rank = values.shape.ndims + if rank is None: + # We need to know the rank because tf.gather requires batch_dims to be `int` + raise ValueError( + "The rank of either keys or values must be known, but " + "both are unknown (i.e. their shapes are both None)." + ) + if dimension in (-1, rank - 1): + + def maybe_swapaxes(a): + return a + + else: + + def maybe_swapaxes(a): + return tf_np.swapaxes(a, dimension, -1) + + # We need to swap axes because tf.gather (and tf.gather_nd) supports + # batch_dims on the left but not on the right. + # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis. + keys = maybe_swapaxes(keys) + values = maybe_swapaxes(values) + idxs = tf_np.argsort(keys) + + # Using tf.gather rather than np.take because the former supports batch_dims + def gather(a): + return tf_np.asarray(tf.gather(a, idxs, batch_dims=rank - 1)) + + keys = gather(keys) + values = gather(values) + keys = maybe_swapaxes(keys) + values = maybe_swapaxes(values) + return keys, values + + +def scan(f, init, xs, length=None, reverse=False): + """Scan a function over leading array axes while carrying along state. + + See the docstring of `jax.lax.scan` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) for + details. + + Args: + f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning + that ``f`` accepts two arguments where the first is a value of the loop + carry and the second is a slice of ``xs`` along its leading axis, and that + ``f`` returns a pair where the first element represents a new value for + the loop carry and the second represents a slice of the output. Note that + the input and output carry must have the same dtype. + init: an initial loop carry value of type ``c``, which can be a scalar, + array, or any pytree (nested Python tuple/list/dict) thereof, representing + the initial loop carry value. This value must have the same structure as + the first element of the pair returned by ``f``. + xs: the value of type ``[a]`` over which to scan along the leading axis, + where ``[a]`` can be an array or any pytree (nested Python + tuple/list/dict) thereof with consistent leading axis sizes. + length: optional integer specifying the number of loop iterations, which + must agree with the sizes of leading axes of the arrays in ``xs`` (but can + be used to perform scans where no input ``xs`` are needed). + reverse: optional boolean specifying whether to run the scan iteration + forward (the default) or in reverse, equivalent to reversing the leading + axes of the arrays in both ``xs`` and in ``ys``. + + Returns: + A pair of type ``(c, [b])`` where the first element represents the final + loop carry value and the second element represents the stacked outputs of + the second output of ``f`` when scanned over the leading axis of the inputs. + """ + init, xs = tf.nest.map_structure( + lambda x: tf_np.asarray(x) if x is not None else None, (init, xs) + ) + if length is not None: + length = int(length) + + def get_length(x): + if x is None: + return None + if x.shape.rank == 0: + raise ValueError("Some array in `xs` doesn't have a leading dimension") + return x.shape[0] + + lengths = tf.nest.flatten(tf.nest.map_structure(get_length, xs)) + for l in lengths: + if l is not None: + if length is None: + length = l + elif length != l: + raise ValueError( + "There are two different leading-dimension lengths: " + f"{length} and {l}" + ) + if length is None: + raise ValueError("Can't determine length. Please set the `length` argument.") + xs_ta = tf.nest.map_structure( + lambda t: ( + tf.TensorArray( + t.dtype, size=length, dynamic_size=False + ).unstack( # pylint: disable=g-long-lambda + t + ) + if t is not None + else None + ), + xs, + ) + # tf.while_loop doesn't allow None in loop_vars, so we mask them. + is_init_none = tf.nest.map_structure(lambda x: x is None, init) + + def to_safe(carry): + return tf.nest.map_structure( + lambda x, is_none: tf.zeros([]) if is_none else x, carry, is_init_none + ) + + def from_safe(safe_carry): + return tf.nest.map_structure( + lambda x, is_none: None if is_none else x, safe_carry, is_init_none + ) + + def body(i, safe_carry, ys_ta): + carry = from_safe(safe_carry) + if reverse: + i_ = length - 1 - i + else: + i_ = i + xs = tf.nest.map_structure( + lambda x_ta: x_ta.read(i_) if x_ta is not None else None, xs_ta + ) + carry, ys = f(*_tf_to_np((carry, xs))) + ys_ta = tf.nest.map_structure( + lambda y_ta, y: (y_ta.write(i_, y) if y is not None else y_ta), ys_ta, ys + ) + i = i + 1 + safe_carry = to_safe(carry) + return i, safe_carry, ys_ta + + xs_spec = tf.nest.map_structure( + lambda t: tf.TensorSpec(t.shape[1:], t.dtype) if t is not None else None, xs + ) + _, ys_spec = eval_on_shapes(f)(init, xs_spec) + # ys_ta can't contain None because tf.while_loop doesn't allow None in + # loop_vars. + ys_ta = tf.nest.map_structure( + lambda y: tf.TensorArray( + y.dtype if y is not None else tf.float32, # pylint: disable=g-long-lambda + size=length, + dynamic_size=False, + ), + ys_spec, + ) + safe_init = to_safe(init) + _, safe_carry, ys_ta = tf.while_loop( + lambda i, *_: i < length, body, (0, safe_init, ys_ta), maximum_iterations=length + ) + carry = from_safe(safe_carry) + + def _stack(a, spec): + if spec is None: + return None + a = a.stack() + a.set_shape((length,) + a.shape[1:]) + return a + + ys = tf.nest.map_structure(_stack, ys_ta, ys_spec) + return _tf_to_np((carry, ys)) + + +# named "tf_map" instead of "map" as in JAX to avoid conflict with Python `map` +def tf_map(f, xs): + """Map a function over leading array axes. + + See the docstring of `jax.lax.map` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html) for + details. + + Args: + f: a Python function to apply element-wise over the first axis or axes of + `xs`. + xs: values over which to map along the leading axis. + + Returns: + Mapped values. + """ + + def g(unused, x): + return unused, f(x) + + carry = tf.nest.map_structure(lambda _: None, xs) + return scan(g, carry, xs)[1] + + +def _get_dynamic_indices(operand, start_indices, slice_sizes): + """Calcuates the indices for `tf.gather_nd` from slices. + + Args: + operand: a Tensor to slice. + start_indices: a vector Tensor of integers, one per dimension. The starts of + the slice. The vector can be dynamic. + slice_sizes: a list of integers, one per dimension. The sizes of the slice. + + Returns: + An index array suitable for `tf.gather_nd` and `tf.scatter_nd`, or `None` if + `operand` is a scalar. + """ + rank = len(slice_sizes) + operand_rank = tf.rank(operand) + tf.debugging.Assert(operand_rank == rank, [operand_rank, rank]) + starts_rank = tf.rank(start_indices) + tf.debugging.Assert(starts_rank == 1, [starts_rank]) + num_starts = tf.shape(start_indices)[0] + tf.debugging.Assert(num_starts == rank, [num_starts, rank]) + operand_shape = tf.shape(operand) + tf.debugging.Assert( + tf.reduce_all(slice_sizes <= operand_shape), [slice_sizes, operand_shape] + ) + if rank == 0: + return None + start_indices = tf.where( + start_indices < 0, start_indices + operand_shape, start_indices + ) + idx_list = [] + for i in range(rank): + start = start_indices[i] + size = slice_sizes[i] + dim = operand_shape[i] + start = tf.clip_by_value(start, 0, dim - size) + # XLA requires tf.range's `start` to be compile-time constant, so we can't + # do tf.range(start, ...). + idx = start + tf.range(size) + shape = [1] * rank + shape[i] = size + idx = tf.reshape(idx, shape) + idx_list.append(idx) + slice_sizes_tensor = tf.convert_to_tensor(slice_sizes) + # tf.stack doesn't support broadcasting, so we need to broadcast manually. + # TODO(wangpeng): Reduce peak memory by broadcasting one-by-one instead of + # all-together. + idx_list = [tf.broadcast_to(x, slice_sizes_tensor) for x in idx_list] + return tf.stack(idx_list, axis=-1) + + +def dynamic_slice(operand, start_indices, slice_sizes): + """Slicing operation where the indices can be dynamic vlaues. + + See the docstring of `jax.lax.dynamic_slice` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) + for details. + + Args: + operand: an array to slice. + start_indices: a vector of integers, one per dimension. The starts of the + slice. The vector can be dynamic. + slice_sizes: a list of integers, one per dimension. The sizes of the slice. + + Returns: + An array containing the slice, with shape equal to `slice_sizes`. + """ + # This implementation uses tf.gather_nd to implement dynamic_slice, which is + # memory inefficient because the size of `indices` given to gather_nd is + # large. + operand = tf_np.asarray(operand).data + start_indices = tf_np.asarray(start_indices, np.int32).data + idx = _get_dynamic_indices(operand, start_indices, slice_sizes) + if idx is not None: + operand = tf.gather_nd(operand, idx) + return tf_np.asarray(operand) + + +def dynamic_update_slice(operand, update, start_indices): + """Updates a dynamic slice. + + See the docstring of `jax.lax.dynamic_update_slice` + (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_update_slice.html) + for details. + + Args: + operand: an array to slice. + update: an array containing the new values to write onto `operand`. + start_indices: a vector of integers, one per dimension. The starts of the + slice. The vector can be dynamic. + + Returns: + The updated version of `operand`. + """ + operand = tf_np.asarray(operand).data + update = tf_np.asarray(update).data + start_indices = tf_np.asarray(start_indices, np.int32).data + if not update.shape.is_fully_defined(): + raise ValueError("update's shape must be fully defined") + slice_sizes = update.shape + idx = _get_dynamic_indices(operand, start_indices, slice_sizes) + if idx is None: + # `np.zeros([])[()] = 1.0` will result in a scalar array of 1.0 + return tf_np.asarray(update) + operand = tf.tensor_scatter_nd_update(operand, idx, update) + return tf_np.asarray(operand) + + +def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0): + """Convenience wrapper around dynamic_slice applying to one dimension.""" + operand = tf_np.asarray(operand) + start_indices = [0] * operand.ndim + slice_sizes = list(operand.shape) + axis = int(axis) + start_indices[axis] = start_index + slice_sizes[axis] = int(slice_size) + return dynamic_slice(operand, start_indices, slice_sizes) + + +def dynamic_update_slice_in_dim(operand, update, start_index, axis): + """Convenience wrapper around dynamic_update_slice for one dimension.""" + operand = tf_np.asarray(operand) + axis = int(axis) + start_indices = [0] * operand.ndim + start_indices[axis] = start_index + return dynamic_update_slice(operand, update, start_indices) + + +# Use int64 instead of int32 to avoid TF's "int32 problem" +_RNG_KEY_DTYPE = np.int64 + + +def _key2seed(a): + """Converts an RNG key to an RNG seed. + + Args: + a: an RNG key, an ndarray of shape [] and dtype `np.int64`. + + Returns: + an RNG seed, a tensor of shape [2] and dtype `tf.int32`. + """ + + def int64_to_int32s(a): + """Converts an int64 tensor of shape [] to an int32 tensor of shape [2].""" + a = tf.cast(a, tf.uint64) + fst = tf.cast(a, tf.uint32) + snd = tf.cast(tf.bitwise.right_shift(a, tf.constant(32, tf.uint64)), tf.uint32) + a = [fst, snd] + a = tf.nest.map_structure(lambda x: tf.cast(x, tf.int32), a) + a = tf.stack(a) + return a + + return int64_to_int32s(a) + + +def _seed2key(a): + """Converts an RNG seed to an RNG key. + + Args: + a: an RNG seed, a tensor of shape [2] and dtype `tf.int32`. + + Returns: + an RNG key, an ndarray of shape [] and dtype `np.int64`. + """ + + def int32s_to_int64(a): + """Converts an int32 tensor of shape [2] to an int64 tensor of shape [].""" + a = tf.bitwise.bitwise_or( + tf.cast(a[0], tf.uint64), + tf.bitwise.left_shift(tf.cast(a[1], tf.uint64), tf.constant(32, tf.uint64)), + ) + a = tf.cast(a, tf.int64) + return a + + return tf_np.asarray(int32s_to_int64(a)) + + +def prng(s): + """Creates RNG state from seed. + + Args: + s: the seed, an integer. + + Returns: + An RNG state, as a scalar array of dtype `np.int64`. + """ + # TODO(wangpeng): Become bitwise-identical to JAX when TF stateless RNGs get + # improved. + return tf_np.asarray(s, dtype=_RNG_KEY_DTYPE) + + +def stateless_split(seed, num=2): + """Splits an RNG seed into `num` new seeds by adding a leading axis. + + Example: + + >>> seed = [1, 2] + >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) + >>> print(new_seeds) + tf.Tensor( + [[1105988140 1738052849] + [-335576002 370444179] + [ 10670227 -246211131]], shape=(3, 2), dtype=int32) + >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) + + + Args: + seed: an RNG seed (a tensor with shape [2] and dtype `int32` or `int64`). + (When using XLA, only `int32` is allowed.) + num: optional, a positive integer or scalar tensor indicating the number of + seeds to produce (default 2). + + Returns: + A tensor with shape [num, 2] representing `num` new seeds. It will have the + same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype + will be determined by `tf.convert_to_tensor`). + """ + seed = tf.convert_to_tensor(seed) + return tf.random.stateless_uniform( + shape=[num, 2], seed=seed, dtype=seed.dtype, minval=None, maxval=None + ) + + +def split(state, num): + """Creates new independent RNG states from an existing state. + + Args: + state: the existing state. + num: the number of the new states. + + Returns: + A tuple of new states. + """ + state = tf_np.asarray(state, dtype=_RNG_KEY_DTYPE) + state = _key2seed(state) + try: + states = tf.random.experimental.stateless_split(state, num) + except AttributeError: # pylint: disable=unused-variable + # TODO(afrozm): For TF < 2.3 we need to do this. Delete once 2.3 launches. + states = stateless_split(state, num) + states = tf.unstack(states, num) + states = tf.nest.map_structure(_seed2key, states) + return states + + +def uniform(key, shape, dtype=tf_np.float32, minval=0.0, maxval=1.0): + """Sample uniform random values in range [`minval`, `maxval`). + + Args: + key: the RNG key. + shape: the shape of the result. + dtype: the dtype of the result. + minval: the minimal value (inclusive). + maxval: the maximal value (exclusive). + + Returns: + An ndarray with shape `shape` and dtype `dtype`. Each value in the ndarray + is sampled uniformly randomly in range [`minval`, `maxval`). + """ + minval = tf.cast(minval, dtype) + maxval = tf.cast(maxval, dtype) + key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) + return tf_np.asarray( + tf.random.stateless_uniform( + shape, seed=_key2seed(key), dtype=dtype, minval=minval, maxval=maxval + ) + ) + + +def normal(key, shape, dtype=tf.float32): + """Sample standard-normal random values. + + Args: + key: the RNG key. + shape: the shape of the result. + dtype: the dtype of the result. + + Returns: + Random values in standard-normal distribution. + """ + key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) + return tf_np.asarray( + tf.random.stateless_normal(shape, seed=_key2seed(key), dtype=dtype) + ) + + +def bernoulli(key, mean=np.float32(0.5), shape=None): + """Sample Bernoulli random values with given shape and mean. + + Args: + key: the RNG key. + mean: optional, an array_like broadcastable to `shape` for the mean of the + random variables (default 0.5). + shape: optional, a tuple of nonnegative integers representing the shape + (default to `mean`'s shape). + + Returns: + A random array with the specified shape and boolean dtype. + """ + mean = tf_np.asarray(mean) + if shape is None: + shape = mean.shape + return uniform(key, shape) < mean + + +def _eager_dataset_iterator(dataset): + for item in dataset: + yield tf.nest.map_structure(tf_np.asarray, item) + + +def dataset_as_numpy(dataset): + """Converts a `tf.data.Dataset` to an iterable of ndarrays. + + `dataset_as_numpy` converts a possibly nested structure of `tf.data.Dataset`s + and `tf.Tensor`s to iterables of ndarrays and ndarrays, respectively. This + function must be run in eager mode outside tf.function. + + Args: + dataset: a possibly nested structure of `tf.data.Dataset`s and/or + `tf.Tensor`s. + + Returns: + A structure matching `dataset` where `tf.data.Dataset`s are converted to + generators of ndarrays and `tf.Tensor`s are converted to ndarrays. + """ + if not tf.executing_eagerly(): + raise ValueError( + "dataset_as_numpy must be run in eager mode outside tf.function" + ) + nested_ds = dataset + del dataset + + # Flatten + flat_ds = tf.nest.flatten(nested_ds) + flat_np = [] + + # Type check for Tensors and Datasets + for ds_el in flat_ds: + if not isinstance(ds_el, (tf.Tensor, tf.data.Dataset)): + types = tf.nest.map_structure(type, nested_ds) + raise ValueError( + "Arguments to dataset_as_numpy must be (possibly nested " + "structure of) tf.Tensors or tf.data.Datasets. Got: %s" % types + ) + + for ds_el in flat_ds: + if isinstance(ds_el, tf.Tensor): + np_el = tf_np.asarray(ds_el) + elif isinstance(ds_el, tf.data.Dataset): + np_el = _eager_dataset_iterator(ds_el) + else: + assert False + flat_np.append(np_el) + + return tf.nest.pack_sequence_as(nested_ds, flat_np) + + +# TODO(nareshmodi): Group key should change based on the set of devices that we +# are mapping over. Make it so that we assign a unique group_key for every +# unique set of devices. We don't change it every time to avoid the overhead of +# discovering the full group (though may not be problematic in the local case). +_GROUP_KEY = 1 +_INSTANCE_KEY = 0 +_INSTANCE_LOCK = threading.Lock() + + +# TODO(b/142565636): Ensure that multiple concurrent calls to a tf.function +# containing a collective op run reasonably. +def _get_instance_key(): + global _INSTANCE_KEY + global _INSTANCE_LOCK + with _INSTANCE_LOCK: + _INSTANCE_KEY = _INSTANCE_KEY + 1 + return _INSTANCE_KEY + + +# Don't use a namedtuple since nest considers that a tuple and unflattens and +# flattens it. +class ShardedNdArray(object): + """Wrapper over ndarray that can contain tensors on multiple devices. + + This is returned by extensions.pmap, and contains the individual tensors on + different devices. + """ + + def __init__(self, tensors): + """Initializes the ShardedNdArray. + + Note that the tensors should be ordered in the way the pmap producing these + tensors is run. + + Args: + tensors: list or tuple of eager tensors, one for each device. + """ + + if not isinstance(tensors, (list, tuple)) or not tensors: + raise ValueError( + "Unable to create a ShardedNdArray without a list of tensors." + ) + self.tensors = tensors + self.n_devices = len(tensors) + + def __getitem__(self, i): + return tf_np.asarray(self.tensors[i]) + + @property + def shape(self): + return (self.n_devices,) + self.tensors[ + 0 + ]._shape_tuple() # pylint: disable=protected-access + + @property + def dtype(self): + return self.tensors[0].dtype + + +def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): + del args, kwargs + # TODO(nareshmodi): Consider a collective op to gather the tensors from the + # various devices for performance reasons. + return tf.stack(value.tensors) + + +tf.register_tensor_conversion_function( + ShardedNdArray, convert_sharded_tensor_to_eager_tensor +) + + +class _PmapConfig(threading.local): + """Simple config used to maintain state related to a current pmap call.""" + + def __init__(self): + super(_PmapConfig, self).__init__() + self._axis_name = None + self._devices = None + + def axis_name(self): + return self._axis_name + + def set_axis_name(self, axis_name): + self._axis_name = axis_name + + def devices(self): + return self._devices + + def set_devices(self, devices): + self._devices = devices + + +_pmap_config = _PmapConfig() + + +@contextlib.contextmanager +def pmap_config(axis_name, devices): + """Records axis_name and devices for this context.""" + old_axis_name = _pmap_config.axis_name() + old_devices = _pmap_config.devices() + _pmap_config.set_axis_name(axis_name) + _pmap_config.set_devices(devices) + try: + yield + finally: + _pmap_config.set_axis_name(old_axis_name) + _pmap_config.set_devices(old_devices) + + +def _psum(tensor, axis_name=None): + """Sum all-reduction. + + Args: + tensor: A tensor. + axis_name: The axis name to reduce. Must equal to that of the surrounding + pmap. + + Returns: + The sum of the `tensor` replicas on each participating devices. + """ + if axis_name != _pmap_config.axis_name(): + raise ValueError( + "axis_name (%s) is not equal to that of the surrounding " + "pmap (%s)" % (axis_name, _pmap_config.axis_name()) + ) + devices = _pmap_config.devices() + if devices is None: + raise ValueError("Can't retrieve the device list from the surrounding pmap") + tensor = tf_np.asarray(tensor) + if tpu_devices(devices): + # TODO(b/170895907): Remove this workaround when tpu.cross_replica_sum + # supports int64/float64. + is_int64 = False + is_float64 = False + if tensor.dtype == np.int64: + is_int64 = True + tensor = tensor.astype(np.int32) + elif tensor.dtype == np.float64: + is_float64 = True + tensor = tensor.astype(np.float32) + # TODO(wangpeng): Supply the `group_assignment` argument to + # tpu.cross_replica_sum, calculated from `devices`. + tensor = tf.compat.v1.tpu.cross_replica_sum(tensor) + if is_int64: + tensor = tf.cast(tensor, tf.int64) + elif is_float64: + tensor = tf.cast(tensor, tf.float64) + else: + tensor = tf.raw_ops.CollectiveReduce( + input=tensor, + group_size=len(devices), + group_key=_GROUP_KEY, + instance_key=_get_instance_key(), + merge_op="Add", + final_op="Id", + subdiv_offsets=(0,), + ) + return tf_np.asarray(tensor) + + +def psum(tensors, axis_name=None): + return tf.nest.map_structure(functools.partial(_psum, axis_name=axis_name), tensors) + + +# Note this is not available in the jax api, but seemed like a reasonable API +# to have. +def pmean(tensor, axis_name=None): + """Mean all-reduction. + + Args: + tensor: A tensor. + axis_name: The axis name to reduce. Must equal to that of the surrounding + pmap. + + Returns: + The mean of the `tensor` replicas on each participating devices. + """ + if axis_name != _pmap_config.axis_name(): + raise ValueError( + "axis_name (%s) is not equal to that of the surrounding " + "pmap (%s)" % (axis_name, _pmap_config.axis_name()) + ) + devices = _pmap_config.devices() + if devices is None: + raise ValueError("Can't retrieve the device list from the surrounding pmap") + if tpu_devices(devices): + # TODO(wangpeng): Implement this. + raise ValueError("pmean for TPU is not supported yet.") + else: + return tf.raw_ops.CollectiveReduce( + input=tensor, + group_size=len(devices), + group_key=_GROUP_KEY, + instance_key=_get_instance_key(), + merge_op="Add", + final_op="Div", + subdiv_offsets=(0,), + ) + + +def _get_pmap_impl(f, devices, has_tpu): + """This is a helper function to return the pmap impl. + + Args: + f: a function that takes ndarrays and returns ndarrays. + devices: a list of strings; the device list. + has_tpu: boolean; whether `devices` contains TPU devices. + + Returns: + A function that takes tensors and returns tensors. + """ + if has_tpu: + # Workaround b/121383831 + output_is_list = [False] # Use list for mutability + + def recorder(args, kwargs, res): + del args, kwargs + output_is_list[0] = isinstance(res, list) + return res + + f = _record_result_type(recorder, f) + + def tf_f(*tf_args): + """A wrapper for `f` that takes/returns tensors.""" + np_args = _tf_to_np(tf_args) + np_out = f(*np_args) + return np_out + + if has_tpu: + + @tf.function(autograph=False) + def fn(inputs): + # TODO(wangpeng): Supply the `device_assignment` argument to + # tpu.replicate, calculated from `devices`. + res = tf.compat.v1.tpu.replicate(tf_f, inputs) + # Workaround b/121383831 + if ( + res + and isinstance(res[0], list) + and len(res[0]) == 1 + and not output_is_list[0] + ): + res = [x[0] for x in res] + return res + + return fn + else: + # This is run in a tf.function so that the various underlying functions can + # be run in parallel. + # The trace happens on the client, so any devices should not depend on any + # side effects. + + jit_tf_f = tf.function(tf_f, autograph=False) + + @tf.function(autograph=False) + def fn(all_per_device_args): + """Multi-device function with calls placed on the correct device.""" + + results = [] + for per_device_args, device in zip(all_per_device_args, devices): + with tf.device(device): + results.append(jit_tf_f(*per_device_args)) + return results + + return fn + + +def pmap(f, axis_name=None, devices=None): + """Transforms a function into a multi-device function. + + The semantics are similar to JAX's pmap. + + Args: + f: The function to be converted. + axis_name: Used for nested pmap, which is not supported yet. + devices: The devices over which the returned function will run. + + Returns: + A function that runs the underlying function `f` on `devices`. Its arguments + can be `ShardedNdArray`s, tensors or other Python objects, and its return + values are all `ShardedNdArray`s. If an input is a tensor, the length of its + first dimension must equal the number of devices, and the tensor will be + splitted along its first dimension among the devices. If an input is an + unknown Python object, it will be replicated among the devices. + """ + if devices is None: + devices = accelerators() + if not isinstance(devices, (list, tuple)): + raise ValueError("Must pass a list or tuple of devices") + num_devices = len(devices) + if not num_devices: + raise ValueError("There must be at least 1 device") + has_tpu = bool(tpu_devices(devices)) + + pmap_fn = _get_pmap_impl(f, devices, has_tpu) + + def wrapper(*args): + """Wrapper that wraps/unwraps args, retvals, and runs the function.""" + if _pmap_config.devices() is not None: + raise ValueError( + "Found a surrounding pmap. Nested pmap is not supported " "yet." + ) + # TODO(wangpeng): Maybe we should use `asarray` to convert everything + # to ndarray first. + + flattened_input_args = tf.nest.flatten(args) + flattened_per_device_args = [[] for _ in devices] + for arg in flattened_input_args: + if isinstance(arg, tf.Tensor): + # TODO(nareshmodi): Try and use the dynamic shape instead. + if (not arg.shape.rank) or arg.shape[0] != len(devices): + # TODO(nareshmodi): Fix this restriction + raise ValueError( + "Input tensors need to have a first dimension equal to " + "the number of devices; got tensor of shape %s and %s devices" + % (arg.shape, len(devices)) + ) + # NOTE: Alternatively use tf.split, and place the split tensors on the + # appropriate device. The best solution for this is to have an API that + # splits a tensor across devices. + for j, device in enumerate(devices): + updated_arg = tf.gather(arg, j) + # TODO(wangpeng): Investigate whether we need a tf.identity for TPU. + if not has_tpu: + with tf.device(device): + updated_arg = tf.identity(updated_arg) + flattened_per_device_args[j].append(updated_arg) + elif isinstance(arg, ShardedNdArray): + for device_args, tensor in zip(flattened_per_device_args, arg.tensors): + device_args.append(tensor) + else: + for device_args in flattened_per_device_args: + device_args.append(arg) + + all_per_device_args = [ + tf.nest.pack_sequence_as(args, device_args) + for device_args in flattened_per_device_args + ] + + with pmap_config(axis_name, devices): + results = pmap_fn(all_per_device_args) + + # Rewrap things. This can probably be written better. + flattened_results = [tf.nest.flatten(result) for result in results] + final_tree = [] + + # TODO(nareshmodi): assert all items in flattened_results have the same + # structures + + for i in range(len(flattened_results[0])): + tensors = [] + for j, device in enumerate(devices): + assert isinstance( + flattened_results[j][i], tf.Tensor + ), "currently only tensor return items are supported" + tensors.append(flattened_results[j][i]) + final_tree.append(ShardedNdArray(tensors)) + + return tf.nest.pack_sequence_as(results[0], final_tree) + + return wrapper + + +def find_devices(device_type, devices=None): + if not devices: + devices = [d.name for d in tf.config.experimental.list_logical_devices()] + devices = [(d, tf.DeviceSpec.from_string(d)) for d in devices] + results = [name for name, d in devices if d.device_type == device_type] + return results + + +def tpu_devices(devices=None): + """Gets TPU devices out of `devices`. + + Args: + devices: A device list (as a list of strings). If None, the list of all + available devices will be used for it. + + Returns: + Those in `devices` that are TPUs. + """ + return find_devices("TPU", devices) + + +def gpu_devices(devices=None): + """Gets GPU devices out of `devices`. + + Args: + devices: A device list (as a list of strings). If None, the list of all + available devices will be used for it. + + Returns: + Those in `devices` that are GPUs. + """ + return find_devices("GPU", devices) + + +def accelerators(devices=None): + return tpu_devices(devices) or gpu_devices(devices) + + +def _tree_broadcast(to, s): + """Broadcasts `s` to the nested structure `to`.""" + if not isinstance(to, (list, tuple, dict)): + if not isinstance(s, (int, type(None))): + raise ValueError + return s + if isinstance(s, (int, type(None))): + return tf.nest.map_structure(lambda x: s, to) + if isinstance(to, (list, tuple)): + if len(to) != len(s): + raise ValueError + new_s = [_tree_broadcast(x, y) for x, y in zip(to, s)] + if isinstance(to, tuple): + new_s = tuple(new_s) + return new_s + elif isinstance(to, dict): + return {k: _tree_broadcast(to[k], s[k]) for k in to} + else: + raise TypeError("Unsupported type %s" % type(to)) + + +def vmap(f, in_axes=0, out_axes=0): + """Returns a function that maps `f` over first dimension of inputs.""" + in_axes_flat = tf.nest.flatten(in_axes) + if not all(isinstance(l, (type(None), int)) for l in in_axes_flat): + raise TypeError( + "vmap in_axes must be an int, None, or (nested) container with " + "those types as leaves, but got {}.".format(in_axes) + ) + if all(isinstance(l, type(None)) for l in in_axes_flat): + raise ValueError("vmap must have at least one non-None value in in_axes") + + out_axes_flat = tf.nest.flatten(out_axes) + if not all(isinstance(l, (type(None), int)) for l in out_axes_flat): + raise TypeError( + "vmap out_axes must be an int, None, or (nested) container with " + "those types as leaves, but got {}.".format(out_axes) + ) + + def _f(*args): + flat_args = tf.nest.flatten(args) + try: + f_in_axes = _tree_broadcast(args, in_axes) + except ValueError: + six.reraise( + ValueError, + ValueError( + "vmap in_axes specification must be a tree prefix of the " + r"corresponding value, got specification %s for value tree %s" + % (in_axes, args) + ), + sys.exc_info()[2], + ) + f_in_axes_flat = tf.nest.flatten(f_in_axes) + + def tf_f(tf_args): + """Function passed to tf.vectorized_map call.""" + # Note that unbatched arguments are not passed to tf_f. Here we fill thos + # arguments back before calling `f`. + tf_flat_args = [] + j = 0 + for arg, axis in zip(flat_args, f_in_axes_flat): + if axis is None: + tf_flat_args.append(arg) + else: + tf_flat_args.append(tf_args[j]) + j += 1 + unbatched_args = tf.nest.pack_sequence_as(args, tf_flat_args) + return f(*unbatched_args) + + # Constructs arguments to pass to `tf_f`. + # Unbatch arguments are skipped. Arguments with non-zero axis are + # transposed. + tf_args = [] + for arg, axis in zip(flat_args, f_in_axes_flat): + if axis is None: + continue + arg = tf_np.asarray(arg) + if axis != 0: + arg = tf_np.moveaxis(arg, axis, 0) + tf_args.append(arg) + # TODO(agarwal): consider creating a tf.function outside of _f and reusing + # that to avoid overheads of re-vectorizing the code when running eagerly. + outputs = tf.vectorized_map(tf_f, tf_args) + try: + f_out_axes = _tree_broadcast(outputs, out_axes) + except ValueError: + six.reraise( + ValueError, + ValueError( + "vmap out_axes specification must be a tree prefix of the " + r"corresponding value, got specification %s for value tree %s" + % (out_axes, outputs) + ), + sys.exc_info()[2], + ) + + def map_output(x, axis): + """Maps output of tf.vectorized_map to the final output.""" + x = tf_np.asarray(x) + if axis is None: + # Note that `tf.vectorized_map always batches the outputs. + # Here we unbatch it again. + return x[0, ...] + elif axis == 0: + return x + else: + # Need to transpose the output. + return tf_np.moveaxis(x, 0, axis) + + new_outputs = [ + map_output(output, axis) + for output, axis in zip( + tf.nest.flatten(outputs), tf.nest.flatten(f_out_axes) + ) + ] + return tf.nest.pack_sequence_as(outputs, new_outputs) + + return _f diff --git a/trax/tf/numpy/__init__.py b/trax/tf/numpy/__init__.py new file mode 100644 index 000000000..ecd2ceea4 --- /dev/null +++ b/trax/tf/numpy/__init__.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NumPy like wrapper for Tensorflow.""" + + +# Enable NumPy behavior globally +from tensorflow.python.ops.numpy_ops import np_config + +np_config.enable_numpy_behavior() + +# Make everything from tensorflow.experimental.numpy available +# Import all from tensorflow.experimental.numpy +from tensorflow import bfloat16 +from tensorflow.experimental.numpy import random +from tensorflow.experimental.numpy import * # pylint: disable=wildcard-import +from tensorflow.python.ops.numpy_ops.np_dtypes import ( + canonicalize_dtype, + default_float_type, + is_allow_float64, + set_allow_float64, +) + +# Define what should be accessible when someone imports from this module +__all__ = [ + "bfloat16", + "canonicalize_dtype", + "default_float_type", + "is_allow_float64", + "set_allow_float64", +] diff --git a/trax/tf_numpy/examples/mnist/dataset.py b/trax/tf_numpy/examples/mnist/dataset.py deleted file mode 100644 index 755f724bf..000000000 --- a/trax/tf_numpy/examples/mnist/dataset.py +++ /dev/null @@ -1,85 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Load pickled MNIST data.""" -import gzip -import os -import pickle -import random -import urllib -import numpy as np - - -def load(): - """Loads the dataset. - - Looks for the dataset at /tmp/mnist.pkl.gz and downloads it if it is not there - already. - - Note: The training data is shuffled. - - Returns: - ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)). - Shapes: - train_x: num_training_examples x image_size - train_y: num_training_examples x num_classes - valid_x: num_validation_examples x image_size - valid_y: num_validation_examples x num_classes - test_x: num_test_examples x image_size - test_y: num_test_examples x num_classes - """ - filepath = _maybe_download() - with gzip.open(os.path.join(filepath), 'rb') as f: - training_data, validation_data, test_data = pickle.load(f) - training_data = (training_data[0], [to_one_hot(x) for x in training_data[1]]) - validation_data = (validation_data[0], - [to_one_hot(x) for x in validation_data[1]]) - test_data = (test_data[0], [to_one_hot(x) for x in test_data[1]]) - - def shuffle(data): - zipped = zip(*data) - random.shuffle(zipped) - return zip(*zipped) - - return (shuffle(training_data), validation_data, test_data) - - -def to_one_hot(label, num_classes=10): - vec = np.zeros(num_classes, dtype=np.float32) - vec[label] = 1. - return vec - - -def _maybe_download(): - """Downloads the MNIST dataset if it is not there already.""" - data_url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' - filename = data_url.split('/')[-1] - filepath = os.path.join(_get_data_dir(), filename) - if not os.path.exists(filepath): - - def _progress(count, block_size, total_size): - print('\r>> Downloading %s %.1f%%' % - (filename, float(count * block_size) / float(total_size) * 100.0)) - - filepath, _ = urllib.urlretrieve(data_url, filepath, _progress) - statinfo = os.stat(filepath) - print('Successfully downloaded %s %d bytes.' % (filename, statinfo.st_size)) - else: - print('Data already present on disk.') - return filepath - - -def _get_data_dir(): - return '/tmp' diff --git a/trax/tf_numpy/examples/mnist/model.py b/trax/tf_numpy/examples/mnist/model.py deleted file mode 100644 index 8f5057b53..000000000 --- a/trax/tf_numpy/examples/mnist/model.py +++ /dev/null @@ -1,132 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model for training on MNIST data.""" -from numpy import float32 -from numpy import int32 - -import tensorflow.compat.v2 as tf - -from trax.tf_numpy import numpy as np - - -class Model(object): - """A simple neural network with dense layers and sigmoid non-linearity. - - The network consists of `len(hidden_layers) + 1` dense layers. The sizes of - the hidden layers are specified by the user in `hidden_layers` and the - network takes care of adding layers to match the input and output size. - - Attributes: - weights: A list of 2-d float32 arrays containing the layer weights. - biases: A list of 2-d float32 arrays containing the layer biases. - - Methods: - forward: Can be used to perform a forward pass on a batch of - flattened images. Output is returned as a batch of one-hot vectors of the - classes. - train: method performs a forward and backward pass and updates the - weights and biases. - evaluate: method can be used to evaluate the network on a batch of - examples. - """ - - def __init__(self, hidden_layers, input_size=784, num_classes=10): - """Initializes the neural network. - - Args: - hidden_layers: List of ints specifying the sizes of hidden layers. Could - be empty. - input_size: Length of the input array. The network receives the input - image as a flattened 1-d array. Defaults to 784(28*28), the default - image size for MNIST. - num_classes: The number of output classes. Defaults to 10. - """ - hidden_layers = [input_size] + hidden_layers + [num_classes] - self.weights = [] - self.biases = [] - for i in range(len(hidden_layers) - 1): - # TODO(srbs): This is manually cast to float32 to avoid the cast in - # np.dot since backprop fails for tf.cast op. - self.weights.append( - np.array( - np.random.randn(hidden_layers[i + 1], hidden_layers[i]), - copy=False, - dtype=float32)) - self.biases.append( - np.array( - np.random.randn(hidden_layers[i + 1]), copy=False, dtype=float32)) - - def forward(self, x): - """Performs the forward pass. - - Args: - x: 2-d array of size batch_size x image_size. - - Returns: - A 2-d array of size batch_size x num_classes. - """ - - def sigmoid(x): - return 1.0 / (1.0 + np.exp(-x)) - - for w, b in zip(self.weights, self.biases): - x = sigmoid(np.dot(w, x.T).T + b) - return x - - def train(self, x, y, learning_rate=0.01): - """Runs a single training pass. - - Args: - x: 2-d array of size batch_size x image_size. - y: 2-d array of size batch_size x num_classes in one-hot notation. - learning_rate: The learning rate. - """ - x = np.array(x, copy=False) - y = np.array(y, copy=False) - - def mean_squared_error(x, y): - diff = x - y - return np.sum(diff * diff) / len(x) - - wb_tensors = self.weights + self.biases - with tf.GradientTape() as g: - g.watch(wb_tensors) - loss = mean_squared_error(self.forward(x), y) - gradients = g.gradient(loss, wb_tensors) - gradients = [np.asarray(grad) for grad in gradients] - - new_weights_and_biases = [] - for v, dv in zip(self.weights + self.biases, gradients): - new_weights_and_biases.append(v - learning_rate * dv) - - total_len = len(new_weights_and_biases) - self.weights = new_weights_and_biases[:total_len // 2] - self.biases = new_weights_and_biases[total_len // 2:] - - def evaluate(self, x, y): - """Returns the number of correct predictions. - - Args: - x: 2-d array of size batch_size x image_size. - y: 2-d array of size batch_size x num_classes. - - Returns: - A scalar, the number of correct predictions. - """ - y_actual = np.argmax(y, axis=1) - y_predicted = np.argmax(self.forward(x), axis=1) - return int( - np.sum(np.array(y_actual == y_predicted, copy=False, dtype=int32))) diff --git a/trax/tf_numpy/examples/mnist/train.py b/trax/tf_numpy/examples/mnist/train.py deleted file mode 100644 index 766c71998..000000000 --- a/trax/tf_numpy/examples/mnist/train.py +++ /dev/null @@ -1,84 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Perform training.""" -from absl import app -from absl import flags - -from six.moves import range -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.examples.mnist import dataset -from trax.tf_numpy.examples.mnist import model as model_lib - -FLAGS = flags.FLAGS - -flags.DEFINE_integer('batch_size', 50, 'Batch size.') -flags.DEFINE_integer('num_training_iters', 10000, - 'Number of iterations to train for.') -flags.DEFINE_integer( - 'validation_steps', 100, - 'Validation is performed every these many training steps.') -flags.DEFINE_float('learning_rate', 5.0, 'Learning rate.') - - -def train(batch_size, learning_rate, num_training_iters, validation_steps): - """Runs the training.""" - print('Loading data') - training_data, validation_data, test_data = dataset.load() - print('Loaded dataset with {} training, {} validation and {} test examples.'. - format( - len(training_data[0]), len(validation_data[0]), len(test_data[0]))) - - assert len(training_data[0]) % batch_size == 0 - assert len(validation_data[0]) % batch_size == 0 - assert len(test_data[0]) % batch_size == 0 - - def build_iterator(data, infinite=True): - """Build the iterator for inputs.""" - index = 0 - size = len(data[0]) - while True: - if index + batch_size > size: - if infinite: - index = 0 - else: - return - yield data[0][index:index + batch_size], data[1][index:index + batch_size] - index += batch_size - - train_iter = build_iterator(training_data) - model = model_lib.Model([30]) - - for i in range(num_training_iters): - train_x, train_y = next(train_iter) - model.train(train_x, train_y, learning_rate) - if (i + 1) % validation_steps == 0: - validation_iter = build_iterator(validation_data, infinite=False) - correct_predictions = 0 - for valid_x, valid_y in validation_iter: - correct_predictions += model.evaluate(valid_x, valid_y) - print('{}/{} correct validation predictions.'.format( - correct_predictions, len(validation_data[0]))) - - -def main(unused_argv): - train(FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_training_iters, - FLAGS.validation_steps) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - app.run(main) diff --git a/trax/tf_numpy/examples/mnist/train_test.py b/trax/tf_numpy/examples/mnist/train_test.py deleted file mode 100644 index 55a6a5eb4..000000000 --- a/trax/tf_numpy/examples/mnist/train_test.py +++ /dev/null @@ -1,60 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test that the example training script works on fake data.""" -import mock -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.examples.mnist import dataset -from trax.tf_numpy.examples.mnist import train - - -class TFNumpyMnistExampleTest(tf.test.TestCase): - - def testRuns(self): - with mock.patch.object(dataset, 'load', new=fake_mnist_data): - train.train( - batch_size=1, - learning_rate=0.1, - num_training_iters=10, - validation_steps=5) - train.train( - batch_size=2, - learning_rate=0.1, - num_training_iters=5, - validation_steps=2) - train.train( - batch_size=10, - learning_rate=0.1, - num_training_iters=1, - validation_steps=1) - - -def fake_mnist_data(): - - def gen_examples(num_examples): - x = np.array( - np.random.randn(num_examples, 784), copy=False, dtype=np.float32) - y = np.zeros((num_examples, 10), dtype=np.float32) - y[:][0] = 1. - return (x, y) - - return (gen_examples(100), gen_examples(10), gen_examples(10)) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/extensions/extensions.py b/trax/tf_numpy/extensions/extensions.py deleted file mode 100644 index 46c4261e8..000000000 --- a/trax/tf_numpy/extensions/extensions.py +++ /dev/null @@ -1,1995 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Extensions such as `jit`, `grad`, `logsumexp`, etc.""" -import bisect -import contextlib -import copy -import functools -import string -import sys -import threading -import numpy as np -import six - -import tensorflow.compat.v2 as tf - -import trax.tf_numpy.numpy as tf_np - -_int_dtype_lower_bounds = [ - -2**63, -2**31, -2**15, -2**7, 0, 2**7, 2**15, 2**31, 2**64 -] -_int_dtypes = [ - tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32, - tf.uint64 -] -_tf_nn_APIs = {1: [tf.nn.conv1d, tf.nn.conv1d_transpose], - 2: [tf.nn.conv2d, tf.nn.conv2d_transpose], - 3: [tf.nn.conv3d, tf.nn.conv3d_transpose]} - - -remat = tf.recompute_grad - - -def most_precise_int_dtype(x): - if not isinstance(x, six.integer_types) or isinstance(x, bool): - return None - i = bisect.bisect_right(_int_dtype_lower_bounds, x) - if i in (0, len(_int_dtype_lower_bounds)): - raise ValueError("Integer %s is out of bounds" % x) - assert len(_int_dtype_lower_bounds) == len(_int_dtypes) + 1 - return _int_dtypes[i - 1] - - -def _canonicalize_jit_arg(x): - if isinstance(x, tf_np.ndarray): - return x - try: - # We need to convert `int` to the most precise dtype, otherwise the dtype - # of the result may be different from numpy's. For example, when a binary - # op takes in a Python integer 5 and an array of uint32, numpy will pick - # uint32 as 5's dtype, while tf.convert_to_tensor will choose int32 which - # will cause the two arguments to be promoted to int64. We pick uint8 - # here, which will be promoted to uint32 by the binary op. - # Note that we prefer unsigned int to signed int when both are equally - # precise. For example, for 5, we pick uint8 instead of int8. There is no - # reason to prefer one to the other, because for each there is a case - # where the behavior diverges from numpy. If we prefer signed int, - # consider the case where the first operand is 5 and the second is - # 2**64-1. Numpy picks uint64 as the result dtype, but because we choose a - # signed type for 5 such as int8, the result type will be float64. On the - # other hand, if we prefer unsigned int, consider the case where the first - # operand is 2**31-1 and the second is -1. Numpy will pick int32, but - # because we choose uint32 for 2*32-1, the result will be int64. The root - # of the problem is that `jit` converts `int` to tensors (hence committing - # to a dtype) too early, when we don't have enough information about the - # jitted function (e.g. which subset of the arguments should be promoted - # together using np.result_type). tf.function doesn't have this problem - # because it doesn't convert `int` to tensors. jax.jit doesn't have this - # problem because it converts `int` to "int tracer" which doesn't commit - # to a dtype. - # TODO(wangpeng): Revisit this design and see whether we can improve `jit` - # and tf.function. - dtype = most_precise_int_dtype(x) - if dtype is None and isinstance(x, float): - dtype = tf_np.default_float_type() - return tf.convert_to_tensor(value=x, dtype=dtype) - except (TypeError, ValueError): - return x - - -def _canonicalize_jit_arguments(inp): - """Canonicalize arguments to be used for jit. - - Args: - inp: a nested structure of arguments to be canonicalized (i.e. to be - converted to Tensors). Only tf_np.ndarray and things accepted by - `tf.convert_to_tensor` will be converted. - - Returns: - The canonicalized version. - """ - return tf.nest.map_structure(_canonicalize_jit_arg, inp) - - -def _tf_to_np(inp): - - def f(x): - if isinstance(x, tf.IndexedSlices): - return tf_np.asarray(x) - else: - return x - - return tf.nest.map_structure(f, inp) - - -def stop_gradient(x): - - def static_stop_gradient(x): - # `tf.stop_gradient` is a no-op for non-Tensor. Returning the original type - # allows it to be used in the conditional without Autograph, if static. For - # example: - # `if fastmath.stop_gradient(5) > 4:` - return tf.stop_gradient(x) if tf.is_tensor(x) else x - - return _tf_to_np(tf.nest.map_structure(static_stop_gradient, x)) - - -def custom_grad(f_vjp, f_original=None): - """Decorator to define a function with a custom gradient. - - This function is very similar to `tf.custom_gradient`. See the documentation - of `tf.custom_gradient` for detailed usage. - - The differences with `tf.custom_gradient` are: - - - All arguments and results are tf_np.ndarrays instead of tensors. - - - The `grad_fn` returned by `f_vjp` accepts and returns nested structures, - unlike that in `tf.custom_gradient` which only accepts and returns lists. - - Args: - f_vjp: the same as the `f` argument of `tf.custom_gradient`. Note that all - inputs and outputs of `f_vjp` and of the `grad_fn` function it returns can - be nested structures. - f_original: (optional) not used. - - Returns: - The same as `tf.custom_gradient`. - """ - del f_original - - @tf.custom_gradient - def tf_f(*tf_args, **tf_kwargs): - np_args = _tf_to_np(tf_args) - np_kwargs = _tf_to_np(tf_kwargs) - np_y, np_vjp = f_vjp(*np_args, **np_kwargs) - tf_y = np_y - - def tf_vjp(*flat_tf_dy): - tf_dy = tf.nest.pack_sequence_as(tf_y, flat_tf_dy) - np_dy = _tf_to_np(tf_dy) - np_dx = np_vjp(np_dy) - return tf.nest.flatten(np_dx) - - return tf_y, tf_vjp - - def np_f(*args, **kwargs): - return _tf_to_np(tf_f(*args), **kwargs) - - return np_f - - -def vjp(f, *primals, has_aux=False): - """Returns the result and the VJP function of `f`. - - This function returns the result and the vector-Jacobian-product (VJP) - function of `f`. - - Args: - f: a function from (nested structures of) tf_np.ndarrays to a (nested - structure of) tf_np.ndarray. If `has_aux` is True, it should return an - extra output. - *primals: the inputs to be fed to `f`. - has_aux: if True, the second output of `f` will be regarded as an auxiliary, - non-differentiable output that will be ignored by the VJP function. - - Returns: - A pair `(y, vjpfun)` if `has_aux` is False; a tuple `(y, vjpfun, aux)` - otherwise. `y` and `aux` are the outputs of `f`, i.e. `y, aux = - f(*primals)`. `vjpfun` is a function `dx = vjpfun(dy)`, where `dy` is the - cotengents of `y`, having the same structures, shapes and dtypes as - `y`. `dx` is the cotengents of `x`, having the same structures, shapes and - dtypes as `x`. - """ - with tf.GradientTape(persistent=True) as tape: - tape.watch(tf.nest.flatten(primals)) - outputs = f(*primals) - if has_aux: - np_out, aux = outputs - else: - np_out = outputs - - def _vjp(dy): - tf_dx = tape.gradient(np_out, primals, output_gradients=dy) - return _tf_to_np(tf_dx) - - if has_aux: - ret = (np_out, _vjp, aux) - else: - ret = (np_out, _vjp) - return ret - - -# TODO(wangpeng): match JAX's handling of kwargs and non-ndarray args -def grad(f, has_aux=False): - """Returns a function that computes gradient of f. - - Gradients can only be computed through numpy and tensorflow operations and not - through python float operations and values. - - Args: - f: a function of type (params, *args) -> scalar. 'params' can be a nested - structure (made of lists and tuples) of ndarrays and the gradient is - evaluated against it. `scalar` is a scalar ndarray. - has_aux: bool, indicates whether fun returns a pair where the first element - is considered the output of the mathematical function to be differentiated - and the second element is auxiliary data. - - Returns: - A gradient function of type (params, *args) -> gradients, where the result - 'gradients' has the same structure and shapes as 'params'. - """ - - def check_loss_shape(np_loss): - if not isinstance(np_loss, tf_np.ndarray): - raise ValueError( - "The result of the function to take gradient must be an ndarray.") - if not np_loss.shape.is_compatible_with([]): - raise ValueError( - "The result of the function to take gradient must be a scalar.") - - def _f(params, *args): - """The gradient function to be returned.""" - with tf.GradientTape() as g: - g.watch(tf.nest.flatten(params)) - outputs = f(params, *args) - if has_aux: - np_loss, aux = outputs - else: - np_loss = outputs - check_loss_shape(np_loss) - tf_grads = g.gradient(np_loss, params) - if has_aux: - res = (tf_grads, aux) - else: - res = tf_grads - return _tf_to_np(res) - - return _f - - -def _record_result_type(recorder, f): - """A decorator that records some information about the function. - - Args: - recorder: a function of signature `(args, kwargs, res) -> res`. - f: the original function. - - Returns: - A transformed function that calls the original function and then the - recorder afterwards. - """ - def wrapper(*args, **kwargs): - res = f(*args, **kwargs) - res = recorder(args, kwargs, res) - return res - - return wrapper - - -def jit(f, - static_argnums=(), - xla_forced_compile=False, - input_signature=None, - autograph=False, - experimental_compile=False): - """Returns a function that runs a trace-compiled version of `f`. - - A trace-compiled version of a function `f` has the same behavior as `f` (when - called with the same "static arguments", see below), but runs faster because - the whole computation is compiled into a computation graph once which is - reused for subsequent executions. - - The trace compilation happens lazily, when the returned function is called for - the first time. The compiled function may not be cached implicitly and - multiple calls to `jit` may not share the compiled function (see below for - "static" vs "dynamic" arguments). - - Args: - f: a function that takes any positional arguments `args` and any keyword - arguments `kwargs`. `ndarray`s and things accepted by - `tf.convert_to_tensor` in `args` and `kwargs` will be treated as 'dynamic - arguments' in the sense that calling the function with different values - for these arguments will not cause retracing. In contrast, arguments of - other types in `args` and `kwargs` are treated as 'static arguments' and - calling the function with different values of them will cause - re-compiling. Positional arguments whose positions are in `static_argnums` - are always treated as static arguments. - static_argnums: a tuple of positions of arguments that will be treated as - static arguments. Note that as aforementioned, any arguments that were not - convertible to tensor will also be static. - xla_forced_compile: if true, it will use XLA to force-compile the graph. - This requires that the function only contain ops that are XLA - compatible. It will compile the entire function into a single XLA op. - input_signature: a list of `tf.TensorSpec`, as the input signature to - control tracing behavior. See the - [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of - `tf.function` for details. - autograph: whether to use autograph to convert Python constructs such as - `if` and `while` to their TensorFlow counterparts. See the - [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of - `tf.function` for details. - experimental_compile: the `experimental_compile` flag for `tf.function`. See - the [doc](https://www.tensorflow.org/api_docs/python/tf/function]) of - `tf.function` for details. This is the recommended way to turn on XLA for - tf.function, but unlike xla_forced_compile, it doesn't force-compile the - entire function into a single XLA op. - - Returns: - A trace-compiled version of f. - """ - - @tf.function(input_signature=input_signature, autograph=autograph, - experimental_compile=experimental_compile) - def _tf_f(*args, **kwargs): - """Accelerated function with tensor inputs/outputs.""" - np_args = _tf_to_np(args) - kwargs = {k: _tf_to_np(v) for k, v in kwargs.items()} - if xla_forced_compile: - # Use list for mutability - output_is_list = [False] - output_is_empty = [False] - output_structure = [None] - def recorder(args, kwargs, res): - del args, kwargs - # Workaround b/121383831 - output_is_list[0] = isinstance(res, list) - # If outputs are empty, xla.compile returns an `Operation`, which we - # don't want. - if tf.nest.flatten(res): - output_is_empty[0] = False - output_structure[0] = None - else: - output_is_empty[0] = True - # Without deepcopy, xla.compile will change output_structure[0] to a - # list of `Operation`. - output_structure[0] = copy.deepcopy(res) - return res - f_ = _record_result_type(recorder, f) - np_out = tf.xla.experimental.compile(lambda: f_(*np_args, **kwargs)) - # Workaround b/121383831 - if output_is_empty[0]: - np_out = output_structure[0] - elif (isinstance(np_out, list) and len(np_out) == 1 and - not output_is_list[0]): - np_out = np_out[0] - else: - np_out = f(*np_args, **kwargs) - return np_out - - def _f(*args, **kwargs): - args = [ - _canonicalize_jit_arguments(arg) if i not in static_argnums else arg - for i, arg in enumerate(args) - ] - kwargs = {k: _canonicalize_jit_arguments(v) for k, v in kwargs.items()} - tf_out = _tf_f(*args, **kwargs) - return _tf_to_np(tf_out) - - _f.tf_function = _tf_f - - return _f - - -def eval_on_shapes(f, static_argnums=(), allow_static_outputs=False): - """Returns a function that evaluates `f` given input shapes and dtypes. - - It transforms function `f` to a function that performs the same computation as - `f` but only on shapes and dtypes (a.k.a. shape inference). - - Args: - f: the function to be transformed. - static_argnums: see documentation of `jit`. - allow_static_outputs: whether to allow non-array outputs. If True, non-array - outputs (e.g. Python integers) will be returned as-is; otherwise, they - will be converted to ndarrays, and then specs of those ndarrays will be - returned. - - Returns: - A function whose input arguments can be either the same as `f`'s or only - their shapes/dtypes represented by `tf.TensorSpec`, and whose return values - are `tf.TensorSpec`s with the same nested structure as `f`'s return - values. If `allow_static_outputs` is True, when `f` returns some non-array - outputs (e.g. Python integers), the converted function will return them - as-is instead of returning `tf.TensorSpec`s for them. - """ - def abstractify(args): - def _abstractify(x): - x = _canonicalize_jit_arg(x) - if isinstance(x, (tf.Tensor, tf_np.ndarray)): - return tf.TensorSpec(x.shape, x.dtype) - else: - return x - new_args = [] - for i, arg in enumerate(args): - if i in static_argnums: - new_args.append(arg) - else: - new_args.append(tf.nest.map_structure(_abstractify, arg)) - return new_args - - if allow_static_outputs: - # When `tf_f` below is called (via get_concrete_function) with the same - # arugments (after abstraction), the Python function `f` won't be run, so we - # need this python_outputs_map to retrieve the Python outputs we've seen - # before that correspond the arguments. - python_outputs_map = {} - def recorder(args, kwargs, res): - # Since the get_concrete_function below only uses positional args, we also - # only positional args here. - del args, kwargs - def is_tensor_like(x): - if hasattr(x, "_type_spec"): - return True # x is a CompositeTensor - return isinstance(x, (tf_np.ndarray, tf.Tensor)) - py_values = tf.nest.map_structure( - lambda x: None if is_tensor_like(x) else x, - res) - key = id(tf.compat.v1.get_default_graph()) - python_outputs_map[key] = py_values - # Set non-tensor outputs to None to avoid tf.function calling - # tf.convert_to_tensor on them. - res = tf.nest.map_structure( - lambda x: None if not is_tensor_like(x) else x, - res) - return res - f = _record_result_type(recorder, f) - - # TODO(wangpeng): tf.function could add a knob to turn off materializing the - # graph, so that we don't waste computation and memory when we just want - # shape inference. - tf_f = jit(f, static_argnums=static_argnums).tf_function - - # pylint: disable=missing-docstring - def f_return(*args): - def to_tensor_spec(x): - if isinstance(x, tf.Tensor): - return tf.TensorSpec(x.shape, x.dtype) - else: - return x - - new_args = abstractify(args) - cfun = tf_f.get_concrete_function(*new_args) - res = cfun.structured_outputs - res = tf.nest.map_structure(to_tensor_spec, res) - - if allow_static_outputs: - key = id(cfun.graph) - py_values = python_outputs_map[key] - # We can also call tf.get_static_value on structured_outputs to retrieve - # the Python values, but since we'll need to use python_outputs_map to - # record "which outputs are static?" anyway, we choose to directly store - # the Python values in python_outputs_map. - res = tf.nest.map_structure( - lambda x, python_value: x if python_value is None else python_value, - res, py_values) - - return res - - # Provides access to `tf_f` for testing purpose. - f_return._tf_function = tf_f # pylint: disable=protected-access - return f_return - - -def _index_update_helper(updater, x, idx, y): - x = tf_np.asarray(x) - y = tf_np.asarray(y) - # TODO(b/164251540): Remove this expensive manual broadcasting once - # tf.raw_ops.tensor_strided_slice_update and tf.tensor_scatter_nd_update - # support broadcasting. - y = tf.broadcast_to(y, tf.shape(x[idx])) - return updater(x, idx, y) - - -# pylint: disable=protected-access -def index_update(x, idx, y): - """Pure equivalent of `x[idx] = y`. - - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] = y`. Because it's a pure function, `x` itself won't be - changed. - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. - - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_update, x, idx, y) - - -def index_add(x, idx, y): - """Pure equivalent of `x[idx] += y`. - - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] += y`. Because it's a pure function, `x` itself won't be - changed. - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. - - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_add, x, idx, y) - - -def index_min(x, idx, y): - """Pure equivalent of `x[idx] = minimum(x[idx], y)`. - - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] = minimum(x[idx], y)`. Because it's a pure function, `x` - itself won't be changed. - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. - - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_min, x, idx, y) - - -def index_max(x, idx, y): - """Pure equivalent of `x[idx] = maximum(x[idx], y)`. - - Returns the value of x that would result from the NumPy-style indexed - assignment `x[idx] = maximum(x[idx], y)`. Because it's a pure function, `x` - itself won't be changed. - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, slice objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. - y: the array of updates. `y` must be broadcastable to the shape of the array - that would be returned by `x[idx]`. - - Returns: - The updated version of `x`. - """ - return _index_update_helper(tf_np.ndarray._with_index_max, x, idx, y) -# pylint: enable=protected-access - - -def logsumexp(x, axis=None, keepdims=None): - """Computes log(sum(exp(elements across dimensions of a tensor))). - - Reduces `x` along the dimensions given in `axis`. - Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - entry in `axis`. If `keepdims` is true, the reduced dimensions - are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a - tensor with a single element is returned. - This function is more numerically stable than log(sum(exp(input))). It avoids - overflows caused by taking the exp of large inputs and underflows caused by - taking the log of small inputs. - - Args: - x: The tensor to reduce. Should have numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all - dimensions. Must be in the range `[-rank(x), rank(x))`. - keepdims: If true, retains reduced dimensions with length 1. - - Returns: - The reduced tensor. - """ - return tf_np.asarray( - tf.math.reduce_logsumexp( - input_tensor=x, axis=axis, keepdims=keepdims)) - - -def expit(x): - """Compute 1 / (1 + exp(-x)).""" - return tf_np.asarray(tf.math.sigmoid(x)) - - -def erf(x): - """Computes the Gauss error function of x element-wise.""" - return tf_np.asarray(tf.math.erf(x)) - - -def _minus(a, b): - return [x for x in a if x not in b] - - -def _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, - lhs_batch, rhs_batch): - """Compose the output string representation. - - e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik - aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik - - Args: - lhs_rep: A string representation for the left-hand side input array - rhs_rep: A string representation for the right-hand side input array - lhs_contraction: Sequence[int] (the contraction dimensions of lhs) - rhs_contraction: Sequence[int] (the contraction dimensions of rhs) - lhs_batch: Sequence[int] (the batch dimensions of lhs) - rhs_batch: Sequence[int] (the batch dimensions of rhs) - - Returns: - A string representation of the result array. - """ - output_rep = [] - for dim in lhs_batch: - output_rep.append(lhs_rep[dim]) - - for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction): - output_rep.append(lhs_rep[i]) - for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction): - output_rep.append(rhs_rep[i]) - return "".join(output_rep) - - -def _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction): - """Compute the non-batched matrix multiplication. - - If it is the general non-batched/single-batched matrix multiplication, - use the highly optimized kernel `tf.tensordot` to handle it. - - Args: - lhs: an array (the left-hand side matrix/vector to be multiplied) - rhs: an array (the right-hand side matrix/vector to be multiplied) - lhs_contraction: Sequence[int] (the contraction dimensions of lhs) - rhs_contraction: Sequence[int] (the contraction dimensions of rhs) - - Returns: - An array that contains the result. - """ - return tf.tensordot( - lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction))) - - -def tf_dot_general(lhs, rhs, dimension_numbers): - """The general dot operation for TensorFlow. - - An equivalent general dot operation as that in JAX - - - Although there is an implementation in TF XLA, avoid directly using XLA when - possible. - - e.g., non-batched: ij,jk->ik - batched: ijk,ikl->ijl - - Args: - lhs: an array (the left-hand side matrix/vector to be multiplied) - rhs: an array (the right-hand side matrix/vector to be multiplied) - dimension_numbers: (Tuple[Tuple[Sequence[int], Sequence[int]], - Tuple[Sequence[int], Sequence[int]]]) – a tuple of tuples of the form - ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, - rhs_batch_dims)) - - Returns: - An array that contains the result. - """ - char_list = list(string.ascii_lowercase) - char_list = char_list[8:] + char_list[:8] - lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape) - lhs_rep = char_list[:lhs_rank] - rhs_rep = char_list[lhs_rank:lhs_rank + rhs_rank] - contraction, batch = dimension_numbers - lhs_contraction, rhs_contraction = contraction - if len(lhs_contraction) != len(rhs_contraction): - raise ValueError( - "The input matrices are required to have the same number " - "of contraction dimensions, but got: lhs {}, rhs: {}".format( - len(lhs_contraction), len(rhs_contraction))) - lhs_batch, rhs_batch = batch - if len(lhs_batch) != len(rhs_batch): - raise ValueError("The input matrices are required to have the same number " - "of batch dimensions, but got: lhs {}, rhs: {}".format( - len(lhs_batch), len(rhs_batch))) - - if not lhs_batch and not rhs_batch: - return _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) - - if (lhs_rank == rhs_rank == 3 and lhs_batch == (0,) and rhs_batch == (0,) and - lhs_contraction == (2,) and rhs_contraction == (1,)): - return tf.linalg.matmul(lhs, rhs) - - for i in range(len(lhs_contraction)): - rhs_rep[rhs_contraction[i]] = lhs_rep[lhs_contraction[i]] - for i in range(len(lhs_batch)): - rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]] - - output_rep = _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, - rhs_contraction, lhs_batch, rhs_batch) - equation = "".join(lhs_rep) + "," + "".join(rhs_rep) + "->" + output_rep - return tf.einsum(equation, lhs, rhs) - - -def _conv_general_param_type_converter(window_strides, lhs_dilation, - rhs_dilation, dim): - """Convert strides, lhs_dilation, rhs_dilation to match TF convention. - - For example, - in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] - if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] - - Args: - window_strides: window_strides to be converted - lhs_dilation: lhs_dilation to be converted - rhs_dilation: rhs_dilation to be converted - dim: dim to be converted - - Returns: - The updated window_strides, lhs_dilation and rhs_dilation - """ - def _as_list_of_size(item, size): - if item is None: - return None - return [item] * size if isinstance(item, int) else list(item) - return (_as_list_of_size(window_strides, dim), - _as_list_of_size(lhs_dilation, dim), - _as_list_of_size(rhs_dilation, dim)) - - -# pylint: disable=g-bad-todo -# TODO(DarrenZhang01): Expand the test cases of general convolution and revise -# the according bugs. -# TODO(DarrenZhang01): Support feature_group_count, batch_group_count and -# precision, and allow lhs_dilation and rhs_dilation to happen at the same time. -# pylint: enable=g-bad-todo -def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, - lhs_dilation=None, rhs_dilation=None, - dimension_numbers=None, feature_group_count=1, - batch_group_count=1, precision=None): - """A general conv API for TensorFlow. - - According JAX version: - https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html - - Args: - lhs: a rank n+2 dimensional input array. - rhs: a rank n+2 dimensional array of kernel weights. - window_strides: a sequence of n integers, representing the inter-window - strides. - padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n - (low, high) integer pairs that give the padding to apply before and - after each spatial dimension. - output_shape: the output shape of the convolution (only required for - transpose convolution). - lhs_dilation: None, or a sequence of n integers, giving the dilation factor - to apply in each spatial dimension of lhs. LHS dilation is - also known as transposed convolution. - rhs_dilation: None, or a sequence of n integers, giving the dilation factor - to apply in each spatial dimension of rhs. RHS dilation is - also known as atrous convolution. - dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple - (lhs_spec, rhs_spec, out_spec), where each element is a - string of length n+2. - feature_group_count: integer, default 1. Changing this is currently not - supported. - batch_group_count: integer, default 1. Changing this is currently not - supported. - precision: Optional. Either None, which means the default precision for the - backend, or a Precision enum value. - - Returns: - A TF NumPy array that contains the convolution result. - """ - dim = None - lhs_spec, rhs_spec, out_spec = dimension_numbers - if lhs_spec != out_spec: - raise ValueError("Current implementation requires the `data_format` of the " - "inputs and outputs to be the same.") - if len(lhs_spec) >= 6: - raise ValueError("Current implmentation does not support 4 or higher" - "dimensional convolution, but got: ", len(lhs_spec) - 2) - dim = len(lhs_spec) - 2 - if lhs_dilation and rhs_dilation: - if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: - lhs_dilation, rhs_dilation = None, None - else: - raise ValueError("Current implementation does not support that " - "deconvolution and dilation to be performed at the same " - "time, but got lhs_dilation: {}, rhs_dilation: {}" - .format(lhs_dilation, rhs_dilation)) - if padding not in ["SAME", "VALID"]: - raise ValueError("Current implementation requires the padding parameter" - "to be either 'VALID' or 'SAME', but got: ", padding) - if batch_group_count != 1 or feature_group_count != 1: - raise NotImplementedError("batch_group_count and feature_group_count " - "other than 1 is currently not supported, but" - " got feature_group_count: {}, batch_group_count" - ": {}".format(feature_group_count, - batch_group_count)) - if precision is not None: - raise NotImplementedError("precision other than `None` is currently not " - "supported, but got: {}".format(precision)) - # Convert params from int/Sequence[int] to list of ints. - strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( - window_strides, lhs_dilation, rhs_dilation, dim - ) - # Preprocess the shapes - dim_maps = {} - if isinstance(lhs_spec, str): - dim_maps["I"] = list(rhs_spec).index("I") - dim_maps["O"] = list(rhs_spec).index("O") - dim_maps["N"] = list(lhs_spec).index("N") - dim_maps["C"] = list(lhs_spec).index("C") - else: - dim_maps["I"] = rhs_spec[1] - dim_maps["O"] = rhs_spec[0] - dim_maps["N"] = lhs_spec[0] - dim_maps["C"] = lhs_spec[1] - - lhs = tf_np.moveaxis(lhs, (dim_maps["N"], dim_maps["C"]), (0, dim + 1)) - # Adjust the filters, put the dimension 'I' and 'O' at last. - rhs = tf_np.moveaxis(rhs, (dim_maps["O"], dim_maps["I"]), (dim + 1, dim)) - spatial_dim_maps = {1: "W", 2: "HW", 3: "DHW"} - data_format = "N" + spatial_dim_maps[dim] + "C" - - if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): - output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, - rhs_dilation) - else: - output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, - padding, data_format, lhs_dilation) - output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps["N"], dim_maps["C"])) - return output - - -def conv(inp, - fltr, - window_strides, - padding, - dimension_numbers, - filter_dilation=None): - """Convolution over an N-D array. - - See https://www.tensorflow.org/api_docs/python/tf/nn/convolution and - https://www.tensorflow.org/xla/operation_semantics#conv_convolution for - reference. - - Args: - inp: an (N+2)-D array. The input of the convolution. - fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. - window_strides: a sequence of N ints, the strides for moving the convolution - window. - padding: a string, either "VALID" or "SAME". The padding algorithm. - dimension_numbers: a tuple of three strings encoding the data format of - input, filter and output. "I" means input; "O" means output; "C" means - channel; other characters such as "W", "H" and "D" means spatial - dimensions. - filter_dilation: the dilation rates for the filter. Dilating the filter - means adding "holes" to the filter. - - Returns: - An (N+2)-D array. The convolution result. - """ - input_spec, filter_spec, output_spec = dimension_numbers - if input_spec != output_spec: - raise ValueError("Input and output data formats must be the same; got %s " - "and %s" % (input_spec, output_spec)) - supported_filter_spec = ["WIO", "HWIO", "DHWIO"] - if filter_spec not in supported_filter_spec: - raise ValueError("The supported data format for the filter are %s; got %s" % - (supported_filter_spec, filter_spec)) - if input_spec[1:-1] != filter_spec[:-2]: - raise ValueError("Input data format (%s) is not compatible with filter " - "data format (%s)" % (input_spec, filter_spec)) - # No type promotion in order to prevent accidentally doing more expensive - # computation. - dtype = tf_np.result_type(inp, fltr) - inp = tf_np.asarray(inp, dtype) - fltr = tf_np.asarray(fltr, dtype) - return tf_np.asarray( - tf.nn.convolution( - input=inp, - filters=fltr, - padding=padding, - strides=window_strides, - dilations=filter_dilation, - data_format=input_spec)) - - -def avg_pool(x, pool_size, strides, padding): - """Performs an N-D average pooling. - - Args: - x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + - [num_channels]`. Pooling happens over the spatial dimensions only. - pool_size: sequence of N ints. - strides: sequence of N ints. - padding: a string, the padding algorithm. Must be "SAME" or "VALID". - - Returns: - An (N+2)-D array, of shape - [batch_size] + output_spatial_shape + [num_channels], - where `output_spatial_shape` depends on the value of padding: - If padding = "SAME": - output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) - If padding = "VALID": - output_spatial_shape[i] = - ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). - """ - x = tf_np.asarray(x) - return tf_np.asarray( - tf.nn.pool( - input=x, - window_shape=pool_size, - pooling_type="AVG", - strides=strides, - padding=padding)) - - -def max_pool(x, pool_size, strides, padding): - """Performs an N-D max pooling. - - Args: - x: ndarray of rank N+2, of shape `[batch_size] + input_spatial_shape + - [num_channels]`. Pooling happens over the spatial dimensions only. - pool_size: sequence of N ints. - strides: sequence of N ints. - padding: a string, the padding algorithm. Must be "SAME" or "VALID". - - Returns: - An (N+2)-D array, of shape - [batch_size] + output_spatial_shape + [num_channels], - where `output_spatial_shape` depends on the value of padding: - If padding = "SAME": - output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) - If padding = "VALID": - output_spatial_shape[i] = - ceil((input_spatial_shape[i] - (pool_size[i] - 1)) / strides[i]). - """ - x = tf_np.asarray(x) - return tf_np.asarray( - tf.nn.pool( - input=x, - window_shape=pool_size, - pooling_type="MAX", - strides=strides, - padding=padding)) - - -def sort_key_val(keys, values, dimension=-1): - """Sorts keys along a dimension and applies same permutation to values. - - Args: - keys: an array. The dtype must be comparable numbers (integers and reals). - values: an array, with the same shape of `keys`. - dimension: an `int`. The dimension along which to sort. - - Returns: - Permuted keys and values. - """ - keys = tf_np.asarray(keys) - values = tf_np.asarray(values) - rank = keys.shape.ndims - if rank is None: - rank = values.shape.ndims - if rank is None: - # We need to know the rank because tf.gather requires batch_dims to be `int` - raise ValueError("The rank of either keys or values must be known, but " - "both are unknown (i.e. their shapes are both None).") - if dimension in (-1, rank - 1): - - def maybe_swapaxes(a): - return a - else: - - def maybe_swapaxes(a): - return tf_np.swapaxes(a, dimension, -1) - - # We need to swap axes because tf.gather (and tf.gather_nd) supports - # batch_dims on the left but not on the right. - # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis. - keys = maybe_swapaxes(keys) - values = maybe_swapaxes(values) - idxs = tf_np.argsort(keys) - - # Using tf.gather rather than np.take because the former supports batch_dims - def gather(a): - return tf_np.asarray(tf.gather(a, idxs, batch_dims=rank - 1)) - - keys = gather(keys) - values = gather(values) - keys = maybe_swapaxes(keys) - values = maybe_swapaxes(values) - return keys, values - - -def scan(f, init, xs, length=None, reverse=False): - """Scan a function over leading array axes while carrying along state. - - See the docstring of `jax.lax.scan` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) for - details. - - Args: - f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning - that ``f`` accepts two arguments where the first is a value of the loop - carry and the second is a slice of ``xs`` along its leading axis, and that - ``f`` returns a pair where the first element represents a new value for - the loop carry and the second represents a slice of the output. Note that - the input and output carry must have the same dtype. - init: an initial loop carry value of type ``c``, which can be a scalar, - array, or any pytree (nested Python tuple/list/dict) thereof, representing - the initial loop carry value. This value must have the same structure as - the first element of the pair returned by ``f``. - xs: the value of type ``[a]`` over which to scan along the leading axis, - where ``[a]`` can be an array or any pytree (nested Python - tuple/list/dict) thereof with consistent leading axis sizes. - length: optional integer specifying the number of loop iterations, which - must agree with the sizes of leading axes of the arrays in ``xs`` (but can - be used to perform scans where no input ``xs`` are needed). - reverse: optional boolean specifying whether to run the scan iteration - forward (the default) or in reverse, equivalent to reversing the leading - axes of the arrays in both ``xs`` and in ``ys``. - - Returns: - A pair of type ``(c, [b])`` where the first element represents the final - loop carry value and the second element represents the stacked outputs of - the second output of ``f`` when scanned over the leading axis of the inputs. - """ - init, xs = tf.nest.map_structure( - lambda x: tf_np.asarray(x) if x is not None else None, (init, xs)) - if length is not None: - length = int(length) - def get_length(x): - if x is None: - return None - if x.shape.rank == 0: - raise ValueError("Some array in `xs` doesn't have a leading dimension") - return x.shape[0] - lengths = tf.nest.flatten(tf.nest.map_structure(get_length, xs)) - for l in lengths: - if l is not None: - if length is None: - length = l - elif length != l: - raise ValueError("There are two different leading-dimension lengths: " - f"{length} and {l}") - if length is None: - raise ValueError( - "Can't determine length. Please set the `length` argument.") - xs_ta = tf.nest.map_structure( - lambda t: (tf.TensorArray(t.dtype, size=length, dynamic_size=False) # pylint: disable=g-long-lambda - .unstack(t) if t is not None else None), - xs) - # tf.while_loop doesn't allow None in loop_vars, so we mask them. - is_init_none = tf.nest.map_structure(lambda x: x is None, init) - def to_safe(carry): - return tf.nest.map_structure( - lambda x, is_none: tf.zeros([]) if is_none else x, carry, is_init_none) - def from_safe(safe_carry): - return tf.nest.map_structure( - lambda x, is_none: None if is_none else x, safe_carry, is_init_none) - def body(i, safe_carry, ys_ta): - carry = from_safe(safe_carry) - if reverse: - i_ = length - 1 - i - else: - i_ = i - xs = tf.nest.map_structure( - lambda x_ta: x_ta.read(i_) if x_ta is not None else None, xs_ta) - carry, ys = f(*_tf_to_np((carry, xs))) - ys_ta = tf.nest.map_structure( - lambda y_ta, y: (y_ta.write(i_, y) if y is not None else y_ta), - ys_ta, ys) - i = i + 1 - safe_carry = to_safe(carry) - return i, safe_carry, ys_ta - xs_spec = tf.nest.map_structure( - lambda t: tf.TensorSpec(t.shape[1:], t.dtype) if t is not None else None, - xs) - _, ys_spec = eval_on_shapes(f)(init, xs_spec) - # ys_ta can't contain None because tf.while_loop doesn't allow None in - # loop_vars. - ys_ta = tf.nest.map_structure( - lambda y: tf.TensorArray(y.dtype if y is not None else tf.float32, # pylint: disable=g-long-lambda - size=length, dynamic_size=False), - ys_spec) - safe_init = to_safe(init) - _, safe_carry, ys_ta = tf.while_loop( - lambda i, *_: i < length, body, (0, safe_init, ys_ta), - maximum_iterations=length) - carry = from_safe(safe_carry) - def _stack(a, spec): - if spec is None: - return None - a = a.stack() - a.set_shape((length,) + a.shape[1:]) - return a - ys = tf.nest.map_structure(_stack, ys_ta, ys_spec) - return _tf_to_np((carry, ys)) - - -# named "tf_map" instead of "map" as in JAX to avoid conflict with Python `map` -def tf_map(f, xs): - """Map a function over leading array axes. - - See the docstring of `jax.lax.map` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html) for - details. - - Args: - f: a Python function to apply element-wise over the first axis or axes of - `xs`. - xs: values over which to map along the leading axis. - - Returns: - Mapped values. - """ - def g(unused, x): - return unused, f(x) - carry = tf.nest.map_structure(lambda _: None, xs) - return scan(g, carry, xs)[1] - - -def _get_dynamic_indices(operand, start_indices, slice_sizes): - """Calcuates the indices for `tf.gather_nd` from slices. - - Args: - operand: a Tensor to slice. - start_indices: a vector Tensor of integers, one per dimension. The starts of - the slice. The vector can be dynamic. - slice_sizes: a list of integers, one per dimension. The sizes of the slice. - - Returns: - An index array suitable for `tf.gather_nd` and `tf.scatter_nd`, or `None` if - `operand` is a scalar. - """ - rank = len(slice_sizes) - operand_rank = tf.rank(operand) - tf.debugging.Assert(operand_rank == rank, [operand_rank, rank]) - starts_rank = tf.rank(start_indices) - tf.debugging.Assert(starts_rank == 1, [starts_rank]) - num_starts = tf.shape(start_indices)[0] - tf.debugging.Assert(num_starts == rank, [num_starts, rank]) - operand_shape = tf.shape(operand) - tf.debugging.Assert(tf.reduce_all(slice_sizes <= operand_shape), - [slice_sizes, operand_shape]) - if rank == 0: - return None - start_indices = tf.where( - start_indices < 0, start_indices + operand_shape, start_indices) - idx_list = [] - for i in range(rank): - start = start_indices[i] - size = slice_sizes[i] - dim = operand_shape[i] - start = tf.clip_by_value(start, 0, dim - size) - # XLA requires tf.range's `start` to be compile-time constant, so we can't - # do tf.range(start, ...). - idx = start + tf.range(size) - shape = [1] * rank - shape[i] = size - idx = tf.reshape(idx, shape) - idx_list.append(idx) - slice_sizes_tensor = tf.convert_to_tensor(slice_sizes) - # tf.stack doesn't support broadcasting, so we need to broadcast manually. - # TODO(wangpeng): Reduce peak memory by broadcasting one-by-one instead of - # all-together. - idx_list = [tf.broadcast_to(x, slice_sizes_tensor) for x in idx_list] - return tf.stack(idx_list, axis=-1) - - -def dynamic_slice(operand, start_indices, slice_sizes): - """Slicing operation where the indices can be dynamic vlaues. - - See the docstring of `jax.lax.dynamic_slice` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) - for details. - - Args: - operand: an array to slice. - start_indices: a vector of integers, one per dimension. The starts of the - slice. The vector can be dynamic. - slice_sizes: a list of integers, one per dimension. The sizes of the slice. - - Returns: - An array containing the slice, with shape equal to `slice_sizes`. - """ - # This implementation uses tf.gather_nd to implement dynamic_slice, which is - # memory inefficient because the size of `indices` given to gather_nd is - # large. - operand = tf_np.asarray(operand).data - start_indices = tf_np.asarray(start_indices, np.int32).data - idx = _get_dynamic_indices(operand, start_indices, slice_sizes) - if idx is not None: - operand = tf.gather_nd(operand, idx) - return tf_np.asarray(operand) - - -def dynamic_update_slice(operand, update, start_indices): - """Updates a dynamic slice. - - See the docstring of `jax.lax.dynamic_update_slice` - (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_update_slice.html) - for details. - - Args: - operand: an array to slice. - update: an array containing the new values to write onto `operand`. - start_indices: a vector of integers, one per dimension. The starts of the - slice. The vector can be dynamic. - - Returns: - The updated version of `operand`. - """ - operand = tf_np.asarray(operand).data - update = tf_np.asarray(update).data - start_indices = tf_np.asarray(start_indices, np.int32).data - if not update.shape.is_fully_defined(): - raise ValueError("update's shape must be fully defined") - slice_sizes = update.shape - idx = _get_dynamic_indices(operand, start_indices, slice_sizes) - if idx is None: - # `np.zeros([])[()] = 1.0` will result in a scalar array of 1.0 - return tf_np.asarray(update) - operand = tf.tensor_scatter_nd_update(operand, idx, update) - return tf_np.asarray(operand) - - -def dynamic_slice_in_dim(operand, start_index, slice_size, axis=0): - """Convenience wrapper around dynamic_slice applying to one dimension.""" - operand = tf_np.asarray(operand) - start_indices = [0] * operand.ndim - slice_sizes = list(operand.shape) - axis = int(axis) - start_indices[axis] = start_index - slice_sizes[axis] = int(slice_size) - return dynamic_slice(operand, start_indices, slice_sizes) - - -def dynamic_update_slice_in_dim(operand, update, start_index, axis): - """Convenience wrapper around dynamic_update_slice for one dimension.""" - operand = tf_np.asarray(operand) - axis = int(axis) - start_indices = [0] * operand.ndim - start_indices[axis] = start_index - return dynamic_update_slice(operand, update, start_indices) - - -# Use int64 instead of int32 to avoid TF's "int32 problem" -_RNG_KEY_DTYPE = np.int64 - - -def _key2seed(a): - """Converts an RNG key to an RNG seed. - - Args: - a: an RNG key, an ndarray of shape [] and dtype `np.int64`. - - Returns: - an RNG seed, a tensor of shape [2] and dtype `tf.int32`. - """ - - def int64_to_int32s(a): - """Converts an int64 tensor of shape [] to an int32 tensor of shape [2].""" - a = tf.cast(a, tf.uint64) - fst = tf.cast(a, tf.uint32) - snd = tf.cast( - tf.bitwise.right_shift(a, tf.constant(32, tf.uint64)), tf.uint32) - a = [fst, snd] - a = tf.nest.map_structure(lambda x: tf.cast(x, tf.int32), a) - a = tf.stack(a) - return a - - return int64_to_int32s(a) - - -def _seed2key(a): - """Converts an RNG seed to an RNG key. - - Args: - a: an RNG seed, a tensor of shape [2] and dtype `tf.int32`. - - Returns: - an RNG key, an ndarray of shape [] and dtype `np.int64`. - """ - - def int32s_to_int64(a): - """Converts an int32 tensor of shape [2] to an int64 tensor of shape [].""" - a = tf.bitwise.bitwise_or( - tf.cast(a[0], tf.uint64), - tf.bitwise.left_shift( - tf.cast(a[1], tf.uint64), tf.constant(32, tf.uint64))) - a = tf.cast(a, tf.int64) - return a - - return tf_np.asarray(int32s_to_int64(a)) - - -def prng(s): - """Creates RNG state from seed. - - Args: - s: the seed, an integer. - - Returns: - An RNG state, as a scalar array of dtype `np.int64`. - """ - # TODO(wangpeng): Become bitwise-identical to JAX when TF stateless RNGs get - # improved. - return tf_np.asarray(s, dtype=_RNG_KEY_DTYPE) - - -def stateless_split(seed, num=2): - """Splits an RNG seed into `num` new seeds by adding a leading axis. - - Example: - - >>> seed = [1, 2] - >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) - >>> print(new_seeds) - tf.Tensor( - [[1105988140 1738052849] - [-335576002 370444179] - [ 10670227 -246211131]], shape=(3, 2), dtype=int32) - >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) - - - Args: - seed: an RNG seed (a tensor with shape [2] and dtype `int32` or `int64`). - (When using XLA, only `int32` is allowed.) - num: optional, a positive integer or scalar tensor indicating the number of - seeds to produce (default 2). - - Returns: - A tensor with shape [num, 2] representing `num` new seeds. It will have the - same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype - will be determined by `tf.convert_to_tensor`). - """ - seed = tf.convert_to_tensor(seed) - return tf.random.stateless_uniform( - shape=[num, 2], seed=seed, dtype=seed.dtype, minval=None, maxval=None) - - -def split(state, num): - """Creates new independent RNG states from an existing state. - - Args: - state: the existing state. - num: the number of the new states. - - Returns: - A tuple of new states. - """ - state = tf_np.asarray(state, dtype=_RNG_KEY_DTYPE) - state = _key2seed(state) - try: - states = tf.random.experimental.stateless_split(state, num) - except AttributeError as e: # pylint: disable=unused-variable - # TODO(afrozm): For TF < 2.3 we need to do this. Delete once 2.3 launches. - states = stateless_split(state, num) - states = tf.unstack(states, num) - states = tf.nest.map_structure(_seed2key, states) - return states - - -def uniform(key, - shape, - dtype=tf_np.random.DEFAULT_RANDN_DTYPE, - minval=0., - maxval=1.): - """Sample uniform random values in range [`minval`, `maxval`). - - Args: - key: the RNG key. - shape: the shape of the result. - dtype: the dtype of the result. - minval: the minimal value (inclusive). - maxval: the maximal value (exclusive). - - Returns: - An ndarray with shape `shape` and dtype `dtype`. Each value in the ndarray - is sampled uniformly randomly in range [`minval`, `maxval`). - """ - minval = tf.cast(minval, dtype) - maxval = tf.cast(maxval, dtype) - key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) - return tf_np.asarray( - tf.random.stateless_uniform( - shape, seed=_key2seed(key), dtype=dtype, minval=minval, - maxval=maxval)) - - -def normal(key, shape, dtype=tf.float32): - """Sample standard-normal random values. - - Args: - key: the RNG key. - shape: the shape of the result. - dtype: the dtype of the result. - - Returns: - Random values in standard-normal distribution. - """ - key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) - return tf_np.asarray( - tf.random.stateless_normal(shape, seed=_key2seed(key), dtype=dtype)) - - -def bernoulli(key, mean=np.float32(0.5), shape=None): - """Sample Bernoulli random values with given shape and mean. - - Args: - key: the RNG key. - mean: optional, an array_like broadcastable to `shape` for the mean of the - random variables (default 0.5). - shape: optional, a tuple of nonnegative integers representing the shape - (default to `mean`'s shape). - - Returns: - A random array with the specified shape and boolean dtype. - """ - mean = tf_np.asarray(mean) - if shape is None: - shape = mean.shape - return uniform(key, shape) < mean - - -def _eager_dataset_iterator(dataset): - for item in dataset: - yield tf.nest.map_structure(tf_np.asarray, item) - - -def dataset_as_numpy(dataset): - """Converts a `tf.data.Dataset` to an iterable of ndarrays. - - `dataset_as_numpy` converts a possibly nested structure of `tf.data.Dataset`s - and `tf.Tensor`s to iterables of ndarrays and ndarrays, respectively. This - function must be run in eager mode outside tf.function. - - Args: - dataset: a possibly nested structure of `tf.data.Dataset`s and/or - `tf.Tensor`s. - - Returns: - A structure matching `dataset` where `tf.data.Dataset`s are converted to - generators of ndarrays and `tf.Tensor`s are converted to ndarrays. - """ - if not tf.executing_eagerly(): - raise ValueError( - "dataset_as_numpy must be run in eager mode outside tf.function") - nested_ds = dataset - del dataset - - # Flatten - flat_ds = tf.nest.flatten(nested_ds) - flat_np = [] - - # Type check for Tensors and Datasets - for ds_el in flat_ds: - if not isinstance(ds_el, (tf.Tensor, tf.data.Dataset)): - types = tf.nest.map_structure(type, nested_ds) - raise ValueError("Arguments to dataset_as_numpy must be (possibly nested " - "structure of) tf.Tensors or tf.data.Datasets. Got: %s" % - types) - - for ds_el in flat_ds: - if isinstance(ds_el, tf.Tensor): - np_el = tf_np.asarray(ds_el) - elif isinstance(ds_el, tf.data.Dataset): - np_el = _eager_dataset_iterator(ds_el) - else: - assert False - flat_np.append(np_el) - - return tf.nest.pack_sequence_as(nested_ds, flat_np) - - -# TODO(nareshmodi): Group key should change based on the set of devices that we -# are mapping over. Make it so that we assign a unique group_key for every -# unique set of devices. We don't change it every time to avoid the overhead of -# discovering the full group (though may not be problematic in the local case). -_GROUP_KEY = 1 -_INSTANCE_KEY = 0 -_INSTANCE_LOCK = threading.Lock() - - -# TODO(b/142565636): Ensure that multiple concurrent calls to a tf.function -# containing a collective op run reasonably. -def _get_instance_key(): - global _INSTANCE_KEY - global _INSTANCE_LOCK - with _INSTANCE_LOCK: - _INSTANCE_KEY = _INSTANCE_KEY + 1 - return _INSTANCE_KEY - - -# Don't use a namedtuple since nest considers that a tuple and unflattens and -# flattens it. -class ShardedNdArray(object): - """Wrapper over ndarray that can contain tensors on multiple devices. - - This is returned by extensions.pmap, and contains the individual tensors on - different devices. - """ - - def __init__(self, tensors): - """Initializes the ShardedNdArray. - - Note that the tensors should be ordered in the way the pmap producing these - tensors is run. - - Args: - tensors: list or tuple of eager tensors, one for each device. - """ - - if not isinstance(tensors, (list, tuple)) or not tensors: - raise ValueError( - "Unable to create a ShardedNdArray without a list of tensors.") - self.tensors = tensors - self.n_devices = len(tensors) - - def __getitem__(self, i): - return tf_np.asarray(self.tensors[i]) - - @property - def shape(self): - return (self.n_devices,) + self.tensors[0]._shape_tuple() # pylint: disable=protected-access - - @property - def dtype(self): - return self.tensors[0].dtype - - -def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): - del args, kwargs - # TODO(nareshmodi): Consider a collective op to gather the tensors from the - # various devices for performance reasons. - return tf.stack(value.tensors) - - -tf.register_tensor_conversion_function(ShardedNdArray, - convert_sharded_tensor_to_eager_tensor) - - -class _PmapConfig(threading.local): - """Simple config used to maintain state related to a current pmap call.""" - - def __init__(self): - super(_PmapConfig, self).__init__() - self._axis_name = None - self._devices = None - - def axis_name(self): - return self._axis_name - - def set_axis_name(self, axis_name): - self._axis_name = axis_name - - def devices(self): - return self._devices - - def set_devices(self, devices): - self._devices = devices - - -_pmap_config = _PmapConfig() - - -@contextlib.contextmanager -def pmap_config(axis_name, devices): - """Records axis_name and devices for this context.""" - old_axis_name = _pmap_config.axis_name() - old_devices = _pmap_config.devices() - _pmap_config.set_axis_name(axis_name) - _pmap_config.set_devices(devices) - try: - yield - finally: - _pmap_config.set_axis_name(old_axis_name) - _pmap_config.set_devices(old_devices) - - -def _psum(tensor, axis_name=None): - """Sum all-reduction. - - Args: - tensor: A tensor. - axis_name: The axis name to reduce. Must equal to that of the surrounding - pmap. - - Returns: - The sum of the `tensor` replicas on each participating devices. - """ - if axis_name != _pmap_config.axis_name(): - raise ValueError("axis_name (%s) is not equal to that of the surrounding " - "pmap (%s)" % (axis_name, _pmap_config.axis_name())) - devices = _pmap_config.devices() - if devices is None: - raise ValueError("Can't retrieve the device list from the surrounding pmap") - tensor = tf_np.asarray(tensor) - if tpu_devices(devices): - # TODO(b/170895907): Remove this workaround when tpu.cross_replica_sum - # supports int64/float64. - is_int64 = False - is_float64 = False - if tensor.dtype == np.int64: - is_int64 = True - tensor = tensor.astype(np.int32) - elif tensor.dtype == np.float64: - is_float64 = True - tensor = tensor.astype(np.float32) - # TODO(wangpeng): Supply the `group_assignment` argument to - # tpu.cross_replica_sum, calculated from `devices`. - tensor = tf.compat.v1.tpu.cross_replica_sum(tensor) - if is_int64: - tensor = tf.cast(tensor, tf.int64) - elif is_float64: - tensor = tf.cast(tensor, tf.float64) - else: - tensor = tf.raw_ops.CollectiveReduce( - input=tensor, - group_size=len(devices), - group_key=_GROUP_KEY, - instance_key=_get_instance_key(), - merge_op="Add", - final_op="Id", - subdiv_offsets=(0,)) - return tf_np.asarray(tensor) - - -def psum(tensors, axis_name=None): - return tf.nest.map_structure( - functools.partial(_psum, axis_name=axis_name), tensors) - - -# Note this is not available in the jax api, but seemed like a reasonable API -# to have. -def pmean(tensor, axis_name=None): - """Mean all-reduction. - - Args: - tensor: A tensor. - axis_name: The axis name to reduce. Must equal to that of the surrounding - pmap. - - Returns: - The mean of the `tensor` replicas on each participating devices. - """ - if axis_name != _pmap_config.axis_name(): - raise ValueError("axis_name (%s) is not equal to that of the surrounding " - "pmap (%s)" % (axis_name, _pmap_config.axis_name())) - devices = _pmap_config.devices() - if devices is None: - raise ValueError("Can't retrieve the device list from the surrounding pmap") - if tpu_devices(devices): - # TODO(wangpeng): Implement this. - raise ValueError("pmean for TPU is not supported yet.") - else: - return tf.raw_ops.CollectiveReduce( - input=tensor, - group_size=len(devices), - group_key=_GROUP_KEY, - instance_key=_get_instance_key(), - merge_op="Add", - final_op="Div", - subdiv_offsets=(0,)) - - -def _get_pmap_impl(f, devices, has_tpu): - """This is a helper function to return the pmap impl. - - Args: - f: a function that takes ndarrays and returns ndarrays. - devices: a list of strings; the device list. - has_tpu: boolean; whether `devices` contains TPU devices. - - Returns: - A function that takes tensors and returns tensors. - """ - if has_tpu: - # Workaround b/121383831 - output_is_list = [False] # Use list for mutability - def recorder(args, kwargs, res): - del args, kwargs - output_is_list[0] = isinstance(res, list) - return res - f = _record_result_type(recorder, f) - - def tf_f(*tf_args): - """A wrapper for `f` that takes/returns tensors.""" - np_args = _tf_to_np(tf_args) - np_out = f(*np_args) - return np_out - - if has_tpu: - - @tf.function(autograph=False) - def fn(inputs): - # TODO(wangpeng): Supply the `device_assignment` argument to - # tpu.replicate, calculated from `devices`. - res = tf.compat.v1.tpu.replicate(tf_f, inputs) - # Workaround b/121383831 - if (res and isinstance(res[0], list) and len(res[0]) == 1 and - not output_is_list[0]): - res = [x[0] for x in res] - return res - - return fn - else: - # This is run in a tf.function so that the various underlying functions can - # be run in parallel. - # The trace happens on the client, so any devices should not depend on any - # side effects. - - jit_tf_f = tf.function(tf_f, autograph=False) - - @tf.function(autograph=False) - def fn(all_per_device_args): - """Multi-device function with calls placed on the correct device.""" - - results = [] - for per_device_args, device in zip(all_per_device_args, devices): - with tf.device(device): - results.append(jit_tf_f(*per_device_args)) - return results - - return fn - - -def pmap(f, axis_name=None, devices=None): - """Transforms a function into a multi-device function. - - The semantics are similar to JAX's pmap. - - Args: - f: The function to be converted. - axis_name: Used for nested pmap, which is not supported yet. - devices: The devices over which the returned function will run. - - Returns: - A function that runs the underlying function `f` on `devices`. Its arguments - can be `ShardedNdArray`s, tensors or other Python objects, and its return - values are all `ShardedNdArray`s. If an input is a tensor, the length of its - first dimension must equal the number of devices, and the tensor will be - splitted along its first dimension among the devices. If an input is an - unknown Python object, it will be replicated among the devices. - """ - if devices is None: - devices = accelerators() - if not isinstance(devices, (list, tuple)): - raise ValueError("Must pass a list or tuple of devices") - num_devices = len(devices) - if not num_devices: - raise ValueError("There must be at least 1 device") - has_tpu = bool(tpu_devices(devices)) - - pmap_fn = _get_pmap_impl(f, devices, has_tpu) - - def wrapper(*args): - """Wrapper that wraps/unwraps args, retvals, and runs the function.""" - if _pmap_config.devices() is not None: - raise ValueError("Found a surrounding pmap. Nested pmap is not supported " - "yet.") - # TODO(wangpeng): Maybe we should use `asarray` to convert everything - # to ndarray first. - - flattened_input_args = tf.nest.flatten(args) - flattened_per_device_args = [[] for _ in devices] - for arg in flattened_input_args: - if isinstance(arg, tf.Tensor): - # TODO(nareshmodi): Try and use the dynamic shape instead. - if (not arg.shape.rank) or arg.shape[0] != len(devices): - # TODO(nareshmodi): Fix this restriction - raise ValueError( - "Input tensors need to have a first dimension equal to " - "the number of devices; got tensor of shape %s and %s devices" % - (arg.shape, len(devices))) - # NOTE: Alternatively use tf.split, and place the split tensors on the - # appropriate device. The best solution for this is to have an API that - # splits a tensor across devices. - for j, device in enumerate(devices): - updated_arg = tf.gather(arg, j) - # TODO(wangpeng): Investigate whether we need a tf.identity for TPU. - if not has_tpu: - with tf.device(device): - updated_arg = tf.identity(updated_arg) - flattened_per_device_args[j].append(updated_arg) - elif isinstance(arg, ShardedNdArray): - for device_args, tensor in zip(flattened_per_device_args, arg.tensors): - device_args.append(tensor) - else: - for device_args in flattened_per_device_args: - device_args.append(arg) - - all_per_device_args = [ - tf.nest.pack_sequence_as(args, device_args) - for device_args in flattened_per_device_args - ] - - with pmap_config(axis_name, devices): - results = pmap_fn(all_per_device_args) - - # Rewrap things. This can probably be written better. - flattened_results = [tf.nest.flatten(result) for result in results] - final_tree = [] - - # TODO(nareshmodi): assert all items in flattened_results have the same - # structures - - for i in range(len(flattened_results[0])): - tensors = [] - for j, device in enumerate(devices): - assert isinstance( - flattened_results[j][i], - tf.Tensor), ("currently only tensor return items are supported") - tensors.append(flattened_results[j][i]) - final_tree.append(ShardedNdArray(tensors)) - - return tf.nest.pack_sequence_as(results[0], final_tree) - - return wrapper - - -def find_devices(device_type, devices=None): - if not devices: - devices = [d.name for d in tf.config.experimental.list_logical_devices()] - devices = [(d, tf.DeviceSpec.from_string(d)) for d in devices] - results = [name for name, d in devices if d.device_type == device_type] - return results - - -def tpu_devices(devices=None): - """Gets TPU devices out of `devices`. - - Args: - devices: A device list (as a list of strings). If None, the list of all - available devices will be used for it. - - Returns: - Those in `devices` that are TPUs. - """ - return find_devices("TPU", devices) - - -def gpu_devices(devices=None): - """Gets GPU devices out of `devices`. - - Args: - devices: A device list (as a list of strings). If None, the list of all - available devices will be used for it. - - Returns: - Those in `devices` that are GPUs. - """ - return find_devices("GPU", devices) - - -def accelerators(devices=None): - return tpu_devices(devices) or gpu_devices(devices) - - -def _tree_broadcast(to, s): - """Broadcasts `s` to the nested structure `to`.""" - if not isinstance(to, (list, tuple, dict)): - if not isinstance(s, (int, type(None))): - raise ValueError - return s - if isinstance(s, (int, type(None))): - return tf.nest.map_structure(lambda x: s, to) - if isinstance(to, (list, tuple)): - if len(to) != len(s): - raise ValueError - new_s = [_tree_broadcast(x, y) for x, y in zip(to, s)] - if isinstance(to, tuple): - new_s = tuple(new_s) - return new_s - elif isinstance(to, dict): - return {k: _tree_broadcast(to[k], s[k]) for k in to} - else: - raise TypeError("Unsupported type %s" % type(to)) - - -def vmap(f, in_axes=0, out_axes=0): - """Returns a function that maps `f` over first dimension of inputs.""" - in_axes_flat = tf.nest.flatten(in_axes) - if not all(isinstance(l, (type(None), int)) - for l in in_axes_flat): - raise TypeError( - "vmap in_axes must be an int, None, or (nested) container with " - "those types as leaves, but got {}.".format(in_axes)) - if all(isinstance(l, type(None)) for l in in_axes_flat): - raise ValueError("vmap must have at least one non-None value in in_axes") - - out_axes_flat = tf.nest.flatten(out_axes) - if not all(isinstance(l, (type(None), int)) - for l in out_axes_flat): - raise TypeError( - "vmap out_axes must be an int, None, or (nested) container with " - "those types as leaves, but got {}.".format(out_axes)) - - def _f(*args): - flat_args = tf.nest.flatten(args) - try: - f_in_axes = _tree_broadcast(args, in_axes) - except ValueError: - six.reraise( - ValueError, - ValueError( - "vmap in_axes specification must be a tree prefix of the " - r"corresponding value, got specification %s for value tree %s" % ( - in_axes, args)), - sys.exc_info()[2]) - f_in_axes_flat = tf.nest.flatten(f_in_axes) - - def tf_f(tf_args): - """Function passed to tf.vectorized_map call.""" - # Note that unbatched arguments are not passed to tf_f. Here we fill thos - # arguments back before calling `f`. - tf_flat_args = [] - j = 0 - for arg, axis in zip(flat_args, f_in_axes_flat): - if axis is None: - tf_flat_args.append(arg) - else: - tf_flat_args.append(tf_args[j]) - j += 1 - unbatched_args = tf.nest.pack_sequence_as(args, tf_flat_args) - return f(*unbatched_args) - - # Constructs arguments to pass to `tf_f`. - # Unbatch arguments are skipped. Arguments with non-zero axis are - # transposed. - tf_args = [] - for arg, axis in zip(flat_args, f_in_axes_flat): - if axis is None: - continue - arg = tf_np.asarray(arg) - if axis != 0: - arg = tf_np.moveaxis(arg, axis, 0) - tf_args.append(arg) - # TODO(agarwal): consider creating a tf.function outside of _f and reusing - # that to avoid overheads of re-vectorizing the code when running eagerly. - outputs = tf.vectorized_map(tf_f, tf_args) - try: - f_out_axes = _tree_broadcast(outputs, out_axes) - except ValueError: - six.reraise( - ValueError, - ValueError( - "vmap out_axes specification must be a tree prefix of the " - r"corresponding value, got specification %s for value tree %s" % ( - out_axes, outputs)), - sys.exc_info()[2]) - - def map_output(x, axis): - """Maps output of tf.vectorized_map to the final output.""" - x = tf_np.asarray(x) - if axis is None: - # Note that `tf.vectorized_map always batches the outputs. - # Here we unbatch it again. - return x[0, ...] - elif axis == 0: - return x - else: - # Need to transpose the output. - return tf_np.moveaxis(x, 0, axis) - new_outputs = [map_output(output, axis) for output, axis in zip( - tf.nest.flatten(outputs), tf.nest.flatten(f_out_axes))] - return tf.nest.pack_sequence_as(outputs, new_outputs) - - return _f diff --git a/trax/tf_numpy/extensions/extensions_test.py b/trax/tf_numpy/extensions/extensions_test.py deleted file mode 100644 index 065dbdeef..000000000 --- a/trax/tf_numpy/extensions/extensions_test.py +++ /dev/null @@ -1,1060 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tf numpy mathematical methods.""" -import functools -import itertools - -from absl import flags -from absl.testing import parameterized - -from jax import lax -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy import extensions -import trax.tf_numpy.numpy as tf_np - - -FLAGS = flags.FLAGS - -flags.DEFINE_bool("requires_tpu", False, "Requires TPU.") - - -def generate_params_inputs_targets(num_examples=1000): - params = (tf_np.asarray(tf.constant(5.)), tf_np.asarray(tf.constant(0.))) - - params_true = (tf_np.asarray(tf.constant(3.)), tf_np.asarray(tf.constant(2.))) - - inputs = tf_np.asarray(tf.random.normal(shape=[num_examples])) - noise = tf_np.asarray(tf.random.normal(shape=[num_examples])) - targets = inputs * params_true[0] + params_true[1] + noise - - return params, params_true, inputs, targets - - -def loss_fn(params, inputs, targets): - predicted = params[0] * inputs + params[1] - loss = tf.reduce_mean(input_tensor=tf.square(predicted - targets)) - return tf_np.asarray(loss) - - -def train_step(params, inputs, targets, learning_rate=0.1): - grad_fn = extensions.grad(loss_fn) - grads = grad_fn(params, inputs, targets) - new_w = params[0] - (grads[0] * learning_rate) - new_b = params[1] - (grads[1] * learning_rate) - - return new_w, new_b - - -def uniform(rng, shape, dtype): - if np.issubdtype(dtype, np.integer): - minval = None - else: - minval = 0 - return tf_np.asarray(rng.uniform(shape=shape, dtype=dtype, minval=minval)) - - -def to_np(a): - return tf.nest.map_structure(tf_np.asarray, a) - - -def to_tf_fn(f): - return lambda *args: f(*to_np(args)) - - -def scan_reference(f, init, xs): - carry = init - ys = [] - for x in xs: - (carry, y) = f(carry, x) - ys.append(tf_np.reshape(y, (1,) + y.shape)) - ys = tf_np.concatenate(ys, 0) - return carry, ys - - -def spec(*args): - return tf.TensorSpec(args, tf.float32) - - -class ExtensionsTest(tf.test.TestCase, parameterized.TestCase): - - def __init__(self, methodName="runTest"): # pylint: disable=invalid-name - super().__init__(methodName) - physical_devices = tf.config.experimental.list_physical_devices("CPU") - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], [ - tf.config.experimental.VirtualDeviceConfiguration(), - tf.config.experimental.VirtualDeviceConfiguration() - ]) - if extensions.tpu_devices(): - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") - tf.tpu.experimental.initialize_tpu_system(resolver) - - def _hasGPU(self): - physical_devices = tf.config.experimental.list_physical_devices("GPU") - return physical_devices - - def testCustomGrad(self): - """Test for custom_grad.""" - x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) - y_shape = (tf.TensorShape([])) - dtype = np.float32 - scale1 = 5.0 - scale2 = 6.0 - - def fwd(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - @extensions.custom_grad - def f(a, b): - y = fwd(a, b) - - def vjp(dy): - return dy * scale1 * a, dy * scale2 * b - - return y, vjp - - rng = tf.random.Generator.from_seed(1234) - x, dy = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype), - [x_shape, y_shape]) - expected_y = fwd(*x) - expected_dx = (dy * scale1 * x[0], dy * scale2 * x[1]) - y, vjp = extensions.vjp(f, *x) - dx = vjp(dy) - self.assertAllClose(expected_y, y) - self.assertAllClose(expected_dx, dx) - - @parameterized.named_parameters([ - ( # pylint: disable=g-complex-comprehension - ("_%s_%s_%s" % (decorator_id, x_struct, y_struct)).replace( - " ", "").replace("None", ""), decorator, x_struct, y_struct) - for y_struct in [[None, ()], (None, (), [], (None, ((), None)))] - for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))] - for decorator_id, decorator in enumerate([lambda f: f, extensions.jit]) - ]) - def testCustomGradStructure(self, decorator, x_struct, y_struct): - """Tests that custom_grad can handle structured inputs/outputs.""" - - def zeros(x): - return tf.nest.map_structure(lambda _: tf_np.zeros([], np.float32), x) - - def get_struct(x): - return tf.nest.map_structure(lambda _: None, x) - - @extensions.custom_grad - def f(*x): - del x - - def vjp(dy): - self.assertEqual(y_struct, get_struct(dy)) - return zeros(x_struct) - - return zeros(y_struct), vjp - - x, dy = zeros([x_struct, y_struct]) - - @decorator - def run(x, dy): - y, vjp = extensions.vjp(f, *x) - dx = vjp(dy) - return dx, y - - dx, y = run(x, dy) - self.assertEqual(x_struct, get_struct(dx)) - self.assertEqual(y_struct, get_struct(y)) - - @parameterized.named_parameters([ - ("_%s" % has_aux, has_aux) for has_aux in [True, False] - ]) - def testVjp(self, has_aux): - x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) - y_shape = (tf.TensorShape([])) - dtype = np.float32 - - def f(a, b): - y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - if has_aux: - return y, tf_np.asarray(1) - else: - return y - - rng = tf.random.Generator.from_seed(1234) - x, dy_list = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype), - [x_shape, [y_shape] * 2]) - tf_x = x - outputs = extensions.vjp(f, *x, has_aux=has_aux) - if has_aux: - y, vjp, aux = outputs - else: - y, vjp = outputs - with tf.GradientTape(persistent=True) as tape: - tape.watch(tf_x) - outputs = f(*x) - if has_aux: - expected_y, expected_aux = outputs - self.assertAllClose(expected_aux, aux) - else: - expected_y = outputs - self.assertAllClose(expected_y, y) - for dy in dy_list: - expected_dx = tape.gradient( - expected_y, tf_x, output_gradients=dy) - self.assertAllClose(expected_dx, vjp(dy)) - - def testGrad(self): - - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - g = extensions.grad(f) - - def compare(a, b): - with tf.GradientTape() as tape: - tape.watch(a) - r = f(a, b) - expected = tape.gradient(r, a) - self.assertAllEqual(expected, g(a, b)) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - compare(a, b) - - def testGradNonArrayOutput(self): - - def f(_): - return 1.0 - - g = extensions.grad(f) - with self.assertRaisesWithPredicateMatch(ValueError, - r"result .* must be an ndarray"): - g(tf_np.asarray(1.0)) - - def testGradNonScalarOutput(self): - - def f(a): - return a - - g = extensions.grad(f) - with self.assertRaisesWithPredicateMatch(ValueError, - r"result .* must be a scalar"): - g(tf_np.asarray([1.0, 2.0])) - - @extensions.jit - def g_jitted(a): - return extensions.grad(f)(a) - - g_jitted(tf_np.asarray(1.0)) - with self.assertRaisesWithPredicateMatch(ValueError, - r"result .* must be a scalar"): - g_jitted(tf_np.asarray([1.0, 2.0])) - - def testJit(self): - - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_jitted = extensions.jit(f) - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - self.assertAllClose(f(a, b), f_jitted(a, b)) - # Call again since the code path is different on second call - self.assertAllClose(f(a, b), f_jitted(a, b)) - - def testJitNoUnnecessaryTracing(self): - - def num_traces(f): - return len(f.tf_function._list_all_concrete_functions_for_serialization()) - - def check_trace_only_once(arg1, arg2): - - @extensions.jit - def f(a): - return a + 1 - - self.assertAllEqual(0, num_traces(f)) - f(arg1) - self.assertAllEqual(1, num_traces(f)) - f(arg2) - self.assertAllEqual(1, num_traces(f)) - - check_trace_only_once(1, 2) - check_trace_only_once(1.1, 2.1) - check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) - check_trace_only_once( - tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2)) - - def _testEvalOnShapes(self, transformer, allow_static_outputs): - - # A class that's not convertable to tensor - class Thing: - - def __init__(self, value): - self.value = value - - def f(a, b, reverse=False): - res = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - res = (res, 10) - if allow_static_outputs: - res = res + (Thing(20),) - if reverse: - res = tuple(reversed(res)) - return res - - f_prime = transformer( - f, static_argnums=(2,), allow_static_outputs=allow_static_outputs) - shape = [10] - dtype = np.float16 - a = tf_np.zeros(shape=shape, dtype=dtype) - b = tf_np.zeros(shape=shape, dtype=dtype) - expected, *_ = f(a, b) - got = f_prime(a, b) - def check(got): - self.assertIsInstance(got[0], (tf.TensorSpec, tf_np.ndarray)) - self.assertAllEqual(expected.shape, got[0].shape) - self.assertAllEqual(expected.dtype, got[0].dtype) - if allow_static_outputs: - self.assertIsInstance(got[1], int) - self.assertEqual(10, got[1]) - self.assertIsInstance(got[2], Thing) - self.assertEqual(20, got[2].value) - else: - self.assertIsInstance(got[1], (tf.TensorSpec, tf_np.ndarray)) - self.assertAllEqual((), got[1].shape) - check(got) - # Call again since the code path is different on second call - got = f_prime(a, b) - check(got) - # Retrace and check again - got = f_prime(a, b, True) - check(tuple(reversed(got))) - got = f_prime(a, b, True) - check(tuple(reversed(got))) - - @parameterized.named_parameters(("_%s" % b, b) for b in [False, True]) - def testEvalOnShapes(self, allow_static_outputs): - self._testEvalOnShapes(extensions.eval_on_shapes, allow_static_outputs) - - def testEvalOnShapesNested(self): - transformer = functools.partial(extensions.eval_on_shapes, - allow_static_outputs=True) - @transformer - def outer(): - @transformer - def inner(): - return 1 - return inner() + 2 - r = outer() - self.assertIsInstance(r, int) - self.assertEqual(3, r) - - def testJitOfEvalOnShapes(self): - """Tests that eval_on_shapes can be called within jit.""" - - def transformer(f, **kwargs): - def f_prime(*args): - res = extensions.eval_on_shapes(f, **kwargs)(*args) - return tf.nest.map_structure( - lambda x: tf_np.zeros(x.shape, x.dtype), res) - return extensions.jit(f_prime, kwargs.get("static_argnums", ())) - - self._testEvalOnShapes(transformer, False) - - def testEvalOnShapesNoUnnecessaryTracing(self): - - def num_traces(f): - return len( - f._tf_function._list_all_concrete_functions_for_serialization()) - - def check_trace_only_once(arg1, arg2): - - @extensions.eval_on_shapes - def f(a): - return a + 1 - - self.assertAllEqual(0, num_traces(f)) - f(arg1) - self.assertAllEqual(1, num_traces(f)) - f(arg2) - self.assertAllEqual(1, num_traces(f)) - - check_trace_only_once(1, 2) - check_trace_only_once(1.1, 2.1) - check_trace_only_once(tf_np.asarray(1), tf_np.asarray(2)) - check_trace_only_once( - tf.convert_to_tensor(value=1), tf.convert_to_tensor(value=2)) - - @parameterized.parameters( - { - "lhs_np": np.ones((5, 3)), - "rhs_np": np.ones((3, 2)), - "dims": (((1,), (0,)), ((), ())) - }, - { - "lhs_np": np.ones((5, 3)), - "rhs_np": np.ones((5, 3)), - "dims": (((0, 1), (0, 1)), ((), ())) - }, - { - "lhs_np": np.ones((5, 3, 2)), - "rhs_np": np.ones((2, 3, 2)), - "dims": (((1, 2), (1, 0)), ((), ())) - }, - { - "lhs_np": np.ones((6, 5, 3)), - "rhs_np": np.ones((6, 3, 2)), - "dims": (((2,), (1,)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((6, 3, 5)), - "rhs_np": np.ones((6, 3, 2)), - "dims": (((1,), (1,)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((5, 3, 2, 2)), - "rhs_np": np.ones((5, 2, 2, 6)), - "dims": (((2, 3), (1, 2)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((2, 2, 5, 3)), - "rhs_np": np.ones((2, 2, 3, 2)), - "dims": (((3,), (2,)), ((0, 1), (0, 1))) - }, - { - "lhs_np": np.ones((2, 2, 5, 2)), - "rhs_np": np.ones((2, 2, 3, 2)), - "dims": (((3,), (1,)), ((0,), (0,))) - }, - { - "lhs_np": np.ones((2, 2, 5, 3, 3)), - "rhs_np": np.ones((2, 3, 2, 3, 2)), - "dims": (((4,), (1,)), ((0,), (0,))) - }, - ) - def test_tf_dot_general(self, lhs_np, rhs_np, dims): - ans = lax.dot_general(lhs_np, rhs_np, dims) - result = extensions.tf_dot_general(lhs_np, rhs_np, dims) - self.assertAllClose(result, np.array(ans)) - - @parameterized.named_parameters([ - ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" # pylint: disable=g-complex-comprehension - "_lhs_dilation={}_rhs_dilation={}" - "_feature_group_count={}_batch_group_count={}_dims={}" - "_perms={}".format(lhs_shape, rhs_shape, - strides, padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, ",".join( - dimension_numbers), perms), - lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, dimension_numbers, perms) - for batch_group_count, feature_group_count in [(1, 1)] - for lhs_shape, rhs_shape in [ - ((b * batch_group_count, i * feature_group_count, 9, w), - (j * feature_group_count * batch_group_count, i, 4, 5)) - for w in [0, 10] - for b, i, j in itertools.product([2, 3], repeat=3)] - for strides in [(1, 1), (2, 1)] - for padding in ["SAME"] - for lhs_dilation, rhs_dilation in [ - (None, (1, 1)) - ] - for dimension_numbers, perms in [ - (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) - ]]) - def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, - padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, - dimension_numbers, perms): - lhs_perm, rhs_perm = perms # permute to compatible shapes - - lhs = np.transpose(np.ones(lhs_shape), lhs_perm) - rhs = np.transpose(np.ones(rhs_shape), rhs_perm) - - jax_conv = lax.conv_general_dilated(lhs, rhs, strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers, - feature_group_count, - batch_group_count) - - tf_conv = extensions.tf_conv_general_dilated(lhs, rhs, strides, - padding, None, - lhs_dilation, rhs_dilation, - dimension_numbers, - feature_group_count, - batch_group_count) - - self.assertAllClose(tf_conv, tf_np.asarray(jax_conv)) - - def testConv(self): - y = extensions.conv( - np.ones([5, 320, 480, 3], dtype=np.float32), - np.ones([3, 4, 3, 11], dtype=np.float32), [1, 1], "SAME", - ("NHWC", "HWIO", "NHWC")) - self.assertAllClose(y.shape, [5, 320, 480, 11]) - self.assertAllClose( - y, - tf.nn.conv2d( - input=tf.ones([5, 320, 480, 3], dtype=tf.float32), - filters=tf.ones([3, 4, 3, 11], dtype=tf.float32), - strides=1, - padding="SAME")) - - def testAvgPool(self): - y = extensions.avg_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") - self.assertAllEqual( - y, - tf.nn.pool( - input=tf.ones([5, 320, 480, 3]), - window_shape=[3, 5], - pooling_type="AVG", - padding="VALID", - strides=[2, 3], - )) - - def testMaxPool(self): - y = extensions.max_pool(np.ones([5, 320, 480, 3]), [3, 5], [2, 3], "VALID") - self.assertAllEqual( - y, - tf.nn.pool( - input=tf.ones([5, 320, 480, 3]), - window_shape=[3, 5], - pooling_type="MAX", - padding="VALID", - strides=[2, 3], - )) - - def assertDTypesEqual(self, a, b): - get_dtype = lambda t: t.dtype - self.assertEqual(tf.nest.map_structure(get_dtype, a), - tf.nest.map_structure(get_dtype, b)) - - @parameterized.named_parameters( - (f"_{jit_scan}_{jit_f}", jit_scan, jit_f) # pylint: disable=g-complex-comprehension - for jit_f in [False, True] - for jit_scan in ["no", "no_xla", "xla_forced_compile"]) - def testScanImpl(self, jit_scan, jit_f): - rng = np.random.RandomState(0) - - d = rng.randn(2) - def f(c, a): - assert a.shape == (3,) - assert c.shape == (4,) - b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + - tf_np.sum(tf_np.tan(d))) - c = tf_np.sin(c * b) - assert b.shape == () # pylint: disable=g-explicit-bool-comparison - return c, b - - if jit_f: - f = extensions.jit(f) - - if jit_scan == "no_xla": - scan = extensions.jit(extensions.scan, static_argnums=(0,)) - elif jit_scan == "xla_forced_compile": - scan = extensions.jit(extensions.scan, static_argnums=(0,), - xla_forced_compile=True) - else: - scan = extensions.scan - - xs = rng.randn(5, 3) - c = rng.randn(4) - - ans = scan(f, c, xs) - expected = scan_reference(f, c, xs) - if jit_scan == "xla_forced_compile": - # xla.compile doesn't preserve list-vs-tuple properly for the outputs, so - # we canonicalize them to lists here. - expected = list(expected) - ans = list(ans) - self.assertDTypesEqual(expected, ans) - self.assertAllClose(expected, ans) - - def testScanStruct(self): - rng = np.random.RandomState(0) - - d = rng.randn(2) - def f(c_g_i, a_e_h): - c_g, i = c_g_i - c, g = c_g - a, e_h = a_e_h - e, h = e_h - assert a.shape == (3,) - assert e.shape == () # pylint: disable=g-explicit-bool-comparison - assert c.shape == (4,) - assert g.shape == (2,) - assert i is None - assert h is None - b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + - tf_np.sum(tf_np.tan(d))) - f = tf_np.cos(a) - c = tf_np.sin(c * b) - g = tf_np.sin(g * b) - assert b.shape == () # pylint: disable=g-explicit-bool-comparison - assert f.shape == (3,) - return [(c, g), i], (b, [f, h]) - - xs = (rng.randn(5, 3), [rng.randn(5), None]) - init = [(rng.randn(4), rng.randn(2)), None] - - c_g_i, b_f_h = extensions.scan(f, init, xs) - self.assertIsInstance(c_g_i, list) - self.assertIsInstance(b_f_h, tuple) - c_g, i = c_g_i - c, g = c_g - self.assertIsInstance(c_g, tuple) - self.assertEqual((4,), c.shape) - self.assertEqual((2,), g.shape) - self.assertIsNone(i) - b, f_h = b_f_h - f, h = f_h - self.assertIsInstance(f_h, list) - self.assertEqual((5,), b.shape) - self.assertEqual((5, 3), f.shape) - self.assertIsNone(h) - - @parameterized.named_parameters( - (f"_{jit_grad}_{jit_scan}_{jit_f}", jit_grad, jit_scan, jit_f) # pylint: disable=g-complex-comprehension - for jit_f in [False, True] - for jit_scan in ["no", "no_xla", "xla_forced_compile"] - for jit_grad in ["no", "no_xla", "xla_forced_compile"]) - def testScanGrad(self, jit_grad, jit_scan, jit_f): - rng = np.random.RandomState(0) - - d = rng.randn(2) - def f(c, a): - assert a.shape == (3,) - assert c.shape == (4,) - b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) + - tf_np.sum(tf_np.sin(d))) - c = tf_np.sin(c * b) - assert b.shape == () # pylint: disable=g-explicit-bool-comparison - return c, b - - if jit_f: - f = extensions.jit(f) - - if jit_scan == "no_xla": - scan = extensions.jit(extensions.scan, static_argnums=(0,)) - elif jit_scan == "xla_forced_compile": - # TODO(b/187107596): Remove `skipTest` - self.skipTest( - "Taking gradients of `jit(scan, experimental_compile=True)` triggers " - "'Support for TensorList crossing the XLA/TF boundary is not " - "implemented' error") - # `xla_forced_compile=True` doesn't support gradients, so we use - # `experimental_compile=True`. - scan = extensions.jit(extensions.scan, static_argnums=(0,), - experimental_compile=True) - else: - scan = extensions.scan - - xs = tf_np.asarray(rng.randn(5, 3)) - c = tf_np.asarray(rng.randn(4)) - - def losses(scan, c, xs): - c, ys = scan(f, c, xs) - return tf_np.concatenate(tf.nest.flatten(tf.nest.map_structure( - lambda a: tf_np.reshape(a, [-1]), (c, ys)))) - def loss(scan, c, xs): - return tf_np.sum(losses(scan, c, xs)) - - def grad_origin(c, xs): - return extensions.grad(functools.partial(loss, scan))(c, xs) - - if jit_grad == "no_xla": - grad_jit = extensions.jit(grad_origin) - elif jit_grad == "xla_forced_compile": - grad_jit = extensions.jit(grad_origin, xla_forced_compile=True) - else: - grad_jit = grad_origin - - ans = grad_jit(c, xs) - expected = extensions.grad(functools.partial(loss, scan_reference))(c, xs) - self.assertDTypesEqual(expected, ans) - self.assertAllClose(expected, ans) - - theoretical, numerical = tf.test.compute_gradient( - to_tf_fn(functools.partial(losses, scan)), (c, xs)) - self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=3e-4) - - @parameterized.named_parameters( - (f"_{i}", *args) # pylint: disable=g-complex-comprehension - for i, args in enumerate([ - (lambda c, x: (c + 1, tf_np.sum(c + x, 0)), - [spec(2), spec(4, 3, 2)], [spec(2), spec(4, 2)]), - (lambda c, x: (c + 1, tf_np.sum(c + x, 0)), - [spec(2), spec(0, 3, 2), 0], [spec(2), spec(0, 2)]), - ])) - def testScanShape(self, f, inputs, expected_outputs): - outputs = extensions.eval_on_shapes( - functools.partial(extensions.scan, f), static_argnums=(2,))(*inputs) - self.assertAllEqual(expected_outputs, outputs) - - def testMap(self): - shape = [2, 3] - dtype = tf_np.int32 - xs1 = tf_np.zeros(shape, dtype) - xs2 = tf_np.ones(shape, dtype) - ys_expected = [xs2 + 10, xs1 + 20] - def f(x): - self.assertIsInstance(x, tuple) - for a in x: - self.assertEqual(a.shape, shape[1:]) - x1, x2 = x - return [x2 + 10, x1 + 20] - ys = extensions.tf_map(f, (xs1, xs2)) - self.assertIsInstance(ys, list) - self.assertAllClose(ys, ys_expected) - - def testPrng(self): - self.assertAllEqual(tf_np.asarray(123, np.int64), extensions.prng(123)) - - def testUniform(self): - minval = 0.43 - maxval = 3.10 - shape = [13, 34, 29] - atol = 0.1 - outputs = extensions.uniform(123, shape, minval=minval, maxval=maxval) - self.assertAllClose((minval + maxval) / 2.0, np.mean(outputs), atol=atol) - - def testNormal(self): - shape = [13, 34, 29] - atol = 0.1 - outputs = extensions.normal(123, shape) - self.assertAllClose(0, np.mean(outputs), atol=atol) - self.assertAllClose(1, np.std(outputs), atol=atol) - - def testBernoulli(self): - mean = 0.23 - shape = [13, 34, 29] - atol = 0.1 - outputs = extensions.bernoulli(123, mean, shape) - self.assertAllClose(mean, np.mean(outputs), atol=atol) - - def testBernoulliWrongShape(self): - mean = [0.1, 0.2] - shape = [3] - with self.assertRaisesIncompatibleShapesError(): - extensions.bernoulli(123, mean, shape) - - def testDatasetAsNumpy(self): - arrs = extensions.dataset_as_numpy( - [tf.constant([1, 2]), tf.constant([3, 4])]) - for a in arrs: - self.assertIsInstance(a, tf_np.ndarray) - with self.assertRaisesWithPredicateMatch( - ValueError, - r"dataset_as_numpy must be run in eager mode outside tf.function"): - - @tf.function - def f(): - return extensions.dataset_as_numpy([tf.constant([1, 2])]) - - f() - - def _get_two_devices(self, require_same_type=False): - tpus = extensions.tpu_devices() - if FLAGS.requires_tpu: - if len(tpus) == 2: - res = tpus - else: - raise ValueError("This test requires 2 TPU cores but %s are found" % - len(tpus)) - else: - if len(tpus) == 2: - res = tpus - elif self._hasGPU() and not require_same_type: - res = ("CPU:0", "GPU:0") - else: - res = ("CPU:0", "CPU:1") - return res - - def testPmap(self): - devices = self._get_two_devices() - - @functools.partial(extensions.pmap, devices=devices) - def return_three(f): - return f, f + 1.0, f + 2.0 - - result = return_three(tf.ones((2, 20))) - # The function returned 3 items, so we got 3 items back. - self.assertLen(result, 3) - - # Each of the items should be a ShardedNdarray that when converted to tensor - # should produce a tensor of shape (2, 20) - converted = tf.nest.map_structure(tf.convert_to_tensor, result) - - self.assertLen(result, 3) - - self.assertAllEqual(converted[0].shape, converted[1].shape) - self.assertAllEqual(converted[0].shape, converted[2].shape) - - self.assertAllEqual(converted[0], tf.ones((2, 20))) - self.assertAllEqual(converted[1], 1 + tf.ones((2, 20))) - self.assertAllEqual(converted[2], 2 + tf.ones((2, 20))) - - @functools.partial(extensions.pmap, devices=devices) - def return_one(f): - return f + 2.0 - - result = return_one(tf.ones((2, 20))) - - # Only a single item is returned, so we can convert it directly. - converted = tf.convert_to_tensor(value=result) - self.assertAllEqual(converted, 2 + tf.ones((2, 20))) - - @functools.partial(extensions.pmap, devices=devices) - def return_list(f): - return [f + 2.0] - - result = return_list(tf.ones((2, 20))) - - # A singleton list is returned. - self.assertLen(result, 1) - converted = tf.convert_to_tensor(value=result[0]) - self.assertAllEqual(converted, 2 + tf.ones((2, 20))) - - def testGradSimpleModel(self): - params, params_true, inputs, targets = generate_params_inputs_targets() - - for _ in range(50): - params = train_step(params, inputs, targets) - - # This is not trained super well, but it usually gets "close". - self.assertAllClose(params[0], params_true[0], atol=1e-1) - self.assertAllClose(params[1], params_true[1], atol=1e-1) - - # NOTE: Compare to testGradSimpleModel to see the differences when pmapping. - def testPmapSimpleModel(self): - devices = self._get_two_devices(require_same_type=True) - n_devices = len(devices) - - params, params_true, inputs, targets = generate_params_inputs_targets() - - def _train_and_reduce(params, inputs, targets, learning_rate=0.1): - new_w, new_b = train_step(params, inputs, targets, learning_rate) - - return (extensions.psum(new_w) / n_devices, - extensions.psum(new_b) / n_devices) - - train_step_pmapped = extensions.pmap(_train_and_reduce, devices=devices) - - def replicate(x, num_devices=2): - return tf_np.broadcast_to(x, (num_devices,) + x.shape) - - params = tf.nest.map_structure(replicate, params) - - def reshape(x, num_devices=2): - x_shape = list(x.shape) - batch_size = x_shape[0] - batch_size_per_device = batch_size // num_devices - - # New shape. - new_shape_prefix = [num_devices, batch_size_per_device] - return tf_np.reshape(x, new_shape_prefix + x_shape[1:]) - - inputs = tf.nest.map_structure(reshape, inputs) - targets = tf.nest.map_structure(reshape, targets) - - for _ in range(50): - params = train_step_pmapped(params, inputs, targets) - - # PMAP returns sharded tensors. - - # Since the inputs are identical, the returned tensors should be identical - self.assertAllClose(params[0][0], params[0][1]) - self.assertAllClose(params[1][0], params[1][1]) - - # This is not trained super well, but it usually gets "close". - self.assertAllClose(params[0][0], params_true[0], atol=1e-1) - self.assertAllClose(params[1][0], params_true[1], atol=1e-1) - - def testPsum(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(f): - return extensions.psum(f) - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - pmapped = extensions.pmap(reduce_sum, devices=devices) - result = pmapped(data) - - self.assertAllClose(result[0], 4) - self.assertAllClose(result[1], 4) - - def testPsumStruct(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(a): - a = extensions.psum(a) - tf.nest.map_structure( - lambda x: self.assertIsInstance(x, tf_np.ndarray), a) - return a - - data = [tf_np.asarray([1, 3]), tf_np.asarray([2, 4], np.int64)] - pmapped = extensions.pmap(reduce_sum, devices=devices) - result = pmapped(data) - - self.assertIsInstance(result[0][0], tf_np.ndarray) - self.assertIsInstance(result[0][1], tf_np.ndarray) - self.assertIsInstance(result[1][0], tf_np.ndarray) - self.assertIsInstance(result[1][1], tf_np.ndarray) - self.assertAllClose(result[0][0], 4) - self.assertAllClose(result[0][1], 4) - self.assertAllClose(result[1][0], 6) - self.assertAllClose(result[1][1], 6) - - def testPmean(self): - if extensions.tpu_devices(): - self.skipTest("pmean for TPU is not supported yet") - devices = self._get_two_devices(require_same_type=True) - - def reduce_mean(f): - return extensions.pmean(f) - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - pmapped = extensions.pmap(reduce_mean, devices=devices) - result = pmapped(data) - - self.assertAllClose(result[0], 2) - self.assertAllClose(result[1], 2) - - def testAxisName(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(f): - return extensions.psum(f, axis_name="foo") - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) - pmapped(data) - - def testWrongAxisName(self): - devices = self._get_two_devices(require_same_type=True) - - def reduce_sum(f): - return extensions.psum(f, axis_name="bar") - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - with self.assertRaisesWithPredicateMatch( - ValueError, r"axis_name (.*) is not equal to that of the surrounding"): - pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) - pmapped(data) - - def testNoNestedPmap(self): - devices = self._get_two_devices(require_same_type=True) - - def f(x): - return x + 1.0 - - data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) - with self.assertRaisesWithPredicateMatch(ValueError, - r"Nested pmap is not supported"): - f = extensions.pmap(f, devices=devices) - f = extensions.pmap(f, devices=devices) - f(data) - - def testVmap(self): - fn1 = extensions.vmap(lambda z: z * z) - - x = tf_np.arange(10) - self.assertAllClose(x * x, fn1(x)) - - y = tf.range(10) - np_y = tf_np.asarray(y) - output = fn1(y) - self.assertIsInstance(output, tf_np.ndarray) - self.assertAllClose(np_y * np_y, output) - - fn2 = extensions.vmap(lambda x, y: x + y) - x = tf_np.random.randn(10, 3) - y = tf_np.random.randn(10, 2, 3) - self.assertAllClose(tf_np.expand_dims(x, 1) + y, fn2(x, y)) - - def testRemat(self): - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.grad(f_remat)(a, b) - expected = extensions.grad(f)(a, b) - self.assertAllClose(actual, expected) - - def testRematLambdaFunction(self): - f = lambda a, b: tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.grad(f_remat)(a, b) - expected = extensions.grad(f)(a, b) - self.assertAllClose(actual, expected) - - def testRematJit(self): - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.jit(extensions.grad(f_remat))(a, b) - expected = extensions.jit(extensions.grad(f))(a, b) - self.assertAllClose(actual, expected) - - def testRematJitXla(self): - def f(a, b): - return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) - - f_remat = extensions.remat(f) - - shape = [10] - a = tf_np.random.randn(*shape) - b = tf_np.random.randn(*shape) - - actual = extensions.jit( - extensions.grad(f_remat), xla_forced_compile=True)(a, b) - expected = extensions.jit(extensions.grad(f), xla_forced_compile=True)(a, b) - self.assertAllClose(actual, expected) - - actual = extensions.jit( - extensions.grad(f_remat), experimental_compile=True)(a, b) - expected = extensions.jit( - extensions.grad(f), experimental_compile=True)(a, b) - self.assertAllClose(actual, expected) - - def testStaticStopGradient(self): - self.assertEqual(extensions.stop_gradient(5.), 5.) - self.assertEqual(type(extensions.stop_gradient(5.)), type(5.)) - - self.assertEqual(extensions.stop_gradient(tf_np.asarray(5.)), 5.) - self.assertNotEqual( - type(extensions.stop_gradient(tf_np.asarray(5.))), type(5.)) - - -if __name__ == "__main__": - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/jax_tests/config.py b/trax/tf_numpy/jax_tests/config.py deleted file mode 100644 index 5da9f1b1e..000000000 --- a/trax/tf_numpy/jax_tests/config.py +++ /dev/null @@ -1,151 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -def bool_env(varname: str, default: bool) -> bool: - """Read an environment variable and interpret it as a boolean. - - True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; - false values are 'n', 'no', 'f', 'false', 'off', and '0'. - - Args: - varname: the name of the variable - default: the default boolean value - Raises: ValueError if the environment variable is anything else. - """ - val = os.getenv(varname, str(default)) - val = val.lower() - if val in ('y', 'yes', 't', 'true', 'on', '1'): - return True - elif val in ('n', 'no', 'f', 'false', 'off', '0'): - return False - else: - raise ValueError("invalid truth value %r for environment %r" % (val, varname)) - - -class Config(object): - def __init__(self): - self.values = {} - self.meta = {} - self.FLAGS = NameSpace(self.read) - self.use_absl = False - - def update(self, name, val): - if self.use_absl: - setattr(self.absl_flags.FLAGS, name, val) - else: - self.check_exists(name) - if name not in self.values: - raise Exception("Unrecognized config option: {}".format(name)) - self.values[name] = val - - def read(self, name): - if self.use_absl: - return getattr(self.absl_flags.FLAGS, name) - else: - self.check_exists(name) - return self.values[name] - - def add_option(self, name, default, opt_type, meta_args, meta_kwargs): - if name in self.values: - raise Exception("Config option {} already defined".format(name)) - self.values[name] = default - self.meta[name] = (opt_type, meta_args, meta_kwargs) - - def check_exists(self, name): - if name not in self.values: - raise Exception("Unrecognized config option: {}".format(name)) - - def DEFINE_bool(self, name, default, *args, **kwargs): - self.add_option(name, default, bool, args, kwargs) - - def DEFINE_integer(self, name, default, *args, **kwargs): - self.add_option(name, default, int, args, kwargs) - - def DEFINE_string(self, name, default, *args, **kwargs): - self.add_option(name, default, str, args, kwargs) - - def DEFINE_enum(self, name, default, *args, **kwargs): - self.add_option(name, default, 'enum', args, kwargs) - - def config_with_absl(self): - # Run this before calling `app.run(main)` etc - import absl.flags as absl_FLAGS - from absl import app, flags as absl_flags - - self.use_absl = True - self.absl_flags = absl_flags - absl_defs = { bool: absl_flags.DEFINE_bool, - int: absl_flags.DEFINE_integer, - str: absl_flags.DEFINE_string, - 'enum': absl_flags.DEFINE_enum } - - for name, val in self.values.items(): - flag_type, meta_args, meta_kwargs = self.meta[name] - absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) - - app.call_after_init(lambda: self.complete_absl_config(absl_flags)) - - def complete_absl_config(self, absl_flags): - for name, _ in self.values.items(): - self.update(name, getattr(absl_flags.FLAGS, name)) - - def parse_flags_with_absl(self): - global already_configured_with_absl - if not already_configured_with_absl: - import absl.flags - self.config_with_absl() - absl.flags.FLAGS(sys.argv, known_only=True) - self.complete_absl_config(absl.flags) - already_configured_with_absl = True - - -class NameSpace(object): - def __init__(self, getter): - self._getter = getter - - def __getattr__(self, name): - return self._getter(name) - - -config = Config() -flags = config -FLAGS = flags.FLAGS - -already_configured_with_absl = False - -flags.DEFINE_bool( - 'jax_enable_checks', - bool_env('JAX_ENABLE_CHECKS', False), - help='Turn on invariant checking (core.skip_checks = False)') - -flags.DEFINE_bool('tf_numpy_additional_tests', True, - 'Run tests added specifically for TF numpy') diff --git a/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py b/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py deleted file mode 100644 index cb583abae..000000000 --- a/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py +++ /dev/null @@ -1,359 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict # pylint: disable=g-importing-member -import itertools - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.jax_tests.config import config -import trax.tf_numpy.jax_tests.test_util as jtu -import trax.tf_numpy.numpy as jnp - - -config.parse_flags_with_absl() - - -class EinsumTest(jtu.TestCase): - - def _check(self, s, *ops): - a = np.einsum(s, *ops) - b = jnp.einsum(s, *ops) - self.assertAllClose(a, b, check_dtypes=True, atol=1e-4, rtol=1e-4) - - def test_three_operands_1(self): - r = self.rng() - x = r.randn(3) - y = r.randn(4) - z = r.randn(5) - s = 'i,j,k->ijk' - self._check(s, x, y, z) - - def test_three_operands_2(self): - r = self.rng() - x = r.randn(3) - y = r.randn(4) - z = r.randn(5) - s = 'i,j,k->ijk' - self._check(s, x, y, z) - - def test_two_operands_1(self): - r = self.rng() - x = r.randn(3, 4) - y = r.randn(4) - s = 'ij,j->i' - self._check(s, x, y) - - def test_two_operands_2(self): - r = self.rng() - x = r.randn(3, 4, 5) - y = r.randn(4) - s = 'ijk,j->i' - self._check(s, x, y) - - def test_two_operands_3(self): - r = self.rng() - x = r.randn(3, 4, 3) - y = r.randn(3) - s = 'iji,i->j' - self._check(s, x, y) - - def test_two_operands_4(self): - r = self.rng() - x = r.randn(3, 4) - y = r.randn(3, 4) - s = 'ij,ij->' - self._check(s, x, y) - - def test_two_operands_5(self): - r = self.rng() - x = r.randn(10, 2, 3) - y = r.randn(3, 4) - s = 'nij,jk->nik' - self._check(s, x, y) - - def test_two_operands_6(self): - # based on https://github.com/google/jax/issues/37#issuecomment-448572187 - r = self.rng() - x = r.randn(2, 1) - y = r.randn(2, 3, 4) - s = 'sa,shb->shab' - self._check(s, x, y) - - def test_one_operand_1(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->j' - self._check(s, x) - - def test_one_operand_2(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->kij' - self._check(s, x) - - def test_one_operand_3(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->ki' - self._check(s, x) - - def test_one_operand_4(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = 'ijk->ki' - self._check(s, x) - - def test_one_operand_5(self): - r = self.rng() - x = r.randn(2, 3, 4, 5) - s = '...ijk->...ki' - self._check(s, x) - - def test_one_operand_6(self): - r = self.rng() - x = r.randn(3, 4, 5) - s = '...ijk->ki' - self._check(s, x) - - def test_one_operand_7(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ii->' - self._check(s, x) - - def test_one_operand_8(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ij->' - self._check(s, x) - - def test_one_operand_9(self): - r = self.rng() - x = r.randn(3, 3, 3) - s = 'iii->' - self._check(s, x) - - def test_one_operand_10(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ii->i' - self._check(s, x) - - def test_one_operand_11(self): - r = self.rng() - x = r.randn(3, 3, 4) - s = 'iij->i' - self._check(s, x) - - def test_one_operand_12(self): - r = self.rng() - x = r.randn(3, 3, 3) - s = 'iii->i' - self._check(s, x) - - def test_one_operand_13(self): - r = self.rng() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkk->i' - self._check(s, x) - - def test_one_operand_14(self): - r = self.rng() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkk->ik' - self._check(s, x) - - def test_one_operand_15(self): - r = self.rng() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkl->il' - self._check(s, x) - - def test_one_operand_16(self): - r = self.rng() - x = r.randn(3, 3) - s = 'ij->ij' - self._check(s, x) - - def test_tf_unsupported_1(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = self.rng() - x = r.randn(2, 3, 5, 1) - y = r.randn(3, 4, 5, 1) - s = 'ij...,jk...->ik...' - self._check(s, x, y) - - def test_tf_unsupported_2(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = self.rng() - x = r.randn(2, 3, 3) - y = r.randn(4) - s = 'ijj,k->ik' - self._check(s, x, y) - - def test_tf_unsupported_3(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = self.rng() - x = r.randn(2, 3) - y = r.randn(2, 3) - z = r.randn(3, 4) - s = 'ij,ij,jk->ik' - self._check(s, x, y, z) - - # these tests are based on https://github.com/dask/dask/pull/3412/files - @parameterized.named_parameters( - {'testcase_name': '_{}_dtype={}'.format(einstr, dtype.__name__), # pylint: disable=g-complex-comprehension - 'einstr': einstr, 'dtype': dtype} - for einstr in [ - 'abc,bad->abcd', - 'abcdef,bcdfg->abcdeg', - 'ea,fb,abcd,gc,hd->efgh', - 'ab,b', - 'aa', - 'a,a->', - 'a,a->a', - 'a,a', - 'a,b', - 'a,b,c', - 'a', - 'ba,b', - 'ba,b->', - 'defab,fedbc->defac', - 'ab...,bc...->ac...', - 'a...a', - 'abc...->cba...', - '...ab->...a', - 'a...a->a...', - # Following 2 from # https://stackoverflow.com/a/19203475/1611416 - '...abc,...abcd->...d', - 'ab...,b->ab...', - # https://github.com/dask/dask/pull/3412#discussion_r182413444 - 'aa->a', - 'ab,ab,c->c', - 'aab,bc->ac', - 'aab,bcc->ac', - 'fdf,cdd,ccd,afe->ae', - 'fff,fae,bef,def->abd', - ] - # TODO(wangpeng): Add jnp.bool_ to dtype list - for dtype in [jnp.float32, jnp.int32, jnp.complex64]) - def test_from_dask(self, einstr, dtype): - r = jtu.rand_default() - if '->' in einstr: - input_str, _ = einstr.split('->') - else: - input_str = einstr - input_names = input_str.split(',') - - dims = itertools.cycle([2, 3, 4]) - shapes = defaultdict(lambda: next(dims)) - input_shapes = [tuple(shapes[c] for c in names.replace('...', '01')) - for names in input_names] - operands = [r(shape, dtype) for shape in input_shapes] - - self._check(einstr, *operands) - - def test_ordered_front_batch_dim_case(self): - x = np.ones((1, 8, 20, 4)) - y = np.ones((1, 8, 20, 4)) - s = 'ijkl,ijml->ijkm' - self._check(s, x, y) - - # pylint: disable=invalid-name - def test_einsum_path(self): - # just check examples from np.einsum_path docstring - a = self.rng().rand(2, 2) - b = self.rng().rand(2, 5) - c = self.rng().rand(5, 2) - - path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') - self.assertEqual(str(path_info[0]), "['einsum_path', (1, 2), (0, 1)]") - self.assertEqual(path_info[1].split('\n')[0], - ' Complete contraction: ij,jk,kl->il') - - # check this doesn't crash - I = self.rng().rand(10, 10, 10, 10) - C = self.rng().rand(10, 10) - np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, optimize='greedy') - - @jtu.disable - def test_einsum_kpmurphy_example(self): - # code from an email with @murphyk - N = 2 - C = 3 - D = 4 - K = 5 - T = 6 - r = self.rng() - S = r.randn(N, T, K) - W = r.randn(K, D) - V = r.randn(D, C) - L = np.zeros((N, C)) - for n in range(N): - for c in range(C): - s = 0 - for d in range(D): - for k in range(K): - for t in range(T): - s += S[n, t, k] * W[k, d] * V[d, c] - L[n, c] = s - - path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0] - rtol = 1e-2 if jtu.device_under_test() == 'tpu' else None - self.assertAllClose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path), - check_dtypes=False, rtol=rtol) - # pylint: enable=invalid-name - - @jtu.disable - def test_contraction_broadcasting(self): - r = self.rng() - x = r.randn(3, 4, 5) - y = r.randn(3, 1, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - @jtu.disable - def test_batch_broadcasting(self): - r = self.rng() - x = r.randn(1, 4, 5) - y = r.randn(3, 5, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - @jtu.disable - def test_batch_and_contraction_broadcasting(self): - r = self.rng() - x = r.randn(1, 4, 5) - y = r.randn(3, 1, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - @jtu.disable - def test_broadcasting_issue_2189(self): - r = self.rng() - x = r.randn(2, 1, 3, 3) - y = r.randn(2, 4, 3) - s = '...ij,...j' - self._check(s, x, y) - - -if __name__ == '__main__': - tf.enable_v2_behavior() - absltest.main() diff --git a/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py b/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py deleted file mode 100644 index 7f0a13f03..000000000 --- a/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py +++ /dev/null @@ -1,1000 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import enum -from functools import partial -import itertools -import unittest - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as onp -import tensorflow.compat.v2 as tf - -import trax.tf_numpy.extensions as npe -import trax.tf_numpy.numpy as jnp - -from trax.tf_numpy.jax_tests.config import config -import trax.tf_numpy.jax_tests.test_util as jtu - -config.parse_flags_with_absl() - - -# We disable the whitespace continuation check in this file because otherwise it -# makes the test name formatting unwieldy. -# pylint: disable=bad-continuation -# We also disable undefined-variable till we start enabling tests. -# pylint: disable=undefined-variable - - -def subvals(lst, replace): - lst = list(lst) - for i, v in replace: - lst[i] = v - return tuple(lst) - - -float_dtypes = [onp.float32, onp.float64] -int_dtypes = [onp.int32, onp.int64] -bool_types = [onp.bool_] -default_dtypes = float_dtypes + int_dtypes -all_dtypes = float_dtypes + int_dtypes + bool_types - -IndexSpec = collections.namedtuple("IndexTest", ["shape", "indexer"]) - - -suppress_deprecated_indexing_warnings = partial( - jtu.ignore_warning, category=FutureWarning, - message='Using a non-tuple sequence.*') - - -STATIC_INDEXING_TESTS = [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), - ]), - ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ("OneSliceIndex", [ - IndexSpec(shape=(10,), indexer=slice(1, 3)), - IndexSpec(shape=(10,), indexer=slice(1, -1)), - IndexSpec(shape=(10,), indexer=slice(None, -1)), - IndexSpec(shape=(10,), indexer=slice(None, None, None)), - IndexSpec(shape=(10, 8), indexer=slice(1, 3)), - IndexSpec(shape=(10, 8), indexer=slice(1, None)), - IndexSpec(shape=(10, 8), indexer=slice(None, 3)), - IndexSpec(shape=(10, 8), indexer=slice(-3, None)), - ]), - ("OneSliceIndexNegativeStride", [ - IndexSpec(shape=(10,), indexer=slice(3, 1, -1)), - IndexSpec(shape=(10,), indexer=slice(1, 8, -1)), # empty result - IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10,), indexer=slice(None, None, -1)), - IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1)), - IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1)), # empty result - IndexSpec(shape=(10, 8), indexer=slice(None, None, -1)), - ]), - ("OneSliceIndexNonUnitStride", [ - IndexSpec(shape=(10,), indexer=slice(0, 8, 2)), - IndexSpec(shape=(10,), indexer=slice(0, 8, 3)), - IndexSpec(shape=(10,), indexer=slice(1, 3, 2)), - IndexSpec(shape=(10,), indexer=slice(1, None, 2)), - IndexSpec(shape=(10,), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, 2)), - IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, -2)), - ]), - ("TwoSliceIndices", [ - IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2))), - IndexSpec( - shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None))), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2))), - ]), - ("OneColonIndex", [ - IndexSpec(shape=(3,), indexer=slice(None)), - IndexSpec(shape=(3, 4), indexer=slice(None)), - ]), - ("MultipleColonIndices", [ - IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), - ]), - ("MixedSliceIndices", [ - IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2))), - IndexSpec(shape=(10, 4), indexer=(1, slice(None))), - ]), - ("EllipsisIndex", [ - IndexSpec(shape=(3,), indexer=Ellipsis), - IndexSpec(shape=(3, 4), indexer=Ellipsis), - IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), - ]), - ("NoneIndex", [ - IndexSpec(shape=(), indexer=None), - IndexSpec(shape=(), indexer=(None, None)), - IndexSpec(shape=(), indexer=(Ellipsis, None)), - IndexSpec(shape=(3,), indexer=None), - IndexSpec(shape=(3, 4), indexer=None), - IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), - IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), - ]), - ("EmptyIndex", [ - IndexSpec(shape=(), indexer=()), - IndexSpec(shape=(3,), indexer=()), - IndexSpec(shape=(3, 4), indexer=()), - ]), -] - -STATIC_INDEXING_GRAD_TESTS = [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), - ]), - ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ("OneSliceIndex", [ - IndexSpec(shape=(5,), indexer=slice(1, 3)), - IndexSpec(shape=(5,), indexer=slice(1, -1)), - IndexSpec(shape=(5,), indexer=slice(None, -1)), - IndexSpec(shape=(5,), indexer=slice(None, None, None)), - IndexSpec(shape=(5, 4), indexer=slice(1, 3)), - IndexSpec(shape=(5, 4), indexer=slice(1, None)), - IndexSpec(shape=(5, 4), indexer=slice(None, 3)), - IndexSpec(shape=(5, 4), indexer=slice(-3, None)), - ]), - ("TwoSliceIndices", [ - IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4), indexer=(slice(1, None), slice(None, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, None), slice(0, 2))), - ]), - ("OneColonIndex", [ - IndexSpec(shape=(3,), indexer=slice(None)), - IndexSpec(shape=(3, 4), indexer=slice(None)), - ]), - ("MultipleColonIndices", [ - IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None))), - ]), - ("MixedSliceIndices", [ - IndexSpec(shape=(5, 4), indexer=(slice(None), slice(0, 2))), - IndexSpec(shape=(5, 4), indexer=(1, slice(None))), - ]), - ("EllipsisIndex", [ - IndexSpec(shape=(3,), indexer=Ellipsis), - IndexSpec(shape=(3, 4), indexer=Ellipsis), - IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3)), - ]), - ("NoneIndex", [ - IndexSpec(shape=(), indexer=None), - IndexSpec(shape=(), indexer=(None, None)), - IndexSpec(shape=(), indexer=(Ellipsis, None)), - IndexSpec(shape=(3,), indexer=None), - IndexSpec(shape=(3, 4), indexer=None), - IndexSpec(shape=(3, 4), indexer=(Ellipsis, None)), - IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis)), - ]), - # TODO(mattjj): these fail for uninteresting dtype reasons - # ("EmptyIndex", - # [IndexSpec(shape=(), indexer=()), - # IndexSpec(shape=(3,), indexer=()), - # IndexSpec(shape=(3, 4), indexer=()), - # ]), -] - -ADVANCED_INDEXING_TESTS = [ - ("One1DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([0, 1])), - IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), - IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), - IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), - ]), - ("One2DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])), - IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], - [0, 1, -1]])), - IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], - [-1, -2, 1, 0]])), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), - onp.array([1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, 0, 1]), - onp.array([-1, 0, -1, 2]))), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), - onp.array([1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2, 0, 1]]), - onp.array([-1, 0, -1, 2]))), - ]), - ("TupleOfListsOfPythonInts", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]])), - ]), - ("TupleOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, - onp.array([[2, 3, 0, 3]]))), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], - onp.array([[2, 3, 0, 3]]))), - ]), -] - -ADVANCED_INDEXING_TESTS_NO_REPEATS = [ - ("One1DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([0, 1])), - IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 0])), - IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), - IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), - # Fails with a TF/XLA error. - # IndexSpec(shape=(0,), indexer=onp.array([], dtype=onp.int32)), - ]), - ("One2DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=onp.array([[0, 1]])), - IndexSpec(shape=(6, 6), indexer=onp.array([[1, 2, 0], - [3, 4, -1]])), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([0, 1]), - onp.array([1, 2]))), - IndexSpec(shape=(4, 5, 6), indexer=(onp.array([0, 2, 1, 3]), - onp.array([-1, 0, -2, 1]))), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(onp.array([[0, 1]]), - onp.array([1, 2]))), - IndexSpec(shape=(4, 5, 6), indexer=(onp.array([[0, 2, -1, 1]]), - onp.array([-1, 0, -2, 2]))), - ]), - ("TupleOfListsOfPythonInts", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1])), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]])), - ]), - ("TupleOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=(0, onp.array([0, 1]))), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, - onp.array([[2, 3, 0]]))), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1], onp.array([0]))), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], - onp.array([[2, 3, 0]]))), - ]), -] - -MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [ - ("SlicesAndOneIntArrayIndex", - [IndexSpec(shape=(2, 3), indexer=(onp.array([0, 1]), slice(1, 2))), - IndexSpec(shape=(2, 3), indexer=(slice(0, 2), - onp.array([0, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([0, 2]), - slice(None))), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([[0, 2], [1, 3]]), - slice(None))), - ]), - ("SlicesAndTwoIntArrayIndices", - [IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([0, 2]), - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - Ellipsis, - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - onp.array([-1, 2]), - Ellipsis)), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - onp.array([-1, 2]), - slice(1, 3))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - slice(1, 3), - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]), - slice(None, None, 2), - onp.array([-1, 2, 1]))), - ]), - ("NonesAndIntArrayIndices", - [IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - None, - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2]), - None, - None, - onp.array([-1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([0, 2]), - None, - None, - onp.array([-1, 2]))), - ]), - ("IntArrayWithInt32Type", - [IndexSpec(shape=(3, 4), indexer=(Ellipsis, onp.array(1, dtype=onp.int32))) - ]), -] - -MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [ - ("SlicesAndOneIntArrayIndex", - [ - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, - onp.array([[0, 2], [1, 1]]), - slice(None))), - ]), - ("SlicesAndTwoIntArrayIndices", - [IndexSpec(shape=(3, 4, 5), indexer=(onp.array([0, 2, -2]), - slice(None, None, 2), - onp.array([-1, 2, -1]))), - IndexSpec(shape=(3, 4, 5), indexer=(onp.array([[0, 2], [2, 0]]), - Ellipsis, - onp.array([[1, 0], [1, 0]]))), - ]),] - - -def dynamic_slice_reference(operand, start_indices, slice_sizes): - out = onp.zeros(slice_sizes, dtype=operand.dtype) - idx = tuple(slice(start, start+size) - for start, size in zip(start_indices, slice_sizes)) - section = operand[idx] - out[tuple(slice(None, stop) for stop in section.shape)] = section - return out - - -def dynamic_update_slice_reference(operand, update, start_indices): - slices = tuple(map( - slice, start_indices, onp.add(start_indices, update.shape))) - updated_operand = onp.copy(operand) - updated_operand[slices] = update - return updated_operand - - -class IndexingTest(jtu.TestCase): - """Tests for Numpy indexing translation rules.""" - - @parameterized.named_parameters(jtu.cases_from_list({ - "testcase_name": "{}_inshape={}_indexer={}".format( - name, jtu.format_shape_dtype_string( shape, dtype), indexer), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer - } for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testStaticIndexing(self, shape, dtype, rng_factory, indexer): - # TODO(rohanj): Revisit passing in self.rng() to this to customize further. - # This would need updating lax_numpy_test as well. - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype)] - onp_fun = lambda x: x[indexer] - jnp_fun = lambda x: jnp.asarray(x)[indexer] - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True) - - def _ReplaceSlicesWithTuples(self, idx): - """Helper method to replace slices with tuples for dynamic indexing args.""" - if isinstance(idx, slice): - triple = idx.start, idx.stop, idx.step - isnone = [i for i, elt in enumerate(triple) if elt is None] - zeros = itertools.repeat(0) - nones = itertools.repeat(None) - out = subvals(triple, zip(isnone, zeros)) - return out, lambda out: slice(*subvals(out, zip(isnone, nones))) - elif isinstance(idx, (tuple, list)) and idx: - t = type(idx) - elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) - return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) - else: - return idx, lambda x: x - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} - for name, index_specs in [ - ("OneSliceIndex", - [IndexSpec(shape=(5,), indexer=slice(1, 3)), - IndexSpec(shape=(5, 4), indexer=slice(1, 3))]), - ("TwoSliceIndices", - [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]), - ("NonUnitStrides", [ - IndexSpec(shape=(3,), indexer=slice(None, None, -1)), - IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), - IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) - ]), - ("OnlyStartOrStopDynamic", [ - IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) - ]), - ] - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testDynamicIndexingWithSlices(self, shape, dtype, rng_factory, indexer): - rng = rng_factory() - unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) - - def onp_fun(x, unpacked_indexer): - indexer = pack_indexer(unpacked_indexer) - return x[indexer] - - jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) - - args_maker = lambda: [rng(shape, dtype), unpacked_indexer] - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - # TODO(wangpeng): check_xla_forced_compile is turned off because some - # compile-time-constant requirements are violated. Investigate and turn it - # on. - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_eval_on_shapes=False, - check_incomplete_shape=True, - check_xla_forced_compile=False) - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} - for name, index_specs in [ - ("OneIntIndex", - [IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2)]), - ("TwoIntIndices", - [IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]), - ("ThreeIntIndices", - [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ] - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): - # TODO(rohanj): Revisit passing in self.rng() to this to customize further. - # This would need updating lax_numpy_test as well. - rng = rng_factory() - unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) - - def onp_fun(x, unpacked_indexer): - indexer = pack_indexer(unpacked_indexer) - return x[indexer] - - jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) - - args_maker = lambda: [rng(shape, dtype), unpacked_indexer] - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True) - - @parameterized.named_parameters( - {"testcase_name": "_{}_inshape={}_indexer={}" # pylint: disable=g-complex-comprehension - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "name": name, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer} - for name, index_specs in ADVANCED_INDEXING_TESTS - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testAdvancedIntegerIndexing(self, name, shape, dtype, rng_factory, - indexer): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), indexer] - onp_fun = lambda x, idx: x[idx] - jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) - - self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) - # TODO(wangpeng): check_xla_forced_compile is turned off for - # ListOfPythonIntsAndIntArrays because it throws "The number of output - # elements has to equal to number of input elements that are sliced when - # input indices are not constant". Investigate and turn it on. - check_xla = (name != "ListOfPythonIntsAndIntArrays") - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, - check_xla_forced_compile=check_xla) - - @parameterized.named_parameters( - {"testcase_name": "_{}_inshape={}_indexer={}" # pylint: disable=g-complex-comprehension - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "name": name, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer} - for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS - for shape, indexer in index_specs - for dtype in all_dtypes - for rng_factory in [jtu.rand_default]) - def testMixedAdvancedIntegerIndexing(self, name, shape, dtype, rng_factory, - indexer): - rng = rng_factory() - indexer_with_dummies = [e if isinstance(e, onp.ndarray) else () - for e in indexer] - substitutes = [(i, e) for i, e in enumerate(indexer) - if not isinstance(e, onp.ndarray)] - args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] - - def np_fun(x, indexer_with_dummies): - idx = type(indexer)(subvals(indexer_with_dummies, substitutes)) - return x[idx] - - jnp_fun = lambda x, idx: np_fun(jnp.asarray(x), idx) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - # TODO(wangpeng): check_xla_forced_compile is turned off for - # IntArrayWithInt32Type because it throws "The number of output elements has - # to equal to number of input elements that are sliced when input indices - # are not constant". Investigate and turn it on. - check_xla = (name != "IntArrayWithInt32Type") - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, - check_xla_forced_compile=check_xla) - - def testAdvancedIndexingManually(self): - x = onp.random.RandomState(0).randn(3, 4, 5) - index_array = onp.array([0, 2, -1, 0]) - - op = lambda x, index_array: x[..., index_array, :] - cop = npe.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2, check_dtypes=True) - - op = lambda x, index_array: x[..., index_array, :, index_array, None] - cop = npe.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2, check_dtypes=True) - - op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] - cop = npe.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2, check_dtypes=True) - - # Note that we don't currently allow __iter__ in graph mode. So this test only - # iterates over eager tensor. - def testUnpacking(self): - - def foo(x): - a, b, c = x - return a + b + c - - a1 = foo(onp.arange(3)) - a2 = foo(jnp.arange(3)) - - self.assertAllClose(a1, a2, check_dtypes=True) - - def testBooleanIndexingArray1D(self): - idx = onp.array([True, True, False]) - x = jnp.asarray(onp.arange(3)) - ans = x[idx] - expected = onp.arange(3)[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingList1D(self): - idx = [True, True, False] - x = jnp.asarray(onp.arange(3)) - ans = x[idx] - expected = onp.arange(3)[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingArray2DBroadcast(self): - idx = onp.array([True, True, False, True]) - x = onp.arange(8).reshape(4, 2) - ans = jnp.asarray(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingList2DBroadcast(self): - idx = [True, True, False, True] - x = onp.arange(8).reshape(4, 2) - ans = jnp.asarray(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingArray2D(self): - idx = onp.array([[True, False], - [False, True], - [False, False], - [True, True]]) - x = onp.arange(8).reshape(4, 2) - ans = jnp.asarray(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingDynamicShape(self): - x = onp.zeros(3) - i = onp.array([True, True, False]) - ans = x[i] - expected = jnp.asarray(x)[i] - self.assertAllClose(ans, expected, check_dtypes=True) - - def testIssue187(self): - x = jnp.ones((5, 5)) - x[[0, 2, 4], [0, 2, 4]] # doesn't crash - - x = onp.arange(25).reshape((5, 5)) - ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) - expected = x[[0, 2, 4], [0, 2, 4]] - self.assertAllClose(ans, expected, check_dtypes=False) - - # TODO(agarwal): Fix this use case. - @jtu.disable - def testIndexingEmptyDimension(self): - # Issue 2671: XLA error when indexing into dimension of size 0 - x = jnp.ones((2, 0)) - # The following work, even on axis 1 of size 0 - _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] - - with self.assertRaisesRegex(IndexError, - "index .* is out of bounds for axis .* with size 0"): - _ = onp.ones((2, 0))[0, 0] # The numpy error - with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): - _ = x[0, 0] # JAX indexing - with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): - npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit - - def testBooleanIndexingWithEmptyResult(self): - # based on a TensorFlow Probability test that started failing after #1623 - x = jnp.array([-1]) - mask = jnp.array([False]) - ans = x[mask] # doesn't crash - - expected = onp.array([-1])[onp.array([False])] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testFloatIndexingError(self): - error_regex = "only integers, slices.*are valid indices" - # Verify onp behavior - with self.assertRaisesRegex(IndexError, error_regex): - _ = onp.zeros((2, 2))[(0, 0.)] - # Test jnp - with self.assertRaisesRegex(IndexError, error_regex): - jnp.zeros(2)[0.] - with self.assertRaisesRegex(IndexError, error_regex): - jnp.zeros((2, 2))[(0, 0.)] - # Test with jit - with self.assertRaisesRegex(IndexError, error_regex): - npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) - - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 - array = jnp.ones(5) - self.assertAllClose(array, array[:10], check_dtypes=True) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, size_indices), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "size_indices": size_indices, "rng_factory": rng_factory} - for shape, start_indices, size_indices in [ - [(3,), onp.array((1,)), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(5, 3), onp.array((1, 1)), (3, 1)], - [(7, 5, 3), onp.array((4, 1, 0)), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicSlice(self, shape, dtype, start_indices, size_indices, - rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] - op = lambda x, starts: npe.dynamic_slice(x, starts, size_indices) - self._CompileAndCheck(op, args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, size_indices), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "size_indices": size_indices, "rng_factory": rng_factory} - for shape, start_indices, size_indices in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicSliceAgainstNumpy(self, shape, dtype, start_indices, - size_indices, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), onp.array(start_indices)] - op = lambda x, s: npe.dynamic_slice(x, s, size_indices) - numpy_op = lambda x, s: dynamic_slice_reference(x, s, size_indices) - self._CheckAgainstNumpy(numpy_op, op, args_maker) - - def testDynamicSliceInDim(self): - rng = jtu.rand_default() - x = rng((6, 7), onp.int32) - self.assertAllClose(npe.dynamic_slice_in_dim(x, 2, 3), x[2:5], - check_dtypes=True) - - -def _broadcastable_shapes(shape): - """Returns all shapes that broadcast to `shape`.""" - def f(rshape): - yield [] - if rshape: - for s in f(rshape[1:]): - yield rshape[0:1] + s - if rshape[0] != 1: - for s in f(rshape[1:]): - yield [1] + s - for x in f(list(reversed(shape))): - yield list(reversed(x)) - - -def _update_shape(shape, indexer): - return onp.zeros(shape)[indexer].shape - - -class UpdateOps(enum.Enum): - UPDATE = 0 - ADD = 1 - # MUL = 2 - MIN = 3 - MAX = 4 - - def np_fn(op, indexer, x, y): # pylint: disable=no-self-argument - x = x.copy() - x[indexer] = { - UpdateOps.UPDATE: lambda: y, - UpdateOps.ADD: lambda: x[indexer] + y, - # UpdateOps.MUL: lambda: x[indexer] * y, - UpdateOps.MIN: lambda: onp.minimum(x[indexer], y), - UpdateOps.MAX: lambda: onp.maximum(x[indexer], y), - }[op]() - return x - - def tfnp_fn(op, indexer, x, y): # pylint: disable=no-self-argument - return { - UpdateOps.UPDATE: npe.index_update, - UpdateOps.ADD: npe.index_add, - # UpdateOps.MUL: npe.index_mul, - UpdateOps.MIN: npe.index_min, - UpdateOps.MAX: npe.index_max, - }[op](x, indexer, y) - - -# a test to workaround b/123559667 -def has_non_trivial_stride(indexer): - def has(idx): - return isinstance(idx, slice) and idx.step not in (1, -1, None) - return any(has(idx) for idx in tf.nest.flatten(indexer)) - - -class IndexedUpdateTest(jtu.TestCase): - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer in index_specs - for op in UpdateOps - for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) - # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing - # xla_context 0-th output from". Investigate. - check_xla = (not has_non_trivial_stride(indexer) and # b/123559667 - not (isinstance(indexer, slice) and indexer.stop == 8 and - indexer.step == -1)) - self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS - for shape, indexer in index_specs - for op in UpdateOps - for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) - self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS - for shape, indexer in index_specs - for op in UpdateOps - for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in all_dtypes - for rng_factory in [jtu.rand_default])) - def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) - check_xla = not has_non_trivial_stride(indexer) # b/123559667 - self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @parameterized.named_parameters(jtu.cases_from_list({ # pylint: disable=g-complex-comprehension - "testcase_name": "_{}_{}_{}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory, - "indexer": indexer, "update_shape": update_shape, - "update_dtype": update_dtype, "op": op - } for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer in index_specs - for op in [UpdateOps.ADD, UpdateOps.UPDATE] - for dtype in float_dtypes - for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) - for update_dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, - rng_factory, indexer, op): - rng = rng_factory() - tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) - x = rng(shape, dtype) - y = rng(update_shape, update_dtype) - self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, update_shape), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "update_shape": update_shape, "rng_factory": rng_factory} - for shape, start_indices, update_shape in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape, - rng_factory): - rng = rng_factory() - def args_maker(): - return [rng(shape, dtype), rng(update_shape, dtype), - onp.array(start_indices)] - # update's shape must be fully known. - # TODO(wangpeng): Support turning off check_incomplete_shape for individual - # arguments. - self._CompileAndCheck(npe.dynamic_update_slice, args_maker, - check_incomplete_shape=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension - jtu.format_shape_dtype_string(shape, dtype), - start_indices, update_shape), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "update_shape": update_shape, "rng_factory": rng_factory} - for shape, start_indices, update_shape in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(5, 3), (1, -2), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - [(), (), ()], - ] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, start_indices, - update_shape, rng_factory): - rng = rng_factory() - def args_maker(): - return [rng(shape, dtype), rng(update_shape, dtype), - onp.array(start_indices)] - self._CheckAgainstNumpy(dynamic_update_slice_reference, - npe.dynamic_update_slice, args_maker) - - def testDynamicUpdateSliceInDim(self): - rng = jtu.rand_default() - x = rng((6, 7), onp.int32) - y = rng((3, 7), onp.int32) - z = x.copy() - z[2:5] = y - self.assertAllClose(npe.dynamic_update_slice_in_dim(x, y, 2, 0), z, - check_dtypes=True) - - -if __name__ == "__main__": - tf.config.set_soft_device_placement(False) - jnp.enable_numpy_behavior() - absltest.main() diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py deleted file mode 100644 index e973ef79f..000000000 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ /dev/null @@ -1,3085 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import functools -from functools import partial -import itertools -import operator -import unittest -from unittest import SkipTest -import warnings - -from absl.testing import absltest -from absl.testing import parameterized -import six - -import numpy as onp - - -import tensorflow.compat.v2 as tf -import trax.tf_numpy.numpy as lnp -import trax.tf_numpy.extensions as npe -from trax.tf_numpy.jax_tests.config import config, FLAGS -import trax.tf_numpy.jax_tests.test_util as jtu - - -from tensorflow.python.framework import ops -from tensorflow.python.ops.numpy_ops import np_config - -config.parse_flags_with_absl() - - -nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] -nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes -empty_array_shapes = [(0,), (0, 4), (3, 0),] - -scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] -array_shapes = nonempty_array_shapes + empty_array_shapes -nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes -nonempty_shapes = scalar_shapes + nonempty_array_shapes -all_shapes = scalar_shapes + array_shapes - -# TODO(wangpeng): float_dtypes = [lnp.bfloat16, onp.float16, onp.float32, -# onp.float64] -float_dtypes = [onp.float16, onp.float32, onp.float64] -complex_dtypes = [onp.complex64, onp.complex128] -int_dtypes = [onp.int32, onp.int64] -unsigned_dtypes = [onp.uint32, onp.uint64] -bool_dtypes = [onp.bool_] -default_dtypes = float_dtypes + int_dtypes -inexact_dtypes = float_dtypes + complex_dtypes -number_dtypes = float_dtypes + complex_dtypes + int_dtypes -all_dtypes = number_dtypes + bool_dtypes - - -python_scalar_dtypes = [lnp.bool_, lnp.int_, lnp.float_, lnp.complex_] - -def _valid_dtypes_for_shape(shape, dtypes): - # Not all (shape, dtype) pairs are valid. In particular, Python scalars only - # have one type in each category (float, bool, etc.) - if shape is jtu.PYTHON_SCALAR_SHAPE: - return [t for t in dtypes if t in python_scalar_dtypes] - return dtypes - -def _shape_and_dtypes(shapes, dtypes): - for shape in shapes: - for dtype in _valid_dtypes_for_shape(shape, dtypes): - yield (shape, dtype) - -OpRecord = collections.namedtuple( - "OpRecord", - ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", - "test_name", "check_dtypes", "tolerance", "inexact", - "check_incomplete_shape"]) - -def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name=None, check_dtypes=True, tolerance=None, inexact=False, - check_incomplete_shape=True): - test_name = test_name or name - return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name, check_dtypes, tolerance, inexact, - check_incomplete_shape) - - -def minus(a, b): - return [x for x in a if x not in b] - - -JAX_ONE_TO_ONE_OP_RECORDS = [ - op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("add", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []), - op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("float_power", 2, inexact_dtypes, all_shapes, - partial(jtu.rand_default, scale=1), ["rev"], - tolerance={ - # TODO(wangpeng): lnp.bfloat16: 1e-2, - onp.float32: 1e-3, - onp.float64: 1e-12, onp.complex64: 2e-4, - onp.complex128: 1e-12}, check_dtypes=False), - op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("greater", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("greater_equal", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("less", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("less_equal", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_equal, []), - op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool, []), - op_record("maximum", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf, []), - op_record("minimum", 2, minus(all_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf, []), - op_record("multiply", 2, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("nextafter", 2, [f for f in float_dtypes - if f not in (lnp.bfloat16, onp.float16)], - all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), - op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), - op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), - op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("signbit", 1, default_dtypes + bool_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"]), - op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("tan", 1, number_dtypes, all_shapes, - partial(jtu.rand_uniform, -1.5, 1.5), ["rev"], - tolerance={onp.complex64: 3e-5, onp.complex128: 4e-14}, - inexact=True), - # TODO(wangpeng): Add float16 support - op_record("sinh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("cosh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_default, ["rev"], - inexact=True), - # TODO(b/142975473): on CPU, tanh for complex128 is only accurate to - # ~float32 precision. - # TODO(b/143135720): on GPU, tanh has only ~float32 precision. - op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={onp.float64: 1e-7, onp.complex128: 1e-7}, - inexact=True), - op_record("arcsin", 1, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arccos", 1, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arctan", 1, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arctan2", 2, minus(float_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), - op_record("arcsinh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("arccosh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("arctanh", 1, minus(number_dtypes, [onp.float16]), all_shapes, jtu.rand_small, ["rev"], - inexact=True), -] - -JAX_COMPOUND_OP_RECORDS = [ - # angle has inconsistent 32/64-bit return types across numpy versions. - op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [], - check_dtypes=False, inexact=True), - op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("cbrt", 1, default_dtypes, all_shapes, jtu.rand_default, ["rev"], - inexact=True), - op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("deg2rad", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero, ["rev"], - inexact=six.PY3), - op_record("divmod", 2, minus(int_dtypes + float_dtypes, [onp.float16]), - all_shapes, jtu.rand_nonzero, []), - op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"], - tolerance={ - # TODO(wangpeng): lnp.bfloat16: 2e-2, - onp.float16: 1e-2}, inexact=True), - # TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32 - # precision. - op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive, [], - test_name="expm1_large", tolerance={onp.float64: 1e-8}, inexact=True), - op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive, - [], tolerance={onp.float64: 1e-8}, inexact=True), - op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("floor_divide", 2, minus(number_dtypes, complex_dtypes), - all_shapes, jtu.rand_nonzero, ["rev"]), - op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], - inexact=True), - op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], - inexact=True), - op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, [], - check_incomplete_shape=False), - op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("imag", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("isfinite", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf_and_nan, []), - op_record("isinf", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf_and_nan, []), - op_record("isnan", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_inf_and_nan, []), - op_record("isneginf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isposinf", 1, float_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), - op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("isrealobj", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive, [], - test_name="log1p_large", tolerance={onp.float64: 1e-12}, - inexact=True), - op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive, [], - tolerance={onp.float64: 1e-12}, inexact=True), - op_record("logaddexp", 2, float_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"], - tolerance={onp.float64: 1e-12}, inexact=True), - op_record("logaddexp2", 2, float_dtypes, all_shapes, - jtu.rand_some_inf_and_nan, ["rev"], - tolerance={onp.float16: 1e-2}, inexact=True), - op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes, - jtu.rand_default, [], check_dtypes=False, - tolerance={onp.float16: 1e-2, onp.float64: 1e-12}, - check_incomplete_shape=False), - op_record("positive", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("power", 2, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - tolerance={onp.complex128: 1e-14}), - op_record("rad2deg", 1, float_dtypes, all_shapes, jtu.rand_default, [], - tolerance={onp.float64: 5e-6}), - op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf, []), - op_record("remainder", 2, minus(default_dtypes, [onp.float16]), all_shapes, - jtu.rand_nonzero, [], tolerance={onp.float16: 1e-2}), - op_record("mod", 2, minus(default_dtypes, [onp.float16]), all_shapes, - jtu.rand_nonzero, []), - op_record("sinc", 1, [t for t in number_dtypes if t != lnp.bfloat16], - all_shapes, jtu.rand_default, ["rev"], - tolerance={onp.complex64: 1e-5}, inexact=True, - check_dtypes=False), - op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), - op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive, ["rev"], - inexact=True), - op_record("transpose", 1, all_dtypes, all_shapes, jtu.rand_default, ["rev"], - check_dtypes=False), - op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero, - ["rev"], inexact=True), - op_record("diff", 1, number_dtypes, nonzerodim_shapes, jtu.rand_default, - ["rev"], check_incomplete_shape=False), -] - -JAX_BITWISE_OP_RECORDS = [ - op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), - op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), - op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), - op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes, - jtu.rand_default, []), -] - -JAX_REDUCER_RECORDS = [ - op_record("mean", 1, number_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []), - op_record("nanmean", 1, minus(inexact_dtypes, complex_dtypes), - nonempty_shapes, jtu.rand_some_nan, [], inexact=True), - op_record("nanprod", 1, minus(inexact_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_nan, []), - op_record("nansum", 1, minus(number_dtypes, complex_dtypes), all_shapes, - jtu.rand_some_nan, []), -] - -JAX_REDUCER_NO_DTYPE_RECORDS = [ - op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), - op_record("max", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_default, []), - op_record("min", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_default, []), - op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), -] - -JAX_ARGMINMAX_RECORDS = [ - op_record("argmin", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_some_equal, []), - op_record("argmax", 1, minus(all_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_some_equal, []), -] - -JAX_OPERATOR_OVERLOADS = [ - op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__lt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__gt__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__ge__", 2, default_dtypes, all_shapes, jtu.rand_default, []), - op_record("__pos__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], - tolerance={onp.float32: 2e-4, onp.complex64: 2e-4, onp.complex128: 1e-14}), - op_record("__mod__", 2, minus(default_dtypes, [onp.float16]), all_shapes, jtu.rand_nonzero, [], - tolerance={onp.float16: 1e-1}), - op_record("__floordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], - inexact=True), - op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default, []), - # TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2 - op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default, []), - # TODO(mattjj): investigate these failures - # op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - # op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - # TODO(mattjj): lshift, rshift -] - -JAX_RIGHT_OPERATOR_OVERLOADS = [ - op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive, [], - tolerance={onp.float32: 2e-4, onp.complex64: 1e-3}), - op_record("__rmod__", 2, minus(default_dtypes, [onp.float16]), all_shapes, jtu.rand_nonzero, [], - tolerance={onp.float16: 1e-1}), - op_record("__rfloordiv__", 2, default_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, [], - inexact=True), - # op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default, []), - # op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool, []), - # op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), -] - -numpy_version = tuple(map(int, onp.version.version.split('.'))) -if numpy_version >= (1, 15): - JAX_COMPOUND_OP_RECORDS += [ - op_record("isclose", 2, [t for t in all_dtypes if t != lnp.bfloat16], - all_shapes, jtu.rand_small_positive, []), - op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default, []), - op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default, []), - ] - JAX_REDUCER_NO_DTYPE_RECORDS += [ - op_record("ptp", 1, minus(number_dtypes, complex_dtypes), nonempty_shapes, - jtu.rand_default, []), - ] - -if six.PY2: - JAX_OPERATOR_OVERLOADS += [ - op_record("__div__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - ] - JAX_RIGHT_OPERATOR_OVERLOADS += [ - op_record("__rdiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero, []), - ] - - -CombosWithReplacement = itertools.combinations_with_replacement - - -def _dtypes_are_compatible_for_bitwise_ops(args): - if len(args) <= 1: - return True - is_signed = lambda dtype: lnp.issubdtype(dtype, onp.signedinteger) - width = lambda dtype: lnp.iinfo(dtype).bits - x, y = args - # `lnp.iinfo(dtype).bits` can't be called on bools, so we convert bools to - # ints. - if x == lnp.bool_: - x = lnp.int32 - if y == lnp.bool_: - y = lnp.int32 - if width(x) > width(y): - x, y = y, x - if x == lnp.uint32 and y == lnp.uint64: - return False - # The following condition seems a little ad hoc, but seems to capture what - # numpy actually implements. - return ( - is_signed(x) == is_signed(y) - or (width(x) == 32 and width(y) == 32) - or (width(x) == 32 and width(y) == 64 and is_signed(y))) - - -def _shapes_are_broadcast_compatible(shapes): - accumulator = onp.zeros([]) - for shape in shapes: - try: - accumulator = accumulator + onp.zeros(shape) - except ValueError: - return False - return True - -def _shapes_are_equal_length(shapes): - return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) - - -def _promote_like_lnp(fun, inexact=False): - """Decorator that promotes the arguments of `fun` to `lnp.result_type(*args)`. - - lnp and onp have different type promotion semantics; this decorator allows - tests make an onp reference implementation act more like an lnp - implementation. - """ - def wrapper(*args, **kw): - flat_args = tf.nest.flatten(args) - if inexact and not any( - lnp.issubdtype(lnp.result_type(x).as_numpy_dtype, lnp.inexact) - for x in flat_args): - dtype = lnp.result_type(lnp.float_, *flat_args) - else: - dtype = lnp.result_type(*flat_args) - dtype = dtype.as_numpy_dtype - args = tf.nest.map_structure(lambda a: onp.asarray(a, dtype), args) - return fun(*args, **kw) - return wrapper - - -def new_test(f): - - def wrapper(self, *args, **kwargs): - if not FLAGS.tf_numpy_additional_tests: - self.skipTest("Newly added test is disabled, since flag is False.") - else: - f(self, *args, **kwargs) - - return wrapper - - -def named_parameters(ls): - """A version that allows an empty param list.""" - def noop(_): - def wrapper(self, *args, **kwargs): - self.skipTest("Empty parameter list") - return wrapper - if isinstance(ls, (list, tuple)) and not ls: - return noop - if isinstance(ls, itertools.chain): - try: - first = next(ls) - except StopIteration: - return noop - else: - ls = itertools.chain([first], ls) - return parameterized.named_parameters(ls) - - -# TODO(wangpeng): Enable all disabled tests in this class -class LaxBackedNumpyTests(jtu.TestCase): - """Tests for LAX-backed Numpy implementation.""" - - def _GetArgsMaker(self, rng, shapes, dtypes, onp_arrays=True): - def f(): - out = [rng(shape, dtype or lnp.float_) - for shape, dtype in zip(shapes, dtypes)] - return out if onp_arrays else [lnp.asarray(a) for a in out] - return f - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance, - "inexact": rec.inexact, - "check_incomplete_shape": rec.check_incomplete_shape} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, - JAX_COMPOUND_OP_RECORDS))) - def testOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes, check_dtypes, - tolerance, inexact, check_incomplete_shape): - # TODO(b/147769803): Remove this skipping - if lnp_op.__name__ == "kron" and shapes == ((2, 3, 4), (2, 3, 4)): - self.skipTest("Case disabled because of b/147769803") - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) - tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) - tol = functools.reduce(jtu.join_tolerance, - [tolerance, tol, jtu.default_tolerance()]) - self._CheckAgainstNumpy(_promote_like_lnp(onp_op, inexact), lnp_op, - args_maker, check_dtypes=check_dtypes, tol=tol) - # tf.math.pow doesn't support int32/int64 on XLA (b/169191476). - check_xla = not (lnp_op.__name__ == "power" and set(dtypes).intersection( - (onp.int32, onp.int64))) - self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol, - check_incomplete_shape=check_incomplete_shape, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "tol": rec.tolerance} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_OPERATOR_OVERLOADS)) - def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): - rng = rng_factory() - # onp and lnp arrays have different type promotion rules; force the use of - # lnp arrays. - args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) - fun = lambda *xs: getattr(operator, name.strip('_'))(*xs) - scalar_arg = (jtu.PYTHON_SCALAR_SHAPE in shapes or - jtu.NUMPY_SCALAR_SHAPE in shapes or - () in shapes) - empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) - self._CompileAndCheck( - fun, args_maker, check_dtypes=True, #not scalar_arg and not empty_shape, - atol=tol, rtol=tol) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "op_tolerance": rec.tolerance} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) - def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, - op_tolerance): - if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: - raise SkipTest() # TODO(mattjj): clean up - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) - fun = lambda fst, snd: getattr(snd, name)(fst) - tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) - scalar_arg = (jtu.PYTHON_SCALAR_SHAPE in shapes or - jtu.NUMPY_SCALAR_SHAPE in shapes or - () in shapes) - empty_shape = any(isinstance(s, tuple) and 0 in s for s in shapes) - self._CompileAndCheck( - fun, args_maker, check_dtypes=True, # not scalar_arg and not empty_shape, - atol=tol, rtol=tol) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.test_name, shapes, dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(rec.shapes, rec.nargs)) - for dtypes in filter( - _dtypes_are_compatible_for_bitwise_ops, - CombosWithReplacement(rec.dtypes, rec.nargs))) - for rec in JAX_BITWISE_OP_RECORDS)) - def testBitwiseOp(self, onp_op, lnp_op, rng_factory, shapes, dtypes): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - has_python_scalar = jtu.PYTHON_SCALAR_SHAPE in shapes - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - if onp_op == onp.bitwise_not and has_python_scalar: - # For bitwise_not with a Python `int`, npe.jit may choose a different - # dtype for the `int` from onp's choice, which may result in a different - # result value, so we skip _CompileAndCheck. - return - # Numpy does value-dependent dtype promotion on Python/numpy/array scalars - # which `jit` can't do (when np.result_type is called inside `jit`, tensor - # values are not available), so we skip dtype check in this case. - check_dtypes = not(set(shapes) & set([jtu.NUMPY_SCALAR_SHAPE, - jtu.PYTHON_SCALAR_SHAPE, ()])) - self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, - "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for out_dtype in [None] + rec.dtypes - for axis in set(range(-len(shape), len(shape))) | set([None]) - for keepdims in [False, True]) - for rec in JAX_REDUCER_RECORDS)) - def testReducer(self, onp_op, lnp_op, rng_factory, shape, dtype, out_dtype, - axis, keepdims, inexact): - rng = rng_factory() - def onp_fun(x): - x_cast = x if dtype != lnp.bfloat16 else x.astype(onp.float32) - t = out_dtype if out_dtype != lnp.bfloat16 else onp.float32 - return onp_op(x_cast, axis, dtype=t, keepdims=keepdims) - onp_fun = _promote_like_lnp(onp_fun, inexact) - lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float16: 1e-2, onp.float32: 1e-3, onp.complex64: 1e-3, - onp.float64: 1e-5, onp.complex128: 1e-5} - tol = jtu.tolerance(dtype, tol_spec) - tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=lnp.bfloat16 not in (dtype, out_dtype), - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for rec in JAX_REDUCER_NO_DTYPE_RECORDS - for shape in rec.shapes for dtype in rec.dtypes - for axis in set(range(-len(shape), len(shape))) | set([None]) - for keepdims in [False, True])) - def testReducerNoDtype(self, onp_op, lnp_op, rng_factory, shape, dtype, axis, - keepdims, inexact): - rng = rng_factory() - onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) - onp_fun = _promote_like_lnp(onp_fun, inexact) - lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for shape in all_shapes for dtype in all_dtypes - for axis in set(range(-len(shape), len(shape))) | set([None]))) - def testCountNonzero(self, shape, dtype, axis): - rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.count_nonzero(x, axis) - lnp_fun = lambda x: lnp.count_nonzero(x, axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testNonzero(self, shape, dtype): - rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.nonzero(x) - lnp_fun = lambda x: lnp.nonzero(x) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) - # The shapes of `nonzero`'s results are value-dependent, so `eval_on_shapes` - # won't return concrete shapes. - # Also, `nonzero` requires a known rank. - # Turns off XLA check because there are no XLA kernels for `Where`, which - # XLA can't support because it's output shape is dynamic. - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_eval_on_shapes=False, - check_incomplete_shape=True, check_unknown_rank=False, - check_experimental_compile=False, check_xla_forced_compile=False) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), - "axis": axis} - for rec in JAX_ARGMINMAX_RECORDS - for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) - for axis in range(-len(shape), len(shape)))) - def testArgMinMax(self, onp_op, lnp_op, rng_factory, shape, dtype, axis): - rng = rng_factory() - if dtype == onp.complex128 and jtu.device_under_test() == "gpu": - raise unittest.SkipTest("complex128 reductions not supported on GPU") - - def onp_fun(array_to_reduce): - return onp_op(array_to_reduce, axis).astype(lnp.int_) - - def lnp_fun(array_to_reduce): - return lnp_op(array_to_reduce, axis) - - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), - axes), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "axes": axes, "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for lhs_shape, rhs_shape, axes in [ - [(2,), (2,), (-1, -1, -1, None)], # scalar output - [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors - [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors - [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting - [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes - [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting - [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors - [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting - [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing - [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before - ] - for lhs_dtype, rhs_dtype in CombosWithReplacement( - minus(number_dtypes, complex_dtypes), 2))) - def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - axisa, axisb, axisc, axis = axes - lnp_fun = lambda a, b: lnp.cross(a, b, axisa, axisb, axisc, axis) - def onp_fun(a, b): - a = a.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else a - b = b.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else b - out = onp.cross(a, b, axisa, axisb, axisc, axis) - return out.astype(lnp.promote_types(lhs_dtype, rhs_dtype)) - tol_spec = { - # TODO(wangpeng): dtypes.bfloat16: 3e-1, - onp.float16: 0.15} - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for name, lhs_shape, rhs_shape in [ - ("matrix-scalar", (3, 3), ()), - ("scalar-matrix", (), (3, 3)), - ("matrix-vector", (4, 5), (5,)), - ("vector-matrix", (6,), (6, 4)), - ("matrix-matrix", (3, 4), (4, 5)), - ("tensor-vector", (4, 3, 2), (2,)), - ("vector-tensor", (2,), (3, 2, 4)), - ("tensor-matrix", (4, 3, 2), (2, 5)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-tensor", (2, 3, 4), (5, 4, 1))] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 1e-14, - onp.complex128: 1e-14} - if jtu.device_under_test() == "tpu": - tol[onp.float32] = tol[onp.complex64] = 2e-1 - def onp_dot(x, y): - x = x.astype(onp.float32) if lhs_dtype == lnp.bfloat16 else x - y = y.astype(onp.float32) if rhs_dtype == lnp.bfloat16 else y - # `onp.dot(x, y).dtype` sometimes differs from `onp.result_type(x, y)` - # (e.g. when x is float64[] and y is complex64[3,3], or when x is - # float16[3,3] and y is int64[]). We ignore this corner case and pretend - # that they agree. - return onp.dot(x, y).astype(onp.result_type(x, y)) - self._CheckAgainstNumpy(onp_dot, lnp.dot, args_maker, - check_dtypes=True, tol=tol) - # We disable dtype check in the following cases because `np.dot` does - # value-dependent type promotion in those cases. - check_dtypes = () not in (lhs_shape, rhs_shape) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for name, lhs_shape, rhs_shape in [ - ("vector-vector", (3,), (3,)), - ("matrix-vector", (3, 3), (3,)), - ("vector-matrix", (3,), (3, 3)), - ("matrix-matrix", (3, 3), (3, 3)), - ("vector-tensor", (3,), (5, 3, 2)), - ("tensor-vector", (5, 3, 2), (2,)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-matrix", (5, 2, 3), (3, 2)), - ("tensor-tensor", (5, 3, 4), (5, 4, 1)), - ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - def onp_fun(x, y): - dtype = lnp.promote_types(lhs_dtype, rhs_dtype) - return (onp.matmul(x, y).astype(dtype), - onp.array(x).__matmul__(y).astype(dtype), - onp.array(y).__rmatmul__(x).astype(dtype)) - def lnp_fun(x, y): - return (lnp.matmul(x, y), - lnp.array(x).__matmul__(y), - lnp.array(y).__rmatmul__(x)) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {onp.float16: 1e-2, onp.float32: 2e-2, onp.float64: 1e-12, - onp.complex128: 1e-12} - if jtu.device_under_test() == "tpu": - tol[onp.float32] = tol[onp.complex64] = 4e-2 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=True, tol=tol) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for name, lhs_shape, rhs_shape in [ - ("vector-vector", (3,), (3,)), - ("vector-matrix", (9,), (3, 3)), - ("matrix-matrix", (3, 3), (3, 3)), - ("tensor-vector", (5, 3, 2), (30,))] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - @new_test - def testVDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {onp.float16: 1e-2, onp.float32: 2e-2, onp.float64: 1e-12, - onp.complex128: 1e-12} - self._CheckAgainstNumpy(onp.vdot, lnp.vdot, args_maker, - check_dtypes=True, tol=tol) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - self._CompileAndCheck(lnp.vdot, args_maker, check_dtypes=True, atol=tol, - rtol=tol, check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), - axes), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "axes": axes, "rng_factory": rng_factory} - for rng_factory in [jtu.rand_default] - for lhs_shape, rhs_shape, axes in [ - [(2, 3, 4), (5, 6, 7), 0], # from issue #740 - [(2, 3, 4), (3, 4, 5, 6), 2], - [(2, 3, 4), (5, 4, 3, 6), [1, 2]], - [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], - [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], - ] - for lhs_dtype, rhs_dtype in CombosWithReplacement(number_dtypes, 2))) - def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - lnp_fun = lambda a, b: lnp.tensordot(a, b, axes) - def onp_fun(a, b): - a = a if lhs_dtype != lnp.bfloat16 else a.astype(onp.float32) - b = b if rhs_dtype != lnp.bfloat16 else b.astype(onp.float32) - dtype = lnp.promote_types(lhs_dtype, rhs_dtype) - return onp.tensordot(a, b, axes).astype(dtype) - tol = {onp.float16: 1e-1, onp.float32: 1e-3, onp.float64: 1e-12, - onp.complex64: 1e-3, onp.complex128: 1e-12} - if jtu.device_under_test() == "tpu": - tol[onp.float32] = tol[onp.complex64] = 2e-1 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - # XLA lacks int32/int64 MatMul kernels (b/168657656). - check_xla = not set((lhs_dtype, rhs_dtype)).intersection( - (onp.int32, onp.int64)) - - tol = {onp.float64: 1e-14, onp.float16: 0.04, onp.complex128: 6e-15} - tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol)) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla, - atol = tol, - rtol = tol) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "rng_factory": jtu.rand_default} - # TODO(phawkins): support integer dtypes too. - for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - if len(jtu._dims_of_shape(lhs_shape)) == 0 - or len(jtu._dims_of_shape(rhs_shape)) == 0 - or lhs_shape[-1] == rhs_shape[-1])) - def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - def onp_fun(lhs, rhs): - lhs = lhs if lhs_dtype != lnp.bfloat16 else lhs.astype(onp.float32) - rhs = rhs if rhs_dtype != lnp.bfloat16 else rhs.astype(onp.float32) - dtype = lnp.promote_types(lhs_dtype, rhs_dtype) - return onp.inner(lhs, rhs).astype(dtype) - lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) - tol_spec = {onp.float16: 1e-2, onp.float32: 1e-5, onp.float64: 2e-6} - if jtu.device_under_test() == "tpu": - tol_spec[onp.float32] = tol_spec[onp.complex64] = 2e-1 - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - # TODO(phawkins): there are float32/float64 disagreements for some inputs. - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False, atol=tol, - rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_amin={}_amax={}".format( - jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), - "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, - "rng_factory": jtu.rand_default} - for shape in all_shapes for dtype in minus(number_dtypes, complex_dtypes) - for a_min, a_max in [(-1, None), (None, 1), (-1, 1), - (-onp.ones(1), None), - (None, onp.ones(1)), - (-onp.ones(1), onp.ones(1))])) - def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) - lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float64: 2e-7} - tol = jtu.tolerance(dtype, tol_spec) - is_x32_scalar = (dtype in [onp.int32, onp.float32] and - shape in [jtu.NUMPY_SCALAR_SHAPE, ()]) - # Turns check_dtypes off if is_x32_scalar is True because there is - # a weird promotion inconsistency in numpy: - # ``` - # print(np.result_type(np.ones([], np.int32), 1)) - # print(np.result_type(np.ones([1], np.int32), 1)) - # print(np.result_type(np.int32(1), 1)) - # print(np.result_type(np.int32, 1)) - # print(np.result_type(np.ones([], np.float32), 1)) - # print(np.result_type(np.ones([1], np.float32), 1)) - # print(np.result_type(np.float32(1), 1)) - # print(np.result_type(np.float32, 1)) - # ``` - # >>> - # int64 - # int32 - # int64 - # int32 - # float64 - # float32 - # float64 - # float32 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=not is_x32_scalar, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=not is_x32_scalar, - atol=tol, rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_amin={}_amax={}".format( - jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), - "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, - "rng_factory": jtu.rand_default} - for shape in array_shapes + [jtu.NUMPY_SCALAR_SHAPE] - for dtype in minus(number_dtypes, complex_dtypes) - for a_min, a_max in [(-1, None), (None, 1), (-1, 1), - (-onp.ones(1), None), - (None, onp.ones(1)), - (-onp.ones(1), onp.ones(1))])) - @new_test - def testClipAsMethodStaticBounds( - self, shape, dtype, a_min, a_max, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) - lnp_fun = lambda x: lnp.asarray(x).clip(a_min=a_min, a_max=a_max) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float64: 2e-7} - tol = jtu.tolerance(dtype, tol_spec) - is_x32_scalar = (dtype in [onp.int32, onp.float32] and - shape in [jtu.NUMPY_SCALAR_SHAPE, ()]) - # Turns check_dtypes off if is_x32_scalar is True because there is - # a weird promotion inconsistency in numpy: - # ``` - # print(np.result_type(np.ones([], np.int32), 1)) - # print(np.result_type(np.ones([1], np.int32), 1)) - # print(np.result_type(np.int32(1), 1)) - # print(np.result_type(np.int32, 1)) - # print(np.result_type(np.ones([], np.float32), 1)) - # print(np.result_type(np.ones([1], np.float32), 1)) - # print(np.result_type(np.float32(1), 1)) - # print(np.result_type(np.float32, 1)) - # ``` - # >>> - # int64 - # int32 - # int64 - # int32 - # float64 - # float32 - # float64 - # float32 - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=not is_x32_scalar, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=not is_x32_scalar, - atol=tol, rtol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_decimals={}".format( - jtu.format_shape_dtype_string(shape, dtype), decimals), - "shape": shape, "dtype": dtype, "decimals": decimals, - "rng_factory": jtu.rand_default} - for shape, dtype in _shape_and_dtypes( - all_shapes, minus(number_dtypes, complex_dtypes)) - for decimals in [0, 1, -2])) - def testRoundStaticDecimals(self, shape, dtype, decimals, rng_factory): - rng = rng_factory() - if lnp.issubdtype(dtype, onp.integer) and decimals < 0: - self.skipTest("Integer rounding with decimals < 0 not implemented") - onp_fun = lambda x: onp.round(x, decimals=decimals) - lnp_fun = lambda x: lnp.round(x, decimals=decimals) - args_maker = lambda: [rng(shape, dtype)] - tol = { - # TODO(b/154768983): lnp.bfloat16: 5e-2, - onp.float16: 1e-2} - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=check_dtypes, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol, check_incomplete_shape=True) - - def testOperatorRound(self): - self.assertAllClose(round(onp.float32(7.532), 1), - round(lnp.float32(7.5), 1), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234), 2), - round(lnp.float32(1.234), 2), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234)), - round(lnp.float32(1.234)), check_dtypes=False) - self.assertAllClose(round(onp.float32(7.532), 1), - round(lnp.array(7.5, lnp.float32), 1), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234), 2), - round(lnp.array(1.234, lnp.float32), 2), check_dtypes=True) - self.assertAllClose(round(onp.float32(1.234)), - round(lnp.array(1.234, lnp.float32)), - check_dtypes=False) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format( - jtu.format_shape_dtype_string(shape, dtype), mode, pad_width_rank, - constant_values_rank), - "shape": shape, "dtype": dtype, "mode": mode, - "pad_width_rank": pad_width_rank, - "constant_values_rank": constant_values_rank, - "rng_factory": jtu.rand_default, - "irng_factory": partial(jtu.rand_int, 3)} - for mode, constant_values_rank, shapes in [ - ('constant', 0, all_shapes), - ('constant', 1, all_shapes), - ('constant', 2, all_shapes), - ('symmetric', None, nonempty_shapes), - ('reflect', None, nonempty_shapes), - ('wrap', None, nonempty_shapes), - ] - for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) - for pad_width_rank in range(3))) - @jtu.disable - def testPad(self, shape, dtype, mode, pad_width_rank, constant_values_rank, - rng_factory, irng_factory): - rng = rng_factory() - irng = irng_factory() - pad_width = irng([len(shape), 2][2 - pad_width_rank:], onp.int32) - def onp_fun(x, kwargs): - if pad_width.size == 0: - return x - return onp.pad(x, pad_width, mode=mode, **kwargs) - def lnp_fun(x, kwargs): - return lnp.pad(x, pad_width, mode=mode, **kwargs) - - def args_maker(): - kwargs = {} - if constant_values_rank: - kwargs["constant_values"] = rng( - [len(shape), 2][2 - constant_values_rank:], dtype) - return rng(shape, dtype), kwargs - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape=[{}]_reps={}".format( - jtu.format_shape_dtype_string(shape, dtype), reps), - "shape": shape, "dtype": dtype, "reps": reps, - "rng_factory": jtu.rand_default} - for reps in [(), (2,), (3, 4), (2, 3, 4)] - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - )) - def testTile(self, shape, dtype, reps, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.tile(arg, reps) - lnp_fun = lambda arg: lnp.tile(arg, reps) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {onp.float64: 2e-7} - tol = jtu.tolerance(dtype, tol_spec) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( - axis, ",".join(str(d) for d in base_shape), - ",".join(onp.dtype(dtype).name for dtype in arg_dtypes)), - "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes, - "rng_factory": jtu.rand_default} - for num_arrs in [3] - for arg_dtypes in CombosWithReplacement(default_dtypes, num_arrs) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape)))) - def testConcatenate(self, axis, base_shape, arg_dtypes, rng_factory): - rng = rng_factory() - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def onp_fun(*args): - # TODO(nareshmodi): enable once bfloat16 has better support - # args = [x if x.dtype != bfloat16 else x.astype(onp.float32) - # for x in args] - dtype = functools.reduce(lnp.promote_types, arg_dtypes) - return onp.concatenate(args, axis=axis).astype(dtype) - lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( - axis, ",".join(str(d) for d in base_shape), - ",".join(onp.dtype(dtype).name for dtype in arg_dtypes)), - "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes, - "rng_factory": jtu.rand_default} - for arg_dtypes in CombosWithReplacement(default_dtypes, 2) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape)))) - def testAppend(self, axis, base_shape, arg_dtypes, rng_factory): - rng = rng_factory() - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def onp_fun(arr, values): - arr = arr.astype(onp.float32) if lnp.bfloat16 == arr.dtype else arr - values = ( - values.astype(onp.float32) - if lnp.bfloat16 == values.dtype else values) - out = onp.append(arr, values, axis=axis) - return out.astype(lnp.promote_types(*arg_dtypes)) - lnp_fun = lambda arr, values: lnp.append(arr, values, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape=[{}]_axis={}_repeats={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, repeats), - "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats, - "rng_factory": jtu.rand_default} - for repeats in [0, 1, 2] - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - for axis in [None] + list(range(-len(shape), len(shape))))) - def testRepeat(self, axis, shape, dtype, repeats, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis) - onp_fun = _promote_like_lnp(onp_fun) - lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) - - args_maker = lambda: [rng(shape, dtype)] - - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False) - - def testIssue1233(self): - ''' - Following numpy test suite from `test_repeat` at https://github.com/numpy/numpy/blob/master/numpy/core/tests/test_multiarray.py - ''' - tol = 1e-5 - - def test_single(m, args_maker, repeats, axis): - lax_ans = lnp.repeat(m, repeats, axis) - numpy_ans = onp.repeat(m, repeats, axis) - - self.assertAllClose(lax_ans, numpy_ans, check_dtypes=True, rtol=tol, atol=tol) - - lnp_fun = lambda arg: lnp.repeat(arg, repeats = repeats, axis=axis) - # Turns off XLA check because there are no XLA kernels for `Where` used by - # tf.repeat (b/169192730). - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False, - check_experimental_compile=False, check_xla_forced_compile=False) - - m = lnp.array([1,2,3,4,5,6]) - args_maker = lambda: [m] - - for repeats in [2, [1,3,2,1,1,2], [1,3,0,1,1,2], [2], lnp.array([1,3,2,1,1,2]), lnp.array([2])]: - test_single(m, args_maker, repeats, None) - - m_rect = m.reshape((2,3)) - args_maker = lambda: [m_rect] - - for repeats in [2, [2,1], [2], lnp.array([2,1]), lnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=0) - - for repeats in [2, [1,3,2], [2], lnp.array([1,3,2]), lnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=1) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( - op, jtu.format_shape_dtype_string(shape, dtype), axis, out_dtype), - "axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "rng_factory": jtu.rand_default, "lnp_op": getattr(lnp, op), - "onp_op": getattr(onp, op)} - for op in ["cumsum", "cumprod"] - for dtype in default_dtypes - for out_dtype in default_dtypes - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape))))) - def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp_op(arg, axis=axis, dtype=out_dtype) - lnp_fun = lambda arg: lnp_op(arg, axis=axis, dtype=out_dtype) - - args_maker = lambda: [rng(shape, dtype)] - - tol = max(jtu.tolerance(dtype), jtu.tolerance(out_dtype)) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - # XLA lacks int64 Cumsum/Cumprod kernels (b/168841378). - check_xla = out_dtype != onp.int64 - rtol = None - if out_dtype == onp.float16: - rtol = 2e-3 - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, rtol=rtol, - check_incomplete_shape=True, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_dtype={}_m={}_n={}_k={}".format( - onp.dtype(dtype).name, m, n, k), - "m": m, "n": n, "k": k, "dtype": dtype, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for n in [0, 4] - for m in [None, 0, 1, 3, 4] - for k in list(range(-4, 4)))) - def testTri(self, m, n, k, dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype) - lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_shape={}_k={}".format( - op, jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "op": op, "k": k, - "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for op in ["tril", "triu"] - for k in list(range(-3, 3)))) - def testTriLU(self, dtype, shape, op, k, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: getattr(onp, op)(arg, k=k) - lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - # Incomplete shape support is not implemented at the moment. - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=False) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_ndim={}_n={}".format(ndim, n), - "ndim": ndim, "n": n} - for ndim in [0, 1, 4] - for n in [0, 1, 7])) - def testDiagIndices(self, ndim, n): - onp.testing.assert_equal(onp.diag_indices(n, ndim), - lnp.diag_indices(n, ndim)) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "k": k, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] - for k in list(range(-4, 4)))) - def testDiag(self, shape, dtype, k, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.diag(arg, k) - lnp_fun = lambda arg: lnp.diag(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( - jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2), - "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1, - "axis2": axis2, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in [a for a in range(-len(shape), len(shape)) - if a % len(shape) != axis1 % len(shape)] - for offset in list(range(-4, 4)))) - def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng_factory): - rng = rng_factory() - onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2) - lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n), - "dtype": dtype, "n": n} - for dtype in default_dtypes - for n in list(range(4)))) - def testIdentity(self, n, dtype): - onp_fun = lambda: onp.identity(n, dtype) - lnp_fun = lambda: lnp.identity(n, dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( - jtu.format_shape_dtype_string(shape, dtype), - out_dtype, offset, axis1, axis2), - "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset, - "axis1": axis1, "axis2": axis2, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for out_dtype in [None] + number_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in range(-len(shape), len(shape)) - if (axis1 % len(shape)) != (axis2 % len(shape)) - for offset in list(range(-4, 4)))) - def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng_factory): - rng = rng_factory() - def onp_fun(arg): - if out_dtype == lnp.bfloat16: - return onp.trace(arg, offset, axis1, axis2, onp.float32).astype(lnp.bfloat16) - else: - return onp.trace(arg, offset, axis1, axis2, out_dtype) - lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis), - "shape": shape, "axis": axis, "dtypes": dtypes, "rng_factory": rng_factory} - for dtypes in [ - [onp.float32], - [onp.float32, onp.float32], - [onp.float32, onp.int32, onp.float32], - [onp.float32, onp.int64, onp.float32], - [onp.float32, onp.int32, onp.float64], - ] - for shape in [(), (2,), (3, 4), (1, 100)] - for axis in range(-len(shape), len(shape) + 1) - for rng_factory in [jtu.rand_default])) - def testStack(self, shape, axis, dtypes, rng_factory): - rng = rng_factory() - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - onp_fun = _promote_like_lnp(partial(onp.stack, axis=axis)) - lnp_fun = partial(lnp.stack, axis=axis) - self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_{}".format( - op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)), - "shape": shape, "op": op, "dtypes": dtypes, "rng_factory": rng_factory} - for op in ["hstack", "vstack", "dstack"] - for dtypes in [ - [onp.float32], - [onp.float32, onp.float32], - [onp.float32, onp.int32, onp.float32], - [onp.float32, onp.int64, onp.float32], - [onp.float32, onp.int32, onp.float64], - ] - for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)] - for rng_factory in [jtu.rand_default])) - def testHVDStack(self, shape, op, dtypes, rng_factory): - rng = rng_factory() - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - onp_fun = _promote_like_lnp(getattr(onp, op)) - lnp_fun = getattr(lnp, op) - self._CheckAgainstNumpy(lnp_fun, onp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outdtype={}".format( - jtu.format_shape_dtype_string(shape, fill_value_dtype), - onp.dtype(out_dtype).name if out_dtype else "None"), - "shape": shape, "fill_value_dtype": fill_value_dtype, - "out_dtype": out_dtype, "rng_factory": jtu.rand_default} - for shape in array_shapes + [3, onp.array(7, dtype=onp.int32)] - for fill_value_dtype in default_dtypes - for out_dtype in [None] + default_dtypes)) - def testFull(self, shape, fill_value_dtype, out_dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) - lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) - args_maker = lambda: [rng((), fill_value_dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype), - "onp_op": getattr(onp, op), "lnp_op": getattr(lnp, op), - "shape": shape, "dtype": dtype} - for op in ["zeros", "ones"] - for shape in [2, (), (2,), (3, 0), onp.array((4, 5, 6), dtype=onp.int32), - onp.array(4, dtype=onp.int32)] - for dtype in all_dtypes)) - def testZerosOnes(self, onp_op, lnp_op, shape, dtype): - rng = jtu.rand_default() - def args_maker(): return [] - onp_op = partial(onp_op, shape, dtype) - lnp_op = partial(lnp_op, shape, dtype) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format( - jtu.format_shape_dtype_string(shape, in_dtype), - onp.dtype(fill_value_dtype).name, - onp.dtype(out_dtype).name), - "shape": shape, "in_dtype": in_dtype, - "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, - "rng_factory": jtu.rand_default} - for shape in array_shapes - for in_dtype in default_dtypes - for fill_value_dtype in default_dtypes - for out_dtype in default_dtypes)) - def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype) - lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype) - args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_{}sections".format( - jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), - "shape": shape, "num_sections": num_sections, "axis": axis, - "dtype": dtype, "rng_factory": jtu.rand_default} - for shape, axis, num_sections in [ - ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] - for dtype in default_dtypes)) - def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.split(x, num_sections, axis=axis) - lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_{}sections".format( - jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), - "shape": shape, "num_sections": num_sections, "axis": axis, - "dtype": dtype, "rng_factory": jtu.rand_default} - for shape, axis, num_sections in [ - ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)] - for dtype in default_dtypes)) - def testHVDSplit(self, shape, num_sections, axis, dtype, rng_factory): - rng = rng_factory() - def fn(module, axis): - if axis == 0: - return module.vsplit - elif axis == 1: - return module.hsplit - else: - assert axis == 2 - return module.dsplit - - onp_fun = lambda x: fn(onp, axis)(x, num_sections) - lnp_fun = lambda x: fn(lnp, axis)(x, num_sections) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_order={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype), - order), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "order": order, "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for order in ["C", "F"] - for arg_shape, out_shape in [ - (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), - ((), (1, 1, 1)), - ((7, 0), (0, 42, 101)), - ((3, 4), 12), - ((3, 4), (12,)), - ((3, 4), -1), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ])) - def testReshape(self, arg_shape, out_shape, dtype, order, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.reshape(x, out_shape, order=order) - lnp_fun = lambda x: lnp.reshape(x, out_shape, order=order) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype)), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "rng_factory": jtu.rand_default} - for dtype in default_dtypes - for arg_shape, out_shape in [ - ((7, 0), (0, 42, 101)), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ])) - def testReshapeMethod(self, arg_shape, out_shape, dtype, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.reshape(x, out_shape) - lnp_fun = lambda x: x.reshape(*out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_expanddim={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), dim), - "arg_shape": arg_shape, "dtype": dtype, "dim": dim, - "rng_factory": jtu.rand_default} - for arg_shape in [(), (3,), (3, 4)] - for dtype in default_dtypes - for dim in range(-len(arg_shape)+1, len(arg_shape)))) - def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.expand_dims(x, dim) - lnp_fun = lambda x: lnp.expand_dims(x, dim) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_axes=({},{})".format( - jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), - "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2, - "rng_factory": jtu.rand_default} - for arg_shape, ax1, ax2 in [ - ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), - ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] - for dtype in default_dtypes)) - def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) - lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axes=({},{})".format( - jtu.format_shape_dtype_string(arg_shape, dtype), source, destination), - "arg_shape": arg_shape, "dtype": dtype, "source": source, - "destination": destination, "rng_factory": jtu.rand_default} - for arg_shape, source, destination in [ - (tuple(range(6)), (0, 2), (3, 5)), - (tuple(range(6)), (0, 2), (-1, -3)), - (tuple(range(6)), (-6, -4),(3, 5)), - (tuple(range(6)), (-6, -4), (-1, -3)), - (tuple(range(6)), 0, 4), - (tuple(range(6)), -6, -2), - (tuple(range(6)), tuple(range(6)), tuple(range(6))), - (tuple(range(6)), tuple(range(6)), tuple(reversed(range(6)))), - (tuple(range(6)), (), ()), - ] for dtype in default_dtypes)) - @new_test - def testMoveaxisStaticAxes(self, arg_shape, dtype, source, destination, - rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.moveaxis(x, source, destination) - lnp_fun = lambda x: lnp.moveaxis(x, source, destination) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_axis={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), ax), - "arg_shape": arg_shape, "dtype": dtype, "ax": ax, - "rng_factory": jtu.rand_default} - for arg_shape, ax in [ - ((3, 1), None), - ((3, 1), 1), - ((1, 3, 1), (0, 2)), - ((1, 4, 1), (0,))] - for dtype in default_dtypes)) - def testSqueeze(self, arg_shape, dtype, ax, rng_factory): - rng = rng_factory() - onp_fun = lambda x: onp.squeeze(x, ax) - lnp_fun = lambda x: lnp.squeeze(x, ax) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format( - jtu.format_shape_dtype_string(shape, dtype), - axis, - (None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)), - returned), - "rng_factory": jtu.rand_default, "shape": shape, "dtype": dtype, "axis": axis, - "weights_shape": weights_shape, "returned": returned} - for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes) - for axis in set(range(-len(shape), len(shape))) | set([None]) - # `weights_shape` is either `None`, same as the averaged axis, or same as - # that of the input - for weights_shape in ([None, shape] if axis is None or len(shape) == 1 - else [None, (shape[axis],), shape]) - for returned in [False, True])) - def testAverage(self, shape, dtype, axis, weights_shape, returned, rng_factory): - rng = rng_factory() - if weights_shape is None: - onp_fun = lambda x: onp.average(x, axis, returned=returned) - lnp_fun = lambda x: lnp.average(x, axis, returned=returned) - args_maker = lambda: [rng(shape, dtype)] - else: - onp_fun = lambda x, weights: onp.average(x, axis, weights, returned) - lnp_fun = lambda x, weights: lnp.average(x, axis, weights, returned) - args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)] - onp_fun = _promote_like_lnp(onp_fun, inexact=True) - tol = { - # TODO(b/154768983): lnp.bfloat16: 1e-1, - onp.float16: 1e-1, onp.float32: 1e-3, onp.float64: 2e-7, - onp.complex64: 1e-3, onp.complex128: 1e-10, - } - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - try: - self._CheckAgainstNumpy( - onp_fun, lnp_fun, args_maker, check_dtypes=check_dtypes, tol=tol) - except ZeroDivisionError: - self.skipTest("don't support checking for ZeroDivisionError") - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=check_dtypes, - rtol=tol, atol=tol, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_arg{}_ndmin={}".format(i, ndmin), - "arg": arg, "ndmin": ndmin, "dtype": dtype} - for i, (arg, dtype) in enumerate([ - ([True, False, True], lnp.bool_), - (3., lnp.float_), - ([1, 2, 3], lnp.int_), - ([1., 2., 3.], lnp.float_), - ([[1, 2], [3, 4], [5, 6]], lnp.int_), - ([[1, 2.], [3, 4], [5, 6]], lnp.float_), - ([[1., 2j], [3., 4.], [5., 6.]], lnp.complex_), - ([[3, onp.array(2, dtype=lnp.float_), 1], - onp.arange(3., dtype=lnp.float_)], lnp.float_), - ]) - for ndmin in [None, onp.ndim(arg), onp.ndim(arg) + 1, onp.ndim(arg) + 2])) - def testArray(self, arg, ndmin, dtype): - args_maker = lambda: [arg] - dtype = lnp.canonicalize_dtype(dtype) - if ndmin is not None: - onp_fun = partial(onp.array, ndmin=ndmin, dtype=dtype) - lnp_fun = partial(lnp.array, ndmin=ndmin) - else: - onp_fun = partial(onp.array, dtype=dtype) - lnp_fun = lnp.array - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, - check_incomplete_shape=True, static_argnums=[0]) - - def testIssue121(self): - assert not onp.isscalar(lnp.array(3)) - - @jtu.disable - def testArrayMethod(self): - class arraylike(object): - dtype = onp.float32 - def __array__(self, dtype=None): - return 3. - a = arraylike() - ans = lnp.array(a) - assert ans == 3. - - @jtu.skip_on_devices("tpu") # TODO(b/32368900): TPUs don't support uint8 yet. - @jtu.disable - def testMemoryView(self): - ans = lnp.array(bytearray(b'\x2a')) - self.assertAllClose( - ans, - onp.array([0x2a], dtype=onp.uint8), - check_dtypes=True) - - def testAllClose(self): - rng = onp.random.RandomState(0) - x = rng.randn(2, 2) - y = rng.randn(2) - - def same(list1, list2): - allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) - elements_close = list(map(allclose, list1, list2)) - return lnp.all(lnp.array(elements_close)) - - csame = npe.jit(same) - - a1 = same((x, y), (x, y)) - a2 = csame((x, y), (x, y)) - a3 = csame((x, y), (x, 2 * y)) - - self.assertTrue(a1) - self.assertTrue(a2) - self.assertFalse(a3) - - @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure - @jtu.disable - def testOnesBroadcastingConstantHandler(self): - # TODO(mattjj): update this test for jax3 - self.skipTest("test needs jax3 update") - - def fun(x): - ones = lnp.ones((3, 4)) - assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) - - # To check that the constant handler generates a Broadcast for stride-zero - # arrays, we monkey-patch the client instance. - # TODO(mattjj): once we have better HLO dumping and inspecting facilities, - # we can check the HLO more directly. - c = x._node.c - Broadcast = c.Broadcast # pylint: disable=invalid-name - was_called = [] - c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) - out = x + ones # the ndarray constant handler should call Broadcast here - assert was_called, "Broadcast was not called." - - return out - - fun = api.jit(fun) - out_val = fun(lnp.ones(4)) - self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False) - - def testZeroStridesConstantHandler(self): - raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) - const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) - - def fun(x): - return x * const - - fun = npe.jit(fun) - out_val = fun(3.) - self.assertAllClose(out_val, 3. * const, check_dtypes=False) - - def testIsInstanceNdarrayDuringTracing(self): - arr = onp.ones(3) - - @npe.jit - def f(x): - self.assertIsInstance(x, lnp.ndarray) - return lnp.sum(x) - - f(arr) - - @jtu.disable - def testNonArrayErrorMessage(self): - x = [1., 2.] - y = onp.array([3., 4.]) - - def g(x, y): - return lnp.add(x, y) - - def f(x, y): - return lnp.dot(x, y) - - self.assertRaises(TypeError, lambda: g(x, y)) - self.assertRaises(TypeError, lambda: f(x, y)) - self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) - self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) - - @jtu.disable - def testAbstractionErrorMessage(self): - - @api.jit - def f(x, n): - for _ in range(n): - x = x * x - return x - - self.assertRaises(TypeError, lambda: f(3., 3)) - - @api.jit - def g(x): - if x > 0.: - return x * 2 - else: - return x + 2 - - self.assertRaises(TypeError, lambda: g(3.)) - - @jtu.disable - def testTracingPrimitiveWithNoTranslationErrorMessage(self): - # TODO(mattjj): update this for jax3 - self.skipTest("test needs jax3 update") - foo = lnp._not_implemented(lambda x: x) - - # No error if there's no tracing. - foo(onp.arange(3)) - - cfoo = api.jit(foo) - self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis} - for shape in [(3,), (2, 3)] - for dtype in default_dtypes - for axis in list(range(-len(shape), len(shape))) + [None] # Test negative axes - for rng_factory in [jtu.rand_default])) - def testFlip(self, shape, dtype, axis, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.flip(x, axis) - onp_op = lambda x: onp.flip(x, axis) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype} - for shape in [(3,), (2, 3), (3, 2, 4)] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testFlipud(self, shape, dtype, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.flipud(x) - onp_op = lambda x: onp.flipud(x) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype} - for shape in [(3, 2), (2, 3), (3, 2, 4)] - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testFliplr(self, shape, dtype, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.fliplr(x) - onp_op = lambda x: onp.fliplr(x) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_k={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), k, axes), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k, "axes": axes} - for shape, axes in [ - [(2, 3), (0, 1)], - [(2, 3), (1, 0)], - [(4, 3, 2), (0, 2)], - [(4, 3, 2), (2, 1)], - ] - for k in range(-3, 4) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testRot90(self, shape, dtype, k, axes, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.rot90(x, k, axes) - onp_op = lambda x: onp.rot90(x, k, axes) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_k={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), k, axes), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k, - "axes": axes} - for shape, axes in [ - [(2, 3), (-2, -1)], - [(2, 3), (-2, 1)], - [(4, 3, 2), (-1, -2)], - [(4, 3, 2), (2, -2)], - ] - for k in range(-3, 4) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - @new_test - # These tests are only added as a separate test from testRot90 since we would - # like to measure coverage directly against the existing baseline. Once we - # stop measuring that, we can combine this test with the above. - def testRot90Additional(self, shape, dtype, k, axes, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - lnp_op = lambda x: lnp.rot90(x, k, axes) - onp_op = lambda x: onp.rot90(x, k, axes) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - # TODO(mattjj): test infix operator overrides - - def testRavel(self): - rng = onp.random.RandomState(0) - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True, - check_incomplete_shape=True) - - def testAstype(self): - rng = onp.random.RandomState(0) - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - op = lambda x: x.astype(lnp.int32) - self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True) - self._CompileAndCheck( - op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - # TODO(mattjj): test other ndarray-like method overrides - - def testOnpMean(self): - # from https://github.com/google/jax/issues/125 - x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.) - ans = onp.mean(x) - self.assertAllClose(ans, onp.array(1./3), check_dtypes=False) - - @jtu.disable - def testArangeOnFloats(self): - # from https://github.com/google/jax/issues/145 - expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_) - ans = lnp.arange(0.0, 1.0, 0.1) - self.assertAllClose(expected, ans, check_dtypes=True) - - def testSortManually(self): - - def _test(*args, **kwargs): - - raw_ans = lnp.sort(*args, **kwargs) - fn_ans = npe.jit(lnp.sort, static_argnums=(1,))(*args, **kwargs) - expected = onp.sort(*args, **kwargs) - - self.assertAllClose(expected, raw_ans, check_dtypes=True) - self.assertAllClose(expected, fn_ans, check_dtypes=True) - - # manual tests for sort are nice because we don't have to worry about ties. - # lax.sort is tested combinatorially. - _test(onp.array([16, 15, 23, 42, 8, 4])) - _test(onp.array([[1, 4], [3, 1]]), None) - _test(onp.array([[1, 4], [3, 1]])) - _test(onp.array([[1, 4], [3, 1]]), 0) - - def testArgsortManually(self): - - def _test(*args, **kwargs): - - raw_ans = lnp.argsort(*args, **kwargs) - fn_ans = npe.jit(lnp.argsort, static_argnums=(1,))(*args, **kwargs) - expected = onp.argsort(*args, **kwargs) - - self.assertAllClose(expected, raw_ans, check_dtypes=True) - self.assertAllClose(expected, fn_ans, check_dtypes=True) - - _test(onp.array([16, 15, 23, 42, 8, 4])) - _test(onp.array([[16, 15, 23], [42, 8, 4]]), 0) - _test(onp.array([[16, 15, 23], [42, 8, 4]]), 1) - _test(onp.array([[16, 15, 23], [42, 8, 4]]), None) - _test(onp.array([[16, 15, 23], [42, 8, 4]])) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_shifts={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), - shifts, axis), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "shifts": shifts, - "axis": axis} - for dtype in all_dtypes - for shape in [(3, 4), (3, 4, 5), (7, 4, 0)] - for shifts, axis in [ - (3, None), - (1, 1), - ((3,), (0,)), - ((-2,), (-2,)), - ((1, 2), (0, -1)) - ] - for rng_factory in [jtu.rand_default])) - def testRoll(self, shape, dtype, shifts, axis, rng_factory): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype), onp.array(shifts)] - lnp_op = partial(lnp.roll, axis=axis) - onp_op = partial(onp.roll, axis=axis) - self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_index={}_axis={}_mode={}".format( - jtu.format_shape_dtype_string(shape, dtype), - jtu.format_shape_dtype_string(index_shape, index_dtype), - axis, mode), - "rng_factory": rng_factory, "rng_indices_factory": rng_indices_factory, - "shape": shape, "index_shape": index_shape, "dtype": dtype, - "index_dtype": index_dtype, "axis": axis, "mode": mode} - for shape in [(3,), (3, 4), (3, 4, 5)] - for index_shape in scalar_shapes + [(3,), (2, 1, 3)] - for axis in itertools.chain(range(-len(shape), len(shape)), [None]) - for dtype in all_dtypes - for index_dtype in int_dtypes - for mode in ['wrap', 'clip'] - for rng_factory in [jtu.rand_default] - for rng_indices_factory in [partial(jtu.rand_int, -5, 5)])) - def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode, - rng_factory, rng_indices_factory): - def args_maker(): - x = rng(shape, dtype) - i = rng_indices(index_shape, index_dtype) - return x, i - - rng = rng_factory() - rng_indices = rng_indices_factory() - lnp_op = lambda x, i: lnp.take(x, i, axis=axis, mode=mode) - onp_op = lambda x, i: onp.take(x, i, axis=axis, mode=mode) - self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_ishape={}_axis={}".format( - jtu.format_shape_dtype_string(x_shape, dtype), i_shape, axis), - "rng_factory": rng_factory, "x_shape": x_shape, "i_shape": i_shape, "dtype": dtype, - "axis": axis} - for x_shape, i_shape in filter( - _shapes_are_equal_length, - filter(_shapes_are_broadcast_compatible, - CombosWithReplacement(nonempty_nonscalar_array_shapes, 2))) - for axis in itertools.chain(range(len(x_shape)), [-1], [None]) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testTakeAlongAxis(self, x_shape, i_shape, dtype, axis, rng_factory): - rng = rng_factory() - i_shape = onp.array(i_shape) - if axis is None: - i_shape = [onp.prod(i_shape, dtype=onp.int64)] - else: - # Test the case where the size of the axis doesn't necessarily broadcast. - i_shape[axis] *= 3 - i_shape = list(i_shape) - def args_maker(): - x = rng(x_shape, dtype) - n = onp.prod(x_shape, dtype=onp.int32) if axis is None else x_shape[axis] - i = rng(i_shape, onp.int32) % (2 * n - 1) - (n - 1) - return x, i - - lnp_op = lambda x, i: lnp.take_along_axis(x, i, axis=axis) - - if hasattr(onp, "take_along_axis"): - onp_op = lambda x, i: onp.take_along_axis(x, i, axis=axis) - self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True) - self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True, - check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_n={}_increasing={}".format( - jtu.format_shape_dtype_string([shape], dtype), - n, increasing), - "dtype": dtype, "shape": shape, "n": n, "increasing": increasing, - "rng_factory": jtu.rand_default} - for dtype in inexact_dtypes - for shape in [0, 5] - for n in [2, 4] - for increasing in [False, True])) - def testVander(self, shape, dtype, n, increasing, rng_factory): - rng = rng_factory() - def onp_fun(arg): - arg = arg.astype(onp.float32) if dtype == lnp.bfloat16 else arg - return onp.vander(arg, N=n, increasing=increasing) - lnp_fun = lambda arg: lnp.vander(arg, N=n, increasing=increasing) - args_maker = lambda: [rng([shape], dtype)] - # np.vander seems to return float64 for all floating types. We could obey - # those semantics, but they seem like a bug. - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol={onp.float32: 1e-3}) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=False, check_incomplete_shape=True, - rtol={onp.complex128: 2e-15}) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("nan_to_num", [shape], - [dtype]), - "rng_factory": jtu.rand_some_inf_and_nan, "shape": shape, - "dtype": dtype} - for shape in all_shapes - for dtype in inexact_dtypes)) - @jtu.disable - def testNanToNum(self, rng_factory, shape, dtype): - rng = rng_factory() - dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type - def onp_fun(x): - if dtype == lnp.bfloat16: - x = onp.where(onp.isnan(x), dtype(0), x) - x = onp.where(onp.isposinf(x), lnp.finfo(dtype).max, x) - x = onp.where(onp.isneginf(x), lnp.finfo(dtype).min, x) - return x - else: - return onp.nan_to_num(x).astype(dtype) - - args_maker = lambda: [rng(shape, dtype)] - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(onp_fun, lnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - self._CompileAndCheck(lnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes), - "rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes} - for shapes, dtypes in ( - ((), ()), - (((7,),), (onp.int32,)), - (((3,), (4,)), (onp.int32, onp.int32)), - (((3,), (1,), (4,)), (onp.int32, onp.int32, onp.int32)), - ))) - def testIx_(self, rng_factory, shapes, dtypes): - rng = rng_factory() - args_maker = lambda: [rng(shape, dtype) - for shape, dtype in zip(shapes, dtypes)] - self._CheckAgainstNumpy(onp.ix_, lnp.ix_, args_maker, - check_dtypes=True) - self._CompileAndCheck( - lnp.ix_, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": - "_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}".format( - op, - jtu.format_shape_dtype_string(a_shape, a_dtype), - jtu.format_shape_dtype_string(q_shape, q_dtype), - axis, keepdims), - "a_rng": jtu.rand_default(), "q_rng": q_rng, "op": op, - "a_shape": a_shape, "a_dtype": a_dtype, - "q_shape": q_shape, "q_dtype": q_dtype, "axis": axis, - "keepdims": keepdims} - for (op, q_rng) in ( - ("percentile", jtu.rand_uniform(low=0., high=100.)), - ("quantile", jtu.rand_uniform(low=0., high=1.)), - ("median", jtu.rand_uniform(low=0., high=1.)), - ) - for a_dtype in float_dtypes - for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((4, 101), 1), - ) - for q_dtype in [onp.float32] - for q_shape in scalar_shapes + [(4,)] - for keepdims in [False, True])) - @jtu.disable - def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype, - axis, keepdims): - if op == "quantile" and numpy_version < (1, 15): - raise SkipTest("Numpy < 1.15 does not have np.quantile") - if op == "median": - args_maker = lambda: [a_rng(a_shape, a_dtype)] - else: - args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] - - def onp_fun(*args): - args = [x if lnp.result_type(x) != lnp.bfloat16 else - onp.asarray(x, onp.float32) for x in args] - return getattr(onp, op)(*args, axis=axis, keepdims=keepdims) - lnp_fun = partial(getattr(lnp, op), axis=axis, keepdims=keepdims) - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {onp.float32: 2e-4, onp.float64: 5e-6} - tol = max(jtu.tolerance(a_dtype, tol_spec), - jtu.tolerance(q_dtype, tol_spec)) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testWhereOneArgument(self, shape, dtype): - rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.where(x) - lnp_fun = lambda x: lnp.where(x) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) - # Turns off XLA check because there are no XLA kernels for `Where`, which - # XLA can't support because it's output shape is dynamic. - self._CompileAndCheck( - lnp.where, - args_maker, - check_dtypes=True, - check_eval_on_shapes=False, - check_incomplete_shape=True, - check_unknown_rank=False, - check_experimental_compile=False, check_xla_forced_compile=False) - - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format("_".join( - jtu.format_shape_dtype_string(shape, dtype) - for shape, dtype in zip(shapes, dtypes))), - "rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes} - for shapes in filter(_shapes_are_broadcast_compatible, - CombosWithReplacement(all_shapes, 3)) - for dtypes in CombosWithReplacement(all_dtypes, 3))) - def testWhereThreeArgument(self, rng_factory, shapes, dtypes): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng_factory(), shapes, dtypes) - def onp_fun(cond, x, y): - return _promote_like_lnp(partial(onp.where, cond))(x, y) - self._CheckAgainstNumpy(onp_fun, lnp.where, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp.where, args_maker, check_dtypes=True, check_incomplete_shape=True) - - def testWhereScalarPromotion(self): - x = lnp.where(lnp.array([True, False]), 3, - lnp.ones((2,), dtype=lnp.float32)) - self.assertEqual(x.dtype, onp.dtype(onp.float32)) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, - (onp.bool_,) * n + dtypes), - "rng_factory": jtu.rand_default, "shapes": shapes, "dtypes": dtypes} - for n in range(0, 3) - for shapes in filter( - _shapes_are_broadcast_compatible, - CombosWithReplacement(all_shapes, 2 * n + 1)) - for dtypes in CombosWithReplacement(all_dtypes, n + 1))) - def testSelect(self, rng_factory, shapes, dtypes): - rng = rng_factory() - n = len(dtypes) - 1 - def args_maker(): - condlist = [rng(shape, onp.bool_) for shape in shapes[:n]] - choicelist = [rng(shape, dtype) - for shape, dtype in zip(shapes[n:-1], dtypes[:n])] - default = rng(shapes[-1], dtypes[-1]) - return condlist, choicelist, default - # TODO(phawkins): float32/float64 type mismatches - def onp_fun(condlist, choicelist, default): - choicelist = [x if lnp.bfloat16 != lnp.result_type(x) - else x.astype(onp.float32) for x in choicelist] - dtype = lnp.result_type(default, *choicelist).as_numpy_dtype - return onp.select(condlist, - [onp.asarray(x, dtype=dtype) for x in choicelist], - onp.asarray(default, dtype=dtype)) - self._CheckAgainstNumpy(onp_fun, lnp.select, args_maker, - check_dtypes=False) - self._CompileAndCheck(lnp.select, args_maker, check_dtypes=True, - check_incomplete_shape=True, - rtol={onp.float64: 1e-7, onp.complex128: 1e-7}) - - - @jtu.disable - def testIssue330(self): - x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash - self.assertEqual(x[0, 0], 1) - - @jtu.disable - def testScalarDtypePromotion(self): - orig_numpy_result = (1 + onp.eye(1, dtype=onp.float32)).dtype - jax_numpy_result = (1 + lnp.eye(1, dtype=lnp.float32)).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - @jtu.disable - def testSymmetrizeDtypePromotion(self): - x = onp.eye(3, dtype=onp.float32) - orig_numpy_result = ((x + x.T) / 2).dtype - - x = lnp.eye(3, dtype=lnp.float32) - jax_numpy_result = ((x + x.T) / 2).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - @jtu.disable - def testIssue347(self): - # https://github.com/google/jax/issues/347 - def test_fail(x): - x = lnp.sqrt(lnp.sum(x ** 2, axis=1)) - ones = lnp.ones_like(x) - x = lnp.where(x > 0.5, x, ones) - return lnp.sum(x) - - x = lnp.array([[1, 2], [3, 4], [0, 0]], dtype=lnp.float64) - result = api.grad(test_fail)(x) - assert not onp.any(onp.isnan(result)) - - def testIssue453(self): - # https://github.com/google/jax/issues/453 - a = onp.arange(6) + 1 - ans = lnp.reshape(a, (3, 2), order='F') - expected = onp.reshape(a, (3, 2), order='F') - self.assertAllClose(ans, expected, check_dtypes=True) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}".format(op, pytype.__name__), - "pytype": pytype, "dtype": dtype, "op": op} - for pytype, dtype in [(int, lnp.int_), (float, lnp.float_), - (bool, lnp.bool_), (complex, lnp.complex_)] - for op in ["atleast_1d", "atleast_2d", "atleast_3d"])) - def testAtLeastNdLiterals(self, pytype, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 - onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype) - lnp_fun = lambda arg: getattr(lnp, op)(arg) - args_maker = lambda: [pytype(2)] - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - - def testLongLong(self): - self.assertAllClose( - onp.int64(7), npe.jit(lambda x: x)(onp.longlong(7)), check_dtypes=True) - - def testArange(self): - # test cases inspired by dask tests at - # https://github.com/dask/dask/blob/master/dask/array/tests/test_creation.py#L92 - self.assertAllClose(lnp.arange(77), - onp.arange(77, dtype=lnp.int_), check_dtypes=True) - self.assertAllClose(lnp.arange(2, 13), - onp.arange(2, 13, dtype=lnp.int_), check_dtypes=True) - self.assertAllClose(lnp.arange(4, 21, 9), - onp.arange(4, 21, 9, dtype=lnp.int_), check_dtypes=True) - self.assertAllClose(lnp.arange(53, 5, -3), - onp.arange(53, 5, -3, dtype=lnp.int_), - check_dtypes=True) - # TODO(mattjj): make these tests work when enable_x64=True - self.assertAllClose( - lnp.arange(77, dtype=float), - onp.arange(77, dtype=float), - check_dtypes=True) - self.assertAllClose( - lnp.arange(2, 13, dtype=int), - onp.arange(2, 13, dtype=int), - check_dtypes=True) - self.assertAllClose(lnp.arange(0, 1, -0.5), - onp.arange(0, 1, -0.5, dtype=lnp.float_), - check_dtypes=True) - - self.assertRaises(TypeError, lambda: lnp.arange()) - - # # The following have been disabled since they test JAX specific behavior - # # test that lnp.arange(N) doesn't instantiate an ndarray - # self.assertFalse(type(lnp.arange(77)) == type(onp.arange(77))) - # self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77))) - - # # test that lnp.arange(N, dtype=int32) doesn't instantiate an ndarray - # self.assertFalse(type(lnp.arange(77, dtype=lnp.int32)) == - # type(onp.arange(77, dtype=onp.int32))) - # self.assertTrue(type(lnp.arange(77, dtype=lnp.int32)) == - # type(lax.iota(onp.int32, 77))) - - def testIssue830(self): - a = lnp.arange(4, dtype=lnp.complex64) - self.assertEqual(a.dtype, lnp.complex64) - - def testIssue728(self): - assert lnp.allclose(lnp.eye(5000), onp.eye(5000)) - self.assertEqual(0, onp.sum(lnp.eye(1050) - onp.eye(1050))) - - def testIssue746(self): - lnp.arange(12).reshape(3, 4) # doesn't crash - - def testIssue764(self): - x = lnp.linspace(190, 200, 4) - f = npe.grad(lambda x: lnp.sum(lnp.tanh(x))) - # Expected values computed with autograd in float64 precision. - expected = onp.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171, - 7.66067839e-174], onp.float64) - self.assertAllClose(f(x), expected, check_dtypes=False) - - @jtu.disable - def testIssue776(self): - """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" - def f(u): - y = onp.ones(10,).at[[2, 4, 5]].add(u) - # The transpose rule for lax.tie_in returns a symbolic zero for its first - # argument. - return lax.tie_in(y, 7.) - - self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)), - check_dtypes=True) - - @jtu.disable - def testIssue777(self): - x = lnp.linspace(-200, 0, 4, dtype=onp.float32) - f = npe.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x)))) - self.assertAllClose(f(x), onp.array([0., 0., 0., 0.25], dtype=onp.float32), - check_dtypes=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]), - "dtype": dtype, "op": op} - for dtype in float_dtypes - for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", - "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", - "log", "expm1", "log1p"))) - def testMathSpecialFloatValues(self, op, dtype): - onp_op = getattr(onp, op) - lnp_op = getattr(lnp, op) - dtype = onp.dtype(lnp.canonicalize_dtype(dtype)).type - for x in (onp.nan, -onp.inf, -100., -2., -1., 0., 1., 2., 100., onp.inf, - lnp.finfo(dtype).max, onp.sqrt(lnp.finfo(dtype).max), - onp.sqrt(lnp.finfo(dtype).max) * 2.): - if (op in ("sin", "cos", "tan", "arctan") and - jtu.device_under_test() == "tpu"): - continue # TODO(b/132196789, b/134175194): fix and reenable. - # TODO(b/158006398): fix and reenable. - if (op in ("cosh", "arccosh", "arcsinh", "arcsin", "sinh", "arccos", - "arctan", "arctanh") and dtype == onp.float16): - continue - x = dtype(x) - expected = onp_op(x) - actual = lnp_op(x) - tol = jtu.tolerance(dtype, {onp.float32: 1e-3, onp.float64: 1e-7}) - self.assertAllClose(expected, actual, check_dtypes=True, atol=tol, - rtol=tol) - - def testIssue883(self): - # from https://github.com/google/jax/issues/883 - - @partial(npe.jit, static_argnums=(1,)) - def f(x, v): - return x - - x = lnp.ones((10, 10)) - v = lnp.array([1, 2, 3]) - first_call = f(x, v) - second_call = f(x, v) # doesn't crash - - def testReductionOfOutOfBoundsAxis(self): # Issue 888 - x = lnp.ones((3, 4)) - self.assertRaises( - tf.errors.InvalidArgumentError, lambda: lnp.sum(x, axis=2)) - - @jtu.disable - def testIssue956(self): - self.assertRaises(TypeError, lambda: lnp.ndarray((1, 1))) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype, out_dtype, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims, "rng_factory": rng_factory} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True] - for rng_factory in [jtu.rand_default])) - def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - def onp_fun(x): - out = onp.var(x.astype(lnp.promote_types(onp.float32, dtype)), - axis=axis, ddof=ddof, keepdims=keepdims) - return out.astype(out_dtype) - lnp_fun = partial(lnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) - tol = jtu.tolerance(out_dtype, {onp.float16: 1e-1, onp.float32: 1e-3, - onp.float64: 1e-3, onp.complex128: 1e-6}) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol, - atol=tol, check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( - shape, dtype, rowvar, ddof, bias), - "shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof, - "bias": bias, "rng_factory": rng_factory} - for shape in [(5,), (10, 5), (5, 10)] - for dtype in all_dtypes - for rowvar in [True, False] - for bias in [True, False] - for ddof in [None, 2, 3] - for rng_factory in [jtu.rand_default])) - @jtu.skip_on_devices("gpu") # TODO(b/138003641): test fails on GPU. - @jtu.disable - def testCov(self, shape, dtype, rowvar, ddof, bias, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - onp_fun = partial(onp.cov, rowvar=rowvar, ddof=ddof, bias=bias) - lnp_fun = partial(lnp.cov, rowvar=rowvar, ddof=ddof, bias=bias) - tol = {onp.float32: 1e-5, onp.float64: 1e-13, onp.complex128: 1e-13} - tol = 7e-2 if jtu.device_under_test() == "tpu" else tol - tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) - self._CheckAgainstNumpy( - onp_fun, lnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, atol=tol, - rtol=tol) - - def testIssue967(self): - self.assertRaises(TypeError, lambda: lnp.zeros(1.5)) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_rowvar={}_ddof={}_bias={}".format( - shape, dtype, rowvar, ddof, bias), - "shape": shape, "dtype": dtype, "rowvar": rowvar, "ddof": ddof, - "bias": bias, "rng_factory": rng_factory} - for shape in [(5,), (10, 5), (3, 10)] - for dtype in number_dtypes - for rowvar in [True, False] - for bias in [True, False] - for ddof in [None, 2, 3] - for rng_factory in [jtu.rand_default])) - @jtu.disable - def testCorrCoef(self, shape, dtype, rowvar, ddof, bias, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - mat = onp.asarray([rng(shape, dtype)]) - onp_fun = partial(onp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) - lnp_fun = partial(lnp.corrcoef, rowvar=rowvar, ddof=ddof, bias=bias) - if not onp.any(onp.isclose(onp.std(mat), 0.0)): - self._CheckAgainstNumpy( - onp_fun, lnp_fun, args_maker, check_dtypes=False, - tol=1e-2 if jtu.device_under_test() == "tpu" else None) - self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) - - @named_parameters( - jtu.cases_from_list( - { - "testcase_name": - "_shapes={}_dtype={}_indexing={}_sparse={}".format( - shapes, jtu.dtype_str(dtype), indexing, sparse), - "shapes": - shapes, - "dtype": - dtype, - "indexing": - indexing, - "sparse": - sparse, - "rng_factory": - rng_factory - } for shapes in [(), (5,), (5, 3)] for dtype in number_dtypes - for indexing in ["xy", "ij"] - for sparse in [False] # TODO(nareshmodi): Make sparse work - for rng_factory in [jtu.rand_default])) - def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], - [dtype] * len(shapes)) - onp_fun = partial(onp.meshgrid, indexing=indexing, sparse=sparse) - lnp_fun = partial(lnp.meshgrid, indexing=indexing, sparse=sparse) - self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_retstep={}_dtype={}").format( - start_shape, stop_shape, num, endpoint, retstep, dtype), - "start_shape": start_shape, "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, "retstep": retstep, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - for retstep in [True, False] - for dtype in number_dtypes + [None,] - for rng_factory in [jtu.rand_default])) - def testLinspace(self, start_shape, stop_shape, num, endpoint, - retstep, dtype, rng_factory): - if not endpoint and onp.issubdtype(dtype, onp.integer): - # TODO(b/157597565): Support all dtypes when the tf op supports endpoint - # Currently, subtracting the step early leads to rounding errors for - # integers. - return - rng = rng_factory() - # relax default tolerances slightly - tol = jtu.tolerance(dtype if dtype else onp.float32) * 10 - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(onp.shape(start + stop)) - for axis in range(-ndim, ndim): - lnp_op = lambda start, stop: lnp.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - onp_op = lambda start, stop: onp.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, - check_dtypes=False, tol=tol) - # floating-point compute between jitted platforms and non-jit + rounding - # cause unavoidable variation in integer truncation for some inputs. - if dtype in (inexact_dtypes + [None,]): - self._CompileAndCheck(lnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol, - check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_base={}_dtype={}").format( - start_shape, stop_shape, num, endpoint, base, - dtype.__name__ if dtype else "None"), - "start_shape": start_shape, - "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, "base": base, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - for base in [10.0, 2, onp.e] - for dtype in inexact_dtypes + [None,] - for rng_factory in [jtu.rand_default])) - def testLogspace(self, start_shape, stop_shape, num, - endpoint, base, dtype, rng_factory): - if (dtype in int_dtypes and - jtu.device_under_test() in ("gpu", "tpu") and - not FLAGS.enable_x64): - raise unittest.SkipTest("GPUx32 truncated exponentiation" - " doesn't exactly match other platforms.") - rng = rng_factory() - # relax default tolerances slightly - tol = {onp.float16: 2e-2, onp.float32: 1e-2, onp.float64: 1e-6, - onp.complex64: 1e-3, onp.complex128: 1e-6} - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(onp.shape(start + stop)) - for axis in range(-ndim, ndim): - lnp_op = lambda start, stop: lnp.logspace( - start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) - onp_op = lambda start, stop: onp.logspace( - start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - # Why do compiled and op-by-op float16 np.power numbers differ - # slightly more than expected? - atol = {onp.float16: 1e-2} - self._CompileAndCheck(lnp_op, args_maker, - check_dtypes=False, atol=atol, rtol=tol, - check_incomplete_shape=True) - - @named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_dtype={}").format( - start_shape, stop_shape, num, endpoint, dtype), - "start_shape": start_shape, - "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, - "dtype": dtype, "rng_factory": rng_factory} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - # NB: numpy's geomspace gives nonsense results on integer types - for dtype in inexact_dtypes + [None,] - for rng_factory in [jtu.rand_default])) - def testGeomspace(self, start_shape, stop_shape, num, - endpoint, dtype, rng_factory): - rng = rng_factory() - # relax default tolerances slightly - tol = {onp.float16: 4e-3, onp.float32: 2e-3, onp.complex128: 1e-14} - def args_maker(): - """Test the set of inputs onp.geomspace is well-defined on.""" - start, stop = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype])() - # onp.geomspace can't handle differently ranked tensors - # w. negative numbers! - start, stop = lnp.broadcast_arrays(start, stop) - if dtype in complex_dtypes: - return start, stop - # to avoid NaNs, non-complex start and stop cannot - # differ in sign, elementwise - start = start * lnp.sign(start) * lnp.sign(stop) - return start, stop - start, stop = args_maker() - ndim = len(onp.shape(start + stop)) - for axis in range(-ndim, ndim): - def lnp_op(start, stop): - return lnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, - axis=axis) - def onp_op(start, stop): - start = start.astype(onp.float32) if dtype == lnp.bfloat16 else start - stop = stop.astype(onp.float32) if dtype == lnp.bfloat16 else stop - return onp.geomspace( - start, stop, num, endpoint=endpoint, - dtype=dtype if dtype != lnp.bfloat16 else onp.float32, - axis=axis).astype(dtype) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - self._CompileAndCheck(lnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol, - check_incomplete_shape=True) - - @jtu.disable - def testDisableNumpyRankPromotionBroadcasting(self): - try: - prev_flag = FLAGS.jax_numpy_rank_promotion - FLAGS.jax_numpy_rank_promotion = "allow" - lnp.ones(2) + lnp.ones((1, 2)) # works just fine - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag - - try: - prev_flag = FLAGS.jax_numpy_rank_promotion - FLAGS.jax_numpy_rank_promotion = "raise" - self.assertRaises(ValueError, lambda: lnp.ones(2) + lnp.ones((1, 2))) - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag - - try: - prev_flag = FLAGS.jax_numpy_rank_promotion - FLAGS.jax_numpy_rank_promotion = "warn" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - lnp.ones(2) + lnp.ones((1, 2)) - assert len(w) > 0 - msg = str(w[-1].message) - expected_msg = ("Following NumPy automatic rank promotion for add on " - "shapes (2,) (1, 2).") - self.assertEqual(msg[:len(expected_msg)], expected_msg) - - prev_len = len(w) - lnp.ones(2) + 3 - self.assertEqual(len(w), prev_len) # don't want to warn for scalars - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag - - def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 - @npe.jit - def foo(x): - return lnp.stack(x) - foo(onp.zeros(2)) # doesn't crash - - @npe.jit - def foo(x): - return lnp.concatenate(x) - foo(onp.zeros((2, 2))) # doesn't crash - - @jtu.disable - def testReluGradientConstants(self): - # This is a regression test that verifies that constants associated with the - # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the - # outermost jaxpr. This was producing some large materialized constants for - # every relu activation in a model. - def body(i, xy): - x, y = xy - y = y + jax.grad(lambda z: lnp.sum(lnp.maximum(z, 0.)))(x) - return x, y - - f = lambda y: lax.fori_loop(0, 5, body, (y, y)) - wrapped = linear_util.wrap_init(f) - pv = partial_eval.PartialVal( - (jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit)) - _, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv]) - self.assertFalse( - any(onp.array_equal(x, onp.full((3, 4), 2., dtype=onp.float32)) - for x in consts)) - - @named_parameters( - {"testcase_name": "_from={}_to={}".format(from_shape, to_shape), - "rng_factory": rng_factory, "from_shape": from_shape, "to_shape": to_shape} - for from_shape, to_shape in [ - [(1, 3), (4, 3)], - [(3,), (2, 1, 3)], - [(3,), (3, 3)], - [(1,), (3,)], - ] - for rng_factory in [jtu.rand_default]) - def testBroadcastTo(self, from_shape, to_shape, rng_factory): - rng = rng_factory() - args_maker = self._GetArgsMaker(rng, [from_shape], [onp.float32]) - onp_op = lambda x: onp.broadcast_to(x, to_shape) - lnp_op = lambda x: lnp.broadcast_to(x, to_shape) - self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) - self._CompileAndCheck( - lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True) - - def testBroadcastToIssue1522(self): - self.assertRaisesRegex( - Exception, "Unable to broadcast", - lambda: lnp.broadcast_to(onp.ones((2, 3)), (1, 3))) - - def testBroadcastToIntIssue1548(self): - self.assertAllClose(lnp.broadcast_to(1, (3, 2)), onp.ones((3, 2)), - check_dtypes=False) - - def testBroadcastToOnScalar(self): - self.assertIsInstance(lnp.broadcast_to(10.0, ()), lnp.ndarray) - self.assertIsInstance(onp.broadcast_to(10.0, ()), onp.ndarray) - - @jtu.disable - def testPrecision(self): - - ones_1d = onp.ones((2,)) - ones_2d = onp.ones((2, 2)) - ones_3d = onp.ones((2, 2, 2)) - HIGHEST = lax.Precision.HIGHEST - - jtu.assert_dot_precision(None, lnp.dot, ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.dot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.dot, precision=HIGHEST), - ones_3d, ones_3d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.matmul, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.vdot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.tensordot, axes=2, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.tensordot, axes=(0, 0), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.einsum, 'i,i', precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.einsum, 'ij,ij', precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(lnp.inner, precision=HIGHEST), - ones_1d, ones_1d) - - @named_parameters(jtu.cases_from_list( - {"testcase_name": - "_{}_{}_{}_{}".format( - shape, jtu.dtype_str(key_dtype), jtu.dtype_str(value_dtype), - dimension).replace(" ", ""), - "shape": shape, "key_dtype": key_dtype, "value_dtype": value_dtype, - "dimension": dimension, "rng_factory": rng_factory} - for shape in all_shapes - for key_dtype in minus(number_dtypes, complex_dtypes) - for value_dtype in all_dtypes - for dimension in range(-len(shape), len(shape)) - for rng_factory in [jtu.rand_default])) - @new_test - def testSortKeyValue(self, shape, key_dtype, value_dtype, dimension, - rng_factory): - def onp_ref(keys, values): - idxs = list(onp.ix_(*[onp.arange(d) for d in keys.shape])) - idxs[dimension] = onp.argsort(keys, axis=dimension) - return keys[tuple(idxs)], values[tuple(idxs)] - rng = rng_factory() - args_maker = self._GetArgsMaker( - rng, [shape, shape], [key_dtype, value_dtype]) - op = partial(npe.sort_key_val, dimension=dimension) - self._CheckAgainstNumpy(onp_ref, op, args_maker, - check_dtypes=True) - # sort_key_val requires known rank. - # XLA only has TopKV2 (used by tf.argsort) kernels on those dtypes - # (b/169194137). - check_xla = key_dtype in (onp.uint32, onp.int32, onp.float32, lnp.bfloat16) - self._CompileAndCheck(op, args_maker, check_dtypes=True, - check_incomplete_shape=True, check_unknown_rank=False, - check_experimental_compile=check_xla, - check_xla_forced_compile=check_xla) - - -# Most grad tests are at the lax level (see lax_test.py), but we add some here -# as needed for e.g. particular compound ops of interest. - -GradTestSpec = collections.namedtuple( - "GradTestSpec", - ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"]) -def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): - return GradTestSpec( - op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) - -GRAD_TEST_RECORDS = [ - grad_test_spec(lnp.arcsinh, nargs=1, order=2, - rng_factory=jtu.rand_positive, - dtypes=[onp.float64, onp.complex64], tol=1e-4), - grad_test_spec(lnp.arccosh, nargs=1, order=2, - rng_factory=jtu.rand_positive, - dtypes=[onp.float64, onp.complex64], tol=1e-4), - grad_test_spec(lnp.arctanh, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, -0.9, 0.9), - dtypes=[onp.float64, onp.complex64], tol=1e-4), -] - -GradSpecialValuesTestSpec = collections.namedtuple( - "GradSpecialValuesTestSpec", ["op", "values", "order"]) - -GRAD_SPECIAL_VALUE_TEST_RECORDS = [ - GradSpecialValuesTestSpec(lnp.arcsinh, [0., 1000.], 2), - GradSpecialValuesTestSpec(lnp.arccosh, [1000.], 2), - GradSpecialValuesTestSpec(lnp.arctanh, [0.], 2), - # TODO(wangpeng): Add `GradSpecialValuesTestSpec(lnp.sinc, [0.], 1)` -] - -def num_float_bits(dtype): - return lnp.finfo(dtypes.canonicalize_dtype(dtype)).bits - -class NumpyGradTests(jtu.TestCase): - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.name, shapes, itertools.repeat(dtype)), - "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, - "order": rec.order, "tol": rec.tol} - for shapes in CombosWithReplacement(nonempty_shapes, rec.nargs) - for dtype in rec.dtypes) - for rec in GRAD_TEST_RECORDS)) - @jtu.disable - def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): - rng = rng_factory() - tol = {onp.float32: 1e-1, onp.complex64: 1e-1} - args = tuple(rng(shape, dtype) for shape in shapes) - check_grads(op, args, order, ["fwd", "rev"], tol, tol) - - @named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), - "op": rec.op, "special_value": special_value, "order": rec.order} - for special_value in rec.values) - for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS)) - @jtu.disable - def testOpGradSpecialValue(self, op, special_value, order): - check_grads(op, (special_value,), order, ["fwd", "rev"], - atol={onp.float32: 3e-3}) - - @jtu.disable - def testTakeAlongAxisIssue1521(self): - # https://github.com/google/jax/issues/1521 - idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1)) - - def f(x): - y = x * lnp.arange(3.).reshape((1, 3)) - return lnp.take_along_axis(y, idx, -1).sum() - - check_grads(f, (1.,), order=1) - - -if __name__ == "__main__": - tf.enable_v2_behavior() - lnp.enable_numpy_behavior() - absltest.main() diff --git a/trax/tf_numpy/jax_tests/test_util.py b/trax/tf_numpy/jax_tests/test_util.py deleted file mode 100644 index b12b04676..000000000 --- a/trax/tf_numpy/jax_tests/test_util.py +++ /dev/null @@ -1,902 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2018 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from contextlib import contextmanager -from distutils.util import strtobool -import functools -from functools import partial -import re -import itertools as it -import os -from typing import Dict, Sequence, Union -import sys -import unittest -import warnings -import zlib - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as onp -import numpy.random as npr -import scipy - -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.jax_tests.config import flags, bool_env -import trax.tf_numpy.extensions as npe -import trax.tf_numpy.numpy as tf_np - - -tree_map = tf.nest.map_structure -tree_multimap = tf.nest.map_structure - - -FLAGS = flags.FLAGS - - -# TODO(wangpeng): Remove this flag after broken tests are fixed -flags.DEFINE_bool('enable_x64', - strtobool('False'), - 'Enable 64-bit types to be used.') - - -flags.DEFINE_enum( - 'test_dut', '', - enum_values=['', 'cpu', 'gpu', 'tpu'], - help= - 'Describes the device under test in case special consideration is required.' -) - - -flags.DEFINE_integer( - 'num_generated_cases', - 10, - help='Number of generated cases to test') - - -EPS = 1e-4 - - -# Default dtypes corresponding to Python scalars. -python_scalar_dtypes = { - bool: onp.dtype(onp.bool_), - int: onp.dtype(onp.int_), - float: onp.dtype(onp.float_), - complex: onp.dtype(onp.complex_), -} - - -def _dtype(x): - if isinstance(x, tf.Tensor): - return x.dtype.as_numpy_dtype - return (getattr(x, 'dtype', None) or - onp.dtype(python_scalar_dtypes.get(type(x), None)) or - onp.asarray(x).dtype) - - -def is_sequence(x): - try: - iter(x) - except TypeError: - return False - else: - return True - -_default_tolerance = { - onp.dtype(onp.bool_): 0, - onp.dtype(onp.int8): 0, - onp.dtype(onp.int16): 0, - onp.dtype(onp.int32): 0, - onp.dtype(onp.int64): 0, - onp.dtype(onp.uint8): 0, - onp.dtype(onp.uint16): 0, - onp.dtype(onp.uint32): 0, - onp.dtype(onp.uint64): 0, - # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-2, - onp.dtype(onp.float16): 1e-3, - onp.dtype(onp.float32): 1e-6, - onp.dtype(onp.float64): 1e-15, - onp.dtype(onp.complex64): 1e-6, - onp.dtype(onp.complex128): 1e-15, -} - -def default_tolerance(): - return _default_tolerance - -default_gradient_tolerance = { - # TODO(b/154768983): onp.dtype(dtypes.bfloat16): 1e-1, - onp.dtype(onp.float16): 1e-2, - onp.dtype(onp.float32): 2e-3, - onp.dtype(onp.float64): 1e-5, - onp.dtype(onp.complex64): 1e-3, - onp.dtype(onp.complex128): 1e-5, -} - -def _assert_numpy_allclose(a, b, atol=None, rtol=None): - # TODO(b/154768983): - # a = a.astype(onp.float32) if a.dtype == dtypes.bfloat16 else a - # b = b.astype(onp.float32) if b.dtype == dtypes.bfloat16 else b - kw = {} - if atol: kw["atol"] = atol - if rtol: kw["rtol"] = rtol - onp.testing.assert_allclose(a, b, **kw) - -def tolerance(dtype, tol=None): - tol = {} if tol is None else tol - if not isinstance(tol, dict): - return tol - tol = {onp.dtype(key): value for key, value in tol.items()} - dtype = onp.dtype(dtype) - return tol.get(dtype, default_tolerance()[dtype]) - -def _normalize_tolerance(tol): - tol = tol or 0 - if isinstance(tol, dict): - return {onp.dtype(k): v for k, v in tol.items()} - else: - return {k: tol for k in _default_tolerance} - -def join_tolerance(tol1, tol2): - tol1 = _normalize_tolerance(tol1) - tol2 = _normalize_tolerance(tol2) - out = tol1 - for k, v in tol2.items(): - out[k] = max(v, tol1.get(k, 0)) - return out - -def _assert_numpy_close(a, b, atol=None, rtol=None): - assert a.shape == b.shape - atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol)) - rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol)) - _assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size) - - -def check_eq(xs, ys): - tree_all(tree_multimap(_assert_numpy_allclose, xs, ys)) - - -def check_close(xs, ys, atol=None, rtol=None): - assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol) - tree_all(tree_multimap(assert_close, xs, ys)) - - -def inner_prod(xs, ys): - def contract(x, y): - return onp.real(onp.dot(onp.conj(x).reshape(-1), y.reshape(-1))) - return tree_reduce(onp.add, tree_multimap(contract, xs, ys)) - - -add = partial(tree_multimap, lambda x, y: onp.add(x, y, dtype=_dtype(x))) -sub = partial(tree_multimap, lambda x, y: onp.subtract(x, y, dtype=_dtype(x))) -conj = partial(tree_map, lambda x: onp.conj(x, dtype=_dtype(x))) - -def scalar_mul(xs, a): - return tree_map(lambda x: onp.multiply(x, a, dtype=_dtype(x)), xs) - - -def rand_like(rng, x): - shape = onp.shape(x) - dtype = _dtype(x) - randn = lambda: onp.asarray(rng.randn(*shape), dtype=dtype) - if onp.issubdtype(dtype, onp.complexfloating): - return randn() + dtype.type(1.0j) * randn() - else: - return randn() - - -def numerical_jvp(f, primals, tangents, eps=EPS): - delta = scalar_mul(tangents, eps) - f_pos = f(*add(primals, delta)) - f_neg = f(*sub(primals, delta)) - return scalar_mul(sub(f_pos, f_neg), 0.5 / eps) - - -def _merge_tolerance(tol, default): - if tol is None: - return default - if not isinstance(tol, dict): - return tol - out = default.copy() - for k, v in tol.items(): - out[onp.dtype(k)] = v - return out - -def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS): - atol = _merge_tolerance(atol, default_gradient_tolerance) - rtol = _merge_tolerance(rtol, default_gradient_tolerance) - rng = onp.random.RandomState(0) - tangent = tree_map(partial(rand_like, rng), args) - v_out, t_out = f_jvp(args, tangent) - v_out_expected = f(*args) - t_out_expected = numerical_jvp(f, args, tangent, eps=eps) - # In principle we should expect exact equality of v_out and v_out_expected, - # but due to nondeterminism especially on GPU (e.g., due to convolution - # autotuning) we only require "close". - check_close(v_out, v_out_expected, atol=atol, rtol=rtol) - check_close(t_out, t_out_expected, atol=atol, rtol=rtol) - - -def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS): - atol = _merge_tolerance(atol, default_gradient_tolerance) - rtol = _merge_tolerance(rtol, default_gradient_tolerance) - _rand_like = partial(rand_like, onp.random.RandomState(0)) - v_out, vjpfun = f_vjp(*args) - v_out_expected = f(*args) - check_close(v_out, v_out_expected, atol=atol, rtol=rtol) - tangent = tree_map(_rand_like, args) - tangent_out = numerical_jvp(f, args, tangent, eps=eps) - cotangent = tree_map(_rand_like, v_out) - cotangent_out = conj(vjpfun(conj(cotangent))) - ip = inner_prod(tangent, cotangent_out) - ip_expected = inner_prod(tangent_out, cotangent) - check_close(ip, ip_expected, atol=atol, rtol=rtol) - - -def device_under_test(): - return FLAGS.test_dut - -def if_device_under_test(device_type: Union[str, Sequence[str]], - if_true, if_false): - """Chooses `if_true` of `if_false` based on device_under_test.""" - if device_under_test() in ([device_type] if isinstance(device_type, str) - else device_type): - return if_true - else: - return if_false - -def supported_dtypes(): - if device_under_test() == "tpu": - return {onp.bool_, onp.int32, onp.uint32, dtypes.bfloat16, onp.float32, - onp.complex64} - else: - return {onp.bool_, onp.int8, onp.int16, onp.int32, onp.int64, - onp.uint8, onp.uint16, onp.uint32, onp.uint64, - dtypes.bfloat16, onp.float16, onp.float32, onp.float64, - onp.complex64, onp.complex128} - -def skip_if_unsupported_type(dtype): - if dtype not in supported_dtypes(): - raise unittest.SkipTest( - f"Type {dtype} not supported on {device_under_test()}") - -def skip_on_devices(*disabled_devices): - """A decorator for test methods to skip the test on certain devices.""" - def skip(test_method): - @functools.wraps(test_method) - def test_method_wrapper(self, *args, **kwargs): - device = device_under_test() - if device in disabled_devices: - test_name = getattr(test_method, '__name__', '[unknown test]') - raise unittest.SkipTest( - f"{test_name} not supported on {device.upper()}.") - return test_method(self, *args, **kwargs) - return test_method_wrapper - return skip - - -def skip_on_flag(flag_name, skip_value): - """A decorator for test methods to skip the test when flags are set.""" - def skip(test_method): # pylint: disable=missing-docstring - @functools.wraps(test_method) - def test_method_wrapper(self, *args, **kwargs): - flag_value = getattr(FLAGS, flag_name) - if flag_value == skip_value: - test_name = getattr(test_method, '__name__', '[unknown test]') - raise unittest.SkipTest( - f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}") - return test_method(self, *args, **kwargs) - return test_method_wrapper - return skip - -# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432 -# Delete this code after the minimum jaxlib version is 0.1.46 or greater. -skip_on_mac_linalg_bug = partial( - unittest.skipIf, - (sys.platform == "darwin" and scipy.version.version > "1.1.0" and - lib.version < (0, 1, 46)), - "Test fails on Mac with new scipy (issue #432)") - - -def format_test_name_suffix(opname, shapes, dtypes): - arg_descriptions = (format_shape_dtype_string(shape, dtype) - for shape, dtype in zip(shapes, dtypes)) - return '{}_{}'.format(opname.capitalize(), '_'.join(arg_descriptions)) - - -# We use special symbols, represented as singleton objects, to distinguish -# between NumPy scalars, Python scalars, and 0-D arrays. -class ScalarShape: - def __len__(self): return 0 - def __getitem__(self, i): - raise IndexError(f'index {i} out of range.') -class _NumpyScalar(ScalarShape): pass -class _PythonScalar(ScalarShape): pass -NUMPY_SCALAR_SHAPE = _NumpyScalar() -PYTHON_SCALAR_SHAPE = _PythonScalar() - - -def _dims_of_shape(shape): - """Converts `shape` to a tuple of dimensions.""" - if type(shape) in (list, tuple): - return shape - elif isinstance(shape, ScalarShape): - return () - else: - raise TypeError(type(shape)) - - -def _cast_to_shape(value, shape, dtype): - """Casts `value` to the correct Python type for `shape` and `dtype`.""" - if shape is NUMPY_SCALAR_SHAPE: - # explicitly cast to NumPy scalar in case `value` is a Python scalar. - return onp.dtype(dtype).type(value) - elif shape is PYTHON_SCALAR_SHAPE: - # explicitly cast to Python scalar via https://stackoverflow.com/a/11389998 - return onp.asarray(value).item() - elif type(shape) in (list, tuple): - assert onp.shape(value) == tuple(shape) - return value - else: - raise TypeError(type(shape)) - - -def dtype_str(dtype): - return onp.dtype(dtype).name - - -def format_shape_dtype_string(shape, dtype): - if shape is NUMPY_SCALAR_SHAPE: - return dtype_str(dtype) - elif shape is PYTHON_SCALAR_SHAPE: - return 'py' + dtype_str(dtype) - elif type(shape) in (list, tuple): - shapestr = ','.join(str(dim) for dim in shape) - return '{}[{}]'.format(dtype_str(dtype), shapestr) - elif type(shape) is int: - return '{}[{},]'.format(dtype_str(dtype), shape) - elif isinstance(shape, onp.ndarray): - return '{}[{}]'.format(dtype_str(dtype), shape) - else: - raise TypeError(type(shape)) - - -def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x): - """Produce random values given shape, dtype, scale, and post-processor. - - Args: - rand: a function for producing random values of a given shape, e.g. a - bound version of either onp.RandomState.randn or onp.RandomState.rand. - shape: a shape value as a tuple of positive integers. - dtype: a numpy dtype. - scale: optional, a multiplicative scale for the random values (default 1). - post: optional, a callable for post-processing the random values (default - identity). - - Returns: - An ndarray of the given shape and dtype using random values based on a call - to rand but scaled, converted to the appropriate dtype, and post-processed. - """ - r = lambda: onp.asarray(scale * rand(*_dims_of_shape(shape)), dtype) - if onp.issubdtype(dtype, onp.complexfloating): - vals = r() + 1.0j * r() - else: - vals = r() - return _cast_to_shape(onp.asarray(post(vals), dtype), shape, dtype) - - -def rand_default(scale=3): - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=scale) - - -def rand_nonzero(): - post = lambda x: onp.where(x == 0, onp.array(1, dtype=x.dtype), x) - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=3, post=post) - - -def rand_positive(): - post = lambda x: x + 1 - rand = npr.RandomState(0).rand - return partial(_rand_dtype, rand, scale=2, post=post) - - -def rand_small(): - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=1e-3) - - -def rand_not_small(offset=10.): - post = lambda x: x + onp.where(x > 0, offset, -offset) - randn = npr.RandomState(0).randn - return partial(_rand_dtype, randn, scale=3., post=post) - - -def rand_small_positive(): - rand = npr.RandomState(0).rand - return partial(_rand_dtype, rand, scale=2e-5) - -def rand_uniform(low=0.0, high=1.0): - assert low < high - rand = npr.RandomState(0).rand - post = lambda x: x * (high - low) + low - return partial(_rand_dtype, rand, post=post) - - -def rand_some_equal(): - randn = npr.RandomState(0).randn - rng = npr.RandomState(0) - - def post(x): - x_ravel = x.ravel() - if len(x_ravel) == 0: - return x - flips = rng.rand(*onp.shape(x)) < 0.5 - return onp.where(flips, x_ravel[0], x) - - return partial(_rand_dtype, randn, scale=100., post=post) - - -def rand_some_inf(): - """Return a random sampler that produces infinities in floating types.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - """ - TODO: Complex numbers are not correctly tested - If blocks should be switched in order, and relevant tests should be fixed - """ - def rand(shape, dtype): - """The random sampler function.""" - if not onp.issubdtype(dtype, onp.floating): - # only float types have inf - return base_rand(shape, dtype) - - if onp.issubdtype(dtype, onp.complexfloating): - base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype - out = (rand(shape, base_dtype) + - onp.array(1j, dtype) * rand(shape, base_dtype)) - return _cast_to_shape(out, shape, dtype) - - dims = _dims_of_shape(shape) - posinf_flips = rng.rand(*dims) < 0.1 - neginf_flips = rng.rand(*dims) < 0.1 - - vals = base_rand(shape, dtype) - vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) - vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - -def rand_some_nan(): - """Return a random sampler that produces nans in floating types.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - def rand(shape, dtype): - """The random sampler function.""" - if onp.issubdtype(dtype, onp.complexfloating): - base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype - out = (rand(shape, base_dtype) + - onp.array(1j, dtype) * rand(shape, base_dtype)) - return _cast_to_shape(out, shape, dtype) - - if not onp.issubdtype(dtype, onp.floating): - # only float types have inf - return base_rand(shape, dtype) - - dims = _dims_of_shape(shape) - nan_flips = rng.rand(*dims) < 0.1 - - vals = base_rand(shape, dtype) - vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - -def rand_some_inf_and_nan(): - """Return a random sampler that produces infinities in floating types.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - """ - TODO: Complex numbers are not correctly tested - If blocks should be switched in order, and relevant tests should be fixed - """ - def rand(shape, dtype): - """The random sampler function.""" - if not onp.issubdtype(dtype, onp.floating): - # only float types have inf - return base_rand(shape, dtype) - - if onp.issubdtype(dtype, onp.complexfloating): - base_dtype = onp.real(onp.array(0, dtype=dtype)).dtype - out = (rand(shape, base_dtype) + - onp.array(1j, dtype) * rand(shape, base_dtype)) - return _cast_to_shape(out, shape, dtype) - - dims = _dims_of_shape(shape) - posinf_flips = rng.rand(*dims) < 0.1 - neginf_flips = rng.rand(*dims) < 0.1 - nan_flips = rng.rand(*dims) < 0.1 - - vals = base_rand(shape, dtype) - vals = onp.where(posinf_flips, onp.array(onp.inf, dtype=dtype), vals) - vals = onp.where(neginf_flips, onp.array(-onp.inf, dtype=dtype), vals) - vals = onp.where(nan_flips, onp.array(onp.nan, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - -# TODO(mattjj): doesn't handle complex types -def rand_some_zero(): - """Return a random sampler that produces some zeros.""" - rng = npr.RandomState(1) - base_rand = rand_default() - - def rand(shape, dtype): - """The random sampler function.""" - dims = _dims_of_shape(shape) - zeros = rng.rand(*dims) < 0.5 - - vals = base_rand(shape, dtype) - vals = onp.where(zeros, onp.array(0, dtype=dtype), vals) - - return _cast_to_shape(onp.asarray(vals, dtype=dtype), shape, dtype) - - return rand - - -def rand_int(low, high=None): - randint = npr.RandomState(0).randint - def fn(shape, dtype): - return randint(low, high=high, size=shape, dtype=dtype) - return fn - -def rand_unique_int(): - randchoice = npr.RandomState(0).choice - def fn(shape, dtype): - return randchoice(onp.arange(onp.prod(shape), dtype=dtype), - size=shape, replace=False) - return fn - -def rand_bool(): - rng = npr.RandomState(0) - def generator(shape, dtype): - return _cast_to_shape(rng.rand(*_dims_of_shape(shape)) < 0.5, shape, dtype) - return generator - -def check_raises(thunk, err_type, msg): - try: - thunk() - assert False - except err_type as e: - assert str(e).startswith(msg), "\n{}\n\n{}\n".format(e, msg) - -def check_raises_regexp(thunk, err_type, pattern): - try: - thunk() - assert False - except err_type as e: - assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern) - - -def _iter_eqns(jaxpr): - # TODO(necula): why doesn't this search in params? - for eqn in jaxpr.eqns: - yield eqn - for subjaxpr in core.subjaxprs(jaxpr): - yield from _iter_eqns(subjaxpr) - -def assert_dot_precision(expected_precision, fun, *args): - jaxpr = api.make_jaxpr(fun)(*args) - precisions = [eqn.params['precision'] for eqn in _iter_eqns(jaxpr.jaxpr) - if eqn.primitive == lax.dot_general_p] - for precision in precisions: - msg = "Unexpected precision: {} != {}".format(expected_precision, precision) - assert precision == expected_precision, msg - - -_CACHED_INDICES: Dict[int, Sequence[int]] = {} - -def cases_from_list(xs): - xs = list(xs) - n = len(xs) - k = min(n, FLAGS.num_generated_cases) - # Random sampling for every parameterized test is expensive. Do it once and - # cache the result. - indices = _CACHED_INDICES.get(n) - if indices is None: - rng = npr.RandomState(42) - _CACHED_INDICES[n] = indices = rng.permutation(n) - return [xs[i] for i in indices[:k]] - -def cases_from_gens(*gens): - sizes = [1, 3, 10] - cases_per_size = int(FLAGS.num_generated_cases / len(sizes)) + 1 - for size in sizes: - for i in range(cases_per_size): - yield ('_{}_{}'.format(size, i),) + tuple(gen(size) for gen in gens) - - -def to_np(a): - return tf.nest.map_structure(tf_np.asarray, a) - - -def to_tf_fn(f): - return lambda *args: f(*to_np(args)) - - -class TestCase(parameterized.TestCase): - """Base class for tests including numerical checks and boilerplate.""" - - # copied from jax.test_util - def setUp(self): - super().setUp() - self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - - # copied from jax.test_util - def rng(self): - return self._rng - - # TODO(mattjj): this obscures the error messages from failures, figure out how - # to re-enable it - # def tearDown(self) -> None: - # assert core.reset_trace_state() - - def assertArraysAllClose(self, x, y, check_dtypes, atol=None, rtol=None): - """Assert that x and y are close (up to numerical tolerances).""" - self.assertEqual(x.shape, y.shape) - atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) - rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) - - _assert_numpy_allclose(x, y, atol=atol, rtol=rtol) - - if check_dtypes: - self.assertDtypesMatch(x, y) - - def assertDtypesMatch(self, x, y): - if FLAGS.enable_x64: - self.assertEqual(_dtype(x), _dtype(y)) - - def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None): - """Assert that x and y, either arrays or nested tuples/lists, are close.""" - if isinstance(x, dict): - self.assertIsInstance(y, dict) - self.assertEqual(set(x.keys()), set(y.keys())) - for k in x: - self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol) - elif is_sequence(x) and not hasattr(x, '__array__'): - self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) - self.assertEqual(len(x), len(y)) - for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol) - elif hasattr(x, '__array__') or onp.isscalar(x): - self.assertTrue(hasattr(y, '__array__') or onp.isscalar(y)) - if check_dtypes: - self.assertDtypesMatch(x, y) - x = onp.asarray(x) - y = onp.asarray(y) - self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol) - elif x == y: - return - else: - raise TypeError((type(x), type(y))) - - def assertMultiLineStrippedEqual(self, expected, what): - """Asserts two strings are equal, after stripping each line.""" - ignore_space_re = re.compile(r'\s*\n\s*') - expected_clean = re.sub(ignore_space_re, '\n', expected.strip()) - what_clean = re.sub(ignore_space_re, '\n', what.strip()) - self.assertMultiLineEqual(expected_clean, what_clean, - msg="Found\n{}\nExpecting\n{}".format(what, expected)) - - def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, - check_dtypes=True, tol=None): - args = args_maker() - lax_ans = lax_op(*args) - numpy_ans = numpy_reference_op(*args) - self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, - atol=tol, rtol=tol) - - def _CompileAndCheck(self, - fun, - args_maker, - check_dtypes=True, - rtol=None, - atol=None, - check_eval_on_shapes=True, - check_incomplete_shape=True, - check_unknown_rank=True, - static_argnums=(), - check_experimental_compile=True, - check_xla_forced_compile=True): - """Compiles the function and checks the results. - - Args: - fun: the function to be checked. - args_maker: a callable that returns a tuple which will be used as the - positional arguments. - check_dtypes: whether to check that the result dtypes from non-compiled - and compiled runs agree. - rtol: relative tolerance for allclose assertions. - atol: absolute tolerance for allclose assertions. - check_eval_on_shapes: whether to run `eval_on_shapes` on the function and - check that the result shapes and dtypes are correct. - check_incomplete_shape: whether to check that the function can handle - incomplete shapes (including those with and without a known rank). - check_unknown_rank: (only has effect when check_incomplete_shape is True) - whether to check that the function can handle unknown ranks. - static_argnums: indices of arguments to be treated as static arguments for - `jit` and `eval_on_shapes`. - check_experimental_compile: whether to check compilation with - experimental_compile=True (in addition to compilation without the flag). - check_xla_forced_compile: whether to check compilation with - forced_compile=True (in addition to compilation without the flag). This - flag is different from experimental_compile because it enforces - whole-function compilation while the latter doesn't. TPU requires - whole-function compilation. - """ - args = args_maker() - - for x in args: - if not hasattr(x, 'dtype'): - # If there is a input that doesn't have dtype info, jit and - # eval_on_shapes may pick a different dtype for it than numpy, so we - # skip the dtype check. - check_dtypes = False - - python_ans = fun(*args) - - python_shapes = tf.nest.map_structure(lambda x: onp.shape(x), python_ans) - onp_shapes = tf.nest.map_structure(lambda x: onp.shape(onp.asarray(x)), - python_ans) - self.assertEqual(python_shapes, onp_shapes) - - def check_compile(**kwargs): - # `wrapped_fun` and `python_should_be_executing` are used to check that - # when the jitted function is called the second time, the original Python - # function won't be executed. - def wrapped_fun(*args): - self.assertTrue(python_should_be_executing) - return fun(*args) - - cfun = npe.jit(wrapped_fun, static_argnums=static_argnums, **kwargs) - python_should_be_executing = True - monitored_ans = cfun(*args) - - python_should_be_executing = False - compiled_ans = cfun(*args) - - self.assertAllClose(python_ans, monitored_ans, check_dtypes, atol, rtol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) - - # Run `cfun` with a different set of arguments to check that changing - # arguments won't cause recompilation. - - new_args = args_maker() - - skip_retracing_test = False - for old, new in zip(tf.nest.flatten(args), tf.nest.flatten(new_args)): - if npe.most_precise_int_dtype(old) != npe.most_precise_int_dtype(new): - # If the old and new arguments result in different dtypes (because - # they fall into different value ranges), tf-numpy will retrace, so we - # skip the no-retrace test. - skip_retracing_test = True - - if not skip_retracing_test: - python_should_be_executing = True - new_python_ans = fun(*new_args) - python_should_be_executing = False - compiled_ans = cfun(*new_args) - self.assertAllClose(new_python_ans, compiled_ans, check_dtypes, atol, - rtol) - - check_compile() - if check_experimental_compile: - check_compile(experimental_compile=True) - if check_xla_forced_compile: - check_compile(xla_forced_compile=True) - - if check_eval_on_shapes: - # Check that npe.eval_on_shapes can get complete output shapes given - # complete input shapes. - cfun = npe.eval_on_shapes(fun, static_argnums=static_argnums) - compiled_ans = cfun(*args) - flat_python_ans = tf.nest.flatten(python_ans) - flat_compiled_ans = tf.nest.flatten(compiled_ans) - self.assertEqual(len(flat_python_ans), len(flat_compiled_ans)) - for a, b in zip(flat_python_ans, flat_compiled_ans): - if hasattr(a, 'shape'): - self.assertEqual(a.shape, b.shape) - if check_dtypes and hasattr(a, 'dtype'): - self.assertEqual(tf.as_dtype(a.dtype), b.dtype) - - # If some argument doesn't have a `dtype` attr (e.g. a Python scalar), we - # skip incomplete-shape checks, since shape specs need dtype. It's OK to - # skip since the same incomplete-shape checks will run for []-shaped arrays. - if check_incomplete_shape and all(hasattr(x, 'dtype') for x in args): - # Check partial shapes with known ranks. - # Numpy scalars (created by e.g. np.int32(5)) have `dtype` but not - # `shape`. - if all(hasattr(x, 'shape') for x in args): - specs = [tf.TensorSpec([None] * len(x.shape), x.dtype) for x in args] - cfun = npe.jit( - fun, static_argnums=static_argnums, input_signature=specs) - compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) - - if check_unknown_rank: - # Check unknown ranks. - specs = [tf.TensorSpec(None, x.dtype) for x in args] - cfun = npe.jit( - fun, static_argnums=static_argnums, input_signature=specs) - compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes, atol, rtol) - - def check_grads(self, f, args, atol=None, rtol=None, delta=None): - """Check gradients against finite differences. - - Args: - f: function to check at ``f(*args)``. - args: a list or tuple of argument values. - atol: absolute tolerance for gradient equality. - rtol: relative tolerance for gradient equality. - delta: step size used for finite differences. - """ - if delta is None: - # Optimal stepsize for central difference is O(epsilon^{1/3}). - dtype = tf_np.result_type(*args) - epsilon = onp.finfo(dtype).eps - delta = epsilon ** (1.0 / 3.0) - theoretical, numerical = tf.test.compute_gradient( - to_tf_fn(f), args, delta=delta) - self.assertAllClose(theoretical, numerical, check_dtypes=False, atol=atol, - rtol=rtol) - - -@contextmanager -def ignore_warning(**kw): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", **kw) - yield - - -def disable(_): - - def wrapper(self, *args, **kwargs): - self.skipTest('Test is disabled') - - return wrapper diff --git a/trax/tf_numpy/jax_tests/vmap_test.py b/trax/tf_numpy/jax_tests/vmap_test.py deleted file mode 100644 index b35f78808..000000000 --- a/trax/tf_numpy/jax_tests/vmap_test.py +++ /dev/null @@ -1,167 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -from absl.testing import parameterized - -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy import extensions -import trax.tf_numpy.numpy as tf_np - -from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=g-direct-tensorflow-import - - -class VmapTest(tf.test.TestCase, parameterized.TestCase): - - def test_vmap_in_axes_list(self): - # https://github.com/google/jax/issues/2367 - dictionary = {'a': 5., 'b': tf_np.ones(2)} - x = tf_np.zeros(3) - y = tf_np.arange(3.) - - def f(dct, x, y): - return dct['a'] + dct['b'] + x + y - - out1 = extensions.vmap(f, (None, 0, 0))(dictionary, x, y) - out2 = extensions.vmap(f, [None, 0, 0])(dictionary, x, y) - self.assertAllClose(out1, out2) - - def test_vmap_in_axes_tree_prefix_error(self): - # https://github.com/google/jax/issues/795 - self.assertRaisesRegex( - ValueError, - 'vmap in_axes specification must be a tree prefix of the corresponding ' - r'value, got specification \(0, 0\) for value tree ', - lambda: extensions.vmap(lambda x: x, in_axes=(0, 0))(tf_np.ones(3))) - - def test_vmap_in_axes_leaf_types(self): - with self.assertRaisesRegex(TypeError, - r'vmap in_axes must be an int, None, or .*'): - extensions.vmap( - lambda x: x, in_axes=(tf_np.array([1., 2.]),))( - tf_np.array([1., 2.])) - - def test_vmap_out_axes_leaf_types(self): - with self.assertRaisesRegex(TypeError, - r'vmap out_axes must be an int, None, or .*'): - extensions.vmap( - lambda x: x, out_axes=(tf_np.array([1., 2.]),))( - tf_np.array([1., 2.])) - - def test_vmap_unbatched_object_passthrough_issue_183(self): - # https://github.com/google/jax/issues/183 - fun = lambda f, x: f(x) - vfun = extensions.vmap(fun, (None, 0)) - ans = vfun(lambda x: x + 1, tf_np.arange(3)) - self.assertAllClose(ans, np.arange(1, 4)) - - def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): - # https://github.com/google/jax/issues/705 - with self.assertRaisesRegex( - ValueError, 'vmap must have at least one non-None value in in_axes'): - # If the output is mapped, there must be a non-None in_axes - extensions.vmap(lambda x: x, in_axes=None)(tf_np.array([1., 2.])) - - # Error is: TypeError: only integer scalar arrays can be converted to a - # scalar index - with self.assertRaisesRegex( - ValueError, 'vmap out_axes specification must be a tree prefix of the ' - 'corresponding value.*'): - extensions.vmap( - lambda x: x, in_axes=0, out_axes=(2, 3))( - tf_np.array([1., 2.])) - - def test_vmap_structured_in_axes(self): - a, b, c, d = 2, 3, 4, 5 - k = 6 # batch size - x = np.ones((k, a, b)) # batch axis in different locations - y = np.ones((b, k, c)) - z = np.ones((c, d, k)) - - def foo(tree_arg): - x, (y, z) = tree_arg - return tf_np.dot(x, tf_np.dot(y, z)) - - tree = (x, (y, z)) - vfoo = extensions.vmap(foo, in_axes=((0, (1, 2)),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - Point = collections.namedtuple('Point', ['x', 'y']) - tree = (x, Point(y, z)) - vfoo = extensions.vmap(foo, in_axes=((0, Point(1, 2)),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - def foo2(tree_arg): - x, dct = tree_arg - y, z = dct['a'], dct['b'] - return tf_np.dot(x, tf_np.dot(y, z)) - - tree = (x, {'a': y, 'b': z}) - vfoo = extensions.vmap(foo2, in_axes=((0, {'a': 1, 'b': 2}),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - tree = (x, collections.OrderedDict([('a', y), ('b', z)])) - vfoo = extensions.vmap( - foo2, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),)) - self.assertEqual(vfoo(tree).shape, (6, 2, 5)) - - def test_vmap_out_axes(self): - f = extensions.vmap(lambda x: x, out_axes=0) - inp = tf_np.arange(6).reshape([2, 3]) - self.assertAllClose(inp, f(inp)) - self.assertAllClose([inp, inp], f((inp, inp))) - - f = extensions.vmap(lambda x: x, out_axes=-1) - self.assertAllClose(inp.T, f(inp)) - - f = extensions.vmap(lambda x: x, out_axes=None) - self.assertAllClose(inp[0], f(inp)) - - f = extensions.vmap(lambda x: x, out_axes=([0], (-1, None), {'a': 1})) - a, b, c = f(([inp], (inp, inp), {'a': inp})) - self.assertAllClose([inp], a) - self.assertAllClose((inp.T, inp[0]), b) - self.assertAllClose(inp.T, c['a']) - - def test_negative_axes(self): - x = np.arange(3 * 4 * 5).reshape(3, 4, 5) - self.assertAllClose( - extensions.vmap(tf_np.sum, in_axes=-3)(x), tf_np.sum(x, axis=(1, 2))) - self.assertAllClose( - extensions.vmap(tf_np.sum, in_axes=-2)(x), tf_np.sum(x, axis=(0, 2))) - self.assertAllClose( - extensions.vmap(tf_np.sum, in_axes=-1)(x), tf_np.sum(x, axis=(0, 1))) - - identity = lambda y: y - self.assertAllClose(x, extensions.vmap(identity, in_axes=0, out_axes=-3)(x)) - self.assertAllClose( - x.transpose(1, 0, 2), - extensions.vmap(identity, in_axes=0, out_axes=-2)(x)) - self.assertAllClose( - x.transpose(1, 2, 0), - extensions.vmap(identity, in_axes=0, out_axes=-1)(x)) - - self.assertAllClose( - np.full((5,), 7), - extensions.vmap(lambda *xs: xs, in_axes=(0, None), - out_axes=(0, -1))(np.arange(5), 7)[1]) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - np_math_ops.enable_numpy_methods_on_tensor() - tf.test.main() diff --git a/trax/tf_numpy/numpy/__init__.py b/trax/tf_numpy/numpy/__init__.py deleted file mode 100644 index 2877116d7..000000000 --- a/trax/tf_numpy/numpy/__init__.py +++ /dev/null @@ -1,68 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""NumPy like wrapper for Tensorflow.""" - -# pylint: disable=wildcard-import -# pylint: disable=g-import-not-at-top -# pylint: disable=g-direct-tensorflow-import - -try: - # Note that this import will work in tf-nightly and TF versions 2.4 and - # higher. - from tensorflow.experimental.numpy import * - # TODO(agarwal): get rid of following imports. - from tensorflow.experimental.numpy import random - from tensorflow import bfloat16 - import numpy as onp - from tensorflow.python.ops.numpy_ops.np_dtypes import canonicalize_dtype - from tensorflow.python.ops.numpy_ops.np_dtypes import default_float_type - from tensorflow.python.ops.numpy_ops.np_dtypes import is_allow_float64 - from tensorflow.python.ops.numpy_ops.np_dtypes import set_allow_float64 - - random.DEFAULT_RANDN_DTYPE = onp.float32 -except ImportError: - try: - # Note that this import will work in TF 2.3 and higher. - from tensorflow.python.ops.numpy_ops import * - from tensorflow import bfloat16 - - except ImportError: - # Note that this fallback will be needed for TF 2.2. - from tensorflow import newaxis - - from trax.tf_numpy.numpy_impl import random - - # pylint: disable=wildcard-import - from trax.tf_numpy.numpy_impl.array_ops import * - from trax.tf_numpy.numpy_impl.arrays import * - from trax.tf_numpy.numpy_impl.dtypes import * - from trax.tf_numpy.numpy_impl.math_ops import * - from trax.tf_numpy.numpy_impl.utils import finfo - from trax.tf_numpy.numpy_impl.utils import promote_types - from trax.tf_numpy.numpy_impl.utils import result_type - # pylint: enable=wildcard-import - - max = amax # pylint: disable=redefined-builtin,undefined-variable - min = amin # pylint: disable=redefined-builtin,undefined-variable - round = around # pylint: disable=redefined-builtin,undefined-variable - -try: - from tensorflow.python.ops.numpy_ops.np_config import enable_numpy_behavior - # TODO(b/171429739): This should be moved to every individual file/test. - enable_numpy_behavior() - -except ImportError: - pass diff --git a/trax/tf_numpy/numpy_impl/array_ops.py b/trax/tf_numpy/numpy_impl/array_ops.py deleted file mode 100644 index c47b827b3..000000000 --- a/trax/tf_numpy/numpy_impl/array_ops.py +++ /dev/null @@ -1,1545 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Common array methods.""" -import functools -import math -import numpy as np -import six -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import arrays as arrays_lib -from trax.tf_numpy.numpy_impl import dtypes -from trax.tf_numpy.numpy_impl import utils - - -def empty(shape, dtype=float): # pylint: disable=redefined-outer-name - """Returns an empty array with the specified shape and dtype. - - Args: - shape: A fully defined shape. Could be - NumPy array or a python scalar, - list or tuple of integers, - TensorFlow tensor/ndarray of integer type and - rank <=1. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. - - Returns: - An ndarray. - """ - return zeros(shape, dtype) - - -def empty_like(a, dtype=None): - """Returns an empty array with the shape and possibly type of the input array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the input array. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. - - Returns: - An ndarray. - """ - return zeros_like(a, dtype) - - -def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name - """Returns an ndarray with the given shape and type filled with zeros. - - Args: - shape: A fully defined shape. Could be - NumPy array or a python scalar, - list or tuple of integers, - TensorFlow tensor/ndarray of integer type and - rank <=1. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. - - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if isinstance(shape, arrays_lib.ndarray): - shape = shape.data - return arrays_lib.tensor_to_ndarray(tf.zeros(shape, dtype=dtype)) - - -def zeros_like(a, dtype=None): - """Returns an array of zeros with the shape and type of the input array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the input array. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. - - Returns: - An ndarray. - """ - if isinstance(a, arrays_lib.ndarray): - a = a.data - if dtype is None: - # We need to let utils.result_type decide the dtype, not tf.zeros_like - dtype = utils.result_type(a) - else: - # TF and numpy has different interpretations of Python types such as - # `float`, so we let `utils.result_type` decide. - dtype = utils.result_type(dtype) - dtype = tf.as_dtype(dtype) # Work around b/149877262 - return arrays_lib.tensor_to_ndarray(tf.zeros_like(a, dtype)) - - -def ones(shape, dtype=float): # pylint: disable=redefined-outer-name - """Returns an ndarray with the given shape and type filled with ones. - - Args: - shape: A fully defined shape. Could be - NumPy array or a python scalar, - list or tuple of integers, - TensorFlow tensor/ndarray of integer type and - rank <=1. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. - - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if isinstance(shape, arrays_lib.ndarray): - shape = shape.data - return arrays_lib.tensor_to_ndarray(tf.ones(shape, dtype=dtype)) - - -def ones_like(a, dtype=None): - """Returns an array of ones with the shape and type of the input array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the input array. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. - - Returns: - An ndarray. - """ - if isinstance(a, arrays_lib.ndarray): - a = a.data - if dtype is None: - dtype = utils.result_type(a) - else: - dtype = utils.result_type(dtype) - return arrays_lib.tensor_to_ndarray(tf.ones_like(a, dtype)) - - -@utils.np_doc(np.eye) -def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-docstring - if dtype: - dtype = utils.result_type(dtype) - if not M: - M = N - # Making sure N, M and k are `int` - N = int(N) - M = int(M) - k = int(k) - if k >= M or -k >= N: - # tf.linalg.diag will raise an error in this case - return zeros([N, M], dtype=dtype) - if k == 0: - return arrays_lib.tensor_to_ndarray(tf.eye(N, M, dtype=dtype)) - # We need the precise length, otherwise tf.linalg.diag will raise an error - diag_len = min(N, M) - if k > 0: - if N >= M: - diag_len -= k - elif N + k > M: - diag_len = M - k - elif k <= 0: - if M >= N: - diag_len += k - elif M - k > N: - diag_len = N + k - diagonal_ = tf.ones([diag_len], dtype=dtype) - return arrays_lib.tensor_to_ndarray( - tf.linalg.diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)) - - -def identity(n, dtype=float): - """Returns a square array with ones on the main diagonal and zeros elsewhere. - - Args: - n: number of rows/cols. - dtype: Optional, defaults to float. The type of the resulting ndarray. Could - be a python type, a NumPy type or a TensorFlow `DType`. - - Returns: - An ndarray of shape (n, n) and requested type. - """ - return eye(N=n, M=n, dtype=dtype) - - -def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name - """Returns an array with given shape and dtype filled with `fill_value`. - - Args: - shape: A valid shape object. Could be a native python object or an object - of type ndarray, numpy.ndarray or tf.TensorShape. - fill_value: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the `fill_value`. The type of the - resulting ndarray. Could be a python type, a NumPy type or a TensorFlow - `DType`. - - Returns: - An ndarray. - - Raises: - ValueError: if `fill_value` can not be broadcast to shape `shape`. - """ - fill_value = asarray(fill_value, dtype=dtype) - if utils.isscalar(shape): - shape = tf.reshape(shape, [1]) - return arrays_lib.tensor_to_ndarray(tf.broadcast_to(fill_value.data, shape)) - - -# Using doc only here since np full_like signature doesn't seem to have the -# shape argument (even though it exists in the documentation online). -@utils.np_doc_only(np.full_like) -def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # pylint: disable=missing-docstring,redefined-outer-name - """order, subok and shape arguments mustn't be changed.""" - if order != 'K': - raise ValueError('Non-standard orders are not supported.') - if not subok: - raise ValueError('subok being False is not supported.') - if shape: - raise ValueError('Overriding the shape is not supported.') - - a = asarray(a).data - dtype = dtype or utils.result_type(a) - fill_value = asarray(fill_value, dtype=dtype) - return arrays_lib.tensor_to_ndarray( - tf.broadcast_to(fill_value.data, tf.shape(a))) - - -# TODO(wangpeng): investigate whether we can make `copy` default to False. -# TODO(wangpeng): utils.np_doc can't handle np.array because np.array is a -# builtin function. Make utils.np_doc support builtin functions. -def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name - """Creates an ndarray with the contents of val. - - Args: - val: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - dtype: Optional, defaults to dtype of the `val`. The type of the resulting - ndarray. Could be a python type, a NumPy type or a TensorFlow `DType`. - copy: Determines whether to create a copy of the backing buffer. Since - Tensors are immutable, a copy is made only if val is placed on a different - device than the current one. Even if `copy` is False, a new Tensor may - need to be built to satisfy `dtype` and `ndim`. This is used only if `val` - is an ndarray or a Tensor. - ndmin: The minimum rank of the returned array. - - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if isinstance(val, arrays_lib.ndarray): - result_t = val.data - else: - result_t = val - - if copy and isinstance(result_t, tf.Tensor): - # Note: In eager mode, a copy of `result_t` is made only if it is not on - # the context device. - result_t = tf.identity(result_t) - - if not isinstance(result_t, tf.Tensor): - if not dtype: - dtype = utils.result_type(result_t) - # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because - # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) - # while np.array allows them. We need to convert-then-cast. - def maybe_data(x): - if isinstance(x, arrays_lib.ndarray): - return x.data - return x - - # Handles lists of ndarrays - result_t = tf.nest.map_structure(maybe_data, result_t) - result_t = arrays_lib.convert_to_tensor(result_t) - result_t = tf.cast(result_t, dtype=dtype) - elif dtype: - result_t = tf.cast(result_t, dtype) - ndims = tf.rank(result_t) - - def true_fn(): - old_shape = tf.shape(result_t) - new_shape = tf.concat([tf.ones(ndmin - ndims, tf.int32), old_shape], axis=0) - return tf.reshape(result_t, new_shape) - - result_t = utils.cond(utils.greater(ndmin, ndims), true_fn, lambda: result_t) - return arrays_lib.tensor_to_ndarray(result_t) - - -@utils.np_doc(np.asarray) -def asarray(a, dtype=None): - if dtype: - dtype = utils.result_type(dtype) - if isinstance(a, arrays_lib.ndarray) and (not dtype or dtype == a.dtype): - return a - return array(a, dtype, copy=False) - - -@utils.np_doc(np.asanyarray) -def asanyarray(a, dtype=None): - return asarray(a, dtype) - - -@utils.np_doc(np.ascontiguousarray) -def ascontiguousarray(a, dtype=None): - return array(a, dtype, ndmin=1) - - -# Numerical ranges. -def arange(start, stop=None, step=1, dtype=None): - """Returns `step`-separated values in the range [start, stop). - - Args: - start: Start of the interval. Included in the range. - stop: End of the interval. If not specified, `start` is treated as 0 and - `start` value is used as `stop`. If specified, it is not included in the - range if `step` is integer. When `step` is floating point, it may or may - not be included. - step: The difference between 2 consecutive values in the output range. It is - recommended to use `linspace` instead of using non-integer values for - `step`. - dtype: Optional. Type of the resulting ndarray. Could be a python type, a - NumPy type or a TensorFlow `DType`. If not provided, the largest type of - `start`, `stop`, `step` is used. - - Raises: - ValueError: If step is zero. - """ - if not step: - raise ValueError('step must be non-zero.') - if dtype: - dtype = utils.result_type(dtype) - else: - if stop is None: - dtype = utils.result_type(start, step) - else: - dtype = utils.result_type(start, step, stop) - if step > 0 and ((stop is not None and start > stop) or - (stop is None and start < 0)): - return array([], dtype=dtype) - if step < 0 and ((stop is not None and start < stop) or - (stop is None and start > 0)): - return array([], dtype=dtype) - # TODO(srbs): There are some bugs when start or stop is float type and dtype - # is integer type. - return arrays_lib.tensor_to_ndarray( - tf.cast(tf.range(start, limit=stop, delta=step), dtype=dtype)) - - -@utils.np_doc(np.geomspace) -def geomspace(start, stop, num=50, endpoint=True, dtype=float): # pylint: disable=missing-docstring - if dtype: - dtype = utils.result_type(dtype) - if num < 0: - raise ValueError('Number of samples {} must be non-negative.'.format(num)) - if not num: - return empty([0]) - step = 1. - if endpoint: - if num > 1: - step = tf.pow((stop / start), 1 / (num - 1)) - else: - step = tf.pow((stop / start), 1 / num) - result = tf.cast(tf.range(num), step.dtype) - result = tf.pow(step, result) - result = tf.multiply(result, start) - if dtype: - result = tf.cast(result, dtype=dtype) - return arrays_lib.tensor_to_ndarray(result) - - -# Building matrices. -@utils.np_doc(np.diag) -def diag(v, k=0): # pylint: disable=missing-docstring - """Raises an error if input is not 1- or 2-d.""" - v = asarray(v).data - v_rank = tf.rank(v) - - v.shape.with_rank_at_most(2) - - # TODO(nareshmodi): Consider a utils.Assert version that will fail during - # tracing time if the shape is known. - tf.debugging.Assert( - utils.logical_or(tf.equal(v_rank, 1), tf.equal(v_rank, 2)), [v_rank]) - - def _diag(v, k): - return utils.cond( - tf.equal(tf.size(v), 0), - lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype), - lambda: tf.linalg.diag(v, k=k)) - - def _diag_part(v, k): - v_shape = tf.shape(v) - v, k = utils.cond( - utils.logical_or( - utils.less_equal(k, -1 * utils.getitem(v_shape, 0)), - utils.greater_equal(k, utils.getitem(v_shape, 1)), - ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) - result = tf.linalg.diag_part(v, k=k) - return result - - result = utils.cond( - tf.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) - return utils.tensor_to_ndarray(result) - - -@utils.np_doc(np.diagonal) -def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring - a = asarray(a).data - - maybe_rank = a.shape.rank - if maybe_rank is not None and offset == 0 and ( - axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or - axis2 == -1): - return utils.tensor_to_ndarray(tf.linalg.diag_part(a)) - - a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data - - a_shape = tf.shape(a) - - def _zeros(): # pylint: disable=missing-docstring - return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) - - # All zeros since diag_part doesn't handle all possible k (aka offset). - # Written this way since cond will run shape inference on both branches, - # and diag_part shape inference will fail when offset is out of bounds. - a, offset = utils.cond( - utils.logical_or( - utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)), - utils.greater_equal(offset, utils.getitem(a_shape, -1)), - ), _zeros, lambda: (a, offset)) - - a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset)) - return a - - -def diagflat(v, k=0): - """Returns a 2-d array with flattened `v` as diagonal. - - Args: - v: array_like of any rank. Gets flattened when setting as diagonal. Could be - an ndarray, a Tensor or any object that can be converted to a Tensor using - `tf.convert_to_tensor`. - k: Position of the diagonal. Defaults to 0, the main diagonal. Positive - values refer to diagonals shifted right, negative values refer to - diagonals shifted left. - - Returns: - 2-d ndarray. - """ - v = asarray(v) - return diag(tf.reshape(v.data, [-1]), k) - - -def _promote_dtype(*arrays): - dtype = utils.result_type(*arrays) - return [asarray(a, dtype=dtype) for a in arrays] - - -def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin - """Whether all array elements or those along an axis evaluate to true. - - Casts the array to bool type if it is not already and uses `tf.reduce_all` to - compute the result. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: Optional. Could be an int or a tuple of integers. If not specified, - the reduction is performed over all array indices. - keepdims: If true, retains reduced dimensions with length 1. - - Returns: - An ndarray. Note that unlike NumPy this does not return a scalar bool if - `axis` is None. - """ - a = asarray(a, dtype=bool) - return utils.tensor_to_ndarray( - tf.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims)) - - -def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin - """Whether any element in the entire array or in an axis evaluates to true. - - Casts the array to bool type if it is not already and uses `tf.reduce_any` to - compute the result. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: Optional. Could be an int or a tuple of integers. If not specified, - the reduction is performed over all array indices. - keepdims: If true, retains reduced dimensions with length 1. - - Returns: - An ndarray. Note that unlike NumPy this does not return a scalar bool if - `axis` is None. - """ - a = asarray(a, dtype=bool) - return utils.tensor_to_ndarray( - tf.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims)) - - -def compress(condition, a, axis=None): - """Compresses `a` by selecting values along `axis` with `condition` true. - - Uses `tf.boolean_mask`. - - Args: - condition: 1-d array of bools. If `condition` is shorter than the array - axis (or the flattened array if axis is None), it is padded with False. - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: Optional. Axis along which to select elements. If None, `condition` is - applied on flattened array. - - Returns: - An ndarray. - - Raises: - ValueError: if `condition` is not of rank 1. - """ - condition = asarray(condition, dtype=bool) - a = asarray(a) - - if condition.ndim != 1: - raise ValueError('condition must be a 1-d array.') - # `np.compress` treats scalars as 1-d arrays. - if a.ndim == 0: - a = ravel(a) - - if axis is None: - a = ravel(a) - axis = 0 - - if axis < 0: - axis += a.ndim - - assert axis >= 0 and axis < a.ndim - - # `tf.boolean_mask` requires the first dimensions of array and condition to - # match. `np.compress` pads condition with False when it is shorter. - condition_t = condition.data - a_t = a.data - if condition.shape[0] < a.shape[axis]: - padding = tf.fill([a.shape[axis] - condition.shape[0]], False) - condition_t = tf.concat([condition_t, padding], axis=0) - return utils.tensor_to_ndarray(tf.boolean_mask(tensor=a_t, mask=condition_t, - axis=axis)) - - -def copy(a): - """Returns a copy of the array.""" - return array(a, copy=True) - - -def _maybe_promote_to_int(a): - if tf.as_dtype(a.dtype).is_integer: - # If a is an integer type and its precision is less than that of `int`, - # the output type will be `int`. - output_type = np.promote_types(a.dtype, int) - if output_type != a.dtype: - a = asarray(a, dtype=output_type) - - return a - - -@utils.np_doc(np.cumprod) -def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring - a = asarray(a, dtype=dtype) - - if dtype is None: - a = _maybe_promote_to_int(a) - - # If axis is None, the input is flattened. - if axis is None: - a = ravel(a) - axis = 0 - elif axis < 0: - axis += tf.rank(a.data) - return utils.tensor_to_ndarray(tf.math.cumprod(a.data, axis)) - - -@utils.np_doc(np.cumsum) -def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring - a = asarray(a, dtype=dtype) - - if dtype is None: - a = _maybe_promote_to_int(a) - - # If axis is None, the input is flattened. - if axis is None: - a = ravel(a) - axis = 0 - elif axis < 0: - axis += tf.rank(a.data) - return utils.tensor_to_ndarray(tf.cumsum(a.data, axis)) - - -def imag(a): - """Returns imaginary parts of all elements in `a`. - - Uses `tf.imag`. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - - Returns: - An ndarray with the same shape as `a`. - """ - a = asarray(a) - # TODO(srbs): np.imag returns a scalar if a is a scalar, whereas we always - # return an ndarray. - return utils.tensor_to_ndarray(tf.math.imag(a.data)) - - -_TO_INT64 = 0 -_TO_FLOAT = 1 - - -def _reduce(tf_fn, a, axis=None, dtype=None, keepdims=None, - promote_int=_TO_INT64, tf_bool_fn=None, preserve_bool=False): - """A general reduction function. - - Args: - tf_fn: the TF reduction function. - a: the array to be reduced. - axis: (optional) the axis along which to do the reduction. If None, all - dimensions are reduced. - dtype: (optional) the dtype of the result. - keepdims: (optional) whether to keep the reduced dimension(s). - promote_int: how to promote integer and bool inputs. There are three - choices: (1) _TO_INT64: always promote them to int64 or uint64; (2) - _TO_FLOAT: always promote them to a float type (determined by - dtypes.default_float_type); (3) None: don't promote. - tf_bool_fn: (optional) the TF reduction function for bool inputs. It - will only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s - dtype is `np.bool_` and `preserve_bool` is True. - preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype - is `np.bool_` (some reductions such as np.sum convert bools to - integers, while others such as np.max preserve bools. - - Returns: - An ndarray. - """ - if dtype: - dtype = utils.result_type(dtype) - if keepdims is None: - keepdims = False - a = asarray(a, dtype=dtype) - if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) - and tf_bool_fn is not None): - return utils.tensor_to_ndarray( - tf_bool_fn(input_tensor=a.data, axis=axis, keepdims=keepdims)) - if dtype is None: - dtype = a.dtype - if np.issubdtype(dtype, np.integer) or dtype == np.bool_: - if promote_int == _TO_INT64: - # If a is an integer/bool type and whose bit width is less than 64, - # numpy up-casts it to 64-bit. - if dtype == np.bool_: - is_signed = True - width = 8 # We can use any number here that is less than 64 - else: - is_signed = np.issubdtype(dtype, np.signedinteger) - width = np.iinfo(dtype).bits - if width < 64: - if is_signed: - dtype = np.int64 - else: - dtype = np.uint64 - a = a.astype(dtype) - elif promote_int == _TO_FLOAT: - a = a.astype(dtypes.default_float_type()) - - return utils.tensor_to_ndarray( - tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims)) - - -@utils.np_doc(np.sum) -def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin - return _reduce(tf.reduce_sum, a, axis=axis, dtype=dtype, keepdims=keepdims, - tf_bool_fn=tf.reduce_any) - - -@utils.np_doc(np.prod) -def prod(a, axis=None, dtype=None, keepdims=None): - return _reduce(tf.reduce_prod, a, axis=axis, dtype=dtype, keepdims=keepdims, - tf_bool_fn=tf.reduce_all) - - -@utils.np_doc(np.mean) -def mean(a, axis=None, dtype=None, keepdims=None): - return _reduce(tf.math.reduce_mean, a, axis=axis, dtype=dtype, - keepdims=keepdims, promote_int=_TO_FLOAT) - - -@utils.np_doc(np.amax) -def amax(a, axis=None, keepdims=None): - return _reduce(tf.reduce_max, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=None, tf_bool_fn=tf.reduce_any, preserve_bool=True) - - -@utils.np_doc(np.amin) -def amin(a, axis=None, keepdims=None): - return _reduce(tf.reduce_min, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=None, tf_bool_fn=tf.reduce_all, preserve_bool=True) - - -# TODO(wangpeng): Remove this workaround once b/157232284 is fixed -def _reduce_variance_complex(input_tensor, axis, keepdims): - f = functools.partial(tf.math.reduce_variance, axis=axis, keepdims=keepdims) - return f(tf.math.real(input_tensor)) + f(tf.math.imag(input_tensor)) - - -# TODO(wangpeng): Remove this workaround once b/157232284 is fixed -def _reduce_std_complex(input_tensor, axis, keepdims): - y = _reduce_variance_complex(input_tensor=input_tensor, axis=axis, - keepdims=keepdims) - return tf.math.sqrt(y) - - -@utils.np_doc(np.var) -def var(a, axis=None, keepdims=None): - def f(input_tensor, axis, keepdims): - if input_tensor.dtype in (tf.complex64, tf.complex128): - # A workaround for b/157232284 - fn = _reduce_variance_complex - else: - fn = tf.math.reduce_variance - return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) - return _reduce(f, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=_TO_FLOAT) - - -@utils.np_doc(np.std) -def std(a, axis=None, keepdims=None): - def f(input_tensor, axis, keepdims): - if input_tensor.dtype in (tf.complex64, tf.complex128): - # A workaround for b/157232284 - fn = _reduce_std_complex - else: - fn = tf.math.reduce_std - return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) - return _reduce(f, a, axis=axis, dtype=None, keepdims=keepdims, - promote_int=_TO_FLOAT) - - -@utils.np_doc(np.ravel) -def ravel(a): # pylint: disable=missing-docstring - a = asarray(a) - if a.ndim == 1: - return a - return utils.tensor_to_ndarray(tf.reshape(a.data, [-1])) - - -setattr(arrays_lib.ndarray, 'ravel', ravel) - - -def real(val): - """Returns real parts of all elements in `a`. - - Uses `tf.real`. - - Args: - val: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - - Returns: - An ndarray with the same shape as `a`. - """ - val = asarray(val) - # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always - # return an ndarray. - return utils.tensor_to_ndarray(tf.math.real(val.data)) - - -@utils.np_doc(np.repeat) -def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring - a = asarray(a).data - original_shape = a._shape_as_list() # pylint: disable=protected-access - # Best effort recovery of the shape. - if original_shape is not None and None not in original_shape: - if not original_shape: - original_shape = (repeats,) - else: - repeats_np = np.ravel(np.array(repeats)) - if repeats_np.size == 1: - repeats_np = repeats_np.item() - if axis is None: - original_shape = (repeats_np * np.prod(original_shape),) - else: - original_shape[axis] = repeats_np * original_shape[axis] - else: - if axis is None: - original_shape = (repeats_np.sum(),) - else: - original_shape[axis] = repeats_np.sum() - - repeats = asarray(repeats).data - result = tf.repeat(a, repeats, axis) - result.set_shape(original_shape) - - return utils.tensor_to_ndarray(result) - - -@utils.np_doc(np.around) -def around(a, decimals=0): # pylint: disable=missing-docstring - a = asarray(a) - dtype = a.dtype - factor = math.pow(10, decimals) - if np.issubdtype(dtype, np.inexact): - factor = tf.cast(factor, dtype) - else: - # Use float as the working dtype when a.dtype is exact (e.g. integer), - # because `decimals` can be negative. - float_dtype = dtypes.default_float_type() - a = a.astype(float_dtype).data - factor = tf.cast(factor, float_dtype) - a = tf.multiply(a, factor) - a = tf.round(a) - a = tf.math.divide(a, factor) - return utils.tensor_to_ndarray(a).astype(dtype) - - -round_ = around -setattr(arrays_lib.ndarray, '__round__', around) - - -@utils.np_doc(np.reshape) -def reshape(a, newshape, order='C'): - """order argument can only b 'C' or 'F'.""" - if order not in {'C', 'F'}: - raise ValueError('Unsupported order argument {}'.format(order)) - - a = asarray(a) - if isinstance(newshape, arrays_lib.ndarray): - newshape = newshape.data - if isinstance(newshape, int): - newshape = [newshape] - - if order == 'F': - r = tf.transpose(tf.reshape(tf.transpose(a.data), newshape[::-1])) - else: - r = tf.reshape(a.data, newshape) - - return utils.tensor_to_ndarray(r) - - -def _reshape_method_wrapper(a, *newshape, **kwargs): - order = kwargs.pop('order', 'C') - if kwargs: - raise ValueError('Unsupported arguments: {}'.format(kwargs.keys())) - - if len(newshape) == 1 and not isinstance(newshape[0], int): - newshape = newshape[0] - - return reshape(a, newshape, order=order) - - -def expand_dims(a, axis): - """Expand the shape of an array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: int. axis on which to expand the shape. - - Returns: - An ndarray with the contents and dtype of `a` and shape expanded on axis. - """ - a = asarray(a) - return utils.tensor_to_ndarray(tf.expand_dims(a.data, axis=axis)) - - -def squeeze(a, axis=None): - """Removes single-element axes from the array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axis: scalar or list/tuple of ints. - - TODO(srbs): tf.squeeze throws error when axis is a Tensor eager execution - is enabled. So we cannot allow axis to be array_like here. Fix. - - Returns: - An ndarray. - """ - a = asarray(a) - return utils.tensor_to_ndarray(tf.squeeze(a, axis)) - - -def transpose(a, axes=None): - """Permutes dimensions of the array. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - axes: array_like. A list of ints with length rank(a) or None specifying the - order of permutation. The i'th dimension of the output array corresponds - to axes[i]'th dimension of the `a`. If None, the axes are reversed. - - Returns: - An ndarray. - """ - a = asarray(a) - if axes is not None: - axes = asarray(axes) - return utils.tensor_to_ndarray(tf.transpose(a=a.data, perm=axes)) - - -@utils.np_doc(np.swapaxes) -def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring - a = asarray(a) - - a_rank = tf.rank(a) - if axis1 < 0: - axis1 += a_rank - if axis2 < 0: - axis2 += a_rank - - perm = tf.range(a_rank) - perm = tf.tensor_scatter_nd_update(perm, [[axis1], [axis2]], [axis2, axis1]) - a = tf.transpose(a, perm) - - return utils.tensor_to_ndarray(a) - - -@utils.np_doc(np.moveaxis) -def moveaxis(a, source, destination): # pylint: disable=missing-docstring - """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" - if not source and not destination: - return a - - a = asarray(a).data - - if isinstance(source, int): - source = (source,) - if isinstance(destination, int): - destination = (destination,) - - a_rank = utils._maybe_static(tf.rank(a)) # pylint: disable=protected-access - - def _correct_axis(axis, rank): - if axis < 0: - return axis + rank - return axis - - source = tuple(_correct_axis(axis, a_rank) for axis in source) - destination = tuple(_correct_axis(axis, a_rank) for axis in destination) - - if a.shape.rank is not None: - perm = [i for i in range(a_rank) if i not in source] - for dest, src in sorted(zip(destination, source)): - assert dest <= len(perm) - perm.insert(dest, src) - else: - r = tf.range(a_rank) - - def _remove_indices(a, b): - """Remove indices (`b`) from `a`.""" - items = tf.unstack(tf.sort(tf.stack(b)), num=len(b)) - - i = 0 - result = [] - - for item in items: - result.append(a[i:item]) - i = item + 1 - - result.append(a[i:]) - - return tf.concat(result, 0) - - minus_sources = _remove_indices(r, source) - minus_dest = _remove_indices(r, destination) - - perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources, [a_rank]) - perm = tf.tensor_scatter_nd_update(perm, tf.expand_dims(destination, 1), - source) - a = tf.transpose(a, perm) - - return utils.tensor_to_ndarray(a) - - -def _setitem(arr, index, value): - """Sets the `value` at `index` in the array `arr`. - - This works by replacing the slice at `index` in the tensor with `value`. - Since tensors are immutable, this builds a new tensor using the `tf.concat` - op. Currently, only 0-d and 1-d indices are supported. - - Note that this may break gradients e.g. - - a = tf_np.array([1, 2, 3]) - old_a_t = a.data - - with tf.GradientTape(persistent=True) as g: - g.watch(a.data) - b = a * 2 - a[0] = 5 - g.gradient(b.data, [a.data]) # [None] - g.gradient(b.data, [old_a_t]) # [[2., 2., 2.]] - - Here `d_b / d_a` is `[None]` since a.data no longer points to the same - tensor. - - Args: - arr: array_like. - index: scalar or 1-d integer array. - value: value to set at index. - - Returns: - ndarray - - Raises: - ValueError: if `index` is not a scalar or 1-d array. - """ - # TODO(srbs): Figure out a solution to the gradient problem. - arr = asarray(arr) - index = asarray(index) - if index.ndim == 0: - index = ravel(index) - elif index.ndim > 1: - raise ValueError('index must be a scalar or a 1-d array.') - value = asarray(value, dtype=arr.dtype) - if arr.shape[len(index):] != value.shape: - value = full(arr.shape[len(index):], value) - prefix_t = arr.data[:index.data[0]] - postfix_t = arr.data[index.data[0] + 1:] - if len(index) == 1: - arr._data = tf.concat( # pylint: disable=protected-access - [prefix_t, tf.expand_dims(value.data, 0), postfix_t], 0) - else: - subarray = arr[index.data[0]] - _setitem(subarray, index[1:], value) - arr._data = tf.concat( # pylint: disable=protected-access - [prefix_t, tf.expand_dims(subarray.data, 0), postfix_t], 0) - - -setattr(arrays_lib.ndarray, 'transpose', transpose) -setattr(arrays_lib.ndarray, 'reshape', _reshape_method_wrapper) -setattr(arrays_lib.ndarray, '__setitem__', _setitem) - - -def pad(ary, pad_width, mode, constant_values=0): - """Pads an array. - - Args: - ary: array_like of rank N. Input array. - pad_width: {sequence, array_like, int}. - Number of values padded to the edges of each axis. - ((before_1, after_1), ... (before_N, after_N)) unique pad widths - for each axis. - ((before, after),) yields same before and after pad for each axis. - (pad,) or int is a shortcut for before = after = pad width for all - axes. - mode: string. One of the following string values: - 'constant' - Pads with a constant value. - 'reflect' - Pads with the reflection of the vector mirrored on - the first and last values of the vector along each - axis. - 'symmetric' - Pads with the reflection of the vector mirrored - along the edge of the array. - **NOTE**: The supported list of `mode` does not match that of numpy's. - constant_values: scalar with same dtype as `array`. - Used in 'constant' mode as the pad value. Default is 0. - - - Returns: - An ndarray padded array of rank equal to `array` with shape increased - according to `pad_width`. - - Raises: - ValueError if `mode` is not supported. - """ - if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'): - raise ValueError('Unsupported padding mode: ' + mode) - mode = mode.upper() - ary = asarray(ary) - pad_width = asarray(pad_width, dtype=tf.int32) - return utils.tensor_to_ndarray(tf.pad( - tensor=ary.data, paddings=pad_width.data, mode=mode, - constant_values=constant_values)) - - -@utils.np_doc(np.take) -def take(a, indices, axis=None, out=None, mode='clip'): - """out argument is not supported, and default mode is clip.""" - if out is not None: - raise ValueError('out argument is not supported in take.') - - if mode not in {'raise', 'clip', 'wrap'}: - raise ValueError("Invalid mode '{}' for take".format(mode)) - - a = asarray(a).data - indices = asarray(indices).data - - if axis is None: - a = tf.reshape(a, [-1]) - axis = 0 - - axis_size = tf.shape(a, indices.dtype)[axis] - if mode == 'clip': - indices = tf.clip_by_value(indices, 0, axis_size-1) - elif mode == 'wrap': - indices = tf.math.floormod(indices, axis_size) - else: - raise ValueError("The 'raise' mode to take is not supported.") - - return utils.tensor_to_ndarray(tf.gather(a, indices, axis=axis)) - - -@utils.np_doc_only(np.where) -def where(condition, x=None, y=None): - """Raises ValueError if exactly one of x or y is not None.""" - condition = asarray(condition, dtype=np.bool_) - if x is None and y is None: - return nonzero(condition) - elif x is not None and y is not None: - x, y = _promote_dtype(x, y) - return utils.tensor_to_ndarray(tf.where(condition.data, x.data, y.data)) - raise ValueError('Both x and y must be ndarrays, or both must be None.') - - -@utils.np_doc(np.select) -def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring - if len(condlist) != len(choicelist): - msg = 'condlist must have length equal to choicelist ({} vs {})' - raise ValueError(msg.format(len(condlist), len(choicelist))) - if not condlist: - raise ValueError('condlist must be non-empty') - choices = _promote_dtype(default, *choicelist) - choicelist = choices[1:] - output = choices[0] - # The traversal is in reverse order so we can return the first value in - # choicelist where condlist is True. - for cond, choice in zip(condlist[::-1], choicelist[::-1]): - output = where(cond, choice, output) - return output - - -def shape(a): - """Return the shape of an array. - - Args: - a: array_like. Input array. - - Returns: - Tuple of ints. - """ - a = asarray(a) - return a.shape - - -def ndim(a): - a = asarray(a) - return a.ndim - - -def isscalar(a): - return ndim(a) == 0 - - -def _boundaries_to_sizes(a, boundaries, axis): - """Converting boundaries of splits to sizes of splits. - - Args: - a: the array to be split. - boundaries: the boundaries, as in np.split. - axis: the axis along which to split. - - Returns: - A list of sizes of the splits, as in tf.split. - """ - if axis >= len(a.shape): - raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape)) - total_size = a.shape[axis] - sizes = [] - sizes_sum = 0 - prev = 0 - for i, b in enumerate(boundaries): - size = b - prev - if size < 0: - raise ValueError('The %s-th boundary %s is smaller than the previous ' - 'boundary %s' % (i, b, prev)) - size = min(size, max(0, total_size - sizes_sum)) - sizes.append(size) - sizes_sum += size - prev = b - sizes.append(max(0, total_size - sizes_sum)) - return sizes - - -@utils.np_doc(np.split) -def split(ary, indices_or_sections, axis=0): - ary = asarray(ary) - if not isinstance(indices_or_sections, six.integer_types): - indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) - result = tf.split(ary.data, indices_or_sections, axis=axis) - return [utils.tensor_to_ndarray(a) for a in result] - - -def _split_on_axis(np_fun, axis): - @utils.np_doc(np_fun) - def f(ary, indices_or_sections): - return split(ary, indices_or_sections, axis=axis) - return f - - -vsplit = _split_on_axis(np.vsplit, axis=0) -hsplit = _split_on_axis(np.hsplit, axis=1) -dsplit = _split_on_axis(np.dsplit, axis=2) - - -@utils.np_doc(np.broadcast_to) -def broadcast_to(array, shape): # pylint: disable=redefined-outer-name - return full(shape, array) - - -@utils.np_doc(np.stack) -def stack(arrays, axis=0): - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - return asarray(tf.stack(unwrapped_arrays, axis)) - - -@utils.np_doc(np.hstack) -def hstack(tup): - arrays = [atleast_1d(a) for a in tup] - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - rank = tf.rank(unwrapped_arrays[0]) - return utils.cond(rank == 1, lambda: tf.concat(unwrapped_arrays, axis=0), - lambda: tf.concat(unwrapped_arrays, axis=1)) - - -@utils.np_doc(np.vstack) -def vstack(tup): - arrays = [atleast_2d(a) for a in tup] - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - return tf.concat(unwrapped_arrays, axis=0) - - -@utils.np_doc(np.dstack) -def dstack(tup): - arrays = [atleast_3d(a) for a in tup] - arrays = _promote_dtype(*arrays) # pylint: disable=protected-access - unwrapped_arrays = [ - a.data if isinstance(a, arrays_lib.ndarray) else a for a in arrays - ] - return tf.concat(unwrapped_arrays, axis=2) - - -def _pad_left_to(n, old_shape): - old_shape = asarray(old_shape, dtype=np.int32).data - new_shape = tf.pad( - old_shape, [[tf.math.maximum(n - tf.size(old_shape), 0), 0]], - constant_values=1) - return asarray(new_shape) - - -def _atleast_nd(n, new_shape, *arys): - """Reshape arrays to be at least `n`-dimensional. - - Args: - n: The minimal rank. - new_shape: a function that takes `n` and the old shape and returns the - desired new shape. - *arys: ndarray(s) to be reshaped. - - Returns: - The reshaped array(s). - """ - - def f(x): - # pylint: disable=g-long-lambda - x = asarray(x) - return asarray( - utils.cond( - utils.greater(n, tf.rank(x)), - lambda: reshape(x, new_shape(n, tf.shape(x.data))).data, - lambda: x.data)) - - arys = list(map(f, arys)) - if len(arys) == 1: - return arys[0] - else: - return arys - - -@utils.np_doc(np.atleast_1d) -def atleast_1d(*arys): - return _atleast_nd(1, _pad_left_to, *arys) - - -@utils.np_doc(np.atleast_2d) -def atleast_2d(*arys): - return _atleast_nd(2, _pad_left_to, *arys) - - -@utils.np_doc(np.atleast_3d) -def atleast_3d(*arys): # pylint: disable=missing-docstring - - def new_shape(_, old_shape): - # pylint: disable=g-long-lambda - ndim_ = tf.size(old_shape) - return utils.cond( - ndim_ == 0, lambda: tf.constant([1, 1, 1], dtype=tf.int32), - lambda: utils.cond( - ndim_ == 1, lambda: tf.pad(old_shape, [[1, 1]], constant_values=1), - lambda: tf.pad(old_shape, [[0, 1]], constant_values=1))) - - return _atleast_nd(3, new_shape, *arys) - - -@utils.np_doc(np.nonzero) -def nonzero(a): - a = atleast_1d(a).data - if a.shape.rank is None: - raise ValueError("The rank of `a` is unknown, so we can't decide how many " - "arrays to return.") - return tf.nest.map_structure( - arrays_lib.tensor_to_ndarray, - tf.unstack(tf.where(tf.cast(a, tf.bool)), a.shape.rank, axis=1)) - - -@utils.np_doc(np.diag_indices) -def diag_indices(n, ndim=2): # pylint: disable=missing-docstring,redefined-outer-name - if n < 0: - raise ValueError('n argument to diag_indices must be nonnegative, got {}' - .format(n)) - if ndim < 0: - raise ValueError('ndim argument to diag_indices must be nonnegative, got {}' - .format(ndim)) - - return (tf.range(n),) * ndim - - -@utils.np_doc(np.tri) -def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring - M = M if M is not None else N - if dtype is not None: - dtype = utils.result_type(dtype) - else: - dtype = dtypes.default_float_type() - - if k < 0: - lower = -k - 1 - if lower > N: - r = tf.zeros([N, M], dtype) - else: - # Keep as tf bool, since we create an upper triangular matrix and invert - # it. - o = tf.ones([N, M], dtype=tf.bool) - r = tf.cast(tf.math.logical_not(tf.linalg.band_part(o, lower, -1)), dtype) - else: - o = tf.ones([N, M], dtype) - if k > M: - r = o - else: - r = tf.linalg.band_part(o, -1, k) - return utils.tensor_to_ndarray(r) - - -@utils.np_doc(np.tril) -def tril(m, k=0): # pylint: disable=missing-docstring - m = asarray(m).data - m_shape = m.shape.as_list() - - if len(m_shape) < 2: - raise ValueError('Argument to tril must have rank at least 2') - - if m_shape[-1] is None or m_shape[-2] is None: - raise ValueError('Currently, the last two dimensions of the input array ' - 'need to be known.') - - z = tf.constant(0, m.dtype) - - mask = tri(*m_shape[-2:], k=k, dtype=bool) - return utils.tensor_to_ndarray( - tf.where(tf.broadcast_to(mask, tf.shape(m)), m, z)) - - -@utils.np_doc(np.triu) -def triu(m, k=0): # pylint: disable=missing-docstring - m = asarray(m).data - m_shape = m.shape.as_list() - - if len(m_shape) < 2: - raise ValueError('Argument to triu must have rank at least 2') - - if m_shape[-1] is None or m_shape[-2] is None: - raise ValueError('Currently, the last two dimensions of the input array ' - 'need to be known.') - - z = tf.constant(0, m.dtype) - - mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) - return utils.tensor_to_ndarray( - tf.where(tf.broadcast_to(mask, tf.shape(m)), z, m)) - - -@utils.np_doc(np.flip) -def flip(m, axis=None): # pylint: disable=missing-docstring - m = asarray(m).data - - if axis is None: - return utils.tensor_to_ndarray(tf.reverse(m, tf.range(tf.rank(m)))) - - axis = utils._canonicalize_axis(axis, tf.rank(m)) # pylint: disable=protected-access - - return utils.tensor_to_ndarray(tf.reverse(m, [axis])) - - -@utils.np_doc(np.flipud) -def flipud(m): # pylint: disable=missing-docstring - return flip(m, 0) - - -@utils.np_doc(np.fliplr) -def fliplr(m): # pylint: disable=missing-docstring - return flip(m, 1) - - -@utils.np_doc(np.roll) -def roll(a, shift, axis=None): # pylint: disable=missing-docstring - a = asarray(a).data - - if axis is not None: - return utils.tensor_to_ndarray(tf.roll(a, shift, axis)) - - # If axis is None, the roll happens as a 1-d tensor. - original_shape = tf.shape(a) - a = tf.roll(tf.reshape(a, [-1]), shift, 0) - return utils.tensor_to_ndarray(tf.reshape(a, original_shape)) - - -@utils.np_doc(np.rot90) -def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring - m_rank = tf.rank(m) - ax1, ax2 = utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access - - k = k % 4 - if k == 0: - return m - elif k == 2: - return flip(flip(m, ax1), ax2) - else: - perm = tf.range(m_rank) - perm = tf.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1]) - - if k == 1: - return transpose(flip(m, ax2), perm) - else: - return flip(transpose(m, perm), ax2) - - -@utils.np_doc(np.vander) -def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name - x = asarray(x).data - - x_shape = tf.shape(x) - N = N or x_shape[0] - - N_temp = utils.get_static_value(N) # pylint: disable=invalid-name - if N_temp is not None: - N = N_temp - if N < 0: - raise ValueError('N must be nonnegative') - else: - tf.debugging.Assert(N >= 0, [N]) - - rank = tf.rank(x) - rank_temp = utils.get_static_value(rank) - if rank_temp is not None: - rank = rank_temp - if rank != 1: - raise ValueError('x must be a one-dimensional array') - else: - tf.debugging.Assert(rank == 1, [rank]) - - if increasing: - start = 0 - limit = N - delta = 1 - else: - start = N - 1 - limit = -1 - delta = -1 - - x = tf.expand_dims(x, -1) - return utils.tensor_to_ndarray( - tf.math.pow(x, tf.cast(tf.range(start, limit, delta), dtype=x.dtype))) - - -@utils.np_doc(np.ix_) -def ix_(*args): # pylint: disable=missing-docstring - n = len(args) - output = [] - for i, a in enumerate(args): - a = asarray(a).data - a_rank = tf.rank(a) - a_rank_temp = utils.get_static_value(a_rank) - if a_rank_temp is not None: - a_rank = a_rank_temp - if a_rank != 1: - raise ValueError( - 'Arguments must be 1-d, got arg {} of rank {}'.format(i, a_rank)) - else: - tf.debugging.Assert(a_rank == 1, [a_rank]) - - new_shape = [1] * n - new_shape[i] = -1 - dtype = a.dtype - if dtype == tf.bool: - output.append( - utils.tensor_to_ndarray(tf.reshape(nonzero(a)[0].data, new_shape))) - elif dtype.is_integer: - output.append(utils.tensor_to_ndarray(tf.reshape(a, new_shape))) - else: - raise ValueError( - 'Only integer and bool dtypes are supported, got {}'.format(dtype)) - - return output diff --git a/trax/tf_numpy/numpy_impl/arrays.py b/trax/tf_numpy/numpy_impl/arrays.py deleted file mode 100644 index 0329c25d0..000000000 --- a/trax/tf_numpy/numpy_impl/arrays.py +++ /dev/null @@ -1,286 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ndarray class.""" -import numpy as np -import six - -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import dtypes - - -def convert_to_tensor(value, dtype=None): - # A safer version of `tf.convert_to_tensor` to work around b/149876037. - # TODO(wangpeng): Remove this function once the bug is fixed. - if (dtype is None and isinstance(value, six.integer_types) - and value >= 2 ** 63): - dtype = tf.uint64 - elif (dtype is None and isinstance(value, float)): - dtype = dtypes.default_float_type() - return tf.convert_to_tensor(value, dtype=dtype) - - -class ndarray(object): # pylint: disable=invalid-name - """Equivalent of numpy.ndarray backed by TensorFlow tensors. - - This does not support all features of NumPy ndarrays e.g. strides and - memory order since, unlike NumPy, the backing storage is not a raw memory - buffer. - - TODO(srbs): Clearly specify which attributes and methods are not supported - or if there are any differences in behavior. - """ - - def __init__(self, shape, dtype=float, buffer=None): # pylint: disable=redefined-builtin - """Initializes an ndarray. - - This is a low level interface for building ndarrays and should be avoided. - Users should instead use methods in array_creation.py. - - This class provides a numpy.ndarray like interface for a TF Tensor with a - fully-defined shape. Note that, unlike the backing buffer of np.ndarray, - Tensors are immutable. So, operations like `__setitem__` are performed by - replacing the Tensor. This restricts the ability to implement NumPy `view` - semantics. - - Compared to numpy.ndarray, this does not support `offset`, `strides` - and `order` arguments. - - Args: - shape: The shape of the array. Must be a scalar, an iterable of integers - or a `TensorShape` object. - dtype: Optional. The dtype of the array. Must be a python type, a numpy - type or a tensorflow `DType` object. - buffer: Optional. The backing buffer of the array. Must have shape - `shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`. - - Raises: - ValueError: If `buffer` is specified and its shape does not match - `shape`. - """ - if dtype and not isinstance(dtype, tf.DType): - dtype = tf.as_dtype(np.dtype(dtype)) - if buffer is None: - buffer = tf.zeros(shape, dtype=dtype) - else: - if isinstance(buffer, ndarray): - buffer = buffer.data - elif isinstance(buffer, np.ndarray): - # If `buffer` is a np.ndarray, the Tensor will share the underlying - # storage of the array. - buffer = convert_to_tensor(value=buffer, dtype=dtype) - elif not isinstance(buffer, tf.Tensor): - raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,' - ' Tensor or np.ndarray.'.format(type(buffer))) - - if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access - # TODO(srbs): NumPy allows this. Investigate if/how to support this. - raise ValueError('shape arg must match buffer.shape.') - - assert isinstance(buffer, tf.Tensor) - if dtype and dtype != buffer.dtype: - buffer = tf.bitcast(buffer, dtype) - self._data = buffer - self.base = None - - @property - def data(self): - """Tensor object containing the array data. - - This has a few key differences from the Python buffer object used in - NumPy arrays. - 1. Tensors are immutable. So operations requiring in-place edit, e.g. - __setitem__, are performed by replacing the underlying buffer with a new - one. - 2. Tensors do not provide access to their raw buffer. - - Returns: - A Tensor. - """ - return self._data - - @property - def shape(self): - """Returns a tuple of array dimensions.""" - return self.data._shape_tuple() # pylint: disable=protected-access - - @property - def dtype(self): - return np.dtype(self.data.dtype.as_numpy_dtype) - - @property - def ndim(self): - return self.data.shape.ndims - - @property - def size(self): - """Returns the number of elements in the array.""" - return np.prod(self.shape) - - @property - def T(self): # pylint: disable=invalid-name - return self.transpose() - - def __len__(self): - if self.shape: - return self.shape[0] - else: - raise TypeError('len() of unsized object.') - - def astype(self, dtype): - if self.dtype == dtype: - return self - else: - return tensor_to_ndarray(tf.cast(self.data, dtype)) - - # Unary operations - def __neg__(self): - return tensor_to_ndarray(-self.data) # pylint: disable=invalid-unary-operand-type - - def __pos__(self): - return self - - __hash__ = None - - def __int__(self): - return int(self.data) - - def __float__(self): - return float(self.data) - - def __nonzero__(self): - return bool(self.data) - - def __bool__(self): - return self.__nonzero__() - - def __getitem__(self, slice_spec): - # TODO(srbs): Need to support better indexing. - result_t = self.data.__getitem__(slice_spec) - return tensor_to_ndarray(result_t) - - def __iter__(self): - for i in range(self.shape[0]): - result_t = self.data[i] - yield tensor_to_ndarray(result_t) - return - - def __array__(self, dtype=None): - """Returns a NumPy ndarray. - - This allows instances of this class to be directly used in NumPy routines. - However, doing that may force a copy to CPU. - - Args: - dtype: A NumPy compatible type. - - Returns: - A NumPy ndarray. - """ - return np.asarray(self.data, dtype) - - __array_priority__ = 110 - - def __index__(self): - """Returns a python scalar. - - This allows using an instance of this class as an array index. - Note that only arrays of integer types with size 1 can be used as array - indices. - - Returns: - A Python scalar. - - Raises: - TypeError: If the array is not of an integer type. - ValueError: If the array does not have size 1. - """ - # TODO(wangpeng): Handle graph mode - return self.data.numpy().item() - - def tolist(self): - return self.data.numpy().tolist() - - def __str__(self): - return 'ndarray<{}>'.format(self.data.__str__()) - - def __repr__(self): - return 'ndarray<{}>'.format(self.data.__repr__()) - - -def tensor_to_ndarray(tensor): - return ndarray(tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor) # pylint: disable=protected-access - - -def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False): - if as_ref: - raise ValueError('as_ref is not supported.') - if dtype and tf.as_dtype(arr.dtype) != dtype: - return tf.cast(arr.data, dtype) - result_t = arr.data - if name: - result_t = tf.identity(result_t, name=name) - return result_t - - -tf.register_tensor_conversion_function(ndarray, ndarray_to_tensor) - - -# Don't use a namedtuple since nest considers that a tuple and unflattens and -# flattens it. -class ShardedNdArray(object): - """Wrapper over ndarray that can contain tensors on multiple devices. - - This is returned by extensions.pmap, and contains the individual tensors on - different devices. - """ - - def __init__(self, tensors): - """Initializes the ShardedNdArray. - - Note that the tensors should be ordered in the way the pmap producing these - tensors is run. - - Args: - tensors: list or tuple of eager tensors, one for each device. - """ - - if not isinstance(tensors, (list, tuple)) or not tensors: - raise ValueError( - 'Unable to create a ShardedNdArray without a list of tensors.') - self.tensors = tensors - self.n_devices = len(tensors) - - def __getitem__(self, i): - return self.tensors[i] - - @property - def shape(self): - return (self.n_devices,) + self.tensors[0]._shape_tuple() # pylint: disable=protected-access - - @property - def dtype(self): - return x.tensors[0].dtype - - -def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): - del args, kwargs - # TODO(nareshmodi): Consider a collective op to gather the tensors from the - # various devices for performance reasons. - return tf.stack(value.tensors) - -tf.register_tensor_conversion_function( - ShardedNdArray, convert_sharded_tensor_to_eager_tensor) diff --git a/trax/tf_numpy/numpy_impl/dtypes.py b/trax/tf_numpy/numpy_impl/dtypes.py deleted file mode 100644 index 6ba976313..000000000 --- a/trax/tf_numpy/numpy_impl/dtypes.py +++ /dev/null @@ -1,94 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Dtypes and dtype utilities.""" -import numpy as np - -# We use numpy's dtypes instead of TF's, because the user expects to use them -# with numpy facilities such as `np.dtype(np.int64)` and -# `if x.dtype.type is np.int64`. -# pylint: disable=unused-import -# pylint: disable=g-bad-import-order -from numpy import bool_ -from numpy import int_ -from numpy import int16 -from numpy import int32 -from numpy import int64 -from numpy import int8 -from numpy import uint16 -from numpy import uint32 -from numpy import uint64 -from numpy import uint8 -from numpy import float_ -from numpy import float16 -from numpy import float32 -from numpy import float64 -from numpy import complex_ -from numpy import complex64 -from numpy import complex128 - -from numpy import inexact - -from numpy import iinfo -from numpy import issubdtype - -from numpy import inf - -# TODO(wangpeng): Make bfloat16 a numpy dtype instead of using TF's -from tensorflow.compat.v2 import bfloat16 -# pylint: enable=g-bad-import-order -# pylint: enable=unused-import - - -_to_float32 = { - np.dtype('float64'): np.dtype('float32'), - np.dtype('complex128'): np.dtype('complex64'), -} - - -_allow_float64 = True - - -def is_allow_float64(): - return _allow_float64 - - -def set_allow_float64(b): - global _allow_float64 - _allow_float64 = b - - -def canonicalize_dtype(dtype): - if not is_allow_float64(): - return _to_float32.get(dtype, dtype) - else: - return dtype - - -def _result_type(*arrays_and_dtypes): - dtype = np.result_type(*arrays_and_dtypes) - return canonicalize_dtype(dtype) - - -def default_float_type(): - """Gets the default float type. - - Returns: - If `is_allow_float64()` is true, returns float64; otherwise returns float32. - """ - if is_allow_float64(): - return float64 - else: - return float32 diff --git a/trax/tf_numpy/numpy_impl/math_ops.py b/trax/tf_numpy/numpy_impl/math_ops.py deleted file mode 100644 index aac45cc2c..000000000 --- a/trax/tf_numpy/numpy_impl/math_ops.py +++ /dev/null @@ -1,1140 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mathematical operations.""" -import sys - -import numpy as np -import six - -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import array_ops -from trax.tf_numpy.numpy_impl import arrays -from trax.tf_numpy.numpy_impl import dtypes -from trax.tf_numpy.numpy_impl import utils - - -@utils.np_doc_only(np.dot) -def dot(a, b): # pylint: disable=missing-docstring - def f(a, b): # pylint: disable=missing-docstring - return utils.cond( - utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), - lambda: a * b, - lambda: utils.cond( # pylint: disable=g-long-lambda - tf.rank(b) == 1, - lambda: tf.tensordot(a, b, axes=[[-1], [-1]]), - lambda: tf.tensordot(a, b, axes=[[-1], [-2]]))) - return _bin_op(f, a, b) - - -# TODO(wangpeng): Make element-wise ops `ufunc`s -def _bin_op(tf_fun, a, b, promote=True): - if promote: - a, b = array_ops._promote_dtype(a, b) # pylint: disable=protected-access - else: - a = array_ops.array(a) - b = array_ops.array(b) - return utils.tensor_to_ndarray(tf_fun(a.data, b.data)) - - -@utils.np_doc(np.add) -def add(x1, x2): - def add_or_or(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_or(x1, x2) - return tf.add(x1, x2) - return _bin_op(add_or_or, x1, x2) - - -@utils.np_doc(np.subtract) -def subtract(x1, x2): - return _bin_op(tf.subtract, x1, x2) - - -@utils.np_doc(np.multiply) -def multiply(x1, x2): - def mul_or_and(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_and(x1, x2) - return tf.multiply(x1, x2) - return _bin_op(mul_or_and, x1, x2) - - -@utils.np_doc(np.true_divide) -def true_divide(x1, x2): - def _avoid_float64(x1, x2): - if x1.dtype == x2.dtype and x1.dtype in (tf.int32, tf.int64): - x1 = tf.cast(x1, dtype=tf.float32) - x2 = tf.cast(x2, dtype=tf.float32) - return x1, x2 - - def f(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - float_ = dtypes.default_float_type() - x1 = tf.cast(x1, float_) - x2 = tf.cast(x2, float_) - if not dtypes.is_allow_float64(): - # tf.math.truediv in Python3 produces float64 when both inputs are int32 - # or int64. We want to avoid that when is_allow_float64() is False. - x1, x2 = _avoid_float64(x1, x2) - return tf.math.truediv(x1, x2) - return _bin_op(f, x1, x2) - - -divide = true_divide - - -@utils.np_doc(np.floor_divide) -def floor_divide(x1, x2): - def f(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - x1 = tf.cast(x1, tf.int8) - x2 = tf.cast(x2, tf.int8) - return tf.math.floordiv(x1, x2) - return _bin_op(f, x1, x2) - - -@utils.np_doc(np.mod) -def mod(x1, x2): - def f(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - x1 = tf.cast(x1, tf.int8) - x2 = tf.cast(x2, tf.int8) - return tf.math.mod(x1, x2) - return _bin_op(f, x1, x2) - - -remainder = mod - - -@utils.np_doc(np.divmod) -def divmod(x1, x2): - return floor_divide(x1, x2), mod(x1, x2) - - -@utils.np_doc(np.maximum) -def maximum(x1, x2): - def max_or_or(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_or(x1, x2) - return tf.math.maximum(x1, x2) - return _bin_op(max_or_or, x1, x2) - - -@utils.np_doc(np.minimum) -def minimum(x1, x2): - def min_or_and(x1, x2): - if x1.dtype == tf.bool: - assert x2.dtype == tf.bool - return tf.logical_and(x1, x2) - return tf.math.minimum(x1, x2) - return _bin_op(min_or_and, x1, x2) - - -@utils.np_doc(np.clip) -def clip(a, a_min, a_max): # pylint: disable=missing-docstring - if a_min is None and a_max is None: - raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.') - if a_min is None: - return minimum(a, a_max) - elif a_max is None: - return maximum(a, a_min) - else: - a, a_min, a_max = array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access - return utils.tensor_to_ndarray( - tf.clip_by_value(*utils.tf_broadcast(a.data, a_min.data, a_max.data))) - - -@utils.np_doc(np.matmul) -def matmul(x1, x2): # pylint: disable=missing-docstring - def f(x1, x2): - try: - return utils.cond(tf.rank(x2) == 1, - lambda: tf.tensordot(x1, x2, axes=1), - lambda: utils.cond(tf.rank(x1) == 1, # pylint: disable=g-long-lambda - lambda: tf.tensordot( # pylint: disable=g-long-lambda - x1, x2, axes=[[0], [-2]]), - lambda: tf.matmul(x1, x2))) - except tf.errors.InvalidArgumentError as err: - six.reraise(ValueError, ValueError(str(err)), sys.exc_info()[2]) - return _bin_op(f, x1, x2) - - -@utils.np_doc(np.tensordot) -def tensordot(a, b, axes=2): - return _bin_op(lambda a, b: tf.tensordot(a, b, axes=axes), a, b) - - -@utils.np_doc_only(np.inner) -def inner(a, b): - def f(a, b): - return utils.cond(utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), - lambda: a * b, - lambda: tf.tensordot(a, b, axes=[[-1], [-1]])) - return _bin_op(f, a, b) - - -@utils.np_doc(np.cross) -def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=missing-docstring - def f(a, b): # pylint: disable=missing-docstring - # We can't assign to captured variable `axisa`, so make a new variable - axis_a = axisa - axis_b = axisb - axis_c = axisc - if axis is not None: - axis_a = axis - axis_b = axis - axis_c = axis - if axis_a < 0: - axis_a = utils.add(axis_a, tf.rank(a)) - if axis_b < 0: - axis_b = utils.add(axis_b, tf.rank(b)) - def maybe_move_axis_to_last(a, axis): - def move_axis_to_last(a, axis): - return tf.transpose( - a, tf.concat( - [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]], - axis=0)) - return utils.cond( - axis == utils.subtract(tf.rank(a), 1), - lambda: a, - lambda: move_axis_to_last(a, axis)) - a = maybe_move_axis_to_last(a, axis_a) - b = maybe_move_axis_to_last(b, axis_b) - a_dim = utils.getitem(tf.shape(a), -1) - b_dim = utils.getitem(tf.shape(b), -1) - def maybe_pad_0(a, size_of_last_dim): - def pad_0(a): - return tf.pad(a, tf.concat([tf.zeros([tf.rank(a) - 1, 2], tf.int32), - tf.constant([[0, 1]], tf.int32)], axis=0)) - return utils.cond(size_of_last_dim == 2, - lambda: pad_0(a), - lambda: a) - a = maybe_pad_0(a, a_dim) - b = maybe_pad_0(b, b_dim) - c = tf.linalg.cross(*utils.tf_broadcast(a, b)) - if axis_c < 0: - axis_c = utils.add(axis_c, tf.rank(c)) - def move_last_to_axis(a, axis): - r = tf.rank(a) - return tf.transpose( - a, tf.concat( - [tf.range(axis), [r - 1], tf.range(axis, r - 1)], axis=0)) - c = utils.cond( - (a_dim == 2) & (b_dim == 2), - lambda: c[..., 2], - lambda: utils.cond( # pylint: disable=g-long-lambda - axis_c == utils.subtract(tf.rank(c), 1), - lambda: c, - lambda: move_last_to_axis(c, axis_c))) - return c - return _bin_op(f, a, b) - - -@utils.np_doc(np.power) -def power(x1, x2): - return _bin_op(tf.math.pow, x1, x2) - - -@utils.np_doc(np.float_power) -def float_power(x1, x2): - return power(x1, x2) - - -@utils.np_doc(np.arctan2) -def arctan2(x1, x2): - return _bin_op(tf.math.atan2, x1, x2) - - -@utils.np_doc(np.nextafter) -def nextafter(x1, x2): - return _bin_op(tf.math.nextafter, x1, x2) - - -@utils.np_doc(np.heaviside) -def heaviside(x1, x2): - def f(x1, x2): - return tf.where(x1 < 0, tf.constant(0, dtype=x2.dtype), - tf.where(x1 > 0, tf.constant(1, dtype=x2.dtype), x2)) - y = _bin_op(f, x1, x2) - if not np.issubdtype(y.dtype, np.inexact): - y = y.astype(dtypes.default_float_type()) - return y - - -@utils.np_doc(np.hypot) -def hypot(x1, x2): - return sqrt(square(x1) + square(x2)) - - -@utils.np_doc(np.kron) -def kron(a, b): - # pylint: disable=protected-access,g-complex-comprehension - a, b = array_ops._promote_dtype(a, b) - ndim = max(a.ndim, b.ndim) - if a.ndim < ndim: - a = array_ops.reshape(a, array_ops._pad_left_to(ndim, a.shape)) - if b.ndim < ndim: - b = array_ops.reshape(b, array_ops._pad_left_to(ndim, b.shape)) - a_reshaped = array_ops.reshape(a, [i for d in a.shape for i in (d, 1)]) - b_reshaped = array_ops.reshape(b, [i for d in b.shape for i in (1, d)]) - out_shape = tuple(np.multiply(a.shape, b.shape)) - return array_ops.reshape(a_reshaped * b_reshaped, out_shape) - - -@utils.np_doc(np.outer) -def outer(a, b): - def f(a, b): - return tf.reshape(a, [-1, 1]) * tf.reshape(b, [-1]) - return _bin_op(f, a, b) - - -# This can also be implemented via tf.reduce_logsumexp -@utils.np_doc(np.logaddexp) -def logaddexp(x1, x2): - amax = maximum(x1, x2) - delta = x1 - x2 - return array_ops.where( - isnan(delta), - x1 + x2, # NaNs or infinities of the same sign. - amax + log1p(exp(-abs(delta)))) - - -@utils.np_doc(np.logaddexp2) -def logaddexp2(x1, x2): - amax = maximum(x1, x2) - delta = x1 - x2 - return array_ops.where( - isnan(delta), - x1 + x2, # NaNs or infinities of the same sign. - amax + log1p(exp2(-abs(delta))) / np.log(2)) - - -@utils.np_doc(np.polyval) -def polyval(p, x): - def f(p, x): - if p.shape.rank == 0: - p = tf.reshape(p, [1]) - p = tf.unstack(p) - # TODO(wangpeng): Make tf version take a tensor for p instead of a list. - y = tf.math.polyval(p, x) - # If the polynomial is 0-order, numpy requires the result to be broadcast to - # `x`'s shape. - if len(p) == 1: - y = tf.broadcast_to(y, x.shape) - return y - return _bin_op(f, p, x) - - -@utils.np_doc(np.isclose) -def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): # pylint: disable=missing-docstring - def f(a, b): # pylint: disable=missing-docstring - dtype = a.dtype - if np.issubdtype(dtype.as_numpy_dtype, np.inexact): - rtol_ = tf.convert_to_tensor(rtol, dtype.real_dtype) - atol_ = tf.convert_to_tensor(atol, dtype.real_dtype) - result = (tf.math.abs(a - b) <= atol_ + rtol_ * tf.math.abs(b)) - if equal_nan: - result = result | (tf.math.is_nan(a) & tf.math.is_nan(b)) - return result - else: - return a == b - return _bin_op(f, a, b) - - -@utils.np_doc(np.allclose) -def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): - return array_ops.all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) - - -def _tf_gcd(x1, x2): - def _gcd_cond_fn(x1, x2): - return tf.reduce_any(x2 != 0) - def _gcd_body_fn(x1, x2): - # tf.math.mod will raise an error when any element of x2 is 0. To avoid - # that, we change those zeros to ones. Their values don't matter because - # they won't be used. - x2_safe = tf.where(x2 != 0, x2, tf.constant(1, x2.dtype)) - x1, x2 = (tf.where(x2 != 0, x2, x1), - tf.where(x2 != 0, tf.math.mod(x1, x2_safe), - tf.constant(0, x2.dtype))) - return (tf.where(x1 < x2, x2, x1), tf.where(x1 < x2, x1, x2)) - if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or - not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)): - raise ValueError("Arguments to gcd must be integers.") - shape = tf.broadcast_static_shape(x1.shape, x2.shape) - x1 = tf.broadcast_to(x1, shape) - x2 = tf.broadcast_to(x2, shape) - gcd, _ = tf.while_loop(_gcd_cond_fn, _gcd_body_fn, - (tf.math.abs(x1), tf.math.abs(x2))) - return gcd - - -@utils.np_doc(np.gcd) -def gcd(x1, x2): - return _bin_op(_tf_gcd, x1, x2) - - -@utils.np_doc(np.lcm) -def lcm(x1, x2): - def f(x1, x2): - d = _tf_gcd(x1, x2) - # Same as the `x2_safe` trick above - d_safe = tf.where(d == 0, tf.constant(1, d.dtype), d) - return tf.where(d == 0, tf.constant(0, d.dtype), - tf.math.abs(x1 * x2) // d_safe) - return _bin_op(f, x1, x2) - - -def _bitwise_binary_op(tf_fn, x1, x2): - def f(x1, x2): - is_bool = (x1.dtype == tf.bool) - if is_bool: - assert x2.dtype == tf.bool - x1 = tf.cast(x1, tf.int8) - x2 = tf.cast(x2, tf.int8) - r = tf_fn(x1, x2) - if is_bool: - r = tf.cast(r, tf.bool) - return r - return _bin_op(f, x1, x2) - - -@utils.np_doc(np.bitwise_and) -def bitwise_and(x1, x2): - return _bitwise_binary_op(tf.bitwise.bitwise_and, x1, x2) - - -@utils.np_doc(np.bitwise_or) -def bitwise_or(x1, x2): - return _bitwise_binary_op(tf.bitwise.bitwise_or, x1, x2) - - -@utils.np_doc(np.bitwise_xor) -def bitwise_xor(x1, x2): - return _bitwise_binary_op(tf.bitwise.bitwise_xor, x1, x2) - - -@utils.np_doc(np.bitwise_not) -def bitwise_not(x): - def f(x): - if x.dtype == tf.bool: - return tf.logical_not(x) - return tf.bitwise.invert(x) - return _scalar(f, x) - - -def _scalar(tf_fn, x, promote_to_float=False): - """Computes the tf_fn(x) for each element in `x`. - - Args: - tf_fn: function that takes a single Tensor argument. - x: array_like. Could be an ndarray, a Tensor or any object that can - be converted to a Tensor using `tf.convert_to_tensor`. - promote_to_float: whether to cast the argument to a float dtype - (`dtypes.default_float_type`) if it is not already. - - Returns: - An ndarray with the same shape as `x`. The default output dtype is - determined by `dtypes.default_float_type`, unless x is an ndarray with a - floating point type, in which case the output type is same as x.dtype. - """ - x = array_ops.asarray(x) - if promote_to_float and not np.issubdtype(x.dtype, np.inexact): - x = x.astype(dtypes.default_float_type()) - return utils.tensor_to_ndarray(tf_fn(x.data)) - - -@utils.np_doc(np.log) -def log(x): - return _scalar(tf.math.log, x, True) - - -@utils.np_doc(np.exp) -def exp(x): - return _scalar(tf.exp, x, True) - - -@utils.np_doc(np.sqrt) -def sqrt(x): - return _scalar(tf.sqrt, x, True) - - -@utils.np_doc(np.abs) -def abs(x): - return _scalar(tf.math.abs, x) - - -@utils.np_doc(np.absolute) -def absolute(x): - return abs(x) - - -@utils.np_doc(np.fabs) -def fabs(x): - return abs(x) - - -@utils.np_doc(np.ceil) -def ceil(x): - return _scalar(tf.math.ceil, x, True) - - -@utils.np_doc(np.floor) -def floor(x): - return _scalar(tf.math.floor, x, True) - - -@utils.np_doc(np.conj) -def conj(x): - return _scalar(tf.math.conj, x) - - -@utils.np_doc(np.negative) -def negative(x): - return _scalar(tf.math.negative, x) - - -@utils.np_doc(np.reciprocal) -def reciprocal(x): - return _scalar(tf.math.reciprocal, x) - - -@utils.np_doc(np.signbit) -def signbit(x): - def f(x): - if x.dtype == tf.bool: - return tf.fill(x.shape, False) - return x < 0 - return _scalar(f, x) - - -@utils.np_doc(np.sin) -def sin(x): - return _scalar(tf.math.sin, x, True) - - -@utils.np_doc(np.cos) -def cos(x): - return _scalar(tf.math.cos, x, True) - - -@utils.np_doc(np.tan) -def tan(x): - return _scalar(tf.math.tan, x, True) - - -@utils.np_doc(np.sinh) -def sinh(x): - return _scalar(tf.math.sinh, x, True) - - -@utils.np_doc(np.cosh) -def cosh(x): - return _scalar(tf.math.cosh, x, True) - - -@utils.np_doc(np.tanh) -def tanh(x): - return _scalar(tf.math.tanh, x, True) - - -@utils.np_doc(np.arcsin) -def arcsin(x): - return _scalar(tf.math.asin, x, True) - - -@utils.np_doc(np.arccos) -def arccos(x): - return _scalar(tf.math.acos, x, True) - - -@utils.np_doc(np.arctan) -def arctan(x): - return _scalar(tf.math.atan, x, True) - - -@utils.np_doc(np.arcsinh) -def arcsinh(x): - return _scalar(tf.math.asinh, x, True) - - -@utils.np_doc(np.arccosh) -def arccosh(x): - return _scalar(tf.math.acosh, x, True) - - -@utils.np_doc(np.arctanh) -def arctanh(x): - return _scalar(tf.math.atanh, x, True) - - -@utils.np_doc(np.deg2rad) -def deg2rad(x): - def f(x): - return x * (np.pi / 180.0) - return _scalar(f, x, True) - - -@utils.np_doc(np.rad2deg) -def rad2deg(x): - return x * (180.0 / np.pi) - - -_tf_float_types = [tf.bfloat16, tf.float16, tf.float32, tf.float64] - - -@utils.np_doc(np.angle) -def angle(z, deg=False): - def f(x): - if x.dtype in _tf_float_types: - # Workaround for b/147515503 - return tf.where(x < 0, np.pi, 0) - else: - return tf.math.angle(x) - y = _scalar(f, z, True) - if deg: - y = rad2deg(y) - return y - - -@utils.np_doc(np.cbrt) -def cbrt(x): - def f(x): - # __pow__ can't handle negative base, so we use `abs` here. - rt = tf.math.abs(x) ** (1.0 / 3) - return tf.where(x < 0, -rt, rt) - return _scalar(f, x, True) - - -@utils.np_doc(np.conjugate) -def conjugate(x): - return _scalar(tf.math.conj, x) - - -@utils.np_doc(np.exp2) -def exp2(x): - def f(x): - return 2 ** x - return _scalar(f, x, True) - - -@utils.np_doc(np.expm1) -def expm1(x): - return _scalar(tf.math.expm1, x, True) - - -@utils.np_doc(np.fix) -def fix(x): - def f(x): - return tf.where(x < 0, tf.math.ceil(x), tf.math.floor(x)) - return _scalar(f, x, True) - - -@utils.np_doc(np.iscomplex) -def iscomplex(x): - return array_ops.imag(x) != 0 - - -@utils.np_doc(np.isreal) -def isreal(x): - return array_ops.imag(x) == 0 - - -@utils.np_doc(np.iscomplexobj) -def iscomplexobj(x): - x = array_ops.array(x) - return np.issubdtype(x.dtype, np.complexfloating) - - -@utils.np_doc(np.isrealobj) -def isrealobj(x): - return not iscomplexobj(x) - - -@utils.np_doc(np.isnan) -def isnan(x): - return _scalar(tf.math.is_nan, x, True) - - -def _make_nan_reduction(onp_reduction, reduction, init_val): - """Helper to generate nan* functions.""" - @utils.np_doc(onp_reduction) - def nan_reduction(a, axis=None, dtype=None, keepdims=False): - a = array_ops.array(a) - v = array_ops.array(init_val, dtype=a.dtype) - return reduction( - array_ops.where(isnan(a), v, a), - axis=axis, - dtype=dtype, - keepdims=keepdims) - return nan_reduction - - -nansum = _make_nan_reduction(np.nansum, array_ops.sum, 0) -nanprod = _make_nan_reduction(np.nanprod, array_ops.prod, 1) - - -@utils.np_doc(np.nanmean) -def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring - a = array_ops.array(a) - if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer): - return array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) - nan_mask = logical_not(isnan(a)) - if dtype is None: - dtype = a.dtype - normalizer = array_ops.sum( - nan_mask, axis=axis, dtype=dtype, keepdims=keepdims) - return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer - - -@utils.np_doc(np.isfinite) -def isfinite(x): - return _scalar(tf.math.is_finite, x, True) - - -@utils.np_doc(np.isinf) -def isinf(x): - return _scalar(tf.math.is_inf, x, True) - - -@utils.np_doc(np.isneginf) -def isneginf(x): - return x == array_ops.full_like(x, -np.inf) - - -@utils.np_doc(np.isposinf) -def isposinf(x): - return x == array_ops.full_like(x, np.inf) - - -@utils.np_doc(np.log2) -def log2(x): - return log(x) / np.log(2) - - -@utils.np_doc(np.log10) -def log10(x): - return log(x) / np.log(10) - - -@utils.np_doc(np.log1p) -def log1p(x): - return _scalar(tf.math.log1p, x, True) - - -@utils.np_doc(np.positive) -def positive(x): - return _scalar(lambda x: x, x) - - -@utils.np_doc(np.sinc) -def sinc(x): - def f(x): - pi_x = x * np.pi - return tf.where(x == 0, tf.ones_like(x), tf.math.sin(pi_x) / pi_x) - return _scalar(f, x, True) - - -@utils.np_doc(np.square) -def square(x): - return _scalar(tf.math.square, x) - - -@utils.np_doc(np.diff) -def diff(a, n=1, axis=-1): - def f(a): - nd = a.shape.rank - if (axis + nd if axis < 0 else axis) >= nd: - raise ValueError("axis %s is out of bounds for array of dimension %s" % - (axis, nd)) - if n < 0: - raise ValueError("order must be non-negative but got %s" % n) - slice1 = [slice(None)] * nd - slice2 = [slice(None)] * nd - slice1[axis] = slice(1, None) - slice2[axis] = slice(None, -1) - slice1 = tuple(slice1) - slice2 = tuple(slice2) - op = tf.not_equal if a.dtype == tf.bool else tf.subtract - for _ in range(n): - a = op(a[slice1], a[slice2]) - return a - return _scalar(f, a) - - -def _flip_args(f): - def _f(a, b): - return f(b, a) - return _f - - -setattr(arrays.ndarray, '__abs__', absolute) -setattr(arrays.ndarray, '__floordiv__', floor_divide) -setattr(arrays.ndarray, '__rfloordiv__', _flip_args(floor_divide)) -setattr(arrays.ndarray, '__mod__', mod) -setattr(arrays.ndarray, '__rmod__', _flip_args(mod)) -setattr(arrays.ndarray, '__add__', add) -setattr(arrays.ndarray, '__radd__', _flip_args(add)) -setattr(arrays.ndarray, '__sub__', subtract) -setattr(arrays.ndarray, '__rsub__', _flip_args(subtract)) -setattr(arrays.ndarray, '__mul__', multiply) -setattr(arrays.ndarray, '__rmul__', _flip_args(multiply)) -setattr(arrays.ndarray, '__pow__', power) -setattr(arrays.ndarray, '__rpow__', _flip_args(power)) -setattr(arrays.ndarray, '__truediv__', true_divide) -setattr(arrays.ndarray, '__rtruediv__', _flip_args(true_divide)) - - -def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): - dtype = utils.result_type(x1, x2) - # Cast x1 and x2 to the result_type if needed. - x1 = array_ops.array(x1, dtype=dtype) - x2 = array_ops.array(x2, dtype=dtype) - x1 = x1.data - x2 = x2.data - if cast_bool_to_int and x1.dtype == tf.bool: - x1 = tf.cast(x1, tf.int32) - x2 = tf.cast(x2, tf.int32) - return utils.tensor_to_ndarray(tf_fun(x1, x2)) - - -@utils.np_doc(np.equal) -def equal(x1, x2): - return _comparison(tf.equal, x1, x2) - - -@utils.np_doc(np.not_equal) -def not_equal(x1, x2): - return _comparison(tf.not_equal, x1, x2) - - -@utils.np_doc(np.greater) -def greater(x1, x2): - return _comparison(tf.greater, x1, x2, True) - - -@utils.np_doc(np.greater_equal) -def greater_equal(x1, x2): - return _comparison(tf.greater_equal, x1, x2, True) - - -@utils.np_doc(np.less) -def less(x1, x2): - return _comparison(tf.less, x1, x2, True) - - -@utils.np_doc(np.less_equal) -def less_equal(x1, x2): - return _comparison(tf.less_equal, x1, x2, True) - - -@utils.np_doc(np.array_equal) -def array_equal(a1, a2): - def f(a1, a2): - if a1.shape != a2.shape: - return tf.constant(False) - return tf.reduce_all(tf.equal(a1, a2)) - return _comparison(f, a1, a2) - - -def _logical_binary_op(tf_fun, x1, x2): - x1 = array_ops.array(x1, dtype=np.bool_) - x2 = array_ops.array(x2, dtype=np.bool_) - return utils.tensor_to_ndarray(tf_fun(x1.data, x2.data)) - - -@utils.np_doc(np.logical_and) -def logical_and(x1, x2): - return _logical_binary_op(tf.logical_and, x1, x2) - - -@utils.np_doc(np.logical_or) -def logical_or(x1, x2): - return _logical_binary_op(tf.logical_or, x1, x2) - - -@utils.np_doc(np.logical_xor) -def logical_xor(x1, x2): - return _logical_binary_op(tf.math.logical_xor, x1, x2) - - -@utils.np_doc(np.logical_not) -def logical_not(x): - x = array_ops.array(x, dtype=np.bool_) - return utils.tensor_to_ndarray(tf.logical_not(x.data)) - -setattr(arrays.ndarray, '__invert__', logical_not) -setattr(arrays.ndarray, '__lt__', less) -setattr(arrays.ndarray, '__le__', less_equal) -setattr(arrays.ndarray, '__gt__', greater) -setattr(arrays.ndarray, '__ge__', greater_equal) -setattr(arrays.ndarray, '__eq__', equal) -setattr(arrays.ndarray, '__ne__', not_equal) - - -@utils.np_doc(np.linspace) -def linspace( # pylint: disable=missing-docstring - start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0): - if dtype: - dtype = utils.result_type(dtype) - start = array_ops.array(start, dtype=dtype).data - stop = array_ops.array(stop, dtype=dtype).data - if num < 0: - raise ValueError('Number of samples {} must be non-negative.'.format(num)) - step = tf.convert_to_tensor(np.nan) - if endpoint: - result = tf.linspace(start, stop, num, axis=axis) - if num > 1: - step = (stop - start) / (num - 1) - else: - # tf.linspace does not support endpoint=False so we manually handle it - # here. - if num > 1: - step = ((stop - start) / num) - new_stop = tf.cast(stop, step.dtype) - step - start = tf.cast(start, new_stop.dtype) - result = tf.linspace(start, new_stop, num, axis=axis) - else: - result = tf.linspace(start, stop, num, axis=axis) - if dtype: - result = tf.cast(result, dtype) - if retstep: - return arrays.tensor_to_ndarray(result), arrays.tensor_to_ndarray(step) - else: - return arrays.tensor_to_ndarray(result) - - -@utils.np_doc(np.logspace) -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): - dtype = utils.result_type(start, stop, dtype) - result = linspace( - start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis).data - result = tf.pow(tf.cast(base, result.dtype), result) - if dtype: - result = tf.cast(result, dtype) - return arrays.tensor_to_ndarray(result) - - -@utils.np_doc(np.ptp) -def ptp(a, axis=None, keepdims=None): - return (array_ops.amax(a, axis=axis, keepdims=keepdims) - - array_ops.amin(a, axis=axis, keepdims=keepdims)) - - -@utils.np_doc_only(np.concatenate) -def concatenate(arys, axis=0): - if not isinstance(arys, (list, tuple)): - arys = [arys] - if not arys: - raise ValueError('Need at least one array to concatenate.') - dtype = utils.result_type(*arys) - arys = [array_ops.array(array, dtype=dtype).data for array in arys] - return arrays.tensor_to_ndarray(tf.concat(arys, axis)) - - -@utils.np_doc_only(np.tile) -def tile(a, reps): - a = array_ops.array(a).data - reps = array_ops.array(reps, dtype=tf.int32).reshape([-1]).data - - a_rank = tf.rank(a) - reps_size = tf.size(reps) - reps = tf.pad( - reps, [[tf.math.maximum(a_rank - reps_size, 0), 0]], - constant_values=1) - a_shape = tf.pad( - tf.shape(a), [[tf.math.maximum(reps_size - a_rank, 0), 0]], - constant_values=1) - a = tf.reshape(a, a_shape) - - return arrays.tensor_to_ndarray(tf.tile(a, reps)) - - -@utils.np_doc(np.count_nonzero) -def count_nonzero(a, axis=None): - return arrays.tensor_to_ndarray( - tf.math.count_nonzero(array_ops.array(a).data, axis)) - - -@utils.np_doc(np.argsort) -def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring - # TODO(nareshmodi): make string tensors also work. - if kind not in ('quicksort', 'stable'): - raise ValueError("Only 'quicksort' and 'stable' arguments are supported.") - if order is not None: - raise ValueError("'order' argument to sort is not supported.") - stable = (kind == 'stable') - - a = array_ops.array(a).data - - def _argsort(a, axis, stable): - if axis is None: - a = tf.reshape(a, [-1]) - axis = 0 - - return tf.argsort(a, axis, stable=stable) - - tf_ans = tf.cond( - tf.rank(a) == 0, lambda: tf.constant([0]), - lambda: _argsort(a, axis, stable)) - - return array_ops.array(tf_ans, dtype=np.intp) - - -@utils.np_doc(np.sort) -def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring - if kind != 'quicksort': - raise ValueError("Only 'quicksort' is supported.") - if order is not None: - raise ValueError("'order' argument to sort is not supported.") - - a = array_ops.array(a) - - if axis is None: - result_t = tf.sort(tf.reshape(a.data, [-1]), 0) - return utils.tensor_to_ndarray(result_t) - else: - return utils.tensor_to_ndarray(tf.sort(a.data, axis)) - - -def _argminmax(fn, a, axis=None): - a = array_ops.array(a) - if axis is None: - # When axis is None numpy flattens the array. - a_t = tf.reshape(a.data, [-1]) - else: - a_t = array_ops.atleast_1d(a).data - return utils.tensor_to_ndarray(fn(input=a_t, axis=axis)) - - -@utils.np_doc(np.argmax) -def argmax(a, axis=None): - return _argminmax(tf.argmax, a, axis) - - -@utils.np_doc(np.argmin) -def argmin(a, axis=None): - return _argminmax(tf.argmin, a, axis) - - -@utils.np_doc(np.append) -def append(arr, values, axis=None): - if axis is None: - return concatenate([array_ops.ravel(arr), array_ops.ravel(values)], 0) - else: - return concatenate([arr, values], axis=axis) - - -@utils.np_doc(np.average) -def average(a, axis=None, weights=None, returned=False): # pylint: disable=missing-docstring - if axis is not None and not isinstance(axis, six.integer_types): - # TODO(wangpeng): Support tuple of ints as `axis` - raise ValueError('`axis` must be an integer. Tuple of ints is not ' - 'supported yet. Got type: %s' % type(axis)) - a = array_ops.array(a) - if weights is None: # Treat all weights as 1 - if not np.issubdtype(a.dtype, np.inexact): - a = a.astype(utils.result_type(a.dtype, dtypes.default_float_type())) - avg = tf.reduce_mean(a.data, axis=axis) - if returned: - if axis is None: - weights_sum = tf.size(a.data) - else: - weights_sum = tf.shape(a.data)[axis] - weights_sum = tf.cast(weights_sum, a.data.dtype) - else: - if np.issubdtype(a.dtype, np.inexact): - out_dtype = utils.result_type(a.dtype, weights) - else: - out_dtype = utils.result_type(a.dtype, weights, - dtypes.default_float_type()) - a = array_ops.array(a, out_dtype).data - weights = array_ops.array(weights, out_dtype).data - - def rank_equal_case(): - tf.debugging.Assert(tf.reduce_all(tf.shape(a) == tf.shape(weights)), - [tf.shape(a), tf.shape(weights)]) - weights_sum = tf.reduce_sum(weights, axis=axis) - avg = tf.reduce_sum(a * weights, axis=axis) / weights_sum - return avg, weights_sum - if axis is None: - avg, weights_sum = rank_equal_case() - else: - def rank_not_equal_case(): - tf.debugging.Assert(tf.rank(weights) == 1, [tf.rank(weights)]) - weights_sum = tf.reduce_sum(weights) - axes = tf.convert_to_tensor([[axis], [0]]) - avg = tf.tensordot(a, weights, axes) / weights_sum - return avg, weights_sum - # We condition on rank rather than shape equality, because if we do the - # latter, when the shapes are partially unknown but the ranks are known - # and different, utils.cond will run shape checking on the true branch, - # which will raise a shape-checking error. - avg, weights_sum = utils.cond(tf.rank(a) == tf.rank(weights), - rank_equal_case, rank_not_equal_case) - - avg = array_ops.array(avg) - if returned: - weights_sum = array_ops.broadcast_to(weights_sum, tf.shape(avg.data)) - return avg, weights_sum - return avg - - -@utils.np_doc(np.trace) -def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring - if dtype: - dtype = utils.result_type(dtype) - a = array_ops.asarray(a, dtype).data - - if offset == 0: - a_shape = a.shape - if a_shape.rank is not None: - rank = len(a_shape) - if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or - axis2 == rank - 1): - return utils.tensor_to_ndarray(tf.linalg.trace(a)) - - a = array_ops.diagonal(a, offset, axis1, axis2) - return array_ops.sum(a, -1, dtype) - - -@utils.np_doc(np.meshgrid) -def meshgrid(*xi, **kwargs): - """This currently requires copy=True and sparse=False.""" - sparse = kwargs.get('sparse', False) - if sparse: - raise ValueError('tf.numpy doesnt support returning sparse arrays yet') - - copy = kwargs.get('copy', True) - if not copy: - raise ValueError('tf.numpy only supports copy=True') - - indexing = kwargs.get('indexing', 'xy') - - xi = [array_ops.asarray(arg).data for arg in xi] - kwargs = {'indexing': indexing} - - outputs = tf.meshgrid(*xi, **kwargs) - outputs = [utils.tensor_to_ndarray(output) for output in outputs] - - return outputs diff --git a/trax/tf_numpy/numpy_impl/random.py b/trax/tf_numpy/numpy_impl/random.py deleted file mode 100644 index 8ed3021eb..000000000 --- a/trax/tf_numpy/numpy_impl/random.py +++ /dev/null @@ -1,53 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Random functions.""" -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import utils - - -DEFAULT_RANDN_DTYPE = np.float32 - - -def randn(*args): - """Returns samples from a normal distribution. - - Uses `tf.random_normal`. - - Args: - *args: The shape of the output array. - - Returns: - An ndarray with shape `args` and dtype `float64`. - """ - # TODO(wangpeng): Use new stateful RNG - if utils.isscalar(args): - args = (args,) - return utils.tensor_to_ndarray( - tf.random.normal(args, dtype=DEFAULT_RANDN_DTYPE)) - - -def seed(s): - """Sets the seed for the random number generator. - - Uses `tf.set_random_seed`. - - Args: - s: an integer. - """ - # TODO(wangpeng): make the signature the same as numpy - tf.random.set_seed(s) diff --git a/trax/tf_numpy/numpy_impl/tests/array_ops_test.py b/trax/tf_numpy/numpy_impl/tests/array_ops_test.py deleted file mode 100644 index b74992ba2..000000000 --- a/trax/tf_numpy/numpy_impl/tests/array_ops_test.py +++ /dev/null @@ -1,1130 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tf numpy array methods.""" -import itertools -import sys -import numpy as np -from six.moves import range -from six.moves import zip -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import array_ops -from trax.tf_numpy.numpy_impl import arrays - - -class ArrayCreationTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - python_shapes = [ - 0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3] - ] - self.shape_transforms = [ - lambda x: x, lambda x: np.array(x, dtype=int), - lambda x: array_ops.array(x, dtype=int), tf.TensorShape - ] - - self.all_shapes = [] - for fn in self.shape_transforms: - self.all_shapes.extend([fn(s) for s in python_shapes]) - - if sys.version_info.major == 3: - # There is a bug of np.empty (and alike) in Python 3 causing a crash when - # the `shape` argument is an arrays.ndarray scalar (or tf.Tensor scalar). - def not_ndarray_scalar(s): - return not (isinstance(s, arrays.ndarray) and s.ndim == 0) - self.all_shapes = list(filter(not_ndarray_scalar, self.all_shapes)) - - self.all_types = [ - int, float, np.int16, np.int32, np.int64, np.float16, np.float32, - np.float64 - ] - - source_array_data = [ - 1, - 5.5, - 7, - (), - (8, 10.), - ((), ()), - ((1, 4), (2, 8)), - [], - [7], - [8, 10.], - [[], []], - [[1, 4], [2, 8]], - ([], []), - ([1, 4], [2, 8]), - [(), ()], - [(1, 4), (2, 8)], - ] - - self.array_transforms = [ - lambda x: x, - tf.convert_to_tensor, - np.array, - array_ops.array, - ] - self.all_arrays = [] - for fn in self.array_transforms: - self.all_arrays.extend([fn(s) for s in source_array_data]) - - def testEmpty(self): - for s in self.all_shapes: - actual = array_ops.empty(s) - expected = np.empty(s) - msg = 'shape: {}'.format(s) - self.match_shape(actual, expected, msg) - self.match_dtype(actual, expected, msg) - - for s, t in itertools.product(self.all_shapes, self.all_types): - actual = array_ops.empty(s, t) - expected = np.empty(s, t) - msg = 'shape: {}, dtype: {}'.format(s, t) - self.match_shape(actual, expected, msg) - self.match_dtype(actual, expected, msg) - - def testEmptyLike(self): - for a in self.all_arrays: - actual = array_ops.empty_like(a) - expected = np.empty_like(a) - msg = 'array: {}'.format(a) - self.match_shape(actual, expected, msg) - self.match_dtype(actual, expected, msg) - - for a, t in itertools.product(self.all_arrays, self.all_types): - actual = array_ops.empty_like(a, t) - expected = np.empty_like(a, t) - msg = 'array: {} type: {}'.format(a, t) - self.match_shape(actual, expected, msg) - self.match_dtype(actual, expected, msg) - - def testZeros(self): - for s in self.all_shapes: - actual = array_ops.zeros(s) - expected = np.zeros(s) - msg = 'shape: {}'.format(s) - self.match(actual, expected, msg) - - for s, t in itertools.product(self.all_shapes, self.all_types): - actual = array_ops.zeros(s, t) - expected = np.zeros(s, t) - msg = 'shape: {}, dtype: {}'.format(s, t) - self.match(actual, expected, msg) - - def testZerosLike(self): - for a in self.all_arrays: - actual = array_ops.zeros_like(a) - expected = np.zeros_like(a) - msg = 'array: {}'.format(a) - self.match(actual, expected, msg) - - for a, t in itertools.product(self.all_arrays, self.all_types): - actual = array_ops.zeros_like(a, t) - expected = np.zeros_like(a, t) - msg = 'array: {} type: {}'.format(a, t) - self.match(actual, expected, msg) - - def testOnes(self): - for s in self.all_shapes: - actual = array_ops.ones(s) - expected = np.ones(s) - msg = 'shape: {}'.format(s) - self.match(actual, expected, msg) - - for s, t in itertools.product(self.all_shapes, self.all_types): - actual = array_ops.ones(s, t) - expected = np.ones(s, t) - msg = 'shape: {}, dtype: {}'.format(s, t) - self.match(actual, expected, msg) - - def testOnesLike(self): - for a in self.all_arrays: - actual = array_ops.ones_like(a) - expected = np.ones_like(a) - msg = 'array: {}'.format(a) - self.match(actual, expected, msg) - - for a, t in itertools.product(self.all_arrays, self.all_types): - actual = array_ops.ones_like(a, t) - expected = np.ones_like(a, t) - msg = 'array: {} type: {}'.format(a, t) - self.match(actual, expected, msg) - - def testEye(self): - n_max = 3 - m_max = 3 - - for n in range(1, n_max + 1): - self.match(array_ops.eye(n), np.eye(n)) - for k in range(-n, n + 1): - self.match(array_ops.eye(n, k=k), np.eye(n, k=k)) - for m in range(1, m_max + 1): - self.match(array_ops.eye(n, m), np.eye(n, m)) - for k in range(-n, m): - self.match(array_ops.eye(n, k=k), np.eye(n, k=k)) - self.match(array_ops.eye(n, m, k), np.eye(n, m, k)) - - for dtype in self.all_types: - for n in range(1, n_max + 1): - self.match(array_ops.eye(n, dtype=dtype), np.eye(n, dtype=dtype)) - for k in range(-n, n + 1): - self.match( - array_ops.eye(n, k=k, dtype=dtype), np.eye(n, k=k, dtype=dtype)) - for m in range(1, m_max + 1): - self.match( - array_ops.eye(n, m, dtype=dtype), np.eye(n, m, dtype=dtype)) - for k in range(-n, m): - self.match( - array_ops.eye(n, k=k, dtype=dtype), np.eye(n, k=k, dtype=dtype)) - self.match( - array_ops.eye(n, m, k, dtype=dtype), - np.eye(n, m, k, dtype=dtype)) - - def testIdentity(self): - n_max = 3 - - for n in range(1, n_max + 1): - self.match(array_ops.identity(n), np.identity(n)) - - for dtype in self.all_types: - for n in range(1, n_max + 1): - self.match( - array_ops.identity(n, dtype=dtype), np.identity(n, dtype=dtype)) - - def testFull(self): - # List of 2-tuples of fill value and shape. - data = [ - (5, ()), - (5, (7,)), - (5., (7,)), - ([5, 8], (2,)), - ([5, 8], (3, 2)), - ([[5], [8]], (2, 3)), - ([[5], [8]], (3, 2, 5)), - ([[5.], [8.]], (3, 2, 5)), - ([[3, 4], [5, 6], [7, 8]], (3, 3, 2)), - ] - for f, s in data: - for fn1, fn2 in itertools.product(self.array_transforms, - self.shape_transforms): - fill_value = fn1(f) - shape = fn2(s) - self.match( - array_ops.full(shape, fill_value), np.full(shape, fill_value)) - for dtype in self.all_types: - self.match( - array_ops.full(shape, fill_value, dtype=dtype), - np.full(shape, fill_value, dtype=dtype)) - - def testFullLike(self): - # List of 2-tuples of fill value and shape. - data = [ - (5, ()), - (5, (7,)), - (5., (7,)), - ([5, 8], (2,)), - ([5, 8], (3, 2)), - ([[5], [8]], (2, 3)), - ([[5], [8]], (3, 2, 5)), - ([[5.], [8.]], (3, 2, 5)), - ] - zeros_builders = [array_ops.zeros, np.zeros] - for f, s in data: - for fn1, fn2, arr_dtype in itertools.product( - self.array_transforms, zeros_builders, self.all_types): - fill_value = fn1(f) - arr = fn2(s, arr_dtype) - self.match( - array_ops.full_like(arr, fill_value), np.full_like(arr, fill_value)) - for dtype in self.all_types: - self.match( - array_ops.full_like(arr, fill_value, dtype=dtype), - np.full_like(arr, fill_value, dtype=dtype)) - - def testArray(self): - ndmins = [0, 1, 2, 5] - for a, dtype, ndmin, copy in itertools.product( - self.all_arrays, self.all_types, ndmins, [True, False]): - self.match( - array_ops.array(a, dtype=dtype, ndmin=ndmin, copy=copy), - np.array(a, dtype=dtype, ndmin=ndmin, copy=copy)) - - zeros_list = array_ops.zeros(5) - - # TODO(srbs): Test that copy=True when context.device is different from - # tensor device copies the tensor. - - # Backing tensor is the same if copy=False, other attributes being None. - self.assertIs(array_ops.array(zeros_list, copy=False).data, zeros_list.data) - self.assertIs( - array_ops.array(zeros_list.data, copy=False).data, zeros_list.data) - - # Backing tensor is different if ndmin is not satisfied. - self.assertIsNot( - array_ops.array(zeros_list, copy=False, ndmin=2).data, zeros_list.data) - self.assertIsNot( - array_ops.array(zeros_list.data, copy=False, ndmin=2).data, - zeros_list.data) - self.assertIs( - array_ops.array(zeros_list, copy=False, ndmin=1).data, zeros_list.data) - self.assertIs( - array_ops.array(zeros_list.data, copy=False, ndmin=1).data, - zeros_list.data) - - # Backing tensor is different if dtype is not satisfied. - self.assertIsNot( - array_ops.array(zeros_list, copy=False, dtype=int).data, - zeros_list.data) - self.assertIsNot( - array_ops.array(zeros_list.data, copy=False, dtype=int).data, - zeros_list.data) - self.assertIs( - array_ops.array(zeros_list, copy=False, dtype=float).data, - zeros_list.data) - self.assertIs( - array_ops.array(zeros_list.data, copy=False, dtype=float).data, - zeros_list.data) - - def testAsArray(self): - for a, dtype in itertools.product(self.all_arrays, self.all_types): - self.match(array_ops.asarray(a, dtype=dtype), np.asarray(a, dtype=dtype)) - - zeros_list = array_ops.zeros(5) - # Same instance is returned if no dtype is specified and input is ndarray. - self.assertIs(array_ops.asarray(zeros_list), zeros_list) - # Different instance is returned if dtype is specified and input is ndarray. - self.assertIsNot(array_ops.asarray(zeros_list, dtype=int), zeros_list) - - def testAsAnyArray(self): - for a, dtype in itertools.product(self.all_arrays, self.all_types): - self.match( - array_ops.asanyarray(a, dtype=dtype), np.asanyarray(a, dtype=dtype)) - zeros_list = array_ops.zeros(5) - # Same instance is returned if no dtype is specified and input is ndarray. - self.assertIs(array_ops.asanyarray(zeros_list), zeros_list) - # Different instance is returned if dtype is specified and input is ndarray. - self.assertIsNot(array_ops.asanyarray(zeros_list, dtype=int), zeros_list) - - def testAsContiguousArray(self): - for a, dtype in itertools.product(self.all_arrays, self.all_types): - self.match( - array_ops.ascontiguousarray(a, dtype=dtype), - np.ascontiguousarray(a, dtype=dtype)) - - def testARange(self): - int_values = np.arange(-3, 3).tolist() - float_values = np.arange(-3.5, 3.5).tolist() - all_values = int_values + float_values - for dtype in self.all_types: - for start in all_values: - msg = 'dtype:{} start:{}'.format(dtype, start) - self.match(array_ops.arange(start), np.arange(start), msg=msg) - self.match( - array_ops.arange(start, dtype=dtype), - np.arange(start, dtype=dtype), - msg=msg) - for stop in all_values: - msg = 'dtype:{} start:{} stop:{}'.format(dtype, start, stop) - self.match( - array_ops.arange(start, stop), np.arange(start, stop), msg=msg) - # TODO(srbs): Investigate and remove check. - # There are some bugs when start or stop is float and dtype is int. - if not isinstance(start, float) and not isinstance(stop, float): - self.match( - array_ops.arange(start, stop, dtype=dtype), - np.arange(start, stop, dtype=dtype), - msg=msg) - # Note: We intentionally do not test with float values for step - # because numpy.arange itself returns inconsistent results. e.g. - # np.arange(0.5, 3, step=0.5, dtype=int) returns - # array([0, 1, 2, 3, 4]) - for step in int_values: - msg = 'dtype:{} start:{} stop:{} step:{}'.format( - dtype, start, stop, step) - if not step: - with self.assertRaises(ValueError): - self.match( - array_ops.arange(start, stop, step), - np.arange(start, stop, step), - msg=msg) - if not isinstance(start, float) and not isinstance(stop, float): - self.match( - array_ops.arange(start, stop, step, dtype=dtype), - np.arange(start, stop, step, dtype=dtype), - msg=msg) - else: - self.match( - array_ops.arange(start, stop, step), - np.arange(start, stop, step), - msg=msg) - if not isinstance(start, float) and not isinstance(stop, float): - self.match( - array_ops.arange(start, stop, step, dtype=dtype), - np.arange(start, stop, step, dtype=dtype), - msg=msg) - - def testGeomSpace(self): - - def run_test(start, stop, **kwargs): - arg1 = start - arg2 = stop - self.match( - array_ops.geomspace(arg1, arg2, **kwargs), - np.geomspace(arg1, arg2, **kwargs), - msg='geomspace({}, {})'.format(arg1, arg2), - almost=True, - decimal=4) - - run_test(1, 1000, num=5) - run_test(1, 1000, num=5, endpoint=False) - run_test(-1, -1000, num=5) - run_test(-1, -1000, num=5, endpoint=False) - - def testDiag(self): - array_transforms = [ - lambda x: x, # Identity, - tf.convert_to_tensor, - np.array, - lambda x: np.array(x, dtype=np.float32), - lambda x: np.array(x, dtype=np.float64), - array_ops.array, - lambda x: array_ops.array(x, dtype=np.float32), - lambda x: array_ops.array(x, dtype=np.float64) - ] - - def run_test(arr): - for fn in array_transforms: - arr = fn(arr) - self.match( - array_ops.diag(arr), np.diag(arr), msg='diag({})'.format(arr)) - for k in range(-3, 3): - self.match( - array_ops.diag(arr, k), - np.diag(arr, k), - msg='diag({}, k={})'.format(arr, k)) - - # 2-d arrays. - run_test(np.arange(9).reshape((3, 3)).tolist()) - run_test(np.arange(6).reshape((2, 3)).tolist()) - run_test(np.arange(6).reshape((3, 2)).tolist()) - run_test(np.arange(3).reshape((1, 3)).tolist()) - run_test(np.arange(3).reshape((3, 1)).tolist()) - run_test([[5]]) - run_test([[]]) - run_test([[], []]) - - # 1-d arrays. - run_test([]) - run_test([1]) - run_test([1, 2]) - - def testDiagFlat(self): - array_transforms = [ - lambda x: x, # Identity, - tf.convert_to_tensor, - np.array, - lambda x: np.array(x, dtype=np.float32), - lambda x: np.array(x, dtype=np.float64), - array_ops.array, - lambda x: array_ops.array(x, dtype=np.float32), - lambda x: array_ops.array(x, dtype=np.float64) - ] - - def run_test(arr): - for fn in array_transforms: - arr = fn(arr) - self.match( - array_ops.diagflat(arr), - np.diagflat(arr), - msg='diagflat({})'.format(arr)) - for k in range(-3, 3): - self.match( - array_ops.diagflat(arr, k), - np.diagflat(arr, k), - msg='diagflat({}, k={})'.format(arr, k)) - - # 1-d arrays. - run_test([]) - run_test([1]) - run_test([1, 2]) - # 2-d arrays. - run_test([[]]) - run_test([[5]]) - run_test([[], []]) - run_test(np.arange(4).reshape((2, 2)).tolist()) - run_test(np.arange(2).reshape((2, 1)).tolist()) - run_test(np.arange(2).reshape((1, 2)).tolist()) - # 3-d arrays - run_test(np.arange(8).reshape((2, 2, 2)).tolist()) - - def match_shape(self, actual, expected, msg=None): - if msg: - msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( - msg, expected.shape, actual.shape) - self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) - - def match_dtype(self, actual, expected, msg=None): - if msg: - msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format( - msg, expected.dtype, actual.dtype) - self.assertEqual(actual.dtype, expected.dtype, msg=msg) - - def match(self, actual, expected, msg=None, almost=False, decimal=7): - msg_ = 'Expected: {} Actual: {}'.format(expected, actual) - if msg: - msg = '{} {}'.format(msg_, msg) - else: - msg = msg_ - self.assertIsInstance(actual, arrays.ndarray) - self.match_dtype(actual, expected, msg) - self.match_shape(actual, expected, msg) - if not almost: - if not actual.shape: - self.assertEqual(actual.tolist(), expected.tolist()) - else: - self.assertSequenceEqual(actual.tolist(), expected.tolist()) - else: - np.testing.assert_almost_equal( - actual.tolist(), expected.tolist(), decimal=decimal) - - def testIndexedSlices(self): - dtype = tf.int64 - iss = tf.IndexedSlices(values=tf.ones([2, 3], dtype=dtype), - indices=tf.constant([1, 9]), - dense_shape=[10, 3]) - a = array_ops.array(iss, copy=False) - expected = tf.scatter_nd([[1], [9]], tf.ones([2, 3], dtype=dtype), [10, 3]) - self.assertAllEqual(expected, a) - - -class ArrayMethodsTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self.array_transforms = [ - lambda x: x, - tf.convert_to_tensor, - np.array, - array_ops.array, - ] - - def testAllAny(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arr = fn(arr) - self.match( - array_ops.all(arr, *args, **kwargs), np.all(arr, *args, **kwargs)) - self.match( - array_ops.any(arr, *args, **kwargs), np.any(arr, *args, **kwargs)) - - run_test(0) - run_test(1) - run_test([]) - run_test([[True, False], [True, True]]) - run_test([[True, False], [True, True]], axis=0) - run_test([[True, False], [True, True]], axis=0, keepdims=True) - run_test([[True, False], [True, True]], axis=1) - run_test([[True, False], [True, True]], axis=1, keepdims=True) - run_test([[True, False], [True, True]], axis=(0, 1)) - run_test([[True, False], [True, True]], axis=(0, 1), keepdims=True) - run_test([5.2, 3.5], axis=0) - run_test([1, 0], axis=0) - - def testCompress(self): - - def run_test(condition, arr, *args, **kwargs): - for fn1 in self.array_transforms: - for fn2 in self.array_transforms: - arg1 = fn1(condition) - arg2 = fn2(arr) - self.match( - array_ops.compress(arg1, arg2, *args, **kwargs), - np.compress( - np.asarray(arg1).astype(bool), arg2, *args, **kwargs)) - - run_test([True], 5) - run_test([False], 5) - run_test([], 5) - run_test([True, False, True], [1, 2, 3]) - run_test([True, False], [1, 2, 3]) - run_test([False, True], [[1, 2], [3, 4]]) - run_test([1, 0, 1], [1, 2, 3]) - run_test([1, 0], [1, 2, 3]) - run_test([0, 1], [[1, 2], [3, 4]]) - run_test([True], [[1, 2], [3, 4]]) - run_test([False, True], [[1, 2], [3, 4]], axis=1) - run_test([False, True], [[1, 2], [3, 4]], axis=0) - run_test([False, True], [[1, 2], [3, 4]], axis=-1) - run_test([False, True], [[1, 2], [3, 4]], axis=-2) - - def testCopy(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - array_ops.copy(arg, *args, **kwargs), np.copy(arg, *args, **kwargs)) - - run_test([]) - run_test([1, 2, 3]) - run_test([1., 2., 3.]) - run_test([True]) - run_test(np.arange(9).reshape((3, 3)).tolist()) - - def testCumProdAndSum(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - array_ops.cumprod(arg, *args, **kwargs), - np.cumprod(arg, *args, **kwargs)) - self.match( - array_ops.cumsum(arg, *args, **kwargs), - np.cumsum(arg, *args, **kwargs)) - - run_test([]) - run_test([1, 2, 3]) - run_test([1, 2, 3], dtype=float) - run_test([1, 2, 3], dtype=np.float32) - run_test([1, 2, 3], dtype=np.float64) - run_test([1., 2., 3.]) - run_test([1., 2., 3.], dtype=int) - run_test([1., 2., 3.], dtype=np.int32) - run_test([1., 2., 3.], dtype=np.int64) - run_test([[1, 2], [3, 4]], axis=1) - run_test([[1, 2], [3, 4]], axis=0) - run_test([[1, 2], [3, 4]], axis=-1) - run_test([[1, 2], [3, 4]], axis=-2) - - def testImag(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - array_ops.imag(arg, *args, **kwargs), - # np.imag may return a scalar so we convert to a np.ndarray. - np.array(np.imag(arg, *args, **kwargs))) - - run_test(1) - run_test(5.5) - run_test(5 + 3j) - run_test(3j) - run_test([]) - run_test([1, 2, 3]) - run_test([1 + 5j, 2 + 3j]) - run_test([[1 + 5j, 2 + 3j], [1 + 7j, 2 + 8j]]) - - def testAMaxAMin(self): - - def run_test(arr, *args, **kwargs): - axis = kwargs.pop('axis', None) - for fn1 in self.array_transforms: - for fn2 in self.array_transforms: - arr_arg = fn1(arr) - axis_arg = fn2(axis) if axis is not None else None - self.match( - array_ops.amax(arr_arg, axis=axis_arg, *args, **kwargs), - np.amax(arr_arg, axis=axis, *args, **kwargs)) - self.match( - array_ops.amin(arr_arg, axis=axis_arg, *args, **kwargs), - np.amin(arr_arg, axis=axis, *args, **kwargs)) - - run_test([1, 2, 3]) - run_test([1., 2., 3.]) - run_test([[1, 2], [3, 4]], axis=1) - run_test([[1, 2], [3, 4]], axis=0) - run_test([[1, 2], [3, 4]], axis=-1) - run_test([[1, 2], [3, 4]], axis=-2) - run_test([[1, 2], [3, 4]], axis=(0, 1)) - run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2)) - run_test( - np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True) - run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0)) - run_test( - np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True) - - def testMean(self): - - def run_test(arr, *args, **kwargs): - axis = kwargs.pop('axis', None) - for fn1 in self.array_transforms: - for fn2 in self.array_transforms: - arr_arg = fn1(arr) - axis_arg = fn2(axis) if axis is not None else None - self.match( - array_ops.mean(arr_arg, axis=axis_arg, *args, **kwargs), - np.mean(arr_arg, axis=axis, *args, **kwargs)) - - run_test([1, 2, 1]) - run_test([1., 2., 1.]) - run_test([1., 2., 1.], dtype=int) - run_test([[1, 2], [3, 4]], axis=1) - run_test([[1, 2], [3, 4]], axis=0) - run_test([[1, 2], [3, 4]], axis=-1) - run_test([[1, 2], [3, 4]], axis=-2) - run_test([[1, 2], [3, 4]], axis=(0, 1)) - run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2)) - run_test( - np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True) - run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0)) - run_test( - np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True) - - def testProd(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - array_ops.prod(arg, *args, **kwargs), np.prod(arg, *args, **kwargs)) - - run_test([1, 2, 3]) - run_test([1., 2., 3.]) - run_test(np.array([1, 2, 3], dtype=np.int16)) - run_test([[1, 2], [3, 4]], axis=1) - run_test([[1, 2], [3, 4]], axis=0) - run_test([[1, 2], [3, 4]], axis=-1) - run_test([[1, 2], [3, 4]], axis=-2) - run_test([[1, 2], [3, 4]], axis=(0, 1)) - run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2)) - run_test( - np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True) - run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0)) - run_test( - np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True) - - def _testReduce(self, math_fun, np_fun, name): - axis_transforms = [ - lambda x: x, # Identity, - tf.convert_to_tensor, - np.array, - array_ops.array, - lambda x: array_ops.array(x, dtype=np.float32), - lambda x: array_ops.array(x, dtype=np.float64), - ] - - def run_test(a, **kwargs): - axis = kwargs.pop('axis', None) - for fn1 in self.array_transforms: - for fn2 in axis_transforms: - arg1 = fn1(a) - axis_arg = fn2(axis) if axis is not None else None - self.match( - math_fun(arg1, axis=axis_arg, **kwargs), - np_fun(arg1, axis=axis, **kwargs), - msg='{}({}, axis={}, keepdims={})'.format( - name, arg1, axis, kwargs.get('keepdims'))) - - run_test(5) - run_test([2, 3]) - run_test([[2, -3], [-6, 7]]) - run_test([[2, -3], [-6, 7]], axis=0) - run_test([[2, -3], [-6, 7]], axis=0, keepdims=True) - run_test([[2, -3], [-6, 7]], axis=1) - run_test([[2, -3], [-6, 7]], axis=1, keepdims=True) - run_test([[2, -3], [-6, 7]], axis=(0, 1)) - run_test([[2, -3], [-6, 7]], axis=(1, 0)) - - def testSum(self): - self._testReduce(array_ops.sum, np.sum, 'sum') - - def testAmax(self): - self._testReduce(array_ops.amax, np.amax, 'amax') - - def testRavel(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - array_ops.ravel(arg, *args, **kwargs), - np.ravel(arg, *args, **kwargs)) - - run_test(5) - run_test(5.) - run_test([]) - run_test([[]]) - run_test([[], []]) - run_test([1, 2, 3]) - run_test([1., 2., 3.]) - run_test([[1, 2], [3, 4]]) - run_test(np.arange(8).reshape((2, 2, 2)).tolist()) - - def testReal(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - array_ops.real(arg, *args, **kwargs), - np.array(np.real(arg, *args, **kwargs))) - - run_test(1) - run_test(5.5) - run_test(5 + 3j) - run_test(3j) - run_test([]) - run_test([1, 2, 3]) - run_test([1 + 5j, 2 + 3j]) - run_test([[1 + 5j, 2 + 3j], [1 + 7j, 2 + 8j]]) - - def testRepeat(self): - - def run_test(arr, repeats, *args, **kwargs): - for fn1 in self.array_transforms: - for fn2 in self.array_transforms: - arr_arg = fn1(arr) - repeats_arg = fn2(repeats) - self.match( - array_ops.repeat(arr_arg, repeats_arg, *args, **kwargs), - np.repeat(arr_arg, repeats_arg, *args, **kwargs)) - - run_test(1, 2) - run_test([1, 2], 2) - run_test([1, 2], [2]) - run_test([1, 2], [1, 2]) - run_test([[1, 2], [3, 4]], 3, axis=0) - run_test([[1, 2], [3, 4]], 3, axis=1) - run_test([[1, 2], [3, 4]], [3], axis=0) - run_test([[1, 2], [3, 4]], [3], axis=1) - run_test([[1, 2], [3, 4]], [3, 2], axis=0) - run_test([[1, 2], [3, 4]], [3, 2], axis=1) - run_test([[1, 2], [3, 4]], [3, 2], axis=-1) - run_test([[1, 2], [3, 4]], [3, 2], axis=-2) - - def testAround(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - array_ops.around(arg, *args, **kwargs), - np.around(arg, *args, **kwargs)) - - run_test(5.5) - run_test(5.567, decimals=2) - run_test([]) - run_test([1.27, 2.49, 2.75], decimals=1) - run_test([23.6, 45.1], decimals=-1) - - def testReshape(self): - - def run_test(arr, newshape, *args, **kwargs): - for fn1 in self.array_transforms: - for fn2 in self.array_transforms: - arr_arg = fn1(arr) - newshape_arg = fn2(newshape) - # If reshape is called on a Tensor, it calls out to the Tensor.reshape - # method. - np_arr_arg = arr_arg - if isinstance(np_arr_arg, tf.Tensor): - np_arr_arg = np_arr_arg.numpy() - self.match( - array_ops.reshape(arr_arg, newshape_arg, *args, **kwargs), - np.reshape(np_arr_arg, newshape, *args, **kwargs)) - - run_test(5, [-1]) - run_test([], [-1]) - run_test([1, 2, 3], [1, 3]) - run_test([1, 2, 3], [3, 1]) - run_test([1, 2, 3, 4], [2, 2]) - run_test([1, 2, 3, 4], [2, 1, 2]) - - def testExpandDims(self): - - def run_test(arr, axis): - self.match(array_ops.expand_dims(arr, axis), np.expand_dims(arr, axis)) - - run_test([1, 2, 3], 0) - run_test([1, 2, 3], 1) - - def testSqueeze(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - # Note: np.squeeze ignores the axis arg for non-ndarray objects. - # This looks like a bug: https://github.com/numpy/numpy/issues/8201 - # So we convert the arg to np.ndarray before passing to np.squeeze. - self.match( - array_ops.squeeze(arg, *args, **kwargs), - np.squeeze(np.array(arg), *args, **kwargs)) - - run_test(5) - run_test([]) - run_test([5]) - run_test([[1, 2, 3]]) - run_test([[[1], [2], [3]]]) - run_test([[[1], [2], [3]]], axis=0) - run_test([[[1], [2], [3]]], axis=2) - run_test([[[1], [2], [3]]], axis=(0, 2)) - run_test([[[1], [2], [3]]], axis=-1) - run_test([[[1], [2], [3]]], axis=-3) - - def testTranspose(self): - - def run_test(arr, axes=None): - for fn1 in self.array_transforms: - for fn2 in self.array_transforms: - arr_arg = fn1(arr) - axes_arg = fn2(axes) if axes is not None else None - # If transpose is called on a Tensor, it calls out to the - # Tensor.transpose method. - np_arr_arg = arr_arg - if isinstance(np_arr_arg, tf.Tensor): - np_arr_arg = np_arr_arg.numpy() - self.match( - array_ops.transpose(arr_arg, axes_arg), - np.transpose(np_arr_arg, axes)) - - run_test(5) - run_test([]) - run_test([5]) - run_test([5, 6, 7]) - run_test(np.arange(30).reshape(2, 3, 5).tolist()) - run_test(np.arange(30).reshape(2, 3, 5).tolist(), [0, 1, 2]) - run_test(np.arange(30).reshape(2, 3, 5).tolist(), [0, 2, 1]) - run_test(np.arange(30).reshape(2, 3, 5).tolist(), [1, 0, 2]) - run_test(np.arange(30).reshape(2, 3, 5).tolist(), [1, 2, 0]) - run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 0, 1]) - run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 1, 0]) - - def testSetItem(self): - - def run_test(arr, index, value): - for fn in self.array_transforms: - value_arg = fn(value) - tf_array = array_ops.array(arr) - np_array = np.array(arr) - tf_array[index] = value_arg - # TODO(srbs): "setting an array element with a sequence" is thrown - # if we do not wrap value_arg in a numpy array. Investigate how this can - # be avoided. - np_array[index] = np.array(value_arg) - self.match(tf_array, np_array) - - run_test([1, 2, 3], 1, 5) - run_test([[1, 2], [3, 4]], 0, [6, 7]) - run_test([[1, 2], [3, 4]], 1, [6, 7]) - run_test([[1, 2], [3, 4]], (0, 1), 6) - run_test([[1, 2], [3, 4]], 0, 6) # Value needs to broadcast. - - def match_shape(self, actual, expected, msg=None): - if msg: - msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( - msg, expected.shape, actual.shape) - self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) - - def match_dtype(self, actual, expected, msg=None): - if msg: - msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format( - msg, expected.dtype, actual.dtype) - self.assertEqual(actual.dtype, expected.dtype, msg=msg) - - def match(self, actual, expected, msg=None, check_dtype=True): - msg_ = 'Expected: {} Actual: {}'.format(expected, actual) - if msg: - msg = '{} {}'.format(msg_, msg) - else: - msg = msg_ - self.assertIsInstance(actual, arrays.ndarray) - if check_dtype: - self.match_dtype(actual, expected, msg) - self.match_shape(actual, expected, msg) - if not actual.shape: - self.assertAllClose(actual.tolist(), expected.tolist()) - else: - self.assertAllClose(actual.tolist(), expected.tolist()) - - def testPad(self): - t = [[1, 2, 3], [4, 5, 6]] - paddings = [[1, 1,], [2, 2]] - self.assertAllEqual( - array_ops.pad(t, paddings, 'constant'), - [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0, 0], [0, 0, 4, 5, 6, 0, 0], - [0, 0, 0, 0, 0, 0, 0]]) - - self.assertAllEqual( - array_ops.pad(t, paddings, 'reflect'), - [[6, 5, 4, 5, 6, 5, 4], [3, 2, 1, 2, 3, 2, 1], [6, 5, 4, 5, 6, 5, 4], - [3, 2, 1, 2, 3, 2, 1]]) - - self.assertAllEqual( - array_ops.pad(t, paddings, 'symmetric'), - [[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2], [5, 4, 4, 5, 6, 6, 5], - [5, 4, 4, 5, 6, 6, 5]]) - - def testTake(self): - a = [4, 3, 5, 7, 6, 8] - indices = [0, 1, 4] - self.assertAllEqual([4, 3, 6], array_ops.take(a, indices)) - indices = [[0, 1], [2, 3]] - self.assertAllEqual([[4, 3], [5, 7]], array_ops.take(a, indices)) - a = [[4, 3, 5], [7, 6, 8]] - self.assertAllEqual([[4, 3], [5, 7]], array_ops.take(a, indices)) - a = np.random.rand(2, 16, 3) - axis = 1 - self.assertAllEqual( - np.take(a, indices, axis=axis), array_ops.take(a, indices, axis=axis)) - - def testWhere(self): - self.assertAllEqual([[1.0, 1.0], [1.0, 1.0]], - array_ops.where([True], [1.0, 1.0], [[0, 0], [0, 0]])) - - def testShape(self): - self.assertAllEqual((1, 2), array_ops.shape([[0, 0]])) - - def testSwapaxes(self): - x = [[1, 2, 3]] - self.assertAllEqual([[1], [2], [3]], array_ops.swapaxes(x, 0, 1)) - self.assertAllEqual([[1], [2], [3]], array_ops.swapaxes(x, -2, -1)) - x = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] - self.assertAllEqual([[[0, 4], [2, 6]], [[1, 5], [3, 7]]], - array_ops.swapaxes(x, 0, 2)) - self.assertAllEqual([[[0, 4], [2, 6]], [[1, 5], [3, 7]]], - array_ops.swapaxes(x, -3, -1)) - - def testMoveaxis(self): - - def _test(*args): - expected = np.moveaxis(*args) - raw_ans = array_ops.moveaxis(*args) - - self.assertAllEqual(expected, raw_ans) - - a = np.random.rand(1, 2, 3, 4, 5, 6) - - # Basic - _test(a, (0, 2), (3, 5)) - _test(a, (0, 2), (-1, -3)) - _test(a, (-6, -4), (3, 5)) - _test(a, (-6, -4), (-1, -3)) - _test(a, 0, 4) - _test(a, -6, -2) - _test(a, tuple(range(6)), tuple(range(6))) - _test(a, tuple(range(6)), tuple(reversed(range(6)))) - _test(a, (), ()) - - def testNdim(self): - self.assertAllEqual(0, array_ops.ndim(0.5)) - self.assertAllEqual(1, array_ops.ndim([1, 2])) - - def testIsscalar(self): - self.assertTrue(array_ops.isscalar(0.5)) - self.assertTrue(array_ops.isscalar(5)) - self.assertTrue(array_ops.isscalar(False)) - self.assertFalse(array_ops.isscalar([1, 2])) - - def assertListEqual(self, a, b): - self.assertAllEqual(len(a), len(b)) - for x, y in zip(a, b): - self.assertAllEqual(x, y) - - def testSplit(self): - x = array_ops.arange(9) - y = array_ops.split(x, 3) - self.assertListEqual([([0, 1, 2]), - ([3, 4, 5]), - ([6, 7, 8])], y) - - x = array_ops.arange(8) - y = array_ops.split(x, [3, 5, 6, 10]) - self.assertListEqual([([0, 1, 2]), - ([3, 4]), - ([5]), - ([6, 7]), - ([])], y) - - -class ArrayManipulationTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self.array_transforms = [ - lambda x: x, - tf.convert_to_tensor, - np.array, - array_ops.array, - ] - - def testBroadcastTo(self): - - def run_test(arr, shape): - for fn in self.array_transforms: - arg1 = fn(arr) - self.match( - array_ops.broadcast_to(arg1, shape), np.broadcast_to(arg1, shape)) - - run_test(1, 2) - run_test(1, (2, 2)) - run_test([1, 2], (2, 2)) - run_test([[1], [2]], (2, 2)) - run_test([[1, 2]], (3, 2)) - run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2)) - - def testIx_(self): - possible_arys = [[True, True], [True, False], [False, False], - list(range(5)), array_ops.empty(0, dtype=np.int64)] - for r in range(len(possible_arys)): - for arys in itertools.combinations_with_replacement(possible_arys, r): - tnp_ans = array_ops.ix_(*arys) - onp_ans = np.ix_(*arys) - for t, o in zip(tnp_ans, onp_ans): - self.match(t, o) - - def match_shape(self, actual, expected, msg=None): - if msg: - msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( - msg, expected.shape, actual.shape) - self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) - - def match_dtype(self, actual, expected, msg=None): - if msg: - msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format( - msg, expected.dtype, actual.dtype) - self.assertEqual(actual.dtype, expected.dtype, msg=msg) - - def match(self, actual, expected, msg=None): - msg_ = 'Expected: {} Actual: {}'.format(expected, actual) - if msg: - msg = '{} {}'.format(msg_, msg) - else: - msg = msg_ - self.assertIsInstance(actual, arrays.ndarray) - self.match_dtype(actual, expected, msg) - self.match_shape(actual, expected, msg) - if not actual.shape: - self.assertEqual(actual.tolist(), expected.tolist()) - else: - self.assertSequenceEqual(actual.tolist(), expected.tolist()) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/tests/arrays_test.py b/trax/tf_numpy/numpy_impl/tests/arrays_test.py deleted file mode 100644 index c744c9623..000000000 --- a/trax/tf_numpy/numpy_impl/tests/arrays_test.py +++ /dev/null @@ -1,181 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for ndarray.""" -from collections import abc - -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import arrays -# Required for operator overloads -from trax.tf_numpy.numpy_impl import math_ops # pylint: disable=unused-import - - -t2a = arrays.tensor_to_ndarray - - -class ArrayTest(tf.test.TestCase): - - def testDtype(self): - a = t2a(tf.zeros(shape=[1, 2], dtype=tf.int64)) - self.assertIs(a.dtype.type, np.int64) - self.assertAllEqual(0, a.dtype.type(0)) - - def testAstype(self): - a = t2a(tf.convert_to_tensor(value=1.1, dtype=tf.float32)).astype(np.int32) - self.assertIs(a.dtype.type, np.int32) - self.assertAllEqual(1, a) - a = t2a(tf.convert_to_tensor(value=[0.0, 1.1], - dtype=tf.float32)).astype(np.bool_) - self.assertIs(a.dtype.type, np.bool_) - self.assertAllEqual([False, True], a) - - def testNeg(self): - a = t2a(tf.convert_to_tensor(value=[1.0, 2.0])) - self.assertAllEqual([-1.0, -2.0], -a) - - def _testBinOp(self, a, b, out, f, types=None): - a = t2a(tf.convert_to_tensor(value=a, dtype=np.int32)) - b = t2a(tf.convert_to_tensor(value=b, dtype=np.int32)) - if not isinstance(out, arrays.ndarray): - out = t2a(tf.convert_to_tensor(value=out, dtype=np.int32)) - if types is None: - types = [[np.int32, np.int32, np.int32], - [np.int64, np.int32, np.int64], - [np.int32, np.int64, np.int64], - [np.float32, np.int32, np.float64], - [np.int32, np.float32, np.float64], - [np.float32, np.float32, np.float32], - [np.float64, np.float32, np.float64], - [np.float32, np.float64, np.float64]] - for a_type, b_type, out_type in types: - o = f(a.astype(a_type), b.astype(b_type)) - self.assertIs(o.dtype.type, out_type) - self.assertAllEqual(out.astype(out_type), o) - - def testAdd(self): - self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: a.__add__(b)) - - def testRadd(self): - self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: b.__radd__(a)) - - def testSub(self): - self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: a.__sub__(b)) - - def testRsub(self): - self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: b.__rsub__(a)) - - def testMul(self): - self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: a.__mul__(b)) - - def testRmul(self): - self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: b.__rmul__(a)) - - def testPow(self): - self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: a.__pow__(b)) - - def testRpow(self): - self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: b.__rpow__(a)) - - _truediv_types = [[np.int32, np.int32, np.float64], - [np.int64, np.int32, np.float64], - [np.int32, np.int64, np.float64], - [np.float32, np.int32, np.float64], - [np.int32, np.float32, np.float64], - [np.float32, np.float32, np.float32], - [np.float64, np.float32, np.float64], - [np.float32, np.float64, np.float64]] - - def testTruediv(self): - self._testBinOp([3, 5], [2, 4], - t2a(tf.convert_to_tensor(value=[1.5, 1.25])), - lambda a, b: a.__truediv__(b), types=self._truediv_types) - - def testRtruediv(self): - self._testBinOp([3, 5], [2, 4], - t2a(tf.convert_to_tensor(value=[1.5, 1.25])), - lambda a, b: b.__rtruediv__(a), types=self._truediv_types) - - def _testCmp(self, a, b, out, f): - a = t2a(tf.convert_to_tensor(value=a, dtype=np.int32)) - b = t2a(tf.convert_to_tensor(value=b, dtype=np.int32)) - types = [[np.int32, np.int32], - [np.int64, np.int32], - [np.int32, np.int64], - [np.float32, np.int32], - [np.int32, np.float32], - [np.float32, np.float32], - [np.float64, np.float32], - [np.float32, np.float64]] - for a_type, b_type in types: - o = f(a.astype(a_type), b.astype(b_type)) - self.assertAllEqual(out, o) - - def testLt(self): - self._testCmp([1, 2, 3], [3, 2, 1], [True, False, False], - lambda a, b: a.__lt__(b)) - - def testLe(self): - self._testCmp([1, 2, 3], [3, 2, 1], [True, True, False], - lambda a, b: a.__le__(b)) - - def testGt(self): - self._testCmp([1, 2, 3], [3, 2, 1], [False, False, True], - lambda a, b: a.__gt__(b)) - - def testGe(self): - self._testCmp([1, 2, 3], [3, 2, 1], [False, True, True], - lambda a, b: a.__ge__(b)) - - def testEq(self): - self._testCmp([1, 2, 3], [3, 2, 1], [False, True, False], - lambda a, b: a.__eq__(b)) - - def testNe(self): - self._testCmp([1, 2, 3], [3, 2, 1], [True, False, True], - lambda a, b: a.__ne__(b)) - - def testInt(self): - v = 10 - u = int(t2a(tf.convert_to_tensor(value=v))) - self.assertIsInstance(u, int) - self.assertAllEqual(v, u) - - def testFloat(self): - v = 21.32 - u = float(t2a(tf.convert_to_tensor(value=v))) - self.assertIsInstance(u, float) - self.assertAllClose(v, u) - - def testBool(self): - b = bool(t2a(tf.convert_to_tensor(value=10))) - self.assertIsInstance(b, bool) - self.assertTrue(b) - self.assertFalse(bool(t2a(tf.convert_to_tensor(value=0)))) - self.assertTrue(bool(t2a(tf.convert_to_tensor(value=0.1)))) - self.assertFalse(bool(t2a(tf.convert_to_tensor(value=0.0)))) - - def testHash(self): - a = t2a(tf.convert_to_tensor(value=10)) - self.assertNotIsInstance(a, abc.Hashable) - with self.assertRaisesWithPredicateMatch( - TypeError, r'unhashable type'): - hash(a) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/tests/backprop_test.py b/trax/tf_numpy/numpy_impl/tests/backprop_test.py deleted file mode 100644 index 581931cd2..000000000 --- a/trax/tf_numpy/numpy_impl/tests/backprop_test.py +++ /dev/null @@ -1,64 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for backpropgration on tf-numpy functions.""" -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import array_ops -# Required for operator overloads -from trax.tf_numpy.numpy_impl import math_ops # pylint: disable=unused-import - - -class BackpropTest(tf.test.TestCase): - - def test_setitem(self): - # Single integer index. - a = array_ops.array([1., 2., 3.]) - b = array_ops.array(5.) - c = array_ops.array(10.) - - tensors = [arr.data for arr in [a, b, c]] - with tf.GradientTape() as g: - g.watch(tensors) - a[1] = b + c - loss = array_ops.sum(a) - - gradients = g.gradient(loss.data, tensors) - self.assertSequenceEqual( - array_ops.array(gradients[0]).tolist(), [1., 0., 1.]) - self.assertEqual(array_ops.array(gradients[1]).tolist(), 1.) - self.assertEqual(array_ops.array(gradients[2]).tolist(), 1.) - - # Tuple index. - a = array_ops.array([[[1., 2.], [3., 4.]], [[5., 6.], - [7., 8.]]]) # 2x2x2 array. - b = array_ops.array([10., 11.]) - - tensors = [arr.data for arr in [a, b]] - with tf.GradientTape() as g: - g.watch(tensors) - a[(1, 0)] = b - loss = array_ops.sum(a) - - gradients = g.gradient(loss.data, tensors) - self.assertSequenceEqual( - array_ops.array(gradients[0]).tolist(), - [[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]]) - self.assertEqual(array_ops.array(gradients[1]).tolist(), [1., 1.]) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/tests/logic_test.py b/trax/tf_numpy/numpy_impl/tests/logic_test.py deleted file mode 100644 index f1a584fe0..000000000 --- a/trax/tf_numpy/numpy_impl/tests/logic_test.py +++ /dev/null @@ -1,104 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tf numpy random number methods.""" -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import array_ops -from trax.tf_numpy.numpy_impl import arrays -from trax.tf_numpy.numpy_impl import math_ops - - -class LogicTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self.array_transforms = [ - lambda x: x, # Identity, - tf.convert_to_tensor, - np.array, - lambda x: np.array(x, dtype=np.int32), - lambda x: np.array(x, dtype=np.int64), - lambda x: np.array(x, dtype=np.float32), - lambda x: np.array(x, dtype=np.float64), - array_ops.array, - lambda x: array_ops.array(x, dtype=tf.int32), - lambda x: array_ops.array(x, dtype=tf.int64), - lambda x: array_ops.array(x, dtype=tf.float32), - lambda x: array_ops.array(x, dtype=tf.float64), - ] - - def testEqual(self): - - def run_test(x1, x2=None): - if x2 is None: - x2 = x1 - for fn1 in self.array_transforms: - for fn2 in self.array_transforms: - arg1 = fn1(x1) - arg2 = fn2(x2) - self.match( - math_ops.equal(arg1, arg2), - np.equal( - make_numpy_compatible(arg1), make_numpy_compatible(arg2))) - - run_test(1) - run_test(1, 2) - run_test([1, 2]) - run_test([1, 2, 3], [2]) - run_test([[1, 2], [3, 4]], [1, 2]) - run_test([[1, 2], [1, 4]], [1, 2]) - run_test([1, 2], [[1, 2], [1, 4]]) - run_test([[1, 2], [3, 4]], [[1, 2], [3, 4]]) - run_test([[1, 2], [3, 4]], [[1, 3], [3, 4]]) - - def match_shape(self, actual, expected, msg=None): - if msg: - msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( - msg, expected.shape, actual.shape) - self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) - - def match_dtype(self, actual, expected, msg=None): - if msg: - msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format( - msg, expected.dtype, actual.dtype) - self.assertEqual(actual.dtype, expected.dtype, msg=msg) - - def match(self, actual, expected, msg=None): - msg_ = 'Expected: {} Actual: {}'.format(expected, actual) - if msg: - msg = '{} {}'.format(msg_, msg) - else: - msg = msg_ - self.assertIsInstance(actual, arrays.ndarray) - self.match_dtype(actual, expected, msg) - self.match_shape(actual, expected, msg) - if not actual.shape: - self.assertEqual(actual.tolist(), expected.tolist()) - else: - self.assertSequenceEqual(actual.tolist(), expected.tolist()) - - -def make_numpy_compatible(s): - return s if not isinstance(s, arrays.ndarray) else s.data.numpy() - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/tests/math_ops_test.py b/trax/tf_numpy/numpy_impl/tests/math_ops_test.py deleted file mode 100644 index 4ce267bf5..000000000 --- a/trax/tf_numpy/numpy_impl/tests/math_ops_test.py +++ /dev/null @@ -1,328 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tf numpy mathematical methods.""" -import itertools -from absl.testing import parameterized -import numpy as np -from six.moves import range -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import array_ops -from trax.tf_numpy.numpy_impl import arrays -from trax.tf_numpy.numpy_impl import math_ops - - -class MathTest(tf.test.TestCase, parameterized.TestCase): - - def setUp(self): - super().setUp() - self.array_transforms = [ - lambda x: x, # Identity, - tf.convert_to_tensor, - np.array, - lambda x: np.array(x, dtype=np.float32), - lambda x: np.array(x, dtype=np.float64), - array_ops.array, - lambda x: array_ops.array(x, dtype=np.float32), - lambda x: array_ops.array(x, dtype=np.float64), - ] - self.types = [np.int32, np.int64, np.float32, np.float64] - - def _testBinaryOp(self, math_fun, np_fun, name, operands=None, - extra_operands=None, - check_promotion=True, - check_promotion_result_type=True): - - def run_test(a, b): - for fn in self.array_transforms: - arg1 = fn(a) - arg2 = fn(b) - self.match( - math_fun(arg1, arg2), - np_fun(arg1, arg2), - msg='{}({}, {})'.format(name, arg1, arg2)) - # Tests type promotion - for type_a in self.types: - for type_b in self.types: - if not check_promotion and type_a != type_b: - continue - arg1 = array_ops.array(a, dtype=type_a) - arg2 = array_ops.array(b, dtype=type_b) - self.match( - math_fun(arg1, arg2), - np_fun(arg1, arg2), - msg='{}({}, {})'.format(name, arg1, arg2), - check_dtype=check_promotion_result_type) - - if operands is None: - operands = [(5, 2), - (5, [2, 3]), - (5, [[2, 3], [6, 7]]), - ([1, 2, 3], 7), - ([1, 2, 3], [5, 6, 7])] - for operand1, operand2 in operands: - run_test(operand1, operand2) - if extra_operands is not None: - for operand1, operand2 in extra_operands: - run_test(operand1, operand2) - - def testDot(self): - extra_operands = [ - ([1, 2], [[5, 6, 7], [8, 9, 10]]), - (np.arange(2 * 3 * 5).reshape([2, 3, 5]).tolist(), - np.arange(5 * 7 * 11).reshape([7, 5, 11]).tolist())] - return self._testBinaryOp( - math_ops.dot, np.dot, 'dot', extra_operands=extra_operands) - - def testMinimum(self): - # The numpy version has strange result type when promotion happens, - # so set check_promotion_result_type to False. - return self._testBinaryOp( - math_ops.minimum, - np.minimum, - 'minimum', - check_promotion_result_type=False) - - def testMaximum(self): - # The numpy version has strange result type when promotion happens, - # so set check_promotion_result_type to False. - return self._testBinaryOp( - math_ops.maximum, - np.maximum, - 'maximum', - check_promotion_result_type=False) - - def testMatmul(self): - operands = [([[1, 2]], [[3, 4, 5], [6, 7, 8]])] - return self._testBinaryOp( - math_ops.matmul, np.matmul, 'matmul', operands=operands) - - def testMatmulError(self): - with self.assertRaisesRegex(ValueError, r''): - math_ops.matmul( - array_ops.ones([], np.int32), array_ops.ones([2, 3], np.int32)) - with self.assertRaisesRegex(ValueError, r''): - math_ops.matmul( - array_ops.ones([2, 3], np.int32), array_ops.ones([], np.int32)) - - def _testUnaryOp(self, math_fun, np_fun, name): - - def run_test(a): - for fn in self.array_transforms: - arg1 = fn(a) - self.match(math_fun(arg1), np_fun(arg1), - msg='{}({})'.format(name, arg1)) - - run_test(5) - run_test([2, 3]) - run_test([[2, -3], [-6, 7]]) - - def testLog(self): - self._testUnaryOp(math_ops.log, np.log, 'log') - - def testExp(self): - self._testUnaryOp(math_ops.exp, np.exp, 'exp') - - def testTanh(self): - self._testUnaryOp(math_ops.tanh, np.tanh, 'tanh') - - def testSqrt(self): - self._testUnaryOp(math_ops.sqrt, np.sqrt, 'sqrt') - - def match(self, actual, expected, msg='', check_dtype=True): - self.assertIsInstance(actual, arrays.ndarray) - if check_dtype: - self.assertEqual( - actual.dtype, expected.dtype, - 'Dtype mismatch.\nActual: {}\nExpected: {}\n{}'.format( - actual.dtype, expected.dtype, msg)) - self.assertEqual( - actual.shape, expected.shape, - 'Shape mismatch.\nActual: {}\nExpected: {}\n{}'.format( - actual.shape, expected.shape, msg)) - np.testing.assert_almost_equal(actual.tolist(), expected.tolist()) - - def testArgsort(self): - self._testUnaryOp(math_ops.argsort, np.argsort, 'argsort') - - # Test stability - r = np.arange(100) - a = np.zeros(100) - np.testing.assert_equal(math_ops.argsort(a, kind='stable'), r) - - def testArgMaxArgMin(self): - data = [ - 0, - 5, - [1], - [1, 2, 3], - [[1, 2, 3]], - [[4, 6], [7, 8]], - [[[4, 6], [9, 10]], [[7, 8], [12, 34]]], - ] - for fn, d in itertools.product(self.array_transforms, data): - arr = fn(d) - self.match(math_ops.argmax(arr), np.argmax(arr)) - self.match(math_ops.argmin(arr), np.argmin(arr)) - if hasattr(arr, 'shape'): - ndims = len(arr.shape) - else: - ndims = array_ops.array(arr, copy=False).ndim - if ndims == 0: - # Numpy flattens the scalar ndarray and treats it as a 1-d array of - # size 1. - ndims = 1 - for axis in range(-ndims, ndims): - self.match(math_ops.argmax(arr, axis=axis), np.argmax(arr, axis=axis)) - self.match(math_ops.argmin(arr, axis=axis), np.argmin(arr, axis=axis)) - - @parameterized.parameters([False, True]) - def testIsCloseEqualNan(self, equal_nan): - a = np.asarray([1, 1, np.nan, 1, np.nan], np.float32) - b = np.asarray([1, 2, 1, np.nan, np.nan], np.float32) - self.match( - math_ops.isclose(a, b, equal_nan=equal_nan), - np.isclose(a, b, equal_nan=equal_nan)) - - def testAverageWrongShape(self): - with self.assertRaisesWithPredicateMatch( - tf.errors.InvalidArgumentError, r''): - math_ops.average(np.ones([2, 3]), weights=np.ones([2, 4])) - with self.assertRaisesWithPredicateMatch( - tf.errors.InvalidArgumentError, r''): - math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([2, 4])) - with self.assertRaisesWithPredicateMatch( - tf.errors.InvalidArgumentError, r''): - math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([])) - with self.assertRaisesWithPredicateMatch( - tf.errors.InvalidArgumentError, r''): - math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([5])) - - def testClip(self): - - def run_test(arr, *args, **kwargs): - check_dtype = kwargs.pop('check_dtype', True) - for fn in self.array_transforms: - arr = fn(arr) - self.match( - math_ops.clip(arr, *args, **kwargs), - np.clip(arr, *args, **kwargs), - check_dtype=check_dtype) - - # NumPy exhibits weird typing behavior when a/a_min/a_max are scalars v/s - # lists, e.g., - # - # np.clip(np.array(0, dtype=np.int32), -5, 5).dtype == np.int64 - # np.clip(np.array([0], dtype=np.int32), -5, 5).dtype == np.int32 - # np.clip(np.array([0], dtype=np.int32), [-5], [5]).dtype == np.int64 - # - # So we skip matching type. In tf-numpy the type of the output array is - # always the same as the input array. - run_test(0, -1, 5, check_dtype=False) - run_test(-1, -1, 5, check_dtype=False) - run_test(5, -1, 5, check_dtype=False) - run_test(-10, -1, 5, check_dtype=False) - run_test(10, -1, 5, check_dtype=False) - run_test(10, None, 5, check_dtype=False) - run_test(10, -1, None, check_dtype=False) - run_test([0, 20, -5, 4], -1, 5, check_dtype=False) - run_test([0, 20, -5, 4], None, 5, check_dtype=False) - run_test([0, 20, -5, 4], -1, None, check_dtype=False) - run_test([0.5, 20.2, -5.7, 4.4], -1.5, 5.1, check_dtype=False) - - run_test([0, 20, -5, 4], [-5, 0, -5, 0], [0, 5, 0, 5], check_dtype=False) - run_test([[1, 2, 3], [4, 5, 6]], [2, 0, 2], 5, check_dtype=False) - run_test([[1, 2, 3], [4, 5, 6]], 0, [5, 3, 1], check_dtype=False) - - def testPtp(self): - - def run_test(arr, *args, **kwargs): - for fn in self.array_transforms: - arg = fn(arr) - self.match( - math_ops.ptp(arg, *args, **kwargs), np.ptp(arg, *args, **kwargs)) - - run_test([1, 2, 3]) - run_test([1., 2., 3.]) - run_test([[1, 2], [3, 4]], axis=1) - run_test([[1, 2], [3, 4]], axis=0) - run_test([[1, 2], [3, 4]], axis=-1) - run_test([[1, 2], [3, 4]], axis=-2) - - def testLinSpace(self): - array_transforms = [ - lambda x: x, # Identity, - tf.convert_to_tensor, - np.array, - lambda x: np.array(x, dtype=np.float32), - lambda x: np.array(x, dtype=np.float64), - array_ops.array, - lambda x: array_ops.array(x, dtype=np.float32), - lambda x: array_ops.array(x, dtype=np.float64) - ] - - def run_test(start, stop, **kwargs): - for fn1 in array_transforms: - for fn2 in array_transforms: - arg1 = fn1(start) - arg2 = fn2(stop) - self.match( - math_ops.linspace(arg1, arg2, **kwargs), - np.linspace(arg1, arg2, **kwargs), - msg='linspace({}, {})'.format(arg1, arg2)) - - run_test(0, 1) - run_test(0, 1, num=10) - run_test(0, 1, endpoint=False) - run_test(0, -1) - run_test(0, -1, num=10) - run_test(0, -1, endpoint=False) - - def testLogSpace(self): - array_transforms = [ - lambda x: x, # Identity, - tf.convert_to_tensor, - np.array, - lambda x: np.array(x, dtype=np.float32), - lambda x: np.array(x, dtype=np.float64), - array_ops.array, - lambda x: array_ops.array(x, dtype=np.float32), - lambda x: array_ops.array(x, dtype=np.float64) - ] - - def run_test(start, stop, **kwargs): - for fn1 in array_transforms: - for fn2 in array_transforms: - arg1 = fn1(start) - arg2 = fn2(stop) - self.match( - math_ops.logspace(arg1, arg2, **kwargs), - np.logspace(arg1, arg2, **kwargs), - msg='logspace({}, {})'.format(arg1, arg2)) - - run_test(0, 5) - run_test(0, 5, num=10) - run_test(0, 5, endpoint=False) - run_test(0, 5, base=2.0) - run_test(0, -5) - run_test(0, -5, num=10) - run_test(0, -5, endpoint=False) - run_test(0, -5, base=2.0) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/tests/random_test.py b/trax/tf_numpy/numpy_impl/tests/random_test.py deleted file mode 100644 index 883078cfa..000000000 --- a/trax/tf_numpy/numpy_impl/tests/random_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tf numpy random number methods.""" -import numpy as np -from six.moves import range -import tensorflow.compat.v2 as tf - -# Needed for ndarray.reshape. -from trax.tf_numpy.numpy_impl import array_ops # pylint: disable=unused-import -from trax.tf_numpy.numpy_impl import random - - -class RandomTest(tf.test.TestCase): - - def assertNotAllClose(self, a, b, **kwargs): - try: - self.assertAllClose(a, b, **kwargs) - except AssertionError: - return - raise AssertionError( - 'The two values are close at all %d elements' % np.size(a)) - - def testRandN(self): - - def run_test(*args): - num_samples = 1000 - tol = 0.1 # High tolerance to keep the # of samples low else the test - # takes a long time to run. - random.seed(10) - outputs = [random.randn(*args) for _ in range(num_samples)] - - # Test output shape. - for output in outputs: - self.assertEqual(output.shape, tuple(args)) - self.assertEqual(output.dtype.type, random.DEFAULT_RANDN_DTYPE) - - if np.prod(args): # Don't bother with empty arrays. - outputs = [output.tolist() for output in outputs] - - # Test that the properties of normal distribution are satisfied. - mean = np.mean(outputs, axis=0) - stddev = np.std(outputs, axis=0) - self.assertAllClose(mean, np.zeros(args), atol=tol) - self.assertAllClose(stddev, np.ones(args), atol=tol) - - # Test that outputs are different with different seeds. - random.seed(20) - diff_seed_outputs = [ - random.randn(*args).tolist() for _ in range(num_samples) - ] - self.assertNotAllClose(outputs, diff_seed_outputs) - - # Test that outputs are the same with the same seed. - random.seed(10) - same_seed_outputs = [ - random.randn(*args).tolist() for _ in range(num_samples) - ] - self.assertAllClose(outputs, same_seed_outputs) - - run_test() - run_test(0) - run_test(1) - run_test(5) - run_test(2, 3) - run_test(0, 2, 3) - run_test(2, 0, 3) - run_test(2, 3, 0) - run_test(2, 3, 5) - - -if __name__ == '__main__': - tf.compat.v1.enable_eager_execution() - tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/tests/utils_test.py b/trax/tf_numpy/numpy_impl/tests/utils_test.py deleted file mode 100644 index ca27a9f21..000000000 --- a/trax/tf_numpy/numpy_impl/tests/utils_test.py +++ /dev/null @@ -1,47 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for utils.py.""" -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import utils - - -class UtilsTest(tf.test.TestCase): - - # pylint: disable=unused-argument - def testNpDoc(self): - def np_fun(x): - """np_fun docstring.""" - return - @utils.np_doc(np_fun) - def f(): - """f docstring.""" - return - expected = """TensorFlow variant of `numpy.np_fun`. - -Unsupported arguments: `x`. - -f docstring. - -Documentation for `numpy.np_fun`: - -np_fun docstring.""" - self.assertEqual(f.__doc__, expected) - - -if __name__ == '__main__': - tf.enable_v2_behavior() - tf.test.main() diff --git a/trax/tf_numpy/numpy_impl/utils.py b/trax/tf_numpy/numpy_impl/utils.py deleted file mode 100644 index 3ad14740a..000000000 --- a/trax/tf_numpy/numpy_impl/utils.py +++ /dev/null @@ -1,397 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for internal use.""" - -# TODO(wangpeng): Use tf_inspect once we move into TF. -import funcsigs -import numpy as np -import tensorflow.compat.v2 as tf - -from trax.tf_numpy.numpy_impl import arrays -from trax.tf_numpy.numpy_impl import dtypes - - -tensor_to_ndarray = arrays.tensor_to_ndarray - - -def _canonicalize_axis(axis, rank): - return _canonicalize_axes([axis], rank)[0] - - -def _canonicalize_axes(axes, rank): - rank = _maybe_static(rank) - - if isinstance(rank, tf.Tensor): - canonicalizer = ( - lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis)) - else: - canonicalizer = lambda axis: axis+rank if axis < 0 else axis - - return [canonicalizer(axis) for axis in axes] - - -def _to_tf_type(dtype): - """Converts a native python or numpy type to TF DType. - - Args: - dtype: Could be a python type, a numpy type or a TF DType. - - Returns: - A tensorflow `DType`. - """ - return tf.as_dtype(dtype) - - -def _to_numpy_type(dtype): - """Converts a native python or TF DType to numpy type. - - Args: - dtype: Could be a python type, a numpy type or a TF DType. - - Returns: - A NumPy `dtype`. - """ - if isinstance(dtype, tf.DType): - return dtype.as_numpy_dtype - return np.dtype(dtype) - - -def finfo(dtype): - """Returns properties of floating point types. - - Note that currently it just forwards to the numpy namesake, while tensorflow - and numpy dtypes may have different properties. - - Args: - dtype: Could be a python type, a numpy type or a TF DType. - - Returns: - A class describing properties of `dtype`, as described by - https://docs.scipy.org/doc/numpy/reference/generated/numpy.finfo.html - """ - return np.finfo(_to_numpy_type(dtype)) - - -def isscalar(val): - """Returns whether `val` is a scalar value or scalar Tensor.""" - if isinstance(val, (np.ndarray, arrays.ndarray, tf.Tensor)): - return len(val.shape) == 0 # pylint: disable=g-explicit-length-test - return np.isscalar(val) - - -# Can't use np_doc because np.result_type is a builtin function. -def result_type(*arrays_and_dtypes): - """Returns the type resulting from applying NumPy type promotion to arguments. - - Args: - *arrays_and_dtypes: A list of array_like objects or dtypes. - - Returns: - A numpy dtype. - """ - def maybe_get_dtype(x): - # Don't put np.ndarray in this list, because np.result_type looks at the - # value (not just dtype) of np.ndarray to decide the result type. - if isinstance(x, (arrays.ndarray, arrays.ShardedNdArray, - tf.Tensor, tf.IndexedSlices)): - return _to_numpy_type(x.dtype) - elif isinstance(x, tf.DType): - return _to_numpy_type(x) - return x - arrays_and_dtypes = [maybe_get_dtype(x) for x in - tf.nest.flatten(arrays_and_dtypes)] - if not arrays_and_dtypes: - # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. - arrays_and_dtypes = [np.asarray([])] - return dtypes._result_type(*arrays_and_dtypes) - - -def promote_types(type1, type2): - """Returns the type resulting from applying NumPy type promotion. - - Args: - type1: A numpy type. - type2: A numpy type. - - Returns: - A numpy type. - """ - type1 = _to_numpy_type(type1) - type2 = _to_numpy_type(type2) - return dtypes.canonicalize_dtype(np.promote_types(type1, type2)) - - -def _has_docstring(f): - return hasattr(f, '__doc__') and isinstance(f.__doc__, str) and f.__doc__ - - -def _add_blank_line(s): - if s.endswith('\n'): - return s + '\n' - else: - return s + '\n\n' - - -def _np_signature(f): - """An enhanced funcsigs.signature that can handle numpy.ufunc.""" - if not isinstance(f, np.ufunc): - try: - return funcsigs.signature(f) - except ValueError: - return None - def names_from_num(prefix, n): - if n <= 0: - return [] - elif n == 1: - return [prefix] - else: - return [prefix + str(i + 1) for i in range(n)] - input_names = names_from_num('x', f.nin) - output_names = names_from_num('out', f.nout) - keyword_only_params = [ - ('where', True), - ('casting', 'same_kind'), - ('order', 'K'), - ('dtype', None), - ('subok', True), - ('signature', None), - ('extobj', None)] - params = [] - params += [funcsigs.Parameter(name, funcsigs.Parameter.POSITIONAL_ONLY) - for name in input_names] - if f.nout > 1: - params += [funcsigs.Parameter(name, funcsigs.Parameter.POSITIONAL_ONLY, - default=None) - for name in output_names] - params += [funcsigs.Parameter( - 'out', funcsigs.Parameter.POSITIONAL_OR_KEYWORD, - default=None if f.nout == 1 else (None,) * f.nout)] - params += [funcsigs.Parameter(name, funcsigs.Parameter.KEYWORD_ONLY, - default=default) - for name, default in keyword_only_params] - return funcsigs.Signature(params) - - -# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't -# allow positional-only argument. So we conflate positional-only, keyword-only -# and positional-or-keyword arguments here. -def _is_compatible_param_kind(a, b): - def relax(k): - if k in (funcsigs.Parameter.POSITIONAL_ONLY, - funcsigs.Parameter.KEYWORD_ONLY): - return funcsigs.Parameter.POSITIONAL_OR_KEYWORD - return k - return relax(a) == relax(b) - - -def np_doc(np_fun): - """Attachs numpy docstring to a function. - - Args: - np_fun: the numpy function whose docstring will be used. - - Returns: - A function decorator that attaches the docstring from `np_fun` to the - decorated function. - """ - np_sig = _np_signature(np_fun) - def decorator(f): - """The decorator.""" - unsupported_params = [] - if np_sig is not None: - sig = funcsigs.signature(f) - for name in np_sig.parameters: - if name not in sig.parameters: - unsupported_params.append(name) - f.__doc__ = _np_doc_helper(f, np_fun, unsupported_params) - return f - return decorator - - -def _np_doc_helper(f, np_f, unsupported_params=None): - """Helper to get docs.""" - if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f): - return np_f.__doc__ - doc = 'TensorFlow variant of `numpy.%s`.\n\n' % np_f.__name__ - if unsupported_params: - doc += 'Unsupported arguments: ' + ', '.join( - '`' + name + '`' for name in unsupported_params) + '.\n\n' - if _has_docstring(f): - doc += f.__doc__ - doc = _add_blank_line(doc) - if _has_docstring(np_f): - doc += 'Documentation for `numpy.%s`:\n\n' % np_f.__name__ - doc += np_f.__doc__ - return doc - - -def np_doc_only(np_f): - """Attachs numpy docstring to a function. - - This differs from np_doc in that it doesn't check for a match in signature. - - Args: - np_f: the numpy function whose docstring will be used. - - Returns: - A function decorator that attaches the docstring from `np_f` to the - decorated function. - """ - - def decorator(f): - f.__doc__ = _np_doc_helper(f, np_f) - return f - - return decorator - - -def tf_broadcast(*args): - """Broadcast tensors. - - Args: - *args: a list of tensors whose shapes are broadcastable against each other. - - Returns: - Tensors broadcasted to the common shape. - """ - if len(args) <= 1: - return args - sh = tf.shape(args[0]) - for arg in args[1:]: - sh = tf.broadcast_dynamic_shape(sh, tf.shape(arg)) - return [tf.broadcast_to(arg, sh) for arg in args] - - -# TODO(wangpeng): Move the following functions to a separate file and check for -# float dtypes in each of them. - - -def get_static_value(x): - """A version of tf.get_static_value that returns None on float dtypes. - - It returns None on float dtypes in order to avoid breaking gradients. - - Args: - x: a tensor. - - Returns: - Same as `tf.get_static_value`, except that it returns None when `x` has a - float dtype. - """ - if isinstance(x, tf.Tensor) and (x.dtype.is_floating or x.dtype.is_complex): - return None - return tf.get_static_value(x) - - -def _maybe_static(x): - value = get_static_value(x) - if value is None: - return x - else: - return value - - -# All the following functions exist because get_static_value can't handle -# their TF counterparts. - - -def cond(pred, true_fn, false_fn): - """A version of tf.cond that tries to evaluate the condition.""" - v = get_static_value(pred) - if v is None: - return tf.cond(pred, true_fn, false_fn) - if v: - return true_fn() - else: - return false_fn() - - -def add(a, b): - """A version of tf.add that eagerly evaluates if possible.""" - return _maybe_static(a) + _maybe_static(b) - - -def subtract(a, b): - """A version of tf.subtract that eagerly evaluates if possible.""" - return _maybe_static(a) - _maybe_static(b) - - -def greater(a, b): - """A version of tf.greater that eagerly evaluates if possible.""" - return _maybe_static(a) > _maybe_static(b) - - -def greater_equal(a, b): - """A version of tf.greater_equal that eagerly evaluates if possible.""" - return _maybe_static(a) >= _maybe_static(b) - - -def less_equal(a, b): - """A version of tf.less_equal that eagerly evaluates if possible.""" - return _maybe_static(a) <= _maybe_static(b) - - -def logical_and(a, b): - """A version of tf.logical_and that eagerly evaluates if possible.""" - a_value = get_static_value(a) - if a_value is not None: - if np.isscalar(a_value): - if a_value: - return _maybe_static(b) - else: - return a_value - else: - return a_value & _maybe_static(b) - else: - return a & _maybe_static(b) - - -def logical_or(a, b): - """A version of tf.logical_or that eagerly evaluates if possible.""" - a_value = get_static_value(a) - if a_value is not None: - if np.isscalar(a_value): - if a_value: - return a_value - else: - return _maybe_static(b) - else: - return a_value | _maybe_static(b) - else: - return a | _maybe_static(b) - - -def getitem(a, slice_spec): - """A version of __getitem__ that eagerly evaluates if possible.""" - return _maybe_static(a)[slice_spec] - - -def reduce_all(input_tensor, axis=None, keepdims=False): - """A version of tf.reduce_all that eagerly evaluates if possible.""" - v = get_static_value(input_tensor) - if v is None: - return tf.reduce_all(input_tensor, axis=axis, keepdims=keepdims) - else: - return v.all(axis=axis, keepdims=keepdims) - - -def reduce_any(input_tensor, axis=None, keepdims=False): - """A version of tf.reduce_any that eagerly evaluates if possible.""" - v = get_static_value(input_tensor) - if v is None: - return tf.reduce_any(input_tensor, axis=axis, keepdims=keepdims) - else: - return v.any(axis=axis, keepdims=keepdims) diff --git a/trax/tf_numpy_and_keras.ipynb b/trax/tf_numpy_and_keras.ipynb deleted file mode 100644 index 70c9e38fa..000000000 --- a/trax/tf_numpy_and_keras.ipynb +++ /dev/null @@ -1,578 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "7yuytuIllsv1" - }, - "source": [ - "# Using Trax with TensorFlow NumPy and Keras\n", - "\n", - "This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/tf_numpy_and_keras.ipynb)) shows how you can run [Trax](https://trax-ml.readthedocs.io/en/latest/) directly with [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy). You will also see how to use Trax layers and models inside [Keras](https://keras.io/) so you can use Trax in production, e.g., with [TensorFlow.js](https://www.tensorflow.org/js/) or [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving).\n", - "\n", - " 1. **Trax with TensorFlow NumPy**: use Trax with [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) without any code changes\n", - " 1. **Convert Trax to Keras**: how to get a [Keras](https://keras.io/) layer for your Trax model and use it\n", - " 1. **Exporting Trax Models for Deployment**: how to export Trax models to [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model)\n", - " \n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-LQ89rFFsEdk" - }, - "source": [ - "## 1. Trax with TensorFlow NumPy\n", - "\n", - "In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n", - "\n", - "The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n", - "\n", - "You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n", - "\n", - "Let's train the sentiment analysis model from the [Trax intro](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb) using TensorFlow NumPy to see how it works." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BIl27504La0G" - }, - "source": [ - "**General Setup**\n", - "\n", - "Execute the following few cells (once) before running any of the code samples." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "cellView": "form", - "executionInfo": { - "elapsed": 38104, - "status": "ok", - "timestamp": 1607390269924, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "oILRLCWN_16u" - }, - "outputs": [], - "source": [ - "#@title\n", - "# Copyright 2020 Google LLC.\n", - "\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "cellView": "both", - "executionInfo": { - "elapsed": 309, - "status": "ok", - "timestamp": 1607390270242, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "vlGjGoGMTt-D", - "outputId": "279a980e-1e71-4080-9587-d89aeb17ebc6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "# Install and import Trax\n", - "!pip install -q -U git+https://github.com/google/trax@master\n", - "\n", - "import os\n", - "import numpy as np\n", - "import trax" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O_3JcfZaT5oP" - }, - "source": [ - "Here is how you can set the fastmath backend to `tensorflow-numpy` and verify that it's been set." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "executionInfo": { - "elapsed": 286, - "status": "ok", - "timestamp": 1607390270535, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "djTiSLcaNFGa", - "outputId": "bac38e28-d1e5-41bd-9054-d85913fc2900" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensorflow-numpy\n" - ] - } - ], - "source": [ - "# Use the tensorflow-numpy backend.\n", - "trax.fastmath.set_backend('tensorflow-numpy')\n", - "print(trax.fastmath.backend_name())" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "executionInfo": { - "elapsed": 15126, - "status": "ok", - "timestamp": 1607390285667, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "AV5wrgjZ10yU", - "outputId": "6385fbe2-5a8e-415c-8851-b5bef099e02f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "batch shapes = [(8, 2048), (8,), (8,)]\n" - ] - } - ], - "source": [ - "# Create data streams.\n", - "train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()\n", - "eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()\n", - "\n", - "data_pipeline = trax.data.Serial(\n", - " trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),\n", - " trax.data.Shuffle(),\n", - " trax.data.FilterByLength(max_length=2048, length_keys=[0]),\n", - " trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],\n", - " batch_sizes=[512, 128, 32, 8, 1],\n", - " length_keys=[0]),\n", - " trax.data.AddLossWeights()\n", - " )\n", - "train_batches_stream = data_pipeline(train_stream)\n", - "eval_batches_stream = data_pipeline(eval_stream)\n", - "\n", - "# Print example shapes.\n", - "example_batch = next(train_batches_stream)\n", - "print(f'batch shapes = {[x.shape for x in example_batch]}')" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "executionInfo": { - "elapsed": 409, - "status": "ok", - "timestamp": 1607390286085, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "WoSz5plIyXOU", - "outputId": "aa1db911-96fb-430b-8360-1a6e3f764cee" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Serial[\n", - " Embedding_8192_256\n", - " Mean\n", - " Dense_2\n", - "]\n" - ] - } - ], - "source": [ - "# Create the model.\n", - "from trax import layers as tl\n", - "\n", - "model = tl.Serial(\n", - " tl.Embedding(vocab_size=8192, d_feature=256),\n", - " tl.Mean(axis=1), # Average on axis 1 (length of sentence).\n", - " tl.Dense(2), # Classify 2 classes.\n", - ")\n", - "\n", - "# You can print model structure.\n", - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "executionInfo": { - "elapsed": 79139, - "status": "ok", - "timestamp": 1607390365232, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "d6bIKUO-3Cw8", - "outputId": "ba4199f4-cc31-459e-b46c-d14ec2f4ef68" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Step 1: Total number of trainable weights: 2097666\n", - "Step 1: Ran 1 train steps in 1.01 secs\n", - "Step 1: train WeightedCategoryCrossEntropy | 0.69292086\n", - "Step 1: eval WeightedCategoryCrossEntropy | 0.68457415\n", - "Step 1: eval WeightedCategoryAccuracy | 0.56406250\n", - "\n", - "Step 500: Ran 499 train steps in 19.92 secs\n", - "Step 500: train WeightedCategoryCrossEntropy | 0.50587755\n", - "Step 500: eval WeightedCategoryCrossEntropy | 0.46716719\n", - "Step 500: eval WeightedCategoryAccuracy | 0.80625000\n", - "\n", - "Step 1000: Ran 500 train steps in 17.50 secs\n", - "Step 1000: train WeightedCategoryCrossEntropy | 0.36375266\n", - "Step 1000: eval WeightedCategoryCrossEntropy | 0.44373559\n", - "Step 1000: eval WeightedCategoryAccuracy | 0.80000000\n", - "\n", - "Step 1500: Ran 500 train steps in 18.40 secs\n", - "Step 1500: train WeightedCategoryCrossEntropy | 0.34449804\n", - "Step 1500: eval WeightedCategoryCrossEntropy | 0.34941847\n", - "Step 1500: eval WeightedCategoryAccuracy | 0.84687500\n", - "\n", - "Step 2000: Ran 500 train steps in 17.18 secs\n", - "Step 2000: train WeightedCategoryCrossEntropy | 0.28685242\n", - "Step 2000: eval WeightedCategoryCrossEntropy | 0.50030373\n", - "Step 2000: eval WeightedCategoryAccuracy | 0.77539062\n" - ] - } - ], - "source": [ - "# Train the model.\n", - "from trax.supervised import training\n", - "\n", - "# Training task.\n", - "train_task = training.TrainTask(\n", - " labeled_data=train_batches_stream,\n", - " loss_layer=tl.WeightedCategoryCrossEntropy(),\n", - " optimizer=trax.optimizers.Adam(0.01),\n", - " n_steps_per_checkpoint=500,\n", - ")\n", - "\n", - "# Evaluaton task.\n", - "eval_task = training.EvalTask(\n", - " labeled_data=eval_batches_stream,\n", - " metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],\n", - " n_eval_batches=20 # For less variance in eval numbers.\n", - ")\n", - "\n", - "# Training loop saves checkpoints to output_dir.\n", - "output_dir = os.path.expanduser('~/output_dir/')\n", - "training_loop = training.Loop(model,\n", - " train_task,\n", - " eval_tasks=[eval_task],\n", - " output_dir=output_dir)\n", - "\n", - "# Run 2000 steps (batches).\n", - "training_loop.run(2000)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "executionInfo": { - "elapsed": 832, - "status": "ok", - "timestamp": 1607390366089, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "yuPu37Lp7GST", - "outputId": "b95f944d-b5e8-44c6-829c-25c0b0b08f38" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "example input_str: The movie features another exceptional collaboration between director William Wyler and cinematographer Gregg Toland, the first after Toland worked on Citizen Kane. But the talent of both these men was focused on achieving a perfectly crafted movie, understood in the good old American sense as a great story. The technical aspects of the movie are covered so as the viewer gets absorbed into the action that takes place on the screen without submitting to the power of the image. Technique is seen as a vehicle of representation unlike in Citizen Kane where Welles' baroque style almost drew the attention from the story to the way the story was told. One of my favorite moves with deep focus in this film is the drama conveyed by the returning home welcoming of Homer and Al. If Homer's girl, Wilma comes towards him perfectly in focus, Al goes over to his wife also perfectly in focus. This is a brilliant move because it shows only through the use of the image the nature of these relationships as we will see them throughout the movie: Wilma loves Homer and she accepts him as he is, Al's wife loves him also but she feels unprepared to fully welcome him home. Also later in the film we find out that their marriage has not always been a bed of roses.\u003cbr /\u003e\u003cbr /\u003eWyler is a director whose force lies in being true to his work without feeling the need to boast. He wanted to show his audience how hard it was for the American soldiers returning from the war to fit into a society that either didn't understand them or treated them with contempt. With a perfect cast and great dialogue Goldwin and Wyler produced a movie that will forever be the template for any other returning home movie. The three hours which coincide with the \"rough cut\" because the test audience back then never felt for a moment that the action was slow and indeed every scene from the film seems perfectly justified. The whole thing is constructed beautifully, every character gets a fair amount of exposure, nothing is left to chance and it is quite pitiful that Hollywood nowadays never manages to bring so much character conflict to the screen. TBYOOL explores the depth of the American way of life, of the American family and society to an extent that makes other movies look like \"the children's hour\".\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\u003cpad\u003e\n", - "Model returned sentiment activations: [[-1.6396211 1.6328843]]\n" - ] - } - ], - "source": [ - "# Run on an example.\n", - "example_input = next(eval_batches_stream)[0][0]\n", - "example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')\n", - "print(f'example input_str: {example_input_str}')\n", - "sentiment_activations = model(example_input[None, :]) # Add batch dimension.\n", - "print(f'Model returned sentiment activations: {np.asarray(sentiment_activations)}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8wgfJyhdihfR" - }, - "source": [ - "## 2. Convert Trax to Keras\n", - "\n", - "Thanks to [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) you can convert the model you just trained into a [Keras](https://keras.io/) layer using `trax.AsKeras`. This allows you to:\n", - "\n", - "* use Trax layers inside Keras models\n", - "* run Trax models with existing Keras input pipelines\n", - "* export Trax models to [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model)\n", - "\n", - "When creating a Keras layer from a Trax one, the Keras layer weights will get initialized to the ones the Trax layer had at the moment of creation. In this way, you can create Keras layers from pre-trained Trax models and save them as SavedModel as shown below." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "executionInfo": { - "elapsed": 322, - "status": "ok", - "timestamp": 1607390366418, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "bxSLRyjftuxH", - "outputId": "6ec7180b-ff85-47e4-bba2-3634df913ad4" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u003ctrax.trax2keras.AsKeras object at 0x7efff5a47a90\u003e\n", - "Keras returned sentiment activations: [[-1.6396211 1.6328843]]\n" - ] - } - ], - "source": [ - "# Convert the model into a Keras layer, use the weights from model.\n", - "keras_layer = trax.AsKeras(model)\n", - "print(keras_layer)\n", - "\n", - "# Run the Keras layer to verify it returns the same result.\n", - "sentiment_activations = keras_layer(example_input[None, :])\n", - "print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "executionInfo": { - "elapsed": 3983, - "status": "ok", - "timestamp": 1607390370412, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "r8C-FoFGxGE1", - "outputId": "0edfd1fa-2677-494a-f03f-2cc87324e88c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Keras returned sentiment activations: [[-1.6396211 1.6328843]]\n" - ] - } - ], - "source": [ - "import tensorflow as tf\n", - "\n", - "# Create a full Keras model using the layer from Trax.\n", - "inputs = tf.keras.Input(shape=(None,), dtype='int32')\n", - "hidden = keras_layer(inputs) \n", - "# You can add other Keras layers here operating on hidden.\n", - "outputs = hidden\n", - "keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", - "print(keras_model)\n", - "\n", - "# Run the Keras model to verify it returns the same result.\n", - "sentiment_activations = keras_model(example_input[None, :])\n", - "print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EQH1bvXwy5fE" - }, - "source": [ - "## 3. Exporting Trax Models for Deployment\n", - "\n", - "You can export the Keras model to disk as [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model). It's as simple as calling `keras_model.save` and allows you to use models with TF tools [TensorFlow.js](https://www.tensorflow.org/js/), [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) and [TensorFlow Lite](https://www.tensorflow.org/lite)." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "executionInfo": { - "elapsed": 1355, - "status": "ok", - "timestamp": 1607390371776, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 480 - }, - "id": "nQIJrOUgxRfK", - "outputId": "62c028a5-da9e-40b1-d223-aa5f45b6a2aa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Keras returned sentiment activations: [[-1.6396211 1.6328843]]\n" - ] - } - ], - "source": [ - "# Save the Keras model to output_dir.\n", - "model_file = os.path.join(output_dir, \"model_checkpoint\")\n", - "keras_model.save(model_file)\n", - "\n", - "# Load the model from SavedModel.\n", - "loaded_model = tf.keras.models.load_model(model_file)\n", - "\n", - "# Run the loaded model to verify it returns the same result.\n", - "sentiment_activations = loaded_model(example_input[None, :])\n", - "print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "last_runtime": { - "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", - "kind": "private" - }, - "name": "Using Trax with Keras", - "provenance": [ - { - "file_id": "1RNbQoOuzKsp_FoDqOFQX4mA--Wzt5ofq", - "timestamp": 1596181556972 - }, - { - "file_id": "https://github.com/google/trax/blob/master/trax/intro.ipynb", - "timestamp": 1596178511100 - }, - { - "file_id": "trax/intro.ipynb", - "timestamp": 1595931762204 - }, - { - "file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV", - "timestamp": 1578964243645 - }, - { - "file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07", - "timestamp": 1572044421118 - }, - { - "file_id": "intro.ipynb", - "timestamp": 1571858674399 - }, - { - "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", - "timestamp": 1569980697572 - }, - { - "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", - "timestamp": 1563927451951 - } - ] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/trax/trainer.py b/trax/trainer.py deleted file mode 100644 index add33cb57..000000000 --- a/trax/trainer.py +++ /dev/null @@ -1,197 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Trax trainer.""" -import atexit -import datetime -import functools -import os - -from absl import app -from absl import flags -from absl import logging - -import gin -import jax -from jax.lib import xla_extension as xc -import tensorflow.compat.v2 as tf -from trax import fastmath -from trax import trainer_flags # pylint: disable=unused-import -from trax.supervised import trainer_lib -from trax.tf_numpy import numpy as tf_np - -FLAGS = flags.FLAGS -Backend = fastmath.Backend - - -# TODO(afrozm): Share between trainer.py and rl_trainer.py -def _tf_setup_from_flags(): - """Processes TensorFlow-relevant flags.""" - if FLAGS.enable_eager_execution: - tf.compat.v1.enable_eager_execution() - if FLAGS.tf_xla: - tf.config.optimizer.set_jit(True) - fastmath.tf.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) - tf.config.optimizer.set_experimental_options({ - 'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host, - 'layout_optimizer': FLAGS.tf_opt_layout, - }) - tf_np.set_allow_float64(FLAGS.tf_allow_float64) - - -# TODO(afrozm): Share between trainer.py and rl_trainer.py -def _gin_parse_configs(): - """Initializes gin-controlled bindings.""" - # Imports for configurables - # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable - from trax import models as _trax_models - from trax import optimizers as _trax_opt - # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable - - configs = FLAGS.config if FLAGS.config is not None else [] - # Override with --dataset and --model - if FLAGS.dataset: - configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset) - if FLAGS.data_dir: - configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir) - if FLAGS.model: - configs.append('train.model=@trax.models.%s' % FLAGS.model) - gin.parse_config_files_and_bindings(FLAGS.config_file, configs) - - -def _output_dir_or_default(): - """Returns a path to the output directory.""" - if FLAGS.output_dir: - output_dir = FLAGS.output_dir - trainer_lib.log('Using --output_dir {}'.format(output_dir)) - return os.path.expanduser(output_dir) - - # Else, generate a default output dir (under the user's home directory). - try: - dataset_name = gin.query_parameter('data_streams.dataset_name') - except ValueError: - dataset_name = 'random' - output_name = '{model_name}_{dataset_name}_{timestamp}'.format( - model_name=gin.query_parameter('train.model').configurable.name, - dataset_name=dataset_name, - timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'), - ) - output_dir = os.path.join('~', 'trax', output_name) - output_dir = os.path.expanduser(output_dir) - print() - trainer_lib.log('No --output_dir specified') - trainer_lib.log('Using default output_dir: {}'.format(output_dir)) - return output_dir - - -# TODO(afrozm): Share between trainer.py and rl_trainer.py -def _jax_and_tf_configure_for_devices(): # pylint: disable=missing-function-docstring - if FLAGS.use_tpu: - jax.config.update('jax_platform_name', 'tpu') - jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) - jax.config.update('jax_backend_target', FLAGS.jax_backend_target) - if (FLAGS.enable_eager_execution and (fastmath.is_backend(Backend.NUMPY) or - fastmath.is_backend(Backend.JAX))): - # Numpy backend doesn't benefit from having the input pipeline run on GPU, - # and jax backend has GPU memory contention if TF uses the GPU. Gin must be - # set up first before determining the backend. - tf.config.experimental.set_visible_devices([], 'GPU') - - -def _train_using_tf(output_dir): - worker_cpu = tf_init_tpu() - with tf.device(worker_cpu): - if trainer_lib.num_devices() == 1: - # TF's device priority is GPU > CPU > TPU, so we need to explicitly make - # the TPU core the default device here. - with tf.device('/device:TPU:0'): - trainer_lib.train(output_dir=output_dir) - else: - trainer_lib.train(output_dir=output_dir) - - -@gin.configurable -def tf_init_tpu(worker='', protocol=None): - """Initializes TPU for TensorFlow. - - Args: - worker: The BNS address of the remote TPU worker. If it's empty (the default - value), TF will assume the TPU devices are connected to the local host. - protocol: The network protocol used to connect to the TPU worker. - Returns: - The device name of the TPU worker's CPU. - """ - protocol = protocol or 'grpc' - is_local = (worker in ('', 'local')) - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=worker) - if not is_local: - tf.config.experimental_connect_to_cluster(resolver, protocol=protocol) - tf.tpu.experimental.initialize_tpu_system(resolver) - if is_local: - return '' - else: - return '/job:worker' - - -def _make_jax_gpu_cluster(host_id, server_ip, n_hosts, server_port=5005): - """Make JAX GPU Cluster.""" - - addr = f'{server_ip}:{server_port}' - if host_id == 0: - logging.info('starting service on %s', addr) - service = xc.get_distributed_runtime_service(addr, n_hosts) - # We add an explicit call to shutdown the service via atexit as Python - # interpreter may not call the service destructor on process termination. - atexit.register(service.shutdown) - - logging.info('connecting to service on %s', addr) - dist_client = xc.get_distributed_runtime_client(addr, host_id) - dist_client.connect() - atexit.register(dist_client.shutdown) - - # register dist gpu backend - factory = functools.partial(jax.lib.xla_client.make_gpu_client, - dist_client, host_id) - jax.lib.xla_bridge.register_backend_factory('gpu', factory, priority=300) - - -def main(_): - logging.set_verbosity(FLAGS.log_level) - - _tf_setup_from_flags() - _gin_parse_configs() - _jax_and_tf_configure_for_devices() - - # Create a JAX GPU cluster if using JAX and given a chief IP. - if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip: - _make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id, - FLAGS.gpu_cluster_chief_ip, - FLAGS.gpu_cluster_n_hosts, - FLAGS.gpu_cluster_port) - - if FLAGS.disable_jit: - fastmath.disable_jit() - - output_dir = _output_dir_or_default() - if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): - _train_using_tf(output_dir) - else: - trainer_lib.train(output_dir=output_dir) - - trainer_lib.log('Finished training.') - - -if __name__ == '__main__': - app.run(main) diff --git a/trax/trainer_flags.py b/trax/trainer_flags.py deleted file mode 100644 index 097a8ddac..000000000 --- a/trax/trainer_flags.py +++ /dev/null @@ -1,93 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Flags for trainer.py and rl_trainer.py. - -We keep these flags in sync across the trainer and the rl_trainer binaries. -""" - -from absl import flags -from absl import logging - -# Common flags. -flags.DEFINE_string('output_dir', - None, - 'Path to the directory to save logs and checkpoints.') -flags.DEFINE_multi_string('config_file', - None, - 'Configuration file with parameters (.gin).') -flags.DEFINE_multi_string('config', - None, - 'Configuration parameters (gin string).') - -# TPU Flags -flags.DEFINE_bool('use_tpu', False, "Whether we're running on TPU.") -flags.DEFINE_string('jax_xla_backend', - '', - 'Either "xla" for the XLA service directly, or "tpu_driver"' - 'for a TPU Driver backend.') -flags.DEFINE_string('jax_backend_target', - 'local', - 'Either "local" or "rpc:address" to connect to a ' - 'remote service target.') - -# trainer.py flags. -flags.DEFINE_string('dataset', None, 'Which dataset to use.') -flags.DEFINE_string('model', None, 'Which model to train.') -flags.DEFINE_string('data_dir', None, 'Path to the directory with data.') -flags.DEFINE_integer('log_level', logging.INFO, 'Log level.') - -# JAX/XLA GPU cluster flags. -flags.DEFINE_string('gpu_cluster_chief_ip', '', 'IP of GPU cluster chief.') -flags.DEFINE_integer('gpu_cluster_n_hosts', 1, - 'Number of hosts in GPU cluster.') -flags.DEFINE_integer('gpu_cluster_host_id', 0, 'Host id inside GPU cluster.') -flags.DEFINE_integer('gpu_cluster_port', 5005, 'Port to use in GPU cluster.') - -# TensorFlow Flags -flags.DEFINE_bool('enable_eager_execution', - True, - "Whether we're running TF in eager mode.") -flags.DEFINE_bool('tf_xla', True, 'Whether to turn on XLA for TF.') -flags.DEFINE_bool('tf_opt_pin_to_host', - False, - 'Whether to turn on TF pin-to-host optimization.') -flags.DEFINE_bool('tf_opt_layout', - False, - 'Whether to turn on TF layout optimization.') -flags.DEFINE_bool('tf_xla_forced_compile', - False, - 'Use forced-compilation instead of auto-clustering for XLA.' - 'This flag only has effects when --tf_xla is on.') -flags.DEFINE_bool('tf_allow_float64', False, 'Whether to allow float64 for TF.') - -# rl_trainer.py flags. -flags.DEFINE_boolean('jax_debug_nans', - False, - 'Setting to true will help to debug nans and disable jit.') -flags.DEFINE_boolean('disable_jit', False, 'Setting to true will disable jit.') -flags.DEFINE_string('envs_output_dir', '', 'Output dir for the envs.') -flags.DEFINE_bool('xm', False, 'Copy atari roms?') -flags.DEFINE_integer('train_batch_size', - 32, - 'Number of parallel environments during training.') -flags.DEFINE_integer('eval_batch_size', 4, 'Batch size for evaluation.') -flags.DEFINE_boolean('parallelize_envs', - False, - 'If true, sets parallelism to number of cpu cores.') -flags.DEFINE_string('trajectory_dump_dir', - '', - 'Directory to dump trajectories to.') -flags.DEFINE_bool('async_mode', False, 'Async mode.') diff --git a/trax/trainers/__init__.py b/trax/trainers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/trainers/base.py b/trax/trainers/base.py new file mode 100644 index 000000000..4fce87f7a --- /dev/null +++ b/trax/trainers/base.py @@ -0,0 +1,1024 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-device accelerated optimization.""" + +import functools +import os + +import jax +import numpy as np +import psutil + +from absl import logging + +from trax import fastmath +from trax import layers as tl +from trax.fastmath import numpy as jnp +from trax.layers import base +from trax.layers import combinators as cb + + +class Trainer: + """Multi-device accelerated trainers. + + Given an optimizer and a composite layer containing model+loss, this class + creates a multi-device accelerated function with which it can compute one step + of updates to the model's weights/state and the optimizer slots. By default + it uses all available accelerators, via JIT compilation and parallel mapping. + + The optimizer and model must be initialized prior to use by this class. + + The key `one_step` function runs one forward-backward pass through the model, + and returns the resulting loss value and updated optimizer statistics. As a + side effect, the function also modifies the model weights and optimizer slots. + """ + + def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): + self._model_with_loss = model_with_loss + self._optimizer = optimizer + self._n_devices = n_devices or fastmath.local_device_count() + self._adasum = adasum + + # optimizer slots and opt_params may need to be replicated + self._slots, self._opt_params = tl.on_cpu( + tl.for_n_devices( + (self._optimizer.slots, self._optimizer.opt_params), self._n_devices + ) + ) + + # accelerated version of model+loss to replicate weights and state + self._accelerated_model_with_loss = tl.Accelerate( + model_with_loss, n_devices=n_devices + ) + + # Signature: + # (batch, weights, state, rng) -> ((loss, state), gradients) + self._forward_and_backward_fn = fastmath.value_and_grad( + model_with_loss.pure_fn, argnums=1, has_aux=True # arg1 of pure_fn: weights + ) # return (loss, state), gradients + + # Signature: + # (weights, slots), step, opt_params, batch, state, rng -> + # (weights, slots), state, stats + self._accelerated_update_fn = _accelerate_update_fn( + self._forward_and_backward_fn, + self._optimizer, + n_devices=self._n_devices, + accelerate=True, + adasum=self._adasum, + ) + + @property + def model_with_loss(self): + """Returns the composite model+loss for this instance.""" + return self._model_with_loss + + @property + def accelerated_model_with_loss(self): + """Returns the accelerated composite model+loss for this instance.""" + return self._accelerated_model_with_loss + + @property + def optimizer(self): + """Returns the optimizer for this instance.""" + return self._optimizer + + @property + def slots(self): + """Returns the slots of the optimizers.""" + return self._optimizer.slots + + @slots.setter + def slots(self, slots): + """Sets the slots of the optimizers and this class (replicated).""" + self._optimizer.slots = slots + self._slots = tl.on_cpu(tl.for_n_devices(slots, self._n_devices)) + + def one_step(self, batch, rng, step=0, learning_rate=None): + """Runs one training step, to update model and optimizer parameters. + + Args: + batch: Batch of labeled training data. + rng: Single-use random number generator (JAX PRNG key). + step: Training step number. + learning_rate: Learning rate for the optimizer; if None, use optimizer's + default learning rate. + + Returns: + Tuple of (loss, optimizer_stats), with the newly computed loss and + updated stats as reported by the optimizer. + """ + if learning_rate is not None: + self._opt_params["learning_rate"] = tl.for_n_devices( + learning_rate, self._n_devices + ) + + # Split the batch across devices (batch_dim --> batch_dim // n_devices) + # and create new rng's 1-1 with devices. + if self._n_devices > 1: + batch = tl.reshape_by_device(batch, self._n_devices) + rng = jnp.stack(fastmath.random.split(rng, self._n_devices)) + + weights = self._accelerated_model_with_loss.weights + state = self._accelerated_model_with_loss.state + if logging.vlog_is_on(1) and ((step & step - 1) == 0): + # Prints every power of two, if debugging is enabled. + logging.info("step[%d]", step) + logging.info("opt_params[%s]", self._opt_params) + logging.info("slots[%s]", self._slots) + logging.info("weights[%s]", weights) + logging.info("state[%s]", state) + + # NOTE: stats is a replicated dictionary of key to jnp arrays. + (new_weights, new_slots), new_state, stats = self._accelerated_update_fn( + (weights, self._slots), step, self._opt_params, batch, state, rng + ) + + if logging.vlog_is_on(1) and ((step & step - 1) == 0): + logging.info("updated weights[%s]", new_weights) + logging.info("stats[%s]", stats) + + self._accelerated_model_with_loss.weights = new_weights + self._accelerated_model_with_loss.state = new_state + self._slots = new_slots + self._optimizer.slots = self._unreplicate(self._slots) + return stats["loss"], stats + + def _unreplicate(self, x): + if self._n_devices == 1: + return x + return fastmath.nested_map(lambda x: x[0], x) + + +def _adasum_merge(g1, g2): + """Adasum gradient composition, see https://arxiv.org/pdf/2006.02924.pdf.""" + frac1 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g1, g1) + 1e-30) + frac2 = jnp.vdot(g1, g2) / (2 * jnp.vdot(g2, g2) + 1e-30) + return (1 - frac1) * g1 + (1 - frac2) * g2 + + +def _average_multidevice_gradients(gradients, adasum=False): + """Averages gradients over all the devices across different hosts.""" + n = fastmath.global_device_count() // base.N_WEIGHTS_SHARDS + if adasum: + # This implements a version of the Adasum algorithm from the following + # paper: https://arxiv.org/pdf/2006.02924.pdf + lg = max([i for i in range(20) if 2**i <= n]) + for lg_i in range(lg): + shift = 2**lg_i + perm = [] + for i in range(n): + block_i = i % (2 * shift) # we do blocks of 2*shift size + if block_i < shift: + perm.append((i, i + shift)) + else: + perm.append((i, i - shift)) + perm_grad = jax.lax.ppermute(gradients, perm=perm, axis_name="batch") + gradients = fastmath.nested_map_multiarg( + _adasum_merge, gradients, perm_grad + ) + if base.N_WEIGHTS_SHARDS > 1: # only sum gradients from matching shards + groups = [ + [base.N_WEIGHTS_SHARDS * i + d for i in range(int(n))] + for d in range(base.N_WEIGHTS_SHARDS) + ] + gradients_psum = fastmath.psum(gradients, "batch", axis_index_groups=groups) + else: + gradients_psum = fastmath.psum(gradients, "batch") # sum all gradients + n = jnp.array(n, dtype=jnp.float32) + return fastmath.nested_map(lambda g: g / n, gradients_psum) + + +# Returns a function with the following signature: +# (weights, slots), step, opt_params, batch, state, rng -> +# (weights, slots), state, stats +def _accelerate_update_fn( + forward_and_backward_fn, optimizer, n_devices, accelerate=True, adasum=False +): + """Accelerates the given forward_and_backward_fn function.""" + if n_devices == 1: + + def single_device_update_fn( + weights_and_slots, step, opt_params, batch, state, rng + ): + step = jnp.array(step, dtype=jnp.int32) # Needed in TFNP backend. + weights, slots = weights_and_slots + + (loss, state), gradients = forward_and_backward_fn( + batch, weights, state, rng + ) + + weights, slots, stats = optimizer.tree_update( + step, gradients, weights, slots, opt_params, store_slots=False + ) + stats["loss"] = loss + return (weights, slots), state, stats + + if accelerate: + # TODO(afrozm): Find out the status of buffer donation on GPUs, then do + # donate_argnums=(0,). + single_device_update_fn = fastmath.jit(single_device_update_fn) + + return single_device_update_fn + + # More than one device (core), i.e. all of TPU configurations etc. + assert n_devices > 1, f"{n_devices} should be greater than 1." + + @functools.partial(fastmath.pmap, axis_name="batch", donate_argnums=(0,)) + def _multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): + # All tensors should have the first dimension = n_devices. + weights, slots = weights_and_slots + (loss, state), gradients = forward_and_backward_fn(batch, weights, state, rng) + gradients = _average_multidevice_gradients(gradients, adasum=adasum) + weights, slots, stats = optimizer.tree_update( + step, gradients, weights, slots, opt_params, store_slots=False + ) + stats["loss"] = loss + return (weights, slots), state, stats + + def multi_device_update_fn(weights_and_slots, step, opt_params, batch, state, rng): + # Need to replicate step to n_devices leading dimension. + return _multi_device_update_fn( + weights_and_slots, + jnp.repeat(step, n_devices), + opt_params, + batch, + state, + rng, + ) + + return multi_device_update_fn + + +class ReversibleSerialTrainer: + """Runs an optimizer on a series of layers, reversible and not. + + We provide layers to this trainers in blocks, each block consisting of + a list of standard layers and a list of reversible layers. They all run + in turn (like one huge Serial block) but in a more memory-efficient way. + + The main motivation for this class is to save memory: it allows to train + models that have more weights than the memory available on accelerators. + This happens by caching the weights in CPU memory and transferring only + the weights of one layer at a time. The reversible layers are used to make + the backward pass without using additional memory for storing activations. + + Note: we do not allow sharing weights between blocks for now. + """ + + def __init__( + self, + blocks, + loss_layer, + optimizer_fn, + n_devices=None, + memoize_jit=True, + free_accelerators_on_step=False, + adasum=False, + ): + """Creates a ReversibleSerialTrainer and the needed optimizers. + + This trainers performs updates equivalent to using the default Trainer on:: + + tl.Serial(blocks + [loss_layer]). + + It is more memory-efficient though since weights are stored on CPU and only + sent to accelerator layer-by-layer. Blocks are pairs consisting of a list + of standard (arbitrary) layers and a list of reversible layers which help + save memory thanks to being reversible. + + Args: + blocks: A list of pairs of lists of standard and reversible layers. + loss_layer: The final layer of the model; it can have trainable weights + but should end with a loss: it is required to produce a scalar output. + optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. + n_devices: An optional integer, number of accelerator devices to use; + by default, all available accelerators will be used. + memoize_jit: Whether to memoize JITed functions; this significantly speeds + up XLA compilation of larger models, but it uses `repr(layer)` as keys + to memoize so it could fail if two layers with different functionality + had the same string representaion. We have not encountered such case + yet so this is turned on by default, but consider turning it off or + reviewing your model if you use custom layers and encounter a problem. + free_accelerators_on_step: If true, frees memory on accelerators when + starting a step. All layers and arguments must be on host for that, + otherwise it can lead to failures. Can prevent memory fragmentation. + adasum: if True, use adaptive summation to gather multi-device gradients. + """ + self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] + self._loss_layer = loss_layer + self._optimizer_fn = optimizer_fn + self._n_devices = n_devices or fastmath.local_device_count() + self._adasum = adasum + self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) + self._n_steps_per_log = 100 # Log layers and stats every 100 steps. + self._n_async_layers = 1 # How many layers to run asynchronously. + self._jit_memory = {} if memoize_jit else None + self._do_free = free_accelerators_on_step + self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs, backend="cpu") + + # Create accelerated versions of layers as pmaped/jited pure_fn. + self._accelerated_layer_fns = fastmath.nested_map( + lambda layer: self._pjit(layer.pure_fn, f"fwd {repr(layer)}"), self._blocks + ) + + # Create per-layer optimizers and replicate opt_params. + def _make_optimizer(layer): + opt = optimizer_fn() + opt.tree_init(layer.weights) + opt.slots = tl.on_cpu(opt.slots) + return opt + + self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) + self._replicated_opt_params = fastmath.nested_map( + lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers + ) + + self._loss_opt = _make_optimizer(loss_layer) + self._replicated_loss_opt_params = self._replicate_cpu( + self._loss_opt.opt_params + ) + + # Forward + backward + optimizer-update functions for all layers. + # We call them in short FBO for "Forward + Backward + Optimizer update". + # Reversible layers define a reverse_and_fbo function that also reverses. + + self._fbos = [] + for i, (std_layer, rev_layers) in enumerate(self._blocks): + (std_opt, rev_opts) = self._optimizers[i] + std_fbo = _fbo_with_layer_and_opt( + std_layer, std_opt, self._n_devices, adasum=self._adasum + ) + rev_and_fbos = [] + for layer, opt in zip(rev_layers, rev_opts): + rev_and_fbo = _reverse_and_fbo_with_layer_and_opt( + layer, opt, self._n_devices, self._adasum + ) + # The donated args are (outputs, weights, grads) and we can donate + # them because weights and grads are immediately replaced and in + # case of reversible layers, the outputs are never used again. + rev_and_fbos.append( + self._pjit( + rev_and_fbo, f"rev+bwd {repr(layer)}", donate_argnums=(0, 1, 2) + ) + ) + # In standard layers, the inputs cannot be donated as they may be used + # as outputs for the reversible block below, but weights and grads can. + jit_std_fbo = self._pjit( + std_fbo, f"bwd {repr(std_layer)}", donate_argnums=(1, 2) + ) + self._fbos.append((jit_std_fbo, rev_and_fbos)) + + loss_fbo = _fbo_with_layer_and_opt( + self._loss_layer, self._loss_opt, self._n_devices, "loss", self._adasum + ) + self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2)) + + @property + def loss_layer(self): + """Returns the loss layer used to initialize this class.""" + return self._loss_layer + + @property + def all_layers(self): + """Returns all layers that compose the model and loss in this class.""" + layers = [] + for (std_layer, rev_layers) in self._blocks: + layers.append(std_layer) + layers.extend(rev_layers) + layers.append(self._loss_layer) + return layers + + @property + def optimizer_fn(self): + """Returns the optimizer function used to initialize this class.""" + return self._optimizer_fn + + @property + def slots(self): + """Returns the slots of all optimizers.""" + optimizers = list(self._optimizers) + [self._loss_opt] + return fastmath.nested_map(lambda opt: opt.slots, optimizers) + + @slots.setter + def slots(self, slots): + """Sets the slots of all optimizers.""" + for ((s_opt, r_opts), (s_slots, r_slots)) in zip(self._optimizers, slots[:-1]): + for (opt, slot) in zip([s_opt] + r_opts, [s_slots] + r_slots): + opt.slots = slot + self._loss_opt.slots = slots[-1] + + def _pjit(self, f, memory_key=None, donate_argnums=()): + """JIT f if 1 device is available and pmap if more are available.""" + should_memoize = self._jit_memory is not None and memory_key is not None + if should_memoize and memory_key in self._jit_memory: + logging.info("Found JITed function in memory for: %s", memory_key) + return self._jit_memory[memory_key] + if self._n_devices == 1: + res = fastmath.jit(f, donate_argnums=donate_argnums) + else: + res = fastmath.pmap(f, axis_name="batch", donate_argnums=donate_argnums) + if should_memoize: + self._jit_memory[memory_key] = res + return res + + def _replicate(self, x): + if self._n_devices > 1: + return tl.for_n_devices(x, self._n_devices) + return tl.on_accelerator(x) + + def _replicate_cpu(self, x): + # TODO(lukaszkaiser): move it to layers/acceleration to be together with + # tl.for_n_devices and other functions like that, possibly refactor them. + def f(x): + if self._n_devices > 1: + return np.broadcast_to(x, (self._n_devices,) + np.asarray(x).shape) + else: + return x + + return tl.on_cpu(fastmath.nested_map(f, x)) + + def _unreplicate(self, x): + if self._n_devices == 1: + return tl.on_cpu(x) + return tl.on_cpu(fastmath.nested_map(lambda x: x[0], x)) + + def _lazy_unreplicate(self, x): + def unreplicate_and_start_async_copy(y): + unreplicated = y if self._n_devices == 1 else y[0] + unreplicated.copy_to_host_async() + return unreplicated + + return fastmath.nested_map(unreplicate_and_start_async_copy, x) + + def _collect_weights(self, layer): + layer.weights = fastmath.nested_map(np.asarray, layer.weights) + + def _free_accelerators(self, exceptions=(), keep_constants=True): + """Deletes all live buffers from accelerator with no safety guarantees.""" + backend = jax.lib.xla_bridge.get_backend() + live_buffers = backend.live_buffers() + logging.info("Deleting %d live buffers.", len(live_buffers)) + exceptions_buffers = [] + for x in fastmath.tree_flatten(exceptions): + if hasattr(x, "device_buffer"): # DeviceArray + exceptions_buffers.append(x.device_buffer) + if hasattr(x, "device_buffers"): # ShardedDeviceArray + exceptions_buffers.extend(x.device_buffers) + for b in live_buffers: + should_delete = True + for e in exceptions_buffers: + if b is e: + should_delete = False + if keep_constants and not b.shape: + should_delete = False + if should_delete: + b.delete() + + def _per_device_rngs(self, rng): + """Create per-device RNGs from a given rng.""" + # Splitting by device first to be identical with default trainers. + per_device_rng = fastmath.random.split(rng, self._n_devices) + per_device_rngs = [ + fastmath.random.split(r, self._n_layers) for r in per_device_rng + ] + rngs = [ + jnp.stack([r[i] for r in per_device_rngs]) for i in range(self._n_layers) + ] + return rngs + + def one_step(self, batch, rng, step=0, learning_rate=None): + """Updates layers weights/state and optimizers slots by running one step. + + Args: + batch: Batch of data to use for optimization. + rng: Random number generator to use for running this step. + step: Which step of the training are we running. + learning_rate: Learning rate to use instead of the default one. + + Returns: + Tuple (loss, stats) with new values from one step + of training, where stats are all optimizer statistics. + """ + # Update the learning rate if needed. + if learning_rate is not None: + self._replicated_loss_opt_params["learning_rate"] = self._replicate_cpu( + learning_rate + ) + for (std_op, rev_ops) in self._replicated_opt_params: + std_op["learning_rate"] = self._replicate_cpu(learning_rate) + for op in rev_ops: + op["learning_rate"] = self._replicate_cpu(learning_rate) + + # Batch needs to be split across the local devices -- the difference + # between _for_n_devices and _reshape_by_device is that the latter splits + # the batch dim to batch // n_devices, vs _for_n_devices + # broadcasts/replicates to n_devices dimension. + step_int = step + if self._n_devices > 1: + batch = tl.reshape_by_device(batch, self._n_devices, pure_np=True) + step = np.repeat(step, self._n_devices) + + # Create separate rng for each device and layer. + if self._n_devices == 1: + rngs = fastmath.random.split(rng, self._n_layers) + else: + # JIT the function and run it on CPU to avoid memory fragmentation. + rngs = self._jit_per_device_rngs(tl.on_cpu(rng)) + # Group rngs by layer blocks. + rng_blocks, rng_i = [], 0 + for _, rev_layers in self._blocks: + l = len(rev_layers) + rng_blocks.append((rngs[rng_i], rngs[rng_i + 1 : rng_i + l + 1])) + rng_i += l + 1 + + # Run the layers forward upto the loss layer. + if self._do_free: + self._free_accelerators() + process = psutil.Process(os.getpid()) + if isinstance(batch, (list, tuple)): + batch_shapes = [x.shape for x in batch] + else: + batch_shapes = batch.shape + logging.info("running step %d on shapes %s", step_int, str(batch_shapes)) + if step_int % self._n_steps_per_log == 1: + logging.info( + "run fwd: cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + + stack = batch + block_inputs_states = [] + for i, (std_layer, rev_layers) in enumerate(self._blocks): + acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i] + std_rng, rev_rngs = rng_blocks[i] + # Run the standard layer. + stack, std_inputs, std_state = self._run_forward_standard( + stack, std_layer, acc_std_layer_fn, std_rng, step_int + ) + + # Run the reversible layers and collect old and new states. + stack, rev_old_states, rev_new_states = self._run_forward_reversible( + stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int + ) + block_inputs_states.append( + tl.on_cpu(((std_inputs, std_state), (rev_old_states, rev_new_states))) + ) + + # Run the loss layer forward and backward with optimizer update. + if step_int % self._n_steps_per_log == 1: + logging.info( + "run loss: cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + loss_state = self._replicate(self._loss_layer.state) + loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in) + loss_stats, grad_stack = self._run_backward_standard( + None, + step, + self._loss_layer, + loss_inputs, + loss_state, + self._loss_fbo, + rngs[-1], + self._loss_opt, + self._replicated_loss_opt_params, + ) + self._collect_weights(self._loss_layer) + stats = [tl.on_cpu(loss_stats)] + + # De-fragment memory. + if self._do_free: + stack, grad_stack = tl.on_cpu(stack), tl.on_cpu(grad_stack) + self._free_accelerators() + + # Run the layers backward and run optimizer updates. + if step_int % self._n_steps_per_log == 1: + logging.info( + "run bwd: cpu memory use (MB): %.2f", + process.memory_info().rss / float(1024 * 1024), + ) + for i in range(len(self._blocks) - 1, -1, -1): + std_layer, rev_layers = self._blocks[i] + (std_inputs, std_state), ( + rev_old_states, + rev_new_states, + ) = block_inputs_states[i] + std_fbo, rev_fbos = self._fbos[i] + std_opt, rev_opts = self._optimizers[i] + std_rng, rev_rngs = rng_blocks[i] + repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i] + + # Run reversible layers backward with optimizer update. + stack, grad_stack, new_stats = self._run_backward_reversible( + stack, + grad_stack, + step, + rev_layers, + rev_fbos, + rev_old_states, + rev_new_states, + rev_rngs, + rev_opts, + repl_rev_opts_params, + ) + stats.extend(tl.on_cpu(new_stats)) + + # Run the standard layer forward-and-backward pass and optimizer update. + std_layer_stats, grad_stack = self._run_backward_standard( + grad_stack, + step, + std_layer, + std_inputs, + std_state, + std_fbo, + std_rng, + std_opt, + repl_std_opt_params, + ) + stack = cb.outputs_onto_stack( # Put layer inputs on the stack. + std_inputs, stack, std_layer.n_out + ) + stats.append(tl.on_cpu(std_layer_stats)) + + # Collect lazily unreplicated layer weights. + for rev_layer_id in range(self._n_async_layers): + self._collect_weights(rev_layers[rev_layer_id]) + self._collect_weights(std_layer) + + # Join stats from different optimizers into one. + joint_stats = {} + for i, stat in enumerate(reversed(stats)): + for k, v in stat.items(): + joint_stats[f"layer{i}/" + k] = v + return stats[0]["loss"], joint_stats + + def _run_forward_standard(self, stack, layer, accelerated_fn, rng, step): + """Run standard layer forward.""" + if step % self._n_steps_per_log == 1: + logging.info("running forward standard layer %s", str(layer)) + layer_inputs = cb.inputs_from_stack(stack, layer.n_in) + layer_weights = self._replicate(layer.weights) + layer_state = self._replicate(layer.state) + outputs, layer_new_state = accelerated_fn( + layer_inputs, layer_weights, layer_state, rng + ) + stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) + return stack, layer_inputs, layer_new_state + + def _run_forward_reversible(self, stack, rev_layers, accelerated_fns, rngs, step): + """Run reversible layers forward, collect states for backwards pass.""" + old_states, new_states = [], [] + for i, layer in enumerate(rev_layers): + if step % self._n_steps_per_log == 1: + logging.info("running forward reversible layer %s", str(layer)) + weights = self._replicate(layer.weights) # also copies cpu -> accelerator + state = self._replicate(layer.state) + old_states.append(state) + inputs = cb.inputs_from_stack(stack, layer.n_in) + outputs, new_state = accelerated_fns[i](inputs, weights, state, rngs[i]) + stack = cb.outputs_onto_stack(outputs, stack, layer.n_in) + new_states.append(new_state) + return stack, old_states, new_states + + def _run_backward_standard( + self, + grad_stack, + step, + layer, + inp, + state, + fbo_fn, + rng, + optimizer, + replicated_opt_params, + ): + """Run reversible layers backwards.""" + step_int = int(step) if self._n_devices < 2 else int(step[0]) + if step_int % self._n_steps_per_log == 1: + logging.info("running backward standard layer %s", str(layer)) + if grad_stack is not None: + grads = cb.inputs_from_stack(grad_stack, layer.n_out) + else: + grads = None + slots = self._replicate(optimizer.slots) + weights = self._replicate(layer.weights) + # Ensure all arguments are on accelerator. + state = tl.on_accelerator(state) + replicated_opt_params = tl.on_accelerator(replicated_opt_params) + rng = tl.on_accelerator(rng) + grads = tl.on_accelerator(grads) + inp = tl.on_accelerator(inp) + new_weights, new_state, new_slots, new_grads, stats = fbo_fn( + inp, weights, grads, state, slots, replicated_opt_params, rng, step + ) + layer.weights = self._lazy_unreplicate(new_weights) + layer.state = self._unreplicate(new_state) + optimizer.slots = self._unreplicate(new_slots) + if grad_stack is not None: + grad_stack = cb.outputs_onto_stack(new_grads, grad_stack, layer.n_out) + else: + grad_stack = new_grads + return stats, grad_stack + + def _run_backward_reversible( + self, + stack, + grad_stack, + step, + rev_layers, + rev_and_fbos, + old_states, + new_states, + rngs, + optimizers, + replicated_opt_params, + ): + """Run reversible layers backwards.""" + counter = 0 + stats = [] + step_int = int(step) if self._n_devices < 2 else int(step[0]) + for layer, reverse_and_fbo, old_state, new_state, rng in reversed( + list(zip(rev_layers, rev_and_fbos, old_states, new_states, rngs)) + ): + if step_int % self._n_steps_per_log == 1: + logging.info("running backward reversible layer %s", str(layer)) + counter -= 1 + stack, grad_stack, layer_stats = self._run_backward_one_reversible( + layer, + stack, + grad_stack, + step, + rng, + optimizers[counter], + replicated_opt_params[counter], + reverse_and_fbo, + old_state, + new_state, + ) + stats.append(layer_stats) + if counter + self._n_async_layers < 0: + self._collect_weights(rev_layers[counter + self._n_async_layers]) + return stack, grad_stack, stats + + def _run_backward_one_reversible( + self, + layer, + stack, + grad_stack, + step, + rng, + optimizer, + opt_params, + reverse_and_fbo, + old_state, + new_state, + ): + """Run one reversible layer backwards.""" + # We are running backwards and reversing, so we get *outputs* from stack. + outputs = cb.inputs_from_stack(stack, layer.n_out) + grads = cb.inputs_from_stack(grad_stack, layer.n_out) + slots = self._replicate(optimizer.slots) + weights = self._replicate(layer.weights) # cpu -> accelerator + # Ensure all arguments are on accelerator. + outputs = tl.on_accelerator(outputs) + grads = tl.on_accelerator(grads) + old_state = tl.on_accelerator(old_state) + new_state = tl.on_accelerator(new_state) + opt_params = tl.on_accelerator(opt_params) + rng = tl.on_accelerator(rng) + new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo( + outputs, weights, grads, old_state, new_state, slots, opt_params, rng, step + ) + layer.weights = self._lazy_unreplicate(new_weights) # accelerator -> cpu + layer.state = self._unreplicate(new_state) + optimizer.slots = self._unreplicate(new_slots) + stack = cb.outputs_onto_stack(inputs, stack, layer.n_out) + grad_stack = cb.outputs_onto_stack(grads, grad_stack, layer.n_out) + return stack, grad_stack, layer_stats + + +# Forward + backward + optimizer-update functions for all layers. +# We call them in short FBO for "Forward + Backward + Optimizer update". + + +def _fbo_with_layer_and_opt(layer, optimizer, n_devices, stats_name=None, adasum=False): + """Create the fbo function for a given layer and optimizer.""" + + def fbo(inputs, weights, grads, state, slots, opt_params, rng, step): + """FBO of the layer.""" + # We need a layer pure_fn but only for inputs and weights. + def pure_fn_without_state_and_rng(x, w): + return layer.pure_fn(x, w, state, rng) + + # Calculate the vector-Jacobian product of the reduced pure fn. + activations, vjp_fn, new_state = fastmath.vjp( + pure_fn_without_state_and_rng, inputs, weights, has_aux=True + ) + + # In the loss layer, set gradients to 1 with the dtype of activations=loss. + if grads is None and stats_name is not None: + grads = jnp.ones((), dtype=activations.dtype) + + # The vjp function returns gradients with respect to inputs and weights. + grads_inputs, grads_weights = vjp_fn(grads) + + # For non-trainable layers, return the calculated arguments. + if _is_empty_tuple(weights): + stats = {} + if stats_name is not None: + stats[stats_name] = activations + return weights, new_state, slots, grads_inputs, stats + + # In multi-device setting, average gradients from multiple devices. + if n_devices > 1: + grads_weights = _average_multidevice_gradients(grads_weights, adasum=adasum) + + # Run the optimizer. + new_weights, new_slots, stats = optimizer.tree_update( + step, grads_weights, weights, slots, opt_params, store_slots=False + ) + if stats_name is not None: + stats[stats_name] = activations + return new_weights, new_state, new_slots, grads_inputs, stats + + return fbo + + +# Reversible layers define a reverse_and_fbo function that both reverses +# and runs the forward-backward pass and applied the optimizer. +# This function uses the `reverse_and_grad` method of reversible layers. + + +def _reverse_and_fbo_with_layer_and_opt(layer, optimizer, n_devices, adasum): + """Create the reverse_and_fbo function for a given layer and optimizer.""" + + def reverse_and_fbo( + output, weights, grads, state, new_state, slots, opt_params, rng, step + ): + """Reverse and FBO of the layer.""" + # Call the reverse_and_grad method of the layer. + inputs, (grads_inputs, grads_weights) = layer.reverse_and_grad( + output, grads, weights, state, new_state, rng=rng + ) + + # For non-trainable layers, return the calculated arguments. + if _is_empty_tuple(weights): + return weights, slots, inputs, grads_inputs, {} + + # In multi-device setting, average gradients from multiple devices. + if n_devices > 1: + grads_weights = _average_multidevice_gradients(grads_weights, adasum=adasum) + + # Run the optimizer. + new_weights, new_slots, stats = optimizer.tree_update( + step, grads_weights, weights, slots, opt_params, store_slots=False + ) + + return new_weights, new_slots, inputs, grads_inputs, stats + + return reverse_and_fbo + + +def _is_empty_tuple(x): + """Check if x is either empty or a tuple of (tuples of) empty things.""" + if not isinstance(x, (list, tuple)): + return False + for y in x: + if not _is_empty_tuple(y): + return False + return True + + +def extract_reversible_blocks(layers, loss_chunk_size=0): + """Extracts blocks and loss layer for use with ReversibleSerialTrainer. + + Args: + layers: a list of layers of a single layer to extract blocks from; + should end with a loss, e.g., [model, loss] or tl.Serial(model, loss). + loss_chunk_size: int, if > 0 creates a chunked loss layer to save memory + in models with larger vocabulary; requires the last sublayers of loss + are [Dense, LogSoftmax, _CrossEntropy, _WeightedMean] in that order. + + Returns: + a pair (blocks, loss_layer) to use with ReversibleSerialTrainer. + """ + + def _flatten(l): + """Flatten all Serial layers and sub(sub-...) layers into a list.""" + if isinstance(l, (list, tuple)): + return [ + x for layer in l for x in _flatten(layer) + ] # pylint: disable=g-complex-comprehension + elif isinstance(l, tl.Serial): + return _flatten(l.sublayers) + else: + return [l] + + # Extract standard and reversible layer blocks. + blocks, std_layers, rev_layers = [], [], [] + for layer in _flatten(layers): + if isinstance(layer, tl.ReversibleLayer): + rev_layers.append(layer) + elif not rev_layers: + std_layers.append(layer) + else: + blocks.append((std_layers, rev_layers)) + std_layers, rev_layers = [], [] + std_layers.append(layer) + if rev_layers: + raise ValueError("The final layer must be a standard loss, not reversible.") + if loss_chunk_size > 0: + # For now we only do chunking of [Dense, LogSoftmax, CrossEntopy, Mean] + # Let's check that these are the last 4 layers. + border_layers = ["StripFromConcatenateWithPadding", "Select"] + + loss_start = None + for index, layer in enumerate(std_layers): + if layer.name in border_layers: + loss_start = index + 1 + if loss_start is None: + raise ValueError( + "Loss layer should be preceeded by one of {}; got {}".format( + border_layers, [l.name for l in std_layers] + ) + ) + if len(std_layers) - loss_start < 4: + raise ValueError("Too short loss layer for chunking") + last_3_names = " ".join([l.name for l in std_layers[-3:]]) + if last_3_names != "LogSoftmax _CrossEntropy _WeightedMean": + raise ValueError( + 'Loss chunking only works with last layers being "' + 'LogSoftmax, _CrossEntropy, _WeightedMean" but got: ' + last_3_names + ) + + # Create chunked dense+logsoftmax+cross-entropy-loss. + chunked_xent = tl.Chunk(tl.Serial(std_layers[loss_start:-1]), loss_chunk_size) + # The chunked loss should operate on a merged batch dimension, e.g., + # including both length and batch size. Need to merge and un-merge later. + def _reshape_to_batch_and_copy_targets(preds, targets): + batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]]) + batched_targets = jnp.reshape(targets, [-1]) + return batched_preds, batched_targets, targets + + def _reshape_xent_back(xent, targets): + return jnp.reshape(xent, targets.shape) + + batched_xent = tl.Serial( + tl.Fn("pre_xent_rebatch", _reshape_to_batch_and_copy_targets, n_out=3), + chunked_xent, + tl.Fn("after_xent_rebatch", _reshape_xent_back), + ) + loss_layer = tl.Serial(std_layers[:loss_start] + [batched_xent], std_layers[-1]) + else: + loss_layer = tl.Serial(std_layers) + return blocks, loss_layer + + +def init_reversible_blocks(blocks, loss_layer, input_signature, rng): + """Initialize reversible blocks and the loss layer and place weights on CPU. + + Args: + blocks: List of reversible blocks (pairs of layer lists). + loss_layer: The final loss layer to initialize. + input_signature: The signature of the input to the blocks. + rng: Random key used to initialize the layers. + """ + sig_stack = input_signature + process = psutil.Process(os.getpid()) + mem_use = process.memory_info().rss + for (std_layers, rev_layers) in blocks: + rngs = fastmath.random.split(rng, len(std_layers) + len(rev_layers) + 1) + rng = rngs[0] + for layer, layer_rng in zip(std_layers + rev_layers, rngs[1:]): + sig = cb.inputs_from_stack(sig_stack, layer.n_in) + layer.init(sig, rng=layer_rng) + layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory + layer.state = tl.on_cpu(layer.state) # store weights in cpu memory + logging.info( + "init: layer %s\nadded cpu memory (MB): %.2f", + str(layer), + (process.memory_info().rss - mem_use) / float(1024 * 1024), + ) + mem_use = process.memory_info().rss + logging.info( + "init: cpu memory use (MB): %.2f", mem_use / float(1024 * 1024) + ) + out_sig = layer.output_signature(sig) + sig_stack = cb.outputs_onto_stack(out_sig, sig_stack, layer.n_in) + loss_layer.init(cb.inputs_from_stack(sig_stack, loss_layer.n_in), rng=rng) + loss_layer.weights = tl.on_cpu(loss_layer.weights) + loss_layer.state = tl.on_cpu(loss_layer.state) diff --git a/trax/trainers/jax.py b/trax/trainers/jax.py new file mode 100644 index 000000000..2befd1093 --- /dev/null +++ b/trax/trainers/jax.py @@ -0,0 +1,994 @@ +# Refactored trainers.py +import math + +import jax +import jax.numpy as jnp +import numpy as np + +import trax.layers as tl + + +def _adasum_merge(a, b): + """Compute the AdaSum of two vectors.""" + dot_val = jnp.vdot(a, b) + a_sq = jnp.vdot(a, a) + b_sq = jnp.vdot(b, b) + # Handle zero-norm edge cases + if a_sq == 0 or b_sq == 0: + return a + b + gamma = a_sq / (a_sq + b_sq) + # If dot < 0, combine them scaled by gamma; else just add. + return gamma * a + (1.0 - gamma) * b if dot_val < 0 else a + b + + +def _average_multidevice_gradients(gradients, adasum=False): + """ + Averages (or Adasum-reduces) 'gradients' across devices using the axis_name='batch'. + + If adasum=False, we do a standard pmean. + If adasum=True, we do a simple all_gather & reduce approach, for demonstration. + """ + if not adasum: + # Standard average via pmean + return jax.lax.pmean(gradients, axis_name="batch") + else: + # Demonstration: gather all grads to each device, then reduce them. + # (A real Adasum might do ring-based or hierarchical merges.) + gathered = jax.lax.all_gather(gradients, axis_name="batch") + + # gathered.shape now has an extra leading dimension [n_devices]. + # We'll do a simple tree_map to accumulate them one by one. + def adasum_reduce(g_list): + acc = g_list[0] + for g in g_list[1:]: + acc = jax.tree_map(_adasum_merge, acc, g) + return acc + + # Because we used all_gather, 'gathered' is shaped like [n_devices, ...] for each leaf + # So we need to pass that list of leaves to adasum_reduce. + # We'll do a small helper to slice along the 0th dimension: + n_devices = ( + gathered[0].shape[0] if isinstance(gathered, tuple) else gathered.shape[0] + ) + + # flatten out the leading dimension for each leaf + # to produce a python list we can fold over: + def gather_to_list(x): + # x shape is (n_devices, ...) -> list of n_devices leaves + return [x[i] for i in range(n_devices)] + + # Now do adasum reduction leaf-by-leaf: + return jax.tree_map( + lambda arrs: adasum_reduce(arrs), jax.tree_map(gather_to_list, gathered) + ) + + +def _pad_batch_for_devices(batch, n_devices): + """ + If batch_size is not divisible by n_devices, pad the leading dimension so it is. + Returns (padded_batch, unpad_amount). + + 'batch' should be a tuple/list of arrays, or a PyTree that includes arrays + on the leading dimension for each item in the batch. + """ + batch_size = batch[0].shape[0] # assume batch is e.g. (input, target, ...) + remainder = batch_size % n_devices + if remainder == 0: + return batch, 0 + + new_size = math.ceil(batch_size / n_devices) * n_devices + to_pad = new_size - batch_size + + def pad_fn(x): + # x has shape [batch_size, ...] + return jnp.pad(x, [(0, to_pad)] + [(0, 0)] * (x.ndim - 1), mode="constant") + + padded = jax.tree_map(pad_fn, batch) + return padded, to_pad + + +def _unpad_batch_outputs(outputs, to_remove): + """ + If we padded the batch by 'to_remove' examples, remove them from + the leading dimension of the returned arrays. + """ + if to_remove == 0: + return outputs + + def unpad_fn(x): + # x has leading dimension we want to slice off the last 'to_remove' elements + return x[:-to_remove] if x.shape[0] > to_remove else x[:0] + + return jax.tree_map(unpad_fn, outputs) + + +def _accelerate_update_fn(forward_and_backward_fn, optimizer, n_devices, adasum): + """ + Returns an update_fn that: + - single-device => jitted function + - multi-device => pmapped function that also does gradient averaging or Adasum + """ + + @jax.jit + def single_device_update_fn( + weights, state, opt_state, batch, rng, step_int, opt_params + ): + # 1) Forward + backward pass -> grads, loss, updated_state + grads, loss, updated_state = forward_and_backward_fn(batch, weights, state, rng) + + # 2) Optimizer update + new_weights, new_opt_state, _metrics = optimizer.tree_update( + step_int, grads, weights, opt_state, opt_params, store_slots=False + ) + return new_weights, updated_state, new_opt_state, loss + + if n_devices <= 1: + # Single device => just call the jitted function + return single_device_update_fn + + # For multi-device: we pmap around single_device_update_fn + def multi_device_update_fn( + weights, state, opt_state, batch, rngs, step_int, opt_params + ): + """ + Each device runs single_device_update_fn on a shard of the batch, + then we do gradient averaging (or Adasum). + """ + + def _per_device_step(w, s, o, b, r): + """ + We do the forward/backward but also average grads across devices + inside this pmap, so each device ends up with the same update. + """ + # -- forward+backward pass -- + grads, loss, st_new = forward_and_backward_fn(b, w, s, r) + # -- average or Adasum the grads across devices -- + grads = _average_multidevice_gradients(grads, adasum=adasum) + # -- apply optimizer update -- + w_new, o_new, _metrics = optimizer.tree_update( + step_int, grads, w, o, opt_params, store_slots=False + ) + return w_new, st_new, o_new, loss + + # We call pmap over the per-device-step + w_updated, s_updated, o_updated, loss = jax.pmap( + _per_device_step, axis_name="batch" + )(weights, state, opt_state, batch, rngs) + return w_updated, s_updated, o_updated, loss + + return multi_device_update_fn + + +class Trainer: + """A trainers that supports single- or multi-device, with optional Adasum, padding, etc.""" + + def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False): + """ + Args: + model_with_loss: A layer that returns (loss, new_state) from pure_fn(...) + optimizer: An optimizer with .tree_init(...) and .tree_update(...) methods + n_devices: Number of devices to use + adasum: Whether to do Adasum gradient reduction (instead of standard averaging) + """ + self._model_with_loss = model_with_loss + self._optimizer = optimizer + self._n_devices = n_devices or jax.local_device_count() + self._adasum = adasum + + # Initialize optimizer state from the model's initial weights + self._slots, self._opt_params = optimizer.tree_init( + self._model_with_loss.weights + ) + + # Build forward+backward function with value_and_grad(has_aux=True) + def forward_and_backward_fn(batch, weights, state, rng): + """ + Returns (gradients, loss, new_state). + """ + + def loss_fn(curr_w, curr_s): + loss_val, new_st = model_with_loss.pure_fn( + batch, curr_w, curr_s, rng, use_cache=True + ) + return loss_val, new_st + + (loss_val, new_state), grads = jax.value_and_grad( + loss_fn, argnums=0, has_aux=True + )(weights, state) + + return grads, loss_val, new_state + + self._forward_and_backward_fn = forward_and_backward_fn + + # Build an update function that does single vs. multi-device + self._accelerated_update_fn = _accelerate_update_fn( + self._forward_and_backward_fn, + self._optimizer, + self._n_devices, + self._adasum, + ) + + @property + def model_with_loss(self): + return self._model_with_loss + + @property + def optimizer(self): + return self._optimizer + + @property + def slots(self): + return self._slots + + def one_step(self, batch, rng, step=0, learning_rate=None): + """ + 1) Possibly pad the batch for multi-device + 2) Single- or multi-device forward/backward + 3) Update weights & state + 4) Unpad if needed, return loss + """ + if learning_rate is not None: + self._opt_params["learning_rate"] = learning_rate + + weights = self._model_with_loss.weights + state = self._model_with_loss.state + + if self._n_devices == 1: + # Single device => just run the function directly (already jitted). + (new_weights, new_state, new_slots, loss,) = self._accelerated_update_fn( + weights, + state, + self._slots, + batch, + rng, + step, + self._opt_params, + ) + + # Store + self._model_with_loss.weights = new_weights + self._model_with_loss.state = new_state + self._slots = new_slots + self._optimizer.slots = new_slots + return loss + + # + # Multi-device => pad the batch if needed, replicate, call pmapped update + # + padded_batch, to_remove = _pad_batch_for_devices(batch, self._n_devices) + padded_size = padded_batch[0].shape[0] + batch_per_device = padded_size // self._n_devices + + # Split rng if it's just a single key + if isinstance(rng, np.ndarray) and rng.shape == (2,): + rng = jax.random.split(rng, self._n_devices) + + # Reshape batch for devices + padded_batch = jax.tree_map( + lambda x: x.reshape((self._n_devices, batch_per_device) + x.shape[1:]), + padded_batch, + ) + + # Replicate weights/state/slots + weights_rep = jax.tree_map( + lambda x: np.broadcast_to(x, (self._n_devices,) + x.shape), weights + ) + state_rep = jax.tree_map( + lambda x: np.broadcast_to(x, (self._n_devices,) + x.shape), state + ) + slots_rep = jax.tree_map( + lambda x: np.broadcast_to(x, (self._n_devices,) + x.shape), self._slots + ) + + # Run the pmapped update + ( + updated_weights_rep, + updated_state_rep, + updated_slots_rep, + loss_rep, + ) = self._accelerated_update_fn( + weights_rep, state_rep, slots_rep, padded_batch, rng, step, self._opt_params + ) + + # Unreplicate results + new_weights = self._unreplicate(updated_weights_rep) + new_state = self._unreplicate(updated_state_rep) + new_slots = self._unreplicate(updated_slots_rep) + loss_vals = self._unreplicate(loss_rep) + + # If we want a single scalar, e.g. average across devices: + # each device sees the same final "loss", if we've pmean'd it, + # so we can just do: + final_loss = float(loss_vals) if np.size(loss_vals) == 1 else np.mean(loss_vals) + + # Update trainers + self._model_with_loss.weights = new_weights + self._model_with_loss.state = new_state + self._slots = new_slots + self._optimizer.slots = new_slots + + # If your model returns per-example losses, you might want to unpad the output + # after the forward pass. But here we've just got a scalar loss, so no unpadding needed + # for the loss. If you needed to unpad e.g. a predictions array, you'd do it here. + + return final_loss + + def _unreplicate(self, tree): + """Return the first element of a replicated array (from shape [n_devices,...] to [...]).""" + return jax.tree_map(lambda x: x[0], tree) + + +class ReversibleSerialTrainer: + """Trainer for a sequence of reversible layers - optimized with JAX JIT.""" + + def __init__( + self, + model_with_loss, + optimizer_fn, + n_devices=None, + adasum=False, + n_steps_per_log=None, + n_async_layers=0, + jit_memory=True, + do_free=True, + ): + """Initialize the trainers. + + Args: + model_with_loss: Serial layer with loss at the end. + optimizer_fn: Function creating an optimizer for each layer. + n_devices: Number of accelerator devices to use in the computation. + adasum: Whether to use Adasum algorithm for gradient aggregation. + n_steps_per_log: How often to log results. + n_async_layers: How many layers to run asynchronously. + jit_memory: Whether to JIT memory cleanup operations. + do_free: Whether to free memory during training. + """ + # First, we need to extract the model and the loss from the model_with_loss. + # Usually model_with_loss is a Serial of the original model and the loss. + if not isinstance(model_with_loss, tl.Serial): + # We may already be given just the model. + self._loss_layer = model_with_loss + self._blocks = None + self._n_layers = 1 + else: + self._loss_layer = model_with_loss[-1] + self._blocks, _ = extract_reversible_blocks(model_with_loss) + # Number of all layers (not blocks, as reversible blocks have 2 layers). + self._n_layers = len(model_with_loss.sublayers) + + # Initialize other training parameters + self._optimizer_fn = optimizer_fn + self._n_devices = n_devices or jax.local_device_count() + self._adasum = adasum + self._n_steps_per_log = n_steps_per_log + self._n_async_layers = n_async_layers + + # Initialize memory management parameters + self._jit_memory = jit_memory + self._do_free = do_free + + # Initialize RNG handling + self._jit_per_device_rngs = jax.pmap( + lambda rng: jax.random.split(rng, jax.local_device_count()), + axis_name="batch", + ) + + # Initialize the accelerated layer functions - JIT compiled versions + if self._blocks is not None: + # Initialize reverse layers + shapes = (1, 8) # Will be replaced by actual batch shapes + + # Create JIT-compiled forward and backward functions for each layer + self._accelerated_layer_fns = [] + for layer in self._blocks: + + def fwd_fn(x, weights, state, rng): + return layer.pure_fn(x, weights, state, rng, True) + + def bwd_fn(y, weights, state, rng, grad_y): + def compute_loss(y): + return jnp.mean(y) # Dummy loss for grad computation + + vjp_fn = jax.vjp(compute_loss, y)[1] + return vjp_fn(grad_y)[0] + + # JIT-compile these functions + self._accelerated_layer_fns.append((jax.jit(fwd_fn), jax.jit(bwd_fn))) + + # Initialize optimizers for each block + if self._blocks is not None: + self._optimizers = [] + self._replicated_opt_params = [] + + # Create optimizer for each layer + for i, block in enumerate(self._blocks): + opt = optimizer_fn(block) + self._optimizers.append(opt) + + # Initialize optimizer state for each layer + if i == len(self._blocks) - 1: + # Last layer includes the loss layer + slots, opt_params = opt.tree_init(block.weights) + else: + slots, opt_params = opt.tree_init(block.weights) + + # Replicate optimizer parameters for multi-device training + self._replicated_opt_params.append(self._replicate(opt_params)) + + # Initialize optimizer for the loss layer + self._loss_opt = optimizer_fn(self._loss_layer) + slots, opt_params = self._loss_opt.tree_init(self._loss_layer.weights) + self._replicated_loss_opt_params = self._replicate(opt_params) + + # Create forward-backward-optimize functions + if self._blocks is not None: + self._fbos = [] + for i, block in enumerate(self._blocks): + # Create the forward-backward-optimize function for this layer + fbo = self._pjit(_fbo_with_layer_and_opt, static_argnums=(0, 1)) + self._fbos.append(fbo) + + # Create loss function forward-backward-optimize + self._loss_fbo = self._pjit(_fbo_with_layer_and_opt, static_argnums=(0, 1)) + + def loss_layer(self): + """Returns the loss layer.""" + return self._loss_layer + + def all_layers(self): + """Returns a list of all layers in the model.""" + if self._blocks is None: + return [self._loss_layer] + layers = [] + for block in self._blocks: + layers.extend(block.sublayers) + layers.append(self._loss_layer) + return layers + + def optimizer_fn(self): + """Returns the optimizer function.""" + return self._optimizer_fn + + def slots(self): + """Returns the optimizer slots.""" + slots = [] + if self._blocks is not None: + for i, block in enumerate(self._blocks): + slots.append(block.weights) + slots.append(self._loss_layer.weights) + return slots + + def slots_and_params(self): + """Returns the optimizer slots and parameters.""" + slots = [] + params = [] + if self._blocks is not None: + for i, opt in enumerate(self._optimizers): + s, p = opt.slots, self._unreplicate(self._replicated_opt_params[i]) + slots.append(s) + params.append(p) + s, p = self._loss_opt.slots, self._unreplicate(self._replicated_loss_opt_params) + slots.append(s) + params.append(p) + return slots, params + + def _pjit(self, f, *args, **kwargs): + """Apply jit compilation but avoiding tl.Accelerate.""" + if self._n_devices == 1: + return jax.jit(f, *args, **kwargs) + return jax.pmap(f, axis_name="batch", *args, **kwargs) + + def _replicate(self, x): + """Replicate a tree of values for use on multiple devices.""" + if self._n_devices <= 1: + return x + return jax.tree_map( + lambda y: jnp.broadcast_to(y, (self._n_devices,) + y.shape), x + ) + + def _replicate_cpu(self, x): + """Replicate a tree of values for use on multiple devices, allowing CPU arrays.""" + if self._n_devices <= 1: + return x + + def rep(y): + if isinstance(y, np.ndarray): + return np.broadcast_to(y, (self._n_devices,) + y.shape) + elif isinstance(y, jnp.ndarray): + return jnp.broadcast_to(y, (self._n_devices,) + y.shape) + else: + return y + + return jax.tree_map(rep, x) + + def _unreplicate(self, x): + """Take the first component of a replicated tree of values.""" + return jax.tree_map(lambda y: y[0], x) + + def _lazy_unreplicate(self, x): + """Like _unreplicate but avoids data movement if possible.""" + if isinstance(x, list) and len(x) == 1: + return x[0] + if self._n_devices == 1: + return x + + def get_first(y): + if y.shape[0] == self._n_devices: + return y[0] + return y + + return jax.tree_map(get_first, x) + + def _collect_weights(self): + """Collect weights from all layers into a single list.""" + weights = [] + if self._blocks is not None: + for block in self._blocks: + weights.append(block.weights) + weights.append(self._loss_layer.weights) + return weights + + def _free_accelerators( + self, n_done_per_replica, replica_id, n_to_do_in_replica=None + ): + """Free accelerator memory not used by a replica at a given step.""" + if not self._do_free: + return + + if n_to_do_in_replica is None: + n_to_do_in_replica = len(self._blocks) * 2 + 3 + + done_rate = n_done_per_replica / n_to_do_in_replica + + # If we have done a large chunk, we can free memory + if done_rate >= 0.5: + # Apply JIT compilation to memory operations if configured + if self._jit_memory: + # Define a memory cleanup function and JIT it + @jax.jit + def cleanup(): + # Reset JAX memory allocation + jax.lax.stop_gradient(0.0) + # Add explicit synchronization + jax.lax.psum(0, axis_name="batch") + return 0 + + cleanup() + else: + # Simple memory cleanup without JIT + jax.lax.stop_gradient(0.0) + jax.lax.psum(0, axis_name="batch") + + def _per_device_rngs(self, rng): + """Create different RNG keys for different devices.""" + if isinstance(rng, np.ndarray) and rng.shape == (2,): + if self._n_devices == 1: + return rng + # Create different RNG keys for different devices + return jax.random.split(rng, self._n_devices) + + # In multi-device case, we get a precomputed set of rngs + return rng + + def one_step(self, batch, rng, step=0, learning_rate=None): + """Run one step of gradient-based training. + + Args: + batch: Batch of training data. + rng: Random number generator. + step: Current training step. + learning_rate: Optional learning rate to use. + + Returns: + Loss computed on the batch. + """ + # Update the learning rate if needed + if learning_rate is not None: + if self._blocks is not None: + for params in self._replicated_opt_params: + params["learning_rate"] = learning_rate + self._replicated_loss_opt_params["learning_rate"] = learning_rate + + # Prepare the batch for multiple devices if needed + if self._n_devices > 1: + batch_size = batch[0].shape[0] + batch_per_device = batch_size // self._n_devices + + batch = jax.tree_map( + lambda x: x.reshape(self._n_devices, batch_per_device, *x.shape[1:]), + batch, + ) + + # Prepare RNGs for each device + device_rngs = self._per_device_rngs(rng) + + if self._blocks is None: + # No reversible layers - direct computation + # Forward pass through the loss layer + output, updated_state = self._loss_layer.pure_fn( + batch, + self._loss_layer.weights, + self._loss_layer.state, + device_rngs, + True, + ) + + # Create the input-output gradient function + def grad_fn(weights): + output, _ = self._loss_layer.pure_fn( + batch, weights, self._loss_layer.state, device_rngs, True + ) + return output + + # Compute gradients for the loss layer + gradients = jax.grad(grad_fn)(self._loss_layer.weights) + + # Average gradients across devices if needed + if self._n_devices > 1: + gradients = _average_multidevice_gradients( + gradients, self._n_devices, self._adasum + ) + + # Update the weights with the optimizer + updates, updated_opt_state = self._loss_opt.tree_update( + gradients, self._loss_opt.slots, self._loss_layer.weights, step + ) + + # Apply updates to weights + updated_weights = jax.tree_map( + lambda w, u: w + u, self._loss_layer.weights, updates + ) + + self._loss_layer.weights = updated_weights + self._loss_layer.state = updated_state + self._loss_opt.slots = updated_opt_state + + return output + + # We have reversible blocks - run the full reversible computation + if not self._blocks[0].sublayers[0].has_backward: + # Standard case - run forward and backward passes separately + (output, updated_state), inputs_stack = self._run_forward_standard( + batch, device_rngs + ) + + # Compute loss gradients + loss_gradients = jax.grad( + lambda w: self._loss_layer.pure_fn( + output, w, self._loss_layer.state, device_rngs, True + )[0] + )(self._loss_layer.weights) + + # Average gradients across devices if needed + if self._n_devices > 1: + loss_gradients = _average_multidevice_gradients( + loss_gradients, self._n_devices, self._adasum + ) + + # Update loss layer weights + loss_updates, loss_updated_opt_state = self._loss_opt.tree_update( + loss_gradients, self._loss_opt.slots, self._loss_layer.weights, step + ) + + self._loss_layer.weights = jax.tree_map( + lambda w, u: w + u, self._loss_layer.weights, loss_updates + ) + self._loss_layer.state = updated_state + + # Run backward pass to compute and update weights for all blocks + self._run_backward_standard(output, inputs_stack, device_rngs, step) + + return output + else: + # Reversible case - use specialized forward-backward + output, output_grad = self._run_forward_reversible(batch, device_rngs) + + # Run backward pass for all reversible blocks + loss = self._run_backward_reversible(output, output_grad, device_rngs, step) + + return loss + + def _run_forward_standard(self, batch, rngs): + """Run the forward pass in standard (non-reversible) mode.""" + # Extract inputs and targets + inputs_stack = [] + + # Forward pass through all blocks + for i, block in enumerate(self._blocks): + inputs_stack.append(batch) + # Run the actual forward pass for this block + batch, updated_state = block.pure_fn( + batch, block.weights, block.state, rngs, True + ) + + # Update block state + if i < len(self._blocks) - 1: + block.state = updated_state + + # Final forward pass through the loss layer + output, loss_updated_state = self._loss_layer.pure_fn( + batch, self._loss_layer.weights, self._loss_layer.state, rngs, True + ) + + self._loss_layer.state = loss_updated_state + + return (output, loss_updated_state), inputs_stack + + def _run_forward_reversible(self, batch, rngs): + """Run the forward pass in reversible mode.""" + # Extract inputs and targets + # Initialize the activations list + activations = [] + + # Forward pass through all blocks + for i, block in enumerate(self._blocks): + # Add the current input to activations + activations.append(batch) + + # Run the forward pass for this block + batch, updated_state = block.pure_fn( + batch, block.weights, block.state, rngs, True + ) + + # Update the block state + block.state = updated_state + + # Final forward pass through the loss layer + output, loss_updated_state = self._loss_layer.pure_fn( + batch, self._loss_layer.weights, self._loss_layer.state, rngs, True + ) + + self._loss_layer.state = loss_updated_state + + # Compute the output gradient + def loss_fn(x): + return self._loss_layer.pure_fn( + x, self._loss_layer.weights, self._loss_layer.state, rngs, True + )[0] + + # Get the gradient with respect to the output + output_grad = jax.grad(loss_fn)(batch) + + return output, output_grad + + def _run_backward_standard(self, loss, inputs_stack, rngs, step): + """Run the backward pass in standard (non-reversible) mode.""" + # Compute gradients for all blocks + grad_fn = lambda weights, i: self._blocks[i].pure_fn( + inputs_stack[i], weights, self._blocks[i].state, rngs, True + )[0] + + # Process blocks in reverse order + for i in range(len(self._blocks) - 1, -1, -1): + # Compute gradients for this block + block_gradients = jax.grad(lambda w: grad_fn(w, i))(self._blocks[i].weights) + + # Average gradients across devices if needed + if self._n_devices > 1: + block_gradients = _average_multidevice_gradients( + block_gradients, self._n_devices, self._adasum + ) + + # Update block weights + block_updates, block_updated_opt_state = self._optimizers[i].tree_update( + block_gradients, + self._optimizers[i].slots, + self._blocks[i].weights, + step, + ) + + self._blocks[i].weights = jax.tree_map( + lambda w, u: w + u, self._blocks[i].weights, block_updates + ) + self._optimizers[i].slots = block_updated_opt_state + + # Free accelerator memory if configured + if self._do_free: + self._free_accelerators(len(self._blocks) - i, 0) + + def _run_backward_reversible(self, batch, loss, output_grads, rngs, step): + """Run the backward pass in reversible mode. + + Args: + batch: The input batch data + loss: The loss value from forward pass + output_grads: Gradients of the loss + rngs: Random number generators + step: Current training step + + Returns: + The loss value + """ + # Initialize the gradient to be backpropagated + grads = output_grads + + # Process blocks in reverse order + for i in range(len(self._blocks) - 1, -1, -1): + # Get the input for this block + if i > 0: + inputs = self._blocks[i - 1].output + else: + # First block - get the original input + inputs = batch[0] # Assuming batch is a tuple of (inputs, targets) + + # Run the backward pass for this block + block_gradients, grads = self._run_backward_one_reversible( + i, inputs, grads, rngs + ) + + # Average gradients across devices if needed + if self._n_devices > 1: + block_gradients = _average_multidevice_gradients( + block_gradients, self._n_devices, self._adasum + ) + + # Use the optimizer's update method to get new weights and updated slots + block_weights = self._blocks[i].weights + opt_slots = self._optimizers[i].slots + opt_params = self._optimizers[i].opt_params + + # Update weights using optimizer's own update logic + new_weights, new_slots = self._optimizers[i].tree_update( + block_gradients, opt_slots, block_weights, step, opt_params + ) + + # Update block weights and optimizer slots + self._blocks[i].weights = new_weights + self._optimizers[i].slots = new_slots + + # Free accelerator memory if configured + if self._do_free: + self._free_accelerators(len(self._blocks) - i, 0) + + return loss + + def _run_backward_one_reversible(self, block_index, inputs, output_grads, rngs): + """Run the backward pass for one reversible block.""" + # Get the block + block = self._blocks[block_index] + + # Define the forward function for gradient computation + def forward_fn(weights, inputs): + output, _ = block.pure_fn(inputs, weights, block.state, rngs, True) + return output + + # Compute block gradients with reverse-mode autodiff + block_gradients, input_grads = jax.vjp( + lambda w: forward_fn(w, inputs), block.weights + )[1](output_grads) + + return block_gradients, input_grads + + +def _fbo_with_layer_and_opt( + optimizer, + layer, + inputs, + weights, + state, + rngs, + opt_state, + opt_params, + grads=None, + step=None, +): + """Forward + backward + optimize on a single layer.""" + # JIT-compiled function for forward-backward-optimize + if grads is None: + # Forward pass + output, new_state = layer.pure_fn(inputs, weights, state, rngs, True) + + # Define gradient function + def loss_fn(weights): + output, _ = layer.pure_fn(inputs, weights, state, rngs, True) + return jnp.mean(output) + + # Compute gradients + gradients = jax.grad(loss_fn)(weights) + else: + # Use provided gradients + gradients = grads + new_state = state + output = None + + # Optimize + updates, new_opt_state = optimizer.tree_update( + gradients, opt_state, weights, step, opt_params + ) + + # Apply updates + new_weights = jax.tree_map(lambda w, u: w + u, weights, updates) + + return output, new_weights, new_state, new_opt_state, gradients + + +def _reverse_and_fbo_with_layer_and_opt( + optimizer, + reversible_layer, + output, + output_grad, + weights, + state, + rngs, + opt_slots, + opt_params, + step=None, +): + """Reverse-mode computation + optimize for a reversible layer.""" + + # Define the backward function for gradient computation + def backward_fn(weights): + # Define a forward pass that computes outputs for these weights + def forward_fn(x): + y, _ = reversible_layer.pure_fn(x, weights, state, rngs, True) + return y + + # Use VJP to compute gradients backward + _, vjp_fn = jax.vjp(forward_fn, output) + return vjp_fn(output_grad)[0] + + # Compute input gradient and weight gradients + input_grad = backward_fn(weights) + + # Compute weight gradients using the chain rule + weight_grads = jax.grad( + lambda w: jnp.sum( + reversible_layer.pure_fn(output, w, state, rngs, True)[0] * output_grad + ) + )(weights) + + # Use optimizer to compute new weights and slots + new_weights, new_slots = optimizer.tree_update( + weight_grads, opt_slots, weights, step, opt_params + ) + + return input_grad, new_weights, state, new_slots + + +def extract_reversible_blocks(layer): + """Extract reversible blocks from a serial layer. + + Args: + layer: A layer, usually a Serial layer containing reversible blocks. + + Returns: + A tuple (reversible_blocks, loss_layer) where reversible_blocks is + a list of blocks that are reversible and loss_layer is the final + loss layer or None if not present. + """ + if not isinstance(layer, tl.Serial): + return [], layer + + blocks = [] + loss_layer = None + + # Check if the last layer is a loss layer + if hasattr(layer.sublayers[-1], "n_in") and layer.sublayers[-1].n_in == 2: + loss_layer = layer.sublayers[-1] + sublayers = layer.sublayers[:-1] + else: + sublayers = layer.sublayers + + # Group layers into reversible blocks + i = 0 + while i < len(sublayers): + if isinstance(sublayers[i], tl.ReversibleLayer) or ( + hasattr(sublayers[i], "has_backward") and sublayers[i].has_backward + ): + blocks.append(sublayers[i]) + i += 1 + elif ( + i + 1 < len(sublayers) + and isinstance(sublayers[i], tl.ReversibleHalfResidual) + and isinstance(sublayers[i + 1], tl.ReversibleHalfResidual) + ): + # Pair of ReversibleHalfResidual layers make a reversible block + blocks.append(tl.ReversibleResidual(sublayers[i], sublayers[i + 1])) + i += 2 + else: + # Non-reversible layer - wrap it in a serial block + blocks.append(tl.Serial(sublayers[i])) + i += 1 + + return blocks, loss_layer diff --git a/trax/trax2keras.py b/trax/trax2keras.py deleted file mode 100644 index df84ef6f1..000000000 --- a/trax/trax2keras.py +++ /dev/null @@ -1,189 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Trax-to-Keras converter.""" - -import functools - -import tensorflow.compat.v2 as tf - -from trax import fastmath as math_lib -from trax import shapes as shapes_lib -from trax.fastmath import numpy as jnp -from trax.layers import base - - -def _replace_none_batch(x, batch_size=None): - if batch_size is None: - return x - if isinstance(x, tf.Tensor) and x.shape[0] is None: - x.set_shape([batch_size] + x.shape[1:]) - return x - elif isinstance(x, tf.TensorShape) and x[0] is None: - return [batch_size] + x[1:] - return x - - -def tensor_shapes_to_shape_dtypes(shapes, dtype): - return math_lib.nested_map( - lambda s: shapes_lib.ShapeDtype(s.as_list(), dtype), shapes) - - -def read_values(variables): - return math_lib.nested_map(lambda v: v.read_value(), variables) - - -def to_tensors(args): - return math_lib.nested_map(tf.convert_to_tensor, args) - - -def to_arrays(args): - return math_lib.nested_map(jnp.asarray, args) - - -class AsKeras(tf.keras.layers.Layer): - """A Keras layer built from a Trax layer. - - This subclass of `tf.keras.layers.Layer` takes in a Trax layer as a - constructor argument and wraps it to be a Keras layer. It uses - `tf.Variable` to store weights and state (initialized according to the Trax - layer), and uses the Trax layer's forward function as its forward function. - - Consider this code snippet:: - - keras_layer = AsKeras(trax_layer, initializer_rng=initializer_rng, - rng=rng, rng_updater=rng_updater) - keras_layer.build(...) # optional - outputs = keras_layer(inputs) - - (Note that in Keras calling `Layer.build` is optional. If omitted, it will be - called automatically by `Layer.__call__`.) - - If `trax_layer` already has weights at `build` time, the snippet is roughly - equivalent to:: - - weights = trax_layer.weights - state = trax_layer.state - keras_layer = tf.keras.layers.Layer() - keras_layer._weights = tf.Variable(weights) - keras_layer._state = tf.Variable(state) - keras_layer._rng = tf.Variable(rng) - outputs, new_state = trax_layer(inputs, keras_layer._weights, - keras_layer._state, keras_layer._rng) - keras_layer._state.assign(new_state) - keras_layer._rng.assign(rng_updater(rng)) - - If `trax_layer` doesn't have weights at `build` time, the snippet is roughly - equivalent to:: - - weights, state = trax_layer.init(..., rng=initializer_rng) - keras_layer = ... - ... - - `AsKeras` uses `tf.Variable` to store weights, not shared with the - original Trax layer (which uses tensors to store weights), so using - `AsKeras` may double the memory footprint. This problem can be solved - by making sure that the Trax layer's weights/state are cleared whenever - `tf.Variable.assign` (and `tf.Variable.assign_add` etc.) is called, because - `tf.Variable` is copy-on-write by default. - - Mutations in those `tf.Variable`s won't affect the Trax layer's weights, but - `AsKeras`'s forward function calls the Trax layer's forward function, - which caches the weights in the Trax layer object, so a forward pass may - change the weights cached in the original Trax layer. - - Note that this class is not thread-safe. If the same `AsKeras` object - is used in multiple threads, the `tf.Variable` updates may happen in a - non-deterministic order. - """ - - def __init__(self, trax_layer, batch_size=None, initializer_rng=None, - rng=None, rng_updater=None, dtype=None): - """Creates a Keras layer wrapping around a Trax layer. - - Args: - trax_layer: an object of class `trax.layers.Layer`, the trax layer to - wrap. - batch_size: (optional) an integer, the batch size that this Keras layer - will be used on. Keras sometimes needs to generate a TF graph for a - layer (e.g. for acceleration or checkpointing). The inputs used to trace - the graph will have `None` as the length of their batch dimensions, so - as to generate a graph that can handle any batch size. Some Trax layers - can't handle tensors whose shapes contain `None`. If `batch_size` is set - to an integer, the graph will be traced with `batch_size` as the batch - size instead of `None`. Note that in this case the graph (and the Keras - layer) can only be used on a specific batch size. If you want to use a - different batch size, you need to create another `AsKeras` object - with a different `batch_size`. - initializer_rng: (optional) an RNG key used to create the weights and - state if `trax_layer` doesn't have them. If `None`, - `trax.fastmath.random.get_prng(0)` will be used. - rng: (optional) an RNG key for the forward function (aka the "forward - key"). If `None`, `trax.fastmath.random.get_prng(0)` will be used. - rng_updater: (optional) a function of type rng_key -> rng_key, used to - update the forward key after each forward pass. If `None`, the function - `lambda x: trax.fastmath.random.split(x, 1)[0]` will be used, which - advances the RNG key. - dtype: (optional) the dtype of the inputs. See the `dtype` argument of - `tf.keras.layers.Layer.__init__` for details. - """ - super().__init__(dtype=dtype) - with math_lib.use_backend(math_lib.Backend.TFNP): - if initializer_rng is None: - initializer_rng = math_lib.random.get_prng(0) - if rng is None: - rng = math_lib.random.get_prng(0) - if rng_updater is None: - rng_updater = lambda x: math_lib.random.split(x, 1)[0] - self._trax_layer = trax_layer - self._batch_size = batch_size - self._initializer_rng = initializer_rng - self._forward_rng_init = rng - self._rng_updater = rng_updater - - def build(self, input_shape): - with math_lib.use_backend(math_lib.Backend.TFNP): - # Using `is` instead of `==` following Trax's practice - if self._trax_layer.weights is base.EMPTY_WEIGHTS: - sanitized_input_shape = math_lib.nested_map( - functools.partial(_replace_none_batch, batch_size=self._batch_size), - input_shape) - weights, state = self._trax_layer.init( - tensor_shapes_to_shape_dtypes(sanitized_input_shape, self.dtype), - rng=self._initializer_rng) - else: - weights = self._trax_layer.weights - state = self._trax_layer.state - # Note: `weights` may contain `EMPTY_WEIGHTS` - self._weights = math_lib.nested_map( - functools.partial(tf.Variable, trainable=True), weights) - self._state = math_lib.nested_map( - functools.partial(tf.Variable, trainable=False), state) - self._rng = tf.Variable(self._forward_rng_init, trainable=False) - super().build(input_shape) - - def call(self, inputs): - with math_lib.use_backend(math_lib.Backend.TFNP): - inputs = math_lib.nested_map( - functools.partial(_replace_none_batch, batch_size=self._batch_size), - inputs) - weights, state, rng = read_values([self._weights, self._state, self._rng]) - inputs, weights, state, rng = to_arrays([inputs, weights, state, rng]) - outputs, new_state = self._trax_layer.pure_fn(inputs, weights=weights, - state=state, rng=rng) - tf.nest.map_structure(lambda v, t: v.assign(t), self._state, new_state) - self._rng.assign(self._rng_updater(rng)) - outputs = to_tensors(outputs) - return outputs diff --git a/trax/trax2keras_test.py b/trax/trax2keras_test.py deleted file mode 100644 index 9fbf86f52..000000000 --- a/trax/trax2keras_test.py +++ /dev/null @@ -1,192 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for trax2keras.""" - -import os - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as onp - -import tensorflow.compat.v2 as tf - -import trax -from trax import fastmath as math_lib -from trax import layers -from trax import trax2keras -from trax.fastmath import numpy as jnp -from trax.models import mlp -from trax.models import transformer -from trax.trax2keras import read_values -from trax.trax2keras import to_arrays -from trax.trax2keras import to_tensors - - -tf.enable_v2_behavior() - - -def has_gpu(): - return bool(tf.config.list_physical_devices("GPU")) - - -def dummy_inputs(rng, input_sig): - def f(sig): - shape = sig.shape - if shape and shape[0] is None: - shape = (2,) + tuple(shape[1:]) - if onp.issubdtype(sig.dtype, onp.integer): - minval = 1 - # Must specify maxval for integer dtype. - # TODO(afrozm): Revisit after TF 2.3 - maxval = 10000 - else: - minval = 0 - maxval = 1 - return rng.uniform( - shape=shape, dtype=sig.dtype, minval=minval, maxval=maxval) - return math_lib.nested_map(f, input_sig) - - -def Mod(n): # pylint: disable=invalid-name - return layers.Fn("Mod", lambda x: x % n) - - -# Format: -# (trax-layer maker, input shapes, input dtype, can handle None batch size?) -_LAYERS = [ - (lambda: layers.Dense(3), tf.TensorShape([4]), onp.float32, True), - (mlp.MLP, tf.TensorShape([4]), onp.float32, False), - (lambda: layers.Serial(Mod(8), transformer.TransformerLM(8)), - tf.TensorShape([4]), onp.int32, False), -] - - -_RNG_UPDATERS = [ - lambda x: x, - lambda rng: math_lib.random.split(rng, 1)[0], -] - - -# Needs tf.test.TestCase for `assertAllClose` and `get_temp_dir` -class Trax2KerasTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters( - [{"testcase_name": "_%s_%s_%s_%s_%s_%s" % ( # pylint: disable=g-complex-comprehension - layer_id, rng_updater_id, batch_size, trax_has_weights, - explicit_build, use_model), - "layer_id": layer_id, - "rng_updater_id": rng_updater_id, - "batch_size": batch_size, - "trax_has_weights": trax_has_weights, - "explicit_build": explicit_build, - "use_model": use_model,} - for use_model in [True, False] - for explicit_build in [True, False] - for trax_has_weights in [True, False] - for batch_size in [2, None] - for rng_updater_id in [1] - for layer_id in range(len(_LAYERS)) - ]) - def testTrain(self, layer_id, rng_updater_id, batch_size, trax_has_weights, - explicit_build, use_model): - """Tests training (forward and backward pass) for AsKeras. - - Args: - layer_id: an integer, the index into `_LAYERS`. - rng_updater_id: an integer, the index into `_RNG_UPDATERS`. - batch_size: an integer or `None`, the value for the `batch_size` argument - in `AsKeras.__init__`. - trax_has_weights: bool, whether to make the trax layer contain weights at - the time when `AsKeras.build` is called. - explicit_build: bool, whether to explicitly call `AsKeras.build`. - use_model: bool, whether to build a `tf.keras.Model` out of the - `AsKeras` layer and use the model to do the training instead of - the bare layer. If `True`, we will also test checkpointing and restoring - using the model. - """ - with trax.fastmath.use_backend("tensorflow-numpy"): - make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = ( - _LAYERS[layer_id]) - # We make a fresh trax layer for each test case, so that different test - # cases won't interfere with each other. - trax_layer = make_trax_layer() - if not allow_none_batch and batch_size is None: - self.skipTest("This Trax layer can't handle None batch size.") - rng_updater = _RNG_UPDATERS[rng_updater_id] - input_shapes = math_lib.nested_map( - lambda s: [batch_size] + s, input_shapes_no_batch) - input_sig = trax2keras.tensor_shapes_to_shape_dtypes(input_shapes, dtype) - initializer_rng = math_lib.random.get_prng(765) - weights, state = trax_layer.init(input_sig, rng=initializer_rng) - generator = tf.random.Generator.from_seed(567) - def get_inputs(): - return dummy_inputs(generator, input_sig) - if trax_has_weights: - trax_layer(to_arrays(get_inputs()), weights=weights, state=state) - rng = math_lib.random.get_prng(1234) - keras_layer = trax2keras.AsKeras( - trax_layer, batch_size=batch_size, initializer_rng=initializer_rng, - rng=rng, rng_updater=rng_updater) - if explicit_build: - keras_layer.build(input_shapes) - if use_model: - x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype) - y = keras_layer(x) - keras_model = tf.keras.Model(inputs=x, outputs=y) - lr = 0.1 # learning rate - for _ in range(3): - inputs = get_inputs() - with tf.GradientTape() as trax_tape: - trax_tape.watch(tf.nest.flatten(weights)) - trax_outputs, state = trax_layer.pure_fn( - to_arrays(inputs), weights=weights, state=state, rng=rng) - trax_grads = trax_tape.gradient(*to_tensors([trax_outputs, weights])) - # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor` - # before multiplication. - weights = tf.nest.map_structure( - lambda w, g: w + jnp.asarray(lr * tf.convert_to_tensor(g), w.dtype), - weights, trax_grads) - rng = rng_updater(rng) - with tf.GradientTape() as keras_tape: - if use_model: - keras_outputs = keras_model(inputs) - else: - keras_outputs = keras_layer(inputs) - if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1: - keras_outputs = keras_outputs[0] - self.assertAllClose(to_tensors(trax_outputs), keras_outputs, atol=1e-5) - keras_grads = keras_tape.gradient(keras_outputs, - keras_layer.trainable_variables) - tf.nest.map_structure( - lambda v, g: v.assign_add( # pylint: disable=g-long-lambda - tf.cast(lr * tf.convert_to_tensor(g), v.dtype)), - keras_layer.trainable_variables, keras_grads) - self.assertAllClose( - to_tensors(weights), read_values(keras_layer._weights), - rtol=2e-6, atol=4.5e-4 if has_gpu() else 1e-6) - self.assertAllClose(to_tensors(state), read_values(keras_layer._state)) - self.assertAllClose(to_tensors(rng), read_values(keras_layer._rng)) - if use_model: - fname = os.path.join(self.get_temp_dir(), "checkpoint") - keras_model.save(fname) - loaded_model = tf.keras.models.load_model(fname) - for _ in range(2): - inputs = get_inputs() - self.assertAllClose(keras_model(inputs), loaded_model(inputs)) - - -if __name__ == "__main__": - absltest.main() diff --git a/trax/utils/__init__.py b/trax/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trax/utils/jaxboard.py b/trax/utils/jaxboard.py new file mode 100644 index 000000000..35dd38cef --- /dev/null +++ b/trax/utils/jaxboard.py @@ -0,0 +1,378 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Write Summaries from JAX for use with Tensorboard. + +See jaxboard_demo.py for example usage. +""" +import io +import struct +import time +import warnings +import wave + +import matplotlib as mpl + +# Necessary to prevent attempted Tk import: +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + mpl.use("Agg") +# pylint: disable=g-import-not-at-top +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf + + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.core.util import event_pb2 +from tensorflow.python.summary.writer.event_file_writer import EventFileWriter + +# pylint: enable=g-direct-tensorflow-import + + +def _pack_images(images, rows, cols): + """Helper utility to make a tiled field of images from numpy arrays. + + Args: + images: Image tensor in shape [N, W, H, C]. + rows: Number of images per row in tiled image. + cols: Number of images per column in tiled image. + + Returns: + A tiled image of shape [W * rows, H * cols, C]. + Truncates incomplete rows. + """ + shape = np.shape(images) + width, height, depth = shape[-3:] + images = np.reshape(images, (-1, width, height, depth)) + batch = np.shape(images)[0] + rows = np.minimum(rows, batch) + cols = np.minimum(batch // rows, cols) + images = images[: rows * cols] + images = np.reshape(images, (rows, cols, width, height, depth)) + images = np.transpose(images, [0, 2, 1, 3, 4]) + images = np.reshape(images, [rows * width, cols * height, depth]) + return images + + +class SummaryWriter: + """Saves data in event and summary protos for tensorboard.""" + + def __init__(self, log_dir, enable=True): + """Create a new SummaryWriter. + + Args: + log_dir: path to record tfevents files in. + enable: bool: if False don't actually write or flush data. Used in + multihost training. + """ + # If needed, create log_dir directory as well as missing parent directories. + if not tf.io.gfile.isdir(log_dir): + tf.io.gfile.makedirs(log_dir) + + self._event_writer = EventFileWriter(log_dir, 10, 120, None) + self._step = 0 + self._closed = False + self._enabled = enable + + def add_summary(self, summary, step): + if not self._enabled: + return + event = event_pb2.Event(summary=summary) + event.wall_time = time.time() + if step is not None: + event.step = int(step) + self._event_writer.add_event(event) + + def close(self): + """Close SummaryWriter. Final!""" + if not self._closed: + self._event_writer.close() + self._closed = True + del self._event_writer + + def __del__(self): # safe? + # TODO(afrozm): Sometimes this complains with + # `TypeError: 'NoneType' object is not callable` + try: + self.close() + except Exception: # pylint: disable=broad-except + pass + + def flush(self): + if not self._enabled: + return + self._event_writer.flush() + + def scalar(self, tag, value, step=None): + """Saves scalar value. + + Args: + tag: str: label for this data + value: int/float: number to log + step: int: training step + """ + value = float(np.array(value)) + if step is None: + step = self._step + else: + self._step = step + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)] + ) + self.add_summary(summary, step) + + def image(self, tag, image, step=None): + """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3]. + + Args: + tag: str: label for this data + image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ + step: int: training step + """ + image = np.array(image) + if step is None: + step = self._step + else: + self._step = step + if len(np.shape(image)) == 2: + image = image[:, :, np.newaxis] + if np.shape(image)[-1] == 1: + image = np.repeat(image, 3, axis=-1) + image_strio = io.BytesIO() + plt.imsave(image_strio, image, format="png") + image_summary = tf.compat.v1.Summary.Image( + encoded_image_string=image_strio.getvalue(), + colorspace=3, + height=image.shape[0], + width=image.shape[1], + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)] + ) + self.add_summary(summary, step) + + def images(self, tag, images, step=None, rows=None, cols=None): + """Saves (rows, cols) tiled images from np.ndarray. + + If either rows or cols aren't given, they are determined automatically + from the size of the image batch, if neither are given a long column + of images is produced. This truncates the image batch rather than padding + if it doesn't fill the final row. + + Args: + tag: str: label for this data + images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d + step: int: training step + rows: int: number of rows in tile + cols: int: number of columns in tile + """ + images = np.array(images) + if step is None: + step = self._step + else: + self._step = step + n_images = np.shape(images)[0] + if rows is None and cols is None: + rows = 1 + cols = n_images + elif rows is None: + rows = n_images // cols + elif cols is None: + cols = n_images // rows + tiled_images = _pack_images(images, rows, cols) + self.image(tag, tiled_images, step=step) + + def plot(self, tag, mpl_plt, step=None, close_plot=True): + """Saves matplotlib plot output to summary image. + + Args: + tag: str: label for this data + mpl_plt: matplotlib stateful pyplot object with prepared plotting state + step: int: training step + close_plot: bool: automatically closes plot + """ + if step is None: + step = self._step + else: + self._step = step + fig = mpl_plt.get_current_fig_manager() + img_w, img_h = fig.canvas.get_width_height() + image_buf = io.BytesIO() + mpl_plt.savefig(image_buf, format="png") + image_summary = tf.compat.v1.Summary.Image( + encoded_image_string=image_buf.getvalue(), + colorspace=4, # RGBA + height=img_h, + width=img_w, + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)] + ) + self.add_summary(summary, step) + if close_plot: + mpl_plt.close() + + def audio(self, tag, audiodata, step=None, sample_rate=44100): + """Saves audio. + + NB: single channel only right now. + + Args: + tag: str: label for this data + audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave + step: int: training step + sample_rate: sample rate of passed in audio buffer + """ + audiodata = np.array(audiodata) + if step is None: + step = self._step + else: + self._step = step + audiodata = np.clip(np.squeeze(audiodata), -1, 1) + if audiodata.ndim != 1: + raise ValueError("Audio data must be 1D.") + sample_list = (32767.0 * audiodata).astype(int).tolist() + wio = io.BytesIO() + wav_buf = wave.open(wio, "wb") + wav_buf.setnchannels(1) + wav_buf.setsampwidth(2) + wav_buf.setframerate(sample_rate) + enc = b"".join([struct.pack(" 0 + else np.concatenate([[0], counts[:end]]) + ) + limits = limits[start : end + 1] + sum_sq = values.dot(values) + histo = tf.compat.v1.HistogramProto( + min=values.min(), + max=values.max(), + num=len(values), + sum=values.sum(), + sum_squares=sum_sq, + bucket_limit=limits.tolist(), + bucket=counts.tolist(), + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)] + ) + self.add_summary(summary, step) + + def text(self, tag, textdata, step=None): + """Saves a text summary. + + Args: + tag: str: label for this data + textdata: string, or 1D/2D list/numpy array of strings + step: int: training step + Note: markdown formatting is rendered by tensorboard. + """ + if step is None: + step = self._step + else: + self._step = step + smd = tf.compat.v1.SummaryMetadata( + plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name="text") + ) + if isinstance(textdata, (str, bytes)): + tensor = tf.make_tensor_proto( + values=[textdata.encode(encoding="utf_8")], shape=(1,) + ) + else: + textdata = np.array(textdata) # convert lists, jax arrays, etc. + datashape = np.shape(textdata) + if len(datashape) == 1: + tensor = tf.make_tensor_proto( + values=[td.encode(encoding="utf_8") for td in textdata], + shape=(datashape[0],), + ) + elif len(datashape) == 2: + tensor = tf.make_tensor_proto( + values=[ + td.encode(encoding="utf_8") for td in np.reshape(textdata, -1) + ], + shape=(datashape[0], datashape[1]), + ) + summary = tf.compat.v1.Summary( + value=[tf.compat.v1.Summary.Value(tag=tag, metadata=smd, tensor=tensor)] + ) + self.add_summary(summary, step) + + +# Copied from gin/tf/utils.py:GinConfigSaverHook +def markdownify_operative_config_str(string): + """Convert an operative config string to markdown format.""" + + # TODO(b/37527917): Total hack below. Implement more principled formatting. + def process(line): + """Convert a single line to markdown format.""" + if not line.startswith("#"): + return " " + line + + line = line[2:] + if line.startswith("===="): + return "" + if line.startswith("None"): + return " # None." + if line.endswith(":"): + return "#### " + line + return line + + output_lines = [] + for line in string.splitlines(): + procd_line = process(line) + if procd_line is not None: + output_lines.append(procd_line) + + return "\n".join(output_lines) diff --git a/trax/utils/predict_drop.py b/trax/utils/predict_drop.py new file mode 100644 index 000000000..a8596d713 --- /dev/null +++ b/trax/utils/predict_drop.py @@ -0,0 +1,354 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prediction binary for the Drop task. + +Binary that loads a checkpoint and runs inference on selected problems +from the Drop dataset. For more details about Drop see +https://arxiv.org/pdf/1903.00161.pdf. +""" + +import json +import os +import re +import time + +import gin +import jax +import numpy as np +import tensorflow as tf + +from absl import app as absl_app +from absl import flags +from seqio import vocabularies as t5_spc_vocab +from t5 import data + +from trax import data as trax_data +from trax import layers as tl +from trax.learning.supervised import decoding +from trax.utils import shapes + +FLAGS = flags.FLAGS + +flags.DEFINE_string("checkpoint_dir", "", "Path to model checkpoint.") +flags.DEFINE_integer("max_answer_len", 1024, "Maximum length of answers to produce.") +flags.DEFINE_integer("batch_size", 1, "Batch size for eval.") +flags.DEFINE_integer("num_examples", 1, "Number of examples to infer.") +flags.DEFINE_integer("n_hashes", None, "n_hashes parameter to override in attentions.") +flags.DEFINE_integer("example_repetitions", 1, "How many times to infer an example.") +flags.DEFINE_bool( + "use_eval_mode", False, "If True, use the slower but easier to debugger eval mode." +) +flags.DEFINE_bool("use_eval_set", False, "If True, use eval set for evaluation.") +flags.DEFINE_bool( + "use_beam_search", + False, + "If True, use beam search, otherwise use autoregresive sampling.", +) +flags.DEFINE_float( + "autoregressive_sample_temp", 1, "The temperature for autoregressive sampling." +) +flags.DEFINE_integer("n_beams", 4, "How many beams to use in beam search.") +flags.DEFINE_string( + "output_dir", + "", + "Path to the output directory where articles, abstracts, " + "and predictions would be stored.", +) +flags.DEFINE_integer("starting_example", 0, "Example index for starting decoding.") +flags.DEFINE_integer( + "reload_after", 1000, "Reload checkpoint after reload_after examples." +) +flags.DEFINE_multi_string( + "config_file", None, "Configuration file with parameters (.gin)." +) + + +def _check_exists(file_path): + if not tf.io.gfile.exists(file_path): + print("No such file: %s" % file_path, flush=True) + exit(1) + + +def multiply_examples(example): + for i in range(FLAGS.example_repetitions): + yield i, example + + +def prepare_model(model_file, batch_size=1): + """Prepare the model.""" + mode = "eval" if FLAGS.use_eval_mode else "predict" + print("Initializing the model in %s mode." % mode, flush=True) + + # Read the model name from the gin file + model_reference = gin.query_parameter("trax.supervised.trainer_lib.train.model") + model = model_reference.scoped_configurable_fn(mode=mode) + + dec_len = 32 if FLAGS.use_eval_mode else 1 + batch_size_pd = max(1, batch_size // jax.local_device_count()) + shape11 = shapes.ShapeDtype((batch_size_pd, dec_len), dtype=np.int32) + # shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) + model.init_from_file( + model_file, weights_only=True, input_signature=(shape11, shape11) + ) + model = tl.Accelerate(model) + + initial_state = model.state + vocab = t5_spc_vocab.SentencePieceVocabulary(data.DEFAULT_SPM_PATH) + + return vocab, model, initial_state + + +def is_number(s): + try: + float(s) + return True + except ValueError: + return False + + +def main(argv): + if len(argv) > 1: + raise absl_app.UsageError("Too many command-line arguments.") + if not FLAGS.output_dir: + raise absl_app.UsageError("--output_dir needs to be provided.") + + # Ensure eager execution in TF2 without raising when already enabled. + if not tf.executing_eagerly(): + tf.compat.v1.enable_eager_execution() + + # Check that checkpoint_dir is correct: should contain model.pkl.gz file. + model_file = os.path.join(FLAGS.checkpoint_dir, "model.pkl.gz") + _check_exists(model_file) + + gin.parse_config_file(os.path.join(FLAGS.checkpoint_dir, "config.gin")) + # Batching on our own because of possible repetitions of examples. + gin.bind_parameter("data.Batch.batch_size", 1) + if FLAGS.n_hashes is not None: + gin.bind_parameter("LSHSelfAttention.n_hashes", FLAGS.n_hashes) + gin.bind_parameter("ref2_encoder/LSHSelfAttention.n_hashes", FLAGS.n_hashes) + + vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) + + host_id, host_count = jax.host_id(), jax.host_count() + print("Running on host %d out of %d." % (host_id, host_count)) + + example_count = 0 + start_time = time.time() + + # Creates all intermediate directories if they do not exist + tf.io.gfile.makedirs(FLAGS.output_dir) + + json_to_write = os.path.join(FLAGS.output_dir, "output%d.json" % host_id) + all_jsons = [] + + # In a case of a reset we have to check how much work was already done. + # We can check whether the processing of an example was finished, but + # currently we are only checking whether it was started. + done = FLAGS.starting_example + reload_count = 0 + all_existing_files = tf.io.gfile.listdir(FLAGS.output_dir) + for filename in all_existing_files: + if "processing" in filename: + # The definition of digits looks for a number after the infix "processing" + # in the file name. Example: tom_processing_532 will lead to + # digits = "processing_532" and number equal to "532". + digits = filename[filename.find("processing") :] + number = "".join(d for d in digits if d.isdigit()) + if ( + is_number(number) + and int(number) < FLAGS.num_examples + FLAGS.starting_example + ): + done = max(done, int(number)) + print("The done number is {}".format(done)) + + if FLAGS.use_eval_set: + drop_gen = trax_data.CreateDropInputs(train=False)() + else: + drop_gen = trax_data.CreateDropInputs(train=True)() + padding_fun = trax_data.PadToLength() + + # TODO(henrykm): improve managment of the counters. + # example_count_total - all numeric examples + # example_count - all numeric examples above starting_example + # reload_count - if we processed FLAGS.reload_after examples, + # then the checkpoint should be reloaded. + # idx - total number of exaples + example_count_total = 0 + reload_count += 1 + for idx, e in enumerate(drop_gen): + if reload_count >= FLAGS.reload_after: + vocab, model, initial_state = prepare_model(model_file, FLAGS.batch_size) + reload_count = 0 + if example_count >= FLAGS.num_examples: + print("Reached the example_count {} - breaking".format(example_count)) + break + if not is_number(e[1]): + continue + target_answer = float(e[1]) + + # We count numeric starting examples + example_count_total += 1 + if example_count_total <= FLAGS.starting_example: + print( + "Skipping example_count_total {} because it is below {}".format( + example_count_total, FLAGS.starting_example + ) + ) + continue + + if example_count % 10 == 0: + elapsed_time = time.time() - start_time + start_time = time.time() + print( + "Starting inference on example %d, %.2fs since last log" + % (example_count, elapsed_time), + flush=True, + ) + + example_count += 1 + if example_count <= done - FLAGS.starting_example + 1: + print( + "Skipping example_count {} because it is below {}".format( + example_count, done - FLAGS.starting_example + ) + ) + # We are increasing the example_count because the example + # was processed before + continue + + if example_count % host_count != host_id: + continue + + # At this point we are committed to the processing of an example with + # index example_count + processing_file = os.path.join(FLAGS.output_dir, "processing_") + data_id = str(example_count + FLAGS.starting_example) + with tf.io.gfile.GFile(processing_file + data_id, "w") as w: + w.write("Procesing started.") + for repetition_id, example in multiply_examples(e): + question = example[0] + question_text = question[question.find(":") + 2 :] + question_text = question_text.replace("-", " - ") + question = "infer full calculation: " + question_text + + list_num = [ + float(num.replace(",", "").rstrip(".")) + for num in re.findall( + r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", question + ) + ] + for i in range(len(list_num)): + question += " n{} = {}".format(i, list_num[i]) + + # print('Question {}'.format(question)) + tokenized_question = next( + padding_fun( + trax_data.tokenize( + [ + question, + ], + vocab_file=gin.query_parameter("trax.data.Tokenize.vocab_file"), + ) + ) + ) + state = model.state + if FLAGS.use_beam_search: + answer_beams = decoding.beam_search( + model, + tokenized_question[None, :], + n_beams=FLAGS.n_beams, + max_length=FLAGS.max_answer_len, + accelerate=False, + ) + model.state = state + else: + answer_beams = [] + # We recycle the n_beams flag to control the number + # of autoregressive samples. + for i in range(FLAGS.n_beams): + answer = decoding.autoregressive_sample( + model, + tokenized_question[None, :], + temperature=FLAGS.autoregressive_sample_temp, + max_length=FLAGS.max_answer_len, + accelerate=False, + ) + model.state = state + answer_beams.append(answer) + + correct_example_index = -1 + + for i in range(len(answer_beams)): + if FLAGS.use_beam_search: + answer = trax_data.detokenize( + answer_beams[i][0][0], + vocab_file=gin.query_parameter("trax.data.Tokenize.vocab_file"), + ) + else: + answer = trax_data.detokenize( + answer_beams[i][0], + vocab_file=gin.query_parameter("trax.data.Tokenize.vocab_file"), + ) + print("Proposed computation {}".format(answer)) + list_op = answer.split("|") + if not list_op[-1]: + list_op = list_op[:-1] + + try: + result = trax_data.tf_inputs.compute_result(list_op, list_num) + if target_answer in result: + correct_example_index = result.index(target_answer) + break + # This is a temporary hack with "broad" exceptions - the computations + # must fail sometime, because we evaluate arbitrary sequences; I am in + # the process of checking what are possible failure modes. + except Exception as e: # pylint: disable=broad-except + print(e) + try: + result = trax_data.tf_inputs.compute_result( + list_op[:-1], list_num + ) + if target_answer in result: + correct_example_index = result.index(target_answer) + break + except Exception as e: # pylint: disable=broad-except + print(e) + print("Infered incorrect computation.") + + if correct_example_index == -1: + continue + + json_record = { + "question": question_text, + "input": question, + "calculation": "|".join(list_op[: correct_example_index + 1]), + "target_answer": target_answer, + } + all_jsons.append(json.dumps(json_record) + "\n") + # Outputting the inferred data in JSONL format. + data_id = str(example_count + FLAGS.starting_example) + with tf.io.gfile.GFile(json_to_write + data_id, "w") as w: + w.write(json.dumps(json_record) + "\n") + with tf.io.gfile.GFile(processing_file + data_id, "w") as w: + w.write("Procesing finished.") + + with tf.io.gfile.GFile(json_to_write + "_" + str(FLAGS.starting_example), "w") as w: + for record in all_jsons: + w.write(record) + + +if __name__ == "__main__": + absl_app.run(main) diff --git a/trax/utils/rl_trainer.py b/trax/utils/rl_trainer.py new file mode 100644 index 000000000..d9dfb803f --- /dev/null +++ b/trax/utils/rl_trainer.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Trainer for RL environments. + +For now we only support PPO as RL algorithm. + +Sample invocation: + +.. code-block:: bash + + TRAIN_BATCH_SIZE=32 + python trax/rl_trainer.py \ + --config_file=trax/reinforcement/configs/ppo_acrobot.gin \ + --train_batch_size=${TRAIN_BATCH_SIZE} \ + --output_dir=${HOME}/ppo_acrobot \ + --alsologtostderr +""" + +import faulthandler + +import gin +import jax + +from absl import app, flags, logging + +from tests.fastmath.jax.config import config +from trax import fastmath +from trax.learning.reinforcement import task as rl_task +from trax.learning.reinforcement import training as light_trainers +from trax.tf import numpy as tf_np + +FLAGS = flags.FLAGS + + +# Not just 'train' to avoid a conflict with trax.train in GIN files. +@gin.configurable(denylist=["output_dir"], module="trax") +def train_rl( + output_dir, + n_epochs=10000, + light_rl=True, + light_rl_trainer=light_trainers.PolicyGradient, +): + """Train the RL agent. + + Args: + output_dir: Output directory. + n_epochs: Number epochs to run the training for. + light_rl: deprecated, always True, left out for old gin configs. + light_rl_trainer: which light RL trainers to use (experimental). + """ + del light_rl + tf_np.set_allow_float64(FLAGS.tf_allow_float64) + task = rl_task.RLTask() + env_name = task.env_name + + if FLAGS.jax_debug_nans: + config.update("jax_debug_nans", True) + + if FLAGS.use_tpu: + config.update("jax_platform_name", "tpu") + else: + config.update("jax_platform_name", "") + + trainer = light_rl_trainer(task=task, output_dir=output_dir) + + def light_training_loop(): + """Run the trainers for n_epochs and call close on it.""" + try: + logging.info("Starting RL training for %d epochs.", n_epochs) + trainer.run(n_epochs, n_epochs_is_total_epochs=True) + logging.info("Completed RL training for %d epochs.", n_epochs) + trainer.close() + logging.info("Trainer is now closed.") + except Exception as e: + raise e + finally: + logging.info("Encountered an exception, still calling trainers.close()") + trainer.close() + logging.info("Trainer is now closed.") + + if FLAGS.jax_debug_nans or FLAGS.disable_jit: + fastmath.disable_jit() + with jax.disable_jit(): + light_training_loop() + else: + light_training_loop() + + +def main(argv): + del argv + logging.info("Starting RL training.") + + gin_configs = FLAGS.config if FLAGS.config is not None else [] + gin.enter_interactive_mode() + gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) + gin.exit_interactive_mode() + + logging.info("Gin config:") + logging.info(gin_configs) + + train_rl(output_dir=FLAGS.output_dir) + + # TODO(afrozm): This is for debugging. + logging.info("Dumping stack traces of all stacks.") + faulthandler.dump_traceback(all_threads=True) + + logging.info("Training is done, should exit.") + + +if __name__ == "__main__": + app.run(main) diff --git a/trax/utils/shapes.py b/trax/utils/shapes.py new file mode 100644 index 000000000..8db9c73d6 --- /dev/null +++ b/trax/utils/shapes.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core class and functions for handling data abstractly as shapes/dtypes.""" + +import numpy as np +import tensorflow as tf + + +class ShapeDtype: + """A NumPy ndarray-like object abstracted as shape and dtype. + + Main use is for representing input and output signatures. + """ + + __slots__ = ["shape", "dtype"] + + def __init__(self, shape, dtype=np.float32): + """Creates a `ShapeDtype` instance, with canonicalized `shape` and `dtype`. + + Args: + shape: A tuple or list, each element of which is an int or, less often, + `None`. + dtype: A `dtype` object, either from NumPy or TensorFlow. + + Returns: + A `ShapeDtype` instance whose `shape` is a tuple and `dtype` is a NumPy + `dtype` object. + """ + # Canonicalize shape and dtype. + if isinstance(shape, tf.TensorShape): + shape = shape.as_list() + if isinstance(shape, list): + shape = tuple(shape) + if not isinstance(shape, tuple): + raise TypeError("shape must be tuple or list; got: {}".format(shape)) + if isinstance(dtype, tf.DType): + dtype = dtype.as_numpy_dtype + + self.shape = shape + self.dtype = dtype + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.shape == other.shape + and self.dtype == other.dtype + ) + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return "ShapeDtype{{shape:{}, dtype:{}}}".format(self.shape, self.dtype) + + def __len__(self): + """Returns length of 1; relevant to input and output signatures.""" + return 1 + + def as_tuple(self): + return self.shape, self.dtype + + def replace(self, **kwargs): + """Creates a copy of the object with some parameters replaced.""" + return type(self)( + shape=kwargs.pop("shape", self.shape), + dtype=kwargs.pop("dtype", self.dtype), + ) + + +def signature(obj): + """Returns a `ShapeDtype` signature for the given `obj`. + + A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype` + instances. Note that this function is permissive with respect to its inputs + (accepts lists or tuples or dicts, and underlying objects can be any type + as long as they have shape and dtype attributes) and returns the corresponding + nested structure of `ShapeDtype`. + + Args: + obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict + of such objects. + + Returns: + A corresponding nested structure of `ShapeDtype` instances. + """ + if isinstance(obj, (list, tuple)): + output = tuple(signature(x) for x in obj) + return output if isinstance(obj, tuple) else list(output) + elif isinstance(obj, dict): + return {k: signature(v) for (k, v) in obj.items()} + else: + return ShapeDtype(obj.shape, obj.dtype) + + +def splice_signatures(*sigs): + """Creates a new signature by splicing together any number of signatures. + + The splicing effectively flattens the top level input signatures. For + instance, it would perform the following mapping: + + - `*sigs: sd1, (sd2, sd3, sd4), (), sd5` + - return: `(sd1, sd2, sd3, sd4, sd5)` + + Args: + *sigs: Any number of signatures. A signature is either a `ShapeDtype` + instance or a tuple of `ShapeDtype` instances. + + Returns: + A single `ShapeDtype` instance if the spliced signature has one element, + else a tuple of `ShapeDtype` instances. + """ + result_sigs = [] + for sig in sigs: + if isinstance(sig, (list, tuple)): + result_sigs.extend(sig) + else: + result_sigs.append(sig) + return result_sigs[0] if len(result_sigs) == 1 else tuple(result_sigs) + + +def assert_shape_equals(array, shape): + """Asserts that an array has the given shape.""" + assert array.shape == shape, "Invalid shape {}; expected {}.".format( + array.shape, shape + ) + + +def assert_same_shape(array1, array2): + """Asserts that two arrays have the same shapes.""" + assert_shape_equals(array1, array2.shape) diff --git a/trax/test_utils.py b/trax/utils/test_utils.py similarity index 80% rename from trax/test_utils.py rename to trax/utils/test_utils.py index cca9f722d..0be848195 100644 --- a/trax/test_utils.py +++ b/trax/utils/test_utils.py @@ -19,17 +19,15 @@ from absl import flags -FLAGS = flags.FLAGS - # pytest doesn't run the test as a main, so it doesn't parse the flags # so if flags are required in tests, this will ensure that flags are manually # parsed and the desired flag exists. def ensure_flag(flag_str): - try: - getattr(FLAGS, flag_str) - except flags.UnparsedFlagAccessError: - # Manually parse flags. - FLAGS(sys.argv) - finally: - assert getattr(FLAGS, flag_str) + try: + getattr(flags.FLAGS, flag_str) + except flags.UnparsedFlagAccessError: + # Manually parse flags. + flags.FLAGS(sys.argv) + finally: + assert getattr(flags.FLAGS, flag_str) diff --git a/trax/utils/trainer.py b/trax/utils/trainer.py new file mode 100644 index 000000000..dec971767 --- /dev/null +++ b/trax/utils/trainer.py @@ -0,0 +1,202 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax trainers.""" +import atexit +import datetime +import functools +import os + +import gin +import jax +import tensorflow.compat.v2 as tf + +from absl import app, flags, logging +from jax.lib import xla_extension as xc + +from trax import fastmath +from trax.learning.supervised import trainer_lib +from trax.tf import numpy as tf_np + +FLAGS = flags.FLAGS +Backend = fastmath.Backend + + +# TODO(afrozm): Share between trainers.py and rl_trainer.py +def _tf_setup_from_flags(): + """Processes TensorFlow-relevant flags.""" + if FLAGS.enable_eager_execution: + # In TF2 eager is default; guard to avoid errors if already eager. + if not tf.executing_eagerly(): + tf.compat.v1.enable_eager_execution() + if FLAGS.tf_xla: + tf.config.optimizer.set_jit(True) + fastmath.tf.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) + tf.config.optimizer.set_experimental_options( + { + "pin_to_host_optimization": FLAGS.tf_opt_pin_to_host, + "layout_optimizer": FLAGS.tf_opt_layout, + } + ) + tf_np.set_allow_float64(FLAGS.tf_allow_float64) + + +# TODO(afrozm): Share between trainers.py and rl_trainer.py +def _gin_parse_configs(): + """Initializes gin-controlled bindings.""" + # Imports for configurables + # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable + + # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable + + configs = FLAGS.config if FLAGS.config is not None else [] + # Override with --dataset and --model + if FLAGS.dataset: + configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset) + if FLAGS.data_dir: + configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir) + if FLAGS.model: + configs.append("train.model=@trax.models.%s" % FLAGS.model) + gin.parse_config_files_and_bindings(FLAGS.config_file, configs) + + +def _output_dir_or_default(): + """Returns a path to the output directory.""" + if FLAGS.output_dir: + output_dir = FLAGS.output_dir + trainer_lib.log("Using --output_dir {}".format(output_dir)) + return os.path.expanduser(output_dir) + + # Else, generate a default output dir (under the user's home directory). + try: + dataset_name = gin.query_parameter("data_streams.dataset_name") + except ValueError: + dataset_name = "random" + output_name = "{model_name}_{dataset_name}_{timestamp}".format( + model_name=gin.query_parameter("train.model").configurable.name, + dataset_name=dataset_name, + timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"), + ) + output_dir = os.path.join("~", "trax", output_name) + output_dir = os.path.expanduser(output_dir) + print() + trainer_lib.log("No --output_dir specified") + trainer_lib.log("Using default output_dir: {}".format(output_dir)) + return output_dir + + +# TODO(afrozm): Share between trainers.py and rl_trainer.py +def _jax_and_tf_configure_for_devices(): # pylint: disable=missing-function-docstring + if FLAGS.use_tpu: + jax.config.update("jax_platform_name", "tpu") + jax.config.update("jax_xla_backend", FLAGS.jax_xla_backend) + jax.config.update("jax_backend_target", FLAGS.jax_backend_target) + if FLAGS.enable_eager_execution and ( + fastmath.is_backend(Backend.NUMPY) or fastmath.is_backend(Backend.JAX) + ): + # Numpy backend doesn't benefit from having the input pipeline run on GPU, + # and jax backend has GPU memory contention if TF uses the GPU. Gin must be + # set up first before determining the backend. + tf.config.experimental.set_visible_devices([], "GPU") + + +def _train_using_tf(output_dir): + worker_cpu = tf_init_tpu() + with tf.device(worker_cpu): + if trainer_lib.num_devices() == 1: + # TF's device priority is GPU > CPU > TPU, so we need to explicitly make + # the TPU core the default device here. + with tf.device("/device:TPU:0"): + trainer_lib.train(output_dir=output_dir) + else: + trainer_lib.train(output_dir=output_dir) + + +@gin.configurable +def tf_init_tpu(worker="", protocol=None): + """Initializes TPU for TensorFlow. + + Args: + worker: The BNS address of the remote TPU worker. If it's empty (the default + value), TF will assume the TPU devices are connected to the local host. + protocol: The network protocol used to connect to the TPU worker. + Returns: + The device name of the TPU worker's CPU. + """ + protocol = protocol or "grpc" + is_local = worker in ("", "local") + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=worker) + if not is_local: + tf.config.experimental_connect_to_cluster(resolver, protocol=protocol) + tf.tpu.experimental.initialize_tpu_system(resolver) + if is_local: + return "" + else: + return "/job:worker" + + +def _make_jax_gpu_cluster(host_id, server_ip, n_hosts, server_port=5005): + """Make JAX GPU Cluster.""" + + addr = f"{server_ip}:{server_port}" + if host_id == 0: + logging.info("starting service on %s", addr) + service = xc.get_distributed_runtime_service(addr, n_hosts) + # We add an explicit call to shut down the service via at exit as Python + # interpreter may not call the service destructor on process termination. + atexit.register(service.shutdown) + + logging.info("connecting to service on %s", addr) + dist_client = xc.get_distributed_runtime_client(addr, host_id) + dist_client.connect() + atexit.register(dist_client.shutdown) + + # register dist gpu backend + factory = functools.partial( + jax.lib.xla_client.make_gpu_client, dist_client, host_id + ) + jax.lib.xla_bridge.register_backend_factory("gpu", factory, priority=300) + + +def main(_): + logging.set_verbosity(FLAGS.log_level) + + _tf_setup_from_flags() + _gin_parse_configs() + _jax_and_tf_configure_for_devices() + + # Create a JAX GPU cluster if using JAX and given a chief IP. + if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip: + _make_jax_gpu_cluster( + FLAGS.gpu_cluster_host_id, + FLAGS.gpu_cluster_chief_ip, + FLAGS.gpu_cluster_n_hosts, + FLAGS.gpu_cluster_port, + ) + + if FLAGS.disable_jit: + fastmath.disable_jit() + + output_dir = _output_dir_or_default() + if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): + _train_using_tf(output_dir) + else: + trainer_lib.train(output_dir=output_dir) + + trainer_lib.log("Finished training.") + + +if __name__ == "__main__": + app.run(main) diff --git a/trax/utils/trainer_flags.py b/trax/utils/trainer_flags.py new file mode 100644 index 000000000..1d95bc68a --- /dev/null +++ b/trax/utils/trainer_flags.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flags for trainers.py and rl_trainer.py. + +We keep these flags in sync across the trainers and the rl_trainer binaries. +""" + +from absl import flags, logging + +# Common flags. +flags.DEFINE_string( + "output_dir", None, "Path to the directory to save logs and checkpoints." +) +flags.DEFINE_multi_string( + "config_file", None, "Configuration file with parameters (.gin)." +) +flags.DEFINE_multi_string("config", None, "Configuration parameters (gin string).") + +# TPU Flags +flags.DEFINE_bool("use_tpu", False, "Whether we're running on TPU.") +flags.DEFINE_string( + "jax_xla_backend", + "", + 'Either "xla" for the XLA service directly, or "tpu_driver"' + "for a TPU Driver backend.", +) +flags.DEFINE_string( + "jax_backend_target", + "local", + 'Either "local" or "rpc:address" to connect to a ' "remote service target.", +) + +# trainers.py flags. +flags.DEFINE_string("dataset", None, "Which dataset to use.") +flags.DEFINE_string("model", None, "Which model to train.") +flags.DEFINE_string("data_dir", None, "Path to the directory with data.") +flags.DEFINE_integer("log_level", logging.INFO, "Log level.") + +# JAX/XLA GPU cluster flags. +flags.DEFINE_string("gpu_cluster_chief_ip", "", "IP of GPU cluster chief.") +flags.DEFINE_integer("gpu_cluster_n_hosts", 1, "Number of hosts in GPU cluster.") +flags.DEFINE_integer("gpu_cluster_host_id", 0, "Host id inside GPU cluster.") +flags.DEFINE_integer("gpu_cluster_port", 5005, "Port to use in GPU cluster.") + +# TensorFlow Flags +flags.DEFINE_bool( + "enable_eager_execution", True, "Whether we're running TF in eager mode." +) +flags.DEFINE_bool("tf_xla", True, "Whether to turn on XLA for TF.") +flags.DEFINE_bool( + "tf_opt_pin_to_host", False, "Whether to turn on TF pin-to-host optimization." +) +flags.DEFINE_bool("tf_opt_layout", False, "Whether to turn on TF layout optimization.") +flags.DEFINE_bool( + "tf_xla_forced_compile", + False, + "Use forced-compilation instead of auto-clustering for XLA." + "This flag only has effects when --tf_xla is on.", +) +flags.DEFINE_bool("tf_allow_float64", False, "Whether to allow float64 for TF.") + +# rl_trainer.py flags. +flags.DEFINE_boolean( + "jax_debug_nans", + False, + "Setting to true will help to debugger nans and disable jit.", +) +flags.DEFINE_boolean("disable_jit", False, "Setting to true will disable jit.") +flags.DEFINE_string("envs_output_dir", "", "Output dir for the envs.") +flags.DEFINE_bool("xm", False, "Copy atari roms?") +flags.DEFINE_integer( + "train_batch_size", 32, "Number of parallel environments during training." +) +flags.DEFINE_integer("eval_batch_size", 4, "Batch size for evaluation.") +flags.DEFINE_boolean( + "parallelize_envs", False, "If true, sets parallelism to number of cpu cores." +) +flags.DEFINE_string("trajectory_dump_dir", "", "Directory to dump trajectories to.") +flags.DEFINE_bool("async_mode", False, "Async mode.") diff --git a/trax/utils/trax2keras.py b/trax/utils/trax2keras.py new file mode 100644 index 000000000..2579a2ff0 --- /dev/null +++ b/trax/utils/trax2keras.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2022 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax-to-Keras converter.""" + +import functools + +import tensorflow.compat.v2 as tf # type: ignore + +from trax import fastmath as math_lib +from trax.fastmath import numpy as jnp +from trax.layers import base +from trax.utils import shapes as shapes_lib + + +def _replace_none_batch(x, batch_size=None): + if batch_size is None: + return x + if isinstance(x, tf.Tensor) and x.shape[0] is None: + x.set_shape([batch_size] + x.shape[1:]) + return x + elif isinstance(x, tf.TensorShape) and x[0] is None: + return [batch_size] + x[1:] + return x + + +def tensor_shapes_to_shape_dtypes(shapes, dtype): + return math_lib.nested_map( + lambda s: shapes_lib.ShapeDtype(s.as_list(), dtype), shapes + ) + + +def read_values(variables): + return math_lib.nested_map(lambda v: v.read_value(), variables) + + +def to_tensors(args): + return math_lib.nested_map(tf.convert_to_tensor, args) + + +def to_arrays(args): + return math_lib.nested_map(jnp.asarray, args) + + +class AsKeras(tf.keras.layers.Layer): + """A Keras layer built from a Trax layer. + + This subclass of `tf.keras.layers.Layer` takes in a Trax layer as a + constructor argument and wraps it to be a Keras layer. It uses + `tf.Variable` to store weights and state (initialized according to the Trax + layer), and uses the Trax layer's forward function as its forward function. + + Consider this code snippet:: + + keras_layer = AsKeras( + trax_layer, initializer_rng=initializer_rng, rng=rng, rng_updater=rng_updater + ) + keras_layer.build(...) # optional + outputs = keras_layer(inputs) + + (Note that in Keras calling `Layer.build` is optional. If omitted, it will be + called automatically by `Layer.__call__`.) + + If `trax_layer` already has weights at `build` time, the snippet is roughly + equivalent to:: + + weights = trax_layer.weights + state = trax_layer.state + keras_layer = tf.keras.layers.Layer() + keras_layer._weights = tf.Variable(weights) + keras_layer._state = tf.Variable(state) + keras_layer._rng = tf.Variable(rng) + outputs, new_state = trax_layer( + inputs, keras_layer._weights, keras_layer._state, keras_layer._rng + ) + keras_layer._state.assign(new_state) + keras_layer._rng.assign(rng_updater(rng)) + + If `trax_layer` doesn't have weights at `build` time, the snippet is roughly + equivalent to:: + + weights, state = trax_layer.init(..., rng=initializer_rng) + keras_layer = ... + ... + + `AsKeras` uses `tf.Variable` to store weights, not shared with the + original Trax layer (which uses tensors to store weights), so using + `AsKeras` may double the memory footprint. This problem can be solved + by making sure that the Trax layer's weights/state are cleared whenever + `tf.Variable.assign` (and `tf.Variable.assign_add` etc.) is called, because + `tf.Variable` is copy-on-write by default. + + Mutations in those `tf.Variable`s won't affect the Trax layer's weights, but + `AsKeras`'s forward function calls the Trax layer's forward function, + which caches the weights in the Trax layer object, so a forward pass may + change the weights cached in the original Trax layer. + + Note that this class is not thread-safe. If the same `AsKeras` object + is used in multiple threads, the `tf.Variable` updates may happen in a + non-deterministic order. + """ + + def __init__( + self, + trax_layer, + batch_size=None, + initializer_rng=None, + rng=None, + rng_updater=None, + dtype=None, + ): + """Creates a Keras layer wrapping around a Trax layer. + + Args: + trax_layer: an object of class `trax.layers.Layer`, the trax layer to + wrap. + batch_size: (optional) an integer, the batch size that this Keras layer + will be used on. Keras sometimes needs to generate a TF graph for a + layer (e.g. for acceleration or checkpointing). The inputs used to trace + the graph will have `None` as the length of their batch dimensions, so + as to generate a graph that can handle any batch size. Some Trax layers + can't handle tensors whose shapes contain `None`. If `batch_size` is set + to an integer, the graph will be traced with `batch_size` as the batch + size instead of `None`. Note that in this case the graph (and the Keras + layer) can only be used on a specific batch size. If you want to use a + different batch size, you need to create another `AsKeras` object + with a different `batch_size`. + initializer_rng: (optional) an RNG key used to create the weights and + state if `trax_layer` doesn't have them. If `None`, + `trax.fastmath.random.get_prng(0)` will be used. + rng: (optional) an RNG key for the forward function (aka the "forward + key"). If `None`, `trax.fastmath.random.get_prng(0)` will be used. + rng_updater: (optional) a function of type rng_key -> rng_key, used to + update the forward key after each forward pass. If `None`, the function + `lambda x: trax.fastmath.random.split(x, 1)[0]` will be used, which + advances the RNG key. + dtype: (optional) the dtype of the inputs. See the `dtype` argument of + `tf.keras.layers.Layer.__init__` for details. + """ + super().__init__(dtype=dtype) + with math_lib.use_backend(math_lib.Backend.TFNP): + if initializer_rng is None: + initializer_rng = math_lib.random.get_prng(0) + if rng is None: + rng = math_lib.random.get_prng(0) + if rng_updater is None: + rng_updater = lambda x: math_lib.random.split(x, 1)[0] + self._trax_layer = trax_layer + self._batch_size = batch_size + self._initializer_rng = initializer_rng + self._forward_rng_init = rng + self._rng_updater = rng_updater + + def build(self, input_shape): + with math_lib.use_backend(math_lib.Backend.TFNP): + # Using `is` instead of `==` following Trax's practice + if self._trax_layer.weights is base.EMPTY_WEIGHTS: + sanitized_input_shape = math_lib.nested_map( + functools.partial(_replace_none_batch, batch_size=self._batch_size), + input_shape, + ) + weights, state = self._trax_layer.init( + tensor_shapes_to_shape_dtypes(sanitized_input_shape, self.dtype), + rng=self._initializer_rng, + ) + else: + weights = self._trax_layer.weights + state = self._trax_layer.state + # Note: `weights` may contain `EMPTY_WEIGHTS` + self._weights = math_lib.nested_map( + functools.partial(tf.Variable, trainable=True), weights + ) + self._state = math_lib.nested_map( + functools.partial(tf.Variable, trainable=False), state + ) + self._rng = tf.Variable(self._forward_rng_init, trainable=False) + super().build(input_shape) + + def call(self, inputs): + with math_lib.use_backend(math_lib.Backend.TFNP): + inputs = math_lib.nested_map( + functools.partial(_replace_none_batch, batch_size=self._batch_size), + inputs, + ) + weights, state, rng = read_values([self._weights, self._state, self._rng]) + inputs, weights, state, rng = to_arrays([inputs, weights, state, rng]) + outputs, new_state = self._trax_layer.pure_fn( + inputs, weights=weights, state=state, rng=rng + ) + tf.nest.map_structure(lambda v, t: v.assign(t), self._state, new_state) + self._rng.assign(self._rng_updater(rng)) + outputs = to_tensors(outputs) + return outputs