Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 59 additions & 15 deletions unstructured2graph/examples/graphrag.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,76 @@
import argparse
import asyncio
import logging
import os

from memgraph_toolbox.api.memgraph import Memgraph
from openai import OpenAI

from unstructured2graph import compute_embeddings, create_vector_search_index
from loading import from_unstructured_with_prep
import prompt_templates


if __name__ == "__main__":
async def full_graphrag(args):
#### INGESTION
# TODO(gitbuda): Add the import here.
memgraph = Memgraph()
compute_embeddings(memgraph, "Chunk")
create_vector_search_index(memgraph, "Chunk", "embedding")
if args.ingestion:
await from_unstructured_with_prep()
compute_embeddings(memgraph, "Chunk")
create_vector_search_index(memgraph, "Chunk", "embedding")

#### RETRIEVAL / GRAPHRAG
# The Native/One-query GraphRAG!
# TODO(gitbuda): In the current small graph, the Chunks are not connected via the entity graph.
#### RETRIEVAL / GRAPHRAG -> The Native/One-query GraphRAG!
prompt = "What is different under v3.7 compared to v3.6?"
retrieved_chunks = []
for row in memgraph.query(
f"""
CALL embeddings.text(['Hello world prompt']) YIELD embeddings, success
CALL vector_search.search('vs_name', 10, embeddings[0]) YIELD distance, node, similarity
MATCH (node)-[r*bfs]-(dst)
CALL embeddings.text(['{prompt}']) YIELD embeddings, success
CALL vector_search.search('vs_name', 5, embeddings[0]) YIELD distance, node, similarity
MATCH (node)-[r*bfs]-(dst:Chunk)
WITH DISTINCT dst, degree(dst) AS degree ORDER BY degree DESC
RETURN dst LIMIT 10;
RETURN dst LIMIT 5;
"""
):
if "description" in row["dst"]:
print(row["dst"]["description"])
retrieved_chunks.append(row["dst"]["description"])
if "text" in row["dst"]:
print(row["dst"]["text"])
print("----")
retrieved_chunks.append(row["dst"]["text"])

#### SUMMARIZATION
# TODO(gitbuda): Call LLM to generate the final answer.
if not retrieved_chunks:
print("No chunks retrieved. Cannot generate answer.")
else:
context = "\n\n".join(retrieved_chunks)
system_message = prompt_templates.system_message
user_message = prompt_templates.user_message(context, prompt)
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError(
"OPENAI_API_KEY environment variable is not set. Please set your OpenAI API key."
)
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
],
temperature=0.1,
)
answer = completion.choices[0].message.content
print(f"\nQuestion: {prompt}")
print(f"\nAnswer:\n{answer}")


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(
description="GraphRAG: Retrieve and answer questions using graph-based RAG"
)
parser.add_argument(
"--ingestion",
action="store_true",
help="Run data ingestion (load documents, compute embeddings, create vector index). By default, ingestion is skipped.",
)
args = parser.parse_args()

asyncio.run(full_graphrag(args))
8 changes: 6 additions & 2 deletions unstructured2graph/examples/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from memgraph_toolbox.api.memgraph import Memgraph
from unstructured2graph import from_unstructured, create_index
from sources import SOURCES
import sources as SOURCES

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
LIGHTRAG_DIR = os.path.join(SCRIPT_DIR, "..", "lightrag_storage.out")
Expand All @@ -34,7 +34,11 @@ async def from_unstructured_with_prep():
await lightrag_wrapper.initialize(working_dir=LIGHTRAG_DIR)

await from_unstructured(
SOURCES, memgraph, lightrag_wrapper, only_chunks=False, link_chunks=True
SOURCES.MEMGRAPH_DOCS_GITHUB_LATEST_RAW,
memgraph,
lightrag_wrapper,
only_chunks=False,
link_chunks=True,
)
await lightrag_wrapper.afinalize()

Expand Down
17 changes: 17 additions & 0 deletions unstructured2graph/examples/prompt_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
system_message = f"""
You are a helpful assistant that answers questions based on the provided context.
Use only the information from the context to answer the question.
If the context doesn't contain enough information to answer the question, say so.
"""

