diff --git a/.gitignore b/.gitignore index dcd7997..57c58be 100644 --- a/.gitignore +++ b/.gitignore @@ -166,6 +166,7 @@ cython_debug/ */__pycache__/* .chainlit/translations/ code/.chainlit/translations/ +**/.chainlit/translations/ storage/logs/* vectorstores/* diff --git a/apps/ai_tutor/app.py b/apps/ai_tutor/app.py index 9d0a2a0..26f666c 100644 --- a/apps/ai_tutor/app.py +++ b/apps/ai_tutor/app.py @@ -26,7 +26,7 @@ import hashlib # set config -config = config_manager.get_config().dict() +config = config_manager.get_config() # set constants GITHUB_REPO = config["misc"]["github_repo"] diff --git a/apps/ai_tutor/chainlit_app.py b/apps/ai_tutor/chainlit_app.py index ba8768a..ac847ee 100644 --- a/apps/ai_tutor/chainlit_app.py +++ b/apps/ai_tutor/chainlit_app.py @@ -31,6 +31,9 @@ from langchain_community.callbacks import get_openai_callback from datetime import datetime, timezone from config.config_manager import config_manager +from edubotics_core.chat.agentic.agent import Agent + +from config.prompts import prompts USER_TIMEOUT = 60_000 SYSTEM = "System" @@ -40,7 +43,8 @@ ERROR = "Error" # set config -config = config_manager.get_config().dict() +config = config_manager.get_config() +print(config) async def setup_data_layer(): @@ -81,6 +85,8 @@ def __init__(self, config): Initialize the Chatbot class. """ self.config = config + # Initialize Agent instance + self.agent = Agent(config=config, prompts=prompts) @no_type_check async def setup_llm(self): @@ -172,10 +178,14 @@ async def make_llm_settings_widgets(self, config=None): cl.input_widget.Select( id="retriever_method", label="Retriever (Default FAISS)", - values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"], - initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index( - config["vectorstore"]["db_option"] - ), + values=["FAISS", "Chroma", "RAGatouille", "RAPTOR", "MVS"], + initial_index=[ + "FAISS", + "Chroma", + "RAGatouille", + "RAPTOR", + "MVS", + ].index(config["vectorstore"]["db_option"]), ), cl.input_widget.Slider( id="memory_window", @@ -297,14 +307,15 @@ async def start(self): } memory = cl.user_session.get("memory", []) - self.llm_tutor = LLMTutor(self.config, user=self.user) + # Replace LLMTutor usage with Agent if needed + # self.llm_tutor = LLMTutor(self.config, user=self.user) + # self.chain = self.llm_tutor.qa_bot(memory=memory) - self.chain = self.llm_tutor.qa_bot( - memory=memory, - ) - self.question_generator = self.llm_tutor.question_generator - cl.user_session.set("llm_tutor", self.llm_tutor) - cl.user_session.set("chain", self.chain) + # cl.user_session.set("llm_tutor", self.llm_tutor) + # cl.user_session.set("chain", self.chain) + + self.agent.set_thread_id(cl.context.session.thread_id) + cl.user_session.set("agent", self.agent) async def stream_response(self, response): """ @@ -344,6 +355,7 @@ async def main(self, message): # update user info with last message time user = cl.user_session.get("user") + await reset_tokens_for_user( user, self.config["token_config"]["tokens_left"], @@ -355,7 +367,11 @@ async def main(self, message): # see if user has token credits left # if not, return message saying they have run out of tokens - if user.metadata["tokens_left"] <= 0 and "admin" not in user.metadata["role"]: + if ( + user.metadata["tokens_left"] <= 0 + and "admin" not in user.metadata["role"] + and config["chat_logging"]["log_chat"] + ): current_datetime = get_time() cooldown, cooldown_end_time = await check_user_cooldown( user, @@ -425,84 +441,83 @@ async def main(self, message): ), } + response = cl.Message(content="") with get_openai_callback() as token_count_cb: - if stream: - res = chain.stream(user_query=user_query_dict, config=chain_config) - res = await self.stream_response(res) - else: - res = await chain.invoke( - user_query=user_query_dict, - config=chain_config, - ) - token_count += token_count_cb.total_tokens - - answer = res.get("answer", res.get("result")) - - answer_with_sources, source_elements, sources_dict = get_sources( - res, answer, stream=stream, view_sources=view_sources - ) - answer_with_sources = answer_with_sources.replace("$$", "$") + async for chunk in self.agent.stream(message.content): + content = chunk["content"] + tokens = chunk["total_tokens"] + await response.stream_token(content) - actions = [] + token_count = tokens - if self.config["llm_params"]["generate_follow_up"]: - cb_follow_up = cl.AsyncLangchainCallbackHandler() - config = { - "callbacks": ( - [cb_follow_up] - if cl_data._data_layer and self.config["chat_logging"]["callbacks"] - else None - ) - } - with get_openai_callback() as token_count_cb: - list_of_questions = await self.question_generator.generate_questions( - query=user_query_dict["input"], - response=answer, - chat_history=res.get("chat_history"), - context=res.get("context"), - config=config, - ) - - token_count += token_count_cb.total_tokens - - for question in list_of_questions: - actions.append( - cl.Action( - name="follow up question", - value="example_value", - description=question, - label=question, - ) - ) + answer = response.content + sources = self.agent.get_sources() - # # update user info with token count - tokens_left = await update_user_from_chainlit(user, token_count) - - answer_with_sources += ( - '\n\n\n" - ) - - await cl.Message( - content=answer_with_sources, - elements=source_elements, - author=LLM, - actions=actions, - ).send() + if len(sources) > 0: + sources_text = "\n\nSources: \n" + "\n".join( + [f"- {source}" for source in sources] + ) + response.content += sources_text + + await response.send() + + # answer_with_sources, source_elements, sources_dict = get_sources( + # res, answer, stream=stream, view_sources=view_sources + # ) + # answer_with_sources = answer_with_sources.replace("$$", "$") + + # actions = [] + + # if self.config["llm_params"]["generate_follow_up"]: + # cb_follow_up = cl.AsyncLangchainCallbackHandler() + # config = { + # "callbacks": ( + # [cb_follow_up] + # if cl_data._data_layer and self.config["chat_logging"]["callbacks"] + # else None + # ) + # } + # with get_openai_callback() as token_count_cb: + # list_of_questions = await self.question_generator.generate_questions( + # query=user_query_dict["input"], + # response=answer, + # chat_history=res.get("chat_history"), + # context=res.get("context"), + # config=config, + # ) + + # token_count += token_count_cb.total_tokens + + # for question in list_of_questions: + # actions.append( + # cl.Action( + # name="follow up question", + # value="example_value", + # description=question, + # label=question, + # ) + # ) + + # # # update user info with token count + # tokens_left = await update_user_from_chainlit(user, token_count) + + # answer_with_sources += ( + # '\n\n\n" + # ) + + # await cl.Message( + # content=answer_with_sources, + # elements=source_elements, + # author=LLM, + # actions=actions, + # ).send() async def on_chat_resume(self, thread: ThreadDict): - # thread_config = None - steps = thread["steps"] - k = self.config["llm_params"][ - "memory_window" - ] # on resume, alwyas use the default memory window - conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM) - # thread_config = get_last_config( - # steps - # ) # TODO: Returns None for now - which causes config to be reloaded with default values - cl.user_session.set("memory", conversation_list) - await self.start() + thread_id = thread["id"] + self.agent.set_thread_id(thread_id) + self.agent.populate_conversation_history(thread) @cl.header_auth_callback def header_auth_callback(headers: dict) -> Optional[cl.User]: diff --git a/apps/ai_tutor/config/constants.py b/apps/ai_tutor/config/constants.py index 506d0af..5ceac00 100644 --- a/apps/ai_tutor/config/constants.py +++ b/apps/ai_tutor/config/constants.py @@ -7,9 +7,14 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY") -HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") +COHERE_API_KEY = os.getenv("COHERE_API_KEY") + LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING") LITERAL_API_URL = os.getenv("LITERAL_API_URL") + +HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") + + CHAINLIT_URL = os.getenv("CHAINLIT_URL") EMAIL_ENCRYPTION_KEY = os.getenv("EMAIL_ENCRYPTION_KEY") diff --git a/apps/ai_tutor/config/prompts.py b/apps/ai_tutor/config/prompts.py index bdd6611..413f788 100644 --- a/apps/ai_tutor/config/prompts.py +++ b/apps/ai_tutor/config/prompts.py @@ -13,8 +13,8 @@ ), "prompt_with_history": { "normal": ( - "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. " - "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. " + "You are an AI Tutor for the course DS542, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. " + "If you don't know the answer, do not make things up, just say you don't know and ask the user to rephrase. Keep the conversation flowing naturally. " "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. " "Render math equations in LaTeX format between $ or $$ signs, stick to the parameter and variable icons found in your context. Be sure to explain the parameters and variables in the equations." "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n" @@ -22,6 +22,7 @@ "Chat History:\n{chat_history}\n\n" "Context:\n{context}\n\n" "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n" + "If the provided context is not relevant, just say you don't know and ask the user to attach the relevant documents. Do not make things up." "Student: {input}\n" "AI Tutor:" ), diff --git a/edubotics_core/chat/agentic/agent.py b/edubotics_core/chat/agentic/agent.py new file mode 100644 index 0000000..1ae8384 --- /dev/null +++ b/edubotics_core/chat/agentic/agent.py @@ -0,0 +1,340 @@ +import asyncio +import sys +import os +from typing import Literal, TypedDict, List +from PIL import Image + +from pprint import pprint +from dotenv import load_dotenv +from langchain_core.output_parsers import StrOutputParser +from langchain_core.messages import HumanMessage, AIMessage, AIMessageChunk +from langchain_openai import ChatOpenAI +from langgraph.graph import END, StateGraph, START +from langchain_core.prompts import ChatPromptTemplate +from chainlit.types import ThreadDict +from langgraph.checkpoint.memory import MemorySaver + +from .utils import RouteQuery, GraphState, rag_template, router_template +from edubotics_core.vectorstore.mvs import MultiVectorStore + + +load_dotenv() + + +class Agent: + def __init__( + self, thread_id: str = None, config: dict = None, prompts: dict = None + ): + self.config = config + self.retrievers = {} + self.content_types = self.config["metadata"]["content_types"] + + mvs = MultiVectorStore(self.config) + self.retrievers = mvs.as_retriever() + + self.workflow = self.create_workflow() + self.graph = self.workflow.compile(checkpointer=MemorySaver()) + + self.updates = [] + self.thread_id = thread_id + + self.conversation_history = [] + self.memory_window = self.config["llm_params"]["memory_window"] + + self.prompts = prompts + + api_key = os.environ["OPENAI_API_KEY"] + + self.rag_prompt = ChatPromptTemplate.from_template( + prompts["openai"]["prompt_with_history"]["normal"] + ) + self.route_prompt = ChatPromptTemplate.from_template(router_template) + self.model = ChatOpenAI( + temperature=0.5, + model_name="gpt-4o-mini", + openai_api_key=api_key, + # stream_options={"include_usage": True}, + # stream=True, + ) + + self.rag_chain = self.rag_prompt | self.model | StrOutputParser() + + structured_model_router = self.model.with_structured_output(RouteQuery) + self.question_router = self.route_prompt | structured_model_router + + def set_thread_id(self, thread_id: str): + self.thread_id = thread_id + + def create_workflow(self): + + workflow = StateGraph(GraphState) + + workflow.add_node("assignments_retrieve", self.assignments_retrieve) + workflow.add_node("lectures_retrieve", self.lectures_retrieve) + workflow.add_node("discussions_retrieve", self.discussions_retrieve) + workflow.add_node("other_retrieve", self.other_retrieve) + workflow.add_node("not_needed", self.no_retrieve) + workflow.add_node("generate", self.generate) + + # Build graph + workflow.add_conditional_edges( + START, + self.route_question, + { + "assignments_retrieve": "assignments_retrieve", + "lectures_retrieve": "lectures_retrieve", + "discussions_retrieve": "discussions_retrieve", + "other_retrieve": "other_retrieve", + "not_needed": "not_needed", + }, + ) + + workflow.add_edge("assignments_retrieve", "generate") + workflow.add_edge("lectures_retrieve", "generate") + workflow.add_edge("discussions_retrieve", "generate") + workflow.add_edge("other_retrieve", "generate") + workflow.add_edge("not_needed", "generate") + workflow.add_edge("generate", END) + + return workflow + + async def stream(self, question: str): + config = {"configurable": {"thread_id": self.thread_id}} + messages_with_history = self.conversation_history + [ + HumanMessage(content=question) + ] + token_count = 0 + for event in self.graph.stream( + {"messages": messages_with_history}, + config, + stream_mode=["messages", "updates"], + ): + if event[0] == "messages" and event[1][1]["langgraph_node"] == "generate": + ai_message = event[1][0] + + if isinstance(ai_message, AIMessageChunk): + if ai_message.usage_metadata: + tokens = ai_message.usage_metadata.get("output_tokens", 0) + token_count += tokens + + yield {"content": ai_message.content, "total_tokens": token_count} + else: + update = event[1] + self.updates.append(update) + + def run(self, question: str) -> dict: + config = {"configurable": {"thread_id": self.thread_id}} + last_state = self.graph.invoke( + {"messages": [HumanMessage(content=question)]}, config + ) + response = last_state["messages"][-1].content + documents = last_state["documents"] + self.updates.append(last_state) + return {"response": response, "documents": documents} + + def get_sources(self): + """ + Get the sources of the documents retrieved from the vector stores for the last question. + """ + if len(self.updates) == 0: + return [] + else: + last_update = self.updates[-1] + if "generate" in last_update: + return last_update["generate"]["documents_sources"] + elif "documents" in last_update: + return last_update["documents"] + else: + return [] + + def assignments_retrieve(self, state): + """ + Retrieve documents from the assignments vector store + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE ASSIGNMENTS---") + messages = state["messages"] + question = messages[-1].content + + # Retrieval + documents = self.retrievers["assignments"].invoke(question) + return {"documents": documents, "type": "retrieve"} + + def lectures_retrieve(self, state): + """ + Retrieve documents from the lectures vector store + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print("---RETRIEVE LECTURES---") + messages = state["messages"] + question = messages[-1].content + + # Retrieval + documents = self.retrievers["lecture"].invoke(question) + return {"documents": documents, "type": "retrieve"} + + def discussions_retrieve(self, state): + """ + Retrieve documents from the discussions vector store + """ + print("---RETRIEVE DISCUSSIONS---") + messages = state["messages"] + question = messages[-1].content + + documents = self.retrievers["discussion"].invoke(question) + return {"documents": documents, "type": "retrieve"} + + def other_retrieve(self, state): + """ + Retrieve documents from the other vector store + """ + print("---RETRIEVE OTHER---") + messages = state["messages"] + question = messages[-1].content + + documents = self.retrievers["other"].invoke(question) + return {"documents": documents, "type": "retrieve"} + + def no_retrieve(self, state): + """ + Return empty documents + """ + print("---RETRIEVE NOT NEEDED---") + return {"documents": [], "type": "retrieve"} + + def generate(self, state): + """ + Generate answer + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, generation, that contains LLM generation + """ + # print("---GENERATE---") + messages = state["messages"] + question = messages[-1].content + + conversation_history = messages[-6:-1] + + documents = state["documents"] + documents_sources = [doc.metadata["source"] for doc in documents] + + # RAG generation + response = self.rag_chain.invoke( + { + "context": documents, + "input": question, + "chat_history": conversation_history, + } + ) + ai_message = AIMessage(content=response) + + return { + "documents_sources": documents_sources, + "messages": [ai_message], + "type": "generate", + } + + def route_question(self, state): + """ + Route question to corresponding RAG. + + Args: + state (dict): The current graph state + + Returns: + str: Next node to call + """ + + # print("---ROUTE QUESTION---") + messages = state["messages"] + question = messages[-1].content + + source = self.question_router.invoke({"input": question}) + if source.datasource == "assignment": + # print("---ROUTE QUESTION TO ASSIGNMENTS---") + return "assignments_retrieve" + elif source.datasource == "lecture": + # print("---ROUTE QUESTION TO LECTURES---") + return "lectures_retrieve" + elif source.datasource == "discussion": + # print("---ROUTE QUESTION TO DISCUSSIONS---") + return "discussions_retrieve" + elif source.datasource == "other": + # print("---ROUTE QUESTION TO OTHER---") + return "other_retrieve" + else: + # print("---ROUTE QUESTION TO RETRIEVAL NOT NEEDED---") + return "not_needed" + + def reset_history(self): + """ + Reset conversation history - before resuming a chat or starting a new one. + """ + self.conversation_history = [] + + def get_history(self) -> List[HumanMessage | AIMessage]: + return self.conversation_history + + def get_state(self, config): + return self.graph.get_state(config) + + def populate_conversation_history(self, thread: ThreadDict): + """ + Populate conversation history from a thread + """ + steps = thread["steps"] + thread_id = thread["id"] + + self.set_thread_id(thread_id) + + self.conversation_history = [] + + for step in steps: + message_type = step["type"] + if message_type == "user_message": + content = step["output"] + self.conversation_history.append(HumanMessage(content=content)) + elif message_type in [ + "assistant_message", + "assistant_message_chunk", + "ai_message", + "ai_message_chunk", + ]: + content = step["output"] + self.conversation_history.append(AIMessage(content=content)) + + def update_config(self, config): + self.config.update(config) + + +if __name__ == "__main__": + import yaml + + from .prompts import prompts + + with open( + "/Users/faridkarimli/Desktop/Programming/AI/edubot-core/edubotics_core/chat/agentic/config.yml", + "r", + ) as f: + config = yaml.safe_load(f) + + agent = Agent(thread_id="123", config=config, prompts=prompts) + + async def stream_responses(): + async for response in agent.stream("What do we do in discussion 4?"): + print(response["content"]) + + asyncio.run(stream_responses()) diff --git a/edubotics_core/chat/agentic/prompts.py b/edubotics_core/chat/agentic/prompts.py new file mode 100644 index 0000000..413f788 --- /dev/null +++ b/edubotics_core/chat/agentic/prompts.py @@ -0,0 +1,98 @@ +prompts = { + "openai": { + "rephrase_prompt": ( + "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. " + "Incorporate relevant details from the chat history to make the question clearer and more specific. " + "Do not change the meaning of the original statement, and maintain the student's tone and perspective. " + "If the question is conversational and doesn't require context, do not rephrase it. " + "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backpropagation.'. " + "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project' " + "Chat history: \n{chat_history}\n" + "Rephrase the following question only if necessary: '{input}'" + "Rephrased Question:'" + ), + "prompt_with_history": { + "normal": ( + "You are an AI Tutor for the course DS542, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. " + "If you don't know the answer, do not make things up, just say you don't know and ask the user to rephrase. Keep the conversation flowing naturally. " + "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. " + "Render math equations in LaTeX format between $ or $$ signs, stick to the parameter and variable icons found in your context. Be sure to explain the parameters and variables in the equations." + "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n" + "Do not get influenced by the style of conversation in the chat history. Follow the instructions given here." + "Chat History:\n{chat_history}\n\n" + "Context:\n{context}\n\n" + "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n" + "If the provided context is not relevant, just say you don't know and ask the user to attach the relevant documents. Do not make things up." + "Student: {input}\n" + "AI Tutor:" + ), + "eli5": ( + "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Your job is to explain things in the simplest and most engaging way possible, just like the 'Explain Like I'm 5' (ELI5) concept." + "If you don't know the answer, do your best without making things up. Keep your explanations straightforward and very easy to understand." + "Use the chat history and context to help you, but avoid repeating past responses. Provide links from the source_file metadata when they're helpful." + "Use very simple language and examples to explain any math equations, and put the equations in LaTeX format between $ or $$ signs." + "Be friendly and engaging, like you're chatting with a young child who's curious and eager to learn. Avoid complex terms and jargon." + "Include simple and clear examples wherever you can to make things easier to understand." + "Do not get influenced by the style of conversation in the chat history. Follow the instructions given here." + "Chat History:\n{chat_history}\n\n" + "Context:\n{context}\n\n" + "Answer the student's question below in a friendly, simple, and engaging way, just like the ELI5 concept. Use the context and history only if they're relevant, otherwise, just have a natural conversation." + "Give a clear and detailed explanation with simple examples to make it easier to understand. Remember, your goal is to break down complex topics into very simple terms, just like ELI5." + "Student: {input}\n" + "AI Tutor:" + ), + "socratic": ( + "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Engage the student in a Socratic dialogue to help them discover answers on their own. Use the provided context to guide your questioning." + "If you don't know the answer, do your best without making things up. Keep the conversation engaging and inquisitive." + "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata when relevant. Use the source context that is most relevant." + "Speak in a friendly and engaging manner, encouraging critical thinking and self-discovery." + "Use questions to lead the student to explore the topic and uncover answers." + "Chat History:\n{chat_history}\n\n" + "Context:\n{context}\n\n" + "Answer the student's question below by guiding them through a series of questions and insights that lead to deeper understanding. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation." + "Foster an inquisitive mindset and help the student discover answers through dialogue." + "Student: {input}\n" + "AI Tutor:" + ), + }, + "prompt_no_history": ( + "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. " + "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. " + "Provide links from the source_file metadata. Use the source context that is most relevant. " + "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n" + "Context:\n{context}\n\n" + "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n" + "Student: {input}\n" + "AI Tutor:" + ), + }, + "tiny_llama": { + "prompt_no_history": ( + "system\n" + "Assistant is an intelligent chatbot designed to help students with questions regarding the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance.\n" + "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally.\n" + "Provide links from the source_file metadata. Use the source context that is most relevant.\n" + "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n" + "\n\n" + "user\n" + "Context:\n{context}\n\n" + "Question: {input}\n" + "\n\n" + "assistant" + ), + "prompt_with_history": ( + "system\n" + "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. " + "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. " + "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. " + "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n" + "\n\n" + "user\n" + "Chat History:\n{chat_history}\n\n" + "Context:\n{context}\n\n" + "Question: {input}\n" + "\n\n" + "assistant" + ), + }, +} diff --git a/edubotics_core/chat/agentic/utils.py b/edubotics_core/chat/agentic/utils.py new file mode 100644 index 0000000..339ecea --- /dev/null +++ b/edubotics_core/chat/agentic/utils.py @@ -0,0 +1,78 @@ +import sys +import os +from typing import Literal, TypedDict, List +from PIL import Image +from typing import Annotated +from langgraph.graph.message import add_messages + +from pprint import pprint +from dotenv import load_dotenv +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_openai import ChatOpenAI, OpenAIEmbeddings + +from pydantic import BaseModel, Field + +# Load the variables from .env +load_dotenv() + +content_types = ["assignment", "lecture", "discussion", "other"] +NUM_VECTORSTORES = len(content_types) +VS_PATH = "vectorstores" + +# Data model + + +class RouteQuery(BaseModel): + """Route a user query to the most relevant vector store.""" + + datasource: Literal[ + "assignment", "lecture", "discussion", "other", "not_needed" + ] = Field( + ..., + description="Given a user question choose to route it to the relevant vector store or none.", + ) + + +class GraphState(TypedDict): + """ + Represents the state of our graph. + + Attributes: + messages: list of messages + generation: LLM generation + documents: list of documents + """ + + messages: Annotated[list, add_messages] + documents: List[str] + documents_sources: List[str] + type: Literal["retrieve", "generate"] + next: str + + +system_prompt = """ +You are an AI Assistant for a university course. +""" + +rag_template = """ +Answer the question with the help of the following context and conversation history: +Context: {context} +Conversation history: {conversation_history} + +You may not need the provided context to answer the question. If that is the case, just answer the question based on your knowledge. +You could also not be provided with any context at all, in that case, just answer the question based on your knowledge and/or conversation history. + +Input: {input} +""" + +router_template = """You are an expert at routing a user question to different vector stores. +There are 4 vector stores: +- assignment: chunks from assignment notebooks containing code exercises and maybe free-form responses. Also contains the midterm challenge. +- lecture: lecture content on machine learning, classification, regression and clustering. +- discussion: discussion content that mirrors content from lecture on a smaller scale, containing shorter exercises meant for classroom discussion +- other: anything else about the class - office hours, syllabus, project and professor info, or answer the question using the conversation history. +Return the corresponding vector store depending of the topics of the question or just not_needed because it does't match with the vector stores. + +Input: {input} +""" diff --git a/edubotics_core/chat_processor/literal_ai.py b/edubotics_core/chat_processor/literal_ai.py index ca7fb13..435ba0c 100644 --- a/edubotics_core/chat_processor/literal_ai.py +++ b/edubotics_core/chat_processor/literal_ai.py @@ -1,7 +1,8 @@ -from chainlit.data import LiteralDataLayer - +from chainlit.data.literalai import LiteralDataLayer # update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py) + + class CustomLiteralDataLayer(LiteralDataLayer): def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/edubotics_core/config/config.yml b/edubotics_core/config/config.yml new file mode 100644 index 0000000..5dc6c43 --- /dev/null +++ b/edubotics_core/config/config.yml @@ -0,0 +1,60 @@ +log_dir: "storage/logs" # str +log_chunk_dir: "storage/logs/chunks" # str +device: "cpu" # str [cuda, cpu] + +vectorstore: + load_from_HF: False # bool + reparse_files: True # bool + data_path: "storage/data" # str + url_file_path: "storage/data/urls.txt" # str + expand_urls: False # bool + db_option: "FAISS" # str [FAISS, Chroma, RAGatouille, RAPTOR, Fusion] + db_path: "vectorstores" # str + model: "sentence-transformers/all-MiniLM-L6-v2" # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002'] + search_top_k: 10 # int + score_threshold: 0.2 # float + + faiss_params: # Not used as of now + index_path: "vectorstores/faiss.index" # str + index_type: "Flat" # str [Flat, HNSW, IVF] + index_dimension: 384 # int + index_nlist: 100 # int + index_nprobe: 10 # int + + colbert_params: + index_name: "new_idx" # str + +llm_params: + llm_arch: "langchain" # [langchain] + use_history: True # bool + generate_follow_up: False # bool + memory_window: 3 # int + llm_style: "Normal" # str [Normal, ELI5] + llm_loader: "gpt-4o-mini" # str [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini] + openai_params: + temperature: 0.7 # float + local_llm_params: + temperature: 0.7 # float + repo_id: "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" # HuggingFace repo id + filename: "tinyllama-1.1b-chat-v1.0.Q5_0.gguf" # Specific name of gguf file in the repo + model_path: "storage/models/tinyllama-1.1b-chat-v1.0.Q5_0.gguf" # Path to the model file + stream: False # bool + pdf_reader: "pymupdf" # str [llama, pymupdf, gpt] + +chat_logging: + log_chat: False # bool + platform: "literalai" + callbacks: True # bool + +splitter_options: + use_splitter: True # bool + split_by_token: True # bool + remove_leftover_delimiters: True # bool + remove_chunks: False # bool + chunking_mode: "fixed" # str [fixed, semantic] + chunk_size: 1000 # int + chunk_overlap: 100 # int + chunk_separators: ["\n\n", "\n", " ", ""] # list of strings + front_chunks_to_remove: null # int or None + last_chunks_to_remove: null # int or None + delimiters_to_remove: ['\t', '\n', " ", " "] # list of strings diff --git a/edubotics_core/config/config_manager.py b/edubotics_core/config/config_manager.py new file mode 100644 index 0000000..6cc5edf --- /dev/null +++ b/edubotics_core/config/config_manager.py @@ -0,0 +1,189 @@ +from pydantic import BaseModel, conint, confloat, HttpUrl +from typing import Optional, List +import yaml + + +class FaissParams(BaseModel): + index_path: str = "vectorstores/faiss.index" + index_type: str = "Flat" # Options: [Flat, HNSW, IVF] + index_dimension: conint(gt=0) = 384 + index_nlist: conint(gt=0) = 100 + index_nprobe: conint(gt=0) = 10 + + +class ColbertParams(BaseModel): + index_name: str = "new_idx" + + +class VectorStoreConfig(BaseModel): + load_from_HF: bool = True + reparse_files: bool = True + data_path: str = "storage/data" + url_file_path: str = "storage/data/urls.txt" + expand_urls: bool = True + db_option: str = "RAGatouille" # Options: [FAISS, Chroma, RAGatouille, RAPTOR] + db_path: str = "vectorstores" + model: str = ( + # Options: [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002] + "sentence-transformers/all-MiniLM-L6-v2" + ) + search_top_k: conint(gt=0) = 3 + score_threshold: confloat(ge=0.0, le=1.0) = 0.2 + + faiss_params: Optional[FaissParams] = None + colbert_params: Optional[ColbertParams] = None + + +class OpenAIParams(BaseModel): + temperature: confloat(ge=0.0, le=1.0) = 0.7 + + +class LocalLLMParams(BaseModel): + temperature: confloat(ge=0.0, le=1.0) = 0.7 + repo_id: str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" # HuggingFace repo id + filename: str = ( + "tinyllama-1.1b-chat-v1.0.Q5_0.gguf" # Specific name of gguf file in the repo + ) + model_path: str = ( + "storage/models/tinyllama-1.1b-chat-v1.0.Q5_0.gguf" # Path to the model file + ) + + +class LLMParams(BaseModel): + llm_arch: str = "langchain" # Options: [langchain] + use_history: bool = True + generate_follow_up: bool = False + memory_window: conint(ge=1) = 3 + llm_style: str = "Normal" # Options: [Normal, ELI5] + llm_loader: str = ( + "gpt-4o-mini" # Options: [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini] + ) + openai_params: Optional[OpenAIParams] = None + local_llm_params: Optional[LocalLLMParams] = None + stream: bool = False + pdf_reader: str = "gpt" # Options: [llama, pymupdf, gpt] + + +class ChatLoggingConfig(BaseModel): + log_chat: bool = True + platform: str = "literalai" + callbacks: bool = True + + +class SplitterOptions(BaseModel): + use_splitter: bool = True + split_by_token: bool = True + remove_leftover_delimiters: bool = True + remove_chunks: bool = False + chunking_mode: str = "semantic" # Options: [fixed, semantic] + chunk_size: conint(gt=0) = 300 + chunk_overlap: conint(ge=0) = 30 + chunk_separators: List[str] = ["\n\n", "\n", " ", ""] + front_chunks_to_remove: Optional[conint(ge=0)] = None + last_chunks_to_remove: Optional[conint(ge=0)] = None + delimiters_to_remove: List[str] = ["\t", "\n", " ", " "] + + +class RetrieverConfig(BaseModel): + retriever_hf_paths: dict[str, str] = {"RAGatouille": "XThomasBU/Colbert_Index"} + + +class MetadataConfig(BaseModel): + metadata_links: List[HttpUrl] = [ + "https://dl4ds.github.io/sp2024/lectures/", + "https://dl4ds.github.io/sp2024/schedule/", + ] + slide_base_link: HttpUrl = "https://dl4ds.github.io" + + +class TokenConfig(BaseModel): + cooldown_time: conint(gt=0) = 60 + regen_time: conint(gt=0) = 180 + tokens_left: conint(gt=0) = 2000 + all_time_tokens_allocated: conint(gt=0) = 1000000 + + +class MiscConfig(BaseModel): + github_repo: HttpUrl = "https://github.com/edubotics-ai/edubot-core" + docs_website: HttpUrl = "https://dl4ds.github.io/dl4ds_tutor/" + + +class APIConfig(BaseModel): + timeout: conint(gt=0) = 60 + + +class Config(BaseModel): + log_dir: str = "storage/logs" + log_chunk_dir: str = "storage/logs/chunks" + device: str = "cpu" # Options: ['cuda', 'cpu'] + + vectorstore: VectorStoreConfig + llm_params: LLMParams + chat_logging: ChatLoggingConfig + splitter_options: SplitterOptions + retriever: RetrieverConfig + metadata: MetadataConfig + token_config: TokenConfig + misc: MiscConfig + api_config: APIConfig + + +class ConfigManager: + def __init__(self, config_path: str, project_config_path: str): + self.config_path = config_path + self.project_config_path = project_config_path + self.config = self.load_config() + self.validate_config() + + def load_config(self) -> Config: + with open(self.config_path, "r") as f: + config_data = yaml.safe_load(f) + + with open(self.project_config_path, "r") as f: + project_config_data = yaml.safe_load(f) + + # Merge the two configurations + merged_config = {**config_data, **project_config_data} + + return Config(**merged_config) + + def get_config(self) -> Config: + return ConfigWrapper(self.config) + + def validate_config(self): + # If any required fields are missing, raise an error + # required_fields = [ + # "vectorstore", "llm_params", "chat_logging", "splitter_options", + # "retriever", "metadata", "token_config", "misc", "api_config" + # ] + # for field in required_fields: + # if not hasattr(self.config, field): + # raise ValueError(f"Missing required configuration field: {field}") + + # # Validate types of specific fields + # if not isinstance(self.config.vectorstore, VectorStoreConfig): + # raise TypeError("vectorstore must be an instance of VectorStoreConfig") + # if not isinstance(self.config.llm_params, LLMParams): + # raise TypeError("llm_params must be an instance of LLMParams") + pass + + +class ConfigWrapper: + def __init__(self, config: Config): + self._config = config + + def __getitem__(self, key): + return getattr(self._config, key) + + def __getattr__(self, name): + return getattr(self._config, name) + + def dict(self): + return self._config.dict() + + +# Usage +config_manager = ConfigManager( + config_path="config/config.yml", project_config_path="config/project_config.yml" +) +# config = config_manager.get_config().dict() diff --git a/edubotics_core/config/constants.py b/edubotics_core/config/constants.py index ef2b51d..a0b0a7f 100644 --- a/edubotics_core/config/constants.py +++ b/edubotics_core/config/constants.py @@ -6,9 +6,10 @@ load_dotenv(".env") # Centralized definition of required constants for easy management and access -TIMEOUT = os.getenv("TIMEOUT", 60) +TIMEOUT = os.getenv("TIMEOUT", 30) OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY", "") +COHERE_API_KEY = os.getenv("COHERE_API_KEY", "") HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "") GITHUB_PERSONAL_ACCESS_TOKEN = os.getenv("GITHUB_PERSONAL_ACCESS_TOKEN", "") GITHUB_USERNAME = os.getenv("GITHUB_USERNAME", "") diff --git a/edubotics_core/config/project_config.yml b/edubotics_core/config/project_config.yml new file mode 100644 index 0000000..5f8cd17 --- /dev/null +++ b/edubotics_core/config/project_config.yml @@ -0,0 +1,61 @@ +retriever: + retriever_hf_paths: + RAGatouille: "XThomasBU/Colbert_Index" + +metadata: + base_link: "https://tools4ds.github.io/fa2024/" + metadata_links: + [ + "https://tools4ds.github.io/fa2024/lectures/", + "https://tools4ds.github.io/fa2024/schedule_a1/", + ] + slide_base_link: "https://tools4ds.github.io/fa2024/lectures/" + + # Assignment base link is used to find the webpage where the assignment is described/posted + assignment_base_link: "https://tools4ds.github.io/fa2024/assignments/" + + # Define content types - assignments, lectures, etc. + content_types: + - "lecture" + - "assignment" + - "discussion" + - "project" + - "other" + + # These need to be patterns from URLs that can be used to identify the type of content uniquely + lectures_pattern: "/lectures/" + assignments_pattern: "/assignments/" + discussion_pattern: "/discussion/" + project_pattern: "/project/" + + # These are fields that can be extracted from the webpages of the course content + lecture_metadata_fields: + - "title" + - "tldr" + - "date" + - "lecture_recording" + - "suggested_readings" + + assignment_metadata_fields: + - "title" + - "release_date" + - "due_date" + - "source_file" + +token_config: + cooldown_time: 60 + regen_time: 180 + tokens_left: 50000 + all_time_tokens_allocated: 1000000 + +content: + notebookheaders_to_split_on: + - ["##", "Section"] + - ["#", "Title"] + +misc: + github_repo: "https://github.com/edubotics-ai/edubot-core" + docs_website: "https://dl4ds.github.io/dl4ds_tutor/" + +api_config: + timeout: 30 diff --git a/edubotics_core/dataloader/data_loader.py b/edubotics_core/dataloader/data_loader.py index 653d652..002aa20 100644 --- a/edubotics_core/dataloader/data_loader.py +++ b/edubotics_core/dataloader/data_loader.py @@ -1,5 +1,7 @@ import os +from pprint import pprint import re +import traceback import requests import pysrt from langchain_community.document_loaders import ( @@ -55,11 +57,11 @@ def check_links(self, base_url, html_content): absolute_url = urljoin(base_url, href) link["href"] = absolute_url - resp = requests.head(absolute_url, timeout=TIMEOUT) - if resp.status_code != 200: - # logger.warning( - # f"Link {absolute_url} is broken. Status code: {resp.status_code}" - # ) + try: + resp = requests.head(absolute_url, timeout=TIMEOUT) + if resp.status_code != 200: + pass + except Exception as e: pass return str(soup) @@ -101,8 +103,18 @@ def read_pdf(self, temp_file_path: str): return documents def read_txt(self, temp_file_path: str): - loader = TextLoader(temp_file_path, autodetect_encoding=True) - return loader.load() + if temp_file_path.startswith("http"): + return self.read_txt_from_url(temp_file_path) + else: + loader = TextLoader(temp_file_path, autodetect_encoding=True) + return loader.load() + + def read_txt_from_url(self, url: str): + response = requests.get(url, timeout=TIMEOUT) + if response.status_code == 200: + return [Document(page_content=response.text)] + else: + return None def read_docx(self, temp_file_path: str): loader = Docx2txtLoader(temp_file_path) @@ -122,7 +134,11 @@ def read_youtube_transcript(self, url: str): return loader.load() def read_html(self, url: str): - return [Document(page_content=self.web_reader.read_html(url))] + return [ + Document( + page_content=self.web_reader.read_html(url), metadata={"source": url} + ) + ] def read_tex_from_url(self, tex_url): response = requests.get(tex_url, timeout=TIMEOUT) @@ -152,19 +168,21 @@ def read_notebook(self, notebook_path): notebook_path = notebook_path.replace("/blob/", "/") self.logger.info(f"Changed notebook path to {notebook_path}") - return read_notebook_from_file( + notebook_content = read_notebook_from_file( notebook_path, headers_to_split_on=self.config["content"]["notebookheaders_to_split_on"], ) + return [Document(page_content=notebook_content)] + class ChunkProcessor: def __init__(self, config, logger): self.config = config self.logger = logger - self.document_data = {} - self.document_metadata = {} + self.document_data = [] + self.document_metadata = [] self.document_chunks_full = [] # TODO: Fix when reparse_files is False @@ -217,7 +235,7 @@ def process_chunks( self, documents, file_type="txt", source="", page=0, metadata={} ): # TODO: Clear up this pipeline of re-adding metadata - documents = [Document(page_content=documents, source=source, page=page)] + # documents = [Document(page_content=documents, source=source, page=page)] if ( file_type == "pdf" and self.config["splitter_options"]["chunking_mode"] == "fixed" @@ -227,14 +245,24 @@ def process_chunks( document_chunks = self.splitter.split_documents(documents) # add the source and page number back to the metadata - for chunk in document_chunks: - chunk.metadata["source"] = source - chunk.metadata["page"] = page - + for i, chunk in enumerate(document_chunks): + # print( + # f"Chunk {i}: {chunk.page_content[:100] + '...' +chunk.page_content[-100:]}" + # ) # add the metadata extracted from the document for key, value in metadata.items(): chunk.metadata[key] = value + chunk_name = ( + os.path.basename(source) if os.path.basename(source) != "" else source + ) + if "README" in chunk_name: + chunk_name = "/".join(source.split("/")[-3:]) + + chunk.metadata["source"] = source + chunk.metadata["page"] = i + chunk.metadata["chunk_id"] = f"{chunk_name}_{i}" + if self.config["splitter_options"]["remove_leftover_delimiters"]: document_chunks = self.remove_delimiters(document_chunks) if self.config["splitter_options"]["remove_chunks"]: @@ -275,73 +303,89 @@ def chunk_docs(self, file_reader, uploaded_files, weblinks): [file_reader] * len(weblinks), [addl_metadata] * len(weblinks), ) - - document_names = [ - f"{file_name}_{page_num}" - for file_name, pages in self.document_data.items() - for page_num in pages.keys() - ] documents = [ - page for doc in self.document_data.values() for page in doc.values() + "".join([chunk.page_content for chunk in doc["chunks"]]) + for doc in self.document_data ] - document_metadata = [ - page for doc in self.document_metadata.values() for page in doc.values() + chunks = [chunk for doc in self.document_data for chunk in doc["chunks"]] + document_names = [doc["document_name"] for doc in self.document_data] + document_metadata = [doc["metadata"] for doc in self.document_data] + + self.document_data_dict = [ + { + "document_name": doc["document_name"], + "metadata": doc["metadata"], + "chunks": [ + { + "content": chunk.page_content, + "source": chunk.metadata["source"], + "page": chunk.metadata["page"], + "chunk_id": chunk.metadata["chunk_id"], + } + for chunk in doc["chunks"] + ], + } + for doc in self.document_data ] self.save_document_data() - self.logger.info( - f"Total document chunks extracted: {len(self.document_chunks_full)}" - ) + total_chunks = sum([len(doc["chunks"]) for doc in self.document_data]) + self.logger.info(f"Total document chunks extracted: {(total_chunks)}") - return self.document_chunks_full, document_names, documents, document_metadata + return chunks, document_names, documents, document_metadata def process_documents( self, documents, file_path, file_type, metadata_source, addl_metadata ): - file_data = {} - file_metadata = {} - - for i, doc in enumerate(documents): - page_num = doc.metadata.get("page", i) - file_data[page_num] = doc.page_content - - # Create a new dictionary for metadata in each iteration - metadata = doc.metadata - metadata["source"] = file_path - metadata["page"] = page_num - - if self.config["metadata"]["lectures_pattern"] in file_path: - addl_metadata_copy = addl_metadata.copy() - metadata.update(addl_metadata_copy) - metadata["content_type"] = "lecture" - elif self.config["metadata"]["assignments_pattern"] in file_path: - addl_metadata = LLMMetadataExtractor( - fields=self.config["metadata"]["assignment_metadata_fields"] - ).extract_metadata(file_path) - - metadata.update(addl_metadata) - metadata["content_type"] = "assignment" - else: - metadata["content_type"] = "other" - file_metadata[page_num] = metadata + doc_name = os.path.basename(file_path) + if "README" in doc_name: + doc_name = file_path - if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]: - document_chunks = self.process_chunks( - doc.page_content, - file_type, - source=file_path, - page=page_num, - metadata=metadata, - ) - self.document_chunks_full.extend(document_chunks) + if doc_name == "": + doc_name = file_path + + doc_metadata = { + "source": file_path, + } - self.document_data[file_path] = file_data - self.document_metadata[file_path] = file_metadata + # Processing metadata for documents + if self.config["metadata"]["lectures_pattern"] in file_path: + addl_metadata_copy = addl_metadata.copy() + doc_metadata.update(addl_metadata_copy) + doc_metadata["content_type"] = "lecture" + elif self.config["metadata"]["assignments_pattern"] in file_path: + addl_metadata = LLMMetadataExtractor( + fields=self.config["metadata"]["assignment_metadata_fields"] + ).extract_metadata(file_path) + + doc_metadata.update(addl_metadata) + doc_metadata["content_type"] = "assignment" + else: + doc_metadata["content_type"] = "other" + + # Chunking + if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]: + document_chunks = self.process_chunks( + documents, + file_type, + source=file_path, + metadata=doc_metadata, + ) + self.document_chunks_full.extend(document_chunks) + + file_data = { + "document_name": doc_name, + "metadata": doc_metadata, + "chunks": document_chunks, + } + + self.document_data.append(file_data) + # self.document_metadata[file_path] = doc_metadata def process_file(self, file_path, file_index, file_reader, addl_metadata): - print(f"Processing file {file_index + 1} : {file_path}") + self.logger.info(f"Processing file {file_index + 1} : {file_path}") file_name = os.path.basename(file_path) file_type = file_name.split(".")[-1] @@ -353,6 +397,7 @@ def process_file(self, file_path, file_index, file_reader, addl_metadata): "srt": file_reader.read_srt, "tex": file_reader.read_tex_from_url, "ipynb": file_reader.read_notebook, + "md": file_reader.read_txt, } if file_type not in read_methods: self.logger.warning(f"Unsupported file type: {file_type}") @@ -388,55 +433,44 @@ def process_weblink(self, link, link_index, file_reader, addl_metadata): else: documents = file_reader.read_html(link) - self.process_documents(documents, link, "txt", "link", addl_metadata) + if len(set([doc.metadata["source"] for doc in documents])) > 1: + self.logger.warning( + f"Documents from link {link_index + 1} : {link} have multiple sources" + ) + for doc in documents: + self.process_documents( + [doc], doc.metadata["source"], "txt", "link", addl_metadata + ) + else: + self.process_documents(documents, link, "txt", "link", addl_metadata) except Exception as e: self.logger.error(f"Error Reading link {link_index + 1} : {link}: {str(e)}") + self.logger.error(f"Error traceback: {traceback.format_exc()}") def save_document_data(self): - if not os.path.exists(f"{self.config['log_chunk_dir']}/docs"): - os.makedirs(f"{self.config['log_chunk_dir']}/docs") - self.logger.info( - f"Creating directory {self.config['log_chunk_dir']}/docs for document data" - ) - self.logger.info( - f"Saving document content to {self.config['log_chunk_dir']}/docs/doc_content.json" - ) - if not os.path.exists(f"{self.config['log_chunk_dir']}/metadata"): - os.makedirs(f"{self.config['log_chunk_dir']}/metadata") - self.logger.info( - f"Creating directory {self.config['log_chunk_dir']}/metadata for document metadata" - ) + if not os.path.exists(f"{self.config['log_chunk_dir']}"): + os.makedirs(f"{self.config['log_chunk_dir']}") + self.logger.info(f"Creating directory {self.config['log_chunk_dir']}") self.logger.info( - f"Saving document metadata to {self.config['log_chunk_dir']}/metadata/doc_metadata.json" + f"Saving document content to {self.config['log_chunk_dir']}/doc_content.json" ) - with open( - f"{self.config['log_chunk_dir']}/docs/doc_content.json", "w" - ) as json_file: - json.dump(self.document_data, json_file, indent=4) - with open( - f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "w" - ) as json_file: - json.dump(self.document_metadata, json_file, indent=4) + with open(f"{self.config['log_chunk_dir']}/doc_content.json", "w") as json_file: + json.dump(self.document_data_dict, json_file, indent=4) def load_document_data(self): try: with open( - f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r" + f"{self.config['log_chunk_dir']}/doc_content.json", "r" ) as json_file: - self.document_data = json.load(json_file) - with open( - f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r" - ) as json_file: - self.document_metadata = json.load(json_file) + self.document_data_dict = json.load(json_file) self.logger.info( - f"Loaded document content from {self.config['log_chunk_dir']}/docs/doc_content.json. Total documents: {len(self.document_data)}" + f"Loaded document content from {self.config['log_chunk_dir']}/doc_content.json. Total documents: {len(self.document_data)}" ) except FileNotFoundError: self.logger.warning( - f"Document content not found in {self.config['log_chunk_dir']}/docs/doc_content.json" + f"Document content not found in {self.config['log_chunk_dir']}/doc_content.json" ) - self.document_data = {} - self.document_metadata = {} + self.document_data_dict = {} class DataLoader: @@ -458,13 +492,18 @@ def get_chunks(self, uploaded_files, weblinks): parser = argparse.ArgumentParser(description="Data Loader") parser.add_argument( - "--config_file", type=str, help="Path to the main config file", required=True + "--config_file", + type=str, + help="Path to the main config file", + # required=True, + default="config/config.yml", ) parser.add_argument( "--project_config_file", type=str, help="Path to the project config file", - required=True, + # required=True, + default="config/project_config.yml", ) args = parser.parse_args() @@ -511,3 +550,4 @@ def get_chunks(self, uploaded_files, weblinks): print(document_names[:5]) print(len(document_chunks)) + print(document_chunks[2].page_content[:100]) diff --git a/edubotics_core/dataloader/pdf_readers/gpt.py b/edubotics_core/dataloader/pdf_readers/gpt.py index d495ca7..f87351f 100644 --- a/edubotics_core/dataloader/pdf_readers/gpt.py +++ b/edubotics_core/dataloader/pdf_readers/gpt.py @@ -80,6 +80,7 @@ def parse(self, pdf_path): Document(page_content=page, metadata={"source": pdf_path, "page": i}) for i, page in enumerate(output) ] + return documents def encode_image(self, image): diff --git a/edubotics_core/dataloader/repo_readers/github.py b/edubotics_core/dataloader/repo_readers/github.py index dd13953..0c318e0 100644 --- a/edubotics_core/dataloader/repo_readers/github.py +++ b/edubotics_core/dataloader/repo_readers/github.py @@ -1,3 +1,5 @@ +import os +from cohere import Document import requests import base64 from edubotics_core.dataloader.repo_readers.helpers import extract_notebook_content @@ -65,10 +67,10 @@ def get_repo_contents(self, url): branch (str, optional): The branch to fetch the contents from. Defaults to 'main'. path (str, optional): The path to the repository. Defaults to ''. """ - repo_owner, repo_name, branch = self.parse_github_url(url) + repo_owner, repo_name, branch, path = self.parse_github_url(url) # top level path is '' - return self.read_github_repo_contents(repo_owner, repo_name, branch) + return self.read_github_repo_contents(repo_owner, repo_name, branch, path) def read_github_repo_contents(self, repo_owner, repo_name, branch="main", path=""): """ @@ -100,16 +102,14 @@ def read_github_repo_contents(self, repo_owner, repo_name, branch="main", path=" file_path = item["path"] extension = file_path.split(".")[-1] - if self.repo_allow_list: - if not any( - pattern in file_path for pattern in self.repo_allow_list - ): - continue if file_path in self.ignore_files or extension in self.ignore_ext: continue file_content = self.get_github_file_content( - repo_owner, repo_name, file_path, branch + repo_owner, + repo_name, + file_path, + branch, ) full_path = f"https://github.com/{repo_owner}/{repo_name}/blob/{branch}/{file_path}" @@ -156,7 +156,6 @@ def get_github_file_content( decoded_content = base64.b64decode(content).decode("utf-8") if not decoded_content.strip(): - print(f"File {file_path} is empty.") return None if file_path.endswith(".ipynb"): @@ -164,7 +163,6 @@ def get_github_file_content( return decoded_content else: - print(f"Failed to fetch file: {response.status_code}") return None @staticmethod @@ -190,26 +188,22 @@ def parse_github_url(url): repo_owner = path_parts[0] repo_name = path_parts[1] branch = "main" # Default branch + path = os.path.join(*path_parts[4:]) if len(path_parts) > 3 and path_parts[2] == "tree": branch = path_parts[3] - return repo_owner, repo_name, branch + return repo_owner, repo_name, branch, path # Usage example if __name__ == "__main__": # Set up argparse to get username and github_url as arguments - parser = argparse.ArgumentParser(description="Read GitHub repository contents.") - parser.add_argument( - "--github_url", type=str, help="GitHub repository URL", required=True - ) - args = parser.parse_args() - github_url = args.github_url + github_url = "https://github.com/DL4DS/sp2024_notebooks/tree/main/discussion/disc4" reader = GithubReader() # Initialize the GithubReader - owner, name, branch = GithubReader.parse_github_url(github_url) + owner, name, branch, path = GithubReader.parse_github_url(github_url) print(f"Owner: {owner}, Repo: {name}, Branch: {branch}") repo_contents = reader.get_repo_contents(github_url) diff --git a/edubotics_core/dataloader/repo_readers/helpers.py b/edubotics_core/dataloader/repo_readers/helpers.py index 0fa4a93..6cd3fa7 100644 --- a/edubotics_core/dataloader/repo_readers/helpers.py +++ b/edubotics_core/dataloader/repo_readers/helpers.py @@ -3,6 +3,7 @@ import requests import argparse from langchain_text_splitters import MarkdownHeaderTextSplitter +from langchain.docstore.document import Document def read_notebook_from_url(notebook_url): @@ -71,8 +72,8 @@ def extract_notebook_content( headers_to_split_on=headers_to_split_on, strip_headers=False ) - chunks = markdown_splitter.split_text(content) - return chunks + # chunks = markdown_splitter.split_text(content) + return content if __name__ == "__main__": diff --git a/edubotics_core/dataloader/webpage_crawler.py b/edubotics_core/dataloader/webpage_crawler.py index a4f3246..e98eca1 100644 --- a/edubotics_core/dataloader/webpage_crawler.py +++ b/edubotics_core/dataloader/webpage_crawler.py @@ -42,10 +42,12 @@ async def get_links(self, session: ClientSession, website_link: str, base_url: s href = link["href"].strip() full_url = urljoin(base_url, href) normalized_url = self.normalize_url(full_url) # sections removed + url_without_extension = normalized_url.rsplit(".", 1)[0] if ( normalized_url not in self.dict_href_links - # and self.is_child_url(normalized_url, base_url) and self.url_exists(normalized_url) + and url_without_extension not in self.dict_href_links + and self.is_relevant_link(full_url, base_url) ): self.dict_href_links[normalized_url] = None list_links.append(normalized_url) @@ -164,3 +166,60 @@ async def _search_links( return found_url return None + + def is_relevant_link(self, url: str, base_url: str) -> bool: + """ + Determines if a link is relevant to the tutor based on multiple criteria. + + Args: + url: The URL to check + base_url: The base URL of the course/platform + """ + # Skip if it's not a valid URL format + if not url or url.startswith("mailto:") or url.startswith("tel:"): + return False + + # Skip common irrelevant file types + irrelevant_extensions = [ + ".zip", + ".exe", + ".dmg", + ".pkg", + ".mp3", + ".mp4", + ".avi", + ".mov", + ".jpg", + ".jpeg", + ".png", + ".gif", + ] + if any(url.lower().endswith(ext) for ext in irrelevant_extensions): + return False + + # Skip social media and common external platforms + irrelevant_domains = ["facebook.com", "twitter.com", "instagram.com"] + domain = urlparse(url).netloc.lower() + if any(site in domain for site in irrelevant_domains): + return False + + # Check for relevant URL patterns + relevant_patterns = [ + "lecture", + "assignment", + "course", + "material", + "resource", + "syllabus", + "schedule", + "homework", + "quiz", + "lab", + "project" "discussion", + "schedule", + "notebook", + "slides", + ] + if any(pattern in url for pattern in relevant_patterns): + return True + return False diff --git a/edubotics_core/retriever/fusion_retriever.py b/edubotics_core/retriever/fusion_retriever.py new file mode 100644 index 0000000..e22df59 --- /dev/null +++ b/edubotics_core/retriever/fusion_retriever.py @@ -0,0 +1,115 @@ +import numpy as np +import bm25s +import os + +from langchain_core.vectorstores.base import VectorStore +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever as BaseRetrieverLangchain +from .base import BaseRetriever as BaseRetrieverEdubotics + +from vectorstore.embedding_model_loader import EmbeddingModelLoader +from langchain_community.vectorstores import FAISS +from config.config_manager import config_manager + + +class FusionRetrieverBase(BaseRetrieverLangchain): + vectorstore: VectorStore = None + config: dict = None + alpha: float = 0.5 + k: int = 10 + bm25: bm25s.BM25 = None + + def __init__(self, vectorstore: VectorStore, config): + super().__init__() + self.vectorstore = vectorstore # FAISS + self.config = config + self.alpha = 0.5 + self.k = config["vectorstore"]["search_top_k"] + self.bm25 = bm25s.BM25.load( + os.path.join( + self.config["vectorstore"]["db_path"], "db_fusion", "bm25_index" + ), + load_corpus=True, + ) + + print("FusionRetrieverBase initialized") + + def _get_relevant_documents(self, query: str, **kwargs) -> list[Document]: + all_docs = self.vectorstore.similarity_search_with_relevance_scores( + "", k=self.vectorstore.index.ntotal + ) + + bm25_scores = self.bm25.get_scores(query.split()) + + vector_results = self.vectorstore.similarity_search_with_relevance_scores( + query, k=len(all_docs) + ) + + vector_scores = np.array([score for _, score in vector_results]) + vector_scores = 1 - (vector_scores - np.min(vector_scores)) / ( + np.max(vector_scores) - np.min(vector_scores) + ) + + bm25_scores = bm25_scores - np.min(bm25_scores) + + # Check if the max and min are the same to avoid division by zero + bm25_range = np.max(bm25_scores) - np.min(bm25_scores) + if bm25_range > 0: + bm25_scores = bm25_scores / bm25_range + else: + bm25_scores = np.zeros_like(bm25_scores) # or handle it as needed + + combined_scores = self.alpha * vector_scores + (1 - self.alpha) * bm25_scores + + sorted_indices = np.argsort(combined_scores)[::-1] + + return [ + Document( + page_content=all_docs[i][0].page_content, + metadata={"score": combined_scores[i], **all_docs[i][0].metadata}, + ) + for i in sorted_indices[:10] + ] + + async def _aget_relevant_documents(self, query: str, **kwargs) -> list[Document]: + return self._get_relevant_documents(query, **kwargs) + + +class FusionRetriever(BaseRetrieverEdubotics): + + def __init__(self): + return + + def return_retriever(self, vectorstore: VectorStore, config): + retriever = FusionRetrieverBase(vectorstore, config) + return retriever + + +if __name__ == "__main__": + config = config_manager.get_config().dict() + + embedding_model_loader = EmbeddingModelLoader(config) + embedding_model = embedding_model_loader.load_embedding_model() + + cwd = os.getcwd() + + faiss_path = os.path.join( + cwd, + config["vectorstore"]["db_path"], + "db_FAISS_sentence-transformers", + "all-MiniLM-L6-v2_semantic", + ) + + vectorstore = FAISS.load_local( + faiss_path, embeddings=embedding_model, allow_dangerous_deserialization=True + ) + retriever = FusionRetrieverBase(vectorstore, config) + + docs = retriever._get_relevant_documents("**Question 1**") + print(f"Number of documents retrieved: {len(docs)}") + + for doc in docs: + print(doc.page_content[:50]) + print(doc.metadata["score"]) + print(doc.metadata["source"]) + print("----") diff --git a/edubotics_core/retriever/helpers.py b/edubotics_core/retriever/helpers.py index 5c21e50..3e828c7 100644 --- a/edubotics_core/retriever/helpers.py +++ b/edubotics_core/retriever/helpers.py @@ -3,6 +3,8 @@ from langchain.schema.document import Document from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun from typing import List +from edubotics_core.config.constants import COHERE_API_KEY +import cohere class VectorStoreRetrieverScore(VectorStoreRetriever): @@ -30,9 +32,18 @@ async def _aget_relevant_documents( query, **self.search_kwargs ) ) - # Make the score part of the document metadata - for doc, similarity in docs_and_similarities: - doc.metadata["score"] = similarity - docs = [doc for doc, _ in docs_and_similarities] - return docs + + cohere_client = cohere.Client(COHERE_API_KEY) + + docs_content = [doc.page_content for doc in docs if doc.page_content != ""] + response = cohere_client.rerank( + query=query, documents=docs_content, top_n=5, model="rerank-english-v3.0" + ) + + final_docs = [] + for result in response.results: + doc = docs[result.index] + doc.metadata["score"] = result.relevance_score + final_docs.append(doc) + return final_docs diff --git a/edubotics_core/retriever/mvs_retriever.py b/edubotics_core/retriever/mvs_retriever.py new file mode 100644 index 0000000..e1ecf51 --- /dev/null +++ b/edubotics_core/retriever/mvs_retriever.py @@ -0,0 +1,21 @@ +from edubotics_core.retriever.faiss_retriever import FaissRetriever +import os + + +class MvsRetriever: + def __init__(self, config): + self.config = config + self.return_top_k = config["vectorstore"]["search_top_k"] + + def load_retrievers(self): + self.retrievers = {} + for content_type in self.config["vectorstore"]["content_types"]: + path = os.path.join( + self.config["vectorstore"]["db_path"], "mvs", f"FAISS_{content_type}" + ) + self.retrievers[content_type] = FaissRetriever().return_retriever( + path, self.config + ) + + def return_retriever(self): + return self.retrievers diff --git a/edubotics_core/vectorstore/embedding_model_loader.py b/edubotics_core/vectorstore/embedding_model_loader.py index bb59910..07c9b78 100644 --- a/edubotics_core/vectorstore/embedding_model_loader.py +++ b/edubotics_core/vectorstore/embedding_model_loader.py @@ -1,5 +1,5 @@ from langchain_community.embeddings import OpenAIEmbeddings -from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_huggingface import HuggingFaceEmbeddings from edubotics_core.config.constants import OPENAI_API_KEY, HUGGINGFACE_TOKEN @@ -11,7 +11,7 @@ def load_embedding_model(self): if self.config["vectorstore"]["model"] in ["text-embedding-ada-002"]: embedding_model = OpenAIEmbeddings( deployment="SL-document_embedder", - model=self.config["vectorestore"]["model"], + model=self.config["vectorstore"]["model"], show_progress_bar=True, openai_api_key=OPENAI_API_KEY, disallowed_special=(), diff --git a/edubotics_core/vectorstore/helpers.py b/edubotics_core/vectorstore/helpers.py index e69de29..86abea8 100644 --- a/edubotics_core/vectorstore/helpers.py +++ b/edubotics_core/vectorstore/helpers.py @@ -0,0 +1,30 @@ +from typing import List, Dict +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.documents import Document + + +def determine_content_type(document: Document) -> str: + """ + Determine the content type of a document based on its source. + TODO: Need to un-hardcode the if statements, use the content type patterns listed in the config instead. + """ + source = document.metadata["source"] + if "/assignments/" in source or "midterm" in source: + return "assignment" + elif ( + "/lectures/" in source + or "Course-Notes" in source + or "slides" in source + or "lecture" in source + ): + return "lecture" + elif ( + "/discussions/" in source + or "discussion_slides" in source + or "discussion" in source + ): + return "discussion" + elif "/project/" in source: + return "project" + else: + return "other" diff --git a/edubotics_core/vectorstore/mvs.py b/edubotics_core/vectorstore/mvs.py new file mode 100644 index 0000000..630f2e5 --- /dev/null +++ b/edubotics_core/vectorstore/mvs.py @@ -0,0 +1,108 @@ +from langchain_community.vectorstores import FAISS +from edubotics_core.vectorstore.base import VectorStoreBase +from edubotics_core.retriever.faiss_retriever import FaissRetriever +from edubotics_core.vectorstore.embedding_model_loader import EmbeddingModelLoader +from edubotics_core.vectorstore.helpers import determine_content_type + +import os + + +class MultiVectorStore(VectorStoreBase): + """ + Implementation of the multi-vector approach, where each vector store corresponds to a different content type. + A parent folder in the db_path called "mvs" is created, and within it, a folder for each vector store (content type) is created. + """ + + def __init__(self, config): + self.config = config + self.vectorstores = {} + self.content_types = config["metadata"]["content_types"] + self.local_path = os.path.join( + self.config["vectorstore"]["db_path"], + "mvs", + self.config["vectorstore"]["db_option"] + "_", + ) + + def _init_vector_db(self): + """ + Initializes the vector stores for each content type. Simply creates an empty FAISS vector store for each content type. + """ + for content_type in self.content_types: + self.vectorstores[content_type] = FAISS( + embedding_function=None, index=0, index_to_docstore_id={}, docstore={} + ) + + def create_database(self, document_chunks, embedding_model): + """ + Creates and saves the vector stores for each content type. + """ + content_map = {} + + for content_type in self.content_types: + content_map[content_type] = list( + filter( + lambda x: determine_content_type(x) == content_type, document_chunks + ) + ) + + for content_type in content_map: + content_chunks = content_map[content_type] + if len(content_chunks) > 0: + for chunk in content_chunks: + chunk.metadata["content_type"] = content_type + + self.vectorstores[content_type] = FAISS.from_documents( + documents=content_chunks, embedding=embedding_model + ) + self.vectorstores[content_type].save_local( + self.local_path + (content_type or "") + ) + else: + print(f"No content chunks found for {content_type}") + + def load_database(self, embedding_model) -> dict: + """ + Loads the vector stores for each content type. + """ + for content_type in self.content_types: + try: + path = self.local_path + (content_type or "") + self.vectorstores[content_type] = FAISS.load_local( + path, embedding_model, allow_dangerous_deserialization=True + ) + except Exception as e: + print(f"Error loading vector store for {content_type}: {e}") + continue + return self.vectorstores + + def as_retriever(self): + """ + Returns retrievers for each content type as a dictionary. + """ + retrievers = {} + embedding_model = EmbeddingModelLoader(self.config).load_embedding_model() + vectorstores = self.load_database(embedding_model) + for content_type in self.content_types: + if content_type in vectorstores: + retriever = FaissRetriever().return_retriever( + vectorstores[content_type], self.config + ) + retrievers[content_type] = retriever + else: + print(f"No vector store found for {content_type}") + + return retrievers + + def __len__(self): + """ + Returns the total number of documents in all vector stores. + """ + return sum( + len(self.vectorstores[content_type]) for content_type in self.content_types + ) + + def __str__(self): + """ + Returns the string representation of the MultiVectorStore. + """ + return f"MultiVectorStore with {len(self)} documents" diff --git a/edubotics_core/vectorstore/store_manager.py b/edubotics_core/vectorstore/store_manager.py index 178a581..c3aa959 100644 --- a/edubotics_core/vectorstore/store_manager.py +++ b/edubotics_core/vectorstore/store_manager.py @@ -3,6 +3,9 @@ from edubotics_core.dataloader.webpage_crawler import WebpageCrawler from edubotics_core.dataloader.data_loader import DataLoader from edubotics_core.vectorstore.embedding_model_loader import EmbeddingModelLoader +from edubotics_core.vectorstore.helpers import determine_content_type +from langchain_core.documents import Document + import logging import os import time @@ -21,7 +24,7 @@ def __init__(self, config, logger=None): self.webpage_crawler = WebpageCrawler() self.vector_db = VectorStore(self.config) - self.logger.info("VectorDB instance instantiated") + self.logger.info("Vector database instantiated") def _setup_logging(self): logger = logging.getLogger(__name__) @@ -84,15 +87,19 @@ def initialize_database( documents: list, document_metadata: list, ): - if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]: + if self.config["vectorstore"]["db_option"] in [ + "FAISS", + "Chroma", + "RAPTOR", + "MVS", + ]: self.embedding_model = self.create_embedding_model() else: self.embedding_model = None - self.logger.info("Initializing vector_db") - self.logger.info( - "\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"]) - ) + self.logger.info("Initializing vector database...") + self.logger.info(f"There are {len(document_chunks)} chunks. ") + self.vector_db._create_database( document_chunks, document_names, @@ -106,8 +113,6 @@ def create_database(self): data_loader = DataLoader(self.config, self.logger) self.logger.info("Loading data") local_files, urls = self.load_files() - # print(f"Local files: {local_files}") - # print(f"URLs: {urls}") files, webpages = self.webpage_crawler.clean_url_list(urls) files.extend(local_files) self.logger.info(f"Number of files: {len(files)}") @@ -146,9 +151,6 @@ def load_database(self): raise ValueError( f"Error loading database, check if it exists. if not run python -m edubotics_core.vectorstore.store_manager / Resteart the HF Space: {e}" ) - # print(f"Creating database") - # self.create_database() - # self.loaded_vector_db = self.vector_db._load_database(self.embedding_model) end_time = time.time() # End time for loading database self.logger.info( f"Time taken to load database {self.config['vectorstore']['db_option']}: {end_time - start_time} seconds" @@ -194,7 +196,6 @@ def main(): # combine the two configs config.update(project_config) - print(config) print(f"Trying to create database with config: {config}") vector_db = VectorStoreManager(config) if config["vectorstore"]["load_from_HF"]: @@ -208,9 +209,6 @@ def main(): ] ) else: - # print(f"HF_PATH not available for {config['vectorstore']['db_option']}") - # print("Creating database") - # vector_db.create_database() raise ValueError( f"HF_PATH not available for {config['vectorstore']['db_option']}" ) @@ -218,11 +216,6 @@ def main(): vector_db.create_database() print("Created database") - print("Trying to load the database") - vector_db = VectorStoreManager(config) - vector_db.load_database() - print("Loaded database") - print(f"View the logs at {config['log_dir']}/vector_db.log") diff --git a/edubotics_core/vectorstore/vectorstore.py b/edubotics_core/vectorstore/vectorstore.py index 81d801e..e39abba 100644 --- a/edubotics_core/vectorstore/vectorstore.py +++ b/edubotics_core/vectorstore/vectorstore.py @@ -1,6 +1,7 @@ from edubotics_core.vectorstore.faiss import FaissVectorStore from edubotics_core.vectorstore.chroma import ChromaVectorStore from edubotics_core.vectorstore.colbert import ColbertVectorStore +from edubotics_core.vectorstore.mvs import MultiVectorStore from edubotics_core.vectorstore.raptor import RAPTORVectoreStore from huggingface_hub import snapshot_download import os @@ -16,6 +17,7 @@ def __init__(self, config): "Chroma": ChromaVectorStore, "RAGatouille": ColbertVectorStore, "RAPTOR": RAPTORVectoreStore, + "MVS": MultiVectorStore, } def _create_database( @@ -29,7 +31,9 @@ def _create_database( db_option = self.config["vectorstore"]["db_option"] vectorstore_class = self.vectorstore_classes.get(db_option) if not vectorstore_class: - raise ValueError(f"Invalid db_option: {db_option}") + raise ValueError( + f"Invalid vector database option: {db_option}. Please pick from: {self.vectorstore_classes.keys()}" + ) self.vectorstore = vectorstore_class(self.config)