diff --git a/backend/.env.example b/backend/.env.example index 5ce3d60e..aeb9606d 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -27,6 +27,11 @@ GOOGLE_EMBEDDINGS=gemini-embedding-001 HF_EMBEDDINGS=thenlper/gte-large HF_RERANKER=BAAI/bge-reranker-base +# Reranker type: 'HF' for HuggingFace CrossEncoder, 'VERTEX_AI' for Google Vertex AI Ranking API +RERANKER_TYPE=HF +VERTEX_AI_PROJECT_ID= +VERTEX_AI_LOCATION=global + # FAISS database path FAISS_DB_PATH=./.faissdb/faiss_index diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9f59fa37..49a23006 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "huggingface-hub[cli]==0.34.4", "langchain==0.3.27", "langchain-community==0.3.27", + "langchain-google-community[vertexaisearch]>=2.0.0", "langchain-google-genai==2.1.9", "langchain-google-vertexai==2.0.28", "langchain-huggingface==0.3.1", diff --git a/backend/src/chains/hybrid_retriever_chain.py b/backend/src/chains/hybrid_retriever_chain.py index 1b68c14f..bd51e2a9 100644 --- a/backend/src/chains/hybrid_retriever_chain.py +++ b/backend/src/chains/hybrid_retriever_chain.py @@ -1,4 +1,5 @@ import os +import logging from typing import Optional, Union, Any from langchain.retrievers import EnsembleRetriever @@ -103,6 +104,7 @@ def create_hybrid_retriever(self) -> None: mmr_retriever = mmr_retriever_chain.retriever bm25_retriever_chain = BM25RetrieverChain() + bm25_retriever = None if self.vector_db is not None and self.vector_db.processed_docs: bm25_retriever_chain.create_bm25_retriever( @@ -119,12 +121,42 @@ def create_hybrid_retriever(self) -> None: retrievers=[similarity_retriever, mmr_retriever, bm25_retriever], weights=self.weights, ) + else: + raise ValueError( + "Failed to create ensemble retriever: one or more sub-retrievers " + "could not be initialized. Ensure vector_db has processed documents." + ) if self.contextual_rerank: - compressor = CrossEncoderReranker( - model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name), - top_n=self.search_k, - ) + reranker_type = os.getenv("RERANKER_TYPE", "HF").upper() + + if reranker_type == "VERTEX_AI": + from langchain_google_community.vertex_rank import VertexAIRank + + project_id = os.getenv("VERTEX_AI_PROJECT_ID", "") + location_id = os.getenv("VERTEX_AI_LOCATION", "global") + + if not project_id: + raise ValueError( + "VERTEX_AI_PROJECT_ID must be set when using RERANKER_TYPE=VERTEX_AI" + ) + + compressor = VertexAIRank( + project_id=project_id, + location_id=location_id, + ranking_config="default_ranking_config", + top_n=self.search_k, + ) + logging.info("Using Vertex AI reranker") + else: + compressor = CrossEncoderReranker( + model=HuggingFaceCrossEncoder( + model_name=self.reranking_model_name + ), + top_n=self.search_k, + ) + logging.info("Using HuggingFace CrossEncoder reranker") + self.retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=ensemble_retriever ) diff --git a/backend/tests/test_hybrid_retriever_chain.py b/backend/tests/test_hybrid_retriever_chain.py index f2061ea8..99aab4f3 100644 --- a/backend/tests/test_hybrid_retriever_chain.py +++ b/backend/tests/test_hybrid_retriever_chain.py @@ -141,6 +141,7 @@ def test_create_hybrid_retriever_with_provided_vector_db( assert chain.retriever == mock_ensemble_instance + @patch.dict("os.environ", {"RERANKER_TYPE": "HF"}) @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain") @patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain") @patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain") @@ -158,7 +159,7 @@ def test_create_hybrid_retriever_with_contextual_rerank( mock_mmr_chain, mock_sim_chain, ): - """Test creating hybrid retriever with contextual reranking enabled.""" + """Test creating hybrid retriever with HF contextual reranking enabled.""" mock_vector_db = Mock() mock_vector_db.processed_docs = [Mock(), Mock()] @@ -209,6 +210,118 @@ def test_create_hybrid_retriever_with_contextual_rerank( assert chain.retriever == mock_compression_instance + @patch.dict( + "os.environ", + { + "RERANKER_TYPE": "VERTEX_AI", + "VERTEX_AI_PROJECT_ID": "test-project", + "VERTEX_AI_LOCATION": "global", + }, + ) + @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain") + @patch("src.chains.hybrid_retriever_chain.EnsembleRetriever") + @patch("src.chains.hybrid_retriever_chain.ContextualCompressionRetriever") + def test_create_hybrid_retriever_with_vertex_ai_rerank( + self, + mock_compression, + mock_ensemble, + mock_bm25_chain, + mock_mmr_chain, + mock_sim_chain, + ): + """Test creating hybrid retriever with Vertex AI reranking enabled.""" + mock_vector_db = Mock() + mock_vector_db.processed_docs = [Mock(), Mock()] + + chain = HybridRetrieverChain( + vector_db=mock_vector_db, + contextual_rerank=True, + search_k=5, + ) + + # Setup mocks + mock_sim_instance = Mock() + mock_sim_instance.retriever = Mock() + mock_sim_chain.return_value = mock_sim_instance + + mock_mmr_instance = Mock() + mock_mmr_instance.retriever = Mock() + mock_mmr_chain.return_value = mock_mmr_instance + + mock_bm25_instance = Mock() + mock_bm25_instance.retriever = Mock() + mock_bm25_chain.return_value = mock_bm25_instance + + mock_ensemble_instance = Mock() + mock_ensemble.return_value = mock_ensemble_instance + + mock_compression_instance = Mock() + mock_compression.return_value = mock_compression_instance + + with patch( + "langchain_google_community.vertex_rank.VertexAIRank" + ) as mock_vertex_rank: + mock_vertex_rank_instance = Mock() + mock_vertex_rank.return_value = mock_vertex_rank_instance + + chain.create_hybrid_retriever() + + mock_vertex_rank.assert_called_once_with( + project_id="test-project", + location_id="global", + ranking_config="default_ranking_config", + top_n=5, + ) + mock_compression.assert_called_once_with( + base_compressor=mock_vertex_rank_instance, + base_retriever=mock_ensemble_instance, + ) + + assert chain.retriever == mock_compression_instance + + @patch.dict( + "os.environ", + {"RERANKER_TYPE": "VERTEX_AI", "VERTEX_AI_PROJECT_ID": ""}, + ) + @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.MMRRetrieverChain") + @patch("src.chains.hybrid_retriever_chain.BM25RetrieverChain") + @patch("src.chains.hybrid_retriever_chain.EnsembleRetriever") + def test_vertex_ai_rerank_raises_without_project_id( + self, + mock_ensemble, + mock_bm25_chain, + mock_mmr_chain, + mock_sim_chain, + ): + """Test that Vertex AI reranker raises error without project ID.""" + mock_vector_db = Mock() + mock_vector_db.processed_docs = [Mock(), Mock()] + + chain = HybridRetrieverChain( + vector_db=mock_vector_db, + contextual_rerank=True, + ) + + mock_sim_instance = Mock() + mock_sim_instance.retriever = Mock() + mock_sim_chain.return_value = mock_sim_instance + + mock_mmr_instance = Mock() + mock_mmr_instance.retriever = Mock() + mock_mmr_chain.return_value = mock_mmr_instance + + mock_bm25_instance = Mock() + mock_bm25_instance.retriever = Mock() + mock_bm25_chain.return_value = mock_bm25_instance + + mock_ensemble.return_value = Mock() + + with pytest.raises(ValueError, match="VERTEX_AI_PROJECT_ID must be set"): + chain.create_hybrid_retriever() + @patch("src.chains.hybrid_retriever_chain.os.path.isdir") @patch("src.chains.hybrid_retriever_chain.os.listdir") @patch("src.chains.hybrid_retriever_chain.SimilarityRetrieverChain")