Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ GOOGLE_EMBEDDINGS=gemini-embedding-001
HF_EMBEDDINGS=thenlper/gte-large
HF_RERANKER=BAAI/bge-reranker-base

# Reranker type: 'HF' for HuggingFace CrossEncoder, 'VERTEX_AI' for Google Vertex AI Ranking API
RERANKER_TYPE=HF
VERTEX_AI_PROJECT_ID=
VERTEX_AI_LOCATION=global

# FAISS database path
FAISS_DB_PATH=./.faissdb/faiss_index

Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"huggingface-hub[cli]==0.34.4",
"langchain==0.3.27",
"langchain-community==0.3.27",
"langchain-google-community[vertexaisearch]>=2.0.0",
"langchain-google-genai==2.1.9",
"langchain-google-vertexai==2.0.28",
"langchain-huggingface==0.3.1",
Expand Down
40 changes: 36 additions & 4 deletions backend/src/chains/hybrid_retriever_chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import logging
from typing import Optional, Union, Any

from langchain.retrievers import EnsembleRetriever
Expand Down Expand Up @@ -103,6 +104,7 @@ def create_hybrid_retriever(self) -> None:
mmr_retriever = mmr_retriever_chain.retriever

bm25_retriever_chain = BM25RetrieverChain()
bm25_retriever = None

if self.vector_db is not None and self.vector_db.processed_docs:
bm25_retriever_chain.create_bm25_retriever(
Expand All @@ -119,12 +121,42 @@ def create_hybrid_retriever(self) -> None:
retrievers=[similarity_retriever, mmr_retriever, bm25_retriever],
weights=self.weights,
)
else:
raise ValueError(
"Failed to create ensemble retriever: one or more sub-retrievers "
"could not be initialized. Ensure vector_db has processed documents."
)

if self.contextual_rerank:
compressor = CrossEncoderReranker(
model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name),
top_n=self.search_k,
)
reranker_type = os.getenv("RERANKER_TYPE", "HF").upper()

if reranker_type == "VERTEX_AI":
from langchain_google_community.vertex_rank import VertexAIRank

project_id = os.getenv("VERTEX_AI_PROJECT_ID", "")
location_id = os.getenv("VERTEX_AI_LOCATION", "global")

if not project_id:
raise ValueError(
"VERTEX_AI_PROJECT_ID must be set when using RERANKER_TYPE=VERTEX_AI"
)

compressor = VertexAIRank(
project_id=project_id,
location_id=location_id,
ranking_config="default_ranking_config",
top_n=self.search_k,
)
logging.info("Using Vertex AI reranker")
else:
compressor = CrossEncoderReranker(
model=HuggingFaceCrossEncoder(
model_name=self.reranking_model_name
),
top_n=self.search_k,
)
logging.info("Using HuggingFace CrossEncoder reranker")

self.retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
Comment on lines +158 to 162
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensemble_retriever can be referenced here even when it was never assigned. If self.vector_db is None or processed_docs is empty, bm25_retriever is never set, which prevents the ensemble_retriever = EnsembleRetriever(...) block from running; later this code still uses ensemble_retriever, leading to UnboundLocalError. Initialize bm25_retriever/ensemble_retriever to None and either raise a clear error when the ensemble cannot be constructed or provide a fallback retriever composition before reaching this point.

