11
11
import collections .abc
12
12
import functools
13
13
import itertools
14
+ import inspect
14
15
import os
15
16
import sys
16
17
import time
17
18
import traceback
19
+ import typing
18
20
import warnings
19
21
import weakref
20
22
@@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
133
135
:param str channel: Channel to listen on.
134
136
135
137
:param callable callback:
136
- A callable receiving the following arguments:
138
+ A callable or a coroutine function receiving the following
139
+ arguments:
137
140
**connection**: a Connection the callback is registered with;
138
141
**pid**: PID of the Postgres server that sent the notification;
139
142
**channel**: name of the channel the notification was sent to;
140
143
**payload**: the payload.
144
+
145
+ .. versionchanged:: 0.24.0
146
+ The ``callback`` argument may be a coroutine function.
141
147
"""
142
148
self ._check_open ()
143
149
if channel not in self ._listeners :
144
150
await self .fetch ('LISTEN {}' .format (utils ._quote_ident (channel )))
145
151
self ._listeners [channel ] = set ()
146
- self ._listeners [channel ].add (callback )
152
+ self ._listeners [channel ].add (_Callback . from_callable ( callback ) )
147
153
148
154
async def remove_listener (self , channel , callback ):
149
155
"""Remove a listening callback on the specified channel."""
150
156
if self .is_closed ():
151
157
return
152
158
if channel not in self ._listeners :
153
159
return
154
- if callback not in self ._listeners [channel ]:
160
+ cb = _Callback .from_callable (callback )
161
+ if cb not in self ._listeners [channel ]:
155
162
return
156
- self ._listeners [channel ].remove (callback )
163
+ self ._listeners [channel ].remove (cb )
157
164
if not self ._listeners [channel ]:
158
165
del self ._listeners [channel ]
159
166
await self .fetch ('UNLISTEN {}' .format (utils ._quote_ident (channel )))
@@ -166,44 +173,51 @@ def add_log_listener(self, callback):
166
173
DEBUG, INFO, or LOG.
167
174
168
175
:param callable callback:
169
- A callable receiving the following arguments:
176
+ A callable or a coroutine function receiving the following
177
+ arguments:
170
178
**connection**: a Connection the callback is registered with;
171
179
**message**: the `exceptions.PostgresLogMessage` message.
172
180
173
181
.. versionadded:: 0.12.0
182
+
183
+ .. versionchanged:: 0.24.0
184
+ The ``callback`` argument may be a coroutine function.
174
185
"""
175
186
if self .is_closed ():
176
187
raise exceptions .InterfaceError ('connection is closed' )
177
- self ._log_listeners .add (callback )
188
+ self ._log_listeners .add (_Callback . from_callable ( callback ) )
178
189
179
190
def remove_log_listener (self , callback ):
180
191
"""Remove a listening callback for log messages.
181
192
182
193
.. versionadded:: 0.12.0
183
194
"""
184
- self ._log_listeners .discard (callback )
195
+ self ._log_listeners .discard (_Callback . from_callable ( callback ) )
185
196
186
197
def add_termination_listener (self , callback ):
187
198
"""Add a listener that will be called when the connection is closed.
188
199
189
200
:param callable callback:
190
- A callable receiving one argument:
201
+ A callable or a coroutine function receiving one argument:
191
202
**connection**: a Connection the callback is registered with.
192
203
193
204
.. versionadded:: 0.21.0
205
+
206
+ .. versionchanged:: 0.24.0
207
+ The ``callback`` argument may be a coroutine function.
194
208
"""
195
- self ._termination_listeners .add (callback )
209
+ self ._termination_listeners .add (_Callback . from_callable ( callback ) )
196
210
197
211
def remove_termination_listener (self , callback ):
198
212
"""Remove a listening callback for connection termination.
199
213
200
214
:param callable callback:
201
- The callable that was passed to
215
+ The callable or coroutine function that was passed to
202
216
:meth:`Connection.add_termination_listener`.
203
217
204
218
.. versionadded:: 0.21.0
205
219
"""
206
- self ._termination_listeners .discard (callback )
220
+ self ._termination_listeners .discard (_Callback . from_callable ( callback ) )
207
221
208
222
def get_server_pid (self ):
209
223
"""Return the PID of the Postgres server the connection is bound to."""
@@ -1430,35 +1444,21 @@ def _process_log_message(self, fields, last_query):
1430
1444
1431
1445
con_ref = self ._unwrap ()
1432
1446
for cb in self ._log_listeners :
1433
- self ._loop .call_soon (
1434
- self ._call_log_listener , cb , con_ref , message )
1435
-
1436
- def _call_log_listener (self , cb , con_ref , message ):
1437
- try :
1438
- cb (con_ref , message )
1439
- except Exception as ex :
1440
- self ._loop .call_exception_handler ({
1441
- 'message' : 'Unhandled exception in asyncpg log message '
1442
- 'listener callback {!r}' .format (cb ),
1443
- 'exception' : ex
1444
- })
1447
+ if cb .is_async :
1448
+ self ._loop .create_task (cb .cb (con_ref , message ))
1449
+ else :
1450
+ self ._loop .call_soon (cb .cb , con_ref , message )
1445
1451
1446
1452
def _call_termination_listeners (self ):
1447
1453
if not self ._termination_listeners :
1448
1454
return
1449
1455
1450
1456
con_ref = self ._unwrap ()
1451
1457
for cb in self ._termination_listeners :
1452
- try :
1453
- cb (con_ref )
1454
- except Exception as ex :
1455
- self ._loop .call_exception_handler ({
1456
- 'message' : (
1457
- 'Unhandled exception in asyncpg connection '
1458
- 'termination listener callback {!r}' .format (cb )
1459
- ),
1460
- 'exception' : ex
1461
- })
1458
+ if cb .is_async :
1459
+ self ._loop .create_task (cb .cb (con_ref ))
1460
+ else :
1461
+ self ._loop .call_soon (cb .cb , con_ref )
1462
1462
1463
1463
self ._termination_listeners .clear ()
1464
1464
@@ -1468,18 +1468,10 @@ def _process_notification(self, pid, channel, payload):
1468
1468
1469
1469
con_ref = self ._unwrap ()
1470
1470
for cb in self ._listeners [channel ]:
1471
- self ._loop .call_soon (
1472
- self ._call_listener , cb , con_ref , pid , channel , payload )
1473
-
1474
- def _call_listener (self , cb , con_ref , pid , channel , payload ):
1475
- try :
1476
- cb (con_ref , pid , channel , payload )
1477
- except Exception as ex :
1478
- self ._loop .call_exception_handler ({
1479
- 'message' : 'Unhandled exception in asyncpg notification '
1480
- 'listener callback {!r}' .format (cb ),
1481
- 'exception' : ex
1482
- })
1471
+ if cb .is_async :
1472
+ self ._loop .create_task (cb .cb (con_ref , pid , channel , payload ))
1473
+ else :
1474
+ self ._loop .call_soon (cb .cb , con_ref , pid , channel , payload )
1483
1475
1484
1476
def _unwrap (self ):
1485
1477
if self ._proxy is None :
@@ -2154,6 +2146,26 @@ def _maybe_cleanup(self):
2154
2146
self ._on_remove (old_entry ._statement )
2155
2147
2156
2148
2149
+ class _Callback (typing .NamedTuple ):
2150
+
2151
+ cb : typing .Callable [..., None ]
2152
+ is_async : bool
2153
+
2154
+ @classmethod
2155
+ def from_callable (cls , cb : typing .Callable [..., None ]) -> '_Callback' :
2156
+ if inspect .iscoroutinefunction (cb ):
2157
+ is_async = True
2158
+ elif callable (cb ):
2159
+ is_async = False
2160
+ else :
2161
+ raise exceptions .InterfaceError (
2162
+ 'expected a callable or an `async def` function,'
2163
+ 'got {!r}' .format (cb )
2164
+ )
2165
+
2166
+ return cls (cb , is_async )
2167
+
2168
+
2157
2169
class _Atomic :
2158
2170
__slots__ = ('_acquired' ,)
2159
2171
0 commit comments