Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions examples/image_classifier_ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# import the logicIntegratedClassifier class

from pathlib import Path
import torch
import torch.nn as nn
import networkx as nx
import numpy as np
import random
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch.nn.functional as F
import cv2
from ultralytics import YOLO

from pyreason.scripts.learning.classification.hf_classifier import HuggingFaceLogicIntegratedClassifier
from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
from pyreason.scripts.rules.rule import Rule
from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_fact, add_rule, load_graph, save_rule_trace


# Step 1: Load a pre-trained model and image processor from Hugging Face
model_name = "google/vit-base-patch16-224" # Vision Transformer model
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

G = nx.DiGraph()
load_graph(G)

# Step 2: Load and preprocess images from the directory
image_dir = Path(__file__).resolve().parent.parent / "examples" / "images"
image_paths = list(Path(image_dir).glob("*.jpeg")) # Get all .jpeg files in the directory
image_list = []
allowed_labels = ['goldfish', 'tiger shark', 'hammerhead', 'great white shark', 'tench']

# Add Rules to the knowlege base
add_rule(Rule("is_fish(x) <-0 goldfish(x)", "is_fish_rule"))
add_rule(Rule("is_fish(x) <-0 tench(x)", "is_fish_rule"))
add_rule(Rule("is_shark(x) <-0 tigershark(x)", "is_shark_rule"))
add_rule(Rule("is_shark(x) <-0 hammerhead(x)", "is_shark_rule"))
add_rule(Rule("is_shark(x) <-0 greatwhiteshark(x)", "is_shark_rule"))
add_rule(Rule("is_scary(x) <-0 is_shark(x)", "is_scary_rule"))
add_rule(Rule("likes_to_eat(y,x) <-0 is_shark(y), is_fish(x)", "likes_to_eat_rule", infer_edges=True))

for image_path in image_paths:
print(f"Processing Image: {image_path.name}")
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")

interface_options = ModelInterfaceOptions(
threshold=0.5, # Only process probabilities above 0.5
set_lower_bound=True, # For high confidence, adjust the lower bound.
set_upper_bound=False, # Keep the upper bound unchanged.
snap_value=1.0 # Use 1.0 as the snap value.
)

classifier_name = image_path.name.split(".")[0]
fish_classifier = HuggingFaceLogicIntegratedClassifier(
model,
allowed_labels,
identifier=classifier_name,
interface_options=interface_options,
limit_classes=True
)

# print("Top Probs: ", filtered_probs)
logits, probabilities, classifier_facts = fish_classifier(inputs)

print("=== Fish Classifier Output ===")
#print("Probabilities:", probabilities)
print("\nGenerated Classifier Facts:")
for fact in classifier_facts:
print(fact)

for fact in classifier_facts:
add_fact(fact)

print("Done processing image ", image_path.name)

# --- Part 4: Run the Reasoning Engine ---

# Reset settings before running reasoning
reset_settings()

# Run the reasoning engine to allow the investigation flag to propagate hat through the network.
Settings.atom_trace = True
interpretation = reason()

trace = get_rule_trace(interpretation)
print(f"NODE RULE TRACE: \n\n{trace[0]}\n")
print(f"EDGE RULE TRACE: \n\n{trace[1]}\n")
Binary file added examples/images/fish_1.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/fish_2.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/shark_1.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/shark_2.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/shark_3.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
135 changes: 135 additions & 0 deletions examples/multiple_classifier_integration_ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pyreason as pr
import torch
import torch.nn as nn
import networkx as nx
import numpy as np
import random


# seed_value = 41 # legitimate, high risk
# seed_value = 42 # fraud, low risk
seed_value = 44 # fraud, high risk
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)


# --- Part 1: Fraud Detector Model Integration ---
# Create a dummy PyTorch model for transaction fraud detection.
fraud_model = nn.Linear(5, 2)
fraud_class_names = ["fraud", "legitimate"]
transaction_features = torch.rand(1, 5)

# Define integration options: only probabilities > 0.5 will trigger bounds adjustment.
fraud_interface_options = pr.ModelInterfaceOptions(
threshold=0.5,
set_lower_bound=True,
set_upper_bound=False,
snap_value=1.0
)

# Wrap the fraud detection model.
fraud_detector = pr.LogicIntegratedClassifier(
fraud_model,
fraud_class_names,
identifier="fraud_detector",
interface_options=fraud_interface_options
)

# Run the fraud detector.
logits_fraud, probabilities_fraud, fraud_facts = fraud_detector(transaction_features) # Talk about time
print("=== Fraud Detector Output ===")
print("Logits:", logits_fraud)
print("Probabilities:", probabilities_fraud)
print("\nGenerated Fraud Detector Facts:")
for fact in fraud_facts:
print(fact)

# Context and reasoning
for fact in fraud_facts:
pr.add_fact(fact)

# Add additional contextual facts:
# 1. The transaction is from a suspicious location.
pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact"))
# 2. Link the transaction to AccountA.
pr.add_fact(pr.Fact("transaction(AccountA)", "transaction_link"))
# 3. Register AccountA as an account.
pr.add_fact(pr.Fact("account(AccountA)", "account_fact"))

# Define reasoning rules:
# Rule A: If the fraud detector flags fraud and the transaction is suspicious, mark AccountA for investigation.
pr.add_rule(pr.Rule("requires_investigation(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud)", "investigation_rule"))

