Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ jobs:
- name: Install dependencies
run: uv sync --group dev

- name: Copy to .env
run: cp .env.dev.example .env

- name: Run tests
run: uv run pytest

Expand Down
7 changes: 5 additions & 2 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ RUN apk add --no-cache wget
# Install uv.
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

# Copy the application into the container.
COPY . /app
# Copy the dependency files.
COPY pyproject.toml uv.lock /app/

# Install the application dependencies.
WORKDIR /app
RUN uv sync --frozen --no-cache

# Copy the rest of the application into the container.
COPY . /app

# Run the application.
CMD ["/app/.venv/bin/fastapi", "run", "src/main.py", "--port", "80", "--host", "0.0.0.0"]
7 changes: 5 additions & 2 deletions api/src/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from fastapi.security import OAuth2PasswordBearer


from src.core.settings import settings


oauth_2_scheme = OAuth2PasswordBearer(tokenUrl="token")

AUTH_SERVER_URL = os.getenv("KEYCLOAK_URL")
KEYCLOAK_ISSUER_URL = os.getenv("KEYCLOAK_ISSUER_URL", AUTH_SERVER_URL)
AUTH_SERVER_URL = settings.KEYCLOAK_URL
KEYCLOAK_ISSUER_URL = settings.KEYCLOAK_ISSUER_URL or AUTH_SERVER_URL
RESOURCE_SERVER_ID = "api"
_JWKS_CLIENTS: dict[str, PyJWKClient] = {}

Expand Down
10 changes: 10 additions & 0 deletions api/src/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ class Settings(BaseSettings):
POSTGRES_USER: str
POSTGRES_PASSWORD: str = ""
POSTGRES_DB: str = ""

# Keycloak
KEYCLOAK_URL: str = ""
KEYCLOAK_INTERNAL_URL: str = ""
KEYCLOAK_ISSUER_URL: str = ""
CLIENT_SECRET: str = ""

WEB_URL: str = "http://localhost:5173"
API_URL: str = "http://localhost:8000"

# MongoDB
MONGODB_URI: str = "mongodb://template_user:template_pass@mongo:27017/securelearning?authSource=securelearning"
Expand Down Expand Up @@ -48,6 +57,7 @@ class Settings(BaseSettings):
RABBITMQ_USER: str
RABBITMQ_PASS: str
RABBITMQ_QUEUE: str
RABBITMQ_TRACKING_QUEUE: str = "tracking_queue"

# Statistics
# Users who fell for phishing in more than this fraction of campaigns are
Expand Down
11 changes: 6 additions & 5 deletions api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from src.core.object_storage import ensure_bucket, garage_enabled
from src.core.settings import settings
from src.tasks import start_scheduler, shutdown_scheduler

from src.tasks.tracking_consumer import start_tracking_consumer, shutdown_tracking_consumer

@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -34,11 +34,12 @@ async def lifespan(app: FastAPI):
await ensure_bucket(settings.GARAGE_BUCKET_LOGOS)
try:
start_scheduler()
start_tracking_consumer()
yield
finally:
shutdown_scheduler()
shutdown_tracking_consumer()
await close_mongo_client()
shutdown_scheduler()


app = FastAPI(
title="Project Template API",
Expand All @@ -49,8 +50,8 @@ async def lifespan(app: FastAPI):
)

origins = [
os.getenv("WEB_URL"),
os.getenv("API_URL"),
settings.WEB_URL,
settings.API_URL,
]

app.add_middleware(
Expand Down
1 change: 1 addition & 0 deletions api/src/models/email_sending/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class EmailSendingStatus(StrEnum):
SCHEDULED = "scheduled"
QUEUED = "queued"
SENT = "sent"
OPENED = "opened"
CLICKED = "clicked"
Expand Down
12 changes: 8 additions & 4 deletions api/src/routers/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from src.core.dependencies import SessionDep, OAuth2Scheme
from src.services.tracking import TrackingService
from src.core.settings import settings

router = APIRouter()

Expand Down Expand Up @@ -68,7 +69,7 @@ def track_sent(si: str, session: SessionDep):
)


