Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,14 @@ using enable_if_valid_list_t = typename enable_if_valid_list<ListT, T>::type;
* `cuvs::neighbors::ivf_pq::helpers::resize_list` which handle type casting internally.
*/
template <typename ListT>
CUVS_EXPORT void resize_list(raft::resources const& res,
std::shared_ptr<ListT>& orig_list, // NOLINT
const typename ListT::spec_type& spec,
typename ListT::size_type new_used_size,
typename ListT::size_type old_logical_size,
typename ListT::size_type old_used_size);

template <typename ListT>
CUVS_EXPORT void resize_list(raft::resources const& res,
std::shared_ptr<ListT>& orig_list, // NOLINT
const typename ListT::spec_type& spec,
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -286,6 +286,7 @@ void extend(raft::resources const& handle,
lists[label],
list_device_spec,
new_list_sizes[label],
old_list_sizes[label],
raft::Pow2<kIndexGroupSize>::roundUp(old_list_sizes[label]));
}
}
Expand Down
20 changes: 17 additions & 3 deletions cpp/src/neighbors/ivf_list.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -66,13 +66,17 @@ CUVS_EXPORT void resize_list(raft::resources const& res,
std::shared_ptr<ListT>& orig_list, // NOLINT
const typename ListT::spec_type& spec,
typename ListT::size_type new_used_size,
typename ListT::size_type old_logical_size,
typename ListT::size_type old_used_size)
{
// old_logical_size is the previous visible size from this index's list_sizes().
// old_used_size is the old allocation copy extent and may include padded slots
// required by interleaved list layouts.
bool skip_resize = false;
if (orig_list) {
if (new_used_size <= orig_list->indices.extent(0)) {
auto shared_list_size = old_used_size;
if (new_used_size <= old_used_size ||
auto shared_list_size = old_logical_size;
if (new_used_size <= old_logical_size ||
orig_list->size.compare_exchange_strong(shared_list_size, new_used_size)) {
// We don't need to resize the list if:
// 1. The list exists
Expand Down Expand Up @@ -104,6 +108,16 @@ CUVS_EXPORT void resize_list(raft::resources const& res,
new_list.swap(orig_list);
}

template <typename ListT>
CUVS_EXPORT void resize_list(raft::resources const& res,
std::shared_ptr<ListT>& orig_list, // NOLINT
const typename ListT::spec_type& spec,
typename ListT::size_type new_used_size,
typename ListT::size_type old_used_size)
{
resize_list(res, orig_list, spec, new_used_size, old_used_size, old_used_size);
}

template <typename ListT>
enable_if_valid_list_t<ListT> serialize_list(const raft::resources& handle,
std::ostream& os,
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/neighbors/ivf_sq/ivf_sq_build.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -358,6 +358,7 @@ void extend_inplace(raft::resources const& handle,
lists[label],
list_device_spec,
new_list_sizes[label],
old_list_sizes[label],
raft::Pow2<kIndexGroupSize>::roundUp(old_list_sizes[label]));
}
}
Expand Down
74 changes: 73 additions & 1 deletion cpp/tests/neighbors/ann_ivf_flat/test_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#include <gtest/gtest.h>

#include "../ann_ivf_flat.cuh"

#include <numeric>
#include <vector>

namespace cuvs::neighbors::ivf_flat {

typedef AnnIVFFlatTest<float, float, int64_t> AnnIVFFlatTestF_float;
Expand All @@ -19,4 +22,73 @@ TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat)

INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_float, ::testing::ValuesIn(inputs));

TEST(AnnIVFFlatTest, RepeatedExtendCopyPreservesSharedListWithinCapacity)
{
raft::resources handle;
auto stream = raft::resource::get_cuda_stream(handle);

constexpr int64_t base_rows = 100;
constexpr int64_t grow_rows = 20;
constexpr int64_t rows = base_rows + grow_rows;
constexpr int64_t dim = 4;

std::vector<float> host_data(rows * dim);
for (int64_t row = 0; row < rows; row++) {
for (int64_t col = 0; col < dim; col++) {
host_data[row * dim + col] = static_cast<float>(row + col);
}
}

auto data = raft::make_device_matrix<float, int64_t>(handle, rows, dim);
raft::copy(data.data_handle(), host_data.data(), host_data.size(), stream);

index_params params;
params.n_lists = 1;
params.metric = cuvs::distance::DistanceType::L2Expanded;
params.add_data_on_build = false;
params.kmeans_trainset_fraction = 1.0;
params.adaptive_centers = false;
params.conservative_memory_allocation = false;

auto all_data_view =
raft::make_device_matrix_view<const float, int64_t>(data.data_handle(), rows, dim);
auto empty_index = build(handle, params, all_data_view);

auto base_data_view =
raft::make_device_matrix_view<const float, int64_t>(data.data_handle(), base_rows, dim);
auto base_index = extend(handle, base_data_view, std::nullopt, empty_index);
raft::resource::sync_stream(handle);

ASSERT_EQ(base_index.lists()[0]->get_size(), base_rows);
ASSERT_GE(base_index.lists()[0]->indices_capacity(), rows);

std::vector<int64_t> host_indices(grow_rows);
std::iota(host_indices.begin(), host_indices.end(), base_rows);
auto indices = raft::make_device_vector<int64_t, int64_t>(handle, grow_rows);
raft::copy(indices.data_handle(), host_indices.data(), host_indices.size(), stream);

auto grow_data_view = raft::make_device_matrix_view<const float, int64_t>(
data.data_handle() + base_rows * dim, grow_rows, dim);
auto grow_indices_view =
raft::make_device_vector_view<const int64_t, int64_t>(indices.data_handle(), grow_rows);
auto first_grown_index =
extend(handle,
grow_data_view,
std::make_optional<raft::device_vector_view<const int64_t, int64_t>>(grow_indices_view),
base_index);
raft::resource::sync_stream(handle);

ASSERT_EQ(first_grown_index.lists()[0]->get_size(), rows);

auto second_grown_index =
extend(handle,
grow_data_view,
std::make_optional<raft::device_vector_view<const int64_t, int64_t>>(grow_indices_view),
base_index);
raft::resource::sync_stream(handle);

EXPECT_NE(first_grown_index.lists()[0].get(), second_grown_index.lists()[0].get());
EXPECT_EQ(second_grown_index.lists()[0]->get_size(), rows);
}

} // namespace cuvs::neighbors::ivf_flat
62 changes: 61 additions & 1 deletion cpp/tests/neighbors/ann_ivf_sq/test_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -1,17 +1,77 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#include <gtest/gtest.h>

#include "../ann_ivf_sq.cuh"

#include <numeric>
#include <vector>

namespace cuvs::neighbors::ivf_sq {

typedef AnnIVFSQTest<float, float, int64_t> AnnIVFSQTestF_float;
TEST_P(AnnIVFSQTestF_float, AnnIVFSQ) { this->testAll(); }

INSTANTIATE_TEST_CASE_P(AnnIVFSQTest, AnnIVFSQTestF_float, ::testing::ValuesIn(inputs));

TEST(AnnIVFSQTest, ExtendInPlaceUpdatesListSizeWithinCapacity)
{
raft::resources handle;
auto stream = raft::resource::get_cuda_stream(handle);

constexpr int64_t base_rows = 100;
constexpr int64_t grow_rows = 20;
constexpr int64_t rows = base_rows + grow_rows;
constexpr int64_t dim = 4;

std::vector<float> host_data(rows * dim);
for (int64_t row = 0; row < rows; row++) {
for (int64_t col = 0; col < dim; col++) {
host_data[row * dim + col] = static_cast<float>(row + col);
}
}

auto data = raft::make_device_matrix<float, int64_t>(handle, rows, dim);
raft::copy(data.data_handle(), host_data.data(), host_data.size(), stream);

index_params params;
params.n_lists = 1;
params.metric = cuvs::distance::DistanceType::L2Expanded;
params.add_data_on_build = false;
params.max_train_points_per_cluster = 256;
params.conservative_memory_allocation = false;

auto all_data_view =
raft::make_device_matrix_view<const float, int64_t>(data.data_handle(), rows, dim);
auto index = build(handle, params, all_data_view);

auto base_data_view =
raft::make_device_matrix_view<const float, int64_t>(data.data_handle(), base_rows, dim);
extend(handle, base_data_view, std::nullopt, &index);
raft::resource::sync_stream(handle);

ASSERT_EQ(index.lists()[0]->get_size(), base_rows);
ASSERT_GE(index.lists()[0]->indices_capacity(), rows);

std::vector<int64_t> host_indices(grow_rows);
std::iota(host_indices.begin(), host_indices.end(), base_rows);
auto indices = raft::make_device_vector<int64_t, int64_t>(handle, grow_rows);
raft::copy(indices.data_handle(), host_indices.data(), host_indices.size(), stream);

auto grow_data_view = raft::make_device_matrix_view<const float, int64_t>(
data.data_handle() + base_rows * dim, grow_rows, dim);
auto grow_indices_view =
raft::make_device_vector_view<const int64_t, int64_t>(indices.data_handle(), grow_rows);
extend(handle,
grow_data_view,
std::make_optional<raft::device_vector_view<const int64_t, int64_t>>(grow_indices_view),
&index);
raft::resource::sync_stream(handle);

EXPECT_EQ(index.lists()[0]->get_size(), rows);
}

} // namespace cuvs::neighbors::ivf_sq
Loading