Skip to content

Commit 16183aa

Browse files
authored
Prefer SSL connections by default (#660)
Switch the default SSL mode from 'disabled' to 'prefer'. This matches libpq's behavior and is a sensible thing to do. Fixes: #654
1 parent ddadce9 commit 16183aa

File tree

3 files changed

+56
-37
lines changed

3 files changed

+56
-37
lines changed

asyncpg/connect_utils.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
380380
passfile=passfile)
381381

382382
addrs = []
383+
have_tcp_addrs = False
383384
for h, p in zip(host, port):
384385
if h.startswith('/'):
385386
# UNIX socket name
@@ -389,6 +390,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
389390
else:
390391
# TCP host/port
391392
addrs.append((h, p))
393+
have_tcp_addrs = True
392394

393395
if not addrs:
394396
raise ValueError(
@@ -397,6 +399,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
397399
if ssl is None:
398400
ssl = os.getenv('PGSSLMODE')
399401

402+
if ssl is None and have_tcp_addrs:
403+
ssl = 'prefer'
404+
400405
# ssl_is_advisory is only allowed to come from the sslmode parameter.
401406
ssl_is_advisory = None
402407
if isinstance(ssl, str):
@@ -435,14 +440,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
435440
if sslmode <= SSLMODES['require']:
436441
ssl.verify_mode = ssl_module.CERT_NONE
437442
ssl_is_advisory = sslmode <= SSLMODES['prefer']
438-
439-
if ssl:
440-
for addr in addrs:
441-
if isinstance(addr, str):
442-
# UNIX socket
443-
raise exceptions.InterfaceError(
444-
'`ssl` parameter can only be enabled for TCP addresses, '
445-
'got a UNIX socket path: {!r}'.format(addr))
443+
elif ssl is True:
444+
ssl = ssl_module.create_default_context()
446445

447446
if server_settings is not None and (
448447
not isinstance(server_settings, dict) or
@@ -542,9 +541,6 @@ def connection_lost(self, exc):
542541
async def _create_ssl_connection(protocol_factory, host, port, *,
543542
loop, ssl_context, ssl_is_advisory=False):
544543

545-
if ssl_context is True:
546-
ssl_context = ssl_module.create_default_context()
547-
548544
tr, pr = await loop.create_connection(
549545
lambda: TLSUpgradeProto(loop, host, port,
550546
ssl_context, ssl_is_advisory),
@@ -625,7 +621,6 @@ async def _connect_addr(
625621

626622
if isinstance(addr, str):
627623
# UNIX socket
628-
assert not params.ssl
629624
connector = loop.create_unix_connection(proto_factory, addr)
630625
elif params.ssl:
631626
connector = _create_ssl_connection(

asyncpg/connection.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,28 @@ async def connect(dsn=None, *,
18691869
Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
18701870
require an SSL connection. If ``True``, a default SSL context
18711871
returned by `ssl.create_default_context() <create_default_context_>`_
1872-
will be used.
1872+
will be used. The value can also be one of the following strings:
1873+
1874+
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
1875+
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
1876+
if SSL connection fails
1877+
- ``'allow'`` - currently equivalent to ``'prefer'``
1878+
- ``'require'`` - only try an SSL connection. Certificate
1879+
verifiction errors are ignored
1880+
- ``'verify-ca'`` - only try an SSL connection, and verify
1881+
that the server certificate is issued by a trusted certificate
1882+
authority (CA)
1883+
- ``'verify-full'`` - only try an SSL connection, verify
1884+
that the server certificate is issued by a trusted CA and
1885+
that the requested server host name matches that in the
1886+
certificate.
1887+
1888+
The default is ``'prefer'``: try an SSL connection and fallback to
1889+
non-SSL connection if that fails.
1890+
1891+
.. note::
1892+
1893+
*ssl* is ignored for Unix domain socket communication.
18731894
18741895
:param dict server_settings:
18751896
An optional dict of server runtime parameters. Refer to
@@ -1926,6 +1947,9 @@ async def connect(dsn=None, *,
19261947
.. versionchanged:: 0.22.0
19271948
Added the *record_class* parameter.
19281949
1950+
.. versionchanged:: 0.22.0
1951+
The *ssl* argument now defaults to ``'prefer'``.
1952+
19291953
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
19301954
.. _create_default_context:
19311955
https://docs.python.org/3/library/ssl.html#ssl.create_default_context

tests/test_connect.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
318318
'result': ([('host', 123)], {
319319
'user': 'user',
320320
'password': 'passw',
321-
'database': 'testdb'})
321+
'database': 'testdb',
322+
'ssl': True,
323+
'ssl_is_advisory': True})
322324
},
323325

324326
{
@@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
384386
'user': 'user3',
385387
'password': '123123',
386388
'database': 'abcdef',
387-
'ssl': ssl.SSLContext,
389+
'ssl': True,
388390
'ssl_is_advisory': True})
389391
},
390392

@@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
461463
'user': 'me',
462464
'password': 'ask',
463465
'database': 'db',
464-
'ssl': ssl.SSLContext,
466+
'ssl': True,
465467
'ssl_is_advisory': False})
466468
},
467469

