Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions align_app/app/alerts_controller.py

This file was deleted.

4 changes: 0 additions & 4 deletions align_app/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .search import SearchController
from .runs_registry import RunsRegistry
from .runs_state_adapter import RunsStateAdapter
from .alerts_controller import AlertsController
from ..adm.decider_registry import create_decider_registry
from ..adm.probe_registry import create_probe_registry
from .import_experiments import import_experiments
Expand Down Expand Up @@ -69,15 +68,12 @@ def __init__(self, server=None):
if experiment_result:
self._runs_registry.add_experiment_items(experiment_result.items)

self._alerts_controller = AlertsController(self.server)

self._runsController = RunsStateAdapter(
self.server,
self._probe_registry,
self._decider_registry,
self._runs_registry,
self.add_system_adm,
self._alerts_controller,
)
self._search_controller = SearchController(
self.server,
Expand Down
60 changes: 31 additions & 29 deletions align_app/app/runs_state_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Optional, Callable
from typing import Optional, Callable
from trame.app import asynchronous
from trame.app.file_upload import ClientFile
from trame.decorators import TrameApp, controller, change, trigger
from trame_alerts.core.service import get_alerts_service
from ..adm.run_models import Run
from .runs_registry import RunsRegistry
from .runs_table_filter import RunsTableFilter
Expand All @@ -14,9 +15,6 @@
from .import_experiments import import_experiments_from_zip
from align_utils.models import AlignmentTarget

if TYPE_CHECKING:
from .alerts_controller import AlertsController


@TrameApp()
class RunsStateAdapter:
Expand All @@ -27,14 +25,13 @@ def __init__(
decider_registry,
runs_registry: RunsRegistry,
add_system_adm_callback: Callable[[str], None],
alerts_controller: "AlertsController",
):
self.server = server
self.runs_registry = runs_registry
self.probe_registry = probe_registry
self.decider_registry = decider_registry
self._add_system_adm_callback = add_system_adm_callback
self._alerts = alerts_controller
self._alerts = get_alerts_service(server)
self.server.state.pending_cache_keys = []
self.server.state.table_collapsed = False
self.server.state.comparison_collapsed = False
Expand Down Expand Up @@ -85,14 +82,6 @@ def _remove_run_from_comparison(self, run_id: str):
if run_id in self.state.runs:
self.state.runs = {k: v for k, v in self.state.runs.items() if k != run_id}

def _update_run_in_comparison(self, run: Run):
"""Update single run in state.runs if it's in comparison."""
if run.id in self.state.runs_to_compare:
run_dict = runs_presentation.run_to_state_dict(
run, self.probe_registry, self.decider_registry
)
self.state.runs = {**self.state.runs, run.id: run_dict}

def _rebuild_comparison_runs(self):
"""Rebuild state.runs from runs_to_compare (for imports/registry changes)."""
new_runs = {}
Expand Down Expand Up @@ -220,14 +209,22 @@ def toggle_comparison_collapsed(self):

@controller.set("toggle_run_in_comparison")
def toggle_run_in_comparison(self, cache_key):
run = self.runs_registry.get_run_by_cache_key(cache_key)
existing_rid = next(
(
rid
for rid in self.state.runs_to_compare
if self.state.runs.get(rid, {}).get("cache_key") == cache_key
),
None,
)

if run and run.id in self.state.runs_to_compare:
if existing_rid is not None:
self.state.runs_to_compare = [
rid for rid in self.state.runs_to_compare if rid != run.id
rid for rid in self.state.runs_to_compare if rid != existing_rid
]
return

run = self.runs_registry.get_run_by_cache_key(cache_key)
if not run:
run = self.runs_registry.materialize_experiment_item(cache_key)
if not run:
Expand Down Expand Up @@ -619,22 +616,23 @@ async def _execute_run_decision(self, run_id: str):

is_cached = self.runs_registry.has_cached_decision(run_id)
if not is_cached:
self._alerts.show("Loading model...")
await self.server.network_completion

self._alerts.show("Making decision...")
alert_id = self._alerts.create_info_alert(
title="Loading model and deciding...", timeout=0
)
else:
alert_id = self._alerts.create_info_alert(title="Deciding...", timeout=0)
await self.server.network_completion

try:
await self.runs_registry.execute_run_decision(run_id)
self._alerts.show("Decision complete", timeout=3000)
self._alerts.remove_alert(alert_id)
self._alerts.create_info_alert(title="Decision complete", timeout=3000)
except Exception as e:
self._alerts.show(f"Decision failed: {e}", timeout=5000)
self._alerts.remove_alert(alert_id)
self._alerts.create_info_alert(title=f"Decision failed: {e}", timeout=5000)

with self.state:
run = self.runs_registry.get_run(run_id)
if run:
self._update_run_in_comparison(run)
self._rebuild_comparison_runs()
self._update_table_rows()
self._remove_pending_cache_key(cache_key)

Expand Down Expand Up @@ -754,8 +752,9 @@ def on_import_experiment_file(self, import_experiment_file, **_):
asynchronous.create_task(self._import_zip_content(file.content))

async def _import_zip_content(self, content: bytes):
with self.state:
self._alerts.show("Loading experiments...")
alert_id = self._alerts.create_info_alert(
title="Loading experiments...", timeout=0
)
await self.server.network_completion

result = import_experiments_from_zip(content)
Expand All @@ -767,7 +766,10 @@ async def _import_zip_content(self, content: bytes):
self._update_table_rows()

self.state.import_experiment_file = None
self._alerts.show(f"Loaded {len(result.items)} experiments", timeout=3000)
self._alerts.remove_alert(alert_id)
self._alerts.create_info_alert(
title=f"Loaded {len(result.items)} experiments", timeout=3000
)

@trigger("import_directory_files")
def trigger_import_directory_files(self, files_data):
Expand Down
21 changes: 6 additions & 15 deletions align_app/app/ui.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from trame.ui.vuetify3 import SinglePageLayout
from trame.widgets import vuetify3, html
from trame.widgets import vuetify3, html, alerts, alerts_vuetify
from ..utils.utils import noop, readable, readable_sentence
from .unordered_object import (
UnorderedObject,
Expand Down Expand Up @@ -1541,20 +1541,9 @@ def __init__(
)

with layout.content:
with vuetify3.VSnackbar(
v_model=("alert_visible", False),
text=("alert_message", ""),
location="bottom left",
color="white",
timeout=("alert_timeout", -1),
content_class="text-h6 font-weight-medium",
):
with vuetify3.Template(v_slot_actions=""):
vuetify3.VBtn(
icon="mdi-close",
variant="text",
click="alert_visible = false",
)
with alerts.AlertsProvider() as alerts_provider:
alerts_provider.bind_controller()
alerts_vuetify.AlertsPopup()
html.Div(
v_html=(
"'<style>"
Expand All @@ -1570,6 +1559,8 @@ def __init__(
".runs-table-panel .v-data-table th { vertical-align: top; }"
".runs-table-panel .v-data-table th:first-child { padding-top: 8px; }"
".drop-zone-active { outline: 3px dashed #1976d2 !important; outline-offset: -3px; }"
".alert-popup-container { left: auto; right: 0; transform: none; width: fit-content; }"
".alert-popup-container .v-alert { --v-theme-info: 66, 66, 66; }"
"</style>'"
)
)
Expand Down