Skip to content

Commit

Permalink
[SPARK-47877][SS][CONNECT] Speed up test_parity_listener
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR makes test_parity_listener run faster.

The test was slow because of `TestListenerSparkV1` and `TestListenerSparkV2` makes server calls and has long wait time, and the test runs on both listeners. They were created to verify the listener function with and without the new `onQueryIdle` callback.

This PR fixes the slowness by removing the V1 and V2 of that listener (now only a `TestListenerSpark`), and create lightweight `TestListenerLocalV1` and `TestListenerLocalV2` for the `onQueryIdle` verification.

### Why are the changes needed?

Faster and more stable ci

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

test only change

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#46072 from WweiL/speed-up-test.

Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
WweiL authored and HyukjinKwon committed Apr 16, 2024
1 parent 5321353 commit 86837d3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 69 deletions.
119 changes: 60 additions & 59 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,45 @@


# Listeners that has spark commands in callback handler functions
# V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`,
# `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
class TestListenerSparkV1(StreamingQueryListener):
class TestListenerSpark(StreamingQueryListener):
def onQueryStarted(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_start_events_v1")
df.write.mode("append").saveAsTable("listener_start_events")

def onQueryProgress(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_progress_events_v1")
df.write.mode("append").saveAsTable("listener_progress_events")

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_terminated_events_v1")
df.write.mode("append").saveAsTable("listener_terminated_events")


# V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
class TestListenerSparkV2(StreamingQueryListener):
# V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`,
# `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
class TestListenerLocalV1(StreamingQueryListener):
def __init__(self):
self.start = []
self.progress = []
self.terminated = []

def onQueryStarted(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_start_events_v2")
self.start.append(event)

def onQueryProgress(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_progress_events_v2")

def onQueryIdle(self, event):
pass
self.progress.append(event)

def onQueryTerminated(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
df.write.mode("append").saveAsTable("listener_terminated_events_v2")
self.terminated.append(event)


class TestListenerLocal(StreamingQueryListener):
class TestListenerLocalV2(StreamingQueryListener):
def __init__(self):
self.start = []
self.progress = []
Expand All @@ -87,19 +85,29 @@ def onQueryTerminated(self, event):

class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase):
def test_listener_management(self):
listener1 = TestListenerLocal()
listener2 = TestListenerLocal()
listener1 = TestListenerLocalV1()
listener2 = TestListenerLocalV2()

try:
self.spark.streams.addListener(listener1)
self.spark.streams.addListener(listener2)
q = self.spark.readStream.format("rate").load().writeStream.format("noop").start()
q = (
self.spark.readStream.format("rate")
.load()
.writeStream.format("noop")
.queryName("test_local")
.start()
)

# Both listeners should have listener events already because onQueryStarted
# is always called before DataStreamWriter.start() returns
self.assertEqual(len(listener1.start), 1)
self.assertEqual(len(listener2.start), 1)
self.check_start_event(listener1.start[0])
self.check_start_event(listener2.start[0])

while q.lastProgress is None:
q.awaitTermination(0.5)
# removeListener is a blocking call, resources are cleaned up by the time it returns
self.spark.streams.removeListener(listener1)
self.spark.streams.removeListener(listener2)
Expand All @@ -109,12 +117,13 @@ def test_listener_management(self):
q.stop()

# need to wait a while before QueryTerminatedEvent reaches client
time.sleep(15)
while len(listener1.terminated) == 0:
time.sleep(1)

self.assertEqual(len(listener1.terminated), 1)

self.check_start_event(listener1.start[0])
for event in listener1.progress:
self.check_progress_event(event)
self.check_progress_event(event, is_stateful=False)
self.check_terminated_event(listener1.terminated[0])

finally:
Expand All @@ -125,7 +134,7 @@ def test_listener_management(self):

def test_slow_query(self):
try:
listener = TestListenerLocal()
listener = TestListenerLocalV2()
self.spark.streams.addListener(listener)

slow_query = (
Expand Down Expand Up @@ -177,7 +186,7 @@ def onQueryTerminated(self, e):
raise Exception("I'm so sorry!")

try:
listener_good = TestListenerLocal()
listener_good = TestListenerLocalV2()
listener_bad = UselessListener()
self.spark.streams.addListener(listener_good)
self.spark.streams.addListener(listener_bad)
Expand All @@ -200,8 +209,14 @@ def onQueryTerminated(self, e):
q.stop()

def test_listener_events_spark_command(self):
def verify(test_listener, table_postfix):
try:
test_listener = TestListenerSpark()

try:
with self.table(
"listener_start_events",
"listener_progress_events",
"listener_terminated_events",
):
self.spark.streams.addListener(test_listener)

# This ensures the read socket on the server won't crash (i.e. because of timeout)
Expand All @@ -214,56 +229,42 @@ def verify(test_listener, table_postfix):
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("complete")
.outputMode("update")
.trigger(processingTime="5 seconds")
.start()
)

self.assertTrue(q.isActive)
# ensure at least one batch is ran
while q.lastProgress is None or q.lastProgress["batchId"] == 0:
q.awaitTermination(5)
q.awaitTermination(0.5)
q.stop()
self.assertFalse(q.isActive)

# Sleep to make sure listener_terminated_events is written successfully
time.sleep(60)

start_table_name = "listener_start_events" + table_postfix
progress_tbl_name = "listener_progress_events" + table_postfix
terminated_tbl_name = "listener_terminated_events" + table_postfix
time.sleep(
60
) # Sleep to make sure listener_terminated_events is written successfully

start_event = pyspark.cloudpickle.loads(
self.spark.read.table(start_table_name).collect()[0][0]
self.spark.read.table("listener_start_events").collect()[0][0]
)

progress_event = pyspark.cloudpickle.loads(
self.spark.read.table(progress_tbl_name).collect()[0][0]
self.spark.read.table("listener_progress_events").collect()[0][0]
)

terminated_event = pyspark.cloudpickle.loads(
self.spark.read.table(terminated_tbl_name).collect()[0][0]
self.spark.read.table("listener_terminated_events").collect()[0][0]
)

self.check_start_event(start_event)
self.check_progress_event(progress_event)
self.check_progress_event(progress_event, is_stateful=True)
self.check_terminated_event(terminated_event)

finally:
self.spark.streams.removeListener(test_listener)

# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)

with self.table(
"listener_start_events_v1",
"listener_progress_events_v1",
"listener_terminated_events_v1",
"listener_start_events_v2",
"listener_progress_events_v2",
"listener_terminated_events_v2",
):
verify(TestListenerSparkV1(), "_v1")
verify(TestListenerSparkV2(), "_v2")
finally:
self.spark.streams.removeListener(test_listener)
# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)


if __name__ == "__main__":
Expand Down
21 changes: 11 additions & 10 deletions python/pyspark/sql/tests/streaming/test_streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ def check_start_event(self, event):
self.assertTrue(isinstance(event, QueryStartedEvent))
self.assertTrue(isinstance(event.id, uuid.UUID))
self.assertTrue(isinstance(event.runId, uuid.UUID))
self.assertTrue(event.name is None or event.name == "test")
self.assertTrue(event.name is None or event.name.startswith("test"))
try:
datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
except ValueError:
self.fail("'%s' is not in ISO 8601 format.")

def check_progress_event(self, event):
def check_progress_event(self, event, is_stateful):
"""Check QueryProgressEvent"""
self.assertTrue(isinstance(event, QueryProgressEvent))
self.check_streaming_query_progress(event.progress)
self.check_streaming_query_progress(event.progress, is_stateful)

def check_terminated_event(self, event, exception=None, error_class=None):
"""Check QueryTerminatedEvent"""
Expand All @@ -65,12 +65,12 @@ def check_terminated_event(self, event, exception=None, error_class=None):
else:
self.assertEqual(event.errorClassOnException, None)

def check_streaming_query_progress(self, progress):
def check_streaming_query_progress(self, progress, is_stateful):
"""Check StreamingQueryProgress"""
self.assertTrue(isinstance(progress, StreamingQueryProgress))
self.assertTrue(isinstance(progress.id, uuid.UUID))
self.assertTrue(isinstance(progress.runId, uuid.UUID))
self.assertEqual(progress.name, "test")
self.assertTrue(progress.name.startswith("test"))
try:
json.loads(progress.json)
except Exception:
Expand Down Expand Up @@ -108,9 +108,10 @@ def check_streaming_query_progress(self, progress):
self.assertTrue(all(map(lambda v: isinstance(v, str), progress.eventTime.values())))

self.assertTrue(isinstance(progress.stateOperators, list))
self.assertTrue(len(progress.stateOperators) >= 1)
for so in progress.stateOperators:
self.check_state_operator_progress(so)
if is_stateful:
self.assertTrue(len(progress.stateOperators) >= 1)
for so in progress.stateOperators:
self.check_state_operator_progress(so)

self.assertTrue(isinstance(progress.sources, list))
self.assertTrue(len(progress.sources) >= 1)
Expand Down Expand Up @@ -313,7 +314,7 @@ def verify(test_listener):
self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()

self.check_start_event(start_event)
self.check_progress_event(progress_event)
self.check_progress_event(progress_event, True)
self.check_terminated_event(terminated_event)

# Check query terminated with exception
Expand Down Expand Up @@ -470,7 +471,7 @@ def test_streaming_query_progress_fromJson(self):
"""
progress = StreamingQueryProgress.fromJson(json.loads(progress_json))

self.check_streaming_query_progress(progress)
self.check_streaming_query_progress(progress, True)

# checks for progress
self.assertEqual(progress.id, uuid.UUID("00000000-0000-0001-0000-000000000001"))
Expand Down

0 comments on commit 86837d3

Please sign in to comment.