Skip to content
Merged
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dev = [
"setuptools",
"build",
"pytest",
"websocket-client",
"ruff",
"docstring_parser>=0.16",
"arduino_app_bricks[all]",
Expand Down
95 changes: 0 additions & 95 deletions src/arduino/app_bricks/web_ui/certs.py

This file was deleted.

115 changes: 77 additions & 38 deletions src/arduino/app_bricks/web_ui/web_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
#
# SPDX-License-Identifier: MPL-2.0

from collections.abc import Callable
import asyncio
import os
import asyncio
import threading
from contextlib import asynccontextmanager
from typing import Any
from collections.abc import Callable

import uvicorn
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi_socketio import SocketManager

from arduino.app_utils import brick, Logger

logger = Logger("WebUI")
Expand All @@ -32,7 +36,8 @@ def __init__(
api_path_prefix: str = "",
assets_dir_path: str = "/app/assets",
certs_dir_path: str = "/app/certs",
use_ssl: bool = False,
use_tls: bool = False,
use_ssl: bool | None = None, # Deprecated alias for use_tls
):
"""Initialize the web server.

Expand All @@ -42,35 +47,72 @@ def __init__(
ui_path_prefix (str, optional): URL prefix for UI routes. Defaults to "" (root).
api_path_prefix (str, optional): URL prefix for API routes. Defaults to "" (root).
assets_dir_path (str, optional): Path to static assets directory. Defaults to "/app/assets".
certs_dir_path (str, optional): Path to SSL certificates directory. Defaults to "/app/certs".
use_ssl (bool, optional): Enable SSL/HTTPS. Defaults to False.
certs_dir_path (str, optional): Path to TLS certificates directory. Defaults to "/app/certs".
use_tls (bool, optional): Enable TLS/HTTPS. Defaults to False.
use_ssl (bool, optional): Deprecated. Use use_tls instead. Defaults to None.
"""
self.app = FastAPI(title=__name__, openapi_url=None, on_startup=[self._on_startup])
# Handle deprecated use_ssl parameter
if use_ssl is not None:
logger.warning("'use_ssl' parameter is deprecated. Use 'use_tls' instead.")
use_tls = use_ssl

@asynccontextmanager
async def lifespan(app):
await self._on_startup()
yield

self.app = FastAPI(title=__name__, openapi_url=None, lifespan=lifespan)
self.sio = SocketManager(app=self.app, mount_location="/socket.io", socketio_path="", max_http_buffer_size=10 * 1024 * 1024)

self._addr = addr
self._port = port

def pick_free_port():
import socket

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]

self._port = port if port != 0 else pick_free_port()
self._ui_path_prefix = ui_path_prefix
self._api_path_prefix = api_path_prefix
self._assets_dir_path = os.path.abspath(assets_dir_path)
self._certs_dir_path = os.path.abspath(certs_dir_path)
self._use_ssl = use_ssl
self._protocol = "https" if self._use_ssl else "http"
self._server: uvicorn.Server = None
self._use_tls = use_tls
self._protocol = "https" if self._use_tls else "http"
self._server: uvicorn.Server | None = None
self._server_loop: asyncio.AbstractEventLoop | None = None
self._on_connect_cb: Callable[[str], None] = None
self._on_disconnect_cb: Callable[[str], None] = None
self._on_connect_cb: Callable[[str], None] | None = None
self._on_disconnect_cb: Callable[[str], None] | None = None
self._on_message_cbs = {}
self._on_message_cbs_lock = threading.Lock()

@property
def local_url(self) -> str:
"""Get the locally addressable URL of the web server.

Returns:
str: The server's URL (including protocol, address, and port).
"""
return f"{self._protocol}://localhost:{self._port}"

@property
def url(self) -> str:
"""Get the externally addressable URL of the web server.

Returns:
str: The server's URL (including protocol, address, and port).
"""
return f"{self._protocol}://{os.getenv('HOST_IP') or self._addr}:{self._port}"

def start(self):
"""Start the web server asynchronously.

This sets up static file routing and WebSocket event handlers, configures SSL if enabled, and launches the server using Uvicorn.
This sets up static file routing and WebSocket event handlers, configures TLS if enabled, and launches the server using Uvicorn.

Raises:
RuntimeError: If 'index.html' is missing in the static assets directory.
RuntimeError: If SSL is enabled but certificates are missing or fail to generate.
RuntimeError: If TLS is enabled but certificates fail to generate.
RuntimeWarning: If the server is already running.
"""
# Setup static routes and SocketIO events
Expand All @@ -82,18 +124,16 @@ def start(self):
self._init_socketio()

