Skip to content

Prefer SSL connections by default #660

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
passfile=passfile)

addrs = []
have_tcp_addrs = False
for h, p in zip(host, port):
if h.startswith('/'):
# UNIX socket name
Expand All @@ -389,6 +390,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
else:
# TCP host/port
addrs.append((h, p))
have_tcp_addrs = True

if not addrs:
raise ValueError(
Expand All @@ -397,6 +399,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None:
ssl = os.getenv('PGSSLMODE')

if ssl is None and have_tcp_addrs:
ssl = 'prefer'

# ssl_is_advisory is only allowed to come from the sslmode parameter.
ssl_is_advisory = None
if isinstance(ssl, str):
Expand Down Expand Up @@ -435,14 +440,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if sslmode <= SSLMODES['require']:
ssl.verify_mode = ssl_module.CERT_NONE
ssl_is_advisory = sslmode <= SSLMODES['prefer']

if ssl:
for addr in addrs:
if isinstance(addr, str):
# UNIX socket
raise exceptions.InterfaceError(
'`ssl` parameter can only be enabled for TCP addresses, '
'got a UNIX socket path: {!r}'.format(addr))
elif ssl is True:
ssl = ssl_module.create_default_context()

if server_settings is not None and (
not isinstance(server_settings, dict) or
Expand Down Expand Up @@ -542,9 +541,6 @@ def connection_lost(self, exc):
async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):

if ssl_context is True:
ssl_context = ssl_module.create_default_context()

tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
ssl_context, ssl_is_advisory),
Expand Down Expand Up @@ -625,7 +621,6 @@ async def _connect_addr(

if isinstance(addr, str):
# UNIX socket
assert not params.ssl
connector = loop.create_unix_connection(proto_factory, addr)
elif params.ssl:
connector = _create_ssl_connection(
Expand Down
26 changes: 25 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,28 @@ async def connect(dsn=None, *,
Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
require an SSL connection. If ``True``, a default SSL context
returned by `ssl.create_default_context() <create_default_context_>`_
will be used.
will be used. The value can also be one of the following strings:

- ``'disable'`` - SSL is disabled (equivalent to ``False``)
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
if SSL connection fails
- ``'allow'`` - currently equivalent to ``'prefer'``
- ``'require'`` - only try an SSL connection. Certificate
verifiction errors are ignored
- ``'verify-ca'`` - only try an SSL connection, and verify
that the server certificate is issued by a trusted certificate
authority (CA)
- ``'verify-full'`` - only try an SSL connection, verify
that the server certificate is issued by a trusted CA and
that the requested server host name matches that in the
certificate.

The default is ``'prefer'``: try an SSL connection and fallback to
non-SSL connection if that fails.

.. note::

*ssl* is ignored for Unix domain socket communication.

:param dict server_settings:
An optional dict of server runtime parameters. Refer to
Expand Down Expand Up @@ -1926,6 +1947,9 @@ async def connect(dsn=None, *,
.. versionchanged:: 0.22.0
Added the *record_class* parameter.

.. versionchanged:: 0.22.0
The *ssl* argument now defaults to ``'prefer'``.

.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
Expand Down
48 changes: 24 additions & 24 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
'result': ([('host', 123)], {
'user': 'user',
'password': 'passw',
'database': 'testdb'})
'database': 'testdb',
'ssl': True,
'ssl_is_advisory': True})
},

{
Expand Down Expand Up @@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
'user': 'user3',
'password': '123123',
'database': 'abcdef',
'ssl': ssl.SSLContext,
'ssl': True,
'ssl_is_advisory': True})
},

Expand Down Expand Up @@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
'user': 'me',
'password': 'ask',
'database': 'db',
'ssl': ssl.SSLContext,
'ssl': True,
'ssl_is_advisory': False})
},

Expand Down Expand Up @@ -545,6 +547,7 @@ class TestConnectParams(tb.TestCase):
{
'user': 'user',
'database': 'user',
'ssl': None
}
)
},
Expand Down Expand Up @@ -574,7 +577,9 @@ class TestConnectParams(tb.TestCase):
('localhost', 5433)
], {
'user': 'spam',
'database': 'db'
'database': 'db',
'ssl': True,
'ssl_is_advisory': True
}
)
},
Expand Down Expand Up @@ -617,7 +622,7 @@ def run_testcase(self, testcase):
password = testcase.get('password')
passfile = testcase.get('passfile')
database = testcase.get('database')
ssl = testcase.get('ssl')
sslmode = testcase.get('ssl')
server_settings = testcase.get('server_settings')

expected = testcase.get('result')
Expand All @@ -640,21 +645,26 @@ def run_testcase(self, testcase):

addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=ssl,
passfile=passfile, database=database, ssl=sslmode,
connect_timeout=None, server_settings=server_settings)

params = {k: v for k, v in params._asdict().items()
if v is not None}
params = {
k: v for k, v in params._asdict().items()
if v is not None or (expected is not None and k in expected[1])
}

if isinstance(params.get('ssl'), ssl.SSLContext):
params['ssl'] = True

result = (addrs, params)

if expected is not None:
for k, v in expected[1].items():
# If `expected` contains a type, allow that to "match" any
# instance of that type tyat `result` may contain. We need
# this because different SSLContexts don't compare equal.
if isinstance(v, type) and isinstance(result[1].get(k), v):
result[1][k] = v
if 'ssl' not in expected[1]:
# Avoid the hassle of specifying the default SSL mode
# unless explicitly tested for.
params.pop('ssl', None)
params.pop('ssl_is_advisory', None)

self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))

def test_test_connect_params_environ(self):
Expand Down Expand Up @@ -1063,16 +1073,6 @@ async def verify_fails(sslmode):
await verify_fails('verify-ca')
await verify_fails('verify-full')

async def test_connection_ssl_unix(self):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)

with self.assertRaisesRegex(asyncpg.InterfaceError,
'can only be enabled for TCP addresses'):
await self.connect(
host='/tmp',
ssl=ssl_context)

async def test_connection_implicit_host(self):
conn_spec = self.get_connection_spec()
con = await asyncpg.connect(
Expand Down