6161log = logging .getLogger (__name__ )
6262JSON_HIGHLIGHTER = JSONHighlighter ()
6363
64- MAX_GENERATOR_TOKENS = 8092
6564
6665class OutlinesTransformersADM (ActionBasedADM ):
6766 def __init__ (self ,
@@ -193,7 +192,8 @@ def kdma_value_to_system_prompt(kdma, value):
193192 else :
194193 return None
195194
196- def _state_to_top_level_prompt (self , scenario_state , actions ):
195+ @staticmethod
196+ def _state_to_top_level_prompt (action_selection_prompt_template , scenario_description , scenario_state , actions ):
197197 """
198198 Generate prompt dialog based on given state and actions
199199 """
@@ -203,8 +203,7 @@ def _state_to_top_level_prompt(self, scenario_state, actions):
203203 scenario_state
204204 )
205205
206- scenario_description = self .scenario_description_template (scenario_state )
207- prompt = self .action_selection_prompt_template (scenario_description , choices )
206+ prompt = action_selection_prompt_template (scenario_description , choices )
208207
209208 return prompt , choices
210209
@@ -231,24 +230,25 @@ def run_in_batches(cls, inference_function, inputs, batch_size):
231230 outputs .extend (output )
232231 return outputs
233232
234- def top_level_choose_action (self ,
235- scenario_state ,
236- available_actions ,
237- alignment_target ,
238- num_positive_samples = 1 ,
239- num_negative_samples = 0 ,
240- generator_batch_size = 5 ,
241- kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
242- reasoning_max_length = 512 ,
243- generator_seed = - 1 ,
244- shuffle_choices = True ,
245- ** kwargs ):
246- if self .baseline and num_negative_samples > 0 :
233+ @staticmethod
234+ def get_dialogs (scenario_state ,
235+ available_actions ,
236+ alignment_target ,
237+ num_positive_samples = 1 ,
238+ num_negative_samples = 0 ,
239+ kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
240+ shuffle_choices = True ,
241+ baseline = False ,
242+ scenario_description_template = scenario_state_description_1 ,
243+ action_selection_prompt_template = action_selection_prompt ,
244+ baseline_system_prompt = baseline_system_prompt ,
245+ ** kwargs ):
246+ if baseline and num_negative_samples > 0 :
247247 raise RuntimeError ("No notion of negative samples for baseline run" )
248- if self . baseline and "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
248+ if baseline and "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
249249 raise RuntimeError ("No notion of incontext examples for baseline run" )
250250
251- scenario_description = self . scenario_description_template (scenario_state )
251+ scenario_description = scenario_description_template (scenario_state )
252252 # Important that the choices stay in the same order as the
253253 # available actions as we'll use the selected index later to
254254 # map to the corresponding action
@@ -261,12 +261,11 @@ def top_level_choose_action(self,
261261 positive_icl_examples = []
262262 negative_icl_examples = []
263263 incontext_settings = kwargs .get ("incontext" , {})
264- if not self .baseline and alignment_target is not None :
265- kdma_values = alignment_target .kdma_values
266264
265+ if not baseline and alignment_target is not None :
266+ kdma_values = alignment_target .kdma_values
267267 if len (kdma_values ) != 1 :
268268 raise RuntimeError ("This ADM assumes a single KDMA target, aborting!" )
269-
270269 kdma_value = kdma_values [0 ]
271270 if isinstance (kdma_value , KDMAValue ):
272271 kdma_value = kdma_value .to_dict ()
@@ -280,8 +279,8 @@ def top_level_choose_action(self,
280279 kdma_descriptions = yaml .load (f , Loader = yaml .FullLoader )
281280 name = kdma_descriptions [kdma ]['name' ]
282281
283- positive_system_prompt = self . __class__ .kdma_value_to_system_prompt (kdma , value )
284- negative_system_prompt = self . __class__ .kdma_value_to_system_prompt (kdma , negative_value )
282+ positive_system_prompt = OutlinesTransformersADM .kdma_value_to_system_prompt (kdma , value )
283+ negative_system_prompt = OutlinesTransformersADM .kdma_value_to_system_prompt (kdma , negative_value )
285284
286285 if positive_system_prompt is None :
287286 raise RuntimeError ("Couldn't find system prompt for kdma: {}, and "
@@ -291,8 +290,7 @@ def top_level_choose_action(self,
291290 "value: {}." .format (kdma , negative_value ))
292291
293292 if "incontext" in kwargs and "number" in incontext_settings and incontext_settings ["number" ] > 0 :
294- scenario_to_match = self .scenario_description_template (scenario_state )
295- prompt_to_match , _ = self ._state_to_top_level_prompt (scenario_state , available_actions )
293+ prompt_to_match , _ = OutlinesTransformersADM ._state_to_top_level_prompt (action_selection_prompt_template , scenario_state , available_actions )
296294
297295 # Create positive ICL example generators
298296 positive_target = {'kdma' : kdma , 'name' : name , 'value' : value }
@@ -301,7 +299,7 @@ def top_level_choose_action(self,
301299 # Get subset of relevant of examples
302300 positive_selected_icl_examples = positive_icl_example_generator .select_icl_examples (
303301 sys_kdma_name = kdma ,
304- scenario_description_to_match = scenario_to_match ,
302+ scenario_description_to_match = scenario_description ,
305303 prompt_to_match = prompt_to_match ,
306304 state_comparison = scenario_state
307305 )
@@ -321,7 +319,7 @@ def top_level_choose_action(self,
321319 # Get subset of relevant of examples
322320 negative_selected_icl_examples = negative_icl_example_generator .select_icl_examples (
323321 sys_kdma_name = kdma ,
324- scenario_description_to_match = scenario_to_match ,
322+ scenario_description_to_match = scenario_description ,
325323 prompt_to_match = prompt_to_match ,
326324 state_comparison = scenario_state
327325 )
@@ -331,17 +329,17 @@ def top_level_choose_action(self,
331329 {"role" : "assistant" , "content" : f'{ icl_sample ["response" ]} ' }
332330 ])
333331 else :
334- positive_system_prompt = self . baseline_system_prompt ()
332+ positive_system_prompt = baseline_system_prompt ()
335333 if num_negative_samples > 0 :
336334 raise RuntimeError ("No notion of negative samples for baseline run" )
337335 if "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
338336 raise RuntimeError ("No notion of incontext examples for baseline run" )
337+ negative_system_prompt = None # Not used in baseline
339338
340339 positive_dialogs = []
341340 for _ in range (num_positive_samples ):
342- shuffled_choices = random .sample (choices , len (choices )) if shuffle_choices else choices
343-
344- prompt = self .action_selection_prompt_template (scenario_description , shuffled_choices )
341+ shuf = random .sample (choices , len (choices )) if shuffle_choices else choices
342+ prompt = action_selection_prompt (scenario_description , shuf )
345343 dialog = [{'role' : 'system' , 'content' : positive_system_prompt }]
346344 dialog .extend (positive_icl_examples )
347345 dialog .append ({'role' : 'user' , 'content' : prompt })
@@ -350,15 +348,54 @@ def top_level_choose_action(self,
350348
351349 negative_dialogs = []
352350 for _ in range (num_negative_samples ):
353- shuffled_choices = random .sample (choices , len (choices )) if shuffle_choices else choices
354-
355- prompt = self .action_selection_prompt_template (scenario_description , shuffled_choices )
351+ shuf = random .sample (choices , len (choices )) if shuffle_choices else choices
352+ prompt = action_selection_prompt (scenario_description , shuf )
356353 dialog = [{'role' : 'system' , 'content' : negative_system_prompt }]
357354 dialog .extend (negative_icl_examples )
358355 dialog .append ({'role' : 'user' , 'content' : prompt })
359-
360356 negative_dialogs .append (dialog )
361357
358+ return {"scenario_description" : scenario_description ,
359+ "choices" : choices ,
360+ "positive_system_prompt" : positive_system_prompt ,
361+ "negative_system_prompt" : negative_system_prompt ,
362+ "positive_dialogs" : positive_dialogs ,
363+ "negative_dialogs" : negative_dialogs }
364+
365+ def top_level_choose_action (self ,
366+ scenario_state ,
367+ available_actions ,
368+ alignment_target ,
369+ num_positive_samples = 1 ,
370+ num_negative_samples = 0 ,
371+ generator_batch_size = 5 ,
372+ kdma_descriptions_map = 'align_system/prompt_engineering/kdma_descriptions.yml' ,
373+ reasoning_max_length = 512 ,
374+ generator_seed = - 1 ,
375+ max_generator_tokens = - 1 ,
376+ shuffle_choices = True ,
377+ ** kwargs ):
378+ if self .baseline and num_negative_samples > 0 :
379+ raise RuntimeError ("No notion of negative samples for baseline run" )
380+ if self .baseline and "incontext" in kwargs and kwargs ["incontext" ]["number" ] > 0 :
381+ raise RuntimeError ("No notion of incontext examples for baseline run" )
382+
383+ dialogs_data = OutlinesTransformersADM .get_dialogs (
384+ scenario_state ,
385+ available_actions ,
386+ alignment_target ,
387+ num_positive_samples ,
388+ num_negative_samples ,
389+ kdma_descriptions_map ,
390+ shuffle_choices ,
391+ baseline = self .baseline ,
392+ scenario_description_template = self .scenario_description_template ,
393+ action_selection_prompt_template = self .action_selection_prompt_template ,
394+ )
395+ choices = dialogs_data ["choices" ]
396+ positive_dialogs = dialogs_data ["positive_dialogs" ]
397+ negative_dialogs = dialogs_data ["negative_dialogs" ]
398+
362399 # Need to set the whitespace_pattern to prevent the state
363400 # machine from looping indefinitely in some cases, see:
364401 # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
@@ -367,12 +404,14 @@ def top_level_choose_action(self,
367404 action_choice_json_schema (json .dumps (choices ), reasoning_max_length ),
368405 sampler = self .sampler ,
369406 whitespace_pattern = r"[ ]?" )
407+
408+ if max_generator_tokens >= 0 :
409+ generator = partial (generator , max_tokens = max_generator_tokens )
370410
371411 if generator_seed >= 0 :
372412 torch .manual_seed (generator_seed )
373413 if torch .cuda .is_available ():
374414 torch .cuda .manual_seed (generator_seed )
375- generator = partial (generator , max_tokens = MAX_GENERATOR_TOKENS )
376415
377416
378417 dialog_texts = [self .dialog_to_prompt (d ) for d in
0 commit comments