|
18 | 18 | import ssl as ssl_module
|
19 | 19 | import stat
|
20 | 20 | import struct
|
| 21 | +import sys |
21 | 22 | import time
|
22 | 23 | import typing
|
23 | 24 | import urllib.parse
|
@@ -220,13 +221,35 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
|
220 | 221 | return hosts, port
|
221 | 222 |
|
222 | 223 |
|
| 224 | +def _parse_tls_version(tls_version): |
| 225 | + if not hasattr(ssl_module, 'TLSVersion'): |
| 226 | + raise ValueError( |
| 227 | + "TLSVersion is not supported in this version of Python" |
| 228 | + ) |
| 229 | + if tls_version.startswith('SSL'): |
| 230 | + raise ValueError( |
| 231 | + f"Unsupported TLS version: {tls_version}" |
| 232 | + ) |
| 233 | + try: |
| 234 | + return ssl_module.TLSVersion[tls_version.replace('.', '_')] |
| 235 | + except KeyError: |
| 236 | + raise ValueError( |
| 237 | + f"No such TLS version: {tls_version}" |
| 238 | + ) |
| 239 | + |
| 240 | + |
| 241 | +def _dot_postgresql_path(filename) -> pathlib.Path: |
| 242 | + return (pathlib.Path.home() / '.postgresql' / filename).resolve() |
| 243 | + |
| 244 | + |
223 | 245 | def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
224 | 246 | password, passfile, database, ssl,
|
225 | 247 | connect_timeout, server_settings):
|
226 | 248 | # `auth_hosts` is the version of host information for the purposes
|
227 | 249 | # of reading the pgpass file.
|
228 | 250 | auth_hosts = None
|
229 |
| - sslcert = sslkey = sslrootcert = sslcrl = None |
| 251 | + sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None |
| 252 | + ssl_min_protocol_version = ssl_max_protocol_version = None |
230 | 253 |
|
231 | 254 | if dsn:
|
232 | 255 | parsed = urllib.parse.urlparse(dsn)
|
@@ -312,24 +335,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
312 | 335 | ssl = val
|
313 | 336 |
|
314 | 337 | if 'sslcert' in query:
|
315 |
| - val = query.pop('sslcert') |
316 |
| - if sslcert is None: |
317 |
| - sslcert = val |
| 338 | + sslcert = query.pop('sslcert') |
318 | 339 |
|
319 | 340 | if 'sslkey' in query:
|
320 |
| - val = query.pop('sslkey') |
321 |
| - if sslkey is None: |
322 |
| - sslkey = val |
| 341 | + sslkey = query.pop('sslkey') |
323 | 342 |
|
324 | 343 | if 'sslrootcert' in query:
|
325 |
| - val = query.pop('sslrootcert') |
326 |
| - if sslrootcert is None: |
327 |
| - sslrootcert = val |
| 344 | + sslrootcert = query.pop('sslrootcert') |
328 | 345 |
|
329 | 346 | if 'sslcrl' in query:
|
330 |
| - val = query.pop('sslcrl') |
331 |
| - if sslcrl is None: |
332 |
| - sslcrl = val |
| 347 | + sslcrl = query.pop('sslcrl') |
| 348 | + |
| 349 | + if 'sslpassword' in query: |
| 350 | + sslpassword = query.pop('sslpassword') |
| 351 | + |
| 352 | + if 'ssl_min_protocol_version' in query: |
| 353 | + ssl_min_protocol_version = query.pop( |
| 354 | + 'ssl_min_protocol_version' |
| 355 | + ) |
| 356 | + |
| 357 | + if 'ssl_max_protocol_version' in query: |
| 358 | + ssl_max_protocol_version = query.pop( |
| 359 | + 'ssl_max_protocol_version' |
| 360 | + ) |
333 | 361 |
|
334 | 362 | if query:
|
335 | 363 | if server_settings is None:
|
@@ -451,34 +479,97 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
451 | 479 | if sslmode < SSLMode.allow:
|
452 | 480 | ssl = False
|
453 | 481 | else:
|
454 |
| - ssl = ssl_module.create_default_context( |
455 |
| - ssl_module.Purpose.SERVER_AUTH) |
| 482 | + ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) |
456 | 483 | ssl.check_hostname = sslmode >= SSLMode.verify_full
|
457 |
| - ssl.verify_mode = ssl_module.CERT_REQUIRED |
458 |
| - if sslmode <= SSLMode.require: |
| 484 | + if sslmode < SSLMode.require: |
459 | 485 | ssl.verify_mode = ssl_module.CERT_NONE
|
| 486 | + else: |
| 487 | + if sslrootcert is None: |
| 488 | + sslrootcert = os.getenv('PGSSLROOTCERT') |
| 489 | + if sslrootcert: |
| 490 | + ssl.load_verify_locations(cafile=sslrootcert) |
| 491 | + ssl.verify_mode = ssl_module.CERT_REQUIRED |
| 492 | + else: |
| 493 | + sslrootcert = _dot_postgresql_path('root.crt') |
| 494 | + try: |
| 495 | + ssl.load_verify_locations(cafile=sslrootcert) |
| 496 | + except FileNotFoundError: |
| 497 | + if sslmode > SSLMode.require: |
| 498 | + raise ValueError( |
| 499 | + f'root certificate file "{sslrootcert}" does ' |
| 500 | + f'not exist\nEither provide the file or ' |
| 501 | + f'change sslmode to disable server ' |
| 502 | + f'certificate verification.' |
| 503 | + ) |
| 504 | + elif sslmode == SSLMode.require: |
| 505 | + ssl.verify_mode = ssl_module.CERT_NONE |
| 506 | + else: |
| 507 | + assert False, 'unreachable' |
| 508 | + else: |
| 509 | + ssl.verify_mode = ssl_module.CERT_REQUIRED |
460 | 510 |
|
461 |
| - if sslcert is None: |
462 |
| - sslcert = os.getenv('PGSSLCERT') |
| 511 | + if sslcrl is None: |
| 512 | + sslcrl = os.getenv('PGSSLCRL') |
| 513 | + if sslcrl: |
| 514 | + ssl.load_verify_locations(cafile=sslcrl) |
| 515 | + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN |
| 516 | + else: |
| 517 | + sslcrl = _dot_postgresql_path('root.crl') |
| 518 | + try: |
| 519 | + ssl.load_verify_locations(cafile=sslcrl) |
| 520 | + except FileNotFoundError: |
| 521 | + pass |
| 522 | + else: |
| 523 | + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN |
463 | 524 |
|
464 | 525 | if sslkey is None:
|
465 | 526 | sslkey = os.getenv('PGSSLKEY')
|
466 |
| - |
467 |
| - if sslrootcert is None: |
468 |
| - sslrootcert = os.getenv('PGSSLROOTCERT') |
469 |
| - |
470 |
| - if sslcrl is None: |
471 |
| - sslcrl = os.getenv('PGSSLCRL') |
472 |
| - |
| 527 | + if not sslkey: |
| 528 | + sslkey = _dot_postgresql_path('postgresql.key') |
| 529 | + if not sslkey.exists(): |
| 530 | + sslkey = None |
| 531 | + if not sslpassword: |
| 532 | + sslpassword = '' |
| 533 | + if sslcert is None: |
| 534 | + sslcert = os.getenv('PGSSLCERT') |
473 | 535 | if sslcert:
|
474 |
| - ssl.load_cert_chain(sslcert, keyfile=sslkey) |
475 |
| - |
476 |
| - if sslrootcert: |
477 |
| - ssl.load_verify_locations(cafile=sslrootcert) |
478 |
| - |
479 |
| - if sslcrl: |
480 |
| - ssl.load_verify_locations(cafile=sslcrl) |
481 |
| - ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN |
| 536 | + ssl.load_cert_chain( |
| 537 | + sslcert, keyfile=sslkey, password=lambda: sslpassword |
| 538 | + ) |
| 539 | + else: |
| 540 | + sslcert = _dot_postgresql_path('postgresql.crt') |
| 541 | + try: |
| 542 | + ssl.load_cert_chain( |
| 543 | + sslcert, keyfile=sslkey, password=lambda: sslpassword |
| 544 | + ) |
| 545 | + except FileNotFoundError: |
| 546 | + pass |
| 547 | + |
| 548 | + # OpenSSL 1.1.1 keylog file, copied from create_default_context() |
| 549 | + if hasattr(ssl, 'keylog_filename'): |
| 550 | + keylogfile = os.environ.get('SSLKEYLOGFILE') |
| 551 | + if keylogfile and not sys.flags.ignore_environment: |
| 552 | + ssl.keylog_filename = keylogfile |
| 553 | + |
| 554 | + if ssl_min_protocol_version is None: |
| 555 | + ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') |
| 556 | + if ssl_min_protocol_version: |
| 557 | + ssl.minimum_version = _parse_tls_version( |
| 558 | + ssl_min_protocol_version |
| 559 | + ) |
| 560 | + else: |
| 561 | + try: |
| 562 | + ssl.minimum_version = _parse_tls_version('TLSv1.2') |
| 563 | + except ValueError: |
| 564 | + # Python 3.6 does not have ssl.TLSVersion |
| 565 | + pass |
| 566 | + |
| 567 | + if ssl_max_protocol_version is None: |
| 568 | + ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') |
| 569 | + if ssl_max_protocol_version: |
| 570 | + ssl.maximum_version = _parse_tls_version( |
| 571 | + ssl_max_protocol_version |
| 572 | + ) |
482 | 573 |
|
483 | 574 | elif ssl is True:
|
484 | 575 | ssl = ssl_module.create_default_context()
|
|
0 commit comments