From 2dca72bb66c086ef77a27f7d7ff0bb524b4b9108 Mon Sep 17 00:00:00 2001 From: Anthonin Bonnefoy Date: Thu, 23 May 2024 11:24:44 +0200 Subject: Fix row estimation in gather paths In parallel plans, the row count of a partial plan is estimated to (rows/parallel_divisor). The parallel_divisor is the number of parallel_workers plus a possible leader contribution. When creating a gather path, we currently estimate the sum of gathered rows to worker_rows*parallel_workers which leads to a lower estimated row count. This patch changes the gather path row estimation to worker_rows*parallel_divisor to get a more accurate estimation. --- src/backend/optimizer/path/allpaths.c | 7 ++-- src/backend/optimizer/path/costsize.c | 19 +++++++++ src/backend/optimizer/plan/planner.c | 6 +-- src/include/optimizer/cost.h | 1 + src/test/regress/expected/join_hash.out | 19 +++++---- src/test/regress/expected/select_parallel.out | 39 +++++++++++++++++++ src/test/regress/expected/test_setup.out | 20 ++++++++++ src/test/regress/sql/select_parallel.sql | 11 ++++++ src/test/regress/sql/test_setup.sql | 21 ++++++++++ 9 files changed, 126 insertions(+), 17 deletions(-) diff --git a/src/backend/optimizer/path/allpaths.c b/src/backend/optimizer/path/allpaths.c index 4895cee994..fc72dfdeab 100644 --- a/src/backend/optimizer/path/allpaths.c +++ b/src/backend/optimizer/path/allpaths.c @@ -3071,8 +3071,7 @@ generate_gather_paths(PlannerInfo *root, RelOptInfo *rel, bool override_rows) * of partial_pathlist because of the way add_partial_path works. */ cheapest_partial_path = linitial(rel->partial_pathlist); - rows = - cheapest_partial_path->rows * cheapest_partial_path->parallel_workers; + rows = compute_gather_rows(cheapest_partial_path); simple_gather_path = (Path *) create_gather_path(root, rel, cheapest_partial_path, rel->reltarget, NULL, rowsp); @@ -3090,7 +3089,7 @@ generate_gather_paths(PlannerInfo *root, RelOptInfo *rel, bool override_rows) if (subpath->pathkeys == NIL) continue; - rows = subpath->rows * subpath->parallel_workers; + rows = compute_gather_rows(subpath); path = create_gather_merge_path(root, rel, subpath, rel->reltarget, subpath->pathkeys, NULL, rowsp); add_path(rel, &path->path); @@ -3274,7 +3273,7 @@ generate_useful_gather_paths(PlannerInfo *root, RelOptInfo *rel, bool override_r subpath, useful_pathkeys, -1.0); - rows = subpath->rows * subpath->parallel_workers; + rows = compute_gather_rows(subpath); } else subpath = (Path *) create_incremental_sort_path(root, diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c index ee23ed7835..c197d3f9e4 100644 --- a/src/backend/optimizer/path/costsize.c +++ b/src/backend/optimizer/path/costsize.c @@ -217,6 +217,25 @@ clamp_row_est(double nrows) return nrows; } +/* + * compute_gather_rows + * Compute the number of rows for gather nodes. + * + * When creating a gather (merge) path, we need to estimate the sum of rows + * distributed to all workers. A worker will have an estimated row set to + * (rows / parallel_divisor). Since parallel_divisor may include the leader + * contribution, we can't simply multiply workers' rows by the number of + * parallel_workers and instead need to reuse the parallel_divisor to get a + * more accurate estimation. + */ +double +compute_gather_rows(Path *partial_path) +{ + double parallel_divisor = get_parallel_divisor(partial_path); + + return clamp_row_est(partial_path->rows * parallel_divisor); +} + /* * clamp_width_est * Force a tuple-width estimate to a sane value. diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c index 4711f91239..c7aea3db9f 100644 --- a/src/backend/optimizer/plan/planner.c +++ b/src/backend/optimizer/plan/planner.c @@ -5370,8 +5370,8 @@ create_ordered_paths(PlannerInfo *root, root->sort_pathkeys, presorted_keys, limit_tuples); - total_groups = input_path->rows * - input_path->parallel_workers; + total_groups = compute_gather_rows(input_path); + sorted_path = (Path *) create_gather_merge_path(root, ordered_rel, sorted_path, @@ -7543,7 +7543,7 @@ gather_grouping_paths(PlannerInfo *root, RelOptInfo *rel) (presorted_keys == 0 || !enable_incremental_sort)) continue; - total_groups = path->rows * path->parallel_workers; + total_groups = compute_gather_rows(path); /* * We've no need to consider both a sort and incremental sort. We'll diff --git a/src/include/optimizer/cost.h b/src/include/optimizer/cost.h index b1c51a4e70..393fc8a9e5 100644 --- a/src/include/optimizer/cost.h +++ b/src/include/optimizer/cost.h @@ -212,5 +212,6 @@ extern PathTarget *set_pathtarget_cost_width(PlannerInfo *root, PathTarget *targ extern double compute_bitmap_pages(PlannerInfo *root, RelOptInfo *baserel, Path *bitmapqual, double loop_count, Cost *cost_p, double *tuples_p); +extern double compute_gather_rows(Path *partial_path); #endif /* COST_H */ diff --git a/src/test/regress/expected/join_hash.out b/src/test/regress/expected/join_hash.out index 262fa71ed8..4fc34a0e72 100644 --- a/src/test/regress/expected/join_hash.out +++ b/src/test/regress/expected/join_hash.out @@ -508,18 +508,17 @@ set local hash_mem_multiplier = 1.0; set local enable_parallel_hash = on; explain (costs off) select count(*) from simple r join extremely_skewed s using (id); - QUERY PLAN ------------------------------------------------------------------------ - Finalize Aggregate + QUERY PLAN +----------------------------------------------------------------- + Aggregate -> Gather Workers Planned: 1 - -> Partial Aggregate - -> Parallel Hash Join - Hash Cond: (r.id = s.id) - -> Parallel Seq Scan on simple r - -> Parallel Hash - -> Parallel Seq Scan on extremely_skewed s -(9 rows) + -> Parallel Hash Join + Hash Cond: (r.id = s.id) + -> Parallel Seq Scan on simple r + -> Parallel Hash + -> Parallel Seq Scan on extremely_skewed s +(8 rows) select count(*) from simple r join extremely_skewed s using (id); count diff --git a/src/test/regress/expected/select_parallel.out b/src/test/regress/expected/select_parallel.out index 5a603f86b7..f95f882704 100644 --- a/src/test/regress/expected/select_parallel.out +++ b/src/test/regress/expected/select_parallel.out @@ -1328,4 +1328,43 @@ SELECT 1 FROM tenk1_vw_sec Filter: (f1 < tenk1_vw_sec.unique1) (9 rows) +-- test estimated rows in gather nodes with different numbers of workers +EXPLAIN (COSTS OFF) +SELECT * FROM tenk1 ORDER BY twenty; + QUERY PLAN +---------------------------------------- + Gather Merge + Workers Planned: 4 + -> Sort + Sort Key: twenty + -> Parallel Seq Scan on tenk1 +(5 rows) + +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); + estimated +----------- + 10000 +(1 row) + +set max_parallel_workers_per_gather=3; +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); + estimated +----------- + 10000 +(1 row) + +set max_parallel_workers_per_gather=2; +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); + estimated +----------- + 10000 +(1 row) + +set max_parallel_workers_per_gather=1; +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); + estimated +----------- + 9999 +(1 row) + rollback; diff --git a/src/test/regress/expected/test_setup.out b/src/test/regress/expected/test_setup.out index 3d0eeec996..8f2d863b9c 100644 --- a/src/test/regress/expected/test_setup.out +++ b/src/test/regress/expected/test_setup.out @@ -239,3 +239,23 @@ create function fipshash(text) returns text strict immutable parallel safe leakproof return substr(encode(sha256($1::bytea), 'hex'), 1, 32); +-- get the number of estimated rows in the top node +create function get_estimated_rows(text) returns table (estimated int) +language plpgsql as +$$ +declare + ln text; + tmp text[]; + first_row bool := true; +begin + for ln in + execute format('explain %s', $1) + loop + if first_row then + first_row := false; + tmp := regexp_match(ln, 'rows=(\d*)'); + return query select tmp[1]::int; + end if; + end loop; +end; +$$; diff --git a/src/test/regress/sql/select_parallel.sql b/src/test/regress/sql/select_parallel.sql index c7df8f775c..b162cab7e9 100644 --- a/src/test/regress/sql/select_parallel.sql +++ b/src/test/regress/sql/select_parallel.sql @@ -510,4 +510,15 @@ EXPLAIN (COSTS OFF) SELECT 1 FROM tenk1_vw_sec WHERE (SELECT sum(f1) FROM int4_tbl WHERE f1 < unique1) < 100; +-- test estimated rows in gather nodes with different numbers of workers +EXPLAIN (COSTS OFF) +SELECT * FROM tenk1 ORDER BY twenty; +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); +set max_parallel_workers_per_gather=3; +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); +set max_parallel_workers_per_gather=2; +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); +set max_parallel_workers_per_gather=1; +SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty'); + rollback; diff --git a/src/test/regress/sql/test_setup.sql b/src/test/regress/sql/test_setup.sql index 06b0e2121f..937d1619c8 100644 --- a/src/test/regress/sql/test_setup.sql +++ b/src/test/regress/sql/test_setup.sql @@ -294,3 +294,24 @@ create function fipshash(text) returns text strict immutable parallel safe leakproof return substr(encode(sha256($1::bytea), 'hex'), 1, 32); + +-- get the number of estimated rows in the top node +create function get_estimated_rows(text) returns table (estimated int) +language plpgsql as +$$ +declare + ln text; + tmp text[]; + first_row bool := true; +begin + for ln in + execute format('explain %s', $1) + loop + if first_row then + first_row := false; + tmp := regexp_match(ln, 'rows=(\d*)'); + return query select tmp[1]::int; + end if; + end loop; +end; +$$; -- 2.39.3 (Apple Git-146)