Skip to content

Commit 78d1d52

Browse files
kelvinjian-dbcloud-fan
authored andcommitted
[SPARK-54339][SQL] Fix AttributeMap non-determinism
### What changes were proposed in this pull request? This PR fixes the `+`, `updated`, and `removed` methods of `AttributeMap` to correctly hash with `Attribute.ExprId` instead of `Attribute` as a whole. ### Why are the changes needed? This change fixes non-determinism with the `AttributeMap` when an entry is being added to the `AttributeMap` with `+` such that `attr1 != attr2` but `attr1.exprId = attr2.exprId`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a new test suite. ### Was this patch authored or co-authored using generative AI tooling? Tests were generated by Claude Code on Sonnet 4.5. Closes #53044 from kelvinjian-db/fix-attributemap. Authored-by: Kelvin Jiang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent a9c1fba commit 78d1d52

File tree

2 files changed

+281
-3
lines changed

2 files changed

+281
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
4848
override def contains(k: Attribute): Boolean = get(k).isDefined
4949

5050
override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] =
51-
AttributeMap(baseMap.values.toMap + kv)
51+
new AttributeMap(baseMap + (kv._1.exprId -> kv))
5252

5353
override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] =
54-
baseMap.values.toMap + (key -> value)
54+
this + (key -> value)
5555

5656
override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
5757

