From 9ca67658d19e6c258eb4021a326ed7d38b3ab75f Mon Sep 17 00:00:00 2001
From: David Rowley <drowley@postgresql.org>
Date: Thu, 17 Oct 2024 14:25:08 +1300
Subject: [PATCH] Don't store intermediate hash values in ExprState->resvalue

adf97c156 made it so ExprStates could support hashing and changed Hash
Join to use that instead of manually extracting Datums from tuples and
hashing them one column at a time.

When hashing multiple columns or expressions, the code added in that
commit stored the intermediate hash value in the ExprState's resvalue
field.  That was a mistake as steps may be injected into the ExprState
between each hashing step that look at or overwrite the stored
intermediate hash value.  EEOP_PARAM_SET is an example of such a step.

Here we fix this by adding a new dedicated field for storing
intermediate hash values and adjust the code so that all apart from the
final hashing step store their result in the intermediate field.

In passing, rename a variable so that it's more aligned to the
surrounding code and also so a few lines stay within the 80 char margin.

Reported-by: Andres Freund
Reviewed-by: Alena Rybakina <a.rybakina@postgrespro.ru>
Discussion: https://postgr.es/m/CAApHDvqo9eenEFXND5zZ9JxO_k4eTA4jKMGxSyjdTrsmYvnmZw@mail.gmail.com
---
 src/backend/executor/execExpr.c       | 35 +++++++++++++++--
 src/backend/executor/execExprInterp.c | 16 ++++----
 src/backend/jit/llvm/llvmjit_expr.c   | 18 ++++++---
 src/include/executor/execExpr.h       |  1 +
 src/test/regress/expected/join.out    | 55 +++++++++++++++++++++++++++
 src/test/regress/sql/join.sql         | 21 ++++++++++
 6 files changed, 129 insertions(+), 17 deletions(-)

