Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "biome",
"productName": "Biome",
"private": true,
"version": "0.1.0",
"version": "0.2.0",
"main": ".vite/build/main.js",
"scripts": {
"dev": "electron-forge start",
Expand Down
123 changes: 85 additions & 38 deletions server-components/engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@
DEVICE = "cuda"
JPEG_QUALITY = 85

# Model-specific runtime configuration
MODEL_CFG = {
"legacy": {
"label": "legacy (single-frame)",
"is_multiframe": False,
"seed_target_size": (360, 640),
"has_prompt_conditioning": False,
},
"waypoint-1.5": {
"label": "waypoint-1.5 (multi-frame)",
"is_multiframe": True,
"seed_target_size": (512, 1024),
"has_prompt_conditioning": False,
},
}

BUTTON_CODES = {}
# A-Z keys
for i in range(65, 91):
Expand Down Expand Up @@ -108,6 +124,10 @@ def __init__(self):
self.model_uri = DEFAULT_MODEL_URI
self.current_prompt = DEFAULT_PROMPT
self.engine_warmed_up = False
self.cfg = MODEL_CFG["legacy"].copy()
self.is_multiframe = self.cfg["is_multiframe"]
self.seed_target_size = self.cfg["seed_target_size"]
self.has_prompt_conditioning = self.cfg["has_prompt_conditioning"]
self._progress_callback = None
self._progress_loop = None
# Prevent concurrent model loads from overlapping across websocket sessions.
Expand Down Expand Up @@ -152,6 +172,22 @@ def _normalize_model_uri(self, model_uri: str | None) -> str:
or DEFAULT_MODEL_URI
)

def _resolve_runtime_cfg(self, model_cfg) -> dict:
"""Resolve runtime config from defaults and override certain values from model_cfg (just prompt_conditioning for now)."""
model_type = getattr(model_cfg, "model_type", None)
if model_type is not None and model_type not in MODEL_CFG:
raise RuntimeError(
f"Unsupported model_type '{model_type}'. Only 'waypoint-1.5' and legacy (no model_type) are supported."
)

cfg_key = model_type or "legacy"
cfg = MODEL_CFG[cfg_key].copy()
cfg["has_prompt_conditioning"] = (
getattr(model_cfg, "prompt_conditioning", None) is not None
)

return cfg

async def _run_on_cuda_thread(self, fn):
"""Run callable on the dedicated CUDA thread."""
loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -189,11 +225,13 @@ def _unload_engine_sync(self):
self.engine = None
self.seed_frame = None
self.engine_warmed_up = False
self.cfg = MODEL_CFG["legacy"].copy()
self.is_multiframe = self.cfg["is_multiframe"]
self.seed_target_size = self.cfg["seed_target_size"]
self.has_prompt_conditioning = self.cfg["has_prompt_conditioning"]
self._free_cuda_memory_sync()

def _load_seed_from_file_sync(
self, file_path: str, target_size: tuple[int, int] = (360, 640)
) -> torch.Tensor:
def _load_seed_from_file_sync(self, file_path: str) -> torch.Tensor:
"""Synchronous helper to load a seed frame from a file path."""
try:
img = Image.open(file_path).convert("RGB")
Expand All @@ -203,28 +241,27 @@ def _load_seed_from_file_sync(
torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0).float()
)
frame = F.interpolate(
img_tensor, size=target_size, mode="bilinear", align_corners=False
img_tensor, size=self.seed_target_size, mode="bilinear", align_corners=False
)[0]
return (
frame = (
frame.to(dtype=torch.uint8, device=DEVICE)
.permute(1, 2, 0)
.contiguous()
)
if self.is_multiframe:
frame = frame.unsqueeze(0).expand(4, -1, -1, -1).contiguous()
return frame
except Exception as e:
logger.error(f"Failed to load seed from file {file_path}: {e}")
return None

async def load_seed_from_file(
self, file_path: str, target_size: tuple[int, int] = (360, 640)
) -> torch.Tensor:
async def load_seed_from_file(self, file_path: str) -> torch.Tensor:
"""Load a seed frame from a file path (async wrapper)."""
return await self._run_on_cuda_thread(
lambda: self._load_seed_from_file_sync(file_path, target_size)
lambda: self._load_seed_from_file_sync(file_path)
)

def _load_seed_from_base64_sync(
self, base64_data: str, target_size: tuple[int, int] = (360, 640)
) -> torch.Tensor:
def _load_seed_from_base64_sync(self, base64_data: str) -> torch.Tensor:
"""Synchronous helper to load a seed frame from base64 encoded data."""
try:
img_data = base64.b64decode(base64_data)
Expand All @@ -235,23 +272,24 @@ def _load_seed_from_base64_sync(
torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0).float()
)
frame = F.interpolate(
img_tensor, size=target_size, mode="bilinear", align_corners=False
img_tensor, size=self.seed_target_size, mode="bilinear", align_corners=False
)[0]
return (
frame = (
frame.to(dtype=torch.uint8, device=DEVICE)
.permute(1, 2, 0)
.contiguous()
)
if self.is_multiframe:
frame = frame.unsqueeze(0).expand(4, -1, -1, -1).contiguous()
return frame
except Exception as e:
logger.error(f"Failed to load seed from base64: {e}")
return None

async def load_seed_from_base64(
self, base64_data: str, target_size: tuple[int, int] = (360, 640)
) -> torch.Tensor:
async def load_seed_from_base64(self, base64_data: str) -> torch.Tensor:
"""Load a seed frame from base64 encoded data (async wrapper)."""
return await self._run_on_cuda_thread(
lambda: self._load_seed_from_base64_sync(base64_data, target_size)
lambda: self._load_seed_from_base64_sync(base64_data)
)

async def load_engine(self, model_uri: str | None = None):
Expand Down Expand Up @@ -297,9 +335,6 @@ async def load_engine(self, model_uri: str | None = None):
logger.info(f" N_FRAMES: {N_FRAMES}")
logger.info(f" Prompt: {self.current_prompt[:60]}...")

# Model config overrides
# scheduler_sigmas: diffusion denoising schedule (MUST end with 0.0)
# ae_uri: VAE model for encoding/decoding frames
model_start = time.perf_counter()
dtype_attempts = [torch.bfloat16, torch.float16]
new_engine = None
Expand All @@ -313,11 +348,6 @@ def _create_engine():
return WorldEngine(
requested_model,
device=DEVICE,
model_config_overrides={
"n_frames": N_FRAMES,
"ae_uri": "OpenWorldLabs/owl_vae_f16_c16_distill_v0_nogan",
"scheduler_sigmas": [1.0, 0.8, 0.2, 0.0],
},
quant=QUANT,
dtype=dtype,
)
Expand Down Expand Up @@ -349,6 +379,16 @@ def _create_engine():
)
logger.info(f"[2/4] Loaded with dtype={selected_dtype}")
self._log_cuda_memory("after load")

