diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index a0521270e2..dc551a85bc 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from warnings import warn from astrapy import DataAPIClient as AstraDBClient @@ -320,31 +320,36 @@ def delete( self, *, ids: Optional[List[str]] = None, - delete_all: Optional[bool] = None, filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, ) -> int: """Delete documents from the Astra index. :param ids: the ids of the documents to delete - :param delete_all: if `True`, delete all documents from the index :param filters: additional filters to apply when deleting documents :returns: the number of documents deleted """ - if delete_all: - query = {"deleteMany": {}} # type: dict + query: Dict[str, Dict[str, Any]] = {} + if ids is not None: query = {"deleteMany": {"filter": {"_id": {"$in": ids}}}} if filters is not None: query = {"deleteMany": {"filter": filters}} filter_dict = {} - if "filter" in query["deleteMany"]: - filter_dict = query["deleteMany"]["filter"] - + filter_dict = query.get("deleteMany", {}).get("filter", {}) delete_result = self._astra_db_collection.delete_many(filter=filter_dict) return delete_result.deleted_count + def delete_all_documents(self) -> int: + """ + Delete all documents from the Astra index. + :returns: the number of documents deleted + """ + delete_result = self._astra_db_collection.delete_many(filter={}) + + return delete_result.deleted_count + def count_documents(self, upper_bound: int = 10000) -> int: """ Count the number of documents in the Astra index. diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 8950c83556..90224319f7 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -5,7 +5,7 @@ from haystack import default_from_dict, default_to_dict, logging from haystack.dataclasses import Document -from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError, MissingDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace @@ -395,12 +395,7 @@ def search( return result - def delete_documents( - self, - document_ids: Optional[List[str]] = None, - *, - delete_all: Optional[bool] = None, - ) -> None: + def delete_documents(self, document_ids: List[str]) -> None: """ Deletes documents from the document store. @@ -413,8 +408,6 @@ def delete_documents( if document_ids is not None: for batch in _batches(document_ids, MAX_BATCH_SIZE): deletion_counter += self.index.delete(ids=batch) - else: - deletion_counter = self.index.delete(delete_all=delete_all) logger.info(f"{deletion_counter} documents deleted") if document_ids is not None and deletion_counter == 0: @@ -422,3 +415,20 @@ def delete_documents( raise MissingDocumentError(msg) else: logger.info("No documents in document store") + + def delete_all_documents(self) -> None: + """ + Deletes all documents from the document store. + """ + deletion_counter = 0 + + try: + deletion_counter = self.index.delete_all_documents() + except Exception as e: + msg = f"Failed to delete all documents from Astra: {e!s}" + raise DocumentStoreError(msg) from e + + if deletion_counter == -1: + logger.info("All documents deleted") + else: + logger.error("Could not delete all documents") diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index 70bd92ce3e..77507344fe 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -54,7 +54,7 @@ class TestDocumentStore(DocumentStoreBaseTests): you can add more to this class. """ - @pytest.fixture + @pytest.fixture(scope="class") def document_store(self) -> AstraDocumentStore: return AstraDocumentStore( collection_name="haystack_integration", @@ -63,11 +63,11 @@ def document_store(self) -> AstraDocumentStore: ) @pytest.fixture(autouse=True) - def run_before_and_after_tests(self, document_store: AstraDocumentStore): + def run_before_tests(self, document_store: AstraDocumentStore): """ Cleaning up document store """ - document_store.delete_documents(delete_all=True) + document_store.delete_all_documents() assert document_store.count_documents() == 0 def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): @@ -136,8 +136,7 @@ def test_delete_documents_more_than_twenty_delete_all(self, document_store: Astr document_store.write_documents(docs) assert document_store.count_documents() == 25 - document_store.delete_documents(delete_all=True) - + document_store.delete_all_documents() assert document_store.count_documents() == 0 def test_delete_documents_more_than_twenty_delete_ids(self, document_store: AstraDocumentStore): @@ -205,6 +204,13 @@ def test_filter_documents_by_in_operator(self, document_store): self.assert_documents_are_equal([result[0]], [docs[0]]) self.assert_documents_are_equal([result[1]], [docs[1]]) + def test_delete_all_documents(self, document_store: AstraDocumentStore): + """ + Test delete_all_documents() on an Astra. + """ + document_store.delete_all_documents() + assert document_store.count_documents() == 0 + @pytest.mark.skip(reason="Unsupported filter operator not.") def test_not_operator(self, document_store, filterable_docs): pass diff --git a/integrations/astra/tests/test_embedding_retrieval.py b/integrations/astra/tests/test_embedding_retrieval.py index 24814c9724..c2248b6a3f 100644 --- a/integrations/astra/tests/test_embedding_retrieval.py +++ b/integrations/astra/tests/test_embedding_retrieval.py @@ -13,7 +13,7 @@ ) @pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set") class TestEmbeddingRetrieval: - @pytest.fixture + @pytest.fixture(scope="class") def document_store(self) -> AstraDocumentStore: return AstraDocumentStore( collection_name="haystack_integration", @@ -22,11 +22,11 @@ def document_store(self) -> AstraDocumentStore: ) @pytest.fixture(autouse=True) - def run_before_and_after_tests(self, document_store: AstraDocumentStore): + def run_before_tests(self, document_store: AstraDocumentStore): """ Cleaning up document store """ - document_store.delete_documents(delete_all=True) + document_store.delete_all_documents() assert document_store.count_documents() == 0 def test_search_with_top_k(self, document_store): @@ -45,3 +45,6 @@ def test_search_with_top_k(self, document_store): for document in result: assert document.score is not None + + document_store.delete_all_documents() + assert document_store.count_documents() == 0