Skip to content

Commit c21d5a4

Browse files
cty123LuciferYang
authored andcommitted
[SPARK-54270][CONNECT] SparkConnectResultSet get* methods should call checkOpen and check index boundary
### What changes were proposed in this pull request? This PR aims to do a minor correction on the current get* functions from `SparkConnectResultSet` class. As previously discussed in the PR #52947, >1. For every getter function, if the statement is closed, the ResultSet should be unusable. I have verified this with MySQL driver and Postgresql driver. > >2. Right now when index goes out of bound, it throws `java.lang.ArrayIndexOutOfBoundsException`, but based on the specification on `java.sql.ResultSet` which is implemented by `SparkConnectResultSet` class, it should throw `java.sql.SQLException` > >``` > * throws SQLException if the columnIndex is not valid; >``` This PR proposes a unified wrapper function called `getColumnValue(columnIndex: Int)` that aims to wrap the `checkOpen` as well as the index boundary check. ### Why are the changes needed? Currently the get* functions don't follow the expected behaviors of `java.sql.ResultSet`. It's technically not a big problem, but since the `SparkConnectResultSet` aims to implement the `java.sql.ResultSet` class, it should strictly follow the specification documented on the interface definition. ### Does this PR introduce _any_ user-facing change? This PR is a small fix related to a new feature introduced recently. ### How was this patch tested? I added 2 tests each covering a bullet point I named above. These 2 functions calls all the get* functions inside `SparkConnectResultSet` class to make sure the correct exception(java.sql.SQLException) is thrown. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52988 from cty123/cty123/address-spark-connect-getters. Lead-authored-by: cty123 <[email protected]> Co-authored-by: cty <[email protected]> Signed-off-by: yangjie01 <[email protected]>
1 parent 03a0e05 commit c21d5a4

File tree

2 files changed

+84
-59
lines changed

2 files changed

+84
-59
lines changed

sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala

Lines changed: 30 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ class SparkConnectResultSet(
7676
}
7777
}
7878

