Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lib/api/src/grpc/proto/points_internal_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,13 @@ message MmrInternal {
uint32 candidates_limit = 3;
}

message Bm25Internal {
string field = 1;
string query = 2;
float k1 = 3;
float b = 4;
}

message QueryShardPoints {
message Query {
oneof score {
Expand All @@ -380,6 +387,8 @@ message QueryShardPoints {
MmrInternal mmr = 6;
// Parameterized RRF fusion
Rrf rrf = 7;
// Full-text BM25 scoring
Bm25Internal bm25 = 8;
}
}

Expand Down
18 changes: 17 additions & 1 deletion lib/api/src/grpc/qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10900,6 +10900,19 @@ pub struct MmrInternal {
#[derive(serde::Serialize)]
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bm25Internal {
#[prost(string, tag = "1")]
pub field: ::prost::alloc::string::String,
#[prost(string, tag = "2")]
pub query: ::prost::alloc::string::String,
#[prost(float, tag = "3")]
pub k1: f32,
#[prost(float, tag = "4")]
pub b: f32,
}
#[derive(serde::Serialize)]
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct QueryShardPoints {
#[prost(message, repeated, tag = "1")]
pub prefetch: ::prost::alloc::vec::Vec<query_shard_points::Prefetch>,
Expand Down Expand Up @@ -10929,7 +10942,7 @@ pub mod query_shard_points {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Query {
#[prost(oneof = "query::Score", tags = "1, 2, 3, 4, 5, 6, 7")]
#[prost(oneof = "query::Score", tags = "1, 2, 3, 4, 5, 6, 7, 8")]
pub score: ::core::option::Option<query::Score>,
}
/// Nested message and enum types in `Query`.
Expand Down Expand Up @@ -10959,6 +10972,9 @@ pub mod query_shard_points {
/// Parameterized RRF fusion
#[prost(message, tag = "7")]
Rrf(super::super::Rrf),
/// Full-text BM25 scoring
#[prost(message, tag = "8")]
Bm25(super::super::Bm25Internal),
}
}
#[derive(serde::Serialize)]
Expand Down
30 changes: 30 additions & 0 deletions lib/api/src/rest/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,9 @@ pub enum Query {
/// Score boosting via an arbitrary formula
Formula(FormulaQuery),

/// Score points with BM25 over a full-text payload index.
Bm25(Bm25Query),

/// Sample points from the collection, non-deterministically.
Sample(SampleQuery),

Expand Down Expand Up @@ -716,6 +719,33 @@ pub struct FormulaQuery {
pub defaults: HashMap<String, Value>,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(rename_all = "snake_case")]
pub struct Bm25Query {
/// Payload field that has a text index.
pub field: JsonPath,

/// Query text to score against the indexed field.
#[validate(length(min = 1))]
pub query: String,

/// BM25 hyperparameters. Solr/Lucene-compatible defaults are used if omitted.
#[validate(nested)]
pub params: Option<Bm25Params>,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema, Validate, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub struct Bm25Params {
/// BM25 `k1` parameter. Default is 1.2.
#[validate(range(min = 0.0))]
pub k1: Option<f32>,

/// BM25 `b` parameter. Default is 0.75.
#[validate(range(min = 0.0, max = 1.0))]
pub b: Option<f32>,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema, Validate)]
#[serde(rename_all = "snake_case")]
pub struct SampleQuery {
Expand Down
1 change: 1 addition & 0 deletions lib/api/src/rest/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl Validate for Query {
Query::Fusion(fusion) => fusion.validate(),
Query::Rrf(rrf) => rrf.validate(),
Query::Formula(formula) => formula.validate(),
Query::Bm25(bm25) => bm25.validate(),
Query::OrderBy(order_by) => order_by.validate(),
Query::Sample(sample) => sample.validate(),
Query::RelevanceFeedback(feedback) => feedback.validate(),
Expand Down
2 changes: 2 additions & 0 deletions lib/collection/src/collection/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ impl Collection {
| Some(ScoringQuery::Vector(_))
| Some(ScoringQuery::OrderBy(_))
| Some(ScoringQuery::Formula(_))
| Some(ScoringQuery::Bm25(_))
| Some(ScoringQuery::Sample(_)) => {
// Otherwise, it will be a list with a single list of scored points.
debug_assert_eq!(intermediates.len(), 1);
Expand Down Expand Up @@ -737,6 +738,7 @@ fn intermediate_query_infos(request: &ShardQueryRequest) -> Vec<IntermediateQuer
| Some(ScoringQuery::Vector(_))
| Some(ScoringQuery::OrderBy(_))
| Some(ScoringQuery::Formula(_))
| Some(ScoringQuery::Bm25(_))
| Some(ScoringQuery::Sample(_)) => {
// Otherwise, we expect the root result
vec![IntermediateQueryInfo {
Expand Down
73 changes: 73 additions & 0 deletions lib/collection/src/collection_manager/segments_searcher.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::collections::BTreeSet;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::{Duration, Instant};

use ahash::AHashMap;
use common::counter::hardware_accumulator::HwMeasurementAcc;
use common::counter::hardware_counter::HardwareCounterCell;
use common::types::{DeferredBehavior, ScoreType};
use futures::stream::FuturesUnordered;
use futures::{FutureExt, TryStreamExt};
Expand All @@ -19,6 +21,7 @@ use segment::types::{
};
use shard::common::stopping_guard::StoppingGuard;
use shard::optimizers::config::DEFAULT_INDEXING_THRESHOLD_KB;
use shard::query::Bm25Internal;
use shard::query::query_context::{fill_query_context, init_query_context};
use shard::query::query_enum::QueryEnum;
use shard::retrieve::record_internal::RecordInternal;
Expand Down Expand Up @@ -51,6 +54,76 @@ type SegmentSearchExecutedResult = CollectionResult<(SegmentBatchSearchResult, V
pub struct SegmentsSearcher;

impl SegmentsSearcher {
pub async fn search_bm25(
segments: LockedSegmentHolder,
query: &Bm25Internal,
filter: Option<&Filter>,
limit: usize,
score_threshold: Option<f32>,
with_payload: &WithPayload,
with_vector: &WithVector,
search_runtime_handle: &Handle,
timeout: Duration,
hw_measurement_acc: HwMeasurementAcc,
) -> CollectionResult<Vec<ScoredPoint>> {
let segments: Vec<_> = {
let Some(segments_lock) = segments.try_read_for(timeout) else {
return Err(CollectionError::timeout(timeout, "bm25 search"));
};
segments_lock
.non_appendable_then_appendable_segments()
.collect()
};

let searches = segments
.into_iter()
.map(|segment| {
let query = query.clone();
let filter = filter.cloned();
let with_payload = with_payload.clone();
let with_vector = with_vector.clone();
let hw_measurement_acc = hw_measurement_acc.clone();
search_runtime_handle.spawn_blocking(move || {
let locked_segment = segment.get();
let Some(read_segment) = locked_segment.try_read_for(timeout) else {
return Err(CollectionError::timeout(timeout, "bm25 search"));
};
let hw_counter = HardwareCounterCell::new_with_accumulator(hw_measurement_acc);
read_segment
.search_bm25(
&query.field,
&query.query,
filter.as_ref(),
limit,
score_threshold,
&with_payload,
&with_vector,
&hw_counter,
&AtomicBool::new(false),
query.k1.into_inner(),
query.b.into_inner(),
)
.map_err(CollectionError::from)
})
})
.collect_vec();

let mut results = Vec::new();
for search in searches {
results.push(search.await.map_err(CollectionError::from)??);
}

let mut aggregator = BatchResultAggregator::new(std::iter::once(limit));
aggregator.update_point_versions(results.iter().flatten());
aggregator.update_batch_results(0, results.into_iter().flatten());

aggregator
.into_topk()
.into_iter()
.next()
.ok_or_else(|| CollectionError::service_error("expected BM25 search result"))
}

/// Execute searches in parallel and return results in the same order as the searches were provided
async fn execute_searches(
searches: Vec<AbortOnDropHandle<SegmentSearchExecutedResult>>,
Expand Down
1 change: 1 addition & 0 deletions lib/collection/src/operations/generalizer/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl Generalizer for ScoringQuery {
ScoringQuery::Fusion(_) => self.clone(),
ScoringQuery::OrderBy(_) => self.clone(),
ScoringQuery::Formula(_) => self.clone(),
ScoringQuery::Bm25(_) => self.clone(),
ScoringQuery::Sample(_) => self.clone(),
ScoringQuery::Mmr(mmr) => ScoringQuery::Mmr(mmr.remove_details()),
}
Expand Down
31 changes: 29 additions & 2 deletions lib/collection/src/operations/universal_query/collection_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use shard::query::query_enum::QueryEnum;

use super::formula::FormulaInternal;
use super::shard_query::{
FusionInternal, SampleInternal, ScoringQuery, ShardPrefetch, ShardQueryRequest,
Bm25Internal, FusionInternal, SampleInternal, ScoringQuery, ShardPrefetch, ShardQueryRequest,
};
use crate::common::fetch_vectors::ReferencedVectors;
use crate::lookup::WithLookup;
Expand Down Expand Up @@ -100,6 +100,9 @@ pub enum Query {
/// Score boosting via an arbitrary formula
Formula(FormulaInternal),

/// Score points with BM25 over a full-text payload index.
Bm25(Bm25QueryInternal),

/// Sample points
Sample(SampleInternal),
}
Expand All @@ -125,6 +128,7 @@ impl Query {
Query::Fusion(fusion) => ScoringQuery::Fusion(fusion),
Query::OrderBy(order_by) => ScoringQuery::OrderBy(order_by),
Query::Formula(formula) => ScoringQuery::Formula(ParsedFormula::try_from(formula)?),
Query::Bm25(bm25) => ScoringQuery::Bm25(Bm25Internal::from(bm25)),
Query::Sample(sample) => ScoringQuery::Sample(sample),
};

Expand All @@ -138,7 +142,30 @@ impl Query {
.into_iter()
.copied()
.collect(),
Self::Fusion(_) | Self::OrderBy(_) | Self::Formula(_) | Self::Sample(_) => Vec::new(),
Self::Fusion(_)
| Self::OrderBy(_)
| Self::Formula(_)
| Self::Bm25(_)
| Self::Sample(_) => Vec::new(),
}
}
}

#[derive(Clone, Debug, PartialEq)]
pub struct Bm25QueryInternal {
pub field: JsonPath,
pub query: String,
pub k1: Option<f32>,
pub b: Option<f32>,
}

impl From<Bm25QueryInternal> for Bm25Internal {
fn from(value: Bm25QueryInternal) -> Self {
Self {
field: value.field,
query: value.query,
k1: OrderedFloat(value.k1.unwrap_or(1.2)),
b: OrderedFloat(value.b.unwrap_or(0.75)),
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn query_result_order(
},
// Score boosting formulas are always have descending order,
// Euclidean scores can be negated within the formula
ScoringQuery::Formula(_formula) => Some(Order::LargeBetter),
ScoringQuery::Formula(_) | ScoringQuery::Bm25(_) => Some(Order::LargeBetter),
ScoringQuery::OrderBy(order_by) => Some(Order::from(order_by.direction())),
// Random sample does not require ordering
ScoringQuery::Sample(SampleInternal::Random) => None,
Expand Down
6 changes: 5 additions & 1 deletion lib/collection/src/operations/verification/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ impl Query {
// Check only applies on `search_allow_exact`
if strict_mode_config.search_allow_exact == Some(false) {
match &self {
Query::Fusion(_) | Query::OrderBy(_) | Query::Formula(_) | Query::Sample(_) => (),
Query::Fusion(_)
| Query::OrderBy(_)
| Query::Formula(_)
| Query::Bm25(_)
| Query::Sample(_) => (),
Query::Vector(_) => {
let config = collection.collection_config.read().await;

Expand Down
45 changes: 44 additions & 1 deletion lib/collection/src/shards/local_shard/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::operations::universal_query::planned_query::{
MergePlan, PlannedQuery, RescoreParams, RootPlan, Source,
};
use crate::operations::universal_query::shard_query::{
FusionInternal, MmrInternal, SampleInternal, ScoringQuery, ShardQueryResponse,
Bm25Internal, FusionInternal, MmrInternal, SampleInternal, ScoringQuery, ShardQueryResponse,
};

pub enum FetchedSource {
Expand Down Expand Up @@ -378,6 +378,18 @@ impl LocalShard {
)
.await
}
ScoringQuery::Bm25(bm25) => {
self.bm25_rescore(
sources,
bm25,
limit,
score_threshold.map(OrderedFloat::into_inner),
timeout,
hw_counter_acc,
search_runtime_handle,
)
.await
}
ScoringQuery::Sample(sample) => match sample {
SampleInternal::Random => {
// create single scroll request for rescoring query
Expand Down Expand Up @@ -421,6 +433,37 @@ impl LocalShard {
}
}

async fn bm25_rescore(
&self,
sources: Vec<Vec<ScoredPoint>>,
bm25: Bm25Internal,
limit: usize,
score_threshold: Option<f32>,
timeout: Duration,
hw_counter_acc: HwMeasurementAcc,
search_runtime_handle: &Handle,
) -> CollectionResult<Vec<ScoredPoint>> {
let filter = if sources.is_empty() {
None
} else {
Some(filter_with_sources_ids(sources.into_iter()))
};

SegmentsSearcher::search_bm25(
self.segments.clone(),
&bm25,
filter.as_ref(),
limit,
score_threshold,
&false.into(),
&false.into(),
search_runtime_handle,
timeout,
hw_counter_acc,
)
.await
}

fn fusion_rescore(
sources: Vec<Vec<ScoredPoint>>,
fusion: FusionInternal,
Expand Down
Loading
Loading