Skip to content

Commit 81d9d1f

Browse files
jerrypengviirya
authored andcommitted
[SPARK-53847] Add ContinuousMemorySink for Real-time Mode testing
### What changes were proposed in this pull request? Add a new in memory sink called "ContinuousMemorySink" to facilitate RTM testing. This sink differentiates from the existing MemorySink by immediately sending output back to the driver once the output is generated and not just at the end of the batch which is what the current MemorySink does. ### Why are the changes needed? To facilitate RTM testing ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added simple test. There will be many RTM related tests that will be added in future PRs. ### Was this patch authored or co-authored using generative AI tooling? Closes #52550 from jerrypeng/SPARK-53847. Authored-by: Jerry Peng <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 9aab260 commit 81d9d1f

File tree

3 files changed

+260
-0
lines changed

3 files changed

+260
-0
lines changed

core/src/main/scala/org/apache/spark/util/RpcUtils.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ private[spark] object RpcUtils {
3636
rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name)
3737
}
3838

39+
def makeDriverRef(
40+
name: String,
41+
driverHost: String,
42+
driverPort: Int,
43+
rpcEnv: RpcEnv): RpcEndpointRef = {
44+
Utils.checkHost(driverHost)
45+
rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name)
46+
}
47+
3948
/** Returns the default Spark timeout to use for RPC ask operations. */
4049
def askRpcTimeout(conf: SparkConf): RpcTimeout = {
4150
RpcTimeout(conf, Seq(RPC_ASK_TIMEOUT.key, NETWORK_TIMEOUT.key), "120s")
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.sources
19+
20+
import java.util
21+
22+
import scala.collection.mutable.ArrayBuffer
23+
24+
import org.apache.spark.{SparkEnv, SparkUnsupportedOperationException}
25+
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
26+
import org.apache.spark.sql.Row
27+
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
29+
import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableCapability}
30+
import org.apache.spark.sql.connector.write.{
31+
LogicalWriteInfo,
32+
PhysicalWriteInfo,
33+
Write,
34+
WriteBuilder,
35+
WriterCommitMessage
36+
}
37+
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
38+
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
39+
import org.apache.spark.sql.types.StructType
40+
41+
/**
42+
* A sink that stores the results in memory. This [[org.apache.spark.sql.execution.streaming.Sink]]
43+
* is primarily intended for use in unit tests and does not provide durability.
44+
* This is mostly copied from MemorySink, except that the data needs to be available not in
45+
* commit() but after each write.
46+
*/
47+
class ContinuousMemorySink
48+
extends MemorySink
49+
with SupportsWrite {
50+
51+
private val batches = new ArrayBuffer[Row]()
52+
override def name(): String = "ContinuousMemorySink"
53+
54+
override def schema(): StructType = StructType(Nil)
55+
56+
override def capabilities(): util.Set[TableCapability] = {
57+
util.EnumSet.of(TableCapability.STREAMING_WRITE)
58+
}
59+
60+
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
61+
new WriteBuilder with SupportsStreamingUpdateAsAppend {
62+
private val inputSchema: StructType = info.schema()
63+
64+
override def build(): Write = {
65+
new ContinuousMemoryWrite(batches, inputSchema)
66+
}
67+
}
68+
}
69+
70+
/** Returns all rows that are stored in this [[Sink]]. */
71+
override def allData: Seq[Row] = {
72+
val batches = getBatches()
73+
batches.synchronized {
74+
batches.toSeq
75+
}
76+
}
77+
78+
override def latestBatchId: Option[Long] = {
79+
None
80+
}
81+
82+
override def latestBatchData: Seq[Row] = {
83+
throw new SparkUnsupportedOperationException(
84+
errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
85+
messageParameters = Map("operation" -> "latestBatchData")
86+
)
87+
}
88+
89+
override def dataSinceBatch(sinceBatchId: Long): Seq[Row] = {
90+
throw new SparkUnsupportedOperationException(
91+
errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
92+
messageParameters = Map("operation" -> "dataSinceBatch")
93+
)
94+
}
95+
96+
override def toDebugString: String = {
97+
s"${allData}"
98+
}
99+
100+
override def write(batchId: Long, needTruncate: Boolean, newRows: Array[Row]): Unit = {
101+
throw new SparkUnsupportedOperationException(
102+
errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
103+
messageParameters = Map("operation" -> "write")
104+
)
105+
}
106+
107+
override def clear(): Unit = synchronized {
108+
batches.clear()
109+
}
110+
111+
private def getBatches(): ArrayBuffer[Row] = {
112+
batches
113+
}
114+
115+
override def toString(): String = "ContinuousMemorySink"
116+
}
117+
118+
class ContinuousMemoryWrite(batches: ArrayBuffer[Row], schema: StructType) extends Write {
119+
override def toStreaming: StreamingWrite = {
120+
new ContinuousMemoryStreamingWrite(batches, schema)
121+
}
122+
}
123+
124+
/**
125+
* An RPC endpoint that receives rows and stores them to the ArrayBuffer in real-time.
126+
*/
127+
class MemoryRealTimeRpcEndpoint(
128+
override val rpcEnv: RpcEnv,
129+
schema: StructType,
130+
batches: ArrayBuffer[Row]
131+
) extends ThreadSafeRpcEndpoint {
132+
private val encoder = ExpressionEncoder(schema).resolveAndBind().createDeserializer()
133+
134+
override def receive: PartialFunction[Any, Unit] = {
135+
case rows: Array[InternalRow] =>
136+
// synchronized block is optional here since ThreadSafeRpcEndpoint already, just to be safe
137+
batches.synchronized {
138+
rows.foreach { row =>
139+
batches += encoder(row)
140+
}
141+
}
142+
}
143+
}
144+
145+
class ContinuousMemoryStreamingWrite(val batches: ArrayBuffer[Row], schema: StructType)
146+
extends StreamingWrite {
147+
148+
private val memoryEndpoint =
149+
new MemoryRealTimeRpcEndpoint(
150+
SparkEnv.get.rpcEnv,
151+
schema,
152+
batches
153+
)
154+
@volatile private var endpointRef: RpcEndpointRef = _
155+
156+
override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
157+
val endpointName = s"MemoryRealTimeRpcEndpoint-${java.util.UUID.randomUUID()}"
158+
endpointRef = memoryEndpoint.rpcEnv.setupEndpoint(endpointName, memoryEndpoint)
159+
RealTimeRowWriterFactory(endpointName, endpointRef.address)
160+
}
161+
162+
override def useCommitCoordinator(): Boolean = false
163+
164+
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
165+
// We don't need to commit anything in this case, as the rows have already been printed
166+
if (endpointRef != null) {
167+
memoryEndpoint.rpcEnv.stop(endpointRef)
168+
}
169+
}
170+
171+
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
172+
if (endpointRef != null) {
173+
memoryEndpoint.rpcEnv.stop(endpointRef)
174+
}
175+
}
176+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.sources
19+
20+
import org.apache.spark.SparkEnv
21+
import org.apache.spark.rpc.RpcAddress
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
24+
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
25+
import org.apache.spark.util.RpcUtils
26+
27+
/**
28+
* A [[StreamingDataWriterFactory]] that creates [[RealTimeRowWriter]], which sends rows to
29+
* the driver in real-time through RPC.
30+
*
31+
* Note that, because it sends all rows to the driver, this factory will generally be unsuitable
32+
* for production-quality sinks. It's intended for use in tests.
33+
*
34+
*/
35+
case class RealTimeRowWriterFactory(
36+
driverEndpointName: String,
37+
driverEndpointAddr: RpcAddress
38+
) extends StreamingDataWriterFactory {
39+
override def createWriter(
40+
partitionId: Int,
41+
taskId: Long,
42+
epochId: Long): DataWriter[InternalRow] = {
43+
new RealTimeRowWriter(
44+
driverEndpointName,
45+
driverEndpointAddr
46+
)
47+
}
48+
}
49+
50+
/**
51+
* A [[DataWriter]] that sends arrays of rows to the driver in real-time through RPC.
52+
*/
53+
class RealTimeRowWriter(
54+
driverEndpointName: String,
55+
driverEndpointAddr: RpcAddress
56+
) extends DataWriter[InternalRow] {
57+
58+
private val endpointRef = RpcUtils.makeDriverRef(
59+
driverEndpointName,
60+
driverEndpointAddr.host,
61+
driverEndpointAddr.port,
62+
SparkEnv.get.rpcEnv
63+
)
64+
65+
// Spark reuses the same `InternalRow` instance, here we copy it before buffer it.
66+
override def write(row: InternalRow): Unit = {
67+
endpointRef.send(Array(row.copy()))
68+
}
69+
70+
override def commit(): WriterCommitMessage = { null }
71+
72+
override def abort(): Unit = {}
73+
74+
override def close(): Unit = {}
75+
}

0 commit comments

Comments
 (0)