Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, BAMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.syncify import syncify
Expand Down
2 changes: 2 additions & 0 deletions dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dspy.adapters.base import Adapter
from dspy.adapters.baml_adapter import BAMLAdapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.two_step_adapter import TwoStepAdapter
Expand All @@ -7,6 +8,7 @@

__all__ = [
"Adapter",
"BAMLAdapter",
"ChatAdapter",
"Type",
"History",
Expand Down
16 changes: 11 additions & 5 deletions dspy/streaming/streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from litellm import ModelResponseStream

from dspy.adapters.baml_adapter import BAMLAdapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.types import Type
Expand All @@ -15,7 +16,7 @@
if TYPE_CHECKING:
from dspy.primitives.module import Module

ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter]
ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter, BAMLAdapter]


class StreamListener:
Expand Down Expand Up @@ -65,6 +66,11 @@ def __init__(
"end_identifier": re.compile(rf"</{self.signature_field_name}>"),
"start_indicator": "<",
},
"BAMLAdapter": {
"start_identifier": f'"{self.signature_field_name}":',
"end_identifier": re.compile(r"\w*\"(,|\s*})"),
"start_indicator": '"',
},
}

def _buffered_message_end_with_start_identifier(self, concat_message: str, start_identifier: str) -> str:
Expand Down Expand Up @@ -145,8 +151,8 @@ def receive(self, chunk: ModelResponseStream):
# Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
value_start_index = concat_message.find(start_identifier) + len(start_identifier)
chunk_message = concat_message[value_start_index:].lstrip()
if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
# For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
if isinstance(settings.adapter, (JSONAdapter, BAMLAdapter)) and chunk_message.startswith('"'):
# For JSONAdapter and BAMLAdapter, we need to remove the leading ". We cannot do this with the start_identifier
# because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
chunk_message = chunk_message[1:]

Expand Down Expand Up @@ -194,7 +200,7 @@ def flush(self) -> str:
"""
last_tokens = "".join(self.field_end_queue.queue)
self.field_end_queue = Queue()
if isinstance(settings.adapter, JSONAdapter):
if isinstance(settings.adapter, (JSONAdapter, BAMLAdapter)):
match = re.search(r'",|"\s*}', last_tokens)
if match:
boundary_index = match.start()
Expand All @@ -206,7 +212,7 @@ def flush(self) -> str:
if boundary_index == -1:
boundary_index = len(last_tokens)
return last_tokens[:boundary_index]
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
elif isinstance(settings.adapter, (ChatAdapter, BAMLAdapter)) or settings.adapter is None:
boundary_index = last_tokens.find("[[")
return last_tokens[:boundary_index]
else:
Expand Down
75 changes: 75 additions & 0 deletions tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,81 @@ async def completion_side_effect(*args, **kwargs):
assert all_chunks[1].chunk == "The answer is humorous."


@pytest.mark.anyio
async def test_stream_listener_returns_correct_chunk_baml_adapter():
class MyProgram(dspy.Module):
def __init__(self):
super().__init__()
self.predict1 = dspy.Predict("question->answer")
self.predict2 = dspy.Predict("question,answer->judgement")

def forward(self, question, **kwargs):
answer = self.predict1(question=question, **kwargs).answer
judgement = self.predict2(question=question, answer=answer, **kwargs)
return judgement

async def baml_stream_1(*args, **kwargs):
# BAML uses JSON format for responses but ChatAdapter-style field delimiters in prompts
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":'))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"'))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}\n"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])

async def baml_stream_2(*args, **kwargs):
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="judgement"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":"'))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"'))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])

stream_generators = [baml_stream_1, baml_stream_2]

async def completion_side_effect(*args, **kwargs):
return stream_generators.pop(0)()

with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
program = dspy.streamify(
MyProgram(),
stream_listeners=[
dspy.streaming.StreamListener(signature_field_name="answer"),
dspy.streaming.StreamListener(signature_field_name="judgement"),
],
)
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.BAMLAdapter()):
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
assert all_chunks[0].chunk == "To get to the other side!"

assert all_chunks[1].predict_name == "predict2"
assert all_chunks[1].signature_field_name == "judgement"
assert all_chunks[1].chunk == "The answer is humorous."


@pytest.mark.anyio
async def test_streaming_allows_custom_chunk_types():
@dataclass
Expand Down
13 changes: 7 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading