diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py index c03c6423b..336bc636d 100644 --- a/edgedb/asyncio_client.py +++ b/edgedb/asyncio_client.py @@ -23,6 +23,7 @@ import socket import ssl import typing +import uuid from . import abstract from . import base_client @@ -322,6 +323,10 @@ def _exclusive(self): finally: self._locked = False + async def savepoint(self) -> transaction.Savepoint: + name = "s" + uuid.uuid4().hex + return await self._declare_savepoint(name) + class AsyncIORetry(transaction.BaseRetry): diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index 7eb761b98..53719a6c3 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -25,6 +25,7 @@ import threading import time import typing +import uuid from . import abstract from . import base_client @@ -270,6 +271,14 @@ async def close(self, timeout=None): self._closing = False +class Savepoint(transaction.Savepoint): + def release(self): + self._tx._client._iter_coroutine(super().release()) + + def rollback(self): + self._tx._client._iter_coroutine(super().rollback()) + + class Iteration(transaction.BaseTransaction, abstract.Executor): __slots__ = ("_managed", "_lock") @@ -320,6 +329,12 @@ def _exclusive(self): finally: self._lock.release() + def savepoint(self) -> Savepoint: + name = "s" + uuid.uuid4().hex + return self._client._iter_coroutine( + self._declare_savepoint(name, cls=Savepoint) + ) + class Retry(transaction.BaseRetry): diff --git a/edgedb/transaction.py b/edgedb/transaction.py index 511b8f42e..ba7aceafd 100644 --- a/edgedb/transaction.py +++ b/edgedb/transaction.py @@ -17,6 +17,8 @@ # +from __future__ import annotations + import enum from . import abstract @@ -32,12 +34,47 @@ class TransactionState(enum.Enum): FAILED = 4 +class Savepoint: + __slots__ = ('_name', '_tx', '_active') + + def __init__(self, name: str, transaction: BaseTransaction): + self._name = name + self._tx = transaction + self._active = True + + @property + def active(self): + return self._active + + def _ensure_active(self): + if not self._active: + raise errors.InterfaceError( + f"savepoint {self._name!r} is no longer active" + ) + + async def release(self): + self._ensure_active() + await self._tx._privileged_execute(f"release savepoint {self._name}") + del self._tx._savepoints[self._name] + self._active = False + + async def rollback(self): + self._ensure_active() + await self._tx._privileged_execute( + f"rollback to savepoint {self._name}" + ) + names = list(self._tx._savepoints) + for name in names[names.index(self._name):]: + self._tx._savepoints.pop(name)._active = False + + class BaseTransaction: __slots__ = ( '_client', '_connection', '_options', + '_savepoints', '_state', '__retry', '__iteration', @@ -48,6 +85,7 @@ def __init__(self, retry, client, iteration): self._client = client self._connection = None self._options = retry._options.transaction_options + self._savepoints = {} self._state = TransactionState.NEW self.__retry = retry self.__iteration = iteration @@ -128,6 +166,9 @@ async def _exit(self, extype, ex): if not self.__started: return False + for sp in self._savepoints.values(): + sp._active = False + try: if extype is None: query = self._make_commit_query() @@ -200,6 +241,16 @@ async def _privileged_execute(self, query: str) -> None: state=self._get_state(), )) + async def _declare_savepoint(self, savepoint: str, cls=Savepoint): + if savepoint in self._savepoints: + raise errors.InterfaceError( + f"savepoint {savepoint!r} already exists" + ) + await self._ensure_transaction() + await self._privileged_execute(f"declare savepoint {savepoint}") + self._savepoints[savepoint] = rv = cls(savepoint, self) + return rv + class BaseRetry: diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py index 8ceeb239a..b174c18f8 100644 --- a/tests/test_async_tx.py +++ b/tests/test_async_tx.py @@ -34,6 +34,10 @@ class TestAsyncTx(tb.AsyncQueryTestCase): }; ''' + TEARDOWN_METHOD = ''' + DELETE test::TransactionTest; + ''' + TEARDOWN = ''' DROP TYPE test::TransactionTest; ''' @@ -104,3 +108,50 @@ async def test_async_transaction_exclusive(self): ): await asyncio.wait_for(f1, timeout=5) await asyncio.wait_for(f2, timeout=5) + + async def test_async_transaction_savepoint_1(self): + async for tx in self.client.transaction(): + async with tx: + sp1 = await tx.savepoint() + sp2 = await tx.savepoint() + await tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + await sp2.release() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + await sp2.release() + await sp1.release() + + result = await self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"]) + + async def test_async_transaction_savepoint_2(self): + async for tx in self.client.transaction(): + async with tx: + await tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + sp1 = await tx.savepoint() + await tx.execute(''' + INSERT test::TransactionTest { name := '2' }; + ''') + sp2 = await tx.savepoint() + await tx.execute(''' + INSERT test::TransactionTest { name := '3' }; + ''') + await sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + await sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + await sp2.rollback() + + result = await self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"]) diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index 3ed2fc55f..d8735049d 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -33,6 +33,10 @@ class TestSyncTx(tb.SyncQueryTestCase): }; ''' + TEARDOWN_METHOD = ''' + DELETE test::TransactionTest; + ''' + TEARDOWN = ''' DROP TYPE test::TransactionTest; ''' @@ -113,3 +117,50 @@ def test_sync_transaction_exclusive(self): ): f1.result(timeout=5) f2.result(timeout=5) + + def test_sync_transaction_savepoint_1(self): + for tx in self.client.transaction(): + with tx: + sp1 = tx.savepoint() + sp2 = tx.savepoint() + tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + sp2.release() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + sp2.release() + sp1.release() + + result = self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"]) + + def test_sync_transaction_savepoint_2(self): + for tx in self.client.transaction(): + with tx: + tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + sp1 = tx.savepoint() + tx.execute(''' + INSERT test::TransactionTest { name := '2' }; + ''') + sp2 = tx.savepoint() + tx.execute(''' + INSERT test::TransactionTest { name := '3' }; + ''') + sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + sp2.rollback() + + result = self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"])