|
| 1 | +import argparse |
| 2 | +import pickle |
| 3 | +import os |
| 4 | +import csv |
| 5 | +import openai |
| 6 | +import re |
| 7 | +import requests |
| 8 | + |
| 9 | +from langchain.document_loaders import BSHTMLLoader |
| 10 | +from langchain.vectorstores import FAISS |
| 11 | +from langchain.text_splitter import CharacterTextSplitter |
| 12 | +from langchain.embeddings.openai import OpenAIEmbeddings |
| 13 | +from dotenv import load_dotenv |
| 14 | + |
| 15 | +load_dotenv() |
| 16 | +openai.api_key = os.getenv("OPENAI_API_KEY") |
| 17 | + |
| 18 | +def parse_args(): |
| 19 | + parser = argparse.ArgumentParser(description='MariaDB KB Vector Store Generator') |
| 20 | + parser.add_argument('--csv-file', type=str, default='kb_urls.csv', help='Path to the input CSV file containing the URLs') |
| 21 | + parser.add_argument('--tmp-dir', type=str, default='tmp', help='Directory where the temporary HTML files will be stored') |
| 22 | + parser.add_argument('--vectorstore-path', type=str, default='vectorstore.pkl', help='Path to save the generated FAISS vector store pickle file') |
| 23 | + parser.add_argument('--chunk-size', type=int, default=4000, help='Chunk size for splitting the documents') |
| 24 | + parser.add_argument('--chunk-overlap', type=int, default=200, help='Overlap size between chunks when splitting documents') |
| 25 | + return parser.parse_args() |
| 26 | + |
| 27 | +def download_web_page(url): |
| 28 | + response = requests.get(url) |
| 29 | + |
| 30 | + if response.status_code == 200: |
| 31 | + content = response.text |
| 32 | + filename = url.replace('://', '_').replace('/', '_') + '.html' |
| 33 | + |
| 34 | + with open('./tmp/' + filename, 'w', encoding='utf-8') as file: |
| 35 | + file.write(content) |
| 36 | + else: |
| 37 | + print(f"Error: Unable to fetch the web page. Status code: {response.status_code}") |
| 38 | + |
| 39 | +def read_csv(csv_file): |
| 40 | + urls = [] |
| 41 | + |
| 42 | + with open(csv_file, newline='', encoding='utf-8') as csvfile: |
| 43 | + csv_reader = csv.reader(csvfile) |
| 44 | + for row in csv_reader: |
| 45 | + if row[0].strip(): |
| 46 | + urls.append(row[0]) |
| 47 | + |
| 48 | + return urls[1:] |
| 49 | + |
| 50 | +def main(): |
| 51 | + args = parse_args() |
| 52 | + |
| 53 | + urls = read_csv(args.csv_file) |
| 54 | + all_docs = [] |
| 55 | + idx = 0 |
| 56 | + for url in urls: |
| 57 | + filename = url.replace('://', '_').replace('/', '_').strip() + '.html' |
| 58 | + doc_path = args.tmp_dir + '/' + filename |
| 59 | + if not os.path.exists(doc_path): |
| 60 | + download_web_page(url) |
| 61 | + loader = BSHTMLLoader(doc_path) |
| 62 | + doc = loader.load()[0] |
| 63 | + |
| 64 | + content = re.sub(r'\s+', ' ', doc.page_content) |
| 65 | + doc.page_content = content |
| 66 | + doc.metadata["source"] = url |
| 67 | + |
| 68 | + all_docs.append(doc) |
| 69 | + |
| 70 | + text_splitter = CharacterTextSplitter( |
| 71 | + separator = " ", |
| 72 | + chunk_size = args.chunk_size, |
| 73 | + chunk_overlap = args.chunk_overlap, |
| 74 | + length_function = len, |
| 75 | + ) |
| 76 | + print("Loaded {} documents".format(len(all_docs))) |
| 77 | + all_docs = text_splitter.split_documents(all_docs) |
| 78 | + print("After split: {} documents".format(len(all_docs))) |
| 79 | + |
| 80 | + faiss_index = FAISS.from_documents(all_docs, OpenAIEmbeddings()) |
| 81 | + |
| 82 | + with open(args.vectorstore_path, "wb") as f: |
| 83 | + pickle.dump(faiss_index, f) |
| 84 | + |
| 85 | +if __name__ == "__main__": |
| 86 | + main() |
0 commit comments