Skip to content

Commit ffbbcf5

Browse files
committed
Factor get_dialogs to static method in outlines_adm
1 parent 18b7cd6 commit ffbbcf5

1 file changed

Lines changed: 76 additions & 37 deletions

File tree

align_system/algorithms/outlines_adm.py

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
log = logging.getLogger(__name__)
6262
JSON_HIGHLIGHTER = JSONHighlighter()
6363

64-
MAX_GENERATOR_TOKENS = 8092
6564

6665
class OutlinesTransformersADM(ActionBasedADM):
6766
def __init__(self,
@@ -193,7 +192,8 @@ def kdma_value_to_system_prompt(kdma, value):
193192
else:
194193
return None
195194

196-
def _state_to_top_level_prompt(self, scenario_state, actions):
195+
@staticmethod
196+
def _state_to_top_level_prompt(action_selection_prompt_template, scenario_description, scenario_state, actions):
197197
"""
198198
Generate prompt dialog based on given state and actions
199199
"""
@@ -203,8 +203,7 @@ def _state_to_top_level_prompt(self, scenario_state, actions):
203203
scenario_state
204204
)
205205

206-
scenario_description = self.scenario_description_template(scenario_state)
207-
prompt = self.action_selection_prompt_template(scenario_description, choices)
206+
prompt = action_selection_prompt_template(scenario_description, choices)
208207

209208
return prompt, choices
210209

@@ -231,24 +230,25 @@ def run_in_batches(cls, inference_function, inputs, batch_size):
231230
outputs.extend(output)
232231
return outputs
233232

234-
def top_level_choose_action(self,
235-
scenario_state,
236-
available_actions,
237-
alignment_target,
238-
num_positive_samples=1,
239-
num_negative_samples=0,
240-
generator_batch_size=5,
241-
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
242-
reasoning_max_length=512,
243-
generator_seed = -1,
244-
shuffle_choices=True,
245-
**kwargs):
246-
if self.baseline and num_negative_samples > 0:
233+
@staticmethod
234+
def get_dialogs(scenario_state,
235+
available_actions,
236+
alignment_target,
237+
num_positive_samples=1,
238+
num_negative_samples=0,
239+
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
240+
shuffle_choices=True,
241+
baseline=False,
242+
scenario_description_template=scenario_state_description_1,
243+
action_selection_prompt_template=action_selection_prompt,
244+
baseline_system_prompt=baseline_system_prompt,
245+
**kwargs):
246+
if baseline and num_negative_samples > 0:
247247
raise RuntimeError("No notion of negative samples for baseline run")
248-
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
248+
if baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
249249
raise RuntimeError("No notion of incontext examples for baseline run")
250250

251-
scenario_description = self.scenario_description_template(scenario_state)
251+
scenario_description = scenario_description_template(scenario_state)
252252
# Important that the choices stay in the same order as the
253253
# available actions as we'll use the selected index later to
254254
# map to the corresponding action
@@ -261,12 +261,11 @@ def top_level_choose_action(self,
261261
positive_icl_examples = []
262262
negative_icl_examples = []
263263
incontext_settings=kwargs.get("incontext", {})
264-
if not self.baseline and alignment_target is not None:
265-
kdma_values = alignment_target.kdma_values
266264

265+
if not baseline and alignment_target is not None:
266+
kdma_values = alignment_target.kdma_values
267267
if len(kdma_values) != 1:
268268
raise RuntimeError("This ADM assumes a single KDMA target, aborting!")
269-
270269
kdma_value = kdma_values[0]
271270
if isinstance(kdma_value, KDMAValue):
272271
kdma_value = kdma_value.to_dict()
@@ -280,8 +279,8 @@ def top_level_choose_action(self,
280279
kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader)
281280
name = kdma_descriptions[kdma]['name']
282281

283-
positive_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, value)
284-
negative_system_prompt = self.__class__.kdma_value_to_system_prompt(kdma, negative_value)
282+
positive_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, value)
283+
negative_system_prompt = OutlinesTransformersADM.kdma_value_to_system_prompt(kdma, negative_value)
285284

