Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ PYTHON_VERSION_MAJOR:=$(shell $(PYTHON) -c "import sys;print(sys.version_info[0]
PLATFORM := $(shell uname)
VERSION :=$(shell poetry version | sed 's/stomp.py\s*//g' | sed 's/\./, /g')
SHELL=/bin/bash
ARTEMIS_VERSION=2.22.0
ARTEMIS_VERSION=2.23.1
TEST_CMD := $(shell podman network exists stomptest &> /dev/null && echo "podman unshare --rootless-netns poetry" || echo "poetry")

all: test install
Expand Down
6 changes: 5 additions & 1 deletion stomp/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,18 @@ def connect(self, *args, **kwargs):
self.transport.start()
Protocol11.connect(self, *args, **kwargs)

def disconnect(self, receipt=None, headers=None, **keyword_headers):
def disconnect(self, receipt=None, headers=None, wait=False, **keyword_headers):
"""
Call the protocol disconnection, and then stop the transport itself.

:param str receipt: the receipt to use with the disconnect
:param dict headers: a map of any additional headers to send with the disconnection
:param bool wait: wait for the started messages to finish ack/nack before disconnection
:param keyword_headers: any additional headers to send with the disconnection
"""
if wait:
self.transport.begin_stop()

Protocol11.disconnect(self, receipt, headers, **keyword_headers)
if receipt is not None:
self.transport.stop()
Expand Down
89 changes: 68 additions & 21 deletions stomp/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ class BaseTransport(stomp.listener.Publisher):
__content_length_re = re.compile(b"^content-length[:]\\s*(?P<value>[0-9]+)", re.MULTILINE)

def __init__(self, auto_decode=True, encoding="utf-8", is_eol_fc=is_eol_default):
self.__receiver_thread_sending_condition = threading.Condition()
self.__receiver_thread_sent = True
self.__recvbuf = b""
self.listeners = {}
self.running = False
self.receiving = True
self.blocking = None
self.connected = False
self.connection_error = False
Expand All @@ -79,7 +82,7 @@ def __init__(self, auto_decode=True, encoding="utf-8", is_eol_fc=is_eol_default)
self.__listeners_change_condition = threading.Condition()
self.__receiver_thread_exit_condition = threading.Condition()
self.__receiver_thread_exited = False
self.__send_wait_condition = threading.Condition()
self.__receipt_wait_condition = threading.Condition()
self.__connect_wait_condition = threading.Condition()
self.__auto_decode = auto_decode
self.__encoding = encoding
Expand Down Expand Up @@ -112,6 +115,13 @@ def start(self):
logging.info("Created thread %s using func %s", receiver_thread, self.create_thread_fc)
self.notify("connecting")

def begin_stop(self):
"""
Begin stop of the connection. Stops reading new messages but keep thread to finish ack/nack of messages.
"""
# emit stop reading new messages
self.receiving = False

def stop(self):
"""
Stop the connection. Performs a clean shutdown by waiting for the
Expand Down Expand Up @@ -206,9 +216,9 @@ def notify(self, frame_type, frame=None):
# logic for wait-on-receipt notification
receipt = frame.headers["receipt-id"]
receipt_value = self.__receipts.get(receipt)
with self.__send_wait_condition:
with self.__receipt_wait_condition:
self.set_receipt(receipt, None)
self.__send_wait_condition.notify()
self.__receipt_wait_condition.notifyAll()

if receipt_value == CMD_DISCONNECT:
self.set_connected(False)
Expand All @@ -232,7 +242,7 @@ def notify(self, frame_type, frame=None):
if not notify_func:
logging.debug("listener %s has no method on_%s", listener, frame_type)
continue
if frame_type in ("heartbeat", "disconnected"):
if frame_type in ("disconnecting", "heartbeat", "disconnected"):
notify_func()
continue
if frame_type == "connecting":
Expand All @@ -252,26 +262,36 @@ def transmit(self, frame):

:param Frame frame: the Frame object to transmit
"""
with self.__listeners_change_condition:
listeners = sorted(self.listeners.items())
with self.__receiver_thread_sending_condition:
self.__receiver_thread_sent = False
self.__receiver_thread_sending_condition.notify_all()

for (_, listener) in listeners:
try:
listener.on_send(frame)
except AttributeError:
continue
try:
with self.__listeners_change_condition:
listeners = sorted(self.listeners.items())

if frame.cmd == CMD_DISCONNECT and HDR_RECEIPT in frame.headers:
self.__disconnect_receipt = frame.headers[HDR_RECEIPT]
for (_, listener) in listeners:
try:
listener.on_send(frame)
except AttributeError:
continue

lines = convert_frame(frame)
packed_frame = pack(lines)
if frame.cmd == CMD_DISCONNECT and HDR_RECEIPT in frame.headers:
self.__disconnect_receipt = frame.headers[HDR_RECEIPT]

if logging.isEnabledFor(logging.DEBUG):
logging.debug("Sending frame: %s", clean_lines(lines))
else:
logging.info("Sending frame: %r", frame.cmd or "heartbeat")
self.send(packed_frame)
lines = convert_frame(frame)
packed_frame = pack(lines)

if logging.isEnabledFor(logging.DEBUG):
logging.debug("Sending frame: %s", clean_lines(lines))
else:
logging.info("Sending frame: %r", frame.cmd or "heartbeat")
self.send(packed_frame)

finally:
with self.__receiver_thread_sending_condition:
self.__receiver_thread_sent = True
self.__receiver_thread_sending_condition.notify_all()

def send(self, encoded_frame):
"""
Expand Down Expand Up @@ -323,6 +343,29 @@ def wait_for_connection(self, timeout=None):
if not self.running or not self.is_connected():
raise exception.ConnectFailedException()

def wait_for_receipt(self, receipt_id, timeout=None):
"""
Wait until we've received a receipt from the server.

:param str receipt_id: the receipt_id
:param float timeout: how long to wait, in seconds
"""
if timeout is not None:
wait_time = timeout / 10.0
else:
wait_time = None
with self.__receipt_wait_condition:
while self.__receipts.get(receipt_id):
self.__receipt_wait_condition.wait(wait_time)

def __wait_finish_processing_received_messages(self):
# wait to finish process messages in progress
with self.__receiver_thread_sending_condition:
while not self.__receiver_thread_sent:
self.__receiver_thread_sending_condition.wait()

self.stop()

def __receiver_loop(self):
"""
Main loop listening for incoming data.
Expand All @@ -332,7 +375,7 @@ def __receiver_loop(self):
try:
while self.running:
try:
while self.running:
while self.running and self.receiving:
frames = self.__read()

for frame in frames:
Expand All @@ -345,6 +388,10 @@ def __receiver_loop(self):
if self.__auto_decode:
f.body = decode(f.body)
self.process_frame(f, frame)

if self.running and not self.receiving:
self.__wait_finish_processing_received_messages()

except exception.ConnectionClosedException:
if self.running:
#
Expand Down
110 changes: 110 additions & 0 deletions tests/test_disconnect_wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging

import stomp
from stomp.listener import TestListener
from .testutils import *


class BrokenConnectionListener(TestListener):
def __init__(self, connection=None):
TestListener.__init__(self)
self.connection = connection
self.messages_started = 0
self.messages_completed = 0

def on_error(self, frame):
TestListener.on_error(self, frame)
assert frame.body.startswith("org.apache.activemq.transport.stomp.ProtocolException: Not connected")

def on_message(self, frame):
TestListener.on_message(self, frame)
self.messages_started += 1

if self.connection.is_connected():
try:
self.connection.ack(frame.headers["message-id"], frame.headers["subscription"])
self.messages_completed += 1
except BrokenPipeError:
logging.error("Expected BrokenPipeError")
self.errors += 1


def conn():
c = stomp.Connection11(get_default_host(), try_loopback_connect=False)
c.set_listener("testlistener", BrokenConnectionListener(c))
c.connect(get_default_user(), get_default_password(), wait=True)
return c


def run_race_condition_situation(conn, wait):
# happens when using ack mode "client-individual"
# some load, eg > 50 messages received at same time (simulated with transaction)
listener = conn.get_listener("testlistener") # type: BrokenConnectionListener

queuename = "/queue/disconnectmidack-%s" % listener.timestamp
conn.subscribe(destination=queuename, id=1, ack="client-individual")

trans_id = conn.begin()
for i in range(50):
conn.send(body="test message", destination=queuename, transaction=trans_id)
conn.commit(transaction=trans_id)

listener.wait_for_message()
conn.disconnect(wait=wait)

# wait for some messages to start between the time of disconnect start and finish (when the race condition happens)
# needed to check result of listener.errors
time.sleep(0.5)

# return listener for asserts
return listener


def assert_race_condition_disconnect_mid_ack(conn, wait=False):
listener = run_race_condition_situation(conn, wait)

started = listener.messages_started
logging.debug("messages started %d", started)

assert listener.connections == 1, "should have received 1 connection acknowledgement"
assert listener.messages == started, f"should have received {started} message"

# Causes either BrokenPipeError or ProtocolException: Not connected
assert listener.errors >= 1, "should have at least one error"
assert listener.messages_started > listener.messages_completed, f"should have not completed all started"


def assert_no_race_condition_disconnect_mid_ack(conn, wait=False):
listener = run_race_condition_situation(conn, wait)

started = listener.messages_started
logging.debug("T%s : messages started %d", started, threading.get_native_id())

assert listener.connections == 1, "should have received 1 connection acknowledgement"
assert listener.messages == started, f"should have received {started} message"

assert listener.errors == 0, "should not have errors"
assert listener.messages_started == listener.messages_completed, f"should have completed all started"


def test_assert_race_condition_in_disconnect_mid_ack():
found_race_condition = False
retries_until_race_condition = 0
while not found_race_condition:
try:
assert_race_condition_disconnect_mid_ack(conn())
found_race_condition = True
except AssertionError as e:
retries_until_race_condition += 1
continue

assert found_race_condition is True
# might occur at first try, might take 50 retries
logging.warning("Tries until race condition: %d", retries_until_race_condition)


def test_assert_fixed_race_condition_in_disconnect_mid_ack():
# same test case but asserts no error
# you can increase forever, it always passes
for n in range(100):
assert_no_race_condition_disconnect_mid_ack(conn(), wait=True)