Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 172 additions & 13 deletions AnnService/inc/Core/Common/BKTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <string>
#include <vector>
#include <mutex>
#include <atomic>
#include <thread>
#include <shared_mutex>
#include "inc/Core/VectorIndex.h"

Expand Down Expand Up @@ -231,9 +233,7 @@ namespace SPTAG

std::vector<std::thread> mythreads;
mythreads.reserve(args._TH);
for (int tid = 0; tid < args._TH; tid++)
{
mythreads.emplace_back([tid, first, last, updateCenters, lambda, subsize, &data, &indices, &args, &currDist]() {
auto kmeansWorker = [&](int tid) {
SizeType istart = first + tid * subsize;
SizeType iend = min(first + (tid + 1) * subsize, last);
SizeType *inewCounts = args.newCounts + tid * args._K;
Expand Down Expand Up @@ -295,13 +295,14 @@ namespace SPTAG
}
}
COMMON::Utils::atomic_float_add(&currDist, idist);
});
}
for (auto &t : mythreads)
{
t.join();
};
if (args._TH <= 1) {
kmeansWorker(0); // serial fast-path: avoid spawning/joining a thread per node
} else {
for (int tid = 0; tid < args._TH; tid++) mythreads.emplace_back(kmeansWorker, tid);
for (auto &t : mythreads) t.join();
mythreads.clear();
}
mythreads.clear();
for (int i = 1; i < args._TH; i++) {
for (int k = 0; k < args._DK; k++) {
args.newCounts[k] += args.newCounts[i * args._K + k];
Expand Down Expand Up @@ -527,7 +528,7 @@ break;
class BKTree
{
public:
BKTree(): m_iTreeNumber(1), m_iBKTKmeansK(32), m_iBKTLeafSize(8), m_iSamples(1000), m_fBalanceFactor(-1.0f), m_bfs(0), m_lock(new std::shared_timed_mutex), m_pQuantizer(nullptr) {}
BKTree(): m_iTreeNumber(1), m_iBKTKmeansK(32), m_iBKTLeafSize(8), m_iSamples(1000), m_fBalanceFactor(-1.0f), m_bfs(0), m_lock(new std::shared_timed_mutex), m_pQuantizer(nullptr), m_parallelBuild(false), m_treeLock(new std::mutex), m_sampleMapLock(new std::mutex) {}

BKTree(const BKTree& other): m_iTreeNumber(other.m_iTreeNumber),
m_iBKTKmeansK(other.m_iBKTKmeansK),
Expand All @@ -536,7 +537,10 @@ break;
m_fBalanceFactor(other.m_fBalanceFactor),
m_lock(new std::shared_timed_mutex),
m_pQuantizer(other.m_pQuantizer),
m_bfs(0) {}
m_bfs(0),
m_parallelBuild(other.m_parallelBuild),
m_treeLock(new std::mutex),
m_sampleMapLock(new std::mutex) {}
~BKTree() {}

inline const BKTNode& operator[](SizeType index) const { return m_pTreeRoots[index]; }
Expand Down Expand Up @@ -570,11 +574,161 @@ break;
m_pSampleCenterMap.swap(newTrees.m_pSampleCenterMap);
}

struct BKTJob {
SizeType index, first, last;
bool debug;
BKTJob(SizeType index_ = 0, SizeType first_ = 0, SizeType last_ = 0, bool debug_ = false)
: index(index_), first(first_), last(last_), debug(debug_) {}
};

// Cluster ONE node into children. The heavy clustering runs lock-free; only the
// structural commit to m_pTreeRoots / m_pSampleCenterMap is done under locks.
// Sibling nodes own disjoint [first,last) ranges of localindices (incl. the centerid
// slot at `last`), so concurrent calls on different nodes never overlap.
template <typename T>
void ProcessOneNode(const Dataset<T>& data, std::vector<SizeType>& localindices,
KmeansArgs<T>& args, BKTJob item, std::vector<SizeType>* reverseIndices,
bool dynamicK, IAbortOperation* abort, std::vector<BKTJob>& outChildren)
{
outChildren.clear();
if (item.last - item.first <= m_iBKTLeafSize) {
std::lock_guard<std::mutex> lk(*m_treeLock);
m_pTreeRoots[item.index].childStart = (SizeType)m_pTreeRoots.size();
for (SizeType j = item.first; j < item.last; j++) {
SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]);
m_pTreeRoots.emplace_back(cid);
}
m_pTreeRoots[item.index].childEnd = (SizeType)m_pTreeRoots.size();
return;
}

if (dynamicK) {
args._DK = std::min<int>((item.last - item.first) / m_iBKTLeafSize + 1, m_iBKTKmeansK);
args._DK = std::max<int>(args._DK, 2);
}

int numClusters = KmeansClustering(data, localindices, item.first, item.last, args,
m_iSamples, m_fBalanceFactor, item.debug, abort);

std::lock_guard<std::mutex> lk(*m_treeLock);
m_pTreeRoots[item.index].childStart = (SizeType)m_pTreeRoots.size();
if (numClusters <= 1) {
SizeType end = min(item.last + 1, (SizeType)localindices.size());
std::sort(localindices.begin() + item.first, localindices.begin() + end);
SizeType center = (reverseIndices == nullptr) ? localindices[item.first] : reverseIndices->at(localindices[item.first]);
m_pTreeRoots[item.index].centerid = center;
m_pTreeRoots[item.index].childStart = -m_pTreeRoots[item.index].childStart;
std::lock_guard<std::mutex> sm(*m_sampleMapLock);
for (SizeType j = item.first + 1; j < end; j++) {
SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]);
m_pTreeRoots.emplace_back(cid);
m_pSampleCenterMap[cid] = center;
}
m_pSampleCenterMap[-1 - center] = item.index;
}
else {
SizeType maxCount = 0;
for (int k = 0; k < m_iBKTKmeansK; k++) if (args.counts[k] > maxCount) maxCount = args.counts[k];
SizeType first = item.first;
for (int k = 0; k < m_iBKTKmeansK; k++) {
if (args.counts[k] == 0) continue;
SizeType cid = (reverseIndices == nullptr) ? localindices[first + args.counts[k] - 1] : reverseIndices->at(localindices[first + args.counts[k] - 1]);
SizeType childIdx = (SizeType)m_pTreeRoots.size();
m_pTreeRoots.emplace_back(cid);
if (args.counts[k] > 1)
outChildren.emplace_back(childIdx, first, first + args.counts[k] - 1, item.debug && (args.counts[k] == maxCount));
first += args.counts[k];
}
}
m_pTreeRoots[item.index].childEnd = (SizeType)m_pTreeRoots.size();
}

