Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions tools/testing/publish_amqp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3
"""AMQP message publisher for triggering GeoZarr conversion workflows.

Publishes JSON payloads to RabbitMQ exchanges with support for
dynamic routing key templates based on payload fields.
"""

from __future__ import annotations

import argparse
import json
import logging
import sys
from pathlib import Path
from typing import Any

import pika
from tenacity import retry, stop_after_attempt, wait_exponential

logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def load_payload(payload_file: Path) -> dict[str, Any]:
"""Load JSON payload from file."""
try:
data: dict[str, Any] = json.loads(payload_file.read_text())
return data
except FileNotFoundError:
logger.exception("Payload file not found", extra={"file": str(payload_file)})
sys.exit(1)
except json.JSONDecodeError:
logger.exception("Invalid JSON in payload file", extra={"file": str(payload_file)})
sys.exit(1)


def format_routing_key(template: str, payload: dict[str, Any]) -> str:
"""Format routing key template using payload fields.

Example: "eopf.item.found.{collection}" → "eopf.item.found.sentinel-2-l2a"
"""
try:
return template.format(**payload)
except KeyError:
logger.exception(
"Missing required field in payload for routing key template",
extra={"template": template, "available_fields": list(payload.keys())},
)
sys.exit(1)


@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
def publish_message(
host: str,
port: int,
user: str,
password: str,
exchange: str,
routing_key: str,
payload: dict[str, Any],
virtual_host: str = "/",
) -> None:
"""Publish message to RabbitMQ exchange with automatic retry."""
credentials = pika.PlainCredentials(user, password)
parameters = pika.ConnectionParameters(
host=host,
port=port,
virtual_host=virtual_host,
credentials=credentials,
)

logger.info("Connecting to amqp://%s@%s:%s%s", user, host, port, virtual_host)
connection = pika.BlockingConnection(parameters)
try:
channel = connection.channel()
channel.basic_publish(
exchange=exchange,
routing_key=routing_key,
body=json.dumps(payload),
properties=pika.BasicProperties(
content_type="application/json",
delivery_mode=2,
),
)
logger.info("Published to exchange='%s' routing_key='%s'", exchange, routing_key)
logger.debug("Payload: %s", json.dumps(payload, indent=2))
finally:
connection.close()


def main() -> None:
"""CLI entry point for AMQP message publisher."""
parser = argparse.ArgumentParser(
description="Publish JSON payload to RabbitMQ exchange for workflow triggers"
)
parser.add_argument("--host", required=True, help="RabbitMQ host")
parser.add_argument("--port", type=int, default=5672, help="RabbitMQ port")
parser.add_argument("--user", required=True, help="RabbitMQ username")
parser.add_argument("--password", required=True, help="RabbitMQ password")
parser.add_argument("--virtual-host", default="/", help="RabbitMQ virtual host")
parser.add_argument("--exchange", required=True, help="RabbitMQ exchange name")
parser.add_argument("--routing-key", help="Static routing key")
parser.add_argument(
"--routing-key-template",
help="Template with {field} placeholders (e.g., 'eopf.item.found.{collection}')",
)
parser.add_argument("--payload-file", type=Path, required=True, help="JSON payload file path")

args = parser.parse_args()

if not args.routing_key and not args.routing_key_template:
parser.error("Must provide either --routing-key or --routing-key-template")
if args.routing_key and args.routing_key_template:
parser.error("Cannot use both --routing-key and --routing-key-template")

payload = load_payload(args.payload_file)
routing_key = args.routing_key or format_routing_key(args.routing_key_template, payload)

try:
publish_message(
host=args.host,
port=args.port,
user=args.user,
password=args.password,
exchange=args.exchange,
routing_key=routing_key,
payload=payload,
virtual_host=args.virtual_host,
)
except Exception:
logger.exception(
"Failed to publish AMQP message",
extra={
"exchange": args.exchange,
"routing_key": routing_key,
"host": args.host,
},
)
sys.exit(1)


