@@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
318
318
'result' : ([('host' , 123 )], {
319
319
'user' : 'user' ,
320
320
'password' : 'passw' ,
321
- 'database' : 'testdb' })
321
+ 'database' : 'testdb' ,
322
+ 'ssl' : True ,
323
+ 'ssl_is_advisory' : True })
322
324
},
323
325
324
326
{
@@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
384
386
'user' : 'user3' ,
385
387
'password' : '123123' ,
386
388
'database' : 'abcdef' ,
387
- 'ssl' : ssl . SSLContext ,
389
+ 'ssl' : True ,
388
390
'ssl_is_advisory' : True })
389
391
},
390
392
@@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
461
463
'user' : 'me' ,
462
464
'password' : 'ask' ,
463
465
'database' : 'db' ,
464
- 'ssl' : ssl . SSLContext ,
466
+ 'ssl' : True ,
465
467
'ssl_is_advisory' : False })
466
468
},
467
469
@@ -617,7 +619,7 @@ def run_testcase(self, testcase):
617
619
password = testcase .get ('password' )
618
620
passfile = testcase .get ('passfile' )
619
621
database = testcase .get ('database' )
620
- ssl = testcase .get ('ssl' )
622
+ sslmode = testcase .get ('ssl' )
621
623
server_settings = testcase .get ('server_settings' )
622
624
623
625
expected = testcase .get ('result' )
@@ -640,21 +642,25 @@ def run_testcase(self, testcase):
640
642
641
643
addrs , params = connect_utils ._parse_connect_dsn_and_args (
642
644
dsn = dsn , host = host , port = port , user = user , password = password ,
643
- passfile = passfile , database = database , ssl = ssl ,
645
+ passfile = passfile , database = database , ssl = sslmode ,
644
646
connect_timeout = None , server_settings = server_settings )
645
647
646
- params = {k : v for k , v in params ._asdict ().items ()
647
- if v is not None }
648
+ params = {
649
+ k : v for k , v in params ._asdict ().items () if v is not None
650
+ }
651
+
652
+ if isinstance (params .get ('ssl' ), ssl .SSLContext ):
653
+ params ['ssl' ] = True
648
654
649
655
result = (addrs , params )
650
656
651
657
if expected is not None :
652
- for k , v in expected [1 ]. items () :
653
- # If `expected` contains a type, allow that to "match" any
654
- # instance of that type tyat `result` may contain. We need
655
- # this because different SSLContexts don't compare equal.
656
- if isinstance ( v , type ) and isinstance ( result [ 1 ]. get ( k ), v ):
657
- result [ 1 ][ k ] = v
658
+ if 'ssl' not in expected [1 ]:
659
+ # Avoid the hassle of specifying the default SSL mode
660
+ # unless explicitly tested for.
661
+ params . pop ( 'ssl' , None )
662
+ params . pop ( 'ssl_is_advisory' , None )
663
+
658
664
self .assertEqual (expected , result , 'Testcase: {}' .format (testcase ))
659
665
660
666
def test_test_connect_params_environ (self ):
@@ -1063,16 +1069,6 @@ async def verify_fails(sslmode):
1063
1069
await verify_fails ('verify-ca' )
1064
1070
await verify_fails ('verify-full' )
1065
1071
1066
- async def test_connection_ssl_unix (self ):
1067
- ssl_context = ssl .SSLContext (ssl .PROTOCOL_SSLv23 )
1068
- ssl_context .load_verify_locations (SSL_CA_CERT_FILE )
1069
-
1070
- with self .assertRaisesRegex (asyncpg .InterfaceError ,
1071
- 'can only be enabled for TCP addresses' ):
1072
- await self .connect (
1073
- host = '/tmp' ,
1074
- ssl = ssl_context )
1075
-
1076
1072
async def test_connection_implicit_host (self ):
1077
1073
conn_spec = self .get_connection_spec ()
1078
1074
con = await asyncpg .connect (
0 commit comments