diff --git a/.meta/mast/README.md b/.meta/mast/README.md index 29d723b39..4f4ccb2b3 100644 --- a/.meta/mast/README.md +++ b/.meta/mast/README.md @@ -119,3 +119,7 @@ This ensures that when MAST runs with `HF_HUB_OFFLINE=1`, the transformers libra Both cache and model files are stored under: - **Cache**: `/mnt/wsfuse/teamforge/hf` (set via `HF_HOME`) - **Model weights**: `/mnt/wsfuse/teamforge/hf/` + + +#### Weights & Biases +If you are part of the torchforge team on WandB, then WandB will work out of the box; the link can be found in the MAST logs. If you are not part of the torchforge team on WandB, then you will need to set the "WANDB_API_KEY" environment variable to your WandB API key. diff --git a/.meta/mast/main.py b/.meta/mast/main.py index d867db1c2..d901d1ab5 100644 --- a/.meta/mast/main.py +++ b/.meta/mast/main.py @@ -6,6 +6,7 @@ import argparse import asyncio +import os import sys from apps.grpo.main import main as grpo_main @@ -31,6 +32,44 @@ DEFAULT_CHECKPOINT_FOLDER = "/mnt/wsfuse/teamforge/forge_runs/" +def setup_wandb_api_key() -> None: + # add wandb API key to the environment + if "WANDB_API_KEY" in os.environ: + print("[wandb] WANDB_API_KEY already set in environment.") + return + secret_name = "TORCHFORGE_WANDB_API_KEY" + print(f"[wandb] Attempting to retrieve API key from keychain {secret_name=}") + try: + import base64 + + from cif import client + + response = client.request( + "keychain.service", + "getSecretV2", + { + "request": { + "name": secret_name, + } + }, + ) + # decode base64 encoded string + wandb_api_key = base64.b64decode( + # pyrefly: ignore [bad-index] + response["result"]["secret"]["value"] + ).decode("utf-8") + print("[wandb] Successfully retrieved API key from keychain.") + os.environ["WANDB_API_KEY"] = wandb_api_key + os.environ["WANDB_BASE_URL"] = "https://meta.wandb.io/" + except Exception as keychain_exception: + print( + f"[wandb] Failed to retrieve API key from keychain. {keychain_exception=}" + ) + raise RuntimeError( + "Failed to retrieve wandb API key. Cannot launch job" + ) from keychain_exception + + async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None): """Main module for launching mast jobs for GRPO training. @@ -64,7 +103,9 @@ async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None) ) await launcher.launch_mast_job() else: - # In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training + # In remote mode, we're already running inside MAST, so set up wandb api key, mount directory, + # init provisioner and run training + setup_wandb_api_key() mount_mnt_directory("/mnt/wsfuse") await init_provisioner(ProvisionerConfig(launcher_config=launcher_config)) await grpo_main(cfg) diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py index 11db5086a..dbe96d94b 100644 --- a/src/forge/controller/launcher.py +++ b/src/forge/controller/launcher.py @@ -271,20 +271,19 @@ def add_additional_packages(self, packages: "Packages") -> "Packages": def build_appdef(self) -> specs.AppDef: # create the app definition for the worker - remote_end_python_path = ":".join( - [ - f"{self.remote_work_dir}{workspace}" - for workspace in self.editable_workspace_paths - ] - ) + additional_python_paths = [ + f"{self.remote_work_dir}{workspace}" + for workspace in self.editable_workspace_paths + ] + additional_python_paths.append(self.remote_work_dir) + # needed for wandb api key extraction from secret + additional_python_paths.append("/packages/cif") default_envs = { **meta_hyperactor.DEFAULT_NVRT_ENVS, **meta_hyperactor.DEFAULT_NCCL_ENVS, **meta_hyperactor.DEFAULT_TORCH_ENVS, - **{ - "TORCHX_RUN_PYTHONPATH": f"{remote_end_python_path}:{self.remote_work_dir}" - }, + **{"TORCHX_RUN_PYTHONPATH": ":".join(additional_python_paths)}, **{ "HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS": "600", "HYPERACTOR_CODE_MAX_FRAME_LENGTH": "1073741824", @@ -293,7 +292,6 @@ def build_appdef(self) -> specs.AppDef: "TORCHDYNAMO_VERBOSE": "1", "VLLM_TORCH_COMPILE_LEVEL": "0", "VLLM_USE_TRITON_FLASH_ATTN": "0", - "WANDB_MODE": "offline", "HF_HUB_OFFLINE": "1", "MONARCH_HOST_MESH_V1_REMOVE_ME_BEFORE_RELEASE": "1", "TORCHSTORE_RDMA_ENABLED": "1",