Skip to content

Commit f70a0e3

Browse files
committed
Guard transaction methods against underlying connection release
Similarly to other connection-dependent objects, transaction methods should not be called once the underlying connection has been released to the pool. Also, add a special handling for the case of asynchronous generator finalization, in which case it's OK for `Transaction.__aexit__()` to be called _after_ `Pool.release()`, since we cannot control when the finalization task would execute. Fixes: #232.
1 parent 46f468c commit f70a0e3

File tree

5 files changed

+126
-13
lines changed

5 files changed

+126
-13
lines changed

asyncpg/connection.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,19 @@ async def reset(self, *, timeout=None):
998998
self._listeners.clear()
999999
self._log_listeners.clear()
10001000
reset_query = self._get_reset_query()
1001+
1002+
if self._protocol.is_in_transaction() or self._top_xact is not None:
1003+
if self._top_xact is None or not self._top_xact._managed:
1004+
# Managed transactions are guaranteed to __aexit__
1005+
# correctly.
1006+
self._loop.call_exception_handler({
1007+
'message': 'Resetting connection with an '
1008+
'active transaction {!r}'.format(self)
1009+
})
1010+
1011+
self._top_xact = None
1012+
reset_query = 'ROLLBACK;\n' + reset_query
1013+
10011014
if reset_query:
10021015
await self.execute(reset_query, timeout=timeout)
10031016

@@ -1152,13 +1165,6 @@ def _get_reset_query(self):
11521165
caps = self._server_caps
11531166

11541167
_reset_query = []
1155-
if self._protocol.is_in_transaction() or self._top_xact is not None:
1156-
self._loop.call_exception_handler({
1157-
'message': 'Resetting connection with an '
1158-
'active transaction {!r}'.format(self)
1159-
})
1160-
self._top_xact = None
1161-
_reset_query.append('ROLLBACK;')
11621168
if caps.advisory_locks:
11631169
_reset_query.append('SELECT pg_advisory_unlock_all();')
11641170
if caps.sql_close_all:

asyncpg/connresource.py

+6
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,9 @@ def _check_conn_validity(self, meth_name):
3636
'cannot call {}.{}(): '
3737
'the underlying connection has been released back '
3838
'to the pool'.format(self.__class__.__name__, meth_name))
39+
40+
if self._connection.is_closed():
41+
raise exceptions.InterfaceError(
42+
'cannot call {}.{}(): '
43+
'the underlying connection is closed'.format(
44+
self.__class__.__name__, meth_name))

asyncpg/pool.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ def __getattr__(self, attr):
7272
# Proxy all unresolved attributes to the wrapped Connection object.
7373
return getattr(self._con, attr)
7474

75-
def _detach(self):
75+
def _detach(self) -> connection.Connection:
7676
if self._con is None:
7777
raise exceptions.InterfaceError(
7878
'cannot detach PoolConnectionProxy: already detached')
7979

8080
con, self._con = self._con, None
8181
con._set_proxy(None)
82+
return con
8283

8384
def __repr__(self):
8485
if self._con is None:
@@ -179,8 +180,6 @@ async def release(self, timeout):
179180
self._in_use = False
180181
self._timeout = None
181182

182-
self._con._on_release()
183-
184183
if self._con.is_closed():
185184
self._con = None
186185

@@ -508,7 +507,8 @@ async def _release_impl(ch: PoolConnectionHolder, timeout: float):
508507
# Already released, do nothing.
509508
return
510509

511-
connection._detach()
510+
con = connection._detach()
511+
con._on_release()
512512

513513
if timeout is None:
514514
timeout = connection._holder._timeout

asyncpg/transaction.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import enum
99

10+
from . import connresource
1011
from . import exceptions as apg_errors
1112

1213

@@ -21,7 +22,7 @@ class TransactionState(enum.Enum):
2122
ISOLATION_LEVELS = {'read_committed', 'serializable', 'repeatable_read'}
2223

2324

