Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ __pycache__/
*$py.class
*.pth
*.pt
*.nc
# C extensions
*.so
*.DS_Store
Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ uv sync
source .venv/bin/activate
```

4. **Zero-Shot Inference**
```bash
python easy_inference/run_easy_inference.py --config-path easy_inference/config_easy.yaml
```

#### Note
The above command is the fastest path to run Surya foundation-model inference on a date range without any additional setup. Data is downloaded from the public S3 bucket `nasa-surya-bench`. This prompts for start/end UTC datetime, downloads needed `.nc` files, and writes:
- `easy_inference/outputs_.../prediction.nc`
- `easy_inference/outputs_.../metrics/*`

**Device selection behavior**
- Please refer to the [config_easy.yaml](easy_inference/config_easy.yaml) file for the default inference configuration.
- `advanced.device: auto` uses priority: `cuda -> mps -> cpu`
- Works across CUDA GPUs, Apple Silicon/macOS MPS, and plain CPU systems

---

### 🧪 Verify Installation

Run the end-to-end test to ensure everything is working:
Expand Down
57 changes: 57 additions & 0 deletions easy_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Easy Inference

Use this folder for the simplest Surya flow:
1. choose date window,
2. download only required hourly files,
3. run rollout inference,
4. save one `prediction.nc`.

## Quick start

```bash
source .venv/bin/activate
bash easy_inference/run_easy_inference.sh
```

Non-interactive defaults:

```bash
bash easy_inference/run_easy_inference.sh --no-prompt
```

## Config

Edit `easy_inference/config_easy.yaml`.

- Normal users: edit only the top `user:` section.
- Advanced users: optional changes in `advanced:`.

Default and override behavior:

```bash
# Uses easy_inference/config_easy.yaml by default.
python easy_inference/run_easy_inference.py

# Optional: use a different YAML file.
python easy_inference/run_easy_inference.py --config-path /path/to/custom_easy.yaml
```

### Debug mode

Set in `advanced:`:
- `debug_mode: true`
- optional `debug_log_path: "path/to/inference_debug.txt"` (default is `<user.output_dir>/inference_debug.txt`)

When enabled, the text log contains stage timings and per-step diagnostics with line number + UTC timestamp:
- input file read / transform timing
- GT file read timing
- per-step forward / CPU-copy / inverse-transform / write timing
- per-step memory stats (`CUDA` peak/allocated/reserved when available)

## Metrics Notebook

Use `easy_inference/compare_prediction_groundtruth.ipynb` to compare `prediction.nc` vs GT and compute:
- overall metrics (`MSE`, `RMSE`, `MAE`, `bias`, `max_abs_error`)
- per-channel metrics
- per-step metrics
- visual prediction vs ground-truth plots
170 changes: 170 additions & 0 deletions easy_inference/compare_prediction_groundtruth.ipynb

Large diffs are not rendered by default.

97 changes: 97 additions & 0 deletions easy_inference/config_easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# --------------------------------------------------------------------
# Edit only this USER section for normal use.
# --------------------------------------------------------------------
user:
# Inclusive UTC start datetime for the download/inference window.
# Format: YYYY-MM-DD HH:MM[:SS]
start_datetime: "2014-10-23 10:00:00"
# Inclusive UTC end datetime for the download/inference window.
# Format: YYYY-MM-DD HH:MM[:SS]
end_datetime: "2014-10-23 17:00:00"
# If true, prompt in terminal to confirm/override start/end datetime and rollout steps.
prompt_for_dates: true
# Directory where run artifacts are written (prediction.nc, metrics CSV/JSON).
output_dir: easy_inference/outputs_24h
# Number of autoregressive prediction steps to generate.
rollout_steps: 5

# --------------------------------------------------------------------
# Advanced section (optional). Leave as-is unless needed.
# --------------------------------------------------------------------
advanced:
# Local path to foundation model architecture/config YAML.
# Downloaded automatically from model_repo_id if missing.
foundation_config_path: data/Surya-1.0/config.yaml
# Local path to scaler definitions used for inverse transform.
# Downloaded automatically from model_repo_id if missing.
scalers_path: data/Surya-1.0/scalers.yaml
# Local path to model weights checkpoint (.pt).
# Downloaded automatically from model_repo_id if missing.
weights_path: data/Surya-1.0/surya.366m.v1.pt
# Hugging Face repository used to fetch missing model assets.
model_repo_id: nasa-ibm-ai4science/Surya-1.0
# Files to pull from model_repo_id when assets are missing locally.
model_allow_patterns:
# Foundation model config.
- config.yaml
# Data scaler config.
- scalers.yaml
# Foundation model weights.
- surya.366m.v1.pt

# Local folder for downloaded/available validation .nc files.
validation_data_dir: data/Surya-1.0_validation_data_20141023_60min
# CSV index generated for the requested date window (used by dataset loader).
index_path: easy_inference/index_20141023_60min.csv

# Expected cadence (minutes) between consecutive source files/timestamps.
cadence_minutes: 60
# Relative input frame offsets (in minutes) used to build model input sequence.
# Example [-60, 0] means: previous hour + current time as input.
time_delta_input_minutes: [-60, 0]
# Target offset (minutes) for the first prediction horizon.
time_delta_target_minutes: 60

# Download settings
# Public S3 bucket containing benchmark .nc files.
s3_bucket: nasa-surya-bench
# If true, do not re-download files that already exist locally.
download_skip_existing: true
# If true, compare local file size with remote and re-download on mismatch.
download_verify_size: false
# Allowed timestamp matching tolerance (minutes) when mapping expected times to files.
# 0 means exact timestamp match only.
download_match_tolerance_minutes: 0
# If true, remove local validation files outside the requested window before download.
prune_validation_data_to_window: false

# Runtime
# Device selection. Values: auto | cuda | mps | cpu
# auto resolves in this order: cuda -> mps -> cpu.
device: auto
# Compute dtype for inference. Values: auto | float32 | float16 | bfloat16
dtype: auto
# Batch is fixed to 1 in easy mode (first valid sample only).
# Number of DataLoader worker processes (0 = main process).
num_workers: 0
# DataLoader prefetch batches per worker (used only when num_workers > 0).
prefetch_factor: 2
# Number of background workers for GT timestep prefetch during rollout (>=1).
gt_prefetch_workers: 4
# If true, disable autocast mixed precision even when supported.
disable_autocast: false
# If true, allow TF32 fast matmul/cudnn paths on CUDA.
enable_tf32: true
# If true, enable cuDNN benchmark autotuning (CUDA only).
enable_cudnn_benchmark: true
# CPU thread count for torch. 0 leaves PyTorch default behavior.
cpu_threads: 0
# If true, print progress logs for download and inference stages.
show_progress: true
# If true, write detailed debug profiling logs (plain text).
debug_mode: false
# Optional path for debug log file. Empty -> <user.output_dir>/inference_debug.txt
debug_log_path: "easy_inference/inference_debug.txt"

# Output dtype for saved prediction.nc values. Values: float16 | float32
prediction_dtype: float32
Loading