@router.post(
@router.get(
"/track/open",
status_code=200,
description="Tracking pixel endpoint - records email opens",
Expand Down Expand Up @@ -100,13 +101,16 @@ async def track_click(si: str, session: SessionDep):

@router.post(
"/track/phish",
status_code=200,
status_code=303,
description="Phishing event endpoint - records when user submits credentials on landing page",
)
def track_phish(si: str, session: SessionDep):
"""
Called when user submits credentials on the landing page.
Records the phishing event.
Records the phishing event and redirects to the simulation oops page.
"""
service.record_phish(si, session)
return {"message": "Event recorded"}
return RedirectResponse(
url=f"{settings.WEB_URL}/simulation-oops.html",
status_code=303
)
3 changes: 2 additions & 1 deletion api/src/services/compliance/token_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import jwt
from jwt import PyJWKClient
from fastapi import HTTPException, status
from src.core.settings import settings


REALM_PATH = "/realms/"
AUTH_SERVER_URL = os.getenv("KEYCLOAK_INTERNAL_URL") or os.getenv("KEYCLOAK_URL")
AUTH_SERVER_URL = settings.KEYCLOAK_INTERNAL_URL or settings.KEYCLOAK_URL
SYSTEM_REALMS = {"platform", "master"}
PRIVILEGED_COMPLIANCE_ROLES = {"admin", "org_manager", "content_manager"}

Expand Down
13 changes: 5 additions & 8 deletions api/src/services/keycloak_admin/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@
import json
from pathlib import Path
from fastapi import HTTPException
from dotenv import load_dotenv
from src.core.settings import settings
from src.services.keycloak_client import get_keycloak_client


load_dotenv()


class base_handler:
def __init__(self):

self.keycloak_url = os.getenv("KEYCLOAK_URL")
self.admin_secret = os.getenv("CLIENT_SECRET")
self.web_url = os.getenv("WEB_URL", "http://localhost:3000")
self.api_url = os.getenv("API_URL", "http://localhost:8080")
self.keycloak_url = settings.KEYCLOAK_URL
self.admin_secret = settings.CLIENT_SECRET
self.web_url = settings.WEB_URL
self.api_url = settings.API_URL
self.keycloak_client = get_keycloak_client()

if not self.keycloak_url:
Expand Down
8 changes: 3 additions & 5 deletions api/src/services/keycloak_client/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
import json
import requests
from fastapi import HTTPException
from dotenv import load_dotenv

load_dotenv()
from src.core.settings import settings


class base_handler:
"""Base handler with shared configuration and HTTP helpers."""

def __init__(self):
self.keycloak_url = os.getenv("KEYCLOAK_URL")
self.admin_secret = os.getenv("CLIENT_SECRET")
self.keycloak_url = settings.KEYCLOAK_URL
self.admin_secret = settings.CLIENT_SECRET

if not self.keycloak_url:
raise HTTPException(
Expand Down
100 changes: 59 additions & 41 deletions api/src/services/tracking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from fastapi import HTTPException
from sqlmodel import Session, select, text
from sqlmodel import Session, select, text, update

from src.models import Campaign, EmailSending, EmailSendingStatus

Expand All @@ -18,14 +18,16 @@ def record_sent(self, tracking_token: str, session: Session) -> EmailSending:
if sending.sent_at is None:
sending.sent_at = datetime.now()
sending.status = EmailSendingStatus.SENT

# Increment campaign counter
campaign = session.get(Campaign, sending.campaign_id)
if campaign:
campaign.total_sent += 1
session.add(campaign)
session.commit()
session.refresh(sending)
session.add(sending)

# Increment campaign counter atomically
session.exec(
update(Campaign)
.where(Campaign.id == sending.campaign_id)
.values(total_sent=Campaign.total_sent + 1)
)
session.commit()
session.refresh(sending)

return sending

Expand All @@ -37,14 +39,16 @@ def record_open(self, tracking_token: str, session: Session) -> EmailSending:
if sending.opened_at is None:
sending.opened_at = datetime.now()
sending.status = EmailSendingStatus.OPENED

# Increment campaign counter using ORM
campaign = session.get(Campaign, sending.campaign_id)
if campaign:
campaign.total_opened += 1
session.add(campaign)
session.commit()
session.refresh(sending)
session.add(sending)

# Increment campaign counter atomically
session.exec(
update(Campaign)
.where(Campaign.id == sending.campaign_id)
.values(total_opened=Campaign.total_opened + 1)
)
session.commit()
session.refresh(sending)

return sending

Expand All @@ -55,24 +59,28 @@ def record_click(self, tracking_token: str, session: Session) -> EmailSending:
# Record open if not already recorded (click implies open)
if sending.opened_at is None:
sending.opened_at = datetime.now()
# Increment campaign counter using ORM
campaign = session.get(Campaign, sending.campaign_id)
if campaign:
campaign.total_opened += 1
session.add(campaign)
session.commit()
# Increment campaign counter atomically
session.exec(
update(Campaign)
.where(Campaign.id == sending.campaign_id)
.values(total_opened=Campaign.total_opened + 1)
)

# Only count first click
if sending.clicked_at is None:
sending.clicked_at = datetime.now()
sending.status = EmailSendingStatus.CLICKED
# Increment campaign counter using ORM
campaign = session.get(Campaign, sending.campaign_id)
if campaign:
campaign.total_clicked += 1
session.add(campaign)
session.commit()
session.refresh(sending)
session.add(sending)

# Increment campaign counter atomically
session.exec(
update(Campaign)
.where(Campaign.id == sending.campaign_id)
.values(total_clicked=Campaign.total_clicked + 1)
)

session.commit()
session.refresh(sending)

return sending

Expand All @@ -99,26 +107,36 @@ def record_phish(self, tracking_token: str, session: Session) -> EmailSending:
sending = self._get_sending_by_token(tracking_token, session)

# Record open and click if not already recorded (phish implies both)
campaign = session.get(Campaign, sending.campaign_id)
if sending.opened_at is None:
sending.opened_at = datetime.now()
if campaign:
campaign.total_opened += 1
session.exec(
update(Campaign)
.where(Campaign.id == sending.campaign_id)
.values(total_opened=Campaign.total_opened + 1)
)

if sending.clicked_at is None:
sending.clicked_at = datetime.now()
if campaign:
campaign.total_clicked += 1
session.exec(
update(Campaign)
.where(Campaign.id == sending.campaign_id)
.values(total_clicked=Campaign.total_clicked + 1)
)

# Only count first phish
if sending.phished_at is None:
sending.phished_at = datetime.now()
sending.status = EmailSendingStatus.PHISHED
if campaign:
campaign.total_phished += 1
if campaign:
session.add(campaign)
session.commit()
session.refresh(sending)
session.add(sending)

session.exec(
update(Campaign)
.where(Campaign.id == sending.campaign_id)
.values(total_phished=Campaign.total_phished + 1)
)

session.commit()
session.refresh(sending)

return sending

Expand Down
10 changes: 5 additions & 5 deletions api/src/tasks/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ def process_pending_emails() -> None:
try:
# Send email to RabbitMQ
campaign_service._send_email_to_rabbitmq(email, campaign)

# Update status and timestamp for sent emails
email.status = EmailSendingStatus.SENT
email.sent_at = datetime.now()


# Update status to queued
email.status = EmailSendingStatus.QUEUED
session.commit()

except (ValueError, ValidationError) as e:
# Irrecoverable payload/configuration issue for this email.
email.status = EmailSendingStatus.FAILED
Expand All @@ -188,6 +187,7 @@ def process_pending_emails() -> None:
logger.error(
f"Failed email {email.id} for campaign {email.campaign_id} marked FAILED: {e}"
)

except Exception as e:
logger.error(
f"Failed to process email {email.id} for campaign {email.campaign_id}: {e}"
Expand Down
Loading
Loading