diff --git a/src/backend/executor/execExpr.c b/src/backend/executor/execExpr.c
index c8077aa57b..a343d0bc6a 100644
--- a/src/backend/executor/execExpr.c
+++ b/src/backend/executor/execExpr.c
@@ -3996,6 +3996,7 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 {
 	ExprState  *state = makeNode(ExprState);
 	ExprEvalStep scratch = {0};
+	NullableDatum *iresult = NULL;
 	List	   *adjust_jumps = NIL;
 	ListCell   *lc;
 	ListCell   *lc2;
@@ -4009,6 +4010,14 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 	/* Insert setup steps as needed. */
 	ExecCreateExprSetupSteps(state, (Node *) hash_exprs);
 
+	/*
+	 * When hashing more than 1 expression or if we have an init value, we
+	 * need somewhere to store the intermediate hash value so that it's
+	 * available to be combined with the result of subsequent hashing.
+	 */
+	if (list_length(hash_exprs) > 1 || init_value != 0)
+		iresult = palloc(sizeof(NullableDatum));
+
 	if (init_value == 0)
 	{
 		/*
@@ -4024,8 +4033,8 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 		/* Set up operation to set the initial value. */
 		scratch.opcode = EEOP_HASHDATUM_SET_INITVAL;
 		scratch.d.hashdatum_initvalue.init_value = UInt32GetDatum(init_value);
-		scratch.resvalue = &state->resvalue;
-		scratch.resnull = &state->resnull;
+		scratch.resvalue = &iresult->value;
+		scratch.resnull = &iresult->isnull;
 
 		ExprEvalPushStep(state, &scratch);
 
@@ -4063,8 +4072,26 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 						&fcinfo->args[0].value,
 						&fcinfo->args[0].isnull);
 
-		scratch.resvalue = &state->resvalue;
-		scratch.resnull = &state->resnull;
+		if (i == list_length(hash_exprs) - 1)
+		{
+			/* the result for hashing the final expr is stored in the state */
+			scratch.resvalue = &state->resvalue;
+			scratch.resnull = &state->resnull;
+		}
+		else
+		{
+			Assert(iresult != NULL);
+
+			/* intermediate values are stored in an intermediate result */
+			scratch.resvalue = &iresult->value;
+			scratch.resnull = &iresult->isnull;
+		}
+
+		/*
+		 * NEXT32 opcodes need to look at the intermediate result.  We might
+		 * as well just set this for all ops.  FIRSTs won't look at it.
+		 */
+		scratch.d.hashdatum.iresult = iresult;
 
 		/* Initialize function call parameter structure too */
 		InitFunctionCallInfoData(*fcinfo, finfo, 1, inputcollid, NULL, NULL);
diff --git a/src/backend/executor/execExprInterp.c b/src/backend/executor/execExprInterp.c
index 9fd988cc99..6a7f18f6de 100644
--- a/src/backend/executor/execExprInterp.c
+++ b/src/backend/executor/execExprInterp.c
@@ -1600,10 +1600,11 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 		EEO_CASE(EEOP_HASHDATUM_NEXT32)
 		{
 			FunctionCallInfo fcinfo = op->d.hashdatum.fcinfo_data;
-			uint32		existing_hash = DatumGetUInt32(*op->resvalue);
+			uint32		existinghash;
 
+			existinghash = DatumGetUInt32(op->d.hashdatum.iresult->value);
 			/* combine successive hash values by rotating */
-			existing_hash = pg_rotate_left32(existing_hash, 1);
+			existinghash = pg_rotate_left32(existinghash, 1);
 
 			/* leave the hash value alone on NULL inputs */
 			if (!fcinfo->args[0].isnull)
@@ -1612,10 +1613,10 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 
 				/* execute hash func and combine with previous hash value */
 				hashvalue = DatumGetUInt32(op->d.hashdatum.fn_addr(fcinfo));
-				existing_hash = existing_hash ^ hashvalue;
+				existinghash = existinghash ^ hashvalue;
 			}
 
-			*op->resvalue = UInt32GetDatum(existing_hash);
+			*op->resvalue = UInt32GetDatum(existinghash);
 			*op->resnull = false;
 
 			EEO_NEXT();
@@ -1638,15 +1639,16 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 			}
 			else
 			{
-				uint32		existing_hash = DatumGetUInt32(*op->resvalue);
+				uint32		existinghash;
 				uint32		hashvalue;
 
+				existinghash = DatumGetUInt32(op->d.hashdatum.iresult->value);
 				/* combine successive hash values by rotating */
-				existing_hash = pg_rotate_left32(existing_hash, 1);
+				existinghash = pg_rotate_left32(existinghash, 1);
 
 				/* execute hash func and combine with previous hash value */
 				hashvalue = DatumGetUInt32(op->d.hashdatum.fn_addr(fcinfo));
-				*op->resvalue = UInt32GetDatum(existing_hash ^ hashvalue);
+				*op->resvalue = UInt32GetDatum(existinghash ^ hashvalue);
 				*op->resnull = false;
 			}
 
diff --git a/src/backend/jit/llvm/llvmjit_expr.c b/src/backend/jit/llvm/llvmjit_expr.c
index 48ccdb942a..0b3b5748ea 100644
--- a/src/backend/jit/llvm/llvmjit_expr.c
+++ b/src/backend/jit/llvm/llvmjit_expr.c
@@ -1940,13 +1940,16 @@ llvm_compile_expr(ExprState *state)
 					{
 						LLVMValueRef v_tmp1;
 						LLVMValueRef v_tmp2;
+						LLVMValueRef tmp;
+
+						tmp = l_ptr_const(&op->d.hashdatum.iresult->value,
+										  l_ptr(TypeSizeT));
 
 						/*
 						 * Fetch the previously hashed value from where the
-						 * EEOP_HASHDATUM_FIRST operation stored it.
+						 * previous hash operation stored it.
 						 */
-						v_prevhash = l_load(b, TypeSizeT, v_resvaluep,
-											"prevhash");
+						v_prevhash = l_load(b, TypeSizeT, tmp, "prevhash");
 
 						/*
 						 * Rotate bits left by 1 bit.  Be careful not to
@@ -2062,13 +2065,16 @@ llvm_compile_expr(ExprState *state)
 					{
 						LLVMValueRef v_tmp1;
 						LLVMValueRef v_tmp2;
+						LLVMValueRef tmp;
+
+						tmp = l_ptr_const(&op->d.hashdatum.iresult->value,
+										  l_ptr(TypeSizeT));
 
 						/*
 						 * Fetch the previously hashed value from where the
-						 * EEOP_HASHDATUM_FIRST_STRICT operation stored it.
+						 * previous hash operation stored it.
 						 */
-						v_prevhash = l_load(b, TypeSizeT, v_resvaluep,
-											"prevhash");
+						v_prevhash = l_load(b, TypeSizeT, tmp, "prevhash");
 
 						/*
 						 * Rotate bits left by 1 bit.  Be careful not to
diff --git a/src/include/executor/execExpr.h b/src/include/executor/execExpr.h
index eec0aa699e..cd97dfa062 100644
--- a/src/include/executor/execExpr.h
+++ b/src/include/executor/execExpr.h
@@ -580,6 +580,7 @@ typedef struct ExprEvalStep
 			/* faster to access without additional indirection: */
 			PGFunction	fn_addr;	/* actual call address */
 			int			jumpdone;	/* jump here on null */
+			NullableDatum *iresult; /* intermediate hash result */
 		}			hashdatum;
 
 		/* for EEOP_CONVERT_ROWTYPE */
