From 378107f06b562460e4cfefdbc7f1b3c58ad51ab2 Mon Sep 17 00:00:00 2001 From: David Rowley Date: Mon, 12 Jul 2021 20:24:08 +1200 Subject: [PATCH v3 2/2] WIP: Add planner support for DISTINCT aggregates --- src/backend/executor/execExpr.c | 24 +++++- src/backend/executor/execExprInterp.c | 82 +++++++++++++++++++ src/backend/executor/nodeAgg.c | 21 ++++- src/backend/optimizer/plan/planner.c | 10 +-- src/include/executor/execExpr.h | 13 +++ src/include/executor/nodeAgg.h | 6 +- src/include/nodes/primnodes.h | 4 +- src/test/regress/expected/aggregates.out | 2 +- .../regress/expected/partition_aggregate.out | 12 +-- src/test/regress/expected/tuplesort.out | 40 +++++---- 10 files changed, 176 insertions(+), 38 deletions(-) diff --git a/src/backend/executor/execExpr.c b/src/backend/executor/execExpr.c index a6e9c48f11..0c84e3757a 100644 --- a/src/backend/executor/execExpr.c +++ b/src/backend/executor/execExpr.c @@ -3426,7 +3426,8 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase, /* * Normal transition function without ORDER BY / DISTINCT or with - * ORDER BY but the planner has given us pre-sorted input. + * ORDER BY / DISTINCT but the planner has given us pre-sorted + * input. */ strictargs = trans_fcinfo->args + 1; @@ -3514,6 +3515,21 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase, state->steps_len - 1); } + /* Handle DISTINCT aggregates which have pre-sorted input */ + if (pertrans->numDistinctCols > 0 && !pertrans->aggsortrequired) + { + if (pertrans->numDistinctCols > 1) + scratch.opcode = EEOP_AGG_PRESORTED_DISTINCT_MULTI; + else + scratch.opcode = EEOP_AGG_PRESORTED_DISTINCT_SINGLE; + + scratch.d.agg_presorted_distinctcheck.pertrans = pertrans; + scratch.d.agg_presorted_distinctcheck.jumpdistinct = -1; /* adjust later */ + ExprEvalPushStep(state, &scratch); + adjust_bailout = lappend_int(adjust_bailout, + state->steps_len - 1); + } + /* * Call transition function (once for each concurrently evaluated * grouping set). Do so for both sort and hash based computations, as @@ -3574,6 +3590,12 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase, Assert(as->d.agg_deserialize.jumpnull == -1); as->d.agg_deserialize.jumpnull = state->steps_len; } + else if (as->opcode == EEOP_AGG_PRESORTED_DISTINCT_SINGLE || + as->opcode == EEOP_AGG_PRESORTED_DISTINCT_MULTI) + { + Assert(as->d.agg_presorted_distinctcheck.jumpdistinct == -1); + as->d.agg_presorted_distinctcheck.jumpdistinct = state->steps_len; + } else Assert(false); } diff --git a/src/backend/executor/execExprInterp.c b/src/backend/executor/execExprInterp.c index eb49817cee..17904bdb7f 100644 --- a/src/backend/executor/execExprInterp.c +++ b/src/backend/executor/execExprInterp.c @@ -488,6 +488,8 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull) &&CASE_EEOP_AGG_PLAIN_TRANS_INIT_STRICT_BYREF, &&CASE_EEOP_AGG_PLAIN_TRANS_STRICT_BYREF, &&CASE_EEOP_AGG_PLAIN_TRANS_BYREF, + &&CASE_EEOP_AGG_PRESORTED_DISTINCT_SINGLE, + &&CASE_EEOP_AGG_PRESORTED_DISTINCT_MULTI, &&CASE_EEOP_AGG_ORDERED_TRANS_DATUM, &&CASE_EEOP_AGG_ORDERED_TRANS_TUPLE, &&CASE_EEOP_LAST @@ -1772,6 +1774,86 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull) EEO_NEXT(); } + EEO_CASE(EEOP_AGG_PRESORTED_DISTINCT_SINGLE) + { + AggStatePerTrans pertrans = op->d.agg_presorted_distinctcheck.pertrans; + Datum value = pertrans->transfn_fcinfo->args[1].value; + bool isnull = pertrans->transfn_fcinfo->args[1].isnull; + + if (!pertrans->haslast || + pertrans->lastisnull != isnull || + !DatumGetBool(FunctionCall2Coll(&pertrans->equalfnOne, + pertrans->aggCollation, + pertrans->lastdatum, value))) + { + if (pertrans->haslast && !pertrans->inputtypeByVal) + pfree(DatumGetPointer(pertrans->lastdatum)); + + pertrans->haslast = true; + if (!isnull) + { + AggState *aggstate = castNode(AggState, state->parent); + + /* + * XXX is it worth having a dedicated ByVal version of this + * operation so that we can skip switching memory contexts + * and do a simple assign rather than datumCopy below? + */ + MemoryContext oldContext; + + oldContext = MemoryContextSwitchTo(aggstate->curaggcontext->ecxt_per_tuple_memory); + + pertrans->lastdatum = datumCopy(value, pertrans->inputtypeByVal, pertrans->inputtypeLen); + + MemoryContextSwitchTo(oldContext); + } + else + pertrans->lastdatum = (Datum) 0; + pertrans->lastisnull = isnull; + EEO_NEXT(); + } + EEO_JUMP(op->d.agg_presorted_distinctcheck.jumpdistinct); + } + + EEO_CASE(EEOP_AGG_PRESORTED_DISTINCT_MULTI) + { + AggState *aggstate = castNode(AggState, state->parent); + AggStatePerTrans pertrans = op->d.agg_presorted_distinctcheck.pertrans; + ExprContext *tmpcontext = aggstate->tmpcontext; + int i; + + /* + * XXX or should we have had these values copied directly into the + * sortslot? If we did then we'd still need to copy them into the + * transfn_fcinfo->args here if we detect the tuple is distinct + * from the previous tuple. + */ + for (i = 0; i < pertrans->numTransInputs; i++) + { + pertrans->sortslot->tts_values[i] = pertrans->transfn_fcinfo->args[i + 1].value; + pertrans->sortslot->tts_isnull[i] = pertrans->transfn_fcinfo->args[i + 1].isnull; + } + + ExecClearTuple(pertrans->sortslot); + pertrans->sortslot->tts_nvalid = pertrans->numInputs; + ExecStoreVirtualTuple(pertrans->sortslot); + + tmpcontext->ecxt_outertuple = pertrans->sortslot; + tmpcontext->ecxt_innertuple = pertrans->uniqslot; + + if (!pertrans->haslast || + !ExecQual(pertrans->equalfnMulti, tmpcontext)) + { + if (pertrans->haslast) + ExecClearTuple(pertrans->uniqslot); + + pertrans->haslast = true; + ExecCopySlot(pertrans->uniqslot, pertrans->sortslot); + EEO_NEXT(); + } + EEO_JUMP(op->d.agg_presorted_distinctcheck.jumpdistinct); + } + /* process single-column ordered aggregate datum */ EEO_CASE(EEOP_AGG_ORDERED_TRANS_DATUM) { diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index e28d53c17b..80689ea466 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -1342,6 +1342,21 @@ finalize_aggregates(AggState *aggstate, pertrans, pergroupstate); } + else if (pertrans->numDistinctCols > 0 && pertrans->haslast) + { + pertrans->haslast = false; + + if (pertrans->numDistinctCols == 1) + { + if (!pertrans->inputtypeByVal && !pertrans->lastisnull) + pfree(DatumGetPointer(pertrans->lastdatum)); + + pertrans->lastisnull = false; + pertrans->lastdatum = (Datum) 0; + } + else + ExecClearTuple(pertrans->uniqslot); + } } /* @@ -4234,10 +4249,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, numSortCols = numDistinctCols = 0; pertrans->aggsortrequired = false; } - else if (aggref->aggpresorted) + else if (aggref->aggpresorted && aggref->aggdistinct == NIL) { - /* DISTINCT not yet supported for aggpresorted */ - Assert(aggref->aggdistinct == NIL); sortlist = NIL; numSortCols = numDistinctCols = 0; pertrans->aggsortrequired = false; @@ -4247,7 +4260,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, sortlist = aggref->aggdistinct; numSortCols = numDistinctCols = list_length(sortlist); Assert(numSortCols >= list_length(aggref->aggorder)); - pertrans->aggsortrequired = true; + pertrans->aggsortrequired = !aggref->aggpresorted; } else { diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c index d5b184ab52..a6c0a639f9 100644 --- a/src/backend/optimizer/plan/planner.c +++ b/src/backend/optimizer/plan/planner.c @@ -3081,7 +3081,7 @@ standard_qp_callback(PlannerInfo *root, void *extra) else root->group_pathkeys = NIL; - /* Determine pathkeys for aggregate functions with an ORDER BY */ + /* Determine pathkeys for aggregate functions with DISTINCT/ORDER BY */ if (parse->groupingSets == NIL && root->numOrderedAggs > 0 && (qp_extra->groupClause == NIL || root->group_pathkeys)) { @@ -3097,15 +3097,15 @@ standard_qp_callback(PlannerInfo *root, void *extra) if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) continue; - /* DISTINCT aggregates not yet supported by the planner */ if (aggref->aggdistinct != NIL) - continue; - - if (aggref->aggorder != NIL) + sortlist = aggref->aggdistinct; + else if (aggref->aggorder != NIL) sortlist = aggref->aggorder; else continue; + Assert(sortlist != NIL); + /* * Find the pathkeys with the most sorted derivative of the first * Aggref. For example, if we determine the pathkeys for the first diff --git a/src/include/executor/execExpr.h b/src/include/executor/execExpr.h index 6a24341faa..633eb809ac 100644 --- a/src/include/executor/execExpr.h +++ b/src/include/executor/execExpr.h @@ -252,6 +252,8 @@ typedef enum ExprEvalOp EEOP_AGG_PLAIN_TRANS_INIT_STRICT_BYREF, EEOP_AGG_PLAIN_TRANS_STRICT_BYREF, EEOP_AGG_PLAIN_TRANS_BYREF, + EEOP_AGG_PRESORTED_DISTINCT_SINGLE, + EEOP_AGG_PRESORTED_DISTINCT_MULTI, EEOP_AGG_ORDERED_TRANS_DATUM, EEOP_AGG_ORDERED_TRANS_TUPLE, @@ -658,6 +660,17 @@ typedef struct ExprEvalStep int jumpnull; } agg_plain_pergroup_nullcheck; + /* for EEOP_AGG_PRESORTED_DISTINCT_{SINGLE,MULTI} */ + struct + { + AggStatePerTrans pertrans; + ExprContext *aggcontext; + int setno; + int transno; + int setoff; + int jumpdistinct; + } agg_presorted_distinctcheck; + /* for EEOP_AGG_PLAIN_TRANS_[INIT_][STRICT_]{BYVAL,BYREF} */ /* for EEOP_AGG_ORDERED_TRANS_{DATUM,TUPLE} */ struct diff --git a/src/include/executor/nodeAgg.h b/src/include/executor/nodeAgg.h index bcd0643699..22976655e1 100644 --- a/src/include/executor/nodeAgg.h +++ b/src/include/executor/nodeAgg.h @@ -49,7 +49,8 @@ typedef struct AggStatePerTransData bool aggshared; /* - * True for ORDER BY aggregates that are not Aggref->aggpresorted + * True for ORDER BY / DISTINCT aggregates that are not + * Aggref->aggpresorted */ bool aggsortrequired; @@ -141,6 +142,9 @@ typedef struct AggStatePerTransData TupleTableSlot *sortslot; /* current input tuple */ TupleTableSlot *uniqslot; /* used for multi-column DISTINCT */ TupleDesc sortdesc; /* descriptor of input tuples */ + Datum lastdatum; /* used for single-column DISTINCT */ + bool lastisnull; /* used for single-column DISTINCT */ + bool haslast; /* got a last value for DISTINCT check */ /* * These values are working state that is initialized at the start of an diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h index 2f3cc39d4b..ad8e6cc8d9 100644 --- a/src/include/nodes/primnodes.h +++ b/src/include/nodes/primnodes.h @@ -304,8 +304,8 @@ typedef struct Param * replaced with a single argument representing the partial-aggregate * transition values. * - * aggpresorted is set by the query planner for ORDER BY aggregates where the - * query plan chosen provides presorted input for the executor. + * aggpresorted is set by the query planner for ORDER BY / DISTINCT aggregates + * where the query plan chosen provides presorted input for the executor. * * aggsplit indicates the expected partial-aggregation mode for the Aggref's * parent plan node. It's always set to AGGSPLIT_SIMPLE in the parser, but diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index ca06d41dd0..db45ba0aba 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -2224,8 +2224,8 @@ NOTICE: avg_transfn called with 3 -- shouldn't share states due to the distinctness not matching. select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one); NOTICE: avg_transfn called with 1 -NOTICE: avg_transfn called with 3 NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 NOTICE: avg_transfn called with 3 my_avg | my_sum --------+-------- diff --git a/src/test/regress/expected/partition_aggregate.out b/src/test/regress/expected/partition_aggregate.out index 484c94e585..72c240c9f7 100644 --- a/src/test/regress/expected/partition_aggregate.out +++ b/src/test/regress/expected/partition_aggregate.out @@ -959,13 +959,13 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA Group Key: pagg_tab_ml.a Filter: (avg(pagg_tab_ml.b) < '3'::numeric) -> Sort - Sort Key: pagg_tab_ml.a + Sort Key: pagg_tab_ml.a, pagg_tab_ml.c -> Seq Scan on pagg_tab_ml_p1 pagg_tab_ml -> GroupAggregate Group Key: pagg_tab_ml_5.a Filter: (avg(pagg_tab_ml_5.b) < '3'::numeric) -> Sort - Sort Key: pagg_tab_ml_5.a + Sort Key: pagg_tab_ml_5.a, pagg_tab_ml_5.c -> Append -> Seq Scan on pagg_tab_ml_p3_s1 pagg_tab_ml_5 -> Seq Scan on pagg_tab_ml_p3_s2 pagg_tab_ml_6 @@ -973,7 +973,7 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA Group Key: pagg_tab_ml_2.a Filter: (avg(pagg_tab_ml_2.b) < '3'::numeric) -> Sort - Sort Key: pagg_tab_ml_2.a + Sort Key: pagg_tab_ml_2.a, pagg_tab_ml_2.c -> Append -> Seq Scan on pagg_tab_ml_p2_s1 pagg_tab_ml_2 -> Seq Scan on pagg_tab_ml_p2_s2 pagg_tab_ml_3 @@ -1005,13 +1005,13 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA Group Key: pagg_tab_ml.a Filter: (avg(pagg_tab_ml.b) < '3'::numeric) -> Sort - Sort Key: pagg_tab_ml.a + Sort Key: pagg_tab_ml.a, pagg_tab_ml.c -> Seq Scan on pagg_tab_ml_p1 pagg_tab_ml -> GroupAggregate Group Key: pagg_tab_ml_5.a Filter: (avg(pagg_tab_ml_5.b) < '3'::numeric) -> Sort - Sort Key: pagg_tab_ml_5.a + Sort Key: pagg_tab_ml_5.a, pagg_tab_ml_5.c -> Append -> Seq Scan on pagg_tab_ml_p3_s1 pagg_tab_ml_5 -> Seq Scan on pagg_tab_ml_p3_s2 pagg_tab_ml_6 @@ -1019,7 +1019,7 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA Group Key: pagg_tab_ml_2.a Filter: (avg(pagg_tab_ml_2.b) < '3'::numeric) -> Sort - Sort Key: pagg_tab_ml_2.a + Sort Key: pagg_tab_ml_2.a, pagg_tab_ml_2.c -> Append -> Seq Scan on pagg_tab_ml_p2_s1 pagg_tab_ml_2 -> Seq Scan on pagg_tab_ml_p2_s2 pagg_tab_ml_3 diff --git a/src/test/regress/expected/tuplesort.out b/src/test/regress/expected/tuplesort.out index 418f296a3f..ef79574ecf 100644 --- a/src/test/regress/expected/tuplesort.out +++ b/src/test/regress/expected/tuplesort.out @@ -622,15 +622,17 @@ EXPLAIN (COSTS OFF) :qry; -> GroupAggregate Group Key: a.col12 Filter: (count(*) > 1) - -> Merge Join - Merge Cond: (a.col12 = b.col12) - -> Sort - Sort Key: a.col12 DESC - -> Seq Scan on test_mark_restore a - -> Sort - Sort Key: b.col12 DESC - -> Seq Scan on test_mark_restore b -(14 rows) + -> Sort + Sort Key: a.col12 DESC, a.col1 + -> Merge Join + Merge Cond: (a.col12 = b.col12) + -> Sort + Sort Key: a.col12 + -> Seq Scan on test_mark_restore a + -> Sort + Sort Key: b.col12 + -> Seq Scan on test_mark_restore b +(16 rows) :qry; col12 | count | count | count | count | count @@ -658,15 +660,17 @@ EXPLAIN (COSTS OFF) :qry; -> GroupAggregate Group Key: a.col12 Filter: (count(*) > 1) - -> Merge Join - Merge Cond: (a.col12 = b.col12) - -> Sort - Sort Key: a.col12 DESC - -> Seq Scan on test_mark_restore a - -> Sort - Sort Key: b.col12 DESC - -> Seq Scan on test_mark_restore b -(14 rows) + -> Sort + Sort Key: a.col12 DESC, a.col1 + -> Merge Join + Merge Cond: (a.col12 = b.col12) + -> Sort + Sort Key: a.col12 + -> Seq Scan on test_mark_restore a + -> Sort + Sort Key: b.col12 + -> Seq Scan on test_mark_restore b +(16 rows) :qry; col12 | count | count | count | count | count -- 2.30.2