diff --git a/examples/image_classifier_ex.py b/examples/image_classifier_ex.py new file mode 100644 index 00000000..f94c0e5f --- /dev/null +++ b/examples/image_classifier_ex.py @@ -0,0 +1,91 @@ +# import the logicIntegratedClassifier class + +from pathlib import Path +import torch +import torch.nn as nn +import networkx as nx +import numpy as np +import random +from transformers import AutoImageProcessor, AutoModelForImageClassification +from PIL import Image +import torch.nn.functional as F +import cv2 +from ultralytics import YOLO + +from pyreason.scripts.learning.classification.hf_classifier import HuggingFaceLogicIntegratedClassifier +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions +from pyreason.scripts.rules.rule import Rule +from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_fact, add_rule, load_graph, save_rule_trace + + +# Step 1: Load a pre-trained model and image processor from Hugging Face +model_name = "google/vit-base-patch16-224" # Vision Transformer model +processor = AutoImageProcessor.from_pretrained(model_name) +model = AutoModelForImageClassification.from_pretrained(model_name) + +G = nx.DiGraph() +load_graph(G) + +# Step 2: Load and preprocess images from the directory +image_dir = Path(__file__).resolve().parent.parent / "examples" / "images" +image_paths = list(Path(image_dir).glob("*.jpeg")) # Get all .jpeg files in the directory +image_list = [] +allowed_labels = ['goldfish', 'tiger shark', 'hammerhead', 'great white shark', 'tench'] + +# Add Rules to the knowlege base +add_rule(Rule("is_fish(x) <-0 goldfish(x)", "is_fish_rule")) +add_rule(Rule("is_fish(x) <-0 tench(x)", "is_fish_rule")) +add_rule(Rule("is_shark(x) <-0 tigershark(x)", "is_shark_rule")) +add_rule(Rule("is_shark(x) <-0 hammerhead(x)", "is_shark_rule")) +add_rule(Rule("is_shark(x) <-0 greatwhiteshark(x)", "is_shark_rule")) +add_rule(Rule("is_scary(x) <-0 is_shark(x)", "is_scary_rule")) +add_rule(Rule("likes_to_eat(y,x) <-0 is_shark(y), is_fish(x)", "likes_to_eat_rule", infer_edges=True)) + +for image_path in image_paths: + print(f"Processing Image: {image_path.name}") + image = Image.open(image_path) + inputs = processor(images=image, return_tensors="pt") + + interface_options = ModelInterfaceOptions( + threshold=0.5, # Only process probabilities above 0.5 + set_lower_bound=True, # For high confidence, adjust the lower bound. + set_upper_bound=False, # Keep the upper bound unchanged. + snap_value=1.0 # Use 1.0 as the snap value. + ) + + classifier_name = image_path.name.split(".")[0] + fish_classifier = HuggingFaceLogicIntegratedClassifier( + model, + allowed_labels, + identifier=classifier_name, + interface_options=interface_options, + limit_classes=True + ) + + # print("Top Probs: ", filtered_probs) + logits, probabilities, classifier_facts = fish_classifier(inputs) + + print("=== Fish Classifier Output ===") + #print("Probabilities:", probabilities) + print("\nGenerated Classifier Facts:") + for fact in classifier_facts: + print(fact) + + for fact in classifier_facts: + add_fact(fact) + + print("Done processing image ", image_path.name) + +# --- Part 4: Run the Reasoning Engine --- + +# Reset settings before running reasoning +reset_settings() + +# Run the reasoning engine to allow the investigation flag to propagate hat through the network. +Settings.atom_trace = True +interpretation = reason() + +trace = get_rule_trace(interpretation) +print(f"NODE RULE TRACE: \n\n{trace[0]}\n") +print(f"EDGE RULE TRACE: \n\n{trace[1]}\n") diff --git a/examples/images/fish_1.jpeg b/examples/images/fish_1.jpeg new file mode 100644 index 00000000..6413569e Binary files /dev/null and b/examples/images/fish_1.jpeg differ diff --git a/examples/images/fish_2.jpeg b/examples/images/fish_2.jpeg new file mode 100644 index 00000000..581f1b17 Binary files /dev/null and b/examples/images/fish_2.jpeg differ diff --git a/examples/images/shark_1.jpeg b/examples/images/shark_1.jpeg new file mode 100644 index 00000000..e7ebb18f Binary files /dev/null and b/examples/images/shark_1.jpeg differ diff --git a/examples/images/shark_2.jpeg b/examples/images/shark_2.jpeg new file mode 100644 index 00000000..037c53dd Binary files /dev/null and b/examples/images/shark_2.jpeg differ diff --git a/examples/images/shark_3.jpeg b/examples/images/shark_3.jpeg new file mode 100644 index 00000000..640d01f2 Binary files /dev/null and b/examples/images/shark_3.jpeg differ diff --git a/examples/multiple_classifier_integration_ex.py b/examples/multiple_classifier_integration_ex.py new file mode 100644 index 00000000..9b3872b0 --- /dev/null +++ b/examples/multiple_classifier_integration_ex.py @@ -0,0 +1,135 @@ +import pyreason as pr +import torch +import torch.nn as nn +import networkx as nx +import numpy as np +import random + + +# seed_value = 41 # legitimate, high risk +# seed_value = 42 # fraud, low risk +seed_value = 44 # fraud, high risk +random.seed(seed_value) +np.random.seed(seed_value) +torch.manual_seed(seed_value) + + +# --- Part 1: Fraud Detector Model Integration --- +# Create a dummy PyTorch model for transaction fraud detection. +fraud_model = nn.Linear(5, 2) +fraud_class_names = ["fraud", "legitimate"] +transaction_features = torch.rand(1, 5) + +# Define integration options: only probabilities > 0.5 will trigger bounds adjustment. +fraud_interface_options = pr.ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 +) + +# Wrap the fraud detection model. +fraud_detector = pr.LogicIntegratedClassifier( + fraud_model, + fraud_class_names, + identifier="fraud_detector", + interface_options=fraud_interface_options +) + +# Run the fraud detector. +logits_fraud, probabilities_fraud, fraud_facts = fraud_detector(transaction_features) # Talk about time +print("=== Fraud Detector Output ===") +print("Logits:", logits_fraud) +print("Probabilities:", probabilities_fraud) +print("\nGenerated Fraud Detector Facts:") +for fact in fraud_facts: + print(fact) + +# Context and reasoning +for fact in fraud_facts: + pr.add_fact(fact) + +# Add additional contextual facts: +# 1. The transaction is from a suspicious location. +pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact")) +# 2. Link the transaction to AccountA. +pr.add_fact(pr.Fact("transaction(AccountA)", "transaction_link")) +# 3. Register AccountA as an account. +pr.add_fact(pr.Fact("account(AccountA)", "account_fact")) + +# Define reasoning rules: +# Rule A: If the fraud detector flags fraud and the transaction is suspicious, mark AccountA for investigation. +pr.add_rule(pr.Rule("requires_investigation(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud)", "investigation_rule")) + +# --- Set up Graph and Load --- +# Build a simple graph of accounts. +G = nx.DiGraph() +G.add_node("AccountA") +G.add_node("AccountB") +G.add_node("AccountC") +# Add edges with an attribute "relationship" set to "associated". +G.add_edge("AccountA", "AccountB", associated=1) +G.add_edge("AccountB", "AccountC", associated=1) +# Load the graph into PyReason. The edge attribute "relationship" is interpreted as the predicate 'associated'. +pr.load_graph(G) + +# Define propagation rules to spread investigation and critical action flags via the "associated" relationship. +pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "investigation_propagation_rule")) + +# --- Part 5: Run the Reasoning Engine --- +# Run the reasoning engine. +pr.settings.allow_ground_rules = True +pr.settings.atom_trace = True +interpretation = pr.reason() + +# Display reasoning results for 'requires_investigation'. +print("\n=== Reasoning Results for 'requires_investigation' ===") +trace = pr.get_rule_trace(interpretation) +print(f"RULE TRACE: \n\n{trace[0]}\n") + + +# --- Part 2: Risk Evaluator Model Integration --- +# Create another dummy PyTorch model for evaluating account risk. +risk_model = nn.Linear(5, 2) +risk_class_names = ["high_risk", "low_risk"] +risk_features = torch.rand(1, 5) + +# Define integration options for the risk evaluator. +risk_interface_options = pr.ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=True, + snap_value=1.0 +) + +# Wrap the risk evaluation model. +risk_evaluator = pr.LogicIntegratedClassifier( + risk_model, + risk_class_names, # document len + identifier="risk_evaluator", # binded constant + interface_options=risk_interface_options +) + +# Run the risk evaluator. +logits_risk, probabilities_risk, risk_facts = risk_evaluator(risk_features) +print("\n=== Risk Evaluator Output ===") +print("Logits:", logits_risk) +print("Probabilities:", probabilities_risk) +print("\nGenerated Risk Evaluator Facts:") +for fact in risk_facts: + print(fact) + +# --- Context and Reasoning again --- +for fact in risk_facts: + pr.add_fact(fact) + +# Rule B: If the fraud detector flags fraud and the risk evaluator flags high risk, mark AccountA for critical action. +pr.add_rule(pr.Rule("critical_action(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud), risk_evaluator(high_risk)", "critical_action_rule")) +pr.add_rule(pr.Rule("critical_action(y) <- critical_action(x), associated(x,y)", "critical_propagation_rule")) + +interpretation = pr.reason(again=True) + +# Display reasoning results for 'critical_action'. +print("\n=== Reasoning Results for 'critical_action' (Reasoning again) ===") +trace = pr.get_rule_trace(interpretation) +print(f"RULE TRACE: \n\n{trace[0]}\n") diff --git a/examples/temporal_classifier_ex.py b/examples/temporal_classifier_ex.py new file mode 100644 index 00000000..94053421 --- /dev/null +++ b/examples/temporal_classifier_ex.py @@ -0,0 +1,85 @@ +import time +import sys +import os + +import torch +import torch.nn as nn +import networkx as nx +import numpy as np +import random +from datetime import timedelta + +from pyreason.scripts.learning.classification.temporal_classifier import TemporalLogicIntegratedClassifier +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions +from pyreason.scripts.rules.rule import Rule +from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_fact, add_rule, load_graph, save_rule_trace, get_time, Query + +seed_value = 65 # Good Gap Gap +# seed_value = 47 # Good Gap Good +# seed_value = 43 # Good Good Good +random.seed(seed_value) +np.random.seed(seed_value) +torch.manual_seed(seed_value) + + +def input_fn(): + return torch.rand(1, 3) # Dummy input function for the model + + +weld_model = nn.Linear(3, 2) +class_names = ["good", "gap"] + +# Define integration options: +# Only consider probabilities above 0.5, adjust lower bound for high confidence, and use a snap value. +interface_options = ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 +) + +# Wrap the model using LogicIntegratedClassifier. +weld_quality_checker = TemporalLogicIntegratedClassifier( + weld_model, + class_names, + identifier="weld_object", + interface_options=interface_options, + poll_interval=timedelta(seconds=0.5), + # poll_interval=1, + poll_condition="gap", + input_fn=input_fn, +) + +add_rule(Rule("repairing(weld_object) <-1 gap(weld_object)", "repair attempted rule")) +add_rule(Rule("defective(weld_object) <-1 gap(weld_object), repairing(weld_object)", "defective rule")) + +max_iters = 5 +for weld_iter in range(max_iters): + # Time step 1: Initial inspection shows the weld is good. + features = torch.rand(1, 3) # Values chosen to indicate a good weld. + t = get_time() + logits, probs, classifier_facts = weld_quality_checker(features, t1=t, t2=t) + # print(f"=== Weld Inspection for Part: {weld_iter} ===") + # print("Logits:", logits) + # print("Probabilities:", probs) + for fact in classifier_facts: + add_fact(fact) + + settings = Settings + # Reasoning + settings.atom_trace = True + settings.verbose = False + again = False if weld_iter == 0 else True + interpretation = reason(timesteps=1, again=again, restart=False) + trace = get_rule_trace(interpretation) + print(f"\n=== Reasoning Rule Trace for Weld Part: {weld_iter} ===") + print(trace[0], "\n\n") + + time.sleep(5) + + # Check if part is defective + # if get_logic_program().interp.query(Query("defective(weld_object)")): + if interpretation.query(Query("defective(weld_object)")): + print("Defective weld detected! \n Replacing the part.\n\n") + break diff --git a/examples/temporal_classifier_integration_ex.py b/examples/temporal_classifier_integration_ex.py deleted file mode 100644 index 7113356e..00000000 --- a/examples/temporal_classifier_integration_ex.py +++ /dev/null @@ -1,80 +0,0 @@ -import pyreason as pr -import torch -import torch.nn as nn -import numpy as np -import random - -# Set a seed for reproducibility. -seed_value = 65 # Good Gap Gap -# seed_value = 47 # Good Gap Good -# seed_value = 43 # Good Good Good -random.seed(seed_value) -np.random.seed(seed_value) -torch.manual_seed(seed_value) - -# --- Part 1: Weld Quality Model Integration --- - -# Create a dummy PyTorch model for detecting weld quality. -# Each weld is represented by 3 features and is classified as "good" or "gap". -weld_model = nn.Linear(3, 2) -class_names = ["good", "gap"] - -# Define integration options: -# Only consider probabilities above 0.5, adjust lower bound for high confidence, and use a snap value. -interface_options = pr.ModelInterfaceOptions( - threshold=0.5, - set_lower_bound=True, - set_upper_bound=False, - snap_value=1.0 -) - -# Wrap the model using LogicIntegratedClassifier. -weld_quality_checker = pr.LogicIntegratedClassifier( - weld_model, - class_names, - identifier="weld_object", - interface_options=interface_options -) - -# --- Part 2: Simulate Weld Inspections Over Time --- -pr.add_rule(pr.Rule("repair_attempted(weld_object) <-1 gap(weld_object)", "repair attempted rule")) -pr.add_rule(pr.Rule("defective(weld_object) <-0 gap(weld_object), repair_attempted(weld_object)", "defective rule")) - -# Time step 1: Initial inspection shows the weld is good. -features_t0 = torch.rand(1, 3) # Values chosen to indicate a good weld. -logits_t0, probs_t0, classifier_facts_t0 = weld_quality_checker(features_t0, t1=0, t2=0) -print("=== Weld Inspection at Time 0 ===") -print("Logits:", logits_t0) -print("Probabilities:", probs_t0) -for fact in classifier_facts_t0: - pr.add_fact(fact) - -# Time step 2: Second inspection detects a gap. -features_t1 = torch.rand(1, 3) # Values chosen to simulate a gap. -logits_t1, probs_t1, classifier_facts_t1 = weld_quality_checker(features_t1, t1=1, t2=1) -print("\n=== Weld Inspection at Time 1 ===") -print("Logits:", logits_t1) -print("Probabilities:", probs_t1) -for fact in classifier_facts_t1: - pr.add_fact(fact) - - -# Time step 3: Third inspection, the gap still persists. -features_t2 = torch.rand(1, 3) # Values chosen to simulate persistent gap. -logits_t2, probs_t2, classifier_facts_t2 = weld_quality_checker(features_t2, t1=2, t2=2) -print("\n=== Weld Inspection at Time 2 ===") -print("Logits:", logits_t2) -print("Probabilities:", probs_t2) -for fact in classifier_facts_t2: - pr.add_fact(fact) - - -# --- Part 3: Run the Reasoning Engine --- - -# Enable atom tracing for debugging the rule application process. -pr.settings.atom_trace = True -interpretation = pr.reason(timesteps=2) -trace = pr.get_rule_trace(interpretation) - -print("\n=== Reasoning Rule Trace ===") -print(trace[0]) diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index ee78df92..e6966a9e 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -27,6 +27,7 @@ import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval +from pyreason.scripts.interpretation.interpretation_parallel import Interpretation from pyreason.scripts.utils.reorder_clauses import reorder_clauses if importlib.util.find_spec("torch") is not None: from pyreason.scripts.learning.classification.classifier import LogicIntegratedClassifier @@ -521,6 +522,37 @@ def reset_rules(): if __program is not None: __program.reset_rules() +def get_logic_program() -> Optional[Program]: + """Get the logic program object + + :return: Logic program object + """ + global __program + return __program + + +def get_interpretation() -> Optional[Interpretation]: + """Get the current interpretation + + :return: Current interpretation + """ + global __program + if __program is None: + raise Exception('No interpretation found. Please run `pr.reason()` first') + return __program.interp + + +def get_time() -> int: + """Get the current time + + :return: Current time + """ + try: + i = get_interpretation() + except Exception: + return 0 + return i.time + 1 + def reset_settings(): """ @@ -707,7 +739,7 @@ def add_fact(pyreason_fact: Fact) -> None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' if pyreason_fact.name in __node_facts_name_set: - warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") + warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.", stacklevel=2) f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) __node_facts_name_set.add(pyreason_fact.name) @@ -717,7 +749,7 @@ def add_fact(pyreason_fact: Fact) -> None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' if pyreason_fact.name in __node_facts_name_set: - warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") + warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.", stacklevel=2) f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) __node_facts_name_set.add(pyreason_fact.name) diff --git a/pyreason/scripts/learning/classification/classifier.py b/pyreason/scripts/learning/classification/classifier.py index b0a3fe53..419894e9 100644 --- a/pyreason/scripts/learning/classification/classifier.py +++ b/pyreason/scripts/learning/classification/classifier.py @@ -1,90 +1,95 @@ -from typing import List, Tuple +from typing import List import torch.nn import torch.nn.functional as F from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions -class LogicIntegratedClassifier(torch.nn.Module): +class LogicIntegratedClassifier(LogicIntegrationBase): """ Class to integrate a PyTorch model with PyReason. The output of the model is returned to the user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them. + Wraps any torch.nn.Module whose forward(x) returns [N, C] logits (multi-class). + Implements _infer, _postprocess, and _pred_to_facts to replace the original forward(). """ - def __init__(self, model, class_names: List[str], identifier: str = 'classifier', interface_options: ModelInterfaceOptions = None): - """ - :param model: PyTorch model to be integrated. - :param class_names: List of class names for the model output. - :param identifier: Identifier for the model, used as the constant in the facts. - :param interface_options: Options for the model interface, including threshold and snapping behavior. - """ - super(LogicIntegratedClassifier, self).__init__() - self.model = model - self.class_names = class_names - self.identifier = identifier - self.interface_options = interface_options - def get_class_facts(self, t1: int, t2: int) -> List[Fact]: + def __init__( + self, + model: torch.nn.Module, + class_names: List[str], + identifier: str = 'classifier', + interface_options: ModelInterfaceOptions = None + ): + super().__init__(model, class_names, interface_options, identifier) + + def _infer(self, x: torch.Tensor) -> torch.Tensor: + # Simply run the underlying model to get raw logits [N, C] + return self.model(x) + + def _postprocess(self, raw_output: torch.Tensor) -> torch.Tensor: """ - Return PyReason facts to create nodes for each class. Each class node will have bounds `[1,1]` with the - predicate corresponding to the model name. - :param t1: Start time for the facts - :param t2: End time for the facts - :return: List of PyReason facts + raw_output: a [N, C] logits tensor. + Apply softmax over dim=1 to get probabilities [N, C]. """ - facts = [] - for c in self.class_names: - fact = Fact(f'{c}({self.identifier})', name=f'{self.identifier}-{c}-fact', start_time=t1, end_time=t2) - facts.append(fact) - return facts + logits = raw_output + if logits.dim() != 2 or logits.size(1) != len(self.class_names): + raise ValueError( + f"Expected logits of shape [N, C] with C={len(self.class_names)}, " + f"got {tuple(logits.shape)}" + ) + return F.softmax(logits, dim=1) - def forward(self, x, t1: int = 0, t2: int = 0) -> Tuple[torch.Tensor, torch.Tensor, List[Fact]]: + def _pred_to_facts( + self, + raw_output: torch.Tensor, + probabilities: torch.Tensor, + t1: int, + t2: int + ) -> List[Fact]: """ - Forward pass of the model - :param x: Input tensor - :param t1: Start time for the facts - :param t2: End time for the facts - :return: Output tensor + Turn the [N, C] probability tensor into a flat List[Fact], + using threshold, snap_value, set_lower_bound, set_upper_bound. + Produces N * C facts. """ - output = self.model(x) - - # Convert logits to probabilities assuming a multi-class classification. - probabilities = F.softmax(output, dim=1).squeeze() opts = self.interface_options + prob = probabilities # [N, C] - # Prepare threshold tensor. - threshold = torch.tensor(opts.threshold, dtype=probabilities.dtype, device=probabilities.device) - condition = probabilities > threshold + # Build a threshold tensor + threshold = torch.tensor(opts.threshold, dtype=prob.dtype, device=prob.device) + condition = prob > threshold # [N, C] boolean + # Determine lower/upper for “true” entries if opts.snap_value is not None: - snap_value = torch.tensor(opts.snap_value, dtype=probabilities.dtype, device=probabilities.device) - # For values that pass the threshold: - lower_val = snap_value if opts.set_lower_bound else torch.tensor(0.0, dtype=probabilities.dtype, - device=probabilities.device) - upper_val = snap_value if opts.set_upper_bound else torch.tensor(1.0, dtype=probabilities.dtype, - device=probabilities.device) + snap_val = torch.tensor(opts.snap_value, dtype=prob.dtype, device=prob.device) + lower_if_true = ( + snap_val if opts.set_lower_bound else torch.tensor(0.0, dtype=prob.dtype, device=prob.device) + ) + upper_if_true = ( + snap_val if opts.set_upper_bound else torch.tensor(1.0, dtype=prob.dtype, device=prob.device) + ) else: - # If no snap_value is provided, keep original probabilities for those passing threshold. - lower_val = probabilities if opts.set_lower_bound else torch.zeros_like(probabilities) - upper_val = probabilities if opts.set_upper_bound else torch.ones_like(probabilities) + lower_if_true = prob if opts.set_lower_bound else torch.zeros_like(prob) + upper_if_true = prob if opts.set_upper_bound else torch.ones_like(prob) - # For probabilities that pass the threshold, apply the above; else, bounds are fixed to [0,1]. - lower_bounds = torch.where(condition, lower_val, torch.zeros_like(probabilities)) - upper_bounds = torch.where(condition, upper_val, torch.ones_like(probabilities)) + # Build full [N, C] lower_bounds and upper_bounds + zeros = torch.zeros_like(prob) + ones = torch.ones_like(prob) + lower_bounds = torch.where(condition, lower_if_true, zeros) # [N, C] + upper_bounds = torch.where(condition, upper_if_true, ones) # [N, C] - # Convert bounds to Python floats for fact creation. - bounds_list = [] - for i in range(len(self.class_names)): - lower = lower_bounds[i].item() - upper = upper_bounds[i].item() - bounds_list.append([lower, upper]) + N, C = prob.shape + facts: List[Fact] = [] - # Define time bounds for the facts. - facts = [] - for class_name, bounds in zip(self.class_names, bounds_list): - lower, upper = bounds - fact_str = f'{class_name}({self.identifier}) : [{lower:.3f}, {upper:.3f}]' - fact = Fact(fact_str, name=f'{self.identifier}-{class_name}-fact', start_time=t1, end_time=t2) - facts.append(fact) - return output, probabilities, facts + for i in range(N): + for j, class_name in enumerate(self.class_names): + lower = lower_bounds[i, j].item() + upper = upper_bounds[i, j].item() + fact_str = f"{class_name}({self.identifier}) : [{lower:.3f}, {upper:.3f}]" + fact_name = f"{self.identifier}-{class_name}-fact" + f = Fact(fact_str, name=fact_name, start_time=t1, end_time=t2) + facts.append(f) + + return facts diff --git a/pyreason/scripts/learning/classification/hf_classifier.py b/pyreason/scripts/learning/classification/hf_classifier.py new file mode 100644 index 00000000..01c8093c --- /dev/null +++ b/pyreason/scripts/learning/classification/hf_classifier.py @@ -0,0 +1,104 @@ +from typing import List, Any + +import torch +import torch.nn.functional as F + +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + + +class HuggingFaceLogicIntegratedClassifier(LogicIntegrationBase): + """ + Integrates a HuggingFace image classification model with PyReason. + Extends LogicIntegrationBase by implementing _infer, _postprocess, and _pred_to_facts. + """ + + def __init__( + self, + model, + class_names: List[str], + identifier: str = 'hf_classifier', + interface_options: ModelInterfaceOptions = None, + limit_classes: bool = False + ): + """ + :param model: A HuggingFace model (e.g. AutoModelForImageClassification). + :param class_names: List of class names for the model output. + :param identifier: Identifier for the model, used as the constant in the facts. + :param interface_options: Options for the model interface, including threshold and snapping behavior. + :param limit_classes: If True, filter output probabilities to only the classes in class_names + using the model's id2label config, renormalize, and reorder class_names by probability. + """ + super().__init__(model, class_names, interface_options, identifier) + self.limit_classes = limit_classes + + def _infer(self, x: Any) -> Any: + return self.model(**x).logits + + def _postprocess(self, raw_output: Any) -> Any: + probabilities = F.softmax(raw_output, dim=1).squeeze() + + if self.limit_classes: + probabilities, self._filtered_labels = self._filter_to_allowed_classes(probabilities) + else: + self._filtered_labels = None + + return probabilities + + def _filter_to_allowed_classes(self, probabilities: torch.Tensor): + """Filter probabilities to only the allowed class_names using model.config.id2label. + Returns (top_probs, top_labels) without mutating self.class_names.""" + id2label = self.model.config.id2label + + allowed_indices = [ + i for i, label in id2label.items() + if label.split(",")[0].strip().lower() in [name.lower() for name in self.class_names] + ] + + filtered_probs = torch.zeros_like(probabilities) + filtered_probs[allowed_indices] = probabilities[allowed_indices] + filtered_probs = filtered_probs / filtered_probs.sum() + + top_labels = [] + top_probs, top_indices = filtered_probs.topk(len(self.class_names)) + for idx in top_indices: + label = id2label[idx.item()].split(",")[0] + top_labels.append(label) + + return top_probs, top_labels + + def _pred_to_facts( + self, + raw_output: Any, + probabilities: Any, + t1: int = 0, + t2: int = 0 + ) -> List[Fact]: + opts = self.interface_options + + threshold = torch.tensor(opts.threshold, dtype=probabilities.dtype, device=probabilities.device) + condition = probabilities > threshold + + if opts.snap_value is not None: + snap_value = torch.tensor(opts.snap_value, dtype=probabilities.dtype, device=probabilities.device) + lower_val = snap_value if opts.set_lower_bound else torch.tensor(0.0, dtype=probabilities.dtype, device=probabilities.device) + upper_val = snap_value if opts.set_upper_bound else torch.tensor(1.0, dtype=probabilities.dtype, device=probabilities.device) + else: + lower_val = probabilities if opts.set_lower_bound else torch.zeros_like(probabilities) + upper_val = probabilities if opts.set_upper_bound else torch.ones_like(probabilities) + + lower_bounds = torch.where(condition, lower_val, torch.zeros_like(probabilities)) + upper_bounds = torch.where(condition, upper_val, torch.ones_like(probabilities)) + + labels = self._filtered_labels if self._filtered_labels is not None else self.class_names + + facts = [] + for i in range(len(labels)): + lower = lower_bounds[i].item() + upper = upper_bounds[i].item() + fact_str = f'{labels[i]}({self.identifier}) : [{lower:.3f}, {upper:.3f}]' + fact = Fact(fact_str, name=f'{self.identifier}-{labels[i]}-fact', start_time=t1, end_time=t2) + facts.append(fact) + + return facts diff --git a/pyreason/scripts/learning/classification/logic_integration_base.py b/pyreason/scripts/learning/classification/logic_integration_base.py new file mode 100644 index 00000000..27c40872 --- /dev/null +++ b/pyreason/scripts/learning/classification/logic_integration_base.py @@ -0,0 +1,124 @@ +import torch +from abc import ABC, abstractmethod +from typing import List, Tuple, Any + +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + + +class LogicIntegrationBase(torch.nn.Module, ABC): + """ + Abstract base class for **any** model (classifier, detector, etc.) whose + outputs you want to convert into PyReason Facts with lower/upper bounds. + + Subclasses must implement: + 1. _infer(x) → raw_output + 2. _pred_to_facts(raw_output, t1, t2) → List[Fact] + + The base class handles: + - Calling `self.model(x)` + - Applying threshold, snap_value, and bound‐construction (for “probabilistic” heads), + if desired. + - Packaging everything into a final (raw_output, probs_or_filtered, facts) tuple. + """ + + def __init__( + self, + model: torch.nn.Module, + class_names: List[str], + interface_options: ModelInterfaceOptions, + identifier: str = "model" + ): + """ + :param model: Any PyTorch module. Subclasses will call it in _infer(). + :param class_names: List of “predicate” names. For a detector, this is the full label list. + :param interface_options: Contains threshold, snap_value, set_lower_bound, set_upper_bound, etc. + :param identifier: Constant to inject into each Fact (e.g. “image1”, “classifier”, “detector”). + """ + super().__init__() + self.model = model + self.class_names = class_names + self.interface_options = interface_options + self.identifier = identifier + + # (Optional) sanity‐check on class_names vs. model (each subclass can override) + self._validate_init() + + def _validate_init(self): + """ + Hook for subclasses to check, e.g. that `len(class_names)` matches + whatever the underlying model expects. + """ + pass + + def forward( + self, + x: Any, + t1: int = 0, + t2: int = 0 + ) -> Tuple[Any, Any, List[Fact]]: + """ + 1) Call `_infer(x)` to get the “raw_output.” + 2) Call `_postprocess(raw_output)` to get either “probabilities” or “filtered detections,” + depending on model‐type. + 3) Call `_pred_to_facts(raw_output, postproc, t1, t2)` to build a List[Fact]. + + Returns a 3‐tuple: + (raw_output, postproc, facts_list) + + - raw_output: whatever `model(x)` naturally returned + - postproc: a tensor of probabilities or a list of filtered boxes, etc. + - facts_list: a flat List[Fact] + """ + # 1) raw predictions + raw_output = self._infer(x) + + # 2) “postprocess” step (e.g. softmax/sigmoid + threshold for classifiers, + # or filtering by confidence for detectors) + postproc = self._postprocess(raw_output) + + # 3) Turn them into Facts + facts: List[Fact] = self._pred_to_facts(raw_output, postproc, t1, t2) + + return raw_output, postproc, facts + + @abstractmethod + def _infer(self, x: Any) -> Any: + """ + Run the underlying PyTorch model (self.model) on input x, returning + the “raw” output. For a classifier, this is a logit‐tensor. For a YOLO detector, + this might be a Results object whose `.xyxy[i]` is a [num_det×6] tensor, etc. + """ + ... + + @abstractmethod + def _postprocess(self, raw_output: Any) -> Any: + """ + Convert raw model outputs into a more convenient “postprocessed” form + that we’ll pass both to the user and into `_pred_to_facts`. + + - For a binary/multiclass classifier, apply sigmoid/softmax + threshold mask. + - For a multilabel classifier, apply sigmoid + per‐class threshold mask. + - For a detector, extract a list of (class_idx, confidence) for all detections + above threshold. + """ + ... + + @abstractmethod + def _pred_to_facts( + self, + raw_output: Any, + postproc: Any, + t1: int, + t2: int + ) -> List[Fact]: + """ + Given raw_output and postproc (see above), build a List of PyReason Fact(...) objects, + each of the form: + f"{class_name}({self.identifier}) : [lower, upper]" + + - raw_output: whatever the model returned + - postproc: tensor-of-probs or list‐of‐(class_idx,confidence) + - t1, t2: start/end timestamps + """ + ... diff --git a/pyreason/scripts/learning/classification/temporal_classifier.py b/pyreason/scripts/learning/classification/temporal_classifier.py new file mode 100644 index 00000000..fe2fa51a --- /dev/null +++ b/pyreason/scripts/learning/classification/temporal_classifier.py @@ -0,0 +1,223 @@ +import threading +import time +from datetime import timedelta +from typing import List, Optional, Union, Callable, Any + +import torch.nn +import torch.nn.functional as F + +import pyreason as pr +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + + +class TemporalLogicIntegratedClassifier(LogicIntegrationBase): + """ + Wraps any torch.nn.Module whose forward(x) returns [N, C] logits (multi‐class), + but additionally polls in the background (either every N timesteps or every N seconds) + and injects new Facts into a PyReason logic program. + """ + def __init__( + self, + model, + class_names: List[str], + identifier: str = 'classifier', + interface_options: ModelInterfaceOptions = None, + logic_program=None, + poll_interval: Optional[Union[int, timedelta]] = None, + poll_condition: Optional[str] = None, + input_fn: Optional[Callable[[], Any]] = None, + ): + """ + :param model: PyTorch model to be integrated. + :param class_names: List of class names for the model output. + :param identifier: Identifier for the model, used as the constant in the facts. + :param interface_options: Options for the model interface, including threshold and snapping behavior. + :param logic_program: PyReason logic program + :param poll_interval: How often to poll the model, either as: + - an integer number of PyReason timesteps or + - a `datetime.timedelta` object representing wall-clock time. + If `None`, polling is disabled. + :param poll_condition: The name of the predicate attached to the model that must be true to trigger a poll. + If `None`, the model will be polled every `poll_interval` time steps/seconds. + :param input_fn: Function to call to get the input to the model. This function should return a tensor. + """ + super().__init__(model, class_names, interface_options, identifier) + self.model = model + self.class_names = class_names + self.identifier = identifier + self.interface_options = interface_options + self.logic_program = logic_program + self.poll_interval = poll_interval + self.poll_condition = poll_condition + self.input_fn = input_fn + + # normalize poll_interval + if isinstance(poll_interval, int): + self.poll_interval: Union[int, timedelta, None] = poll_interval + else: + self.poll_interval = poll_interval + + # start the async polling task if configured + if self.poll_interval is not None and self.input_fn is not None: + # this schedules the background task + # self._poller_task = asyncio.create_task(self._poll_loop()) + # kick off the background thread + t = threading.Thread(target=self._poll_loop, daemon=True) + t.start() + + def _get_current_timestep(self): + """ + Get the current timestep from the PyReason logic program. + :return: Current timestep + """ + if self.logic_program is not None and self.logic_program.interp is not None: + interp = self.logic_program.interp + t = interp.time + return t + elif pr.get_logic_program() is not None and pr.get_logic_program().interp is not None: + self.logic_program = pr.get_logic_program() + interp = self.logic_program.interp + t = interp.time + return t + else: + # raise ValueError("No PyReason logic program provided.") + return -1 + + def _poll_loop(self) -> None: + """ + Background async loop that polls every self.poll_interval. + """ + # if self.logic_program is None: + # raise ValueError("No logic program to add facts into.") + + # check if we have a logic program yet or not + while True: + current_time = self._get_current_timestep() + if current_time != -1: + # determine mode + if isinstance(self.poll_interval, timedelta): + interval_secs = self.poll_interval.total_seconds() + while True: + time.sleep(interval_secs) + current_time = self._get_current_timestep() + t1 = current_time + 1 + t2 = t1 + + if self.poll_condition: + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + continue + + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=False) + + else: + step_interval = self.poll_interval + last_step = current_time + 1 + while True: + # wait until enough timesteps have passed + while self._get_current_timestep() - last_step < step_interval: + time.sleep(0.01) + current = self._get_current_timestep() + t1, t2 = current, current + last_step = current + + if self.poll_condition: + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + continue + + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=False) + + + def get_class_facts(self, t1: int, t2: int) -> List[Fact]: + """ + Return PyReason facts to create nodes for each class. Each class node will have bounds `[1,1]` with the + predicate corresponding to the model name. + :param t1: Start time for the facts + :param t2: End time for the facts + :return: List of PyReason facts + """ + facts = [] + for c in self.class_names: + fact = Fact(f'{c}({self.identifier})', name=f'{self.identifier}-{c}-fact', start_time=t1, end_time=t2) + facts.append(fact) + return facts + + def _infer(self, x: torch.Tensor) -> torch.Tensor: + """ + Run the underlying model to get raw logits [N, C]. + """ + return self.model(x) + + def _postprocess(self, raw_output: torch.Tensor) -> torch.Tensor: + """ + raw_output should be a [N, C] logits tensor. Assert C == len(class_names), + then apply softmax over dim=1 → [N, C] probabilities. + """ + logits = raw_output + if logits.dim() != 2 or logits.size(1) != len(self.class_names): + raise ValueError( + f"Expected logits of shape [N, C] with C={len(self.class_names)}, got {tuple(logits.shape)}" + ) + return F.softmax(logits, dim=1) + + def _pred_to_facts( + self, + raw_output: torch.Tensor, + probabilities: torch.Tensor, + t1: int, + t2: int + ) -> List[Fact]: + """ + Given a [N, C] probability tensor, build a flat List[Fact], + using threshold, snap_value, set_lower_bound, set_upper_bound. + Returns N * C facts. + """ + opts = self.interface_options + prob = probabilities # [N, C] + + # Build a threshold tensor + threshold = torch.tensor(opts.threshold, dtype=prob.dtype, device=prob.device) + condition = prob > threshold # [N, C] boolean mask + + # Determine lower/upper for “true” entries + if opts.snap_value is not None: + snap_val = torch.tensor(opts.snap_value, dtype=prob.dtype, device=prob.device) + lower_if_true = (snap_val if opts.set_lower_bound + else torch.tensor(0.0, dtype=prob.dtype, device=prob.device)) + upper_if_true = (snap_val if opts.set_upper_bound + else torch.tensor(1.0, dtype=prob.dtype, device=prob.device)) + else: + lower_if_true = prob if opts.set_lower_bound else torch.zeros_like(prob) + upper_if_true = prob if opts.set_upper_bound else torch.ones_like(prob) + + zeros = torch.zeros_like(prob) + ones = torch.ones_like(prob) + lower_bounds = torch.where(condition, lower_if_true, zeros) # [N, C] + upper_bounds = torch.where(condition, upper_if_true, ones) # [N, C] + + N, C = prob.shape + all_facts: List[Fact] = [] + + for i in range(N): + for j, class_name in enumerate(self.class_names): + lower_val = lower_bounds[i, j].item() + upper_val = upper_bounds[i, j].item() + fact_str = f"{class_name}({self.identifier}) : [{lower_val:.3f}, {upper_val:.3f}]" + fact_name = f"{self.identifier}-{class_name}-fact" + f = Fact(fact_str, name=fact_name, start_time=t1, end_time=t2) + all_facts.append(f) + + return all_facts diff --git a/pyreason/scripts/learning/classification/yolo_classifier.py b/pyreason/scripts/learning/classification/yolo_classifier.py new file mode 100644 index 00000000..80b06c4a --- /dev/null +++ b/pyreason/scripts/learning/classification/yolo_classifier.py @@ -0,0 +1,187 @@ +from datetime import timedelta +import threading +import time + +import pyreason as pr +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + +from typing import List, Optional, Union, Callable, Any + +class YoloLogicIntegratedTemporalClassifier(LogicIntegrationBase): + """ + Class to integrate a YOLO model with PyReason. The output of the model is returned to the + user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them. + Wraps a YOLO model whose forward(x) returns bounding boxes with class probabilities. + Implements _infer, _postprocess, and _pred_to_facts to replace the original forward(). + """ + + def __init__( + self, + model, + class_names: List[str], + identifier: str = 'yolo_classifier', + interface_options: ModelInterfaceOptions = None, + poll_interval: Optional[Union[int, timedelta]] = None, + poll_condition: Optional[str] = None, + input_fn: Optional[Callable[[], Any]] = None + ): + """ + :param model: PyTorch model to be integrated. + :param class_names: List of class names for the model output. + :param identifier: Identifier for the model, used as the constant in the facts. + :param interface_options: Options for the model interface, including threshold and snapping behavior. + :param logic_program: PyReason logic program + :param poll_interval: How often to poll the model, either as: + - an integer number of PyReason timesteps or + - a `datetime.timedelta` object representing wall-clock time. + If `None`, polling is disabled. + :param poll_condition: The name of the predicate attached to the model that must be true to trigger a poll. + If `None`, the model will be polled every `poll_interval` time steps/seconds. + :param input_fn: Function to call to get the input to the model. This function should return a tensor. + """ + super().__init__(model, class_names, interface_options, identifier) + self.poll_interval = poll_interval + self.poll_condition = poll_condition + self.input_fn = input_fn + #TODO: Use the logic program? + self.logic_program = None # Get the current logic program + + # normalize poll_interval + if isinstance(poll_interval, int): + self.poll_interval: Union[int, timedelta, None] = poll_interval + else: + self.poll_interval = poll_interval + + # start the async polling task if configured + if self.poll_interval is not None and self.input_fn is not None: + # this schedules the background task + # self._poller_task = asyncio.create_task(self._poll_loop()) + # kick off the background thread + t = threading.Thread(target=self._poll_loop, daemon=True) + t.start() + + def _infer(self, x: Any) -> Any: + # resized_image = cv2.resize(image, (640, 640)) # Direct resize + # normalized_image = resized_image / 255.0 # Normalize + result_predict = self.model.predict(source = x, imgsz=(640), conf=0.1) #the default image size + return result_predict + + def _postprocess(self, raw_output: Any) -> Any: + """ + Process the raw output from the YOLO model to extract bounding boxes and class probabilities. + """ + result = raw_output[0] # Get the first result from the prediction + box = result.boxes[0] # Get the first bounding box from the result + label_id = int(box.cls) + confidence = float(box.conf) + label_name = result.names[label_id] # Get the label name from the names dictionary + return [label_name, confidence] + + def _pred_to_facts( + self, + raw_output: Any, + result: List, + t1: int, + t2: int + ) -> List[Fact]: + """ + Given a [N, C] probability tensor, build a flat List[Fact], + using threshold, snap_value, set_lower_bound, set_upper_bound. + Returns N * C facts. + """ + opts = self.interface_options + label = result[0] + confidence = result[1] + # Determine lower/upper for "true" entries + if opts.snap_value is not None: + snap_val = opts.snap_value + lower_if_true = (snap_val if opts.set_lower_bound + else 0) + upper_if_true = (snap_val if opts.set_upper_bound + else 1.0) + else: + lower_if_true = confidence if opts.set_lower_bound else 0 + upper_if_true = confidence if opts.set_upper_bound else 1.0 + + all_facts: List[Fact] = [] + + fact_str = f"_{label}({self.identifier}) : [{lower_if_true:.3f}, {upper_if_true:.3f}]" + fact_name = f"{self.identifier}-{label}-fact" + f = Fact(fact_str, name=fact_name, start_time=t1, end_time=t2) + all_facts.append(f) + + return all_facts + + def _get_current_timestep(self): + """ + Get the current timestep from the PyReason logic program. + :return: Current timestep + """ + if self.logic_program is not None and self.logic_program.interp is not None: + interp = self.logic_program.interp + t = interp.time + return t + elif pr.get_logic_program() is not None and pr.get_logic_program().interp is not None: + self.logic_program = pr.get_logic_program() + interp = self.logic_program.interp + t = interp.time + return t + else: + # raise ValueError("No PyReason logic program provided.") + return -1 + + def _poll_loop(self) -> None: + """ + Background async loop that polls every self.poll_interval. + """ + # if self.logic_program is None: + # raise ValueError("No logic program to add facts into.") + + # check if we have a logic program yet or not + while True: + current_time = self._get_current_timestep() + if current_time != -1: + # determine mode + if isinstance(self.poll_interval, timedelta): + interval_secs = self.poll_interval.total_seconds() + while True: + time.sleep(interval_secs) + current_time = self._get_current_timestep() + t1 = current_time + 1 + t2 = t1 + + if self.poll_condition: + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + continue + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=True) + + else: + step_interval = self.poll_interval + last_step = current_time + 1 + while True: + # wait until enough timesteps have passed + while self._get_current_timestep() - last_step < step_interval: + time.sleep(0.01) + current = self._get_current_timestep() + t1, t2 = current, current + last_step = current + + if self.poll_condition: + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + continue + + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=False) diff --git a/tests/unit/dont_disable_jit/test_classifiers.py b/tests/unit/dont_disable_jit/test_classifiers.py new file mode 100644 index 00000000..ad7d6736 --- /dev/null +++ b/tests/unit/dont_disable_jit/test_classifiers.py @@ -0,0 +1,473 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +from pyreason.scripts.learning.classification.classifier import LogicIntegratedClassifier +from pyreason.scripts.learning.classification.hf_classifier import HuggingFaceLogicIntegratedClassifier +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase +from pyreason.scripts.learning.classification.temporal_classifier import TemporalLogicIntegratedClassifier +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + + +class TestLogicIntegrationBase: + """Test that the abstract base class enforces the expected contract.""" + + def test_cannot_instantiate_directly(self): + model = nn.Linear(4, 2) + opts = ModelInterfaceOptions() + with pytest.raises(TypeError): + LogicIntegrationBase(model, ["a", "b"], opts, "test") + + def test_subclass_must_implement_abstract_methods(self): + class Incomplete(LogicIntegrationBase): + pass + + model = nn.Linear(4, 2) + opts = ModelInterfaceOptions() + with pytest.raises(TypeError): + Incomplete(model, ["a", "b"], opts, "test") + + +class TestLogicIntegratedClassifier: + """Basic coverage for the standard multi-class classifier wrapper.""" + + @pytest.fixture + def default_opts(self): + return ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 + ) + + @pytest.fixture + def classifier(self, default_opts): + torch.manual_seed(0) + model = nn.Linear(4, 3) + return LogicIntegratedClassifier( + model, + class_names=["cat", "dog", "bird"], + identifier="test_clf", + interface_options=default_opts + ) + + def test_forward_returns_three_tuple(self, classifier): + x = torch.rand(1, 4) + result = classifier(x) + assert len(result) == 3 + + def test_forward_returns_logits_probs_facts(self, classifier): + x = torch.rand(1, 4) + logits, probs, facts = classifier(x) + assert logits.shape == (1, 3) + assert probs.shape == (1, 3) + assert isinstance(facts, list) + + def test_probabilities_sum_to_one(self, classifier): + x = torch.rand(1, 4) + _, probs, _ = classifier(x) + assert abs(probs.sum().item() - 1.0) < 1e-5 + + def test_produces_one_fact_per_class(self, classifier): + x = torch.rand(1, 4) + _, _, facts = classifier(x) + assert len(facts) == 3 + + def test_fact_names_contain_identifier(self, classifier): + x = torch.rand(1, 4) + _, _, facts = classifier(x) + for fact in facts: + assert "test_clf" in fact.name + + def test_fact_predicates_match_class_names(self, classifier): + x = torch.rand(1, 4) + _, _, facts = classifier(x) + pred_names = [str(f.pred) for f in facts] + assert set(pred_names) == {"cat", "dog", "bird"} + + def test_batch_input_produces_n_times_c_facts(self, classifier): + x = torch.rand(3, 4) + _, _, facts = classifier(x) + assert len(facts) == 9 # 3 samples * 3 classes + + def test_snap_value_bounds(self, default_opts): + """When a probability exceeds threshold, lower bound should snap to snap_value.""" + torch.manual_seed(42) + # Use a model that outputs a clear winner + model = nn.Linear(2, 2) + clf = LogicIntegratedClassifier( + model, ["yes", "no"], identifier="snap_test", + interface_options=default_opts + ) + # Feed input that produces a high probability for one class + x = torch.tensor([[10.0, -10.0]]) + _, probs, facts = clf(x) + # The dominant class should have lower=1.0 (snapped), upper=1.0 (default) + dominant_fact = [f for f in facts if f.bound.lower == 1.0][0] + assert dominant_fact.bound.lower == 1.0 + assert dominant_fact.bound.upper == 1.0 + + def test_no_snap_value_uses_raw_probabilities(self): + """When snap_value is None, bounds should use the raw probability.""" + opts = ModelInterfaceOptions( + threshold=0.0, + set_lower_bound=True, + set_upper_bound=True, + snap_value=None + ) + torch.manual_seed(0) + model = nn.Linear(2, 2) + clf = LogicIntegratedClassifier( + model, ["a", "b"], identifier="raw_test", + interface_options=opts + ) + x = torch.rand(1, 2) + _, probs, facts = clf(x) + # With threshold=0 and snap_value=None, bounds should equal raw probabilities + for i, fact in enumerate(facts): + expected = probs[0, i].item() + assert abs(fact.bound.lower - expected) < 1e-3 + assert abs(fact.bound.upper - expected) < 1e-3 + + def test_below_threshold_gets_default_bounds(self, default_opts): + """Classes below threshold should get [0, 1] bounds.""" + torch.manual_seed(0) + model = nn.Linear(2, 2) + clf = LogicIntegratedClassifier( + model, ["a", "b"], identifier="thresh_test", + interface_options=default_opts + ) + x = torch.tensor([[10.0, -10.0]]) + _, probs, facts = clf(x) + # The losing class (prob ≈ 0) should have [0, 1] + losing = [f for f in facts if f.bound.lower == 0.0 and f.bound.upper == 1.0] + assert len(losing) == 1 + + def test_time_bounds_propagate(self, classifier): + x = torch.rand(1, 4) + _, _, facts = classifier(x, t1=5, t2=10) + for fact in facts: + assert fact.start_time == 5 + assert fact.end_time == 10 + + def test_logits_shape_mismatch_raises(self): + """Passing wrong input dimension should raise ValueError from _postprocess.""" + opts = ModelInterfaceOptions() + model = nn.Linear(4, 3) + clf = LogicIntegratedClassifier( + model, ["a", "b"], identifier="bad", + interface_options=opts + ) + x = torch.rand(1, 4) # model outputs 3 but class_names has 2 + with pytest.raises(ValueError, match="Expected logits of shape"): + clf(x) + + def test_class_names_not_mutated(self, classifier): + original = list(classifier.class_names) + x = torch.rand(1, 4) + classifier(x) + classifier(x) + assert classifier.class_names == original + + +def _make_hf_mock_model(num_classes=5): + """Create a mock HuggingFace model that returns logits from a Linear layer.""" + linear = nn.Linear(10, num_classes) + id2label = { + 0: "goldfish, Carassius auratus", + 1: "tiger shark, Galeocerdo cuvieri", + 2: "hammerhead, hammerhead shark", + 3: "great white shark, white shark", + 4: "tench, Tinca tinca", + } + + def mock_forward(**kwargs): + # Accept any kwargs (like pixel_values), return logits from the linear layer + x = torch.rand(1, 10) + return SimpleNamespace(logits=linear(x)) + + model = MagicMock() + model.side_effect = mock_forward + model.config = SimpleNamespace(id2label=id2label) + # Make it pass isinstance checks for nn.Module by giving it the needed attrs + model.training = False + return model + + +class TestHuggingFaceClassifier: + """Basic coverage for the HuggingFace classifier wrapper using mocked models.""" + + @pytest.fixture + def default_opts(self): + return ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 + ) + + @pytest.fixture + def hf_classifier(self, default_opts): + torch.manual_seed(0) + model = _make_hf_mock_model() + return HuggingFaceLogicIntegratedClassifier( + model, + class_names=["goldfish", "tiger shark", "hammerhead", "great white shark", "tench"], + identifier="hf_test", + interface_options=default_opts, + limit_classes=False + ) + + @pytest.fixture + def hf_classifier_limited(self, default_opts): + torch.manual_seed(0) + model = _make_hf_mock_model() + return HuggingFaceLogicIntegratedClassifier( + model, + class_names=["goldfish", "tiger shark", "hammerhead", "great white shark", "tench"], + identifier="hf_limited", + interface_options=default_opts, + limit_classes=True + ) + + def test_forward_returns_three_tuple(self, hf_classifier): + inputs = {"pixel_values": torch.rand(1, 3, 224, 224)} + result = hf_classifier(inputs) + assert len(result) == 3 + + def test_produces_facts(self, hf_classifier): + inputs = {"pixel_values": torch.rand(1, 3, 224, 224)} + _, _, facts = hf_classifier(inputs) + assert len(facts) > 0 + + def test_fact_names_contain_identifier(self, hf_classifier): + inputs = {"pixel_values": torch.rand(1, 3, 224, 224)} + _, _, facts = hf_classifier(inputs) + for fact in facts: + assert "hf_test" in fact.name + + def test_time_bounds_propagate(self, hf_classifier): + inputs = {"pixel_values": torch.rand(1, 3, 224, 224)} + _, _, facts = hf_classifier(inputs, t1=3, t2=7) + for fact in facts: + assert fact.start_time == 3 + assert fact.end_time == 7 + + def test_limit_classes_produces_correct_count(self, hf_classifier_limited): + inputs = {"pixel_values": torch.rand(1, 3, 224, 224)} + _, _, facts = hf_classifier_limited(inputs) + assert len(facts) == 5 + + def test_limit_classes_does_not_mutate_class_names(self, hf_classifier_limited): + original = list(hf_classifier_limited.class_names) + inputs = {"pixel_values": torch.rand(1, 3, 224, 224)} + hf_classifier_limited(inputs) + hf_classifier_limited(inputs) + assert hf_classifier_limited.class_names == original + + def test_limit_classes_facts_use_filtered_labels(self, hf_classifier_limited): + inputs = {"pixel_values": torch.rand(1, 3, 224, 224)} + _, _, facts = hf_classifier_limited(inputs) + # All fact predicates should be real label names from id2label, not indices + for fact in facts: + pred = str(fact.pred) + assert pred.replace(" ", "").isalpha() or " " in pred + + +def _make_yolo_mock_model(label_name="dog", confidence=0.85): + """Create a mock YOLO model that returns a single detection.""" + mock_box = MagicMock() + mock_box.cls = torch.tensor([1]) + mock_box.conf = torch.tensor([confidence]) + + mock_result = MagicMock() + mock_result.boxes = [mock_box] + mock_result.names = {0: "cat", 1: label_name, 2: "bird"} + + model = MagicMock() + model.predict.return_value = [mock_result] + model.training = False + return model + + +class TestYoloClassifier: + """Basic coverage for the YOLO classifier wrapper using mocked models.""" + + @pytest.fixture + def default_opts(self): + return ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 + ) + + @pytest.fixture + def yolo_classifier(self, default_opts): + from pyreason.scripts.learning.classification.yolo_classifier import ( + YoloLogicIntegratedTemporalClassifier, + ) + model = _make_yolo_mock_model("dog", 0.85) + return YoloLogicIntegratedTemporalClassifier( + model, + class_names=["cat", "dog", "bird"], + identifier="yolo_test", + interface_options=default_opts, + poll_interval=None, # disable polling for unit tests + ) + + def test_forward_returns_three_tuple(self, yolo_classifier): + raw, postproc, facts = yolo_classifier("fake_image.jpg") + assert len((raw, postproc, facts)) == 3 + + def test_produces_single_fact(self, yolo_classifier): + _, _, facts = yolo_classifier("fake_image.jpg") + assert len(facts) == 1 + + def test_fact_contains_detected_label(self, yolo_classifier): + _, _, facts = yolo_classifier("fake_image.jpg") + assert "dog" in str(facts[0].pred) + + def test_fact_name_contains_identifier(self, yolo_classifier): + _, _, facts = yolo_classifier("fake_image.jpg") + assert "yolo_test" in facts[0].name + + def test_postprocess_returns_label_and_confidence(self, yolo_classifier): + _, postproc, _ = yolo_classifier("fake_image.jpg") + assert postproc[0] == "dog" + assert abs(postproc[1] - 0.85) < 1e-2 + + def test_snap_value_applied(self, yolo_classifier): + _, _, facts = yolo_classifier("fake_image.jpg") + # snap_value=1.0 with set_lower_bound=True → lower should be 1.0 + assert facts[0].bound.lower == 1.0 + + def test_time_bounds_propagate(self, yolo_classifier): + _, _, facts = yolo_classifier("fake_image.jpg", t1=2, t2=5) + assert facts[0].start_time == 2 + assert facts[0].end_time == 5 + + def test_no_snap_uses_confidence(self): + from pyreason.scripts.learning.classification.yolo_classifier import ( + YoloLogicIntegratedTemporalClassifier, + ) + opts = ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=True, + snap_value=None + ) + model = _make_yolo_mock_model("cat", 0.92) + clf = YoloLogicIntegratedTemporalClassifier( + model, ["cat", "dog"], identifier="nosnap", + interface_options=opts, poll_interval=None + ) + _, _, facts = clf("img.jpg") + assert abs(facts[0].bound.lower - 0.92) < 1e-2 + assert abs(facts[0].bound.upper - 0.92) < 1e-2 + + def test_no_polling_thread_when_interval_none(self, yolo_classifier): + # poll_interval=None means no background thread should be started + # Verify the classifier works normally without hanging + _, _, facts = yolo_classifier("test.jpg") + assert len(facts) == 1 + + +class TestTemporalLogicIntegratedClassifier: + """Basic coverage for the temporal classifier wrapper (no polling).""" + + @pytest.fixture + def default_opts(self): + return ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 + ) + + @pytest.fixture + def temporal_classifier(self, default_opts): + torch.manual_seed(0) + model = nn.Linear(4, 3) + return TemporalLogicIntegratedClassifier( + model, + class_names=["cat", "dog", "bird"], + identifier="temporal_test", + interface_options=default_opts, + poll_interval=None, # disable polling for unit tests + ) + + def test_forward_returns_three_tuple(self, temporal_classifier): + x = torch.rand(1, 4) + result = temporal_classifier(x) + assert len(result) == 3 + + def test_forward_returns_logits_probs_facts(self, temporal_classifier): + x = torch.rand(1, 4) + logits, probs, facts = temporal_classifier(x) + assert logits.shape == (1, 3) + assert probs.shape == (1, 3) + assert isinstance(facts, list) + + def test_probabilities_sum_to_one(self, temporal_classifier): + x = torch.rand(1, 4) + _, probs, _ = temporal_classifier(x) + assert abs(probs.sum().item() - 1.0) < 1e-5 + + def test_produces_one_fact_per_class(self, temporal_classifier): + x = torch.rand(1, 4) + _, _, facts = temporal_classifier(x) + assert len(facts) == 3 + + def test_fact_names_contain_identifier(self, temporal_classifier): + x = torch.rand(1, 4) + _, _, facts = temporal_classifier(x) + for fact in facts: + assert "temporal_test" in fact.name + + def test_fact_predicates_match_class_names(self, temporal_classifier): + x = torch.rand(1, 4) + _, _, facts = temporal_classifier(x) + pred_names = [str(f.pred) for f in facts] + assert set(pred_names) == {"cat", "dog", "bird"} + + def test_batch_input_produces_n_times_c_facts(self, temporal_classifier): + x = torch.rand(3, 4) + _, _, facts = temporal_classifier(x) + assert len(facts) == 9 # 3 samples * 3 classes + + def test_time_bounds_propagate(self, temporal_classifier): + x = torch.rand(1, 4) + _, _, facts = temporal_classifier(x, t1=5, t2=10) + for fact in facts: + assert fact.start_time == 5 + assert fact.end_time == 10 + + def test_logits_shape_mismatch_raises(self): + opts = ModelInterfaceOptions() + model = nn.Linear(4, 3) + clf = TemporalLogicIntegratedClassifier( + model, ["a", "b"], identifier="bad", + interface_options=opts, poll_interval=None + ) + x = torch.rand(1, 4) # model outputs 3 but class_names has 2 + with pytest.raises(ValueError, match="Expected logits of shape"): + clf(x) + + def test_get_class_facts(self, temporal_classifier): + facts = temporal_classifier.get_class_facts(t1=0, t2=5) + assert len(facts) == 3 + for fact in facts: + assert fact.start_time == 0 + assert fact.end_time == 5 + assert fact.bound.lower == 1.0 + assert fact.bound.upper == 1.0 + + def test_no_polling_thread_when_interval_none(self, temporal_classifier): + # Just verify it works without hanging + x = torch.rand(1, 4) + _, _, facts = temporal_classifier(x) + assert len(facts) == 3