config = uvicorn.Config(self.app, host=self._addr, port=self._port, log_level="warning")
if self._use_ssl:
from . import certs
if self._use_tls:
from arduino.app_utils.tls_cert_manager import TLSCertificateManager

if not certs.cert_exists(self._certs_dir_path):
try:
certs.generate_self_signed_cert(self._certs_dir_path)
except Exception as e:
logger.exception(f"Failed to generate SSL certificate: {e}")
raise RuntimeError("Failed to generate SSL certificate. Please check the certs directory.") from e

config.ssl_keyfile = certs.get_pkey(self._certs_dir_path)
config.ssl_certfile = certs.get_cert(self._certs_dir_path)
try:
cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=self._certs_dir_path, common_name=self._addr)
config.ssl_certfile = cert_path
config.ssl_keyfile = key_path
except Exception as e:
logger.exception(f"Failed to configure SSL certificate: {e}")
raise RuntimeError("Failed to configure TLS certificate. Please check the certs directory.") from e

self._server = uvicorn.Server(config)

Expand All @@ -108,33 +148,31 @@ def stop(self):

def execute(self):
logger.debug(f"Serving static web files from {self._assets_dir_path}")
if self._use_ssl:
logger.debug(f"Serving certificates from {self._certs_dir_path}")
if self._use_tls:
logger.debug(f"Using TLS certificates from {self._certs_dir_path}")

logger.debug("Starting server...")

startup_log = "The application interface is available here:\n"
startup_log += f" - Local URL: {self._protocol}://localhost:{self._port}"
host_ip = os.getenv("HOST_IP")
if host_ip:
network_url = f"{self._protocol}://{host_ip}:{self._port}"
startup_log += f"\n - Network URL: {network_url}"
startup_log += f" - Local URL: {self.local_url}"
if os.getenv("HOST_IP"):
startup_log += f"\n - Network URL: {self.url}"
logger.info(startup_log)

try:
self._server.run()
except Exception as e:
logger.exception(f"Error running server: {e}")

def expose_api(self, method: str, path: str, function: callable):
def expose_api(self, method: str, path: str, function: Callable):
"""Register a route with the specified HTTP method and path.

The path will be prefixed with the api_path_prefix configured during initialization.

Args:
method (str): HTTP method to use (e.g., "GET", "POST").
path (str): URL path for the API endpoint (without the prefix).
function (callable): Function to execute when the route is accessed.
function (Callable): Function to execute when the route is accessed.
"""
self.app.add_api_route(self._api_path_prefix + path, function, methods=[method])

Expand All @@ -160,7 +198,7 @@ def on_disconnect(self, callback: Callable[[str], None]):
"""
self._on_disconnect_cb = callback

def on_message(self, message_type: str, callback: Callable[[str, any], any]):
def on_message(self, message_type: str, callback: Callable[[str, Any], Any]):
"""Register a callback function for a specific WebSocket message type received by clients.

The client should send messages named as message_type for this callback to be triggered.
Expand All @@ -170,7 +208,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]):

Args:
message_type (str): The message type name to listen for.
callback (Callable[[str, any], any]): Function to handle the message. Receives two arguments:
callback (Callable[[str, Any], Any]): Function to handle the message. Receives two arguments:
the session ID (sid) and the incoming message data.

"""
Expand All @@ -180,7 +218,7 @@ def on_message(self, message_type: str, callback: Callable[[str, any], any]):
self._on_message_cbs[message_type] = callback
logger.debug(f"Registered listener for message '{message_type}'")

def send_message(self, message_type: str, message: dict | str, room: str = None):
def send_message(self, message_type: str, message: dict | str, room: str | None = None):
"""Send a message to connected WebSocket clients.

Args:
Expand All @@ -200,7 +238,8 @@ def send_message(self, message_type: str, message: dict | str, room: str = None)
logger.exception(f"Failed to send WebSocket message '{message_type}': {e}")

async def _on_startup(self):
"""This function is called by uvicorn when the server starts up, it is necessary to capture the running
"""
This function is called by uvicorn when the server starts up, it is necessary to capture the running
asyncio event loop and reuse it later for emitting socket.io events as it requires an asyncio context.
"""
self._server_loop = asyncio.get_running_loop()
Expand Down
Loading
Loading