diff --git a/README.md b/README.md index 48966b3..09157b2 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,30 @@ if __name__ == '__main__': * running it outputs the previous PlantUML diagram in the terminal and writes it in a file. +### Additionally you can also pass filters to skip specific blocks and relations +```python +from py2puml.domain.umlrelation import UmlRelation +from py2puml.domain.umlclass import UmlMethod +from py2puml.domain.umlitem import UmlItem +from py2puml.export.puml import Filters +from py2puml.py2puml import py2puml + +def skip_block(item: UmlItem) -> bool: + return item.fqn.endswith('') + +def skip_relation(relation: UmlRelation) -> bool: + return relation.source_fqn.endswith('') and relation.target_fqn.endswith('') + +filters = Filters(skip_block, skip_relation) + +puml_content = "".join( + py2puml( + 'py2puml/domain', + 'py2puml.domain', + filters + ) + ) +``` # Tests diff --git a/py2puml/export/__init__.py b/py2puml/export/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py2puml/export/puml.py b/py2puml/export/puml.py index 638cbd5..f2567f7 100644 --- a/py2puml/export/puml.py +++ b/py2puml/export/puml.py @@ -1,4 +1,5 @@ -from typing import Iterable, List +from dataclasses import dataclass +from typing import Callable, Iterable, List, Optional from py2puml.domain.umlclass import UmlClass from py2puml.domain.umlenum import UmlEnum @@ -26,11 +27,40 @@ FEATURE_INSTANCE = '' -def to_puml_content(diagram_name: str, uml_items: List[UmlItem], uml_relations: List[UmlRelation]) -> Iterable[str]: +@dataclass +class Filters: + skip_block: Optional[Callable[[UmlItem], bool]] = None + skip_relation: Optional[Callable[[UmlRelation], bool]] = None + + +def should_skip(filter: Callable | None, item: UmlItem | UmlRelation) -> bool: + if filter is None: + return False + + if not callable(filter): + raise ValueError('Filter must be a callable') + + try: + _should_skip = filter(item) + if not isinstance(_should_skip, bool): + raise ValueError('Filter must return a boolean value') + return _should_skip + except Exception as e: + raise ValueError('Error while applying filter') from e + + +def to_puml_content( + diagram_name: str, uml_items: List[UmlItem], uml_relations: List[UmlRelation], filters: Optional[Filters] = None +) -> Iterable[str]: + if filters is None: + filters = Filters() + yield PUML_FILE_START.format(diagram_name=diagram_name) # exports the domain classes and enums for uml_item in uml_items: + if should_skip(filters.skip_block, uml_item): + continue if isinstance(uml_item, UmlEnum): uml_enum: UmlEnum = uml_item yield PUML_ITEM_START_TPL.format(item_type='enum', item_fqn=uml_enum.fqn) @@ -48,12 +78,15 @@ def to_puml_content(diagram_name: str, uml_items: List[UmlItem], uml_relations: attr_type=uml_attr.type, staticity=FEATURE_STATIC if uml_attr.static else FEATURE_INSTANCE, ) + # TODO: Add skip_method filter here once PR #43 is merged yield PUML_ITEM_END else: raise TypeError(f'cannot process uml_item of type {uml_item.__class__}') # exports the domain relationships between classes and enums for uml_relation in uml_relations: + if should_skip(filters.skip_relation, uml_relation): + continue yield PUML_RELATION_TPL.format( source_fqn=uml_relation.source_fqn, rel_type=uml_relation.type.value, target_fqn=uml_relation.target_fqn ) diff --git a/py2puml/py2puml.py b/py2puml/py2puml.py index 8300beb..b5518c1 100644 --- a/py2puml/py2puml.py +++ b/py2puml/py2puml.py @@ -1,14 +1,14 @@ -from typing import Dict, Iterable, List +from typing import Dict, Iterable, List, Optional from py2puml.domain.umlitem import UmlItem from py2puml.domain.umlrelation import UmlRelation -from py2puml.export.puml import to_puml_content +from py2puml.export.puml import Filters, to_puml_content from py2puml.inspection.inspectpackage import inspect_package -def py2puml(domain_path: str, domain_module: str) -> Iterable[str]: +def py2puml(domain_path: str, domain_module: str, filters: Optional[Filters] = None) -> Iterable[str]: domain_items_by_fqn: Dict[str, UmlItem] = {} domain_relations: List[UmlRelation] = [] inspect_package(domain_path, domain_module, domain_items_by_fqn, domain_relations) - return to_puml_content(domain_module, domain_items_by_fqn.values(), domain_relations) + return to_puml_content(domain_module, domain_items_by_fqn.values(), domain_relations, filters) diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/py2puml/export/test_filters.py b/tests/py2puml/export/test_filters.py new file mode 100644 index 0000000..39d34f4 --- /dev/null +++ b/tests/py2puml/export/test_filters.py @@ -0,0 +1,147 @@ +import pytest + +from py2puml.domain.umlitem import UmlItem +from py2puml.domain.umlrelation import UmlRelation +from py2puml.export.puml import Filters +from py2puml.py2puml import py2puml + +un_modified_puml = [ + '@startuml tests.modules.withinheritedconstructor\n!pragma useIntermediatePackages false\n\n', + 'class tests.modules.withinheritedconstructor.metricorigin.MetricOrigin {\n', + ' unit: str {static}\n', + '}\n', + 'class tests.modules.withinheritedconstructor.point.Origin {\n', + ' is_origin: bool {static}\n', + '}\n', + 'class tests.modules.withinheritedconstructor.point.Point {\n', + ' x: float\n', + ' y: float\n', + '}\n', + 'tests.modules.withinheritedconstructor.point.Origin <|-- tests.modules.withinheritedconstructor.metricorigin.MetricOrigin\n', + 'tests.modules.withinheritedconstructor.point.Point <|-- tests.modules.withinheritedconstructor.point.Origin\n', + 'footer Generated by //py2puml//\n', + '@enduml\n', +] + +puml_with_origin_class_skipped = [ + '@startuml tests.modules.withinheritedconstructor\n!pragma useIntermediatePackages false\n\n', + 'class tests.modules.withinheritedconstructor.metricorigin.MetricOrigin {\n', + ' unit: str {static}\n', + '}\n', + 'class tests.modules.withinheritedconstructor.point.Point {\n', + ' x: float\n', + ' y: float\n', + '}\n', + 'tests.modules.withinheritedconstructor.point.Origin <|-- tests.modules.withinheritedconstructor.metricorigin.MetricOrigin\n', + 'tests.modules.withinheritedconstructor.point.Point <|-- tests.modules.withinheritedconstructor.point.Origin\n', + 'footer Generated by //py2puml//\n', + '@enduml\n', +] + +puml_with_point_origin_relation_skipped = [ + '@startuml tests.modules.withinheritedconstructor\n!pragma useIntermediatePackages false\n\n', + 'class tests.modules.withinheritedconstructor.metricorigin.MetricOrigin {\n', + ' unit: str {static}\n', + '}\n', + 'class tests.modules.withinheritedconstructor.point.Origin {\n', + ' is_origin: bool {static}\n', + '}\n', + 'class tests.modules.withinheritedconstructor.point.Point {\n', + ' x: float\n', + ' y: float\n', + '}\n', + 'tests.modules.withinheritedconstructor.point.Origin <|-- tests.modules.withinheritedconstructor.metricorigin.MetricOrigin\n', + 'footer Generated by //py2puml//\n', + '@enduml\n', +] + +puml_with_point_class_and_point_origin_relation_skipped = [ + '@startuml tests.modules.withinheritedconstructor\n!pragma useIntermediatePackages false\n\n', + 'class tests.modules.withinheritedconstructor.metricorigin.MetricOrigin {\n', + ' unit: str {static}\n', + '}\n', + 'class tests.modules.withinheritedconstructor.point.Point {\n', + ' x: float\n', + ' y: float\n', + '}\n', + 'tests.modules.withinheritedconstructor.point.Origin <|-- tests.modules.withinheritedconstructor.metricorigin.MetricOrigin\n', + 'footer Generated by //py2puml//\n', + '@enduml\n', +] + + +def skip_origin_block(item: UmlItem) -> bool: + return item.fqn.endswith('.Origin') + + +def skip_point_origin_relation(relation: UmlRelation) -> bool: + return relation.source_fqn.endswith('.Point') and relation.target_fqn.endswith('.Origin') + + +def get_puml_content(filters: Filters) -> list[str]: + return list(py2puml('tests/modules/withinheritedconstructor', 'tests.modules.withinheritedconstructor', filters)) + + +def invalid_filter_without_filter_argument(): + return True + + +def invalid_filter_with_wrong_return_type(item: UmlItem) -> str: + return 'True' + + +def invalid_filter_with_exception(item: UmlItem) -> bool: + raise Exception('An error occurred') + + +non_callable_filter = 'not a function' + + +def test_without_giving_filters(): + generated_puml = list(py2puml('tests/modules/withinheritedconstructor', 'tests.modules.withinheritedconstructor')) + assert generated_puml == un_modified_puml + + +def test_default_filters(): + filters = Filters() + generated_puml = get_puml_content(filters) + assert generated_puml == un_modified_puml + + +def test_skip_origin_class(): + filters = Filters(skip_block=skip_origin_block) + generated_puml = get_puml_content(filters) + assert generated_puml == puml_with_origin_class_skipped + + +def test_skip_point_origin_relation(): + filters = Filters(skip_relation=skip_point_origin_relation) + generated_puml = get_puml_content(filters) + assert generated_puml == puml_with_point_origin_relation_skipped + + +def test_skip_point_class_and_point_origin_relation(): + filters = Filters(skip_block=skip_origin_block, skip_relation=skip_point_origin_relation) + generated_puml = get_puml_content(filters) + print(''.join(generated_puml)) + print(len(generated_puml), len(puml_with_point_class_and_point_origin_relation_skipped)) + assert generated_puml == puml_with_point_class_and_point_origin_relation_skipped + + +@pytest.mark.parametrize( + 'invalid_filter', + [ + invalid_filter_without_filter_argument, + invalid_filter_with_wrong_return_type, + invalid_filter_with_exception, + non_callable_filter, + ], +) +def test_invalid_filters(invalid_filter): + with pytest.raises(ValueError): + filters = Filters(skip_block=invalid_filter) + get_puml_content(filters) + + with pytest.raises(ValueError): + filters = Filters(skip_relation=invalid_filter) + get_puml_content(filters)