diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c index 746d7cbb8a..ecaab21e13 100644 --- a/src/backend/libpq/auth.c +++ b/src/backend/libpq/auth.c @@ -873,6 +873,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) int inputlen; int result; bool initial; + List *channel_bindings = NIL; /* * SASL auth is not supported for protocol versions before 3, because it @@ -898,7 +899,17 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) strlen(SCRAM_SHA256_NAME) + 3); p = sasl_mechs; - if (port->ssl_in_use) +#ifdef USE_SSL + /* + * Get the list of channel binding types supported by this SSL + * implementation to determine if server should publish -PLUS + * mechanisms or not. + */ + channel_bindings = be_tls_list_channel_bindings(); +#endif + + if (port->ssl_in_use && + list_length(channel_bindings) > 0) { strcpy(p, SCRAM_SHA256_PLUS_NAME); p += strlen(SCRAM_SHA256_PLUS_NAME) + 1; diff --git a/src/backend/libpq/be-secure-openssl.c b/src/backend/libpq/be-secure-openssl.c index fc6e8a0a88..95511a61b3 100644 --- a/src/backend/libpq/be-secure-openssl.c +++ b/src/backend/libpq/be-secure-openssl.c @@ -58,6 +58,7 @@ #include #endif +#include "common/scram-common.h" #include "libpq/libpq.h" #include "miscadmin.h" #include "pgstat.h" @@ -1215,6 +1216,18 @@ be_tls_get_peerdn_name(Port *port, char *ptr, size_t len) ptr[0] = '\0'; } +/* + * Routine to get the list of channel binding types available in this SSL + * implementation. For OpenSSL, both tls-unique and tls-server-end-point + * are supported. + */ +List * +be_tls_list_channel_bindings(void) +{ + return list_make2(pstrdup(SCRAM_CHANNEL_BINDING_TLS_UNIQUE), + pstrdup(SCRAM_CHANNEL_BINDING_TLS_END_POINT)); +} + /* * Routine to get the expected TLS Finished message information from the * client, useful for authorization when doing channel binding. diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h index 49cb263110..3c37e800c1 100644 --- a/src/include/libpq/libpq-be.h +++ b/src/include/libpq/libpq-be.h @@ -209,6 +209,7 @@ extern bool be_tls_get_compression(Port *port); extern void be_tls_get_version(Port *port, char *ptr, size_t len); extern void be_tls_get_cipher(Port *port, char *ptr, size_t len); extern void be_tls_get_peerdn_name(Port *port, char *ptr, size_t len); +extern List *be_tls_list_channel_bindings(void); extern char *be_tls_get_peer_finished(Port *port, size_t *len); extern char *be_tls_get_certificate_hash(Port *port, size_t *len); #endif