Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vectordb_bench/backend/clients/envector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ class EnVectorTypedDict(TypedDict):
str,
click.option("--eval-mode", help="Evaluation mode", type=click.Choice(["mm", "rmp"]), default="mm"),
]


class EnVectorFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): ...


@cli.command(name="envectorflat")
@click_parameter_decorators_from_typed_dict(EnVectorFlatIndexTypedDict)
def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
from .config import FlatIndexConfig, EnVectorConfig
from .config import EnVectorConfig, FlatIndexConfig

run(
db=DBTYPE,
Expand All @@ -46,7 +46,7 @@ def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]):
)


class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
nlist: Annotated[
int,
click.option("--nlist", type=int, help="nlist for IVF index", default=250),
Expand Down Expand Up @@ -76,7 +76,7 @@ class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict):
@cli.command(name="envectorivfflat")
@click_parameter_decorators_from_typed_dict(EnVectorIVFFlatIndexTypedDict)
def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
from .config import IVFFlatIndexConfig, EnVectorConfig
from .config import EnVectorConfig, IVFFlatIndexConfig

run(
db=DBTYPE,
Expand All @@ -87,7 +87,7 @@ def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]):
index_params={"nlist": parameters["nlist"], "nprobe": parameters["nprobe"]},
),
db_case_config=IVFFlatIndexConfig(
nlist=parameters["nlist"],
nlist=parameters["nlist"],
nprobe=parameters["nprobe"],
train_centroids=parameters["train_centroids"],
centroids_path=parameters["centroids_path"],
Expand Down
8 changes: 4 additions & 4 deletions vectordb_bench/backend/clients/envector/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType


class EnVectorConfig(DBConfig):
Expand Down Expand Up @@ -67,10 +67,10 @@ class IVFFlatIndexConfig(EnVectorIndexConfig, DBCaseConfig):
nlist: int = 0
nprobe: int = 0
eval_mode: str = "mm"
train_centroids: bool = False # whether to train centroids before inserting data
train_centroids: bool = False # whether to train centroids before inserting data
centroids_path: str | None = None # path to centroids file
is_vct: bool = False # whether use VCT index
vct_path: str | None = None # path to VCT index file
is_vct: bool = False # whether use VCT index
vct_path: str | None = None # path to VCT index file

def index_param(self) -> dict:
return {
Expand Down
156 changes: 79 additions & 77 deletions vectordb_bench/backend/clients/envector/envector.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
"""Wrapper around the EnVector vector database over VectorDB"""

from typing import Any, Dict

import logging
import os
from collections.abc import Iterable
from contextlib import contextmanager
import pickle

import numpy as np
from pathlib import Path
from typing import Any

import es2
import numpy as np

from vectordb_bench.backend.filter import Filter, FilterOp

from ..api import VectorDB
from .config import EnVectorIndexConfig


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -45,8 +42,8 @@ def __init__(
self.case_config = db_case_config
self.collection_name = collection_name

self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT
self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT

self._primary_field = "pk"
self._scalar_id_field = "id"
self._scalar_label_field = "label"
Expand All @@ -57,83 +54,89 @@ def __init__(
self.col: es2.Index | None = None

self.is_vct: bool = False
self.vct_params: Dict[str, Any] = {}
kwargs: Dict[str, Any] = {}

self.vct_params: dict[str, Any] = {}

es2.init(
address=self.db_config.get("uri"),
key_path=self.db_config.get("key_path"),
address=self.db_config.get("uri"),
key_path=self.db_config.get("key_path"),
key_id=self.db_config.get("key_id"),
eval_mode=self.case_config.eval_mode,
)
if drop_old:
log.info(f"{self.name} client drop_old index: {self.collection_name}")
if self.collection_name in es2.get_index_list():
log.info(f"{self.name} client drop_old index: {self.collection_name}")
if self.collection_name in es2.get_index_list():
es2.drop_index(self.collection_name)

# Create the collection
log.info(f"{self.name} create index: {self.collection_name}")


index_kwargs = dict(kwargs)
self._ensure_index(dim, index_kwargs)

es2.disconnect()

def _ensure_index(self, dim: int, index_kwargs: dict[str, Any]):
if self.collection_name in es2.get_index_list():
log.info(f"{self.name} index {self.collection_name} already exists, skip creating")
self.is_vct = self.case_config.index_param().get("is_vct", False)
log.debug(f"IS_VCT: {self.is_vct}")
return
self._create_index(dim, index_kwargs)

else:
index_param = self.case_config.index_param().get("params", {})
index_type = index_param.get("index_type", "FLAT")
train_centroids = self.case_config.index_param().get("train_centroids", False)

if index_type == "IVF_FLAT" and train_centroids:

centroid_path = self.case_config.index_param().get("centroids_path", None)
self.is_vct = self.case_config.index_param().get("is_vct", False)
log.debug(f"IS_VCT: {self.is_vct}")

if centroid_path is not None:
if not os.path.exists(centroid_path):
raise FileNotFoundError(f"Centroid file {centroid_path} not found for IVF_FLAT index training.")

# load trained centroids from file
log.debug(f"Centroids: {centroid_path}")
centroids = np.load(centroid_path)
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")

# set centroids for index creation
index_param["centroids"] = centroids.tolist()

if self.is_vct:
# set VCT parameters if applicable
vct_path = self.case_config.index_param().get("vct_path", None)
log.debug(f"VCT: {vct_path}")
index_param["virtual_cluster"] = True
kwargs["tree_description"] = vct_path
self.is_vct = True
log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.")
def _create_index(self, dim: int, index_kwargs: dict[str, Any]):
index_param = self.case_config.index_param().get("params", {})
index_type = index_param.get("index_type", "FLAT")
train_centroids = self.case_config.index_param().get("train_centroids", False)

else:
raise ValueError("Centroids path must be provided for IVF_FLAT index training.")

# set larger batch size for IVF_FLAT insertions
if index_type == "IVF_FLAT":
self.batch_size = int(os.environ.get("NUM_PER_BATCH", 500_000))
log.debug(
f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. "
f"This should be the size of dataset for better performance when IVF_FLAT."
)
if index_type == "IVF_FLAT" and train_centroids:
self._configure_centroids(index_param, index_kwargs)

# create index after training centroids
es2.create_index(
index_name=self.collection_name,
dim=dim,
key_path=self.db_config.get("key_path"),
key_id=self.db_config.get("key_id"),
index_params=index_param,
eval_mode=self.case_config.eval_mode,
**kwargs,
)
if index_type == "IVF_FLAT":
self._adjust_batch_size()

es2.disconnect()
es2.create_index(
index_name=self.collection_name,
dim=dim,
key_path=self.db_config.get("key_path"),
key_id=self.db_config.get("key_id"),
index_params=index_param,
eval_mode=self.case_config.eval_mode,
**index_kwargs,
)

def _configure_centroids(self, index_param: dict[str, Any], index_kwargs: dict[str, Any]):
centroid_path = self.case_config.index_param().get("centroids_path", None)
self.is_vct = self.case_config.index_param().get("is_vct", False)
log.debug(f"IS_VCT: {self.is_vct}")

if centroid_path is None:
raise ValueError("Centroids path must be provided for IVF_FLAT index training.")

centroid_file = Path(centroid_path)
if not centroid_file.exists():
msg = f"Centroid file {centroid_path} not found for IVF_FLAT index training."
raise FileNotFoundError(msg)

log.debug(f"Centroids: {centroid_path}")
centroids = np.load(centroid_file)
log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.")

index_param["centroids"] = centroids.tolist()

if self.is_vct:
vct_path = self.case_config.index_param().get("vct_path", None)
log.debug(f"VCT: {vct_path}")
index_param["virtual_cluster"] = True
index_kwargs["tree_description"] = vct_path
self.is_vct = True
log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.")

def _adjust_batch_size(self):
self.batch_size = int(os.environ.get("NUM_PER_BATCH", "500000"))
log.debug(
f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. "
f"This should be the size of dataset for better performance when IVF_FLAT."
)

@contextmanager
def init(self):
Expand All @@ -152,7 +155,7 @@ def init(self):
try:
self.col = es2.Index(self.collection_name)
if self.is_vct:
log.debug(f"VCT: {self.col.index_config.index_param.index_params["virtual_cluster"]}")
log.debug(f"VCT: {self.col.index_config.index_param.index_params['virtual_cluster']}")
is_vct = self.case_config.index_param().get("is_vct", False)
assert self.is_vct == is_vct, "is_vct mismatch"
vct_path = self.case_config.index_param().get("vct_path", None)
Expand Down Expand Up @@ -190,7 +193,7 @@ def insert_embeddings(
# use the first insert_embeddings to init collection
assert self.col is not None
assert len(embeddings) == len(metadata)

log.debug(f"IS_VCT: {self.is_vct}")

insert_count = 0
Expand Down Expand Up @@ -229,7 +232,7 @@ def search_embedding(
output_fields=["metadata"],
search_params=self.case_config.search_param().get("search_params", {}),
)

else:
# Perform the search.
res = self.col.search(
Expand All @@ -247,12 +250,11 @@ def search_embedding(
# Extract metadata from results
# res structure: [[{id: X, score: Y, metadata: Z}, ...]]
log.debug(f"Search results: {res[0][:1]}") # Log first 1 results for debugging
if len(res) > 0 and len(res[0]) > 0:
return [int(result["metadata"]) for result in res[0] if "metadata" in result]
else:
if not (res and len(res[0]) > 0):
log.warning(f"Unexpected result structure: {res}")
return []
return [int(result["metadata"]) for result in res[0] if "metadata" in result]

except Exception as e:
log.error(f"Search failed: {e}")
except Exception:
log.exception("Search failed")
return []
4 changes: 2 additions & 2 deletions vectordb_bench/log_util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
from logging import config
from pathlib import Path
import os


def init(log_level: str):
os.environ["TQDM_DISABLE"] = "1"

# Create logs directory if it doesn't exist
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
Expand Down