diff --git a/poetry.lock b/poetry.lock index 466c12fb5..d545ce1b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2633,8 +2633,6 @@ files = [ {file = "lxml-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35bc626eec405f745199200ccb5c6b36f202675d204aa29bb52e27ba2b71dea8"}, {file = "lxml-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:246b40f8a4aec341cbbf52617cad8ab7c888d944bfe12a6abd2b1f6cfb6f6082"}, {file = "lxml-6.0.0-cp310-cp310-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:2793a627e95d119e9f1e19720730472f5543a6d84c50ea33313ce328d870f2dd"}, - {file = "lxml-6.0.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:46b9ed911f36bfeb6338e0b482e7fe7c27d362c52fde29f221fddbc9ee2227e7"}, - {file = "lxml-6.0.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b4790b558bee331a933e08883c423f65bbcd07e278f91b2272489e31ab1e2b4"}, {file = "lxml-6.0.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2030956cf4886b10be9a0285c6802e078ec2391e1dd7ff3eb509c2c95a69b76"}, {file = "lxml-6.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d23854ecf381ab1facc8f353dcd9adeddef3652268ee75297c1164c987c11dc"}, {file = "lxml-6.0.0-cp310-cp310-manylinux_2_31_armv7l.whl", hash = "sha256:43fe5af2d590bf4691531b1d9a2495d7aab2090547eaacd224a3afec95706d76"}, @@ -2647,8 +2645,6 @@ files = [ {file = "lxml-6.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4ee56288d0df919e4aac43b539dd0e34bb55d6a12a6562038e8d6f3ed07f9e36"}, {file = "lxml-6.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b8dd6dd0e9c1992613ccda2bcb74fc9d49159dbe0f0ca4753f37527749885c25"}, {file = "lxml-6.0.0-cp311-cp311-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:d7ae472f74afcc47320238b5dbfd363aba111a525943c8a34a1b657c6be934c3"}, - {file = "lxml-6.0.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5592401cdf3dc682194727c1ddaa8aa0f3ddc57ca64fd03226a430b955eab6f6"}, - {file = "lxml-6.0.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:58ffd35bd5425c3c3b9692d078bf7ab851441434531a7e517c4984d5634cd65b"}, {file = "lxml-6.0.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f720a14aa102a38907c6d5030e3d66b3b680c3e6f6bc95473931ea3c00c59967"}, {file = "lxml-6.0.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2a5e8d207311a0170aca0eb6b160af91adc29ec121832e4ac151a57743a1e1e"}, {file = "lxml-6.0.0-cp311-cp311-manylinux_2_31_armv7l.whl", hash = "sha256:2dd1cc3ea7e60bfb31ff32cafe07e24839df573a5e7c2d33304082a5019bcd58"}, @@ -2661,15 +2657,11 @@ files = [ {file = "lxml-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:78718d8454a6e928470d511bf8ac93f469283a45c354995f7d19e77292f26108"}, {file = "lxml-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:84ef591495ffd3f9dcabffd6391db7bb70d7230b5c35ef5148354a134f56f2be"}, {file = "lxml-6.0.0-cp312-cp312-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:2930aa001a3776c3e2601cb8e0a15d21b8270528d89cc308be4843ade546b9ab"}, - {file = "lxml-6.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:219e0431ea8006e15005767f0351e3f7f9143e793e58519dc97fe9e07fae5563"}, - {file = "lxml-6.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bd5913b4972681ffc9718bc2d4c53cde39ef81415e1671ff93e9aa30b46595e7"}, {file = "lxml-6.0.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:390240baeb9f415a82eefc2e13285016f9c8b5ad71ec80574ae8fa9605093cd7"}, - {file = "lxml-6.0.0-cp312-cp312-manylinux_2_27_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d6e200909a119626744dd81bae409fc44134389e03fbf1d68ed2a55a2fb10991"}, {file = "lxml-6.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ca50bd612438258a91b5b3788c6621c1f05c8c478e7951899f492be42defc0da"}, {file = "lxml-6.0.0-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:c24b8efd9c0f62bad0439283c2c795ef916c5a6b75f03c17799775c7ae3c0c9e"}, {file = "lxml-6.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:afd27d8629ae94c5d863e32ab0e1d5590371d296b87dae0a751fb22bf3685741"}, {file = "lxml-6.0.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:54c4855eabd9fc29707d30141be99e5cd1102e7d2258d2892314cf4c110726c3"}, - {file = "lxml-6.0.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c907516d49f77f6cd8ead1322198bdfd902003c3c330c77a1c5f3cc32a0e4d16"}, {file = "lxml-6.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36531f81c8214e293097cd2b7873f178997dae33d3667caaae8bdfb9666b76c0"}, {file = "lxml-6.0.0-cp312-cp312-win32.whl", hash = "sha256:690b20e3388a7ec98e899fd54c924e50ba6693874aa65ef9cb53de7f7de9d64a"}, {file = "lxml-6.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:310b719b695b3dd442cdfbbe64936b2f2e231bb91d998e99e6f0daf991a3eba3"}, @@ -2677,22 +2669,17 @@ files = [ {file = "lxml-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6da7cd4f405fd7db56e51e96bff0865b9853ae70df0e6720624049da76bde2da"}, {file = "lxml-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b34339898bb556a2351a1830f88f751679f343eabf9cf05841c95b165152c9e7"}, {file = "lxml-6.0.0-cp313-cp313-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:51a5e4c61a4541bd1cd3ba74766d0c9b6c12d6a1a4964ef60026832aac8e79b3"}, - {file = "lxml-6.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d18a25b19ca7307045581b18b3ec9ead2b1db5ccd8719c291f0cd0a5cec6cb81"}, - {file = "lxml-6.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d4f0c66df4386b75d2ab1e20a489f30dc7fd9a06a896d64980541506086be1f1"}, {file = "lxml-6.0.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f4b481b6cc3a897adb4279216695150bbe7a44c03daba3c894f49d2037e0a24"}, - {file = "lxml-6.0.0-cp313-cp313-manylinux_2_27_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8a78d6c9168f5bcb20971bf3329c2b83078611fbe1f807baadc64afc70523b3a"}, {file = "lxml-6.0.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ae06fbab4f1bb7db4f7c8ca9897dc8db4447d1a2b9bee78474ad403437bcc29"}, {file = "lxml-6.0.0-cp313-cp313-manylinux_2_31_armv7l.whl", hash = "sha256:1fa377b827ca2023244a06554c6e7dc6828a10aaf74ca41965c5d8a4925aebb4"}, {file = "lxml-6.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1676b56d48048a62ef77a250428d1f31f610763636e0784ba67a9740823988ca"}, {file = "lxml-6.0.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:0e32698462aacc5c1cf6bdfebc9c781821b7e74c79f13e5ffc8bfe27c42b1abf"}, - {file = "lxml-6.0.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4d6036c3a296707357efb375cfc24bb64cd955b9ec731abf11ebb1e40063949f"}, {file = "lxml-6.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7488a43033c958637b1a08cddc9188eb06d3ad36582cebc7d4815980b47e27ef"}, {file = "lxml-6.0.0-cp313-cp313-win32.whl", hash = "sha256:5fcd7d3b1d8ecb91445bd71b9c88bdbeae528fefee4f379895becfc72298d181"}, {file = "lxml-6.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:2f34687222b78fff795feeb799a7d44eca2477c3d9d3a46ce17d51a4f383e32e"}, {file = "lxml-6.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:21db1ec5525780fd07251636eb5f7acb84003e9382c72c18c542a87c416ade03"}, {file = "lxml-6.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4eb114a0754fd00075c12648d991ec7a4357f9cb873042cc9a77bf3a7e30c9db"}, {file = "lxml-6.0.0-cp38-cp38-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:7da298e1659e45d151b4028ad5c7974917e108afb48731f4ed785d02b6818994"}, - {file = "lxml-6.0.0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7bf61bc4345c1895221357af8f3e89f8c103d93156ef326532d35c707e2fb19d"}, {file = "lxml-6.0.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63b634facdfbad421d4b61c90735688465d4ab3a8853ac22c76ccac2baf98d97"}, {file = "lxml-6.0.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e380e85b93f148ad28ac15f8117e2fd8e5437aa7732d65e260134f83ce67911b"}, {file = "lxml-6.0.0-cp38-cp38-win32.whl", hash = "sha256:185efc2fed89cdd97552585c624d3c908f0464090f4b91f7d92f8ed2f3b18f54"}, @@ -2700,8 +2687,6 @@ files = [ {file = "lxml-6.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:85b14a4689d5cff426c12eefe750738648706ea2753b20c2f973b2a000d3d261"}, {file = "lxml-6.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f64ccf593916e93b8d36ed55401bb7fe9c7d5de3180ce2e10b08f82a8f397316"}, {file = "lxml-6.0.0-cp39-cp39-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:b372d10d17a701b0945f67be58fae4664fd056b85e0ff0fbc1e6c951cdbc0512"}, - {file = "lxml-6.0.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a674c0948789e9136d69065cc28009c1b1874c6ea340253db58be7622ce6398f"}, - {file = "lxml-6.0.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:edf6e4c8fe14dfe316939711e3ece3f9a20760aabf686051b537a7562f4da91a"}, {file = "lxml-6.0.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:048a930eb4572829604982e39a0c7289ab5dc8abc7fc9f5aabd6fbc08c154e93"}, {file = "lxml-6.0.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0b5fa5eda84057a4f1bbb4bb77a8c28ff20ae7ce211588d698ae453e13c6281"}, {file = "lxml-6.0.0-cp39-cp39-manylinux_2_31_armv7l.whl", hash = "sha256:c352fc8f36f7e9727db17adbf93f82499457b3d7e5511368569b4c5bd155a922"}, @@ -2712,14 +2697,10 @@ files = [ {file = "lxml-6.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:e0b1520ef900e9ef62e392dd3d7ae4f5fa224d1dd62897a792cf353eb20b6cae"}, {file = "lxml-6.0.0-cp39-cp39-win_arm64.whl", hash = "sha256:e35e8aaaf3981489f42884b59726693de32dabfc438ac10ef4eb3409961fd402"}, {file = "lxml-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:dbdd7679a6f4f08152818043dbb39491d1af3332128b3752c3ec5cebc0011a72"}, - {file = "lxml-6.0.0-pp310-pypy310_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40442e2a4456e9910875ac12951476d36c0870dcb38a68719f8c4686609897c4"}, - {file = "lxml-6.0.0-pp310-pypy310_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:db0efd6bae1c4730b9c863fc4f5f3c0fa3e8f05cae2c44ae141cb9dfc7d091dc"}, {file = "lxml-6.0.0-pp310-pypy310_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ab542c91f5a47aaa58abdd8ea84b498e8e49fe4b883d67800017757a3eb78e8"}, {file = "lxml-6.0.0-pp310-pypy310_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:013090383863b72c62a702d07678b658fa2567aa58d373d963cca245b017e065"}, {file = "lxml-6.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c86df1c9af35d903d2b52d22ea3e66db8058d21dc0f59842ca5deb0595921141"}, {file = "lxml-6.0.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:4337e4aec93b7c011f7ee2e357b0d30562edd1955620fdd4aeab6aacd90d43c5"}, - {file = "lxml-6.0.0-pp39-pypy39_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ae74f7c762270196d2dda56f8dd7309411f08a4084ff2dfcc0b095a218df2e06"}, - {file = "lxml-6.0.0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:059c4cbf3973a621b62ea3132934ae737da2c132a788e6cfb9b08d63a0ef73f9"}, {file = "lxml-6.0.0-pp39-pypy39_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:17f090a9bc0ce8da51a5632092f98a7e7f84bca26f33d161a98b57f7fb0004ca"}, {file = "lxml-6.0.0-pp39-pypy39_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9da022c14baeec36edfcc8daf0e281e2f55b950249a455776f0d1adeeada4734"}, {file = "lxml-6.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a55da151d0b0c6ab176b4e761670ac0e2667817a1e0dadd04a01d0561a219349"}, @@ -5810,23 +5791,22 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] name = "starlette" -version = "0.47.2" +version = "0.27.0" description = "The little ASGI library that shines." -optional = true -python-versions = ">=3.9" -groups = ["main"] -markers = "python_version >= \"3.10\"" +optional = false +python-versions = ">=3.7" +groups = ["main", "test"] files = [ - {file = "starlette-0.47.2-py3-none-any.whl", hash = "sha256:c5847e96134e5c5371ee9fac6fdf1a67336d5815e09eb2a01fdb57a351ef915b"}, - {file = "starlette-0.47.2.tar.gz", hash = "sha256:6ae9aa5db235e4846decc1e7b79c4f346adf41e9777aebeb49dfd09bbd7023d8"}, + {file = "starlette-0.27.0-py3-none-any.whl", hash = "sha256:918416370e846586541235ccd38a474c08b80443ed31c578a418e2209b3eef91"}, + {file = "starlette-0.27.0.tar.gz", hash = "sha256:6a6b0d042acb8d469a01eba54e9cda6cbd24ac602c4cd016723117d6a7e73b75"}, ] [package.dependencies] -anyio = ">=3.6.2,<5" -typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] -full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] [[package]] name = "sympy" @@ -7167,12 +7147,13 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["crewai", "langchain-core", "openai"] +all = ["crewai", "langchain-core", "openai", "starlette"] crewai = ["crewai"] langchain = ["langchain-core"] openai = ["openai", "openai-agents", "packaging"] +starlette = ["starlette"] [metadata] lock-version = "2.1" python-versions = "^3.9,<3.14" -content-hash = "d1393ee2f99acbe114b94c1dc15d2d0d56a1266970de16c3d29ed6fbb1bb8f77" +content-hash = "b9079f98d4ff64391c365588c34b6cec022d75970ae71ac4008d45f143b8e02b" diff --git a/pyproject.toml b/pyproject.toml index b36e65c3f..5d7c74885 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ openai-agents = { version = "<0.2.1", optional = true } galileo-core = "~=3.67.3" backoff = "^2.2.1" crewai = { version = ">=0.152.0,<=0.201.1", optional = true, python = ">=3.10,<3.14" } +starlette = { version = "^0.27.0", optional = true } [tool.poetry.group.test.dependencies] pytest = "^8.4.0" @@ -34,6 +35,7 @@ galileo-core = { extras = ["testing"], version = "~=3.67.3" } pytest-env = "^1.1.5" langchain-core = "^0.3.68" pytest-sugar = "^1.0.0" +starlette = "^0.27.0" vcrpy = "^7.0.0" time-machine = "^2.17.0" # freezegun causes problems with pydantic model validations @@ -315,7 +317,8 @@ formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"] langchain = ["langchain-core"] openai = ["openai", "packaging (>=24.2,<25.0)", "openai-agents"] crewai = ["crewai (>=0.152.0,<=0.201.1)"] -all = ["langchain-core", "openai", "crewai"] +starlette = ["starlette"] +all = ["langchain-core", "openai", "crewai", "starlette"] [build-system] diff --git a/src/galileo/decorator.py b/src/galileo/decorator.py index eb58af226..00fa88be4 100644 --- a/src/galileo/decorator.py +++ b/src/galileo/decorator.py @@ -60,6 +60,7 @@ def call_llm(prompt, temperature=0.7): from galileo.schema.metrics import LocalMetricConfig from galileo.schema.trace import SPAN_TYPE from galileo.utils import _get_timestamp +from galileo.utils.distributed_tracing import extract_tracing_headers from galileo.utils.logging import is_concludable_span_type, is_textual_span_type from galileo.utils.serialization import EventSerializer, serialize_to_str from galileo.utils.singleton import GalileoLoggerSingleton @@ -311,7 +312,10 @@ async def async_wrapper(*args, **kwargs) -> Any: func_args=args, func_kwargs=kwargs, ) - self._prepare_call(span_type, span_params, dataset_record) + if span_params is None: + return await func(*args, **kwargs) + + self._prepare_call(span_type, span_params, dataset_record, func_args=args, func_kwargs=kwargs) result = None try: @@ -365,7 +369,10 @@ def sync_wrapper(*args, **kwargs) -> Any: func_args=args, func_kwargs=kwargs, ) - self._prepare_call(span_type, span_params, dataset_record) + if span_params is None: + return func(*args, **kwargs) + + self._prepare_call(span_type, span_params, dataset_record, func_args=args, func_kwargs=kwargs) result = None try: @@ -553,7 +560,12 @@ def _get_span_param_names(self, span_type: SPAN_TYPE) -> list[str]: return span_params.get(span_type, common_params) def _prepare_call( - self, span_type: Optional[SPAN_TYPE], span_params: dict[str, Any], dataset_record: Optional[DatasetRecord] + self, + span_type: Optional[SPAN_TYPE], + span_params: dict[str, Any], + dataset_record: Optional[DatasetRecord], + func_args: tuple = (), + func_kwargs: Optional[dict] = None, ) -> None: """ Prepare the call for logging by setting up trace and span contexts. @@ -564,23 +576,45 @@ def _prepare_call( Type of span to create span_params Parameters for the span + dataset_record + Optional dataset record + func_args + Function arguments (used to extract distributed tracing headers) + func_kwargs + Function keyword arguments (used to extract distributed tracing headers) """ - client_instance = self.get_logger_instance() + # Extract distributed tracing headers from function arguments + trace_id, span_id = extract_tracing_headers(func_args=func_args, func_kwargs=func_kwargs) + + client_instance = self.get_logger_instance(trace_id=trace_id, span_id=span_id) _logger.debug(f"client_instance {id(client_instance)} {client_instance}") input_ = span_params.get("input_serialized", "") name = span_params.get("name", "") - if not _trace_context.get(): - # If the singleton logger has an active trace, use it - if client_instance.has_active_trace(): + # If we have trace_id/span_id (distributed tracing in streaming mode), the logger should have loaded an existing trace + # Set the trace context immediately so we don't create a new trace + # In streaming mode, traces are created immediately so we can add spans to them + if trace_id or span_id: + # In streaming mode with distributed tracing, the trace should be in traces[0] after _init_trace() or _init_span() + if client_instance.traces: + # Trace is loaded in traces list - use it! + _trace_context.set(client_instance.traces[0]) + _logger.debug(f"Set trace context from distributed tracing: trace_id={client_instance.traces[0].id}") + else: + # This should not happen in streaming mode - if trace_id/span_id was provided, trace should be loaded + raise ValueError( + f"Distributed tracing trace not found in streaming mode (trace_id={trace_id}, span_id={span_id}). " + "The trace should have been loaded during logger initialization." + ) + elif not _trace_context.get(): + # Normal mode: no distributed tracing, start a new trace if needed + if client_instance.has_active_trace() and client_instance.traces: trace = client_instance.traces[-1] else: - # If no trace is available, start a new one trace = client_instance.start_trace( input=input_, name=name, - # TODO: add dataset_row_id dataset_input=dataset_record.input if dataset_record else None, dataset_output=dataset_record.output if dataset_record else None, dataset_metadata=dataset_record.metadata if dataset_record else None, @@ -707,7 +741,10 @@ def _handle_call_result(self, span_type: Optional[SPAN_TYPE], span_params: dict[ span_params["created_at"] = created_at span_params["duration_ns"] = 0 - logger = self.get_logger_instance() + # Get logger instance - extract trace_id/span_id from context for nested calls + # to ensure we get the same cached logger instance (cache key includes trace_id/span_id) + trace_id, span_id = extract_tracing_headers() + logger = self.get_logger_instance(trace_id=trace_id, span_id=span_id) # If the span type is a workflow or agent, conclude it _logger.debug(f"{span_type=} {stack=} {span_params=}") @@ -829,7 +866,12 @@ async def _wrap_async_generator_result( self._handle_call_result(span_type, span_params, output) def get_logger_instance( - self, project: Optional[str] = None, log_stream: Optional[str] = None, experiment_id: Optional[str] = None + self, + project: Optional[str] = None, + log_stream: Optional[str] = None, + experiment_id: Optional[str] = None, + trace_id: Optional[str] = None, + span_id: Optional[str] = None, ) -> GalileoLogger: """ Get the Galileo Logger instance for the current decorator context. @@ -840,15 +882,28 @@ def get_logger_instance( Optional project name to use log_stream Optional log stream name to use + experiment_id + Optional experiment ID to use + trace_id + Optional trace ID for distributed tracing (automatically extracted from headers if not provided) + span_id + Optional span ID for distributed tracing (automatically extracted from headers if not provided) Returns ------- GalileoLogger instance configured with the specified project and log stream """ + # Get mode from context (defaults to "batch" if not set) + # Mode will be overridden to "streaming" if trace_id/span_id is provided + mode = _mode_context.get() or "batch" + return GalileoLoggerSingleton().get( project=project or _project_context.get(), log_stream=log_stream or _log_stream_context.get(), experiment_id=experiment_id or _experiment_id_context.get(), + mode=mode, + trace_id=trace_id, + span_id=span_id, ) def get_current_project(self) -> Optional[str]: @@ -976,6 +1031,7 @@ def init( log_stream: Optional[str] = None, experiment_id: Optional[str] = None, local_metrics: Optional[list[LocalMetricConfig]] = None, + mode: str = "batch", ) -> None: """ Initialize the context with a project and log stream. Optionally, it can also be used @@ -994,15 +1050,19 @@ def init( The experiment id. Defaults to None. local_metrics Local metrics configs to run on the traces/spans before submitting them for ingestion. Defaults to None. + mode + The logging mode. Use "streaming" for distributed tracing or real-time logging. + Use "batch" for batch processing. Defaults to "batch". """ GalileoLoggerSingleton().reset(project=project, log_stream=log_stream, experiment_id=experiment_id) GalileoLoggerSingleton().get( - project=project, log_stream=log_stream, experiment_id=experiment_id, local_metrics=local_metrics + project=project, log_stream=log_stream, experiment_id=experiment_id, local_metrics=local_metrics, mode=mode ) _project_context.set(project) _log_stream_context.set(log_stream) _experiment_id_context.set(experiment_id) + _mode_context.set(mode) _span_stack_context.set([]) _trace_context.set(None) @@ -1045,6 +1105,35 @@ def set_session(self, session_id: str) -> None: """ self.get_logger_instance().set_session(session_id) + def get_tracing_headers(self) -> dict[str, str]: + """ + Get current trace and span IDs as headers for distributed tracing. + + Similar to LangSmith's `get_current_run_tree().to_headers()`, this method + returns a dictionary of headers that can be passed to HTTP requests to + propagate distributed tracing context. + + Returns + ------- + dict[str, str] + Dictionary with X-Trace-ID and/or X-Span-ID headers if available + """ + headers = {} + trace = self.get_current_trace() + span_stack = self.get_current_span_stack() + + if trace: + headers["X-Trace-ID"] = str(trace.id) + + # Get the most recent span (top of stack) + if span_stack: + headers["X-Span-ID"] = str(span_stack[-1].id) + elif trace: + # If no span but we have a trace, use trace ID as span ID + headers["X-Span-ID"] = str(trace.id) + + return headers + galileo_context = GalileoDecorator() log = galileo_context.log diff --git a/src/galileo/middleware/__init__.py b/src/galileo/middleware/__init__.py new file mode 100644 index 000000000..dca99a775 --- /dev/null +++ b/src/galileo/middleware/__init__.py @@ -0,0 +1,10 @@ +""" +Galileo middleware for web frameworks. + +This module provides middleware for automatically extracting distributed tracing +headers from HTTP requests and making them available to the @log decorator. +""" + +from galileo.middleware.tracing import TracingMiddleware + +__all__ = ["TracingMiddleware"] diff --git a/src/galileo/middleware/tracing.py b/src/galileo/middleware/tracing.py new file mode 100644 index 000000000..b51bf5276 --- /dev/null +++ b/src/galileo/middleware/tracing.py @@ -0,0 +1,108 @@ +""" +Tracing middleware for FastAPI/Starlette applications. + +This middleware automatically extracts distributed tracing headers from incoming +HTTP requests and stores them in a context variable, making them available to +the @log decorator throughout the request lifecycle. + +How it works: +1. The `dispatch` method intercepts incoming HTTP requests +2. Extracts X-Trace-ID and X-Span-ID headers from the request +3. Stores them in ContextVar (thread-local context variables) +4. The @log decorator (via extract_tracing_headers in distributed_tracing.py) + reads these context variables to automatically configure distributed tracing + +This middleware is only for ASGI frameworks (FastAPI/Starlette). For Flask (WSGI), +users can manually pass request objects to decorated functions, and the decorator +will extract headers from the request object directly. +""" + +import logging +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from typing import Optional + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +_logger = logging.getLogger(__name__) + +# Context variables to store trace and span IDs +_trace_id_context: ContextVar[Optional[str]] = ContextVar("trace_id_context", default=None) +_span_id_context: ContextVar[Optional[str]] = ContextVar("span_id_context", default=None) + + +def get_trace_id() -> Optional[str]: + """Get the current trace ID from context.""" + return _trace_id_context.get() + + +def get_span_id() -> Optional[str]: + """Get the current span ID from context.""" + return _span_id_context.get() + + +class TracingMiddleware(BaseHTTPMiddleware): + """ + Middleware that extracts distributed tracing headers from HTTP requests. + + This middleware automatically extracts X-Trace-ID and X-Span-ID headers + from incoming requests and stores them in context variables. The @log decorator + can then read these values to automatically configure distributed tracing. + + Usage: + from fastapi import FastAPI + from galileo.middleware import TracingMiddleware + + app = FastAPI() + app.add_middleware(TracingMiddleware) + + Note: This requires starlette to be installed. Install it with: + pip install galileo[starlette] + # or + pip install starlette + # or + pip install fastapi # (which includes starlette) + """ + + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + """ + Extract tracing headers from incoming request and store in context. + + This method: + 1. Extracts X-Trace-ID and X-Span-ID headers from the HTTP request + 2. Stores them in ContextVar (context variables) that are automatically + available throughout the async request lifecycle + 3. The @log decorator (via extract_tracing_headers in distributed_tracing.py) + reads these context variables using get_trace_id() and get_span_id() + + The context variables are thread-local and async-safe, so they work correctly + with FastAPI/Starlette's async request handling. + + Parameters + ---------- + request + The incoming HTTP request + call_next + The next middleware or route handler in the chain + + Returns + ------- + Response + The HTTP response from the next handler + """ + # Extract X-Trace-ID and X-Span-ID headers (case-insensitive) + trace_id = request.headers.get("x-trace-id") or request.headers.get("X-Trace-ID") + span_id = request.headers.get("x-span-id") or request.headers.get("X-Span-ID") + + # Store in context variables for @log decorator to use + # These context variables are automatically available to extract_tracing_headers() + # via get_trace_id() and get_span_id() throughout the request lifecycle + if trace_id: + _trace_id_context.set(trace_id) + if span_id: + _span_id_context.set(span_id) + + # Call the next middleware/route handler + return await call_next(request) diff --git a/src/galileo/utils/distributed_tracing.py b/src/galileo/utils/distributed_tracing.py new file mode 100644 index 000000000..0dbe156c4 --- /dev/null +++ b/src/galileo/utils/distributed_tracing.py @@ -0,0 +1,110 @@ +""" +Utility functions for distributed tracing header detection. + +This module provides functions to automatically detect distributed tracing headers +from various web frameworks (FastAPI, Flask, Starlette, etc.) and extract trace/span IDs. + +How it works with TracingMiddleware: +1. TracingMiddleware.dispatch() extracts headers from HTTP requests and stores them + in ContextVar (context variables) using _trace_id_context and _span_id_context +2. extract_tracing_headers() first checks these context variables (via get_trace_id() + and get_span_id()) - this is the primary path when middleware is used +3. If context variables are empty, it falls back to checking function arguments + for request objects (useful when middleware isn't used or for Flask/WSGI apps) + +This two-tier approach allows the @log decorator to work seamlessly: +- With middleware (FastAPI/Starlette): Headers are automatically extracted and stored +- Without middleware (Flask or manual): Headers are extracted from request objects + passed as function arguments +""" + +import logging +from typing import Any, Optional + +from galileo.middleware.tracing import get_span_id, get_trace_id + +_logger = logging.getLogger(__name__) + + +def extract_tracing_headers( + func_args: tuple = (), func_kwargs: Optional[dict] = None +) -> tuple[Optional[str], Optional[str]]: + """ + Extract distributed tracing headers from context or function arguments. + + This function first checks context variables (set by TracingMiddleware), + then falls back to looking for web framework request objects in function arguments. + + Parameters + ---------- + func_args + Positional arguments passed to the function + func_kwargs + Keyword arguments passed to the function + + Returns + ------- + Tuple[Optional[str], Optional[str]] + A tuple of (trace_id, span_id) if found, otherwise (None, None) + """ + # First, check context variables (set by TracingMiddleware) + trace_id = get_trace_id() + span_id = get_span_id() + + if trace_id or span_id: + return trace_id, span_id + + # Fallback: check function arguments for request objects + if func_kwargs is None: + func_kwargs = {} + + # Check all arguments for request objects + all_args = list(func_args) + list(func_kwargs.values()) + + for arg in all_args: + if arg is None: + continue + + # Check if it's a web framework request object (FastAPI/Starlette or Flask) + # Both ASGI (FastAPI) and WSGI (Flask) request objects expose headers similarly + trace_id, span_id = _extract_from_request_object(arg) + if trace_id or span_id: + return trace_id, span_id + + return None, None + + +def _extract_from_request_object(request_obj: Any) -> tuple[Optional[str], Optional[str]]: + """ + Extract tracing headers from a web framework request object. + + This function works with both FastAPI/Starlette (ASGI) and Flask (WSGI) request objects. + Both frameworks expose headers in a similar way, so we can use a unified extraction method. + + Note: While FastAPI/Starlette uses ASGI (async) and Flask uses WSGI (sync), both expose + request headers via a `headers` attribute with a `.get()` method, making unified extraction possible. + + Parameters + ---------- + request_obj + Potential web framework request object (FastAPI/Starlette Request or Flask request) + + Returns + ------- + Tuple[Optional[str], Optional[str]] + A tuple of (trace_id, span_id) if found, otherwise (None, None) + """ + try: + # Both FastAPI/Starlette and Flask request objects have a headers attribute + if hasattr(request_obj, "headers"): + headers = request_obj.headers + # Both frameworks support .get() method on headers + if hasattr(headers, "get"): + # Try both lowercase and uppercase header names (case-insensitive) + trace_id = headers.get("x-trace-id") or headers.get("X-Trace-ID") + span_id = headers.get("x-span-id") or headers.get("X-Span-ID") + return trace_id, span_id + except Exception as e: + _logger.debug(f"Failed to extract headers from request object: {e}") + + return None, None diff --git a/src/galileo/utils/singleton.py b/src/galileo/utils/singleton.py index d2668f244..4f6ebbd78 100644 --- a/src/galileo/utils/singleton.py +++ b/src/galileo/utils/singleton.py @@ -5,11 +5,35 @@ from galileo.constants import DEFAULT_LOG_STREAM_NAME, DEFAULT_PROJECT_NAME from galileo.logger import GalileoLogger +from galileo.logger.logger import GalileoLoggerException from galileo.schema.metrics import LocalMetricConfig _logger = logging.getLogger(__name__) +def _validate_distributed_tracing_mode(trace_id: Optional[str], span_id: Optional[str], mode: str) -> None: + """ + Validate that distributed tracing (trace_id/span_id) is only used in streaming mode. + + Parameters + ---------- + trace_id + Optional trace ID for distributed tracing + span_id + Optional span ID for distributed tracing + mode + The logger mode (should be "streaming" if trace_id or span_id are provided) + + Raises + ------ + GalileoLoggerException + If trace_id or span_id are provided but mode is not "streaming" + """ + if trace_id or span_id: + if mode != "streaming": + raise GalileoLoggerException("trace_id or span_id can only be used in streaming mode") + + class GalileoLoggerSingleton: """ A singleton class that manages a collection of GalileoLogger instances. @@ -24,7 +48,7 @@ class GalileoLoggerSingleton: _instance = None # Class-level attribute to hold the singleton instance. _lock = threading.Lock() # Lock for thread-safe instantiation and operations. - _galileo_loggers: ClassVar[dict[tuple[str, str, str, str], GalileoLogger]] = {} # Cache for loggers. + _galileo_loggers: ClassVar[dict[tuple[str, ...], GalileoLogger]] = {} # Cache for loggers. def __new__(cls) -> "GalileoLoggerSingleton": """ @@ -39,14 +63,17 @@ def __new__(cls) -> "GalileoLoggerSingleton": with cls._lock: if not cls._instance: # Double-checked locking. cls._instance = super().__new__(cls) - # Initialize the logger dictionary in the new instance. - cls._instance._galileo_loggers = {} return cls._instance @staticmethod def _get_key( - project: Optional[str], log_stream: Optional[str], experiment_id: Optional[str] = None, mode: str = "batch" - ) -> tuple[str, str, str, str]: + project: Optional[str], + log_stream: Optional[str], + experiment_id: Optional[str] = None, + mode: str = "batch", + trace_id: Optional[str] = None, + span_id: Optional[str] = None, + ) -> tuple[str, ...]: """ Generate a key tuple based on project and log_stream parameters. @@ -64,11 +91,16 @@ def _get_key( The experiment ID. mode: (Optional[str]) The logger mode. + trace_id: (Optional[str]) + Trace ID for distributed tracing (included in cache key for proper isolation). + span_id: (Optional[str]) + Span ID for distributed tracing (included in cache key for proper isolation). Returns ------- - Tuple[str, str, str, str] - A tuple key (project, log_stream, experiment_id, mode) used for caching. + Tuple[str, ...] + A tuple key used for caching. Includes trace_id/span_id when provided for distributed tracing + to ensure nested calls within the same request reuse the same logger instance. """ _logger.debug("current thread is %s", threading.current_thread().name) @@ -81,10 +113,16 @@ def _get_key( if log_stream is None: log_stream = getenv("GALILEO_LOG_STREAM", DEFAULT_LOG_STREAM_NAME) - if experiment_id is not None: - return (*key, project, experiment_id) + base_key = (*key, project, experiment_id) if experiment_id is not None else (*key, project, log_stream) - return (*key, project, log_stream) + # For distributed tracing, include trace_id/span_id in the cache key + # This allows nested calls within the same request to reuse the same logger instance + if trace_id or span_id: + # Assert that mode is streaming when distributed tracing is used + _validate_distributed_tracing_mode(trace_id, span_id, mode) + return (*base_key, trace_id or "", span_id or "") + + return base_key def get( self, @@ -94,6 +132,8 @@ def get( experiment_id: Optional[str] = None, mode: str = "batch", local_metrics: Optional[list[LocalMetricConfig]] = None, + trace_id: Optional[str] = None, + span_id: Optional[str] = None, ) -> GalileoLogger: """ Retrieve an existing GalileoLogger or create a new one if it does not exist. @@ -113,14 +153,26 @@ def get( local_metrics (Optional[list[LocalScorerConfig]], optional) Local scorers to run on traces/spans. Only used if initializing a new logger, ignored otherwise. Defaults to None. + trace_id (Optional[str], optional) + Trace ID for distributed tracing. If provided, logger will use streaming mode. + Defaults to None. + span_id (Optional[str], optional) + Span ID for distributed tracing. If provided, logger will use streaming mode. + Defaults to None. Returns ------- GalileoLogger An instance of GalileoLogger corresponding to the key. """ - # Compute the key based on provided parameters or environment variables. - key = GalileoLoggerSingleton._get_key(project, log_stream, experiment_id, mode) + # If trace_id or span_id is provided, use streaming mode + if trace_id or span_id: + mode = "streaming" + # Assert that mode is streaming when distributed tracing is used + _validate_distributed_tracing_mode(trace_id, span_id, mode) + + # Compute cache key (includes trace_id/span_id for distributed tracing) + key = GalileoLoggerSingleton._get_key(project, log_stream, experiment_id, mode, trace_id, span_id) # First check without acquiring lock for performance. if key in self._galileo_loggers: @@ -139,11 +191,13 @@ def get( "experiment_id": experiment_id, "local_metrics": local_metrics, "experimental": {"mode": mode}, + "trace_id": trace_id, + "span_id": span_id, } # Create the logger with filtered kwargs. logger = GalileoLogger(**{k: v for k, v in galileo_client_init_args.items() if v is not None}) - # Cache the newly created logger. + # Cache the logger if logger: self._galileo_loggers[key] = logger return logger @@ -218,14 +272,14 @@ def flush_all(self) -> None: for logger in self._galileo_loggers.values(): logger.flush() - def get_all_loggers(self) -> dict[tuple[str, str, str, str], GalileoLogger]: + def get_all_loggers(self) -> dict[tuple[str, ...], GalileoLogger]: """ Retrieve a copy of the dictionary containing all active loggers. Returns ------- - Dict[Tuple[str, str, str], GalileoLogger]: - A dictionary mapping keys (project, log_stream) to their corresponding GalileoLogger instances. + Dict[Tuple[str, ...], GalileoLogger]: + A dictionary mapping keys (project, log_stream, ...) to their corresponding GalileoLogger instances. """ # Return a shallow copy of the loggers dictionary to prevent external modifications. - return dict(self._galileo_loggers) + return self._galileo_loggers.copy() diff --git a/tests/test_distributed_tracing.py b/tests/test_distributed_tracing.py new file mode 100644 index 000000000..c2a622dfd --- /dev/null +++ b/tests/test_distributed_tracing.py @@ -0,0 +1,342 @@ +"""Tests for distributed tracing functionality.""" + +import datetime +from unittest.mock import AsyncMock, Mock, patch +from uuid import UUID + +import pytest +from starlette.requests import Request +from starlette.responses import Response + +from galileo import galileo_context, log +from galileo.middleware.tracing import TracingMiddleware, get_span_id, get_trace_id +from galileo.utils.distributed_tracing import extract_tracing_headers +from tests.testutils.setup import setup_mock_logstreams_client, setup_mock_projects_client, setup_mock_traces_client + + +@pytest.fixture +def reset_context() -> None: + """Reset galileo context before each test.""" + galileo_context.reset() + + +def test_extract_tracing_headers_from_context() -> None: + """Test extracting trace/span IDs from context variables.""" + # Set context variables (as middleware would) + from galileo.middleware.tracing import _span_id_context, _trace_id_context + + _trace_id_context.set("test-trace-id-123") + _span_id_context.set("test-span-id-456") + + trace_id, span_id = extract_tracing_headers() + + assert trace_id == "test-trace-id-123" + assert span_id == "test-span-id-456" + + # Clean up + _trace_id_context.set(None) + _span_id_context.set(None) + + +def test_extract_tracing_headers_from_fastapi_request() -> None: + """Test extracting trace/span IDs from FastAPI Request object.""" + # Create a mock request with headers + mock_request = Mock(spec=Request) + mock_request.headers = {"X-Trace-ID": "trace-from-request", "X-Span-ID": "span-from-request"} + + trace_id, span_id = extract_tracing_headers(func_args=(mock_request,)) + + assert trace_id == "trace-from-request" + assert span_id == "span-from-request" + + +def test_extract_tracing_headers_from_fastapi_request_lowercase() -> None: + """Test extracting trace/span IDs from FastAPI Request with lowercase headers.""" + mock_request = Mock(spec=Request) + mock_request.headers = {"x-trace-id": "trace-lower", "x-span-id": "span-lower"} + + trace_id, span_id = extract_tracing_headers(func_args=(mock_request,)) + + assert trace_id == "trace-lower" + assert span_id == "span-lower" + + +def test_extract_tracing_headers_from_flask_request() -> None: + """Test extracting trace/span IDs from Flask request object.""" + # Create a mock Flask request + mock_request = Mock() + mock_request.headers = Mock() + mock_request.headers.get = Mock( + side_effect=lambda key, default=None: {"X-Trace-ID": "flask-trace", "X-Span-ID": "flask-span"}.get(key, default) + ) + + trace_id, span_id = extract_tracing_headers(func_args=(mock_request,)) + + assert trace_id == "flask-trace" + assert span_id == "flask-span" + + +def test_extract_tracing_headers_priority_context_over_request() -> None: + """Test that context variables take priority over request objects.""" + from galileo.middleware.tracing import _span_id_context, _trace_id_context + + # Set context variables + _trace_id_context.set("context-trace") + _span_id_context.set("context-span") + + # Create request with different IDs + mock_request = Mock(spec=Request) + mock_request.headers = {"X-Trace-ID": "request-trace", "X-Span-ID": "request-span"} + + trace_id, span_id = extract_tracing_headers(func_args=(mock_request,)) + + # Context should take priority + assert trace_id == "context-trace" + assert span_id == "context-span" + + # Clean up + _trace_id_context.set(None) + _span_id_context.set(None) + + +def test_extract_tracing_headers_no_headers() -> None: + """Test extracting when no headers are present.""" + trace_id, span_id = extract_tracing_headers() + + assert trace_id is None + assert span_id is None + + +@patch("galileo.logger.logger.LogStreams") +@patch("galileo.logger.logger.Projects") +@patch("galileo.logger.logger.Traces") +def test_tracing_middleware_extracts_headers( + mock_traces_client: Mock, mock_projects_client: Mock, mock_logstreams_client: Mock +) -> None: + """Test that TracingMiddleware extracts headers from requests.""" + setup_mock_traces_client(mock_traces_client) + setup_mock_projects_client(mock_projects_client) + setup_mock_logstreams_client(mock_logstreams_client) + + middleware = TracingMiddleware(app=None) + + # Create a mock request + mock_request = Mock(spec=Request) + mock_request.headers = {"X-Trace-ID": "middleware-trace-123", "X-Span-ID": "middleware-span-456"} + + # Create a mock call_next + async def mock_call_next(request: Request) -> Response: + # Verify context was set + assert get_trace_id() == "middleware-trace-123" + assert get_span_id() == "middleware-span-456" + return Response() + + # Run the middleware + import asyncio + + asyncio.run(middleware.dispatch(mock_request, mock_call_next)) + + # Verify context is cleared after request (or still available if needed) + # Note: Context variables are request-scoped, so they may persist + # until the next request in the same context + + +@patch("galileo.logger.logger.LogStreams") +@patch("galileo.logger.logger.Projects") +@patch("galileo.logger.logger.Traces") +def test_decorator_with_distributed_tracing_context( + mock_traces_client: Mock, mock_projects_client: Mock, mock_logstreams_client: Mock, reset_context +) -> None: + """Test that @log decorator uses distributed tracing headers from context.""" + mock_traces_client_instance = setup_mock_traces_client(mock_traces_client) + setup_mock_projects_client(mock_projects_client) + setup_mock_logstreams_client(mock_logstreams_client) + + # Set up context with distributed tracing IDs (as middleware would) + from galileo.middleware.tracing import _span_id_context, _trace_id_context + + test_trace_id = "6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9d" + test_span_id = "6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9e" + + _trace_id_context.set(test_trace_id) + _span_id_context.set(test_span_id) + + # Mock get_trace to return an existing trace (for distributed tracing) + # Note: get_trace is already mocked as AsyncMock in setup_mock_traces_client + # We just need to update its return value + mock_traces_client_instance.get_trace = AsyncMock( + return_value={ + "id": UUID(test_trace_id), + "name": "distributed-trace", + "type": "trace", + "input": "original input", + "output": None, + "created_at": datetime.datetime.now(), + "updated_at": datetime.datetime.now(), + "user_metadata": {}, + "spans": [], + "metrics": {}, + } + ) + + # Mock get_span to return an existing span (for distributed tracing) + # Note: get_span is already mocked as AsyncMock in setup_mock_traces_client + mock_traces_client_instance.get_span = AsyncMock( + return_value={ + "id": UUID(test_span_id), + "name": "distributed-span", + "type": "workflow", + "input": "original input", + "output": None, + "created_at": datetime.datetime.now(), + "updated_at": datetime.datetime.now(), + "user_metadata": {}, + "metrics": {}, + "parent_id": UUID(test_trace_id), + "trace_id": UUID(test_trace_id), + } + ) + + # Initialize context in streaming mode (required for distributed tracing) + galileo_context.init(project="test-project", log_stream="test-stream", mode="streaming") + + @log(span_type="retriever") + def retrieval_service(query: str) -> str: + return "retrieved data" + + @log + def retrieve_endpoint(query: str) -> str: + return retrieval_service(query=query) + + result = retrieve_endpoint(query="test query") + galileo_context.flush_all() + + # Verify get_trace was called to load the distributed trace + # In streaming mode with distributed tracing, get_trace should be called to load the existing trace + mock_traces_client_instance.get_trace.assert_called() + + # Verify the functions executed successfully + assert result == "retrieved data" + + # Clean up + _trace_id_context.set(None) + _span_id_context.set(None) + + +@patch("galileo.logger.logger.LogStreams") +@patch("galileo.logger.logger.Projects") +@patch("galileo.logger.logger.Traces") +def test_decorator_with_distributed_tracing_request_object( + mock_traces_client: Mock, mock_projects_client: Mock, mock_logstreams_client: Mock, reset_context +) -> None: + """Test that @log decorator extracts headers from Request object in function args.""" + mock_traces_client_instance = setup_mock_traces_client(mock_traces_client) + setup_mock_projects_client(mock_projects_client) + setup_mock_logstreams_client(mock_logstreams_client) + + # Mock get_trace for distributed tracing + test_trace_id = "6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9f" + test_span_id = "6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9g" + + # Note: get_trace is already mocked as AsyncMock in setup_mock_traces_client + # We just need to update its return value + mock_traces_client_instance.get_trace = AsyncMock( + return_value={ + "id": UUID(test_trace_id), + "name": "request-trace", + "type": "trace", + "input": "original input", + "output": None, + "created_at": datetime.datetime.now(), + "updated_at": datetime.datetime.now(), + "user_metadata": {}, + "spans": [], + "metrics": {}, + } + ) + + galileo_context.init(project="test-project", log_stream="test-stream", mode="streaming") + + # Create a mock request with headers + mock_request = Mock(spec=Request) + mock_request.headers = {"X-Trace-ID": test_trace_id, "X-Span-ID": test_span_id} + + @log(span_type="retriever") + def retrieval_service(query: str) -> str: + return "retrieved data" + + @log + def retrieve_endpoint(request: Request, query: str) -> str: + return retrieval_service(query=query) + + retrieve_endpoint(request=mock_request, query="test query") + galileo_context.flush_all() + + # Verify get_trace was called to load the distributed trace + mock_traces_client_instance.get_trace.assert_called() + + +@patch("galileo.logger.logger.LogStreams") +@patch("galileo.logger.logger.Projects") +@patch("galileo.logger.logger.Traces") +def test_singleton_not_caches_distributed_tracing_loggers( + mock_traces_client: Mock, mock_projects_client: Mock, mock_logstreams_client: Mock, reset_context +) -> None: + """Test that singleton does not cache loggers when trace_id/span_id are provided.""" + from galileo.utils.singleton import GalileoLoggerSingleton + + setup_mock_traces_client(mock_traces_client) + setup_mock_projects_client(mock_projects_client) + setup_mock_logstreams_client(mock_logstreams_client) + + # Clear any existing loggers + singleton = GalileoLoggerSingleton() + singleton.reset() + + galileo_context.init(project="test-project", log_stream="test-stream", mode="streaming") + + # Set up context with distributed tracing IDs + from galileo.middleware.tracing import _span_id_context, _trace_id_context + from galileo.utils.distributed_tracing import extract_tracing_headers + + _trace_id_context.set("trace-1") + _span_id_context.set("span-1") + + # Extract trace_id/span_id and get logger instance with them + trace_id_1, span_id_1 = extract_tracing_headers() + logger1 = galileo_context.get_logger_instance(trace_id=trace_id_1, span_id=span_id_1) + + @log + def func1() -> str: + return "result1" + + func1() + + # Change trace/span IDs + _trace_id_context.set("trace-2") + _span_id_context.set("span-2") + + # Extract new trace_id/span_id and get a new logger instance + trace_id_2, span_id_2 = extract_tracing_headers() + logger2 = galileo_context.get_logger_instance(trace_id=trace_id_2, span_id=span_id_2) + + @log + def func2() -> str: + return "result2" + + func2() + + # In distributed tracing mode, each request should get a new logger instance + # (not cached based on trace_id/span_id) + # The loggers should be different instances + assert logger1 is not logger2 + + # Verify they have different trace_id/span_id + assert logger1.trace_id == "trace-1" + assert logger1.span_id == "span-1" + assert logger2.trace_id == "trace-2" + assert logger2.span_id == "span-2" + + # Clean up + _trace_id_context.set(None) + _span_id_context.set(None)