Copilot uses AI. Check for mistakes.
Expand Down
115 changes: 114 additions & 1 deletion backend/tests/test_hybrid_retriever_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_create_hybrid_retriever_with_provided_vector_db(

assert chain.retriever == mock_ensemble_instance

@patch.dict("os.environ", {"RERANKER_TYPE": "HF"})
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain")
Expand All @@ -158,7 +159,7 @@ def test_create_hybrid_retriever_with_contextual_rerank(
mock_mmr_chain,
mock_sim_chain,
):
"""Test creating hybrid retriever with contextual reranking enabled."""
"""Test creating hybrid retriever with HF contextual reranking enabled."""
mock_vector_db = Mock()
mock_vector_db.processed_docs = [Mock(), Mock()]

Expand Down Expand Up @@ -209,6 +210,118 @@ def test_create_hybrid_retriever_with_contextual_rerank(

assert chain.retriever == mock_compression_instance

@patch.dict(
"os.environ",
{
"RERANKER_TYPE": "VERTEX_AI",
"VERTEX_AI_PROJECT_ID": "test-project",
"VERTEX_AI_LOCATION": "global",
},
)
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain")
@patch("src.chains.hybrid_retriever_chain.EnsembleRetriever")
@patch("src.chains.hybrid_retriever_chain.ContextualCompressionRetriever")
def test_create_hybrid_retriever_with_vertex_ai_rerank(
self,
mock_compression,
mock_ensemble,
mock_bm25_chain,
mock_mmr_chain,
mock_sim_chain,
):
"""Test creating hybrid retriever with Vertex AI reranking enabled."""
mock_vector_db = Mock()
mock_vector_db.processed_docs = [Mock(), Mock()]

chain = HybridRetrieverChain(
vector_db=mock_vector_db,
contextual_rerank=True,
search_k=5,
)

# Setup mocks
mock_sim_instance = Mock()
mock_sim_instance.retriever = Mock()
mock_sim_chain.return_value = mock_sim_instance

mock_mmr_instance = Mock()
mock_mmr_instance.retriever = Mock()
mock_mmr_chain.return_value = mock_mmr_instance

mock_bm25_instance = Mock()
mock_bm25_instance.retriever = Mock()
mock_bm25_chain.return_value = mock_bm25_instance

mock_ensemble_instance = Mock()
mock_ensemble.return_value = mock_ensemble_instance

mock_compression_instance = Mock()
mock_compression.return_value = mock_compression_instance

with patch(
"langchain_google_community.vertex_rank.VertexAIRank"
) as mock_vertex_rank:
mock_vertex_rank_instance = Mock()
mock_vertex_rank.return_value = mock_vertex_rank_instance

chain.create_hybrid_retriever()

mock_vertex_rank.assert_called_once_with(
project_id="test-project",
location_id="global",
ranking_config="default_ranking_config",
top_n=5,
)
mock_compression.assert_called_once_with(
base_compressor=mock_vertex_rank_instance,
base_retriever=mock_ensemble_instance,
)

assert chain.retriever == mock_compression_instance

@patch.dict(
"os.environ",
{"RERANKER_TYPE": "VERTEX_AI", "VERTEX_AI_PROJECT_ID": ""},
)
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain")
@patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain")
@patch("src.chains.hybrid_retriever_chain.EnsembleRetriever")
def test_vertex_ai_rerank_raises_without_project_id(
self,
mock_ensemble,
mock_bm25_chain,
mock_mmr_chain,
mock_sim_chain,
):
"""Test that Vertex AI reranker raises error without project ID."""
mock_vector_db = Mock()
mock_vector_db.processed_docs = [Mock(), Mock()]

chain = HybridRetrieverChain(
vector_db=mock_vector_db,
contextual_rerank=True,
)

mock_sim_instance = Mock()
mock_sim_instance.retriever = Mock()
mock_sim_chain.return_value = mock_sim_instance

mock_mmr_instance = Mock()
mock_mmr_instance.retriever = Mock()
mock_mmr_chain.return_value = mock_mmr_instance

mock_bm25_instance = Mock()
mock_bm25_instance.retriever = Mock()
mock_bm25_chain.return_value = mock_bm25_instance

mock_ensemble.return_value = Mock()

with pytest.raises(ValueError, match="VERTEX_AI_PROJECT_ID must be set"):
chain.create_hybrid_retriever()

@patch("src.chains.hybrid_retriever_chain.os.path.isdir")
@patch("src.chains.hybrid_retriever_chain.os.listdir")
@patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")
Expand Down