Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion .meta/mast/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
import asyncio
import os
import sys

from apps.grpo.main import main as grpo_main
Expand All @@ -31,6 +32,41 @@
DEFAULT_CHECKPOINT_FOLDER = "/mnt/wsfuse/teamforge/forge_runs/"


def setup_wandb_api_key() -> None:
# add wandb API key to the environment
secret_name = "TORCHFORGE_WANDB_API_KEY"
print(f"[wandb] Attempting to retrieve API key from keychain {secret_name=}")
try:
import base64

from cif import client

response = client.request(
"keychain.service",
"getSecretV2",
{
"request": {
"name": secret_name,
}
},
)
# decode base64 encoded string
wandb_api_key = base64.b64decode(
# pyrefly: ignore [bad-index]
response["result"]["secret"]["value"]
).decode("utf-8")
print("[wandb] Successfully retrieved API key from keychain.")
os.environ["WANDB_API_KEY"] = wandb_api_key
os.environ["WANDB_BASE_URL"] = "https://meta.wandb.io/"
except Exception as keychain_exception:
print(
f"[wandb] Failed to retrieve API key from keychain. {keychain_exception=}"
)
raise RuntimeError(
"Failed to retrieve wandb API key. Cannot launch job"
) from keychain_exception


async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None):
"""Main module for launching mast jobs for GRPO training.

Expand Down Expand Up @@ -64,7 +100,9 @@ async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None)
)
await launcher.launch_mast_job()
else:
# In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training
# In remote mode, we're already running inside MAST, so set up wandb api key, mount directory,
# init provisioner and run training
setup_wandb_api_key()
mount_mnt_directory("/mnt/wsfuse")
await init_provisioner(ProvisionerConfig(launcher_config=launcher_config))
await grpo_main(cfg)
Expand Down
18 changes: 8 additions & 10 deletions src/forge/controller/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,20 +271,19 @@ def add_additional_packages(self, packages: "Packages") -> "Packages":

def build_appdef(self) -> specs.AppDef:
# create the app definition for the worker
remote_end_python_path = ":".join(
[
f"{self.remote_work_dir}{workspace}"
for workspace in self.editable_workspace_paths
]
)
additional_python_paths = [
f"{self.remote_work_dir}{workspace}"
for workspace in self.editable_workspace_paths
]
additional_python_paths.append(self.remote_work_dir)

# needed for wandb api key extraction from secret
additional_python_paths.append("/packages/cif")
default_envs = {
**meta_hyperactor.DEFAULT_NVRT_ENVS,
**meta_hyperactor.DEFAULT_NCCL_ENVS,
**meta_hyperactor.DEFAULT_TORCH_ENVS,
**{
"TORCHX_RUN_PYTHONPATH": f"{remote_end_python_path}:{self.remote_work_dir}"
},
**{"TORCHX_RUN_PYTHONPATH": ":".join(additional_python_paths)},
**{
"HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS": "600",
"HYPERACTOR_CODE_MAX_FRAME_LENGTH": "1073741824",
Expand All @@ -293,7 +292,6 @@ def build_appdef(self) -> specs.AppDef:
"TORCHDYNAMO_VERBOSE": "1",
"VLLM_TORCH_COMPILE_LEVEL": "0",
"VLLM_USE_TRITON_FLASH_ATTN": "0",
"WANDB_MODE": "offline",
"HF_HUB_OFFLINE": "1",
"MONARCH_HOST_MESH_V1_REMOVE_ME_BEFORE_RELEASE": "1",
"TORCHSTORE_RDMA_ENABLED": "1",
Expand Down
Loading