From fdbad1976a78d179b104138c31ac106e20338b0f Mon Sep 17 00:00:00 2001 From: Daniel Gustafsson Date: Wed, 21 Feb 2024 17:04:26 +0100 Subject: [PATCH v20 6/9] Introduce OAuth validator libraries This replaces the serverside validation code with an module API for loading in extensions for validating bearer tokens. A lot of code is left to be written. Co-authored-by: Jacob Champion --- src/backend/libpq/auth-oauth.c | 431 +++++------------- src/backend/utils/misc/guc_tables.c | 6 +- src/bin/pg_combinebackup/Makefile | 2 +- src/common/Makefile | 2 +- src/include/libpq/oauth.h | 29 +- src/test/modules/meson.build | 1 + src/test/modules/oauth_validator/.gitignore | 4 + src/test/modules/oauth_validator/Makefile | 19 + .../oauth_validator/expected/validator.out | 6 + src/test/modules/oauth_validator/meson.build | 33 ++ .../modules/oauth_validator/sql/validator.sql | 1 + .../modules/oauth_validator/t/001_server.pl | 78 ++++ src/test/modules/oauth_validator/validator.c | 82 ++++ src/test/perl/PostgreSQL/Test/Cluster.pm | 14 +- src/test/perl/PostgreSQL/Test/OAuthServer.pm | 183 ++++++++ src/tools/pgindent/typedefs.list | 2 + 16 files changed, 561 insertions(+), 332 deletions(-) create mode 100644 src/test/modules/oauth_validator/.gitignore create mode 100644 src/test/modules/oauth_validator/Makefile create mode 100644 src/test/modules/oauth_validator/expected/validator.out create mode 100644 src/test/modules/oauth_validator/meson.build create mode 100644 src/test/modules/oauth_validator/sql/validator.sql create mode 100644 src/test/modules/oauth_validator/t/001_server.pl create mode 100644 src/test/modules/oauth_validator/validator.c create mode 100644 src/test/perl/PostgreSQL/Test/OAuthServer.pm diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c index 16596c089a..024f304e4d 100644 --- a/src/backend/libpq/auth-oauth.c +++ b/src/backend/libpq/auth-oauth.c @@ -6,7 +6,7 @@ * See the following RFC for more details: * - RFC 7628: https://tools.ietf.org/html/rfc7628 * - * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group * Portions Copyright (c) 1994, Regents of the University of California * * src/backend/libpq/auth-oauth.c @@ -19,22 +19,30 @@ #include #include "common/oauth-common.h" +#include "fmgr.h" #include "lib/stringinfo.h" #include "libpq/auth.h" #include "libpq/hba.h" #include "libpq/oauth.h" #include "libpq/sasl.h" #include "storage/fd.h" +#include "storage/ipc.h" #include "utils/json.h" /* GUC */ -char *oauth_validator_command; +char *OAuthValidatorLibrary = ""; static void oauth_get_mechanisms(Port *port, StringInfo buf); static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass); static int oauth_exchange(void *opaq, const char *input, int inputlen, char **output, int *outputlen, const char **logdetail); +static void load_validator_library(void); +static void shutdown_validator_library(int code, Datum arg); + +static ValidatorModuleState *validator_module_state; +static const OAuthValidatorCallbacks *ValidatorCallbacks; + /* Mechanism declaration */ const pg_be_sasl_mech pg_be_oauth_mech = { oauth_get_mechanisms, @@ -63,11 +71,7 @@ struct oauth_ctx static char *sanitize_char(char c); static char *parse_kvpairs_for_auth(char **input); static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen); -static bool validate(Port *port, const char *auth, const char **logdetail); -static bool run_validator_command(Port *port, const char *token); -static bool check_exit(FILE **fh, const char *command); -static bool set_cloexec(int fd); -static bool username_ok_for_shell(const char *username); +static bool validate(Port *port, const char *auth); #define KVSEP 0x01 #define AUTH_KEY "auth" @@ -100,6 +104,8 @@ oauth_init(Port *port, const char *selected_mech, const char *shadow_pass) ctx->issuer = port->hba->oauth_issuer; ctx->scope = port->hba->oauth_scope; + load_validator_library(); + return ctx; } @@ -250,7 +256,7 @@ oauth_exchange(void *opaq, const char *input, int inputlen, errmsg("malformed OAUTHBEARER message"), errdetail("Message contains additional data after the final terminator."))); - if (!validate(ctx->port, auth, logdetail)) + if (!validate(ctx->port, auth)) { generate_error_response(ctx, output, outputlen); @@ -489,70 +495,73 @@ generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen) *outputlen = buf.len; } -static bool -validate(Port *port, const char *auth, const char **logdetail) +/*----- + * Validates the provided Authorization header and returns the token from + * within it. NULL is returned on validation failure. + * + * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec. + * 2.1: + * + * b64token = 1*( ALPHA / DIGIT / + * "-" / "." / "_" / "~" / "+" / "/" ) *"=" + * credentials = "Bearer" 1*SP b64token + * + * The "credentials" construction is what we receive in our auth value. + * + * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization + * header format; RFC 7235 Sec. 2), the "Bearer" scheme string must be + * compared case-insensitively. (This is not mentioned in RFC 6750, but + * it's pointed out in RFC 7628 Sec. 4.) + * + * Invalid formats are technically a protocol violation, but we shouldn't + * reflect any information about the sensitive Bearer token back to the + * client; log at COMMERROR instead. + * + * TODO: handle the Authorization spec, RFC 7235 Sec. 2.1. + */ +static const char * +validate_token_format(const char *header) { - static const char *const b64_set = + size_t span; + const char *token; + static const char *const b64token_allowed_set = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "0123456789-._~+/"; - const char *token; - size_t span; - int ret; + /* If the token is empty or simply too short to be correct */ + if (!header || strlen(header) <= 7) + { + ereport(COMMERROR, + (errmsg("malformed OAuth bearer token 1"))); + return NULL; + } - /* TODO: handle logdetail when the test framework can check it */ - - /*----- - * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec. - * 2.1: - * - * b64token = 1*( ALPHA / DIGIT / - * "-" / "." / "_" / "~" / "+" / "/" ) *"=" - * credentials = "Bearer" 1*SP b64token - * - * The "credentials" construction is what we receive in our auth value. - * - * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization - * header format; RFC 7235 Sec. 2), the "Bearer" scheme string must be - * compared case-insensitively. (This is not mentioned in RFC 6750, but - * it's pointed out in RFC 7628 Sec. 4.) - * - * TODO: handle the Authorization spec, RFC 7235 Sec. 2.1. - */ - if (pg_strncasecmp(auth, BEARER_SCHEME, strlen(BEARER_SCHEME))) - return false; + if (pg_strncasecmp(header, BEARER_SCHEME, strlen(BEARER_SCHEME))) + return NULL; /* Pull the bearer token out of the auth value. */ - token = auth + strlen(BEARER_SCHEME); + token = header + strlen(BEARER_SCHEME); /* Swallow any additional spaces. */ while (*token == ' ') token++; - /* - * Before invoking the validator command, sanity-check the token format to - * avoid any injection attacks later in the chain. Invalid formats are - * technically a protocol violation, but don't reflect any information - * about the sensitive Bearer token back to the client; log at COMMERROR - * instead. - */ - /* Tokens must not be empty. */ if (!*token) { ereport(COMMERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), - errmsg("malformed OAUTHBEARER message"), + errmsg("malformed OAuth bearer token 2"), errdetail("Bearer token is empty."))); - return false; + return NULL; } /* * Make sure the token contains only allowed characters. Tokens may end * with any number of '=' characters. */ - span = strspn(token, b64_set); + span = strspn(token, b64token_allowed_set); while (token[span] == '=') span++; @@ -565,15 +574,35 @@ validate(Port *port, const char *auth, const char **logdetail) */ ereport(COMMERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), - errmsg("malformed OAUTHBEARER message"), + errmsg("malformed OAuth bearer token 3"), errdetail("Bearer token is not in the correct format."))); - return false; + return NULL; } - /* Have the validator check the token. */ - if (!run_validator_command(port, token)) + return token; +} + +static bool +validate(Port *port, const char *auth) +{ + int map_status; + ValidatorModuleResult *ret; + const char *token; + + /* Ensure that we have a correct token to validate */ + if (!(token = validate_token_format(auth))) + return false; + + /* Call the validation function from the validator module */ + ret = ValidatorCallbacks->validate_cb(validator_module_state, + token, port->user_name); + + if (!ret->authorized) return false; + if (ret->authn_id) + set_authn_id(port, ret->authn_id); + if (port->hba->oauth_skip_usermap) { /* @@ -586,7 +615,7 @@ validate(Port *port, const char *auth, const char **logdetail) } /* Make sure the validator authenticated the user. */ - if (!MyClientConnectionInfo.authn_id) + if (ret->authn_id == NULL || ret->authn_id[0] == '\0') { /* TODO: use logdetail; reduce message duplication */ ereport(LOG, @@ -596,288 +625,42 @@ validate(Port *port, const char *auth, const char **logdetail) } /* Finally, check the user map. */ - ret = check_usermap(port->hba->usermap, port->user_name, - MyClientConnectionInfo.authn_id, false); - return (ret == STATUS_OK); -} - -static bool -run_validator_command(Port *port, const char *token) -{ - bool success = false; - int rc; - int pipefd[2]; - int rfd = -1; - int wfd = -1; - - StringInfoData command = {0}; - char *p; - FILE *fh = NULL; - - ssize_t written; - char *line = NULL; - size_t size = 0; - ssize_t len; - - Assert(oauth_validator_command); - - if (!oauth_validator_command[0]) - { - ereport(COMMERROR, - (errmsg("oauth_validator_command is not set"), - errhint("To allow OAuth authenticated connections, set " - "oauth_validator_command in postgresql.conf."))); - return false; - } - - /*------ - * Since popen() is unidirectional, open up a pipe for the other - * direction. Use CLOEXEC to ensure that our write end doesn't - * accidentally get copied into child processes, which would prevent us - * from closing it cleanly. - * - * XXX this is ugly. We should just read from the child process's stdout, - * but that's a lot more code. - * XXX by bypassing the popen API, we open the potential of process - * deadlock. Clearly document child process requirements (i.e. the child - * MUST read all data off of the pipe before writing anything). - * TODO: port to Windows using _pipe(). - */ - rc = pipe(pipefd); - if (rc < 0) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not create child pipe: %m"))); - return false; - } - - rfd = pipefd[0]; - wfd = pipefd[1]; - - if (!set_cloexec(wfd)) - { - /* error message was already logged */ - goto cleanup; - } - - /*---------- - * Construct the command, substituting any recognized %-specifiers: - * - * %f: the file descriptor of the input pipe - * %r: the role that the client wants to assume (port->user_name) - * %%: a literal '%' - */ - initStringInfo(&command); - - for (p = oauth_validator_command; *p; p++) - { - if (p[0] == '%') - { - switch (p[1]) - { - case 'f': - appendStringInfo(&command, "%d", rfd); - p++; - break; - case 'r': - - /* - * TODO: decide how this string should be escaped. The - * role is controlled by the client, so if we don't escape - * it, command injections are inevitable. - * - * This is probably an indication that the role name needs - * to be communicated to the validator process in some - * other way. For this proof of concept, just be - * incredibly strict about the characters that are allowed - * in user names. - */ - if (!username_ok_for_shell(port->user_name)) - goto cleanup; - - appendStringInfoString(&command, port->user_name); - p++; - break; - case '%': - appendStringInfoChar(&command, '%'); - p++; - break; - default: - appendStringInfoChar(&command, p[0]); - } - } - else - appendStringInfoChar(&command, p[0]); - } - - /* Execute the command. */ - fh = OpenPipeStream(command.data, "r"); - if (!fh) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("opening pipe to OAuth validator: %m"))); - goto cleanup; - } - - /* We don't need the read end of the pipe anymore. */ - close(rfd); - rfd = -1; - - /* Give the command the token to validate. */ - written = write(wfd, token, strlen(token)); - if (written != strlen(token)) - { - /* TODO must loop for short writes, EINTR et al */ - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not write token to child pipe: %m"))); - goto cleanup; - } - - close(wfd); - wfd = -1; - - /*----- - * Read the command's response. - * - * TODO: getline() is probably too new to use, unfortunately. - * TODO: loop over all lines - */ - if ((len = getline(&line, &size, fh)) >= 0) - { - /* TODO: fail if the authn_id doesn't end with a newline */ - if (len > 0) - line[len - 1] = '\0'; - - set_authn_id(port, line); - } - else if (ferror(fh)) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not read from command \"%s\": %m", - command.data))); - goto cleanup; - } - - /* Make sure the command exits cleanly. */ - if (!check_exit(&fh, command.data)) - { - /* error message already logged */ - goto cleanup; - } - - /* Done. */ - success = true; - -cleanup: - if (line) - free(line); - - /* - * In the successful case, the pipe fds are already closed. For the error - * case, always close out the pipe before waiting for the command, to - * prevent deadlock. - */ - if (rfd >= 0) - close(rfd); - if (wfd >= 0) - close(wfd); - - if (fh) - { - Assert(!success); - check_exit(&fh, command.data); - } - - if (command.data) - pfree(command.data); - - return success; + map_status = check_usermap(port->hba->usermap, port->user_name, + MyClientConnectionInfo.authn_id, false); + return (map_status == STATUS_OK); } -static bool -check_exit(FILE **fh, const char *command) +static void +load_validator_library(void) { - int rc; - - rc = ClosePipeStream(*fh); - *fh = NULL; - - if (rc == -1) - { - /* pclose() itself failed. */ - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not close pipe to command \"%s\": %m", - command))); - } - else if (rc != 0) - { - char *reason = wait_result_to_str(rc); + OAuthValidatorModuleInit validator_init; - ereport(COMMERROR, - (errmsg("failed to execute command \"%s\": %s", - command, reason))); + if (OAuthValidatorLibrary[0] == '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("oauth_validator_library is not set"))); - pfree(reason); - } + validator_init = (OAuthValidatorModuleInit) + load_external_function(OAuthValidatorLibrary, + "_PG_oauth_validator_module_init", false, NULL); - return (rc == 0); -} + if (validator_init == NULL) + ereport(ERROR, + (errmsg("%s module \"%s\" have to define the symbol %s", + "OAuth validator", OAuthValidatorLibrary, "_PG_oauth_validator_module_init"))); -static bool -set_cloexec(int fd) -{ - int flags; - int rc; + ValidatorCallbacks = (*validator_init) (); - flags = fcntl(fd, F_GETFD); - if (flags == -1) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not get fd flags for child pipe: %m"))); - return false; - } + validator_module_state = (ValidatorModuleState *) palloc0(sizeof(ValidatorModuleState)); + if (ValidatorCallbacks->startup_cb != NULL) + ValidatorCallbacks->startup_cb(validator_module_state); - rc = fcntl(fd, F_SETFD, flags | FD_CLOEXEC); - if (rc < 0) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not set FD_CLOEXEC for child pipe: %m"))); - return false; - } - - return true; + before_shmem_exit(shutdown_validator_library, 0); } -/* - * XXX This should go away eventually and be replaced with either a proper - * escape or a different strategy for communication with the validator command. - */ -static bool -username_ok_for_shell(const char *username) +static void +shutdown_validator_library(int code, Datum arg) { - /* This set is borrowed from fe_utils' appendShellStringNoError(). */ - static const char *const allowed = - "abcdefghijklmnopqrstuvwxyz" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "0123456789-_./:"; - size_t span; - - Assert(username && username[0]); /* should have already been checked */ - - span = strspn(username, allowed); - if (username[span] != '\0') - { - ereport(COMMERROR, - (errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator"))); - return false; - } - - return true; + if (ValidatorCallbacks->shutdown_cb != NULL) + ValidatorCallbacks->shutdown_cb(validator_module_state); } diff --git a/src/backend/utils/misc/guc_tables.c b/src/backend/utils/misc/guc_tables.c index d28209901f..99bb89d54e 100644 --- a/src/backend/utils/misc/guc_tables.c +++ b/src/backend/utils/misc/guc_tables.c @@ -4672,12 +4672,12 @@ struct config_string ConfigureNamesString[] = }, { - {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH, - gettext_noop("Command to validate OAuth v2 bearer tokens."), + {"oauth_validator_library", PGC_SIGHUP, CONN_AUTH_AUTH, + gettext_noop("Sets the library that will be called to validate OAuth v2 bearer tokens."), NULL, GUC_SUPERUSER_ONLY | GUC_NOT_IN_SAMPLE }, - &oauth_validator_command, + &OAuthValidatorLibrary, "", NULL, NULL, NULL }, diff --git a/src/bin/pg_combinebackup/Makefile b/src/bin/pg_combinebackup/Makefile index c3729755ba..4f24b1aff6 100644 --- a/src/bin/pg_combinebackup/Makefile +++ b/src/bin/pg_combinebackup/Makefile @@ -31,7 +31,7 @@ OBJS = \ all: pg_combinebackup pg_combinebackup: $(OBJS) | submake-libpgport submake-libpgfeutils - $(CC) $(CFLAGS) $^ $(LDFLAGS) $(LDFLAGS_EX) $(LIBS) -o $@$(X) + $(CC) $(CFLAGS) $^ $(LDFLAGS) $(LDFLAGS_EX) $(libpq_pgport) $(LIBS) -o $@$(X) install: all installdirs $(INSTALL_PROGRAM) pg_combinebackup$(X) '$(DESTDIR)$(bindir)/pg_combinebackup$(X)' diff --git a/src/common/Makefile b/src/common/Makefile index bbb5c3ab11..00e30e6bfe 100644 --- a/src/common/Makefile +++ b/src/common/Makefile @@ -41,7 +41,7 @@ override CPPFLAGS += -DVAL_LDFLAGS_SL="\"$(LDFLAGS_SL)\"" override CPPFLAGS += -DVAL_LIBS="\"$(LIBS)\"" override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common -I$(libpq_srcdir) $(CPPFLAGS) -LIBS += $(PTHREAD_LIBS) +LIBS += $(PTHREAD_LIBS) $(libpq_pgport) OBJS_COMMON = \ archive.o \ diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h index 5edab3b25a..6f98e84cc9 100644 --- a/src/include/libpq/oauth.h +++ b/src/include/libpq/oauth.h @@ -3,7 +3,7 @@ * oauth.h * Interface to libpq/auth-oauth.c * - * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group * Portions Copyright (c) 1994, Regents of the University of California * * src/include/libpq/oauth.h @@ -16,7 +16,32 @@ #include "libpq/libpq-be.h" #include "libpq/sasl.h" -extern char *oauth_validator_command; +extern PGDLLIMPORT char *OAuthValidatorLibrary; + +typedef struct ValidatorModuleState +{ + void *private_data; +} ValidatorModuleState; + +typedef struct ValidatorModuleResult +{ + bool authorized; + char *authn_id; +} ValidatorModuleResult; + +typedef void (*ValidatorStartupCB) (ValidatorModuleState *state); +typedef void (*ValidatorShutdownCB) (ValidatorModuleState *state); +typedef ValidatorModuleResult *(*ValidatorValidateCB) (ValidatorModuleState *state, const char *token, const char *role); + +typedef struct OAuthValidatorCallbacks +{ + ValidatorStartupCB startup_cb; + ValidatorShutdownCB shutdown_cb; + ValidatorValidateCB validate_cb; +} OAuthValidatorCallbacks; + +typedef const OAuthValidatorCallbacks *(*OAuthValidatorModuleInit) (void); +extern PGDLLEXPORT const OAuthValidatorCallbacks *_PG_oauth_validator_module_init(void); /* Implementation */ extern const pg_be_sasl_mech pg_be_oauth_mech; diff --git a/src/test/modules/meson.build b/src/test/modules/meson.build index 8fbe742d38..dc54ce7189 100644 --- a/src/test/modules/meson.build +++ b/src/test/modules/meson.build @@ -9,6 +9,7 @@ subdir('gin') subdir('injection_points') subdir('ldap_password_func') subdir('libpq_pipeline') +subdir('oauth_validator') subdir('plsample') subdir('spgist_name_ops') subdir('ssl_passphrase_callback') diff --git a/src/test/modules/oauth_validator/.gitignore b/src/test/modules/oauth_validator/.gitignore new file mode 100644 index 0000000000..5dcb3ff972 --- /dev/null +++ b/src/test/modules/oauth_validator/.gitignore @@ -0,0 +1,4 @@ +# Generated subdirectories +/log/ +/results/ +/tmp_check/ diff --git a/src/test/modules/oauth_validator/Makefile b/src/test/modules/oauth_validator/Makefile new file mode 100644 index 0000000000..1f874cd7f2 --- /dev/null +++ b/src/test/modules/oauth_validator/Makefile @@ -0,0 +1,19 @@ +MODULES = validator +PGFILEDESC = "validator - test OAuth validator module" + +NO_INSTALLCHECK = 1 + +TAP_TESTS = 1 + +REGRESS = validator + +ifdef USE_PGXS +PG_CONFIG = pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) +else +subdir = src/test/modules/oauth_validator +top_builddir = ../../../.. +include $(top_builddir)/src/Makefile.global +include $(top_srcdir)/contrib/contrib-global.mk +endif diff --git a/src/test/modules/oauth_validator/expected/validator.out b/src/test/modules/oauth_validator/expected/validator.out new file mode 100644 index 0000000000..360caa2cb3 --- /dev/null +++ b/src/test/modules/oauth_validator/expected/validator.out @@ -0,0 +1,6 @@ +SELECT 1; + ?column? +---------- + 1 +(1 row) + diff --git a/src/test/modules/oauth_validator/meson.build b/src/test/modules/oauth_validator/meson.build new file mode 100644 index 0000000000..d9c1d1d577 --- /dev/null +++ b/src/test/modules/oauth_validator/meson.build @@ -0,0 +1,33 @@ +# Copyright (c) 2024, PostgreSQL Global Development Group + +validator_sources = files( + 'validator.c', +) + +if host_system == 'windows' + validator_sources += rc_lib_gen.process(win32ver_rc, extra_args: [ + '--NAME', 'validator', + '--FILEDESC', 'validator - test OAuth validator module',]) +endif + +validator = shared_module('validator', + validator_sources, + kwargs: pg_test_mod_args, +) +test_install_libs += validator + +tests += { + 'name': 'oauth_validator', + 'sd': meson.current_source_dir(), + 'bd': meson.current_build_dir(), + 'regress': { + 'sql': [ + 'validator', + ], + }, + 'tap': { + 'tests': [ + 't/001_server.pl', + ], + }, +} diff --git a/src/test/modules/oauth_validator/sql/validator.sql b/src/test/modules/oauth_validator/sql/validator.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/src/test/modules/oauth_validator/sql/validator.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/src/test/modules/oauth_validator/t/001_server.pl b/src/test/modules/oauth_validator/t/001_server.pl new file mode 100644 index 0000000000..14c7778298 --- /dev/null +++ b/src/test/modules/oauth_validator/t/001_server.pl @@ -0,0 +1,78 @@ + +# Copyright (c) 2021-2024, PostgreSQL Global Development Group + +use strict; +use warnings FATAL => 'all'; + +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use PostgreSQL::Test::OAuthServer; +use Test::More; + +my $node = PostgreSQL::Test::Cluster->new('primary'); +$node->init; +$node->append_conf('postgresql.conf', "log_connections = on\n"); +$node->append_conf('postgresql.conf', "shared_preload_libraries = 'validator'\n"); +$node->append_conf('postgresql.conf', "oauth_validator_library = 'validator'\n"); +$node->start; + +$node->safe_psql('postgres', 'CREATE USER test;'); +$node->safe_psql('postgres', 'CREATE USER testalt;'); + +my $issuer = "127.0.0.1:18080"; + +unlink($node->data_dir . '/pg_hba.conf'); +$node->append_conf('pg_hba.conf', qq{ +local all test oauth issuer="$issuer" scope="openid postgres" +local all testalt oauth issuer="$issuer/alternate" scope="openid postgres alt" +}); +$node->reload; + +my $webserver = PostgreSQL::Test::OAuthServer->new(18080); + +my $port = $webserver->port(); + +is($port, 18080, "Port is 18080"); + +$webserver->setup(); +$webserver->run(); + +my ($log_start, $log_end); +$log_start = $node->wait_for_log(qr/reloading configuration files/); + +my $user = "test"; +$node->connect_ok("user=$user dbname=postgres oauth_client_id=f02c6361-0635", "connect", + expected_stderr => qr@Visit https://example\.com/ and enter the code: postgresuser@); + +$log_end = $node->wait_for_log(qr/connection authorized/, $log_start); +$node->log_check("user $user: validator receives correct parameters", $log_start, + log_like => [ + qr/oauth_validator: token="9243959234", role="$user"/, + qr/oauth_validator: issuer="\Q$issuer\E", scope="openid postgres"/, + ]); +$node->log_check("user $user: validator sets authenticated identity", $log_start, + log_like => [ + qr/connection authenticated: identity="test" method=oauth/, + ]); +$log_start = $log_end; + +# The /alternate issuer uses slightly different parameters. +$user = "testalt"; +$node->connect_ok("user=$user dbname=postgres oauth_client_id=f02c6361-0636", "connect", + expected_stderr => qr@Visit https://example\.org/ and enter the code: postgresuser@); + +$log_end = $node->wait_for_log(qr/connection authorized/, $log_start); +$node->log_check("user $user: validator receives correct parameters", $log_start, + log_like => [ + qr/oauth_validator: token="9243959234-alt", role="$user"/, + qr|oauth_validator: issuer="\Q$issuer/alternate\E", scope="openid postgres alt"|, + ]); +$node->log_check("user $user: validator sets authenticated identity", $log_start, + log_like => [ + qr/connection authenticated: identity="testalt" method=oauth/, + ]); +$log_start = $log_end; + +$node->stop; + +done_testing(); diff --git a/src/test/modules/oauth_validator/validator.c b/src/test/modules/oauth_validator/validator.c new file mode 100644 index 0000000000..09a4bf61d2 --- /dev/null +++ b/src/test/modules/oauth_validator/validator.c @@ -0,0 +1,82 @@ +/*------------------------------------------------------------------------- + * + * validator.c + * Test module for serverside OAuth token validation callbacks + * + * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/test/modules/oauth_validator/validator.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include "fmgr.h" +#include "libpq/oauth.h" +#include "miscadmin.h" +#include "utils/memutils.h" + +PG_MODULE_MAGIC; + +static void validator_startup(ValidatorModuleState *state); +static void validator_shutdown(ValidatorModuleState *state); +static ValidatorModuleResult * validate_token(ValidatorModuleState *state, + const char *token, + const char *role); + +static const OAuthValidatorCallbacks validator_callbacks = { + .startup_cb = validator_startup, + .shutdown_cb = validator_shutdown, + .validate_cb = validate_token +}; + +void +_PG_init(void) +{ + /* no-op */ +} + +const OAuthValidatorCallbacks * +_PG_oauth_validator_module_init(void) +{ + return &validator_callbacks; +} + +#define PRIVATE_COOKIE ((void *) 13579) + +static void +validator_startup(ValidatorModuleState *state) +{ + state->private_data = PRIVATE_COOKIE; +} + +static void +validator_shutdown(ValidatorModuleState *state) +{ + /* do nothing */ +} + +static ValidatorModuleResult * +validate_token(ValidatorModuleState *state, const char *token, const char *role) +{ + ValidatorModuleResult *res; + + /* Check to make sure our private state still exists. */ + if (state->private_data != PRIVATE_COOKIE) + elog(ERROR, "oauth_validator: private state cookie changed to %p", + state->private_data); + + res = palloc(sizeof(ValidatorModuleResult)); + + elog(LOG, "oauth_validator: token=\"%s\", role=\"%s\"", token, role); + elog(LOG, "oauth_validator: issuer=\"%s\", scope=\"%s\"", + MyProcPort->hba->oauth_issuer, + MyProcPort->hba->oauth_scope); + + res->authorized = true; + res->authn_id = pstrdup(role); + + return res; +} diff --git a/src/test/perl/PostgreSQL/Test/Cluster.pm b/src/test/perl/PostgreSQL/Test/Cluster.pm index 4fec417f6f..b291bbf8ee 100644 --- a/src/test/perl/PostgreSQL/Test/Cluster.pm +++ b/src/test/perl/PostgreSQL/Test/Cluster.pm @@ -2302,6 +2302,11 @@ instead of the default. If this regular expression is set, matches it with the output generated. +=item expected_stderr => B + +If this regular expression is set, matches it against the standard error +stream; otherwise the stderr must be empty. + =item log_like => [ qr/required message/ ] =item log_unlike => [ qr/prohibited message/ ] @@ -2345,7 +2350,14 @@ sub connect_ok like($stdout, $params{expected_stdout}, "$test_name: stdout matches"); } - is($stderr, "", "$test_name: no stderr"); + if (defined($params{expected_stderr})) + { + like($stderr, $params{expected_stderr}, "$test_name: stderr matches"); + } + else + { + is($stderr, "", "$test_name: no stderr"); + } $self->log_check($test_name, $log_location, %params); } diff --git a/src/test/perl/PostgreSQL/Test/OAuthServer.pm b/src/test/perl/PostgreSQL/Test/OAuthServer.pm new file mode 100644 index 0000000000..5c195efb79 --- /dev/null +++ b/src/test/perl/PostgreSQL/Test/OAuthServer.pm @@ -0,0 +1,183 @@ +#!/usr/bin/perl + +package PostgreSQL::Test::OAuthServer; + +use warnings; +use strict; +use threads; +use Socket; +use IO::Select; + +local *server_socket; + +sub new +{ + my $class = shift; + my $port = shift; + + my $self = {}; + bless($self, $class); + + $self->{'port'} = $port; + + return $self; +} + +sub setup +{ + my $self = shift; + my $tcp = getprotobyname('tcp'); + + socket($self->{'socket'}, PF_INET, SOCK_STREAM, $tcp) + or die "no socket"; + setsockopt($self->{'socket'}, SOL_SOCKET, SO_REUSEADDR, pack("l", 1)); + bind($self->{'socket'}, sockaddr_in($self->{'port'}, INADDR_ANY)); +} + +sub port +{ + my $self = shift; + + return $self->{'port'}; +} + +sub run +{ + my $self = shift; + + my $server_thread = threads->create(\&_listen, $self); + $server_thread->detach(); +} + +sub _listen +{ + my $self = shift; + + listen($self->{'socket'}, SOMAXCONN) or die "fail to listen: $!"; + + while (1) + { + my $fh; + my %request; + my $remote = accept($fh, $self->{'socket'}); + binmode $fh; + + my ($method, $object, $prot) = split(/ /, <$fh>); + $request{'method'} = $method; + $request{'object'} = $object; + chomp($request{'object'}); + + local $/ = Socket::CRLF; + my $c = 0; + while(<$fh>) + { + chomp; + # Headers + if (/:/) + { + my ($field, $value) = split(/:/, $_, 2); + $value =~ s/^\s+//; + $request{'headers'}{lc $field} = $value; + } + # POST data + elsif (/^$/) + { + read($fh, $request{'content'}, $request{'headers'}{'content-length'}) + if defined $request{'headers'}{'content-length'}; + last; + } + } + + # Debug printing + # print ": read ".$request{'method'} . ";" . $request{'object'}.";\n"; + # foreach my $h (keys(%{$request{'headers'}})) + #{ + # printf ": headers: " . $request{'headers'}{$h} . "\n"; + #} + #printf ": POST: " . $request{'content'} . "\n" if defined $request{'content'}; + + my $alternate = 0; + if ($request{'object'} =~ qr|^/alternate(/.*)$|) + { + $alternate = 1; + $request{'object'} = $1; + } + + if ($request{'object'} eq '/.well-known/openid-configuration') + { + my $issuer = "http://localhost:$self->{'port'}"; + if ($alternate) + { + $issuer .= "/alternate"; + } + + print $fh "HTTP/1.0 200 OK\r\nServer: Postgres Regress\r\n"; + print $fh "Content-Type: application/json\r\n"; + print $fh "\r\n"; + print $fh <