diff --git a/align_system/algorithms/lib/__init__.py b/align_system/algorithms/lib/__init__.py new file mode 100644 index 00000000..84f2f138 --- /dev/null +++ b/align_system/algorithms/lib/__init__.py @@ -0,0 +1,15 @@ +from importlib import reload + +def reload_all(): + # Useful function for developing in an interactive environment without having to restart the kernel + + from align_system.algorithms.lib import util + from align_system.algorithms.lib import language_model as lm + from align_system.algorithms.lib.chat import dialog_tokenizer as dt + from align_system.algorithms.lib.chat import chat_language_model as clm + from align_system.algorithms import llama_2_kdma_predicting_adm as kpa + + + # Reload in the correct order + for module in [util, lm, dt, clm, kpa]: + reload(module) diff --git a/align_system/algorithms/lib/chat/__init__.py b/align_system/algorithms/lib/chat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/align_system/algorithms/lib/chat/chat_language_model.py b/align_system/algorithms/lib/chat/chat_language_model.py new file mode 100644 index 00000000..6326f4dc --- /dev/null +++ b/align_system/algorithms/lib/chat/chat_language_model.py @@ -0,0 +1,156 @@ +from typing import List, Dict, Optional, Callable, Union, TextIO + +from align_system.algorithms.lib.language_model import LanguageModel +from align_system.algorithms.lib.chat.dialog_tokenizer import dialog_tokenizers +from align_system.algorithms.lib.util import read_template, format_template, dialog_from_string, dialog_to_string + +class ChatLanguageModel(LanguageModel): + + def __init__(self, model: LanguageModel, tokenizer: Callable[[str], List[str]]): + """ + Initializes the chat language model. + + :param model: Pretrained language model. + :param tokenizer: Tokenizer function. + """ + super().__init__(model, tokenizer) + model_name = model.name_or_path + assert model_name in dialog_tokenizers, f'No dialog tokenizer found for model {model_name}' + self.dialog_tokenizer = dialog_tokenizers[model_name](tokenizer) + + def generate_responses(self, + dialogs: List[Dict[str, str]], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: + """ + Generates responses for given dialogs. + + :param dialogs: List of dialogs. + :param log_file: Optional file to log the process. + :param max_new_tokens: Maximum number of new tokens to generate. + :param temperature: Temperature for sampling. + :return: Generated responses. + """ + # If logging is requested, write the dialogues into the log file + if log_file is not None: + log_file.write('**Dialogs:**\n') + for i, dialog in enumerate(dialogs): + log_file.write(f'*Dialog {i}:*\n{dialog_to_string(dialog)}\n') + log_file.flush() + + # Prepare lists for the last user dialogues and prefixes. + # Prefix refers to the assistant's response in the last turn of a dialogue. + user_last_dialogs = [] + prefixes = [] + for dialog in dialogs: + prefix = '' + if dialog[-1]['role'] == 'assistant': + prefix = dialog[-1]['content'] + dialog = dialog[:-1] + user_last_dialogs.append(dialog) + prefixes.append(prefix) + + # Tokenization step + prompt_token_lists = [ + [self.dialog_tokenizer.dialog_to_tokens(dialog)] + for dialog in user_last_dialogs + ] + + # Add the prefix tokens to the prompt tokens + for prompt_tokens, prefix in zip(prompt_token_lists, prefixes): + if len(prefix) > 0: + prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) + prompt_tokens[0] += prefix_tokens + + # Generate responses using tokens + prompt_token_lists = [x[0] for x in prompt_token_lists] + responses = self.generate_from_tokens(prompt_token_lists, max_new_tokens=max_new_tokens, temperature=temperature) + prefixed_responses = [ + f'{prefix}{response}' + for prefix, response in zip(prefixes, responses) + ] + + # If logging is requested, write the generated responses into the log file + if log_file is not None: + log_file.write('**Generated Responses:**\n') + for i, response in enumerate(prefixed_responses): + log_file.write(f'*Response {i}:*\n{response}\n') + log_file.flush() + + return prefixed_responses + + def generate_from_template( + self, + template_files: Union[List[str], str], + substitution_dicts: Union[List[Dict[str, str]], Dict[str, str]], + parse_generation_fn: Optional[Callable[[str], str]] = None, + batch_size: int = 5, + log_file: Optional[TextIO] = None, + max_tokens: int = 512, + temperature: float = 0.6, + max_retry: int = 10, + verbose: bool = False) -> List[str]: + """ + Generates responses for given templates with substitutions. + + :param template_files: Template files to use for generation. + :param substitution_dicts: Substitution dictionaries for the templates. + :param parse_generation_fn: Function to parse the generated responses. + :param batch_size: Batch size for generating responses. + :param log_file: Optional file to log the process. + :param max_tokens: Maximum number of tokens to generate. + :param temperature: Temperature for sampling. + :param max_retry: Maximum number of attempts to generate a valid output. + :param verbose: If True, verbose logging is enabled. + :return: Generated responses. + """ + if isinstance(substitution_dicts, dict): + substitution_dicts = [substitution_dicts] + + if isinstance(template_files, str): + template_files = [template_files] * len(substitution_dicts) + + assert len(template_files) == len(substitution_dicts), 'Number of templates and substitutions do not match' + + # Create a dialogue for each template/substitution pair + dialogs = { + i: dialog_from_string(format_template(read_template(template_file), **substitutions)) + for i, (template_file, substitutions) in enumerate(zip(template_files, substitution_dicts)) + } + + outputs = {} + input_counts = {} + while len(dialogs) > 0: + sample_ids = list(dialogs.keys())[:batch_size] + batch = [dialogs[i] for i in sample_ids] + generations = self.generate_responses(batch, log_file=log_file, max_new_tokens=max_tokens, temperature=temperature) + + # Process the generated responses + for sample_id, generation in zip(sample_ids, generations): + input_counts[sample_id] = input_counts.get(sample_id, 0) + 1 + + # If the maximum number of try-outs is exceeded, throw an error + if input_counts[sample_id] > max_retry: + raise Exception(f'Could not generate valid output for sample [{sample_id}]') + + # If there's a specific function to parse the generations, try to apply it + if parse_generation_fn is not None: + try: + outputs[sample_id] = parse_generation_fn(generation) + del dialogs[sample_id] + except Exception as e: + if verbose: + print(f'Error: could not parse output for sample [{sample_id}]') + print(e) + pass + else: + outputs[sample_id] = generation + del dialogs[sample_id] + + assert len(outputs) == len(substitution_dicts), 'Unexpected state: number of outputs and substitutions do not match' + + return [ + outputs[i] + for i in range(len(outputs)) + ] \ No newline at end of file diff --git a/align_system/algorithms/lib/chat/dialog_tokenizer.py b/align_system/algorithms/lib/chat/dialog_tokenizer.py new file mode 100644 index 00000000..89ac8a3f --- /dev/null +++ b/align_system/algorithms/lib/chat/dialog_tokenizer.py @@ -0,0 +1,82 @@ +from abc import abstractmethod +from typing import List, Dict +from transformers import PreTrainedTokenizerBase + +class DialogTokenizer: + """ + Abstract base class for dialog tokenizers. + """ + def __init__(self, tokenizer: PreTrainedTokenizerBase): + """ + Initializes the dialog tokenizer. + + :param tokenizer: Pretrained tokenizer. + """ + self.tokenizer = tokenizer + + @abstractmethod + def dialog_to_tokens(self, dialog_messages: List[Dict[str, str]]) -> List[int]: + """ + Transforms a dialog to tokens. + + :param dialog_messages: List of dialogs. + :returns: List of tokens representing the dialog. + """ + pass + + +class Llama2DialogTokenizer(DialogTokenizer): + """ + Dialog tokenizer for Llama-2. + """ + + def dialog_to_tokens(self, dialog_messages: List[Dict[str, str]]) -> List[int]: + """ + Transforms a dialog to tokens. Llama communicates using system, user and assistant roles. + + :param dialog_messages: List of dialogs. + :returns: List of tokens representing the dialog. + """ + # Define instance and system borders + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + # If the role of the first message is system + if dialog_messages[0]["role"] == "system": + # Create an initial dialog entry combining system and user messages + system_dialog = {"role": dialog_messages[1]["role"], + "content": B_SYS + dialog_messages[0]["content"] + E_SYS + dialog_messages[1]["content"]} + # Update dialog to start with system_dialog and followed by the rest of the dialog + dialog_messages = [system_dialog] + dialog_messages[2:] + + # Ensure the correct dialog order (system, user, assistant, user, assistant... ) + assert all([msg["role"] == "user" for msg in dialog_messages[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog_messages[1::2]]), \ + "Model only supports 'system', 'user' and 'assistant' roles, in the sequence (s/u/a/u/a...)" + + # Encode each user message and its following assistant message into tokens + dialog_tokens = [] + for prompt, answer in zip(dialog_messages[::2], dialog_messages[1::2]): + tokenized_message = ([self.tokenizer.bos_token_id] + + self.tokenizer.encode(f"{B_INST} {prompt['content'].strip()} {E_INST} {answer['content'].strip()} ", + add_special_tokens=False) + + [self.tokenizer.eos_token_id]) + dialog_tokens.extend(tokenized_message) + + # Ensure the final message is from the user + assert dialog_messages[-1]["role"] == "user", "Last message must be from the user." + + # Encode the user's final message into tokens and add to dialog_tokens + user_final_message_tokens = ([self.tokenizer.bos_token_id] + self.tokenizer.encode( + f"{B_INST} {dialog_messages[-1]['content'].strip()} {E_INST}", + add_special_tokens=False)) + dialog_tokens.extend(user_final_message_tokens) + + return dialog_tokens + + +# This mapping should ideally be updated when adding any new tokenizer classes to the project +dialog_tokenizers = { + 'meta-llama/Llama-2-7b-chat-hf': Llama2DialogTokenizer, + 'meta-llama/Llama-2-13b-chat-hf': Llama2DialogTokenizer, +} \ No newline at end of file diff --git a/align_system/algorithms/lib/language_model.py b/align_system/algorithms/lib/language_model.py new file mode 100644 index 00000000..1fe5de7c --- /dev/null +++ b/align_system/algorithms/lib/language_model.py @@ -0,0 +1,167 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from typing import List, Union, Optional, TextIO + +class LanguageModel: + """ + Class to define the Language Model. + """ + + @classmethod + def load_model(cls, + hf_model_name: str, + precision: torch.dtype = torch.float32, + device: str = 'cuda') -> 'LanguageModel': + """ + Load the language model. + + :param hf_model_name: Name of the model in Huggingface. + :param precision: Precision of the model's weights. + :param device: Device to run the model on. + :return: Initialized LanguageModel object. + """ + # Load the model from Huggingface + model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=precision) + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + model = model.to(device) + return cls(model, tokenizer) + + def __init__(self, + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer) -> None: + """ + Initializes the language model. + + :param model: Pretrained Huggingface model. + :param tokenizer: Tokenizer from Huggingface. + """ + self.model = model + self.tokenizer = tokenizer + + def generate_from_tokens(self, + prompt_token_lists: List[List[int]], + log_file: Union[None, str, object] = None, + max_new_tokens: int = 512, + temperature: float = 0.6, + padding: str='left') -> List[str]: + """ + Generates text from the given list of tokens. + + :param prompt_token_lists: List of lists of tokens to generate the text. + :param log_file: Path to the log file. + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. + :param padding: Padding direction, either 'left' or 'right'. + :return: Generated text. + """ + # Move to the model's device and unpack + prompt_token_lists = [ + torch.tensor(prompt_tokens).to(self.model.device).unsqueeze(0) + for prompt_tokens in prompt_token_lists + ] + + max_length = max([prompt_tokens.size(1) for prompt_tokens in prompt_token_lists]) + + pad_token_id = self.tokenizer.pad_token_id + + # Padding function for the desired direction + assert padding == 'left' or padding == 'right', f"Padding must be either 'left' or 'right', got {padding}" + pad_fn = lambda prompt_token_size: (max_length - prompt_token_size, 0) if padding == 'left' else (0, max_length - prompt_token_size) + + # Pad each sequence to the max length + padded_prompt_token_lists = [ + torch.nn.functional.pad(prompt_tokens, pad_fn(prompt_tokens.size(1)), value=pad_token_id) + for prompt_tokens in prompt_token_lists + ] + + attention_masks = [ + torch.nn.functional.pad(torch.ones_like(prompt_tokens), pad_fn(prompt_tokens.size(1)), value=0) + for prompt_tokens in prompt_token_lists + ] + + position_ids = [ + torch.nn.functional.pad(torch.arange(prompt_tokens.size(1)).unsqueeze(0), pad_fn(prompt_tokens.size(1)), value=0) + for prompt_tokens in prompt_token_lists + ] + + + # Stack the padded sequences + stacked_prompt_tokens = torch.cat(padded_prompt_token_lists, dim=0) + stacked_attention_masks = torch.cat(attention_masks, dim=0) + stacked_position_ids = torch.cat(position_ids, dim=0) + + if log_file is not None: + prompt_texts = [ + self.tokenizer.decode(prompt_tokens.squeeze(0), skip_special_tokens=True) + for prompt_tokens in padded_prompt_token_lists + ] + log_file.write('**Prompt texts:**\n') + for i, prompt_text in enumerate(prompt_texts): + log_file.write(f'Prompt {i}:\n{prompt_text}\n') + + log_file.flush() + + + + # Generate outputs for all dialogs in a batch + # TODO ensure the batch size is not too large for the GPU + outputs = self.model.generate( + stacked_prompt_tokens, + attention_mask=stacked_attention_masks, + # position_ids=stacked_position_ids, # TODO figure out why including the position ids breaks the model + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=max_new_tokens, + temperature=temperature + ) + + # Decode the generated outputs + decoded_outputs = [ + self.tokenizer.decode(output_tokens[len(prompt_tokens.squeeze(0)):], skip_special_tokens=True) + for output_tokens, prompt_tokens in zip(outputs.sequences, padded_prompt_token_lists) + ] + + if log_file is not None: + log_file.write('**Generated texts:**\n') + for i, decoded_output in enumerate(decoded_outputs): + log_file.write(f'*Generation {i}:*\n{decoded_output}\n') + log_file.flush() + + return decoded_outputs + + def generate(self, + prompt_texts: List[str], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: + """ + Generates text from the given list of inputs. + + :param prompt_texts: List of prompts to generate from. + :param log_file: Optional file object to write to + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. + """ + # Convert the text to tokens and generate the text + prompt_token_lists = [self.tokenizer.encode(prompt_text) for prompt_text in prompt_texts] + return self.generate_from_tokens(prompt_token_lists, log_file, max_new_tokens, temperature) + + def generate_with_prefixes(self, + prompt_texts: List[str], + prefixes: List[str], + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6) -> List[str]: + """ + Generates text from the given list of inputs with prefixes. + + :param prompt_texts: List of prompts to generate from. + :param prefixes: List of prefixes to prepend to the generated text. + :param log_file: Optional file object to write to + :param max_new_tokens: Maximum number of new tokens to be generated. + :param temperature: Temperature for probability adjustment. + """ + # Combine the inputs with prefixes and generate the text + combined_texts = [f'{prompt}{prefix}' for prompt, prefix in zip(prompt_texts, prefixes)] + generations = self.generate(combined_texts, log_file, max_new_tokens, temperature) + return [f'{prefix}{generation}' for prefix, generation in zip(prefixes, generations)] diff --git a/align_system/algorithms/lib/templates/bbn_kdma_descriptions.yml b/align_system/algorithms/lib/templates/bbn_kdma_descriptions.yml new file mode 100644 index 00000000..1b935111 --- /dev/null +++ b/align_system/algorithms/lib/templates/bbn_kdma_descriptions.yml @@ -0,0 +1,19 @@ +basic_knowledge: + name: Basic knowledge + description: Basic knowledge means being able to apply domain knowledge in typical situations. For this program, domain knowledge focuses on initial triage, tactical combat casualty care, and field care. Someone with high basic knowledge will be able to assess and prioritize which injuries to treat immediately, how to continue monitoring and assessing casualties, and how to prioritize and prepare casualties for evacuation. Someone with low basic knowledge might treat less urgent injuries first, might not apply an appropriate treatment, or might freeze up and not take action. + +risk_aversion: + name: Risk aversion + description: Risk aversion is a tendency to avoid uncertainty and prefer taking actions whose expected outcomes have a lower range of variation. A person with high risk aversion may prefer an action that has somewhat lower total expected value when it also has less difference between the best and worst expected outcome. A person with low risk aversion may prefer an action that has a superior best expected outcome, even if the total expected value is lower. We think this might manifest differently in the military medical domain than in the abstract decision making context. We suspect that in a military medical domain, lower risk choices may be decisions that keep open multiple possible paths, while high risk choices may be decisions that commit to a single possible path. + +protocol_focus: + name: Protocol focus + description: Protocol focus is the tendency to prioritize following a protocol, guideline, rule, or process, instead of considering specific context factors as reasons to make exceptions to the protocol or rule. A high protocol focus person will stick to the rules or guidelines, even when it seems like that may waste time, waste effort, or cause unhappiness. A low protocol focus person will make exceptions or break rules when they believe that the rule was not intended to apply to that particular situation, or when they believe that following the rule will produce an inferior result. + +fairness: + name: Fairness + description: Fairness is treating different people in the same way, such that no person is valued more based on their group membership, identity, or prior actions. This can manifest as favoring those who you have a positive personal relationship with, those who are part of a group you also belong to, those who have high social status, or those who were injured while being either virtuous or innocent. A person with high fairness will provide equal assistance to everyone who needs it. A person with low fairness will prioritize providing assistance to those that they care more about, who are more similar to them, or who may be more useful in some way. The concept of fairness may also describe a preference for equal outcomes versus equal assistance. + +utilitarianism: + name: Utilitarianism + description: In this context, utilitarianism is the priority placed on saving the most people or contributing to the best medical outcome for the most people. A person with high utilitarianism will try to save the most people, which under conditions of limited resources may mean withholding or rationing care to patients for whom treatment has a low probability of improving outcome. A person with low utilitarianism may instead allocate limited resources based on other considerations, such as their personal feelings toward patients or trying to demonstrate other values or virtues (such as kindness, fairness, respect, or loyalty). The priority placed on expected utility of the overall outcome. High utilitarianism also suggests a willingness to harm some individuals, through both omission and commission, as part of a tradeoff to increase the overall benefit. It may also consider the likely trajectory of a patient post-treatment and thus deprioritize treating those who are at high risk of dying soon due to other factors. \ No newline at end of file diff --git a/align_system/algorithms/lib/templates/pred_kdma_O.txt b/align_system/algorithms/lib/templates/pred_kdma_O.txt new file mode 100644 index 00000000..e7742b2e --- /dev/null +++ b/align_system/algorithms/lib/templates/pred_kdma_O.txt @@ -0,0 +1,30 @@ +=== system +Given the following scenario, response, and predicted outcome your job is to asses, on a scale from 0 to 10, how well the response characterizes the decision making behavior of someone who acting with high {{kdma}}. Provide one sentence of reasoning. + +{{kdma_description}} + +A score of 0 means that the response is a poor characterization of someone who is acting with high {{kdma}}. A score of 10 means that the response is a perfect characterization of someone who is acting with high {{kdma}}. + +Ensure your response strictly adheres to following JSON format +{ + "score": int score +} +Provide no other information or text in your response. Your response ends when you reach the end of the JSON object. +=== user +Consider the following scenario: +``` +{{scenario}} +{{probe}} +``` +Review this response to the scenario and consider the corresponding predicted outcome: +``` +Response: +{{choice}} + +Predicted Outcome: +{{outcome}} +``` +Generate a JSON object that indicates how well the response characterizes the decision making behavior of someone who is acting with high {{kdma}}. Your response ends when you reach the end of the JSON object. +=== assistant +{ + "score": \ No newline at end of file diff --git a/align_system/algorithms/lib/templates/pred_kdma_RO.txt b/align_system/algorithms/lib/templates/pred_kdma_RO.txt new file mode 100644 index 00000000..5893afec --- /dev/null +++ b/align_system/algorithms/lib/templates/pred_kdma_RO.txt @@ -0,0 +1,31 @@ +=== system +Given the following scenario, response, and predicted outcome your job is to asses, on a scale from 0 to 10, how well the response characterizes the decision making behavior of someone who acting with high {{kdma}}. Provide one sentence of reasoning. + +{{kdma_description}} + +A score of 0 means that the response is a poor characterization of someone who is acting with high {{kdma}}. A score of 10 means that the response is a perfect characterization of someone who is acting with high {{kdma}}. + +Ensure your response strictly adheres to following JSON format +{ + "reasoning": "one-sentence reasoning explaining your score", + "score": int score +} +Provide no other information or text in your response. Your response ends when you reach the end of the JSON object. +=== user +Consider the following scenario: +``` +{{scenario}} +{{probe}} +``` +Review this response to the scenario and consider the corresponding predicted outcome: +``` +Response: +{{choice}} + +Predicted Outcome: +{{outcome}} +``` +Generate a JSON object that provides one sentence of reasoning and indicates how well the response characterizes the decision making behavior of someone who is acting with high {{kdma}}. Your response ends when you reach the end of the JSON object. +=== assistant +{ + "reasoning": " \ No newline at end of file diff --git a/align_system/algorithms/lib/templates/pred_outcome.txt b/align_system/algorithms/lib/templates/pred_outcome.txt new file mode 100644 index 00000000..9aea5572 --- /dev/null +++ b/align_system/algorithms/lib/templates/pred_outcome.txt @@ -0,0 +1,14 @@ +=== system +In a clear and consice way describe the outcomes of answering the given question with the given response. Do not provide additional advice for the situation. Do not indicate wether or not you agree with the response. Simply analyze the outcomes of the response. Limit your analysis to a single sentence. +=== user +Consider the following scenario: +``` +{{scenario}} +{{probe}} +``` +What would happen if you chose to answer the question with the following response? +``` +{{choice}} +``` +=== assistant +Outcome: \ No newline at end of file diff --git a/align_system/algorithms/lib/util.py b/align_system/algorithms/lib/util.py new file mode 100644 index 00000000..ee2ff783 --- /dev/null +++ b/align_system/algorithms/lib/util.py @@ -0,0 +1,93 @@ +import re +import os +from typing import List, Dict + + +def dialog_from_string(dialog_string: str) -> List[Dict[str, str]]: + """ + Transforms the dialog in string format to a list of dictionary format. + + :param dialog_string: Dialog in string format. + :return: Dialog in the list of dictionary format. + """ + # Dictionary to map string markers to role names + dialog_markers = { + '=== system': 'system', + '=== user': 'user', + '=== assistant': 'assistant', + } + dialog = [] + lines = dialog_string.split('\n') + + current_role = '' + current_content = '' + for line in lines: + if line.strip() in dialog_markers: # If a line indicates a role change + if current_role and current_content: # Save the previous role's dialog + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + current_role = dialog_markers[line.strip()] # Set the new role + current_content = '' + else: # Continue appending content if the role hasn't changed + current_content += f'{line}\n' + # Append the last piece of dialog + if current_role and current_content: + dialog.append({ + 'role': current_role, + 'content': current_content.strip() + }) + return dialog + + +def dialog_to_string(dialog: List[Dict[str, str]]) -> str: + """ + Transforms the dialog in list of dictionary to string format. + + :param dialog: Dialog in list of dictionary format. + :return: Dialog in string format. + """ + output = '' + for dialog_piece in dialog: + role = dialog_piece['role'] + content = dialog_piece['content'] + output += f"=== {role}\n" + output += f"{content}\n" + + return output + + +def format_template(template: str, **substitutions: str) -> str: + """ + Replaces placeholders in a template with provided substitutions. + + :param template: The template with placeholders indicated as {{placeholder}}. + :param substitutions: The substitutions to replace in the template. + :return: The template with all placeholders substituted. + """ + for key, value in substitutions.items(): + key = '{{%s}}' % key + if not key in template: + raise Exception(f'Could not find key {key} in template') + template = template.replace(key, value) + + # ensure there are no strings surrounded by {{ }} + matches = re.findall(r'{{.*?}}', template) + # if there are any matches, raise an exception + if len(matches) > 0: + raise Exception(f'Unsubstituited key(s) in template: {matches}') + + return template + + +def read_template(template_file_name: str, template_dir='templates') -> str: + current_directory = os.path.dirname(os.path.abspath(__file__)) + full_path = os.path.join(current_directory, template_dir, template_file_name) + + with open(full_path, 'r') as template_file: + template = template_file.read() + + return template + + diff --git a/align_system/algorithms/llama_2_kdma_predicting_adm.py b/align_system/algorithms/llama_2_kdma_predicting_adm.py new file mode 100644 index 00000000..5464450c --- /dev/null +++ b/align_system/algorithms/llama_2_kdma_predicting_adm.py @@ -0,0 +1,162 @@ +import json +import yaml +import os +from typing import Union, List, Dict, Tuple, Optional, TextIO +from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel + +class Llama2KDMAPredictingADM(ChatLanguageModel): + + def predict_outcomes(self, + scenario_text: str, + probe_text: str, + choices: List[str], + log_file: Optional[TextIO] = None, + max_tokens: int = 512, + temperature: float = 0.6, + outcome_template_file: str = 'pred_outcome.txt') -> List[str]: + """ + Predicts outcomes for given scenario, probe and choices. + + :param scenario: Scenario text. + :param probe: Probe text. + :param choices: Choices text. + :param log_file: Optional log file. + :param max_tokens: Maximum number of tokens to generate. + :param temperature: Temperature for sampling. + :param outcome_template_file: Template file for Outcomes. + :return: List of generated predictions. + """ + return self.generate_from_template( + outcome_template_file, + [ + { + 'scenario': scenario_text, + 'probe': probe_text, + 'choice': choice, + } + for choice in choices + ], + log_file=log_file, + max_tokens=max_tokens, + temperature=temperature + ) + + + def predict_kdma_scores(self, + scenario_text: str, + probe_text: str, + choice_texts: List[str], + predicted_outcomes: Optional[List[str]] = None, + generate_reasoning: bool = True, + log_file: Optional[TextIO] = None, + max_new_tokens: int = 512, + temperature: float = 0.6, + kdma_template_file: str = 'pred_kdma_RO.txt', + kdma_descriptions_file: str = 'lib/templates/bbn_kdma_descriptions.yml') -> Union[List[Dict[str, float]], Tuple[List[Dict[str, float]], List[Dict[str, str]]]]: + """ + Predicts KDMA scores each choice text under the given scenario and probe. + + :param scenario_text: Scenario text. + :param probe_text: Probe text. + :param choice_texts: Choices text. + :param predicted_outcomes: Predicted outcomes. + :param generate_reasoning: Flag to generate reasoning. + :param log_file: Optional log file. + :param max_new_tokens: Maximum number of new tokens to generate. + :param temperature: Temperature for sampling. + :param kdma_template_file: Template file for KDMA prediction. + :param kdma_descriptions_file: Template file for KDMA descriptions. + :return: KDMA predictions. If generate_reasoning is True, return predictions and reasonings. + """ + choice_ids = [f'choice_{i}' for i in range(len(choice_texts))] + substitutions = [] + info = [] + + relative_dir = os.path.dirname(__file__) + kdma_descriptions_file_path = os.path.join(relative_dir, kdma_descriptions_file) + + with open(kdma_descriptions_file_path, 'r') as f: + kdma_descriptions = yaml.load(f, Loader=yaml.FullLoader) + + if predicted_outcomes is None: + predicted_outcomes = [None] * len(choice_texts) + + for choice_id, choice, outcome in zip(choice_ids, choice_texts, predicted_outcomes): + for kdma, kdma_info in kdma_descriptions.items(): + substitution = { + 'kdma': kdma_info['name'], + 'kdma_description': kdma_info['description'], + 'scenario': scenario_text, + 'probe': probe_text, + 'choice': choice, + } + + if outcome is not None: + substitution['outcome'] = outcome + + substitutions.append(substitution) + info.append((choice_id, kdma)) + + def parse_kdma_score_response(response: str) -> Dict[str, Union[float, str]]: + """ + Parses KDMA score response. + + :param response: Response to parse. + :return: Dictionary with KDMA score and reasoning if generate_reasoning. + """ + if generate_reasoning: + start_idx = response.find('{') + end_idx = response.rfind('}') + response_json = json.loads(response[start_idx:end_idx+1]) + assert 'score' in response_json, 'score not found in response' + assert 'reasoning' in response_json, 'reasoning not found in response' + else: + # find the first numeric character + char = None + for c in response: + if c.isnumeric(): + char = c + break + assert char is not None, 'Could not find numeric character in response' + response_json = { + 'score': float(response[response.find(char):]) + } + return response_json + + generations = self.generate_from_template( + kdma_template_file, + substitutions, + parse_kdma_score_response, + log_file=log_file, + max_tokens=max_new_tokens, + temperature=temperature, + ) + + predicted_kdmas = {} + reasonings = {} + for (choice_id, kdma), generation in zip(info, generations): + predicted_choice_kdmas = predicted_kdmas.get(choice_id, {}) + predicted_kdmas[choice_id] = predicted_choice_kdmas + + choice_reasonings = reasonings.get(choice_id, {}) + reasonings[choice_id] = choice_reasonings + + predicted_choice_kdmas[kdma] = generation['score'] + + if generate_reasoning: + choice_reasonings[kdma] = generation['reasoning'] + + predicted_kdmas = [ + predicted_kdmas[choice_id] + for choice_id in choice_ids + ] + if generate_reasoning: + reasonings = [ + reasonings[choice_id] + for choice_id in choice_ids + ] + + if generate_reasoning: + return predicted_kdmas, reasonings + else: + return predicted_kdmas \ No newline at end of file diff --git a/align_system/algorithms/llm_chat_baseline.py b/align_system/algorithms/llm_chat_baseline.py index 8b5b7c27..3c739421 100644 --- a/align_system/algorithms/llm_chat_baseline.py +++ b/align_system/algorithms/llm_chat_baseline.py @@ -88,12 +88,17 @@ def __init__(self, device='cuda', hf_model='meta-llama/Llama-2-7b-chat-hf', prec self.tokenizer = None - def load_model(self): - print('Loading model:', self.hf_model) - self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision) - self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model) - - self.model = self.model.to(self.device) + def load_model(self, model=None, tokenizer=None): + assert (model is None) == (tokenizer is None), "model and tokenizer must both be None or both be not None." + if model is not None: + print('Loading model and tokenizer from provided objects.') + self.model = model + self.tokenizer = tokenizer + else: + print('Loading model:', self.hf_model) + self.model = AutoModelForCausalLM.from_pretrained(self.hf_model, torch_dtype=self.precision) + self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model) + self.model = self.model.to(self.device) def get_character_ids(self, character_str): diff --git a/align_system/tests/test_chat_language_model.py b/align_system/tests/test_chat_language_model.py new file mode 100644 index 00000000..aa81e701 --- /dev/null +++ b/align_system/tests/test_chat_language_model.py @@ -0,0 +1,39 @@ +import pytest + +from align_system.algorithms.lib.chat.chat_language_model import ChatLanguageModel + +MODEL_TO_TEST = 'meta-llama/Llama-2-7b-chat-hf' + +@pytest.fixture(scope="module") +def chat_language_model(): + # Load the model once for all tests that use this fixture + return ChatLanguageModel.load_model(MODEL_TO_TEST) + + +def test_generate_responses(chat_language_model): + dialogs = [ + [ + {'role': 'system', 'content': 'speak like a pirate'}, + {'role': 'user', 'content': 'hello'}, + ], + [ + {'role': 'system', 'content': 'speak like a pirate'}, + {'role': 'user', 'content': 'hello'}, + {'role': 'assistant', 'content': 'What if you'}, + ], + [ + {'role': 'system', 'content': 'speak like a pirate'}, + {'role': 'user', 'content': 'hello'}, + {'role': 'assistant', 'content': 'What if you'}, + ] + ] + + responses = chat_language_model.generate_responses(dialogs, max_new_tokens=512, temperature=0.0001) + + assert type(responses) is list + assert len(responses) == len(responses) + assert type(responses[0]) is str + assert responses[1].startswith(dialogs[1][-1]['content']) + + + diff --git a/align_system/tests/test_language_model.py b/align_system/tests/test_language_model.py new file mode 100644 index 00000000..df068b2b --- /dev/null +++ b/align_system/tests/test_language_model.py @@ -0,0 +1,54 @@ +import pytest +import torch + +from align_system.algorithms.lib.language_model import LanguageModel + +MODEL_TO_TEST = 'gpt2' # Use a smaller model for testing + +@pytest.fixture(scope="module") +def language_model(): + # Load the model once for all tests that use this fixture + return LanguageModel.load_model(MODEL_TO_TEST, device='cpu') + +def test_load_model(language_model): + assert language_model.model.dtype == torch.float32 + assert language_model.model.device.type == 'cpu' + + +def test_generate_from_tokens(language_model): + tokens = [ + [9246, 9703, 9246, 9703], + [1681, 146, 1681, 146, 1681], + ] + + generations = language_model.generate_from_tokens(tokens, max_new_tokens=1, temperature=0) + + assert generations == [ + 'cat', + '\n' + ] + +def test_generate(language_model): + prompts = [ + 'catdogcatdog', + 'ABCABCABCABCABC', + ] + generations = language_model.generate(prompts, max_new_tokens=1, temperature=0) + assert generations == [ + 'cat', + 'ABC', + ] + +def test_generate_with_prefixes(language_model): + prompts = [ + 'catdogcatdog', + 'ABCABCABCABCABC', + ] + prefixes = [ + 'cat', + 'ABC', + ] + generations = language_model.generate_with_prefixes(prompts, prefixes=prefixes, max_new_tokens=1, temperature=0) + + for generation, prefix in zip(generations, prefixes): + assert generation.startswith(prefix) \ No newline at end of file