58-
override def removed(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
58+
override def removed(key: Attribute): Map[Attribute, A] = new AttributeMap(baseMap - key.exprId)
5959

6060
def ++(other: AttributeMap[A]): AttributeMap[A] = new AttributeMap(baseMap ++ other.baseMap)
6161
}
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType}
22+
23+
class AttributeMapSuite extends SparkFunSuite {
24+
25+
val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
26+
val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
27+
val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
28+
29+
val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
30+
val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
31+
32+
val cAttr = AttributeReference("c", StringType)(exprId = ExprId(4))
33+
34+
test("basic map operations - get") {
35+
val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
36+
37+
// Should find by exprId, not by attribute equality
38+
assert(map.get(aLower) === Some("value1"))
39+
assert(map.get(aUpper) === Some("value1"))
40+
assert(map.get(bLower) === Some("value2"))
41+
assert(map.get(bUpper) === Some("value2"))
42+
43+
// Different exprId should not be found
44+
assert(map.get(fakeA) === None)
45+
}
46+
47+
test("basic map operations - contains") {
48+
val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
49+
50+
// Should find by exprId, not by attribute equality
51+
assert(map.contains(aLower))
52+
assert(map.contains(aUpper))
53+
assert(map.contains(bUpper))
54+
assert(!map.contains(fakeA))
55+
}
56+
57+
test("basic map operations - getOrElse") {
58+
val map = AttributeMap(Seq((aUpper, "value1")))
59+
60+
assert(map.getOrElse(aLower, "default") === "value1")
61+
assert(map.getOrElse(fakeA, "default") === "default")
62+
}
63+
64+
test("+ operator preserves ExprId-based hashing") {
65+
val map1 = AttributeMap(Seq((aUpper, "value1")))
66+
val map2 = map1 + (bUpper -> "value2")
67+
68+
// The resulting map should still be an AttributeMap
69+
assert(map2.isInstanceOf[AttributeMap[_]])
70+
71+
// Should look up by exprId, not by attribute equality
72+
assert(map2.get(aLower) === Some("value1"))
73+
assert(map2.get(bLower) === Some("value2"))
74+
}
75+
76+
test("+ operator with attribute having different metadata") {
77+
val metadata1 = new MetadataBuilder().putString("key", "value1").build()
78+
val metadata2 = new MetadataBuilder().putString("key", "value2").build()
79+
80+
// Two attributes with same exprId but different metadata
81+
val attrWithMetadata1 = AttributeReference("col", IntegerType, metadata = metadata1)(
82+
exprId = ExprId(100))
83+
val attrWithMetadata2 = AttributeReference("col", IntegerType, metadata = metadata2)(
84+
exprId = ExprId(100))
85+
86+
// These should have different hashCodes but same exprId
87+
assert(attrWithMetadata1.hashCode() != attrWithMetadata2.hashCode(),
88+
"Attributes with different metadata should have different hashCodes")
89+
assert(attrWithMetadata1.exprId == attrWithMetadata2.exprId,
90+
"Attributes should have the same exprId")
91+
92+
// Create a map with the first attribute
93+
val map1 = AttributeMap(Seq((attrWithMetadata1, "original")))
94+
95+
// Add an entry using the + operator
96+
val map2 = map1 + (cAttr -> "new")
97+
98+
// CRITICAL: The map should still find the original entry using an attribute
99+
// with the same exprId but different metadata
100+
assert(map2.get(attrWithMetadata2) === Some("original"),
101+
"AttributeMap should look up by exprId, not by attribute hashCode")
102+
103+
// And the new entry should also be present
104+
assert(map2.get(cAttr) === Some("new"))
105+
}
106+
107+
test("+ operator updates existing key") {
108+
val map1 = AttributeMap(Seq((aUpper, "value1")))
109+
val map2 = map1 + (aLower -> "updated")
110+
111+
// Since aLower has the same exprId as aUpper, it should update the value
112+
assert(map2.get(aUpper) === Some("updated"))
113+
assert(map2.get(aLower) === Some("updated"))
114+
assert(map2.size === 1)
115+
}
116+
117+
test("+ operator with type widening") {
118+
val map1: AttributeMap[String] = AttributeMap(Seq((aUpper, "value1")))
119+
val map2: AttributeMap[Any] = map1 + (bUpper -> 42)
120+
121+
assert(map2.get(aUpper) === Some("value1"))
122+
assert(map2.get(bUpper) === Some(42))
123+
}
124+
125+
test("++ operator preserves AttributeMap semantics") {
126+
val map1 = AttributeMap(Seq((aUpper, "value1")))
127+
val map2 = AttributeMap(Seq((bUpper, "value2")))
128+
val combined = map1 ++ map2
129+
130+
assert(combined.isInstanceOf[AttributeMap[_]])
131+
assert(combined.get(aLower) === Some("value1"))
132+
assert(combined.get(bLower) === Some("value2"))
133+
}
134+
135+
test("updated method") {
136+
val map1 = AttributeMap(Seq((aUpper, "value1")))
137+
val map2 = map1.updated(bUpper, "value2")
138+
139+
// Note: updated returns a Map[Attribute, B1], not AttributeMap
140+
assert(map2.get(aUpper) === Some("value1"))
141+
assert(map2.get(bUpper) === Some("value2"))
142+
}
143+
144+
test("- operator (removal)") {
145+
val map1 = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
146+
val map2 = map1 - aLower
147+
148+
// Note: - returns a Map[Attribute, A], not AttributeMap
149+
assert(map2.get(aUpper) === None)
150+
assert(map2.get(bUpper) === Some("value2"))
151+
}
152+
153+
test("iterator") {
154+
val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
155+
val entries = map.iterator.toSeq
156+
157+
assert(entries.size === 2)
158+
assert(entries.contains((aUpper, "value1")))
159+
assert(entries.contains((bUpper, "value2")))
160+
}
161+
162+
test("size") {
163+
val emptyMap = AttributeMap.empty[String]
164+
assert(emptyMap.size === 0)
165+
166+
val map1 = AttributeMap(Seq((aUpper, "value1")))
167+
assert(map1.size === 1)
168+
169+
val map2 = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
170+
assert(map2.size === 2)
171+
}
172+
173+
test("empty map") {
174+
val emptyMap = AttributeMap.empty[String]
175+
assert(emptyMap.get(aUpper) === None)
176+
assert(emptyMap.size === 0)
177+
assert(!emptyMap.contains(aUpper))
178+
}
179+
180+
test("duplicate keys in construction") {
181+
// When constructing with duplicate exprIds, the last one should win
182+
val map = AttributeMap(Seq((aUpper, "value1"), (aLower, "value2")))
183+
assert(map.size === 1)
184+
assert(map.get(aUpper) === Some("value2"))
185+
}
186+
187+
test("map construction from Map") {
188+
val regularMap = Map(aUpper -> "value1", bUpper -> "value2")
189+
val attrMap = AttributeMap(regularMap)
190+
191+
assert(attrMap.get(aLower) === Some("value1"))
192+
assert(attrMap.get(bLower) === Some("value2"))
193+
}
194+
195+
test("chained + operations") {
196+
val map = AttributeMap.empty[String] + (aUpper -> "value1") + (bUpper -> "value2") +
197+
(cAttr -> "value3")
198+
199+
assert(map.size === 3)
200+
assert(map.get(aLower) === Some("value1"))
201+
assert(map.get(bLower) === Some("value2"))
202+
assert(map.get(cAttr) === Some("value3"))
203+
}
204+
205+
test("+ should be deterministic with Attributes with diff hashcodes and same exprId") {
206+
// The HashMap needs to be of a certain size before the hashing starts to collide, set up
207+
// these AttributeMaps to start with size 5.
208+
var map1 = AttributeMap(
209+
Seq(
210+
AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
211+
AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
212+
AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
213+
AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
214+
AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
215+
)
216+
)
217+
var map2 = AttributeMap(
218+
Seq(
219+
AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
220+
AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
221+
AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
222+
AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
223+
AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
224+
)
225+
)
226+
val qualifier1 = Seq("d")
227+
val qualifier2 = Seq()
228+
val exprId = ExprId(42)
229+
val attr1 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier1)
230+
val attr2 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier2)
231+
assert(attr1.hashCode != attr2.hashCode)
232+
233+
map1 = map1 + (attr1 -> 100)
234+
map1 = map1 + (attr2 -> 200)
235+
assert(map1.get(attr2) === Some(200))
236+
237+
map2 = map2 + (attr2 -> 200)
238+
map2 = map2 + (attr1 -> 100)
239+
assert(map2.get(attr2) === Some(100))
240+
}
241+
242+
test("updated should be deterministic with Attributes with diff hashcodes and same exprId") {
243+
// The HashMap needs to be of a certain size before the hashing starts to collide, set up
244+
// these AttributeMaps to start with size 5.
245+
var map1: Map[Attribute, Int] = AttributeMap(
246+
Seq(
247+
AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
248+
AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
249+
AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
250+
AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
251+
AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
252+
)
253+
)
254+
var map2: Map[Attribute, Int] = AttributeMap(
255+
Seq(
256+
AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
257+
AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
258+
AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
259+
AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
260+
AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
261+
)
262+
)
263+
val qualifier1 = Seq("d")
264+
val qualifier2 = Seq()
265+
val exprId = ExprId(42)
266+
val attr1 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier1)
267+
val attr2 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier2)
268+
assert(attr1.hashCode != attr2.hashCode)
269+
270+
map1 = map1.updated(attr1, 100)
271+
map1 = map1.updated(attr2, 200)
272+
assert(map1.get(attr2) === Some(200))
273+
274+
map2 = map2.updated(attr2, 200)
275+
map2 = map2.updated(attr1, 100)
276+
assert(map2.get(attr2) === Some(100))
277+
}
278+
}

0 commit comments

Comments
 (0)