diff --git a/fastchat/serve/monitor/classify/README.md b/fastchat/serve/monitor/classify/README.md index 259957618..bdd86f369 100644 --- a/fastchat/serve/monitor/classify/README.md +++ b/fastchat/serve/monitor/classify/README.md @@ -29,6 +29,8 @@ To test your new classifier for a new category, you would have to make sure you python label.py --config config.yaml --testing ``` +If you are labeling a vision category, add the `--vision` flag to the command. This will add a new column to the input data called `image_path` that contains the path to the image corresponding to each conversation. Ensure that you update your config with the correct `image_dir` where the images are stored. + Then, add your new category bench to `tag_names` in `display_score.py`. After making sure that you also have a correctly formatted ground truth json file, you can report the performance of your classifier by running ```console python display_score.py --bench diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py index d21181829..ac84dc626 100644 --- a/fastchat/serve/monitor/classify/category.py +++ b/fastchat/serve/monitor/classify/category.py @@ -26,6 +26,20 @@ def create_category(name): return CategoryMath() elif name == "creative_writing_v0.1": return CategoryCreativeWriting() + elif name == "captioning_v0.1": + return CategoryCaptioning() + elif name == "creative_writing_vision_v0.1": + return CategoryCreativeWritingVision() + elif name == "entity_recognition_v0.1": + return CategoryEntityRecognition() + elif name == "ocr_v0.1": + return CategoryOpticalCharacterRecognition() + elif name == "humor_v0.1": + return CategoryHumor() + elif name == "homework_v0.1": + return CategoryHomework() + elif name == "diagram_v0.1": + return CategoryDiagram() raise Exception(f"Category name is incorrect: {name}") @@ -65,7 +79,7 @@ def get_score(self, judgment): def pre_process(self, prompt): conv = [{"role": "system", "content": self.sys_prompt}] - conv.append({"role": "user", "content": prompt}) + conv.append({"role": "user", "content": prompt["prompt"]}) return conv def post_process(self, judgment): @@ -92,7 +106,7 @@ def get_score(self, judgment): return None def pre_process(self, prompt): - args = {"PROMPT": prompt} + args = {"PROMPT": prompt["prompt"]} conv = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt_template.format(**args)}, @@ -126,7 +140,7 @@ def get_score(self, judgment): return None def pre_process(self, prompt): - args = {"PROMPT": prompt} + args = {"PROMPT": prompt["prompt"]} conv = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt_template.format(**args)}, @@ -163,7 +177,7 @@ def get_score(self, judgment): return None def pre_process(self, prompt): - args = {"PROMPT": prompt} + args = {"PROMPT": prompt["prompt"]} conv = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt_template.format(**args)}, @@ -174,3 +188,392 @@ def post_process(self, judgment): score = self.get_score(judgment=judgment) bool_score = bool(score == "yes") if score else False return {"creative_writing": bool_score, "score": score} + + +##################### +# Vision Categories # +##################### +class CategoryCaptioning(Category): + def __init__(self): + super().__init__() + self.name_tag = "captioning_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a captioning question. A captioning question asks for a general, overall description of the entire image. It must be a single, open-ended query that does NOT ask about particular objects, people, or parts of the image, nor require interpretation beyond a broad description of what is visually present. Examples include 'What is happening in this image?', 'Describe this picture.', 'Explain', etc. An example of a non-captioning question is 'Describe what is funny in this picture.' because it asks for a specific interpretation of the image content. \n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt, api_type="openai"): + args = {"PROMPT": prompt["prompt"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"captioning": bool(score == "yes") if score else False} + + +class CategoryCreativeWritingVision(Category): + def __init__(self): + super().__init__() + self.name_tag = "creative_writing_vision_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = 'You are tasked with determining whether a given VQA user prompt is asking for creative writing. Creative writing is defined as any form of writing that goes beyond standard professional, journalistic, academic, or technical literature. It typically involves imagination, originality, and expression of thoughts and emotions. Prompts which only ask to caption the image without any other requests do NOT count as creative writing. Creative writing can include, but is not limited to, the following formats:\n- Fiction (e.g., short stories, novels)\n- Poetry (e.g., sonnets, free verse)\n- Dramatic writing (e.g., screenplays, monologues, scripts)\n- Personal essays (focusing on subjective experiences or narrative storytelling)\n- Songs and lyrics\n\nCarefully analyze the user prompt and consider whether it primarily requires creative writing. Think about the following aspects:\n1. Does the prompt ask for fictional content, speculative scenarios, or the use of imagination to construct narratives?\n2. Does it encourage the expression of thoughts, emotions, or personal experiences beyond mere factual reporting or analysis?\n3. Is it asking for writing in a specific creative format (e.g., story, poem, script, etc)?\n4. Is the primary purpose of the prompt to foster creative expression or originality rather than information delivery, technical documentation, or analytical reasoning?\n5. Does the prompt request stylistic or rhetorical elements often associated with creative writing, such as metaphor, imagery, dialogue, etc?\n6. Does the prompt expect a response in natural language (e.g., sentences, paragraphs) rather than visual, mathematical, or non-linguistic output?\n\nOutput your verdict as either "yes" or "no"in the following format:\n\n[yes/no]\n. Do NOT explain.' + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall( + judgment.replace("\n", "") + .replace("[", "") + .replace("]", "") + .replace(" ", "") + .lower() + ) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt, api_type="openai"): + args = {"PROMPT": prompt["prompt"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + bool_score = bool(score == "yes") if score else False + return {"creative_writing": bool_score, "score": score} + + +class CategoryEntityRecognition(Category): + def __init__(self): + super().__init__() + self.name_tag = "entity_recognition_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is an entity recognition question. An entity recognition question asks for the identification of specific objects or people in the image. This does NOT include questions that ask for a general description of the image, questions that only ask for object counts, or questions that only require reading text in the image.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt, api_type="openai"): + args = {"PROMPT": prompt["prompt"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"entity_recognition": bool(score == "yes") if score else False} + + +import base64 +import io +from PIL import Image + + +def pil_to_base64(image_path): + image = Image.open(image_path) + buffered = io.BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + + +class CategoryOpticalCharacterRecognition(Category): + def __init__(self): + super().__init__() + self.name_tag = "ocr_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is an optical character recognition (OCR) question. An OCR question requires reading and understanding text in the image to answer. If there is some amount of text in the image and the question requires reading the text in any capacity it should be classified as Optical Character Recognition.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt, api_type="openai"): + args = {"PROMPT": prompt["prompt"]} + base64_image = pil_to_base64(prompt["image_path"]) + if api_type == "anthropic": + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": base64.b64encode( + prompt["image_path"].content + ).decode("utf-8"), + }, + }, + {"type": "text", "text": self.prompt_template.format(**args)}, + ], + }, + ] + else: + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"ocr": bool(score == "yes") if score else False} + + +class CategoryHumor(Category): + def __init__(self): + super().__init__() + self.name_tag = "humor_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a humor question. A humor question asks for a humorous or funny response based on the image or asks to understand what is funny about an image. This includes questions that ask to explain an image which is humorous, such as memes.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt, api_type="openai"): + args = {"PROMPT": prompt["prompt"]} + base64_image = pil_to_base64(prompt["image_path"]) + if api_type == "anthropic": + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": base64_image, + }, + }, + {"type": "text", "text": self.prompt_template.format(**args)}, + ], + }, + ] + else: + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"humor": bool(score == "yes") if score else False} + + +import os + + +class CategoryHomework(Category): + def __init__(self): + super().__init__() + self.name_tag = "homework_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = """You are tasked with determining if the given image contains a homework or exam question. A homework or exam question typically contains text with a well-defined question or task which asks for a solution. In addition, many homework and exam questions contain multiple choice, equations, and question numbers. You may also see text referring to showing your work or providing justification. Note that documents such as resumes, business cards, records, or personal notes are NOT considered homework or exam questions; homework and exam questions explicitly ask for a solution or explanation. + +Output your verdict in the following format: +[yes/no] +. Do NOT explain.""" + self.prompt_template = "" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt, api_type="openai"): + base64_image = pil_to_base64(prompt["image_path"]) + + # Open the local image file in binary mode and encode it as base64 + assert os.path.exists(prompt["image_path"]) + with open(prompt["image_path"], "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode("utf-8") + if api_type == "anthropic": + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_data, + }, + }, + {"type": "text", "text": ""}, + ], + }, + ] + else: + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": ""}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"homework": bool(score == "yes") if score else False} + + +class CategoryDiagram(Category): + def __init__(self): + super().__init__() + self.name_tag = "diagram_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = """You are tasked with determining whether the given image contains a chart, diagram, or figure. Carefully examine the user prompt and consider the following aspects: +1. Does the image contain visual elements such as graphs, flowcharts, method figures, chemical structures, or other visual representations of data or concepts? +2. Does the prompt require interpreting or analyzing the flow of information, relationships between elements, or the structure of the visual representation in the image? +3. Does the prompt require spatial reasoning and understanding the layout or structure of the visual elements? +4. Note that images containing only text, tables, handwriting, or photographs without any other visual graphics is NOT considered a chart or diagram. + +Output your verdict in the following format: +[yes/no] +. Do NOT explain.""" + self.prompt_template = "" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt, api_type="openai"): + base64_image = pil_to_base64(prompt["image_path"]) + + # Open the local image file in binary mode and encode it as base64 + assert os.path.exists(prompt["image_path"]) + with open(prompt["image_path"], "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode("utf-8") + if api_type == "anthropic": + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_data, + }, + }, + {"type": "text", "text": ""}, + ], + }, + ] + else: + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": ""}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"diagram": bool(score == "yes") if score else False} diff --git a/fastchat/serve/monitor/classify/config.yaml b/fastchat/serve/monitor/classify/config.yaml index 315f0dccc..cc26c3ce1 100644 --- a/fastchat/serve/monitor/classify/config.yaml +++ b/fastchat/serve/monitor/classify/config.yaml @@ -14,6 +14,7 @@ task_name: model_name: null name: llama-3-70b-instruct +api_type: openai endpoints: - api_base: null api_key: null @@ -21,6 +22,8 @@ parallel: 50 temperature: 0.0 max_token: 512 +image_dir: null # directory where vision arena images are stored + max_retry: 2 retry_sleep: 10 error_output: $ERROR$ \ No newline at end of file diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py index 2d0471a1f..fe928a9ba 100644 --- a/fastchat/serve/monitor/classify/label.py +++ b/fastchat/serve/monitor/classify/label.py @@ -89,6 +89,95 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No return output +def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict=None): + import anthropic + + if api_dict: + api_key = api_dict["api_key"] + else: + api_key = os.environ["ANTHROPIC_API_KEY"] + + sys_msg = "" + if messages[0]["role"] == "system": + sys_msg = messages[0]["content"] + messages = messages[1:] + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + c = anthropic.Anthropic(api_key=api_key) + response = c.messages.create( + model=model, + messages=messages, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens=max_tokens, + temperature=temperature, + system=sys_msg, + ) + output = response.content[0].text + break + except anthropic.APIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + return output + + +def chat_completion_gemini( + model, messages, temperature, max_tokens, api_dict=None, image_path=None +): + import google + import google.generativeai as genai + from google.generativeai.types import HarmCategory, HarmBlockThreshold + from PIL import Image + + if api_dict: + api_key = api_dict["api_key"] + genai.configure(api_key=api_key) + else: + genai.configure(api_key=os.environ["GENAI_API_KEY"]) + + sys_msg = "" + if messages[0]["role"] == "system": + sys_msg = messages[0]["content"] + messages = messages[1:] + + prompt = messages[0]["content"] + if type(prompt) == list: + prompt = [prompt[0]["text"], Image.open(image_path).convert("RGB")] + + safety_settings = { + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + gemini = genai.GenerativeModel(model, system_instruction=sys_msg) + gemini.max_output_tokens = max_tokens + gemini.temperature = temperature + response = gemini.generate_content(prompt, safety_settings=safety_settings) + if response.candidates[0].finish_reason != 1: + print( + f"Gemini did not finish generating content: {response.candidates[0].finish_reason}" + ) + output = "Gemini did not finish generating content" + else: + output = response.text + break + except google.api_core.exceptions.ResourceExhausted as e: + # THIS IS A TEMPORARY FIX + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except Exception as e: + # THIS IS A TEMPORARY FIX + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + return output + + def get_answer( question: dict, model_name: str, @@ -98,6 +187,7 @@ def get_answer( api_dict: dict, categories: list, testing: bool, + api_type: str, ): if "category_tag" in question: category_tag = question["category_tag"] @@ -107,14 +197,34 @@ def get_answer( output_log = {} for category in categories: - conv = category.pre_process(question["prompt"]) - output = chat_completion_openai( - model=model_name, - messages=conv, - temperature=temperature, - max_tokens=max_tokens, - api_dict=api_dict, - ) + conv = category.pre_process(question) + if api_type == "openai": + output = chat_completion_openai( + model=model_name, + messages=conv, + temperature=temperature, + max_tokens=max_tokens, + api_dict=api_dict, + ) + elif api_type == "anthropic": + output = chat_completion_anthropic( + model=model_name, + messages=conv, + temperature=temperature, + max_tokens=max_tokens, + api_dict=api_dict, + ) + elif api_type == "gemini": + output = chat_completion_gemini( + model=model_name, + messages=conv, + temperature=temperature, + max_tokens=max_tokens, + api_dict=api_dict, + image_path=question.get("image_path"), + ) + else: + raise ValueError(f"api_type {api_type} not supported") # Dump answers category_tag[category.name_tag] = category.post_process(output) @@ -169,6 +279,7 @@ def find_required_tasks(row): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--testing", action="store_true") + parser.add_argument("--vision", action="store_true") args = parser.parse_args() enter = input( @@ -199,6 +310,15 @@ def find_required_tasks(row): assert len(input_data) == len(input_data.uid.unique()) print(f"{len(input_data)}# of input data just loaded") + if args.vision: + old_len = len(input_data) + input_data["image_hash"] = input_data.conversation_a.map( + lambda convo: convo[0]["content"][1][0] + ) + input_data["image_path"] = input_data.image_hash.map( + lambda x: f"{config['image_dir']}/{x}.png" + ) + if config["cache_file"]: print("loading cache data") with open(config["cache_file"], "rb") as f: @@ -246,9 +366,18 @@ def find_required_tasks(row): f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}" ) - not_labeled["prompt"] = not_labeled.conversation_a.map( - lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) - ) + if args.vision: + not_labeled["prompt"] = not_labeled.conversation_a.map( + lambda convo: "\n".join( + [convo[i]["content"][0] for i in range(0, len(convo), 2)] + ) + ) + else: + not_labeled["prompt"] = not_labeled.conversation_a.map( + lambda convo: "\n".join( + [convo[i]["content"] for i in range(0, len(convo), 2)] + ) + ) not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) with concurrent.futures.ThreadPoolExecutor( @@ -270,6 +399,7 @@ def find_required_tasks(row): if category.name_tag in row["required_tasks"] ], args.testing, + config["api_type"], ) futures.append(future) for future in tqdm.tqdm( diff --git a/fastchat/serve/monitor/classify/vision_config.yaml b/fastchat/serve/monitor/classify/vision_config.yaml new file mode 100644 index 000000000..0002ff654 --- /dev/null +++ b/fastchat/serve/monitor/classify/vision_config.yaml @@ -0,0 +1,34 @@ +# Yaml config file for category classification + +input_file: null # json +cache_file: null # json +output_file: null # json line + +convert_to_json: True + +task_name: + - captioning_v0.1 + - homework_v0.1 + - ocr_v0.1 + - humor_v0.1 + - entity_recognition_v0.1 + - creative_writing_vision_v0.1 + - diagram_v0.1 + + +model_name: null +name: gemini-1.5-flash +api_type: gemini +endpoints: + - api_base: null + api_key: null + +parallel: 50 +temperature: 0.0 +max_token: 512 + +image_dir: null # directory where vision arena images are stored + +max_retry: 2 +retry_sleep: 10 +error_output: $ERROR$ \ No newline at end of file diff --git a/fastchat/serve/monitor/code_tagger.py b/fastchat/serve/monitor/code_tagger.py deleted file mode 100644 index 12eeaed4b..000000000 --- a/fastchat/serve/monitor/code_tagger.py +++ /dev/null @@ -1,180 +0,0 @@ -import re -import json -import argparse -import multiprocessing as mp - -import nltk -from tqdm import tqdm -from nltk.tokenize import word_tokenize - - -def is_code_conversation(text: str) -> tuple[bool, list[str]]: - """Check if the text is a code conversation""" - - if "```plaintext" in text: - lines = text.split("\n") - line1_idx = [idx for idx, line in enumerate(lines) if "```plaintext" in line][0] - line2_idx = [ - line1_idx + 1 + idx - for idx, line in enumerate(lines) - if "```" in line[line1_idx + 1 :] - ] - if line2_idx: - line2_idx = line2_idx[0] - text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :]) - else: - text = "\n".join(lines[:line1_idx]) - return is_code_conversation(text) - - if "```markdown" in text: - otext = text - lines = text.split("\n") - line1_idx = [idx for idx, line in enumerate(lines) if "```markdown" in line][0] - line2_idx = [ - line1_idx + 1 + idx - for idx, line in enumerate(lines) - if "```" in line[line1_idx + 1 :] - ] - if line2_idx: - line2_idx = line2_idx[0] - text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :]) - else: - text = "\n".join(lines[:line1_idx]) - return is_code_conversation(text) - - if "ascii art" in text.lower(): - return False, [] - - # 1. Check for code formatting - if re.search(r"```", text): - return True, ["backticks"] - - # Tokenize the text - tokens = word_tokenize(text) - tokens = [token.lower() for token in tokens] - - # 2. Check for programming concepts - concepts = ["git", "github", "pull request", "dataframe", "nginx", "pip"] - if any(concept in tokens for concept in concepts): - matched_concepts = list(set(tokens).intersection(set(concepts))) - return True, matched_concepts - - # 3. Check for programming language name - languages = [ - "python", - "c++", - "cpp", - "java", - "javascript", - "typescript", - "html", - "css", - "sql", - "bash", - "powershell", - "matlab", - "golang", - "linux", - "ubuntu", - ] - if any(language in tokens for language in languages): - matched_languages = list(set(tokens).intersection(set(languages))) - return True, matched_languages - - # 4. Programming concept substrings - strings = [ - "import pandas", - "import numpy", - "import torch", - "jax", - "tensorflow", - "pytorch", - "keras", - "scikit-learn", - "sklearn", - " apt-get ", - ] - found_array = [string in text for string in strings] - if any(found_array): - matched_strings = [ - string for string, found in zip(strings, found_array) if found - ] - return True, matched_strings - - # 5. Programming concept regexes - regexes = [ - r"from \w+ import \w+", - r"conda install \w+", - r"pip install -r \w+", - r"conda install -c \w+ \w+", - r"#include <\w+>", - r"import \w+ as \w+", - r"#include \"\w+\.h\"", - ] - found_array = [re.search(regex, text) for regex in regexes] - if any(found_array): - matched_regexes = [regex for regex, found in zip(regexes, found_array) if found] - return True, matched_regexes - - return False, [] - - -def check_code_conv(conv) -> tuple[bool, list[str]]: - """Check if the conversation is a code conversation""" - for _, msg in enumerate(conv): - content = msg["content"] - if not isinstance(content, str): - continue - is_code_conv_res = is_code_conversation(content) - if is_code_conv_res[0]: - return is_code_conv_res - return False, [] - - -def check_conv_row(conv_row): - check_a, code_a = check_code_conv(conv_row["conversation_a"]) - check_b, code_b = check_code_conv(conv_row["conversation_b"]) - - return check_a or check_b, code_a + code_b - - -def process_battle_file(battle_file_path: str, n_cpus: int): - with open(battle_file_path, "r") as f: - data = json.load(f) - - with mp.Pool(n_cpus) as pool: - tagged_data = list(tqdm(pool.imap(check_conv_row, data), total=len(data))) - - output_data = [row for row, (is_code, _) in zip(data, tagged_data) if is_code] - - return output_data - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--clean-battle-file", type=str) - parser.add_argument("--output-clean-battle-file", type=str, default=None) - parser.add_argument("--n-cpus", type=int, default=-1) - - args = parser.parse_args() - - if args.output_clean_battle_file is None: - args.output_clean_battle_file = args.clean_battle_file - - if args.n_cpus == -1: - args.n_cpus = mp.cpu_count() - - print( - f"Processing {args.clean_battle_file} and saving to {args.output_clean_battle_file} with {args.n_cpus} cpus" - ) - - output_data = process_battle_file(args.clean_battle_file, args.n_cpus) - - with open(args.output_clean_battle_file, "w") as f: - json.dump(output_data, f, indent=4) - - print(f"Total code conversations: {len(output_data)}") - print("Done!") - - with open(args.output_clean_battle_file, "r") as f: - data = json.load(f) diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index c07ee4669..df6315445 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -45,6 +45,7 @@ k2c[k + "_style_control"] = v + "_style_control" key_to_category_name = k2c + notebook_url = ( "https://colab.research.google.com/drive/1KdwokPjirkTmpO_P1WByFNFiqxWQquwH" ) @@ -442,8 +443,11 @@ def build_arena_tab( for k in key_to_category_name.keys(): if k not in elo_results: continue - arena_dfs[key_to_category_name[k]] = elo_results[k]["leaderboard_table_df"] - category_elo_results[key_to_category_name[k]] = elo_results[k] + category_name = key_to_category_name[k.replace("_style_control", "")] + if "_style_control" in k: + category_name = f"{category_name} w/ Style Control" + arena_dfs[category_name] = elo_results[k]["leaderboard_table_df"] + category_elo_results[category_name] = elo_results[k] arena_df = arena_dfs["Overall"] @@ -791,7 +795,7 @@ def highlight_top_3(s): style = style.background_gradient( cmap="Blues", subset=category_names, - vmin=1150, + vmin=category_df[category_names].max().max() - 250, vmax=category_df[category_names].max().max(), ) @@ -814,10 +818,6 @@ def build_category_leaderboard_tab( combined_elo_df, categories, "rating" ) sort_ranking = lambda _: get_arena_category_table(combined_elo_df, categories) - with gr.Row(): - gr.Markdown( - f"""  Chatbot Arena Overview""" - ) overall_ranking_leaderboard = gr.Dataframe( headers=["Model"] + [key_to_category_name[k] for k in categories], @@ -852,6 +852,20 @@ def build_category_leaderboard_tab( ] selected_categories_width = [110, 110, 110, 110, 80, 80, 80, 110, 80, 80] +vision_categories = [ + "full", + "full_style_control", + "captioning", + "captioning_style_control", + "entity_recognition", + "ocr", + "creative_writing_vision", + "homework", + "diagram", + "no_refusal", +] +vision_categories_width = [110, 110, 100, 110, 110, 60, 80, 80, 80, 80] + language_categories = [ "english", "chinese", @@ -963,16 +977,26 @@ def build_leaderboard_tab( combined_table = get_combined_table(elo_results_text, model_table_df) build_category_leaderboard_tab( combined_table, - "Task", + "LLM Task", selected_categories, selected_categories_width, ) build_category_leaderboard_tab( combined_table, - "Language", + "LLM Language", language_categories, language_categories_width, ) + if elo_results_vision is not None: + vision_combined_table = get_combined_table( + elo_results_vision, model_table_df + ) + build_category_leaderboard_tab( + vision_combined_table, + "VLM Task", + vision_categories, + vision_categories_width, + ) gr.Markdown( f""" ***Rank (UB)**: model's ranking (upper-bound), defined by one + the number of models that are statistically better than the target model. @@ -1074,31 +1098,10 @@ def build_demo(elo_results_file, leaderboard_table_file, arena_hard_leaderboard) from fastchat.serve.gradio_web_server import block_css text_size = gr.themes.sizes.text_lg - # load theme from theme.json - theme = gr.themes.Default.load("theme.json") - # set text size to large - theme.text_size = text_size - theme.set( - button_large_text_size="20px", - button_small_text_size="20px", - button_large_text_weight="100", - button_small_text_weight="100", - button_shadow="*shadow_drop_lg", - button_shadow_hover="*shadow_drop_lg", - checkbox_label_shadow="*shadow_drop_lg", - button_shadow_active="*shadow_inset", - button_secondary_background_fill="*primary_300", - button_secondary_background_fill_dark="*primary_700", - button_secondary_background_fill_hover="*primary_200", - button_secondary_background_fill_hover_dark="*primary_500", - button_secondary_text_color="*primary_800", - button_secondary_text_color_dark="white", - ) with gr.Blocks( title="Chatbot Arena Leaderboard", - # theme=gr.themes.Default(text_size=text_size), - theme=theme, + theme=gr.themes.Default(text_size=text_size), css=block_css, ) as demo: with gr.Tabs() as tabs: diff --git a/fastchat/serve/monitor/monitor_md.py b/fastchat/serve/monitor/monitor_md.py index 8fadd0137..89f99f09c 100644 --- a/fastchat/serve/monitor/monitor_md.py +++ b/fastchat/serve/monitor/monitor_md.py @@ -14,17 +14,15 @@ key_to_category_name = { "full": "Overall", - "full_style_control": "Overall w/ Style Control", "dedup": "De-duplicate Top Redundant Queries (soon to be default)", "math": "Math", "if": "Instruction Following", "multiturn": "Multi-Turn", "creative_writing": "Creative Writing", + "creative_writing_vision": "Creative Writing", "coding": "Coding", - "coding_style_control": "Coding w/ Style Control", "hard_6": "Hard Prompts", "hard_english_6": "Hard Prompts (English)", - "hard_6_style_control": "Hard Prompts w/ Style Control", "long_user": "Longer Query", "english": "English", "chinese": "Chinese", @@ -39,18 +37,22 @@ "no_refusal": "Exclude Refusal", "overall_limit_5_user_vote": "overall_limit_5_user_vote", "full_old": "Overall (Deprecated)", + "captioning": "Captioning", + "entity_recognition": "Entity Recognition", + "ocr": "OCR", + "humor": "Humor", + "homework": "Homework", + "diagram": "Diagram", + "is_preset": "Exclude Preset Images", } cat_name_to_explanation = { "Overall": "Overall Questions", - "Overall w/ Style Control": "Overall Leaderboard with Style Control. See details in [blog post](https://lmsys.org/blog/2024-08-28-style-control/).", "De-duplicate Top Redundant Queries (soon to be default)": "De-duplicate top redundant queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", "Math": "Math", "Instruction Following": "Instruction Following", "Multi-Turn": "Multi-Turn Conversation (>= 2 turns)", "Coding": "Coding: whether conversation contains code snippets", - "Coding w/ Style Control": "Coding with Style Control", "Hard Prompts": "Hard Prompts: details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)", - "Hard Prompts w/ Style Control": "Hard Prompts with Style Control. See details in [blog post](https://lmsys.org/blog/2024-08-28-style-control/).", "Hard Prompts (English)": "Hard Prompts (English), note: the delta is to English Category. details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)", "Longer Query": "Longer Query (>= 500 tokens)", "English": "English Prompts", @@ -67,6 +69,13 @@ "overall_limit_5_user_vote": "overall_limit_5_user_vote", "Overall (Deprecated)": "Overall without De-duplicating Top Redundant Queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", "Creative Writing": "Creative Writing", + "Exclude Preset Images": "Exclude Images from 'Random Example' Option", + "Captioning": "Open-Ended Captioning", + "Entity Recognition": "Entity Recognition (e.g. who is in the image)", + "OCR": "Optical Character Recognition", + "Humor": "Humor (e.g. writing jokes, meme understanding)", + "Homework": "Homework problems", + "Diagram": "Diagram (e.g. plots, flow charts, figures)", } cat_name_to_baseline = { "Hard Prompts (English)": "English", @@ -126,7 +135,14 @@ def make_category_arena_leaderboard_md(arena_df, arena_subset_df, name="Overall" space = "   " total_subset_votes = sum(arena_subset_df["num_battles"]) // 2 total_subset_models = len(arena_subset_df) - leaderboard_md = f"""### {cat_name_to_explanation[name]} + if "w/ Style Control" in name: + explanation = ( + cat_name_to_explanation[name.replace(" w/ Style Control", "")] + + " with Style Control. See details in [blog post](https://lmsys.org/blog/2024-08-28-style-control/)." + ) + else: + explanation = cat_name_to_explanation[name] + leaderboard_md = f"""### {explanation} #### {space} #models: **{total_subset_models} ({round(total_subset_models/total_models *100)}%)** {space} #votes: **{"{:,}".format(total_subset_votes)} ({round(total_subset_votes/total_votes * 100)}%)**{space} """ return leaderboard_md