Skip to content

Commit a19ce50

Browse files
vangheemelprans
authored andcommitted
Connection.prepare should not use statement cache
1 parent 458cf05 commit a19ce50

File tree

3 files changed

+51
-35
lines changed

3 files changed

+51
-35
lines changed

asyncpg/connection.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -264,19 +264,21 @@ async def executemany(self, command: str, args, *, timeout: float=None):
264264
self._check_open()
265265
return await self._executemany(command, args, timeout)
266266

267-
async def _get_statement(self, query, timeout, *, named: bool=False):
268-
statement = self._stmt_cache.get(query)
269-
if statement is not None:
270-
return statement
271-
272-
# Only use the cache when:
273-
# * `statement_cache_size` is greater than 0;
274-
# * query size is less than `max_cacheable_statement_size`.
275-
use_cache = self._stmt_cache.get_max_size() > 0
276-
if (use_cache and
277-
self._config.max_cacheable_statement_size and
278-
len(query) > self._config.max_cacheable_statement_size):
279-
use_cache = False
267+
async def _get_statement(self, query, timeout, *, named: bool=False,
268+
use_cache: bool=True):
269+
if use_cache:
270+
statement = self._stmt_cache.get(query)
271+
if statement is not None:
272+
return statement
273+
274+
# Only use the cache when:
275+
# * `statement_cache_size` is greater than 0;
276+
# * query size is less than `max_cacheable_statement_size`.
277+
use_cache = self._stmt_cache.get_max_size() > 0
278+
if (use_cache and
279+
self._config.max_cacheable_statement_size and
280+
len(query) > self._config.max_cacheable_statement_size):
281+
use_cache = False
280282

