35
35
'password' ,
36
36
'database' ,
37
37
'ssl' ,
38
- 'ssl_is_advisory ' ,
38
+ 'alt_retry_ssl_first ' ,
39
39
'connect_timeout' ,
40
40
'server_settings' ,
41
41
])
@@ -402,8 +402,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402
402
if ssl is None and have_tcp_addrs :
403
403
ssl = 'prefer'
404
404
405
- # ssl_is_advisory is only allowed to come from the sslmode parameter.
406
- ssl_is_advisory = None
405
+ # alt_retry_ssl_first is particularly for "allow" and "prefer"
406
+ # to alternatively try SSL/non-SSL connections (once each if supported):
407
+ # False - allow (try non-SSL first)
408
+ # True - prefer (try SSL first)
409
+ # None - other (don't retry, stick with the "ssl" parameter)
410
+ alt_retry_ssl_first = None
411
+
407
412
if isinstance (ssl , str ):
408
413
SSLMODES = {
409
414
'disable' : 0 ,
@@ -420,26 +425,21 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
420
425
raise exceptions .InterfaceError (
421
426
'`sslmode` parameter must be one of: {}' .format (modes ))
422
427
423
- # sslmode 'allow' is currently handled as 'prefer' because we're
424
- # missing the "retry with SSL" behavior for 'allow', but do have the
425
- # "retry without SSL" behavior for 'prefer'.
426
- # Not changing 'allow' to 'prefer' here would be effectively the same
427
- # as changing 'allow' to 'disable'.
428
428
if sslmode == SSLMODES ['allow' ]:
429
- sslmode = SSLMODES ['prefer' ]
429
+ alt_retry_ssl_first = False
430
+ elif sslmode == SSLMODES ['prefer' ]:
431
+ alt_retry_ssl_first = True
430
432
431
433
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432
434
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433
- if sslmode <= SSLMODES ['allow' ]:
435
+ if sslmode < SSLMODES ['allow' ]:
434
436
ssl = False
435
- ssl_is_advisory = sslmode >= SSLMODES ['allow' ]
436
437
else :
437
438
ssl = ssl_module .create_default_context ()
438
439
ssl .check_hostname = sslmode >= SSLMODES ['verify-full' ]
439
440
ssl .verify_mode = ssl_module .CERT_REQUIRED
440
441
if sslmode <= SSLMODES ['require' ]:
441
442
ssl .verify_mode = ssl_module .CERT_NONE
442
- ssl_is_advisory = sslmode <= SSLMODES ['prefer' ]
443
443
elif ssl is True :
444
444
ssl = ssl_module .create_default_context ()
445
445
@@ -453,7 +453,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453
453
454
454
params = _ConnectionParameters (
455
455
user = user , password = password , database = database , ssl = ssl ,
456
- ssl_is_advisory = ssl_is_advisory , connect_timeout = connect_timeout ,
456
+ alt_retry_ssl_first = alt_retry_ssl_first ,
457
+ connect_timeout = connect_timeout ,
457
458
server_settings = server_settings )
458
459
459
460
return addrs , params
@@ -520,9 +521,8 @@ def data_received(self, data):
520
521
data == b'N' ):
521
522
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522
523
# since the only way to get ssl_is_advisory is from
523
- # sslmode=prefer (or sslmode=allow). But be extra sure to
524
- # disallow insecure connections when the ssl context asks for
525
- # real security.
524
+ # sslmode=prefer. But be extra sure to disallow insecure
525
+ # connections when the ssl context asks for real security.
526
526
self .on_data .set_result (False )
527
527
else :
528
528
self .on_data .set_exception (
@@ -566,6 +566,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566
566
new_tr = tr
567
567
568
568
pg_proto = protocol_factory ()
569
+ pg_proto .is_ssl = do_ssl_upgrade
569
570
pg_proto .connection_made (new_tr )
570
571
new_tr .set_protocol (pg_proto )
571
572
@@ -584,7 +585,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584
585
tr .close ()
585
586
586
587
try :
587
- return await conn_factory (sock = sock )
588
+ new_tr , pg_proto = await conn_factory (sock = sock )
589
+ pg_proto .is_ssl = do_ssl_upgrade
590
+ return new_tr , pg_proto
588
591
except (Exception , asyncio .CancelledError ):
589
592
sock .close ()
590
593
raise
@@ -605,8 +608,6 @@ async def _connect_addr(
605
608
if timeout <= 0 :
606
609
raise asyncio .TimeoutError
607
610
608
- connected = _create_future (loop )
609
-
610
611
params_input = params
611
612
if callable (params .password ):
612
613
if inspect .iscoroutinefunction (params .password ):
@@ -615,6 +616,44 @@ async def _connect_addr(
615
616
password = params .password ()
616
617
617
618
params = params ._replace (password = password )
619
+ args = (addr , loop , config , connection_class , record_class , params_input )
620
+
621
+ # skip retry if alt_retry is not enabled
622
+ if params .alt_retry_ssl_first is None :
623
+ return await __connect_addr (params , timeout , * args )
624
+
625
+ # prepare the params (which attempt has ssl) for the 2 attempts
626
+ params_retry = params ._replace (ssl = None )
627
+ if not params .alt_retry_ssl_first :
628
+ params , params_retry = params_retry , params
629
+
630
+ # first attempt
631
+ before = time .monotonic ()
632
+ try :
633
+ return await __connect_addr (params , timeout , * args )
634
+ except ConnectionError :
635
+ pass
636
+
637
+ # the second attempt with alt_retry_ssl_first=None
638
+ timeout -= time .monotonic () - before
639
+ if timeout <= 0 :
640
+ raise asyncio .TimeoutError
641
+ else :
642
+ params_retry = params_retry ._replace (alt_retry_ssl_first = None )
643
+ return await __connect_addr (params_retry , timeout , * args )
644
+
645
+
646
+ async def __connect_addr (
647
+ params ,
648
+ timeout ,
649
+ addr ,
650
+ loop ,
651
+ config ,
652
+ connection_class ,
653
+ record_class ,
654
+ params_input ,
655
+ ):
656
+ connected = _create_future (loop )
618
657
619
658
proto_factory = lambda : protocol .Protocol (
620
659
addr , connected , params , record_class , loop )
@@ -625,7 +664,7 @@ async def _connect_addr(
625
664
elif params .ssl :
626
665
connector = _create_ssl_connection (
627
666
proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
628
- ssl_is_advisory = params .ssl_is_advisory )
667
+ ssl_is_advisory = params .alt_retry_ssl_first )
629
668
else :
630
669
connector = loop .create_connection (proto_factory , * addr )
631
670
@@ -638,6 +677,23 @@ async def _connect_addr(
638
677
if timeout <= 0 :
639
678
raise asyncio .TimeoutError
640
679
await compat .wait_for (connected , timeout = timeout )
680
+ except exceptions .InvalidAuthorizationSpecificationError :
681
+ tr .close ()
682
+
683
+ # pr.is_ssl is a bool, so this equal test implies
684
+ # alt_retry_ssl_first is not None (should do alt_retry)
685
+ if params .alt_retry_ssl_first == pr .is_ssl :
686
+ # Elevate the error to ConnectionError to trigger retry
687
+ raise ConnectionError ("Connection rejected trying {} SSL" .format (
688
+ 'with' if pr .is_ssl else 'without' ))
689
+
690
+ else :
691
+ # Don't retry if alt_retry_ssl_first is None, or we don't need to
692
+ # (alt_retry_ssl_first=True and pr.is_ssl=False means the server
693
+ # doesn't support SSL, and we've already tried to Startup without
694
+ # SSL but failed; The opposite case doesn't exist).
695
+ raise
696
+
641
697
except (Exception , asyncio .CancelledError ):
642
698
tr .close ()
643
699
raise
@@ -684,6 +740,7 @@ class CancelProto(asyncio.Protocol):
684
740
685
741
def __init__ (self ):
686
742
self .on_disconnect = _create_future (loop )
743
+ self .is_ssl = False
687
744
688
745
def connection_lost (self , exc ):
689
746
if not self .on_disconnect .done ():
@@ -692,17 +749,31 @@ def connection_lost(self, exc):
692
749
if isinstance (addr , str ):
693
750
tr , pr = await loop .create_unix_connection (CancelProto , addr )
694
751
else :
695
- if params .ssl :
696
- tr , pr = await _create_ssl_connection (
697
- CancelProto ,
698
- * addr ,
699
- loop = loop ,
700
- ssl_context = params .ssl ,
701
- ssl_is_advisory = params .ssl_is_advisory )
752
+ async def _connect (params_in , ssl_is_advisory ):
753
+ if params_in .ssl :
754
+ return await _create_ssl_connection (
755
+ CancelProto ,
756
+ * addr ,
757
+ loop = loop ,
758
+ ssl_context = params_in .ssl ,
759
+ ssl_is_advisory = ssl_is_advisory )
760
+ else :
761
+ rv = await loop .create_connection (
762
+ CancelProto , * addr )
763
+ _set_nodelay (_get_socket (rv [0 ]))
764
+ return rv
765
+
766
+ if params .alt_retry_ssl_first is None :
767
+ tr , pr = await _connect (params , False )
702
768
else :
703
- tr , pr = await loop .create_connection (
704
- CancelProto , * addr )
705
- _set_nodelay (_get_socket (tr ))
769
+ params_retry = params ._replace (ssl = None )
770
+ if not params .alt_retry_ssl_first :
771
+ params , params_retry = params_retry , params
772
+ try :
773
+ tr , pr = await _connect (params , True )
774
+ except ConnectionError :
775
+ tr , pr = await _connect (
776
+ params ._replace (alt_retry_ssl_first = None ), False )
706
777
707
778
# Pack a CancelRequest message
708
779
msg = struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
0 commit comments