# Resolve runtime config from defaults overridden by model config.
self.cfg = self._resolve_runtime_cfg(self.engine.model_cfg)
self.is_multiframe = self.cfg["is_multiframe"]
self.seed_target_size = self.cfg["seed_target_size"]
self.has_prompt_conditioning = self.cfg["has_prompt_conditioning"]
logger.info(f"[2/4] Model type: {self.cfg['label']}")
logger.info(f"[2/4] Seed target size: {self.seed_target_size}")
logger.info(f"[2/4] Prompt conditioning: {self.has_prompt_conditioning}")

self._report_progress(SESSION_LOADING_DONE)
self.model_uri = requested_model

Expand Down Expand Up @@ -430,12 +470,15 @@ async def init_session(self):
logger.info(f"[INIT] engine.append_frame() took {time.perf_counter() - t0:.2f}s")

self._report_progress(SESSION_INIT_FRAME)
t0 = time.perf_counter()
logger.info("[INIT] Starting engine.set_prompt()...")
await self._run_on_cuda_thread(
lambda: self.engine.set_prompt(self.current_prompt)
)
logger.info(f"[INIT] engine.set_prompt() took {time.perf_counter() - t0:.2f}s")
if self.has_prompt_conditioning:
t0 = time.perf_counter()
logger.info("[INIT] Starting engine.set_prompt()...")
await self._run_on_cuda_thread(
lambda: self.engine.set_prompt(self.current_prompt)
)
logger.info(f"[INIT] engine.set_prompt() took {time.perf_counter() - t0:.2f}s")
else:
logger.info(f"[INIT] No prompt conditioning enabled, skipping engine.set_prompt()")

async def recover_from_cuda_error(self):
"""
Expand Down Expand Up @@ -497,12 +540,16 @@ def do_warmup():
)

self._report_progress(SESSION_WARMUP_PROMPT)
logger.info("[5/5] Step 3: Setting prompt...")
prompt_start = time.perf_counter()
self.engine.set_prompt(self.current_prompt)
logger.info(
f"[5/5] Step 3: Prompt set in {time.perf_counter() - prompt_start:.2f}s"
)

if self.has_prompt_conditioning:
logger.info("[5/5] Step 3: Setting prompt...")
prompt_start = time.perf_counter()
self.engine.set_prompt(self.current_prompt)
logger.info(
f"[5/5] Step 3: Prompt set in {time.perf_counter() - prompt_start:.2f}s"
)
else:
logger.info("[5/5] Step 3: Skipping prompt conditioning...")

self._report_progress(SESSION_WARMUP_COMPILE)
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion server-components/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "biome-server"
version = "0.1.0"
requires-python = "==3.13.*"
dependencies = [
"world-engine @ https://github.com/Overworldai/world_engine/archive/a30a00c302380c0f657347e8456bb6837ff37c22.zip",
"world-engine @ https://github.com/Overworldai/world_engine/archive/refs/heads/orthorope.zip",
"torch",
"torchvision",
"pillow",
Expand Down
Loading