Skip to content

Commit ea04238

Browse files
jiateohHeartSaVioR
authored andcommitted
[SPARK-53870][PYTHON][SS][4.0] Fix partial read bug for large proto messages in TransformWithStateInPySparkStateServer
### What changes were proposed in this pull request? This is a branch-4.0 PR for #52539. Description is copied and updated below (4.0 has a slightly different test setup and only provides pandas tests). Fix the TransformWithState StateServer's `parseProtoMessage` method to fully read the desired message using the correct [readFully DataInputStream API](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/io/DataInput.html#readFully(byte%5B%5D)) rather than `read` (InputStream/FilterInputStream) which only reads all available data and may not return the full message. [`readFully` (DataInputStream)](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/io/DataInput.html#readFully(byte%5B%5D)) will continue fetching until it fills up the provided buffer. In addition to the linked API above, this StackOverflow post also illustrates the difference between the two APIs: https://stackoverflow.com/a/25900095 ### Why are the changes needed? For large state values used in the TransformWithState API, `inputStream.read` is not guaranteed to read `messageLen`'s bytes of data as per the InputStream API. For large values, `read` will return prematurely and the messageBytes will only be partially filled, yielding an incorrect and likely unparseable proto message. This is not a common scenario, as testing also indicated that the actual proto messages had to be somewhat large to consistently trigger this error. The test case I added uses 512KB strings in the state value updates. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? (Note: compared to original PR, this 4.0 branch organizes tests differently and only supports the pandas tests) Added a new test case using 512KB strings: - Value state update - List state update with 3 (different) values (note: list state provides a multi-value update API, so this message is even larger than the other two) - Map state update with single key/value ``` build/sbt -Phive -Phive-thriftserver -DskipTests package python/run-tests --testnames 'pyspark.sql.tests.pandas.test_pandas_transform_with_state TransformWithStateInPandasTests' ``` The configured data size (512KB) triggers an incomplete read, while also completing in a reasonable time (within 30s on my laptop). I had separately tested a larger input size of 4MB which took 30min which I considered too expensive to include in the test. Below is sample/testing results from using `read` only (i.e., no fix) and adding a check on message length vs read bytes ([test code is included in this commit](b68cfd7) but reverted later for the PR). The check is no longer required after the `readFully` fix as that is handled within the provided API. ``` TransformWithStateInPandasTests pyspark.errors.exceptions.base.PySparkRuntimeError: Error updating map state value: TESTING: Failed to read message bytes: expected 524369 bytes, but only read 261312 bytes ``` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code (claude-sonnet-4-5-20250929) Closes #52596 from jiateoh/tws_readFully_fix-4.0. Authored-by: Jason Teoh <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 1a9051a commit ea04238

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,49 @@ def check_exception(error):
16011601
check_exception=check_exception,
16021602
)
16031603

1604+
def test_transform_with_state_in_pandas_large_values(self):
1605+
"""Test large state values (512KB) to validate readFully fix for SPARK-53870"""
1606+
1607+
def check_results(batch_df, batch_id):
1608+
batch_df.collect()
1609+
target_size_bytes = 512 * 1024
1610+
large_string = "a" * target_size_bytes
1611+
expected_list_elements = ",".join(
1612+
[large_string, large_string + "b", large_string + "c"]
1613+
)
1614+
expected_map_result = f"large_string_key:{large_string}"
1615+
1616+
assert set(batch_df.sort("id").collect()) == {
1617+
Row(
1618+
id="0",
1619+
valueStateResult=large_string,
1620+
listStateResult=expected_list_elements,
1621+
mapStateResult=expected_map_result,
1622+
),
1623+
Row(
1624+
id="1",
1625+
valueStateResult=large_string,
1626+
listStateResult=expected_list_elements,
1627+
mapStateResult=expected_map_result,
1628+
),
1629+
}
1630+
1631+
output_schema = StructType(
1632+
[
1633+
StructField("id", StringType(), True),
1634+
StructField("valueStateResult", StringType(), True),
1635+
StructField("listStateResult", StringType(), True),
1636+
StructField("mapStateResult", StringType(), True),
1637+
]
1638+
)
1639+
1640+
self._test_transform_with_state_in_pandas_basic(
1641+
PandasLargeValueStatefulProcessor(),
1642+
check_results,
1643+
single_batch=True,
1644+
output_schema=output_schema,
1645+
)
1646+
16041647

16051648
class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
16061649
# this dict is the same as input initial state dataframe
@@ -2374,6 +2417,46 @@ def close(self) -> None:
23742417
pass
23752418

23762419

2420+
class PandasLargeValueStatefulProcessor(StatefulProcessor):
2421+
"""Test processor for large state values (512KB) to validate readFully fix"""
2422+
2423+
def init(self, handle: StatefulProcessorHandle):
2424+
value_state_schema = StructType([StructField("value", StringType(), True)])
2425+
self.value_state = handle.getValueState("valueState", value_state_schema)
2426+
2427+
list_state_schema = StructType([StructField("value", StringType(), True)])
2428+
self.list_state = handle.getListState("listState", list_state_schema)
2429+
2430+
self.map_state = handle.getMapState("mapState", "key string", "value string")
2431+
2432+
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
2433+
target_size_bytes = 512 * 1024
2434+
large_string = "a" * target_size_bytes
2435+
2436+
self.value_state.update((large_string,))
2437+
value_retrieved = self.value_state.get()[0]
2438+
2439+
self.list_state.put([(large_string,), (large_string + "b",), (large_string + "c",)])
2440+
list_retrieved = list(self.list_state.get())
2441+
list_elements = ",".join([elem[0] for elem in list_retrieved])
2442+
2443+
map_key = ("large_string_key",)
2444+
self.map_state.updateValue(map_key, (large_string,))
2445+
map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}"
2446+
2447+
yield pd.DataFrame(
2448+
{
2449+
"id": key,
2450+
"valueStateResult": [value_retrieved],
2451+
"listStateResult": [list_elements],
2452+
"mapStateResult": [map_retrieved],
2453+
}
2454+
)
2455+
2456+
def close(self) -> None:
2457+
pass
2458+
2459+
23772460
class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase):
23782461
pass
23792462

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class TransformWithStateInPandasStateServer(
190190
private def parseProtoMessage(): StateRequest = {
191191
val messageLen = inputStream.readInt()
192192
val messageBytes = new Array[Byte](messageLen)
193-
inputStream.read(messageBytes)
193+
inputStream.readFully(messageBytes)
194194
StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
195195
}
196196

0 commit comments

Comments
 (0)