From 65ee4c90fb4298dad03ac80fac9a6c6442e9ccf2 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Wed, 23 Mar 2022 16:05:29 -0500 Subject: [PATCH 1/2] deps: move asyncstdlib to dev --- poetry.lock | 88 +++++++++++++++++++++++++------------------------- pyproject.toml | 2 +- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/poetry.lock b/poetry.lock index bbd99de..9f91da5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -28,7 +28,7 @@ python-versions = ">=3.5" name = "asyncstdlib" version = "3.10.3" description = "The missing async toolbox" -category = "main" +category = "dev" optional = false python-versions = "~=3.6" @@ -105,7 +105,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "coverage" -version = "6.3.1" +version = "6.3.2" description = "Code coverage measurement for Python" category = "dev" optional = false @@ -528,7 +528,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "560cc7be4f02dcf0399bdd8f9517f148d1db21382196f6a6c7ca3bb24d749182" +content-hash = "7a7c118a1106e5cda8b73581f4d83c3af49fbbe2d8280c1f24949fa5a45e87be" [metadata.files] anyio = [ @@ -611,47 +611,47 @@ colorama = [ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] coverage = [ - {file = "coverage-6.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeffd96882d8c06d31b65dddcf51db7c612547babc1c4c5db6a011abe9798525"}, - {file = "coverage-6.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:621f6ea7260ea2ffdaec64fe5cb521669984f567b66f62f81445221d4754df4c"}, - {file = "coverage-6.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84f2436d6742c01136dd940ee158bfc7cf5ced3da7e4c949662b8703b5cd8145"}, - {file = "coverage-6.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de73fca6fb403dd72d4da517cfc49fcf791f74eee697d3219f6be29adf5af6ce"}, - {file = "coverage-6.3.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78fbb2be068a13a5d99dce9e1e7d168db880870f7bc73f876152130575bd6167"}, - {file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f5a4551dfd09c3bd12fca8144d47fe7745275adf3229b7223c2f9e29a975ebda"}, - {file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7bff3a98f63b47464480de1b5bdd80c8fade0ba2832c9381253c9b74c4153c27"}, - {file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a06c358f4aed05fa1099c39decc8022261bb07dfadc127c08cfbd1391b09689e"}, - {file = "coverage-6.3.1-cp310-cp310-win32.whl", hash = "sha256:9fff3ff052922cb99f9e52f63f985d4f7a54f6b94287463bc66b7cdf3eb41217"}, - {file = "coverage-6.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:276b13cc085474e482566c477c25ed66a097b44c6e77132f3304ac0b039f83eb"}, - {file = "coverage-6.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:56c4a409381ddd7bbff134e9756077860d4e8a583d310a6f38a2315b9ce301d0"}, - {file = "coverage-6.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eb494070aa060ceba6e4bbf44c1bc5fa97bfb883a0d9b0c9049415f9e944793"}, - {file = "coverage-6.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e15d424b8153756b7c903bde6d4610be0c3daca3986173c18dd5c1a1625e4cd"}, - {file = "coverage-6.3.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61d47a897c1e91f33f177c21de897267b38fbb45f2cd8e22a710bcef1df09ac1"}, - {file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:25e73d4c81efa8ea3785274a2f7f3bfbbeccb6fcba2a0bdd3be9223371c37554"}, - {file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:fac0bcc5b7e8169bffa87f0dcc24435446d329cbc2b5486d155c2e0f3b493ae1"}, - {file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:72128176fea72012063200b7b395ed8a57849282b207321124d7ff14e26988e8"}, - {file = "coverage-6.3.1-cp37-cp37m-win32.whl", hash = "sha256:1bc6d709939ff262fd1432f03f080c5042dc6508b6e0d3d20e61dd045456a1a0"}, - {file = "coverage-6.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:618eeba986cea7f621d8607ee378ecc8c2504b98b3fdc4952b30fe3578304687"}, - {file = "coverage-6.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ed164af5c9078596cfc40b078c3b337911190d3faeac830c3f1274f26b8320"}, - {file = "coverage-6.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:352c68e233409c31048a3725c446a9e48bbff36e39db92774d4f2380d630d8f8"}, - {file = "coverage-6.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:448d7bde7ceb6c69e08474c2ddbc5b4cd13c9e4aa4a717467f716b5fc938a734"}, - {file = "coverage-6.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9fde6b90889522c220dd56a670102ceef24955d994ff7af2cb786b4ba8fe11e4"}, - {file = "coverage-6.3.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e647a0be741edbb529a72644e999acb09f2ad60465f80757da183528941ff975"}, - {file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a5cdc3adb4f8bb8d8f5e64c2e9e282bc12980ef055ec6da59db562ee9bdfefa"}, - {file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2dd70a167843b4b4b2630c0c56f1b586fe965b4f8ac5da05b6690344fd065c6b"}, - {file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9ad0a117b8dc2061ce9461ea4c1b4799e55edceb236522c5b8f958ce9ed8fa9a"}, - {file = "coverage-6.3.1-cp38-cp38-win32.whl", hash = "sha256:e92c7a5f7d62edff50f60a045dc9542bf939758c95b2fcd686175dd10ce0ed10"}, - {file = "coverage-6.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:482fb42eea6164894ff82abbcf33d526362de5d1a7ed25af7ecbdddd28fc124f"}, - {file = "coverage-6.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c5b81fb37db76ebea79aa963b76d96ff854e7662921ce742293463635a87a78d"}, - {file = "coverage-6.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a4f923b9ab265136e57cc14794a15b9dcea07a9c578609cd5dbbfff28a0d15e6"}, - {file = "coverage-6.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56d296cbc8254a7dffdd7bcc2eb70be5a233aae7c01856d2d936f5ac4e8ac1f1"}, - {file = "coverage-6.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1245ab82e8554fa88c4b2ab1e098ae051faac5af829efdcf2ce6b34dccd5567c"}, - {file = "coverage-6.3.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f2b05757c92ad96b33dbf8e8ec8d4ccb9af6ae3c9e9bd141c7cc44d20c6bcba"}, - {file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9e3dd806f34de38d4c01416344e98eab2437ac450b3ae39c62a0ede2f8b5e4ed"}, - {file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d651fde74a4d3122e5562705824507e2f5b2d3d57557f1916c4b27635f8fbe3f"}, - {file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:704f89b87c4f4737da2860695a18c852b78ec7279b24eedacab10b29067d3a38"}, - {file = "coverage-6.3.1-cp39-cp39-win32.whl", hash = "sha256:2aed4761809640f02e44e16b8b32c1a5dee5e80ea30a0ff0912158bde9c501f2"}, - {file = "coverage-6.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:9976fb0a5709988778ac9bc44f3d50fccd989987876dfd7716dee28beed0a9fa"}, - {file = "coverage-6.3.1-pp36.pp37.pp38-none-any.whl", hash = "sha256:463e52616ea687fd323888e86bf25e864a3cc6335a043fad6bbb037dbf49bbe2"}, - {file = "coverage-6.3.1.tar.gz", hash = "sha256:6c3f6158b02ac403868eea390930ae64e9a9a2a5bbfafefbb920d29258d9f2f8"}, + {file = "coverage-6.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9b27d894748475fa858f9597c0ee1d4829f44683f3813633aaf94b19cb5453cf"}, + {file = "coverage-6.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37d1141ad6b2466a7b53a22e08fe76994c2d35a5b6b469590424a9953155afac"}, + {file = "coverage-6.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9987b0354b06d4df0f4d3e0ec1ae76d7ce7cbca9a2f98c25041eb79eec766f1"}, + {file = "coverage-6.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:26e2deacd414fc2f97dd9f7676ee3eaecd299ca751412d89f40bc01557a6b1b4"}, + {file = "coverage-6.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dd8bafa458b5c7d061540f1ee9f18025a68e2d8471b3e858a9dad47c8d41903"}, + {file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:46191097ebc381fbf89bdce207a6c107ac4ec0890d8d20f3360345ff5976155c"}, + {file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6f89d05e028d274ce4fa1a86887b071ae1755082ef94a6740238cd7a8178804f"}, + {file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58303469e9a272b4abdb9e302a780072c0633cdcc0165db7eec0f9e32f901e05"}, + {file = "coverage-6.3.2-cp310-cp310-win32.whl", hash = "sha256:2fea046bfb455510e05be95e879f0e768d45c10c11509e20e06d8fcaa31d9e39"}, + {file = "coverage-6.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:a2a8b8bcc399edb4347a5ca8b9b87e7524c0967b335fbb08a83c8421489ddee1"}, + {file = "coverage-6.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f1555ea6d6da108e1999b2463ea1003fe03f29213e459145e70edbaf3e004aaa"}, + {file = "coverage-6.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5f4e1edcf57ce94e5475fe09e5afa3e3145081318e5fd1a43a6b4539a97e518"}, + {file = "coverage-6.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a15dc0a14008f1da3d1ebd44bdda3e357dbabdf5a0b5034d38fcde0b5c234b7"}, + {file = "coverage-6.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21b7745788866028adeb1e0eca3bf1101109e2dc58456cb49d2d9b99a8c516e6"}, + {file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8ce257cac556cb03be4a248d92ed36904a59a4a5ff55a994e92214cde15c5bad"}, + {file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b0be84e5a6209858a1d3e8d1806c46214e867ce1b0fd32e4ea03f4bd8b2e3359"}, + {file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:acf53bc2cf7282ab9b8ba346746afe703474004d9e566ad164c91a7a59f188a4"}, + {file = "coverage-6.3.2-cp37-cp37m-win32.whl", hash = "sha256:8bdde1177f2311ee552f47ae6e5aa7750c0e3291ca6b75f71f7ffe1f1dab3dca"}, + {file = "coverage-6.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b31651d018b23ec463e95cf10070d0b2c548aa950a03d0b559eaa11c7e5a6fa3"}, + {file = "coverage-6.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:07e6db90cd9686c767dcc593dff16c8c09f9814f5e9c51034066cad3373b914d"}, + {file = "coverage-6.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c6dbb42f3ad25760010c45191e9757e7dce981cbfb90e42feef301d71540059"}, + {file = "coverage-6.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c76aeef1b95aff3905fb2ae2d96e319caca5b76fa41d3470b19d4e4a3a313512"}, + {file = "coverage-6.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cf5cfcb1521dc3255d845d9dca3ff204b3229401994ef8d1984b32746bb45ca"}, + {file = "coverage-6.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fbbdc8d55990eac1b0919ca69eb5a988a802b854488c34b8f37f3e2025fa90d"}, + {file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ec6bc7fe73a938933d4178c9b23c4e0568e43e220aef9472c4f6044bfc6dd0f0"}, + {file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9baff2a45ae1f17c8078452e9e5962e518eab705e50a0aa8083733ea7d45f3a6"}, + {file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd9e830e9d8d89b20ab1e5af09b32d33e1a08ef4c4e14411e559556fd788e6b2"}, + {file = "coverage-6.3.2-cp38-cp38-win32.whl", hash = "sha256:f7331dbf301b7289013175087636bbaf5b2405e57259dd2c42fdcc9fcc47325e"}, + {file = "coverage-6.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:68353fe7cdf91f109fc7d474461b46e7f1f14e533e911a2a2cbb8b0fc8613cf1"}, + {file = "coverage-6.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b78e5afb39941572209f71866aa0b206c12f0109835aa0d601e41552f9b3e620"}, + {file = "coverage-6.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e21876082ed887baed0146fe222f861b5815455ada3b33b890f4105d806128d"}, + {file = "coverage-6.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34626a7eee2a3da12af0507780bb51eb52dca0e1751fd1471d0810539cefb536"}, + {file = "coverage-6.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ebf730d2381158ecf3dfd4453fbca0613e16eaa547b4170e2450c9707665ce7"}, + {file = "coverage-6.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd6fe30bd519694b356cbfcaca9bd5c1737cddd20778c6a581ae20dc8c04def2"}, + {file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:96f8a1cb43ca1422f36492bebe63312d396491a9165ed3b9231e778d43a7fca4"}, + {file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dd035edafefee4d573140a76fdc785dc38829fe5a455c4bb12bac8c20cfc3d69"}, + {file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5ca5aeb4344b30d0bec47481536b8ba1181d50dbe783b0e4ad03c95dc1296684"}, + {file = "coverage-6.3.2-cp39-cp39-win32.whl", hash = "sha256:f5fa5803f47e095d7ad8443d28b01d48c0359484fec1b9d8606d0e3282084bc4"}, + {file = "coverage-6.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:9548f10d8be799551eb3a9c74bbf2b4934ddb330e08a73320123c07f95cc2d92"}, + {file = "coverage-6.3.2-pp36.pp37.pp38-none-any.whl", hash = "sha256:18d520c6860515a771708937d2f78f63cc47ab3b80cb78e86573b0a760161faf"}, + {file = "coverage-6.3.2.tar.gz", hash = "sha256:03e2a7826086b91ef345ff18742ee9fc47a6839ccd517061ef8fa1976e652ce9"}, ] distlib = [ {file = "distlib-0.3.2-py2.py3-none-any.whl", hash = "sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c"}, diff --git a/pyproject.toml b/pyproject.toml index 655a22e..061ff48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ python = "^3.7" anyio = "^3.3.1" msgpack = "^1.0.2" websockets = "^10.0" -asyncstdlib = "^3.10.1" [tool.poetry.dev-dependencies] msgpack-types = "^0.2.0" @@ -34,6 +33,7 @@ trio = "^0.19.0" pytest-mock = "^3.7.0" tox = "^3.24.5" mock = {version = "^4.0.3", python = "<3.8"} +asyncstdlib = "^3.10.1" [build-system] requires = ["poetry-core>=1.0.0"] From 42ca6dc3e8000624257561cfc329dd163c110e91 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Wed, 23 Mar 2022 16:07:31 -0500 Subject: [PATCH 2/2] refactor streams into separate context --- README.md | 6 +- src/rpcx/client.py | 150 +++++++++++++++++++++------------------ tests/test_client.py | 58 ++++++++------- tests/test_functional.py | 39 +++++----- 4 files changed, 139 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index 5cff08b..59f7181 100644 --- a/README.md +++ b/README.md @@ -63,16 +63,16 @@ async def main() -> None: assert await client.request("add", 1, 2) == 3 # Streaming (server to client) example - async with client.request_stream("fibonacci", 6) as stream: + async with client.request_stream("fibonacci", 6) as request, request.stream as stream: async for num in stream: print(num) # 1, 1, 2, 3, 5, 8 # Streaming (client to server) example - async with client.request_stream("sum") as stream: + async with client.request_stream("sum") as request, request.stream as stream: for num in range(10): await stream.send(num) - assert await stream == 45 + assert await request == 45 anyio.run(main) diff --git a/src/rpcx/client.py b/src/rpcx/client.py index 87ba523..87b8089 100644 --- a/src/rpcx/client.py +++ b/src/rpcx/client.py @@ -1,19 +1,11 @@ import logging import math -import sys from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Generator, Optional - -if sys.version_info >= (3, 9): # pragma: nocover - from collections.abc import Awaitable -else: # pragma: nocover - from typing import Awaitable +from dataclasses import dataclass, field +from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Generator, Optional, Union import anyio -import asyncstdlib as astd from anyio.abc import AnyByteStream -from anyio.streams.memory import MemoryObjectReceiveStream from .message import ( Message, @@ -55,54 +47,81 @@ class InternalError(ClientError): """ +@dataclass class _RequestTask: """ Container for streams associated with a request. """ - def __init__(self) -> None: + _value: Any = object() + _lock: anyio.Lock = field(default_factory=anyio.Lock) + stream: Union["RequestStream", None] = None + + def __post_init__(self) -> None: # There will only ever be one response, buffer size of 1 is appropriate self.response_producer, self.response_consumer = anyio.create_memory_object_stream(1, item_type=Response) # There may be many stream chunks, do not block on receiving them. self.stream_producer, self.stream_consumer = anyio.create_memory_object_stream(math.inf) - @astd.lru_cache(maxsize=1) # type: ignore[misc] + def __await__(self) -> Generator[None, None, Any]: + return self.get_response().__await__() + async def get_response(self) -> Any: - with self.response_producer, self.response_consumer: - response = await self.response_consumer.receive() + async with self._lock: + if self._value is not _RequestTask._value: + return self._value - if response.status_is_ok: - return response.value - elif response.status_is_invalid: - raise InvalidValue(response.value) - elif response.status_is_internal: - raise InternalError(response.value) - else: - raise RemoteError(response.value) + with self.response_producer, self.response_consumer: + response = await self.response_consumer.receive() + self._value = response.value + + if response.status_is_ok: + return self._value + elif response.status_is_invalid: + raise InvalidValue(response.value) + elif response.status_is_internal: + raise InternalError(response.value) + else: + raise RemoteError(response.value) @dataclass(repr=False) -class RequestStream(Awaitable[Any], AsyncIterator["RequestStream"]): +class RequestStream(AsyncIterator["RequestStream"]): """ Yielded from `RPCClient.request_stream()` """ - _get_response: Callable[[], Coroutine[None, None, None]] - _stream_consumer: MemoryObjectReceiveStream[Any] - send: Callable[[Any], Coroutine[None, None, None]] + task: _RequestTask + req_id: int + _send_msg: Callable[[Union[RequestStreamChunk, RequestStreamEnd]], Coroutine[None, None, None]] + _did_send_chunk: bool = False + _did_enter_ctx: bool = False + + async def send(self, chunk: Any) -> None: + self._did_send_chunk = True + await self._send_msg(RequestStreamChunk(self.req_id, chunk)) + + async def __aenter__(self) -> "RequestStream": + self._did_enter_ctx = True + return self + + async def __aexit__(self, *args: Any) -> None: + if self._did_send_chunk: + # Sent a stream chunk, must send the stream end + with anyio.CancelScope(shield=True): + await self._send_msg(RequestStreamEnd(self.req_id)) def __aiter__(self) -> "RequestStream": + if not self._did_enter_ctx: + raise RuntimeError("Stream must have entered context to iterate") return self async def __anext__(self) -> Any: try: - return await self._stream_consumer.receive() + return await self.task.stream_consumer.receive() except anyio.EndOfStream: raise StopAsyncIteration() - def __await__(self) -> Generator[None, None, Any]: - return self._get_response().__await__() - class RPCClient: def __init__(self, stream: AnyByteStream, raise_on_error: bool = False) -> None: @@ -144,52 +163,45 @@ async def __aenter__(self) -> "RPCClient": async def __aexit__(self, *args: Any) -> Optional[bool]: return await self._ctx.__aexit__(*args) + @asynccontextmanager + async def _request_context(self, request: Request) -> AsyncIterator[_RequestTask]: + """ + Send a request, create a task, and send cancellation if cancelled. + """ + await self.send_msg(request) + task = self.tasks[request.id] = _RequestTask() + + async def send_cancel() -> None: + with anyio.CancelScope(shield=True): + await self.send_msg(RequestCancel(request.id)) + + with task.stream_producer, task.stream_consumer: + try: + yield task + except anyio.get_cancelled_exc_class(): + await send_cancel() + raise + except anyio.ExceptionGroup as exc: + if any(isinstance(e, anyio.get_cancelled_exc_class()) for e in exc.exceptions): + await send_cancel() + raise + finally: + del self.tasks[request.id] + async def request(self, method: str, *args: Any, **kwargs: Any) -> Any: req = Request(id=self.next_msg_id, method=method, args=args, kwargs=kwargs) - await self.send_msg(req) - task = self.tasks[req.id] = _RequestTask() - - try: + async with self._request_context(req) as task: return await task.get_response() - except anyio.get_cancelled_exc_class(): - with anyio.CancelScope(shield=True): - await self.send_msg(RequestCancel(req.id)) - raise - finally: - del self.tasks[req.id] @asynccontextmanager - async def request_stream(self, method: str, *args: Any, **kwargs: Any) -> AsyncIterator[RequestStream]: + async def request_stream(self, method: str, *args: Any, **kwargs: Any) -> AsyncIterator[_RequestTask]: req = Request(id=self.next_msg_id, method=method, args=args, kwargs=kwargs) - await self.send_msg(req) - task = self.tasks[req.id] = _RequestTask() - did_send_chunk = False - - async def send_stream_chunk(value: Any) -> None: - nonlocal did_send_chunk - did_send_chunk = True - await self.send_msg(RequestStreamChunk(req.id, value)) - - stream = RequestStream( - _get_response=task.get_response, - _stream_consumer=task.stream_consumer, - send=send_stream_chunk, - ) - - try: - with task.stream_producer, task.stream_consumer: - yield stream - if did_send_chunk: - # Sent a stream chunk, must send the stream end - await self.send_msg(RequestStreamEnd(req.id)) - await task.get_response() - except anyio.get_cancelled_exc_class(): - with anyio.CancelScope(shield=True): - await self.send_msg(RequestCancel(req.id)) - raise - finally: - del self.tasks[req.id] + async with self._request_context(req) as task: + task.stream = RequestStream(task=task, req_id=req.id, _send_msg=self.send_msg) + async with anyio.create_task_group() as task_group: + task_group.start_soon(task.get_response) + yield task async def receive_loop(self) -> None: """ diff --git a/tests/test_client.py b/tests/test_client.py index c533f83..c2f01ce 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -30,6 +30,14 @@ async def test_response(test_client): ) +async def test_stream_no_enter_context(test_client): + async with test_client.client as client: + with pytest.raises(RuntimeError, match="Stream must have entered context to iterate"): + async with client.request_stream("any") as request: + async for item in request.stream: + pass + + async def test_cancel(test_client): async with test_client.client as client: async with anyio.create_task_group() as task_group: @@ -55,15 +63,16 @@ async def test_stream_response(test_client): await test_client.server_stream.send(message_to_bytes(Response(id=1, status=ResponseStatus.OK, value="1"))) async with test_client.client as client: - async with client.request_stream("any") as stream: + async with client.request_stream("any") as request: assert message_from_bytes(await test_client.server_stream.receive()) == Request( id=1, method="any", args=(), kwargs={} ) items = [] - async for item in stream: - items.append(item) + async with request.stream as stream: + async for item in stream: + items.append(item) assert items == ["a", "b", "c"] - assert await stream == "1" + assert await request == "1" async def test_stream_receive_data_after_task_ends(test_client, caplog): @@ -76,12 +85,12 @@ async def test_stream_receive_data_after_task_ends(test_client, caplog): async with test_client.client as client: client.raise_on_error = False - async with client.request_stream("any") as stream: + async with client.request_stream("any") as request, request.stream as stream: items = [] async for item in stream: items.append(item) assert items == ["a", "b", "c"] - assert await stream == "1" + assert await request == "1" assert "Client receive error: ResponseStreamChunk(id=1, value='d')" in caplog.text @@ -89,7 +98,7 @@ async def test_stream_cancel(test_client): async with test_client.client as client: items = [] async with anyio.create_task_group() as task_group: - async with client.request_stream("any") as stream: + async with client.request_stream("any") as request, request.stream as stream: async def cancel_soon(): await anyio.wait_all_tasks_blocked() @@ -103,32 +112,33 @@ async def cancel_soon(): async for item in stream: items.append(item) - assert items == ["a"] - assert message_from_bytes(await test_client.server_stream.receive()) == Request( - id=1, method="any", args=(), kwargs={} - ) - assert message_from_bytes(await test_client.server_stream.receive()) == RequestCancel(id=1) + assert items == ["a"] + assert message_from_bytes(await test_client.server_stream.receive()) == Request( + id=1, method="any", args=(), kwargs={} + ) + assert message_from_bytes(await test_client.server_stream.receive()) == RequestCancel(id=1) async def test_send_stream(test_client): async with test_client.client as client: - async with client.request_stream("any") as stream: - await stream.send("a") - await stream.send("b") - await stream.send("c") + async with client.request_stream("any") as request: + async with request.stream as stream: + await stream.send("a") + await stream.send("b") + await stream.send("c") await test_client.server_stream.send(message_to_bytes(Response(id=1, status=ResponseStatus.OK, value="1"))) - assert await stream == "1" + assert await request == "1" - assert message_from_bytes(await test_client.server_stream.receive()) == Request( - id=1, method="any", args=(), kwargs={} - ) - assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamChunk(id=1, value="a") - assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamChunk(id=1, value="b") - assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamChunk(id=1, value="c") - assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamEnd(id=1) + assert message_from_bytes(await test_client.server_stream.receive()) == Request( + id=1, method="any", args=(), kwargs={} + ) + assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamChunk(id=1, value="a") + assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamChunk(id=1, value="b") + assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamChunk(id=1, value="c") + assert message_from_bytes(await test_client.server_stream.receive()) == RequestStreamEnd(id=1) async def test_receive_unhandled_message(test_client, caplog): diff --git a/tests/test_functional.py b/tests/test_functional.py index 8dbb376..0fc0a01 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -51,9 +51,10 @@ async def server_stream(stream: Stream): manager.register("server_stream", server_stream) async with test_stack(manager) as stack: - async with stack.client.request_stream("server_stream") as stream: - async for i, response in astd.enumerate(stream): - assert response == i + async with stack.client.request_stream("server_stream") as request: + async with request.stream as stream: + async for i, response in astd.enumerate(stream): + assert response == i # There is a task in this context assert stack.client.tasks @@ -62,7 +63,7 @@ async def server_stream(stream: Stream): assert not stack.client.tasks # Then we can still await a result from the stream - assert await stream == "vincent loves strings" + assert await request == "vincent loves strings" # And the server task has finished assert not stack.server.dispatcher.tasks @@ -79,11 +80,12 @@ async def client_stream(stream: Stream): manager.register("client_stream", client_stream) async with test_stack(manager) as stack: - async with stack.client.request_stream("client_stream") as stream: - for i in range(10): - await stream.send(i) + async with stack.client.request_stream("client_stream") as request: + async with request.stream as stream: + for i in range(10): + await stream.send(i) - assert await stream == sum(range(10)) + assert await request == sum(range(10)) async def test_bidirectional_stream(test_stack): @@ -97,15 +99,16 @@ async def bidirectional_stream(stream: Stream): manager.register("bidirectional_stream", bidirectional_stream) async with test_stack(manager) as stack: - async with stack.client.request_stream("bidirectional_stream") as stream: + async with stack.client.request_stream("bidirectional_stream") as request: i = 0 - await stream.send(i) - async for response in stream: - i += 1 - assert response == i + async with request.stream as stream: await stream.send(i) - if i >= 10: - break + async for response in stream: + i += 1 + assert response == i + await stream.send(i) + if i >= 10: + break async def test_invalid_name(test_stack): @@ -179,9 +182,9 @@ async def cancel_soon(): tg.cancel_scope.cancel() async with anyio.create_task_group() as tg: - tg.start_soon(cancel_soon) - async with stack.client.request_stream("simple_cancel"): - await anyio.sleep_forever() + async with stack.client.request_stream("simple_cancel") as request: + tg.start_soon(cancel_soon) + await request with anyio.fail_after(1): await cancelled_event.wait()