From b7243e7da3b1d5574f8a82b55f26164d65124b3c Mon Sep 17 00:00:00 2001 From: Jacob Champion Date: Tue, 19 Aug 2025 12:56:45 -0700 Subject: [PATCH v2 3/6] WIP: pytest: Add some SSL client tests This is a sample client-only test suite. It tests some handshake failures against a mock server, as well as a full SSL handshake + empty query + response. pyca/cryptography is added as a new package dependency. Certificates for testing are generated on the fly. The `pg` test package contains some helpers and fixtures (as well as some self-tests for more complicated behavior). Of note: - pg.require_test_extra() lets you mark a test/class/module as skippable if PG_TEST_EXTRA does not contain the necessary strings. - pg.remaining_timeout() is a function which can be repeatedly called to determine how much of the PG_TEST_TIMEOUT_DEFAULT remains for the current test item. - pg.libpq is a fixture that wraps libpq.so in a more friendly, but still low-level, ctypes FFI. Allocated resources are unwound and released during test teardown. The mock design is threaded: the server socket is listening on a background thread, and the test provides the server logic via a callback. There is some additional work still needed to make this production-ready; see the notes for _TCPServer.background(). (Currently, an exception in the wrong place could result in a hang-until-timeout rather than an immediate failure.) TODOs: - local_server and tcp_server_class are nearly identical and should share code. - fix exception-related timeouts for .background() - figure out the proper use of "session" vs "module" scope - ensure that pq.libpq unwinds (to close connections) before tcp_server; see comment in test_server_with_ssl_disabled() --- .cirrus.tasks.yml | 18 +- config/pytest-requirements.txt | 10 ++ pytest.ini | 3 + src/test/pytest/meson.build | 1 + src/test/pytest/pg/__init__.py | 3 + src/test/pytest/pg/_env.py | 55 ++++++ src/test/pytest/pg/fixtures.py | 212 +++++++++++++++++++++++ src/test/pytest/pyt/conftest.py | 3 + src/test/pytest/pyt/test_libpq.py | 171 ++++++++++++++++++ src/test/ssl/Makefile | 2 + src/test/ssl/meson.build | 6 + src/test/ssl/pyt/conftest.py | 129 ++++++++++++++ src/test/ssl/pyt/test_client.py | 278 ++++++++++++++++++++++++++++++ 13 files changed, 885 insertions(+), 6 deletions(-) create mode 100644 src/test/pytest/pg/__init__.py create mode 100644 src/test/pytest/pg/_env.py create mode 100644 src/test/pytest/pg/fixtures.py create mode 100644 src/test/pytest/pyt/conftest.py create mode 100644 src/test/pytest/pyt/test_libpq.py create mode 100644 src/test/ssl/pyt/conftest.py create mode 100644 src/test/ssl/pyt/test_client.py diff --git a/.cirrus.tasks.yml b/.cirrus.tasks.yml index 80f9b394bd2..4e744f1c105 100644 --- a/.cirrus.tasks.yml +++ b/.cirrus.tasks.yml @@ -225,6 +225,7 @@ task: sysctl kern.corefile='/tmp/cores/%N.%P.core' setup_additional_packages_script: | pkg install -y \ + py311-cryptography \ py311-packaging \ py311-pytest @@ -316,6 +317,7 @@ task: setup_additional_packages_script: | pkgin -y install \ + py312-cryptography \ py312-packaging \ py312-test ln -s /usr/pkg/bin/pytest-3.12 /usr/pkg/bin/pytest @@ -339,8 +341,9 @@ task: setup_additional_packages_script: | pkg_add -I \ - py3-test \ - py3-packaging + py3-cryptography \ + py3-packaging \ + py3-test # Always core dump to ${CORE_DUMP_DIR} set_core_dump_script: sysctl -w kern.nosuidcoredump=2 <<: *openbsd_task_template @@ -501,8 +504,9 @@ task: setup_additional_packages_script: | apt-get update DEBIAN_FRONTEND=noninteractive apt-get -y install \ - python3-pytest \ - python3-packaging + python3-cryptography \ + python3-packaging \ + python3-pytest matrix: # SPECIAL: @@ -643,6 +647,7 @@ task: CIRRUS_WORKING_DIR: ${HOME}/pgsql/ CCACHE_DIR: ${HOME}/ccache MACPORTS_CACHE: ${HOME}/macports-cache + PYTEST_DEBUG_TEMPROOT: /tmp # default is too long for UNIX sockets on Mac MESON_FEATURES: >- -Dbonjour=enabled @@ -663,6 +668,7 @@ task: p5.34-io-tty p5.34-ipc-run python312 + py312-cryptography py312-packaging py312-pytest tcl @@ -801,7 +807,7 @@ task: # XXX Does Chocolatey really not have any Python package installers? setup_additional_packages_script: | REM choco install -y --no-progress ... - pip3 install --user packaging pytest + pip3 install --user cryptography packaging pytest setup_hosts_file_script: | echo 127.0.0.1 pg-loadbalancetest >> c:\Windows\System32\Drivers\etc\hosts @@ -864,7 +870,7 @@ task: folder: ${CCACHE_DIR} setup_additional_packages_script: | - C:\msys64\usr\bin\pacman.exe -S --noconfirm mingw-w64-ucrt-x86_64-python-packaging mingw-w64-ucrt-x86_64-python-pytest + C:\msys64\usr\bin\pacman.exe -S --noconfirm mingw-w64-ucrt-x86_64-python-cryptography mingw-w64-ucrt-x86_64-python-packaging mingw-w64-ucrt-x86_64-python-pytest mingw_info_script: | %BASH% -c "where gcc" diff --git a/config/pytest-requirements.txt b/config/pytest-requirements.txt index b941624b2f3..0bd6cadf608 100644 --- a/config/pytest-requirements.txt +++ b/config/pytest-requirements.txt @@ -19,3 +19,13 @@ pytest >= 7.0, < 9 # packaging is used by check_pytest.py at configure time. packaging + +# Notes on the cryptography package: +# - 3.3.2 is shipped on Debian bullseye. +# - 3.4.x drops support for Python 2, making it a version of note for older LTS +# distros. +# - 35.x switched versioning schemes and moved to Rust parsing. +# - 40.x is the last version supporting Python 3.6. +# XXX Is it appropriate to require cryptography, or should we simply skip +# dependent tests? +cryptography >= 3.3.2 diff --git a/pytest.ini b/pytest.ini index 8e8388f3afc..e7aa84f3a84 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,3 +4,6 @@ minversion = 7.0 # Ignore ./config (which contains the configure-time check_pytest.py tests) by # default. addopts = --ignore ./config + +# Common test code can be found here. +pythonpath = src/test/pytest diff --git a/src/test/pytest/meson.build b/src/test/pytest/meson.build index abd128dfa24..f53193e8686 100644 --- a/src/test/pytest/meson.build +++ b/src/test/pytest/meson.build @@ -11,6 +11,7 @@ tests += { 'pytest': { 'tests': [ 'pyt/test_something.py', + 'pyt/test_libpq.py', ], }, } diff --git a/src/test/pytest/pg/__init__.py b/src/test/pytest/pg/__init__.py new file mode 100644 index 00000000000..ef8faf54ca4 --- /dev/null +++ b/src/test/pytest/pg/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +from ._env import has_test_extra, require_test_extra diff --git a/src/test/pytest/pg/_env.py b/src/test/pytest/pg/_env.py new file mode 100644 index 00000000000..6f18af07844 --- /dev/null +++ b/src/test/pytest/pg/_env.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import logging +import os +from typing import List, Optional + +import pytest + +logger = logging.getLogger(__name__) + + +def has_test_extra(key: str) -> bool: + """ + Returns True if the PG_TEST_EXTRA environment variable contains the given + key. + """ + extra = os.getenv("PG_TEST_EXTRA", "") + return key in extra.split() + + +def require_test_extra(*keys: str) -> bool: + """ + A convenience annotation which will skip tests if all of the required keys + are not present in PG_TEST_EXTRA. + + To skip a particular test function or class: + + @pg.require_test_extra("ldap") + def test_some_ldap_feature(): + ... + + To skip an entire module: + + pytestmark = pg.require_test_extra("ssl", "kerberos") + """ + return pytest.mark.skipif( + not all([has_test_extra(k) for k in keys]), + reason="requires {} to be set in PG_TEST_EXTRA".format(", ".join(keys)), + ) + + +def test_timeout_default() -> int: + """ + Returns the value of the PG_TEST_TIMEOUT_DEFAULT environment variable, in + seconds, or 180 if one was not provided. + """ + default = os.getenv("PG_TEST_TIMEOUT_DEFAULT", "") + if not default: + return 180 + + try: + return int(default) + except ValueError as v: + logger.warning("PG_TEST_TIMEOUT_DEFAULT could not be parsed: " + str(v)) + return 180 diff --git a/src/test/pytest/pg/fixtures.py b/src/test/pytest/pg/fixtures.py new file mode 100644 index 00000000000..b5d3bff69a8 --- /dev/null +++ b/src/test/pytest/pg/fixtures.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import contextlib +import ctypes +import platform +import time +from typing import Any, Callable, Dict + +import pytest + +from ._env import test_timeout_default + + +@pytest.fixture +def remaining_timeout(): + """ + This fixture provides a function that returns how much of the + PG_TEST_TIMEOUT_DEFAULT remains for the current test, in fractional seconds. + This value is never less than zero. + + This fixture is per-test, so the deadline is also reset on a per-test basis. + """ + now = time.monotonic() + deadline = now + test_timeout_default() + + return lambda: max(deadline - time.monotonic(), 0) + + +class _PGconn(ctypes.Structure): + pass + + +class _PGresult(ctypes.Structure): + pass + + +_PGconn_p = ctypes.POINTER(_PGconn) +_PGresult_p = ctypes.POINTER(_PGresult) + + +@pytest.fixture(scope="session") +def libpq_handle(): + """ + Loads a ctypes handle for libpq. Some common function prototypes are + initialized for general use. + """ + system = platform.system() + + if system in ("Linux", "FreeBSD", "NetBSD", "OpenBSD"): + name = "libpq.so.5" + elif system == "Darwin": + name = "libpq.5.dylib" + elif system == "Windows": + name = "libpq.dll" + else: + assert False, f"the libpq fixture must be updated for {system}" + + # XXX ctypes.CDLL() is a little stricter with load paths on Windows. The + # preferred way around that is to know the absolute path to libpq.dll, but + # that doesn't seem to mesh well with the current test infrastructure. For + # now, enable "standard" LoadLibrary behavior. + loadopts = {} + if system == "Windows": + loadopts["winmode"] = 0 + + lib = ctypes.CDLL(name, **loadopts) + + # + # Function Prototypes + # + + lib.PQconnectdb.restype = _PGconn_p + lib.PQconnectdb.argtypes = [ctypes.c_char_p] + + lib.PQstatus.restype = ctypes.c_int + lib.PQstatus.argtypes = [_PGconn_p] + + lib.PQexec.restype = _PGresult_p + lib.PQexec.argtypes = [_PGconn_p, ctypes.c_char_p] + + lib.PQresultStatus.restype = ctypes.c_int + lib.PQresultStatus.argtypes = [_PGresult_p] + + lib.PQclear.restype = None + lib.PQclear.argtypes = [_PGresult_p] + + lib.PQerrorMessage.restype = ctypes.c_char_p + lib.PQerrorMessage.argtypes = [_PGconn_p] + + lib.PQfinish.restype = None + lib.PQfinish.argtypes = [_PGconn_p] + + return lib + + +class PGresult(contextlib.AbstractContextManager): + """Wraps a raw _PGresult_p with a more friendly interface.""" + + def __init__(self, lib: ctypes.CDLL, res: _PGresult_p): + self._lib = lib + self._res = res + + def __exit__(self, *exc): + self._lib.PQclear(self._res) + self._res = None + + def status(self): + return self._lib.PQresultStatus(self._res) + + +class PGconn(contextlib.AbstractContextManager): + """ + Wraps a raw _PGconn_p with a more friendly interface. This is just a + stub; it's expected to grow. + """ + + def __init__( + self, + lib: ctypes.CDLL, + handle: _PGconn_p, + stack: contextlib.ExitStack, + ): + self._lib = lib + self._handle = handle + self._stack = stack + + def __exit__(self, *exc): + self._lib.PQfinish(self._handle) + self._handle = None + + def exec(self, query: str) -> PGresult: + """ + Executes a query via PQexec() and returns a PGresult. + """ + res = self._lib.PQexec(self._handle, query.encode()) + return self._stack.enter_context(PGresult(self._lib, res)) + + +@pytest.fixture +def libpq(libpq_handle, remaining_timeout): + """ + Provides a ctypes-based API wrapped around libpq.so. This fixture keeps + track of allocated resources and cleans them up during teardown. See + _Libpq's public API for details. + """ + + class _Libpq(contextlib.ExitStack): + CONNECTION_OK = 0 + + PGRES_EMPTY_QUERY = 0 + + class Error(RuntimeError): + """ + libpq.Error is the exception class for application-level errors that + are encountered during libpq operations. + """ + + pass + + def __init__(self): + super().__init__() + self.lib = libpq_handle + + def _connstr(self, opts: Dict[str, Any]) -> str: + """ + Flattens the provided options into a libpq connection string. Values + are converted to str and quoted/escaped as necessary. + """ + settings = [] + + for k, v in opts.items(): + v = str(v) + if not v: + v = "''" + else: + v = v.replace("\\", "\\\\") + v = v.replace("'", "\\'") + + if " " in v: + v = f"'{v}'" + + settings.append(f"{k}={v}") + + return " ".join(settings) + + def must_connect(self, **opts) -> PGconn: + """ + Connects to a server, using the given connection options, and + returns a libpq.PGconn object wrapping the connection handle. A + failure will raise libpq.Error. + + Connections honor PG_TEST_TIMEOUT_DEFAULT unless connect_timeout is + explicitly overridden in opts. + """ + + if "connect_timeout" not in opts: + t = int(remaining_timeout()) + opts["connect_timeout"] = max(t, 1) + + conn_p = self.lib.PQconnectdb(self._connstr(opts).encode()) + + # Ensure the connection handle is always closed at the end of the + # test. + conn = self.enter_context(PGconn(self.lib, conn_p, stack=self)) + + if self.lib.PQstatus(conn_p) != self.CONNECTION_OK: + raise self.Error(self.lib.PQerrorMessage(conn_p).decode()) + + return conn + + with _Libpq() as lib: + yield lib diff --git a/src/test/pytest/pyt/conftest.py b/src/test/pytest/pyt/conftest.py new file mode 100644 index 00000000000..ecb72be26d7 --- /dev/null +++ b/src/test/pytest/pyt/conftest.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +from pg.fixtures import * diff --git a/src/test/pytest/pyt/test_libpq.py b/src/test/pytest/pyt/test_libpq.py new file mode 100644 index 00000000000..9f0857cc612 --- /dev/null +++ b/src/test/pytest/pyt/test_libpq.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import contextlib +import os +import socket +import struct +import threading +from typing import Callable + +import pytest + + +@pytest.mark.parametrize( + "opts, expected", + [ + (dict(), ""), + (dict(port=5432), "port=5432"), + (dict(port=5432, dbname="postgres"), "port=5432 dbname=postgres"), + (dict(host=""), "host=''"), + (dict(host=" "), r"host=' '"), + (dict(keyword="'"), r"keyword=\'"), + (dict(keyword=" \\' "), r"keyword=' \\\' '"), + ], +) +def test_connstr(libpq, opts, expected): + """Tests the escape behavior for libpq._connstr().""" + assert libpq._connstr(opts) == expected + + +def test_must_connect_errors(libpq): + """Tests that must_connect() raises libpq.Error.""" + with pytest.raises(libpq.Error, match="invalid connection option"): + libpq.must_connect(some_unknown_keyword="whatever") + + +@pytest.fixture +def local_server(tmp_path, remaining_timeout): + """ + Opens up a local UNIX socket for mocking a Postgres server on a background + thread. See the _Server API for usage. + + This fixture requires AF_UNIX support; dependent tests will be skipped on + platforms that don't provide it. + """ + + try: + from socket import AF_UNIX + except ImportError: + pytest.skip("AF_UNIX not supported on this platform") + + class _Server(contextlib.ExitStack): + """ + Implementation class for local_server. See .background() for the primary + entry point for tests. Postgres clients may connect to this server via + local_server.host/local_server.port. + + _Server derives from contextlib.ExitStack to provide easy cleanup of + associated resources; see the documentation for that class for a full + explanation. + """ + + def __init__(self): + super().__init__() + + self.host = tmp_path + self.port = 5432 + + self._thread = None + self._thread_exc = None + self._listener = self.enter_context( + socket.socket(AF_UNIX, socket.SOCK_STREAM), + ) + + def bind_and_listen(self): + """ + Does the actual work of binding the UNIX socket using the Postgres + server conventions and listening for connections. + + The listen backlog is currently hardcoded to one. + """ + sockfile = self.host / ".s.PGSQL.{}".format(self.port) + + # Lock down the permissions on the new socket. + prev_mask = os.umask(0o077) + + # Bind (creating the socket file), and immediately register it for + # deletion from disk when the stack is cleaned up. + self._listener.bind(bytes(sockfile)) + self.callback(os.unlink, sockfile) + + os.umask(prev_mask) + + self._listener.listen(1) + + def background(self, fn: Callable[[socket.socket], None]) -> None: + """ + Accepts a client connection on a background thread and passes it to + the provided callback. Any exceptions raised from the callback will + be re-raised on the main thread during fixture teardown. + + Blocking operations on the connected socket default to using the + remaining_timeout(), though this can be changed by the test via the + socket's .settimeout(). + """ + + def _bg(): + try: + self._listener.settimeout(remaining_timeout()) + sock, _ = self._listener.accept() + + with sock: + sock.settimeout(remaining_timeout()) + fn(sock) + + except Exception as e: + # Save the exception for re-raising on the main thread. + self._thread_exc = e + + # TODO: rather than using callback(), consider explicitly signaling + # the fn() implementation to stop early if we get an exception. + # Otherwise we'll hang until the end of the timeout. + self._thread = threading.Thread(target=_bg) + self.callback(self._join) + + self._thread.start() + + def _join(self): + """ + Waits for the background thread to finish and raises any thrown + exception. This is called during fixture teardown. + """ + # Give a little bit of wiggle room on the join timeout, since we're + # racing against the test's own use of remaining_timeout(). (It's + # preferable to let tests report timeouts; the stack traces will + # help with debugging.) + self._thread.join(remaining_timeout() + 1) + if self._thread.is_alive(): + raise TimeoutError("background thread is still running after timeout") + + if self._thread_exc is not None: + raise self._thread_exc + + with _Server() as s: + s.bind_and_listen() + yield s + + +def test_connection_is_finished_on_error(libpq, local_server, remaining_timeout): + """Tests that PQfinish() gets called at the end of testing.""" + expected_error = "something is wrong" + + def serve_error(s: socket.socket) -> None: + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Quick check for the startup packet version. + version = struct.unpack("!HH", s.recv(4)) + assert version == (3, 0) + + # Discard the remainder of the startup packet and send a v2 error. + s.recv(pktlen - 8) + s.send(b"E" + expected_error.encode() + b"\0") + + # And now the socket should be closed. + assert not s.recv(1), "client sent unexpected data" + + local_server.background(serve_error) + + with pytest.raises(libpq.Error, match=expected_error): + # Exiting this context should result in PQfinish(). + with libpq: + libpq.must_connect(host=local_server.host, port=local_server.port) diff --git a/src/test/ssl/Makefile b/src/test/ssl/Makefile index e8a1639db2d..895ea5ea41c 100644 --- a/src/test/ssl/Makefile +++ b/src/test/ssl/Makefile @@ -30,6 +30,8 @@ clean distclean: # Doesn't depend on sslfiles because we don't rebuild them by default check: $(prove_check) + # XXX these suites should run independently, not serially + $(pytest_check) installcheck: $(prove_installcheck) diff --git a/src/test/ssl/meson.build b/src/test/ssl/meson.build index d8e0fb518e0..a0ee2af0899 100644 --- a/src/test/ssl/meson.build +++ b/src/test/ssl/meson.build @@ -15,4 +15,10 @@ tests += { 't/003_sslinfo.pl', ], }, + 'pytest': { + 'tests': [ + 'pyt/test_client.py', + 'pyt/test_server.py', + ], + }, } diff --git a/src/test/ssl/pyt/conftest.py b/src/test/ssl/pyt/conftest.py new file mode 100644 index 00000000000..fb4db372f03 --- /dev/null +++ b/src/test/ssl/pyt/conftest.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import datetime +import tempfile +from collections import namedtuple + +import pytest + +import pg +from pg.fixtures import * + + +@pytest.fixture(scope="session") +def cryptography(): + return pytest.importorskip("cryptography", "3.3.2") + + +Cert = namedtuple("Cert", "cert, certpath, key, keypath") + + +@pytest.fixture(scope="session") +def certs(cryptography, tmp_path_factory): + """ + Caches commonly used certificates at the session level, and provides a way + to create new ones. + + - certs.ca: the root CA certificate + + - certs.server: the "standard" server certficate, signed by certs.ca + + - certs.server_host: the hostname of the certs.server certificate + + - certs.new(): creates a custom certificate, signed by certs.ca + """ + + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + + tmpdir = tmp_path_factory.mktemp("test-certs") + + class _Certs: + def __init__(self): + self.ca = self.new( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "PG pytest CA")], + ), + ca=True, + ) + + self.server_host = "example.org" + self.server = self.new( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, self.server_host)], + ) + ) + + def new(self, subject: x509.Name, *, ca=False) -> Cert: + """ + Creates and signs a new Cert with the given subject name. If ca is + True, the certificate will be self-signed; otherwise the certificate + is signed by self.ca. + """ + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + builder = x509.CertificateBuilder() + now = datetime.datetime.now(datetime.timezone.utc) + + builder = ( + builder.subject_name(subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(hours=1)) + ) + + if ca: + builder = builder.issuer_name(subject) + else: + builder = builder.issuer_name(self.ca.cert.subject) + + builder = builder.add_extension( + x509.BasicConstraints(ca=ca, path_length=None), + critical=True, + ) + + cert = builder.sign( + private_key=key if ca else self.ca.key, + algorithm=hashes.SHA256(), + ) + + # Dump the certificate and key to file. + keypath = self._tofile( + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ), + suffix=".key", + ) + certpath = self._tofile( + cert.public_bytes(serialization.Encoding.PEM), + suffix="-ca.crt" if ca else ".crt", + ) + + return Cert( + cert=cert, + certpath=certpath, + key=key, + keypath=keypath, + ) + + def _tofile(self, data: bytes, *, suffix) -> str: + """ + Dumps data to a file on disk with the requested suffix and returns + the path. The file is located somewhere in pytest's temporary + directory root. + """ + f = tempfile.NamedTemporaryFile(suffix=suffix, dir=tmpdir, delete=False) + with f: + f.write(data) + + return f.name + + return _Certs() diff --git a/src/test/ssl/pyt/test_client.py b/src/test/ssl/pyt/test_client.py new file mode 100644 index 00000000000..28110ae0717 --- /dev/null +++ b/src/test/ssl/pyt/test_client.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025, PostgreSQL Global Development Group + +import contextlib +import ctypes +import socket +import ssl +import struct +import threading +from typing import Callable + +import pytest + +import pg + +# This suite opens up local TCP ports and is hidden behind PG_TEST_EXTRA=ssl. +pytestmark = pg.require_test_extra("ssl") + + +@pytest.fixture(scope="session", autouse=True) +def skip_if_no_ssl_support(libpq_handle): + """Skips tests if SSL support is not configured.""" + + # Declare PQsslAttribute(). + PQsslAttribute = libpq_handle.PQsslAttribute + PQsslAttribute.restype = ctypes.c_char_p + PQsslAttribute.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + + if not PQsslAttribute(None, b"library"): + pytest.skip("requires SSL support to be configured") + + +# +# Test Fixtures +# + + +@pytest.fixture +def tcp_server_class(remaining_timeout): + """ + Metafixture to combine related logic for tcp_server and ssl_server. + + TODO: combine with test_libpq.local_server + """ + + class _TCPServer(contextlib.ExitStack): + """ + Implementation class for tcp_server. See .background() for the primary + entry point for tests. Postgres clients may connect to this server via + **tcp_server.conninfo. + + _TCPServer derives from contextlib.ExitStack to provide easy cleanup of + associated resources; see the documentation for that class for a full + explanation. + """ + + def __init__(self): + super().__init__() + + self._thread = None + self._thread_exc = None + self._listener = self.enter_context( + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ) + + self._bind_and_listen() + sockname = self._listener.getsockname() + self.conninfo = dict( + hostaddr=sockname[0], + port=sockname[1], + ) + + def _bind_and_listen(self): + """ + Does the actual work of binding the socket and listening for + connections. + + The listen backlog is currently hardcoded to one. + """ + self._listener.bind(("127.0.0.1", 0)) + self._listener.listen(1) + + def background(self, fn: Callable[[socket.socket], None]) -> None: + """ + Accepts a client connection on a background thread and passes it to + the provided callback. Any exceptions raised from the callback will + be re-raised on the main thread during fixture teardown. + + Blocking operations on the connected socket default to using the + remaining_timeout(), though this can be changed by the test via the + socket's .settimeout(). + """ + + def _bg(): + try: + self._listener.settimeout(remaining_timeout()) + sock, _ = self._listener.accept() + + with sock: + sock.settimeout(remaining_timeout()) + fn(sock) + + except Exception as e: + # Save the exception for re-raising on the main thread. + self._thread_exc = e + + # TODO: rather than using callback(), consider explicitly signaling + # the fn() implementation to stop early if we get an exception. + # Otherwise we'll hang until the end of the timeout. + self._thread = threading.Thread(target=_bg) + self.callback(self._join) + + self._thread.start() + + def _join(self): + """ + Waits for the background thread to finish and raises any thrown + exception. This is called during fixture teardown. + """ + # Give a little bit of wiggle room on the join timeout, since we're + # racing against the test's own use of remaining_timeout(). (It's + # preferable to let tests report timeouts; the stack traces will + # help with debugging.) + self._thread.join(remaining_timeout() + 1) + if self._thread.is_alive(): + raise TimeoutError("background thread is still running after timeout") + + if self._thread_exc is not None: + raise self._thread_exc + + return _TCPServer + + +@pytest.fixture +def tcp_server(tcp_server_class): + """ + Opens up a local TCP socket for mocking a Postgres server on a background + thread. See the _TCPServer API for usage. + """ + with tcp_server_class() as s: + yield s + + +@pytest.fixture +def ssl_server(tcp_server_class, certs): + """ + Like tcp_server, but with an additional .background_ssl() method which will + perform a SSLRequest handshake on the socket before handing the connection + to the test callback. + + This server uses certs.server as its identity. + """ + + class _SSLServer(tcp_server_class): + def __init__(self): + super().__init__() + + self.conninfo["host"] = certs.server_host + + self._ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self._ctx.load_cert_chain(certs.server.certpath, certs.server.keypath) + + def background_ssl(self, fn: Callable[[ssl.SSLSocket], None]) -> None: + """ + Invokes a server callback as with .background(), but an SSLRequest + handshake is performed first, and the socket provided to the + callback has been wrapped in an OpenSSL layer. + """ + + def handshake(s: socket.socket): + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Make sure we get an SSLRequest. + version = struct.unpack("!HH", s.recv(4)) + assert version == (1234, 5679) + assert pktlen == 8 + + # Accept the SSLRequest. + s.send(b"S") + + with self._ctx.wrap_socket(s, server_side=True) as wrapped: + fn(wrapped) + + self.background(handshake) + + with _SSLServer() as s: + yield s + + +# +# Tests +# + + +@pytest.mark.parametrize("sslmode", ("require", "verify-ca", "verify-full")) +def test_server_with_ssl_disabled(libpq, tcp_server, certs, sslmode): + """ + Make sure client refuses to talk to non-SSL servers with stricter + sslmodes. + """ + + def refuse_ssl(s: socket.socket): + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Make sure we get an SSLRequest. + version = struct.unpack("!HH", s.recv(4)) + assert version == (1234, 5679) + assert pktlen == 8 + + # Refuse the SSLRequest. + s.send(b"N") + + # Wait for the client to close the connection. + assert not s.recv(1), "client sent unexpected data" + + tcp_server.background(refuse_ssl) + + with pytest.raises(libpq.Error, match="server does not support SSL"): + with libpq: # XXX tests shouldn't need to do this + libpq.must_connect( + **tcp_server.conninfo, + sslrootcert=certs.ca.certpath, + sslmode=sslmode, + ) + + +def test_verify_full_connection(libpq, ssl_server, certs): + """Completes a verify-full connection and empty query.""" + + def handle_empty_query(s: ssl.SSLSocket): + pktlen = struct.unpack("!I", s.recv(4))[0] + + # Check the startup packet version, then discard the remainder. + version = struct.unpack("!HH", s.recv(4)) + assert version == (3, 0) + s.recv(pktlen - 8) + + # Send the required litany of server messages. + s.send(struct.pack("!cII", b"R", 8, 0)) # AuthenticationOK + + # ParameterStatus: client_encoding + key = b"client_encoding\0" + val = b"UTF-8\0" + s.send(struct.pack("!cI", b"S", 4 + len(key) + len(val)) + key + val) + + # ParameterStatus: DateStyle + key = b"DateStyle\0" + val = b"ISO, MDY\0" + s.send(struct.pack("!cI", b"S", 4 + len(key) + len(val)) + key + val) + + s.send(struct.pack("!cIII", b"K", 12, 1234, 1234)) # BackendKeyData + s.send(struct.pack("!cIc", b"Z", 5, b"I")) # ReadyForQuery + + # Expect an empty query. + pkttype = s.recv(1) + assert pkttype == b"Q" + pktlen = struct.unpack("!I", s.recv(4))[0] + assert s.recv(pktlen - 4) == b"\0" + + # Send an EmptyQueryResponse+ReadyForQuery. + s.send(struct.pack("!cI", b"I", 4)) + s.send(struct.pack("!cIc", b"Z", 5, b"I")) + + # libpq should terminate and close the connection. + assert s.recv(1) == b"X" + pktlen = struct.unpack("!I", s.recv(4))[0] + assert pktlen == 4 + + assert not s.recv(1), "client sent unexpected data" + + ssl_server.background_ssl(handle_empty_query) + + conn = libpq.must_connect( + **ssl_server.conninfo, + sslrootcert=certs.ca.certpath, + sslmode="verify-full", + ) + with conn: + assert conn.exec("").status() == libpq.PGRES_EMPTY_QUERY -- 2.34.1