diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..6f94355 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +asyncio_mode = auto diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e346afd --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,256 @@ +"""Shared fixtures and mocks for manx plugin tests.""" + +import json +import os +import sys +from unittest.mock import AsyncMock, MagicMock + +import pytest + +# --------------------------------------------------------------------------- +# Set up sys.path so that `plugins.manx` resolves to this repo. +# The repo lives at /tmp/manx-pytest, and imports use `plugins.manx.app...` +# We create a virtual package path: +# /tmp/manx-pytest/.. --> contains plugins/manx (symlink or path trick) +# --------------------------------------------------------------------------- +_repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +_plugins_root = os.path.dirname(_repo_root) # parent of manx-pytest + +# We need "plugins.manx" to resolve. Create namespace package. +sys.path.insert(0, _plugins_root) + +# Ensure the 'plugins' package exists as a namespace +import importlib +if 'plugins' not in sys.modules: + # Create a namespace package pointing to the parent dir + import types + plugins_pkg = types.ModuleType('plugins') + plugins_pkg.__path__ = [os.path.join(_plugins_root)] + plugins_pkg.__package__ = 'plugins' + sys.modules['plugins'] = plugins_pkg + +# Ensure plugins.manx points to this repo +if 'plugins.manx' not in sys.modules: + import types + manx_pkg = types.ModuleType('plugins.manx') + manx_pkg.__path__ = [_repo_root] + manx_pkg.__package__ = 'plugins.manx' + sys.modules['plugins.manx'] = manx_pkg + +# --------------------------------------------------------------------------- +# Stub out heavy Caldera imports so the plugin modules can be imported +# without pulling in the full Caldera framework. +# --------------------------------------------------------------------------- + +# app.utility.base_world -------------------------------------------------- +base_world_mod = MagicMock() + + +class _FakeBaseWorld: + TIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ' + + class Access: + RED = 'red' + BLUE = 'blue' + + @staticmethod + def get_config(prop): + _configs = { + 'app.contact.websocket': 'ws://localhost:7012', + } + return _configs.get(prop, '') + + @staticmethod + def generate_name(size=10): + return 'a' * size + + +base_world_mod.BaseWorld = _FakeBaseWorld +sys.modules.setdefault('app', MagicMock()) +sys.modules['app.utility'] = MagicMock() +sys.modules['app.utility.base_world'] = base_world_mod + +# app.utility.base_service ------------------------------------------------- +base_service_mod = MagicMock() + + +class _FakeBaseService(_FakeBaseWorld): + @staticmethod + def add_service(name, svc): + return MagicMock() # logger + + +base_service_mod.BaseService = _FakeBaseService +sys.modules['app.utility.base_service'] = base_service_mod + +# aiohttp / aiohttp_jinja2 ------------------------------------------------ +if 'aiohttp' not in sys.modules: + aiohttp_mod = MagicMock() + + class _FakeWebResponse: + def __init__(self, data=None, **kwargs): + self.data = data + self.body = json.dumps(data).encode() if data else b'' + self.content_type = 'application/json' + + aiohttp_mod.web.json_response = lambda d: _FakeWebResponse(data=d) + aiohttp_mod.web.Response = _FakeWebResponse + sys.modules['aiohttp'] = aiohttp_mod + sys.modules['aiohttp.web'] = aiohttp_mod.web + +if 'aiohttp_jinja2' not in sys.modules: + jinja2_mod = MagicMock() + + def _fake_template(name): + """Decorator that just passes through the coroutine.""" + def decorator(fn): + return fn + return decorator + + jinja2_mod.template = _fake_template + sys.modules['aiohttp_jinja2'] = jinja2_mod + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class FakeSession: + """Mimics a TCP session object.""" + + def __init__(self, session_id=1, paw='abc123'): + self.id = session_id + self.paw = paw + + +class FakeAgent: + """Mimics a Caldera agent object.""" + + def __init__(self, paw='abc123', platform='linux', executors=None): + self.paw = paw + self.platform = platform + self.executors = executors or ['sh'] + self.display = dict(paw=paw, platform=platform, executors=self.executors) + + +class FakeAbility: + """Mimics a Caldera ability object.""" + + def __init__(self, ability_id='ability-1', name='test ability'): + self.ability_id = ability_id + self.name = name + self.display = dict(ability_id=ability_id, name=name) + + +class FakeTcpHandler: + """Mimics a TCP handler with sessions and send/refresh.""" + + def __init__(self, sessions=None): + self.sessions = sessions if sessions is not None else [] + self.send = AsyncMock(return_value=(0, '/home', 'output', '50ms')) + self.refresh = AsyncMock() + + +class FakeSocketConn: + """Mimics a contact with tcp_handler.""" + + def __init__(self, handler=None): + self.name = 'tcp' + self.tcp_handler = handler or FakeTcpHandler() + self.handler = MagicMock() + self.handler.handles = [] + + +class FakeWebsocketContact: + """Mimics a websocket contact.""" + + def __init__(self): + self.name = 'websocket' + self.handler = MagicMock() + self.handler.handles = [] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def fake_sessions(): + return [FakeSession(1, 'abc123'), FakeSession(2, 'def456')] + + +@pytest.fixture +def fake_agents(): + return [ + FakeAgent('abc123', 'linux', ['sh']), + FakeAgent('def456', 'windows', ['psh']), + ] + + +@pytest.fixture +def tcp_handler(fake_sessions): + return FakeTcpHandler(sessions=list(fake_sessions)) + + +@pytest.fixture +def socket_conn(tcp_handler): + return FakeSocketConn(handler=tcp_handler) + + +@pytest.fixture +def websocket_contact(): + return FakeWebsocketContact() + + +@pytest.fixture +def mock_services(socket_conn, websocket_contact, fake_agents): + """Build a dict of mock Caldera services.""" + data_svc = AsyncMock() + data_svc.apply = AsyncMock() + data_svc.locate = AsyncMock(side_effect=lambda table, match=None: [ + a for a in fake_agents if match and a.paw == match.get('paw') + ]) + + auth_svc = AsyncMock() + file_svc = AsyncMock() + file_svc.find_file_path = AsyncMock(return_value=('manx', '/tmp/manx-pytest/shells/manx.go')) + file_svc.compile_go = AsyncMock() + file_svc.add_special_payload = AsyncMock() + file_svc.sanitize_ldflag_value = MagicMock(side_effect=lambda param, val: val) + + contact_svc = MagicMock() + contact_svc.contacts = [socket_conn, websocket_contact] + contact_svc.report = {'websocket': []} + + app_svc = MagicMock() + app_svc.application = MagicMock() + app_svc.application.router = MagicMock() + app_svc.application.router.add_route = MagicMock() + app_svc.application.router.add_static = MagicMock() + app_svc.retrieve_compiled_file = AsyncMock(return_value=b'compiled_binary') + + rest_svc = AsyncMock() + rest_svc.find_abilities = AsyncMock(return_value=[]) + + services = { + 'data_svc': data_svc, + 'auth_svc': auth_svc, + 'file_svc': file_svc, + 'contact_svc': contact_svc, + 'app_svc': app_svc, + 'rest_svc': rest_svc, + 'term_svc': MagicMock(), + } + # Make term_svc.socket_conn point to our fake + services['term_svc'].socket_conn = socket_conn + return services + + +@pytest.fixture +def mock_request(): + """Factory for fake aiohttp requests.""" + def _make(json_data=None): + req = AsyncMock() + req.json = AsyncMock(return_value=json_data or {}) + return req + return _make diff --git a/tests/test_h_terminal.py b/tests/test_h_terminal.py new file mode 100644 index 0000000..cf793cd --- /dev/null +++ b/tests/test_h_terminal.py @@ -0,0 +1,183 @@ +"""Tests for app.h_terminal — Handle class (websocket terminal handler).""" + +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from plugins.manx.app.h_terminal import Handle + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + +class TestHandleInit: + def test_tag_stored(self): + h = Handle(tag='manx') + assert h.tag == 'manx' + + def test_tag_empty_string(self): + h = Handle(tag='') + assert h.tag == '' + + def test_tag_custom_value(self): + h = Handle(tag='custom-tag-123') + assert h.tag == 'custom-tag-123' + + def test_tag_none(self): + h = Handle(tag=None) + assert h.tag is None + + +# --------------------------------------------------------------------------- +# run() — the static websocket handler +# --------------------------------------------------------------------------- + +class TestHandleRun: + + @pytest.fixture + def mock_socket(self): + sock = AsyncMock() + sock.recv = AsyncMock(return_value='whoami') + sock.send = AsyncMock() + return sock + + @pytest.fixture + def handler_services(self, tcp_handler, fake_sessions): + tcp_handler.sessions = list(fake_sessions) + svc = { + 'term_svc': MagicMock(), + 'contact_svc': MagicMock(), + } + svc['term_svc'].socket_conn.tcp_handler = tcp_handler + svc['contact_svc'].report = {'websocket': []} + return svc + + @pytest.mark.asyncio + async def test_run_sends_response(self, mock_socket, handler_services): + """run() should recv a command, send it to tcp_handler, and send JSON back.""" + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + mock_socket.recv.assert_awaited_once() + handler_services['term_svc'].socket_conn.tcp_handler.send.assert_awaited_once_with('1', 'whoami') + mock_socket.send.assert_awaited_once() + sent = json.loads(mock_socket.send.call_args[0][0]) + assert sent['response'] == 'output' + assert sent['pwd'] == '/home' + assert sent['status'] == 0 + assert sent['response_time'] == '50ms' + + @pytest.mark.asyncio + async def test_run_appends_to_report(self, mock_socket, handler_services): + """run() should append an entry to the websocket report.""" + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + report = handler_services['contact_svc'].report['websocket'] + assert len(report) == 1 + assert report[0]['paw'] == 'abc123' + assert report[0]['cmd'] == 'whoami' + assert 'date' in report[0] + + @pytest.mark.asyncio + async def test_run_with_second_session(self, mock_socket, handler_services): + """run() should correctly resolve paw for session id 2.""" + path = '/manx/2/ws' + await Handle.run(mock_socket, path, handler_services) + + report = handler_services['contact_svc'].report['websocket'] + assert report[0]['paw'] == 'def456' + + @pytest.mark.asyncio + async def test_run_strips_response_whitespace(self, mock_socket, handler_services): + """run() should strip whitespace from the response.""" + handler_services['term_svc'].socket_conn.tcp_handler.send = AsyncMock( + return_value=(0, '/tmp', ' padded output ', '10ms') + ) + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + sent = json.loads(mock_socket.send.call_args[0][0]) + assert sent['response'] == 'padded output' + + @pytest.mark.asyncio + async def test_run_empty_command(self, mock_socket, handler_services): + """run() should handle an empty command string.""" + mock_socket.recv = AsyncMock(return_value='') + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + handler_services['term_svc'].socket_conn.tcp_handler.send.assert_awaited_once_with('1', '') + + @pytest.mark.asyncio + async def test_run_nonzero_status(self, mock_socket, handler_services): + """run() should forward non-zero status codes.""" + handler_services['term_svc'].socket_conn.tcp_handler.send = AsyncMock( + return_value=(1, '/root', 'permission denied', '5ms') + ) + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + sent = json.loads(mock_socket.send.call_args[0][0]) + assert sent['status'] == 1 + assert sent['response'] == 'permission denied' + + @pytest.mark.asyncio + async def test_run_session_not_found_raises(self, mock_socket, handler_services): + """run() should raise RuntimeError (wrapping StopIteration) when session id doesn't exist.""" + path = '/manx/999/ws' + with pytest.raises(RuntimeError): + await Handle.run(mock_socket, path, handler_services) + + @pytest.mark.asyncio + async def test_run_multipart_path(self, mock_socket, handler_services): + """run() extracts session_id from path.split('/')[2] regardless of extra segments.""" + path = '/manx/1/ws/extra/stuff' + await Handle.run(mock_socket, path, handler_services) + + handler_services['term_svc'].socket_conn.tcp_handler.send.assert_awaited_once_with('1', 'whoami') + + @pytest.mark.asyncio + async def test_run_date_format(self, mock_socket, handler_services): + """The date written to the report should follow TIME_FORMAT.""" + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + report = handler_services['contact_svc'].report['websocket'] + date_str = report[0]['date'] + # Should be parseable as UTC time in the expected format + parsed = datetime.strptime(date_str, '%Y-%m-%dT%H:%M:%SZ') + assert parsed is not None + + @pytest.mark.asyncio + async def test_run_response_is_valid_json(self, mock_socket, handler_services): + """The data sent back over the socket should be valid JSON.""" + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + raw = mock_socket.send.call_args[0][0] + data = json.loads(raw) + assert set(data.keys()) == {'response', 'pwd', 'status', 'response_time'} + + @pytest.mark.asyncio + async def test_run_special_characters_in_command(self, mock_socket, handler_services): + """run() should handle commands with special characters.""" + mock_socket.recv = AsyncMock(return_value='echo "hello; rm -rf /" | cat') + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + report = handler_services['contact_svc'].report['websocket'] + assert report[0]['cmd'] == 'echo "hello; rm -rf /" | cat' + + @pytest.mark.asyncio + async def test_run_unicode_command(self, mock_socket, handler_services): + """run() should handle unicode commands.""" + mock_socket.recv = AsyncMock(return_value='echo "\u00e9\u00e8\u00ea"') + path = '/manx/1/ws' + await Handle.run(mock_socket, path, handler_services) + + report = handler_services['contact_svc'].report['websocket'] + assert '\u00e9' in report[0]['cmd'] diff --git a/tests/test_hook.py b/tests/test_hook.py new file mode 100644 index 0000000..a391fe2 --- /dev/null +++ b/tests/test_hook.py @@ -0,0 +1,123 @@ +"""Tests for hook.py — plugin entry point.""" + +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestHookModuleAttributes: + + def test_name(self): + from plugins.manx import hook + assert hook.name == 'Terminal' + + def test_description(self): + from plugins.manx import hook + assert hook.description == 'A toolset which supports terminal access' + + def test_address(self): + from plugins.manx import hook + assert hook.address == '/plugin/manx/gui' + + def test_access(self): + from plugins.manx import hook + assert hook.access == 'red' + + +class TestHookEnable: + + @pytest.fixture + def services(self, mock_services): + return mock_services + + @pytest.mark.asyncio + async def test_enable_applies_sessions(self, services): + from plugins.manx.hook import enable + await enable(services) + services['data_svc'].apply.assert_awaited_once_with('sessions') + + @pytest.mark.asyncio + async def test_enable_adds_routes(self, services): + from plugins.manx.hook import enable + await enable(services) + router = services['app_svc'].application.router + # 5 add_route calls + 1 add_static + assert router.add_route.call_count == 5 + assert router.add_static.call_count == 1 + + @pytest.mark.asyncio + async def test_enable_registers_static_route(self, services): + from plugins.manx.hook import enable + await enable(services) + router = services['app_svc'].application.router + router.add_static.assert_called_once_with('/manx', 'plugins/manx/static/', append_version=True) + + @pytest.mark.asyncio + async def test_enable_registers_gui_route(self, services): + from plugins.manx.hook import enable + await enable(services) + router = services['app_svc'].application.router + calls = [c for c in router.add_route.call_args_list if c[0][1] == '/plugin/manx/gui'] + assert len(calls) == 1 + assert calls[0][0][0] == 'GET' + + @pytest.mark.asyncio + async def test_enable_registers_get_sessions_route(self, services): + from plugins.manx.hook import enable + await enable(services) + router = services['app_svc'].application.router + calls = [c for c in router.add_route.call_args_list if c[0][1] == '/plugin/manx/sessions'] + # One GET and one POST + assert len(calls) == 2 + + @pytest.mark.asyncio + async def test_enable_registers_history_route(self, services): + from plugins.manx.hook import enable + await enable(services) + router = services['app_svc'].application.router + calls = [c for c in router.add_route.call_args_list if c[0][1] == '/plugin/manx/history'] + assert len(calls) == 1 + assert calls[0][0][0] == 'POST' + + @pytest.mark.asyncio + async def test_enable_registers_ability_route(self, services): + from plugins.manx.hook import enable + await enable(services) + router = services['app_svc'].application.router + calls = [c for c in router.add_route.call_args_list if c[0][1] == '/plugin/manx/ability'] + assert len(calls) == 1 + assert calls[0][0][0] == 'POST' + + @pytest.mark.asyncio + async def test_enable_appends_handle_to_websocket(self, services): + from plugins.manx.hook import enable + ws_contact = services['contact_svc'].contacts[1] + initial_count = len(ws_contact.handler.handles) + await enable(services) + assert len(ws_contact.handler.handles) == initial_count + 1 + + @pytest.mark.asyncio + async def test_enable_handle_tag_is_manx(self, services): + from plugins.manx.hook import enable + ws_contact = services['contact_svc'].contacts[1] + await enable(services) + handle = ws_contact.handler.handles[-1] + assert handle.tag == 'manx' + + @pytest.mark.asyncio + async def test_enable_adds_special_payload(self, services): + from plugins.manx.hook import enable + await enable(services) + services['file_svc'].add_special_payload.assert_awaited_once() + call_args = services['file_svc'].add_special_payload.call_args[0] + assert call_args[0] == 'manx.go' + + @pytest.mark.asyncio + async def test_enable_no_websocket_contact_raises(self, mock_services): + """If no websocket contact exists, enable should raise.""" + from plugins.manx.hook import enable + mock_services['contact_svc'].contacts = [mock_services['contact_svc'].contacts[0]] + # Only TCP contact, no websocket — the list comprehension returns empty + with pytest.raises(IndexError): + await enable(mock_services) diff --git a/tests/test_term_api.py b/tests/test_term_api.py new file mode 100644 index 0000000..3d2d44f --- /dev/null +++ b/tests/test_term_api.py @@ -0,0 +1,419 @@ +"""Tests for app.term_api — TermApi class (HTTP endpoint handlers).""" + +import json +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from plugins.manx.app.term_api import TermApi +from tests.conftest import FakeAbility, FakeAgent, FakeSession + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + +class TestTermApiInit: + def test_services_assigned(self, mock_services): + api = TermApi(mock_services) + assert api.auth_svc is mock_services['auth_svc'] + assert api.file_svc is mock_services['file_svc'] + assert api.data_svc is mock_services['data_svc'] + assert api.contact_svc is mock_services['contact_svc'] + assert api.app_svc is mock_services['app_svc'] + assert api.rest_svc is mock_services['rest_svc'] + + def test_term_svc_created(self, mock_services): + api = TermApi(mock_services) + assert api.term_svc is not None + + +# --------------------------------------------------------------------------- +# splash() +# --------------------------------------------------------------------------- + +class TestSplash: + + @pytest.fixture + def api(self, mock_services): + return TermApi(mock_services) + + @pytest.mark.asyncio + async def test_splash_returns_sessions_and_websocket(self, api, mock_request): + result = await api.splash(mock_request()) + assert 'sessions' in result + assert 'websocket' in result + + @pytest.mark.asyncio + async def test_splash_session_fields(self, api, mock_request): + result = await api.splash(mock_request()) + for s in result['sessions']: + assert 'id' in s + assert 'info' in s + assert 'platform' in s + assert 'executors' in s + + @pytest.mark.asyncio + async def test_splash_calls_refresh(self, api, mock_request): + await api.splash(mock_request()) + api.term_svc.socket_conn.tcp_handler.refresh.assert_awaited() + + @pytest.mark.asyncio + async def test_splash_no_sessions(self, mock_services, mock_request): + mock_services['contact_svc'].contacts[0].tcp_handler.sessions = [] + api = TermApi(mock_services) + result = await api.splash(mock_request()) + assert result['sessions'] == [] + + @pytest.mark.asyncio + async def test_splash_handles_exception(self, mock_services, mock_request, capsys): + """splash() catches exceptions and prints them.""" + mock_services['data_svc'].locate = AsyncMock(side_effect=RuntimeError('boom')) + api = TermApi(mock_services) + result = await api.splash(mock_request()) + captured = capsys.readouterr() + assert 'boom' in captured.out + assert result is None + + @pytest.mark.asyncio + async def test_splash_multiple_sessions(self, api, mock_request): + result = await api.splash(mock_request()) + assert len(result['sessions']) == 2 + + @pytest.mark.asyncio + async def test_splash_websocket_config(self, api, mock_request): + result = await api.splash(mock_request()) + assert result['websocket'] == 'ws://localhost:7012' + + +# --------------------------------------------------------------------------- +# get_sessions() +# --------------------------------------------------------------------------- + +class TestGetSessions: + + @pytest.fixture + def api(self, mock_services): + return TermApi(mock_services) + + @pytest.mark.asyncio + async def test_returns_json_response(self, api, mock_request): + resp = await api.get_sessions(mock_request()) + data = resp.data + assert 'sessions' in data + + @pytest.mark.asyncio + async def test_session_count(self, api, mock_request): + resp = await api.get_sessions(mock_request()) + assert len(resp.data['sessions']) == 2 + + @pytest.mark.asyncio + async def test_session_structure(self, api, mock_request): + resp = await api.get_sessions(mock_request()) + for s in resp.data['sessions']: + assert set(s.keys()) == {'id', 'info', 'platform', 'executors'} + + @pytest.mark.asyncio + async def test_calls_refresh(self, api, mock_request): + await api.get_sessions(mock_request()) + api.term_svc.socket_conn.tcp_handler.refresh.assert_awaited() + + @pytest.mark.asyncio + async def test_empty_sessions(self, mock_services, mock_request): + mock_services['contact_svc'].contacts[0].tcp_handler.sessions = [] + api = TermApi(mock_services) + resp = await api.get_sessions(mock_request()) + assert resp.data['sessions'] == [] + + @pytest.mark.asyncio + async def test_agent_not_found_for_session(self, mock_services, mock_request): + """If data_svc.locate returns nothing for a paw, the session is not included.""" + mock_services['data_svc'].locate = AsyncMock(return_value=[]) + api = TermApi(mock_services) + resp = await api.get_sessions(mock_request()) + assert resp.data['sessions'] == [] + + +# --------------------------------------------------------------------------- +# sessions() (POST) +# --------------------------------------------------------------------------- + +class TestSessionsPost: + + @pytest.fixture + def api(self, mock_services): + return TermApi(mock_services) + + @pytest.mark.asyncio + async def test_returns_list(self, api, mock_request): + resp = await api.sessions(mock_request()) + assert isinstance(resp.data, list) + + @pytest.mark.asyncio + async def test_session_fields(self, api, mock_request): + resp = await api.sessions(mock_request()) + for s in resp.data: + assert 'id' in s + assert 'info' in s + assert set(s.keys()) == {'id', 'info'} + + @pytest.mark.asyncio + async def test_session_count_matches(self, api, mock_request): + resp = await api.sessions(mock_request()) + assert len(resp.data) == 2 + + @pytest.mark.asyncio + async def test_calls_refresh(self, api, mock_request): + await api.sessions(mock_request()) + api.term_svc.socket_conn.tcp_handler.refresh.assert_awaited() + + @pytest.mark.asyncio + async def test_empty_sessions(self, mock_services, mock_request): + mock_services['contact_svc'].contacts[0].tcp_handler.sessions = [] + api = TermApi(mock_services) + resp = await api.sessions(mock_request()) + assert resp.data == [] + + @pytest.mark.asyncio + async def test_info_is_paw(self, api, mock_request): + resp = await api.sessions(mock_request()) + paws = [s['info'] for s in resp.data] + assert 'abc123' in paws + assert 'def456' in paws + + +# --------------------------------------------------------------------------- +# get_history() +# --------------------------------------------------------------------------- + +class TestGetHistory: + + @pytest.fixture + def api(self, mock_services): + mock_services['contact_svc'].report = { + 'websocket': [ + {'paw': 'abc123', 'cmd': 'ls', 'date': '2025-01-01T00:00:00Z'}, + {'paw': 'abc123', 'cmd': 'pwd', 'date': '2025-01-01T00:01:00Z'}, + {'paw': 'other', 'cmd': 'whoami', 'date': '2025-01-01T00:02:00Z'}, + ] + } + return TermApi(mock_services) + + @pytest.mark.asyncio + async def test_returns_matching_history(self, api, mock_request): + req = mock_request(json_data={'paw': 'abc123'}) + resp = await api.get_history(req) + assert len(resp.data) == 2 + + @pytest.mark.asyncio + async def test_filters_by_paw(self, api, mock_request): + req = mock_request(json_data={'paw': 'other'}) + resp = await api.get_history(req) + assert len(resp.data) == 1 + assert resp.data[0]['cmd'] == 'whoami' + + @pytest.mark.asyncio + async def test_no_matching_paw(self, api, mock_request): + req = mock_request(json_data={'paw': 'nonexistent'}) + resp = await api.get_history(req) + assert resp.data == [] + + @pytest.mark.asyncio + async def test_empty_report(self, mock_services, mock_request): + mock_services['contact_svc'].report = {'websocket': []} + api = TermApi(mock_services) + req = mock_request(json_data={'paw': 'abc123'}) + resp = await api.get_history(req) + assert resp.data == [] + + @pytest.mark.asyncio + async def test_missing_paw_in_request(self, api, mock_request): + """If paw is missing from request, data.get('paw') returns None, no match.""" + req = mock_request(json_data={}) + resp = await api.get_history(req) + assert resp.data == [] + + @pytest.mark.asyncio + async def test_history_preserves_entry_structure(self, api, mock_request): + req = mock_request(json_data={'paw': 'abc123'}) + resp = await api.get_history(req) + for entry in resp.data: + assert 'paw' in entry + assert 'cmd' in entry + assert 'date' in entry + + +# --------------------------------------------------------------------------- +# get_abilities() +# --------------------------------------------------------------------------- + +class TestGetAbilities: + + @pytest.fixture + def api(self, mock_services): + abilities = [FakeAbility('a1', 'recon'), FakeAbility('a2', 'exfil')] + mock_services['rest_svc'].find_abilities = AsyncMock(return_value=abilities) + return TermApi(mock_services) + + @pytest.mark.asyncio + async def test_returns_abilities(self, api, mock_request): + req = mock_request(json_data={'paw': 'abc123'}) + resp = await api.get_abilities(req) + assert 'abilities' in resp.data + assert len(resp.data['abilities']) == 2 + + @pytest.mark.asyncio + async def test_ability_display_format(self, api, mock_request): + req = mock_request(json_data={'paw': 'abc123'}) + resp = await api.get_abilities(req) + for a in resp.data['abilities']: + assert 'ability_id' in a + assert 'name' in a + + @pytest.mark.asyncio + async def test_no_abilities(self, mock_services, mock_request): + mock_services['rest_svc'].find_abilities = AsyncMock(return_value=[]) + api = TermApi(mock_services) + req = mock_request(json_data={'paw': 'abc123'}) + resp = await api.get_abilities(req) + assert resp.data['abilities'] == [] + + @pytest.mark.asyncio + async def test_calls_rest_svc_with_paw(self, api, mock_request): + req = mock_request(json_data={'paw': 'test-paw'}) + await api.get_abilities(req) + api.rest_svc.find_abilities.assert_awaited_once_with(paw='test-paw') + + @pytest.mark.asyncio + async def test_missing_paw_raises(self, api, mock_request): + """If 'paw' key is absent, KeyError should be raised.""" + req = mock_request(json_data={}) + with pytest.raises(KeyError): + await api.get_abilities(req) + + +# --------------------------------------------------------------------------- +# dynamically_compile() +# --------------------------------------------------------------------------- + +class TestDynamicallyCompile: + + @pytest.fixture + def api(self, mock_services): + return TermApi(mock_services) + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_with_go_available(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'linux'} + result = await api.dynamically_compile(headers) + api.file_svc.find_file_path.assert_awaited_once_with('manx.go') + api.file_svc.compile_go.assert_awaited_once() + assert result == b'compiled_binary' + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value=None) + async def test_compile_without_go(self, mock_which, api): + """When go is not installed, skip compilation but still retrieve.""" + headers = {'file': 'manx.go', 'platform': 'linux'} + result = await api.dynamically_compile(headers) + api.file_svc.compile_go.assert_not_awaited() + api.app_svc.retrieve_compiled_file.assert_awaited_once_with('manx.go', 'linux') + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_with_contact_header(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'windows', 'contact': 'tcp'} + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + ldflags = call_kwargs.kwargs.get('ldflags', '') or call_kwargs[1].get('ldflags', '') + assert 'main.contact=tcp' in ldflags + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_with_socket_header(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'linux', 'socket': '0.0.0.0:7010'} + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + ldflags = call_kwargs.kwargs.get('ldflags', '') or call_kwargs[1].get('ldflags', '') + assert 'main.socket=0.0.0.0:7010' in ldflags + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_with_http_header(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'darwin', 'http': 'http://localhost:8888'} + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + ldflags = call_kwargs.kwargs.get('ldflags', '') or call_kwargs[1].get('ldflags', '') + assert 'main.http=http://localhost:8888' in ldflags + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_all_optional_headers(self, mock_which, api): + headers = { + 'file': 'manx.go', + 'platform': 'linux', + 'contact': 'tcp', + 'socket': '0.0.0.0:7010', + 'http': 'http://localhost:8888', + } + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + ldflags = call_kwargs.kwargs.get('ldflags', '') or call_kwargs[1].get('ldflags', '') + assert 'main.contact=tcp' in ldflags + assert 'main.socket=0.0.0.0:7010' in ldflags + assert 'main.http=http://localhost:8888' in ldflags + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_default_architecture(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'linux'} + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + assert call_kwargs.kwargs.get('arch', '') == 'amd64' or call_kwargs[1].get('arch', '') == 'amd64' + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_custom_architecture(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'linux', 'architecture': 'arm64'} + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + assert call_kwargs.kwargs.get('arch', '') == 'arm64' or call_kwargs[1].get('arch', '') == 'arm64' + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_ldflags_contain_key(self, mock_which, api): + """ldflags should always include -s -w and a generated key.""" + headers = {'file': 'manx.go', 'platform': 'linux'} + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + ldflags = call_kwargs.kwargs.get('ldflags', '') or call_kwargs[1].get('ldflags', '') + assert '-s' in ldflags + assert '-w' in ldflags + assert 'main.key=' in ldflags + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_output_path(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'linux'} + await api.dynamically_compile(headers) + call_kwargs = api.file_svc.compile_go.call_args + output = call_kwargs[0][1] if len(call_kwargs[0]) > 1 else call_kwargs.kwargs.get('output', '') + # output is the second positional arg to compile_go(platform, output, ...) + assert 'manx.go-linux' in str(output) + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_sanitizes_ldflag_values(self, mock_which, api): + """file_svc.sanitize_ldflag_value should be called for each param header.""" + headers = {'file': 'manx.go', 'platform': 'linux', 'contact': 'tcp', 'socket': '0.0.0.0:7010'} + await api.dynamically_compile(headers) + assert api.file_svc.sanitize_ldflag_value.call_count == 2 + + @pytest.mark.asyncio + @patch('plugins.manx.app.term_api.which', return_value='/usr/local/go/bin/go') + async def test_compile_returns_binary(self, mock_which, api): + headers = {'file': 'manx.go', 'platform': 'windows'} + result = await api.dynamically_compile(headers) + assert result == b'compiled_binary' diff --git a/tests/test_term_svc.py b/tests/test_term_svc.py new file mode 100644 index 0000000..4aeca8d --- /dev/null +++ b/tests/test_term_svc.py @@ -0,0 +1,68 @@ +"""Tests for app.term_svc — TermService class.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from plugins.manx.app.term_svc import TermService +from tests.conftest import FakeSocketConn + + +class TestTermServiceInit: + + def test_socket_conn_assigned(self, mock_services): + svc = TermService(mock_services) + assert svc.socket_conn is not None + + def test_socket_conn_is_tcp_contact(self, mock_services): + svc = TermService(mock_services) + assert svc.socket_conn.name == 'tcp' + + def test_log_created(self, mock_services): + svc = TermService(mock_services) + assert svc.log is not None + + def test_no_tcp_contact_raises(self): + """If there is no TCP contact, the index will be out of range.""" + services = { + 'contact_svc': MagicMock(), + } + services['contact_svc'].contacts = [] + with pytest.raises(IndexError): + TermService(services) + + def test_multiple_contacts_picks_tcp(self): + """TermService should pick the contact whose name is 'tcp'.""" + ws = MagicMock() + ws.name = 'websocket' + tcp = FakeSocketConn() + services = { + 'contact_svc': MagicMock(), + } + services['contact_svc'].contacts = [ws, tcp] + svc = TermService(services) + assert svc.socket_conn is tcp + + def test_tcp_contact_first_in_list(self): + tcp = FakeSocketConn() + other = MagicMock() + other.name = 'udp' + services = { + 'contact_svc': MagicMock(), + } + services['contact_svc'].contacts = [tcp, other] + svc = TermService(services) + assert svc.socket_conn is tcp + + def test_only_websocket_contacts(self): + """If no TCP contact exists, should raise IndexError.""" + ws1 = MagicMock() + ws1.name = 'websocket' + ws2 = MagicMock() + ws2.name = 'websocket' + services = { + 'contact_svc': MagicMock(), + } + services['contact_svc'].contacts = [ws1, ws2] + with pytest.raises(IndexError): + TermService(services)