Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5a4c0a3
Added conversion/extra variable functionality to anemoidataset source…
evenmn Sep 9, 2025
c00695b
Added expression evaluation to anemoidatasets verif source
evenmn Sep 16, 2025
68d113c
Minor
evenmn Sep 16, 2025
c662f19
Added support for variable expressions in NetCDF as well
evenmn Sep 17, 2025
8707302
Add setup_logging, change prints to log.
ways Sep 17, 2025
c42edb1
Change all prints to log.
ways Sep 17, 2025
ecede56
Change all prints to log.
ways Sep 17, 2025
ce2e82e
Change all prints to log.
ways Sep 17, 2025
63b5bd2
Lint
ways Sep 17, 2025
ed31052
Small formatting changes.
ways Sep 17, 2025
0220e88
Added conversion/extra variable functionality to anemoidataset source…
evenmn Sep 9, 2025
266038d
Added expression evaluation to anemoidatasets verif source
evenmn Sep 16, 2025
5c45810
Minor
evenmn Sep 16, 2025
2d3da3a
Added support for variable expressions in NetCDF as well
evenmn Sep 17, 2025
dd4a735
Merge
evenmn Sep 18, 2025
a9d0003
Ruff
evenmn Sep 18, 2025
7da9b5c
Ruff
evenmn Sep 18, 2025
9c28122
Ruff
evenmn Sep 18, 2025
d7b57df
Unruff
evenmn Sep 18, 2025
bd5bff1
Adjust log levels.
ways Sep 18, 2025
34a0c40
Adjust log levels.
ways Sep 18, 2025
70f4f55
Merge pull request #182 from metno/add-logging-timing
ways Sep 18, 2025
0960087
Added conversion/extra variable functionality to anemoidataset source…
evenmn Sep 9, 2025
b1b596b
Added expression evaluation to anemoidatasets verif source
evenmn Sep 16, 2025
1f0fd91
Minor
evenmn Sep 16, 2025
c78ca49
MergE
evenmn Sep 19, 2025
c9c7c5e
Added expression evaluation to anemoidatasets verif source
evenmn Sep 16, 2025
4d44e35
Added support for variable expressions in NetCDF as well
evenmn Sep 17, 2025
9013aa9
Merge
evenmn Sep 19, 2025
ee61557
Ruff
evenmn Sep 18, 2025
b03e858
Ruff
evenmn Sep 18, 2025
1f45151
Unruff
evenmn Sep 18, 2025
9ce0d95
Merge
evenmn Sep 19, 2025
72a0134
Final (?) ruff fix
evenmn Sep 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions bris/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
from datetime import datetime, timedelta

from anemoi.utils.dates import frequency_to_seconds
Expand All @@ -11,20 +12,22 @@
from .checkpoint import Checkpoint
from .inference import Inference
from .utils import (
LOGGER,
create_config,
get_all_leadtimes,
parse_args,
set_base_seed,
set_encoder_decoder_num_chunks,
setup_logging,
)
from .writer import CustomWriter

LOGGER = logging.getLogger(__name__)


def main(arg_list: list[str] | None = None):
t0 = time.perf_counter()
args = parse_args(arg_list)
config = create_config(args["config"], args)
setup_logging(config)

models = list(config.checkpoints.keys())
checkpoints = {
Expand Down Expand Up @@ -88,8 +91,8 @@ def main(arg_list: list[str] | None = None):
),
"%Y-%m-%dT%H:%M:%S",
)
LOGGER.info(
"No start_date given, setting %s based on start_date and timestep.",
LOGGER.warning(
"No start_date given, setting %s based on end_date and timestep.",
config.start_date,
)
else:
Expand Down Expand Up @@ -143,13 +146,6 @@ def main(arg_list: list[str] | None = None):
)
writer = CustomWriter(decoder_outputs, write_interval="batch")

# Set hydra defaults
config.defaults = [
{"override hydra/job_logging": "none"}, # disable config parsing logs
{"override hydra/hydra_logging": "none"}, # disable config parsing logs
"_self_",
]