// Two-phase parallel build:
// Phase 1 (top): serial node schedule, each node clustered with all threads
// (intra-node data parallelism handles the few large nodes well).
// Phase 2 (bottom): subtrees of size <= cutoff are built in parallel, one worker per
// subtree (per-worker args, _TH=1), sharing the single N-sized label.
template <typename T>
void BuildTrees(const Dataset<T>& data, DistCalcMethod distMethod, int numOfThreads,
std::vector<SizeType>* indices = nullptr, std::vector<SizeType>* reverseIndices = nullptr,
void BuildTreesParallel(const Dataset<T>& data, DistCalcMethod distMethod, int numOfThreads,
std::vector<SizeType>* indices, std::vector<SizeType>* reverseIndices,
bool dynamicK, IAbortOperation* abort)
{
std::vector<SizeType> localindices;
if (indices == nullptr) {
localindices.resize(data.R());
for (SizeType i = 0; i < (SizeType)localindices.size(); i++) localindices[i] = i;
}
else {
localindices.assign(indices->begin(), indices->end());
}
KmeansArgs<T> args(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), numOfThreads, distMethod, m_pQuantizer);
if (m_fBalanceFactor < 0) m_fBalanceFactor = DynamicFactorSelect(data, localindices, 0, (SizeType)localindices.size(), args, m_iSamples);

std::mt19937 rg;
m_pSampleCenterMap.clear();
for (char t = 0; t < m_iTreeNumber; t++) {
std::shuffle(localindices.begin(), localindices.end(), rg);
m_pTreeStart.push_back((SizeType)m_pTreeRoots.size());
m_pTreeRoots.emplace_back((SizeType)localindices.size());
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to build BKTree %d (parallel)\n", t + 1);

const SizeType cutoff = std::max<SizeType>((SizeType)(m_iBKTLeafSize * 4),
(SizeType)(localindices.size() / ((size_t)numOfThreads * 16) + 1));

// ---- Phase 1: serial schedule + intra-node multithread (shared args) ----
std::vector<BKTJob> deferred;
std::stack<BKTJob> ss;
ss.push(BKTJob(m_pTreeStart[t], 0, (SizeType)localindices.size(), true));
while (!ss.empty()) {
if (abort && abort->ShouldAbort()) return;
BKTJob job = ss.top(); ss.pop();
std::vector<BKTJob> kids;
ProcessOneNode<T>(data, localindices, args, job, reverseIndices, dynamicK, abort, kids);
for (size_t c = 0; c < kids.size(); c++) {
if (kids[c].last - kids[c].first <= cutoff) deferred.push_back(kids[c]);
else ss.push(kids[c]);
}
}

// ---- Phase 2: node-level parallelism over small subtrees ----
int* sharedLabel = args.label; // workers write disjoint [first,last) slices
std::atomic<size_t> nextJob(0);
std::vector<std::thread> pool;
pool.reserve(numOfThreads);
for (int w = 0; w < numOfThreads; w++) {
pool.emplace_back([&]() {
KmeansArgs<T> wargs(m_iBKTKmeansK, data.C(), 1, 1, distMethod, m_pQuantizer);
delete[] wargs.label; wargs.label = sharedLabel; // share N-sized label
std::stack<BKTJob> wss;
for (size_t idx = nextJob.fetch_add(1); idx < deferred.size(); idx = nextJob.fetch_add(1)) {
if (abort && abort->ShouldAbort()) break;
wss.push(deferred[idx]);
while (!wss.empty()) {
BKTJob job = wss.top(); wss.pop();
std::vector<BKTJob> kids;
ProcessOneNode<T>(data, localindices, wargs, job, reverseIndices, dynamicK, abort, kids);
for (size_t c = 0; c < kids.size(); c++) wss.push(kids[c]);
}
}
wargs.label = nullptr; // don't double-free the shared label in dtor
});
}
for (size_t w = 0; w < pool.size(); w++) pool[w].join();

m_pTreeRoots.emplace_back(-1);
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d BKTree built (parallel), %zu %zu\n", t + 1, m_pTreeRoots.size() - m_pTreeStart[t], localindices.size());
}
}

