forked from techwithtim/ProductionGradeRAGPythonApp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvector_db.py
More file actions
37 lines (31 loc) · 1.35 KB
/
vector_db.py
File metadata and controls
37 lines (31 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
class QdrantStorage:
def __init__(self, url="http://localhost:6333", collection="docs", dim=3072):
self.client = QdrantClient(url=url, timeout=30)
self.collection = collection
if not self.client.collection_exists(self.collection):
self.client.create_collection(
collection_name=self.collection,
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
)
def upsert(self, ids, vectors, payloads):
points = [PointStruct(id=ids[i], vector=vectors[i], payload=payloads[i]) for i in range(len(ids))]
self.client.upsert(self.collection, points=points)
def search(self, query_vector, top_k: int = 5):
results = self.client.search(
collection_name=self.collection,
query_vector=query_vector,
with_payload=True,
limit=top_k
)
contexts = []
sources = set()
for r in results:
payload = getattr(r, "payload", None) or {}
text = payload.get("text", "")
source = payload.get("source", "")
if text:
contexts.append(text)
sources.add(source)
return {"contexts": contexts, "sources": list(sources)}