if __name__ == "__main__":
main()
131 changes: 131 additions & 0 deletions tools/testing/test_publish_amqp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Unit tests for publish_amqp.py script."""

from __future__ import annotations

import json
import sys
from pathlib import Path

import pika.exceptions
import pytest

sys.path.insert(0, str(Path(__file__).parent.parent.parent / "scripts"))
from publish_amqp import format_routing_key, load_payload


@pytest.fixture
def sample_payload() -> dict[str, str]:
"""Sample payload for tests."""
return {"collection": "sentinel-2-l2a", "item_id": "test-123"}


@pytest.fixture
def payload_file(tmp_path: Path, sample_payload: dict[str, str]) -> Path:
"""Create a temporary payload file."""
file = tmp_path / "payload.json"
file.write_text(json.dumps(sample_payload))
return file


class TestLoadPayload:
"""Tests for payload loading."""

def test_valid_payload(self, payload_file: Path, sample_payload: dict[str, str]) -> None:
"""Load valid JSON payload."""
assert load_payload(payload_file) == sample_payload

def test_missing_file(self, tmp_path: Path) -> None:
"""Handle missing file with exit code 1."""
with pytest.raises(SystemExit, match="1"):
load_payload(tmp_path / "missing.json")

def test_invalid_json(self, tmp_path: Path) -> None:
"""Handle invalid JSON with exit code 1."""
invalid = tmp_path / "invalid.json"
invalid.write_text("{not valid json")
with pytest.raises(SystemExit, match="1"):
load_payload(invalid)


class TestFormatRoutingKey:
"""Tests for routing key formatting."""

@pytest.mark.parametrize(
("template", "payload", "expected"),
[
(
"eopf.item.found.{collection}",
{"collection": "sentinel-2-l2a"},
"eopf.item.found.sentinel-2-l2a",
),
(
"{env}.{service}.{collection}",
{"env": "prod", "service": "ingest", "collection": "s1"},
"prod.ingest.s1",
),
("static.key", {"collection": "sentinel-2"}, "static.key"),
],
)
def test_format_templates(self, template: str, payload: dict[str, str], expected: str) -> None:
"""Format various routing key templates."""
assert format_routing_key(template, payload) == expected

def test_missing_field(self) -> None:
"""Handle missing field with exit code 1."""
with pytest.raises(SystemExit, match="1"):
format_routing_key("eopf.item.found.{collection}", {"item_id": "test"})


class TestPublishMessage:
"""Tests for message publishing (mocked)."""

def test_publish_success(self, mocker: pytest.MonkeyPatch) -> None:
"""Publish message successfully."""
from publish_amqp import publish_message

mock_conn = mocker.patch("publish_amqp.pika.BlockingConnection")
mock_channel = mocker.MagicMock()
mock_conn.return_value.channel.return_value = mock_channel

publish_message(
host="rabbitmq.test",
port=5672,
user="testuser",
password="testpass",
exchange="test_exchange",
routing_key="test.key",
payload={"test": "data"},
)

mock_conn.assert_called_once()
mock_channel.basic_publish.assert_called_once()
call = mock_channel.basic_publish.call_args.kwargs
assert call["exchange"] == "test_exchange"
assert call["routing_key"] == "test.key"
assert json.loads(call["body"]) == {"test": "data"}

def test_connection_retry(self, mocker: pytest.MonkeyPatch) -> None:
"""Verify tenacity retry on transient failures."""
from publish_amqp import publish_message

mock_conn = mocker.patch("publish_amqp.pika.BlockingConnection")
mock_channel = mocker.MagicMock()

# Fail twice, succeed on third attempt
mock_conn.side_effect = [
pika.exceptions.AMQPConnectionError("Transient error"),
pika.exceptions.AMQPConnectionError("Transient error"),
mocker.MagicMock(channel=mocker.MagicMock(return_value=mock_channel)),
]

publish_message(
host="rabbitmq.test",
port=5672,
user="testuser",
password="testpass",
exchange="test_exchange",
routing_key="test.key",
payload={"test": "data"},
)

assert mock_conn.call_count == 3
Loading