diff --git a/src/python_bindings/bindings.cpp b/src/python_bindings/bindings.cpp index b68c14653..6addb7de9 100644 --- a/src/python_bindings/bindings.cpp +++ b/src/python_bindings/bindings.cpp @@ -558,6 +558,44 @@ class PyBFIndex : public PyVecSimIndex { #if HAVE_SVS class PySVSIndex : public PyVecSimIndex { +private: + template // size_t/double for KNN/range queries. + using QueryFunc = + std::function; + + template // size_t/double for KNN / range queries. + void runParallelQueries(const py::array &queries, size_t n_queries, search_param_t param, + VecSimQueryParams *query_params, int n_threads, + QueryFunc queryFunc, VecSimQueryReply **results) { + + // Use number of hardware cores as default number of threads, unless specified otherwise. + if (n_threads <= 0) { + n_threads = (int)std::thread::hardware_concurrency(); + } + std::atomic_int global_counter(0); + + auto parallel_search = [&](const py::array &items) { + while (true) { + int ind = global_counter.fetch_add(1); + if (ind >= n_queries) { + break; + } + results[ind] = queryFunc((const char *)items.data(ind), param, query_params); + } + }; + std::thread thread_objs[n_threads]; + { + // Release python GIL while threads are running. + py::gil_scoped_release py_gil; + for (size_t i = 0; i < n_threads; i++) { + thread_objs[i] = std::thread(parallel_search, queries); + } + for (size_t i = 0; i < n_threads; i++) { + thread_objs[i].join(); + } + } + } + public: explicit PySVSIndex(const SVSParams &svs_params) { VecSimParams params = {.algo = VecSimAlgo_SVS, .algoParams = {.svsParams = svs_params}}; @@ -567,6 +605,25 @@ class PySVSIndex : public PyVecSimIndex { } } + py::object searchKnnParallel(const py::object &input, size_t k, VecSimQueryParams *query_params, + int n_threads) { + + py::array queries(input); + if (queries.ndim() != 2) { + throw std::runtime_error("Input queries array must be 2D array"); + } + size_t n_queries = queries.shape(0); + QueryFunc searchKnnWrapper( + [this](const char *query_, size_t k_, + VecSimQueryParams *query_params_) -> VecSimQueryReply * { + return this->searchKnnInternal(query_, k_, query_params_); + }); + VecSimQueryReply *results[n_queries]; + runParallelQueries(queries, n_queries, k, query_params, n_threads, searchKnnWrapper, + results); + return wrap_results(results, k, n_queries); + } + explicit PySVSIndex(const std::string &location, const SVSParams &svs_params) { VecSimParams params = {.algo = VecSimAlgo_SVS, .algoParams = {.svsParams = svs_params}}; this->index = @@ -843,6 +900,8 @@ PYBIND11_MODULE(VecSim, m) { return new PySVSIndex(location, params); }), py::arg("location"), py::arg("params")) + .def("knn_parallel", &PySVSIndex::searchKnnParallel, py::arg("queries"), py::arg("k"), + py::arg("query_param") = nullptr, py::arg("num_threads") = -1) .def("add_vector_parallel", &PySVSIndex::addVectorsParallel, py::arg("vectors"), py::arg("labels")) .def("check_integrity", &PySVSIndex::checkIntegrity)