diff --git a/src/test/regress/expected/join.out b/src/test/regress/expected/join.out
index 756c2e2496..5669ed929a 100644
--- a/src/test/regress/expected/join.out
+++ b/src/test/regress/expected/join.out
@@ -2358,6 +2358,61 @@ where b.f1 = t.thousand and a.f1 = b.f1 and (a.f1+b.f1+999) = t.tenthous;
 ----+----+----------+----------
 (0 rows)
 
+--
+-- Test hash joins with multiple hash keys and subplans.
+--
+-- First ensure we get a hash join with multiple hash keys.
+explain (costs off)
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+                                             QUERY PLAN                                             
+----------------------------------------------------------------------------------------------------
+ Sort
+   Sort Key: t1.unique1
+   ->  Hash Join
+         Hash Cond: ((t1.two = t2.two) AND (t1.unique1 = (SubPlan 2)))
+         ->  Bitmap Heap Scan on tenk1 t1
+               Recheck Cond: (unique1 < 10)
+               ->  Bitmap Index Scan on tenk1_unique1
+                     Index Cond: (unique1 < 10)
+         ->  Hash
+               ->  Bitmap Heap Scan on tenk1 t2
+                     Recheck Cond: (unique1 < 10)
+                     ->  Bitmap Index Scan on tenk1_unique1
+                           Index Cond: (unique1 < 10)
+               SubPlan 2
+                 ->  Result
+                       InitPlan 1
+                         ->  Limit
+                               ->  Index Only Scan using tenk1_unique1 on tenk1
+                                     Index Cond: ((unique1 IS NOT NULL) AND (unique1 = t2.unique1))
+(19 rows)
+
+-- Ensure we get the expected result
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+ unique1 | unique1 
+---------+---------
+       0 |       0
+       1 |       1
+       2 |       2
+       3 |       3
+       4 |       4
+       5 |       5
+       6 |       6
+       7 |       7
+       8 |       8
+       9 |       9
+(10 rows)
+
 --
 -- checks for correct handling of quals in multiway outer joins
 --
diff --git a/src/test/regress/sql/join.sql b/src/test/regress/sql/join.sql
index 0c65e5af4b..73474bb64f 100644
--- a/src/test/regress/sql/join.sql
+++ b/src/test/regress/sql/join.sql
@@ -441,6 +441,27 @@ select a.f1, b.f1, t.thousand, t.tenthous from
   (select sum(f1) as f1 from int4_tbl i4b) b
 where b.f1 = t.thousand and a.f1 = b.f1 and (a.f1+b.f1+999) = t.tenthous;
 
+--
+-- Test hash joins with multiple hash keys and subplans.
+--
+
+-- First ensure we get a hash join with multiple hash keys.
+explain (costs off)
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+
+-- Ensure we get the expected result
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+
 --
 -- checks for correct handling of quals in multiway outer joins
 --