diff --git a/docs/config/yaml.md b/docs/config/yaml.md index ace57e3b1c..941e938a6d 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -81,7 +81,7 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th #### Fields - `storage` **StorageConfig** - - `type` **file|blob|cosmosdb** - The storage type to use. Default=`file` + - `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -115,7 +115,7 @@ This section controls the storage mechanism used by the pipeline used for export #### Fields -- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file` +- `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. @@ -128,7 +128,7 @@ The section defines a secondary storage location for running incremental indexin #### Fields -- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file` +- `type` **FileStorage|AzureBlobStorage|AzureCosmosStorage** - The storage type to use. Default=`FileStorage` - `base_dir` **str** - The base directory to write output artifacts to, relative to the root. - `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string. - `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name. diff --git a/packages/graphrag-storage/README.md b/packages/graphrag-storage/README.md new file mode 100644 index 0000000000..cd31bb0658 --- /dev/null +++ b/packages/graphrag-storage/README.md @@ -0,0 +1,58 @@ +# GraphRAG Storage + +## Basic + +```python +import asyncio +from graphrag_storage import StorageConfig, create_storage +from graphrag_storage.file_storage import FileStorage + +async def run(): + storage = create_storage( + StorageConfig( + type="FileStorage", # or FileStorage.__name__ + base_dir="output" + ) + ) + + await storage.set("my_key", "value") + print(await storage.get("my_key")) + +if __name__ == "__main__": + asyncio.run(run()) +``` + +## Custom Storage + +```python +import asyncio +from typing import Any +from graphrag_storage import Storage, StorageConfig, create_storage, register_storage + +class MyStorage(Storage): + def __init__(self, some_setting: str, **kwargs: Any): + # Validate settings and initialize + ... + + #Implement rest of interface + ... + +register_storage("MyStorage", MyStorage) + +async def run(): + storage = create_storage( + StorageConfig( + type="MyStorage" + some_setting="My Setting" + ) + ) + # Or use the factory directly to instantiate with a dict instead of using + # StorageConfig + create_factory + # from graphrag_storage.storage_factory import storage_factory + # storage = storage_factory.create(strategy="MyStorage", init_args={"some_setting": "My Setting"}) + + await storage.set("my_key", "value") + print(await storage.get("my_key")) + +if __name__ == "__main__": + asyncio.run(run()) diff --git a/packages/graphrag-storage/graphrag_storage/__init__.py b/packages/graphrag-storage/graphrag_storage/__init__.py new file mode 100644 index 0000000000..0684dfb889 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The GraphRAG Storage package.""" + +from graphrag_storage.storage import Storage +from graphrag_storage.storage_config import StorageConfig +from graphrag_storage.storage_factory import create_storage, register_storage + +__all__ = [ + "Storage", + "StorageConfig", + "create_storage", + "register_storage", +] diff --git a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py similarity index 81% rename from packages/graphrag/graphrag/storage/blob_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/azure_blob_storage.py index 1435cb387d..9028259bf8 100644 --- a/packages/graphrag/graphrag/storage/blob_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_blob_storage.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Azure Blob Storage implementation of PipelineStorage.""" +"""Azure Blob Storage implementation of Storage.""" import logging import re @@ -12,15 +12,15 @@ from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient -from graphrag.storage.pipeline_storage import ( - PipelineStorage, +from graphrag_storage.storage import ( + Storage, get_timestamp_formatted_with_local_tz, ) logger = logging.getLogger(__name__) -class BlobPipelineStorage(PipelineStorage): +class AzureBlobStorage(Storage): """The Blob-Storage implementation.""" _connection_string: str | None @@ -28,20 +28,36 @@ class BlobPipelineStorage(PipelineStorage): _base_dir: str | None _encoding: str _storage_account_blob_url: str | None + _blob_service_client: BlobServiceClient + _storage_account_name: str | None - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, + base_dir: str | None = None, + connection_string: str | None = None, + storage_account_blob_url: str | None = None, + container_name: str | None = None, + encoding: str = "utf-8", + **kwargs: Any, + ) -> None: """Create a new BlobStorage instance.""" - connection_string = kwargs.get("connection_string") - storage_account_blob_url = kwargs.get("storage_account_blob_url") - base_dir = kwargs.get("base_dir") - container_name = kwargs["container_name"] - if container_name is None: - msg = "No container name provided for blob storage." - raise ValueError(msg) if connection_string is None and storage_account_blob_url is None: - msg = "No storage account blob url provided for blob storage." + msg = "AzureBlobStorage requires either a connection_string or storage_account_blob_url to be specified." + logger.error(msg) + raise ValueError(msg) + + if connection_string is not None and storage_account_blob_url is not None: + msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both." + logger.error(msg) raise ValueError(msg) + if container_name is None: + msg = "AzureBlobStorage requires a container_name to be specified." + logger.error(msg) + raise ValueError(msg) + + _validate_blob_container_name(container_name) + logger.info( "Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir ) @@ -49,16 +65,12 @@ def __init__(self, **kwargs: Any) -> None: self._blob_service_client = BlobServiceClient.from_connection_string( connection_string ) - else: - if storage_account_blob_url is None: - msg = "Either connection_string or storage_account_blob_url must be provided." - raise ValueError(msg) - + elif storage_account_blob_url: self._blob_service_client = BlobServiceClient( account_url=storage_account_blob_url, credential=DefaultAzureCredential(), ) - self._encoding = kwargs.get("encoding", "utf-8") + self._encoding = encoding self._container_name = container_name self._connection_string = connection_string self._base_dir = base_dir @@ -208,12 +220,12 @@ async def delete(self, key: str) -> None: async def clear(self) -> None: """Clear the cache.""" - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" if name is None: return self path = str(Path(self._base_dir) / name) if self._base_dir else name - return BlobPipelineStorage( + return AzureBlobStorage( connection_string=self._connection_string, container_name=self._container_name, encoding=self._encoding, @@ -245,7 +257,7 @@ async def get_creation_date(self, key: str) -> str: return "" -def validate_blob_container_name(container_name: str): +def _validate_blob_container_name(container_name: str) -> None: """ Check if the provided blob container name is valid based on Azure rules. @@ -267,32 +279,25 @@ def validate_blob_container_name(container_name: str): """ # Check the length of the name if len(container_name) < 3 or len(container_name) > 63: - return ValueError( - f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long." - ) + msg = f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long." + raise ValueError(msg) # Check if the name starts with a letter or number if not container_name[0].isalnum(): - return ValueError( - f"Container name must start with a letter or number. Starting character was {container_name[0]}." - ) + msg = f"Container name must start with a letter or number. Starting character was {container_name[0]}." + raise ValueError(msg) # Check for valid characters (letters, numbers, hyphen) and lowercase letters if not re.match(r"^[a-z0-9-]+$", container_name): - return ValueError( - f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}." - ) + msg = f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}." + raise ValueError(msg) # Check for consecutive hyphens if "--" in container_name: - return ValueError( - f"Container name cannot contain consecutive hyphens. Name provided was {container_name}." - ) + msg = f"Container name cannot contain consecutive hyphens. Name provided was {container_name}." + raise ValueError(msg) # Check for hyphens at the end of the name if container_name[-1] == "-": - return ValueError( - f"Container name cannot end with a hyphen. Name provided was {container_name}." - ) - - return True + msg = f"Container name cannot end with a hyphen. Name provided was {container_name}." + raise ValueError(msg) diff --git a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py similarity index 91% rename from packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index a12da0ee5f..4e4e034eb7 100644 --- a/packages/graphrag/graphrag/storage/cosmosdb_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -16,17 +16,17 @@ from azure.cosmos.exceptions import CosmosResourceNotFoundError from azure.cosmos.partition_key import PartitionKey from azure.identity import DefaultAzureCredential - from graphrag.logger.progress import Progress -from graphrag.storage.pipeline_storage import ( - PipelineStorage, + +from graphrag_storage.storage import ( + Storage, get_timestamp_formatted_with_local_tz, ) logger = logging.getLogger(__name__) -class CosmosDBPipelineStorage(PipelineStorage): +class AzureCosmosStorage(Storage): """The CosmosDB-Storage Implementation.""" _cosmos_client: CosmosClient @@ -39,28 +39,40 @@ class CosmosDBPipelineStorage(PipelineStorage): _encoding: str _no_id_prefixes: list[str] - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, + base_dir: str | None = None, + container_name: str | None = None, + connection_string: str | None = None, + cosmosdb_account_url: str | None = None, + **kwargs: Any, + ) -> None: """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") - cosmosdb_account_url = kwargs.get("cosmosdb_account_url") - connection_string = kwargs.get("connection_string") - database_name = kwargs["base_dir"] - container_name = kwargs["container_name"] - if not database_name: - msg = "No base_dir provided for database name" + database_name = base_dir + if database_name is None: + msg = "CosmosDB Storage requires a base_dir to be specified. This is used as the database name." + logger.error(msg) raise ValueError(msg) + if connection_string is None and cosmosdb_account_url is None: - msg = "connection_string or cosmosdb_account_url is required." + msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified." + logger.error(msg) + raise ValueError(msg) + + if connection_string is not None and cosmosdb_account_url is not None: + msg = "CosmosDB Storage requires either a connection_string or cosmosdb_account_url to be specified, not both." + logger.error(msg) + raise ValueError(msg) + + if container_name is None: + msg = "CosmosDB Storage requires a container_name to be specified." + logger.error(msg) raise ValueError(msg) if connection_string: self._cosmos_client = CosmosClient.from_connection_string(connection_string) - else: - if cosmosdb_account_url is None: - msg = ( - "Either connection_string or cosmosdb_account_url must be provided." - ) - raise ValueError(msg) + elif cosmosdb_account_url: self._cosmos_client = CosmosClient( url=cosmosdb_account_url, credential=DefaultAzureCredential(), @@ -307,7 +319,7 @@ def keys(self) -> list[str]: msg = "CosmosDB storage does yet not support listing keys." raise NotImplementedError(msg) - def child(self, name: str | None) -> PipelineStorage: + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" return self diff --git a/packages/graphrag/graphrag/storage/file_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/file_storage.py similarity index 72% rename from packages/graphrag/graphrag/storage/file_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/file_storage.py index 52402c8bd6..61cb922ec6 100644 --- a/packages/graphrag/graphrag/storage/file_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/file_storage.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""File-based Storage implementation of PipelineStorage.""" +"""File-based Storage implementation of Storage.""" import logging import os @@ -16,26 +16,33 @@ from aiofiles.os import remove from aiofiles.ospath import exists -from graphrag.storage.pipeline_storage import ( - PipelineStorage, +from graphrag_storage.storage import ( + Storage, get_timestamp_formatted_with_local_tz, ) logger = logging.getLogger(__name__) -class FilePipelineStorage(PipelineStorage): +class FileStorage(Storage): """File storage class definition.""" - _base_dir: str + _base_dir: Path _encoding: str - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, base_dir: str | None = "", encoding: str = "utf-8", **kwargs: Any + ) -> None: """Create a file based storage.""" - self._base_dir = kwargs.get("base_dir", "") - self._encoding = kwargs.get("encoding", "utf-8") + if base_dir is None: + msg = "FileStorage requires a base_dir to be specified." + logger.error(msg) + raise ValueError(msg) + + self._base_dir = Path(base_dir).resolve() + self._encoding = encoding logger.info("Creating file storage at [%s]", self._base_dir) - Path(self._base_dir).mkdir(parents=True, exist_ok=True) + self._base_dir.mkdir(parents=True, exist_ok=True) def find( self, @@ -45,7 +52,7 @@ def find( logger.info( "Search [%s] for files matching [%s]", self._base_dir, file_pattern.pattern ) - all_files = list(Path(self._base_dir).rglob("**/*")) + all_files = list(self._base_dir.rglob("**/*")) logger.debug("All files and folders: %s", [file.name for file in all_files]) num_loaded = 0 num_total = len(all_files) @@ -53,7 +60,7 @@ def find( for file in all_files: match = file_pattern.search(f"{file}") if match: - filename = f"{file}".replace(str(Path(self._base_dir)), "", 1) + filename = f"{file}".replace(str(self._base_dir), "", 1) if filename.startswith(os.sep): filename = filename[1:] yield filename @@ -71,7 +78,7 @@ async def get( self, key: str, as_bytes: bool | None = False, encoding: str | None = None ) -> Any: """Get method definition.""" - file_path = join_path(self._base_dir, key) + file_path = _join_path(self._base_dir, key) if await self.has(key): return await self._read_file(file_path, as_bytes, encoding) @@ -101,7 +108,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: write_type = "wb" if is_bytes else "w" encoding = None if is_bytes else encoding or self._encoding async with aiofiles.open( - join_path(self._base_dir, key), + _join_path(self._base_dir, key), cast("Any", write_type), encoding=encoding, ) as f: @@ -109,35 +116,35 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: async def has(self, key: str) -> bool: """Has method definition.""" - return await exists(join_path(self._base_dir, key)) + return await exists(_join_path(self._base_dir, key)) async def delete(self, key: str) -> None: """Delete method definition.""" if await self.has(key): - await remove(join_path(self._base_dir, key)) + await remove(_join_path(self._base_dir, key)) async def clear(self) -> None: """Clear method definition.""" - for file in Path(self._base_dir).glob("*"): + for file in self._base_dir.glob("*"): if file.is_dir(): shutil.rmtree(file) else: file.unlink() - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" if name is None: return self - child_path = str(Path(self._base_dir) / Path(name)) - return FilePipelineStorage(base_dir=child_path, encoding=self._encoding) + child_path = str(self._base_dir / name) + return FileStorage(base_dir=child_path, encoding=self._encoding) def keys(self) -> list[str]: """Return the keys in the storage.""" - return [item.name for item in Path(self._base_dir).iterdir() if item.is_file()] + return [item.name for item in self._base_dir.iterdir() if item.is_file()] async def get_creation_date(self, key: str) -> str: """Get the creation date of a file.""" - file_path = Path(join_path(self._base_dir, key)) + file_path = _join_path(self._base_dir, key) creation_timestamp = file_path.stat().st_ctime creation_time_utc = datetime.fromtimestamp(creation_timestamp, tz=timezone.utc) @@ -145,6 +152,6 @@ async def get_creation_date(self, key: str) -> str: return get_timestamp_formatted_with_local_tz(creation_time_utc) -def join_path(file_path: str, file_name: str) -> Path: +def _join_path(file_path: Path, file_name: str) -> Path: """Join a path and a file. Independent of the OS.""" - return Path(file_path) / Path(file_name).parent / Path(file_name).name + return (file_path / Path(file_name).parent / Path(file_name).name).resolve() diff --git a/packages/graphrag/graphrag/storage/memory_pipeline_storage.py b/packages/graphrag-storage/graphrag_storage/memory_storage.py similarity index 82% rename from packages/graphrag/graphrag/storage/memory_pipeline_storage.py rename to packages/graphrag-storage/graphrag_storage/memory_storage.py index 3567e3d1e3..7908d98a35 100644 --- a/packages/graphrag/graphrag/storage/memory_pipeline_storage.py +++ b/packages/graphrag-storage/graphrag_storage/memory_storage.py @@ -1,24 +1,24 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A module containing 'InMemoryStorage' model.""" +"""In-memory storage implementation.""" from typing import TYPE_CHECKING, Any -from graphrag.storage.file_pipeline_storage import FilePipelineStorage +from graphrag_storage.file_storage import FileStorage if TYPE_CHECKING: - from graphrag.storage.pipeline_storage import PipelineStorage + from graphrag_storage.storage import Storage -class MemoryPipelineStorage(FilePipelineStorage): +class MemoryStorage(FileStorage): """In memory storage class definition.""" _storage: dict[str, Any] - def __init__(self): + def __init__(self, **kwargs: Any) -> None: """Init method definition.""" - super().__init__() + super().__init__(**kwargs) self._storage = {} async def get( @@ -69,9 +69,9 @@ async def clear(self) -> None: """Clear the storage.""" self._storage.clear() - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": """Create a child storage instance.""" - return MemoryPipelineStorage() + return MemoryStorage() def keys(self) -> list[str]: """Return the keys in the storage.""" diff --git a/packages/graphrag-storage/graphrag_storage/py.typed b/packages/graphrag-storage/graphrag_storage/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/graphrag-storage/graphrag_storage/storage.py b/packages/graphrag-storage/graphrag_storage/storage.py new file mode 100644 index 0000000000..d8016d4ae7 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/storage.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Abstract base class for storage.""" + +import re +from abc import ABC, abstractmethod +from collections.abc import Iterator +from datetime import datetime +from typing import Any + + +class Storage(ABC): + """Provide a storage interface.""" + + @abstractmethod + def __init__(self, **kwargs: Any) -> None: + """Create a storage instance.""" + + @abstractmethod + def find( + self, + file_pattern: re.Pattern[str], + ) -> Iterator[str]: + """Find files in the storage using a file pattern. + + Args + ---- + - file_pattern: re.Pattern[str] + The file pattern to use for finding files. + + Returns + ------- + Iterator[str]: + An iterator over the found file keys. + + """ + + @abstractmethod + async def get( + self, key: str, as_bytes: bool | None = None, encoding: str | None = None + ) -> Any: + """Get the value for the given key. + + Args + ---- + - key: str + The key to get the value for. + - as_bytes: bool | None, optional (default=None) + Whether or not to return the value as bytes. + - encoding: str | None, optional (default=None) + The encoding to use when decoding the value. + + Returns + ------- + Any: + The value for the given key. + """ + + @abstractmethod + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Set the value for the given key. + + Args + ---- + - key: str + The key to set the value for. + - value: Any + The value to set. + """ + + @abstractmethod + async def has(self, key: str) -> bool: + """Return True if the given key exists in the storage. + + Args + ---- + - key: str + The key to check for. + + Returns + ------- + bool: + True if the key exists in the storage, False otherwise. + """ + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete the given key from the storage. + + Args + ---- + - key: str + The key to delete. + """ + + @abstractmethod + async def clear(self) -> None: + """Clear the storage.""" + + @abstractmethod + def child(self, name: str | None) -> "Storage": + """Create a child storage instance. + + Args + ---- + - name: str | None + The name of the child storage. + + Returns + ------- + Storage + The child storage instance. + + """ + + @abstractmethod + def keys(self) -> list[str]: + """List all keys in the storage.""" + + @abstractmethod + async def get_creation_date(self, key: str) -> str: + """Get the creation date for the given key. + + Args + ---- + - key: str + The key to get the creation date for. + + Returns + ------- + str: + The creation date for the given key. + """ + + +def get_timestamp_formatted_with_local_tz(timestamp: datetime) -> str: + """Get the formatted timestamp with the local time zone.""" + creation_time_local = timestamp.astimezone() + + return creation_time_local.strftime("%Y-%m-%d %H:%M:%S %z") diff --git a/packages/graphrag-storage/graphrag_storage/storage_config.py b/packages/graphrag-storage/graphrag_storage/storage_config.py new file mode 100644 index 0000000000..0a8cf76893 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/storage_config.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Storage configuration model.""" + +from pydantic import BaseModel, ConfigDict, Field + + +class StorageConfig(BaseModel): + """The default configuration section for storage.""" + + model_config = ConfigDict(extra="allow") + """Allow extra fields to support custom storage implementations.""" + + type: str = Field( + description="The storage type to use.", + default="FileStorage", + ) + + base_dir: str | None = Field( + description="The base directory for the output.", + default=None, + ) + + connection_string: str | None = Field( + description="The storage connection string to use.", + default=None, + ) + + container_name: str | None = Field( + description="The storage container name to use.", + default=None, + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", + default=None, + ) + cosmosdb_account_url: str | None = Field( + description="The cosmosdb account url to use.", + default=None, + ) diff --git a/packages/graphrag-storage/graphrag_storage/storage_factory.py b/packages/graphrag-storage/graphrag_storage/storage_factory.py new file mode 100644 index 0000000000..d1ab2f4db0 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/storage_factory.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + + +"""Storage factory implementation.""" + +from collections.abc import Callable + +from graphrag_common.factory import Factory + +from graphrag_storage.azure_blob_storage import AzureBlobStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage +from graphrag_storage.file_storage import FileStorage +from graphrag_storage.memory_storage import MemoryStorage +from graphrag_storage.storage import Storage +from graphrag_storage.storage_config import StorageConfig + + +class _StorageFactory(Factory[Storage]): + """A factory class for storage implementations. + + Includes a method for users to register a custom storage implementation. + + Configuration arguments are passed to each storage implementation as kwargs + for individual enforcement of required/optional arguments. + """ + + +storage_factory = _StorageFactory() +storage_factory.register(FileStorage.__name__, FileStorage) +storage_factory.register(MemoryStorage.__name__, MemoryStorage) +storage_factory.register(AzureBlobStorage.__name__, AzureBlobStorage) +storage_factory.register(AzureCosmosStorage.__name__, AzureCosmosStorage) + + +def register_storage(storage: str, storage_initializer: Callable[..., Storage]) -> None: + """Register a custom storage implementation. + + Args + ---- + - storage: str + The storage id to register. + - storage_initializer: Callable[..., Storage] + The storage initializer to register. + """ + storage_factory.register(storage, storage_initializer) + + +def create_storage(config: StorageConfig) -> Storage: + """Create a storage implementation based on the given configuration. + + Args + ---- + - config: StorageConfig + The storage configuration to use. + + Returns + ------- + Storage + The created storage implementation. + """ + storage_strategy = config.type + + if storage_strategy not in storage_factory: + msg = f"StorageConfig.type '{storage_strategy}' is not registered in the StorageFactory. Registered types: {', '.join(storage_factory.keys())}." + raise ValueError(msg) + + return storage_factory.create(config.type, config.model_dump()) diff --git a/packages/graphrag-storage/pyproject.toml b/packages/graphrag-storage/pyproject.toml new file mode 100644 index 0000000000..464189f950 --- /dev/null +++ b/packages/graphrag-storage/pyproject.toml @@ -0,0 +1,49 @@ +[project] +name = "graphrag-storage" +version = "2.7.0" +description = "GraphRAG storage package." +authors = [ + {name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"}, + {name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"}, + {name = "Chris Trevino", email = "chtrevin@microsoft.com"}, + {name = "David Tittsworth", email = "datittsw@microsoft.com"}, + {name = "Dayenne de Souza", email = "ddesouza@microsoft.com"}, + {name = "Derek Worthen", email = "deworthe@microsoft.com"}, + {name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"}, + {name = "Ha Trinh", email = "trinhha@microsoft.com"}, + {name = "Jonathan Larson", email = "jolarso@microsoft.com"}, + {name = "Josh Bradley", email = "joshbradley@microsoft.com"}, + {name = "Kate Lytvynets", email = "kalytv@microsoft.com"}, + {name = "Kenny Zhang", email = "zhangken@microsoft.com"}, + {name = "Mónica Carvajal"}, + {name = "Nathan Evans", email = "naevans@microsoft.com"}, + {name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"}, + {name = "Sarah Smith", email = "smithsarah@microsoft.com"}, +] +license = "MIT" +readme = "README.md" +license-files = ["LICENSE"] +requires-python = ">=3.10,<3.13" +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "aiofiles>=24.1.0", + "azure-cosmos>=4.9.0", + "azure-identity>=1.19.0", + "azure-storage-blob>=12.24.0", + "graphrag-common==2.7.0", + "pandas>=2.2.3", + "pydantic>=2.10.3", +] + +[project.urls] +Source = "https://github.com/microsoft/graphrag" + +[build-system] +requires = ["hatchling>=1.27.0,<2.0.0"] +build-backend = "hatchling.build" + diff --git a/packages/graphrag/graphrag/cache/factory.py b/packages/graphrag/graphrag/cache/factory.py index 971c22c6d5..ccbf1e200f 100644 --- a/packages/graphrag/graphrag/cache/factory.py +++ b/packages/graphrag/graphrag/cache/factory.py @@ -6,15 +6,15 @@ from __future__ import annotations from graphrag_common.factory import Factory +from graphrag_storage.azure_blob_storage import AzureBlobStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage +from graphrag_storage.file_storage import FileStorage from graphrag.cache.json_pipeline_cache import JsonPipelineCache from graphrag.cache.memory_pipeline_cache import InMemoryCache from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.enums import CacheType -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage -from graphrag.storage.file_pipeline_storage import FilePipelineStorage class CacheFactory(Factory[PipelineCache]): @@ -30,19 +30,19 @@ class CacheFactory(Factory[PipelineCache]): # --- register built-in cache implementations --- def create_file_cache(**kwargs) -> PipelineCache: """Create a file-based cache implementation.""" - storage = FilePipelineStorage(**kwargs) + storage = FileStorage(**kwargs) return JsonPipelineCache(storage) def create_blob_cache(**kwargs) -> PipelineCache: """Create a blob storage-based cache implementation.""" - storage = BlobPipelineStorage(**kwargs) + storage = AzureBlobStorage(**kwargs) return JsonPipelineCache(storage) def create_cosmosdb_cache(**kwargs) -> PipelineCache: """Create a CosmosDB-based cache implementation.""" - storage = CosmosDBPipelineStorage(**kwargs) + storage = AzureCosmosStorage(**kwargs) return JsonPipelineCache(storage) diff --git a/packages/graphrag/graphrag/cache/json_pipeline_cache.py b/packages/graphrag/graphrag/cache/json_pipeline_cache.py index 84cd180c52..22b438936e 100644 --- a/packages/graphrag/graphrag/cache/json_pipeline_cache.py +++ b/packages/graphrag/graphrag/cache/json_pipeline_cache.py @@ -6,17 +6,18 @@ import json from typing import Any +from graphrag_storage import Storage + from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.storage.pipeline_storage import PipelineStorage class JsonPipelineCache(PipelineCache): """File pipeline cache class definition.""" - _storage: PipelineStorage + _storage: Storage _encoding: str - def __init__(self, storage: PipelineStorage, encoding="utf-8"): + def __init__(self, storage: Storage, encoding="utf-8"): """Init method definition.""" self._storage = storage self._encoding = encoding diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index 93163db19d..6ce049ddcc 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -8,11 +8,12 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from graphrag_storage import create_storage + import graphrag.api as api from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks from graphrag.config.load_config import load_config from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.utils.api import create_storage_from_config from graphrag.utils.storage import load_table_from_storage, storage_has_table if TYPE_CHECKING: @@ -376,7 +377,7 @@ def _resolve_output_files( ) -> dict[str, Any]: """Read indexing output files to a dataframe dict.""" dataframe_dict = {} - storage_obj = create_storage_from_config(config.output) + storage_obj = create_storage(config.output) for name in output_list: df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj)) dataframe_dict[name] = df_value diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index 88449a6050..9d988ad928 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -7,6 +7,8 @@ from pathlib import Path from typing import ClassVar +from graphrag_storage.file_storage import FileStorage + from graphrag.config.embeddings import default_embeddings from graphrag.config.enums import ( AsyncType, @@ -17,7 +19,6 @@ ModelType, NounPhraseExtractorType, ReportingType, - StorageType, VectorStoreType, ) from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import ( @@ -230,7 +231,7 @@ class GlobalSearchDefaults: class StorageDefaults: """Default values for storage.""" - type: ClassVar[StorageType] = StorageType.file + type: str = FileStorage.__name__ base_dir: str | None = None connection_string: None = None container_name: None = None diff --git a/packages/graphrag/graphrag/config/init_content.py b/packages/graphrag/graphrag/config/init_content.py index 1cbccf74df..1cb70ddd16 100644 --- a/packages/graphrag/graphrag/config/init_content.py +++ b/packages/graphrag/graphrag/config/init_content.py @@ -50,7 +50,7 @@ input: storage: - type: {graphrag_config_defaults.input.storage.type.value} # or blob + type: {graphrag_config_defaults.input.storage.type} # or AzureBlobStorage, AzureCosmosStorage base_dir: "{graphrag_config_defaults.input.storage.base_dir}" file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] @@ -63,7 +63,7 @@ ## connection_string and container_name must be provided output: - type: {graphrag_config_defaults.output.type.value} # [file, blob, cosmosdb] + type: {graphrag_config_defaults.output.type} # or AzureBlobStorage, AzureCosmosStorage base_dir: "{graphrag_config_defaults.output.base_dir}" cache: diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 15d02eaf3a..e2bdf81f72 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -6,6 +6,8 @@ from pathlib import Path from devtools import pformat +from graphrag_storage import StorageConfig +from graphrag_storage.file_storage import FileStorage from pydantic import BaseModel, Field, model_validator import graphrag.config.defaults as defs @@ -29,7 +31,6 @@ from graphrag.config.models.prune_graph_config import PruneGraphConfig from graphrag.config.models.reporting_config import ReportingConfig from graphrag.config.models.snapshots_config import SnapshotsConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) @@ -135,7 +136,7 @@ def _validate_input_pattern(self) -> None: def _validate_input_base_dir(self) -> None: """Validate the input base directory.""" - if self.input.storage.type == defs.StorageType.file: + if self.input.storage.type == FileStorage.__name__: if not self.input.storage.base_dir: msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." raise ValueError(msg) @@ -159,7 +160,7 @@ def _validate_input_base_dir(self) -> None: def _validate_output_base_dir(self) -> None: """Validate the output base directory.""" - if self.output.type == defs.StorageType.file: + if self.output.type == FileStorage.__name__: if not self.output.base_dir: msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) @@ -175,7 +176,7 @@ def _validate_output_base_dir(self) -> None: def _validate_update_index_output_base_dir(self) -> None: """Validate the update index output base directory.""" - if self.update_index_output.type == defs.StorageType.file: + if self.update_index_output.type == FileStorage.__name__: if not self.update_index_output.base_dir: msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) diff --git a/packages/graphrag/graphrag/config/models/input_config.py b/packages/graphrag/graphrag/config/models/input_config.py index bc34d9402d..c3c30f6302 100644 --- a/packages/graphrag/graphrag/config/models/input_config.py +++ b/packages/graphrag/graphrag/config/models/input_config.py @@ -3,12 +3,12 @@ """Parameterization settings for the default configuration.""" +from graphrag_storage import StorageConfig from pydantic import BaseModel, Field import graphrag.config.defaults as defs from graphrag.config.defaults import graphrag_config_defaults from graphrag.config.enums import InputFileType -from graphrag.config.models.storage_config import StorageConfig class InputConfig(BaseModel): diff --git a/packages/graphrag/graphrag/config/models/storage_config.py b/packages/graphrag/graphrag/config/models/storage_config.py deleted file mode 100644 index 7491454c0a..0000000000 --- a/packages/graphrag/graphrag/config/models/storage_config.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Parameterization settings for the default configuration.""" - -from pathlib import Path - -from pydantic import BaseModel, Field, field_validator - -from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.enums import StorageType - - -class StorageConfig(BaseModel): - """The default configuration section for storage.""" - - type: StorageType | str = Field( - description="The storage type to use.", - default=graphrag_config_defaults.storage.type, - ) - base_dir: str | None = Field( - description="The base directory for the output.", - default=graphrag_config_defaults.storage.base_dir, - ) - - # Validate the base dir for multiple OS (use Path) - # if not using a cloud storage type. - @field_validator("base_dir", mode="before") - @classmethod - def validate_base_dir(cls, value, info): - """Ensure that base_dir is a valid filesystem path when using local storage.""" - # info.data contains other field values, including 'type' - if info.data.get("type") != StorageType.file: - return value - return str(Path(value)) - - connection_string: str | None = Field( - description="The storage connection string to use.", - default=graphrag_config_defaults.storage.connection_string, - ) - container_name: str | None = Field( - description="The storage container name to use.", - default=graphrag_config_defaults.storage.container_name, - ) - storage_account_blob_url: str | None = Field( - description="The storage account blob url to use.", - default=graphrag_config_defaults.storage.storage_account_blob_url, - ) - cosmosdb_account_url: str | None = Field( - description="The cosmosdb account url to use.", - default=graphrag_config_defaults.storage.cosmosdb_account_url, - ) diff --git a/packages/graphrag/graphrag/index/input/input_reader.py b/packages/graphrag/graphrag/index/input/input_reader.py index ed0add9f97..98a713e509 100644 --- a/packages/graphrag/graphrag/index/input/input_reader.py +++ b/packages/graphrag/graphrag/index/input/input_reader.py @@ -13,8 +13,9 @@ import pandas as pd if TYPE_CHECKING: + from graphrag_storage import Storage + from graphrag.config.models.input_config import InputConfig - from graphrag.storage.pipeline_storage import PipelineStorage logger = logging.getLogger(__name__) @@ -22,7 +23,7 @@ class InputReader(metaclass=ABCMeta): """Provide a cache interface for the pipeline.""" - def __init__(self, storage: PipelineStorage, config: InputConfig, **kwargs): + def __init__(self, storage: Storage, config: InputConfig, **kwargs): self._storage = storage self._config = config diff --git a/packages/graphrag/graphrag/index/operations/snapshot_graphml.py b/packages/graphrag/graphrag/index/operations/snapshot_graphml.py index c1eb9b0688..9124038401 100644 --- a/packages/graphrag/graphrag/index/operations/snapshot_graphml.py +++ b/packages/graphrag/graphrag/index/operations/snapshot_graphml.py @@ -4,14 +4,13 @@ """A module containing snapshot_graphml method definition.""" import networkx as nx - -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage async def snapshot_graphml( input: str | nx.Graph, name: str, - storage: PipelineStorage, + storage: Storage, ) -> None: """Take a entire snapshot of a graph to standard graphml format.""" graphml = input if isinstance(input, str) else "\n".join(nx.generate_graphml(input)) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index a373e43b76..4eef24f14c 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -12,6 +12,7 @@ from typing import Any import pandas as pd +from graphrag_storage import Storage, create_storage from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -19,8 +20,7 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.storage.pipeline_storage import PipelineStorage -from graphrag.utils.api import create_cache_from_config, create_storage_from_config +from graphrag.utils.api import create_cache_from_config from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -35,8 +35,8 @@ async def run_pipeline( input_documents: pd.DataFrame | None = None, ) -> AsyncIterable[PipelineRunResult]: """Run all workflows using a simplified pipeline.""" - input_storage = create_storage_from_config(config.input.storage) - output_storage = create_storage_from_config(config.output) + input_storage = create_storage(config.input.storage) + output_storage = create_storage(config.output) cache = create_cache_from_config(config.cache) # load existing state in case any workflows are stateful @@ -49,7 +49,7 @@ async def run_pipeline( if is_update_run: logger.info("Running incremental indexing.") - update_storage = create_storage_from_config(config.update_index_output) + update_storage = create_storage(config.update_index_output) # we use this to store the new subset index, and will merge its content with the previous index update_timestamp = time.strftime("%Y%m%d-%H%M%S") timestamped_storage = update_storage.child(update_timestamp) @@ -156,8 +156,8 @@ async def _dump_json(context: PipelineRunContext) -> None: async def _copy_previous_output( - storage: PipelineStorage, - copy_storage: PipelineStorage, + storage: Storage, + copy_storage: Storage, ): for file in storage.find(re.compile(r"\.parquet$")): base_name = file[0].replace(".parquet", "") diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index 52b1f0bd31..03e789746a 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -3,6 +3,9 @@ """Utility functions for the GraphRAG run module.""" +from graphrag_storage import Storage, create_storage +from graphrag_storage.memory_storage import MemoryStorage + from graphrag.cache.memory_pipeline_cache import InMemoryCache from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks @@ -12,15 +15,12 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.storage.pipeline_storage import PipelineStorage -from graphrag.utils.api import create_storage_from_config def create_run_context( - input_storage: PipelineStorage | None = None, - output_storage: PipelineStorage | None = None, - previous_storage: PipelineStorage | None = None, + input_storage: Storage | None = None, + output_storage: Storage | None = None, + previous_storage: Storage | None = None, cache: PipelineCache | None = None, callbacks: WorkflowCallbacks | None = None, stats: PipelineRunStats | None = None, @@ -28,9 +28,9 @@ def create_run_context( ) -> PipelineRunContext: """Create the run context for the pipeline.""" return PipelineRunContext( - input_storage=input_storage or MemoryPipelineStorage(), - output_storage=output_storage or MemoryPipelineStorage(), - previous_storage=previous_storage or MemoryPipelineStorage(), + input_storage=input_storage or MemoryStorage(), + output_storage=output_storage or MemoryStorage(), + previous_storage=previous_storage or MemoryStorage(), cache=cache or InMemoryCache(), callbacks=callbacks or NoopWorkflowCallbacks(), stats=stats or PipelineRunStats(), @@ -50,10 +50,10 @@ def create_callback_chain( def get_update_storages( config: GraphRagConfig, timestamp: str -) -> tuple[PipelineStorage, PipelineStorage, PipelineStorage]: +) -> tuple[Storage, Storage, Storage]: """Get storage objects for the update index run.""" - output_storage = create_storage_from_config(config.output) - update_storage = create_storage_from_config(config.update_index_output) + output_storage = create_storage(config.output) + update_storage = create_storage(config.update_index_output) timestamped_storage = update_storage.child(timestamp) delta_storage = timestamped_storage.child("delta") previous_storage = timestamped_storage.child("previous") diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index ef2e1f7ea5..465ec7214c 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -10,7 +10,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage @dataclass @@ -18,11 +18,11 @@ class PipelineRunContext: """Provides the context for the current pipeline run.""" stats: PipelineRunStats - input_storage: PipelineStorage + input_storage: Storage "Storage for input documents." - output_storage: PipelineStorage + output_storage: Storage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." - previous_storage: PipelineStorage + previous_storage: Storage "Storage for previous pipeline run when running in update mode." cache: PipelineCache "Cache instance for reading previous LLM responses." diff --git a/packages/graphrag/graphrag/index/update/incremental_index.py b/packages/graphrag/graphrag/index/update/incremental_index.py index ac56e30df4..81f917e187 100644 --- a/packages/graphrag/graphrag/index/update/incremental_index.py +++ b/packages/graphrag/graphrag/index/update/incremental_index.py @@ -7,8 +7,8 @@ import numpy as np import pandas as pd +from graphrag_storage import Storage -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import ( load_table_from_storage, write_table_to_storage, @@ -31,16 +31,14 @@ class InputDelta: deleted_inputs: pd.DataFrame -async def get_delta_docs( - input_dataset: pd.DataFrame, storage: PipelineStorage -) -> InputDelta: +async def get_delta_docs(input_dataset: pd.DataFrame, storage: Storage) -> InputDelta: """Get the delta between the input dataset and the final documents. Parameters ---------- input_dataset : pd.DataFrame The input dataset. - storage : PipelineStorage + storage : Storage The Pipeline storage. Returns @@ -65,9 +63,9 @@ async def get_delta_docs( async def concat_dataframes( name: str, - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, ) -> pd.DataFrame: """Concatenate dataframes.""" old_df = await load_table_from_storage(name, previous_storage) diff --git a/packages/graphrag/graphrag/index/workflows/load_update_documents.py b/packages/graphrag/graphrag/index/workflows/load_update_documents.py index 7755091017..e68471eae7 100644 --- a/packages/graphrag/graphrag/index/workflows/load_update_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_update_documents.py @@ -6,6 +6,7 @@ import logging import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.input.factory import InputReaderFactory @@ -13,7 +14,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.incremental_index import get_delta_docs -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ async def run_workflow( async def load_update_documents( input_reader: InputReader, - previous_storage: PipelineStorage, + previous_storage: Storage, ) -> pd.DataFrame: """Load and parse update-only input documents into a standard format.""" input_documents = await input_reader.read_files() diff --git a/packages/graphrag/graphrag/index/workflows/update_communities.py b/packages/graphrag/graphrag/index/workflows/update_communities.py index b7e3e6a343..da4fdef147 100644 --- a/packages/graphrag/graphrag/index/workflows/update_communities.py +++ b/packages/graphrag/graphrag/index/workflows/update_communities.py @@ -5,12 +5,13 @@ import logging +from graphrag_storage import Storage + from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_communities -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -37,9 +38,9 @@ async def run_workflow( async def _update_communities( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, ) -> dict: """Update the communities output.""" old_communities = await load_table_from_storage("communities", previous_storage) diff --git a/packages/graphrag/graphrag/index/workflows/update_community_reports.py b/packages/graphrag/graphrag/index/workflows/update_community_reports.py index 42576aca27..790f9fc296 100644 --- a/packages/graphrag/graphrag/index/workflows/update_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/update_community_reports.py @@ -6,13 +6,13 @@ import logging import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_community_reports -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -43,9 +43,9 @@ async def run_workflow( async def _update_community_reports( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, community_id_mapping: dict, ) -> pd.DataFrame: """Update the community reports output.""" diff --git a/packages/graphrag/graphrag/index/workflows/update_covariates.py b/packages/graphrag/graphrag/index/workflows/update_covariates.py index f0bf29a6ae..09f8b4053d 100644 --- a/packages/graphrag/graphrag/index/workflows/update_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/update_covariates.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import ( load_table_from_storage, storage_has_table, @@ -43,9 +43,9 @@ async def run_workflow( async def _update_covariates( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, ) -> None: """Update the covariates output.""" old_covariates = await load_table_from_storage("covariates", previous_storage) diff --git a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py index 1245303559..2ddd171457 100644 --- a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py +++ b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py @@ -6,6 +6,7 @@ import logging import pandas as pd +from graphrag_storage import Storage from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -17,7 +18,6 @@ from graphrag.index.update.relationships import _update_and_merge_relationships from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships from graphrag.language_model.manager import ModelManager -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -55,9 +55,9 @@ async def run_workflow( async def _update_entities_and_relationships( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, config: GraphRagConfig, cache: PipelineCache, callbacks: WorkflowCallbacks, diff --git a/packages/graphrag/graphrag/index/workflows/update_text_units.py b/packages/graphrag/graphrag/index/workflows/update_text_units.py index 392533f16b..c97f89ce7a 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_units.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd +from graphrag_storage import Storage from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -40,9 +40,9 @@ async def run_workflow( async def _update_text_units( - previous_storage: PipelineStorage, - delta_storage: PipelineStorage, - output_storage: PipelineStorage, + previous_storage: Storage, + delta_storage: Storage, + output_storage: Storage, entity_id_mapping: dict, ) -> pd.DataFrame: """Update the text units output.""" diff --git a/packages/graphrag/graphrag/prompt_tune/loader/input.py b/packages/graphrag/graphrag/prompt_tune/loader/input.py index 5e9fccb440..c810b0ce41 100644 --- a/packages/graphrag/graphrag/prompt_tune/loader/input.py +++ b/packages/graphrag/graphrag/prompt_tune/loader/input.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from graphrag_storage import create_storage from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks @@ -25,7 +26,6 @@ ) from graphrag.prompt_tune.types import DocSelectionType from graphrag.tokenizer.get_tokenizer import get_tokenizer -from graphrag.utils.api import create_storage_from_config def _sample_chunks_from_embeddings( @@ -63,7 +63,7 @@ async def load_docs_in_chunks( cache=NoopPipelineCache(), ) tokenizer = get_tokenizer(embeddings_llm_settings) - input_storage = create_storage_from_config(config.input.storage) + input_storage = create_storage(config.input.storage) input_reader = InputReaderFactory().create( config.input.file_type, {"storage": input_storage, "config": config.input}, diff --git a/packages/graphrag/graphrag/storage/__init__.py b/packages/graphrag/graphrag/storage/__init__.py index b21f077cb1..94146bcd02 100644 --- a/packages/graphrag/graphrag/storage/__init__.py +++ b/packages/graphrag/graphrag/storage/__init__.py @@ -2,3 +2,7 @@ # Licensed under the MIT License """The storage package root.""" + +from graphrag_storage import create_storage, register_storage + +__all__ = ["create_storage", "register_storage"] diff --git a/packages/graphrag/graphrag/storage/factory.py b/packages/graphrag/graphrag/storage/factory.py deleted file mode 100644 index 738a46420b..0000000000 --- a/packages/graphrag/graphrag/storage/factory.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Factory functions for creating storage.""" - -from __future__ import annotations - -from graphrag_common.factory import Factory - -from graphrag.config.enums import StorageType -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage -from graphrag.storage.file_pipeline_storage import FilePipelineStorage -from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.storage.pipeline_storage import PipelineStorage - - -class StorageFactory(Factory[PipelineStorage]): - """A factory class for storage implementations. - - Includes a method for users to register a custom storage implementation. - - Configuration arguments are passed to each storage implementation as kwargs - for individual enforcement of required/optional arguments. - """ - - -# --- register built-in storage implementations --- -storage_factory = StorageFactory() -storage_factory.register(StorageType.blob.value, BlobPipelineStorage) -storage_factory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage) -storage_factory.register(StorageType.file.value, FilePipelineStorage) -storage_factory.register(StorageType.memory.value, MemoryPipelineStorage) diff --git a/packages/graphrag/graphrag/storage/pipeline_storage.py b/packages/graphrag/graphrag/storage/pipeline_storage.py deleted file mode 100644 index 5c79921736..0000000000 --- a/packages/graphrag/graphrag/storage/pipeline_storage.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing 'PipelineStorage' model.""" - -import re -from abc import ABCMeta, abstractmethod -from collections.abc import Iterator -from datetime import datetime -from typing import Any - - -class PipelineStorage(metaclass=ABCMeta): - """Provide a storage interface for the pipeline. This is where the pipeline will store its output data.""" - - @abstractmethod - def find( - self, - file_pattern: re.Pattern[str], - ) -> Iterator[str]: - """Find files in the storage using a file pattern.""" - - @abstractmethod - async def get( - self, key: str, as_bytes: bool | None = None, encoding: str | None = None - ) -> Any: - """Get the value for the given key. - - Args: - - key - The key to get the value for. - - as_bytes - Whether or not to return the value as bytes. - - Returns - ------- - - output - The value for the given key. - """ - - @abstractmethod - async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Set the value for the given key. - - Args: - - key - The key to set the value for. - - value - The value to set. - """ - - @abstractmethod - async def has(self, key: str) -> bool: - """Return True if the given key exists in the storage. - - Args: - - key - The key to check for. - - Returns - ------- - - output - True if the key exists in the storage, False otherwise. - """ - - @abstractmethod - async def delete(self, key: str) -> None: - """Delete the given key from the storage. - - Args: - - key - The key to delete. - """ - - @abstractmethod - async def clear(self) -> None: - """Clear the storage.""" - - @abstractmethod - def child(self, name: str | None) -> "PipelineStorage": - """Create a child storage instance.""" - - @abstractmethod - def keys(self) -> list[str]: - """List all keys in the storage.""" - - @abstractmethod - async def get_creation_date(self, key: str) -> str: - """Get the creation date for the given key. - - Args: - - key - The key to get the creation date for. - - Returns - ------- - - output - The creation date for the given key. - """ - - -def get_timestamp_formatted_with_local_tz(timestamp: datetime) -> str: - """Get the formatted timestamp with the local time zone.""" - creation_time_local = timestamp.astimezone() - - return creation_time_local.strftime("%Y-%m-%d %H:%M:%S %z") diff --git a/packages/graphrag/graphrag/utils/api.py b/packages/graphrag/graphrag/utils/api.py index f264c1a9ed..2d83d692ff 100644 --- a/packages/graphrag/graphrag/utils/api.py +++ b/packages/graphrag/graphrag/utils/api.py @@ -10,10 +10,7 @@ from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.embeddings import create_index_name from graphrag.config.models.cache_config import CacheConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.storage.factory import StorageFactory -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.vector_stores.base import ( BaseVectorStore, ) @@ -101,15 +98,6 @@ def load_search_prompt(prompt_config: str | None) -> str | None: return None -def create_storage_from_config(output: StorageConfig) -> PipelineStorage: - """Create a storage object from the config.""" - storage_config = output.model_dump() - return StorageFactory().create( - storage_config["type"], - storage_config, - ) - - def create_cache_from_config(cache: CacheConfig) -> PipelineCache: """Create a cache object from the config.""" cache_config = cache.model_dump() diff --git a/packages/graphrag/graphrag/utils/storage.py b/packages/graphrag/graphrag/utils/storage.py index 8534330a15..852d066091 100644 --- a/packages/graphrag/graphrag/utils/storage.py +++ b/packages/graphrag/graphrag/utils/storage.py @@ -7,13 +7,12 @@ from io import BytesIO import pandas as pd - -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage logger = logging.getLogger(__name__) -async def load_table_from_storage(name: str, storage: PipelineStorage) -> pd.DataFrame: +async def load_table_from_storage(name: str, storage: Storage) -> pd.DataFrame: """Load a parquet from the storage instance.""" filename = f"{name}.parquet" if not await storage.has(filename): @@ -28,17 +27,17 @@ async def load_table_from_storage(name: str, storage: PipelineStorage) -> pd.Dat async def write_table_to_storage( - table: pd.DataFrame, name: str, storage: PipelineStorage + table: pd.DataFrame, name: str, storage: Storage ) -> None: """Write a table to storage.""" await storage.set(f"{name}.parquet", table.to_parquet()) -async def delete_table_from_storage(name: str, storage: PipelineStorage) -> None: +async def delete_table_from_storage(name: str, storage: Storage) -> None: """Delete a table to storage.""" await storage.delete(f"{name}.parquet") -async def storage_has_table(name: str, storage: PipelineStorage) -> bool: +async def storage_has_table(name: str, storage: Storage) -> bool: """Check if a table exists in storage.""" return await storage.has(f"{name}.parquet") diff --git a/packages/graphrag/pyproject.toml b/packages/graphrag/pyproject.toml index a7b97f1f0f..7b7eec259d 100644 --- a/packages/graphrag/pyproject.toml +++ b/packages/graphrag/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "devtools>=0.12.2", "environs>=11.0.0", "graphrag-common==2.7.0", + "graphrag-storage==2.7.0", "graspologic-native>=1.2.5", "json-repair>=0.30.3", "lancedb>=0.17.0", diff --git a/pyproject.toml b/pyproject.toml index 3cb1ae1d67..1979df5cb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ members = ["packages/*"] [tool.uv.sources] graphrag-common = { workspace = true } +graphrag-storage = { workspace = true } # Keep poethepoet for task management to minimize changes [tool.poe.tasks] @@ -69,6 +70,7 @@ _semversioner_changelog = "semversioner changelog > CHANGELOG.md" # Add more update toml tasks as packages are added _semversioner_update_graphrag_toml_version = "update-toml update --file packages/graphrag/pyproject.toml --path project.version --value $(uv run semversioner current-version)" _semversioner_update_graphrag_common_toml_version = "update-toml update --file packages/graphrag-common/pyproject.toml --path project.version --value $(uv run semversioner current-version)" +_semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)" _semversioner_update_workspace_dependency_versions = "python -m scripts.update_workspace_dependency_versions" semversioner_add = "semversioner add-change" coverage_report = 'coverage report --omit "**/tests/**" --show-missing' @@ -103,6 +105,7 @@ sequence = [ # Add more update toml tasks as packages are added '_semversioner_update_graphrag_toml_version', '_semversioner_update_graphrag_common_toml_version', + '_semversioner_update_graphrag_storage_toml_version', '_semversioner_update_workspace_dependency_versions', '_sync', ] @@ -220,6 +223,7 @@ convention = "numpy" include = [ "packages/graphrag/graphrag", "packages/graphrag-common/graphrag_common", + "packages/graphrag-storage/graphrag_storage", "tests" ] exclude = ["**/node_modules", "**/__pycache__"] diff --git a/tests/integration/storage/test_blob_pipeline_storage.py b/tests/integration/storage/test_blob_storage.py similarity index 94% rename from tests/integration/storage/test_blob_pipeline_storage.py rename to tests/integration/storage/test_blob_storage.py index 818b588bd6..44216700da 100644 --- a/tests/integration/storage/test_blob_pipeline_storage.py +++ b/tests/integration/storage/test_blob_storage.py @@ -5,14 +5,14 @@ import re from datetime import datetime -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag_storage.azure_blob_storage import AzureBlobStorage # cspell:disable-next-line well-known-key WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" async def test_find(): - storage = BlobPipelineStorage( + storage = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testfind", ) @@ -42,7 +42,7 @@ async def test_find(): async def test_dotprefix(): - storage = BlobPipelineStorage( + storage = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testfind", path_prefix=".", @@ -56,7 +56,7 @@ async def test_dotprefix(): async def test_get_creation_date(): - storage = BlobPipelineStorage( + storage = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testfind", path_prefix=".", @@ -74,7 +74,7 @@ async def test_get_creation_date(): async def test_child(): - parent = BlobPipelineStorage( + parent = AzureBlobStorage( connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, container_name="testchild", ) diff --git a/tests/integration/storage/test_cosmosdb_storage.py b/tests/integration/storage/test_cosmosdb_storage.py index 3d6128872f..9f85d93e0f 100644 --- a/tests/integration/storage/test_cosmosdb_storage.py +++ b/tests/integration/storage/test_cosmosdb_storage.py @@ -8,7 +8,7 @@ from datetime import datetime import pytest -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage # cspell:disable-next-line well-known-key WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" @@ -21,7 +21,7 @@ async def test_find(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testfind", container_name="testfindcontainer", @@ -64,20 +64,20 @@ async def test_find(): async def test_child(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testchild", container_name="testchildcontainer", ) try: child_storage = storage.child("child") - assert type(child_storage) is CosmosDBPipelineStorage + assert type(child_storage) is AzureCosmosStorage finally: await storage.clear() async def test_clear(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testclear", container_name="testclearcontainer", @@ -107,7 +107,7 @@ async def test_clear(): async def test_get_creation_date(): - storage = CosmosDBPipelineStorage( + storage = AzureCosmosStorage( connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, base_dir="testclear", container_name="testclearcontainer", diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 87a2960dbc..bbd5e276f2 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -8,13 +8,11 @@ import sys import pytest -from graphrag.config.enums import StorageType -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage -from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage -from graphrag.storage.factory import StorageFactory -from graphrag.storage.file_pipeline_storage import FilePipelineStorage -from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag_storage import Storage, StorageConfig, create_storage, register_storage +from graphrag_storage.azure_blob_storage import AzureBlobStorage +from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage +from graphrag_storage.file_storage import FileStorage +from graphrag_storage.memory_storage import MemoryStorage # cspell:disable-next-line well-known-key WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" @@ -24,14 +22,14 @@ @pytest.mark.skip(reason="Blob storage emulator is not available in this environment") def test_create_blob_storage(): - kwargs = { - "type": "blob", - "connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, - "base_dir": "testbasedir", - "container_name": "testcontainer", - } - storage = StorageFactory().create(StorageType.blob.value, kwargs) - assert isinstance(storage, BlobPipelineStorage) + config = StorageConfig( + type=AzureBlobStorage.__name__, + connection_string=WELL_KNOWN_BLOB_STORAGE_KEY, + base_dir="testbasedir", + container_name="testcontainer", + ) + storage = create_storage(config) + assert isinstance(storage, AzureBlobStorage) @pytest.mark.skipif( @@ -39,63 +37,61 @@ def test_create_blob_storage(): reason="cosmosdb emulator is only available on windows runners at this time", ) def test_create_cosmosdb_storage(): - kwargs = { - "type": "cosmosdb", - "connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, - "base_dir": "testdatabase", - "container_name": "testcontainer", - } - storage = StorageFactory().create(StorageType.cosmosdb.value, kwargs) - assert isinstance(storage, CosmosDBPipelineStorage) + config = StorageConfig( + type=AzureCosmosStorage.__name__, + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + base_dir="testdatabase", + container_name="testcontainer", + ) + storage = create_storage(config) + assert isinstance(storage, AzureCosmosStorage) def test_create_file(): - kwargs = {"type": "file", "base_dir": "/tmp/teststorage"} - storage = StorageFactory().create(StorageType.file.value, kwargs) - assert isinstance(storage, FilePipelineStorage) + config = StorageConfig( + type=FileStorage.__name__, + base_dir="/tmp/teststorage", + ) + storage = create_storage(config) + assert isinstance(storage, FileStorage) def test_create_memory_storage(): - kwargs = {} # MemoryPipelineStorage doesn't accept any constructor parameters - storage = StorageFactory().create(StorageType.memory.value, kwargs) - assert isinstance(storage, MemoryPipelineStorage) + config = StorageConfig( + base_dir="", + type=MemoryStorage.__name__, + ) + storage = create_storage(config) + assert isinstance(storage, MemoryStorage) def test_register_and_create_custom_storage(): """Test registering and creating a custom storage type.""" from unittest.mock import MagicMock - # Create a mock that satisfies the PipelineStorage interface - custom_storage_class = MagicMock(spec=PipelineStorage) + # Create a mock that satisfies the Storage interface + custom_storage_class = MagicMock(spec=Storage) # Make the mock return a mock instance when instantiated instance = MagicMock() - # We can set attributes on the mock instance, even if they don't exist on PipelineStorage + # We can set attributes on the mock instance, even if they don't exist on Storage instance.initialized = True custom_storage_class.return_value = instance - StorageFactory().register("custom", lambda **kwargs: custom_storage_class(**kwargs)) - storage = StorageFactory().create("custom", {}) + register_storage("custom", lambda **kwargs: custom_storage_class(**kwargs)) + storage = create_storage(StorageConfig(type="custom")) assert custom_storage_class.called assert storage is instance # Access the attribute we set on our mock assert storage.initialized is True # type: ignore # Attribute only exists on our mock - # Check if it's in the list of registered storage types - assert "custom" in StorageFactory() - - -def test_get_storage_types(): - # Check that built-in types are registered - assert StorageType.file.value in StorageFactory() - assert StorageType.memory.value in StorageFactory() - assert StorageType.blob.value in StorageFactory() - assert StorageType.cosmosdb.value in StorageFactory() - def test_create_unknown_storage(): - with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."): - StorageFactory().create("unknown") + with pytest.raises( + ValueError, + match="StorageConfig\\.type 'unknown' is not registered in the StorageFactory\\.", + ): + create_storage(StorageConfig(type="unknown")) def test_register_class_directly_works(): @@ -104,9 +100,7 @@ def test_register_class_directly_works(): from collections.abc import Iterator from typing import Any - from graphrag.storage.pipeline_storage import PipelineStorage - - class CustomStorage(PipelineStorage): + class CustomStorage(Storage): def __init__(self, **kwargs): pass @@ -133,7 +127,7 @@ async def has(self, key: str) -> bool: async def clear(self) -> None: pass - def child(self, name: str | None) -> "PipelineStorage": + def child(self, name: str | None) -> "Storage": return self def keys(self) -> list[str]: @@ -143,11 +137,8 @@ async def get_creation_date(self, key: str) -> str: return "2024-01-01 00:00:00 +0000" # StorageFactory allows registering classes directly (no TypeError) - StorageFactory().register("custom_class", CustomStorage) - - # Verify it was registered - assert "custom_class" in StorageFactory() + register_storage("custom_class", CustomStorage) # Test creating an instance - storage = StorageFactory().create("custom_class") + storage = create_storage(StorageConfig(type="custom_class")) assert isinstance(storage, CustomStorage) diff --git a/tests/integration/storage/test_file_pipeline_storage.py b/tests/integration/storage/test_file_storage.py similarity index 86% rename from tests/integration/storage/test_file_pipeline_storage.py rename to tests/integration/storage/test_file_storage.py index 95e329b6bf..b6edc77b03 100644 --- a/tests/integration/storage/test_file_pipeline_storage.py +++ b/tests/integration/storage/test_file_storage.py @@ -7,17 +7,15 @@ from datetime import datetime from pathlib import Path -from graphrag.storage.file_pipeline_storage import ( - FilePipelineStorage, +from graphrag_storage.file_storage import ( + FileStorage, ) __dirname__ = os.path.dirname(__file__) async def test_find(): - storage = FilePipelineStorage( - base_dir="tests/fixtures/text/input", - ) + storage = FileStorage(base_dir="tests/fixtures/text/input") items = list(storage.find(file_pattern=re.compile(r".*\.txt$"))) assert items == [str(Path("dulce.txt"))] output = await storage.get("dulce.txt") @@ -32,7 +30,7 @@ async def test_find(): async def test_get_creation_date(): - storage = FilePipelineStorage( + storage = FileStorage( base_dir="tests/fixtures/text/input", ) @@ -45,7 +43,7 @@ async def test_get_creation_date(): async def test_child(): - storage = FilePipelineStorage() + storage = FileStorage() storage = storage.child("tests/fixtures/text/input") items = list(storage.find(re.compile(r".*\.txt$"))) assert items == [str(Path("dulce.txt"))] diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 9821bed551..53205c7c09 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -17,7 +17,7 @@ from graphrag.query.context_builder.community_context import ( NO_COMMUNITY_RECORDS_WARNING, ) -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag_storage.azure_blob_storage import AzureBlobStorage logger = logging.getLogger(__name__) @@ -94,7 +94,7 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non input_base_dir = azure.get("input_base_dir") root = Path(input_path) - input_storage = BlobPipelineStorage( + input_storage = AzureBlobStorage( connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING, container_name=input_container, ) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 001518f62a..0f2235a58b 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -25,11 +25,11 @@ from graphrag.config.models.prune_graph_config import PruneGraphConfig from graphrag.config.models.reporting_config import ReportingConfig from graphrag.config.models.snapshots_config import SnapshotsConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) from graphrag.config.models.vector_store_config import VectorStoreConfig +from graphrag_storage import StorageConfig from pydantic import BaseModel FAKE_API_KEY = "NOT_AN_API_KEY" diff --git a/tests/unit/indexing/cache/test_file_pipeline_cache.py b/tests/unit/indexing/cache/test_file_pipeline_cache.py index c392b4e08e..c672d3718f 100644 --- a/tests/unit/indexing/cache/test_file_pipeline_cache.py +++ b/tests/unit/indexing/cache/test_file_pipeline_cache.py @@ -5,15 +5,15 @@ import unittest from graphrag.cache.json_pipeline_cache import JsonPipelineCache -from graphrag.storage.file_pipeline_storage import ( - FilePipelineStorage, +from graphrag_storage.file_storage import ( + FileStorage, ) TEMP_DIR = "./.tmp" def create_cache(): - storage = FilePipelineStorage(base_dir=os.path.join(os.getcwd(), ".tmp")) + storage = FileStorage(base_dir=os.path.join(os.getcwd(), ".tmp")) return JsonPipelineCache(storage) diff --git a/tests/unit/indexing/input/test_csv_loader.py b/tests/unit/indexing/input/test_csv_loader.py index 8a6b0e351d..b0dc645e1b 100644 --- a/tests/unit/indexing/input/test_csv_loader.py +++ b/tests/unit/indexing/input/test_csv_loader.py @@ -3,9 +3,8 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import InputReaderFactory -from graphrag.utils.api import create_storage_from_config +from graphrag_storage import StorageConfig, create_storage async def test_csv_loader_one_file(): @@ -16,7 +15,7 @@ async def test_csv_loader_one_file(): file_type=InputFileType.csv, file_pattern=".*\\.csv$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -35,7 +34,7 @@ async def test_csv_loader_one_file_with_title(): file_pattern=".*\\.csv$", title_column="title", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -55,7 +54,7 @@ async def test_csv_loader_one_file_with_metadata(): title_column="title", metadata=["title"], ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -74,7 +73,7 @@ async def test_csv_loader_multiple_files(): file_type=InputFileType.csv, file_pattern=".*\\.csv$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) diff --git a/tests/unit/indexing/input/test_json_loader.py b/tests/unit/indexing/input/test_json_loader.py index 1ce7001aab..3959096b1c 100644 --- a/tests/unit/indexing/input/test_json_loader.py +++ b/tests/unit/indexing/input/test_json_loader.py @@ -3,9 +3,8 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import InputReaderFactory -from graphrag.utils.api import create_storage_from_config +from graphrag_storage import StorageConfig, create_storage async def test_json_loader_one_file_one_object(): @@ -16,7 +15,7 @@ async def test_json_loader_one_file_one_object(): file_type=InputFileType.json, file_pattern=".*\\.json$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -34,7 +33,7 @@ async def test_json_loader_one_file_multiple_objects(): file_type=InputFileType.json, file_pattern=".*\\.json$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -54,7 +53,7 @@ async def test_json_loader_one_file_with_title(): file_pattern=".*\\.json$", title_column="title", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -74,7 +73,7 @@ async def test_json_loader_one_file_with_metadata(): title_column="title", metadata=["title"], ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -92,7 +91,7 @@ async def test_json_loader_multiple_files(): file_type=InputFileType.json, file_pattern=".*\\.json$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) diff --git a/tests/unit/indexing/input/test_txt_loader.py b/tests/unit/indexing/input/test_txt_loader.py index 239f622d72..57ded82507 100644 --- a/tests/unit/indexing/input/test_txt_loader.py +++ b/tests/unit/indexing/input/test_txt_loader.py @@ -3,9 +3,8 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig -from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import InputReaderFactory -from graphrag.utils.api import create_storage_from_config +from graphrag_storage import StorageConfig, create_storage async def test_txt_loader_one_file(): @@ -16,7 +15,7 @@ async def test_txt_loader_one_file(): file_type=InputFileType.text, file_pattern=".*\\.txt$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -35,7 +34,7 @@ async def test_txt_loader_one_file_with_metadata(): file_pattern=".*\\.txt$", metadata=["title"], ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) @@ -54,7 +53,7 @@ async def test_txt_loader_multiple_files(): file_type=InputFileType.text, file_pattern=".*\\.txt$", ) - storage = create_storage_from_config(config.storage) + storage = create_storage(config.storage) documents = ( await InputReaderFactory() .create(config.file_type, {"storage": storage, "config": config}) diff --git a/tests/unit/load_config/fixtures/config.yaml b/tests/unit/load_config/fixtures/config.yaml new file mode 100644 index 0000000000..a54919d1eb --- /dev/null +++ b/tests/unit/load_config/fixtures/config.yaml @@ -0,0 +1,10 @@ +name: test_name +value: 100 +nested: + nested_str: nested_value + nested_int: 42 +nested_list: + - nested_str: list_value_1 + nested_int: 7 + - nested_str: list_value_2 + nested_int: 8 \ No newline at end of file diff --git a/uv.lock b/uv.lock index 6ef295df9d..8ec2e3830a 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,7 @@ members = [ "graphrag", "graphrag-common", "graphrag-monorepo", + "graphrag-storage", ] [[package]] @@ -1042,6 +1043,7 @@ dependencies = [ { name = "devtools" }, { name = "environs" }, { name = "graphrag-common" }, + { name = "graphrag-storage" }, { name = "graspologic-native" }, { name = "json-repair" }, { name = "lancedb" }, @@ -1073,6 +1075,7 @@ requires-dist = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "environs", specifier = ">=11.0.0" }, { name = "graphrag-common", editable = "packages/graphrag-common" }, + { name = "graphrag-storage", editable = "packages/graphrag-storage" }, { name = "graspologic-native", specifier = ">=1.2.5" }, { name = "json-repair", specifier = ">=0.30.3" }, { name = "lancedb", specifier = ">=0.17.0" }, @@ -1160,6 +1163,31 @@ dev = [ { name = "update-toml", specifier = ">=0.2.1" }, ] +[[package]] +name = "graphrag-storage" +version = "2.7.0" +source = { editable = "packages/graphrag-storage" } +dependencies = [ + { name = "aiofiles" }, + { name = "azure-cosmos" }, + { name = "azure-identity" }, + { name = "azure-storage-blob" }, + { name = "graphrag-common" }, + { name = "pandas" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "aiofiles", specifier = ">=24.1.0" }, + { name = "azure-cosmos", specifier = ">=4.9.0" }, + { name = "azure-identity", specifier = ">=1.19.0" }, + { name = "azure-storage-blob", specifier = ">=12.24.0" }, + { name = "graphrag-common", editable = "packages/graphrag-common" }, + { name = "pandas", specifier = ">=2.2.3" }, + { name = "pydantic", specifier = ">=2.10.3" }, +] + [[package]] name = "graspologic-native" version = "1.2.5"