From 75e2461bef523129f0826d562a93757d994af5d6 Mon Sep 17 00:00:00 2001 From: Matheus Alcantara Date: Mon, 20 Jan 2025 15:25:51 -0300 Subject: [PATCH v5 1/2] dblink: refactor get connection routines Refactor dblink_get_conn and dblink_connect to move the logic of actually opening the connection to the new connect_pg_server function which them can be re-used on both functions. This is a pre-work for a next commit that will add support for scram pass-through authentication to dblink which will be able to implement most of the logic into the connect_pg_server function which now already have all necessary data information. --- contrib/dblink/dblink.c | 199 ++++++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 98 deletions(-) diff --git a/contrib/dblink/dblink.c b/contrib/dblink/dblink.c index bed2dee3d72..f7641f1b2ba 100644 --- a/contrib/dblink/dblink.c +++ b/contrib/dblink/dblink.c @@ -117,7 +117,7 @@ static bool dblink_connstr_has_pw(const char *connstr); static void dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr); static void dblink_res_error(PGconn *conn, const char *conname, PGresult *res, bool fail, const char *fmt,...) pg_attribute_printf(5, 6); -static char *get_connect_string(const char *servername); +static char *get_connect_string(ForeignServer *foreign_server, UserMapping *user_mapping); static char *escape_param_str(const char *str); static void validate_pkattnums(Relation rel, int2vector *pkattnums_arg, int32 pknumatts_arg, @@ -126,6 +126,7 @@ static bool is_valid_dblink_option(const PQconninfoOption *options, const char *option, Oid context); static int applyRemoteGucs(PGconn *conn); static void restoreLocalGucs(int nestlevel); +static PGconn *connect_pg_server(char *connstr_or_srvname, remoteConn *rconn, uint32 wait_event_info); /* Global */ static remoteConn *pconn = NULL; @@ -201,33 +202,11 @@ dblink_get_conn(char *conname_or_str, } else { - const char *connstr; - - connstr = get_connect_string(conname_or_str); - if (connstr == NULL) - connstr = conname_or_str; - dblink_connstr_check(connstr); - /* first time, allocate or get the custom wait event */ if (dblink_we_get_conn == 0) dblink_we_get_conn = WaitEventExtensionNew("DblinkGetConnect"); - /* OK to make connection */ - conn = libpqsrv_connect(connstr, dblink_we_get_conn); - - if (PQstatus(conn) == CONNECTION_BAD) - { - char *msg = pchomp(PQerrorMessage(conn)); - - libpqsrv_disconnect(conn); - ereport(ERROR, - (errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), - errmsg("could not establish connection"), - errdetail_internal("%s", msg))); - } - dblink_security_check(conn, rconn, connstr); - if (PQclientEncoding(conn) != GetDatabaseEncoding()) - PQsetClientEncoding(conn, GetDatabaseEncodingName()); + conn = connect_pg_server(conname_or_str, rconn, dblink_we_get_conn); freeconn = true; conname = NULL; } @@ -272,9 +251,7 @@ Datum dblink_connect(PG_FUNCTION_ARGS) { char *conname_or_str = NULL; - char *connstr = NULL; char *connname = NULL; - char *msg; PGconn *conn = NULL; remoteConn *rconn = NULL; @@ -297,40 +274,21 @@ dblink_connect(PG_FUNCTION_ARGS) rconn->newXactForCursor = false; } - /* first check for valid foreign data server */ - connstr = get_connect_string(conname_or_str); - if (connstr == NULL) - connstr = conname_or_str; - - /* check password in connection string if not superuser */ - dblink_connstr_check(connstr); - /* first time, allocate or get the custom wait event */ if (dblink_we_connect == 0) dblink_we_connect = WaitEventExtensionNew("DblinkConnect"); - /* OK to make connection */ - conn = libpqsrv_connect(connstr, dblink_we_connect); - - if (PQstatus(conn) == CONNECTION_BAD) + PG_TRY(); + { + conn = connect_pg_server(conname_or_str, rconn, dblink_we_connect); + } + PG_CATCH(); { - msg = pchomp(PQerrorMessage(conn)); - libpqsrv_disconnect(conn); if (rconn) pfree(rconn); - - ereport(ERROR, - (errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), - errmsg("could not establish connection"), - errdetail_internal("%s", msg))); + PG_RE_THROW(); } - - /* check password actually used if not superuser */ - dblink_security_check(conn, rconn, connstr); - - /* attempt to set client encoding to match server encoding, if needed */ - if (PQclientEncoding(conn) != GetDatabaseEncoding()) - PQsetClientEncoding(conn, GetDatabaseEncodingName()); + PG_END_TRY(); if (connname) { @@ -2784,15 +2742,17 @@ dblink_res_error(PGconn *conn, const char *conname, PGresult *res, * Obtain connection string for a foreign server */ static char * -get_connect_string(const char *servername) +get_connect_string(ForeignServer *foreign_server, UserMapping *user_mapping) { - ForeignServer *foreign_server = NULL; - UserMapping *user_mapping; ListCell *cell; StringInfoData buf; ForeignDataWrapper *fdw; AclResult aclresult; - char *srvname; + + /* first gather the server connstr options */ + Oid serverid = foreign_server->serverid; + Oid fdwid = foreign_server->fdwid; + Oid userid = GetUserId(); static const PQconninfoOption *options = NULL; @@ -2815,57 +2775,42 @@ get_connect_string(const char *servername) errdetail("Could not get libpq's default connection options."))); } - /* first gather the server connstr options */ - srvname = pstrdup(servername); - truncate_identifier(srvname, strlen(srvname), false); - foreign_server = GetForeignServerByName(srvname, true); - - if (foreign_server) - { - Oid serverid = foreign_server->serverid; - Oid fdwid = foreign_server->fdwid; - Oid userid = GetUserId(); - - user_mapping = GetUserMapping(userid, serverid); - fdw = GetForeignDataWrapper(fdwid); - - /* Check permissions, user must have usage on the server. */ - aclresult = object_aclcheck(ForeignServerRelationId, serverid, userid, ACL_USAGE); - if (aclresult != ACLCHECK_OK) - aclcheck_error(aclresult, OBJECT_FOREIGN_SERVER, foreign_server->servername); + fdw = GetForeignDataWrapper(fdwid); - foreach(cell, fdw->options) - { - DefElem *def = lfirst(cell); + /* Check permissions, user must have usage on the server. */ + aclresult = object_aclcheck(ForeignServerRelationId, serverid, userid, ACL_USAGE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, OBJECT_FOREIGN_SERVER, foreign_server->servername); - if (is_valid_dblink_option(options, def->defname, ForeignDataWrapperRelationId)) - appendStringInfo(&buf, "%s='%s' ", def->defname, - escape_param_str(strVal(def->arg))); - } + foreach(cell, fdw->options) + { + DefElem *def = lfirst(cell); - foreach(cell, foreign_server->options) - { - DefElem *def = lfirst(cell); + if (is_valid_dblink_option(options, def->defname, ForeignDataWrapperRelationId)) + appendStringInfo(&buf, "%s='%s' ", def->defname, + escape_param_str(strVal(def->arg))); + } - if (is_valid_dblink_option(options, def->defname, ForeignServerRelationId)) - appendStringInfo(&buf, "%s='%s' ", def->defname, - escape_param_str(strVal(def->arg))); - } + foreach(cell, foreign_server->options) + { + DefElem *def = lfirst(cell); - foreach(cell, user_mapping->options) - { + if (is_valid_dblink_option(options, def->defname, ForeignServerRelationId)) + appendStringInfo(&buf, "%s='%s' ", def->defname, + escape_param_str(strVal(def->arg))); + } - DefElem *def = lfirst(cell); + foreach(cell, user_mapping->options) + { - if (is_valid_dblink_option(options, def->defname, UserMappingRelationId)) - appendStringInfo(&buf, "%s='%s' ", def->defname, - escape_param_str(strVal(def->arg))); - } + DefElem *def = lfirst(cell); - return buf.data; + if (is_valid_dblink_option(options, def->defname, UserMappingRelationId)) + appendStringInfo(&buf, "%s='%s' ", def->defname, + escape_param_str(strVal(def->arg))); } - else - return NULL; + + return buf.data; } /* @@ -3087,3 +3032,61 @@ restoreLocalGucs(int nestlevel) if (nestlevel > 0) AtEOXact_GUC(true, nestlevel); } + +/* + * Connect to remote server. If connstr_or_srvname maps to a foreign server, + * the associated properties and user mapping properties is also used to open + * the connection. Otherwise a connection will be open using the raw + * connstr_or_srvname value. + */ +static PGconn * +connect_pg_server(char *connstr_or_srvname, remoteConn *rconn, uint32 wait_event_info) +{ + PGconn *conn; + ForeignServer *foreign_server = NULL; + const char *connstr; + char *srvname; + Oid serverid; + UserMapping *user_mapping; + Oid userid = GetUserId(); + + /* first gather the server connstr options */ + srvname = pstrdup(connstr_or_srvname); + truncate_identifier(srvname, strlen(srvname), false); + foreign_server = GetForeignServerByName(srvname, true); + + if (foreign_server) + { + serverid = foreign_server->serverid; + user_mapping = GetUserMapping(userid, serverid); + + connstr = get_connect_string(foreign_server, user_mapping); + } + else + connstr = connstr_or_srvname; + + dblink_connstr_check(connstr); + + /* OK to make connection */ + conn = libpqsrv_connect(connstr, wait_event_info); + + if (PQstatus(conn) == CONNECTION_BAD) + { + char *msg = pchomp(PQerrorMessage(conn)); + + libpqsrv_disconnect(conn); + + ereport(ERROR, + (errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), + errmsg("could not establish connection"), + errdetail_internal("%s", msg))); + } + + dblink_security_check(conn, rconn, connstr); + + /* attempt to set client encoding to match server encoding, if needed */ + if (PQclientEncoding(conn) != GetDatabaseEncoding()) + PQsetClientEncoding(conn, GetDatabaseEncodingName()); + + return conn; +} -- 2.39.5 (Apple Git-154)