@@ -895,6 +895,10 @@ def parse_stream_chunk(cls, chunk):
895895 def parse_lm_response (cls , response : dict ) -> "CustomType" :
896896 return CustomType (message = response .split ("\n \n " )[0 ])
897897
898+ @classmethod
899+ def is_natively_supported (cls , lm , lm_kwargs ):
900+ return True
901+
898902 class CustomSignature (dspy .Signature ):
899903 question : str = dspy .InputField ()
900904 answer : CustomType = dspy .OutputField ()
@@ -907,7 +911,6 @@ class CustomSignature(dspy.Signature):
907911 )
908912
909913 async def stream (* args , ** kwargs ):
910- yield ModelResponseStream (model = "gpt-4o-mini" , choices = [StreamingChoices (delta = Delta (content = "[[ ## answer ## ]]\n " ))])
911914 yield ModelResponseStream (model = "gpt-4o-mini" , choices = [StreamingChoices (delta = Delta (content = "Hello" ))])
912915 yield ModelResponseStream (model = "gpt-4o-mini" , choices = [StreamingChoices (delta = Delta (content = "World" ))])
913916 yield ModelResponseStream (model = "gpt-4o-mini" , choices = [StreamingChoices (delta = Delta (content = "\n \n " ))])
@@ -916,9 +919,10 @@ async def stream(*args, **kwargs):
916919 yield ModelResponseStream (model = "gpt-4o-mini" , choices = [StreamingChoices (delta = Delta (content = " ##" ))])
917920 yield ModelResponseStream (model = "gpt-4o-mini" , choices = [StreamingChoices (delta = Delta (content = " ]]" ))])
918921
919-
920922 with mock .patch ("litellm.acompletion" , side_effect = stream ):
921- with dspy .context (lm = dspy .LM ("openai/gpt-4o-mini" , cache = False ), adapter = dspy .ChatAdapter (native_response_types = [CustomType ])):
923+ with dspy .context (
924+ lm = dspy .LM ("openai/gpt-4o-mini" , cache = False ), adapter = dspy .ChatAdapter (native_response_types = [CustomType ])
925+ ):
922926 output = program (question = "why did a chicken cross the kitchen?" )
923927 all_chunks = []
924928 async for value in output :
@@ -935,6 +939,7 @@ async def stream(*args, **kwargs):
935939async def test_streaming_with_citations ():
936940 class AnswerWithSources (dspy .Signature ):
937941 """Answer questions using provided documents with citations."""
942+
938943 documents : list [Document ] = dspy .InputField ()
939944 question : str = dspy .InputField ()
940945 answer : str = dspy .OutputField ()
@@ -959,20 +964,36 @@ async def citation_stream(*args, **kwargs):
959964 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " 100°C" ))])
960965 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "." ))])
961966 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "\n \n " ))])
962- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = '[{"type": "char_location", "cited_text": "Water boils at 100°C", "document_index": 0, "document_title": "Physics Facts", "start_char_index": 0, "end_char_index": 19}]' ))])
963- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (
964- content = "" ,
965- provider_specific_fields = {
966- "citation" : {
967- "type" : "char_location" ,
968- "cited_text" : "Water boils at 100°C" ,
969- "document_index" : 0 ,
970- "document_title" : "Physics Facts" ,
971- "start_char_index" : 0 ,
972- "end_char_index" : 19
973- }
974- }
975- ))])
967+ yield ModelResponseStream (
968+ model = "claude" ,
969+ choices = [
970+ StreamingChoices (
971+ delta = Delta (
972+ content = '[{"type": "char_location", "cited_text": "Water boils at 100°C", "document_index": 0, "document_title": "Physics Facts", "start_char_index": 0, "end_char_index": 19}]'
973+ )
974+ )
975+ ],
976+ )
977+ yield ModelResponseStream (
978+ model = "claude" ,
979+ choices = [
980+ StreamingChoices (
981+ delta = Delta (
982+ content = "" ,
983+ provider_specific_fields = {
984+ "citation" : {
985+ "type" : "char_location" ,
986+ "cited_text" : "Water boils at 100°C" ,
987+ "document_index" : 0 ,
988+ "document_title" : "Physics Facts" ,
989+ "start_char_index" : 0 ,
990+ "end_char_index" : 19 ,
991+ }
992+ },
993+ )
994+ )
995+ ],
996+ )
976997 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "\n \n " ))])
977998 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "[[ ##" ))])
978999 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " completed" ))])
@@ -990,7 +1011,10 @@ async def citation_stream(*args, **kwargs):
9901011 # Create test documents
9911012 docs = [Document (data = "Water boils at 100°C at standard pressure." , title = "Physics Facts" )]
9921013
993- with dspy .context (lm = dspy .LM ("anthropic/claude-3-5-sonnet-20241022" , cache = False ), adapter = dspy .ChatAdapter (native_response_types = [Citations ])):
1014+ with dspy .context (
1015+ lm = dspy .LM ("anthropic/claude-3-5-sonnet-20241022" , cache = False ),
1016+ adapter = dspy .ChatAdapter (native_response_types = [Citations ]),
1017+ ):
9941018 output = program (documents = docs , question = "What temperature does water boil?" )
9951019 citation_chunks = []
9961020 final_prediction = None
0 commit comments