286285
if positive_system_prompt is None:
287286
raise RuntimeError("Couldn't find system prompt for kdma: {}, and "
@@ -291,8 +290,7 @@ def top_level_choose_action(self,
291290
"value: {}.".format(kdma, negative_value))
292291

293292
if "incontext" in kwargs and "number" in incontext_settings and incontext_settings["number"] > 0:
294-
scenario_to_match = self.scenario_description_template(scenario_state)
295-
prompt_to_match, _ = self._state_to_top_level_prompt(scenario_state, available_actions)
293+
prompt_to_match, _ = OutlinesTransformersADM._state_to_top_level_prompt(action_selection_prompt_template, scenario_state, available_actions)
296294

297295
# Create positive ICL example generators
298296
positive_target = {'kdma': kdma, 'name': name, 'value': value}
@@ -301,7 +299,7 @@ def top_level_choose_action(self,
301299
# Get subset of relevant of examples
302300
positive_selected_icl_examples = positive_icl_example_generator.select_icl_examples(
303301
sys_kdma_name=kdma,
304-
scenario_description_to_match=scenario_to_match,
302+
scenario_description_to_match=scenario_description,
305303
prompt_to_match=prompt_to_match,
306304
state_comparison=scenario_state
307305
)
@@ -321,7 +319,7 @@ def top_level_choose_action(self,
321319
# Get subset of relevant of examples
322320
negative_selected_icl_examples = negative_icl_example_generator.select_icl_examples(
323321
sys_kdma_name=kdma,
324-
scenario_description_to_match=scenario_to_match,
322+
scenario_description_to_match=scenario_description,
325323
prompt_to_match=prompt_to_match,
326324
state_comparison=scenario_state
327325
)
@@ -331,17 +329,17 @@ def top_level_choose_action(self,
331329
{"role": "assistant", "content": f'{icl_sample["response"]}'}
332330
])
333331
else:
334-
positive_system_prompt = self.baseline_system_prompt()
332+
positive_system_prompt = baseline_system_prompt()
335333
if num_negative_samples > 0:
336334
raise RuntimeError("No notion of negative samples for baseline run")
337335
if "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
338336
raise RuntimeError("No notion of incontext examples for baseline run")
337+
negative_system_prompt = None # Not used in baseline
339338

340339
positive_dialogs = []
341340
for _ in range(num_positive_samples):
342-
shuffled_choices = random.sample(choices, len(choices)) if shuffle_choices else choices
343-
344-
prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices)
341+
shuf = random.sample(choices, len(choices)) if shuffle_choices else choices
342+
prompt = action_selection_prompt(scenario_description, shuf)
345343
dialog = [{'role': 'system', 'content': positive_system_prompt}]
346344
dialog.extend(positive_icl_examples)
347345
dialog.append({'role': 'user', 'content': prompt})
@@ -350,15 +348,54 @@ def top_level_choose_action(self,
350348

351349
negative_dialogs = []
352350
for _ in range(num_negative_samples):
353-
shuffled_choices = random.sample(choices, len(choices)) if shuffle_choices else choices
354-
355-
prompt = self.action_selection_prompt_template(scenario_description, shuffled_choices)
351+
shuf = random.sample(choices, len(choices)) if shuffle_choices else choices
352+
prompt = action_selection_prompt(scenario_description, shuf)
356353
dialog = [{'role': 'system', 'content': negative_system_prompt}]
357354
dialog.extend(negative_icl_examples)
358355
dialog.append({'role': 'user', 'content': prompt})
359-
360356
negative_dialogs.append(dialog)
361357

358+
return {"scenario_description": scenario_description,
359+
"choices": choices,
360+
"positive_system_prompt": positive_system_prompt,
361+
"negative_system_prompt": negative_system_prompt,
362+
"positive_dialogs": positive_dialogs,
363+
"negative_dialogs": negative_dialogs}
364+
365+
def top_level_choose_action(self,
366+
scenario_state,
367+
available_actions,
368+
alignment_target,
369+
num_positive_samples=1,
370+
num_negative_samples=0,
371+
generator_batch_size=5,
372+
kdma_descriptions_map='align_system/prompt_engineering/kdma_descriptions.yml',
373+
reasoning_max_length=512,
374+
generator_seed=-1,
375+
max_generator_tokens=-1,
376+
shuffle_choices=True,
377+
**kwargs):
378+
if self.baseline and num_negative_samples > 0:
379+
raise RuntimeError("No notion of negative samples for baseline run")
380+
if self.baseline and "incontext" in kwargs and kwargs["incontext"]["number"] > 0:
381+
raise RuntimeError("No notion of incontext examples for baseline run")
382+
383+
dialogs_data = OutlinesTransformersADM.get_dialogs(
384+
scenario_state,
385+
available_actions,
386+
alignment_target,
387+
num_positive_samples,
388+
num_negative_samples,
389+
kdma_descriptions_map,
390+
shuffle_choices,
391+
baseline=self.baseline,
392+
scenario_description_template=self.scenario_description_template,
393+
action_selection_prompt_template=self.action_selection_prompt_template,
394+
)
395+
choices = dialogs_data["choices"]
396+
positive_dialogs = dialogs_data["positive_dialogs"]
397+
negative_dialogs = dialogs_data["negative_dialogs"]
398+
362399
# Need to set the whitespace_pattern to prevent the state
363400
# machine from looping indefinitely in some cases, see:
364401
# https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
@@ -367,12 +404,14 @@ def top_level_choose_action(self,
367404
action_choice_json_schema(json.dumps(choices), reasoning_max_length),
368405
sampler=self.sampler,
369406
whitespace_pattern=r"[ ]?")
407+
408+
if max_generator_tokens >= 0:
409+
generator = partial(generator, max_tokens=max_generator_tokens)
370410

371411
if generator_seed >= 0:
372412
torch.manual_seed(generator_seed)
373413
if torch.cuda.is_available():
374414
torch.cuda.manual_seed(generator_seed)
375-
generator = partial(generator, max_tokens=MAX_GENERATOR_TOKENS)
376415

377416

378417
dialog_texts = [self.dialog_to_prompt(d) for d in

0 commit comments

Comments
 (0)