Skip to content

Commit 50edd8c

Browse files
committed
protocol: Use try-finally explicitly every time we create a waiter
1 parent f29de23 commit 50edd8c

File tree

5 files changed

+129
-75
lines changed

5 files changed

+129
-75
lines changed

asyncpg/exceptions/_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111

1212
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
13-
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage')
13+
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
14+
'InternalClientError')
1415

1516

1617
def _is_asyncpg_class(cls):
@@ -190,6 +191,10 @@ def __init__(self, msg, *, detail=None, hint=None):
190191
Warning.__init__(self, msg)
191192

192193

194+
class InternalClientError(Exception):
195+
pass
196+
197+
193198
class PostgresLogMessage(PostgresMessage):
194199
"""A base class for non-error server messages."""
195200

asyncpg/protocol/coreproto.pyx

+9-9
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ cdef class CoreProtocol:
125125
self.buffer.consume_message()
126126

127127
else:
128-
raise RuntimeError(
128+
raise apg_exc.InternalClientError(
129129
'protocol is in an unknown state {}'.format(state))
130130

131131
except Exception as ex:
@@ -472,7 +472,7 @@ cdef class CoreProtocol:
472472

473473
if ASYNCPG_DEBUG:
474474
if buf.get_message_type() != b'D':
475-
raise RuntimeError(
475+
raise apg_exc.InternalClientError(
476476
'_parse_data_msgs: first message is not "D"')
477477

478478
if self._discard_data:
@@ -484,7 +484,7 @@ cdef class CoreProtocol:
484484

485485
if ASYNCPG_DEBUG:
486486
if type(self.result) is not list:
487-
raise RuntimeError(
487+
raise apg_exc.InternalClientError(
488488
'_parse_data_msgs: result is not a list, but {!r}'.
489489
format(self.result))
490490

@@ -639,11 +639,11 @@ cdef class CoreProtocol:
639639
cdef _set_state(self, ProtocolState new_state):
640640
if new_state == PROTOCOL_IDLE:
641641
if self.state == PROTOCOL_FAILED:
642-
raise RuntimeError(
642+
raise apg_exc.InternalClientError(
643643
'cannot switch to "idle" state; '
644644
'protocol is in the "failed" state')
645645
elif self.state == PROTOCOL_IDLE:
646-
raise RuntimeError(
646+
raise apg_exc.InternalClientError(
647647
'protocol is already in the "idle" state')
648648
else:
649649
self.state = new_state
@@ -671,18 +671,18 @@ cdef class CoreProtocol:
671671
self.state = new_state
672672

673673
elif self.state == PROTOCOL_FAILED:
674-
raise RuntimeError(
674+
raise apg_exc.InternalClientError(
675675
'cannot switch to state {}; '
676676
'protocol is in the "failed" state'.format(new_state))
677677
else:
678-
raise RuntimeError(
678+
raise apg_exc.InternalClientError(
679679
'cannot switch to state {}; '
680680
'another operation ({}) is in progress'.format(
681681
new_state, self.state))
682682

683683
cdef _ensure_connected(self):
684684
if self.con_status != CONNECTION_OK:
685-
raise RuntimeError('not connected')
685+
raise apg_exc.InternalClientError('not connected')
686686

687687
cdef WriteBuffer _build_bind_message(self, str portal_name,
688688
str stmt_name,
@@ -707,7 +707,7 @@ cdef class CoreProtocol:
707707
WriteBuffer outbuf
708708

709709
if self.con_status != CONNECTION_BAD:
710-
raise RuntimeError('already connected')
710+
raise apg_exc.InternalClientError('already connected')
711711

712712
self._set_state(PROTOCOL_AUTH)
713713
self.con_status = CONNECTION_STARTED

asyncpg/protocol/protocol.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ cdef class BaseProtocol(CoreProtocol):
5151
cdef _get_timeout_impl(self, timeout)
5252
cdef _check_state(self)
5353
cdef _new_waiter(self, timeout)
54+
cdef _coreproto_error(self)
5455

5556
cdef _on_result__connect(self, object waiter)
5657
cdef _on_result__prepare(self, object waiter)

asyncpg/protocol/protocol.pyx

+111-64
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,16 @@ cdef class BaseProtocol(CoreProtocol):
156156
self._check_state()
157157
timeout = self._get_timeout_impl(timeout)
158158

159-
self._prepare(stmt_name, query)
160-
self.last_query = query
161-
self.statement = PreparedStatementState(stmt_name, query, self)
162-
163-
return await self._new_waiter(timeout)
159+
waiter = self._new_waiter(timeout)
160+
try:
161+
self._prepare(stmt_name, query) # network op
162+
self.last_query = query
163+
self.statement = PreparedStatementState(stmt_name, query, self)
164+
except Exception as ex:
165+
waiter.set_exception(ex)
166+
self._coreproto_error()
167+
finally:
168+
return await waiter
164169

165170
async def bind_execute(self, PreparedStatementState state, args,
166171
str portal_name, int limit, return_extra,
@@ -174,19 +179,25 @@ cdef class BaseProtocol(CoreProtocol):
174179

175180
self._check_state()
176181
timeout = self._get_timeout_impl(timeout)
182+
args_buf = state._encode_bind_msg(args)
177183

178-
self._bind_execute(
179-
portal_name,
180-
state.name,
181-
state._encode_bind_msg(args),
182-
limit)
183-
184-
self.last_query = state.query
185-
self.statement = state
186-
self.return_extra = return_extra
187-
self.queries_count += 1
188-
189-
return await self._new_waiter(timeout)
184+
waiter = self._new_waiter(timeout)
185+
try:
186+
self._bind_execute(
187+
portal_name,
188+
state.name,
189+
args_buf,
190+
limit) # network op
191+
192+
self.last_query = state.query
193+
self.statement = state
194+
self.return_extra = return_extra
195+
self.queries_count += 1
196+
except Exception as ex:
197+
waiter.set_exception(ex)
198+
self._coreproto_error()
199+
finally:
200+
return await waiter
190201

191202
async def bind_execute_many(self, PreparedStatementState state, args,
192203
str portal_name, timeout):
@@ -207,18 +218,21 @@ cdef class BaseProtocol(CoreProtocol):
207218
arg_bufs = iter(data_gen)
208219

209220
waiter = self._new_waiter(timeout)
221+
try:
222+
self._bind_execute_many(
223+
portal_name,
224+
state.name,
225+
arg_bufs) # network op
210226

211-
self._bind_execute_many(
212-
portal_name,
213-
state.name,
214-
arg_bufs)
215-
216-
self.last_query = state.query
217-
self.statement = state
218-
self.return_extra = False
219-
self.queries_count += 1
220-
221-
return await waiter
227+
self.last_query = state.query
228+
self.statement = state
229+
self.return_extra = False
230+
self.queries_count += 1
231+
except Exception as ex:
232+
waiter.set_exception(ex)
233+
self._coreproto_error()
234+
finally:
235+
return await waiter
222236

223237
async def bind(self, PreparedStatementState state, args,
224238
str portal_name, timeout):
@@ -231,16 +245,22 @@ cdef class BaseProtocol(CoreProtocol):
231245

232246
self._check_state()
233247
timeout = self._get_timeout_impl(timeout)
248+
args_buf = state._encode_bind_msg(args)
234249

235-
self._bind(
236-
portal_name,
237-
state.name,
238-
state._encode_bind_msg(args))
239-
240-
self.last_query = state.query
241-
self.statement = state
242-
243-
return await self._new_waiter(timeout)
250+
waiter = self._new_waiter(timeout)
251+
try:
252+
self._bind(
253+
portal_name,
254+
state.name,
255+
args_buf) # network op
256+
257+
self.last_query = state.query
258+
self.statement = state
259+
except Exception as ex:
260+
waiter.set_exception(ex)
261+
self._coreproto_error()
262+
finally:
263+
return await waiter
244264

245265
async def execute(self, PreparedStatementState state,
246266
str portal_name, int limit, return_extra,
@@ -255,16 +275,21 @@ cdef class BaseProtocol(CoreProtocol):
255275
self._check_state()
256276
timeout = self._get_timeout_impl(timeout)
257277

258-
self._execute(
259-
portal_name,
260-
limit)
261-
262-
self.last_query = state.query
263-
self.statement = state
264-
self.return_extra = return_extra
265-
self.queries_count += 1
266-
267-
return await self._new_waiter(timeout)
278+
waiter = self._new_waiter(timeout)
279+
try:
280+
self._execute(
281+
portal_name,
282+
limit) # network op
283+
284+
self.last_query = state.query
285+
self.statement = state
286+
self.return_extra = return_extra
287+
self.queries_count += 1
288+
except Exception as ex:
289+
waiter.set_exception(ex)
290+
self._coreproto_error()
291+
finally:
292+
return await waiter
268293

269294
async def query(self, query, timeout):
270295
if self.cancel_waiter is not None:
@@ -279,11 +304,16 @@ cdef class BaseProtocol(CoreProtocol):
279304
# prepare/bind/execute methods.
280305
timeout = self._get_timeout(timeout)
281306

282-
self._simple_query(query)
283-
self.last_query = query
284-
self.queries_count += 1
285-
286-
return await self._new_waiter(timeout)
307+
waiter = self._new_waiter(timeout)
308+
try:
309+
self._simple_query(query) # network op
310+
self.last_query = query
311+
self.queries_count += 1
312+
except Exception as ex:
313+
waiter.set_exception(ex)
314+
self._coreproto_error()
315+
finally:
316+
return await waiter
287317

288318
async def copy_out(self, copy_stmt, sink, timeout):
289319
if self.cancel_waiter is not None:
@@ -378,7 +408,7 @@ cdef class BaseProtocol(CoreProtocol):
378408
for codec in codecs:
379409
if (not codec.has_encoder() or
380410
codec.format != PG_FORMAT_BINARY):
381-
raise RuntimeError(
411+
raise apg_exc.InternalClientError(
382412
'no binary format encoder for '
383413
'type {} (OID {})'.format(codec.name, codec.oid))
384414

@@ -439,7 +469,7 @@ cdef class BaseProtocol(CoreProtocol):
439469
except TimeoutError:
440470
raise
441471
else:
442-
raise RuntimeError('TimoutError was not raised')
472+
raise apg_exc.InternalClientError('TimoutError was not raised')
443473

444474
except Exception as e:
445475
self._write_copy_fail_msg(str(e))
@@ -460,16 +490,22 @@ cdef class BaseProtocol(CoreProtocol):
460490
self.cancel_sent_waiter = None
461491

462492
self._check_state()
463-
timeout = self._get_timeout_impl(timeout)
464493

465494
if state.refs != 0:
466-
raise RuntimeError(
495+
raise apg_exc.InternalClientError(
467496
'cannot close prepared statement; refs == {} != 0'.format(
468497
state.refs))
469498

470-
self._close(state.name, False)
471-
state.closed = True
472-
return await self._new_waiter(timeout)
499+
timeout = self._get_timeout_impl(timeout)
500+
waiter = self._new_waiter(timeout)
501+
try:
502+
self._close(state.name, False) # network op
503+
state.closed = True
504+
except Exception as ex:
505+
waiter.set_exception(ex)
506+
self._coreproto_error()
507+
finally:
508+
return await waiter
473509

474510
def is_closed(self):
475511
return self.closing
@@ -579,6 +615,17 @@ cdef class BaseProtocol(CoreProtocol):
579615
raise apg_exc.InterfaceError(
580616
'cannot perform operation: another operation is in progress')
581617

618+
cdef _coreproto_error(self):
619+
try:
620+
if self.waiter is not None:
621+
if not self.waiter.done():
622+
raise apg_exc.InternalClientError(
623+
'waiter is not done while handling critical '
624+
'protocol error')
625+
self.waiter = None
626+
finally:
627+
self.abort()
628+
582629
cdef _new_waiter(self, timeout):
583630
if self.waiter is not None:
584631
raise apg_exc.InterfaceError(
@@ -596,7 +643,7 @@ cdef class BaseProtocol(CoreProtocol):
596643
cdef _on_result__prepare(self, object waiter):
597644
if ASYNCPG_DEBUG:
598645
if self.statement is None:
599-
raise RuntimeError(
646+
raise apg_exc.InternalClientError(
600647
'_on_result__prepare: statement is None')
601648

602649
if self.result_param_desc is not None:
@@ -643,7 +690,7 @@ cdef class BaseProtocol(CoreProtocol):
643690
cdef _decode_row(self, const char* buf, ssize_t buf_len):
644691
if ASYNCPG_DEBUG:
645692
if self.statement is None:
646-
raise RuntimeError(
693+
raise apg_exc.InternalClientError(
647694
'_decode_row: statement is None')
648695

649696
return self.statement._decode_row(buf, buf_len)
@@ -654,13 +701,13 @@ cdef class BaseProtocol(CoreProtocol):
654701

655702
if ASYNCPG_DEBUG:
656703
if waiter is None:
657-
raise RuntimeError('_on_result: waiter is None')
704+
raise apg_exc.InternalClientError('_on_result: waiter is None')
658705

659706
if waiter.cancelled():
660707
return
661708

662709
if waiter.done():
663-
raise RuntimeError('_on_result: waiter is done')
710+
raise apg_exc.InternalClientError('_on_result: waiter is done')
664711

665712
if self.result_type == RESULT_FAILED:
666713
if isinstance(self.result, dict):
@@ -704,7 +751,7 @@ cdef class BaseProtocol(CoreProtocol):
704751
self._on_result__copy_in(waiter)
705752

706753
else:
707-
raise RuntimeError(
754+
raise apg_exc.InternalClientError(
708755
'got result for unknown protocol state {}'.
709756
format(self.state))
710757

0 commit comments

Comments
 (0)