@@ -156,11 +156,16 @@ cdef class BaseProtocol(CoreProtocol):
156
156
self ._check_state()
157
157
timeout = self ._get_timeout_impl(timeout)
158
158
159
- self ._prepare(stmt_name, query)
160
- self .last_query = query
161
- self .statement = PreparedStatementState(stmt_name, query, self )
162
-
163
- return await self ._new_waiter(timeout)
159
+ waiter = self ._new_waiter(timeout)
160
+ try :
161
+ self ._prepare(stmt_name, query) # network op
162
+ self .last_query = query
163
+ self .statement = PreparedStatementState(stmt_name, query, self )
164
+ except Exception as ex:
165
+ waiter.set_exception(ex)
166
+ self ._coreproto_error()
167
+ finally :
168
+ return await waiter
164
169
165
170
async def bind_execute(self , PreparedStatementState state, args,
166
171
str portal_name, int limit, return_extra,
@@ -174,19 +179,25 @@ cdef class BaseProtocol(CoreProtocol):
174
179
175
180
self ._check_state()
176
181
timeout = self ._get_timeout_impl(timeout)
182
+ args_buf = state._encode_bind_msg(args)
177
183
178
- self ._bind_execute(
179
- portal_name,
180
- state.name,
181
- state._encode_bind_msg(args),
182
- limit)
183
-
184
- self .last_query = state.query
185
- self .statement = state
186
- self .return_extra = return_extra
187
- self .queries_count += 1
188
-
189
- return await self ._new_waiter(timeout)
184
+ waiter = self ._new_waiter(timeout)
185
+ try :
186
+ self ._bind_execute(
187
+ portal_name,
188
+ state.name,
189
+ args_buf,
190
+ limit) # network op
191
+
192
+ self .last_query = state.query
193
+ self .statement = state
194
+ self .return_extra = return_extra
195
+ self .queries_count += 1
196
+ except Exception as ex:
197
+ waiter.set_exception(ex)
198
+ self ._coreproto_error()
199
+ finally :
200
+ return await waiter
190
201
191
202
async def bind_execute_many(self , PreparedStatementState state, args,
192
203
str portal_name, timeout):
@@ -207,18 +218,21 @@ cdef class BaseProtocol(CoreProtocol):
207
218
arg_bufs = iter (data_gen)
208
219
209
220
waiter = self ._new_waiter(timeout)
221
+ try :
222
+ self ._bind_execute_many(
223
+ portal_name,
224
+ state.name,
225
+ arg_bufs) # network op
210
226
211
- self ._bind_execute_many(
212
- portal_name,
213
- state.name,
214
- arg_bufs)
215
-
216
- self .last_query = state.query
217
- self .statement = state
218
- self .return_extra = False
219
- self .queries_count += 1
220
-
221
- return await waiter
227
+ self .last_query = state.query
228
+ self .statement = state
229
+ self .return_extra = False
230
+ self .queries_count += 1
231
+ except Exception as ex:
232
+ waiter.set_exception(ex)
233
+ self ._coreproto_error()
234
+ finally :
235
+ return await waiter
222
236
223
237
async def bind(self , PreparedStatementState state, args,
224
238
str portal_name, timeout):
@@ -231,16 +245,22 @@ cdef class BaseProtocol(CoreProtocol):
231
245
232
246
self ._check_state()
233
247
timeout = self ._get_timeout_impl(timeout)
248
+ args_buf = state._encode_bind_msg(args)
234
249
235
- self ._bind(
236
- portal_name,
237
- state.name,
238
- state._encode_bind_msg(args))
239
-
240
- self .last_query = state.query
241
- self .statement = state
242
-
243
- return await self ._new_waiter(timeout)
250
+ waiter = self ._new_waiter(timeout)
251
+ try :
252
+ self ._bind(
253
+ portal_name,
254
+ state.name,
255
+ args_buf) # network op
256
+
257
+ self .last_query = state.query
258
+ self .statement = state
259
+ except Exception as ex:
260
+ waiter.set_exception(ex)
261
+ self ._coreproto_error()
262
+ finally :
263
+ return await waiter
244
264
245
265
async def execute(self , PreparedStatementState state,
246
266
str portal_name, int limit, return_extra,
@@ -255,16 +275,21 @@ cdef class BaseProtocol(CoreProtocol):
255
275
self ._check_state()
256
276
timeout = self ._get_timeout_impl(timeout)
257
277
258
- self ._execute(
259
- portal_name,
260
- limit)
261
-
262
- self .last_query = state.query
263
- self .statement = state
264
- self .return_extra = return_extra
265
- self .queries_count += 1
266
-
267
- return await self ._new_waiter(timeout)
278
+ waiter = self ._new_waiter(timeout)
279
+ try :
280
+ self ._execute(
281
+ portal_name,
282
+ limit) # network op
283
+
284
+ self .last_query = state.query
285
+ self .statement = state
286
+ self .return_extra = return_extra
287
+ self .queries_count += 1
288
+ except Exception as ex:
289
+ waiter.set_exception(ex)
290
+ self ._coreproto_error()
291
+ finally :
292
+ return await waiter
268
293
269
294
async def query(self , query, timeout):
270
295
if self .cancel_waiter is not None :
@@ -279,11 +304,16 @@ cdef class BaseProtocol(CoreProtocol):
279
304
# prepare/bind/execute methods.
280
305
timeout = self ._get_timeout(timeout)
281
306
282
- self ._simple_query(query)
283
- self .last_query = query
284
- self .queries_count += 1
285
-
286
- return await self ._new_waiter(timeout)
307
+ waiter = self ._new_waiter(timeout)
308
+ try :
309
+ self ._simple_query(query) # network op
310
+ self .last_query = query
311
+ self .queries_count += 1
312
+ except Exception as ex:
313
+ waiter.set_exception(ex)
314
+ self ._coreproto_error()
315
+ finally :
316
+ return await waiter
287
317
288
318
async def copy_out(self , copy_stmt, sink, timeout):
289
319
if self .cancel_waiter is not None :
@@ -378,7 +408,7 @@ cdef class BaseProtocol(CoreProtocol):
378
408
for codec in codecs:
379
409
if (not codec.has_encoder() or
380
410
codec.format != PG_FORMAT_BINARY):
381
- raise RuntimeError (
411
+ raise apg_exc.InternalClientError (
382
412
' no binary format encoder for '
383
413
' type {} (OID {})' .format(codec.name, codec.oid))
384
414
@@ -439,7 +469,7 @@ cdef class BaseProtocol(CoreProtocol):
439
469
except TimeoutError:
440
470
raise
441
471
else :
442
- raise RuntimeError (' TimoutError was not raised' )
472
+ raise apg_exc.InternalClientError (' TimoutError was not raised' )
443
473
444
474
except Exception as e:
445
475
self ._write_copy_fail_msg(str (e))
@@ -460,16 +490,22 @@ cdef class BaseProtocol(CoreProtocol):
460
490
self .cancel_sent_waiter = None
461
491
462
492
self ._check_state()
463
- timeout = self ._get_timeout_impl(timeout)
464
493
465
494
if state.refs != 0 :
466
- raise RuntimeError (
495
+ raise apg_exc.InternalClientError (
467
496
' cannot close prepared statement; refs == {} != 0' .format(
468
497
state.refs))
469
498
470
- self ._close(state.name, False )
471
- state.closed = True
472
- return await self ._new_waiter(timeout)
499
+ timeout = self ._get_timeout_impl(timeout)
500
+ waiter = self ._new_waiter(timeout)
501
+ try :
502
+ self ._close(state.name, False ) # network op
503
+ state.closed = True
504
+ except Exception as ex:
505
+ waiter.set_exception(ex)
506
+ self ._coreproto_error()
507
+ finally :
508
+ return await waiter
473
509
474
510
def is_closed (self ):
475
511
return self .closing
@@ -579,6 +615,17 @@ cdef class BaseProtocol(CoreProtocol):
579
615
raise apg_exc.InterfaceError(
580
616
' cannot perform operation: another operation is in progress' )
581
617
618
+ cdef _coreproto_error(self ):
619
+ try :
620
+ if self .waiter is not None :
621
+ if not self .waiter.done():
622
+ raise apg_exc.InternalClientError(
623
+ ' waiter is not done while handling critical '
624
+ ' protocol error' )
625
+ self .waiter = None
626
+ finally :
627
+ self .abort()
628
+
582
629
cdef _new_waiter(self , timeout):
583
630
if self .waiter is not None :
584
631
raise apg_exc.InterfaceError(
@@ -596,7 +643,7 @@ cdef class BaseProtocol(CoreProtocol):
596
643
cdef _on_result__prepare(self , object waiter):
597
644
if ASYNCPG_DEBUG:
598
645
if self .statement is None :
599
- raise RuntimeError (
646
+ raise apg_exc.InternalClientError (
600
647
' _on_result__prepare: statement is None' )
601
648
602
649
if self .result_param_desc is not None :
@@ -643,7 +690,7 @@ cdef class BaseProtocol(CoreProtocol):
643
690
cdef _decode_row(self , const char * buf, ssize_t buf_len):
644
691
if ASYNCPG_DEBUG:
645
692
if self .statement is None :
646
- raise RuntimeError (
693
+ raise apg_exc.InternalClientError (
647
694
' _decode_row: statement is None' )
648
695
649
696
return self .statement._decode_row(buf, buf_len)
@@ -654,13 +701,13 @@ cdef class BaseProtocol(CoreProtocol):
654
701
655
702
if ASYNCPG_DEBUG:
656
703
if waiter is None :
657
- raise RuntimeError (' _on_result: waiter is None' )
704
+ raise apg_exc.InternalClientError (' _on_result: waiter is None' )
658
705
659
706
if waiter.cancelled():
660
707
return
661
708
662
709
if waiter.done():
663
- raise RuntimeError (' _on_result: waiter is done' )
710
+ raise apg_exc.InternalClientError (' _on_result: waiter is done' )
664
711
665
712
if self .result_type == RESULT_FAILED:
666
713
if isinstance (self .result, dict ):
@@ -704,7 +751,7 @@ cdef class BaseProtocol(CoreProtocol):
704
751
self ._on_result__copy_in(waiter)
705
752
706
753
else :
707
- raise RuntimeError (
754
+ raise apg_exc.InternalClientError (
708
755
' got result for unknown protocol state {}' .
709
756
format(self .state))
710
757
0 commit comments