diff --git a/LICENSE b/LICENSE old mode 100755 new mode 100644 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9d2487e --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ + +dependencies: + pip install -r requirements.txt + +dependencies-dev: + pip install -r requirements-dev.txt + +mypy: + python -m mypy --strict --exclude build . + +test: + python -m pytest . + +.PHONY: dependencies dependencies-dev mypy test diff --git a/README.md b/README.md old mode 100755 new mode 100644 diff --git a/challtools/__init__.py b/challtools/__init__.py old mode 100755 new mode 100644 diff --git a/challtools/challenge.schema.json b/challtools/challenge.schema.json old mode 100755 new mode 100644 diff --git a/challtools/cli.py b/challtools/cli.py old mode 100755 new mode 100644 index c2320e5..4182c17 --- a/challtools/cli.py +++ b/challtools/cli.py @@ -1,43 +1,51 @@ # PYTHON_ARGCOMPLETE_OK -import sys -import time import argparse -import os -import uuid import hashlib +import json +import os +import pkg_resources import shutil +import sys +import time import urllib.parse -import json +import uuid from pathlib import Path -import pkg_resources +from typing import List, Optional, Callable, Dict, Any + +import argcomplete +import docker import requests import yaml -import docker -import argcomplete -from .validator import ConfigValidator, is_url + +from .constants import * from .utils import ( + _copytree, + build_chall, + build_docker_images, + create_docker_name, CriticalException, - process_messages, - load_ctf_config, - load_config, - get_ctf_config_path, - get_valid_config, discover_challenges, + format_user_service, + generate_compose, + get_ctf_config_path, get_docker_client, - create_docker_name, - build_docker_images, - build_chall, + get_valid_config, + load_config, + load_ctf_config, + process_messages, start_chall, start_solution, validate_solution_output, - format_user_service, - generate_compose, - _copytree, ) -from .constants import * +from .validator import ConfigValidator, is_url +class CliArguments(argparse.Namespace): + def __init__(self) -> None: + self.somearg: str + self.func: Callable -def main(passed_args=None): + +def main(passed_args: Optional[List[str]] = None) -> int: parser = argparse.ArgumentParser( prog="challtools", description="A tool for managing CTF challenges and challenge repositories using the OpenChallSpec", @@ -165,19 +173,20 @@ def main(passed_args=None): argcomplete.autocomplete(parser, always_complete_options=False) - args = parser.parse_args(passed_args) + args = parser.parse_args(passed_args, namespace=CliArguments) if not getattr(args, "func", None): parser.print_usage() + return 1 else: try: - exit(args.func(args)) + return args.func(args) except CriticalException as e: print(CRITICAL + e.args[0] + CLEAR) - exit(1) + return 1 -def allchalls(args): +def allchalls(args: CliArguments) -> int: parser = args.subparsers.choices.get(args.command[0]) if not parser: @@ -210,7 +219,7 @@ def allchalls(args): return int(failed) -def validate(args): +def validate(args: CliArguments) -> int: config = load_config() @@ -260,7 +269,7 @@ def validate(args): return 0 -def build(args): +def build(args: CliArguments) -> int: config = get_valid_config() if build_chall(config): @@ -271,7 +280,7 @@ def build(args): return 0 -def start(args): +def start(args: CliArguments) -> int: config = get_valid_config() if args.build and build_chall(config): @@ -322,7 +331,7 @@ def start(args): return 1 -def solve(args): # TODO add support for solve script +def solve(args: CliArguments) -> int: # TODO add support for solve script config = get_valid_config() if not config["solution_image"]: @@ -370,7 +379,7 @@ def solve(args): # TODO add support for solve script return 0 -def compose(args): +def compose(args: CliArguments) -> int: if args.all: configs = [ (path, get_valid_config(path, cd=False)) for path in discover_challenges() @@ -390,7 +399,7 @@ def compose(args): return 0 -def ensureid(args): +def ensureid(args: CliArguments) -> int: path = Path(".") if (path / "challenge.yml").exists(): path = path / "challenge.yml" @@ -413,7 +422,7 @@ def ensureid(args): if highest_level == 5: print( "\n".join( - process_messages([m for m in messages if m["level"] == 5])[ + process_messages([m for m in messages if m.level == 5])[ "message_strings" ] ) @@ -457,7 +466,7 @@ def ensureid(args): return 0 -def push(args): +def push(args: CliArguments) -> int: config = get_valid_config() ctf_config = load_ctf_config() @@ -629,7 +638,7 @@ def push(args): return 0 -def init(args): +def init(args: CliArguments) -> int: if args.list: for template_path in Path( @@ -687,7 +696,7 @@ def init(args): return 0 -def templateCompleter(**kwargs): +def templateCompleter(**kwargs: Dict[str, Any]) -> List[str]: return [ path.name for path in Path( @@ -696,7 +705,7 @@ def templateCompleter(**kwargs): ] -def spoilerfree(args): +def spoilerfree(args: CliArguments) -> int: config = get_valid_config() print(f"\033[1;97m{config['title']}{CLEAR}") diff --git a/challtools/codes.yml b/challtools/codes.yml old mode 100755 new mode 100644 diff --git a/challtools/constants.py b/challtools/constants.py old mode 100755 new mode 100644 diff --git a/challtools/templates/flask/container/server.py b/challtools/templates/flask/container/server.py index 548c378..3a8aa86 100644 --- a/challtools/templates/flask/container/server.py +++ b/challtools/templates/flask/container/server.py @@ -1,10 +1,11 @@ from flask import Flask +from flask.typing import ResponseReturnValue app = Flask(__name__) @app.route("/") -def index(): +def index() -> ResponseReturnValue: return "Template challenge running!" diff --git a/challtools/utils.py b/challtools/utils.py old mode 100755 new mode 100644 index 6fcc38c..80fb51c --- a/challtools/utils.py +++ b/challtools/utils.py @@ -1,23 +1,25 @@ +import hashlib +import json import os import re +import shutil import subprocess import sys -import hashlib -import json -import shutil from pathlib import Path -import yaml +from typing import Dict, Any, Optional, Union, List, Tuple + import docker import requests -from .validator import ConfigValidator -from .constants import * +import yaml +from .constants import * +from .validator import ConfigValidator, Message class CriticalException(Exception): pass -def process_messages(messages, verbose=False): +def process_messages(messages: List[Message], verbose: bool = False) -> Dict[str, Any]: """Processes a list of messages from validator.ConfigValidator.validate for printing. Args: @@ -37,17 +39,17 @@ def process_messages(messages, verbose=False): highest_level = 0 message_strings = [] for message in messages: - level_counts[message["level"] - 1] += 1 - highest_level = max(highest_level, message["level"]) + level_counts[message.level - 1] += 1 + highest_level = max(highest_level, message.level) message_string = ( - f"[{STYLED_LEVELS[message['level']-1]}] [{BOLD}{message['code']}{CLEAR}] " + f"[{STYLED_LEVELS[message.level-1]}] [{BOLD}{message.code}{CLEAR}] " ) - if message["field"]: - message_string += f"{message['field']}: " - message_string += message["name"] + if message.field: + message_string += f"{message.field}: " + message_string += message.name if verbose: - message_string += "\n" + message["message"] + message_string += "\n" + message.message message_strings.append(message_string) level_name_counts = {i: count for i, count in enumerate(level_counts) if count} @@ -75,7 +77,7 @@ def process_messages(messages, verbose=False): } -def get_ctf_config_path(search_start=Path(".")): +def get_ctf_config_path(search_start: Path = Path(".")) -> Optional[Path]: """Locates the global CTF configuration file (ctf.yml) and returns a path to it. Returns: @@ -93,7 +95,7 @@ def get_ctf_config_path(search_start=Path(".")): return None -def get_config_path(search_start=Path(".")): +def get_config_path(search_start: Path = Path(".")) -> Optional[Path]: """Locates the challenge configuration file (challenge.yml) and returns a path to it. Returns: @@ -111,7 +113,7 @@ def get_config_path(search_start=Path(".")): return None -def load_ctf_config(): +def load_ctf_config() -> Dict[str, Any]: """Loads the global CTF configuration file (ctf.yml) from the current or a parent directory. Returns: @@ -123,13 +125,13 @@ def load_ctf_config(): if not ctfpath: return None - raw_config = ctfpath.read_text() - config = yaml.safe_load(raw_config) + with open(ctfpath, 'r') as config_file: + config = yaml.safe_load(config_file) return config if config else {} -def load_config(workdir=".", search=True, cd=True): +def load_config(workdir: str = ".", search: bool = True, cd: bool = True) -> Dict[str, Any]: """Loads the challenge configuration file from the current directory, a specified directory, or optionally one of their parent directories. Optionally changes the working directory to the directory of the configuration file. Args: @@ -144,10 +146,10 @@ def load_config(workdir=".", search=True, cd=True): CriticalException: If the challenge configuration cannot be found """ - path = Path(workdir).absolute() + workdir_path = Path(workdir).absolute() if search: - path = get_config_path(path) + path = get_config_path(workdir_path) else: if (path / "challenge.yml").exists(): path = path / "challenge.yml" @@ -161,8 +163,10 @@ def load_config(workdir=".", search=True, cd=True): f"Could not find a challenge.yml file in this{' or a parent' if search else ''} directory." ) - raw_config = path.read_text() - config = yaml.safe_load(raw_config) + with open(path, 'r') as config_file: + config = yaml.safe_load(config_file) + if not config: + raise RuntimeError(f'Failed to load config from path "{path}"') if cd: os.chdir(path.parent) @@ -170,7 +174,7 @@ def load_config(workdir=".", search=True, cd=True): return config -def get_valid_config(workdir=None, search=True, cd=True): +def get_valid_config(workdir: Optional[Union[str, Path]] = None, search: bool = True, cd: bool = True) -> Any: """Loads the challenge configuration file from the current directory and makes sure its valid. Args: @@ -184,9 +188,10 @@ def get_valid_config(workdir=None, search=True, cd=True): Raises: CriticalException: If there are critical validation errors """ - config = load_config( - search=search, cd=cd, **{"workdir": workdir} if workdir else {} - ) + if workdir: + config = load_config(search=search, cd=cd, workdir=workdir) + else: + config = load_config(search=search, cd=cd) validator = ConfigValidator(config) messages = validator.validate()[1] @@ -195,7 +200,7 @@ def get_valid_config(workdir=None, search=True, cd=True): if highest_level == 5: print( "\n".join( - process_messages([m for m in messages if m["level"] == 5])[ + process_messages([m for m in messages if m.level == 5])[ "message_strings" ] ) @@ -207,7 +212,7 @@ def get_valid_config(workdir=None, search=True, cd=True): elif highest_level == 4: print( "\n".join( - process_messages([m for m in messages if m["level"] == 4])[ + process_messages([m for m in messages if m.level == 4])[ "message_strings" ] ) @@ -219,21 +224,19 @@ def get_valid_config(workdir=None, search=True, cd=True): return validator.normalized_config -def discover_challenges(search_start=None): +def discover_challenges(search_start: Optional[str] = None) -> Optional[List[Path]]: """Discovers all challenges at the same level as or in a subdirectory below the CTF configuration file. Returns: list: A list of pathlib.Path objects to all found challenge configurations None: If there was no CTF config """ - root = get_ctf_config_path( - **{"search_start": search_start} if search_start else {} - ).parent - + root = (get_ctf_config_path(search_start) if search_start else get_ctf_config_path()) if not root: return None + root = root.parent - def checkdir(d): + def checkdir(d: Path) -> List[Path]: if (d / "challenge.yml").exists(): return [d / "challenge.yml"] if (d / "challenge.yaml").exists(): @@ -246,7 +249,7 @@ def checkdir(d): return checkdir(root) -def get_docker_client(): +def get_docker_client() -> docker.api.client.ContainerApiMixin: """Gets an authenticated docker client. Returns: @@ -272,7 +275,7 @@ def get_docker_client(): return client -def get_first_text_flag(config): +def get_first_text_flag(config: Dict[str, Any]) -> Optional[str]: """Creates a valid flag with the flag format using the flag format and the first text flag, if it exists. Args: @@ -297,7 +300,7 @@ def get_first_text_flag(config): return config["flag_format_prefix"] + text_flag + config["flag_format_suffix"] -def dockerize_string(string): +def dockerize_string(string: str) -> str: """Converts a string into a valid docker tag name. Args: @@ -315,7 +318,7 @@ def dockerize_string(string): return string[:128] -def create_docker_name(title, container_name=None, chall_id=None): +def create_docker_name(title: str, container_name: Optional[str] = None, chall_id: Optional[str] = None) -> str: """Converts challenge information into a most likely unique and valid docker tag name. Args: @@ -339,7 +342,7 @@ def create_docker_name(title, container_name=None, chall_id=None): return "_".join([title[:32], digest[:16]]) -def format_user_service(config, service_type, **kwargs): +def format_user_service(config: Dict[str, Any], service_type: str, **kwargs) -> str: """Formats a string displayed to the user based on the service type and a substitution context (``display`` in the OpenChallSpec). Args: @@ -366,7 +369,7 @@ def format_user_service(config, service_type, **kwargs): return string -def validate_solution_output(config, output): +def validate_solution_output(config: Dict[str, Any], output: str) -> bool: """validates a flag outputted by a solver by stripping the whitespace and validating the flag. Args: @@ -379,7 +382,7 @@ def validate_solution_output(config, output): return validate_flag(config, output.strip()) -def validate_flag(config, submitted_flag): +def validate_flag(config: Dict[str, Any], submitted_flag: str) -> bool: """validates a flag against the flags in the challenge config. Args: @@ -410,7 +413,7 @@ def validate_flag(config, submitted_flag): return False -def build_image(image, tag, client): +def build_image(image: str, tag: str, client: docker.api.client.ContainerApiMixin) -> None: """Build a docker image given the image (as a path to a folder, if archive it will load it), the tag and the docker client. Args: @@ -457,7 +460,7 @@ def build_image(image, tag, client): ) -def run_build_script(config): +def run_build_script(config: Dict[str, Any]) -> None: if "build_script" not in config["custom"]: raise CriticalException(f"Build script has not been defined!") @@ -476,7 +479,7 @@ def run_build_script(config): raise CriticalException(f"Build script exited with code {p.returncode}") -def build_docker_images(config, client): +def build_docker_images(config: Dict[str, Any], client: docker.api.client.ContainerApiMixin) -> bool: if not config["deployment"]: return False @@ -511,7 +514,7 @@ def build_docker_images(config, client): return True -def build_chall(config): +def build_chall(config: Dict[str, Any]) -> bool: """Builds a challenge including running the build script and building service and solution docker images. Expects to be run from the root directory of the challenge. Args: @@ -554,7 +557,7 @@ def build_chall(config): return did_something -def start_chall(config): +def start_chall(config: Dict[str, Any]) -> Tuple[List[docker.api.client.ContainerApiMixin], List[str]]: """Starts all docker containers for this challenge. Args: @@ -661,7 +664,7 @@ def start_chall(config): return containers, service_strings -def start_solution(config): +def start_solution(config: Dict[str, Any]) -> docker.api.client.ContainerApiMixin: """Starts a solution container for this challenge. Args: @@ -757,7 +760,7 @@ def generate_compose(configs, is_global=False): # TODO handle services with set external ports first so the auto assigned ports dont potentially conflict with them for name, container in config["deployment"]["containers"].items(): - compose_service = {"ports": []} + compose_service: Dict[str, Any] = {"ports": []} volumes = [] networks = [] @@ -831,7 +834,7 @@ def generate_compose(configs, is_global=False): # https://stackoverflow.com/a/12514470 # needs to exist to support python 3.6 & 3.7, otherwise shutil.copytree should be used with dirs_exist_ok=True -def _copytree(src, dst, ignore=lambda dir, content: list()): +def _copytree(src: Union[str, Path], dst: Union[str, Path], ignore=lambda dir, content: list()) -> None: if not os.path.exists(dst): os.makedirs(dst) dirlist = os.listdir(src) diff --git a/challtools/validator.py b/challtools/validator.py old mode 100755 new mode 100644 index 188dc7c..3e63b5d --- a/challtools/validator.py +++ b/challtools/validator.py @@ -1,11 +1,12 @@ -import os -from copy import deepcopy import json -from pathlib import Path +import os import pkg_resources -import yaml import re +import yaml +from copy import deepcopy from jsonschema import validate, ValidationError, Draft7Validator, validators +from pathlib import Path +from typing import Dict, Any, List, Tuple with pkg_resources.resource_stream("challtools", "codes.yml") as f: codes = yaml.safe_load(f) @@ -14,7 +15,7 @@ schema = json.load(f) -def is_url(s): +def is_url(s: str) -> bool: return s.startswith("http://") or s.startswith("https://") @@ -44,11 +45,20 @@ def set_defaults(validator, properties, instance, schema2): DefaultValidatingDraft7Validator = extend_with_default(Draft7Validator) +class Message(object): + def __init__(self, code: str, field, name: str, level: int, message: str): + self.code = code + self.field = field + self.name = name + self.level = level + self.message = message + + class ConfigValidator: def __init__(self, config, ctf_config=None, challdir=None): if not isinstance(config, dict): raise ValueError("Config parameter needs to be a dict") - self.messages = [] + self.messages: List[Message] = [] self.config = {} self.normalized_config = {} self.config = config @@ -64,7 +74,7 @@ def __init__(self, config, ctf_config=None, challdir=None): # return DefaultValidatingDraft7Validator(schema).validate(self.normalized_config) - def validate(self): + def validate(self) -> Tuple[bool, List[Message]]: """Validates the challenge config and returns a list of messages. Returns: @@ -252,7 +262,7 @@ def validate(self): return True, self.messages - def _raise_code(self, code, field=None, **formatting): + def _raise_code(self, code: str, field: str = None, **formatting) -> None: """Adds a formatted message entry into the messages array. Args: @@ -267,13 +277,13 @@ def _raise_code(self, code, field=None, **formatting): # no valid field check because of A002 self.messages.append( - { - "code": code, - "field": field, - "name": codes[code]["name"], - "level": codes[code]["level"], - "message": codes[code]["formatted_message"].format( + Message( + code=code, + field=field, + name=codes[code]["name"], + level=codes[code]["level"], + message=codes[code]["formatted_message"].format( field_name=field, **formatting ), - } + ) ) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..23fa822 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +-r requirements.txt +pytest +pytest-cov + +types-flask +types-jsonschema +types-pkg_resources +types-PyYAML +types-requests +types-setuptools diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8193e15 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +argcomplete +docker +jsonschema +pyyaml +requests \ No newline at end of file diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 diff --git a/tests/conftest.py b/tests/conftest.py index 3f1883c..3969cd2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,17 @@ -import pytest +from typing import Generator + import docker +import pytest + from challtools.utils import get_docker_client -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption("--docker-fails", action="store_true") parser.addoption("--docker-strict", action="store_true") -def pytest_collection_modifyitems(session, config, items): +def pytest_collection_modifyitems(session, config, items) -> None: for item in items: if config.option.docker_fails and "fails_without_docker" in set( marker.name for marker in item.own_markers @@ -17,19 +20,19 @@ def pytest_collection_modifyitems(session, config, items): @pytest.fixture(scope="session") -def docker_client(): +def docker_client() -> docker.api.client.ContainerApiMixin: return get_docker_client() @pytest.fixture() -def clean_container_state(docker_client): +def clean_container_state(docker_client: docker.api.client.ContainerApiMixin) -> Generator[None, None, None]: relevant_tags = [ "challtools_test", "challtools_test_challenge_f9629917705648c9", "sol_challtools_test_9461485faadf529f", ] - def remove_tags(): + def remove_tags() -> None: for image in docker_client.images.list(): for tag in image.tags: if tag.split(":")[0] in relevant_tags: diff --git a/tests/test_cli.py b/tests/test_cli.py index 485f5d4..ef39804 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,39 +1,41 @@ import os from pathlib import Path -import yaml + +import docker import pytest +import yaml + from challtools.utils import build_chall, get_valid_config from utils import populate_dir, main_wrapper, inittemplatepath - class Test_allchalls: - def test_validate(self, tmp_path, capsys): + def test_validate(self, tmp_path: Path, capsys) -> None: populate_dir(tmp_path, "simple_ctf") assert main_wrapper(["allchalls", "validate"]) == 0 assert capsys.readouterr().out.count("Validation succeeded.") == 3 - def test_no_ctf_config(self, tmp_path): + def test_no_ctf_config(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") Path("ctf.yml").unlink() assert main_wrapper(["allchalls", "validate"]) == 1 class Test_validate: - def test_ok(self, tmp_path): + def test_ok(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") assert main_wrapper(["validate"]) == 0 - def test_ok_subdir(self, tmp_path): + def test_ok_subdir(self, tmp_path: Path) -> None: populate_dir(tmp_path, "subdir") os.chdir("subdir") assert main_wrapper(["validate"]) == 0 - def test_schema_violation(self, tmp_path, capsys): + def test_schema_violation(self, tmp_path: Path, capsys) -> None: populate_dir(tmp_path, "schema_violation") assert main_wrapper(["validate"]) == 1 assert "A002" in capsys.readouterr().out - def test_schema_violation_list(self, tmp_path, capsys): + def test_schema_violation_list(self, tmp_path: Path, capsys) -> None: populate_dir(tmp_path, "schema_violation_list") assert main_wrapper(["validate"]) == 1 assert "A002" in capsys.readouterr().out @@ -41,13 +43,13 @@ def test_schema_violation_list(self, tmp_path, capsys): class Test_build: # TODO build scripts - def test_no_service(self, tmp_path, capsys): + def test_no_service(self, tmp_path: Path, capsys) -> None: populate_dir(tmp_path, "minimal_valid") assert main_wrapper(["build"]) == 0 assert "nothing to do" in capsys.readouterr().out.lower() @pytest.mark.fails_without_docker - def test_single(self, tmp_path, docker_client, clean_container_state): + def test_single(self, tmp_path: Path, docker_client: docker.api.client.ContainerApiMixin, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp") assert main_wrapper(["build"]) == 0 assert "challtools_test_challenge_f9629917705648c9:latest" in [ @@ -55,7 +57,7 @@ def test_single(self, tmp_path, docker_client, clean_container_state): ] @pytest.mark.fails_without_docker - def test_subdir(self, tmp_path, docker_client, clean_container_state): + def test_subdir(self, tmp_path: Path, docker_client: docker.api.client.ContainerApiMixin, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp") os.chdir("container") assert main_wrapper(["build"]) == 0 @@ -64,7 +66,7 @@ def test_subdir(self, tmp_path, docker_client, clean_container_state): ] @pytest.mark.fails_without_docker - def test_solution(self, tmp_path, docker_client, clean_container_state): + def test_solution(self, tmp_path: Path, docker_client: docker.api.client.ContainerApiMixin, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp_solution") assert main_wrapper(["build"]) == 0 import time @@ -75,13 +77,13 @@ def test_solution(self, tmp_path, docker_client, clean_container_state): assert "sol_challtools_test_9461485faadf529f:latest" in tags @pytest.mark.fails_without_docker - def test_build_error(self, tmp_path, capsys, clean_container_state): + def test_build_error(self, tmp_path: Path, capsys, clean_container_state) -> None: populate_dir(tmp_path, "build_error") assert main_wrapper(["build"]) == 1 assert "copy failed:" in capsys.readouterr().out.lower() @pytest.mark.fails_without_docker - def test_parse_error(self, tmp_path, capsys, clean_container_state): + def test_parse_error(self, tmp_path: Path, capsys, clean_container_state) -> None: populate_dir(tmp_path, "dockerfile_parse_error") assert main_wrapper(["build"]) == 1 assert "dockerfile parse error" in capsys.readouterr().out.lower() @@ -93,21 +95,21 @@ def test_parse_error(self, tmp_path, capsys, clean_container_state): class Test_solve: - def test_no_service(self, tmp_path, capsys): + def test_no_service(self, tmp_path: Path, capsys) -> None: populate_dir(tmp_path, "minimal_valid") build_chall(get_valid_config()) assert main_wrapper(["solve"]) == 0 assert "no solution defined" in capsys.readouterr().out.lower() @pytest.mark.fails_without_docker - def test_ok(self, tmp_path, capsys, clean_container_state): + def test_ok(self, tmp_path: Path, capsys, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp_solution") build_chall(get_valid_config()) assert main_wrapper(["solve"]) == 0 assert "solved" in capsys.readouterr().out.lower() @pytest.mark.fails_without_docker - def test_fail(self, tmp_path, capsys, clean_container_state): + def test_fail(self, tmp_path: Path, capsys, clean_container_state) -> None: populate_dir(tmp_path, "broken_solution") build_chall(get_valid_config()) assert main_wrapper(["solve"]) == 1 @@ -116,12 +118,12 @@ def test_fail(self, tmp_path, capsys, clean_container_state): class Test_compose: # TODO challenges with muliple containers - def test_no_service(self, tmp_path): + def test_no_service(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") assert main_wrapper(["compose"]) == 0 assert not Path("docker-compose.yml").exists() - def test_single(self, tmp_path): + def test_single(self, tmp_path: Path) -> None: populate_dir(tmp_path, "trivial_tcp") assert main_wrapper(["compose"]) == 0 assert Path("docker-compose.yml").exists() @@ -133,13 +135,13 @@ def test_single(self, tmp_path): class Test_ensureid: - def test_ok(self, tmp_path, capsys): + def test_ok(self, tmp_path: Path, capsys) -> None: populate_dir(tmp_path, "minimal_valid") assert main_wrapper(["ensureid"]) == 0 assert get_valid_config()["challenge_id"] assert "written" in capsys.readouterr().out.lower() - def test_has_id(self, tmp_path, capsys): + def test_has_id(self, tmp_path: Path, capsys) -> None: populate_dir(tmp_path, "has_id") assert main_wrapper(["ensureid"]) == 0 assert get_valid_config()["challenge_id"] @@ -147,7 +149,7 @@ def test_has_id(self, tmp_path, capsys): class Test_init: - def check_identical(self, tmp_path, template): + def check_identical(self, tmp_path: Path, template) -> bool: if not ( len(list(tmp_path.rglob("*"))) == len(list((inittemplatepath / template).rglob("*"))) - 1 @@ -166,19 +168,19 @@ def check_identical(self, tmp_path, template): return True - def test_empty(self, tmp_path, capsys): + def test_empty(self, tmp_path: Path, capsys) -> None: os.chdir(tmp_path) assert main_wrapper(["init"]) == 0 assert "initialized" in capsys.readouterr().out.lower() assert self.check_identical(tmp_path, "default") - def test_nonempty(self, tmp_path): + def test_nonempty(self, tmp_path: Path) -> None: os.chdir(tmp_path) Path("existing_file").touch() assert main_wrapper(["init", "default"]) == 1 assert not self.check_identical(tmp_path, "default") - def test_force(self, tmp_path): + def test_force(self, tmp_path: Path) -> None: os.chdir(tmp_path) Path("existing_file").touch() assert main_wrapper(["init", "default", "-f"]) == 0 @@ -186,7 +188,7 @@ def test_force(self, tmp_path): Path("existing_file").unlink() assert self.check_identical(tmp_path, "default") - def test_list(self, capsys): + def test_list(self, capsys) -> None: assert main_wrapper(["init", "--list"]) == 0 assert ( "default - a generic template suitable for any type of challenge" diff --git a/tests/test_utils.py b/tests/test_utils.py index c8e808a..931b5eb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,10 @@ import os import re from pathlib import Path + import pytest import yaml + from challtools.utils import ( CriticalException, process_messages, @@ -23,87 +25,86 @@ ) from utils import populate_dir - # TODO # class Test_process_messages: # pass class Test_get_ctf_config_path: - def test_root(self, tmp_path): + def test_root(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") assert get_ctf_config_path() == tmp_path / "ctf.yml" - def test_subdir(self, tmp_path): + def test_subdir(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") os.chdir("chall1") assert get_ctf_config_path() == tmp_path / "ctf.yml" - def test_yaml(self, tmp_path): + def test_yaml(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") Path("ctf.yml").rename("ctf.yaml") assert get_ctf_config_path() == tmp_path / "ctf.yaml" - def test_missing(self, tmp_path): + def test_missing(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") assert get_ctf_config_path() is None class Test_load_ctf_config: - def test_empty(self, tmp_path): + def test_empty(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") assert load_ctf_config() == {} - def test_populated(self, tmp_path): + def test_populated(self, tmp_path: Path) -> None: populate_dir(tmp_path, "ctf_authors") assert load_ctf_config() == yaml.safe_load((tmp_path / "ctf.yml").read_text()) - def test_missing(self, tmp_path): + def test_missing(self, tmp_path: Path) -> None: os.chdir(tmp_path) assert load_ctf_config() == None class Test_load_config: - def test_root(self, tmp_path): + def test_root(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") assert load_config() == yaml.safe_load((tmp_path / "challenge.yml").read_text()) - def test_subdir(self, tmp_path): + def test_subdir(self, tmp_path: Path) -> None: populate_dir(tmp_path, "subdir") os.chdir("subdir") assert load_config() == yaml.safe_load((tmp_path / "challenge.yml").read_text()) - def test_yaml(self, tmp_path): + def test_yaml(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") Path("challenge.yml").rename("challenge.yaml") assert load_config() == yaml.safe_load( (tmp_path / "challenge.yaml").read_text() ) - def test_missing(self, tmp_path): + def test_missing(self, tmp_path: Path) -> None: os.chdir(tmp_path) with pytest.raises(CriticalException): load_config() class Test_get_valid_config: - def test_valid(self, tmp_path): + def test_valid(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") assert get_valid_config() - def test_invalid(self, tmp_path): + def test_invalid(self, tmp_path: Path) -> None: populate_dir(tmp_path, "schema_violation") with pytest.raises(CriticalException): get_valid_config() - def test_invalid_list(self, tmp_path): + def test_invalid_list(self, tmp_path: Path) -> None: populate_dir(tmp_path, "schema_violation_list") with pytest.raises(CriticalException): get_valid_config() class Test_discover_challenges: - def test_root(self, tmp_path): + def test_root(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") assert set(discover_challenges()) == { tmp_path / "chall1" / "challenge.yml", @@ -111,7 +112,7 @@ def test_root(self, tmp_path): tmp_path / "chall3" / "challenge.yml", } - def test_subdir(self, tmp_path): + def test_subdir(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") os.chdir(tmp_path / "chall1") assert set(discover_challenges()) == { @@ -120,7 +121,7 @@ def test_subdir(self, tmp_path): tmp_path / "chall3" / "challenge.yml", } - def test_yaml(self, tmp_path): + def test_yaml(self, tmp_path: Path) -> None: populate_dir(tmp_path, "simple_ctf") (tmp_path / "chall2" / "challenge.yml").rename( tmp_path / "chall2" / "challenge.yaml" @@ -133,42 +134,42 @@ def test_yaml(self, tmp_path): class Test_get_first_text_flag: - def test_exists(self, tmp_path): + def test_exists(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") assert get_first_text_flag(get_valid_config()) == "CTF{d3f4ul7_fl46}" - def test_missing(self, tmp_path): + def test_missing(self, tmp_path: Path) -> None: populate_dir(tmp_path, "regex_flag") assert get_first_text_flag(get_valid_config()) is None class Test_create_docker_name: - def check_valid(self, name): + def check_valid(self, name: str) -> None: assert all(ord(c) < 128 for c in name) # docker tags can typically be 128 long, but here we check for 124 since challtools prefixes solution cointainers with "sol_" assert re.match(r"[\w][\w.-]{,123}", name) - def test_basic(self): + def test_basic(self) -> None: self.check_valid(create_docker_name("challenge")) - def test_long_title(self): + def test_long_title(self) -> None: self.check_valid(create_docker_name("challenge" * 128)) - def test_container(self): + def test_container(self) -> None: self.check_valid(create_docker_name("challenge", container_name="container")) - def test_container_long(self): + def test_container_long(self) -> None: self.check_valid( create_docker_name("challenge", container_name="container" * 128) ) - def test_chall_id(self): + def test_chall_id(self) -> None: self.check_valid(create_docker_name("challenge", chall_id="ididididid")) - def test_chall_id_long(self): + def test_chall_id_long(self) -> None: self.check_valid(create_docker_name("challenge", chall_id="ididididid" * 128)) - def test_all_long(self): + def test_all_long(self) -> None: self.check_valid( create_docker_name( "challenge" * 128, @@ -179,7 +180,7 @@ def test_all_long(self): class Test_format_user_service: - def test_tcp(self): + def test_tcp(self) -> None: assert ( format_user_service( {"custom_service_types": []}, "tcp", host="127.0.0.1", port="1337" @@ -187,7 +188,7 @@ def test_tcp(self): == "nc 127.0.0.1 1337" ) - def test_website(self): + def test_website(self) -> None: assert ( format_user_service( {"custom_service_types": []}, "website", url="http://127.0.0.1:1337" @@ -195,7 +196,7 @@ def test_website(self): == "http://127.0.0.1:1337" ) - def test_custom(self): + def test_custom(self) -> None: assert ( format_user_service( { @@ -215,28 +216,28 @@ def test_custom(self): class Test_validate_flag: - def test_default(self, tmp_path): + def test_default(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") config = get_valid_config() assert validate_flag(config, "CTF{d3f4ul7_fl46}") assert not validate_flag(config, "CTF{invalid}") assert not validate_flag(config, "d3f4ul7_fl46") - def test_no_format(self, tmp_path): + def test_no_format(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") config = get_valid_config() config["flag_format_prefix"] = None assert validate_flag(config, "d3f4ul7_fl46") assert not validate_flag(config, "CTF{d3f4ul7_fl46}") - def test_multiple(self, tmp_path): + def test_multiple(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") config = get_valid_config() config["flags"].append({"type": "text", "flag": "second_valid"}) assert validate_flag(config, "CTF{d3f4ul7_fl46}") assert validate_flag(config, "CTF{second_valid}") - def test_regex(self, tmp_path): + def test_regex(self, tmp_path: Path) -> None: populate_dir(tmp_path, "minimal_valid") config = get_valid_config() config["flags"] = [{"type": "regex", "flag": r"^\d{8}$"}] @@ -248,7 +249,7 @@ def test_regex(self, tmp_path): class Test_build_image: @pytest.mark.fails_without_docker - def test_simple(self, tmp_path, docker_client, clean_container_state): + def test_simple(self, tmp_path: Path, docker_client, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp") build_image("container", "challtools_test", docker_client) assert "challtools_test:latest" in [ @@ -260,7 +261,7 @@ class Test_build_chall: # TODO challenges with muliple containers # TODO build scripts @pytest.mark.fails_without_docker - def test_trivial_tcp(self, tmp_path, docker_client, clean_container_state): + def test_trivial_tcp(self, tmp_path, docker_client, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp") assert build_chall(get_valid_config()) assert "challtools_test_challenge_f9629917705648c9:latest" in [ @@ -268,7 +269,7 @@ def test_trivial_tcp(self, tmp_path, docker_client, clean_container_state): ] @pytest.mark.fails_without_docker - def test_solution(self, tmp_path, docker_client, clean_container_state): + def test_solution(self, tmp_path, docker_client, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp_solution") assert build_chall(get_valid_config()) tags = [tag for image in docker_client.images.list() for tag in image.tags] @@ -279,7 +280,7 @@ def test_solution(self, tmp_path, docker_client, clean_container_state): class Test_start_chall: # TODO challenges with muliple containers @pytest.mark.fails_without_docker - def test_single(self, tmp_path, clean_container_state): + def test_single(self, tmp_path, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp") config = get_valid_config() build_chall(config) @@ -288,7 +289,7 @@ def test_single(self, tmp_path, clean_container_state): assert re.match(r"nc 127.0.0.1 \d+", services[0]) @pytest.mark.fails_without_docker - def test_missing(self, tmp_path, clean_container_state): + def test_missing(self, tmp_path, clean_container_state) -> None: populate_dir(tmp_path, "minimal_valid") config = get_valid_config() build_chall(config) @@ -298,7 +299,7 @@ def test_missing(self, tmp_path, clean_container_state): class Test_start_solution: @pytest.mark.fails_without_docker - def test_simple(self, tmp_path, clean_container_state): + def test_simple(self, tmp_path, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp_solution") config = get_valid_config() build_chall(config) @@ -306,7 +307,7 @@ def test_simple(self, tmp_path, clean_container_state): assert container.image.tags[0] == "sol_challtools_test_9461485faadf529f:latest" @pytest.mark.fails_without_docker - def test_missing(self, tmp_path, clean_container_state): + def test_missing(self, tmp_path, clean_container_state) -> None: populate_dir(tmp_path, "trivial_tcp") config = get_valid_config() build_chall(config) diff --git a/tests/test_validator.py b/tests/test_validator.py index 04246c9..022ee57 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -19,7 +19,7 @@ def test_invalid(self): success, errors = validator.validate() assert not success - assert any([error["code"] == "A002" for error in errors]) + assert any([error.code == "A002" for error in errors]) class Test_A005: @@ -30,7 +30,7 @@ def test_valid(self): success, errors = validator.validate() assert success - assert not any([error["code"] == "A005" for error in errors]) + assert not any([error.code == "A005" for error in errors]) def test_warn(self): config = get_min_valid_config() @@ -39,7 +39,7 @@ def test_warn(self): success, errors = validator.validate() assert success - assert any([error["code"] == "A005" for error in errors]) + assert any([error.code == "A005" for error in errors]) class Test_A006: @@ -54,7 +54,7 @@ def test_valid(self): success, errors = validator.validate() assert success - assert not any([error["code"] == "A006" for error in errors]) + assert not any([error.code == "A006" for error in errors]) def test_invalid(self): config = get_min_valid_config() @@ -67,7 +67,7 @@ def test_invalid(self): success, errors = validator.validate() assert success - assert any([error["code"] == "A006" for error in errors]) + assert any([error.code == "A006" for error in errors]) class Test_A007: @@ -84,7 +84,7 @@ def test_valid(self): success, errors = validator.validate() assert success - assert not any([error["code"] == "A007" for error in errors]) + assert not any([error.code == "A007" for error in errors]) def test_invalid(self): config = get_min_valid_config() @@ -99,7 +99,7 @@ def test_invalid(self): success, errors = validator.validate() assert success - assert any([error["code"] == "A007" for error in errors]) + assert any([error.code == "A007" for error in errors]) class Test_A008: @@ -112,7 +112,7 @@ def test_valid(self): success, errors = validator.validate() assert success - assert not any([error["code"] == "A008" for error in errors]) + assert not any([error.code == "A008" for error in errors]) def test_invalid(self): config = get_min_valid_config() @@ -121,5 +121,7 @@ def test_invalid(self): success, errors = validator.validate() + print(errors) + assert success - assert any([error["code"] == "A008" for error in errors]) + assert any([error.code == "A008" for error in errors]) diff --git a/tests/utils.py b/tests/utils.py index 42c652f..d7dc44f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,7 @@ import os from pathlib import Path +from typing import List, Union + from challtools.cli import main from challtools.utils import _copytree @@ -8,7 +10,7 @@ inittemplatepath = testpath / ".." / "challtools" / "templates" -def populate_dir(path, template): +def populate_dir(path: Union[str, Path], template: str) -> None: os.chdir(path) if not template or not isinstance(template, str): @@ -20,7 +22,7 @@ def populate_dir(path, template): _copytree(templatepath / template, path) -def main_wrapper(args): +def main_wrapper(args: List[str]) -> int: try: exit_code = main(args) except SystemExit as e: