diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c index 229b408..6e1a6ae 100644 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -259,6 +259,13 @@ typedef struct NumericVar NumericDigit *digits; /* base-NBASE digits */ } NumericVar; +/* Transition state for numeric average aggregate. */ +typedef struct AvgAggState +{ + Numeric sumX; + uint64 N; + size_t sumX_size; +} AvgAggState; /* ---------- * Some preinitialized constants @@ -432,7 +439,7 @@ static void trunc_var(NumericVar *var, int rscale); static void strip_var(NumericVar *var); static void compute_bucket(Numeric operand, Numeric bound1, Numeric bound2, NumericVar *count_var, NumericVar *result_var); - +static AvgAggState * makeAvgAggState(FunctionCallInfo fcinfo); /* ---------------------------------------------------------------------- * @@ -2511,38 +2518,28 @@ do_numeric_accum(ArrayType *transarray, Numeric newval) return result; } -/* - * Improve avg performance by not caclulating sum(X*X). - */ -static ArrayType * -do_numeric_avg_accum(ArrayType *transarray, Numeric newval) +static void +do_numeric_avg_accum(AvgAggState *state, Numeric newval) { - Datum *transdatums; - int ndatums; - Datum N, - sumX; - ArrayType *result; + Numeric newsumX; + size_t newsumX_size; - /* We assume the input is array of numeric */ - deconstruct_array(transarray, - NUMERICOID, -1, false, 'i', - &transdatums, NULL, &ndatums); - if (ndatums != 2) - elog(ERROR, "expected 2-element numeric array"); - N = transdatums[0]; - sumX = transdatums[1]; + /* Calculate the new value for sumX. */ + newsumX = DatumGetNumeric(DirectFunctionCall2(numeric_add, + NumericGetDatum(state->sumX), + NumericGetDatum(newval))); - N = DirectFunctionCall1(numeric_inc, N); - sumX = DirectFunctionCall2(numeric_add, sumX, - NumericGetDatum(newval)); - - transdatums[0] = N; - transdatums[1] = sumX; - - result = construct_array(transdatums, 2, - NUMERICOID, -1, false, 'i'); + /* Enlarge state->sumX to have enough space for the new sumX. */ + newsumX_size = VARSIZE(newsumX); + if (newsumX_size > state->sumX_size) + { + state->sumX = repalloc(state->sumX, newsumX_size); + state->sumX_size = newsumX_size; + } - return result; + /* Update state. */ + memcpy(state->sumX, newsumX, newsumX_size); + state->N++; } Datum @@ -2560,10 +2557,20 @@ numeric_accum(PG_FUNCTION_ARGS) Datum numeric_avg_accum(PG_FUNCTION_ARGS) { - ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); - Numeric newval = PG_GETARG_NUMERIC(1); + AvgAggState *state; + + state = PG_ARGISNULL(0) ? NULL : (AvgAggState *) PG_GETARG_POINTER(0); + + if (!PG_ARGISNULL(1)) + { + /* On the first time through, create the state variable. */ + if (state == NULL) + state = makeAvgAggState(fcinfo); + + do_numeric_avg_accum(state, PG_GETARG_NUMERIC(1)); + } - PG_RETURN_ARRAYTYPE_P(do_numeric_avg_accum(transarray, newval)); + PG_RETURN_POINTER(state); } /* @@ -2617,42 +2624,43 @@ int8_accum(PG_FUNCTION_ARGS) Datum int8_avg_accum(PG_FUNCTION_ARGS) { - ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); - Datum newval8 = PG_GETARG_DATUM(1); - Numeric newval; + AvgAggState *state; - newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric, newval8)); + state = PG_ARGISNULL(0) ? NULL : (AvgAggState *) PG_GETARG_POINTER(0); - PG_RETURN_ARRAYTYPE_P(do_numeric_avg_accum(transarray, newval)); -} + if (!PG_ARGISNULL(1)) + { + Datum newval8; + Numeric newval; + /* On the first time through, create the state variable. */ + if (state == NULL) + state = makeAvgAggState(fcinfo); + + newval8 = PG_GETARG_DATUM(1); + newval = DatumGetNumeric(DirectFunctionCall1(int8_numeric, newval8)); + + do_numeric_avg_accum(state, newval); + } + + PG_RETURN_POINTER(state); +} Datum numeric_avg(PG_FUNCTION_ARGS) { - ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); - Datum *transdatums; - int ndatums; - Numeric N, - sumX; + Datum countd; + AvgAggState *state; - /* We assume the input is array of numeric */ - deconstruct_array(transarray, - NUMERICOID, -1, false, 'i', - &transdatums, NULL, &ndatums); - if (ndatums != 2) - elog(ERROR, "expected 2-element numeric array"); - N = DatumGetNumeric(transdatums[0]); - sumX = DatumGetNumeric(transdatums[1]); - - /* SQL92 defines AVG of no values to be NULL */ - /* N is zero iff no digits (cf. numeric_uminus) */ - if (NUMERIC_NDIGITS(N) == 0) + if (PG_ARGISNULL(0)) PG_RETURN_NULL(); + state = (AvgAggState *) PG_GETARG_POINTER(0); + countd = DirectFunctionCall1(int8_numeric, Int64GetDatum(state->N)); + PG_RETURN_DATUM(DirectFunctionCall2(numeric_div, - NumericGetDatum(sumX), - NumericGetDatum(N))); + NumericGetDatum(state->sumX), + countd)); } /* @@ -6170,3 +6178,39 @@ strip_var(NumericVar *var) var->digits = digits; var->ndigits = ndigits; } + +/* + * makeAvgAggState + * + * Initialize state for numeric avg aggregate in the aggregate context + */ +static AvgAggState * +makeAvgAggState(FunctionCallInfo fcinfo) +{ + AvgAggState *state; + NumericVar *sumX_var; + MemoryContext agg_context; + MemoryContext old_context; + + if (!AggCheckCallContext(fcinfo, &agg_context)) + { + /* cannot be called directly because of internal-type argument */ + elog(ERROR, "numeric_avg_accum called in non-aggregate context"); + } + + sumX_var = palloc0(sizeof(NumericVar)); + zero_var(sumX_var); + + /* + * Create state in aggregate context. It'll stay there across subsequent + * calls. + */ + old_context = MemoryContextSwitchTo(agg_context); + state = palloc0(sizeof(AvgAggState)); + state->sumX = make_result(sumX_var); + state->sumX_size = VARSIZE(state->sumX); + state->N = 0; + MemoryContextSwitchTo(old_context); + + return state; +} diff --git a/src/include/catalog/pg_aggregate.h b/src/include/catalog/pg_aggregate.h index 6fb10a9..b34fedf 100644 --- a/src/include/catalog/pg_aggregate.h +++ b/src/include/catalog/pg_aggregate.h @@ -77,10 +77,10 @@ typedef FormData_pg_aggregate *Form_pg_aggregate; */ /* avg */ -DATA(insert ( 2100 int8_avg_accum numeric_avg 0 1231 "{0,0}" )); +DATA(insert ( 2100 int8_avg_accum numeric_avg 0 2281 _null_ )); DATA(insert ( 2101 int4_avg_accum int8_avg 0 1016 "{0,0}" )); DATA(insert ( 2102 int2_avg_accum int8_avg 0 1016 "{0,0}" )); -DATA(insert ( 2103 numeric_avg_accum numeric_avg 0 1231 "{0,0}" )); +DATA(insert ( 2103 numeric_avg_accum numeric_avg 0 2281 _null_ )); DATA(insert ( 2104 float4_accum float8_avg 0 1022 "{0,0,0}" )); DATA(insert ( 2105 float8_accum float8_avg 0 1022 "{0,0,0}" )); DATA(insert ( 2106 interval_accum interval_avg 0 1187 "{0 second,0 second}" )); diff --git a/src/include/catalog/pg_proc.h b/src/include/catalog/pg_proc.h index c97056e..f30a7a3 100644 --- a/src/include/catalog/pg_proc.h +++ b/src/include/catalog/pg_proc.h @@ -2385,7 +2385,7 @@ DATA(insert OID = 1832 ( float8_stddev_samp PGNSP PGUID 12 1 0 0 0 f f f f t f DESCR("aggregate final function"); DATA(insert OID = 1833 ( numeric_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i 2 0 1231 "1231 1700" _null_ _null_ _null_ _null_ numeric_accum _null_ _null_ _null_ )); DESCR("aggregate transition function"); -DATA(insert OID = 2858 ( numeric_avg_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i 2 0 1231 "1231 1700" _null_ _null_ _null_ _null_ numeric_avg_accum _null_ _null_ _null_ )); +DATA(insert OID = 2858 ( numeric_avg_accum PGNSP PGUID 12 1 0 0 0 f f f f f f i 2 0 2281 "2281 1700" _null_ _null_ _null_ _null_ numeric_avg_accum _null_ _null_ _null_ )); DESCR("aggregate transition function"); DATA(insert OID = 1834 ( int2_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i 2 0 1231 "1231 21" _null_ _null_ _null_ _null_ int2_accum _null_ _null_ _null_ )); DESCR("aggregate transition function"); @@ -2393,9 +2393,9 @@ DATA(insert OID = 1835 ( int4_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i 2 0 DESCR("aggregate transition function"); DATA(insert OID = 1836 ( int8_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i 2 0 1231 "1231 20" _null_ _null_ _null_ _null_ int8_accum _null_ _null_ _null_ )); DESCR("aggregate transition function"); -DATA(insert OID = 2746 ( int8_avg_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i 2 0 1231 "1231 20" _null_ _null_ _null_ _null_ int8_avg_accum _null_ _null_ _null_ )); +DATA(insert OID = 2746 ( int8_avg_accum PGNSP PGUID 12 1 0 0 0 f f f f f f i 2 0 2281 "2281 20" _null_ _null_ _null_ _null_ int8_avg_accum _null_ _null_ _null_ )); DESCR("aggregate transition function"); -DATA(insert OID = 1837 ( numeric_avg PGNSP PGUID 12 1 0 0 0 f f f f t f i 1 0 1700 "1231" _null_ _null_ _null_ _null_ numeric_avg _null_ _null_ _null_ )); +DATA(insert OID = 1837 ( numeric_avg PGNSP PGUID 12 1 0 0 0 f f f f f f i 1 0 1700 "2281" _null_ _null_ _null_ _null_ numeric_avg _null_ _null_ _null_ )); DESCR("aggregate final function"); DATA(insert OID = 2514 ( numeric_var_pop PGNSP PGUID 12 1 0 0 0 f f f f t f i 1 0 1700 "1231" _null_ _null_ _null_ _null_ numeric_var_pop _null_ _null_ _null_ )); DESCR("aggregate final function");