diff --git a/vectordb_bench/backend/clients/envector/cli.py b/vectordb_bench/backend/clients/envector/cli.py index b14d509b4..3197b8e3b 100644 --- a/vectordb_bench/backend/clients/envector/cli.py +++ b/vectordb_bench/backend/clients/envector/cli.py @@ -23,7 +23,7 @@ class EnVectorTypedDict(TypedDict): str, click.option("--eval-mode", help="Evaluation mode", type=click.Choice(["mm", "rmp"]), default="mm"), ] - + class EnVectorFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): ... @@ -31,7 +31,7 @@ class EnVectorFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): ... @cli.command(name="envectorflat") @click_parameter_decorators_from_typed_dict(EnVectorFlatIndexTypedDict) def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]): - from .config import FlatIndexConfig, EnVectorConfig + from .config import EnVectorConfig, FlatIndexConfig run( db=DBTYPE, @@ -46,7 +46,7 @@ def EnVectorFlat(**parameters: Unpack[EnVectorFlatIndexTypedDict]): ) -class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): +class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): nlist: Annotated[ int, click.option("--nlist", type=int, help="nlist for IVF index", default=250), @@ -76,7 +76,7 @@ class EnVectorIVFFlatIndexTypedDict(CommonTypedDict, EnVectorTypedDict): @cli.command(name="envectorivfflat") @click_parameter_decorators_from_typed_dict(EnVectorIVFFlatIndexTypedDict) def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]): - from .config import IVFFlatIndexConfig, EnVectorConfig + from .config import EnVectorConfig, IVFFlatIndexConfig run( db=DBTYPE, @@ -87,7 +87,7 @@ def EnVectorIVFFlat(**parameters: Unpack[EnVectorIVFFlatIndexTypedDict]): index_params={"nlist": parameters["nlist"], "nprobe": parameters["nprobe"]}, ), db_case_config=IVFFlatIndexConfig( - nlist=parameters["nlist"], + nlist=parameters["nlist"], nprobe=parameters["nprobe"], train_centroids=parameters["train_centroids"], centroids_path=parameters["centroids_path"], diff --git a/vectordb_bench/backend/clients/envector/config.py b/vectordb_bench/backend/clients/envector/config.py index 08db61bbd..3bc58f862 100644 --- a/vectordb_bench/backend/clients/envector/config.py +++ b/vectordb_bench/backend/clients/envector/config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, SecretStr -from ..api import DBCaseConfig, DBConfig, IndexType, MetricType, SQType +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType class EnVectorConfig(DBConfig): @@ -67,10 +67,10 @@ class IVFFlatIndexConfig(EnVectorIndexConfig, DBCaseConfig): nlist: int = 0 nprobe: int = 0 eval_mode: str = "mm" - train_centroids: bool = False # whether to train centroids before inserting data + train_centroids: bool = False # whether to train centroids before inserting data centroids_path: str | None = None # path to centroids file - is_vct: bool = False # whether use VCT index - vct_path: str | None = None # path to VCT index file + is_vct: bool = False # whether use VCT index + vct_path: str | None = None # path to VCT index file def index_param(self) -> dict: return { diff --git a/vectordb_bench/backend/clients/envector/envector.py b/vectordb_bench/backend/clients/envector/envector.py index 50b274289..0e34f44bb 100644 --- a/vectordb_bench/backend/clients/envector/envector.py +++ b/vectordb_bench/backend/clients/envector/envector.py @@ -1,23 +1,20 @@ """Wrapper around the EnVector vector database over VectorDB""" -from typing import Any, Dict - import logging import os from collections.abc import Iterable from contextlib import contextmanager -import pickle - -import numpy as np +from pathlib import Path +from typing import Any import es2 +import numpy as np from vectordb_bench.backend.filter import Filter, FilterOp from ..api import VectorDB from .config import EnVectorIndexConfig - log = logging.getLogger(__name__) @@ -45,8 +42,8 @@ def __init__( self.case_config = db_case_config self.collection_name = collection_name - self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT - + self.batch_size = 128 * 32 # default batch size for insertions, can be modified for IVF_FLAT + self._primary_field = "pk" self._scalar_id_field = "id" self._scalar_label_field = "label" @@ -57,83 +54,89 @@ def __init__( self.col: es2.Index | None = None self.is_vct: bool = False - self.vct_params: Dict[str, Any] = {} - kwargs: Dict[str, Any] = {} - + self.vct_params: dict[str, Any] = {} + es2.init( - address=self.db_config.get("uri"), - key_path=self.db_config.get("key_path"), + address=self.db_config.get("uri"), + key_path=self.db_config.get("key_path"), key_id=self.db_config.get("key_id"), eval_mode=self.case_config.eval_mode, ) if drop_old: - log.info(f"{self.name} client drop_old index: {self.collection_name}") - if self.collection_name in es2.get_index_list(): + log.info(f"{self.name} client drop_old index: {self.collection_name}") + if self.collection_name in es2.get_index_list(): es2.drop_index(self.collection_name) - + # Create the collection log.info(f"{self.name} create index: {self.collection_name}") - + + index_kwargs = dict(kwargs) + self._ensure_index(dim, index_kwargs) + + es2.disconnect() + + def _ensure_index(self, dim: int, index_kwargs: dict[str, Any]): if self.collection_name in es2.get_index_list(): log.info(f"{self.name} index {self.collection_name} already exists, skip creating") self.is_vct = self.case_config.index_param().get("is_vct", False) log.debug(f"IS_VCT: {self.is_vct}") + return + self._create_index(dim, index_kwargs) - else: - index_param = self.case_config.index_param().get("params", {}) - index_type = index_param.get("index_type", "FLAT") - train_centroids = self.case_config.index_param().get("train_centroids", False) - - if index_type == "IVF_FLAT" and train_centroids: - - centroid_path = self.case_config.index_param().get("centroids_path", None) - self.is_vct = self.case_config.index_param().get("is_vct", False) - log.debug(f"IS_VCT: {self.is_vct}") - - if centroid_path is not None: - if not os.path.exists(centroid_path): - raise FileNotFoundError(f"Centroid file {centroid_path} not found for IVF_FLAT index training.") - - # load trained centroids from file - log.debug(f"Centroids: {centroid_path}") - centroids = np.load(centroid_path) - log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.") - - # set centroids for index creation - index_param["centroids"] = centroids.tolist() - - if self.is_vct: - # set VCT parameters if applicable - vct_path = self.case_config.index_param().get("vct_path", None) - log.debug(f"VCT: {vct_path}") - index_param["virtual_cluster"] = True - kwargs["tree_description"] = vct_path - self.is_vct = True - log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.") + def _create_index(self, dim: int, index_kwargs: dict[str, Any]): + index_param = self.case_config.index_param().get("params", {}) + index_type = index_param.get("index_type", "FLAT") + train_centroids = self.case_config.index_param().get("train_centroids", False) - else: - raise ValueError("Centroids path must be provided for IVF_FLAT index training.") - - # set larger batch size for IVF_FLAT insertions - if index_type == "IVF_FLAT": - self.batch_size = int(os.environ.get("NUM_PER_BATCH", 500_000)) - log.debug( - f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. " - f"This should be the size of dataset for better performance when IVF_FLAT." - ) + if index_type == "IVF_FLAT" and train_centroids: + self._configure_centroids(index_param, index_kwargs) - # create index after training centroids - es2.create_index( - index_name=self.collection_name, - dim=dim, - key_path=self.db_config.get("key_path"), - key_id=self.db_config.get("key_id"), - index_params=index_param, - eval_mode=self.case_config.eval_mode, - **kwargs, - ) + if index_type == "IVF_FLAT": + self._adjust_batch_size() - es2.disconnect() + es2.create_index( + index_name=self.collection_name, + dim=dim, + key_path=self.db_config.get("key_path"), + key_id=self.db_config.get("key_id"), + index_params=index_param, + eval_mode=self.case_config.eval_mode, + **index_kwargs, + ) + + def _configure_centroids(self, index_param: dict[str, Any], index_kwargs: dict[str, Any]): + centroid_path = self.case_config.index_param().get("centroids_path", None) + self.is_vct = self.case_config.index_param().get("is_vct", False) + log.debug(f"IS_VCT: {self.is_vct}") + + if centroid_path is None: + raise ValueError("Centroids path must be provided for IVF_FLAT index training.") + + centroid_file = Path(centroid_path) + if not centroid_file.exists(): + msg = f"Centroid file {centroid_path} not found for IVF_FLAT index training." + raise FileNotFoundError(msg) + + log.debug(f"Centroids: {centroid_path}") + centroids = np.load(centroid_file) + log.info(f"{self.name} loaded centroids from {centroid_path} for IVF_FLAT index training.") + + index_param["centroids"] = centroids.tolist() + + if self.is_vct: + vct_path = self.case_config.index_param().get("vct_path", None) + log.debug(f"VCT: {vct_path}") + index_param["virtual_cluster"] = True + index_kwargs["tree_description"] = vct_path + self.is_vct = True + log.info(f"{self.name} VCT parameters set for IVF_FLAT index creation.") + + def _adjust_batch_size(self): + self.batch_size = int(os.environ.get("NUM_PER_BATCH", "500000")) + log.debug( + f"Set EnVector IVF_FLAT insert batch size to {self.batch_size}. " + f"This should be the size of dataset for better performance when IVF_FLAT." + ) @contextmanager def init(self): @@ -152,7 +155,7 @@ def init(self): try: self.col = es2.Index(self.collection_name) if self.is_vct: - log.debug(f"VCT: {self.col.index_config.index_param.index_params["virtual_cluster"]}") + log.debug(f"VCT: {self.col.index_config.index_param.index_params['virtual_cluster']}") is_vct = self.case_config.index_param().get("is_vct", False) assert self.is_vct == is_vct, "is_vct mismatch" vct_path = self.case_config.index_param().get("vct_path", None) @@ -190,7 +193,7 @@ def insert_embeddings( # use the first insert_embeddings to init collection assert self.col is not None assert len(embeddings) == len(metadata) - + log.debug(f"IS_VCT: {self.is_vct}") insert_count = 0 @@ -229,7 +232,7 @@ def search_embedding( output_fields=["metadata"], search_params=self.case_config.search_param().get("search_params", {}), ) - + else: # Perform the search. res = self.col.search( @@ -247,12 +250,11 @@ def search_embedding( # Extract metadata from results # res structure: [[{id: X, score: Y, metadata: Z}, ...]] log.debug(f"Search results: {res[0][:1]}") # Log first 1 results for debugging - if len(res) > 0 and len(res[0]) > 0: - return [int(result["metadata"]) for result in res[0] if "metadata" in result] - else: + if not (res and len(res[0]) > 0): log.warning(f"Unexpected result structure: {res}") return [] + return [int(result["metadata"]) for result in res[0] if "metadata" in result] - except Exception as e: - log.error(f"Search failed: {e}") + except Exception: + log.exception("Search failed") return [] diff --git a/vectordb_bench/log_util.py b/vectordb_bench/log_util.py index 3bf7e4725..5c94424a9 100644 --- a/vectordb_bench/log_util.py +++ b/vectordb_bench/log_util.py @@ -1,12 +1,12 @@ import logging +import os from logging import config from pathlib import Path -import os def init(log_level: str): os.environ["TQDM_DISABLE"] = "1" - + # Create logs directory if it doesn't exist log_dir = Path("logs") log_dir.mkdir(exist_ok=True)