@@ -545,6 +547,7 @@ class TestConnectParams(tb.TestCase):
545547
{
546548
'user': 'user',
547549
'database': 'user',
550+
'ssl': None
548551
}
549552
)
550553
},
@@ -574,7 +577,9 @@ class TestConnectParams(tb.TestCase):
574577
('localhost', 5433)
575578
], {
576579
'user': 'spam',
577-
'database': 'db'
580+
'database': 'db',
581+
'ssl': True,
582+
'ssl_is_advisory': True
578583
}
579584
)
580585
},
@@ -617,7 +622,7 @@ def run_testcase(self, testcase):
617622
password = testcase.get('password')
618623
passfile = testcase.get('passfile')
619624
database = testcase.get('database')
620-
ssl = testcase.get('ssl')
625+
sslmode = testcase.get('ssl')
621626
server_settings = testcase.get('server_settings')
622627

623628
expected = testcase.get('result')
@@ -640,21 +645,26 @@ def run_testcase(self, testcase):
640645

641646
addrs, params = connect_utils._parse_connect_dsn_and_args(
642647
dsn=dsn, host=host, port=port, user=user, password=password,
643-
passfile=passfile, database=database, ssl=ssl,
648+
passfile=passfile, database=database, ssl=sslmode,
644649
connect_timeout=None, server_settings=server_settings)
645650

646-
params = {k: v for k, v in params._asdict().items()
647-
if v is not None}
651+
params = {
652+
k: v for k, v in params._asdict().items()
653+
if v is not None or (expected is not None and k in expected[1])
654+
}
655+
656+
if isinstance(params.get('ssl'), ssl.SSLContext):
657+
params['ssl'] = True
648658

649659
result = (addrs, params)
650660

651661
if expected is not None:
652-
for k, v in expected[1].items():
653-
# If `expected` contains a type, allow that to "match" any
654-
# instance of that type tyat `result` may contain. We need
655-
# this because different SSLContexts don't compare equal.
656-
if isinstance(v, type) and isinstance(result[1].get(k), v):
657-
result[1][k] = v
662+
if 'ssl' not in expected[1]:
663+
# Avoid the hassle of specifying the default SSL mode
664+
# unless explicitly tested for.
665+
params.pop('ssl', None)
666+
params.pop('ssl_is_advisory', None)
667+
658668
self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))
659669

660670
def test_test_connect_params_environ(self):
@@ -1063,16 +1073,6 @@ async def verify_fails(sslmode):
10631073
await verify_fails('verify-ca')
10641074
await verify_fails('verify-full')
10651075

1066-
async def test_connection_ssl_unix(self):
1067-
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1068-
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)
1069-
1070-
with self.assertRaisesRegex(asyncpg.InterfaceError,
1071-
'can only be enabled for TCP addresses'):
1072-
await self.connect(
1073-
host='/tmp',
1074-
ssl=ssl_context)
1075-
10761076
async def test_connection_implicit_host(self):
10771077
conn_spec = self.get_connection_spec()
10781078
con = await asyncpg.connect(

0 commit comments

Comments
 (0)