Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from warnings import warn

from astrapy import DataAPIClient as AstraDBClient
Expand Down Expand Up @@ -320,31 +320,36 @@ def delete(
self,
*,
ids: Optional[List[str]] = None,
delete_all: Optional[bool] = None,
filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
) -> int:
"""Delete documents from the Astra index.

:param ids: the ids of the documents to delete
:param delete_all: if `True`, delete all documents from the index
:param filters: additional filters to apply when deleting documents
:returns: the number of documents deleted
"""
if delete_all:
query = {"deleteMany": {}} # type: dict
query: Dict[str, Dict[str, Any]] = {}

if ids is not None:
query = {"deleteMany": {"filter": {"_id": {"$in": ids}}}}
if filters is not None:
query = {"deleteMany": {"filter": filters}}

filter_dict = {}
if "filter" in query["deleteMany"]:
filter_dict = query["deleteMany"]["filter"]

filter_dict = query.get("deleteMany", {}).get("filter", {})
delete_result = self._astra_db_collection.delete_many(filter=filter_dict)

return delete_result.deleted_count

def delete_all_documents(self) -> int:
"""
Delete all documents from the Astra index.
:returns: the number of documents deleted
"""
delete_result = self._astra_db_collection.delete_many(filter={})

return delete_result.deleted_count

def count_documents(self, upper_bound: int = 10000) -> int:
"""
Count the number of documents in the Astra index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError, MissingDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret, deserialize_secrets_inplace

Expand Down Expand Up @@ -395,12 +395,7 @@ def search(

return result

def delete_documents(
self,
document_ids: Optional[List[str]] = None,
*,
delete_all: Optional[bool] = None,
) -> None:
def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes documents from the document store.

Expand All @@ -413,12 +408,27 @@ def delete_documents(
if document_ids is not None:
for batch in _batches(document_ids, MAX_BATCH_SIZE):
deletion_counter += self.index.delete(ids=batch)
else:
deletion_counter = self.index.delete(delete_all=delete_all)
logger.info(f"{deletion_counter} documents deleted")

if document_ids is not None and deletion_counter == 0:
msg = f"Document {document_ids} does not exist"
raise MissingDocumentError(msg)
else:
logger.info("No documents in document store")

def delete_all_documents(self) -> None:
"""
Deletes all documents from the document store.
"""
deletion_counter = 0

try:
deletion_counter = self.index.delete_all_documents()
except Exception as e:
msg = f"Failed to delete all documents from Astra: {e!s}"
raise DocumentStoreError(msg) from e

if deletion_counter == -1:
logger.info("All documents deleted")
else:
logger.error("Could not delete all documents")
16 changes: 11 additions & 5 deletions integrations/astra/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TestDocumentStore(DocumentStoreBaseTests):
you can add more to this class.
"""

@pytest.fixture
@pytest.fixture(scope="class")
def document_store(self) -> AstraDocumentStore:
return AstraDocumentStore(
collection_name="haystack_integration",
Expand All @@ -63,11 +63,11 @@ def document_store(self) -> AstraDocumentStore:
)

@pytest.fixture(autouse=True)
def run_before_and_after_tests(self, document_store: AstraDocumentStore):
def run_before_tests(self, document_store: AstraDocumentStore):
"""
Cleaning up document store
"""
document_store.delete_documents(delete_all=True)
document_store.delete_all_documents()
assert document_store.count_documents() == 0

def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
Expand Down Expand Up @@ -136,8 +136,7 @@ def test_delete_documents_more_than_twenty_delete_all(self, document_store: Astr
document_store.write_documents(docs)
assert document_store.count_documents() == 25

document_store.delete_documents(delete_all=True)

document_store.delete_all_documents()
assert document_store.count_documents() == 0

def test_delete_documents_more_than_twenty_delete_ids(self, document_store: AstraDocumentStore):
Expand Down Expand Up @@ -205,6 +204,13 @@ def test_filter_documents_by_in_operator(self, document_store):
self.assert_documents_are_equal([result[0]], [docs[0]])
self.assert_documents_are_equal([result[1]], [docs[1]])

def test_delete_all_documents(self, document_store: AstraDocumentStore):
"""
Test delete_all_documents() on an Astra.
"""
document_store.delete_all_documents()
assert document_store.count_documents() == 0

@pytest.mark.skip(reason="Unsupported filter operator not.")
def test_not_operator(self, document_store, filterable_docs):
pass
Expand Down
9 changes: 6 additions & 3 deletions integrations/astra/tests/test_embedding_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set")
class TestEmbeddingRetrieval:
@pytest.fixture
@pytest.fixture(scope="class")
def document_store(self) -> AstraDocumentStore:
return AstraDocumentStore(
collection_name="haystack_integration",
Expand All @@ -22,11 +22,11 @@ def document_store(self) -> AstraDocumentStore:
)

@pytest.fixture(autouse=True)
def run_before_and_after_tests(self, document_store: AstraDocumentStore):
def run_before_tests(self, document_store: AstraDocumentStore):
"""
Cleaning up document store
"""
document_store.delete_documents(delete_all=True)
document_store.delete_all_documents()
assert document_store.count_documents() == 0

def test_search_with_top_k(self, document_store):
Expand All @@ -45,3 +45,6 @@ def test_search_with_top_k(self, document_store):

for document in result:
assert document.score is not None

document_store.delete_all_documents()
assert document_store.count_documents() == 0