# Create the prompt with context
user_message = (
lambda context, prompt: f"""
Based on the following context, please answer the question.

Context: {context}

Question: {prompt}

Answer:"""
)
16 changes: 15 additions & 1 deletion unstructured2graph/examples/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,25 @@
pypdf_samples_dir = os.path.join(SCRIPT_DIR, "..", "sample-data", "pdf", "sample-files")
docx_samples_dir = os.path.join(SCRIPT_DIR, "..", "sample-data", "doc")
xls_samples_dir = os.path.join(SCRIPT_DIR, "..", "sample-data", "xls")
SOURCES = [
RANDOM = [
os.path.join(
pypdf_samples_dir, "011-google-doc-document", "google-doc-document.pdf"
),
os.path.join(docx_samples_dir, "sample3.docx"),
# os.path.join(xls_samples_dir, "financial-sample.xlsx"),
# "https://memgraph.com/docs/ai-ecosystem/graph-rag",
]
MEMGRAPH_DOCS = [
"https://memgraph.com/docs/querying/clauses",
"https://memgraph.com/docs/clustering/high-availability",
]

MEMGRAPH_DOCS_GITHUB_LATEST = [
"https://github.com/memgraph/documentation/pull/1452/files"
]

MEMGRAPH_DOCS_GITHUB_LATEST_RAW = [
"https://raw.githubusercontent.com/memgraph/documentation/f6f165649b89efc51fa4153fffc08ff5304ca0c9/pages/database-management/authentication-and-authorization/mlbac-migration-guide.mdx",
# "https://raw.githubusercontent.com/memgraph/documentation/f6f165649b89efc51fa4153fffc08ff5304ca0c9/pages/database-management/authentication-and-authorization/role-based-access-control.mdx",
# "https://raw.githubusercontent.com/memgraph/documentation/40ab6644f7113aa5cb86faa48961d2cb2c34f2cc/pages/data-migration/parquet.mdx",
]
17 changes: 11 additions & 6 deletions unstructured2graph/src/unstructured2graph/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,10 @@ async def from_unstructured(
)
memgraph_node_props = []
for chunk in document.chunks:
if not only_chunks:
await lightrag_wrapper.ainsert(
input=chunk.text, file_paths=[chunk.hash]
)
logger.info(f"Chunk: {chunk.hash} - {chunk.text}")
memgraph_node_props.append({"hash": chunk.hash, "text": chunk.text})
create_nodes_from_list(memgraph, memgraph_node_props, "Chunk", 100)

if link_chunks:
hash_pairs = [
(document.chunks[i].hash, document.chunks[i + 1].hash)
Expand All @@ -149,6 +147,15 @@ async def from_unstructured(
for from_hash, to_hash in hash_pairs
]
link_nodes_in_order(memgraph, "Chunk", "hash", relationships, "NEXT")

for chunk in document.chunks:
if not only_chunks:
await lightrag_wrapper.ainsert(
input=chunk.text, file_paths=[chunk.hash]
)
if not only_chunks:
connect_chunks_to_entities(memgraph, "Chunk", "base")

processed_chunks += len(document.chunks)
elapsed_time = time.time() - start_time
estimated_time_remaining = (
Expand All @@ -168,5 +175,3 @@ async def from_unstructured(
logger.info(
f"Processed {processed_chunks} chunks out of {total_chunks}. Estimated time remaining: {time_str}"
)
if not only_chunks:
connect_chunks_to_entities(memgraph, "Chunk", "base")
6 changes: 5 additions & 1 deletion unstructured2graph/src/unstructured2graph/memgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def create_nodes_from_list(

def connect_chunks_to_entities(memgraph: Memgraph, chunk_label: str, entity_label: str):
memgraph.query(
f"MATCH (n:{entity_label}), (m:{chunk_label}) WHERE n.file_path = m.hash CREATE (n)-[:MENTIONED_IN]->(m);"
f"""
MATCH (n:{entity_label}), (m:{chunk_label})
WHERE n.file_path = m.hash
MERGE (n)-[:MENTIONED_IN]->(m);
"""
)


Expand Down
12 changes: 11 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading