diff --git a/Cargo.toml b/Cargo.toml index 058288a..80f939c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,6 @@ base64 = "0.22.1" [dev.dependencies] tokio-tungstenite = { version = "0.24.0", features = ["native-tls"] } +tokio-test = "0.4" +tempfile = "3.0" +mockall = "0.12" diff --git a/aura/tests/test_pair.py b/aura/tests/test_pair.py index 544319c..5696c75 100644 --- a/aura/tests/test_pair.py +++ b/aura/tests/test_pair.py @@ -2,13 +2,13 @@ from PIL import Image import subprocess from aura.dataset import PairsGenerator -import time + @pytest.fixture def random_image(): """Generate a random image for testing.""" width, height = 100, 100 - image = Image.new("RGB", (width, height), "blue") + image = Image.new("RGB", (width, height), "blue") return image @@ -16,17 +16,25 @@ def random_image(): def ensure_ollama_running(): """Ensure Ollama server is running, skip test if it's not available.""" try: - subprocess.run(["ollama", "list"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + subprocess.run( + ["ollama", "list"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) except FileNotFoundError: pytest.skip("Ollama CLI is not installed.") except subprocess.CalledProcessError: pytest.skip("Ollama is not running or not connected to the app.") - + + def test_generate_text_description(random_image, ensure_ollama_running): """ Test _generate_text_description using Ollama. Skips if Ollama is unavailable. """ - description = PairsGenerator._generate_text_description(random_image, model="ollama/bakllava") + description = PairsGenerator._generate_text_description( + random_image, model="ollama/bakllava" + ) assert isinstance(description, str) assert len(description) > 0 diff --git a/aura/tests/test_processing_pipeline.py b/aura/tests/test_processing_pipeline.py index f7cfa5a..eeb9025 100644 --- a/aura/tests/test_processing_pipeline.py +++ b/aura/tests/test_processing_pipeline.py @@ -1,46 +1,49 @@ -import pytest -import numpy as np -import cv2 import os -import logging from pathlib import Path import tempfile + +import cv2 +import numpy as np +import pytest import requests -from io import BytesIO -from aura.camera import ProcessingPipeline, FaceNotFoundException + +from aura.camera import FaceNotFoundException, ProcessingPipeline def download_image(url): response = requests.get(url) - return cv2.imdecode( - np.frombuffer(response.content, np.uint8), - cv2.IMREAD_COLOR - ) + return cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR) + @pytest.fixture def pipeline(): with tempfile.TemporaryDirectory() as tmp_dir: yield ProcessingPipeline(log_path=tmp_dir, verbose=2) + @pytest.fixture def image_with_face(): url = "https://www.yourtango.com/sites/default/files/image_blog/habits-of-truly-nice-people.png" return download_image(url) + @pytest.fixture def image_without_face(): url = "https://www.bsr.org/images/heroes/bsr-focus-nature-hero.jpg" return download_image(url) + def test_initialization(pipeline): assert pipeline.verbose == 2 assert os.path.exists(pipeline.log_path) assert pipeline.face_cascade is not None + def test_invalid_verbosity(): with pytest.raises(ValueError): ProcessingPipeline(verbose=3) + def test_face_detection_with_face(pipeline, image_with_face): bbox = pipeline.get_bounding_box(image_with_face) assert bbox is not None @@ -49,10 +52,12 @@ def test_face_detection_with_face(pipeline, image_with_face): assert all(isinstance(val, (int, np.int32, np.int64)) for val in [x, y, w, h]) assert w > 0 and h > 0 + def test_face_detection_without_face(pipeline, image_without_face): bbox = pipeline.get_bounding_box(image_without_face) assert bbox is None + def test_annotation(pipeline, image_with_face): annotated = pipeline.annotate_face(image_with_face) assert annotated.shape == image_with_face.shape @@ -60,13 +65,16 @@ def test_annotation(pipeline, image_with_face): log_files = list(Path(pipeline.current_log_dir).glob("bbox.jpg")) assert len(log_files) == 1 + def test_image_processing_with_face(pipeline, image_with_face): processed = pipeline.process_image(image_with_face) - + assert processed.shape == (3, 224, 224), f"Unexpected shape: {processed.shape}" assert processed.dtype == np.float32, f"Unexpected dtype: {processed.dtype}" - assert 0 <= processed.min() <= processed.max() <= 1, "Pixel values not normalized to [0, 1]" - + assert 0 <= processed.min() <= processed.max() <= 1, ( + "Pixel values not normalized to [0, 1]" + ) + if pipeline.verbose > 0: log_files = list(Path(pipeline.current_log_dir).glob("processed.jpg")) assert len(log_files) == 1, "Processed image log not found" @@ -76,6 +84,7 @@ def test_image_processing_without_face(pipeline, image_without_face): with pytest.raises(FaceNotFoundException): pipeline.process_image(image_without_face) + def test_grayscale_conversion(pipeline, image_with_face): gray = pipeline._convert_to_gray(image_with_face) assert len(gray.shape) == 2 diff --git a/aura/tests/test_provider.py b/aura/tests/test_provider.py index 176ad52..6a159af 100644 --- a/aura/tests/test_provider.py +++ b/aura/tests/test_provider.py @@ -1,100 +1,115 @@ -import pytest +from unittest.mock import patch + import numpy as np -import cv2 -from unittest.mock import patch, MagicMock +import pytest + from aura.dataset import DatasetProvider + @pytest.fixture def dataset_provider(): - with patch('kagglehub.dataset_download') as mock_download: + with patch("kagglehub.dataset_download") as mock_download: mock_download.return_value = "mock/path" - with patch.object(DatasetProvider, '_collect_files') as mock_collect: + with patch.object(DatasetProvider, "_collect_files") as mock_collect: mock_data = [(np.zeros((64, 64, 3)), 0, "happy") for _ in range(10)] mock_collect.return_value = mock_data - provider = DatasetProvider(target_size=(224, 224), split = 0.7) + provider = DatasetProvider(target_size=(224, 224), split=0.7) return provider - + + def test_collect_files_with_augmentation(dataset_provider): mock_files = [("mock/path", [], ["image1.jpg", "image2.jpg"])] with patch("os.walk", return_value=mock_files): - with patch.object(dataset_provider, '_get_emotion', return_value="happy"): - with patch('cv2.imread', return_value=np.zeros((64, 64, 3))): + with patch.object(dataset_provider, "_get_emotion", return_value="happy"): + with patch("cv2.imread", return_value=np.zeros((64, 64, 3))): dataset = dataset_provider._collect_files("mock/path", augment=True) - assert len(dataset) == 8 + assert len(dataset) == 8 assert all(len(item) == 3 for item in dataset) first_img, first_label, first_emotion = dataset[0] assert isinstance(first_img, np.ndarray) assert isinstance(first_label, int) assert isinstance(first_emotion, str) + def test_collect_files_without_augmentation(dataset_provider): mock_files = [("mock/path", [], ["image1.jpg", "image2.jpg"])] with patch("os.walk", return_value=mock_files): - with patch.object(dataset_provider, '_get_emotion', return_value="happy"): - with patch('cv2.imread', return_value=np.zeros((64, 64, 3))): + with patch.object(dataset_provider, "_get_emotion", return_value="happy"): + with patch("cv2.imread", return_value=np.zeros((64, 64, 3))): dataset = dataset_provider._collect_files("mock/path", augment=False) - assert len(dataset) == 2 + assert len(dataset) == 2 assert all(len(item) == 3 for item in dataset) first_img, first_label, first_emotion = dataset[0] assert isinstance(first_img, np.ndarray) assert isinstance(first_label, int) assert isinstance(first_emotion, str) + def test_init(dataset_provider): assert dataset_provider.target_size == (224, 224) assert len(dataset_provider.emotion_labels) == 8 assert len(dataset_provider.train) == 7 assert len(dataset_provider.test) == 3 + def test_sample_valid_index(dataset_provider): - img, label, emotion = dataset_provider.sample(0, source = "test") + img, label, emotion = dataset_provider.sample(0, source="test") assert isinstance(img, np.ndarray) assert img.shape == (64, 64, 3) assert isinstance(label, int) assert 0 <= label <= 7 assert emotion in dataset_provider.emotion_labels + def test_sample_invalid_index(dataset_provider): with pytest.raises(ValueError): - dataset_provider.sample(100, source = "train") + dataset_provider.sample(100, source="train") + def test_get_next_image_batch(dataset_provider): batch_size = 2 - batch_generator = dataset_provider.get_next_image_batch(batch_size, source = "train") + batch_generator = dataset_provider.get_next_image_batch(batch_size, source="train") first_batch = next(batch_generator) - + assert len(first_batch) == batch_size assert isinstance(first_batch[0][0], np.ndarray) assert isinstance(first_batch[0][1], int) + def test_get_next_image_batch_invalid_size(dataset_provider): with pytest.raises(ValueError): - next(dataset_provider.get_next_image_batch(100, source = "train")) + next(dataset_provider.get_next_image_batch(100, source="train")) + def test_resize_image(dataset_provider): test_image = np.zeros((100, 150, 3), dtype=np.uint8) resized = dataset_provider._resize_image(test_image) assert resized.shape == (224, 224, 3) + def test_random_rotation(dataset_provider): test_image = np.zeros((64, 64, 3), dtype=np.uint8) rotated = dataset_provider._random_rotation(test_image) assert rotated.shape == test_image.shape + def test_flip(dataset_provider): test_image = np.zeros((64, 64, 3), dtype=np.uint8) flipped = dataset_provider._flip(test_image) assert flipped.shape == test_image.shape + def test_random_brightness(dataset_provider): test_image = np.zeros((64, 64, 3), dtype=np.uint8) brightened = dataset_provider._random_brightness(test_image) assert brightened.shape == test_image.shape + def test_get_emotion_invalid_path(dataset_provider): with pytest.raises(ValueError): dataset_provider._get_emotion("invalid_path") + def test_set_black_background(dataset_provider): test_image = np.ones((64, 64, 3), dtype=np.uint8) * 255 processed = dataset_provider._set_black_background(test_image, threshold=20) diff --git a/aura/tests/test_queue.py b/aura/tests/test_queue.py index 3f0fdb0..7e1ccf5 100644 --- a/aura/tests/test_queue.py +++ b/aura/tests/test_queue.py @@ -2,28 +2,33 @@ import time from aura.queue import QueueManager, Priority, UserState, QueueError + @pytest.fixture def queue_manager(): return QueueManager(5, 100, 3) + def test_queue_initialization(queue_manager): assert queue_manager.max_session_time == 300 assert queue_manager.max_queue_size == 100 assert queue_manager.max_reconnect_attempts == 3 + def test_add_to_queue(queue_manager): position = queue_manager.add_to_queue("user1", Priority.normal()) assert position == 1 + def test_priority_ordering(queue_manager): queue_manager.add_to_queue("user1", Priority.normal()) queue_manager.add_to_queue("user2", Priority.high()) queue_manager.add_to_queue("user3", Priority.low()) - + queue_list = list(queue_manager.queue) - assert queue_list[0].id == "user2" # High priority first - assert queue_list[1].id == "user1" # Normal priority second - assert queue_list[2].id == "user3" # Low priority last + assert queue_list[0].id == "user2" + assert queue_list[1].id == "user1" + assert queue_list[2].id == "user3" + def test_duplicate_user(queue_manager): queue_manager.add_to_queue("user1", None) @@ -31,6 +36,7 @@ def test_duplicate_user(queue_manager): queue_manager.add_to_queue("user1", None) assert "already in queue" in str(exc_info.value).lower() + def test_queue_full(): queue = QueueManager(5, 2, 3) queue.add_to_queue("user1", None) @@ -39,31 +45,36 @@ def test_queue_full(): queue.add_to_queue("user3", None) assert "queue is full" in str(exc_info.value).lower() + def test_state_transitions(queue_manager): queue_manager.add_to_queue("user1", None) queue_manager.update_user_state("user1", UserState.CONNECTING) queue_manager.update_user_state("user1", UserState.CONNECTED) queue_manager.update_user_state("user1", UserState.DISCONNECTED) + def test_cleanup_timeouts(queue_manager): queue_manager.add_to_queue("user1", None) # Force timeout by waiting - time.sleep(31) + time.sleep(31) timed_out = queue_manager.cleanup_timeouts() assert len(timed_out) == 1 assert timed_out[0] == "user1" + def test_remove_from_queue(queue_manager): queue_manager.add_to_queue("user1", None) assert queue_manager.remove_from_queue("user1") is True assert queue_manager.remove_from_queue("nonexistent") is False + def test_invalid_state_transition(queue_manager): queue_manager.add_to_queue("user1", None) with pytest.raises(QueueError) as exc_info: queue_manager.update_user_state("user1", UserState.DISCONNECTED) assert "invalid state transition" in str(exc_info.value).lower() + def test_user_not_found(queue_manager): with pytest.raises(QueueError) as exc_info: queue_manager.update_user_state("nonexistent", UserState.CONNECTED) diff --git a/aura/tests/test_signaling.py b/aura/tests/test_signaling.py index 72d0d6e..b810cf8 100644 --- a/aura/tests/test_signaling.py +++ b/aura/tests/test_signaling.py @@ -1,47 +1,50 @@ -import pytest import asyncio -import websockets import json + +import pytest +import websockets + from aura.webrtc import SignalingServer -import sys -import time @pytest.fixture(scope="function") def unused_tcp_port(): """Get an unused TCP port.""" import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] + @pytest.fixture(scope="function") def signaling_server(unused_tcp_port): server = SignalingServer(port=unused_tcp_port) server.start() async def is_server_ready(): - uri = f'ws://localhost:{unused_tcp_port}/signaling' - for _ in range(10): + uri = f"ws://localhost:{unused_tcp_port}/signaling" + for _ in range(10): try: async with websockets.connect(uri): - return True + return True except (ConnectionRefusedError, OSError): - await asyncio.sleep(0.1) - return False + await asyncio.sleep(0.1) + return False if not asyncio.run(is_server_ready()): raise RuntimeError("Server did not start in time") yield server + @pytest.fixture async def websocket_clients(signaling_server, unused_tcp_port): """Create two WebSocket clients for testing""" client1 = None client2 = None try: - uri = f'ws://localhost:{unused_tcp_port}/signaling' + uri = f"ws://localhost:{unused_tcp_port}/signaling" client1 = await websockets.connect(uri) client2 = await websockets.connect(uri) yield client1, client2 @@ -51,6 +54,7 @@ async def websocket_clients(signaling_server, unused_tcp_port): if client2: await client2.close() + async def send_and_receive(sender, receiver, message): """Helper function to send and receive WebSocket messages""" await sender.send(json.dumps(message)) @@ -63,119 +67,123 @@ def test_server_initialization(): server = SignalingServer() assert isinstance(server, SignalingServer) + @pytest.mark.asyncio async def test_client_connection(signaling_server, unused_tcp_port): """Test client connection and count""" initial_count = signaling_server.get_client_count() assert initial_count == 0 - uri = f'ws://localhost:{unused_tcp_port}/signaling' - async with websockets.connect(uri) as websocket: + uri = f"ws://localhost:{unused_tcp_port}/signaling" + async with websockets.connect(uri): await asyncio.sleep(0.1) assert signaling_server.get_client_count() == 1 + @pytest.mark.asyncio async def test_signaling_message_exchange(signaling_server, unused_tcp_port): """Test sending and receiving signaling messages between clients""" - uri = f'ws://localhost:{unused_tcp_port}/signaling' - + uri = f"ws://localhost:{unused_tcp_port}/signaling" + async with websockets.connect(uri) as client1, websockets.connect(uri) as client2: await asyncio.sleep(0.1) - - offer_message = { - "type": "offer", - "sdp": "test_sdp_offer" - } - + + offer_message = {"type": "offer", "sdp": "test_sdp_offer"} + await client1.send(json.dumps(offer_message)) response = await client2.recv() response_data = json.loads(response) - + assert response_data["type"] == "offer" assert response_data["sdp"] == "test_sdp_offer" + @pytest.mark.asyncio async def test_broadcast_message(signaling_server, unused_tcp_port): """Test broadcasting messages to all clients""" - uri = f'ws://localhost:{unused_tcp_port}/signaling' - + uri = f"ws://localhost:{unused_tcp_port}/signaling" + async with websockets.connect(uri) as client1, websockets.connect(uri) as client2: await asyncio.sleep(0.1) - + test_message = "broadcast test message" signaling_server.broadcast_message(test_message) - + msg1 = await client1.recv() msg2 = await client2.recv() - + assert msg1 == test_message assert msg2 == test_message + @pytest.mark.asyncio async def test_disconnect_client(signaling_server, unused_tcp_port): """Test disconnecting a client""" - uri = f'ws://localhost:{unused_tcp_port}/signaling' - - async with websockets.connect(uri) as client: + uri = f"ws://localhost:{unused_tcp_port}/signaling" + + async with websockets.connect(uri): await asyncio.sleep(0.1) - + connected_clients = signaling_server.get_connected_clients() assert len(connected_clients) == 1 client_id = connected_clients[0] - - assert signaling_server.disconnect_client(client_id) == True + + assert signaling_server.disconnect_client(client_id) await asyncio.sleep(0.1) - + assert signaling_server.get_client_count() == 0 - assert signaling_server.disconnect_client(client_id) == False + assert not signaling_server.disconnect_client(client_id) + @pytest.mark.asyncio async def test_send_to_client(signaling_server, unused_tcp_port): """Test sending messages to specific clients""" - uri = f'ws://localhost:{unused_tcp_port}/signaling' - - async with websockets.connect(uri) as client1, websockets.connect(uri) as client2: + uri = f"ws://localhost:{unused_tcp_port}/signaling" + + async with websockets.connect(uri) as client1, websockets.connect(uri): await asyncio.sleep(0.1) - + clients = signaling_server.get_connected_clients() client_id = clients[0] - + test_message = "test message" - assert signaling_server.send_to_client(client_id, test_message) == True - - assert signaling_server.send_to_client("invalid_id", test_message) == False - + assert signaling_server.send_to_client(client_id, test_message) + + assert not signaling_server.send_to_client("invalid_id", test_message) + received_msg = await client1.recv() assert received_msg == test_message + @pytest.mark.asyncio async def test_server_status(signaling_server, unused_tcp_port): """Test server status information""" - uri = f'ws://localhost:{unused_tcp_port}/signaling' - + uri = f"ws://localhost:{unused_tcp_port}/signaling" + status = signaling_server.get_server_status() assert isinstance(status, dict) assert "ip" in status assert "port" in status assert "connected_clients" in status assert status["port"] == str(unused_tcp_port) - - async with websockets.connect(uri) as client: + + async with websockets.connect(uri): await asyncio.sleep(0.1) updated_status = signaling_server.get_server_status() assert updated_status["connected_clients"] == "1" + @pytest.mark.asyncio async def test_client_capacity(signaling_server, unused_tcp_port): """Test server capacity handling""" - uri = f'ws://localhost:{unused_tcp_port}/signaling' - - assert signaling_server.is_at_capacity() == False - - async with websockets.connect(uri) as client1: + uri = f"ws://localhost:{unused_tcp_port}/signaling" + + assert not signaling_server.is_at_capacity() + + async with websockets.connect(uri): await asyncio.sleep(0.1) - assert signaling_server.is_at_capacity() == False - - async with websockets.connect(uri) as client2: + assert not signaling_server.is_at_capacity() + + async with websockets.connect(uri): await asyncio.sleep(0.1) - assert signaling_server.is_at_capacity() == True + assert signaling_server.is_at_capacity() diff --git a/aura/tests/test_streamer.py b/aura/tests/test_streamer.py index f916bf2..41321ef 100644 --- a/aura/tests/test_streamer.py +++ b/aura/tests/test_streamer.py @@ -4,76 +4,401 @@ import pytest import asyncio import websockets +import time +import threading from contextlib import closing -from aura.webrtc import VideoStreamer +from aura.webrtc import ( + VideoStreamer, + DEFAULT_WS_IP, + DEFAULT_TEST_IVF_DIR, + TEST_IP_ADDRESSES, +) + @pytest.fixture(autouse=True) def setup_video_dir(): - video_dir = "/tmp/video" + video_dir = DEFAULT_TEST_IVF_DIR os.makedirs(video_dir, exist_ok=True) yield shutil.rmtree(video_dir, ignore_errors=True) + @pytest.fixture def free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] + @pytest.fixture def streamer(free_port): streamer = VideoStreamer( - ws_ip="127.0.0.1", - ws_port=free_port, - ivf_dir="/tmp/video" + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=DEFAULT_TEST_IVF_DIR ) yield streamer try: streamer.close_connection() - except: + except RuntimeError: pass + @pytest.fixture async def mock_websocket_server(free_port): async def handler(websocket): async for message in websocket: await websocket.send(message) - + server = await websockets.serve( - handler, - "127.0.0.1", - free_port, - reuse_address=True, - reuse_port=True + handler, DEFAULT_WS_IP, free_port, reuse_address=True, reuse_port=True ) yield server server.close() await server.wait_closed() + @pytest.fixture def sample_ivf_file(tmp_path): ivf_path = tmp_path / "test.ivf" with open(ivf_path, "wb") as f: f.write(b"DKIF\x00\x00\x00\x00") - f.write(b"\x00" * 24) # Dummy header data + f.write(b"\x00" * 24) + return ivf_path + + +@pytest.fixture +def valid_ivf_file(tmp_path): + ivf_path = tmp_path / "valid.ivf" + with open(ivf_path, "wb") as f: + f.write(b"DKIF") + f.write((0).to_bytes(2, byteorder="little")) + f.write((32).to_bytes(2, byteorder="little")) + f.write(b"VP80") + f.write((640).to_bytes(2, byteorder="little")) + f.write((480).to_bytes(2, byteorder="little")) + f.write((1).to_bytes(4, byteorder="little")) + f.write((30).to_bytes(4, byteorder="little")) + f.write((100).to_bytes(4, byteorder="little")) + f.write((0).to_bytes(4, byteorder="little")) return ivf_path -def test_take_screenshot_no_connection(streamer): - with pytest.raises(RuntimeError, match="No active peer connection"): - streamer.take_screenshot() + +@pytest.fixture +def mock_signaling_server(free_port): + received_messages = [] + + async def handler(websocket, path): + try: + async for message in websocket: + received_messages.append(message) + await websocket.send(message) + except websockets.exceptions.ConnectionClosed: + pass + + server = None + + async def start_server(): + nonlocal server + server = await websockets.serve( + handler, DEFAULT_WS_IP, free_port, reuse_address=True + ) + return server + + def get_messages(): + return received_messages.copy() + + return start_server, get_messages + def test_get_stats_no_connection(streamer): with pytest.raises(RuntimeError, match="No active peer connection"): streamer.get_stats() + def test_close_connection_no_connection(streamer): with pytest.raises(RuntimeError, match="No active peer connection"): streamer.close_connection() + def test_video_directory_monitoring(streamer, sample_ivf_file): streamer.start_streaming() shutil.copy(sample_ivf_file, "/tmp/video/test.ivf") import time + time.sleep(1) os.remove("/tmp/video/test.ivf") + + +class TestVideoStreamerInitialization: + def test_streamer_initialization(self, free_port): + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=DEFAULT_TEST_IVF_DIR + ) + assert streamer is not None + + def test_streamer_initialization_with_different_ips(self, free_port): + for ip in TEST_IP_ADDRESSES: + streamer = VideoStreamer( + ws_ip=ip, ws_port=free_port, ivf_dir=DEFAULT_TEST_IVF_DIR + ) + assert streamer is not None + + def test_streamer_initialization_with_different_ports(self): + for port in [8080, 9000, 3000]: + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=port, ivf_dir=DEFAULT_TEST_IVF_DIR + ) + assert streamer is not None + + def test_streamer_initialization_with_different_directories( + self, free_port, tmp_path + ): + test_dirs = [DEFAULT_TEST_IVF_DIR, str(tmp_path / "custom"), "/tmp/test_stream"] + for directory in test_dirs: + os.makedirs(directory, exist_ok=True) + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=directory + ) + assert streamer is not None + + +class TestVideoStreamerErrorHandling: + def test_get_stats_no_connection_detailed(self, streamer): + with pytest.raises(RuntimeError) as exc_info: + streamer.get_stats() + assert "No active peer connection" in str(exc_info.value) + + def test_close_connection_no_connection_detailed(self, streamer): + with pytest.raises(RuntimeError) as exc_info: + streamer.close_connection() + assert "No active peer connection" in str(exc_info.value) + + def test_multiple_close_connection_calls(self, streamer): + with pytest.raises(RuntimeError, match="No active peer connection"): + streamer.close_connection() + + with pytest.raises(RuntimeError, match="No active peer connection"): + streamer.close_connection() + + def test_get_stats_after_close_connection(self, streamer): + with pytest.raises(RuntimeError, match="No active peer connection"): + streamer.close_connection() + + with pytest.raises(RuntimeError, match="No active peer connection"): + streamer.get_stats() + + +class TestVideoStreamerStreaming: + def test_start_streaming_does_not_raise(self, streamer): + streamer.start_streaming() + + def test_start_streaming_multiple_times(self, streamer): + streamer.start_streaming() + streamer.start_streaming() + + def test_start_streaming_with_invalid_directory(self, free_port, tmp_path): + invalid_dir = str(tmp_path / "nonexistent") + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=invalid_dir + ) + + streamer.start_streaming() + + +class TestFileMonitoring: + def test_video_directory_monitoring_with_multiple_files( + self, streamer, sample_ivf_file + ): + streamer.start_streaming() + + for i in range(3): + shutil.copy(sample_ivf_file, f"/tmp/video/test_{i}.ivf") + time.sleep(0.1) + + time.sleep(1) + + for i in range(3): + os.remove(f"/tmp/video/test_{i}.ivf") + + def test_video_directory_monitoring_with_non_ivf_files(self, streamer, tmp_path): + streamer.start_streaming() + + non_ivf_files = ["test.txt", "test.mp4", "test.avi", "test.mov"] + for filename in non_ivf_files: + with open(f"/tmp/video/{filename}", "w") as f: + f.write("test content") + time.sleep(0.1) + + time.sleep(1) + + for filename in non_ivf_files: + os.remove(f"/tmp/video/{filename}") + + def test_video_directory_monitoring_with_valid_ivf(self, streamer, valid_ivf_file): + streamer.start_streaming() + shutil.copy(valid_ivf_file, "/tmp/video/valid_test.ivf") + time.sleep(1) + os.remove("/tmp/video/valid_test.ivf") + + def test_video_directory_monitoring_with_corrupted_ivf(self, streamer, tmp_path): + streamer.start_streaming() + + corrupted_ivf = tmp_path / "corrupted.ivf" + with open(corrupted_ivf, "wb") as f: + f.write(b"INVALID_HEADER") + + shutil.copy(corrupted_ivf, "/tmp/video/corrupted.ivf") + time.sleep(1) + os.remove("/tmp/video/corrupted.ivf") + + def test_video_directory_monitoring_file_removal(self, streamer, sample_ivf_file): + streamer.start_streaming() + + shutil.copy(sample_ivf_file, "/tmp/video/test.ivf") + time.sleep(0.5) + + os.remove("/tmp/video/test.ivf") + time.sleep(0.5) + + shutil.copy(sample_ivf_file, "/tmp/video/test.ivf") + time.sleep(0.5) + os.remove("/tmp/video/test.ivf") + + +class TestWebSocketIntegration: + @pytest.mark.asyncio + async def test_streamer_with_mock_websocket_server( + self, free_port, sample_ivf_file + ): + async def handler(websocket, path): + async for message in websocket: + await websocket.send(message) + + server = await websockets.serve( + handler, DEFAULT_WS_IP, free_port, reuse_address=True + ) + + try: + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=DEFAULT_TEST_IVF_DIR + ) + streamer.start_streaming() + + await asyncio.sleep(1) + + shutil.copy(sample_ivf_file, "/tmp/video/test.ivf") + await asyncio.sleep(1) + + os.remove("/tmp/video/test.ivf") + + finally: + server.close() + await server.wait_closed() + + @pytest.mark.asyncio + async def test_streamer_connection_failure(self, free_port, sample_ivf_file): + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=DEFAULT_TEST_IVF_DIR + ) + streamer.start_streaming() + + await asyncio.sleep(1) + + shutil.copy(sample_ivf_file, "/tmp/video/test.ivf") + await asyncio.sleep(1) + os.remove("/tmp/video/test.ivf") + + +class TestConcurrentOperations: + def test_concurrent_file_operations(self, streamer, sample_ivf_file): + streamer.start_streaming() + + def add_file(file_num): + shutil.copy(sample_ivf_file, f"/tmp/video/concurrent_{file_num}.ivf") + time.sleep(0.1) + os.remove(f"/tmp/video/concurrent_{file_num}.ivf") + + threads = [] + for i in range(5): + thread = threading.Thread(target=add_file, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + time.sleep(1) + + def test_start_streaming_while_processing_files(self, streamer, sample_ivf_file): + streamer.start_streaming() + + shutil.copy(sample_ivf_file, "/tmp/video/test.ivf") + time.sleep(0.5) + + streamer.start_streaming() + + shutil.copy(sample_ivf_file, "/tmp/video/test2.ivf") + time.sleep(0.5) + + os.remove("/tmp/video/test.ivf") + os.remove("/tmp/video/test2.ivf") + + +class TestEdgeCases: + def test_empty_directory_monitoring(self, streamer): + streamer.start_streaming() + time.sleep(1) + + def test_very_large_ivf_file(self, streamer, tmp_path): + large_ivf = tmp_path / "large.ivf" + with open(large_ivf, "wb") as f: + f.write(b"DKIF\x00\x00\x00\x00") + f.write(b"\x00" * 24) + f.write(b"\x00" * 1024 * 1024) + + streamer.start_streaming() + shutil.copy(large_ivf, "/tmp/video/large.ivf") + time.sleep(1) + os.remove("/tmp/video/large.ivf") + + def test_rapid_file_changes(self, streamer, sample_ivf_file): + streamer.start_streaming() + + for i in range(10): + shutil.copy(sample_ivf_file, f"/tmp/video/rapid_{i}.ivf") + time.sleep(0.05) + os.remove(f"/tmp/video/rapid_{i}.ivf") + time.sleep(0.05) + + def test_streamer_with_special_characters_in_path(self, free_port, tmp_path): + special_dir = str(tmp_path / "test with spaces & symbols!") + os.makedirs(special_dir, exist_ok=True) + + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=special_dir + ) + streamer.start_streaming() + time.sleep(1) + + +class TestCleanupAndTeardown: + def test_streamer_cleanup_in_fixture(self, free_port): + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port, ivf_dir=DEFAULT_TEST_IVF_DIR + ) + streamer.start_streaming() + + def test_multiple_streamers_cleanup(self, free_port): + streamers = [] + for i in range(3): + streamer = VideoStreamer( + ws_ip=DEFAULT_WS_IP, ws_port=free_port + i, ivf_dir=f"/tmp/video_{i}" + ) + streamer.start_streaming() + streamers.append(streamer) + + for streamer in streamers: + try: + streamer.close_connection() + except RuntimeError: + pass diff --git a/aura/webrtc/__init__.py b/aura/webrtc/__init__.py index c9d97a5..2167173 100644 --- a/aura/webrtc/__init__.py +++ b/aura/webrtc/__init__.py @@ -2,4 +2,38 @@ from aura import VideoStreamer __version__ = "0.1.0" -__all__ = ["SignalingServer", "VideoStreamer"] + +# WebSocket and server configuration +DEFAULT_WS_IP = "127.0.0.1" +DEFAULT_WS_PORT = 3030 +DEFAULT_SIGNALING_PORT = 3030 + +# Directory paths +DEFAULT_IVF_DIR = "./ivf_files" +DEFAULT_LOGS_DIR = "../logs" +DEFAULT_TEST_IVF_DIR = "/tmp/video" + +# Processing configuration +DEFAULT_VERBOSE_LEVEL = 2 + +# Timing configuration +CAPTURE_DELAY_SECONDS = 2 +SERVER_SLEEP_SECONDS = 10 + +# Test configuration +TEST_IP_ADDRESSES = ["127.0.0.1", "localhost", "0.0.0.0"] + +__all__ = [ + "SignalingServer", + "VideoStreamer", + "DEFAULT_WS_IP", + "DEFAULT_WS_PORT", + "DEFAULT_SIGNALING_PORT", + "DEFAULT_IVF_DIR", + "DEFAULT_LOGS_DIR", + "DEFAULT_TEST_IVF_DIR", + "DEFAULT_VERBOSE_LEVEL", + "CAPTURE_DELAY_SECONDS", + "SERVER_SLEEP_SECONDS", + "TEST_IP_ADDRESSES", +] diff --git a/aura/webrtc/signaling.py b/aura/webrtc/signaling.py index 920f451..ba013f8 100644 --- a/aura/webrtc/signaling.py +++ b/aura/webrtc/signaling.py @@ -1,14 +1,19 @@ import socket -from aura import SignalingServer, VideoStreamer +from aura import SignalingServer import sys -import socket import signal import time from aura.camera import ProcessingPipeline, FaceNotFoundException import os import numpy as np import cv2 -from datetime import datetime +from aura.webrtc import ( + DEFAULT_SIGNALING_PORT, + DEFAULT_LOGS_DIR, + DEFAULT_VERBOSE_LEVEL, + CAPTURE_DELAY_SECONDS, + SERVER_SLEEP_SECONDS, +) def get_free_port(): @@ -24,16 +29,14 @@ def signal_handler(sig, frame): sys.exit(0) -def capture_images(server, output_dir="../logs", verbose=2): +def capture_images(server, output_dir=DEFAULT_LOGS_DIR, verbose=DEFAULT_VERBOSE_LEVEL): """Capture, process, and save images with face detection""" os.makedirs(output_dir, exist_ok=True) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - pipeline = ProcessingPipeline(log_path=output_dir, verbose=verbose) server.capture() - time.sleep(2) + time.sleep(CAPTURE_DELAY_SECONDS) image_bytes = server.get_capture() if image_bytes: @@ -41,13 +44,13 @@ def capture_images(server, output_dir="../logs", verbose=2): image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) try: - annotated_image = pipeline.annotate_face(image) - processed_face = pipeline.process_image(image) + pipeline.annotate_face(image) + pipeline.process_image(image) # TODO: Send processed face to the model except FaceNotFoundException: - print(f"No face detected in captured image.") + print("No face detected in captured image.") except Exception as e: print(f"Error processing image: {str(e)}") else: @@ -57,7 +60,7 @@ def capture_images(server, output_dir="../logs", verbose=2): def main(): signal.signal(signal.SIGINT, signal_handler) - port = 3030 + port = DEFAULT_SIGNALING_PORT server = SignalingServer(port=port) server.start() @@ -67,9 +70,9 @@ def main(): try: while True: - time.sleep(10) + time.sleep(SERVER_SLEEP_SECONDS) capture_images(server) - time.sleep(10) + time.sleep(SERVER_SLEEP_SECONDS) except KeyboardInterrupt: print("\nShutting down signaling server...") diff --git a/aura/webrtc/streamer.py b/aura/webrtc/streamer.py index 5edb703..12c3930 100644 --- a/aura/webrtc/streamer.py +++ b/aura/webrtc/streamer.py @@ -1,8 +1,9 @@ -import socket -import sys import signal +import sys import time + from aura import VideoStreamer +from aura.webrtc import DEFAULT_WS_IP, DEFAULT_WS_PORT, DEFAULT_IVF_DIR def signal_handler(sig, frame): @@ -14,9 +15,9 @@ def signal_handler(sig, frame): def main(): signal.signal(signal.SIGINT, signal_handler) - ws_ip = "127.0.0.1" # WebSocket IP address - ws_port = 3030 # WebSocket port - ivf_dir = "./ivf_files" # Match RustWebRTC directory structure + ws_ip = DEFAULT_WS_IP + ws_port = DEFAULT_WS_PORT + ivf_dir = DEFAULT_IVF_DIR streamer = VideoStreamer(ws_ip, ws_port, ivf_dir) streamer.start_streaming() diff --git a/src/lib.rs b/src/lib.rs index abe84e6..3d16f73 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,14 @@ +//! # Aura WebRTC Library +//! +//! A Rust library providing WebRTC signaling server and video streaming capabilities +//! for real-time communication applications. +//! +//! ## Features +//! +//! - **SignalingServer**: WebSocket-based signaling server for WebRTC peer connections +//! - **VideoStreamer**: Real-time video streaming with IVF file support +//! - **SignalingMessage**: Comprehensive message types for WebRTC signaling + use pyo3::prelude::*; mod server; @@ -6,41 +17,20 @@ pub use signaling_types::SignalingMessage; mod streamer; +#[cfg(test)] +mod tests; + use server::SignalingServer; use streamer::VideoStreamer; +/// Python module initialization for the Aura WebRTC library. +/// +/// This function registers the main classes with the Python interpreter: +/// - `SignalingServer`: For managing WebRTC signaling connections +/// - `VideoStreamer`: For streaming video content via WebRTC #[pymodule] fn aura(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_server_creation() { - let server = SignalingServer::new(Some(3031), Some("127.0.0.1".to_string())); - assert_eq!(server.port, 3031); - assert_eq!(server.ip, "127.0.0.1"); - } - - #[test] - fn test_server_default_values() { - let server = SignalingServer::new(None, None); - assert_eq!(server.port, 3030); - assert_eq!(server.ip, "127.0.0.1"); - } - - #[test] - fn test_message_serialization() { - let offer = SignalingMessage::Offer { - sdp: "test_sdp".to_string(), - }; - let serialized = serde_json::to_string(&offer).unwrap(); - assert!(serialized.contains("offer")); - assert!(serialized.contains("test_sdp")); - } -} diff --git a/src/server.rs b/src/server.rs index 044b494..d7e8723 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,3 +1,8 @@ +//! # Signaling Server Module +//! +//! This module provides a WebSocket-based signaling server for WebRTC peer connections. +//! It handles client connections, message routing, and image capture functionality. + use crate::SignalingMessage; use anyhow::Result; @@ -17,18 +22,33 @@ use uuid::Uuid; use warp::ws::{Message, WebSocket}; use warp::Filter; +/// Type alias for managing connected WebSocket peers pub type Peers = Arc>>>>>; +/// WebSocket-based signaling server for WebRTC peer connections. +/// +/// The SignalingServer manages client connections, handles WebRTC signaling messages, +/// and provides image capture functionality for connected clients. #[pyclass] pub struct SignalingServer { + /// Map of connected client IDs to their WebSocket senders pub peers: Peers, + /// Server port number pub port: u16, + /// Server IP address pub ip: String, + /// Last captured image data from clients pub last_captured_image: Arc>>>, } #[pymethods] impl SignalingServer { + /// Create a new SignalingServer instance. + /// + /// # Arguments + /// + /// * `port` - Optional port number (defaults to 3030) + /// * `ip` - Optional IP address (defaults to "127.0.0.1") #[new] #[pyo3(signature = (port=None, ip=None))] #[pyo3( @@ -43,6 +63,9 @@ impl SignalingServer { } } + /// Start the signaling server. + /// + /// This method starts the WebSocket server and begins accepting client connections. #[pyo3(text_signature = "($self)")] pub fn start(&self, _py: Python<'_>) -> PyResult<()> { let peers = self.peers.clone(); @@ -70,6 +93,11 @@ impl SignalingServer { Ok(()) } + /// Trigger image capture from connected clients. + /// + /// # Arguments + /// + /// * `client_id` - Optional specific client ID to capture from. If None, captures from all clients. #[pyo3(signature = (client_id=None))] #[pyo3(text_signature = "(self, client_id: Optional[str] = None) -> None")] pub fn capture(&self, client_id: Option) -> PyResult<()> { @@ -101,6 +129,16 @@ impl SignalingServer { Ok(()) } + /// Send a message to a specific connected client. + /// + /// # Arguments + /// + /// * `client_id` - ID of the target client + /// * `message` - Message content to send + /// + /// # Returns + /// + /// True if the message was sent successfully, False otherwise. #[pyo3(text_signature = "(self, client_id: str, message: str) -> bool")] pub fn send_to_client(&self, client_id: String, message: String) -> PyResult { let peers = self.peers.clone(); @@ -118,6 +156,15 @@ impl SignalingServer { Ok(result) } + /// Check if a specific client is currently connected. + /// + /// # Arguments + /// + /// * `client_id` - ID of the client to check + /// + /// # Returns + /// + /// True if the client is connected, False otherwise. #[pyo3(text_signature = "(self, client_id: str) -> bool")] pub fn is_client_connected(&self, client_id: String) -> PyResult { let peers = self.peers.clone(); @@ -128,6 +175,11 @@ impl SignalingServer { Ok(is_connected) } + /// Get a list of all currently connected client IDs. + /// + /// # Returns + /// + /// A vector of client ID strings. #[pyo3(text_signature = "(self) -> List[str]")] pub fn get_connected_clients(&self) -> PyResult> { let peers = self.peers.clone(); @@ -138,6 +190,15 @@ impl SignalingServer { Ok(clients) } + /// Disconnect a specific client from the server. + /// + /// # Arguments + /// + /// * `client_id` - ID of the client to disconnect + /// + /// # Returns + /// + /// True if the client was disconnected, False if the client was not found. #[pyo3(text_signature = "(self, client_id: str) -> bool")] pub fn disconnect_client(&self, client_id: String) -> PyResult { let peers = self.peers.clone(); @@ -149,6 +210,11 @@ impl SignalingServer { Ok(was_removed) } + /// Check if the server is at capacity (2 or more clients). + /// + /// # Returns + /// + /// True if the server has reached its capacity limit. #[pyo3(text_signature = "(self) -> bool")] pub fn is_at_capacity(&self) -> PyResult { Python::with_gil(|py| { @@ -157,6 +223,11 @@ impl SignalingServer { }) } + /// Get the current server status information. + /// + /// # Returns + /// + /// A HashMap containing server status details including IP, port, and client count. #[pyo3(text_signature = "(self) -> int")] pub fn get_server_status(&self) -> PyResult> { let peers = self.peers.clone(); @@ -172,6 +243,11 @@ impl SignalingServer { Ok(status) } + /// Broadcast a message to all connected clients. + /// + /// # Arguments + /// + /// * `message` - Message content to broadcast #[pyo3(text_signature = "(self, message: str) -> None")] pub fn broadcast_message(&self, message: String, _py: Python<'_>) -> PyResult<()> { let peers = self.peers.clone(); @@ -190,6 +266,11 @@ impl SignalingServer { Ok(()) } + /// Get the current number of connected clients. + /// + /// # Returns + /// + /// The number of currently connected clients. #[pyo3(text_signature = "($self)")] pub fn get_client_count(&self, _py: Python<'_>) -> PyResult { let peers = self.peers.clone(); @@ -200,6 +281,11 @@ impl SignalingServer { Ok(count) } + /// Get the last captured image data from clients. + /// + /// # Returns + /// + /// The last captured image as bytes, or None if no image has been captured. #[pyo3(text_signature = "(self) -> Optional[bytes]")] pub fn get_capture<'py>(&self, py: Python<'py>) -> PyResult>> { let last_image = self.last_captured_image.clone(); @@ -221,12 +307,26 @@ impl SignalingServer { } } +/// Warp filter for injecting the peers map into request handlers. +/// +/// This function creates a Warp filter that provides access to the shared +/// peers map for WebSocket connection handling. fn with_peers( peers: Peers, ) -> impl Filter + Clone { warp::any().map(move || peers.clone()) } +/// Handle a new WebSocket connection from a client. +/// +/// This function manages the lifecycle of a client connection, including +/// message handling, image capture processing, and cleanup on disconnect. +/// +/// # Arguments +/// +/// * `ws` - The WebSocket connection +/// * `peers` - Shared map of connected peers +/// * `last_captured_image` - Shared storage for the last captured image async fn handle_connection( ws: WebSocket, peers: Peers, @@ -246,11 +346,9 @@ async fn handle_connection( if let Ok(text) = msg.to_str() { println!("Received message from {}: {}", client_id, text); - // Attempt to parse the message let signaling_message: Result = serde_json::from_str(text); match signaling_message { Ok(SignalingMessage::Image { data }) => { - // Handle image message println!("Handling image message from client {}", client_id); handle_image_message(data.clone()).await; @@ -264,7 +362,6 @@ async fn handle_connection( } } Ok(message) => { - // Handle other signaling messages println!("Parsed signaling message: {:?}", message); forward_message(&client_id, &message, &peers).await; } @@ -288,6 +385,13 @@ async fn handle_connection( println!("Client {} disconnected", client_id); } +/// Process and save an image message from a client. +/// +/// This function decodes base64 image data and saves it to a local file. +/// +/// # Arguments +/// +/// * `data` - Base64-encoded image data with data URL prefix async fn handle_image_message(data: String) { println!("Received image data of length: {}", data.len()); @@ -313,6 +417,15 @@ async fn handle_image_message(data: String) { } } +/// Send an image capture trigger message to a specific client. +/// +/// # Arguments +/// +/// * `sender` - WebSocket sender for the target client +/// +/// # Returns +/// +/// Result indicating success or failure of the operation. async fn trigger_image_capture( sender: Arc>>, ) -> Result<(), Box> { @@ -323,6 +436,13 @@ async fn trigger_image_capture( Ok(()) } +/// Forward a signaling message to all connected clients except the sender. +/// +/// # Arguments +/// +/// * `sender_id` - ID of the client who sent the message +/// * `message` - The signaling message to forward +/// * `peers` - Map of all connected peers async fn forward_message(sender_id: &str, message: &SignalingMessage, peers: &Peers) { let serialized_message = match serde_json::to_string(message) { Ok(json) => json, @@ -332,13 +452,13 @@ async fn forward_message(sender_id: &str, message: &SignalingMessage, peers: &Pe } }; - let peers = peers.lock().await; // Await the async Mutex lock + let peers = peers.lock().await; for (client_id, client) in peers.iter() { if client_id != sender_id { - let mut client = client.lock().await; // Await the async Mutex lock + let mut client = client.lock().await; if let Err(e) = client.send(Message::text(serialized_message.clone())).await { eprintln!("Error sending message to {}: {}", client_id, e); } } } -} \ No newline at end of file +} diff --git a/src/signaling_types.rs b/src/signaling_types.rs index fd6a4f2..df94f22 100644 --- a/src/signaling_types.rs +++ b/src/signaling_types.rs @@ -1,21 +1,57 @@ +//! # Signaling Message Types +//! +//! This module defines the message types used for WebRTC signaling communication +//! between clients and the signaling server. + use serde::{Deserialize, Serialize}; +/// WebRTC signaling messages for peer-to-peer communication setup. +/// +/// This enum represents all possible message types that can be exchanged +/// during WebRTC connection establishment and maintenance. #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "lowercase")] pub enum SignalingMessage { + /// WebRTC offer message containing SDP (Session Description Protocol) data. + /// + /// Sent by the initiating peer to propose connection parameters. Offer { + /// Session Description Protocol data describing the offer sdp: String, }, + + /// WebRTC answer message containing SDP data. + /// + /// Sent by the receiving peer in response to an offer. Answer { + /// Session Description Protocol data describing the answer sdp: String, }, + + /// ICE candidate information for NAT traversal. + /// + /// Contains network connectivity information to help establish + /// the optimal connection path between peers. Candidate { + /// ICE candidate string candidate: String, + /// SDP media identification sdp_mid: Option, + /// SDP media line index sdp_mline_index: Option, }, + + /// Image data message for screen capture or image sharing. + /// + /// Contains base64-encoded image data that can be shared between clients. Image { + /// Base64-encoded image data (typically with data URL prefix) data: String, }, + + /// Trigger message to request image capture from a client. + /// + /// Sent by the server to request that a client capture and send + /// an image of their current screen or application state. TriggerImageCapture, } diff --git a/src/streamer.rs b/src/streamer.rs index 60d3c7b..8723341 100644 --- a/src/streamer.rs +++ b/src/streamer.rs @@ -1,3 +1,8 @@ +//! # Video Streamer Module +//! +//! This module provides WebRTC-based video streaming capabilities with support for +//! IVF (Indeo Video Format) files and real-time file watching. + use crate::SignalingMessage; use anyhow::Result; @@ -21,21 +26,37 @@ use webrtc::{ media::{io::ivf_reader::IVFReader, Sample}, peer_connection::{ configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState, - sdp::session_description::RTCSessionDescription, + sdp::session_description::RTCSessionDescription, RTCPeerConnection, }, rtp_transceiver::rtp_codec::RTCRtpCodecCapability, track::track_local::{track_local_static_sample::TrackLocalStaticSample, TrackLocal}, }; +/// WebRTC video streamer for real-time video transmission. +/// +/// The VideoStreamer handles WebRTC peer connections and streams video content +/// from IVF files to connected clients through a signaling server. #[pyclass] pub struct VideoStreamer { + /// WebSocket server IP address for signaling ws_ip: String, + /// WebSocket server port for signaling ws_port: u16, + /// Directory containing IVF video files to stream ivf_dir: String, + /// Shared reference to the active WebRTC peer connection + peer_connection: Arc>>>, } #[pymethods] impl VideoStreamer { + /// Create a new VideoStreamer instance. + /// + /// # Arguments + /// + /// * `ws_ip` - IP address of the WebSocket signaling server + /// * `ws_port` - Port number of the WebSocket signaling server + /// * `ivf_dir` - Directory path containing IVF video files to stream #[pyo3(text_signature = "(self, ws_ip: str, ws_port: int, ivf_dir: str) -> VideoStreamer")] #[new] fn new(ws_ip: String, ws_port: u16, ivf_dir: String) -> Self { @@ -43,19 +64,26 @@ impl VideoStreamer { ws_ip, ws_port, ivf_dir, + peer_connection: Arc::new(Mutex::new(None)), } } + /// Start the video streaming process. + /// + /// This method initiates the WebRTC connection and begins streaming + /// video content from the specified IVF directory. #[pyo3(text_signature = "(self) -> None")] fn start_streaming(&self) -> PyResult<()> { let ws_ip = self.ws_ip.clone(); let ws_port = self.ws_port; let ivf_dir = self.ivf_dir.clone(); + let peer_connection_store = self.peer_connection.clone(); std::thread::spawn(move || { let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async move { - if let Err(e) = start_webrtc(&ws_ip, ws_port, &ivf_dir).await { + if let Err(e) = start_webrtc(&ws_ip, ws_port, &ivf_dir, peer_connection_store).await + { eprintln!("Error starting WebRTC: {}", e); } }); @@ -63,24 +91,90 @@ impl VideoStreamer { Ok(()) } + + /// Get WebRTC connection statistics. + /// + /// Returns a JSON string containing detailed statistics about the + /// current WebRTC peer connection, including bandwidth, latency, and packet loss. + #[pyo3(text_signature = "(self) -> str")] + fn get_stats(&self) -> PyResult { + let peer_connection = self.peer_connection.clone(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + if let Some(pc) = peer_connection.lock().await.as_ref() { + let stats = pc.get_stats().await; + match serde_json::to_string(&stats) { + Ok(stats_string) => Ok(stats_string), + Err(e) => Err(PyErr::new::(format!( + "Failed to serialize stats: {}", + e + ))), + } + } else { + Err(PyErr::new::( + "No active peer connection", + )) + } + }) + } + + /// Close the active WebRTC peer connection. + /// + /// This method gracefully closes the current peer connection and + /// cleans up associated resources. + #[pyo3(text_signature = "(self) -> None")] + fn close_connection(&self) -> PyResult<()> { + let peer_connection = self.peer_connection.clone(); + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + if let Some(pc) = peer_connection.lock().await.as_ref() { + if let Err(e) = pc.close().await { + return Err(PyErr::new::(format!( + "Failed to close connection: {}", + e + ))); + } + *peer_connection.lock().await = None; + Ok(()) + } else { + Err(PyErr::new::( + "No active peer connection", + )) + } + }) + } } -async fn start_webrtc(ws_ip: &str, ws_port: u16, ivf_dir: &str) -> Result<()> { - // Create MediaEngine +/// Initialize and start the WebRTC streaming session. +/// +/// This function sets up the WebRTC peer connection, establishes signaling +/// communication, and begins streaming video content from IVF files. +/// +/// # Arguments +/// +/// * `ws_ip` - WebSocket server IP address +/// * `ws_port` - WebSocket server port +/// * `ivf_dir` - Directory containing IVF video files +/// * `peer_connection_store` - Shared storage for the peer connection + +async fn start_webrtc( + ws_ip: &str, + ws_port: u16, + ivf_dir: &str, + peer_connection_store: Arc>>>, +) -> Result<()> { let mut m = MediaEngine::default(); m.register_default_codecs()?; - // Create a registry for interceptors let mut registry = Registry::new(); registry = register_default_interceptors(registry, &mut m)?; - // Create the API object let api = APIBuilder::new() .with_media_engine(m) .with_interceptor_registry(registry) .build(); - // Prepare the configuration let config = RTCConfiguration { ice_servers: vec![RTCIceServer { urls: vec!["stun:stun.l.google.com:19302".to_owned()], @@ -89,10 +183,9 @@ async fn start_webrtc(ws_ip: &str, ws_port: u16, ivf_dir: &str) -> Result<()> { ..Default::default() }; - // Create a new RTCPeerConnection let peer_connection = Arc::new(api.new_peer_connection(config).await?); + *peer_connection_store.lock().await = Some(Arc::clone(&peer_connection)); - // Create video track let video_track = Arc::new(TrackLocalStaticSample::new( RTCRtpCodecCapability { mime_type: MIME_TYPE_VP8.to_owned(), @@ -102,30 +195,25 @@ async fn start_webrtc(ws_ip: &str, ws_port: u16, ivf_dir: &str) -> Result<()> { "webcam".to_owned(), )); - // Add track to peer connection let rtp_sender = peer_connection .add_track(Arc::clone(&video_track) as Arc) .await?; - // Handle RTCP packets tokio::spawn(async move { let mut rtcp_buf = vec![0u8; 1500]; while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {} }); - // Connect to signaling server let (ws_stream, _) = connect_async(format!("ws://{}:{}/signaling", ws_ip, ws_port)).await?; - let (mut write, mut read) = ws_stream.split(); + let (write, mut read) = ws_stream.split(); let write = Arc::new(Mutex::new(write)); let pc = Arc::clone(&peer_connection); - // Handle connection state changes peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { println!("Connection State has changed: {s}"); Box::pin(async {}) })); - // Handle incoming messages let write_clone = Arc::clone(&write); tokio::spawn(async move { while let Some(msg) = read.next().await { @@ -170,7 +258,9 @@ async fn start_webrtc(ws_ip: &str, ws_port: u16, ivf_dir: &str) -> Result<()> { println!("Received image message - ignoring in WebRTC context"); } SignalingMessage::TriggerImageCapture => { - println!("Received trigger capture message - ignoring in WebRTC context"); + println!( + "Received trigger capture message - ignoring in WebRTC context" + ); } } } @@ -179,11 +269,20 @@ async fn start_webrtc(ws_ip: &str, ws_port: u16, ivf_dir: &str) -> Result<()> { }); println!("Starting video stream..."); - watchand_stream_video(ivf_dir, video_track).await?; + watch_and_stream_video(ivf_dir, video_track).await?; Ok(()) } +/// Write video frames from an IVF file to a WebRTC track. +/// +/// This function reads an IVF file frame by frame and writes the video data +/// to the specified WebRTC track at the appropriate frame rate. +/// +/// # Arguments +/// +/// * `path` - Path to the IVF video file +/// * `track` - WebRTC track to write video data to async fn write_video_to_track(path: &str, track: Arc) -> Result<()> { let file = File::open(path)?; let reader = BufReader::new(file); @@ -207,12 +306,18 @@ async fn write_video_to_track(path: &str, track: Arc) -> } } -//File watcher -async fn watchand_stream_video(directory: &str, track: Arc) -> Result<()> { - // Create a channel for file events +/// Watch a directory for IVF file changes and stream video content. +/// +/// This function monitors the specified directory for new or modified IVF files +/// and automatically streams them to the WebRTC track when detected. +/// +/// # Arguments +/// +/// * `directory` - Directory path to watch for IVF files +/// * `track` - WebRTC track to stream video data to +async fn watch_and_stream_video(directory: &str, track: Arc) -> Result<()> { let (tx, mut rx) = mpsc::channel(100); - // Create an async file watcher let mut watcher = RecommendedWatcher::new( move |res| { if let Ok(event) = res { @@ -222,14 +327,10 @@ async fn watchand_stream_video(directory: &str, track: Arc