diff --git a/lib/api/src/grpc/proto/points_internal_service.proto b/lib/api/src/grpc/proto/points_internal_service.proto index 54cc42d26ca..82e41738807 100644 --- a/lib/api/src/grpc/proto/points_internal_service.proto +++ b/lib/api/src/grpc/proto/points_internal_service.proto @@ -363,6 +363,13 @@ message MmrInternal { uint32 candidates_limit = 3; } +message Bm25Internal { + string field = 1; + string query = 2; + float k1 = 3; + float b = 4; +} + message QueryShardPoints { message Query { oneof score { @@ -380,6 +387,8 @@ message QueryShardPoints { MmrInternal mmr = 6; // Parameterized RRF fusion Rrf rrf = 7; + // Full-text BM25 scoring + Bm25Internal bm25 = 8; } } diff --git a/lib/api/src/grpc/qdrant.rs b/lib/api/src/grpc/qdrant.rs index d2c1af2f26c..b2c3fd4fdd1 100644 --- a/lib/api/src/grpc/qdrant.rs +++ b/lib/api/src/grpc/qdrant.rs @@ -10900,6 +10900,19 @@ pub struct MmrInternal { #[derive(serde::Serialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Bm25Internal { + #[prost(string, tag = "1")] + pub field: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub query: ::prost::alloc::string::String, + #[prost(float, tag = "3")] + pub k1: f32, + #[prost(float, tag = "4")] + pub b: f32, +} +#[derive(serde::Serialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct QueryShardPoints { #[prost(message, repeated, tag = "1")] pub prefetch: ::prost::alloc::vec::Vec, @@ -10929,7 +10942,7 @@ pub mod query_shard_points { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Query { - #[prost(oneof = "query::Score", tags = "1, 2, 3, 4, 5, 6, 7")] + #[prost(oneof = "query::Score", tags = "1, 2, 3, 4, 5, 6, 7, 8")] pub score: ::core::option::Option, } /// Nested message and enum types in `Query`. @@ -10959,6 +10972,9 @@ pub mod query_shard_points { /// Parameterized RRF fusion #[prost(message, tag = "7")] Rrf(super::super::Rrf), + /// Full-text BM25 scoring + #[prost(message, tag = "8")] + Bm25(super::super::Bm25Internal), } } #[derive(serde::Serialize)] diff --git a/lib/api/src/rest/schema.rs b/lib/api/src/rest/schema.rs index ae4d1b754bc..9079bdba1b1 100644 --- a/lib/api/src/rest/schema.rs +++ b/lib/api/src/rest/schema.rs @@ -646,6 +646,9 @@ pub enum Query { /// Score boosting via an arbitrary formula Formula(FormulaQuery), + /// Score points with BM25 over a full-text payload index. + Bm25(Bm25Query), + /// Sample points from the collection, non-deterministically. Sample(SampleQuery), @@ -716,6 +719,33 @@ pub struct FormulaQuery { pub defaults: HashMap, } +#[derive(Debug, Serialize, Deserialize, JsonSchema, Validate)] +#[serde(rename_all = "snake_case")] +pub struct Bm25Query { + /// Payload field that has a text index. + pub field: JsonPath, + + /// Query text to score against the indexed field. + #[validate(length(min = 1))] + pub query: String, + + /// BM25 hyperparameters. Solr/Lucene-compatible defaults are used if omitted. + #[validate(nested)] + pub params: Option, +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema, Validate, Clone, PartialEq)] +#[serde(rename_all = "snake_case")] +pub struct Bm25Params { + /// BM25 `k1` parameter. Default is 1.2. + #[validate(range(min = 0.0))] + pub k1: Option, + + /// BM25 `b` parameter. Default is 0.75. + #[validate(range(min = 0.0, max = 1.0))] + pub b: Option, +} + #[derive(Debug, Serialize, Deserialize, JsonSchema, Validate)] #[serde(rename_all = "snake_case")] pub struct SampleQuery { diff --git a/lib/api/src/rest/validate.rs b/lib/api/src/rest/validate.rs index a13722cffbf..a2c3df51fdc 100644 --- a/lib/api/src/rest/validate.rs +++ b/lib/api/src/rest/validate.rs @@ -40,6 +40,7 @@ impl Validate for Query { Query::Fusion(fusion) => fusion.validate(), Query::Rrf(rrf) => rrf.validate(), Query::Formula(formula) => formula.validate(), + Query::Bm25(bm25) => bm25.validate(), Query::OrderBy(order_by) => order_by.validate(), Query::Sample(sample) => sample.validate(), Query::RelevanceFeedback(feedback) => feedback.validate(), diff --git a/lib/collection/src/collection/query.rs b/lib/collection/src/collection/query.rs index 0086d17df89..8107e50ef63 100644 --- a/lib/collection/src/collection/query.rs +++ b/lib/collection/src/collection/query.rs @@ -424,6 +424,7 @@ impl Collection { | Some(ScoringQuery::Vector(_)) | Some(ScoringQuery::OrderBy(_)) | Some(ScoringQuery::Formula(_)) + | Some(ScoringQuery::Bm25(_)) | Some(ScoringQuery::Sample(_)) => { // Otherwise, it will be a list with a single list of scored points. debug_assert_eq!(intermediates.len(), 1); @@ -737,6 +738,7 @@ fn intermediate_query_infos(request: &ShardQueryRequest) -> Vec { // Otherwise, we expect the root result vec![IntermediateQueryInfo { diff --git a/lib/collection/src/collection_manager/segments_searcher.rs b/lib/collection/src/collection_manager/segments_searcher.rs index f9e3c372f17..65f010d7230 100644 --- a/lib/collection/src/collection_manager/segments_searcher.rs +++ b/lib/collection/src/collection_manager/segments_searcher.rs @@ -1,9 +1,11 @@ use std::collections::BTreeSet; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::time::{Duration, Instant}; use ahash::AHashMap; use common::counter::hardware_accumulator::HwMeasurementAcc; +use common::counter::hardware_counter::HardwareCounterCell; use common::types::{DeferredBehavior, ScoreType}; use futures::stream::FuturesUnordered; use futures::{FutureExt, TryStreamExt}; @@ -19,6 +21,7 @@ use segment::types::{ }; use shard::common::stopping_guard::StoppingGuard; use shard::optimizers::config::DEFAULT_INDEXING_THRESHOLD_KB; +use shard::query::Bm25Internal; use shard::query::query_context::{fill_query_context, init_query_context}; use shard::query::query_enum::QueryEnum; use shard::retrieve::record_internal::RecordInternal; @@ -51,6 +54,76 @@ type SegmentSearchExecutedResult = CollectionResult<(SegmentBatchSearchResult, V pub struct SegmentsSearcher; impl SegmentsSearcher { + pub async fn search_bm25( + segments: LockedSegmentHolder, + query: &Bm25Internal, + filter: Option<&Filter>, + limit: usize, + score_threshold: Option, + with_payload: &WithPayload, + with_vector: &WithVector, + search_runtime_handle: &Handle, + timeout: Duration, + hw_measurement_acc: HwMeasurementAcc, + ) -> CollectionResult> { + let segments: Vec<_> = { + let Some(segments_lock) = segments.try_read_for(timeout) else { + return Err(CollectionError::timeout(timeout, "bm25 search")); + }; + segments_lock + .non_appendable_then_appendable_segments() + .collect() + }; + + let searches = segments + .into_iter() + .map(|segment| { + let query = query.clone(); + let filter = filter.cloned(); + let with_payload = with_payload.clone(); + let with_vector = with_vector.clone(); + let hw_measurement_acc = hw_measurement_acc.clone(); + search_runtime_handle.spawn_blocking(move || { + let locked_segment = segment.get(); + let Some(read_segment) = locked_segment.try_read_for(timeout) else { + return Err(CollectionError::timeout(timeout, "bm25 search")); + }; + let hw_counter = HardwareCounterCell::new_with_accumulator(hw_measurement_acc); + read_segment + .search_bm25( + &query.field, + &query.query, + filter.as_ref(), + limit, + score_threshold, + &with_payload, + &with_vector, + &hw_counter, + &AtomicBool::new(false), + query.k1.into_inner(), + query.b.into_inner(), + ) + .map_err(CollectionError::from) + }) + }) + .collect_vec(); + + let mut results = Vec::new(); + for search in searches { + results.push(search.await.map_err(CollectionError::from)??); + } + + let mut aggregator = BatchResultAggregator::new(std::iter::once(limit)); + aggregator.update_point_versions(results.iter().flatten()); + aggregator.update_batch_results(0, results.into_iter().flatten()); + + aggregator + .into_topk() + .into_iter() + .next() + .ok_or_else(|| CollectionError::service_error("expected BM25 search result")) + } + /// Execute searches in parallel and return results in the same order as the searches were provided async fn execute_searches( searches: Vec>, diff --git a/lib/collection/src/operations/generalizer/query.rs b/lib/collection/src/operations/generalizer/query.rs index 82bd74fb6f2..b3b3c62821f 100644 --- a/lib/collection/src/operations/generalizer/query.rs +++ b/lib/collection/src/operations/generalizer/query.rs @@ -76,6 +76,7 @@ impl Generalizer for ScoringQuery { ScoringQuery::Fusion(_) => self.clone(), ScoringQuery::OrderBy(_) => self.clone(), ScoringQuery::Formula(_) => self.clone(), + ScoringQuery::Bm25(_) => self.clone(), ScoringQuery::Sample(_) => self.clone(), ScoringQuery::Mmr(mmr) => ScoringQuery::Mmr(mmr.remove_details()), } diff --git a/lib/collection/src/operations/universal_query/collection_query.rs b/lib/collection/src/operations/universal_query/collection_query.rs index c7a540707eb..15c2d2c8f3d 100644 --- a/lib/collection/src/operations/universal_query/collection_query.rs +++ b/lib/collection/src/operations/universal_query/collection_query.rs @@ -19,7 +19,7 @@ use shard::query::query_enum::QueryEnum; use super::formula::FormulaInternal; use super::shard_query::{ - FusionInternal, SampleInternal, ScoringQuery, ShardPrefetch, ShardQueryRequest, + Bm25Internal, FusionInternal, SampleInternal, ScoringQuery, ShardPrefetch, ShardQueryRequest, }; use crate::common::fetch_vectors::ReferencedVectors; use crate::lookup::WithLookup; @@ -100,6 +100,9 @@ pub enum Query { /// Score boosting via an arbitrary formula Formula(FormulaInternal), + /// Score points with BM25 over a full-text payload index. + Bm25(Bm25QueryInternal), + /// Sample points Sample(SampleInternal), } @@ -125,6 +128,7 @@ impl Query { Query::Fusion(fusion) => ScoringQuery::Fusion(fusion), Query::OrderBy(order_by) => ScoringQuery::OrderBy(order_by), Query::Formula(formula) => ScoringQuery::Formula(ParsedFormula::try_from(formula)?), + Query::Bm25(bm25) => ScoringQuery::Bm25(Bm25Internal::from(bm25)), Query::Sample(sample) => ScoringQuery::Sample(sample), }; @@ -138,7 +142,30 @@ impl Query { .into_iter() .copied() .collect(), - Self::Fusion(_) | Self::OrderBy(_) | Self::Formula(_) | Self::Sample(_) => Vec::new(), + Self::Fusion(_) + | Self::OrderBy(_) + | Self::Formula(_) + | Self::Bm25(_) + | Self::Sample(_) => Vec::new(), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Bm25QueryInternal { + pub field: JsonPath, + pub query: String, + pub k1: Option, + pub b: Option, +} + +impl From for Bm25Internal { + fn from(value: Bm25QueryInternal) -> Self { + Self { + field: value.field, + query: value.query, + k1: OrderedFloat(value.k1.unwrap_or(1.2)), + b: OrderedFloat(value.b.unwrap_or(0.75)), } } } diff --git a/lib/collection/src/operations/universal_query/shard_query.rs b/lib/collection/src/operations/universal_query/shard_query.rs index 7a30be2a54e..9b901c1fb01 100644 --- a/lib/collection/src/operations/universal_query/shard_query.rs +++ b/lib/collection/src/operations/universal_query/shard_query.rs @@ -29,7 +29,7 @@ pub fn query_result_order( }, // Score boosting formulas are always have descending order, // Euclidean scores can be negated within the formula - ScoringQuery::Formula(_formula) => Some(Order::LargeBetter), + ScoringQuery::Formula(_) | ScoringQuery::Bm25(_) => Some(Order::LargeBetter), ScoringQuery::OrderBy(order_by) => Some(Order::from(order_by.direction())), // Random sample does not require ordering ScoringQuery::Sample(SampleInternal::Random) => None, diff --git a/lib/collection/src/operations/verification/query.rs b/lib/collection/src/operations/verification/query.rs index a30e666928e..adc0a8bc52b 100644 --- a/lib/collection/src/operations/verification/query.rs +++ b/lib/collection/src/operations/verification/query.rs @@ -48,7 +48,11 @@ impl Query { // Check only applies on `search_allow_exact` if strict_mode_config.search_allow_exact == Some(false) { match &self { - Query::Fusion(_) | Query::OrderBy(_) | Query::Formula(_) | Query::Sample(_) => (), + Query::Fusion(_) + | Query::OrderBy(_) + | Query::Formula(_) + | Query::Bm25(_) + | Query::Sample(_) => (), Query::Vector(_) => { let config = collection.collection_config.read().await; diff --git a/lib/collection/src/shards/local_shard/query.rs b/lib/collection/src/shards/local_shard/query.rs index b797fa817f4..caacb1045d6 100644 --- a/lib/collection/src/shards/local_shard/query.rs +++ b/lib/collection/src/shards/local_shard/query.rs @@ -26,7 +26,7 @@ use crate::operations::universal_query::planned_query::{ MergePlan, PlannedQuery, RescoreParams, RootPlan, Source, }; use crate::operations::universal_query::shard_query::{ - FusionInternal, MmrInternal, SampleInternal, ScoringQuery, ShardQueryResponse, + Bm25Internal, FusionInternal, MmrInternal, SampleInternal, ScoringQuery, ShardQueryResponse, }; pub enum FetchedSource { @@ -378,6 +378,18 @@ impl LocalShard { ) .await } + ScoringQuery::Bm25(bm25) => { + self.bm25_rescore( + sources, + bm25, + limit, + score_threshold.map(OrderedFloat::into_inner), + timeout, + hw_counter_acc, + search_runtime_handle, + ) + .await + } ScoringQuery::Sample(sample) => match sample { SampleInternal::Random => { // create single scroll request for rescoring query @@ -421,6 +433,37 @@ impl LocalShard { } } + async fn bm25_rescore( + &self, + sources: Vec>, + bm25: Bm25Internal, + limit: usize, + score_threshold: Option, + timeout: Duration, + hw_counter_acc: HwMeasurementAcc, + search_runtime_handle: &Handle, + ) -> CollectionResult> { + let filter = if sources.is_empty() { + None + } else { + Some(filter_with_sources_ids(sources.into_iter())) + }; + + SegmentsSearcher::search_bm25( + self.segments.clone(), + &bm25, + filter.as_ref(), + limit, + score_threshold, + &false.into(), + &false.into(), + search_runtime_handle, + timeout, + hw_counter_acc, + ) + .await + } + fn fusion_rescore( sources: Vec>, fusion: FusionInternal, diff --git a/lib/collection/src/tests/shard_query.rs b/lib/collection/src/tests/shard_query.rs index ff0af74b44a..875590e0ff6 100644 --- a/lib/collection/src/tests/shard_query.rs +++ b/lib/collection/src/tests/shard_query.rs @@ -4,8 +4,18 @@ use common::budget::ResourceBudget; use common::counter::hardware_accumulator::HwMeasurementAcc; use common::save_on_disk::SaveOnDisk; use segment::common::reciprocal_rank_fusion::DEFAULT_RRF_K; -use segment::data_types::vectors::{DEFAULT_VECTOR_NAME, NamedQuery, VectorInternal}; -use segment::types::{PointIdType, WithPayloadInterface, WithVector}; +use segment::data_types::vectors::{ + DEFAULT_VECTOR_NAME, NamedQuery, VectorInternal, VectorStructInternal, +}; +use segment::json_path::JsonPath; +use segment::types::{ + PayloadFieldSchema, PayloadSchemaType, PointIdType, WithPayloadInterface, WithVector, +}; +use shard::operations::point_ops::{ + PointInsertOperationsInternal, PointOperations, PointStructPersisted, +}; +use shard::operations::{CollectionUpdateOperations, CreateIndex, FieldIndexOperations}; +use shard::query::Bm25Internal; use shard::query::query_enum::QueryEnum; use tempfile::Builder; use tokio::runtime::Handle; @@ -19,6 +29,64 @@ use crate::shards::local_shard::LocalShard; use crate::shards::shard_trait::{ShardOperation, WaitUntil}; use crate::tests::fixtures::*; +async fn create_text_index(shard: &LocalShard, field_name: &str) { + let create_index = CollectionUpdateOperations::FieldIndexOperation( + FieldIndexOperations::CreateIndex(CreateIndex { + field_name: field_name.parse().unwrap(), + field_schema: Some(PayloadFieldSchema::FieldType(PayloadSchemaType::Text)), + }), + ); + + shard + .update( + create_index.into(), + WaitUntil::Visible, + None, + HwMeasurementAcc::new(), + ) + .await + .unwrap(); +} + +fn bm25_upsert_operation() -> CollectionUpdateOperations { + let points = vec![ + PointStructPersisted { + id: 1.into(), + vector: VectorStructInternal::from(vec![1.0, 0.0, 0.0, 0.0]).into(), + payload: Some( + serde_json::from_value(serde_json::json!({ + "content": "apple apple banana" + })) + .unwrap(), + ), + }, + PointStructPersisted { + id: 2.into(), + vector: VectorStructInternal::from(vec![0.0, 1.0, 0.0, 0.0]).into(), + payload: Some( + serde_json::from_value(serde_json::json!({ + "content": "apple banana" + })) + .unwrap(), + ), + }, + PointStructPersisted { + id: 3.into(), + vector: VectorStructInternal::from(vec![0.0, 0.0, 1.0, 0.0]).into(), + payload: Some( + serde_json::from_value(serde_json::json!({ + "content": "banana banana" + })) + .unwrap(), + ), + }, + ]; + + CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints( + PointInsertOperationsInternal::PointsList(points), + )) +} + #[tokio::test(flavor = "multi_thread")] async fn test_shard_query_rrf_rescoring() { let collection_dir = Builder::new().prefix("test_collection").tempdir().unwrap(); @@ -372,6 +440,79 @@ async fn test_shard_query_vector_rescoring() { assert_eq!(sources_scores[0].len(), outer_limit); } +#[tokio::test(flavor = "multi_thread")] +async fn test_shard_query_bm25_rescoring() { + let collection_dir = Builder::new().prefix("test_collection").tempdir().unwrap(); + + let config = create_collection_config(); + let collection_name = "test".to_string(); + let current_runtime: Handle = Handle::current(); + + let payload_index_schema_dir = Builder::new().prefix("qdrant-test").tempdir().unwrap(); + let payload_index_schema_file = payload_index_schema_dir.path().join("payload-schema.json"); + let payload_index_schema = + Arc::new(SaveOnDisk::load_or_init_default(payload_index_schema_file).unwrap()); + + let shard = LocalShard::build( + 0, + collection_name, + collection_dir.path(), + Arc::new(RwLock::new(config.clone())), + Arc::new(Default::default()), + payload_index_schema, + current_runtime.clone(), + current_runtime.clone(), + ResourceBudget::default(), + config.optimizer_config.clone(), + ) + .await + .unwrap(); + + shard + .update( + bm25_upsert_operation().into(), + WaitUntil::Visible, + None, + HwMeasurementAcc::new(), + ) + .await + .unwrap(); + + create_text_index(&shard, "content").await; + + let query = ShardQueryRequest { + prefetches: vec![], + query: Some(ScoringQuery::Bm25(Bm25Internal { + field: JsonPath::new("content"), + query: "apple".to_string(), + k1: 1.2.into(), + b: 0.75.into(), + })), + filter: None, + score_threshold: None, + limit: 3, + offset: 0, + params: None, + with_vector: WithVector::Bool(false), + with_payload: WithPayloadInterface::Bool(false), + }; + + let hw_acc = HwMeasurementAcc::new(); + let response = shard + .query_batch(Arc::new(vec![query]), ¤t_runtime, None, hw_acc) + .await + .unwrap() + .pop() + .unwrap(); + + assert_eq!(response.len(), 1); + let points = &response[0]; + assert_eq!(points.len(), 2); + assert_eq!(points[0].id, PointIdType::NumId(1)); + assert_eq!(points[1].id, PointIdType::NumId(2)); + assert!(points[0].score > points[1].score); +} + #[tokio::test(flavor = "multi_thread")] async fn test_shard_query_payload_vector() { let collection_dir = Builder::new().prefix("test_collection").tempdir().unwrap(); diff --git a/lib/edge/python/src/query.rs b/lib/edge/python/src/query.rs index 4c509035926..768250331f7 100644 --- a/lib/edge/python/src/query.rs +++ b/lib/edge/python/src/query.rs @@ -251,6 +251,7 @@ impl FromPyObject<'_, '_> for PyScoringQuery { ScoringQuery::Fusion(_) => {} ScoringQuery::OrderBy(_) => {} ScoringQuery::Formula(_) => {} + ScoringQuery::Bm25(_) => {} ScoringQuery::Sample(_) => {} ScoringQuery::Mmr(_) => {} } @@ -280,6 +281,9 @@ impl<'py> IntoPyObject<'py> for PyScoringQuery { ScoringQuery::Fusion(fusion) => PyFusion::from(fusion).into_bound_py_any(py), ScoringQuery::OrderBy(order_by) => PyOrderBy(order_by).into_bound_py_any(py), ScoringQuery::Formula(formula) => PyFormula(formula).into_bound_py_any(py), + ScoringQuery::Bm25(_bm25) => Err(pyo3::exceptions::PyNotImplementedError::new_err( + "BM25 queries are not yet exposed in the Python edge bindings", + )), ScoringQuery::Sample(sample) => PySample::from(sample).into_bound_py_any(py), ScoringQuery::Mmr(mmr) => PyMmr(mmr).into_bound_py_any(py), } @@ -303,6 +307,7 @@ impl Repr for PyScoringQuery { ScoringQuery::Fusion(fusion) => PyFusion::from(fusion.clone()).fmt(f), ScoringQuery::OrderBy(order_by) => PyOrderBy::wrap_ref(order_by).fmt(f), ScoringQuery::Formula(_formula) => f.unimplemented(), // TODO! + ScoringQuery::Bm25(_bm25) => f.unimplemented(), ScoringQuery::Sample(sample) => PySample::from(*sample).fmt(f), ScoringQuery::Mmr(mmr) => PyMmr::wrap_ref(mmr).fmt(f), } diff --git a/lib/edge/src/query.rs b/lib/edge/src/query.rs index 3bbb40ed020..94031b3b00e 100644 --- a/lib/edge/src/query.rs +++ b/lib/edge/src/query.rs @@ -252,6 +252,10 @@ impl EdgeShard { hw_counter_acc, ), + ScoringQuery::Bm25(_bm25) => Err(OperationError::service_error( + "BM25 queries are not supported in the edge query runtime", + )), + ScoringQuery::Sample(sample) => match sample { SampleInternal::Random => { // create single scroll request for rescoring query diff --git a/lib/segment/src/entry/entry_point.rs b/lib/segment/src/entry/entry_point.rs index 90866713647..32b1a14fc56 100644 --- a/lib/segment/src/entry/entry_point.rs +++ b/lib/segment/src/entry/entry_point.rs @@ -64,6 +64,22 @@ pub trait ReadSegmentEntry: SnapshotEntry { hw_counter: &HardwareCounterCell, ) -> OperationResult>; + #[allow(clippy::too_many_arguments)] + fn search_bm25( + &self, + field: &JsonPath, + query: &str, + filter: Option<&Filter>, + top: usize, + score_threshold: Option, + with_payload: &WithPayload, + with_vector: &WithVector, + hw_counter: &HardwareCounterCell, + is_stopped: &AtomicBool, + k1: f32, + b: f32, + ) -> OperationResult>; + fn vector( &self, vector_name: &VectorName, diff --git a/lib/segment/src/index/field_index/field_index_base.rs b/lib/segment/src/index/field_index/field_index_base.rs index d7d03986183..2a83e4837b3 100644 --- a/lib/segment/src/index/field_index/field_index_base.rs +++ b/lib/segment/src/index/field_index/field_index_base.rs @@ -157,6 +157,13 @@ impl std::fmt::Debug for FieldIndex { } impl FieldIndex { + pub fn as_full_text(&self) -> Option<&FullTextIndex> { + match self { + FieldIndex::FullTextIndex(index) => Some(index), + _ => None, + } + } + /// Try to check condition for a payload given a field index. /// Required because some index parameters may influence the condition checking logic. /// For example, full text index may have different tokenizers. diff --git a/lib/segment/src/index/field_index/full_text_index/immutable_text_index.rs b/lib/segment/src/index/field_index/full_text_index/immutable_text_index.rs index df02d2f55b5..5dc081077d9 100644 --- a/lib/segment/src/index/field_index/full_text_index/immutable_text_index.rs +++ b/lib/segment/src/index/field_index/full_text_index/immutable_text_index.rs @@ -88,4 +88,8 @@ impl ImmutableFullTextIndex { }, } } + + pub fn total_tokens_count(&self) -> usize { + self.inverted_index.total_tokens_count + } } diff --git a/lib/segment/src/index/field_index/full_text_index/inverted_index/immutable_inverted_index.rs b/lib/segment/src/index/field_index/full_text_index/inverted_index/immutable_inverted_index.rs index 5edd8b90bdf..3f713a7604c 100644 --- a/lib/segment/src/index/field_index/full_text_index/inverted_index/immutable_inverted_index.rs +++ b/lib/segment/src/index/field_index/full_text_index/inverted_index/immutable_inverted_index.rs @@ -28,6 +28,7 @@ pub struct ImmutableInvertedIndex { pub(in crate::index::field_index::full_text_index) vocab: HashMap, pub(in crate::index::field_index::full_text_index) point_to_tokens_count: Vec, pub(in crate::index::field_index::full_text_index) points_count: usize, + pub(in crate::index::field_index::full_text_index) total_tokens_count: usize, } impl ImmutableInvertedIndex { @@ -318,6 +319,7 @@ impl From for ImmutableInvertedIndex { postings, vocab, point_to_tokens, + point_to_tokens_count, point_to_doc, points_count, } = index; @@ -338,12 +340,14 @@ impl From for ImmutableInvertedIndex { ImmutableInvertedIndex { postings, vocab, + total_tokens_count: point_to_tokens_count.iter().sum(), point_to_tokens_count: point_to_tokens .iter() - .map(|tokenset| { + .enumerate() + .map(|(idx, tokenset)| { tokenset .as_ref() - .map(|tokenset| tokenset.len()) + .map(|_| point_to_tokens_count.get(idx).copied().unwrap_or(0)) .unwrap_or(0) }) .collect(), @@ -478,6 +482,7 @@ impl From<&MmapInvertedIndex> for ImmutableInvertedIndex { ImmutableInvertedIndex { postings, vocab, + total_tokens_count: index.storage.point_to_tokens_count.iter().sum(), point_to_tokens_count: index.storage.point_to_tokens_count.to_vec(), points_count: index.points_count(), } diff --git a/lib/segment/src/index/field_index/full_text_index/inverted_index/mmap_inverted_index/mod.rs b/lib/segment/src/index/field_index/full_text_index/inverted_index/mmap_inverted_index/mod.rs index 9a25dd13c29..12ef6696a5d 100644 --- a/lib/segment/src/index/field_index/full_text_index/inverted_index/mmap_inverted_index/mod.rs +++ b/lib/segment/src/index/field_index/full_text_index/inverted_index/mmap_inverted_index/mod.rs @@ -41,6 +41,7 @@ pub struct MmapInvertedIndex { pub(in crate::index::field_index::full_text_index) storage: Storage, /// Number of points which are not deleted pub(in crate::index::field_index::full_text_index) active_points_count: usize, + pub(in crate::index::field_index::full_text_index) total_tokens_count: usize, is_on_disk: bool, } @@ -59,6 +60,7 @@ impl MmapInvertedIndex { vocab, point_to_tokens_count, points_count: _, + total_tokens_count: _, } = inverted_index; debug_assert_eq!(vocab.len(), postings.len()); @@ -153,6 +155,7 @@ impl MmapInvertedIndex { )?; let num_deleted_points = deleted.count_ones()?; let deleted_points = MmapBitSliceBufferedUpdateWrapper::new(deleted); + let total_tokens_count: usize = point_to_tokens_count.iter().sum(); let points_count = point_to_tokens_count.len() - num_deleted_points; Ok(Some(Self { @@ -164,6 +167,7 @@ impl MmapInvertedIndex { deleted_points, }, active_points_count: points_count, + total_tokens_count, is_on_disk: !populate, })) } @@ -442,6 +446,7 @@ impl InvertedIndex for MmapInvertedIndex { self.storage.deleted_points.set(idx as usize, true); if let Some(count) = self.storage.point_to_tokens_count.get_mut(idx as usize) { + self.total_tokens_count = self.total_tokens_count.saturating_sub(*count); *count = 0; // `deleted_points`'s length can be larger than `point_to_tokens_count`'s length. diff --git a/lib/segment/src/index/field_index/full_text_index/inverted_index/mutable_inverted_index.rs b/lib/segment/src/index/field_index/full_text_index/inverted_index/mutable_inverted_index.rs index 6bc5b3741d5..67c2ee115bc 100644 --- a/lib/segment/src/index/field_index/full_text_index/inverted_index/mutable_inverted_index.rs +++ b/lib/segment/src/index/field_index/full_text_index/inverted_index/mutable_inverted_index.rs @@ -14,6 +14,7 @@ pub struct MutableInvertedIndex { pub(super) postings: Vec, pub vocab: HashMap, pub(super) point_to_tokens: Vec>, + pub(super) point_to_tokens_count: Vec, /// Optional additional structure to store positional information of tokens in the documents. /// @@ -29,6 +30,7 @@ impl MutableInvertedIndex { postings: Vec::new(), vocab: HashMap::new(), point_to_tokens: Vec::new(), + point_to_tokens_count: Vec::new(), point_to_doc: with_positions.then_some(Vec::new()), points_count: 0, } @@ -42,6 +44,18 @@ impl MutableInvertedIndex { self.point_to_doc.as_ref()?.get(idx as usize)?.as_ref() } + pub fn set_point_tokens_count(&mut self, point_id: PointOffsetType, tokens_count: usize) { + let point_id = point_id as usize; + if self.point_to_tokens_count.len() <= point_id { + self.point_to_tokens_count.resize(point_id + 1, 0); + } + self.point_to_tokens_count[point_id] = tokens_count; + } + + pub fn total_tokens_count(&self) -> usize { + self.point_to_tokens_count.iter().sum() + } + /// Iterate over point ids whose documents contain all given tokens fn filter_has_all(&self, tokens: TokenSet) -> impl Iterator + '_ { let postings_opt: Option> = tokens @@ -136,6 +150,10 @@ impl InvertedIndex for MutableInvertedIndex { self.point_to_tokens.resize_with(new_len, Default::default); } + if self.point_to_tokens_count.len() <= point_id as usize { + self.point_to_tokens_count.resize(point_id as usize + 1, 0); + } + for token_id in tokens.tokens() { let token_idx_usize = *token_id as usize; @@ -200,6 +218,7 @@ impl InvertedIndex for MutableInvertedIndex { } self.points_count -= 1; + self.point_to_tokens_count[point_id as usize] = 0; for removed_token in removed_token_set.tokens() { // unwrap safety: posting list exists and contains the point idx diff --git a/lib/segment/src/index/field_index/full_text_index/mmap_text_index.rs b/lib/segment/src/index/field_index/full_text_index/mmap_text_index.rs index d38a6131969..f0d7e470d9c 100644 --- a/lib/segment/src/index/field_index/full_text_index/mmap_text_index.rs +++ b/lib/segment/src/index/field_index/full_text_index/mmap_text_index.rs @@ -78,6 +78,10 @@ impl MmapFullTextIndex { self.inverted_index.is_on_disk() } + pub fn total_tokens_count(&self) -> usize { + self.inverted_index.total_tokens_count + } + /// Populate all pages in the mmap. /// Block until all pages are populated. pub fn populate(&self) -> OperationResult<()> { @@ -152,6 +156,8 @@ impl ValueIndexer for FullTextMmapIndexBuilder { let token_set = TokenSet::from_iter(tokens); self.mutable_index.index_tokens(id, token_set, hw_counter)?; + self.mutable_index + .set_point_tokens_count(id, str_tokens.len()); Ok(()) } diff --git a/lib/segment/src/index/field_index/full_text_index/mutable_text_index.rs b/lib/segment/src/index/field_index/full_text_index/mutable_text_index.rs index 88ea51542e5..3bee0428f7e 100644 --- a/lib/segment/src/index/field_index/full_text_index/mutable_text_index.rs +++ b/lib/segment/src/index/field_index/full_text_index/mutable_text_index.rs @@ -184,6 +184,8 @@ impl MutableFullTextIndex { let token_set = TokenSet::from_iter(tokens); self.inverted_index .index_tokens(idx, token_set, hw_counter)?; + self.inverted_index + .set_point_tokens_count(idx, str_tokens.len()); let tokens_to_store = if phrase_matching { // store ordered tokens diff --git a/lib/segment/src/index/field_index/full_text_index/text_index.rs b/lib/segment/src/index/field_index/full_text_index/text_index.rs index 4e4ecbfb6e7..81942a72cce 100644 --- a/lib/segment/src/index/field_index/full_text_index/text_index.rs +++ b/lib/segment/src/index/field_index/full_text_index/text_index.rs @@ -98,6 +98,10 @@ impl FullTextIndex { } } + pub fn indexed_points_count(&self) -> usize { + self.points_count() + } + pub(super) fn get_token( &self, token: &str, @@ -110,6 +114,22 @@ impl FullTextIndex { } } + pub fn token_id(&self, token: &str, hw_counter: &HardwareCounterCell) -> Option { + self.get_token(token, hw_counter) + } + + pub fn get_posting_len( + &self, + token_id: TokenId, + hw_counter: &HardwareCounterCell, + ) -> Option { + match self { + Self::Mutable(index) => index.inverted_index.get_posting_len(token_id, hw_counter), + Self::Immutable(index) => index.inverted_index.get_posting_len(token_id, hw_counter), + Self::Mmap(index) => index.inverted_index.get_posting_len(token_id, hw_counter), + } + } + pub(super) fn filter_query<'a>( &'a self, query: ParsedQuery, @@ -132,6 +152,14 @@ impl FullTextIndex { } } + pub fn values_total_count(&self) -> usize { + match self { + Self::Mutable(index) => index.inverted_index.total_tokens_count(), + Self::Immutable(index) => index.total_tokens_count(), + Self::Mmap(index) => index.total_tokens_count(), + } + } + fn payload_blocks( &self, threshold: usize, @@ -179,6 +207,14 @@ impl FullTextIndex { } } + pub fn average_len(&self) -> f32 { + let points_count = self.points_count(); + if points_count == 0 { + return 0.0; + } + self.values_total_count() as f32 / points_count as f32 + } + pub fn values_is_empty(&self, point_id: PointOffsetType) -> bool { match self { Self::Mutable(index) => index.inverted_index.values_is_empty(point_id), @@ -267,6 +303,32 @@ impl FullTextIndex { Some(ParsedQuery::AnyTokens(tokens)) } + pub fn parse_bm25_query( + &self, + text: &str, + hw_counter: &HardwareCounterCell, + ) -> Option> { + let ParsedQuery::AnyTokens(tokens) = self.parse_text_any_query(text, hw_counter)? else { + return None; + }; + Some(tokens.inner()) + } + + pub fn bm25_candidates<'a>( + &'a self, + text: &str, + hw_counter: &'a HardwareCounterCell, + ) -> Box + 'a> { + let Some(parsed_query) = self.parse_text_any_query(text, hw_counter) else { + return Box::new(std::iter::empty()); + }; + self.filter_query(parsed_query, hw_counter) + } + + pub fn tokenize_document_text<'a>(&'a self, value: &'a str, mut f: impl FnMut(Cow<'a, str>)) { + self.get_tokenizer().tokenize_doc(value, |token| f(token)); + } + pub fn parse_tokenset(&self, text: &str, hw_counter: &HardwareCounterCell) -> TokenSet { let mut tokenset = AHashSet::new(); self.get_tokenizer().tokenize_doc(text, |token| { diff --git a/lib/segment/src/index/struct_payload_index.rs b/lib/segment/src/index/struct_payload_index.rs index 580c553bca4..21e5f041675 100644 --- a/lib/segment/src/index/struct_payload_index.rs +++ b/lib/segment/src/index/struct_payload_index.rs @@ -583,6 +583,19 @@ impl StructPayloadIndex { }) } + pub fn get_full_text_index( + &self, + key: &JsonPath, + ) -> OperationResult<&crate::index::field_index::full_text_index::text_index::FullTextIndex> + { + self.field_indexes + .get(key) + .and_then(|indexes| indexes.iter().find_map(FieldIndex::as_full_text)) + .ok_or_else(|| { + OperationError::service_error(format!("Missing full-text index for field `{key}`")) + }) + } + pub fn populate(&self) -> OperationResult<()> { for (_, field_indexes) in self.field_indexes.iter() { for index in field_indexes { diff --git a/lib/segment/src/segment/bm25_search.rs b/lib/segment/src/segment/bm25_search.rs new file mode 100644 index 00000000000..81ee1678042 --- /dev/null +++ b/lib/segment/src/segment/bm25_search.rs @@ -0,0 +1,166 @@ +use std::collections::HashMap; +use std::sync::atomic::AtomicBool; + +use common::counter::hardware_counter::HardwareCounterCell; +use common::top_k::TopK; +use common::types::{PointOffsetType, ScoredPointOffset}; + +use super::Segment; +use crate::common::check_stopped; +use crate::common::operation_error::OperationResult; +use crate::json_path::JsonPath; +use crate::payload_storage::FilterContext; +use crate::types::{Filter, PayloadContainer}; + +impl Segment { + pub(super) fn do_search_bm25( + &self, + field: &JsonPath, + query: &str, + filter: Option<&Filter>, + top: usize, + score_threshold: Option, + is_stopped: &AtomicBool, + hw_counter: &HardwareCounterCell, + k1: f32, + b: f32, + ) -> OperationResult> { + if top == 0 { + return Ok(Vec::new()); + } + + let payload_index = self.payload_index.borrow(); + let text_index = payload_index.get_full_text_index(field)?; + let Some(query_tokens) = text_index.parse_bm25_query(query, hw_counter) else { + return Ok(Vec::new()); + }; + + let avgdl = text_index.average_len(); + if avgdl <= 0.0 { + return Ok(Vec::new()); + } + + let filter_context = filter + .map(|filter| payload_index.struct_filtered_context(filter, hw_counter)) + .transpose()?; + + let mut top_k = TopK::new(top); + + for point_id in text_index.bm25_candidates(query, hw_counter) { + check_stopped(is_stopped)?; + + if filter_context + .as_ref() + .is_some_and(|context| !context.check(point_id)) + { + continue; + } + + let score = self.score_bm25_point( + field, + point_id, + text_index, + &query_tokens, + avgdl, + hw_counter, + k1, + b, + )?; + + if score > 0.0 && score_threshold.is_none_or(|threshold| score >= threshold) { + top_k.push(ScoredPointOffset { + idx: point_id, + score, + }); + } + } + + Ok(top_k.into_vec()) + } + + #[allow(clippy::too_many_arguments)] + fn score_bm25_point( + &self, + field: &JsonPath, + point_id: PointOffsetType, + text_index: &crate::index::field_index::full_text_index::text_index::FullTextIndex, + query_tokens: &[u32], + avgdl: f32, + hw_counter: &HardwareCounterCell, + k1: f32, + b: f32, + ) -> OperationResult { + let payload = self.payload_by_offset(point_id, hw_counter)?; + let values = payload.get_value(field); + if values.is_empty() { + return Ok(0.0); + } + + let mut tf: HashMap = query_tokens + .iter() + .copied() + .map(|token| (token, 0)) + .collect(); + let mut doc_len = 0usize; + + for value in values { + match value { + serde_json::Value::String(string) => { + text_index.tokenize_document_text(&string, |token| { + doc_len += 1; + if let Some(token_id) = text_index.token_id(token.as_ref(), hw_counter) + && let Some(count) = tf.get_mut(&token_id) + { + *count += 1; + } + }); + } + serde_json::Value::Array(values) => { + for value in values { + if let serde_json::Value::String(string) = value { + text_index.tokenize_document_text(&string, |token| { + doc_len += 1; + if let Some(token_id) = + text_index.token_id(token.as_ref(), hw_counter) + && let Some(count) = tf.get_mut(&token_id) + { + *count += 1; + } + }); + } + } + } + _ => {} + } + } + + if doc_len == 0 { + return Ok(0.0); + } + + let points_count = text_index.indexed_points_count() as f32; + let doc_len = doc_len as f32; + + let mut score = 0.0; + for token_id in query_tokens { + let term_freq = tf.get(token_id).copied().unwrap_or(0) as f32; + if term_freq == 0.0 { + continue; + } + + let doc_freq = text_index + .get_posting_len(*token_id, hw_counter) + .unwrap_or(0) as f32; + if doc_freq <= 0.0 { + continue; + } + + let idf = ((points_count - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln(); + let norm = k1 * (1.0 - b + b * doc_len / avgdl); + let tf_norm = term_freq * (k1 + 1.0) / (term_freq + norm); + score += idf * tf_norm; + } + + Ok(score) + } +} diff --git a/lib/segment/src/segment/entry.rs b/lib/segment/src/segment/entry.rs index 6b751364161..9d16eb1ca3c 100644 --- a/lib/segment/src/segment/entry.rs +++ b/lib/segment/src/segment/entry.rs @@ -133,6 +133,40 @@ impl ReadSegmentEntry for Segment { ) } + fn search_bm25( + &self, + field: &JsonPath, + query: &str, + filter: Option<&Filter>, + top: usize, + score_threshold: Option, + with_payload: &WithPayload, + with_vector: &WithVector, + hw_counter: &HardwareCounterCell, + is_stopped: &AtomicBool, + k1: f32, + b: f32, + ) -> OperationResult> { + let internal_results = self.do_search_bm25( + field, + query, + filter, + top, + score_threshold, + is_stopped, + hw_counter, + k1, + b, + )?; + self.process_search_result( + internal_results, + with_payload, + with_vector, + hw_counter, + is_stopped, + ) + } + fn vector( &self, vector_name: &VectorName, diff --git a/lib/segment/src/segment/mod.rs b/lib/segment/src/segment/mod.rs index d0f8e28f6e5..c5cc646821e 100644 --- a/lib/segment/src/segment/mod.rs +++ b/lib/segment/src/segment/mod.rs @@ -1,3 +1,4 @@ +mod bm25_search; mod entry; mod facet; mod formula_rescore; diff --git a/lib/shard/src/proxy_segment/segment_entry.rs b/lib/shard/src/proxy_segment/segment_entry.rs index 192ab20d2df..90e152a48c7 100644 --- a/lib/shard/src/proxy_segment/segment_entry.rs +++ b/lib/shard/src/proxy_segment/segment_entry.rs @@ -149,6 +149,44 @@ impl ReadSegmentEntry for ProxySegment { Ok(result) } + fn search_bm25( + &self, + field: &JsonPath, + query: &str, + filter: Option<&Filter>, + top: usize, + score_threshold: Option, + with_payload: &WithPayload, + with_vector: &WithVector, + hw_counter: &HardwareCounterCell, + is_stopped: &AtomicBool, + k1: f32, + b: f32, + ) -> OperationResult> { + let wrapped_filter = if self.deleted_points.is_empty() { + filter.cloned() + } else { + Some(Self::add_deleted_points_condition_to_filter( + filter, + self.deleted_points.keys().copied(), + )) + }; + + self.wrapped_segment.get().read().search_bm25( + field, + query, + wrapped_filter.as_ref(), + top, + score_threshold, + with_payload, + with_vector, + hw_counter, + is_stopped, + k1, + b, + ) + } + fn vector( &self, vector_name: &VectorName, diff --git a/lib/shard/src/query/conversions.rs b/lib/shard/src/query/conversions.rs index 3cb30397277..efc2e4dd8f9 100644 --- a/lib/shard/src/query/conversions.rs +++ b/lib/shard/src/query/conversions.rs @@ -377,6 +377,19 @@ impl ScoringQuery { candidates_limit: candidates_limit as usize, }) } + grpc::query_shard_points::query::Score::Bm25(grpc::Bm25Internal { + field, + query, + k1, + b, + }) => ScoringQuery::Bm25(crate::query::Bm25Internal { + field: field + .parse() + .map_err(|_| tonic::Status::invalid_argument("invalid BM25 field path"))?, + query, + k1: OrderedFloat(k1), + b: OrderedFloat(b), + }), }; Ok(scoring_query) @@ -413,6 +426,19 @@ impl From for grpc::query_shard_points::Query { candidates_limit: candidates_limit as u32, })), }, + ScoringQuery::Bm25(crate::query::Bm25Internal { + field, + query, + k1, + b, + }) => Self { + score: Some(Score::Bm25(grpc::Bm25Internal { + field: field.to_string(), + query, + k1: k1.into_inner(), + b: b.into_inner(), + })), + }, } } } diff --git a/lib/shard/src/query/mod.rs b/lib/shard/src/query/mod.rs index c575c607db6..ed1117278c0 100644 --- a/lib/shard/src/query/mod.rs +++ b/lib/shard/src/query/mod.rs @@ -16,6 +16,7 @@ use ordered_float::OrderedFloat; use segment::data_types::order_by::OrderBy; use segment::data_types::vectors::VectorInternal; use segment::index::query_optimization::rescore_formula::parsed_formula::ParsedFormula; +use segment::json_path::JsonPath; use segment::types::*; use serde::Serialize; @@ -122,6 +123,9 @@ pub enum ScoringQuery { /// Score boosting via an arbitrary formula Formula(ParsedFormula), + /// Score points with BM25 over a full-text payload index. + Bm25(Bm25Internal), + /// Sample points Sample(SampleInternal), @@ -150,7 +154,11 @@ impl ScoringQuery { }, // MMR is a nearest neighbors search before computing diversity at collection level Self::Mmr(_) => false, - Self::Vector(_) | Self::OrderBy(_) | Self::Formula(_) | Self::Sample(_) => false, + Self::Vector(_) + | Self::OrderBy(_) + | Self::Formula(_) + | Self::Bm25(_) + | Self::Sample(_) => false, } } @@ -159,11 +167,23 @@ impl ScoringQuery { match self { Self::Vector(query) => Some(query.get_vector_name()), Self::Mmr(mmr) => Some(&mmr.using), - _ => None, + Self::Fusion(_) + | Self::OrderBy(_) + | Self::Formula(_) + | Self::Bm25(_) + | Self::Sample(_) => None, } } } +#[derive(Clone, Debug, PartialEq, Hash, Serialize)] +pub struct Bm25Internal { + pub field: JsonPath, + pub query: String, + pub k1: OrderedFloat, + pub b: OrderedFloat, +} + #[derive(Clone, Debug, PartialEq, Hash, Serialize)] pub enum FusionInternal { /// Reciprocal Rank Fusion with optional weights per prefetch diff --git a/lib/shard/src/query/planned_query.rs b/lib/shard/src/query/planned_query.rs index 61dfc03b149..271f9358160 100644 --- a/lib/shard/src/query/planned_query.rs +++ b/lib/shard/src/query/planned_query.rs @@ -137,6 +137,7 @@ impl PlannedQuery { | Some(ScoringQuery::Fusion(_)) | Some(ScoringQuery::OrderBy(_)) | Some(ScoringQuery::Formula(_)) + | Some(ScoringQuery::Bm25(_)) | Some(ScoringQuery::Sample(_)) => with_vector, Some(ScoringQuery::Mmr(mmr)) => with_vector.merge(&WithVector::from(mmr.using.clone())), }; @@ -186,6 +187,12 @@ impl PlannedQuery { Some(ScoringQuery::Fusion(_)) => None, // Expect fusion to have prefetches Some(ScoringQuery::OrderBy(_)) => None, Some(ScoringQuery::Formula(_)) => None, + Some(ScoringQuery::Bm25(_)) => Some(RescoreStages::shard_level(RescoreParams { + rescore: query.clone().unwrap(), + limit, + score_threshold: score_threshold.map(OrderedFloat), + params, + })), Some(ScoringQuery::Sample(_)) => None, Some(ScoringQuery::Mmr(_)) => Some(RescoreStages::collection_level(RescoreParams { rescore: query.clone().unwrap(), @@ -196,15 +203,19 @@ impl PlannedQuery { }; // Everything must come from a single source. - let sources = vec![leaf_source_from_scoring_query( - &mut self.searches, - &mut self.scrolls, - query, - limit, - params, - score_threshold, - filter, - )?]; + let sources = if matches!(query, Some(ScoringQuery::Bm25(_))) { + Vec::new() + } else { + vec![leaf_source_from_scoring_query( + &mut self.searches, + &mut self.scrolls, + query, + limit, + params, + score_threshold, + filter, + )?] + }; // Root-level query without prefetches means we won't do any extra rescoring let merge_plan = MergePlan::new(sources, rescore_stages)?; @@ -269,6 +280,7 @@ impl PlannedQuery { rescore @ (ScoringQuery::Vector(_) | ScoringQuery::OrderBy(_) | ScoringQuery::Formula(_) + | ScoringQuery::Bm25(_) | ScoringQuery::Sample(_)) => Some(RescoreStages::shard_level(RescoreParams { rescore, limit, @@ -418,6 +430,11 @@ fn leaf_source_from_scoring_query( "cannot apply Formula without prefetches".to_string(), )); } + Some(ScoringQuery::Bm25(_)) => { + return Err(OperationError::validation_error( + "cannot apply BM25 as a leaf query without rescoring".to_string(), + )); + } Some(ScoringQuery::Sample(SampleInternal::Random)) => { let scroll = QueryScrollRequestInternal { scroll_order: ScrollOrder::Random, diff --git a/lib/shard/src/query/validation.rs b/lib/shard/src/query/validation.rs index 4839b92b247..7ab13ad79de 100644 --- a/lib/shard/src/query/validation.rs +++ b/lib/shard/src/query/validation.rs @@ -66,6 +66,7 @@ fn validate_query(query: &ScoringQuery, sources: &[Source]) -> OperationResult<( ScoringQuery::Fusion(fusion) => validate_fusion(fusion, sources.len()), ScoringQuery::OrderBy(_) => Ok(()), ScoringQuery::Formula(_) => Ok(()), + ScoringQuery::Bm25(_) => Ok(()), ScoringQuery::Sample(_) => Ok(()), ScoringQuery::Mmr(_) => Ok(()), } diff --git a/src/common/inference/batch_processing.rs b/src/common/inference/batch_processing.rs index a1c22770619..b7767996cc5 100644 --- a/src/common/inference/batch_processing.rs +++ b/src/common/inference/batch_processing.rs @@ -112,6 +112,7 @@ fn collect_query(query: &Query, batch: &mut BatchAccum) { | Query::Fusion(_) | Query::Rrf(_) | Query::Formula(_) + | Query::Bm25(_) | Query::Sample(_) => {} } } diff --git a/src/common/inference/query_requests_rest.rs b/src/common/inference/query_requests_rest.rs index 16fd0c866f5..3759b5bfdfd 100644 --- a/src/common/inference/query_requests_rest.rs +++ b/src/common/inference/query_requests_rest.rs @@ -2,8 +2,9 @@ use api::rest::models::InferenceUsage; use api::rest::schema as rest; use collection::lookup::WithLookup; use collection::operations::universal_query::collection_query::{ - CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, FeedbackInternal, - FeedbackStrategy, Mmr, NearestWithMmr, Query, VectorInputInternal, VectorQuery, + Bm25QueryInternal, CollectionPrefetch, CollectionQueryGroupsRequest, CollectionQueryRequest, + FeedbackInternal, FeedbackStrategy, Mmr, NearestWithMmr, Query, VectorInputInternal, + VectorQuery, }; use collection::operations::universal_query::formula::FormulaInternal; use collection::operations::universal_query::shard_query::{FusionInternal, SampleInternal}; @@ -272,6 +273,12 @@ fn convert_query_with_inferred( rest::Query::Fusion(fusion) => Ok(Query::Fusion(FusionInternal::from(fusion.fusion))), rest::Query::Rrf(rrf) => Ok(Query::Fusion(FusionInternal::from(rrf.rrf))), rest::Query::Formula(formula) => Ok(Query::Formula(FormulaInternal::from(formula))), + rest::Query::Bm25(bm25) => Ok(Query::Bm25(Bm25QueryInternal { + field: bm25.field, + query: bm25.query, + k1: bm25.params.as_ref().and_then(|p| p.k1), + b: bm25.params.as_ref().and_then(|p| p.b), + })), rest::Query::Sample(sample) => Ok(Query::Sample(SampleInternal::from(sample.sample))), rest::Query::RelevanceFeedback(relevance_feedback) => { let rest::RelevanceFeedbackInput { @@ -357,7 +364,9 @@ fn context_pair_from_rest_with_inferred( mod tests { use std::collections::HashMap; - use api::rest::schema::{Document, Image, InferenceObject, NearestQuery}; + use api::rest::schema::{ + Bm25Params, Bm25Query, Document, Image, InferenceObject, NearestQuery, + }; use collection::operations::point_ops::VectorPersisted; use serde_json::json; @@ -490,4 +499,28 @@ mod tests { _ => panic!("Expected nearest query"), } } + + #[test] + fn test_convert_query_with_inferred_bm25() { + let inferred = create_test_inferred_batch(); + let query = rest::QueryInterface::Query(rest::Query::Bm25(Bm25Query { + field: "content".parse().unwrap(), + query: "quick brown fox".to_string(), + params: Some(Bm25Params { + k1: Some(1.5), + b: Some(0.6), + }), + })); + + let result = convert_query_with_inferred(query, &inferred).unwrap(); + match result { + Query::Bm25(bm25) => { + assert_eq!(bm25.field, "content".parse().unwrap()); + assert_eq!(bm25.query, "quick brown fox"); + assert_eq!(bm25.k1, Some(1.5)); + assert_eq!(bm25.b, Some(0.6)); + } + _ => panic!("Expected BM25 query"), + } + } }