# --- Set up Graph and Load ---
# Build a simple graph of accounts.
G = nx.DiGraph()
G.add_node("AccountA")
G.add_node("AccountB")
G.add_node("AccountC")
# Add edges with an attribute "relationship" set to "associated".
G.add_edge("AccountA", "AccountB", associated=1)
G.add_edge("AccountB", "AccountC", associated=1)
# Load the graph into PyReason. The edge attribute "relationship" is interpreted as the predicate 'associated'.
pr.load_graph(G)

# Define propagation rules to spread investigation and critical action flags via the "associated" relationship.
pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "investigation_propagation_rule"))

# --- Part 5: Run the Reasoning Engine ---
# Run the reasoning engine.
pr.settings.allow_ground_rules = True
pr.settings.atom_trace = True
interpretation = pr.reason()

# Display reasoning results for 'requires_investigation'.
print("\n=== Reasoning Results for 'requires_investigation' ===")
trace = pr.get_rule_trace(interpretation)
print(f"RULE TRACE: \n\n{trace[0]}\n")


# --- Part 2: Risk Evaluator Model Integration ---
# Create another dummy PyTorch model for evaluating account risk.
risk_model = nn.Linear(5, 2)
risk_class_names = ["high_risk", "low_risk"]
risk_features = torch.rand(1, 5)

# Define integration options for the risk evaluator.
risk_interface_options = pr.ModelInterfaceOptions(
threshold=0.5,
set_lower_bound=True,
set_upper_bound=True,
snap_value=1.0
)

# Wrap the risk evaluation model.
risk_evaluator = pr.LogicIntegratedClassifier(
risk_model,
risk_class_names, # document len
identifier="risk_evaluator", # binded constant
interface_options=risk_interface_options
)

# Run the risk evaluator.
logits_risk, probabilities_risk, risk_facts = risk_evaluator(risk_features)
print("\n=== Risk Evaluator Output ===")
print("Logits:", logits_risk)
print("Probabilities:", probabilities_risk)
print("\nGenerated Risk Evaluator Facts:")
for fact in risk_facts:
print(fact)

# --- Context and Reasoning again ---
for fact in risk_facts:
pr.add_fact(fact)

# Rule B: If the fraud detector flags fraud and the risk evaluator flags high risk, mark AccountA for critical action.
pr.add_rule(pr.Rule("critical_action(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud), risk_evaluator(high_risk)", "critical_action_rule"))
pr.add_rule(pr.Rule("critical_action(y) <- critical_action(x), associated(x,y)", "critical_propagation_rule"))

interpretation = pr.reason(again=True)

# Display reasoning results for 'critical_action'.
print("\n=== Reasoning Results for 'critical_action' (Reasoning again) ===")
trace = pr.get_rule_trace(interpretation)
print(f"RULE TRACE: \n\n{trace[0]}\n")
85 changes: 85 additions & 0 deletions examples/temporal_classifier_ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import time
import sys
import os

import torch
import torch.nn as nn
import networkx as nx
import numpy as np
import random
from datetime import timedelta

from pyreason.scripts.learning.classification.temporal_classifier import TemporalLogicIntegratedClassifier
from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
from pyreason.scripts.rules.rule import Rule
from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_fact, add_rule, load_graph, save_rule_trace, get_time, Query

seed_value = 65 # Good Gap Gap
# seed_value = 47 # Good Gap Good
# seed_value = 43 # Good Good Good
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)


def input_fn():
return torch.rand(1, 3) # Dummy input function for the model


weld_model = nn.Linear(3, 2)
class_names = ["good", "gap"]

# Define integration options:
# Only consider probabilities above 0.5, adjust lower bound for high confidence, and use a snap value.
interface_options = ModelInterfaceOptions(
threshold=0.5,
set_lower_bound=True,
set_upper_bound=False,
snap_value=1.0
)

# Wrap the model using LogicIntegratedClassifier.
weld_quality_checker = TemporalLogicIntegratedClassifier(
weld_model,
class_names,
identifier="weld_object",
interface_options=interface_options,
poll_interval=timedelta(seconds=0.5),
# poll_interval=1,
poll_condition="gap",
input_fn=input_fn,
)

add_rule(Rule("repairing(weld_object) <-1 gap(weld_object)", "repair attempted rule"))
add_rule(Rule("defective(weld_object) <-1 gap(weld_object), repairing(weld_object)", "defective rule"))

max_iters = 5
for weld_iter in range(max_iters):
# Time step 1: Initial inspection shows the weld is good.
features = torch.rand(1, 3) # Values chosen to indicate a good weld.
t = get_time()
logits, probs, classifier_facts = weld_quality_checker(features, t1=t, t2=t)
# print(f"=== Weld Inspection for Part: {weld_iter} ===")
# print("Logits:", logits)
# print("Probabilities:", probs)
for fact in classifier_facts:
add_fact(fact)

settings = Settings
# Reasoning
settings.atom_trace = True
settings.verbose = False
again = False if weld_iter == 0 else True
interpretation = reason(timesteps=1, again=again, restart=False)
trace = get_rule_trace(interpretation)
print(f"\n=== Reasoning Rule Trace for Weld Part: {weld_iter} ===")
print(trace[0], "\n\n")

time.sleep(5)

# Check if part is defective
# if get_logic_program().interp.query(Query("defective(weld_object)")):
if interpretation.query(Query("defective(weld_object)")):
print("Defective weld detected! \n Replacing the part.\n\n")
break
80 changes: 0 additions & 80 deletions examples/temporal_classifier_integration_ex.py

This file was deleted.

Loading