@@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
318
318
'result' : ([('host' , 123 )], {
319
319
'user' : 'user' ,
320
320
'password' : 'passw' ,
321
- 'database' : 'testdb' })
321
+ 'database' : 'testdb' ,
322
+ 'ssl' : True ,
323
+ 'ssl_is_advisory' : True })
322
324
},
323
325
324
326
{
@@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
384
386
'user' : 'user3' ,
385
387
'password' : '123123' ,
386
388
'database' : 'abcdef' ,
387
- 'ssl' : ssl . SSLContext ,
389
+ 'ssl' : True ,
388
390
'ssl_is_advisory' : True })
389
391
},
390
392
@@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
461
463
'user' : 'me' ,
462
464
'password' : 'ask' ,
463
465
'database' : 'db' ,
464
- 'ssl' : ssl . SSLContext ,
466
+ 'ssl' : True ,
465
467
'ssl_is_advisory' : False })
466
468
},
467
469
@@ -545,6 +547,7 @@ class TestConnectParams(tb.TestCase):
545
547
{
546
548
'user' : 'user' ,
547
549
'database' : 'user' ,
550
+ 'ssl' : None
548
551
}
549
552
)
550
553
},
@@ -574,7 +577,9 @@ class TestConnectParams(tb.TestCase):
574
577
('localhost' , 5433 )
575
578
], {
576
579
'user' : 'spam' ,
577
- 'database' : 'db'
580
+ 'database' : 'db' ,
581
+ 'ssl' : True ,
582
+ 'ssl_is_advisory' : True
578
583
}
579
584
)
580
585
},
@@ -617,7 +622,7 @@ def run_testcase(self, testcase):
617
622
password = testcase .get ('password' )
618
623
passfile = testcase .get ('passfile' )
619
624
database = testcase .get ('database' )
620
- ssl = testcase .get ('ssl' )
625
+ sslmode = testcase .get ('ssl' )
621
626
server_settings = testcase .get ('server_settings' )
622
627
623
628
expected = testcase .get ('result' )
@@ -640,21 +645,26 @@ def run_testcase(self, testcase):
640
645
641
646
addrs , params = connect_utils ._parse_connect_dsn_and_args (
642
647
dsn = dsn , host = host , port = port , user = user , password = password ,
643
- passfile = passfile , database = database , ssl = ssl ,
648
+ passfile = passfile , database = database , ssl = sslmode ,
644
649
connect_timeout = None , server_settings = server_settings )
645
650
646
- params = {k : v for k , v in params ._asdict ().items ()
647
- if v is not None }
651
+ params = {
652
+ k : v for k , v in params ._asdict ().items ()
653
+ if v is not None or (expected is not None and k in expected [1 ])
654
+ }
655
+
656
+ if isinstance (params .get ('ssl' ), ssl .SSLContext ):
657
+ params ['ssl' ] = True
648
658
649
659
result = (addrs , params )
650
660
651
661
if expected is not None :
652
- for k , v in expected [1 ]. items () :
653
- # If `expected` contains a type, allow that to "match" any
654
- # instance of that type tyat `result` may contain. We need
655
- # this because different SSLContexts don't compare equal.
656
- if isinstance ( v , type ) and isinstance ( result [ 1 ]. get ( k ), v ):
657
- result [ 1 ][ k ] = v
662
+ if 'ssl' not in expected [1 ]:
663
+ # Avoid the hassle of specifying the default SSL mode
664
+ # unless explicitly tested for.
665
+ params . pop ( 'ssl' , None )
666
+ params . pop ( 'ssl_is_advisory' , None )
667
+
658
668
self .assertEqual (expected , result , 'Testcase: {}' .format (testcase ))
659
669
660
670
def test_test_connect_params_environ (self ):
@@ -1063,16 +1073,6 @@ async def verify_fails(sslmode):
1063
1073
await verify_fails ('verify-ca' )
1064
1074
await verify_fails ('verify-full' )
1065
1075
1066
- async def test_connection_ssl_unix (self ):
1067
- ssl_context = ssl .SSLContext (ssl .PROTOCOL_SSLv23 )
1068
- ssl_context .load_verify_locations (SSL_CA_CERT_FILE )
1069
-
1070
- with self .assertRaisesRegex (asyncpg .InterfaceError ,
1071
- 'can only be enabled for TCP addresses' ):
1072
- await self .connect (
1073
- host = '/tmp' ,
1074
- ssl = ssl_context )
1075
-
1076
1076
async def test_connection_implicit_host (self ):
1077
1077
conn_spec = self .get_connection_spec ()
1078
1078
con = await asyncpg .connect (
0 commit comments