281283
if use_cache or named:
282284
stmt_name = self._get_unique_id('stmt')
@@ -328,8 +330,12 @@ async def prepare(self, query, *, timeout=None):
328330
329331
:return: A :class:`~prepared_stmt.PreparedStatement` instance.
330332
"""
333+
return await self._prepare(query, timeout=timeout, use_cache=False)
334+
335+
async def _prepare(self, query, *, timeout=None, use_cache: bool=False):
331336
self._check_open()
332-
stmt = await self._get_statement(query, timeout, named=True)
337+
stmt = await self._get_statement(query, timeout, named=True,
338+
use_cache=use_cache)
333339
return prepared_stmt.PreparedStatement(self, query, stmt)
334340

335341
async def fetch(self, query, *args, timeout=None) -> list:
@@ -645,7 +651,7 @@ async def copy_records_to_table(self, table_name, *, records,
645651
intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format(
646652
tab=tabname, cols=col_list)
647653

648-
intro_ps = await self.prepare(intro_query)
654+
intro_ps = await self._prepare(intro_query, use_cache=True)
649655

650656
opts = '(FORMAT binary)'
651657

tests/test_cache_invalidation.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ async def test_prepared_type_cache_invalidation(self):
151151

152152
try:
153153
await self.con.execute('INSERT INTO tab1 VALUES (1, (2, 3))')
154-
prep = await self.con.prepare('SELECT * FROM tab1')
154+
prep = await self.con._prepare('SELECT * FROM tab1',
155+
use_cache=True)
155156
result = await prep.fetchrow()
156157
self.assertEqual(result, (1, (2, 3)))
157158

@@ -170,7 +171,8 @@ async def test_prepared_type_cache_invalidation(self):
170171
'the prepared statement is closed'):
171172
await prep.fetchrow()
172173

173-
prep = await self.con.prepare('SELECT * FROM tab1')
174+
prep = await self.con._prepare('SELECT * FROM tab1',
175+
use_cache=True)
174176
# The second PS must be correct (cache was dropped):
175177
result = await prep.fetchrow()
176178
self.assertEqual(result, (1, (2, 3, None)))
@@ -183,7 +185,8 @@ async def test_prepared_type_cache_invalidation(self):
183185
await prep.fetchrow()
184186

185187
# Reprepare it again after dropping cache.
186-
prep = await self.con.prepare('SELECT * FROM tab1')
188+
prep = await self.con._prepare('SELECT * FROM tab1',
189+
use_cache=True)
187190
# This is now OK, the cache is filled after being dropped.
188191
result = await prep.fetchrow()
189192
self.assertEqual(result, (1, (2, 3)))

tests/test_prepare.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def test_prepare_10_stmt_lru(self):
153153

154154
stmts = []
155155
for i in range(iter_max):
156-
s = await self.con.prepare(query.format(i))
156+
s = await self.con._prepare(query.format(i), use_cache=True)
157157
self.assertEqual(await s.fetchval(), i)
158158
stmts.append(s)
159159

@@ -207,7 +207,7 @@ async def test_prepare_11_stmt_gc(self):
207207
# The prepared statement that we'll create will be GCed
208208
# right await. However, its state should be still in
209209
# in the statements LRU cache.
210-
await self.con.prepare('select 1')
210+
await self.con._prepare('select 1', use_cache=True)
211211
gc.collect()
212212

213213
self.assertEqual(len(cache), 1)
@@ -224,12 +224,12 @@ async def test_prepare_12_stmt_gc(self):
224224
self.assertEqual(len(cache), 0)
225225
self.assertEqual(len(self.con._stmts_to_close), 0)
226226

227-
stmt = await self.con.prepare('select 100000000')
227+
stmt = await self.con._prepare('select 100000000', use_cache=True)
228228
self.assertEqual(len(cache), 1)
229229
self.assertEqual(len(self.con._stmts_to_close), 0)
230230

231231
for i in range(cache_max):
232-
await self.con.prepare('select {}'.format(i))
232+
await self.con._prepare('select {}'.format(i), use_cache=True)
233233

234234
self.assertEqual(len(cache), cache_max)
235235
self.assertEqual(len(self.con._stmts_to_close), 0)
@@ -293,7 +293,7 @@ async def test_prepare_15_stmt_gc_cache_disabled(self):
293293
# Disable cache
294294
cache.set_max_size(0)
295295

296-
stmt = await self.con.prepare('select 100000000')
296+
stmt = await self.con._prepare('select 100000000', use_cache=True)
297297
self.assertEqual(len(cache), 0)
298298
self.assertEqual(len(self.con._stmts_to_close), 0)
299299

@@ -305,7 +305,7 @@ async def test_prepare_15_stmt_gc_cache_disabled(self):
305305
self.assertEqual(len(self.con._stmts_to_close), 1)
306306

307307
# Next "prepare" call will trigger a cleanup
308-
stmt = await self.con.prepare('select 1')
308+
stmt = await self.con._prepare('select 1', use_cache=True)
309309
self.assertEqual(len(cache), 0)
310310
self.assertEqual(len(self.con._stmts_to_close), 0)
311311

@@ -468,25 +468,25 @@ async def test_prepare_24_max_lifetime(self):
468468
self.assertEqual(cache.get_max_lifetime(), 142)
469469
cache.set_max_lifetime(1)
470470

471-
s = await self.con.prepare('SELECT 1')
471+
s = await self.con._prepare('SELECT 1', use_cache=True)
472472
state = s._state
473473

474-
s = await self.con.prepare('SELECT 1')
474+
s = await self.con._prepare('SELECT 1', use_cache=True)
475475
self.assertIs(s._state, state)
476476

477-
s = await self.con.prepare('SELECT 1')
477+
s = await self.con._prepare('SELECT 1', use_cache=True)
478478
self.assertIs(s._state, state)
479479

480480
await asyncio.sleep(1, loop=self.loop)
481481

482-
s = await self.con.prepare('SELECT 1')
482+
s = await self.con._prepare('SELECT 1', use_cache=True)
483483
self.assertIsNot(s._state, state)
484484

485485
@tb.with_connection_options(max_cached_statement_lifetime=0.5)
486486
async def test_prepare_25_max_lifetime_reset(self):
487487
cache = self.con._stmt_cache
488488

489-
s = await self.con.prepare('SELECT 1')
489+
s = await self.con._prepare('SELECT 1', use_cache=True)
490490
state = s._state
491491

492492
# Disable max_lifetime
@@ -495,20 +495,20 @@ async def test_prepare_25_max_lifetime_reset(self):
495495
await asyncio.sleep(1, loop=self.loop)
496496

497497
# The statement should still be cached (as we disabled the timeout).
498-
s = await self.con.prepare('SELECT 1')
498+
s = await self.con._prepare('SELECT 1', use_cache=True)
499499
self.assertIs(s._state, state)
500500

501501
@tb.with_connection_options(max_cached_statement_lifetime=0.5)
502502
async def test_prepare_26_max_lifetime_max_size(self):
503503
cache = self.con._stmt_cache
504504

505-
s = await self.con.prepare('SELECT 1')
505+
s = await self.con._prepare('SELECT 1', use_cache=True)
506506
state = s._state
507507

508508
# Disable max_lifetime
509509
cache.set_max_size(0)
510510

511-
s = await self.con.prepare('SELECT 1')
511+
s = await self.con._prepare('SELECT 1', use_cache=True)
512512
self.assertIsNot(s._state, state)
513513

514514
# Check that nothing crashes after the initial timeout
@@ -518,12 +518,12 @@ async def test_prepare_26_max_lifetime_max_size(self):
518518
async def test_prepare_27_max_cacheable_statement_size(self):
519519
cache = self.con._stmt_cache
520520

521-
await self.con.prepare('SELECT 1')
521+
await self.con._prepare('SELECT 1', use_cache=True)
522522
self.assertEqual(len(cache), 1)
523523

524524
# Test that long and explicitly created prepared statements
525525
# are not cached.
526-
await self.con.prepare("SELECT \'" + "a" * 50 + "\'")
526+
await self.con._prepare("SELECT \'" + "a" * 50 + "\'", use_cache=True)
527527
self.assertEqual(len(cache), 1)
528528

529529
# Test that implicitly created long prepared statements
@@ -532,7 +532,7 @@ async def test_prepare_27_max_cacheable_statement_size(self):
532532
self.assertEqual(len(cache), 1)
533533

534534
# Test that short prepared statements can still be cached.
535-
await self.con.prepare('SELECT 2')
535+
await self.con._prepare('SELECT 2', use_cache=True)
536536
self.assertEqual(len(cache), 2)
537537

538538
async def test_prepare_28_max_args(self):
@@ -593,3 +593,10 @@ async def test_prepare_31_pgbouncer_note(self):
593593
self.assertTrue('pgbouncer' in e.hint)
594594
else:
595595
self.fail('InvalidSQLStatementNameError not raised')
596+
597+
async def test_prepare_does_not_use_cache(self):
598+
cache = self.con._stmt_cache
599+
600+
# prepare with disabled cache
601+
await self.con.prepare('select 1')
602+
self.assertEqual(len(cache), 0)

0 commit comments

Comments
 (0)