99
1010if TYPE_CHECKING :
1111 import logging
12+ from collections .abc import Callable , Coroutine
1213
13- from crawlee .storages . _key_value_store import KeyValueStore
14+ from crawlee .storages import KeyValueStore
1415
1516TStateModel = TypeVar ('TStateModel' , bound = BaseModel )
1617
@@ -38,7 +39,7 @@ def __init__(
3839 persistence_enabled : Literal [True , False , 'explicit_only' ] = False ,
3940 persist_state_kvs_name : str | None = None ,
4041 persist_state_kvs_id : str | None = None ,
41- persist_state_kvs : KeyValueStore | None = None ,
42+ persist_state_kvs_factory : Callable [[], Coroutine [ None , None , KeyValueStore ]] | None = None ,
4243 logger : logging .Logger ,
4344 ) -> None :
4445 """Initialize a new recoverable state object.
@@ -53,28 +54,40 @@ def __init__(
5354 If neither a name nor and id are supplied, the default store will be used.
5455 persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
5556 If neither a name nor and id are supplied, the default store will be used.
56- persist_state_kvs: KeyValueStore to use for persistence. If not provided, a system-wide KeyValueStore will
57- be used, based on service locator configuration.
57+ persist_state_kvs_factory: Factory that can be awaited to create KeyValueStore to use for persistence. If
58+ not provided, a system-wide KeyValueStore will be used, based on service locator configuration.
5859 logger: A logger instance for logging operations related to state persistence
5960 """
60- raise_if_too_many_kwargs (persist_state_kvs_name = persist_state_kvs_name ,
61- persist_state_kvs_id = persist_state_kvs_id ,
62- key_value_store = persist_state_kvs )
63- if not persist_state_kvs :
61+ raise_if_too_many_kwargs (
62+ persist_state_kvs_name = persist_state_kvs_name ,
63+ persist_state_kvs_id = persist_state_kvs_id ,
64+ persist_state_kvs_factory = persist_state_kvs_factory ,
65+ )
66+ if not persist_state_kvs_factory :
6467 logger .debug (
6568 'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore '
6669 'based on service_locator configuration, potentially calling service_locator.set_storage_client in the '
6770 'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid '
68- 'global side effects.' )
71+ 'global side effects.'
72+ )
6973
7074 self ._default_state = default_state
7175 self ._state_type : type [TStateModel ] = self ._default_state .__class__
7276 self ._state : TStateModel | None = None
7377 self ._persistence_enabled = persistence_enabled
7478 self ._persist_state_key = persist_state_key
75- self ._persist_state_kvs_name = persist_state_kvs_name
76- self ._persist_state_kvs_id = persist_state_kvs_id
77- self ._key_value_store : KeyValueStore | None = persist_state_kvs
79+ if persist_state_kvs_factory is None :
80+
81+ async def kvs_factory () -> KeyValueStore :
82+ from crawlee .storages import KeyValueStore # noqa: PLC0415 avoid circular import
83+
84+ return await KeyValueStore .open (name = persist_state_kvs_name , id = persist_state_kvs_id )
85+
86+ self ._persist_state_kvs_factory = kvs_factory
87+ else :
88+ self ._persist_state_kvs_factory = persist_state_kvs_factory
89+
90+ self ._key_value_store : KeyValueStore | None = None
7891 self ._log = logger
7992
8093 async def initialize (self ) -> TStateModel :
@@ -91,12 +104,8 @@ async def initialize(self) -> TStateModel:
91104 return self .current_value
92105
93106 # Import here to avoid circular imports.
94- from crawlee .storages ._key_value_store import KeyValueStore # noqa: PLC0415
95107
96- if not self ._key_value_store :
97- self ._key_value_store = await KeyValueStore .open (
98- name = self ._persist_state_kvs_name , id = self ._persist_state_kvs_id
99- )
108+ self ._key_value_store = await self ._persist_state_kvs_factory ()
100109
101110 await self ._load_saved_state ()
102111
0 commit comments