diff --git a/dspy/__init__.py b/dspy/__init__.py index ea4c75a862..9408c9d9ab 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -6,7 +6,7 @@ from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, BAMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.syncify import syncify diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index 1dea6da47a..11e6b1fbf6 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -1,4 +1,5 @@ from dspy.adapters.base import Adapter +from dspy.adapters.baml_adapter import BAMLAdapter from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.two_step_adapter import TwoStepAdapter @@ -7,6 +8,7 @@ __all__ = [ "Adapter", + "BAMLAdapter", "ChatAdapter", "Type", "History", diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index 9e2d7dda6b..0f8ff8bd91 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -5,6 +5,7 @@ from litellm import ModelResponseStream +from dspy.adapters.baml_adapter import BAMLAdapter from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.types import Type @@ -15,7 +16,7 @@ if TYPE_CHECKING: from dspy.primitives.module import Module -ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter] +ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter, BAMLAdapter] class StreamListener: @@ -65,6 +66,11 @@ def __init__( "end_identifier": re.compile(rf""), "start_indicator": "<", }, + "BAMLAdapter": { + "start_identifier": f'"{self.signature_field_name}":', + "end_identifier": re.compile(r"\w*\"(,|\s*})"), + "start_indicator": '"', + }, } def _buffered_message_end_with_start_identifier(self, concat_message: str, start_identifier: str) -> str: @@ -145,8 +151,8 @@ def receive(self, chunk: ModelResponseStream): # Keep the part after the start_identifier from the concat_message, we need to write it to the buffer. value_start_index = concat_message.find(start_identifier) + len(start_identifier) chunk_message = concat_message[value_start_index:].lstrip() - if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'): - # For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier + if isinstance(settings.adapter, (JSONAdapter, BAMLAdapter)) and chunk_message.startswith('"'): + # For JSONAdapter and BAMLAdapter, we need to remove the leading ". We cannot do this with the start_identifier # because there could be a few splitters between ':' and '"', e.g., '"name": "value"'. chunk_message = chunk_message[1:] @@ -194,7 +200,7 @@ def flush(self) -> str: """ last_tokens = "".join(self.field_end_queue.queue) self.field_end_queue = Queue() - if isinstance(settings.adapter, JSONAdapter): + if isinstance(settings.adapter, (JSONAdapter, BAMLAdapter)): match = re.search(r'",|"\s*}', last_tokens) if match: boundary_index = match.start() @@ -206,7 +212,7 @@ def flush(self) -> str: if boundary_index == -1: boundary_index = len(last_tokens) return last_tokens[:boundary_index] - elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None: + elif isinstance(settings.adapter, (ChatAdapter, BAMLAdapter)) or settings.adapter is None: boundary_index = last_tokens.find("[[") return last_tokens[:boundary_index] else: diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index 9a96820393..e467f97ae7 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -851,6 +851,81 @@ async def completion_side_effect(*args, **kwargs): assert all_chunks[1].chunk == "The answer is humorous." +@pytest.mark.anyio +async def test_stream_listener_returns_correct_chunk_baml_adapter(): + class MyProgram(dspy.Module): + def __init__(self): + super().__init__() + self.predict1 = dspy.Predict("question->answer") + self.predict2 = dspy.Predict("question,answer->judgement") + + def forward(self, question, **kwargs): + answer = self.predict1(question=question, **kwargs).answer + judgement = self.predict2(question=question, answer=answer, **kwargs) + return judgement + + async def baml_stream_1(*args, **kwargs): + # BAML uses JSON format for responses but ChatAdapter-style field delimiters in prompts + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}\n"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))]) + + async def baml_stream_2(*args, **kwargs): + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="judgement"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))]) + + stream_generators = [baml_stream_1, baml_stream_2] + + async def completion_side_effect(*args, **kwargs): + return stream_generators.pop(0)() + + with mock.patch("litellm.acompletion", side_effect=completion_side_effect): + program = dspy.streamify( + MyProgram(), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="answer"), + dspy.streaming.StreamListener(signature_field_name="judgement"), + ], + ) + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.BAMLAdapter()): + output = program(question="why did a chicken cross the kitchen?") + all_chunks = [] + async for value in output: + if isinstance(value, dspy.streaming.StreamResponse): + all_chunks.append(value) + + assert all_chunks[0].predict_name == "predict1" + assert all_chunks[0].signature_field_name == "answer" + assert all_chunks[0].chunk == "To get to the other side!" + + assert all_chunks[1].predict_name == "predict2" + assert all_chunks[1].signature_field_name == "judgement" + assert all_chunks[1].chunk == "The answer is humorous." + + @pytest.mark.anyio async def test_streaming_allows_custom_chunk_types(): @dataclass diff --git a/uv.lock b/uv.lock index 0e9afdf5c1..2a3bc669d7 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.12.4' and sys_platform == 'win32'", @@ -664,7 +664,7 @@ wheels = [ [[package]] name = "dspy" -version = "3.0.3" +version = "3.0.4b1" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -741,7 +741,7 @@ requires-dist = [ { name = "datamodel-code-generator", marker = "extra == 'dev'", specifier = ">=0.26.3" }, { name = "datasets", marker = "extra == 'test-extras'", specifier = ">=2.14.6" }, { name = "diskcache", specifier = ">=5.6.0" }, - { name = "gepa", extras = ["dspy"], specifier = "==0.0.12" }, + { name = "gepa", extras = ["dspy"], specifier = "==0.0.17" }, { name = "joblib", specifier = "~=1.3" }, { name = "json-repair", specifier = ">=0.30.0" }, { name = "langchain-core", marker = "extra == 'langchain'" }, @@ -758,6 +758,7 @@ requires-dist = [ { name = "optuna", marker = "extra == 'test-extras'", specifier = ">=3.4.0" }, { name = "orjson", specifier = ">=3.9.0" }, { name = "pandas", marker = "extra == 'test-extras'", specifier = ">=2.1.1" }, + { name = "pillow", specifier = ">=10.1.0" }, { name = "pillow", marker = "extra == 'dev'", specifier = ">=10.1.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0" }, { name = "pydantic", specifier = ">=2.0" }, @@ -957,11 +958,11 @@ wheels = [ [[package]] name = "gepa" -version = "0.0.12" +version = "0.0.17" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1c/c1/748f282ca83dea3d2a8dcefd6b1476d8780c99a2e3bdd80dfebbcb6e823b/gepa-0.0.12.tar.gz", hash = "sha256:0c725790c28399e333a37f32dc858a674dc9e748fcccac1df632acdf9f0302ef", size = 63863, upload-time = "2025-09-09T01:36:21.818Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/f0/fe312ed4405ddc2ca97dc1ce8915c4dd707e413503e6832910ab088fceb6/gepa-0.0.17.tar.gz", hash = "sha256:641ed46f8127618341b66ee82a87fb46a21c5d2d427a5e0b91c850a7f7f64e7f", size = 99816, upload-time = "2025-09-25T22:13:45.476Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/fb/73d6b15259067248a9c574508924dc2df427b5cc1b50eccfb3ceddf4334d/gepa-0.0.12-py3-none-any.whl", hash = "sha256:bf254177d1b9056d09473273472bac96fd1272124fb20abfe8dc548e4b914c58", size = 64264, upload-time = "2025-09-09T01:36:20.469Z" }, + { url = "https://files.pythonhosted.org/packages/88/dc/2bc81a01caa887ed58db3c725bebf1e98f37807a4d06c51ecaa85a7cabe0/gepa-0.0.17-py3-none-any.whl", hash = "sha256:0ea98f4179dbc8dd83bdf53494f302e663ee1da8300d086c4cc8ce4aefa4042c", size = 110464, upload-time = "2025-09-25T22:13:44.14Z" }, ] [[package]]