79+
private def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = {
80+
checkOpen()
81+
// the passed index value is 1-indexed, but the underlying array is 0-indexed
82+
val index = columnIndex - 1
83+
if (index < 0 || index >= currentRow.length) {
84+
throw new SQLException(s"The column index is out of range: $columnIndex, " +
85+
s"number of columns: ${currentRow.length}.")
86+
}
87+
88+
if (currentRow.isNullAt(index)) {
89+
_wasNull = true
90+
defaultVal
91+
} else {
92+
_wasNull = false
93+
getter(index)
94+
}
95+
}
96+
7997
override def findColumn(columnLabel: String): Int = {
8098
sparkResult.schema.getFieldIndex(columnLabel) match {
8199
case Some(i) => i + 1
@@ -85,75 +103,35 @@ class SparkConnectResultSet(
85103
}
86104

87105
override def getString(columnIndex: Int): String = {
88-
if (currentRow.isNullAt(columnIndex - 1)) {
89-
_wasNull = true
90-
return null
91-
}
92-
_wasNull = false
93-
String.valueOf(currentRow.get(columnIndex - 1))
106+
getColumnValue(columnIndex, null: String) { idx => String.valueOf(currentRow.get(idx)) }
94107
}
95108

96109
override def getBoolean(columnIndex: Int): Boolean = {
97-
if (currentRow.isNullAt(columnIndex - 1)) {
98-
_wasNull = true
99-
return false
100-
}
101-
_wasNull = false
102-
currentRow.getBoolean(columnIndex - 1)
110+
getColumnValue(columnIndex, false) { idx => currentRow.getBoolean(idx) }
103111
}
104112

105113
override def getByte(columnIndex: Int): Byte = {
106-
if (currentRow.isNullAt(columnIndex - 1)) {
107-
_wasNull = true
108-
return 0.toByte
109-
}
110-
_wasNull = false
111-
currentRow.getByte(columnIndex - 1)
114+
getColumnValue(columnIndex, 0.toByte) { idx => currentRow.getByte(idx) }
112115
}
113116

114117
override def getShort(columnIndex: Int): Short = {
115-
if (currentRow.isNullAt(columnIndex - 1)) {
116-
_wasNull = true
117-
return 0.toShort
118-
}
119-
_wasNull = false
120-
currentRow.getShort(columnIndex - 1)
118+
getColumnValue(columnIndex, 0.toShort) { idx => currentRow.getShort(idx) }
121119
}
122120

123121
override def getInt(columnIndex: Int): Int = {
124-
if (currentRow.isNullAt(columnIndex - 1)) {
125-
_wasNull = true
126-
return 0
127-
}
128-
_wasNull = false
129-
currentRow.getInt(columnIndex - 1)
122+
getColumnValue(columnIndex, 0) { idx => currentRow.getInt(idx) }
130123
}
131124

132125
override def getLong(columnIndex: Int): Long = {
133-
if (currentRow.isNullAt(columnIndex - 1)) {
134-
_wasNull = true
135-
return 0L
136-
}
137-
_wasNull = false
138-
currentRow.getLong(columnIndex - 1)
126+
getColumnValue(columnIndex, 0.toLong) { idx => currentRow.getLong(idx) }
139127
}
140128

141129
override def getFloat(columnIndex: Int): Float = {
142-
if (currentRow.isNullAt(columnIndex - 1)) {
143-
_wasNull = true
144-
return 0.toFloat
145-
}
146-
_wasNull = false
147-
currentRow.getFloat(columnIndex - 1)
130+
getColumnValue(columnIndex, 0.toFloat) { idx => currentRow.getFloat(idx) }
148131
}
149132

150133
override def getDouble(columnIndex: Int): Double = {
151-
if (currentRow.isNullAt(columnIndex - 1)) {
152-
_wasNull = true
153-
return 0.toDouble
154-
}
155-
_wasNull = false
156-
currentRow.getDouble(columnIndex - 1)
134+
getColumnValue(columnIndex, 0.toDouble) { idx => currentRow.getDouble(idx) }
157135
}
158136

159137
override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal =
@@ -240,12 +218,9 @@ class SparkConnectResultSet(
240218
}
241219

242220
override def getObject(columnIndex: Int): AnyRef = {
243-
if (currentRow.isNullAt(columnIndex - 1)) {
244-
_wasNull = true
245-
return null
221+
getColumnValue(columnIndex, null: AnyRef) { idx =>
222+
currentRow.get(idx).asInstanceOf[AnyRef]
246223
}
247-
_wasNull = false
248-
currentRow.get(columnIndex - 1).asInstanceOf[AnyRef]
249224
}
250225

251226
override def getObject(columnLabel: String): AnyRef =
@@ -258,12 +233,9 @@ class SparkConnectResultSet(
258233
throw new SQLFeatureNotSupportedException
259234

260235
override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = {
261-
if (currentRow.isNullAt(columnIndex - 1)) {
262-
_wasNull = true
263-
return null
236+
getColumnValue(columnIndex, null: java.math.BigDecimal) { idx =>
237+
currentRow.getDecimal(idx)
264238
}
265-
_wasNull = false
266-
currentRow.getDecimal(columnIndex - 1)
267239
}
268240

269241
override def getBigDecimal(columnLabel: String): java.math.BigDecimal =

sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.connect.client.jdbc
1919

20-
import java.sql.Types
20+
import java.sql.{ResultSet, SQLException, Types}
2121

2222
import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper
2323
import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession}
@@ -248,4 +248,57 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess
248248
}
249249
}
250250
}
251+
252+
test("getter functions column index out of bound") {
253+
Seq(
254+
("'foo'", (rs: ResultSet) => rs.getString(999)),
255+
("true", (rs: ResultSet) => rs.getBoolean(999)),
256+
("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)),
257+
("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)),
258+
("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)),
259+
("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)),
260+
("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)),
261+
("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)),
262+
("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999))
263+
).foreach {
264+
case (query, getter) =>
265+
withExecuteQuery(s"SELECT $query") { rs =>
266+
assert(rs.next())
267+
val exception = intercept[SQLException] {
268+
getter(rs)
269+
}
270+
assert(exception.getMessage() ===
271+
"The column index is out of range: 999, number of columns: 1.")
272+
}
273+
}
274+
}
275+
276+
test("getter functions called after statement closed") {
277+
Seq(
278+
("'foo'", (rs: ResultSet) => rs.getString(1), "foo"),
279+
("true", (rs: ResultSet) => rs.getBoolean(1), true),
280+
("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte),
281+
("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort),
282+
("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt),
283+
("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong),
284+
("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat),
285+
("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble),
286+
("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1),
287+
new java.math.BigDecimal("1.00000"))
288+
).foreach {
289+
case (query, getter, expectedValue) =>
290+
var resultSet: Option[ResultSet] = None
291+
withExecuteQuery(s"SELECT $query") { rs =>
292+
assert(rs.next())
293+
assert(getter(rs) === expectedValue)
294+
assert(!rs.wasNull)
295+
resultSet = Some(rs)
296+
}
297+
assert(resultSet.isDefined)
298+
val exception = intercept[SQLException] {
299+
getter(resultSet.get)
300+
}
301+
assert(exception.getMessage() === "JDBC Statement is closed.")
302+
}
303+
}
251304
}

0 commit comments

Comments
 (0)