Skip to content

Commit c993737

Browse files
committed
Add sslmode=allow support and fix =prefer retry
We didn't really retry the connection without SSL if the first SSL connection fails under sslmode=prefer, that led to an issue when the server has SSL support but explicitly denies SSL connection through pg_hba.conf. This commit adds a retry in a new connection, which makes it easy to implement the sslmode=allow retry. Fixes #716
1 parent 53bea98 commit c993737

File tree

5 files changed

+243
-49
lines changed

5 files changed

+243
-49
lines changed

asyncpg/connect_utils.py

+101-30
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'password',
3636
'database',
3737
'ssl',
38-
'ssl_is_advisory',
38+
'alt_retry_ssl_first',
3939
'connect_timeout',
4040
'server_settings',
4141
])
@@ -402,8 +402,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402402
if ssl is None and have_tcp_addrs:
403403
ssl = 'prefer'
404404

405-
# ssl_is_advisory is only allowed to come from the sslmode parameter.
406-
ssl_is_advisory = None
405+
# alt_retry_ssl_first is particularly for "allow" and "prefer"
406+
# to alternatively try SSL/non-SSL connections (once each if supported):
407+
# False - allow (try non-SSL first)
408+
# True - prefer (try SSL first)
409+
# None - other (don't retry, stick with the "ssl" parameter)
410+
alt_retry_ssl_first = None
411+
407412
if isinstance(ssl, str):
408413
SSLMODES = {
409414
'disable': 0,
@@ -420,26 +425,21 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
420425
raise exceptions.InterfaceError(
421426
'`sslmode` parameter must be one of: {}'.format(modes))
422427

423-
# sslmode 'allow' is currently handled as 'prefer' because we're
424-
# missing the "retry with SSL" behavior for 'allow', but do have the
425-
# "retry without SSL" behavior for 'prefer'.
426-
# Not changing 'allow' to 'prefer' here would be effectively the same
427-
# as changing 'allow' to 'disable'.
428428
if sslmode == SSLMODES['allow']:
429-
sslmode = SSLMODES['prefer']
429+
alt_retry_ssl_first = False
430+
elif sslmode == SSLMODES['prefer']:
431+
alt_retry_ssl_first = True
430432

431433
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432434
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433-
if sslmode <= SSLMODES['allow']:
435+
if sslmode < SSLMODES['allow']:
434436
ssl = False
435-
ssl_is_advisory = sslmode >= SSLMODES['allow']
436437
else:
437438
ssl = ssl_module.create_default_context()
438439
ssl.check_hostname = sslmode >= SSLMODES['verify-full']
439440
ssl.verify_mode = ssl_module.CERT_REQUIRED
440441
if sslmode <= SSLMODES['require']:
441442
ssl.verify_mode = ssl_module.CERT_NONE
442-
ssl_is_advisory = sslmode <= SSLMODES['prefer']
443443
elif ssl is True:
444444
ssl = ssl_module.create_default_context()
445445

@@ -453,7 +453,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453453

454454
params = _ConnectionParameters(
455455
user=user, password=password, database=database, ssl=ssl,
456-
ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
456+
alt_retry_ssl_first=alt_retry_ssl_first,
457+
connect_timeout=connect_timeout,
457458
server_settings=server_settings)
458459

459460
return addrs, params
@@ -520,9 +521,8 @@ def data_received(self, data):
520521
data == b'N'):
521522
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522523
# since the only way to get ssl_is_advisory is from
523-
# sslmode=prefer (or sslmode=allow). But be extra sure to
524-
# disallow insecure connections when the ssl context asks for
525-
# real security.
524+
# sslmode=prefer. But be extra sure to disallow insecure
525+
# connections when the ssl context asks for real security.
526526
self.on_data.set_result(False)
527527
else:
528528
self.on_data.set_exception(
@@ -566,6 +566,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566566
new_tr = tr
567567

568568
pg_proto = protocol_factory()
569+
pg_proto.is_ssl = do_ssl_upgrade
569570
pg_proto.connection_made(new_tr)
570571
new_tr.set_protocol(pg_proto)
571572

@@ -584,7 +585,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584585
tr.close()
585586

586587
try:
587-
return await conn_factory(sock=sock)
588+
new_tr, pg_proto = await conn_factory(sock=sock)
589+
pg_proto.is_ssl = do_ssl_upgrade
590+
return new_tr, pg_proto
588591
except (Exception, asyncio.CancelledError):
589592
sock.close()
590593
raise
@@ -605,8 +608,6 @@ async def _connect_addr(
605608
if timeout <= 0:
606609
raise asyncio.TimeoutError
607610

608-
connected = _create_future(loop)
609-
610611
params_input = params
611612
if callable(params.password):
612613
if inspect.iscoroutinefunction(params.password):
@@ -615,6 +616,44 @@ async def _connect_addr(
615616
password = params.password()
616617

617618
params = params._replace(password=password)
619+
args = (addr, loop, config, connection_class, record_class, params_input)
620+
621+
# skip retry if alt_retry is not enabled
622+
if params.alt_retry_ssl_first is None:
623+
return await __connect_addr(params, timeout, *args)
624+
625+
# prepare the params (which attempt has ssl) for the 2 attempts
626+
params_retry = params._replace(ssl=None)
627+
if not params.alt_retry_ssl_first:
628+
params, params_retry = params_retry, params
629+
630+
# first attempt
631+
before = time.monotonic()
632+
try:
633+
return await __connect_addr(params, timeout, *args)
634+
except ConnectionError:
635+
pass
636+
637+
# the second attempt with alt_retry_ssl_first=None
638+
timeout -= time.monotonic() - before
639+
if timeout <= 0:
640+
raise asyncio.TimeoutError
641+
else:
642+
params_retry = params_retry._replace(alt_retry_ssl_first=None)
643+
return await __connect_addr(params_retry, timeout, *args)
644+
645+
646+
async def __connect_addr(
647+
params,
648+
timeout,
649+
addr,
650+
loop,
651+
config,
652+
connection_class,
653+
record_class,
654+
params_input,
655+
):
656+
connected = _create_future(loop)
618657

619658
proto_factory = lambda: protocol.Protocol(
620659
addr, connected, params, record_class, loop)
@@ -625,7 +664,7 @@ async def _connect_addr(
625664
elif params.ssl:
626665
connector = _create_ssl_connection(
627666
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
628-
ssl_is_advisory=params.ssl_is_advisory)
667+
ssl_is_advisory=params.alt_retry_ssl_first)
629668
else:
630669
connector = loop.create_connection(proto_factory, *addr)
631670

@@ -638,6 +677,23 @@ async def _connect_addr(
638677
if timeout <= 0:
639678
raise asyncio.TimeoutError
640679
await compat.wait_for(connected, timeout=timeout)
680+
except exceptions.InvalidAuthorizationSpecificationError:
681+
tr.close()
682+
683+
# pr.is_ssl is a bool, so this equal test implies
684+
# alt_retry_ssl_first is not None (should do alt_retry)
685+
if params.alt_retry_ssl_first == pr.is_ssl:
686+
# Elevate the error to ConnectionError to trigger retry
687+
raise ConnectionError("Connection rejected trying {} SSL".format(
688+
'with' if pr.is_ssl else 'without'))
689+
690+
else:
691+
# Don't retry if alt_retry_ssl_first is None, or we don't need to
692+
# (alt_retry_ssl_first=True and pr.is_ssl=False means the server
693+
# doesn't support SSL, and we've already tried to Startup without
694+
# SSL but failed; The opposite case doesn't exist).
695+
raise
696+
641697
except (Exception, asyncio.CancelledError):
642698
tr.close()
643699
raise
@@ -684,6 +740,7 @@ class CancelProto(asyncio.Protocol):
684740

685741
def __init__(self):
686742
self.on_disconnect = _create_future(loop)
743+
self.is_ssl = False
687744

688745
def connection_lost(self, exc):
689746
if not self.on_disconnect.done():
@@ -692,17 +749,31 @@ def connection_lost(self, exc):
692749
if isinstance(addr, str):
693750
tr, pr = await loop.create_unix_connection(CancelProto, addr)
694751
else:
695-
if params.ssl:
696-
tr, pr = await _create_ssl_connection(
697-
CancelProto,
698-
*addr,
699-
loop=loop,
700-
ssl_context=params.ssl,
701-
ssl_is_advisory=params.ssl_is_advisory)
752+
async def _connect(params_in, ssl_is_advisory):
753+
if params_in.ssl:
754+
return await _create_ssl_connection(
755+
CancelProto,
756+
*addr,
757+
loop=loop,
758+
ssl_context=params_in.ssl,
759+
ssl_is_advisory=ssl_is_advisory)
760+
else:
761+
rv = await loop.create_connection(
762+
CancelProto, *addr)
763+
_set_nodelay(_get_socket(rv[0]))
764+
return rv
765+
766+
if params.alt_retry_ssl_first is None:
767+
tr, pr = await _connect(params, False)
702768
else:
703-
tr, pr = await loop.create_connection(
704-
CancelProto, *addr)
705-
_set_nodelay(_get_socket(tr))
769+
params_retry = params._replace(ssl=None)
770+
if not params.alt_retry_ssl_first:
771+
params, params_retry = params_retry, params
772+
try:
773+
tr, pr = await _connect(params, True)
774+
except ConnectionError:
775+
tr, pr = await _connect(
776+
params._replace(alt_retry_ssl_first=None), False)
706777

707778
# Pack a CancelRequest message
708779
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)

asyncpg/connection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,8 @@ async def connect(dsn=None, *,
18791879
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
18801880
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
18811881
if SSL connection fails
1882-
- ``'allow'`` - currently equivalent to ``'prefer'``
1882+
- ``'allow'`` - try without SSL first, then retry with SSL if the first
1883+
attempt fails.
18831884
- ``'require'`` - only try an SSL connection. Certificate
18841885
verification errors are ignored
18851886
- ``'verify-ca'`` - only try an SSL connection, and verify

asyncpg/protocol/protocol.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ cdef class BaseProtocol(CoreProtocol):
5252

5353
readonly uint64_t queries_count
5454

55+
bint _is_ssl
56+
5557
PreparedStatementState statement
5658

5759
cdef get_connection(self)

asyncpg/protocol/protocol.pyx

+10
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ cdef class BaseProtocol(CoreProtocol):
103103

104104
self.queries_count = 0
105105

106+
self._is_ssl = False
107+
106108
try:
107109
self.create_future = loop.create_future
108110
except AttributeError:
@@ -943,6 +945,14 @@ cdef class BaseProtocol(CoreProtocol):
943945
def resume_writing(self):
944946
self.writing_allowed.set()
945947

948+
@property
949+
def is_ssl(self):
950+
return self._is_ssl
951+
952+
@is_ssl.setter
953+
def is_ssl(self, value):
954+
self._is_ssl = value
955+
946956

947957
class Timer:
948958
def __init__(self, budget):

0 commit comments

Comments
 (0)