template <typename T>
void BuildTrees(const Dataset<T>& data, DistCalcMethod distMethod, int numOfThreads,
std::vector<SizeType>* indices = nullptr, std::vector<SizeType>* reverseIndices = nullptr,
bool dynamicK = false, IAbortOperation* abort = nullptr)
{
if (m_parallelBuild && numOfThreads > 1) {
BuildTreesParallel<T>(data, distMethod, numOfThreads, indices, reverseIndices, dynamicK, abort);
return;
}
struct BKTStackItem {
SizeType index, first, last;
bool debug;
Expand Down Expand Up @@ -865,6 +1019,11 @@ break;
int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples, m_bfs;
float m_fBalanceFactor;
std::shared_ptr<SPTAG::COMMON::IQuantizer> m_pQuantizer;

// Parallel BKT build (node-level parallelism for SelectHead). Default off.
bool m_parallelBuild;
std::unique_ptr<std::mutex> m_treeLock; // guards m_pTreeRoots structural writes
std::unique_ptr<std::mutex> m_sampleMapLock; // guards m_pSampleCenterMap writes
};
}
}
Expand Down
7 changes: 6 additions & 1 deletion AnnService/inc/Core/Common/CommonUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <string.h>
#include <vector>
#include <set>
#include <random>

#define PREFETCH

Expand All @@ -33,7 +34,11 @@ namespace SPTAG
public:
static SizeType rand(SizeType high = MaxSize, SizeType low = 0) // Generates a random int value.
{
return low + (SizeType)(float(high - low)*(std::rand() / (RAND_MAX + 1.0)));
// thread_local generator: thread-safe (the old std::rand() shares hidden global
// state and races under concurrent BKT building). Sequences differ from the legacy
// std::rand() version (equivalent quality, not bit-identical to old builds).
static thread_local std::mt19937 g(std::random_device{}());
return low + (SizeType)(float(high - low) * (g() / (g.max() + 1.0)));
}

static inline float atomic_float_add(volatile float* ptr, const float operand)
Expand Down
1 change: 1 addition & 0 deletions AnnService/inc/Core/SPANN/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace SPTAG {
bool m_recursiveCheckSmallCluster;
bool m_printSizeCount;
std::string m_selectType;
bool m_selectHeadParallel;

// Section 3: for build head
bool m_buildHead;
Expand Down
1 change: 1 addition & 0 deletions AnnService/inc/Core/SPANN/ParameterDefinitionList.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DefineSelectHeadParameter(m_headVectorCount, int, 0, "Count")
DefineSelectHeadParameter(m_recursiveCheckSmallCluster, bool, true, "RecursiveCheckSmallCluster")
DefineSelectHeadParameter(m_printSizeCount, bool, true, "PrintSizeCount")
DefineSelectHeadParameter(m_selectType, std::string, "BKT", "SelectHeadType")
DefineSelectHeadParameter(m_selectHeadParallel, bool, false, "ParallelBuildBKT")
#endif

#ifdef DefineBuildHeadParameter
Expand Down
Loading