diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py index 0df7ed31..16bf5047 100644 --- a/asyncpg/__init__.py +++ b/asyncpg/__init__.py @@ -15,4 +15,4 @@ __all__ = ('connect', 'create_pool', 'Record', 'Connection') + \ exceptions.__all__ # NOQA -__version__ = '0.16.0.dev0' +__version__ = '0.16.0.dev1' diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 2a4eba86..ea1a86fa 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -204,6 +204,14 @@ def transaction(self, *, isolation='read_committed', readonly=False, self._check_open() return transaction.Transaction(self, isolation, readonly, deferrable) + def is_in_transaction(self): + """Return True if Connection is currently inside a transaction. + + :return bool: True if inside transaction, False otherwise. + .. versionadded:: 0.16.0 + """ + return self._protocol.is_in_transaction() + async def execute(self, query: str, *args, timeout: float=None) -> str: """Execute an SQL command (or commands). diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 935d2abf..eb2d948e 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -14,12 +14,15 @@ class TestTransaction(tb.ConnectedTestCase): async def test_transaction_regular(self): self.assertIsNone(self.con._top_xact) + self.assertFalse(self.con.is_in_transaction()) tr = self.con.transaction() self.assertIsNone(self.con._top_xact) + self.assertFalse(self.con.is_in_transaction()) with self.assertRaises(ZeroDivisionError): async with tr as with_tr: self.assertIs(self.con._top_xact, tr) + self.assertTrue(self.con.is_in_transaction()) # We don't return the transaction object from __aenter__, # to make it harder for people to use '.rollback()' and @@ -33,6 +36,7 @@ async def test_transaction_regular(self): 1 / 0 self.assertIsNone(self.con._top_xact) + self.assertFalse(self.con.is_in_transaction()) with self.assertRaisesRegex(asyncpg.PostgresError, '"mytab" does not exist'): @@ -42,12 +46,17 @@ async def test_transaction_regular(self): async def test_transaction_nested(self): self.assertIsNone(self.con._top_xact) + self.assertFalse(self.con.is_in_transaction()) + tr = self.con.transaction() + self.assertIsNone(self.con._top_xact) + self.assertFalse(self.con.is_in_transaction()) with self.assertRaises(ZeroDivisionError): async with tr: self.assertIs(self.con._top_xact, tr) + self.assertTrue(self.con.is_in_transaction()) await self.con.execute(''' CREATE TABLE mytab (a int); @@ -55,18 +64,21 @@ async def test_transaction_nested(self): async with self.con.transaction(): self.assertIs(self.con._top_xact, tr) + self.assertTrue(self.con.is_in_transaction()) await self.con.execute(''' INSERT INTO mytab (a) VALUES (1), (2); ''') self.assertIs(self.con._top_xact, tr) + self.assertTrue(self.con.is_in_transaction()) with self.assertRaises(ZeroDivisionError): in_tr = self.con.transaction() async with in_tr: self.assertIs(self.con._top_xact, tr) + self.assertTrue(self.con.is_in_transaction()) await self.con.execute(''' INSERT INTO mytab (a) VALUES (3), (4); @@ -85,10 +97,12 @@ async def test_transaction_nested(self): self.assertEqual(recs[1][0], 2) self.assertIs(self.con._top_xact, tr) + self.assertTrue(self.con.is_in_transaction()) 1 / 0 self.assertIs(self.con._top_xact, None) + self.assertFalse(self.con.is_in_transaction()) with self.assertRaisesRegex(asyncpg.PostgresError, '"mytab" does not exist'): @@ -98,6 +112,7 @@ async def test_transaction_nested(self): async def test_transaction_interface_errors(self): self.assertIsNone(self.con._top_xact) + self.assertFalse(self.con.is_in_transaction()) tr = self.con.transaction(readonly=True, isolation='serializable') with self.assertRaisesRegex(asyncpg.InterfaceError, @@ -109,6 +124,7 @@ async def test_transaction_interface_errors(self): '