# Forecaster must know about what leadtimes to output
model = instantiate(
config.model,
Expand Down Expand Up @@ -180,9 +176,13 @@ def main(arg_list: list[str] | None = None):
if is_main_thread:
for decoder_output in decoder_outputs:
for output in decoder_output["outputs"]:
t1 = time.perf_counter()
output.finalize()
LOGGER.debug(
f"finalizing decoder {decoder_output} output {output.filename_pattern} in {time.perf_counter() - t1:.1f}s"
)

print("Model run completed. 🤖")
LOGGER.info(f"Bris completed in {time.perf_counter() - t0:.1f}s. 🤖")


if __name__ == "__main__":
Expand Down
12 changes: 8 additions & 4 deletions bris/conventions/metno.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Additionally, the names of some dimension-variables do not use CF-names
"""

from bris.utils import LOGGER


class Metno:
cf_to_metno = {
Expand Down Expand Up @@ -46,12 +48,14 @@ def get_ncname(self, cfname: str, leveltype: str, level: int):
# This is likely a forcing variable
return cfname
else:
print(cfname, leveltype, level)
LOGGER.error(
f"get_ncname not implemented cfname {cfname}, leveltype {leveltype}, level {level}"
)
raise NotImplementedError()

return ncname

def is_single_level(self, cfname: str, leveltype: str) -> str:
def is_single_level(self, cfname: str, leveltype: str) -> bool:
"""Returns true if there should only be a single level in the level dimension for this
variable.

Expand All @@ -69,13 +73,13 @@ def is_single_level(self, cfname: str, leveltype: str) -> str:
"wind_speed",
] and leveltype in ["height"]

def get_name(self, cfname: str):
def get_name(self, cfname: str) -> str:
"""Get MetNorway's dimension name from cf standard name"""
if cfname in self.cf_to_metno:
return self.cf_to_metno[cfname]
return cfname

def get_cfname(self, ncname):
def get_cfname(self, ncname) -> str:
"""Get the CF-standard name from a given MetNo name"""
for k, v in self.cf_to_metno.items():
if v == ncname:
Expand Down
8 changes: 3 additions & 5 deletions bris/data/dataset/nativegrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from torch.utils.data import IterableDataset

from bris.data.grid_indices import BaseGridIndices
from bris.utils import get_base_seed, get_usable_indices

LOGGER = logging.getLogger(__name__)
from bris.utils import LOGGER, get_base_seed, get_usable_indices


