From 3461694c75776fe9c876e377f207e970e6a1a530 Mon Sep 17 00:00:00 2001 From: Ashutosh Sharma Date: Thu, 13 Jun 2024 07:52:51 +0000 Subject: [PATCH] Introduce a new control file flag called 'protected' for extensions This flag controls PostgreSQL's behavior in setting the implicit search_path within the proconfig for functions created by an extension that does not have the search_path explicitly set in proconfig. When enabled, the search_path is set to $extension_schema, pg_temp, function_schema, where function_schema is included only if it differs from the extension's schema. $extension_schema resolves to all schemas on which the extension depends, a process triggered during the recomputation of the namespace path. --- src/backend/catalog/namespace.c | 22 +++++++++ src/backend/catalog/pg_depend.c | 43 +++++++++++++++++ src/backend/commands/extension.c | 66 +++++++++++++++++++++----- src/backend/commands/functioncmds.c | 73 ++++++++++++++++++++++++++++- src/backend/parser/gram.y | 41 ++++++++++++++++ src/backend/utils/fmgr/fmgr.c | 16 +++++++ src/include/catalog/dependency.h | 1 + src/include/commands/extension.h | 5 ++ 8 files changed, 254 insertions(+), 13 deletions(-) diff --git a/src/backend/catalog/namespace.c b/src/backend/catalog/namespace.c index a2510cf80c..a61f3d0c90 100644 --- a/src/backend/catalog/namespace.c +++ b/src/backend/catalog/namespace.c @@ -42,6 +42,7 @@ #include "catalog/pg_ts_template.h" #include "catalog/pg_type.h" #include "commands/dbcommands.h" +#include "commands/extension.h" #include "common/hashfn_unstable.h" #include "funcapi.h" #include "mb/pg_wchar.h" @@ -4152,6 +4153,27 @@ preprocessNamespacePath(const char *searchPath, Oid roleid, *temp_missing = true; } } + else if (strcmp(curname, "$extension_schema") == 0) + { + Oid extOid = GetCurrentExtensionId(); + List *extList = getExtensionsOfExtension(extOid); + ListCell *lc; + + extList = lappend_oid(extList, extOid); + + foreach(lc, extList) + { + extOid = lfirst_oid(lc); + + namespaceId = get_extension_schema(extOid); + if (OidIsValid(namespaceId) && + object_aclcheck(NamespaceRelationId, namespaceId, roleid, + ACL_USAGE) == ACLCHECK_OK) + oidlist = lappend_oid(oidlist, namespaceId); + } + + list_free(extList); + } else { /* normal namespace reference */ diff --git a/src/backend/catalog/pg_depend.c b/src/backend/catalog/pg_depend.c index cfd7ef51df..8a7f071c00 100644 --- a/src/backend/catalog/pg_depend.c +++ b/src/backend/catalog/pg_depend.c @@ -814,6 +814,49 @@ getAutoExtensionsOfObject(Oid classId, Oid objectId) return result; } +/* + * Return (possibly NIL) list of extensions that the given extension depends on + * in DEPENDENCY_NORMAL mode. + */ +List * +getExtensionsOfExtension(Oid objectId) +{ + List *result = NIL; + Relation depRel; + ScanKeyData key[2]; + SysScanDesc scan; + HeapTuple tup; + + depRel = table_open(DependRelationId, AccessShareLock); + + ScanKeyInit(&key[0], + Anum_pg_depend_classid, + BTEqualStrategyNumber, F_OIDEQ, + ObjectIdGetDatum(ExtensionRelationId)); + ScanKeyInit(&key[1], + Anum_pg_depend_objid, + BTEqualStrategyNumber, F_OIDEQ, + ObjectIdGetDatum(objectId)); + + scan = systable_beginscan(depRel, DependDependerIndexId, true, + NULL, 2, key); + + while (HeapTupleIsValid((tup = systable_getnext(scan)))) + { + Form_pg_depend depform = (Form_pg_depend) GETSTRUCT(tup); + + if (depform->refclassid == ExtensionRelationId && + depform->deptype == DEPENDENCY_NORMAL) + result = lappend_oid(result, depform->refobjid); + } + + systable_endscan(scan); + + table_close(depRel, AccessShareLock); + + return result; +} + /* * Detect whether a sequence is marked as "owned" by a column * diff --git a/src/backend/commands/extension.c b/src/backend/commands/extension.c index 1643c8c69a..0d22a2156d 100644 --- a/src/backend/commands/extension.c +++ b/src/backend/commands/extension.c @@ -70,6 +70,8 @@ /* Globally visible state variables */ bool creating_extension = false; Oid CurrentExtensionObject = InvalidOid; +bool create_extension_set_search_path = false; +Oid CurrentExtensionId = InvalidOid; /* * Internal data structure to hold the results of parsing a control file @@ -86,6 +88,8 @@ typedef struct ExtensionControlFile bool relocatable; /* is ALTER EXTENSION SET SCHEMA supported? */ bool superuser; /* must be superuser to install? */ bool trusted; /* allow becoming superuser on the fly? */ + bool protected; /* should we protect extension by setting implicit + * search_path for functions and procedures? */ int encoding; /* encoding of the script file, or -1 */ List *requires; /* names of prerequisite extensions */ List *no_relocate; /* names of prerequisite extensions that @@ -117,7 +121,8 @@ static Oid get_required_extension(char *reqExtensionName, char *origSchemaName, bool cascade, List *parents, - bool is_create); + bool is_create, + bool set_search_path); static void get_available_versions_for_extension(ExtensionControlFile *pcontrol, Tuplestorestate *tupstore, TupleDesc tupdesc); @@ -128,12 +133,31 @@ static void ApplyExtensionUpdates(Oid extensionOid, List *updateVersions, char *origSchemaName, bool cascade, - bool is_create); + bool is_create, + bool set_search_path); static void ExecAlterExtensionContentsRecurse(AlterExtensionContentsStmt *stmt, ObjectAddress extension, ObjectAddress object); static char *read_whole_file(const char *filename, int *length); +/* + * SetCurrentExtensionId - Set the current extension Oid. + */ +void +SetCurrentExtensionId(Oid extensionOid) +{ + CurrentExtensionId = extensionOid; +} + +/* + * GetCurrentExtensionId - Get the current extension Oid. + */ +Oid +GetCurrentExtensionId() +{ + Assert(OidIsValid(CurrentExtensionId)); + return CurrentExtensionId; +} /* * get_extension_oid - given an extension name, look up the OID @@ -585,6 +609,14 @@ parse_extension_control_file(ExtensionControlFile *control, errmsg("parameter \"%s\" requires a Boolean value", item->name))); } + else if (strcmp(item->name, "protected") == 0) + { + if (!parse_bool(item->value, &control->protected)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("parameter \"%s\" requires a Boolean value", + item->name))); + } else if (strcmp(item->name, "encoding") == 0) { control->encoding = pg_valid_server_encoding(item->value); @@ -871,7 +903,8 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control, const char *from_version, const char *version, List *requiredSchemas, - const char *schemaName, Oid schemaOid) + const char *schemaName, Oid schemaOid, + bool set_search_path) { bool switch_to_superuser = false; char *filename; @@ -992,6 +1025,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control, */ creating_extension = true; CurrentExtensionObject = extensionOid; + create_extension_set_search_path = set_search_path; PG_TRY(); { char *c_sql = read_extension_script_file(control, filename); @@ -1116,6 +1150,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control, { creating_extension = false; CurrentExtensionObject = InvalidOid; + create_extension_set_search_path = false; } PG_END_TRY(); @@ -1475,6 +1510,7 @@ CreateExtensionInternal(char *extensionName, Oid extensionOid; ObjectAddress address; ListCell *lc; + bool set_search_path = false; /* * Read the primary control file. Note we assume that it does not contain @@ -1542,6 +1578,10 @@ CreateExtensionInternal(char *extensionName, */ control = read_extension_aux_control_file(pcontrol, versionName); + /* Check if this extension requires protection */ + if (control->protected) + set_search_path = true; + /* * Determine the target schema to install the extension into */ @@ -1648,7 +1688,8 @@ CreateExtensionInternal(char *extensionName, origSchemaName, cascade, parents, - is_create); + is_create, + set_search_path); reqschema = get_extension_schema(reqext); requiredExtensions = lappend_oid(requiredExtensions, reqext); requiredSchemas = lappend_oid(requiredSchemas, reqschema); @@ -1677,7 +1718,7 @@ CreateExtensionInternal(char *extensionName, execute_extension_script(extensionOid, control, NULL, versionName, requiredSchemas, - schemaName, schemaOid); + schemaName, schemaOid, set_search_path); /* * If additional update scripts have to be executed, apply the updates as @@ -1685,7 +1726,7 @@ CreateExtensionInternal(char *extensionName, */ ApplyExtensionUpdates(extensionOid, pcontrol, versionName, updateVersions, - origSchemaName, cascade, is_create); + origSchemaName, cascade, is_create, set_search_path); return address; } @@ -1699,7 +1740,8 @@ get_required_extension(char *reqExtensionName, char *origSchemaName, bool cascade, List *parents, - bool is_create) + bool is_create, + bool set_search_path) { Oid reqExtensionOid; @@ -3115,7 +3157,7 @@ ExecAlterExtensionStmt(ParseState *pstate, AlterExtensionStmt *stmt) */ ApplyExtensionUpdates(extensionOid, control, oldVersionName, updateVersions, - NULL, false, false); + NULL, false, false, false); ObjectAddressSet(address, ExtensionRelationId, extensionOid); @@ -3137,7 +3179,8 @@ ApplyExtensionUpdates(Oid extensionOid, List *updateVersions, char *origSchemaName, bool cascade, - bool is_create) + bool is_create, + bool set_search_path) { const char *oldVersionName = initialVersion; ListCell *lcv; @@ -3232,7 +3275,8 @@ ApplyExtensionUpdates(Oid extensionOid, origSchemaName, cascade, NIL, - is_create); + is_create, + set_search_path); reqschema = get_extension_schema(reqext); requiredExtensions = lappend_oid(requiredExtensions, reqext); requiredSchemas = lappend_oid(requiredSchemas, reqschema); @@ -3269,7 +3313,7 @@ ApplyExtensionUpdates(Oid extensionOid, execute_extension_script(extensionOid, control, oldVersionName, versionName, requiredSchemas, - schemaName, schemaOid); + schemaName, schemaOid, set_search_path); /* * Update prior-version name and loop around. Since diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c index 6593fd7d81..79764f2996 100644 --- a/src/backend/commands/functioncmds.c +++ b/src/backend/commands/functioncmds.c @@ -52,6 +52,7 @@ #include "executor/functions.h" #include "funcapi.h" #include "miscadmin.h" +#include "nodes/makefuncs.h" #include "nodes/nodeFuncs.h" #include "optimizer/optimizer.h" #include "parser/analyze.h" @@ -71,6 +72,7 @@ #include "utils/snapmgr.h" #include "utils/syscache.h" #include "utils/typcache.h" +#include "utils/varlena.h" /* * Examine the RETURNS clause of the CREATE FUNCTION statement @@ -705,6 +707,25 @@ interpret_func_support(DefElem *defel) return procOid; } +/* + * Returns true if search_path is set in set_items list. + */ +static bool +IsSearchPathSet(List *set_items) +{ + ListCell *l; + + foreach(l, set_items) + { + VariableSetStmt *sstmt = lfirst_node(VariableSetStmt, l); + + if (pg_strcasecmp(sstmt->name, "search_path") == 0 && + sstmt->kind == VAR_SET_VALUE) + return true; + } + + return false; +} /* * Dissect the list of options assembled in gram.y into function @@ -726,7 +747,8 @@ compute_function_attributes(ParseState *pstate, float4 *procost, float4 *prorows, Oid *prosupport, - char *parallel_p) + char *parallel_p, + Oid namespaceId) { ListCell *option; DefElem *as_item = NULL; @@ -813,6 +835,53 @@ compute_function_attributes(ParseState *pstate, *security_definer = boolVal(security_item->arg); if (leakproof_item) *leakproof_p = boolVal(leakproof_item->arg); + + /* + * If "create_extension_set_search_path" is enabled, it indicates that the + * user has set "protected" flag inside the extension control file. + * Therefore, we must ensure that the function(s) created by an extension + * have their search_path set to trusted schema(s), which includes the + * schema where the function is being created and the search_path set by the + * extension. See execute_extension_script() for details on search_path set + * by the extension. + */ + if (creating_extension && create_extension_set_search_path) + { + /* If the search_path is already set, there is nothing to do. */ + if (!set_items || !IsSearchPathSet(set_items)) + { + StringInfoData sp_string; + VariableSetStmt *sp_node = makeNode(VariableSetStmt); + List *schemaList; + ListCell *lc; + + sp_node->kind = VAR_SET_VALUE; + sp_node->name = "search_path"; + + initStringInfo(&sp_string); + + if (namespaceId != get_extension_schema(CurrentExtensionObject)) + { + appendStringInfoString(&sp_string, get_namespace_name(namespaceId)); + appendStringInfoString(&sp_string, ", "); + } + appendStringInfoString(&sp_string, "$extension_schema, pg_temp"); + + (void) SplitIdentifierString(sp_string.data, ',', &schemaList); + + foreach(lc, schemaList) + { + char *schema_name = lfirst(lc); + + sp_node->args = lappend(sp_node->args, + makeStringConst(pstrdup(schema_name), -1)); + } + + set_items = lappend(set_items, sp_node); + pfree(sp_string.data); + } + } + if (set_items) *proconfig = update_proconfig_value(NULL, set_items); if (cost_item) @@ -1079,7 +1148,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) &isWindowFunc, &volatility, &isStrict, &security, &isLeakProof, &proconfig, &procost, &prorows, - &prosupport, ¶llel); + &prosupport, ¶llel, namespaceId); if (!language) { diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y index 4d582950b7..98923b1858 100644 --- a/src/backend/parser/gram.y +++ b/src/backend/parser/gram.y @@ -1669,6 +1669,24 @@ generic_set: VariableSetStmt *n = makeNode(VariableSetStmt); n->kind = VAR_SET_VALUE; + + if (strcmp($1, "search_path") == 0) + { + ListCell *lc; + + foreach(lc, $3) + { + void *arg = lfirst(lc); + + if (IsA(arg, A_Const) && + castNode(A_Const, arg)->val.node.type == T_String && + strcmp(castNode(A_Const, arg)->val.sval.sval, "$extension_schema") == 0) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("search_path cannot be set to $extension_schema"), + parser_errposition(((A_Const *) arg)->location))); + } + } n->name = $1; n->args = $3; $$ = n; @@ -1678,6 +1696,24 @@ generic_set: VariableSetStmt *n = makeNode(VariableSetStmt); n->kind = VAR_SET_VALUE; + + if (strcmp($1, "search_path") == 0) + { + ListCell *lc; + + foreach(lc, $3) + { + void *arg = lfirst(lc); + + if (IsA(arg, A_Const) && + castNode(A_Const, arg)->val.node.type == T_String && + strcmp(castNode(A_Const, arg)->val.sval.sval, "$extension_schema") == 0) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("search_path cannot be set to $extension_schema"), + parser_errposition(((A_Const *) arg)->location))); + } + } n->name = $1; n->args = $3; $$ = n; @@ -1737,6 +1773,11 @@ set_rest_more: /* Generic SET syntaxes: */ n->kind = VAR_SET_VALUE; n->name = "search_path"; + if (strcmp($2, "$extension_schema") == 0) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("search_path cannot be set to $extension_schema"), + parser_errposition(@2))); n->args = list_make1(makeStringConst($2, @2)); $$ = n; } diff --git a/src/backend/utils/fmgr/fmgr.c b/src/backend/utils/fmgr/fmgr.c index e48a86be54..e2211c82f3 100644 --- a/src/backend/utils/fmgr/fmgr.c +++ b/src/backend/utils/fmgr/fmgr.c @@ -16,6 +16,8 @@ #include "postgres.h" #include "access/detoast.h" +#include "commands/extension.h" +#include "catalog/dependency.h" #include "catalog/pg_language.h" #include "catalog/pg_proc.h" #include "catalog/pg_type.h" @@ -641,6 +643,15 @@ fmgr_security_definer(PG_FUNCTION_ARGS) *lc3; volatile int save_nestlevel; PgStat_FunctionCallUsage fcusage; + Oid extensionOid = InvalidOid; + + /* + * Let's check if this is an extension created function. If it is, we'll set + * the CurrentExtensionId before calling it, so that preprocessNamespacePath + * can handle $extension_schema correctly. + */ + extensionOid = getExtensionOfObject(ProcedureRelationId, + fcinfo->flinfo->fn_oid); if (!fcinfo->flinfo->fn_extra) { @@ -737,6 +748,9 @@ fmgr_security_definer(PG_FUNCTION_ARGS) */ save_flinfo = fcinfo->flinfo; + if (OidIsValid(extensionOid)) + SetCurrentExtensionId(extensionOid); + PG_TRY(); { fcinfo->flinfo = &fcache->flinfo; @@ -758,6 +772,7 @@ fmgr_security_definer(PG_FUNCTION_ARGS) PG_CATCH(); { fcinfo->flinfo = save_flinfo; + SetCurrentExtensionId(InvalidOid); if (fmgr_hook) (*fmgr_hook) (FHET_ABORT, &fcache->flinfo, &fcache->arg); PG_RE_THROW(); @@ -765,6 +780,7 @@ fmgr_security_definer(PG_FUNCTION_ARGS) PG_END_TRY(); fcinfo->flinfo = save_flinfo; + SetCurrentExtensionId(InvalidOid); if (fcache->configNames != NIL) AtEOXact_GUC(true, save_nestlevel); diff --git a/src/include/catalog/dependency.h b/src/include/catalog/dependency.h index 7eee66f810..e2874c1b16 100644 --- a/src/include/catalog/dependency.h +++ b/src/include/catalog/dependency.h @@ -174,6 +174,7 @@ extern long changeDependenciesOn(Oid refClassId, Oid oldRefObjectId, extern Oid getExtensionOfObject(Oid classId, Oid objectId); extern List *getAutoExtensionsOfObject(Oid classId, Oid objectId); +extern List *getExtensionsOfExtension(Oid objectId); extern bool sequenceIsOwned(Oid seqId, char deptype, Oid *tableId, int32 *colId); extern List *getOwnedSequences(Oid relid); diff --git a/src/include/commands/extension.h b/src/include/commands/extension.h index c6f3f867eb..9512e8109c 100644 --- a/src/include/commands/extension.h +++ b/src/include/commands/extension.h @@ -29,6 +29,8 @@ */ extern PGDLLIMPORT bool creating_extension; extern PGDLLIMPORT Oid CurrentExtensionObject; +extern PGDLLIMPORT bool create_extension_set_search_path; +extern PGDLLIMPORT Oid CurrentExtensionId; extern ObjectAddress CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt); @@ -53,4 +55,7 @@ extern bool extension_file_exists(const char *extensionName); extern ObjectAddress AlterExtensionNamespace(const char *extensionName, const char *newschema, Oid *oldschema); +extern void SetCurrentExtensionId(Oid extensionOid); +extern Oid GetCurrentExtensionId(void); + #endif /* EXTENSION_H */ -- 2.17.1