21
21
from asyncpg import connection as pg_connection
22
22
from asyncpg import pool as pg_pool
23
23
24
+ from . import fuzzer
25
+
24
26
25
27
@contextlib .contextmanager
26
28
def silence_asyncio_long_exec_warning ():
@@ -36,7 +38,16 @@ def flt(log_record):
36
38
logger .removeFilter (flt )
37
39
38
40
41
+ def with_timeout (timeout ):
42
+ def wrap (func ):
43
+ func .__timeout__ = timeout
44
+ return func
45
+
46
+ return wrap
47
+
48
+
39
49
class TestCaseMeta (type (unittest .TestCase )):
50
+ TEST_TIMEOUT = None
40
51
41
52
@staticmethod
42
53
def _iter_methods (bases , ns ):
@@ -64,7 +75,18 @@ def __new__(mcls, name, bases, ns):
64
75
for methname , meth in mcls ._iter_methods (bases , ns ):
65
76
@functools .wraps (meth )
66
77
def wrapper (self , * args , __meth__ = meth , ** kwargs ):
67
- self .loop .run_until_complete (__meth__ (self , * args , ** kwargs ))
78
+ coro = __meth__ (self , * args , ** kwargs )
79
+ timeout = getattr (__meth__ , '__timeout__' , mcls .TEST_TIMEOUT )
80
+ if timeout :
81
+ coro = asyncio .wait_for (coro , timeout , loop = self .loop )
82
+ try :
83
+ self .loop .run_until_complete (coro )
84
+ except asyncio .TimeoutError :
85
+ raise self .failureException (
86
+ 'test timed out after {} seconds' .format (
87
+ timeout )) from None
88
+ else :
89
+ self .loop .run_until_complete (coro )
68
90
ns [methname ] = wrapper
69
91
70
92
return super ().__new__ (mcls , name , bases , ns )
@@ -169,7 +191,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None):
169
191
170
192
171
193
def _shutdown_cluster (cluster ):
172
- cluster .stop ()
194
+ if cluster .get_status () == 'running' :
195
+ cluster .stop ()
173
196
cluster .destroy ()
174
197
175
198
@@ -220,9 +243,11 @@ def get_connection_spec(cls, kwargs={}):
220
243
conn_spec ['user' ] = 'postgres'
221
244
return conn_spec
222
245
223
- def create_pool (self , pool_class = pg_pool .Pool , ** kwargs ):
246
+ def create_pool (self , pool_class = pg_pool .Pool ,
247
+ connection_class = pg_connection .Connection , ** kwargs ):
224
248
conn_spec = self .get_connection_spec (kwargs )
225
- return create_pool (loop = self .loop , pool_class = pool_class , ** conn_spec )
249
+ return create_pool (loop = self .loop , pool_class = pool_class ,
250
+ connection_class = connection_class , ** conn_spec )
226
251
227
252
@classmethod
228
253
def connect (cls , ** kwargs ):
@@ -238,6 +263,49 @@ def start_cluster(cls, ClusterCls, *,
238
263
server_settings , _get_initdb_options (initdb_options ))
239
264
240
265
266
+ class ProxiedClusterTestCase (ClusterTestCase ):
267
+ @classmethod
268
+ def get_server_settings (cls ):
269
+ settings = dict (super ().get_server_settings ())
270
+ settings ['listen_addresses' ] = '127.0.0.1'
271
+ return settings
272
+
273
+ @classmethod
274
+ def get_proxy_settings (cls ):
275
+ return {'fuzzing-mode' : None }
276
+
277
+ @classmethod
278
+ def setUpClass (cls ):
279
+ super ().setUpClass ()
280
+ conn_spec = cls .cluster .get_connection_spec ()
281
+ host = conn_spec .get ('host' )
282
+ if not host :
283
+ host = '127.0.0.1'
284
+ elif host .startswith ('/' ):
285
+ host = '127.0.0.1'
286
+ cls .proxy = fuzzer .TCPFuzzingProxy (
287
+ backend_host = host ,
288
+ backend_port = conn_spec ['port' ],
289
+ )
290
+ cls .proxy .start ()
291
+
292
+ @classmethod
293
+ def tearDownClass (cls ):
294
+ cls .proxy .stop ()
295
+ super ().tearDownClass ()
296
+
297
+ @classmethod
298
+ def get_connection_spec (cls , kwargs ):
299
+ conn_spec = super ().get_connection_spec (kwargs )
300
+ conn_spec ['host' ] = cls .proxy .listening_addr
301
+ conn_spec ['port' ] = cls .proxy .listening_port
302
+ return conn_spec
303
+
304
+ def tearDown (self ):
305
+ self .proxy .reset ()
306
+ super ().tearDown ()
307
+
308
+
241
309
def with_connection_options (** options ):
242
310
if not options :
243
311
raise ValueError ('no connection options were specified' )
0 commit comments