24-
class Transaction:
25+
class Transaction(connresource.ConnectionResource):
2526
"""Represents a transaction or savepoint block.
2627
2728
Transactions are created by calling the
@@ -33,6 +34,8 @@ class Transaction:
3334
'_state', '_nested', '_id', '_managed')
3435

3536
def __init__(self, connection, isolation, readonly, deferrable):
37+
super().__init__(connection)
38+
3639
if isolation not in ISOLATION_LEVELS:
3740
raise ValueError(
3841
'isolation is expected to be either of {}, '
@@ -49,7 +52,6 @@ def __init__(self, connection, isolation, readonly, deferrable):
4952
'"deferrable" is only supported for '
5053
'serializable readonly transactions')
5154

52-
self._connection = connection
5355
self._isolation = isolation
5456
self._readonly = readonly
5557
self._deferrable = deferrable
@@ -66,6 +68,22 @@ async def __aenter__(self):
6668
await self.start()
6769

6870
async def __aexit__(self, extype, ex, tb):
71+
try:
72+
self._check_conn_validity('__aexit__')
73+
except apg_errors.InterfaceError:
74+
if extype is GeneratorExit:
75+
# When a PoolAcquireContext is being exited, and there
76+
# is an open transaction in an async generator that has
77+
# not been iterated fully, there is a possibility that
78+
# Pool.release() would race with this __aexit__(), since
79+
# both would be in concurrent tasks. In such case we
80+
# yield to Pool.release() to do the ROLLBACK for us.
81+
# See https://github.com/MagicStack/asyncpg/issues/232
82+
# for an example.
83+
return
84+
else:
85+
raise
86+
6987
try:
7088
if extype is not None:
7189
await self.__rollback()
@@ -74,6 +92,7 @@ async def __aexit__(self, extype, ex, tb):
7492
finally:
7593
self._managed = False
7694

95+
@connresource.guarded
7796
async def start(self):
7897
"""Enter the transaction or savepoint block."""
7998
self.__check_state_base('start')
@@ -183,13 +202,15 @@ async def __rollback(self):
183202
else:
184203
self._state = TransactionState.ROLLEDBACK
185204

205+
@connresource.guarded
186206
async def commit(self):
187207
"""Exit the transaction or savepoint block and commit changes."""
188208
if self._managed:
189209
raise apg_errors.InterfaceError(
190210
'cannot manually commit from within an `async with` block')
191211
await self.__commit()
192212

213+
@connresource.guarded
193214
async def rollback(self):
194215
"""Exit the transaction or savepoint block and rollback changes."""
195216
if self._managed:

tests/test_pool.py

+80
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import os
1212
import platform
1313
import random
14+
import sys
15+
import textwrap
1416
import time
1517
import unittest
1618

@@ -195,6 +197,7 @@ async def test_pool_11(self):
195197
self.assertIn(repr(con._con), repr(con)) # Test __repr__.
196198

197199
ps = await con.prepare('SELECT 1')
200+
txn = con.transaction()
198201
async with con.transaction():
199202
cur = await con.cursor('SELECT 1')
200203
ps_cur = await ps.cursor()
@@ -233,6 +236,14 @@ async def test_pool_11(self):
233236

234237
c.forward(1)
235238

239+
for meth in ('start', 'commit', 'rollback'):
240+
with self.assertRaisesRegex(
241+
asyncpg.InterfaceError,
242+
r'cannot call Transaction\.{meth}.*released '
243+
r'back to the pool'.format(meth=meth)):
244+
245+
getattr(txn, meth)()
246+
236247
await pool.close()
237248

238249
async def test_pool_12(self):
@@ -661,6 +672,75 @@ async def test_pool_handles_inactive_connection_errors(self):
661672
await con.close()
662673
await pool.close()
663674

675+
@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
676+
async def test_pool_handles_transaction_exit_in_asyncgen_1(self):
677+
pool = await self.create_pool(database='postgres',
678+
min_size=1, max_size=1)
679+
680+
locals_ = {}
681+
exec(textwrap.dedent('''\
682+
async def iterate(con):
683+
async with con.transaction():
684+
for record in await con.fetch("SELECT 1"):
685+
yield record
686+
'''), globals(), locals_)
687+
iterate = locals_['iterate']
688+
689+
class MyException(Exception):
690+
pass
691+
692+
with self.assertRaises(MyException):
693+
async with pool.acquire() as con:
694+
async for _ in iterate(con): # noqa
695+
raise MyException()
696+
697+
@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
698+
async def test_pool_handles_transaction_exit_in_asyncgen_2(self):
699+
pool = await self.create_pool(database='postgres',
700+
min_size=1, max_size=1)
701+
702+
locals_ = {}
703+
exec(textwrap.dedent('''\
704+
async def iterate(con):
705+
async with con.transaction():
706+
for record in await con.fetch("SELECT 1"):
707+
yield record
708+
'''), globals(), locals_)
709+
iterate = locals_['iterate']
710+
711+
class MyException(Exception):
712+
pass
713+
714+
with self.assertRaises(MyException):
715+
async with pool.acquire() as con:
716+
iterator = iterate(con)
717+
async for _ in iterator: # noqa
718+
raise MyException()
719+
720+
del iterator
721+
722+
@unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support')
723+
async def test_pool_handles_asyncgen_finalization(self):
724+
pool = await self.create_pool(database='postgres',
725+
min_size=1, max_size=1)
726+
727+
locals_ = {}
728+
exec(textwrap.dedent('''\
729+
async def iterate(con):
730+
for record in await con.fetch("SELECT 1"):
731+
yield record
732+
'''), globals(), locals_)
733+
iterate = locals_['iterate']
734+
735+
class MyException(Exception):
736+
pass
737+
738+
with self.assertRaises(MyException):
739+
async with pool.acquire() as con:
740+
async with con.transaction():
741+
async for _ in iterate(con): # noqa
742+
raise MyException()
743+
664744

665745
@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
666746
class TestHotStandby(tb.ConnectedTestCase):

0 commit comments

Comments
 (0)