class NativeGridDataset(IterableDataset):
Expand Down Expand Up @@ -200,8 +198,8 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:
f"num_data_parallel = num_nodes * num_gpus_per_node / num_gpus_per_model"
)
if len(self.valid_date_indices) % self.ens_comm_num_groups != 0:
print(
f"Warning: Dataloader has {len(self.valid_date_indices)} samples, which is not divisible by "
LOGGER.warning(
f"Dataloader has {len(self.valid_date_indices)} samples, which is not divisible by "
f"{self.ens_comm_num_groups} data parallel workers. This will lead to "
f"{len(self.valid_date_indices) % self.ens_comm_num_groups} unprocessed samples.",
"num_data_parallel = num_nodes * num_gpus_per_node / num_gpus_per_model",
Expand Down
16 changes: 12 additions & 4 deletions bris/ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn

from bris.utils import get_base_seed

LOGGER = logging.getLogger(__name__)
from bris.utils import LOGGER, get_base_seed


class DDPGroupStrategy(DDPStrategy):
"""Distributed Data Parallel strategy with group communication."""

# Define type of model, set in DDPStrategy somewhere
model: pl.LightningModule

def __init__(
self,
num_gpus_per_model: int,
Expand Down Expand Up @@ -89,6 +90,13 @@ def setup(self, trainer: pl.Trainer) -> None:

# set up reader groups by further splitting model_comm_group_ranks with read_group_size:

LOGGER.debug(
"world_size %d, model_comm_group_size %d, read_group_size %d",
self.world_size,
self.model_comm_group_size,
self.read_group_size,
)

assert self.model_comm_group_size % self.read_group_size == 0, (
f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size "
f"({self.read_group_size})."
Expand Down Expand Up @@ -226,7 +234,7 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, in

def get_my_reader_group(
self, model_comm_group_rank: int, read_group_size: int
) -> tuple[int, int, int]:
) -> tuple[int, int, int, int]:
"""Determine tasks that work together and from a reader group.

Parameters
Expand Down
9 changes: 6 additions & 3 deletions bris/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from functools import cached_property
from typing import Any, Optional

Expand All @@ -7,11 +8,10 @@
from anemoi.utils.config import DotDict

from bris.ddp_strategy import DDPGroupStrategy
from bris.utils import LOGGER

from .data.datamodule import DataModule

LOGGER = logging.getLogger(__name__)


class Inference:
def __init__(
Expand Down Expand Up @@ -41,7 +41,7 @@ def device(self) -> str:
LOGGER.info("Specified device not set. Found GPU")
return "cuda"

LOGGER.info("Specified device not set. Could not find gpu, using CPU")
LOGGER.warning("Specified device not set. Could not find gpu, using CPU")
return "cpu"

LOGGER.info("Using specified device: %s", self._device)
Expand Down Expand Up @@ -74,6 +74,9 @@ def trainer(self) -> pl.Trainer:
return trainer

def run(self):
t0 = time.perf_counter()
LOGGER.debug("Bris/Inference/run Predicting")
self.trainer.predict(
self.model, datamodule=self.datamodule, return_predictions=False
)
LOGGER.debug(f"bris/Inference.run: {time.perf_counter() - t0:.1f}s")
3 changes: 1 addition & 2 deletions bris/model/brispredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
get_dynamic_forcings,
)
from ..utils import (
LOGGER,
check_anemoi_training,
timedelta64_from_timestep,
)
from .basepredictor import BasePredictor
from .model_utils import get_model_static_forcings, get_variable_indices

LOGGER = logging.getLogger(__name__)


class BrisPredictor(BasePredictor):
"""
Expand Down
50 changes: 25 additions & 25 deletions bris/outputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy
import time
from typing import Optional

import numpy as np

from bris import sources
from bris.predict_metadata import PredictMetadata
from bris.utils import LOGGER, datetime_to_unixtime, expr_to_var, safe_eval_expr


def instantiate(name: str, predict_metadata: PredictMetadata, workdir: str, init_args):
Expand Down Expand Up @@ -46,32 +48,20 @@ def get_required_variables(name, init_args):
provided
"""

if name == "netcdf":
if name in ["netcdf", "grib"]:
if "variables" in init_args:
variables = init_args["variables"]
if "extra_variables" in init_args:
for var_name in init_args["extra_variables"]:
if var_name == "ws":
variables += ["10u", "10v"]
extr_vars, _, _ = expr_to_var(var_name)
variables.extend(extr_vars)
variables = sorted(list(set(variables)))
return variables
return [None]

if name in ["verif", "powerspectrum_gridded", "powerspectrum_global"]:
if init_args["variable"] == "ws":
return ["10u", "10v"]
return [init_args["variable"]]

if name == "grib":
if "variables" in init_args:
variables = init_args["variables"]
if "extra_variables" in init_args:
for name in init_args["extra_variables"]:
if name == "ws":
variables += ["10u", "10v"]
variables = sorted(list(set(variables)))
return variables
return [None]
extr_vars, _, _ = expr_to_var(init_args["variable"])
return extr_vars

raise ValueError(f"Invalid output: {name}")

Expand Down Expand Up @@ -110,22 +100,28 @@ def add_forecast(self, times: list, ensemble_member: int, pred: np.ndarray):
# Append extra variables to prediction
for name in self.extra_variables:
if name not in self.pm.variables:
self.pm.variables.append(name)
clean_name = name.replace("[", "").replace("]", "").replace("*", "")
self.pm.variables.append(clean_name)

# only do this once. For multiple members, intermediate calls this several times
t0 = time.perf_counter()
if pred.shape[2] != len(self.pm.variables):
# Append extra variables to prediction
extra_pred = []
for name in self.extra_variables:
if name == "ws":
Ix = self.pm.variables.index("10u")
Iy = self.pm.variables.index("10v")
curr = np.sqrt(pred[..., [Ix]] ** 2 + pred[..., [Iy]] ** 2)
extra_pred += [curr]
else:
raise ValueError(f"No recipe to compute {name}")
extr_vars, name, success = expr_to_var(name)
assert success, "Variables could not be extracted from expression"

variables_dict = {}
for v in extr_vars:
idx = self.pm.variables.index(v)
variables_dict[v] = pred[..., idx]
extra_pred += [safe_eval_expr(name, variables_dict)[..., None]]

pred = np.concatenate([pred] + extra_pred, axis=2)
LOGGER.debug(
f"outputs.add_forecast Calculate ws in {time.perf_counter() - t0:.1f}s"
)

assert pred.shape[0] == self.pm.num_leadtimes
assert pred.shape[1] == len(self.pm.lats)
Expand All @@ -136,7 +132,11 @@ def add_forecast(self, times: list, ensemble_member: int, pred: np.ndarray):
assert ensemble_member >= 0
assert ensemble_member < self.pm.num_members

t1 = time.perf_counter()
self._add_forecast(times, ensemble_member, pred)
LOGGER.debug(
f"outputs.add_forecast called _add_forecast in {time.perf_counter() - t1:.1f}s"
)

def _add_forecast(self, times: list, ensemble_member: int, pred: np.ndarray):
"""Subclasses should implement this"""
Expand Down
Loading